tint->dawn: Shuffle source tree in preperation of merging repos
docs/ -> docs/tint/
fuzzers/ -> src/tint/fuzzers/
samples/ -> src/tint/cmd/
src/ -> src/tint/
test/ -> test/tint/
BUG=tint:1418,tint:1433
Change-Id: Id2aa79f989aef3245b80ef4aa37a27ff16cd700b
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/80482
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
diff --git a/src/tint/resolver/array_accessor_test.cc b/src/tint/resolver/array_accessor_test.cc
new file mode 100644
index 0000000..c8fcb4e
--- /dev/null
+++ b/src/tint/resolver/array_accessor_test.cc
@@ -0,0 +1,312 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/reference_type.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ResolverIndexAccessorTest = ResolverTest;
+
+TEST_F(ResolverIndexAccessorTest, Matrix_Dynamic_F32) {
+ Global("my_var", ty.mat2x3<f32>(), ast::StorageClass::kPrivate);
+ auto* acc = IndexAccessor("my_var", Expr(Source{{12, 34}}, 1.0f));
+ WrapInFunction(acc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: index must be of type 'i32' or 'u32', found: 'f32'");
+}
+
+TEST_F(ResolverIndexAccessorTest, Matrix_Dynamic_Ref) {
+ Global("my_var", ty.mat2x3<f32>(), ast::StorageClass::kPrivate);
+ auto* idx = Var("idx", ty.i32(), Construct(ty.i32()));
+ auto* acc = IndexAccessor("my_var", idx);
+ WrapInFunction(Decl(idx), acc);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverIndexAccessorTest, Matrix_BothDimensions_Dynamic_Ref) {
+ Global("my_var", ty.mat4x4<f32>(), ast::StorageClass::kPrivate);
+ auto* idx = Var("idx", ty.u32(), Expr(3u));
+ auto* idy = Var("idy", ty.u32(), Expr(2u));
+ auto* acc = IndexAccessor(IndexAccessor("my_var", idx), idy);
+ WrapInFunction(Decl(idx), Decl(idy), acc);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverIndexAccessorTest, Matrix_Dynamic) {
+ GlobalConst("my_const", ty.mat2x3<f32>(), Construct(ty.mat2x3<f32>()));
+ auto* idx = Var("idx", ty.i32(), Construct(ty.i32()));
+ auto* acc = IndexAccessor("my_const", Expr(Source{{12, 34}}, idx));
+ WrapInFunction(Decl(idx), acc);
+
+ EXPECT_TRUE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "");
+}
+
+TEST_F(ResolverIndexAccessorTest, Matrix_XDimension_Dynamic) {
+ GlobalConst("my_var", ty.mat4x4<f32>(), Construct(ty.mat4x4<f32>()));
+ auto* idx = Var("idx", ty.u32(), Expr(3u));
+ auto* acc = IndexAccessor("my_var", Expr(Source{{12, 34}}, idx));
+ WrapInFunction(Decl(idx), acc);
+
+ EXPECT_TRUE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "");
+}
+
+TEST_F(ResolverIndexAccessorTest, Matrix_BothDimension_Dynamic) {
+ GlobalConst("my_var", ty.mat4x4<f32>(), Construct(ty.mat4x4<f32>()));
+ auto* idx = Var("idy", ty.u32(), Expr(2u));
+ auto* acc =
+ IndexAccessor(IndexAccessor("my_var", Expr(Source{{12, 34}}, idx)), 1);
+ WrapInFunction(Decl(idx), acc);
+
+ EXPECT_TRUE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "");
+}
+
+TEST_F(ResolverIndexAccessorTest, Matrix) {
+ Global("my_var", ty.mat2x3<f32>(), ast::StorageClass::kPrivate);
+
+ auto* acc = IndexAccessor("my_var", 2);
+ WrapInFunction(acc);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(acc), nullptr);
+ ASSERT_TRUE(TypeOf(acc)->Is<sem::Reference>());
+
+ auto* ref = TypeOf(acc)->As<sem::Reference>();
+ ASSERT_TRUE(ref->StoreType()->Is<sem::Vector>());
+ EXPECT_EQ(ref->StoreType()->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_F(ResolverIndexAccessorTest, Matrix_BothDimensions) {
+ Global("my_var", ty.mat2x3<f32>(), ast::StorageClass::kPrivate);
+
+ auto* acc = IndexAccessor(IndexAccessor("my_var", 2), 1);
+ WrapInFunction(acc);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(acc), nullptr);
+ ASSERT_TRUE(TypeOf(acc)->Is<sem::Reference>());
+
+ auto* ref = TypeOf(acc)->As<sem::Reference>();
+ EXPECT_TRUE(ref->StoreType()->Is<sem::F32>());
+}
+
+TEST_F(ResolverIndexAccessorTest, Vector_F32) {
+ Global("my_var", ty.vec3<f32>(), ast::StorageClass::kPrivate);
+ auto* acc = IndexAccessor("my_var", Expr(Source{{12, 34}}, 2.0f));
+ WrapInFunction(acc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: index must be of type 'i32' or 'u32', found: 'f32'");
+}
+
+TEST_F(ResolverIndexAccessorTest, Vector_Dynamic_Ref) {
+ Global("my_var", ty.vec3<f32>(), ast::StorageClass::kPrivate);
+ auto* idx = Var("idx", ty.i32(), Expr(2));
+ auto* acc = IndexAccessor("my_var", idx);
+ WrapInFunction(Decl(idx), acc);
+
+ EXPECT_TRUE(r()->Resolve());
+}
+
+TEST_F(ResolverIndexAccessorTest, Vector_Dynamic) {
+ GlobalConst("my_var", ty.vec3<f32>(), Construct(ty.vec3<f32>()));
+ auto* idx = Var("idx", ty.i32(), Expr(2));
+ auto* acc = IndexAccessor("my_var", Expr(Source{{12, 34}}, idx));
+ WrapInFunction(Decl(idx), acc);
+
+ EXPECT_TRUE(r()->Resolve());
+}
+
+TEST_F(ResolverIndexAccessorTest, Vector) {
+ Global("my_var", ty.vec3<f32>(), ast::StorageClass::kPrivate);
+
+ auto* acc = IndexAccessor("my_var", 2);
+ WrapInFunction(acc);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(acc), nullptr);
+ ASSERT_TRUE(TypeOf(acc)->Is<sem::Reference>());
+
+ auto* ref = TypeOf(acc)->As<sem::Reference>();
+ EXPECT_TRUE(ref->StoreType()->Is<sem::F32>());
+}
+
+TEST_F(ResolverIndexAccessorTest, Array) {
+ auto* idx = Expr(2);
+ Global("my_var", ty.array<f32, 3>(), ast::StorageClass::kPrivate);
+
+ auto* acc = IndexAccessor("my_var", idx);
+ WrapInFunction(acc);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(acc), nullptr);
+ ASSERT_TRUE(TypeOf(acc)->Is<sem::Reference>());
+
+ auto* ref = TypeOf(acc)->As<sem::Reference>();
+ EXPECT_TRUE(ref->StoreType()->Is<sem::F32>());
+}
+
+TEST_F(ResolverIndexAccessorTest, Alias_Array) {
+ auto* aary = Alias("myarrty", ty.array<f32, 3>());
+
+ Global("my_var", ty.Of(aary), ast::StorageClass::kPrivate);
+
+ auto* acc = IndexAccessor("my_var", 2);
+ WrapInFunction(acc);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(acc), nullptr);
+ ASSERT_TRUE(TypeOf(acc)->Is<sem::Reference>());
+
+ auto* ref = TypeOf(acc)->As<sem::Reference>();
+ EXPECT_TRUE(ref->StoreType()->Is<sem::F32>());
+}
+
+TEST_F(ResolverIndexAccessorTest, Array_Constant) {
+ GlobalConst("my_var", ty.array<f32, 3>(), array<f32, 3>());
+
+ auto* acc = IndexAccessor("my_var", 2);
+ WrapInFunction(acc);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(acc), nullptr);
+ EXPECT_TRUE(TypeOf(acc)->Is<sem::F32>()) << TypeOf(acc)->type_name();
+}
+
+TEST_F(ResolverIndexAccessorTest, Array_Dynamic_I32) {
+ // let a : array<f32, 3> = 0;
+ // var idx : i32 = 0;
+ // var f : f32 = a[idx];
+ auto* a = Const("a", ty.array<f32, 3>(), array<f32, 3>());
+ auto* idx = Var("idx", ty.i32(), Construct(ty.i32()));
+ auto* f = Var("f", ty.f32(), IndexAccessor("a", Expr(Source{{12, 34}}, idx)));
+ Func("my_func", ast::VariableList{}, ty.void_(),
+ {
+ Decl(a),
+ Decl(idx),
+ Decl(f),
+ },
+ ast::AttributeList{});
+
+ EXPECT_TRUE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "");
+}
+
+TEST_F(ResolverIndexAccessorTest, Array_Literal_F32) {
+ // let a : array<f32, 3>;
+ // var f : f32 = a[2.0f];
+ auto* a = Const("a", ty.array<f32, 3>(), array<f32, 3>());
+ auto* f =
+ Var("a_2", ty.f32(), IndexAccessor("a", Expr(Source{{12, 34}}, 2.0f)));
+ Func("my_func", ast::VariableList{}, ty.void_(),
+ {
+ Decl(a),
+ Decl(f),
+ },
+ ast::AttributeList{});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: index must be of type 'i32' or 'u32', found: 'f32'");
+}
+
+TEST_F(ResolverIndexAccessorTest, Array_Literal_I32) {
+ // let a : array<f32, 3>;
+ // var f : f32 = a[2];
+ auto* a = Const("a", ty.array<f32, 3>(), array<f32, 3>());
+ auto* f = Var("a_2", ty.f32(), IndexAccessor("a", 2));
+ Func("my_func", ast::VariableList{}, ty.void_(),
+ {
+ Decl(a),
+ Decl(f),
+ },
+ ast::AttributeList{});
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverIndexAccessorTest, EXpr_Deref_FuncGoodParent) {
+ // fn func(p: ptr<function, vec4<f32>>) -> f32 {
+ // let idx: u32 = u32();
+ // let x: f32 = (*p)[idx];
+ // return x;
+ // }
+ auto* p =
+ Param("p", ty.pointer(ty.vec4<f32>(), ast::StorageClass::kFunction));
+ auto* idx = Const("idx", ty.u32(), Construct(ty.u32()));
+ auto* star_p = Deref(p);
+ auto* accessor_expr = IndexAccessor(Source{{12, 34}}, star_p, idx);
+ auto* x = Var("x", ty.f32(), accessor_expr);
+ Func("func", {p}, ty.f32(), {Decl(idx), Decl(x), Return(x)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverIndexAccessorTest, EXpr_Deref_FuncBadParent) {
+ // fn func(p: ptr<function, vec4<f32>>) -> f32 {
+ // let idx: u32 = u32();
+ // let x: f32 = *p[idx];
+ // return x;
+ // }
+ auto* p =
+ Param("p", ty.pointer(ty.vec4<f32>(), ast::StorageClass::kFunction));
+ auto* idx = Const("idx", ty.u32(), Construct(ty.u32()));
+ auto* accessor_expr = IndexAccessor(Source{{12, 34}}, p, idx);
+ auto* star_p = Deref(accessor_expr);
+ auto* x = Var("x", ty.f32(), star_p);
+ Func("func", {p}, ty.f32(), {Decl(idx), Decl(x), Return(x)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: cannot index type 'ptr<function, vec4<f32>, read_write>'");
+}
+
+TEST_F(ResolverIndexAccessorTest, Exr_Deref_BadParent) {
+ // var param: vec4<f32>
+ // let x: f32 = *(¶m)[0];
+ auto* param = Var("param", ty.vec4<f32>());
+ auto* idx = Var("idx", ty.u32(), Construct(ty.u32()));
+ auto* addressOf_expr = AddressOf(param);
+ auto* accessor_expr = IndexAccessor(Source{{12, 34}}, addressOf_expr, idx);
+ auto* star_p = Deref(accessor_expr);
+ auto* x = Var("x", ty.f32(), star_p);
+ WrapInFunction(param, idx, x);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: cannot index type 'ptr<function, vec4<f32>, read_write>'");
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/assignment_validation_test.cc b/src/tint/resolver/assignment_validation_test.cc
new file mode 100644
index 0000000..a69d6a0
--- /dev/null
+++ b/src/tint/resolver/assignment_validation_test.cc
@@ -0,0 +1,403 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/ast/struct_block_attribute.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/storage_texture_type.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ResolverAssignmentValidationTest = ResolverTest;
+
+TEST_F(ResolverAssignmentValidationTest, ReadOnlyBuffer) {
+ // [[block]] struct S { m : i32 };
+ // @group(0) @binding(0)
+ // var<storage,read> a : S;
+ auto* s = Structure("S", {Member("m", ty.i32())},
+ {create<ast::StructBlockAttribute>()});
+ Global(Source{{12, 34}}, "a", ty.Of(s), ast::StorageClass::kStorage,
+ ast::Access::kRead,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ WrapInFunction(Assign(Source{{56, 78}}, MemberAccessor("a", "m"), 1));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "56:78 error: cannot store into a read-only type 'ref<storage, "
+ "i32, read>'");
+}
+
+TEST_F(ResolverAssignmentValidationTest, AssignIncompatibleTypes) {
+ // {
+ // var a : i32 = 2;
+ // a = 2.3;
+ // }
+
+ auto* var = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2));
+
+ auto* assign = Assign(Source{{12, 34}}, "a", 2.3f);
+ WrapInFunction(var, assign);
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), "12:34 error: cannot assign 'f32' to 'i32'");
+}
+
+TEST_F(ResolverAssignmentValidationTest,
+ AssignArraysWithDifferentSizeExpressions_Pass) {
+ // let len = 4u;
+ // {
+ // var a : array<f32, 4>;
+ // var b : array<f32, len>;
+ // a = b;
+ // }
+
+ GlobalConst("len", nullptr, Expr(4u));
+
+ auto* a = Var("a", ty.array(ty.f32(), 4));
+ auto* b = Var("b", ty.array(ty.f32(), "len"));
+
+ auto* assign = Assign(Source{{12, 34}}, "a", "b");
+ WrapInFunction(a, b, assign);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverAssignmentValidationTest,
+ AssignArraysWithDifferentSizeExpressions_Fail) {
+ // let len = 5u;
+ // {
+ // var a : array<f32, 4>;
+ // var b : array<f32, len>;
+ // a = b;
+ // }
+
+ GlobalConst("len", nullptr, Expr(5u));
+
+ auto* a = Var("a", ty.array(ty.f32(), 4));
+ auto* b = Var("b", ty.array(ty.f32(), "len"));
+
+ auto* assign = Assign(Source{{12, 34}}, "a", "b");
+ WrapInFunction(a, b, assign);
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cannot assign 'array<f32, 5>' to 'array<f32, 4>'");
+}
+
+TEST_F(ResolverAssignmentValidationTest,
+ AssignCompatibleTypesInBlockStatement_Pass) {
+ // {
+ // var a : i32 = 2;
+ // a = 2
+ // }
+ auto* var = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2));
+ WrapInFunction(var, Assign("a", 2));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverAssignmentValidationTest,
+ AssignIncompatibleTypesInBlockStatement_Fail) {
+ // {
+ // var a : i32 = 2;
+ // a = 2.3;
+ // }
+
+ auto* var = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2));
+ WrapInFunction(var, Assign(Source{{12, 34}}, "a", 2.3f));
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), "12:34 error: cannot assign 'f32' to 'i32'");
+}
+
+TEST_F(ResolverAssignmentValidationTest,
+ AssignIncompatibleTypesInNestedBlockStatement_Fail) {
+ // {
+ // {
+ // var a : i32 = 2;
+ // a = 2.3;
+ // }
+ // }
+
+ auto* var = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2));
+ auto* inner_block = Block(Decl(var), Assign(Source{{12, 34}}, "a", 2.3f));
+ auto* outer_block = Block(inner_block);
+ WrapInFunction(outer_block);
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), "12:34 error: cannot assign 'f32' to 'i32'");
+}
+
+TEST_F(ResolverAssignmentValidationTest, AssignToScalar_Fail) {
+ // var my_var : i32 = 2;
+ // 1 = my_var;
+
+ auto* var = Var("my_var", ty.i32(), ast::StorageClass::kNone, Expr(2));
+ WrapInFunction(var, Assign(Expr(Source{{12, 34}}, 1), "my_var"));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: cannot assign to value of type 'i32'");
+}
+
+TEST_F(ResolverAssignmentValidationTest, AssignCompatibleTypes_Pass) {
+ // var a : i32 = 2;
+ // a = 2
+ auto* var = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2));
+ WrapInFunction(var, Assign(Source{{12, 34}}, "a", 2));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverAssignmentValidationTest,
+ AssignCompatibleTypesThroughAlias_Pass) {
+ // alias myint = i32;
+ // var a : myint = 2;
+ // a = 2
+ auto* myint = Alias("myint", ty.i32());
+ auto* var = Var("a", ty.Of(myint), ast::StorageClass::kNone, Expr(2));
+ WrapInFunction(var, Assign(Source{{12, 34}}, "a", 2));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverAssignmentValidationTest,
+ AssignCompatibleTypesInferRHSLoad_Pass) {
+ // var a : i32 = 2;
+ // var b : i32 = 3;
+ // a = b;
+ auto* var_a = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2));
+ auto* var_b = Var("b", ty.i32(), ast::StorageClass::kNone, Expr(3));
+ WrapInFunction(var_a, var_b, Assign(Source{{12, 34}}, "a", "b"));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverAssignmentValidationTest, AssignThroughPointer_Pass) {
+ // var a : i32;
+ // let b : ptr<function,i32> = &a;
+ // *b = 2;
+ const auto func = ast::StorageClass::kFunction;
+ auto* var_a = Var("a", ty.i32(), func, Expr(2));
+ auto* var_b = Const("b", ty.pointer<int>(func), AddressOf(Expr("a")));
+ WrapInFunction(var_a, var_b, Assign(Source{{12, 34}}, Deref("b"), 2));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverAssignmentValidationTest, AssignToConstant_Fail) {
+ // {
+ // let a : i32 = 2;
+ // a = 2
+ // }
+ auto* var = Const("a", ty.i32(), Expr(2));
+ WrapInFunction(var, Assign(Expr(Source{{12, 34}}, "a"), 2));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cannot assign to const\nnote: 'a' is declared here:");
+}
+
+TEST_F(ResolverAssignmentValidationTest, AssignNonConstructible_Handle) {
+ // var a : texture_storage_1d<rgba8unorm, write>;
+ // var b : texture_storage_1d<rgba8unorm, write>;
+ // a = b;
+
+ auto make_type = [&] {
+ return ty.storage_texture(ast::TextureDimension::k1d,
+ ast::TexelFormat::kRgba8Unorm,
+ ast::Access::kWrite);
+ };
+
+ Global("a", make_type(), ast::StorageClass::kNone,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+ Global("b", make_type(), ast::StorageClass::kNone,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(1),
+ create<ast::GroupAttribute>(0),
+ });
+
+ WrapInFunction(Assign(Source{{56, 78}}, "a", "b"));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "56:78 error: storage type of assignment must be constructible");
+}
+
+TEST_F(ResolverAssignmentValidationTest, AssignNonConstructible_Atomic) {
+ // [[block]] struct S { a : atomic<i32>; };
+ // @group(0) @binding(0) var<storage, read_write> v : S;
+ // v.a = v.a;
+
+ auto* s = Structure("S", {Member("a", ty.atomic(ty.i32()))},
+ {create<ast::StructBlockAttribute>()});
+ Global(Source{{12, 34}}, "v", ty.Of(s), ast::StorageClass::kStorage,
+ ast::Access::kReadWrite,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ WrapInFunction(Assign(Source{{56, 78}}, MemberAccessor("v", "a"),
+ MemberAccessor("v", "a")));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "56:78 error: storage type of assignment must be constructible");
+}
+
+TEST_F(ResolverAssignmentValidationTest, AssignNonConstructible_RuntimeArray) {
+ // [[block]] struct S { a : array<f32>; };
+ // @group(0) @binding(0) var<storage, read_write> v : S;
+ // v.a = v.a;
+
+ auto* s = Structure("S", {Member("a", ty.array(ty.f32()))},
+ {create<ast::StructBlockAttribute>()});
+ Global(Source{{12, 34}}, "v", ty.Of(s), ast::StorageClass::kStorage,
+ ast::Access::kReadWrite,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ WrapInFunction(Assign(Source{{56, 78}}, MemberAccessor("v", "a"),
+ MemberAccessor("v", "a")));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "56:78 error: storage type of assignment must be constructible");
+}
+
+TEST_F(ResolverAssignmentValidationTest,
+ AssignToPhony_NonConstructibleStruct_Fail) {
+ // [[block]]
+ // struct S {
+ // arr: array<i32>;
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ // fn f() {
+ // _ = s;
+ // }
+ auto* s = Structure("S", {Member("arr", ty.array<i32>())}, {StructBlock()});
+ Global("s", ty.Of(s), ast::StorageClass::kStorage, GroupAndBinding(0, 0));
+
+ WrapInFunction(Assign(Phony(), Expr(Source{{12, 34}}, "s")));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cannot assign 'S' to '_'. "
+ "'_' can only be assigned a constructible, pointer, texture or "
+ "sampler type");
+}
+
+TEST_F(ResolverAssignmentValidationTest, AssignToPhony_DynamicArray_Fail) {
+ // [[block]]
+ // struct S {
+ // arr: array<i32>;
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ // fn f() {
+ // _ = s.arr;
+ // }
+ auto* s = Structure("S", {Member("arr", ty.array<i32>())}, {StructBlock()});
+ Global("s", ty.Of(s), ast::StorageClass::kStorage, GroupAndBinding(0, 0));
+
+ WrapInFunction(Assign(Phony(), MemberAccessor(Source{{12, 34}}, "s", "arr")));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: cannot assign 'array<i32>' to '_'. "
+ "'_' can only be assigned a constructible, pointer, texture or sampler "
+ "type");
+}
+
+TEST_F(ResolverAssignmentValidationTest, AssignToPhony_Pass) {
+ // [[block]]
+ // struct S {
+ // i: i32;
+ // arr: array<i32>;
+ // };
+ // [[block]]
+ // struct U {
+ // i: i32;
+ // };
+ // @group(0) @binding(0) var tex texture_2d;
+ // @group(0) @binding(1) var smp sampler;
+ // @group(0) @binding(2) var<uniform> u : U;
+ // @group(0) @binding(3) var<storage, read_write> s : S;
+ // var<workgroup> wg : array<f32, 10>
+ // fn f() {
+ // _ = 1;
+ // _ = 2u;
+ // _ = 3.0;
+ // _ = vec2<bool>();
+ // _ = tex;
+ // _ = smp;
+ // _ = &s;
+ // _ = s.i;
+ // _ = &s.arr;
+ // _ = u;
+ // _ = u.i;
+ // _ = wg;
+ // _ = wg[3];
+ // }
+ auto* S = Structure("S",
+ {
+ Member("i", ty.i32()),
+ Member("arr", ty.array<i32>()),
+ },
+ {StructBlock()});
+ auto* U = Structure("U", {Member("i", ty.i32())}, {StructBlock()});
+ Global("tex", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()),
+ GroupAndBinding(0, 0));
+ Global("smp", ty.sampler(ast::SamplerKind::kSampler), GroupAndBinding(0, 1));
+ Global("u", ty.Of(U), ast::StorageClass::kUniform, GroupAndBinding(0, 2));
+ Global("s", ty.Of(S), ast::StorageClass::kStorage, GroupAndBinding(0, 3));
+ Global("wg", ty.array<f32, 10>(), ast::StorageClass::kWorkgroup);
+
+ WrapInFunction(Assign(Phony(), 1), //
+ Assign(Phony(), 2), //
+ Assign(Phony(), 3), //
+ Assign(Phony(), vec2<bool>()), //
+ Assign(Phony(), "tex"), //
+ Assign(Phony(), "smp"), //
+ Assign(Phony(), AddressOf("s")), //
+ Assign(Phony(), MemberAccessor("s", "i")), //
+ Assign(Phony(), AddressOf(MemberAccessor("s", "arr"))), //
+ Assign(Phony(), "u"), //
+ Assign(Phony(), MemberAccessor("u", "i")), //
+ Assign(Phony(), "wg"), //
+ Assign(Phony(), IndexAccessor("wg", 3)));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/atomics_test.cc b/src/tint/resolver/atomics_test.cc
new file mode 100644
index 0000000..04592b0
--- /dev/null
+++ b/src/tint/resolver/atomics_test.cc
@@ -0,0 +1,74 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/struct_block_attribute.h"
+#include "src/tint/resolver/resolver.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/atomic_type.h"
+#include "src/tint/sem/reference_type.h"
+
+#include "gmock/gmock.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+struct ResolverAtomicTest : public resolver::TestHelper,
+ public testing::Test {};
+
+TEST_F(ResolverAtomicTest, GlobalWorkgroupI32) {
+ auto* g = Global("a", ty.atomic(Source{{12, 34}}, ty.i32()),
+ ast::StorageClass::kWorkgroup);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(TypeOf(g)->Is<sem::Reference>());
+ auto* atomic = TypeOf(g)->UnwrapRef()->As<sem::Atomic>();
+ ASSERT_NE(atomic, nullptr);
+ EXPECT_TRUE(atomic->Type()->Is<sem::I32>());
+}
+
+TEST_F(ResolverAtomicTest, GlobalWorkgroupU32) {
+ auto* g = Global("a", ty.atomic(Source{{12, 34}}, ty.u32()),
+ ast::StorageClass::kWorkgroup);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(TypeOf(g)->Is<sem::Reference>());
+ auto* atomic = TypeOf(g)->UnwrapRef()->As<sem::Atomic>();
+ ASSERT_NE(atomic, nullptr);
+ EXPECT_TRUE(atomic->Type()->Is<sem::U32>());
+}
+
+TEST_F(ResolverAtomicTest, GlobalStorageStruct) {
+ auto* s = Structure("s", {Member("a", ty.atomic(Source{{12, 34}}, ty.i32()))},
+ {create<ast::StructBlockAttribute>()});
+ auto* g = Global("g", ty.Of(s), ast::StorageClass::kStorage,
+ ast::Access::kReadWrite,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(TypeOf(g)->Is<sem::Reference>());
+ auto* str = TypeOf(g)->UnwrapRef()->As<sem::Struct>();
+ ASSERT_NE(str, nullptr);
+ ASSERT_EQ(str->Members().size(), 1u);
+ auto* atomic = str->Members()[0]->Type()->As<sem::Atomic>();
+ ASSERT_NE(atomic, nullptr);
+ ASSERT_TRUE(atomic->Type()->Is<sem::I32>());
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/atomics_validation_test.cc b/src/tint/resolver/atomics_validation_test.cc
new file mode 100644
index 0000000..47da93b
--- /dev/null
+++ b/src/tint/resolver/atomics_validation_test.cc
@@ -0,0 +1,332 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/atomic_type.h"
+#include "src/tint/sem/reference_type.h"
+
+#include "gmock/gmock.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+struct ResolverAtomicValidationTest : public resolver::TestHelper,
+ public testing::Test {};
+
+TEST_F(ResolverAtomicValidationTest, StorageClass_WorkGroup) {
+ Global("a", ty.atomic(Source{{12, 34}}, ty.i32()),
+ ast::StorageClass::kWorkgroup);
+
+ EXPECT_TRUE(r()->Resolve());
+}
+
+TEST_F(ResolverAtomicValidationTest, StorageClass_Storage) {
+ auto* s = Structure("s", {Member("a", ty.atomic(Source{{12, 34}}, ty.i32()))},
+ {StructBlock()});
+ Global("g", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kReadWrite,
+ GroupAndBinding(0, 0));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverAtomicValidationTest, InvalidType) {
+ Global("a", ty.atomic(ty.f32(Source{{12, 34}})),
+ ast::StorageClass::kWorkgroup);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: atomic only supports i32 or u32 types");
+}
+
+TEST_F(ResolverAtomicValidationTest, InvalidStorageClass_Simple) {
+ Global("a", ty.atomic(Source{{12, 34}}, ty.i32()),
+ ast::StorageClass::kPrivate);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: atomic variables must have <storage> or <workgroup> "
+ "storage class");
+}
+
+TEST_F(ResolverAtomicValidationTest, InvalidStorageClass_Array) {
+ Global("a", ty.atomic(Source{{12, 34}}, ty.i32()),
+ ast::StorageClass::kPrivate);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: atomic variables must have <storage> or <workgroup> "
+ "storage class");
+}
+
+TEST_F(ResolverAtomicValidationTest, InvalidStorageClass_Struct) {
+ auto* s =
+ Structure("s", {Member("a", ty.atomic(Source{{12, 34}}, ty.i32()))});
+ Global("g", ty.Of(s), ast::StorageClass::kPrivate);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "error: atomic variables must have <storage> or <workgroup> "
+ "storage class\n"
+ "note: atomic sub-type of 's' is declared here");
+}
+
+TEST_F(ResolverAtomicValidationTest, InvalidStorageClass_StructOfStruct) {
+ // struct Inner { m : atomic<i32>; };
+ // struct Outer { m : array<Inner, 4>; };
+ // var<private> g : Outer;
+
+ auto* Inner =
+ Structure("Inner", {Member("m", ty.atomic(Source{{12, 34}}, ty.i32()))});
+ auto* Outer = Structure("Outer", {Member("m", ty.Of(Inner))});
+ Global("g", ty.Of(Outer), ast::StorageClass::kPrivate);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "error: atomic variables must have <storage> or <workgroup> "
+ "storage class\n"
+ "note: atomic sub-type of 'Outer' is declared here");
+}
+
+TEST_F(ResolverAtomicValidationTest,
+ InvalidStorageClass_StructOfStructOfArray) {
+ // struct Inner { m : array<atomic<i32>, 4>; };
+ // struct Outer { m : array<Inner, 4>; };
+ // var<private> g : Outer;
+
+ auto* Inner =
+ Structure("Inner", {Member(Source{{12, 34}}, "m", ty.atomic(ty.i32()))});
+ auto* Outer = Structure("Outer", {Member("m", ty.Of(Inner))});
+ Global("g", ty.Of(Outer), ast::StorageClass::kPrivate);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "error: atomic variables must have <storage> or <workgroup> "
+ "storage class\n"
+ "12:34 note: atomic sub-type of 'Outer' is declared here");
+}
+
+TEST_F(ResolverAtomicValidationTest, InvalidStorageClass_ArrayOfArray) {
+ // type AtomicArray = array<atomic<i32>, 5>;
+ // var<private> v: array<s, 5>;
+
+ auto* atomic_array = Alias(Source{{12, 34}}, "AtomicArray",
+ ty.atomic(Source{{12, 34}}, ty.i32()));
+ Global(Source{{56, 78}}, "v", ty.Of(atomic_array),
+ ast::StorageClass::kPrivate);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "error: atomic variables must have <storage> or <workgroup> "
+ "storage class");
+}
+
+TEST_F(ResolverAtomicValidationTest, InvalidStorageClass_ArrayOfStruct) {
+ // struct S{
+ // m: atomic<u32>;
+ // };
+ // var<private> v: array<S, 5>;
+
+ auto* s = Structure("S", {Member("m", ty.atomic<u32>())});
+ Global(Source{{56, 78}}, "v", ty.array(ty.Of(s), 5),
+ ast::StorageClass::kPrivate);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "error: atomic variables must have <storage> or <workgroup> "
+ "storage class\n"
+ "note: atomic sub-type of 'array<S, 5>' is declared here");
+}
+
+TEST_F(ResolverAtomicValidationTest, InvalidStorageClass_ArrayOfStructOfArray) {
+ // type AtomicArray = array<atomic<i32>, 5>;
+ // struct S{
+ // m: AtomicArray;
+ // };
+ // var<private> v: array<S, 5>;
+
+ auto* atomic_array = Alias(Source{{12, 34}}, "AtomicArray",
+ ty.atomic(Source{{12, 34}}, ty.i32()));
+ auto* s = Structure("S", {Member("m", ty.Of(atomic_array))});
+ Global(Source{{56, 78}}, "v", ty.array(ty.Of(s), 5),
+ ast::StorageClass::kPrivate);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "error: atomic variables must have <storage> or <workgroup> "
+ "storage class\n"
+ "note: atomic sub-type of 'array<S, 5>' is declared here");
+}
+
+TEST_F(ResolverAtomicValidationTest, InvalidStorageClass_Complex) {
+ // type AtomicArray = array<atomic<i32>, 5>;
+ // struct S6 { x: array<i32, 4>; };
+ // struct S5 { x: S6;
+ // y: AtomicArray;
+ // z: array<atomic<u32>, 8>; };
+ // struct S4 { x: S6;
+ // y: S5;
+ // z: array<atomic<i32>, 4>; };
+ // struct S3 { x: S4; };
+ // struct S2 { x: S3; };
+ // struct S1 { x: S2; };
+ // struct S0 { x: S1; };
+ // var<private> g : S0;
+
+ auto* atomic_array = Alias(Source{{12, 34}}, "AtomicArray",
+ ty.atomic(Source{{12, 34}}, ty.i32()));
+ auto* array_i32_4 = ty.array(ty.i32(), 4);
+ auto* array_atomic_u32_8 = ty.array(ty.atomic(ty.u32()), 8);
+ auto* array_atomic_i32_4 = ty.array(ty.atomic(ty.i32()), 4);
+
+ auto* s6 = Structure("S6", {Member("x", array_i32_4)});
+ auto* s5 = Structure("S5", {Member("x", ty.Of(s6)), //
+ Member("y", ty.Of(atomic_array)), //
+ Member("z", array_atomic_u32_8)}); //
+ auto* s4 = Structure("S4", {Member("x", ty.Of(s6)), //
+ Member("y", ty.Of(s5)), //
+ Member("z", array_atomic_i32_4)}); //
+ auto* s3 = Structure("S3", {Member("x", ty.Of(s4))});
+ auto* s2 = Structure("S2", {Member("x", ty.Of(s3))});
+ auto* s1 = Structure("S1", {Member("x", ty.Of(s2))});
+ auto* s0 = Structure("S0", {Member("x", ty.Of(s1))});
+ Global(Source{{56, 78}}, "g", ty.Of(s0), ast::StorageClass::kPrivate);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "error: atomic variables must have <storage> or <workgroup> "
+ "storage class\n"
+ "note: atomic sub-type of 'S0' is declared here");
+}
+
+TEST_F(ResolverAtomicValidationTest, Struct_AccessMode_Read) {
+ auto* s = Structure("s", {Member("a", ty.atomic(Source{{12, 34}}, ty.i32()))},
+ {StructBlock()});
+ Global(Source{{56, 78}}, "g", ty.Of(s), ast::StorageClass::kStorage,
+ ast::Access::kRead, GroupAndBinding(0, 0));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "error: atomic variables in <storage> storage class must have read_write "
+ "access mode\n"
+ "note: atomic sub-type of 's' is declared here");
+}
+
+TEST_F(ResolverAtomicValidationTest, InvalidAccessMode_Struct) {
+ auto* s = Structure("s", {Member("a", ty.atomic(Source{{12, 34}}, ty.i32()))},
+ {StructBlock()});
+ Global(Source{{56, 78}}, "g", ty.Of(s), ast::StorageClass::kStorage,
+ ast::Access::kRead, GroupAndBinding(0, 0));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "error: atomic variables in <storage> storage class must have read_write "
+ "access mode\n"
+ "note: atomic sub-type of 's' is declared here");
+}
+
+TEST_F(ResolverAtomicValidationTest, InvalidAccessMode_StructOfStruct) {
+ // struct Inner { m : atomic<i32>; };
+ // struct Outer { m : array<Inner, 4>; };
+ // var<storage, read> g : Outer;
+
+ auto* Inner =
+ Structure("Inner", {Member("m", ty.atomic(Source{{12, 34}}, ty.i32()))});
+ auto* Outer =
+ Structure("Outer", {Member("m", ty.Of(Inner))}, {StructBlock()});
+ Global(Source{{56, 78}}, "g", ty.Of(Outer), ast::StorageClass::kStorage,
+ ast::Access::kRead, GroupAndBinding(0, 0));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "error: atomic variables in <storage> storage class must have read_write "
+ "access mode\n"
+ "note: atomic sub-type of 'Outer' is declared here");
+}
+
+TEST_F(ResolverAtomicValidationTest, InvalidAccessMode_StructOfStructOfArray) {
+ // struct Inner { m : array<atomic<i32>, 4>; };
+ // struct Outer { m : array<Inner, 4>; };
+ // var<storage, read> g : Outer;
+
+ auto* Inner =
+ Structure("Inner", {Member(Source{{12, 34}}, "m", ty.atomic(ty.i32()))});
+ auto* Outer =
+ Structure("Outer", {Member("m", ty.Of(Inner))}, {StructBlock()});
+ Global(Source{{56, 78}}, "g", ty.Of(Outer), ast::StorageClass::kStorage,
+ ast::Access::kRead, GroupAndBinding(0, 0));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "error: atomic variables in <storage> storage class must have "
+ "read_write access mode\n"
+ "12:34 note: atomic sub-type of 'Outer' is declared here");
+}
+
+TEST_F(ResolverAtomicValidationTest, InvalidAccessMode_Complex) {
+ // type AtomicArray = array<atomic<i32>, 5>;
+ // struct S6 { x: array<i32, 4>; };
+ // struct S5 { x: S6;
+ // y: AtomicArray;
+ // z: array<atomic<u32>, 8>; };
+ // struct S4 { x: S6;
+ // y: S5;
+ // z: array<atomic<i32>, 4>; };
+ // struct S3 { x: S4; };
+ // struct S2 { x: S3; };
+ // struct S1 { x: S2; };
+ // struct S0 { x: S1; };
+ // var<storage, read> g : S0;
+
+ auto* atomic_array = Alias(Source{{12, 34}}, "AtomicArray",
+ ty.atomic(Source{{12, 34}}, ty.i32()));
+ auto* array_i32_4 = ty.array(ty.i32(), 4);
+ auto* array_atomic_u32_8 = ty.array(ty.atomic(ty.u32()), 8);
+ auto* array_atomic_i32_4 = ty.array(ty.atomic(ty.i32()), 4);
+
+ auto* s6 = Structure("S6", {Member("x", array_i32_4)});
+ auto* s5 = Structure("S5", {Member("x", ty.Of(s6)), //
+ Member("y", ty.Of(atomic_array)), //
+ Member("z", array_atomic_u32_8)}); //
+ auto* s4 = Structure("S4", {Member("x", ty.Of(s6)), //
+ Member("y", ty.Of(s5)), //
+ Member("z", array_atomic_i32_4)}); //
+ auto* s3 = Structure("S3", {Member("x", ty.Of(s4))});
+ auto* s2 = Structure("S2", {Member("x", ty.Of(s3))});
+ auto* s1 = Structure("S1", {Member("x", ty.Of(s2))});
+ auto* s0 = Structure("S0", {Member("x", ty.Of(s1))}, {StructBlock()});
+ Global(Source{{56, 78}}, "g", ty.Of(s0), ast::StorageClass::kStorage,
+ ast::Access::kRead, GroupAndBinding(0, 0));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "error: atomic variables in <storage> storage class must have "
+ "read_write access mode\n"
+ "note: atomic sub-type of 'S0' is declared here");
+}
+
+TEST_F(ResolverAtomicValidationTest, Local) {
+ WrapInFunction(Var("a", ty.atomic(Source{{12, 34}}, ty.i32())));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: function variable must have a constructible type");
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/attribute_validation_test.cc b/src/tint/resolver/attribute_validation_test.cc
new file mode 100644
index 0000000..7231596
--- /dev/null
+++ b/src/tint/resolver/attribute_validation_test.cc
@@ -0,0 +1,1404 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/disable_validation_attribute.h"
+#include "src/tint/resolver/resolver.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+
+#include "gmock/gmock.h"
+
+namespace tint {
+namespace resolver {
+
+// Helpers and typedefs
+template <typename T>
+using DataType = builder::DataType<T>;
+template <typename T>
+using vec2 = builder::vec2<T>;
+template <typename T>
+using vec3 = builder::vec3<T>;
+template <typename T>
+using vec4 = builder::vec4<T>;
+template <typename T>
+using mat2x2 = builder::mat2x2<T>;
+template <typename T>
+using mat3x3 = builder::mat3x3<T>;
+template <typename T>
+using mat4x4 = builder::mat4x4<T>;
+template <typename T, int ID = 0>
+using alias = builder::alias<T, ID>;
+template <typename T>
+using alias1 = builder::alias1<T>;
+template <typename T>
+using alias2 = builder::alias2<T>;
+template <typename T>
+using alias3 = builder::alias3<T>;
+using f32 = builder::f32;
+using i32 = builder::i32;
+using u32 = builder::u32;
+
+namespace AttributeTests {
+namespace {
+enum class AttributeKind {
+ kAlign,
+ kBinding,
+ kBuiltin,
+ kGroup,
+ kId,
+ kInterpolate,
+ kInvariant,
+ kLocation,
+ kOffset,
+ kSize,
+ kStage,
+ kStride,
+ kStructBlock,
+ kWorkgroup,
+
+ kBindingAndGroup,
+};
+
+static bool IsBindingAttribute(AttributeKind kind) {
+ switch (kind) {
+ case AttributeKind::kBinding:
+ case AttributeKind::kGroup:
+ case AttributeKind::kBindingAndGroup:
+ return true;
+ default:
+ return false;
+ }
+}
+
+struct TestParams {
+ AttributeKind kind;
+ bool should_pass;
+};
+struct TestWithParams : ResolverTestWithParam<TestParams> {};
+
+static ast::AttributeList createAttributes(const Source& source,
+ ProgramBuilder& builder,
+ AttributeKind kind) {
+ switch (kind) {
+ case AttributeKind::kAlign:
+ return {builder.create<ast::StructMemberAlignAttribute>(source, 4u)};
+ case AttributeKind::kBinding:
+ return {builder.create<ast::BindingAttribute>(source, 1u)};
+ case AttributeKind::kBuiltin:
+ return {builder.Builtin(source, ast::Builtin::kPosition)};
+ case AttributeKind::kGroup:
+ return {builder.create<ast::GroupAttribute>(source, 1u)};
+ case AttributeKind::kId:
+ return {builder.create<ast::IdAttribute>(source, 0u)};
+ case AttributeKind::kInterpolate:
+ return {builder.Interpolate(source, ast::InterpolationType::kLinear,
+ ast::InterpolationSampling::kCenter)};
+ case AttributeKind::kInvariant:
+ return {builder.Invariant(source)};
+ case AttributeKind::kLocation:
+ return {builder.Location(source, 1)};
+ case AttributeKind::kOffset:
+ return {builder.create<ast::StructMemberOffsetAttribute>(source, 4u)};
+ case AttributeKind::kSize:
+ return {builder.create<ast::StructMemberSizeAttribute>(source, 16u)};
+ case AttributeKind::kStage:
+ return {builder.Stage(source, ast::PipelineStage::kCompute)};
+ case AttributeKind::kStride:
+ return {builder.create<ast::StrideAttribute>(source, 4u)};
+ case AttributeKind::kStructBlock:
+ return {builder.create<ast::StructBlockAttribute>(source)};
+ case AttributeKind::kWorkgroup:
+ return {builder.create<ast::WorkgroupAttribute>(source, builder.Expr(1))};
+ case AttributeKind::kBindingAndGroup:
+ return {builder.create<ast::BindingAttribute>(source, 1u),
+ builder.create<ast::GroupAttribute>(source, 1u)};
+ }
+ return {};
+}
+
+namespace FunctionInputAndOutputTests {
+using FunctionParameterAttributeTest = TestWithParams;
+TEST_P(FunctionParameterAttributeTest, IsValid) {
+ auto& params = GetParam();
+
+ Func("main",
+ ast::VariableList{Param("a", ty.vec4<f32>(),
+ createAttributes({}, *this, params.kind))},
+ ty.void_(), {});
+
+ if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "error: attribute is not valid for non-entry point function "
+ "parameters");
+ }
+}
+INSTANTIATE_TEST_SUITE_P(
+ ResolverAttributeValidationTest,
+ FunctionParameterAttributeTest,
+ testing::Values(TestParams{AttributeKind::kAlign, false},
+ TestParams{AttributeKind::kBinding, false},
+ TestParams{AttributeKind::kBuiltin, false},
+ TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kInterpolate, false},
+ TestParams{AttributeKind::kInvariant, false},
+ TestParams{AttributeKind::kLocation, false},
+ TestParams{AttributeKind::kOffset, false},
+ TestParams{AttributeKind::kSize, false},
+ TestParams{AttributeKind::kStage, false},
+ TestParams{AttributeKind::kStride, false},
+ TestParams{AttributeKind::kStructBlock, false},
+ TestParams{AttributeKind::kWorkgroup, false},
+ TestParams{AttributeKind::kBindingAndGroup, false}));
+
+using FunctionReturnTypeAttributeTest = TestWithParams;
+TEST_P(FunctionReturnTypeAttributeTest, IsValid) {
+ auto& params = GetParam();
+
+ Func("main", ast::VariableList{}, ty.f32(), ast::StatementList{Return(1.f)},
+ {}, createAttributes({}, *this, params.kind));
+
+ if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "error: attribute is not valid for non-entry point function "
+ "return types");
+ }
+}
+INSTANTIATE_TEST_SUITE_P(
+ ResolverAttributeValidationTest,
+ FunctionReturnTypeAttributeTest,
+ testing::Values(TestParams{AttributeKind::kAlign, false},
+ TestParams{AttributeKind::kBinding, false},
+ TestParams{AttributeKind::kBuiltin, false},
+ TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kInterpolate, false},
+ TestParams{AttributeKind::kInvariant, false},
+ TestParams{AttributeKind::kLocation, false},
+ TestParams{AttributeKind::kOffset, false},
+ TestParams{AttributeKind::kSize, false},
+ TestParams{AttributeKind::kStage, false},
+ TestParams{AttributeKind::kStride, false},
+ TestParams{AttributeKind::kStructBlock, false},
+ TestParams{AttributeKind::kWorkgroup, false},
+ TestParams{AttributeKind::kBindingAndGroup, false}));
+} // namespace FunctionInputAndOutputTests
+
+namespace EntryPointInputAndOutputTests {
+using ComputeShaderParameterAttributeTest = TestWithParams;
+TEST_P(ComputeShaderParameterAttributeTest, IsValid) {
+ auto& params = GetParam();
+ auto* p = Param("a", ty.vec4<f32>(),
+ createAttributes(Source{{12, 34}}, *this, params.kind));
+ Func("main", ast::VariableList{p}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
+
+ if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ if (params.kind == AttributeKind::kBuiltin) {
+ EXPECT_EQ(r()->error(),
+ "12:34 error: builtin(position) cannot be used in input of "
+ "compute pipeline stage");
+ } else if (params.kind == AttributeKind::kInterpolate ||
+ params.kind == AttributeKind::kLocation ||
+ params.kind == AttributeKind::kInvariant) {
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attribute is not valid for compute shader inputs");
+ } else {
+ EXPECT_EQ(r()->error(),
+ "12:34 error: attribute is not valid for function parameters");
+ }
+ }
+}
+INSTANTIATE_TEST_SUITE_P(
+ ResolverAttributeValidationTest,
+ ComputeShaderParameterAttributeTest,
+ testing::Values(TestParams{AttributeKind::kAlign, false},
+ TestParams{AttributeKind::kBinding, false},
+ TestParams{AttributeKind::kBuiltin, false},
+ TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kInterpolate, false},
+ TestParams{AttributeKind::kInvariant, false},
+ TestParams{AttributeKind::kLocation, false},
+ TestParams{AttributeKind::kOffset, false},
+ TestParams{AttributeKind::kSize, false},
+ TestParams{AttributeKind::kStage, false},
+ TestParams{AttributeKind::kStride, false},
+ TestParams{AttributeKind::kStructBlock, false},
+ TestParams{AttributeKind::kWorkgroup, false},
+ TestParams{AttributeKind::kBindingAndGroup, false}));
+
+using FragmentShaderParameterAttributeTest = TestWithParams;
+TEST_P(FragmentShaderParameterAttributeTest, IsValid) {
+ auto& params = GetParam();
+ auto attrs = createAttributes(Source{{12, 34}}, *this, params.kind);
+ if (params.kind != AttributeKind::kBuiltin &&
+ params.kind != AttributeKind::kLocation) {
+ attrs.push_back(Builtin(Source{{34, 56}}, ast::Builtin::kPosition));
+ }
+ auto* p = Param("a", ty.vec4<f32>(), attrs);
+ Func("frag_main", {p}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: attribute is not valid for function parameters");
+ }
+}
+INSTANTIATE_TEST_SUITE_P(
+ ResolverAttributeValidationTest,
+ FragmentShaderParameterAttributeTest,
+ testing::Values(TestParams{AttributeKind::kAlign, false},
+ TestParams{AttributeKind::kBinding, false},
+ TestParams{AttributeKind::kBuiltin, true},
+ TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kId, false},
+ // kInterpolate tested separately (requires [[location]])
+ TestParams{AttributeKind::kInvariant, true},
+ TestParams{AttributeKind::kLocation, true},
+ TestParams{AttributeKind::kOffset, false},
+ TestParams{AttributeKind::kSize, false},
+ TestParams{AttributeKind::kStage, false},
+ TestParams{AttributeKind::kStride, false},
+ TestParams{AttributeKind::kStructBlock, false},
+ TestParams{AttributeKind::kWorkgroup, false},
+ TestParams{AttributeKind::kBindingAndGroup, false}));
+
+using VertexShaderParameterAttributeTest = TestWithParams;
+TEST_P(VertexShaderParameterAttributeTest, IsValid) {
+ auto& params = GetParam();
+ auto attrs = createAttributes(Source{{12, 34}}, *this, params.kind);
+ if (params.kind != AttributeKind::kLocation) {
+ attrs.push_back(Location(Source{{34, 56}}, 2));
+ }
+ auto* p = Param("a", ty.vec4<f32>(), attrs);
+ Func("vertex_main", ast::VariableList{p}, ty.vec4<f32>(),
+ {Return(Construct(ty.vec4<f32>()))},
+ {Stage(ast::PipelineStage::kVertex)},
+ {Builtin(ast::Builtin::kPosition)});
+
+ if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ if (params.kind == AttributeKind::kBuiltin) {
+ EXPECT_EQ(r()->error(),
+ "12:34 error: builtin(position) cannot be used in input of "
+ "vertex pipeline stage");
+ } else if (params.kind == AttributeKind::kInvariant) {
+ EXPECT_EQ(r()->error(),
+ "12:34 error: invariant attribute must only be applied to a "
+ "position builtin");
+ } else {
+ EXPECT_EQ(r()->error(),
+ "12:34 error: attribute is not valid for function parameters");
+ }
+ }
+}
+INSTANTIATE_TEST_SUITE_P(
+ ResolverAttributeValidationTest,
+ VertexShaderParameterAttributeTest,
+ testing::Values(TestParams{AttributeKind::kAlign, false},
+ TestParams{AttributeKind::kBinding, false},
+ TestParams{AttributeKind::kBuiltin, false},
+ TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kInterpolate, true},
+ TestParams{AttributeKind::kInvariant, false},
+ TestParams{AttributeKind::kLocation, true},
+ TestParams{AttributeKind::kOffset, false},
+ TestParams{AttributeKind::kSize, false},
+ TestParams{AttributeKind::kStage, false},
+ TestParams{AttributeKind::kStride, false},
+ TestParams{AttributeKind::kStructBlock, false},
+ TestParams{AttributeKind::kWorkgroup, false},
+ TestParams{AttributeKind::kBindingAndGroup, false}));
+
+using ComputeShaderReturnTypeAttributeTest = TestWithParams;
+TEST_P(ComputeShaderReturnTypeAttributeTest, IsValid) {
+ auto& params = GetParam();
+ Func("main", ast::VariableList{}, ty.vec4<f32>(),
+ {Return(Construct(ty.vec4<f32>(), 1.f))},
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)},
+ createAttributes(Source{{12, 34}}, *this, params.kind));
+
+ if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ if (params.kind == AttributeKind::kBuiltin) {
+ EXPECT_EQ(r()->error(),
+ "12:34 error: builtin(position) cannot be used in output of "
+ "compute pipeline stage");
+ } else if (params.kind == AttributeKind::kInterpolate ||
+ params.kind == AttributeKind::kLocation ||
+ params.kind == AttributeKind::kInvariant) {
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attribute is not valid for compute shader output");
+ } else {
+ EXPECT_EQ(r()->error(),
+ "12:34 error: attribute is not valid for entry point return "
+ "types");
+ }
+ }
+}
+INSTANTIATE_TEST_SUITE_P(
+ ResolverAttributeValidationTest,
+ ComputeShaderReturnTypeAttributeTest,
+ testing::Values(TestParams{AttributeKind::kAlign, false},
+ TestParams{AttributeKind::kBinding, false},
+ TestParams{AttributeKind::kBuiltin, false},
+ TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kInterpolate, false},
+ TestParams{AttributeKind::kInvariant, false},
+ TestParams{AttributeKind::kLocation, false},
+ TestParams{AttributeKind::kOffset, false},
+ TestParams{AttributeKind::kSize, false},
+ TestParams{AttributeKind::kStage, false},
+ TestParams{AttributeKind::kStride, false},
+ TestParams{AttributeKind::kStructBlock, false},
+ TestParams{AttributeKind::kWorkgroup, false},
+ TestParams{AttributeKind::kBindingAndGroup, false}));
+
+using FragmentShaderReturnTypeAttributeTest = TestWithParams;
+TEST_P(FragmentShaderReturnTypeAttributeTest, IsValid) {
+ auto& params = GetParam();
+ auto attrs = createAttributes(Source{{12, 34}}, *this, params.kind);
+ attrs.push_back(Location(Source{{34, 56}}, 2));
+ Func("frag_main", {}, ty.vec4<f32>(), {Return(Construct(ty.vec4<f32>()))},
+ {Stage(ast::PipelineStage::kFragment)}, attrs);
+
+ if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ if (params.kind == AttributeKind::kBuiltin) {
+ EXPECT_EQ(r()->error(),
+ "12:34 error: builtin(position) cannot be used in output of "
+ "fragment pipeline stage");
+ } else if (params.kind == AttributeKind::kInvariant) {
+ EXPECT_EQ(r()->error(),
+ "12:34 error: invariant attribute must only be applied to a "
+ "position builtin");
+ } else if (params.kind == AttributeKind::kLocation) {
+ EXPECT_EQ(r()->error(),
+ "34:56 error: duplicate location attribute\n"
+ "12:34 note: first attribute declared here");
+ } else {
+ EXPECT_EQ(r()->error(),
+ "12:34 error: attribute is not valid for entry point return "
+ "types");
+ }
+ }
+}
+INSTANTIATE_TEST_SUITE_P(
+ ResolverAttributeValidationTest,
+ FragmentShaderReturnTypeAttributeTest,
+ testing::Values(TestParams{AttributeKind::kAlign, false},
+ TestParams{AttributeKind::kBinding, false},
+ TestParams{AttributeKind::kBuiltin, false},
+ TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kInterpolate, true},
+ TestParams{AttributeKind::kInvariant, false},
+ TestParams{AttributeKind::kLocation, false},
+ TestParams{AttributeKind::kOffset, false},
+ TestParams{AttributeKind::kSize, false},
+ TestParams{AttributeKind::kStage, false},
+ TestParams{AttributeKind::kStride, false},
+ TestParams{AttributeKind::kStructBlock, false},
+ TestParams{AttributeKind::kWorkgroup, false},
+ TestParams{AttributeKind::kBindingAndGroup, false}));
+
+using VertexShaderReturnTypeAttributeTest = TestWithParams;
+TEST_P(VertexShaderReturnTypeAttributeTest, IsValid) {
+ auto& params = GetParam();
+ auto attrs = createAttributes(Source{{12, 34}}, *this, params.kind);
+ // a vertex shader must include the 'position' builtin in its return type
+ if (params.kind != AttributeKind::kBuiltin) {
+ attrs.push_back(Builtin(Source{{34, 56}}, ast::Builtin::kPosition));
+ }
+ Func("vertex_main", ast::VariableList{}, ty.vec4<f32>(),
+ {Return(Construct(ty.vec4<f32>()))},
+ {Stage(ast::PipelineStage::kVertex)}, attrs);
+
+ if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ if (params.kind == AttributeKind::kLocation) {
+ EXPECT_EQ(r()->error(),
+ "34:56 error: multiple entry point IO attributes\n"
+ "12:34 note: previously consumed location(1)");
+ } else {
+ EXPECT_EQ(r()->error(),
+ "12:34 error: attribute is not valid for entry point return "
+ "types");
+ }
+ }
+}
+INSTANTIATE_TEST_SUITE_P(
+ ResolverAttributeValidationTest,
+ VertexShaderReturnTypeAttributeTest,
+ testing::Values(TestParams{AttributeKind::kAlign, false},
+ TestParams{AttributeKind::kBinding, false},
+ TestParams{AttributeKind::kBuiltin, true},
+ TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kId, false},
+ // kInterpolate tested separately (requires [[location]])
+ TestParams{AttributeKind::kInvariant, true},
+ TestParams{AttributeKind::kLocation, false},
+ TestParams{AttributeKind::kOffset, false},
+ TestParams{AttributeKind::kSize, false},
+ TestParams{AttributeKind::kStage, false},
+ TestParams{AttributeKind::kStride, false},
+ TestParams{AttributeKind::kStructBlock, false},
+ TestParams{AttributeKind::kWorkgroup, false},
+ TestParams{AttributeKind::kBindingAndGroup, false}));
+
+using EntryPointParameterAttributeTest = TestWithParams;
+TEST_F(EntryPointParameterAttributeTest, DuplicateAttribute) {
+ Func("main", ast::VariableList{}, ty.f32(), ast::StatementList{Return(1.f)},
+ {Stage(ast::PipelineStage::kFragment)},
+ {
+ Location(Source{{12, 34}}, 2),
+ Location(Source{{56, 78}}, 3),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(56:78 error: duplicate location attribute
+12:34 note: first attribute declared here)");
+}
+
+TEST_F(EntryPointParameterAttributeTest, DuplicateInternalAttribute) {
+ auto* s = Param("s", ty.sampler(ast::SamplerKind::kSampler),
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ Disable(ast::DisabledValidation::kBindingPointCollision),
+ Disable(ast::DisabledValidation::kEntryPointParameter),
+ });
+ Func("f", {s}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+using EntryPointReturnTypeAttributeTest = ResolverTest;
+TEST_F(EntryPointReturnTypeAttributeTest, DuplicateAttribute) {
+ Func("main", ast::VariableList{}, ty.f32(), ast::StatementList{Return(1.f)},
+ ast::AttributeList{Stage(ast::PipelineStage::kFragment)},
+ ast::AttributeList{
+ Location(Source{{12, 34}}, 2),
+ Location(Source{{56, 78}}, 3),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(56:78 error: duplicate location attribute
+12:34 note: first attribute declared here)");
+}
+
+TEST_F(EntryPointReturnTypeAttributeTest, DuplicateInternalAttribute) {
+ Func("f", {}, ty.i32(), {Return(1)}, {Stage(ast::PipelineStage::kFragment)},
+ ast::AttributeList{
+ Disable(ast::DisabledValidation::kBindingPointCollision),
+ Disable(ast::DisabledValidation::kEntryPointParameter),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+} // namespace EntryPointInputAndOutputTests
+
+namespace StructAndStructMemberTests {
+using StructAttributeTest = TestWithParams;
+TEST_P(StructAttributeTest, IsValid) {
+ auto& params = GetParam();
+
+ Structure("mystruct", {Member("a", ty.f32())},
+ createAttributes(Source{{12, 34}}, *this, params.kind));
+
+ WrapInFunction();
+
+ if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: attribute is not valid for struct declarations");
+ }
+}
+INSTANTIATE_TEST_SUITE_P(
+ ResolverAttributeValidationTest,
+ StructAttributeTest,
+ testing::Values(TestParams{AttributeKind::kAlign, false},
+ TestParams{AttributeKind::kBinding, false},
+ TestParams{AttributeKind::kBuiltin, false},
+ TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kInterpolate, false},
+ TestParams{AttributeKind::kInvariant, false},
+ TestParams{AttributeKind::kLocation, false},
+ TestParams{AttributeKind::kOffset, false},
+ TestParams{AttributeKind::kSize, false},
+ TestParams{AttributeKind::kStage, false},
+ TestParams{AttributeKind::kStride, false},
+ TestParams{AttributeKind::kStructBlock, true},
+ TestParams{AttributeKind::kWorkgroup, false},
+ TestParams{AttributeKind::kBindingAndGroup, false}));
+
+TEST_F(StructAttributeTest, DuplicateAttribute) {
+ Structure("mystruct",
+ {
+ Member("a", ty.i32()),
+ },
+ {
+ create<ast::StructBlockAttribute>(Source{{12, 34}}),
+ create<ast::StructBlockAttribute>(Source{{56, 78}}),
+ });
+ WrapInFunction();
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(56:78 error: duplicate block attribute
+12:34 note: first attribute declared here)");
+}
+using StructMemberAttributeTest = TestWithParams;
+TEST_P(StructMemberAttributeTest, IsValid) {
+ auto& params = GetParam();
+ ast::StructMemberList members;
+ if (params.kind == AttributeKind::kBuiltin) {
+ members.push_back(
+ {Member("a", ty.vec4<f32>(),
+ createAttributes(Source{{12, 34}}, *this, params.kind))});
+ } else {
+ members.push_back(
+ {Member("a", ty.f32(),
+ createAttributes(Source{{12, 34}}, *this, params.kind))});
+ }
+ Structure("mystruct", members);
+ WrapInFunction();
+ if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: attribute is not valid for structure members");
+ }
+}
+INSTANTIATE_TEST_SUITE_P(
+ ResolverAttributeValidationTest,
+ StructMemberAttributeTest,
+ testing::Values(TestParams{AttributeKind::kAlign, true},
+ TestParams{AttributeKind::kBinding, false},
+ TestParams{AttributeKind::kBuiltin, true},
+ TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kId, false},
+ // kInterpolate tested separately (requires [[location]])
+ // kInvariant tested separately (requires position builtin)
+ TestParams{AttributeKind::kLocation, true},
+ TestParams{AttributeKind::kOffset, true},
+ TestParams{AttributeKind::kSize, true},
+ TestParams{AttributeKind::kStage, false},
+ TestParams{AttributeKind::kStride, false},
+ TestParams{AttributeKind::kStructBlock, false},
+ TestParams{AttributeKind::kWorkgroup, false},
+ TestParams{AttributeKind::kBindingAndGroup, false}));
+TEST_F(StructMemberAttributeTest, DuplicateAttribute) {
+ Structure(
+ "mystruct",
+ {
+ Member(
+ "a", ty.i32(),
+ {
+ create<ast::StructMemberAlignAttribute>(Source{{12, 34}}, 4u),
+ create<ast::StructMemberAlignAttribute>(Source{{56, 78}}, 8u),
+ }),
+ });
+ WrapInFunction();
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(56:78 error: duplicate align attribute
+12:34 note: first attribute declared here)");
+}
+TEST_F(StructMemberAttributeTest, InvariantAttributeWithPosition) {
+ Structure("mystruct", {
+ Member("a", ty.vec4<f32>(),
+ {
+ Invariant(),
+ Builtin(ast::Builtin::kPosition),
+ }),
+ });
+ WrapInFunction();
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+TEST_F(StructMemberAttributeTest, InvariantAttributeWithoutPosition) {
+ Structure("mystruct", {
+ Member("a", ty.vec4<f32>(),
+ {
+ Invariant(Source{{12, 34}}),
+ }),
+ });
+ WrapInFunction();
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: invariant attribute must only be applied to a "
+ "position builtin");
+}
+
+} // namespace StructAndStructMemberTests
+
+using ArrayAttributeTest = TestWithParams;
+TEST_P(ArrayAttributeTest, IsValid) {
+ auto& params = GetParam();
+
+ auto* arr = ty.array(ty.f32(), nullptr,
+ createAttributes(Source{{12, 34}}, *this, params.kind));
+ Structure("mystruct",
+ {
+ Member("a", arr),
+ },
+ {create<ast::StructBlockAttribute>()});
+
+ WrapInFunction();
+
+ if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: attribute is not valid for array types");
+ }
+}
+INSTANTIATE_TEST_SUITE_P(
+ ResolverAttributeValidationTest,
+ ArrayAttributeTest,
+ testing::Values(TestParams{AttributeKind::kAlign, false},
+ TestParams{AttributeKind::kBinding, false},
+ TestParams{AttributeKind::kBuiltin, false},
+ TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kInterpolate, false},
+ TestParams{AttributeKind::kInvariant, false},
+ TestParams{AttributeKind::kLocation, false},
+ TestParams{AttributeKind::kOffset, false},
+ TestParams{AttributeKind::kSize, false},
+ TestParams{AttributeKind::kStage, false},
+ TestParams{AttributeKind::kStride, true},
+ TestParams{AttributeKind::kStructBlock, false},
+ TestParams{AttributeKind::kWorkgroup, false},
+ TestParams{AttributeKind::kBindingAndGroup, false}));
+
+using VariableAttributeTest = TestWithParams;
+TEST_P(VariableAttributeTest, IsValid) {
+ auto& params = GetParam();
+
+ if (IsBindingAttribute(params.kind)) {
+ Global("a", ty.sampler(ast::SamplerKind::kSampler),
+ ast::StorageClass::kNone, nullptr,
+ createAttributes(Source{{12, 34}}, *this, params.kind));
+ } else {
+ Global("a", ty.f32(), ast::StorageClass::kPrivate, nullptr,
+ createAttributes(Source{{12, 34}}, *this, params.kind));
+ }
+
+ WrapInFunction();
+
+ if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ if (!IsBindingAttribute(params.kind)) {
+ EXPECT_EQ(r()->error(),
+ "12:34 error: attribute is not valid for variables");
+ }
+ }
+}
+INSTANTIATE_TEST_SUITE_P(
+ ResolverAttributeValidationTest,
+ VariableAttributeTest,
+ testing::Values(TestParams{AttributeKind::kAlign, false},
+ TestParams{AttributeKind::kBinding, false},
+ TestParams{AttributeKind::kBuiltin, false},
+ TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kInterpolate, false},
+ TestParams{AttributeKind::kInvariant, false},
+ TestParams{AttributeKind::kLocation, false},
+ TestParams{AttributeKind::kOffset, false},
+ TestParams{AttributeKind::kSize, false},
+ TestParams{AttributeKind::kStage, false},
+ TestParams{AttributeKind::kStride, false},
+ TestParams{AttributeKind::kStructBlock, false},
+ TestParams{AttributeKind::kWorkgroup, false},
+ TestParams{AttributeKind::kBindingAndGroup, true}));
+
+TEST_F(VariableAttributeTest, DuplicateAttribute) {
+ Global("a", ty.sampler(ast::SamplerKind::kSampler),
+ ast::AttributeList{
+ create<ast::BindingAttribute>(Source{{12, 34}}, 2),
+ create<ast::GroupAttribute>(2),
+ create<ast::BindingAttribute>(Source{{56, 78}}, 3),
+ });
+
+ WrapInFunction();
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(56:78 error: duplicate binding attribute
+12:34 note: first attribute declared here)");
+}
+
+TEST_F(VariableAttributeTest, LocalVariable) {
+ auto* v = Var("a", ty.f32(),
+ ast::AttributeList{
+ create<ast::BindingAttribute>(Source{{12, 34}}, 2),
+ });
+
+ WrapInFunction(v);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: attributes are not valid on local variables");
+}
+
+using ConstantAttributeTest = TestWithParams;
+TEST_P(ConstantAttributeTest, IsValid) {
+ auto& params = GetParam();
+
+ GlobalConst("a", ty.f32(), Expr(1.23f),
+ createAttributes(Source{{12, 34}}, *this, params.kind));
+
+ WrapInFunction();
+
+ if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: attribute is not valid for constants");
+ }
+}
+INSTANTIATE_TEST_SUITE_P(
+ ResolverAttributeValidationTest,
+ ConstantAttributeTest,
+ testing::Values(TestParams{AttributeKind::kAlign, false},
+ TestParams{AttributeKind::kBinding, false},
+ TestParams{AttributeKind::kBuiltin, false},
+ TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kId, true},
+ TestParams{AttributeKind::kInterpolate, false},
+ TestParams{AttributeKind::kInvariant, false},
+ TestParams{AttributeKind::kLocation, false},
+ TestParams{AttributeKind::kOffset, false},
+ TestParams{AttributeKind::kSize, false},
+ TestParams{AttributeKind::kStage, false},
+ TestParams{AttributeKind::kStride, false},
+ TestParams{AttributeKind::kStructBlock, false},
+ TestParams{AttributeKind::kWorkgroup, false},
+ TestParams{AttributeKind::kBindingAndGroup, false}));
+
+TEST_F(ConstantAttributeTest, DuplicateAttribute) {
+ GlobalConst("a", ty.f32(), Expr(1.23f),
+ ast::AttributeList{
+ create<ast::IdAttribute>(Source{{12, 34}}, 0),
+ create<ast::IdAttribute>(Source{{56, 78}}, 1),
+ });
+
+ WrapInFunction();
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(56:78 error: duplicate id attribute
+12:34 note: first attribute declared here)");
+}
+
+} // namespace
+} // namespace AttributeTests
+
+namespace ArrayStrideTests {
+namespace {
+
+struct Params {
+ builder::ast_type_func_ptr create_el_type;
+ uint32_t stride;
+ bool should_pass;
+};
+
+template <typename T>
+constexpr Params ParamsFor(uint32_t stride, bool should_pass) {
+ return Params{DataType<T>::AST, stride, should_pass};
+}
+
+struct TestWithParams : ResolverTestWithParam<Params> {};
+
+using ArrayStrideTest = TestWithParams;
+TEST_P(ArrayStrideTest, All) {
+ auto& params = GetParam();
+ auto* el_ty = params.create_el_type(*this);
+
+ std::stringstream ss;
+ ss << "el_ty: " << FriendlyName(el_ty) << ", stride: " << params.stride
+ << ", should_pass: " << params.should_pass;
+ SCOPED_TRACE(ss.str());
+
+ auto* arr = ty.array(Source{{12, 34}}, el_ty, 4, params.stride);
+
+ Global("myarray", arr, ast::StorageClass::kPrivate);
+
+ if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: arrays decorated with the stride attribute must "
+ "have a stride that is at least the size of the element type, "
+ "and be a multiple of the element type's alignment value.");
+ }
+}
+
+struct SizeAndAlignment {
+ uint32_t size;
+ uint32_t align;
+};
+constexpr SizeAndAlignment default_u32 = {4, 4};
+constexpr SizeAndAlignment default_i32 = {4, 4};
+constexpr SizeAndAlignment default_f32 = {4, 4};
+constexpr SizeAndAlignment default_vec2 = {8, 8};
+constexpr SizeAndAlignment default_vec3 = {12, 16};
+constexpr SizeAndAlignment default_vec4 = {16, 16};
+constexpr SizeAndAlignment default_mat2x2 = {16, 8};
+constexpr SizeAndAlignment default_mat3x3 = {48, 16};
+constexpr SizeAndAlignment default_mat4x4 = {64, 16};
+
+INSTANTIATE_TEST_SUITE_P(
+ ResolverAttributeValidationTest,
+ ArrayStrideTest,
+ testing::Values(
+ // Succeed because stride >= element size (while being multiple of
+ // element alignment)
+ ParamsFor<u32>(default_u32.size, true),
+ ParamsFor<i32>(default_i32.size, true),
+ ParamsFor<f32>(default_f32.size, true),
+ ParamsFor<vec2<f32>>(default_vec2.size, true),
+ // vec3's default size is not a multiple of its alignment
+ // ParamsFor<vec3<f32>, default_vec3.size, true},
+ ParamsFor<vec4<f32>>(default_vec4.size, true),
+ ParamsFor<mat2x2<f32>>(default_mat2x2.size, true),
+ ParamsFor<mat3x3<f32>>(default_mat3x3.size, true),
+ ParamsFor<mat4x4<f32>>(default_mat4x4.size, true),
+
+ // Fail because stride is < element size
+ ParamsFor<u32>(default_u32.size - 1, false),
+ ParamsFor<i32>(default_i32.size - 1, false),
+ ParamsFor<f32>(default_f32.size - 1, false),
+ ParamsFor<vec2<f32>>(default_vec2.size - 1, false),
+ ParamsFor<vec3<f32>>(default_vec3.size - 1, false),
+ ParamsFor<vec4<f32>>(default_vec4.size - 1, false),
+ ParamsFor<mat2x2<f32>>(default_mat2x2.size - 1, false),
+ ParamsFor<mat3x3<f32>>(default_mat3x3.size - 1, false),
+ ParamsFor<mat4x4<f32>>(default_mat4x4.size - 1, false),
+
+ // Succeed because stride equals multiple of element alignment
+ ParamsFor<u32>(default_u32.align * 7, true),
+ ParamsFor<i32>(default_i32.align * 7, true),
+ ParamsFor<f32>(default_f32.align * 7, true),
+ ParamsFor<vec2<f32>>(default_vec2.align * 7, true),
+ ParamsFor<vec3<f32>>(default_vec3.align * 7, true),
+ ParamsFor<vec4<f32>>(default_vec4.align * 7, true),
+ ParamsFor<mat2x2<f32>>(default_mat2x2.align * 7, true),
+ ParamsFor<mat3x3<f32>>(default_mat3x3.align * 7, true),
+ ParamsFor<mat4x4<f32>>(default_mat4x4.align * 7, true),
+
+ // Fail because stride is not multiple of element alignment
+ ParamsFor<u32>((default_u32.align - 1) * 7, false),
+ ParamsFor<i32>((default_i32.align - 1) * 7, false),
+ ParamsFor<f32>((default_f32.align - 1) * 7, false),
+ ParamsFor<vec2<f32>>((default_vec2.align - 1) * 7, false),
+ ParamsFor<vec3<f32>>((default_vec3.align - 1) * 7, false),
+ ParamsFor<vec4<f32>>((default_vec4.align - 1) * 7, false),
+ ParamsFor<mat2x2<f32>>((default_mat2x2.align - 1) * 7, false),
+ ParamsFor<mat3x3<f32>>((default_mat3x3.align - 1) * 7, false),
+ ParamsFor<mat4x4<f32>>((default_mat4x4.align - 1) * 7, false)));
+
+TEST_F(ArrayStrideTest, DuplicateAttribute) {
+ auto* arr = ty.array(Source{{12, 34}}, ty.i32(), 4,
+ {
+ create<ast::StrideAttribute>(Source{{12, 34}}, 4),
+ create<ast::StrideAttribute>(Source{{56, 78}}, 4),
+ });
+
+ Global("myarray", arr, ast::StorageClass::kPrivate);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(56:78 error: duplicate stride attribute
+12:34 note: first attribute declared here)");
+}
+
+} // namespace
+} // namespace ArrayStrideTests
+
+namespace ResourceTests {
+namespace {
+
+using ResourceAttributeTest = ResolverTest;
+TEST_F(ResourceAttributeTest, UniformBufferMissingBinding) {
+ auto* s = Structure("S", {Member("x", ty.i32())},
+ {create<ast::StructBlockAttribute>()});
+ Global(Source{{12, 34}}, "G", ty.Of(s), ast::StorageClass::kUniform);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: resource variables require @group and @binding attributes)");
+}
+
+TEST_F(ResourceAttributeTest, StorageBufferMissingBinding) {
+ auto* s = Structure("S", {Member("x", ty.i32())},
+ {create<ast::StructBlockAttribute>()});
+ Global(Source{{12, 34}}, "G", ty.Of(s), ast::StorageClass::kStorage,
+ ast::Access::kRead);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: resource variables require @group and @binding attributes)");
+}
+
+TEST_F(ResourceAttributeTest, TextureMissingBinding) {
+ Global(Source{{12, 34}}, "G", ty.depth_texture(ast::TextureDimension::k2d),
+ ast::StorageClass::kNone);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: resource variables require @group and @binding attributes)");
+}
+
+TEST_F(ResourceAttributeTest, SamplerMissingBinding) {
+ Global(Source{{12, 34}}, "G", ty.sampler(ast::SamplerKind::kSampler),
+ ast::StorageClass::kNone);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: resource variables require @group and @binding attributes)");
+}
+
+TEST_F(ResourceAttributeTest, BindingPairMissingBinding) {
+ Global(Source{{12, 34}}, "G", ty.sampler(ast::SamplerKind::kSampler),
+ ast::StorageClass::kNone,
+ ast::AttributeList{
+ create<ast::GroupAttribute>(1),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: resource variables require @group and @binding attributes)");
+}
+
+TEST_F(ResourceAttributeTest, BindingPairMissingGroup) {
+ Global(Source{{12, 34}}, "G", ty.sampler(ast::SamplerKind::kSampler),
+ ast::StorageClass::kNone,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(1),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: resource variables require @group and @binding attributes)");
+}
+
+TEST_F(ResourceAttributeTest, BindingPointUsedTwiceByEntryPoint) {
+ Global(Source{{12, 34}}, "A",
+ ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()),
+ ast::StorageClass::kNone,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(1),
+ create<ast::GroupAttribute>(2),
+ });
+ Global(Source{{56, 78}}, "B",
+ ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()),
+ ast::StorageClass::kNone,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(1),
+ create<ast::GroupAttribute>(2),
+ });
+
+ Func("F", {}, ty.void_(),
+ {
+ Decl(Var("a", ty.vec4<f32>(), ast::StorageClass::kNone,
+ Call("textureLoad", "A", vec2<i32>(1, 2), 0))),
+ Decl(Var("b", ty.vec4<f32>(), ast::StorageClass::kNone,
+ Call("textureLoad", "B", vec2<i32>(1, 2), 0))),
+ },
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: entry point 'F' references multiple variables that use the same resource binding @group(2), @binding(1)
+12:34 note: first resource binding usage declared here)");
+}
+
+TEST_F(ResourceAttributeTest, BindingPointUsedTwiceByDifferentEntryPoints) {
+ Global(Source{{12, 34}}, "A",
+ ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()),
+ ast::StorageClass::kNone,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(1),
+ create<ast::GroupAttribute>(2),
+ });
+ Global(Source{{56, 78}}, "B",
+ ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()),
+ ast::StorageClass::kNone,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(1),
+ create<ast::GroupAttribute>(2),
+ });
+
+ Func("F_A", {}, ty.void_(),
+ {
+ Decl(Var("a", ty.vec4<f32>(), ast::StorageClass::kNone,
+ Call("textureLoad", "A", vec2<i32>(1, 2), 0))),
+ },
+ {Stage(ast::PipelineStage::kFragment)});
+ Func("F_B", {}, ty.void_(),
+ {
+ Decl(Var("b", ty.vec4<f32>(), ast::StorageClass::kNone,
+ Call("textureLoad", "B", vec2<i32>(1, 2), 0))),
+ },
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResourceAttributeTest, BindingPointOnNonResource) {
+ Global(Source{{12, 34}}, "G", ty.f32(), ast::StorageClass::kPrivate,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(1),
+ create<ast::GroupAttribute>(2),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: non-resource variables must not have @group or @binding attributes)");
+}
+
+} // namespace
+} // namespace ResourceTests
+
+namespace InvariantAttributeTests {
+namespace {
+using InvariantAttributeTests = ResolverTest;
+TEST_F(InvariantAttributeTests, InvariantWithPosition) {
+ auto* param = Param("p", ty.vec4<f32>(),
+ {Invariant(Source{{12, 34}}),
+ Builtin(Source{{56, 78}}, ast::Builtin::kPosition)});
+ Func("main", ast::VariableList{param}, ty.vec4<f32>(),
+ ast::StatementList{Return(Construct(ty.vec4<f32>()))},
+ ast::AttributeList{Stage(ast::PipelineStage::kFragment)},
+ ast::AttributeList{
+ Location(0),
+ });
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(InvariantAttributeTests, InvariantWithoutPosition) {
+ auto* param =
+ Param("p", ty.vec4<f32>(), {Invariant(Source{{12, 34}}), Location(0)});
+ Func("main", ast::VariableList{param}, ty.vec4<f32>(),
+ ast::StatementList{Return(Construct(ty.vec4<f32>()))},
+ ast::AttributeList{Stage(ast::PipelineStage::kFragment)},
+ ast::AttributeList{
+ Location(0),
+ });
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: invariant attribute must only be applied to a "
+ "position builtin");
+}
+} // namespace
+} // namespace InvariantAttributeTests
+
+namespace WorkgroupAttributeTests {
+namespace {
+
+using WorkgroupAttribute = ResolverTest;
+TEST_F(WorkgroupAttribute, ComputeShaderPass) {
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ create<ast::WorkgroupAttribute>(Source{{12, 34}}, Expr(1))});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(WorkgroupAttribute, Missing) {
+ Func(Source{{12, 34}}, "main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: a compute shader must include 'workgroup_size' in its "
+ "attributes");
+}
+
+TEST_F(WorkgroupAttribute, NotAnEntryPoint) {
+ Func("main", {}, ty.void_(), {},
+ {create<ast::WorkgroupAttribute>(Source{{12, 34}}, Expr(1))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: the workgroup_size attribute is only valid for "
+ "compute stages");
+}
+
+TEST_F(WorkgroupAttribute, NotAComputeShader) {
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment),
+ create<ast::WorkgroupAttribute>(Source{{12, 34}}, Expr(1))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: the workgroup_size attribute is only valid for "
+ "compute stages");
+}
+
+TEST_F(WorkgroupAttribute, DuplicateAttribute) {
+ Func(Source{{12, 34}}, "main", {}, ty.void_(), {},
+ {
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Source{{12, 34}}, 1, nullptr, nullptr),
+ WorkgroupSize(Source{{56, 78}}, 2, nullptr, nullptr),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(56:78 error: duplicate workgroup_size attribute
+12:34 note: first attribute declared here)");
+}
+
+} // namespace
+} // namespace WorkgroupAttributeTests
+
+namespace InterpolateTests {
+namespace {
+
+using InterpolateTest = ResolverTest;
+
+struct Params {
+ ast::InterpolationType type;
+ ast::InterpolationSampling sampling;
+ bool should_pass;
+};
+
+struct TestWithParams : ResolverTestWithParam<Params> {};
+
+using InterpolateParameterTest = TestWithParams;
+TEST_P(InterpolateParameterTest, All) {
+ auto& params = GetParam();
+
+ Func("main",
+ ast::VariableList{Param(
+ "a", ty.f32(),
+ {Location(0),
+ Interpolate(Source{{12, 34}}, params.type, params.sampling)})},
+ ty.void_(), {},
+ ast::AttributeList{Stage(ast::PipelineStage::kFragment)});
+
+ if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: flat interpolation attribute must not have a "
+ "sampling parameter");
+ }
+}
+
+TEST_P(InterpolateParameterTest, IntegerScalar) {
+ auto& params = GetParam();
+
+ Func("main",
+ ast::VariableList{Param(
+ "a", ty.i32(),
+ {Location(0),
+ Interpolate(Source{{12, 34}}, params.type, params.sampling)})},
+ ty.void_(), {},
+ ast::AttributeList{Stage(ast::PipelineStage::kFragment)});
+
+ if (params.type != ast::InterpolationType::kFlat) {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: interpolation type must be 'flat' for integral "
+ "user-defined IO types");
+ } else if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: flat interpolation attribute must not have a "
+ "sampling parameter");
+ }
+}
+
+TEST_P(InterpolateParameterTest, IntegerVector) {
+ auto& params = GetParam();
+
+ Func("main",
+ ast::VariableList{Param(
+ "a", ty.vec4<u32>(),
+ {Location(0),
+ Interpolate(Source{{12, 34}}, params.type, params.sampling)})},
+ ty.void_(), {},
+ ast::AttributeList{Stage(ast::PipelineStage::kFragment)});
+
+ if (params.type != ast::InterpolationType::kFlat) {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: interpolation type must be 'flat' for integral "
+ "user-defined IO types");
+ } else if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: flat interpolation attribute must not have a "
+ "sampling parameter");
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ ResolverAttributeValidationTest,
+ InterpolateParameterTest,
+ testing::Values(Params{ast::InterpolationType::kPerspective,
+ ast::InterpolationSampling::kNone, true},
+ Params{ast::InterpolationType::kPerspective,
+ ast::InterpolationSampling::kCenter, true},
+ Params{ast::InterpolationType::kPerspective,
+ ast::InterpolationSampling::kCentroid, true},
+ Params{ast::InterpolationType::kPerspective,
+ ast::InterpolationSampling::kSample, true},
+ Params{ast::InterpolationType::kLinear,
+ ast::InterpolationSampling::kNone, true},
+ Params{ast::InterpolationType::kLinear,
+ ast::InterpolationSampling::kCenter, true},
+ Params{ast::InterpolationType::kLinear,
+ ast::InterpolationSampling::kCentroid, true},
+ Params{ast::InterpolationType::kLinear,
+ ast::InterpolationSampling::kSample, true},
+ // flat interpolation must not have a sampling type
+ Params{ast::InterpolationType::kFlat,
+ ast::InterpolationSampling::kNone, true},
+ Params{ast::InterpolationType::kFlat,
+ ast::InterpolationSampling::kCenter, false},
+ Params{ast::InterpolationType::kFlat,
+ ast::InterpolationSampling::kCentroid, false},
+ Params{ast::InterpolationType::kFlat,
+ ast::InterpolationSampling::kSample, false}));
+
+TEST_F(InterpolateTest, FragmentInput_Integer_MissingFlatInterpolation) {
+ Func("main",
+ ast::VariableList{Param(Source{{12, 34}}, "a", ty.i32(), {Location(0)})},
+ ty.void_(), {},
+ ast::AttributeList{Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: integral user-defined fragment inputs must have a flat interpolation attribute)");
+}
+
+TEST_F(InterpolateTest, VertexOutput_Integer_MissingFlatInterpolation) {
+ auto* s = Structure(
+ "S",
+ {
+ Member("pos", ty.vec4<f32>(), {Builtin(ast::Builtin::kPosition)}),
+ Member(Source{{12, 34}}, "u", ty.u32(), {Location(0)}),
+ },
+ {});
+ Func("main", {}, ty.Of(s), {Return(Construct(ty.Of(s)))},
+ ast::AttributeList{Stage(ast::PipelineStage::kVertex)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: integral user-defined vertex outputs must have a flat interpolation attribute
+note: while analysing entry point 'main')");
+}
+
+TEST_F(InterpolateTest, MissingLocationAttribute_Parameter) {
+ Func("main",
+ ast::VariableList{
+ Param("a", ty.vec4<f32>(),
+ {Builtin(ast::Builtin::kPosition),
+ Interpolate(Source{{12, 34}}, ast::InterpolationType::kFlat,
+ ast::InterpolationSampling::kNone)})},
+ ty.void_(), {},
+ ast::AttributeList{Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: interpolate attribute must only be used with @location)");
+}
+
+TEST_F(InterpolateTest, MissingLocationAttribute_ReturnType) {
+ Func("main", {}, ty.vec4<f32>(), {Return(Construct(ty.vec4<f32>()))},
+ ast::AttributeList{Stage(ast::PipelineStage::kVertex)},
+ {Builtin(ast::Builtin::kPosition),
+ Interpolate(Source{{12, 34}}, ast::InterpolationType::kFlat,
+ ast::InterpolationSampling::kNone)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: interpolate attribute must only be used with @location)");
+}
+
+TEST_F(InterpolateTest, MissingLocationAttribute_Struct) {
+ Structure(
+ "S", {Member("a", ty.f32(),
+ {Interpolate(Source{{12, 34}}, ast::InterpolationType::kFlat,
+ ast::InterpolationSampling::kNone)})});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: interpolate attribute must only be used with @location)");
+}
+
+} // namespace
+} // namespace InterpolateTests
+
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/bitcast_validation_test.cc b/src/tint/resolver/bitcast_validation_test.cc
new file mode 100644
index 0000000..d13ffd9
--- /dev/null
+++ b/src/tint/resolver/bitcast_validation_test.cc
@@ -0,0 +1,228 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/bitcast_expression.h"
+#include "src/tint/resolver/resolver.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+
+#include "gmock/gmock.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+struct Type {
+ template <typename T>
+ static constexpr Type Create() {
+ return Type{builder::DataType<T>::AST, builder::DataType<T>::Sem,
+ builder::DataType<T>::Expr};
+ }
+
+ builder::ast_type_func_ptr ast;
+ builder::sem_type_func_ptr sem;
+ builder::ast_expr_func_ptr expr;
+};
+
+static constexpr Type kNumericScalars[] = {
+ Type::Create<builder::f32>(),
+ Type::Create<builder::i32>(),
+ Type::Create<builder::u32>(),
+};
+static constexpr Type kVec2NumericScalars[] = {
+ Type::Create<builder::vec2<builder::f32>>(),
+ Type::Create<builder::vec2<builder::i32>>(),
+ Type::Create<builder::vec2<builder::u32>>(),
+};
+static constexpr Type kVec3NumericScalars[] = {
+ Type::Create<builder::vec3<builder::f32>>(),
+ Type::Create<builder::vec3<builder::i32>>(),
+ Type::Create<builder::vec3<builder::u32>>(),
+};
+static constexpr Type kVec4NumericScalars[] = {
+ Type::Create<builder::vec4<builder::f32>>(),
+ Type::Create<builder::vec4<builder::i32>>(),
+ Type::Create<builder::vec4<builder::u32>>(),
+};
+static constexpr Type kInvalid[] = {
+ // A non-exhaustive selection of uncastable types
+ Type::Create<bool>(),
+ Type::Create<builder::vec2<bool>>(),
+ Type::Create<builder::vec3<bool>>(),
+ Type::Create<builder::vec4<bool>>(),
+ Type::Create<builder::array<2, builder::i32>>(),
+ Type::Create<builder::array<3, builder::u32>>(),
+ Type::Create<builder::array<4, builder::f32>>(),
+ Type::Create<builder::array<5, bool>>(),
+ Type::Create<builder::mat2x2<builder::f32>>(),
+ Type::Create<builder::mat3x3<builder::f32>>(),
+ Type::Create<builder::mat4x4<builder::f32>>(),
+ Type::Create<builder::ptr<builder::i32>>(),
+ Type::Create<builder::ptr<builder::array<2, builder::i32>>>(),
+ Type::Create<builder::ptr<builder::mat2x2<builder::f32>>>(),
+};
+
+using ResolverBitcastValidationTest =
+ ResolverTestWithParam<std::tuple<Type, Type>>;
+
+////////////////////////////////////////////////////////////////////////////////
+// Valid bitcasts
+////////////////////////////////////////////////////////////////////////////////
+using ResolverBitcastValidationTestPass = ResolverBitcastValidationTest;
+TEST_P(ResolverBitcastValidationTestPass, Test) {
+ auto src = std::get<0>(GetParam());
+ auto dst = std::get<1>(GetParam());
+
+ auto* cast = Bitcast(dst.ast(*this), src.expr(*this, 0));
+ WrapInFunction(cast);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(TypeOf(cast), dst.sem(*this));
+}
+INSTANTIATE_TEST_SUITE_P(Scalars,
+ ResolverBitcastValidationTestPass,
+ testing::Combine(testing::ValuesIn(kNumericScalars),
+ testing::ValuesIn(kNumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec2,
+ ResolverBitcastValidationTestPass,
+ testing::Combine(testing::ValuesIn(kVec2NumericScalars),
+ testing::ValuesIn(kVec2NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec3,
+ ResolverBitcastValidationTestPass,
+ testing::Combine(testing::ValuesIn(kVec3NumericScalars),
+ testing::ValuesIn(kVec3NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec4,
+ ResolverBitcastValidationTestPass,
+ testing::Combine(testing::ValuesIn(kVec4NumericScalars),
+ testing::ValuesIn(kVec4NumericScalars)));
+
+////////////////////////////////////////////////////////////////////////////////
+// Invalid source type for bitcasts
+////////////////////////////////////////////////////////////////////////////////
+using ResolverBitcastValidationTestInvalidSrcTy = ResolverBitcastValidationTest;
+TEST_P(ResolverBitcastValidationTestInvalidSrcTy, Test) {
+ auto src = std::get<0>(GetParam());
+ auto dst = std::get<1>(GetParam());
+
+ auto* cast = Bitcast(dst.ast(*this), Expr(Source{{12, 34}}, "src"));
+ WrapInFunction(Const("src", nullptr, src.expr(*this, 0)), cast);
+
+ auto expected = "12:34 error: '" + src.sem(*this)->FriendlyName(Symbols()) +
+ "' cannot be bitcast";
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), expected);
+}
+INSTANTIATE_TEST_SUITE_P(Scalars,
+ ResolverBitcastValidationTestInvalidSrcTy,
+ testing::Combine(testing::ValuesIn(kInvalid),
+ testing::ValuesIn(kNumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec2,
+ ResolverBitcastValidationTestInvalidSrcTy,
+ testing::Combine(testing::ValuesIn(kInvalid),
+ testing::ValuesIn(kVec2NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec3,
+ ResolverBitcastValidationTestInvalidSrcTy,
+ testing::Combine(testing::ValuesIn(kInvalid),
+ testing::ValuesIn(kVec3NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec4,
+ ResolverBitcastValidationTestInvalidSrcTy,
+ testing::Combine(testing::ValuesIn(kInvalid),
+ testing::ValuesIn(kVec4NumericScalars)));
+
+////////////////////////////////////////////////////////////////////////////////
+// Invalid target type for bitcasts
+////////////////////////////////////////////////////////////////////////////////
+using ResolverBitcastValidationTestInvalidDstTy = ResolverBitcastValidationTest;
+TEST_P(ResolverBitcastValidationTestInvalidDstTy, Test) {
+ auto src = std::get<0>(GetParam());
+ auto dst = std::get<1>(GetParam());
+
+ // Use an alias so we can put a Source on the bitcast type
+ Alias("T", dst.ast(*this));
+ WrapInFunction(
+ Bitcast(ty.type_name(Source{{12, 34}}, "T"), src.expr(*this, 0)));
+
+ auto expected = "12:34 error: cannot bitcast to '" +
+ dst.sem(*this)->FriendlyName(Symbols()) + "'";
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), expected);
+}
+INSTANTIATE_TEST_SUITE_P(Scalars,
+ ResolverBitcastValidationTestInvalidDstTy,
+ testing::Combine(testing::ValuesIn(kNumericScalars),
+ testing::ValuesIn(kInvalid)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec2,
+ ResolverBitcastValidationTestInvalidDstTy,
+ testing::Combine(testing::ValuesIn(kVec2NumericScalars),
+ testing::ValuesIn(kInvalid)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec3,
+ ResolverBitcastValidationTestInvalidDstTy,
+ testing::Combine(testing::ValuesIn(kVec3NumericScalars),
+ testing::ValuesIn(kInvalid)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec4,
+ ResolverBitcastValidationTestInvalidDstTy,
+ testing::Combine(testing::ValuesIn(kVec4NumericScalars),
+ testing::ValuesIn(kInvalid)));
+
+////////////////////////////////////////////////////////////////////////////////
+// Incompatible bitcast, but both src and dst types are valid
+////////////////////////////////////////////////////////////////////////////////
+using ResolverBitcastValidationTestIncompatible = ResolverBitcastValidationTest;
+TEST_P(ResolverBitcastValidationTestIncompatible, Test) {
+ auto src = std::get<0>(GetParam());
+ auto dst = std::get<1>(GetParam());
+
+ WrapInFunction(Bitcast(Source{{12, 34}}, dst.ast(*this), src.expr(*this, 0)));
+
+ auto expected = "12:34 error: cannot bitcast from '" +
+ src.sem(*this)->FriendlyName(Symbols()) + "' to '" +
+ dst.sem(*this)->FriendlyName(Symbols()) + "'";
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), expected);
+}
+INSTANTIATE_TEST_SUITE_P(
+ ScalarToVec2,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(kNumericScalars),
+ testing::ValuesIn(kVec2NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec2ToVec3,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(kVec2NumericScalars),
+ testing::ValuesIn(kVec3NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec3ToVec4,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(kVec3NumericScalars),
+ testing::ValuesIn(kVec4NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec4ToScalar,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(kVec4NumericScalars),
+ testing::ValuesIn(kNumericScalars)));
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/builtin_test.cc b/src/tint/resolver/builtin_test.cc
new file mode 100644
index 0000000..8914783
--- /dev/null
+++ b/src/tint/resolver/builtin_test.cc
@@ -0,0 +1,2061 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/ast/assignment_statement.h"
+#include "src/tint/ast/bitcast_expression.h"
+#include "src/tint/ast/break_statement.h"
+#include "src/tint/ast/builtin_texture_helper_test.h"
+#include "src/tint/ast/call_statement.h"
+#include "src/tint/ast/continue_statement.h"
+#include "src/tint/ast/if_statement.h"
+#include "src/tint/ast/loop_statement.h"
+#include "src/tint/ast/return_statement.h"
+#include "src/tint/ast/stage_attribute.h"
+#include "src/tint/ast/struct_block_attribute.h"
+#include "src/tint/ast/switch_statement.h"
+#include "src/tint/ast/unary_op_expression.h"
+#include "src/tint/ast/variable_decl_statement.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/sampled_texture_type.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/variable.h"
+
+using ::testing::ElementsAre;
+using ::testing::HasSubstr;
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using BuiltinType = sem::BuiltinType;
+
+using ResolverBuiltinTest = ResolverTest;
+
+using ResolverBuiltinDerivativeTest = ResolverTestWithParam<std::string>;
+TEST_P(ResolverBuiltinDerivativeTest, Scalar) {
+ auto name = GetParam();
+
+ Global("ident", ty.f32(), ast::StorageClass::kPrivate);
+
+ auto* expr = Call(name, "ident");
+ Func("func", {}, ty.void_(), {Ignore(expr)},
+ {create<ast::StageAttribute>(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::F32>());
+}
+
+TEST_P(ResolverBuiltinDerivativeTest, Vector) {
+ auto name = GetParam();
+ Global("ident", ty.vec4<f32>(), ast::StorageClass::kPrivate);
+
+ auto* expr = Call(name, "ident");
+ Func("func", {}, ty.void_(), {Ignore(expr)},
+ {create<ast::StageAttribute>(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(expr)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(expr)->As<sem::Vector>()->Width(), 4u);
+}
+
+TEST_P(ResolverBuiltinDerivativeTest, MissingParam) {
+ auto name = GetParam();
+
+ auto* expr = Call(name);
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), "error: no matching call to " + name +
+ "()\n\n"
+ "2 candidate functions:\n " +
+ name + "(f32) -> f32\n " + name +
+ "(vecN<f32>) -> vecN<f32>\n");
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverTest,
+ ResolverBuiltinDerivativeTest,
+ testing::Values("dpdx",
+ "dpdxCoarse",
+ "dpdxFine",
+ "dpdy",
+ "dpdyCoarse",
+ "dpdyFine",
+ "fwidth",
+ "fwidthCoarse",
+ "fwidthFine"));
+
+using ResolverBuiltinTest_BoolMethod = ResolverTestWithParam<std::string>;
+TEST_P(ResolverBuiltinTest_BoolMethod, Scalar) {
+ auto name = GetParam();
+
+ Global("my_var", ty.bool_(), ast::StorageClass::kPrivate);
+
+ auto* expr = Call(name, "my_var");
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<sem::Bool>());
+}
+TEST_P(ResolverBuiltinTest_BoolMethod, Vector) {
+ auto name = GetParam();
+
+ Global("my_var", ty.vec3<bool>(), ast::StorageClass::kPrivate);
+
+ auto* expr = Call(name, "my_var");
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<sem::Bool>());
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTest,
+ ResolverBuiltinTest_BoolMethod,
+ testing::Values("any", "all"));
+
+using ResolverBuiltinTest_FloatMethod = ResolverTestWithParam<std::string>;
+TEST_P(ResolverBuiltinTest_FloatMethod, Vector) {
+ auto name = GetParam();
+
+ Global("my_var", ty.vec3<f32>(), ast::StorageClass::kPrivate);
+
+ auto* expr = Call(name, "my_var");
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(expr)->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_EQ(TypeOf(expr)->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_P(ResolverBuiltinTest_FloatMethod, Scalar) {
+ auto name = GetParam();
+
+ Global("my_var", ty.f32(), ast::StorageClass::kPrivate);
+
+ auto* expr = Call(name, "my_var");
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<sem::Bool>());
+}
+
+TEST_P(ResolverBuiltinTest_FloatMethod, MissingParam) {
+ auto name = GetParam();
+
+ Global("my_var", ty.f32(), ast::StorageClass::kPrivate);
+
+ auto* expr = Call(name);
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), "error: no matching call to " + name +
+ "()\n\n"
+ "2 candidate functions:\n " +
+ name + "(f32) -> bool\n " + name +
+ "(vecN<f32>) -> vecN<bool>\n");
+}
+
+TEST_P(ResolverBuiltinTest_FloatMethod, TooManyParams) {
+ auto name = GetParam();
+
+ Global("my_var", ty.f32(), ast::StorageClass::kPrivate);
+
+ auto* expr = Call(name, "my_var", 1.23f);
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), "error: no matching call to " + name +
+ "(f32, f32)\n\n"
+ "2 candidate functions:\n " +
+ name + "(f32) -> bool\n " + name +
+ "(vecN<f32>) -> vecN<bool>\n");
+}
+INSTANTIATE_TEST_SUITE_P(
+ ResolverTest,
+ ResolverBuiltinTest_FloatMethod,
+ testing::Values("isInf", "isNan", "isFinite", "isNormal"));
+
+enum class Texture { kF32, kI32, kU32 };
+inline std::ostream& operator<<(std::ostream& out, Texture data) {
+ if (data == Texture::kF32) {
+ out << "f32";
+ } else if (data == Texture::kI32) {
+ out << "i32";
+ } else {
+ out << "u32";
+ }
+ return out;
+}
+
+struct TextureTestParams {
+ ast::TextureDimension dim;
+ Texture type = Texture::kF32;
+ ast::TexelFormat format = ast::TexelFormat::kR32Float;
+};
+inline std::ostream& operator<<(std::ostream& out, TextureTestParams data) {
+ out << data.dim << "_" << data.type;
+ return out;
+}
+
+class ResolverBuiltinTest_TextureOperation
+ : public ResolverTestWithParam<TextureTestParams> {
+ public:
+ /// Gets an appropriate type for the coords parameter depending the the
+ /// dimensionality of the texture being sampled.
+ /// @param dim dimensionality of the texture being sampled
+ /// @param scalar the scalar type
+ /// @returns a pointer to a type appropriate for the coord param
+ const ast::Type* GetCoordsType(ast::TextureDimension dim,
+ const ast::Type* scalar) {
+ switch (dim) {
+ case ast::TextureDimension::k1d:
+ return scalar;
+ case ast::TextureDimension::k2d:
+ case ast::TextureDimension::k2dArray:
+ return ty.vec(scalar, 2);
+ case ast::TextureDimension::k3d:
+ case ast::TextureDimension::kCube:
+ case ast::TextureDimension::kCubeArray:
+ return ty.vec(scalar, 3);
+ default:
+ [=]() { FAIL() << "Unsupported texture dimension: " << dim; }();
+ }
+ return nullptr;
+ }
+
+ void add_call_param(std::string name,
+ const ast::Type* type,
+ ast::ExpressionList* call_params) {
+ if (type->IsAnyOf<ast::Texture, ast::Sampler>()) {
+ Global(name, type,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ } else {
+ Global(name, type, ast::StorageClass::kPrivate);
+ }
+
+ call_params->push_back(Expr(name));
+ }
+ const ast::Type* subtype(Texture type) {
+ if (type == Texture::kF32) {
+ return ty.f32();
+ }
+ if (type == Texture::kI32) {
+ return ty.i32();
+ }
+ return ty.u32();
+ }
+};
+
+using ResolverBuiltinTest_SampledTextureOperation =
+ ResolverBuiltinTest_TextureOperation;
+TEST_P(ResolverBuiltinTest_SampledTextureOperation, TextureLoadSampled) {
+ auto dim = GetParam().dim;
+ auto type = GetParam().type;
+
+ auto* s = subtype(type);
+ auto* coords_type = GetCoordsType(dim, ty.i32());
+ auto* texture_type = ty.sampled_texture(dim, s);
+
+ ast::ExpressionList call_params;
+
+ add_call_param("texture", texture_type, &call_params);
+ add_call_param("coords", coords_type, &call_params);
+ if (dim == ast::TextureDimension::k2dArray) {
+ add_call_param("array_index", ty.i32(), &call_params);
+ }
+ add_call_param("level", ty.i32(), &call_params);
+
+ auto* expr = Call("textureLoad", call_params);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::Vector>());
+ if (type == Texture::kF32) {
+ EXPECT_TRUE(TypeOf(expr)->As<sem::Vector>()->type()->Is<sem::F32>());
+ } else if (type == Texture::kI32) {
+ EXPECT_TRUE(TypeOf(expr)->As<sem::Vector>()->type()->Is<sem::I32>());
+ } else {
+ EXPECT_TRUE(TypeOf(expr)->As<sem::Vector>()->type()->Is<sem::U32>());
+ }
+ EXPECT_EQ(TypeOf(expr)->As<sem::Vector>()->Width(), 4u);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ ResolverTest,
+ ResolverBuiltinTest_SampledTextureOperation,
+ testing::Values(TextureTestParams{ast::TextureDimension::k1d},
+ TextureTestParams{ast::TextureDimension::k2d},
+ TextureTestParams{ast::TextureDimension::k2dArray},
+ TextureTestParams{ast::TextureDimension::k3d}));
+
+TEST_F(ResolverBuiltinTest, Dot_Vec2) {
+ Global("my_var", ty.vec2<f32>(), ast::StorageClass::kPrivate);
+
+ auto* expr = Call("dot", "my_var", "my_var");
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<sem::F32>());
+}
+
+TEST_F(ResolverBuiltinTest, Dot_Vec3) {
+ Global("my_var", ty.vec3<i32>(), ast::StorageClass::kPrivate);
+
+ auto* expr = Call("dot", "my_var", "my_var");
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<sem::I32>());
+}
+
+TEST_F(ResolverBuiltinTest, Dot_Vec4) {
+ Global("my_var", ty.vec4<u32>(), ast::StorageClass::kPrivate);
+
+ auto* expr = Call("dot", "my_var", "my_var");
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<sem::U32>());
+}
+
+TEST_F(ResolverBuiltinTest, Dot_Error_Scalar) {
+ auto* expr = Call("dot", 1.0f, 1.0f);
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ R"(error: no matching call to dot(f32, f32)
+
+1 candidate function:
+ dot(vecN<T>, vecN<T>) -> T where: T is f32, i32 or u32
+)");
+}
+
+TEST_F(ResolverBuiltinTest, Select) {
+ Global("my_var", ty.vec3<f32>(), ast::StorageClass::kPrivate);
+
+ Global("bool_var", ty.vec3<bool>(), ast::StorageClass::kPrivate);
+
+ auto* expr = Call("select", "my_var", "my_var", "bool_var");
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<sem::Vector>());
+ EXPECT_EQ(TypeOf(expr)->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(TypeOf(expr)->As<sem::Vector>()->type()->Is<sem::F32>());
+}
+
+TEST_F(ResolverBuiltinTest, Select_Error_NoParams) {
+ auto* expr = Call("select");
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ R"(error: no matching call to select()
+
+3 candidate functions:
+ select(T, T, bool) -> T where: T is f32, i32, u32 or bool
+ select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, i32, u32 or bool
+ select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, i32, u32 or bool
+)");
+}
+
+TEST_F(ResolverBuiltinTest, Select_Error_SelectorInt) {
+ auto* expr = Call("select", 1, 1, 1);
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ R"(error: no matching call to select(i32, i32, i32)
+
+3 candidate functions:
+ select(T, T, bool) -> T where: T is f32, i32, u32 or bool
+ select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, i32, u32 or bool
+ select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, i32, u32 or bool
+)");
+}
+
+TEST_F(ResolverBuiltinTest, Select_Error_Matrix) {
+ auto* expr = Call(
+ "select", mat2x2<f32>(vec2<f32>(1.0f, 1.0f), vec2<f32>(1.0f, 1.0f)),
+ mat2x2<f32>(vec2<f32>(1.0f, 1.0f), vec2<f32>(1.0f, 1.0f)), Expr(true));
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ R"(error: no matching call to select(mat2x2<f32>, mat2x2<f32>, bool)
+
+3 candidate functions:
+ select(T, T, bool) -> T where: T is f32, i32, u32 or bool
+ select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, i32, u32 or bool
+ select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, i32, u32 or bool
+)");
+}
+
+TEST_F(ResolverBuiltinTest, Select_Error_MismatchTypes) {
+ auto* expr = Call("select", 1.0f, vec2<f32>(2.0f, 3.0f), Expr(true));
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ R"(error: no matching call to select(f32, vec2<f32>, bool)
+
+3 candidate functions:
+ select(T, T, bool) -> T where: T is f32, i32, u32 or bool
+ select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, i32, u32 or bool
+ select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, i32, u32 or bool
+)");
+}
+
+TEST_F(ResolverBuiltinTest, Select_Error_MismatchVectorSize) {
+ auto* expr = Call("select", vec2<f32>(1.0f, 2.0f),
+ vec3<f32>(3.0f, 4.0f, 5.0f), Expr(true));
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ R"(error: no matching call to select(vec2<f32>, vec3<f32>, bool)
+
+3 candidate functions:
+ select(T, T, bool) -> T where: T is f32, i32, u32 or bool
+ select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, i32, u32 or bool
+ select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, i32, u32 or bool
+)");
+}
+
+struct BuiltinData {
+ const char* name;
+ BuiltinType builtin;
+};
+
+inline std::ostream& operator<<(std::ostream& out, BuiltinData data) {
+ out << data.name;
+ return out;
+}
+
+using ResolverBuiltinTest_Barrier = ResolverTestWithParam<BuiltinData>;
+TEST_P(ResolverBuiltinTest_Barrier, InferType) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name);
+ WrapInFunction(CallStmt(call));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<sem::Void>());
+}
+
+TEST_P(ResolverBuiltinTest_Barrier, Error_TooManyParams) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, vec4<f32>(1.f, 2.f, 3.f, 4.f), 1.0f);
+ WrapInFunction(CallStmt(call));
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_THAT(r()->error(), HasSubstr("error: no matching call to " +
+ std::string(param.name)));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ ResolverTest,
+ ResolverBuiltinTest_Barrier,
+ testing::Values(BuiltinData{"storageBarrier", BuiltinType::kStorageBarrier},
+ BuiltinData{"workgroupBarrier",
+ BuiltinType::kWorkgroupBarrier}));
+
+using ResolverBuiltinTest_DataPacking = ResolverTestWithParam<BuiltinData>;
+TEST_P(ResolverBuiltinTest_DataPacking, InferType) {
+ auto param = GetParam();
+
+ bool pack4 = param.builtin == BuiltinType::kPack4x8snorm ||
+ param.builtin == BuiltinType::kPack4x8unorm;
+
+ auto* call = pack4 ? Call(param.name, vec4<f32>(1.f, 2.f, 3.f, 4.f))
+ : Call(param.name, vec2<f32>(1.f, 2.f));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<sem::U32>());
+}
+
+TEST_P(ResolverBuiltinTest_DataPacking, Error_IncorrectParamType) {
+ auto param = GetParam();
+
+ bool pack4 = param.builtin == BuiltinType::kPack4x8snorm ||
+ param.builtin == BuiltinType::kPack4x8unorm;
+
+ auto* call = pack4 ? Call(param.name, vec4<i32>(1, 2, 3, 4))
+ : Call(param.name, vec2<i32>(1, 2));
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_THAT(r()->error(), HasSubstr("error: no matching call to " +
+ std::string(param.name)));
+}
+
+TEST_P(ResolverBuiltinTest_DataPacking, Error_NoParams) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name);
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_THAT(r()->error(), HasSubstr("error: no matching call to " +
+ std::string(param.name)));
+}
+
+TEST_P(ResolverBuiltinTest_DataPacking, Error_TooManyParams) {
+ auto param = GetParam();
+
+ bool pack4 = param.builtin == BuiltinType::kPack4x8snorm ||
+ param.builtin == BuiltinType::kPack4x8unorm;
+
+ auto* call = pack4 ? Call(param.name, vec4<f32>(1.f, 2.f, 3.f, 4.f), 1.0f)
+ : Call(param.name, vec2<f32>(1.f, 2.f), 1.0f);
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_THAT(r()->error(), HasSubstr("error: no matching call to " +
+ std::string(param.name)));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ ResolverTest,
+ ResolverBuiltinTest_DataPacking,
+ testing::Values(BuiltinData{"pack4x8snorm", BuiltinType::kPack4x8snorm},
+ BuiltinData{"pack4x8unorm", BuiltinType::kPack4x8unorm},
+ BuiltinData{"pack2x16snorm", BuiltinType::kPack2x16snorm},
+ BuiltinData{"pack2x16unorm", BuiltinType::kPack2x16unorm},
+ BuiltinData{"pack2x16float", BuiltinType::kPack2x16float}));
+
+using ResolverBuiltinTest_DataUnpacking = ResolverTestWithParam<BuiltinData>;
+TEST_P(ResolverBuiltinTest_DataUnpacking, InferType) {
+ auto param = GetParam();
+
+ bool pack4 = param.builtin == BuiltinType::kUnpack4x8snorm ||
+ param.builtin == BuiltinType::kUnpack4x8unorm;
+
+ auto* call = Call(param.name, 1u);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_vector());
+ if (pack4) {
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 4u);
+ } else {
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 2u);
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ ResolverTest,
+ ResolverBuiltinTest_DataUnpacking,
+ testing::Values(
+ BuiltinData{"unpack4x8snorm", BuiltinType::kUnpack4x8snorm},
+ BuiltinData{"unpack4x8unorm", BuiltinType::kUnpack4x8unorm},
+ BuiltinData{"unpack2x16snorm", BuiltinType::kUnpack2x16snorm},
+ BuiltinData{"unpack2x16unorm", BuiltinType::kUnpack2x16unorm},
+ BuiltinData{"unpack2x16float", BuiltinType::kUnpack2x16float}));
+
+using ResolverBuiltinTest_SingleParam = ResolverTestWithParam<BuiltinData>;
+TEST_P(ResolverBuiltinTest_SingleParam, Scalar) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, 1.f);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_scalar());
+}
+
+TEST_P(ResolverBuiltinTest_SingleParam, Vector) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, vec3<f32>(1.0f, 1.0f, 3.0f));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_vector());
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_P(ResolverBuiltinTest_SingleParam, Error_NoParams) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name);
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "error: no matching call to " + std::string(param.name) +
+ "()\n\n"
+ "2 candidate functions:\n " +
+ std::string(param.name) + "(f32) -> f32\n " +
+ std::string(param.name) + "(vecN<f32>) -> vecN<f32>\n");
+}
+
+TEST_P(ResolverBuiltinTest_SingleParam, Error_TooManyParams) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, 1, 2, 3);
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "error: no matching call to " + std::string(param.name) +
+ "(i32, i32, i32)\n\n"
+ "2 candidate functions:\n " +
+ std::string(param.name) + "(f32) -> f32\n " +
+ std::string(param.name) + "(vecN<f32>) -> vecN<f32>\n");
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ ResolverTest,
+ ResolverBuiltinTest_SingleParam,
+ testing::Values(BuiltinData{"acos", BuiltinType::kAcos},
+ BuiltinData{"asin", BuiltinType::kAsin},
+ BuiltinData{"atan", BuiltinType::kAtan},
+ BuiltinData{"ceil", BuiltinType::kCeil},
+ BuiltinData{"cos", BuiltinType::kCos},
+ BuiltinData{"cosh", BuiltinType::kCosh},
+ BuiltinData{"exp", BuiltinType::kExp},
+ BuiltinData{"exp2", BuiltinType::kExp2},
+ BuiltinData{"floor", BuiltinType::kFloor},
+ BuiltinData{"fract", BuiltinType::kFract},
+ BuiltinData{"inverseSqrt", BuiltinType::kInverseSqrt},
+ BuiltinData{"log", BuiltinType::kLog},
+ BuiltinData{"log2", BuiltinType::kLog2},
+ BuiltinData{"round", BuiltinType::kRound},
+ BuiltinData{"sign", BuiltinType::kSign},
+ BuiltinData{"sin", BuiltinType::kSin},
+ BuiltinData{"sinh", BuiltinType::kSinh},
+ BuiltinData{"sqrt", BuiltinType::kSqrt},
+ BuiltinData{"tan", BuiltinType::kTan},
+ BuiltinData{"tanh", BuiltinType::kTanh},
+ BuiltinData{"trunc", BuiltinType::kTrunc}));
+
+using ResolverBuiltinDataTest = ResolverTest;
+
+TEST_F(ResolverBuiltinDataTest, ArrayLength_Vector) {
+ auto* ary = ty.array<i32>();
+ auto* str =
+ Structure("S", {Member("x", ary)}, {create<ast::StructBlockAttribute>()});
+ Global("a", ty.Of(str), ast::StorageClass::kStorage, ast::Access::kRead,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ auto* call = Call("arrayLength", AddressOf(MemberAccessor("a", "x")));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<sem::U32>());
+}
+
+TEST_F(ResolverBuiltinDataTest, ArrayLength_Error_ArraySized) {
+ Global("arr", ty.array<int, 4>(), ast::StorageClass::kPrivate);
+ auto* call = Call("arrayLength", AddressOf("arr"));
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(error: no matching call to arrayLength(ptr<private, array<i32, 4>, read_write>)
+
+1 candidate function:
+ arrayLength(ptr<storage, array<T>, A>) -> u32
+)");
+}
+
+TEST_F(ResolverBuiltinDataTest, Normalize_Vector) {
+ auto* call = Call("normalize", vec3<f32>(1.0f, 1.0f, 3.0f));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_vector());
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_F(ResolverBuiltinDataTest, Normalize_Error_NoParams) {
+ auto* call = Call("normalize");
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), R"(error: no matching call to normalize()
+
+1 candidate function:
+ normalize(vecN<f32>) -> vecN<f32>
+)");
+}
+
+TEST_F(ResolverBuiltinDataTest, FrexpScalar) {
+ auto* call = Call("frexp", 1.0f);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ auto* ty = TypeOf(call)->As<sem::Struct>();
+ ASSERT_NE(ty, nullptr);
+ ASSERT_EQ(ty->Members().size(), 2u);
+
+ auto* sig = ty->Members()[0];
+ EXPECT_TRUE(sig->Type()->Is<sem::F32>());
+ EXPECT_EQ(sig->Offset(), 0u);
+ EXPECT_EQ(sig->Size(), 4u);
+ EXPECT_EQ(sig->Align(), 4u);
+ EXPECT_EQ(sig->Name(), Sym("sig"));
+
+ auto* exp = ty->Members()[1];
+ EXPECT_TRUE(exp->Type()->Is<sem::I32>());
+ EXPECT_EQ(exp->Offset(), 4u);
+ EXPECT_EQ(exp->Size(), 4u);
+ EXPECT_EQ(exp->Align(), 4u);
+ EXPECT_EQ(exp->Name(), Sym("exp"));
+
+ EXPECT_EQ(ty->Size(), 8u);
+ EXPECT_EQ(ty->SizeNoPadding(), 8u);
+}
+
+TEST_F(ResolverBuiltinDataTest, FrexpVector) {
+ auto* call = Call("frexp", vec3<f32>());
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ auto* ty = TypeOf(call)->As<sem::Struct>();
+ ASSERT_NE(ty, nullptr);
+ ASSERT_EQ(ty->Members().size(), 2u);
+
+ auto* sig = ty->Members()[0];
+ ASSERT_TRUE(sig->Type()->Is<sem::Vector>());
+ EXPECT_EQ(sig->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(sig->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(sig->Offset(), 0u);
+ EXPECT_EQ(sig->Size(), 12u);
+ EXPECT_EQ(sig->Align(), 16u);
+ EXPECT_EQ(sig->Name(), Sym("sig"));
+
+ auto* exp = ty->Members()[1];
+ ASSERT_TRUE(exp->Type()->Is<sem::Vector>());
+ EXPECT_EQ(exp->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(exp->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(exp->Offset(), 16u);
+ EXPECT_EQ(exp->Size(), 12u);
+ EXPECT_EQ(exp->Align(), 16u);
+ EXPECT_EQ(exp->Name(), Sym("exp"));
+
+ EXPECT_EQ(ty->Size(), 32u);
+ EXPECT_EQ(ty->SizeNoPadding(), 28u);
+}
+
+TEST_F(ResolverBuiltinDataTest, Frexp_Error_FirstParamInt) {
+ Global("v", ty.i32(), ast::StorageClass::kWorkgroup);
+ auto* call = Call("frexp", 1, AddressOf("v"));
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(error: no matching call to frexp(i32, ptr<workgroup, i32, read_write>)
+
+2 candidate functions:
+ frexp(f32) -> __frexp_result
+ frexp(vecN<f32>) -> __frexp_result_vecN
+)");
+}
+
+TEST_F(ResolverBuiltinDataTest, Frexp_Error_SecondParamFloatPtr) {
+ Global("v", ty.f32(), ast::StorageClass::kWorkgroup);
+ auto* call = Call("frexp", 1.0f, AddressOf("v"));
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(error: no matching call to frexp(f32, ptr<workgroup, f32, read_write>)
+
+2 candidate functions:
+ frexp(f32) -> __frexp_result
+ frexp(vecN<f32>) -> __frexp_result_vecN
+)");
+}
+
+TEST_F(ResolverBuiltinDataTest, Frexp_Error_SecondParamNotAPointer) {
+ auto* call = Call("frexp", 1.0f, 1);
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), R"(error: no matching call to frexp(f32, i32)
+
+2 candidate functions:
+ frexp(f32) -> __frexp_result
+ frexp(vecN<f32>) -> __frexp_result_vecN
+)");
+}
+
+TEST_F(ResolverBuiltinDataTest, Frexp_Error_VectorSizesDontMatch) {
+ Global("v", ty.vec4<i32>(), ast::StorageClass::kWorkgroup);
+ auto* call = Call("frexp", vec2<f32>(1.0f, 2.0f), AddressOf("v"));
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(error: no matching call to frexp(vec2<f32>, ptr<workgroup, vec4<i32>, read_write>)
+
+2 candidate functions:
+ frexp(vecN<f32>) -> __frexp_result_vecN
+ frexp(f32) -> __frexp_result
+)");
+}
+
+TEST_F(ResolverBuiltinDataTest, ModfScalar) {
+ auto* call = Call("modf", 1.0f);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ auto* ty = TypeOf(call)->As<sem::Struct>();
+ ASSERT_NE(ty, nullptr);
+ ASSERT_EQ(ty->Members().size(), 2u);
+
+ auto* fract = ty->Members()[0];
+ EXPECT_TRUE(fract->Type()->Is<sem::F32>());
+ EXPECT_EQ(fract->Offset(), 0u);
+ EXPECT_EQ(fract->Size(), 4u);
+ EXPECT_EQ(fract->Align(), 4u);
+ EXPECT_EQ(fract->Name(), Sym("fract"));
+
+ auto* whole = ty->Members()[1];
+ EXPECT_TRUE(whole->Type()->Is<sem::F32>());
+ EXPECT_EQ(whole->Offset(), 4u);
+ EXPECT_EQ(whole->Size(), 4u);
+ EXPECT_EQ(whole->Align(), 4u);
+ EXPECT_EQ(whole->Name(), Sym("whole"));
+
+ EXPECT_EQ(ty->Size(), 8u);
+ EXPECT_EQ(ty->SizeNoPadding(), 8u);
+}
+
+TEST_F(ResolverBuiltinDataTest, ModfVector) {
+ auto* call = Call("modf", vec3<f32>());
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ auto* ty = TypeOf(call)->As<sem::Struct>();
+ ASSERT_NE(ty, nullptr);
+ ASSERT_EQ(ty->Members().size(), 2u);
+
+ auto* fract = ty->Members()[0];
+ ASSERT_TRUE(fract->Type()->Is<sem::Vector>());
+ EXPECT_EQ(fract->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(fract->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(fract->Offset(), 0u);
+ EXPECT_EQ(fract->Size(), 12u);
+ EXPECT_EQ(fract->Align(), 16u);
+ EXPECT_EQ(fract->Name(), Sym("fract"));
+
+ auto* whole = ty->Members()[1];
+ ASSERT_TRUE(whole->Type()->Is<sem::Vector>());
+ EXPECT_EQ(whole->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(whole->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(whole->Offset(), 16u);
+ EXPECT_EQ(whole->Size(), 12u);
+ EXPECT_EQ(whole->Align(), 16u);
+ EXPECT_EQ(whole->Name(), Sym("whole"));
+
+ EXPECT_EQ(ty->Size(), 32u);
+ EXPECT_EQ(ty->SizeNoPadding(), 28u);
+}
+
+TEST_F(ResolverBuiltinDataTest, Modf_Error_FirstParamInt) {
+ Global("whole", ty.f32(), ast::StorageClass::kWorkgroup);
+ auto* call = Call("modf", 1, AddressOf("whole"));
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(error: no matching call to modf(i32, ptr<workgroup, f32, read_write>)
+
+2 candidate functions:
+ modf(f32) -> __modf_result
+ modf(vecN<f32>) -> __modf_result_vecN
+)");
+}
+
+TEST_F(ResolverBuiltinDataTest, Modf_Error_SecondParamIntPtr) {
+ Global("whole", ty.i32(), ast::StorageClass::kWorkgroup);
+ auto* call = Call("modf", 1.0f, AddressOf("whole"));
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(error: no matching call to modf(f32, ptr<workgroup, i32, read_write>)
+
+2 candidate functions:
+ modf(f32) -> __modf_result
+ modf(vecN<f32>) -> __modf_result_vecN
+)");
+}
+
+TEST_F(ResolverBuiltinDataTest, Modf_Error_SecondParamNotAPointer) {
+ auto* call = Call("modf", 1.0f, 1.0f);
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), R"(error: no matching call to modf(f32, f32)
+
+2 candidate functions:
+ modf(f32) -> __modf_result
+ modf(vecN<f32>) -> __modf_result_vecN
+)");
+}
+
+TEST_F(ResolverBuiltinDataTest, Modf_Error_VectorSizesDontMatch) {
+ Global("whole", ty.vec4<f32>(), ast::StorageClass::kWorkgroup);
+ auto* call = Call("modf", vec2<f32>(1.0f, 2.0f), AddressOf("whole"));
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(error: no matching call to modf(vec2<f32>, ptr<workgroup, vec4<f32>, read_write>)
+
+2 candidate functions:
+ modf(vecN<f32>) -> __modf_result_vecN
+ modf(f32) -> __modf_result
+)");
+}
+
+using ResolverBuiltinTest_SingleParam_FloatOrInt =
+ ResolverTestWithParam<BuiltinData>;
+TEST_P(ResolverBuiltinTest_SingleParam_FloatOrInt, Float_Scalar) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, 1.f);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_scalar());
+}
+
+TEST_P(ResolverBuiltinTest_SingleParam_FloatOrInt, Float_Vector) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, vec3<f32>(1.0f, 1.0f, 3.0f));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_vector());
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_P(ResolverBuiltinTest_SingleParam_FloatOrInt, Sint_Scalar) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, -1);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<sem::I32>());
+}
+
+TEST_P(ResolverBuiltinTest_SingleParam_FloatOrInt, Sint_Vector) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, vec3<i32>(1, 1, 3));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_signed_integer_vector());
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_P(ResolverBuiltinTest_SingleParam_FloatOrInt, Uint_Scalar) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, 1u);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<sem::U32>());
+}
+
+TEST_P(ResolverBuiltinTest_SingleParam_FloatOrInt, Uint_Vector) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, vec3<u32>(1u, 1u, 3u));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_unsigned_integer_vector());
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_P(ResolverBuiltinTest_SingleParam_FloatOrInt, Error_NoParams) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name);
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "error: no matching call to " + std::string(param.name) +
+ "()\n\n"
+ "2 candidate functions:\n " +
+ std::string(param.name) +
+ "(T) -> T where: T is f32, i32 or u32\n " +
+ std::string(param.name) +
+ "(vecN<T>) -> vecN<T> where: T is f32, i32 or u32\n");
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverTest,
+ ResolverBuiltinTest_SingleParam_FloatOrInt,
+ testing::Values(BuiltinData{"abs",
+ BuiltinType::kAbs}));
+
+TEST_F(ResolverBuiltinTest, Length_Scalar) {
+ auto* call = Call("length", 1.f);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_scalar());
+}
+
+TEST_F(ResolverBuiltinTest, Length_FloatVector) {
+ auto* call = Call("length", vec3<f32>(1.0f, 1.0f, 3.0f));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_scalar());
+}
+
+using ResolverBuiltinTest_TwoParam = ResolverTestWithParam<BuiltinData>;
+TEST_P(ResolverBuiltinTest_TwoParam, Scalar) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, 1.f, 1.f);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_scalar());
+}
+
+TEST_P(ResolverBuiltinTest_TwoParam, Vector) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, vec3<f32>(1.0f, 1.0f, 3.0f),
+ vec3<f32>(1.0f, 1.0f, 3.0f));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_vector());
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_P(ResolverBuiltinTest_TwoParam, Error_NoTooManyParams) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, 1, 2, 3);
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "error: no matching call to " + std::string(param.name) +
+ "(i32, i32, i32)\n\n"
+ "2 candidate functions:\n " +
+ std::string(param.name) + "(f32, f32) -> f32\n " +
+ std::string(param.name) +
+ "(vecN<f32>, vecN<f32>) -> vecN<f32>\n");
+}
+
+TEST_P(ResolverBuiltinTest_TwoParam, Error_NoParams) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name);
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "error: no matching call to " + std::string(param.name) +
+ "()\n\n"
+ "2 candidate functions:\n " +
+ std::string(param.name) + "(f32, f32) -> f32\n " +
+ std::string(param.name) +
+ "(vecN<f32>, vecN<f32>) -> vecN<f32>\n");
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ ResolverTest,
+ ResolverBuiltinTest_TwoParam,
+ testing::Values(BuiltinData{"atan2", BuiltinType::kAtan2},
+ BuiltinData{"pow", BuiltinType::kPow},
+ BuiltinData{"step", BuiltinType::kStep}));
+
+TEST_F(ResolverBuiltinTest, Distance_Scalar) {
+ auto* call = Call("distance", 1.f, 1.f);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_scalar());
+}
+
+TEST_F(ResolverBuiltinTest, Distance_Vector) {
+ auto* call = Call("distance", vec3<f32>(1.0f, 1.0f, 3.0f),
+ vec3<f32>(1.0f, 1.0f, 3.0f));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
+}
+
+TEST_F(ResolverBuiltinTest, Cross) {
+ auto* call =
+ Call("cross", vec3<f32>(1.0f, 2.0f, 3.0f), vec3<f32>(1.0f, 2.0f, 3.0f));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_vector());
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_F(ResolverBuiltinTest, Cross_Error_NoArgs) {
+ auto* call = Call("cross");
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), R"(error: no matching call to cross()
+
+1 candidate function:
+ cross(vec3<f32>, vec3<f32>) -> vec3<f32>
+)");
+}
+
+TEST_F(ResolverBuiltinTest, Cross_Error_Scalar) {
+ auto* call = Call("cross", 1.0f, 1.0f);
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), R"(error: no matching call to cross(f32, f32)
+
+1 candidate function:
+ cross(vec3<f32>, vec3<f32>) -> vec3<f32>
+)");
+}
+
+TEST_F(ResolverBuiltinTest, Cross_Error_Vec3Int) {
+ auto* call = Call("cross", vec3<i32>(1, 2, 3), vec3<i32>(1, 2, 3));
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ R"(error: no matching call to cross(vec3<i32>, vec3<i32>)
+
+1 candidate function:
+ cross(vec3<f32>, vec3<f32>) -> vec3<f32>
+)");
+}
+
+TEST_F(ResolverBuiltinTest, Cross_Error_Vec4) {
+ auto* call = Call("cross", vec4<f32>(1.0f, 2.0f, 3.0f, 4.0f),
+ vec4<f32>(1.0f, 2.0f, 3.0f, 4.0f));
+
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ R"(error: no matching call to cross(vec4<f32>, vec4<f32>)
+
+1 candidate function:
+ cross(vec3<f32>, vec3<f32>) -> vec3<f32>
+)");
+}
+
+TEST_F(ResolverBuiltinTest, Cross_Error_TooManyParams) {
+ auto* call = Call("cross", vec3<f32>(1.0f, 2.0f, 3.0f),
+ vec3<f32>(1.0f, 2.0f, 3.0f), vec3<f32>(1.0f, 2.0f, 3.0f));
+
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ R"(error: no matching call to cross(vec3<f32>, vec3<f32>, vec3<f32>)
+
+1 candidate function:
+ cross(vec3<f32>, vec3<f32>) -> vec3<f32>
+)");
+}
+TEST_F(ResolverBuiltinTest, Normalize) {
+ auto* call = Call("normalize", vec3<f32>(1.0f, 1.0f, 3.0f));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_vector());
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_F(ResolverBuiltinTest, Normalize_NoArgs) {
+ auto* call = Call("normalize");
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), R"(error: no matching call to normalize()
+
+1 candidate function:
+ normalize(vecN<f32>) -> vecN<f32>
+)");
+}
+
+using ResolverBuiltinTest_ThreeParam = ResolverTestWithParam<BuiltinData>;
+TEST_P(ResolverBuiltinTest_ThreeParam, Scalar) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, 1.f, 1.f, 1.f);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_scalar());
+}
+
+TEST_P(ResolverBuiltinTest_ThreeParam, Vector) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, vec3<f32>(1.0f, 1.0f, 3.0f),
+ vec3<f32>(1.0f, 1.0f, 3.0f), vec3<f32>(1.0f, 1.0f, 3.0f));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_vector());
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
+}
+TEST_P(ResolverBuiltinTest_ThreeParam, Error_NoParams) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name);
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_THAT(r()->error(), HasSubstr("error: no matching call to " +
+ std::string(param.name) + "()"));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ ResolverTest,
+ ResolverBuiltinTest_ThreeParam,
+ testing::Values(BuiltinData{"mix", BuiltinType::kMix},
+ BuiltinData{"smoothStep", BuiltinType::kSmoothStep},
+ BuiltinData{"fma", BuiltinType::kFma}));
+
+using ResolverBuiltinTest_ThreeParam_FloatOrInt =
+ ResolverTestWithParam<BuiltinData>;
+TEST_P(ResolverBuiltinTest_ThreeParam_FloatOrInt, Float_Scalar) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, 1.f, 1.f, 1.f);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_scalar());
+}
+
+TEST_P(ResolverBuiltinTest_ThreeParam_FloatOrInt, Float_Vector) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, vec3<f32>(1.0f, 1.0f, 3.0f),
+ vec3<f32>(1.0f, 1.0f, 3.0f), vec3<f32>(1.0f, 1.0f, 3.0f));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_vector());
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_P(ResolverBuiltinTest_ThreeParam_FloatOrInt, Sint_Scalar) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, 1, 1, 1);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<sem::I32>());
+}
+
+TEST_P(ResolverBuiltinTest_ThreeParam_FloatOrInt, Sint_Vector) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, vec3<i32>(1, 1, 3), vec3<i32>(1, 1, 3),
+ vec3<i32>(1, 1, 3));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_signed_integer_vector());
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_P(ResolverBuiltinTest_ThreeParam_FloatOrInt, Uint_Scalar) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, 1u, 1u, 1u);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<sem::U32>());
+}
+
+TEST_P(ResolverBuiltinTest_ThreeParam_FloatOrInt, Uint_Vector) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, vec3<u32>(1u, 1u, 3u), vec3<u32>(1u, 1u, 3u),
+ vec3<u32>(1u, 1u, 3u));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_unsigned_integer_vector());
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_P(ResolverBuiltinTest_ThreeParam_FloatOrInt, Error_NoParams) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name);
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "error: no matching call to " + std::string(param.name) +
+ "()\n\n"
+ "2 candidate functions:\n " +
+ std::string(param.name) +
+ "(T, T, T) -> T where: T is f32, i32 or u32\n " +
+ std::string(param.name) +
+ "(vecN<T>, vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32 "
+ "or u32\n");
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverTest,
+ ResolverBuiltinTest_ThreeParam_FloatOrInt,
+ testing::Values(BuiltinData{"clamp",
+ BuiltinType::kClamp}));
+
+using ResolverBuiltinTest_Int_SingleParam = ResolverTestWithParam<BuiltinData>;
+TEST_P(ResolverBuiltinTest_Int_SingleParam, Scalar) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, 1);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_integer_scalar());
+}
+
+TEST_P(ResolverBuiltinTest_Int_SingleParam, Vector) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, vec3<i32>(1, 1, 3));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_signed_integer_vector());
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_P(ResolverBuiltinTest_Int_SingleParam, Error_NoParams) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name);
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), "error: no matching call to " +
+ std::string(param.name) +
+ "()\n\n"
+ "2 candidate functions:\n " +
+ std::string(param.name) +
+ "(T) -> T where: T is i32 or u32\n " +
+ std::string(param.name) +
+ "(vecN<T>) -> vecN<T> where: T is i32 or u32\n");
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ ResolverTest,
+ ResolverBuiltinTest_Int_SingleParam,
+ testing::Values(BuiltinData{"countOneBits", BuiltinType::kCountOneBits},
+ BuiltinData{"reverseBits", BuiltinType::kReverseBits}));
+
+using ResolverBuiltinTest_FloatOrInt_TwoParam =
+ ResolverTestWithParam<BuiltinData>;
+TEST_P(ResolverBuiltinTest_FloatOrInt_TwoParam, Scalar_Signed) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, 1, 1);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<sem::I32>());
+}
+
+TEST_P(ResolverBuiltinTest_FloatOrInt_TwoParam, Scalar_Unsigned) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, 1u, 1u);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<sem::U32>());
+}
+
+TEST_P(ResolverBuiltinTest_FloatOrInt_TwoParam, Scalar_Float) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, 1.0f, 1.0f);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
+}
+
+TEST_P(ResolverBuiltinTest_FloatOrInt_TwoParam, Vector_Signed) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, vec3<i32>(1, 1, 3), vec3<i32>(1, 1, 3));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_signed_integer_vector());
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_P(ResolverBuiltinTest_FloatOrInt_TwoParam, Vector_Unsigned) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name, vec3<u32>(1u, 1u, 3u), vec3<u32>(1u, 1u, 3u));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_unsigned_integer_vector());
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_P(ResolverBuiltinTest_FloatOrInt_TwoParam, Vector_Float) {
+ auto param = GetParam();
+
+ auto* call =
+ Call(param.name, vec3<f32>(1.f, 1.f, 3.f), vec3<f32>(1.f, 1.f, 3.f));
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->is_float_vector());
+ EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_P(ResolverBuiltinTest_FloatOrInt_TwoParam, Error_NoParams) {
+ auto param = GetParam();
+
+ auto* call = Call(param.name);
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "error: no matching call to " + std::string(param.name) +
+ "()\n\n"
+ "2 candidate functions:\n " +
+ std::string(param.name) +
+ "(T, T) -> T where: T is f32, i32 or u32\n " +
+ std::string(param.name) +
+ "(vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32 or u32\n");
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverTest,
+ ResolverBuiltinTest_FloatOrInt_TwoParam,
+ testing::Values(BuiltinData{"min", BuiltinType::kMin},
+ BuiltinData{"max",
+ BuiltinType::kMax}));
+
+TEST_F(ResolverBuiltinTest, Determinant_2x2) {
+ Global("var", ty.mat2x2<f32>(), ast::StorageClass::kPrivate);
+
+ auto* call = Call("determinant", "var");
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
+}
+
+TEST_F(ResolverBuiltinTest, Determinant_3x3) {
+ Global("var", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
+
+ auto* call = Call("determinant", "var");
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
+}
+
+TEST_F(ResolverBuiltinTest, Determinant_4x4) {
+ Global("var", ty.mat4x4<f32>(), ast::StorageClass::kPrivate);
+
+ auto* call = Call("determinant", "var");
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
+}
+
+TEST_F(ResolverBuiltinTest, Determinant_NotSquare) {
+ Global("var", ty.mat2x3<f32>(), ast::StorageClass::kPrivate);
+
+ auto* call = Call("determinant", "var");
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), R"(error: no matching call to determinant(mat2x3<f32>)
+
+1 candidate function:
+ determinant(matNxN<f32>) -> f32
+)");
+}
+
+TEST_F(ResolverBuiltinTest, Determinant_NotMatrix) {
+ Global("var", ty.f32(), ast::StorageClass::kPrivate);
+
+ auto* call = Call("determinant", "var");
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), R"(error: no matching call to determinant(f32)
+
+1 candidate function:
+ determinant(matNxN<f32>) -> f32
+)");
+}
+
+using ResolverBuiltinTest_Texture =
+ ResolverTestWithParam<ast::builtin::test::TextureOverloadCase>;
+
+INSTANTIATE_TEST_SUITE_P(
+ ResolverTest,
+ ResolverBuiltinTest_Texture,
+ testing::ValuesIn(ast::builtin::test::TextureOverloadCase::ValidCases()));
+
+std::string to_str(const std::string& function,
+ const sem::ParameterList& params) {
+ std::stringstream out;
+ out << function << "(";
+ bool first = true;
+ for (auto* param : params) {
+ if (!first) {
+ out << ", ";
+ }
+ out << sem::str(param->Usage());
+ first = false;
+ }
+ out << ")";
+ return out.str();
+}
+
+const char* expected_texture_overload(
+ ast::builtin::test::ValidTextureOverload overload) {
+ using ValidTextureOverload = ast::builtin::test::ValidTextureOverload;
+ switch (overload) {
+ case ValidTextureOverload::kDimensions1d:
+ case ValidTextureOverload::kDimensions2d:
+ case ValidTextureOverload::kDimensions2dArray:
+ case ValidTextureOverload::kDimensions3d:
+ case ValidTextureOverload::kDimensionsCube:
+ case ValidTextureOverload::kDimensionsCubeArray:
+ case ValidTextureOverload::kDimensionsMultisampled2d:
+ case ValidTextureOverload::kDimensionsDepth2d:
+ case ValidTextureOverload::kDimensionsDepth2dArray:
+ case ValidTextureOverload::kDimensionsDepthCube:
+ case ValidTextureOverload::kDimensionsDepthCubeArray:
+ case ValidTextureOverload::kDimensionsDepthMultisampled2d:
+ case ValidTextureOverload::kDimensionsStorageWO1d:
+ case ValidTextureOverload::kDimensionsStorageWO2d:
+ case ValidTextureOverload::kDimensionsStorageWO2dArray:
+ case ValidTextureOverload::kDimensionsStorageWO3d:
+ return R"(textureDimensions(texture))";
+ case ValidTextureOverload::kGather2dF32:
+ return R"(textureGather(component, texture, sampler, coords))";
+ case ValidTextureOverload::kGather2dOffsetF32:
+ return R"(textureGather(component, texture, sampler, coords, offset))";
+ case ValidTextureOverload::kGather2dArrayF32:
+ return R"(textureGather(component, texture, sampler, coords, array_index))";
+ case ValidTextureOverload::kGather2dArrayOffsetF32:
+ return R"(textureGather(component, texture, sampler, coords, array_index, offset))";
+ case ValidTextureOverload::kGatherCubeF32:
+ return R"(textureGather(component, texture, sampler, coords))";
+ case ValidTextureOverload::kGatherCubeArrayF32:
+ return R"(textureGather(component, texture, sampler, coords, array_index))";
+ case ValidTextureOverload::kGatherDepth2dF32:
+ return R"(textureGather(texture, sampler, coords))";
+ case ValidTextureOverload::kGatherDepth2dOffsetF32:
+ return R"(textureGather(texture, sampler, coords, offset))";
+ case ValidTextureOverload::kGatherDepth2dArrayF32:
+ return R"(textureGather(texture, sampler, coords, array_index))";
+ case ValidTextureOverload::kGatherDepth2dArrayOffsetF32:
+ return R"(textureGather(texture, sampler, coords, array_index, offset))";
+ case ValidTextureOverload::kGatherDepthCubeF32:
+ return R"(textureGather(texture, sampler, coords))";
+ case ValidTextureOverload::kGatherDepthCubeArrayF32:
+ return R"(textureGather(texture, sampler, coords, array_index))";
+ case ValidTextureOverload::kGatherCompareDepth2dF32:
+ return R"(textureGatherCompare(texture, sampler, coords, depth_ref))";
+ case ValidTextureOverload::kGatherCompareDepth2dOffsetF32:
+ return R"(textureGatherCompare(texture, sampler, coords, depth_ref, offset))";
+ case ValidTextureOverload::kGatherCompareDepth2dArrayF32:
+ return R"(textureGatherCompare(texture, sampler, coords, array_index, depth_ref))";
+ case ValidTextureOverload::kGatherCompareDepth2dArrayOffsetF32:
+ return R"(textureGatherCompare(texture, sampler, coords, array_index, depth_ref, offset))";
+ case ValidTextureOverload::kGatherCompareDepthCubeF32:
+ return R"(textureGatherCompare(texture, sampler, coords, depth_ref))";
+ case ValidTextureOverload::kGatherCompareDepthCubeArrayF32:
+ return R"(textureGatherCompare(texture, sampler, coords, array_index, depth_ref))";
+ case ValidTextureOverload::kNumLayers2dArray:
+ case ValidTextureOverload::kNumLayersCubeArray:
+ case ValidTextureOverload::kNumLayersDepth2dArray:
+ case ValidTextureOverload::kNumLayersDepthCubeArray:
+ case ValidTextureOverload::kNumLayersStorageWO2dArray:
+ return R"(textureNumLayers(texture))";
+ case ValidTextureOverload::kNumLevels2d:
+ case ValidTextureOverload::kNumLevels2dArray:
+ case ValidTextureOverload::kNumLevels3d:
+ case ValidTextureOverload::kNumLevelsCube:
+ case ValidTextureOverload::kNumLevelsCubeArray:
+ case ValidTextureOverload::kNumLevelsDepth2d:
+ case ValidTextureOverload::kNumLevelsDepth2dArray:
+ case ValidTextureOverload::kNumLevelsDepthCube:
+ case ValidTextureOverload::kNumLevelsDepthCubeArray:
+ return R"(textureNumLevels(texture))";
+ case ValidTextureOverload::kNumSamplesDepthMultisampled2d:
+ case ValidTextureOverload::kNumSamplesMultisampled2d:
+ return R"(textureNumSamples(texture))";
+ case ValidTextureOverload::kDimensions2dLevel:
+ case ValidTextureOverload::kDimensions2dArrayLevel:
+ case ValidTextureOverload::kDimensions3dLevel:
+ case ValidTextureOverload::kDimensionsCubeLevel:
+ case ValidTextureOverload::kDimensionsCubeArrayLevel:
+ case ValidTextureOverload::kDimensionsDepth2dLevel:
+ case ValidTextureOverload::kDimensionsDepth2dArrayLevel:
+ case ValidTextureOverload::kDimensionsDepthCubeLevel:
+ case ValidTextureOverload::kDimensionsDepthCubeArrayLevel:
+ return R"(textureDimensions(texture, level))";
+ case ValidTextureOverload::kSample1dF32:
+ return R"(textureSample(texture, sampler, coords))";
+ case ValidTextureOverload::kSample2dF32:
+ return R"(textureSample(texture, sampler, coords))";
+ case ValidTextureOverload::kSample2dOffsetF32:
+ return R"(textureSample(texture, sampler, coords, offset))";
+ case ValidTextureOverload::kSample2dArrayF32:
+ return R"(textureSample(texture, sampler, coords, array_index))";
+ case ValidTextureOverload::kSample2dArrayOffsetF32:
+ return R"(textureSample(texture, sampler, coords, array_index, offset))";
+ case ValidTextureOverload::kSample3dF32:
+ return R"(textureSample(texture, sampler, coords))";
+ case ValidTextureOverload::kSample3dOffsetF32:
+ return R"(textureSample(texture, sampler, coords, offset))";
+ case ValidTextureOverload::kSampleCubeF32:
+ return R"(textureSample(texture, sampler, coords))";
+ case ValidTextureOverload::kSampleCubeArrayF32:
+ return R"(textureSample(texture, sampler, coords, array_index))";
+ case ValidTextureOverload::kSampleDepth2dF32:
+ return R"(textureSample(texture, sampler, coords))";
+ case ValidTextureOverload::kSampleDepth2dOffsetF32:
+ return R"(textureSample(texture, sampler, coords, offset))";
+ case ValidTextureOverload::kSampleDepth2dArrayF32:
+ return R"(textureSample(texture, sampler, coords, array_index))";
+ case ValidTextureOverload::kSampleDepth2dArrayOffsetF32:
+ return R"(textureSample(texture, sampler, coords, array_index, offset))";
+ case ValidTextureOverload::kSampleDepthCubeF32:
+ return R"(textureSample(texture, sampler, coords))";
+ case ValidTextureOverload::kSampleDepthCubeArrayF32:
+ return R"(textureSample(texture, sampler, coords, array_index))";
+ case ValidTextureOverload::kSampleBias2dF32:
+ return R"(textureSampleBias(texture, sampler, coords, bias))";
+ case ValidTextureOverload::kSampleBias2dOffsetF32:
+ return R"(textureSampleBias(texture, sampler, coords, bias, offset))";
+ case ValidTextureOverload::kSampleBias2dArrayF32:
+ return R"(textureSampleBias(texture, sampler, coords, array_index, bias))";
+ case ValidTextureOverload::kSampleBias2dArrayOffsetF32:
+ return R"(textureSampleBias(texture, sampler, coords, array_index, bias, offset))";
+ case ValidTextureOverload::kSampleBias3dF32:
+ return R"(textureSampleBias(texture, sampler, coords, bias))";
+ case ValidTextureOverload::kSampleBias3dOffsetF32:
+ return R"(textureSampleBias(texture, sampler, coords, bias, offset))";
+ case ValidTextureOverload::kSampleBiasCubeF32:
+ return R"(textureSampleBias(texture, sampler, coords, bias))";
+ case ValidTextureOverload::kSampleBiasCubeArrayF32:
+ return R"(textureSampleBias(texture, sampler, coords, array_index, bias))";
+ case ValidTextureOverload::kSampleLevel2dF32:
+ return R"(textureSampleLevel(texture, sampler, coords, level))";
+ case ValidTextureOverload::kSampleLevel2dOffsetF32:
+ return R"(textureSampleLevel(texture, sampler, coords, level, offset))";
+ case ValidTextureOverload::kSampleLevel2dArrayF32:
+ return R"(textureSampleLevel(texture, sampler, coords, array_index, level))";
+ case ValidTextureOverload::kSampleLevel2dArrayOffsetF32:
+ return R"(textureSampleLevel(texture, sampler, coords, array_index, level, offset))";
+ case ValidTextureOverload::kSampleLevel3dF32:
+ return R"(textureSampleLevel(texture, sampler, coords, level))";
+ case ValidTextureOverload::kSampleLevel3dOffsetF32:
+ return R"(textureSampleLevel(texture, sampler, coords, level, offset))";
+ case ValidTextureOverload::kSampleLevelCubeF32:
+ return R"(textureSampleLevel(texture, sampler, coords, level))";
+ case ValidTextureOverload::kSampleLevelCubeArrayF32:
+ return R"(textureSampleLevel(texture, sampler, coords, array_index, level))";
+ case ValidTextureOverload::kSampleLevelDepth2dF32:
+ return R"(textureSampleLevel(texture, sampler, coords, level))";
+ case ValidTextureOverload::kSampleLevelDepth2dOffsetF32:
+ return R"(textureSampleLevel(texture, sampler, coords, level, offset))";
+ case ValidTextureOverload::kSampleLevelDepth2dArrayF32:
+ return R"(textureSampleLevel(texture, sampler, coords, array_index, level))";
+ case ValidTextureOverload::kSampleLevelDepth2dArrayOffsetF32:
+ return R"(textureSampleLevel(texture, sampler, coords, array_index, level, offset))";
+ case ValidTextureOverload::kSampleLevelDepthCubeF32:
+ return R"(textureSampleLevel(texture, sampler, coords, level))";
+ case ValidTextureOverload::kSampleLevelDepthCubeArrayF32:
+ return R"(textureSampleLevel(texture, sampler, coords, array_index, level))";
+ case ValidTextureOverload::kSampleGrad2dF32:
+ return R"(textureSampleGrad(texture, sampler, coords, ddx, ddy))";
+ case ValidTextureOverload::kSampleGrad2dOffsetF32:
+ return R"(textureSampleGrad(texture, sampler, coords, ddx, ddy, offset))";
+ case ValidTextureOverload::kSampleGrad2dArrayF32:
+ return R"(textureSampleGrad(texture, sampler, coords, array_index, ddx, ddy))";
+ case ValidTextureOverload::kSampleGrad2dArrayOffsetF32:
+ return R"(textureSampleGrad(texture, sampler, coords, array_index, ddx, ddy, offset))";
+ case ValidTextureOverload::kSampleGrad3dF32:
+ return R"(textureSampleGrad(texture, sampler, coords, ddx, ddy))";
+ case ValidTextureOverload::kSampleGrad3dOffsetF32:
+ return R"(textureSampleGrad(texture, sampler, coords, ddx, ddy, offset))";
+ case ValidTextureOverload::kSampleGradCubeF32:
+ return R"(textureSampleGrad(texture, sampler, coords, ddx, ddy))";
+ case ValidTextureOverload::kSampleGradCubeArrayF32:
+ return R"(textureSampleGrad(texture, sampler, coords, array_index, ddx, ddy))";
+ case ValidTextureOverload::kSampleCompareDepth2dF32:
+ return R"(textureSampleCompare(texture, sampler, coords, depth_ref))";
+ case ValidTextureOverload::kSampleCompareDepth2dOffsetF32:
+ return R"(textureSampleCompare(texture, sampler, coords, depth_ref, offset))";
+ case ValidTextureOverload::kSampleCompareDepth2dArrayF32:
+ return R"(textureSampleCompare(texture, sampler, coords, array_index, depth_ref))";
+ case ValidTextureOverload::kSampleCompareDepth2dArrayOffsetF32:
+ return R"(textureSampleCompare(texture, sampler, coords, array_index, depth_ref, offset))";
+ case ValidTextureOverload::kSampleCompareDepthCubeF32:
+ return R"(textureSampleCompare(texture, sampler, coords, depth_ref))";
+ case ValidTextureOverload::kSampleCompareDepthCubeArrayF32:
+ return R"(textureSampleCompare(texture, sampler, coords, array_index, depth_ref))";
+ case ValidTextureOverload::kSampleCompareLevelDepth2dF32:
+ return R"(textureSampleCompare(texture, sampler, coords, depth_ref))";
+ case ValidTextureOverload::kSampleCompareLevelDepth2dOffsetF32:
+ return R"(textureSampleCompare(texture, sampler, coords, depth_ref, offset))";
+ case ValidTextureOverload::kSampleCompareLevelDepth2dArrayF32:
+ return R"(textureSampleCompare(texture, sampler, coords, array_index, depth_ref))";
+ case ValidTextureOverload::kSampleCompareLevelDepth2dArrayOffsetF32:
+ return R"(textureSampleCompare(texture, sampler, coords, array_index, depth_ref, offset))";
+ case ValidTextureOverload::kSampleCompareLevelDepthCubeF32:
+ return R"(textureSampleCompare(texture, sampler, coords, depth_ref))";
+ case ValidTextureOverload::kSampleCompareLevelDepthCubeArrayF32:
+ return R"(textureSampleCompare(texture, sampler, coords, array_index, depth_ref))";
+ case ValidTextureOverload::kLoad1dLevelF32:
+ case ValidTextureOverload::kLoad1dLevelU32:
+ case ValidTextureOverload::kLoad1dLevelI32:
+ case ValidTextureOverload::kLoad2dLevelF32:
+ case ValidTextureOverload::kLoad2dLevelU32:
+ case ValidTextureOverload::kLoad2dLevelI32:
+ return R"(textureLoad(texture, coords, level))";
+ case ValidTextureOverload::kLoad2dArrayLevelF32:
+ case ValidTextureOverload::kLoad2dArrayLevelU32:
+ case ValidTextureOverload::kLoad2dArrayLevelI32:
+ return R"(textureLoad(texture, coords, array_index, level))";
+ case ValidTextureOverload::kLoad3dLevelF32:
+ case ValidTextureOverload::kLoad3dLevelU32:
+ case ValidTextureOverload::kLoad3dLevelI32:
+ case ValidTextureOverload::kLoadDepth2dLevelF32:
+ return R"(textureLoad(texture, coords, level))";
+ case ValidTextureOverload::kLoadDepthMultisampled2dF32:
+ case ValidTextureOverload::kLoadMultisampled2dF32:
+ case ValidTextureOverload::kLoadMultisampled2dU32:
+ case ValidTextureOverload::kLoadMultisampled2dI32:
+ return R"(textureLoad(texture, coords, sample_index))";
+ case ValidTextureOverload::kLoadDepth2dArrayLevelF32:
+ return R"(textureLoad(texture, coords, array_index, level))";
+ case ValidTextureOverload::kStoreWO1dRgba32float:
+ case ValidTextureOverload::kStoreWO2dRgba32float:
+ case ValidTextureOverload::kStoreWO3dRgba32float:
+ return R"(textureStore(texture, coords, value))";
+ case ValidTextureOverload::kStoreWO2dArrayRgba32float:
+ return R"(textureStore(texture, coords, array_index, value))";
+ }
+ return "<unmatched texture overload>";
+}
+
+TEST_P(ResolverBuiltinTest_Texture, Call) {
+ auto param = GetParam();
+
+ param.BuildTextureVariable(this);
+ param.BuildSamplerVariable(this);
+
+ auto* call = Call(param.function, param.args(this));
+ auto* stmt = CallStmt(call);
+ Func("func", {}, ty.void_(), {stmt}, {Stage(ast::PipelineStage::kFragment)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ if (std::string(param.function) == "textureDimensions") {
+ switch (param.texture_dimension) {
+ default:
+ FAIL() << "invalid texture dimensions: " << param.texture_dimension;
+ case ast::TextureDimension::k1d:
+ EXPECT_TRUE(TypeOf(call)->Is<sem::I32>());
+ break;
+ case ast::TextureDimension::k2d:
+ case ast::TextureDimension::k2dArray:
+ case ast::TextureDimension::kCube:
+ case ast::TextureDimension::kCubeArray: {
+ auto* vec = As<sem::Vector>(TypeOf(call));
+ ASSERT_NE(vec, nullptr);
+ EXPECT_EQ(vec->Width(), 2u);
+ EXPECT_TRUE(vec->type()->Is<sem::I32>());
+ break;
+ }
+ case ast::TextureDimension::k3d: {
+ auto* vec = As<sem::Vector>(TypeOf(call));
+ ASSERT_NE(vec, nullptr);
+ EXPECT_EQ(vec->Width(), 3u);
+ EXPECT_TRUE(vec->type()->Is<sem::I32>());
+ break;
+ }
+ }
+ } else if (std::string(param.function) == "textureNumLayers") {
+ EXPECT_TRUE(TypeOf(call)->Is<sem::I32>());
+ } else if (std::string(param.function) == "textureNumLevels") {
+ EXPECT_TRUE(TypeOf(call)->Is<sem::I32>());
+ } else if (std::string(param.function) == "textureNumSamples") {
+ EXPECT_TRUE(TypeOf(call)->Is<sem::I32>());
+ } else if (std::string(param.function) == "textureStore") {
+ EXPECT_TRUE(TypeOf(call)->Is<sem::Void>());
+ } else if (std::string(param.function) == "textureGather") {
+ auto* vec = As<sem::Vector>(TypeOf(call));
+ ASSERT_NE(vec, nullptr);
+ EXPECT_EQ(vec->Width(), 4u);
+ switch (param.texture_data_type) {
+ case ast::builtin::test::TextureDataType::kF32:
+ EXPECT_TRUE(vec->type()->Is<sem::F32>());
+ break;
+ case ast::builtin::test::TextureDataType::kU32:
+ EXPECT_TRUE(vec->type()->Is<sem::U32>());
+ break;
+ case ast::builtin::test::TextureDataType::kI32:
+ EXPECT_TRUE(vec->type()->Is<sem::I32>());
+ break;
+ }
+ } else if (std::string(param.function) == "textureGatherCompare") {
+ auto* vec = As<sem::Vector>(TypeOf(call));
+ ASSERT_NE(vec, nullptr);
+ EXPECT_EQ(vec->Width(), 4u);
+ EXPECT_TRUE(vec->type()->Is<sem::F32>());
+ } else {
+ switch (param.texture_kind) {
+ case ast::builtin::test::TextureKind::kRegular:
+ case ast::builtin::test::TextureKind::kMultisampled:
+ case ast::builtin::test::TextureKind::kStorage: {
+ auto* vec = TypeOf(call)->As<sem::Vector>();
+ ASSERT_NE(vec, nullptr);
+ switch (param.texture_data_type) {
+ case ast::builtin::test::TextureDataType::kF32:
+ EXPECT_TRUE(vec->type()->Is<sem::F32>());
+ break;
+ case ast::builtin::test::TextureDataType::kU32:
+ EXPECT_TRUE(vec->type()->Is<sem::U32>());
+ break;
+ case ast::builtin::test::TextureDataType::kI32:
+ EXPECT_TRUE(vec->type()->Is<sem::I32>());
+ break;
+ }
+ break;
+ }
+ case ast::builtin::test::TextureKind::kDepth:
+ case ast::builtin::test::TextureKind::kDepthMultisampled: {
+ EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
+ break;
+ }
+ }
+ }
+
+ auto* call_sem = Sem().Get(call);
+ ASSERT_NE(call_sem, nullptr);
+ auto* target = call_sem->Target();
+ ASSERT_NE(target, nullptr);
+
+ auto got = resolver::to_str(param.function, target->Parameters());
+ auto* expected = expected_texture_overload(param.overload);
+ EXPECT_EQ(got, expected);
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/builtin_validation_test.cc b/src/tint/resolver/builtin_validation_test.cc
new file mode 100644
index 0000000..0ed2f50
--- /dev/null
+++ b/src/tint/resolver/builtin_validation_test.cc
@@ -0,0 +1,402 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/builtin_texture_helper_test.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ResolverBuiltinValidationTest = ResolverTest;
+
+TEST_F(ResolverBuiltinValidationTest,
+ FunctionTypeMustMatchReturnStatementType_void_fail) {
+ // fn func { return workgroupBarrier(); }
+ Func("func", {}, ty.void_(),
+ {
+ Return(Call(Source{Source::Location{12, 34}}, "workgroupBarrier")),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: builtin 'workgroupBarrier' does not return a value");
+}
+
+TEST_F(ResolverBuiltinValidationTest, InvalidPipelineStageDirect) {
+ // @stage(compute) @workgroup_size(1) fn func { return dpdx(1.0); }
+
+ auto* dpdx = create<ast::CallExpression>(Source{{3, 4}}, Expr("dpdx"),
+ ast::ExpressionList{Expr(1.0f)});
+ Func(Source{{1, 2}}, "func", ast::VariableList{}, ty.void_(),
+ {CallStmt(dpdx)},
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "3:4 error: built-in cannot be used by compute pipeline stage");
+}
+
+TEST_F(ResolverBuiltinValidationTest, InvalidPipelineStageIndirect) {
+ // fn f0 { return dpdx(1.0); }
+ // fn f1 { f0(); }
+ // fn f2 { f1(); }
+ // @stage(compute) @workgroup_size(1) fn main { return f2(); }
+
+ auto* dpdx = create<ast::CallExpression>(Source{{3, 4}}, Expr("dpdx"),
+ ast::ExpressionList{Expr(1.0f)});
+ Func(Source{{1, 2}}, "f0", {}, ty.void_(), {CallStmt(dpdx)});
+
+ Func(Source{{3, 4}}, "f1", {}, ty.void_(), {CallStmt(Call("f0"))});
+
+ Func(Source{{5, 6}}, "f2", {}, ty.void_(), {CallStmt(Call("f1"))});
+
+ Func(Source{{7, 8}}, "main", {}, ty.void_(), {CallStmt(Call("f2"))},
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(3:4 error: built-in cannot be used by compute pipeline stage
+1:2 note: called by function 'f0'
+3:4 note: called by function 'f1'
+5:6 note: called by function 'f2'
+7:8 note: called by entry point 'main')");
+}
+
+TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsFunction) {
+ Func(Source{{12, 34}}, "mix", {}, ty.i32(), {});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: 'mix' is a builtin and cannot be redeclared as a function)");
+}
+
+TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsGlobalLet) {
+ GlobalConst(Source{{12, 34}}, "mix", ty.i32(), Expr(1));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: 'mix' is a builtin and cannot be redeclared as a module-scope let)");
+}
+
+TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsGlobalVar) {
+ Global(Source{{12, 34}}, "mix", ty.i32(), Expr(1),
+ ast::StorageClass::kPrivate);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: 'mix' is a builtin and cannot be redeclared as a module-scope var)");
+}
+
+TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsAlias) {
+ Alias(Source{{12, 34}}, "mix", ty.i32());
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: 'mix' is a builtin and cannot be redeclared as an alias)");
+}
+
+TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsStruct) {
+ Structure(Source{{12, 34}}, "mix", {Member("m", ty.i32())});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: 'mix' is a builtin and cannot be redeclared as a struct)");
+}
+
+namespace texture_constexpr_args {
+
+using TextureOverloadCase = ast::builtin::test::TextureOverloadCase;
+using ValidTextureOverload = ast::builtin::test::ValidTextureOverload;
+using TextureKind = ast::builtin::test::TextureKind;
+using TextureDataType = ast::builtin::test::TextureDataType;
+using u32 = ProgramBuilder::u32;
+using i32 = ProgramBuilder::i32;
+using f32 = ProgramBuilder::f32;
+
+static std::vector<TextureOverloadCase> TextureCases(
+ std::unordered_set<ValidTextureOverload> overloads) {
+ std::vector<TextureOverloadCase> cases;
+ for (auto c : TextureOverloadCase::ValidCases()) {
+ if (overloads.count(c.overload)) {
+ cases.push_back(c);
+ }
+ }
+ return cases;
+}
+
+enum class Position {
+ kFirst,
+ kLast,
+};
+
+struct Parameter {
+ const char* const name;
+ const Position position;
+ int min;
+ int max;
+};
+
+class Constexpr {
+ public:
+ enum class Kind {
+ kScalar,
+ kVec2,
+ kVec3,
+ kVec3_Scalar_Vec2,
+ kVec3_Vec2_Scalar,
+ kEmptyVec2,
+ kEmptyVec3,
+ };
+
+ Constexpr(int32_t invalid_idx,
+ Kind k,
+ int32_t x = 0,
+ int32_t y = 0,
+ int32_t z = 0)
+ : invalid_index(invalid_idx), kind(k), values{x, y, z} {}
+
+ const ast::Expression* operator()(Source src, ProgramBuilder& b) {
+ switch (kind) {
+ case Kind::kScalar:
+ return b.Expr(src, values[0]);
+ case Kind::kVec2:
+ return b.Construct(src, b.ty.vec2<i32>(), values[0], values[1]);
+ case Kind::kVec3:
+ return b.Construct(src, b.ty.vec3<i32>(), values[0], values[1],
+ values[2]);
+ case Kind::kVec3_Scalar_Vec2:
+ return b.Construct(src, b.ty.vec3<i32>(), values[0],
+ b.vec2<i32>(values[1], values[2]));
+ case Kind::kVec3_Vec2_Scalar:
+ return b.Construct(src, b.ty.vec3<i32>(),
+ b.vec2<i32>(values[0], values[1]), values[2]);
+ case Kind::kEmptyVec2:
+ return b.Construct(src, b.ty.vec2<i32>());
+ case Kind::kEmptyVec3:
+ return b.Construct(src, b.ty.vec3<i32>());
+ }
+ return nullptr;
+ }
+
+ static const constexpr int32_t kValid = -1;
+ const int32_t invalid_index; // Expected error value, or kValid
+ const Kind kind;
+ const std::array<int32_t, 3> values;
+};
+
+static std::ostream& operator<<(std::ostream& out, Parameter param) {
+ return out << param.name;
+}
+
+static std::ostream& operator<<(std::ostream& out, Constexpr expr) {
+ switch (expr.kind) {
+ case Constexpr::Kind::kScalar:
+ return out << expr.values[0];
+ case Constexpr::Kind::kVec2:
+ return out << "vec2(" << expr.values[0] << ", " << expr.values[1] << ")";
+ case Constexpr::Kind::kVec3:
+ return out << "vec3(" << expr.values[0] << ", " << expr.values[1] << ", "
+ << expr.values[2] << ")";
+ case Constexpr::Kind::kVec3_Scalar_Vec2:
+ return out << "vec3(" << expr.values[0] << ", vec2(" << expr.values[1]
+ << ", " << expr.values[2] << "))";
+ case Constexpr::Kind::kVec3_Vec2_Scalar:
+ return out << "vec3(vec2(" << expr.values[0] << ", " << expr.values[1]
+ << "), " << expr.values[2] << ")";
+ case Constexpr::Kind::kEmptyVec2:
+ return out << "vec2()";
+ case Constexpr::Kind::kEmptyVec3:
+ return out << "vec3()";
+ }
+ return out;
+}
+
+using BuiltinTextureConstExprArgValidationTest = ResolverTestWithParam<
+ std::tuple<TextureOverloadCase, Parameter, Constexpr>>;
+
+TEST_P(BuiltinTextureConstExprArgValidationTest, Immediate) {
+ auto& p = GetParam();
+ auto overload = std::get<0>(p);
+ auto param = std::get<1>(p);
+ auto expr = std::get<2>(p);
+
+ overload.BuildTextureVariable(this);
+ overload.BuildSamplerVariable(this);
+
+ auto args = overload.args(this);
+ auto*& arg_to_replace =
+ (param.position == Position::kFirst) ? args.front() : args.back();
+
+ // BuildTextureVariable() uses a Literal for scalars, and a CallExpression for
+ // a vector constructor.
+ bool is_vector = arg_to_replace->Is<ast::CallExpression>();
+
+ // Make the expression to be replaced, reachable. This keeps the resolver
+ // happy.
+ WrapInFunction(arg_to_replace);
+
+ arg_to_replace = expr(Source{{12, 34}}, *this);
+
+ // Call the builtin with the constexpr argument replaced
+ Func("func", {}, ty.void_(), {CallStmt(Call(overload.function, args))},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ if (expr.invalid_index == Constexpr::kValid) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ std::stringstream err;
+ if (is_vector) {
+ err << "12:34 error: each component of the " << param.name
+ << " argument must be at least " << param.min << " and at most "
+ << param.max << ". " << param.name << " component "
+ << expr.invalid_index << " is "
+ << std::to_string(expr.values[expr.invalid_index]);
+ } else {
+ err << "12:34 error: the " << param.name << " argument must be at least "
+ << param.min << " and at most " << param.max << ". " << param.name
+ << " is " << std::to_string(expr.values[expr.invalid_index]);
+ }
+ EXPECT_EQ(r()->error(), err.str());
+ }
+}
+
+TEST_P(BuiltinTextureConstExprArgValidationTest, GlobalConst) {
+ auto& p = GetParam();
+ auto overload = std::get<0>(p);
+ auto param = std::get<1>(p);
+ auto expr = std::get<2>(p);
+
+ // Build the global texture and sampler variables
+ overload.BuildTextureVariable(this);
+ overload.BuildSamplerVariable(this);
+
+ // Build the module-scope let 'G' with the offset value
+ GlobalConst("G", nullptr, expr({}, *this));
+
+ auto args = overload.args(this);
+ auto*& arg_to_replace =
+ (param.position == Position::kFirst) ? args.front() : args.back();
+
+ // Make the expression to be replaced, reachable. This keeps the resolver
+ // happy.
+ WrapInFunction(arg_to_replace);
+
+ arg_to_replace = Expr(Source{{12, 34}}, "G");
+
+ // Call the builtin with the constexpr argument replaced
+ Func("func", {}, ty.void_(), {CallStmt(Call(overload.function, args))},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ std::stringstream err;
+ err << "12:34 error: the " << param.name
+ << " argument must be a const_expression";
+ EXPECT_EQ(r()->error(), err.str());
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ Offset2D,
+ BuiltinTextureConstExprArgValidationTest,
+ testing::Combine(
+ testing::ValuesIn(TextureCases({
+ ValidTextureOverload::kSample2dOffsetF32,
+ ValidTextureOverload::kSample2dArrayOffsetF32,
+ ValidTextureOverload::kSampleDepth2dOffsetF32,
+ ValidTextureOverload::kSampleDepth2dArrayOffsetF32,
+ ValidTextureOverload::kSampleBias2dOffsetF32,
+ ValidTextureOverload::kSampleBias2dArrayOffsetF32,
+ ValidTextureOverload::kSampleLevel2dOffsetF32,
+ ValidTextureOverload::kSampleLevel2dArrayOffsetF32,
+ ValidTextureOverload::kSampleLevelDepth2dOffsetF32,
+ ValidTextureOverload::kSampleLevelDepth2dArrayOffsetF32,
+ ValidTextureOverload::kSampleGrad2dOffsetF32,
+ ValidTextureOverload::kSampleGrad2dArrayOffsetF32,
+ ValidTextureOverload::kSampleCompareDepth2dOffsetF32,
+ ValidTextureOverload::kSampleCompareDepth2dArrayOffsetF32,
+ ValidTextureOverload::kSampleCompareLevelDepth2dOffsetF32,
+ ValidTextureOverload::kSampleCompareLevelDepth2dArrayOffsetF32,
+ })),
+ testing::Values(Parameter{"offset", Position::kLast, -8, 7}),
+ testing::Values(
+ Constexpr{Constexpr::kValid, Constexpr::Kind::kEmptyVec2},
+ Constexpr{Constexpr::kValid, Constexpr::Kind::kVec2, -1, 1},
+ Constexpr{Constexpr::kValid, Constexpr::Kind::kVec2, 7, -8},
+ Constexpr{0, Constexpr::Kind::kVec2, 8, 0},
+ Constexpr{1, Constexpr::Kind::kVec2, 0, 8},
+ Constexpr{0, Constexpr::Kind::kVec2, -9, 0},
+ Constexpr{1, Constexpr::Kind::kVec2, 0, -9},
+ Constexpr{0, Constexpr::Kind::kVec2, 8, 8},
+ Constexpr{0, Constexpr::Kind::kVec2, -9, -9})));
+
+INSTANTIATE_TEST_SUITE_P(
+ Offset3D,
+ BuiltinTextureConstExprArgValidationTest,
+ testing::Combine(
+ testing::ValuesIn(TextureCases({
+ ValidTextureOverload::kSample3dOffsetF32,
+ ValidTextureOverload::kSampleBias3dOffsetF32,
+ ValidTextureOverload::kSampleLevel3dOffsetF32,
+ ValidTextureOverload::kSampleGrad3dOffsetF32,
+ })),
+ testing::Values(Parameter{"offset", Position::kLast, -8, 7}),
+ testing::Values(
+ Constexpr{Constexpr::kValid, Constexpr::Kind::kEmptyVec3},
+ Constexpr{Constexpr::kValid, Constexpr::Kind::kVec3, 0, 0, 0},
+ Constexpr{Constexpr::kValid, Constexpr::Kind::kVec3, 7, -8, 7},
+ Constexpr{0, Constexpr::Kind::kVec3, 10, 0, 0},
+ Constexpr{1, Constexpr::Kind::kVec3, 0, 10, 0},
+ Constexpr{2, Constexpr::Kind::kVec3, 0, 0, 10},
+ Constexpr{0, Constexpr::Kind::kVec3, 10, 11, 12},
+ Constexpr{0, Constexpr::Kind::kVec3_Scalar_Vec2, 10, 0, 0},
+ Constexpr{1, Constexpr::Kind::kVec3_Scalar_Vec2, 0, 10, 0},
+ Constexpr{2, Constexpr::Kind::kVec3_Scalar_Vec2, 0, 0, 10},
+ Constexpr{0, Constexpr::Kind::kVec3_Scalar_Vec2, 10, 11, 12},
+ Constexpr{0, Constexpr::Kind::kVec3_Vec2_Scalar, 10, 0, 0},
+ Constexpr{1, Constexpr::Kind::kVec3_Vec2_Scalar, 0, 10, 0},
+ Constexpr{2, Constexpr::Kind::kVec3_Vec2_Scalar, 0, 0, 10},
+ Constexpr{0, Constexpr::Kind::kVec3_Vec2_Scalar, 10, 11, 12})));
+
+INSTANTIATE_TEST_SUITE_P(
+ Component,
+ BuiltinTextureConstExprArgValidationTest,
+ testing::Combine(
+ testing::ValuesIn(
+ TextureCases({ValidTextureOverload::kGather2dF32,
+ ValidTextureOverload::kGather2dOffsetF32,
+ ValidTextureOverload::kGather2dArrayF32,
+ ValidTextureOverload::kGather2dArrayOffsetF32,
+ ValidTextureOverload::kGatherCubeF32,
+ ValidTextureOverload::kGatherCubeArrayF32})),
+ testing::Values(Parameter{"component", Position::kFirst, 0, 3}),
+ testing::Values(
+ Constexpr{Constexpr::kValid, Constexpr::Kind::kScalar, 0},
+ Constexpr{Constexpr::kValid, Constexpr::Kind::kScalar, 1},
+ Constexpr{Constexpr::kValid, Constexpr::Kind::kScalar, 2},
+ Constexpr{Constexpr::kValid, Constexpr::Kind::kScalar, 3},
+ Constexpr{0, Constexpr::Kind::kScalar, 4},
+ Constexpr{0, Constexpr::Kind::kScalar, 123},
+ Constexpr{0, Constexpr::Kind::kScalar, -1})));
+
+} // namespace texture_constexpr_args
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/builtins_validation_test.cc b/src/tint/resolver/builtins_validation_test.cc
new file mode 100644
index 0000000..7fd5bd6
--- /dev/null
+++ b/src/tint/resolver/builtins_validation_test.cc
@@ -0,0 +1,1292 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/call_statement.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+template <typename T>
+using DataType = builder::DataType<T>;
+template <typename T>
+using vec2 = builder::vec2<T>;
+template <typename T>
+using vec3 = builder::vec3<T>;
+template <typename T>
+using vec4 = builder::vec4<T>;
+using f32 = builder::f32;
+using i32 = builder::i32;
+using u32 = builder::u32;
+
+class ResolverBuiltinsValidationTest : public resolver::TestHelper,
+ public testing::Test {};
+namespace StageTest {
+struct Params {
+ builder::ast_type_func_ptr type;
+ ast::Builtin builtin;
+ ast::PipelineStage stage;
+ bool is_valid;
+};
+
+template <typename T>
+constexpr Params ParamsFor(ast::Builtin builtin,
+ ast::PipelineStage stage,
+ bool is_valid) {
+ return Params{DataType<T>::AST, builtin, stage, is_valid};
+}
+static constexpr Params cases[] = {
+ ParamsFor<vec4<f32>>(ast::Builtin::kPosition,
+ ast::PipelineStage::kVertex,
+ false),
+ ParamsFor<vec4<f32>>(ast::Builtin::kPosition,
+ ast::PipelineStage::kFragment,
+ true),
+ ParamsFor<vec4<f32>>(ast::Builtin::kPosition,
+ ast::PipelineStage::kCompute,
+ false),
+
+ ParamsFor<u32>(ast::Builtin::kVertexIndex,
+ ast::PipelineStage::kVertex,
+ true),
+ ParamsFor<u32>(ast::Builtin::kVertexIndex,
+ ast::PipelineStage::kFragment,
+ false),
+ ParamsFor<u32>(ast::Builtin::kVertexIndex,
+ ast::PipelineStage::kCompute,
+ false),
+
+ ParamsFor<u32>(ast::Builtin::kInstanceIndex,
+ ast::PipelineStage::kVertex,
+ true),
+ ParamsFor<u32>(ast::Builtin::kInstanceIndex,
+ ast::PipelineStage::kFragment,
+ false),
+ ParamsFor<u32>(ast::Builtin::kInstanceIndex,
+ ast::PipelineStage::kCompute,
+ false),
+
+ ParamsFor<bool>(ast::Builtin::kFrontFacing,
+ ast::PipelineStage::kVertex,
+ false),
+ ParamsFor<bool>(ast::Builtin::kFrontFacing,
+ ast::PipelineStage::kFragment,
+ true),
+ ParamsFor<bool>(ast::Builtin::kFrontFacing,
+ ast::PipelineStage::kCompute,
+ false),
+
+ ParamsFor<vec3<u32>>(ast::Builtin::kLocalInvocationId,
+ ast::PipelineStage::kVertex,
+ false),
+ ParamsFor<vec3<u32>>(ast::Builtin::kLocalInvocationId,
+ ast::PipelineStage::kFragment,
+ false),
+ ParamsFor<vec3<u32>>(ast::Builtin::kLocalInvocationId,
+ ast::PipelineStage::kCompute,
+ true),
+
+ ParamsFor<u32>(ast::Builtin::kLocalInvocationIndex,
+ ast::PipelineStage::kVertex,
+ false),
+ ParamsFor<u32>(ast::Builtin::kLocalInvocationIndex,
+ ast::PipelineStage::kFragment,
+ false),
+ ParamsFor<u32>(ast::Builtin::kLocalInvocationIndex,
+ ast::PipelineStage::kCompute,
+ true),
+
+ ParamsFor<vec3<u32>>(ast::Builtin::kGlobalInvocationId,
+ ast::PipelineStage::kVertex,
+ false),
+ ParamsFor<vec3<u32>>(ast::Builtin::kGlobalInvocationId,
+ ast::PipelineStage::kFragment,
+ false),
+ ParamsFor<vec3<u32>>(ast::Builtin::kGlobalInvocationId,
+ ast::PipelineStage::kCompute,
+ true),
+
+ ParamsFor<vec3<u32>>(ast::Builtin::kWorkgroupId,
+ ast::PipelineStage::kVertex,
+ false),
+ ParamsFor<vec3<u32>>(ast::Builtin::kWorkgroupId,
+ ast::PipelineStage::kFragment,
+ false),
+ ParamsFor<vec3<u32>>(ast::Builtin::kWorkgroupId,
+ ast::PipelineStage::kCompute,
+ true),
+
+ ParamsFor<vec3<u32>>(ast::Builtin::kNumWorkgroups,
+ ast::PipelineStage::kVertex,
+ false),
+ ParamsFor<vec3<u32>>(ast::Builtin::kNumWorkgroups,
+ ast::PipelineStage::kFragment,
+ false),
+ ParamsFor<vec3<u32>>(ast::Builtin::kNumWorkgroups,
+ ast::PipelineStage::kCompute,
+ true),
+
+ ParamsFor<u32>(ast::Builtin::kSampleIndex,
+ ast::PipelineStage::kVertex,
+ false),
+ ParamsFor<u32>(ast::Builtin::kSampleIndex,
+ ast::PipelineStage::kFragment,
+ true),
+ ParamsFor<u32>(ast::Builtin::kSampleIndex,
+ ast::PipelineStage::kCompute,
+ false),
+
+ ParamsFor<u32>(ast::Builtin::kSampleMask,
+ ast::PipelineStage::kVertex,
+ false),
+ ParamsFor<u32>(ast::Builtin::kSampleMask,
+ ast::PipelineStage::kFragment,
+ true),
+ ParamsFor<u32>(ast::Builtin::kSampleMask,
+ ast::PipelineStage::kCompute,
+ false),
+};
+
+using ResolverBuiltinsStageTest = ResolverTestWithParam<Params>;
+TEST_P(ResolverBuiltinsStageTest, All_input) {
+ const Params& params = GetParam();
+
+ auto* p = Global("p", ty.vec4<f32>(), ast::StorageClass::kPrivate);
+ auto* input =
+ Param("input", params.type(*this),
+ ast::AttributeList{Builtin(Source{{12, 34}}, params.builtin)});
+ switch (params.stage) {
+ case ast::PipelineStage::kVertex:
+ Func("main", {input}, ty.vec4<f32>(), {Return(p)},
+ {Stage(ast::PipelineStage::kVertex)},
+ {Builtin(Source{{12, 34}}, ast::Builtin::kPosition)});
+ break;
+ case ast::PipelineStage::kFragment:
+ Func("main", {input}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)}, {});
+ break;
+ case ast::PipelineStage::kCompute:
+ Func("main", {input}, ty.void_(), {},
+ ast::AttributeList{Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(1)});
+ break;
+ default:
+ break;
+ }
+
+ if (params.is_valid) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ std::stringstream err;
+ err << "12:34 error: builtin(" << params.builtin << ")";
+ err << " cannot be used in input of " << params.stage << " pipeline stage";
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), err.str());
+ }
+}
+INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest,
+ ResolverBuiltinsStageTest,
+ testing::ValuesIn(cases));
+
+TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInput_Fail) {
+ // @stage(fragment)
+ // fn fs_main(
+ // @builtin(frag_depth) fd: f32,
+ // ) -> @location(0) f32 { return 1.0; }
+ auto* fd = Param(
+ "fd", ty.f32(),
+ ast::AttributeList{Builtin(Source{{12, 34}}, ast::Builtin::kFragDepth)});
+ Func("fs_main", ast::VariableList{fd}, ty.f32(), {Return(1.0f)},
+ ast::AttributeList{Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: builtin(frag_depth) cannot be used in input of "
+ "fragment pipeline stage");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInputStruct_Fail) {
+ // struct MyInputs {
+ // @builtin(frag_depth) ff: f32;
+ // };
+ // @stage(fragment)
+ // fn fragShader(arg: MyInputs) -> @location(0) f32 { return 1.0; }
+
+ auto* s = Structure(
+ "MyInputs", {Member("frag_depth", ty.f32(),
+ ast::AttributeList{Builtin(
+ Source{{12, 34}}, ast::Builtin::kFragDepth)})});
+
+ Func("fragShader", {Param("arg", ty.Of(s))}, ty.f32(), {Return(1.0f)},
+ {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: builtin(frag_depth) cannot be used in input of "
+ "fragment pipeline stage\n"
+ "note: while analysing entry point 'fragShader'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, StructBuiltinInsideEntryPoint_Ignored) {
+ // struct S {
+ // @builtin(vertex_index) idx: u32;
+ // };
+ // @stage(fragment)
+ // fn fragShader() { var s : S; }
+
+ Structure("S",
+ {Member("idx", ty.u32(), {Builtin(ast::Builtin::kVertexIndex)})});
+
+ Func("fragShader", {}, ty.void_(), {Decl(Var("s", ty.type_name("S")))},
+ {Stage(ast::PipelineStage::kFragment)});
+ EXPECT_TRUE(r()->Resolve());
+}
+
+} // namespace StageTest
+
+TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_Struct_Fail) {
+ // struct MyInputs {
+ // @builtin(kPosition) p: vec4<u32>;
+ // };
+ // @stage(fragment)
+ // fn fragShader(is_front: MyInputs) -> @location(0) f32 { return 1.0; }
+
+ auto* m = Member(
+ "position", ty.vec4<u32>(),
+ ast::AttributeList{Builtin(Source{{12, 34}}, ast::Builtin::kPosition)});
+ auto* s = Structure("MyInputs", {m});
+ Func("fragShader", {Param("arg", ty.Of(s))}, ty.f32(), {Return(1.0f)},
+ {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(position) must be 'vec4<f32>'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_ReturnType_Fail) {
+ // @stage(vertex)
+ // fn main() -> @builtin(position) f32 { return 1.0; }
+ Func("main", {}, ty.f32(), {Return(1.0f)},
+ {Stage(ast::PipelineStage::kVertex)},
+ {Builtin(Source{{12, 34}}, ast::Builtin::kPosition)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(position) must be 'vec4<f32>'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, FragDepthNotF32_Struct_Fail) {
+ // struct MyInputs {
+ // @builtin(kFragDepth) p: i32;
+ // };
+ // @stage(fragment)
+ // fn fragShader(is_front: MyInputs) -> @location(0) f32 { return 1.0; }
+
+ auto* m = Member(
+ "frag_depth", ty.i32(),
+ ast::AttributeList{Builtin(Source{{12, 34}}, ast::Builtin::kFragDepth)});
+ auto* s = Structure("MyInputs", {m});
+ Func("fragShader", {Param("arg", ty.Of(s))}, ty.f32(), {Return(1.0f)},
+ {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(frag_depth) must be 'f32'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, SampleMaskNotU32_Struct_Fail) {
+ // struct MyInputs {
+ // @builtin(sample_mask) m: f32;
+ // };
+ // @stage(fragment)
+ // fn fragShader(is_front: MyInputs) -> @location(0) f32 { return 1.0; }
+
+ auto* s = Structure(
+ "MyInputs", {Member("m", ty.f32(),
+ ast::AttributeList{Builtin(
+ Source{{12, 34}}, ast::Builtin::kSampleMask)})});
+ Func("fragShader", {Param("arg", ty.Of(s))}, ty.f32(), {Return(1.0f)},
+ {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(sample_mask) must be 'u32'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, SampleMaskNotU32_ReturnType_Fail) {
+ // @stage(fragment)
+ // fn main() -> @builtin(sample_mask) i32 { return 1; }
+ Func("main", {}, ty.i32(), {Return(1)},
+ {Stage(ast::PipelineStage::kFragment)},
+ {Builtin(Source{{12, 34}}, ast::Builtin::kSampleMask)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(sample_mask) must be 'u32'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, SampleMaskIsNotU32_Fail) {
+ // @stage(fragment)
+ // fn fs_main(
+ // @builtin(sample_mask) arg: bool
+ // ) -> @location(0) f32 { return 1.0; }
+ auto* arg = Param(
+ "arg", ty.bool_(),
+ ast::AttributeList{Builtin(Source{{12, 34}}, ast::Builtin::kSampleMask)});
+ Func("fs_main", ast::VariableList{arg}, ty.f32(), {Return(1.0f)},
+ ast::AttributeList{Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(sample_mask) must be 'u32'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, SampleIndexIsNotU32_Struct_Fail) {
+ // struct MyInputs {
+ // @builtin(sample_index) m: f32;
+ // };
+ // @stage(fragment)
+ // fn fragShader(is_front: MyInputs) -> @location(0) f32 { return 1.0; }
+
+ auto* s = Structure(
+ "MyInputs", {Member("m", ty.f32(),
+ ast::AttributeList{Builtin(
+ Source{{12, 34}}, ast::Builtin::kSampleIndex)})});
+ Func("fragShader", {Param("arg", ty.Of(s))}, ty.f32(), {Return(1.0f)},
+ {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(sample_index) must be 'u32'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, SampleIndexIsNotU32_Fail) {
+ // @stage(fragment)
+ // fn fs_main(
+ // @builtin(sample_index) arg: bool
+ // ) -> @location(0) f32 { return 1.0; }
+ auto* arg = Param("arg", ty.bool_(),
+ ast::AttributeList{
+ Builtin(Source{{12, 34}}, ast::Builtin::kSampleIndex)});
+ Func("fs_main", ast::VariableList{arg}, ty.f32(), {Return(1.0f)},
+ ast::AttributeList{Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(sample_index) must be 'u32'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, PositionIsNotF32_Fail) {
+ // @stage(fragment)
+ // fn fs_main(
+ // @builtin(kPosition) p: vec3<f32>,
+ // ) -> @location(0) f32 { return 1.0; }
+ auto* p = Param(
+ "p", ty.vec3<f32>(),
+ ast::AttributeList{Builtin(Source{{12, 34}}, ast::Builtin::kPosition)});
+ Func("fs_main", ast::VariableList{p}, ty.f32(), {Return(1.0f)},
+ ast::AttributeList{Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(position) must be 'vec4<f32>'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, FragDepthIsNotF32_Fail) {
+ // @stage(fragment)
+ // fn fs_main() -> @builtin(kFragDepth) f32 { var fd: i32; return fd; }
+ auto* fd = Var("fd", ty.i32());
+ Func("fs_main", {}, ty.i32(), {Decl(fd), Return(fd)},
+ ast::AttributeList{Stage(ast::PipelineStage::kFragment)},
+ ast::AttributeList{Builtin(Source{{12, 34}}, ast::Builtin::kFragDepth)});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(frag_depth) must be 'f32'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, VertexIndexIsNotU32_Fail) {
+ // @stage(vertex)
+ // fn main(
+ // @builtin(kVertexIndex) vi : f32,
+ // @builtin(kPosition) p :vec4<f32>
+ // ) -> @builtin(kPosition) vec4<f32> { return vec4<f32>(); }
+ auto* p = Param("p", ty.vec4<f32>(),
+ ast::AttributeList{Builtin(ast::Builtin::kPosition)});
+ auto* vi = Param("vi", ty.f32(),
+ ast::AttributeList{
+ Builtin(Source{{12, 34}}, ast::Builtin::kVertexIndex)});
+ Func("main", ast::VariableList{vi, p}, ty.vec4<f32>(), {Return(Expr("p"))},
+ ast::AttributeList{Stage(ast::PipelineStage::kVertex)},
+ ast::AttributeList{Builtin(ast::Builtin::kPosition)});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(vertex_index) must be 'u32'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, InstanceIndexIsNotU32) {
+ // @stage(vertex)
+ // fn main(
+ // @builtin(kInstanceIndex) ii : f32,
+ // @builtin(kPosition) p :vec4<f32>
+ // ) -> @builtin(kPosition) vec4<f32> { return vec4<f32>(); }
+ auto* p = Param("p", ty.vec4<f32>(),
+ ast::AttributeList{Builtin(ast::Builtin::kPosition)});
+ auto* ii = Param("ii", ty.f32(),
+ ast::AttributeList{Builtin(Source{{12, 34}},
+ ast::Builtin::kInstanceIndex)});
+ Func("main", ast::VariableList{ii, p}, ty.vec4<f32>(), {Return(Expr("p"))},
+ ast::AttributeList{Stage(ast::PipelineStage::kVertex)},
+ ast::AttributeList{Builtin(ast::Builtin::kPosition)});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(instance_index) must be 'u32'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltin_Pass) {
+ // @stage(fragment)
+ // fn fs_main(
+ // @builtin(kPosition) p: vec4<f32>,
+ // @builtin(front_facing) ff: bool,
+ // @builtin(sample_index) si: u32,
+ // @builtin(sample_mask) sm : u32
+ // ) -> @builtin(frag_depth) f32 { var fd: f32; return fd; }
+ auto* p = Param("p", ty.vec4<f32>(),
+ ast::AttributeList{Builtin(ast::Builtin::kPosition)});
+ auto* ff = Param("ff", ty.bool_(),
+ ast::AttributeList{Builtin(ast::Builtin::kFrontFacing)});
+ auto* si = Param("si", ty.u32(),
+ ast::AttributeList{Builtin(ast::Builtin::kSampleIndex)});
+ auto* sm = Param("sm", ty.u32(),
+ ast::AttributeList{Builtin(ast::Builtin::kSampleMask)});
+ auto* var_fd = Var("fd", ty.f32());
+ Func("fs_main", ast::VariableList{p, ff, si, sm}, ty.f32(),
+ {Decl(var_fd), Return(var_fd)},
+ ast::AttributeList{Stage(ast::PipelineStage::kFragment)},
+ ast::AttributeList{Builtin(ast::Builtin::kFragDepth)});
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, VertexBuiltin_Pass) {
+ // @stage(vertex)
+ // fn main(
+ // @builtin(vertex_index) vi : u32,
+ // @builtin(instance_index) ii : u32,
+ // ) -> @builtin(position) vec4<f32> { var p :vec4<f32>; return p; }
+ auto* vi = Param("vi", ty.u32(),
+ ast::AttributeList{
+ Builtin(Source{{12, 34}}, ast::Builtin::kVertexIndex)});
+
+ auto* ii = Param("ii", ty.u32(),
+ ast::AttributeList{Builtin(Source{{12, 34}},
+ ast::Builtin::kInstanceIndex)});
+ auto* p = Var("p", ty.vec4<f32>());
+ Func("main", ast::VariableList{vi, ii}, ty.vec4<f32>(),
+ {
+ Decl(p),
+ Return(p),
+ },
+ ast::AttributeList{Stage(ast::PipelineStage::kVertex)},
+ ast::AttributeList{Builtin(ast::Builtin::kPosition)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, ComputeBuiltin_Pass) {
+ // @stage(compute) @workgroup_size(1)
+ // fn main(
+ // @builtin(local_invocationId) li_id: vec3<u32>,
+ // @builtin(local_invocationIndex) li_index: u32,
+ // @builtin(global_invocationId) gi: vec3<u32>,
+ // @builtin(workgroup_id) wi: vec3<u32>,
+ // @builtin(num_workgroups) nwgs: vec3<u32>,
+ // ) {}
+
+ auto* li_id =
+ Param("li_id", ty.vec3<u32>(),
+ ast::AttributeList{Builtin(ast::Builtin::kLocalInvocationId)});
+ auto* li_index =
+ Param("li_index", ty.u32(),
+ ast::AttributeList{Builtin(ast::Builtin::kLocalInvocationIndex)});
+ auto* gi =
+ Param("gi", ty.vec3<u32>(),
+ ast::AttributeList{Builtin(ast::Builtin::kGlobalInvocationId)});
+ auto* wi = Param("wi", ty.vec3<u32>(),
+ ast::AttributeList{Builtin(ast::Builtin::kWorkgroupId)});
+ auto* nwgs = Param("nwgs", ty.vec3<u32>(),
+ ast::AttributeList{Builtin(ast::Builtin::kNumWorkgroups)});
+
+ Func("main", ast::VariableList{li_id, li_index, gi, wi, nwgs}, ty.void_(), {},
+ ast::AttributeList{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{Source::Location{12, 34}}, 2))});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, ComputeBuiltin_WorkGroupIdNotVec3U32) {
+ auto* wi = Param("wi", ty.f32(),
+ ast::AttributeList{
+ Builtin(Source{{12, 34}}, ast::Builtin::kWorkgroupId)});
+ Func("main", ast::VariableList{wi}, ty.void_(), {},
+ ast::AttributeList{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{Source::Location{12, 34}}, 2))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(workgroup_id) must be "
+ "'vec3<u32>'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, ComputeBuiltin_NumWorkgroupsNotVec3U32) {
+ auto* nwgs = Param("nwgs", ty.f32(),
+ ast::AttributeList{Builtin(Source{{12, 34}},
+ ast::Builtin::kNumWorkgroups)});
+ Func("main", ast::VariableList{nwgs}, ty.void_(), {},
+ ast::AttributeList{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{Source::Location{12, 34}}, 2))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(num_workgroups) must be "
+ "'vec3<u32>'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest,
+ ComputeBuiltin_GlobalInvocationNotVec3U32) {
+ auto* gi = Param("gi", ty.vec3<i32>(),
+ ast::AttributeList{Builtin(
+ Source{{12, 34}}, ast::Builtin::kGlobalInvocationId)});
+ Func("main", ast::VariableList{gi}, ty.void_(), {},
+ ast::AttributeList{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{Source::Location{12, 34}}, 2))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(global_invocation_id) must be "
+ "'vec3<u32>'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest,
+ ComputeBuiltin_LocalInvocationIndexNotU32) {
+ auto* li_index =
+ Param("li_index", ty.vec3<u32>(),
+ ast::AttributeList{Builtin(Source{{12, 34}},
+ ast::Builtin::kLocalInvocationIndex)});
+ Func("main", ast::VariableList{li_index}, ty.void_(), {},
+ ast::AttributeList{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{Source::Location{12, 34}}, 2))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: store type of builtin(local_invocation_index) must be "
+ "'u32'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest,
+ ComputeBuiltin_LocalInvocationNotVec3U32) {
+ auto* li_id = Param("li_id", ty.vec2<u32>(),
+ ast::AttributeList{Builtin(
+ Source{{12, 34}}, ast::Builtin::kLocalInvocationId)});
+ Func("main", ast::VariableList{li_id}, ty.void_(), {},
+ ast::AttributeList{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{Source::Location{12, 34}}, 2))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(local_invocation_id) must be "
+ "'vec3<u32>'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltinStruct_Pass) {
+ // Struct MyInputs {
+ // @builtin(kPosition) p: vec4<f32>;
+ // @builtin(frag_depth) fd: f32;
+ // @builtin(sample_index) si: u32;
+ // @builtin(sample_mask) sm : u32;;
+ // };
+ // @stage(fragment)
+ // fn fragShader(arg: MyInputs) -> @location(0) f32 { return 1.0; }
+
+ auto* s = Structure(
+ "MyInputs",
+ {Member("position", ty.vec4<f32>(),
+ ast::AttributeList{Builtin(ast::Builtin::kPosition)}),
+ Member("front_facing", ty.bool_(),
+ ast::AttributeList{Builtin(ast::Builtin::kFrontFacing)}),
+ Member("sample_index", ty.u32(),
+ ast::AttributeList{Builtin(ast::Builtin::kSampleIndex)}),
+ Member("sample_mask", ty.u32(),
+ ast::AttributeList{Builtin(ast::Builtin::kSampleMask)})});
+ Func("fragShader", {Param("arg", ty.Of(s))}, ty.f32(), {Return(1.0f)},
+ {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, FrontFacingParamIsNotBool_Fail) {
+ // @stage(fragment)
+ // fn fs_main(
+ // @builtin(front_facing) is_front: i32;
+ // ) -> @location(0) f32 { return 1.0; }
+
+ auto* is_front = Param("is_front", ty.i32(),
+ ast::AttributeList{Builtin(
+ Source{{12, 34}}, ast::Builtin::kFrontFacing)});
+ Func("fs_main", ast::VariableList{is_front}, ty.f32(), {Return(1.0f)},
+ ast::AttributeList{Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(front_facing) must be 'bool'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, FrontFacingMemberIsNotBool_Fail) {
+ // struct MyInputs {
+ // @builtin(front_facing) pos: f32;
+ // };
+ // @stage(fragment)
+ // fn fragShader(is_front: MyInputs) -> @location(0) f32 { return 1.0; }
+
+ auto* s = Structure(
+ "MyInputs", {Member("pos", ty.f32(),
+ ast::AttributeList{Builtin(
+ Source{{12, 34}}, ast::Builtin::kFrontFacing)})});
+ Func("fragShader", {Param("is_front", ty.Of(s))}, ty.f32(), {Return(1.0f)},
+ {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of builtin(front_facing) must be 'bool'");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Length_Float_Scalar) {
+ auto* builtin = Call("length", 1.0f);
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Length_Float_Vec2) {
+ auto* builtin = Call("length", vec2<f32>(1.0f, 1.0f));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Length_Float_Vec3) {
+ auto* builtin = Call("length", vec3<f32>(1.0f, 1.0f, 1.0f));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Length_Float_Vec4) {
+ auto* builtin = Call("length", vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Distance_Float_Scalar) {
+ auto* builtin = Call("distance", 1.0f, 1.0f);
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Distance_Float_Vec2) {
+ auto* builtin =
+ Call("distance", vec2<f32>(1.0f, 1.0f), vec2<f32>(1.0f, 1.0f));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Distance_Float_Vec3) {
+ auto* builtin = Call("distance", vec3<f32>(1.0f, 1.0f, 1.0f),
+ vec3<f32>(1.0f, 1.0f, 1.0f));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Distance_Float_Vec4) {
+ auto* builtin = Call("distance", vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f),
+ vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Determinant_Mat2x2) {
+ auto* builtin = Call(
+ "determinant", mat2x2<f32>(vec2<f32>(1.0f, 1.0f), vec2<f32>(1.0f, 1.0f)));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Determinant_Mat3x3) {
+ auto* builtin = Call("determinant", mat3x3<f32>(vec3<f32>(1.0f, 1.0f, 1.0f),
+ vec3<f32>(1.0f, 1.0f, 1.0f),
+ vec3<f32>(1.0f, 1.0f, 1.0f)));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Determinant_Mat4x4) {
+ auto* builtin =
+ Call("determinant", mat4x4<f32>(vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f),
+ vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f),
+ vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f),
+ vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f)));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Frexp_Scalar) {
+ auto* builtin = Call("frexp", 1.0f);
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* res_ty = TypeOf(builtin)->As<sem::Struct>();
+ ASSERT_TRUE(res_ty != nullptr);
+ auto& members = res_ty->Members();
+ ASSERT_EQ(members.size(), 2u);
+ EXPECT_TRUE(members[0]->Type()->Is<sem::F32>());
+ EXPECT_TRUE(members[1]->Type()->Is<sem::I32>());
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Frexp_Vec2) {
+ auto* builtin = Call("frexp", vec2<f32>(1.0f, 1.0f));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* res_ty = TypeOf(builtin)->As<sem::Struct>();
+ ASSERT_TRUE(res_ty != nullptr);
+ auto& members = res_ty->Members();
+ ASSERT_EQ(members.size(), 2u);
+ ASSERT_TRUE(members[0]->Type()->Is<sem::Vector>());
+ ASSERT_TRUE(members[1]->Type()->Is<sem::Vector>());
+ EXPECT_EQ(members[0]->Type()->As<sem::Vector>()->Width(), 2u);
+ EXPECT_TRUE(members[0]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(members[1]->Type()->As<sem::Vector>()->Width(), 2u);
+ EXPECT_TRUE(members[1]->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Frexp_Vec3) {
+ auto* builtin = Call("frexp", vec3<f32>(1.0f, 1.0f, 1.0f));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* res_ty = TypeOf(builtin)->As<sem::Struct>();
+ ASSERT_TRUE(res_ty != nullptr);
+ auto& members = res_ty->Members();
+ ASSERT_EQ(members.size(), 2u);
+ ASSERT_TRUE(members[0]->Type()->Is<sem::Vector>());
+ ASSERT_TRUE(members[1]->Type()->Is<sem::Vector>());
+ EXPECT_EQ(members[0]->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(members[0]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(members[1]->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(members[1]->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Frexp_Vec4) {
+ auto* builtin = Call("frexp", vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* res_ty = TypeOf(builtin)->As<sem::Struct>();
+ ASSERT_TRUE(res_ty != nullptr);
+ auto& members = res_ty->Members();
+ ASSERT_EQ(members.size(), 2u);
+ ASSERT_TRUE(members[0]->Type()->Is<sem::Vector>());
+ ASSERT_TRUE(members[1]->Type()->Is<sem::Vector>());
+ EXPECT_EQ(members[0]->Type()->As<sem::Vector>()->Width(), 4u);
+ EXPECT_TRUE(members[0]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(members[1]->Type()->As<sem::Vector>()->Width(), 4u);
+ EXPECT_TRUE(members[1]->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Modf_Scalar) {
+ auto* builtin = Call("modf", 1.0f);
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* res_ty = TypeOf(builtin)->As<sem::Struct>();
+ ASSERT_TRUE(res_ty != nullptr);
+ auto& members = res_ty->Members();
+ ASSERT_EQ(members.size(), 2u);
+ EXPECT_TRUE(members[0]->Type()->Is<sem::F32>());
+ EXPECT_TRUE(members[1]->Type()->Is<sem::F32>());
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Modf_Vec2) {
+ auto* builtin = Call("modf", vec2<f32>(1.0f, 1.0f));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* res_ty = TypeOf(builtin)->As<sem::Struct>();
+ ASSERT_TRUE(res_ty != nullptr);
+ auto& members = res_ty->Members();
+ ASSERT_EQ(members.size(), 2u);
+ ASSERT_TRUE(members[0]->Type()->Is<sem::Vector>());
+ ASSERT_TRUE(members[1]->Type()->Is<sem::Vector>());
+ EXPECT_EQ(members[0]->Type()->As<sem::Vector>()->Width(), 2u);
+ EXPECT_TRUE(members[0]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(members[1]->Type()->As<sem::Vector>()->Width(), 2u);
+ EXPECT_TRUE(members[1]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Modf_Vec3) {
+ auto* builtin = Call("modf", vec3<f32>(1.0f, 1.0f, 1.0f));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* res_ty = TypeOf(builtin)->As<sem::Struct>();
+ ASSERT_TRUE(res_ty != nullptr);
+ auto& members = res_ty->Members();
+ ASSERT_EQ(members.size(), 2u);
+ ASSERT_TRUE(members[0]->Type()->Is<sem::Vector>());
+ ASSERT_TRUE(members[1]->Type()->Is<sem::Vector>());
+ EXPECT_EQ(members[0]->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(members[0]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(members[1]->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(members[1]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Modf_Vec4) {
+ auto* builtin = Call("modf", vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* res_ty = TypeOf(builtin)->As<sem::Struct>();
+ ASSERT_TRUE(res_ty != nullptr);
+ auto& members = res_ty->Members();
+ ASSERT_EQ(members.size(), 2u);
+ ASSERT_TRUE(members[0]->Type()->Is<sem::Vector>());
+ ASSERT_TRUE(members[1]->Type()->Is<sem::Vector>());
+ EXPECT_EQ(members[0]->Type()->As<sem::Vector>()->Width(), 4u);
+ EXPECT_TRUE(members[0]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(members[1]->Type()->As<sem::Vector>()->Width(), 4u);
+ EXPECT_TRUE(members[1]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Cross_Float_Vec3) {
+ auto* builtin =
+ Call("cross", vec3<f32>(1.0f, 1.0f, 1.0f), vec3<f32>(1.0f, 1.0f, 1.0f));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Dot_Float_Vec2) {
+ auto* builtin = Call("dot", vec2<f32>(1.0f, 1.0f), vec2<f32>(1.0f, 1.0f));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Dot_Float_Vec3) {
+ auto* builtin =
+ Call("dot", vec3<f32>(1.0f, 1.0f, 1.0f), vec3<f32>(1.0f, 1.0f, 1.0f));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Dot_Float_Vec4) {
+ auto* builtin = Call("dot", vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f),
+ vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Select_Float_Scalar) {
+ auto* builtin = Call("select", Expr(1.0f), Expr(1.0f), Expr(true));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Select_Integer_Scalar) {
+ auto* builtin = Call("select", Expr(1), Expr(1), Expr(true));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Select_Boolean_Scalar) {
+ auto* builtin = Call("select", Expr(true), Expr(true), Expr(true));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Select_Float_Vec2) {
+ auto* builtin = Call("select", vec2<f32>(1.0f, 1.0f), vec2<f32>(1.0f, 1.0f),
+ vec2<bool>(true, true));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Select_Integer_Vec2) {
+ auto* builtin =
+ Call("select", vec2<int>(1, 1), vec2<int>(1, 1), vec2<bool>(true, true));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverBuiltinsValidationTest, Select_Boolean_Vec2) {
+ auto* builtin = Call("select", vec2<bool>(true, true), vec2<bool>(true, true),
+ vec2<bool>(true, true));
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+template <typename T>
+class ResolverBuiltinsValidationTestWithParams
+ : public resolver::TestHelper,
+ public testing::TestWithParam<T> {};
+
+using FloatAllMatching =
+ ResolverBuiltinsValidationTestWithParams<std::tuple<std::string, uint32_t>>;
+
+TEST_P(FloatAllMatching, Scalar) {
+ std::string name = std::get<0>(GetParam());
+ uint32_t num_params = std::get<1>(GetParam());
+
+ ast::ExpressionList params;
+ for (uint32_t i = 0; i < num_params; ++i) {
+ params.push_back(Expr(1.0f));
+ }
+ auto* builtin = Call(name, params);
+ Func("func", {}, ty.void_(), {CallStmt(builtin)},
+ {create<ast::StageAttribute>(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_TRUE(TypeOf(builtin)->Is<sem::F32>());
+}
+
+TEST_P(FloatAllMatching, Vec2) {
+ std::string name = std::get<0>(GetParam());
+ uint32_t num_params = std::get<1>(GetParam());
+
+ ast::ExpressionList params;
+ for (uint32_t i = 0; i < num_params; ++i) {
+ params.push_back(vec2<f32>(1.0f, 1.0f));
+ }
+ auto* builtin = Call(name, params);
+ Func("func", {}, ty.void_(), {CallStmt(builtin)},
+ {create<ast::StageAttribute>(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_TRUE(TypeOf(builtin)->is_float_vector());
+}
+
+TEST_P(FloatAllMatching, Vec3) {
+ std::string name = std::get<0>(GetParam());
+ uint32_t num_params = std::get<1>(GetParam());
+
+ ast::ExpressionList params;
+ for (uint32_t i = 0; i < num_params; ++i) {
+ params.push_back(vec3<f32>(1.0f, 1.0f, 1.0f));
+ }
+ auto* builtin = Call(name, params);
+ Func("func", {}, ty.void_(), {CallStmt(builtin)},
+ {create<ast::StageAttribute>(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_TRUE(TypeOf(builtin)->is_float_vector());
+}
+
+TEST_P(FloatAllMatching, Vec4) {
+ std::string name = std::get<0>(GetParam());
+ uint32_t num_params = std::get<1>(GetParam());
+
+ ast::ExpressionList params;
+ for (uint32_t i = 0; i < num_params; ++i) {
+ params.push_back(vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f));
+ }
+ auto* builtin = Call(name, params);
+ Func("func", {}, ty.void_(), {CallStmt(builtin)},
+ {create<ast::StageAttribute>(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_TRUE(TypeOf(builtin)->is_float_vector());
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest,
+ FloatAllMatching,
+ ::testing::Values(std::make_tuple("abs", 1),
+ std::make_tuple("acos", 1),
+ std::make_tuple("asin", 1),
+ std::make_tuple("atan", 1),
+ std::make_tuple("atan2", 2),
+ std::make_tuple("ceil", 1),
+ std::make_tuple("clamp", 3),
+ std::make_tuple("cos", 1),
+ std::make_tuple("cosh", 1),
+ std::make_tuple("dpdx", 1),
+ std::make_tuple("dpdxCoarse", 1),
+ std::make_tuple("dpdxFine", 1),
+ std::make_tuple("dpdy", 1),
+ std::make_tuple("dpdyCoarse", 1),
+ std::make_tuple("dpdyFine", 1),
+ std::make_tuple("exp", 1),
+ std::make_tuple("exp2", 1),
+ std::make_tuple("floor", 1),
+ std::make_tuple("fma", 3),
+ std::make_tuple("fract", 1),
+ std::make_tuple("fwidth", 1),
+ std::make_tuple("fwidthCoarse", 1),
+ std::make_tuple("fwidthFine", 1),
+ std::make_tuple("inverseSqrt", 1),
+ std::make_tuple("log", 1),
+ std::make_tuple("log2", 1),
+ std::make_tuple("max", 2),
+ std::make_tuple("min", 2),
+ std::make_tuple("mix", 3),
+ std::make_tuple("pow", 2),
+ std::make_tuple("round", 1),
+ std::make_tuple("sign", 1),
+ std::make_tuple("sin", 1),
+ std::make_tuple("sinh", 1),
+ std::make_tuple("smoothStep", 3),
+ std::make_tuple("sqrt", 1),
+ std::make_tuple("step", 2),
+ std::make_tuple("tan", 1),
+ std::make_tuple("tanh", 1),
+ std::make_tuple("trunc", 1)));
+
+using IntegerAllMatching =
+ ResolverBuiltinsValidationTestWithParams<std::tuple<std::string, uint32_t>>;
+
+TEST_P(IntegerAllMatching, ScalarUnsigned) {
+ std::string name = std::get<0>(GetParam());
+ uint32_t num_params = std::get<1>(GetParam());
+
+ ast::ExpressionList params;
+ for (uint32_t i = 0; i < num_params; ++i) {
+ params.push_back(Construct<uint32_t>(1));
+ }
+ auto* builtin = Call(name, params);
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_TRUE(TypeOf(builtin)->Is<sem::U32>());
+}
+
+TEST_P(IntegerAllMatching, Vec2Unsigned) {
+ std::string name = std::get<0>(GetParam());
+ uint32_t num_params = std::get<1>(GetParam());
+
+ ast::ExpressionList params;
+ for (uint32_t i = 0; i < num_params; ++i) {
+ params.push_back(vec2<uint32_t>(1u, 1u));
+ }
+ auto* builtin = Call(name, params);
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_TRUE(TypeOf(builtin)->is_unsigned_integer_vector());
+}
+
+TEST_P(IntegerAllMatching, Vec3Unsigned) {
+ std::string name = std::get<0>(GetParam());
+ uint32_t num_params = std::get<1>(GetParam());
+
+ ast::ExpressionList params;
+ for (uint32_t i = 0; i < num_params; ++i) {
+ params.push_back(vec3<uint32_t>(1u, 1u, 1u));
+ }
+ auto* builtin = Call(name, params);
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_TRUE(TypeOf(builtin)->is_unsigned_integer_vector());
+}
+
+TEST_P(IntegerAllMatching, Vec4Unsigned) {
+ std::string name = std::get<0>(GetParam());
+ uint32_t num_params = std::get<1>(GetParam());
+
+ ast::ExpressionList params;
+ for (uint32_t i = 0; i < num_params; ++i) {
+ params.push_back(vec4<uint32_t>(1u, 1u, 1u, 1u));
+ }
+ auto* builtin = Call(name, params);
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_TRUE(TypeOf(builtin)->is_unsigned_integer_vector());
+}
+
+TEST_P(IntegerAllMatching, ScalarSigned) {
+ std::string name = std::get<0>(GetParam());
+ uint32_t num_params = std::get<1>(GetParam());
+
+ ast::ExpressionList params;
+ for (uint32_t i = 0; i < num_params; ++i) {
+ params.push_back(Construct<int32_t>(1));
+ }
+ auto* builtin = Call(name, params);
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_TRUE(TypeOf(builtin)->Is<sem::I32>());
+}
+
+TEST_P(IntegerAllMatching, Vec2Signed) {
+ std::string name = std::get<0>(GetParam());
+ uint32_t num_params = std::get<1>(GetParam());
+
+ ast::ExpressionList params;
+ for (uint32_t i = 0; i < num_params; ++i) {
+ params.push_back(vec2<int32_t>(1, 1));
+ }
+ auto* builtin = Call(name, params);
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_TRUE(TypeOf(builtin)->is_signed_integer_vector());
+}
+
+TEST_P(IntegerAllMatching, Vec3Signed) {
+ std::string name = std::get<0>(GetParam());
+ uint32_t num_params = std::get<1>(GetParam());
+
+ ast::ExpressionList params;
+ for (uint32_t i = 0; i < num_params; ++i) {
+ params.push_back(vec3<int32_t>(1, 1, 1));
+ }
+ auto* builtin = Call(name, params);
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_TRUE(TypeOf(builtin)->is_signed_integer_vector());
+}
+
+TEST_P(IntegerAllMatching, Vec4Signed) {
+ std::string name = std::get<0>(GetParam());
+ uint32_t num_params = std::get<1>(GetParam());
+
+ ast::ExpressionList params;
+ for (uint32_t i = 0; i < num_params; ++i) {
+ params.push_back(vec4<int32_t>(1, 1, 1, 1));
+ }
+ auto* builtin = Call(name, params);
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_TRUE(TypeOf(builtin)->is_signed_integer_vector());
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest,
+ IntegerAllMatching,
+ ::testing::Values(std::make_tuple("abs", 1),
+ std::make_tuple("clamp", 3),
+ std::make_tuple("countOneBits", 1),
+ std::make_tuple("max", 2),
+ std::make_tuple("min", 2),
+ std::make_tuple("reverseBits", 1)));
+
+using BooleanVectorInput =
+ ResolverBuiltinsValidationTestWithParams<std::tuple<std::string, uint32_t>>;
+
+TEST_P(BooleanVectorInput, Vec2) {
+ std::string name = std::get<0>(GetParam());
+ uint32_t num_params = std::get<1>(GetParam());
+
+ ast::ExpressionList params;
+ for (uint32_t i = 0; i < num_params; ++i) {
+ params.push_back(vec2<bool>(true, true));
+ }
+ auto* builtin = Call(name, params);
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_P(BooleanVectorInput, Vec3) {
+ std::string name = std::get<0>(GetParam());
+ uint32_t num_params = std::get<1>(GetParam());
+
+ ast::ExpressionList params;
+ for (uint32_t i = 0; i < num_params; ++i) {
+ params.push_back(vec3<bool>(true, true, true));
+ }
+ auto* builtin = Call(name, params);
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_P(BooleanVectorInput, Vec4) {
+ std::string name = std::get<0>(GetParam());
+ uint32_t num_params = std::get<1>(GetParam());
+
+ ast::ExpressionList params;
+ for (uint32_t i = 0; i < num_params; ++i) {
+ params.push_back(vec4<bool>(true, true, true, true));
+ }
+ auto* builtin = Call(name, params);
+ WrapInFunction(builtin);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest,
+ BooleanVectorInput,
+ ::testing::Values(std::make_tuple("all", 1),
+ std::make_tuple("any", 1)));
+
+using DataPacking4x8 = ResolverBuiltinsValidationTestWithParams<std::string>;
+
+TEST_P(DataPacking4x8, Float_Vec4) {
+ auto name = GetParam();
+ auto* builtin = Call(name, vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f));
+ WrapInFunction(builtin);
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest,
+ DataPacking4x8,
+ ::testing::Values("pack4x8snorm", "pack4x8unorm"));
+
+using DataPacking2x16 = ResolverBuiltinsValidationTestWithParams<std::string>;
+
+TEST_P(DataPacking2x16, Float_Vec2) {
+ auto name = GetParam();
+ auto* builtin = Call(name, vec2<f32>(1.0f, 1.0f));
+ WrapInFunction(builtin);
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest,
+ DataPacking2x16,
+ ::testing::Values("pack2x16snorm",
+ "pack2x16unorm",
+ "pack2x16float"));
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/call_test.cc b/src/tint/resolver/call_test.cc
new file mode 100644
index 0000000..038654d
--- /dev/null
+++ b/src/tint/resolver/call_test.cc
@@ -0,0 +1,118 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/ast/call_statement.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+// Helpers and typedefs
+template <typename T>
+using DataType = builder::DataType<T>;
+template <int N, typename T>
+using vec = builder::vec<N, T>;
+template <typename T>
+using vec2 = builder::vec2<T>;
+template <typename T>
+using vec3 = builder::vec3<T>;
+template <typename T>
+using vec4 = builder::vec4<T>;
+template <int N, int M, typename T>
+using mat = builder::mat<N, M, T>;
+template <typename T>
+using mat2x2 = builder::mat2x2<T>;
+template <typename T>
+using mat2x3 = builder::mat2x3<T>;
+template <typename T>
+using mat3x2 = builder::mat3x2<T>;
+template <typename T>
+using mat3x3 = builder::mat3x3<T>;
+template <typename T>
+using mat4x4 = builder::mat4x4<T>;
+template <typename T, int ID = 0>
+using alias = builder::alias<T, ID>;
+template <typename T>
+using alias1 = builder::alias1<T>;
+template <typename T>
+using alias2 = builder::alias2<T>;
+template <typename T>
+using alias3 = builder::alias3<T>;
+using f32 = builder::f32;
+using i32 = builder::i32;
+using u32 = builder::u32;
+
+using ResolverCallTest = ResolverTest;
+
+struct Params {
+ builder::ast_expr_func_ptr create_value;
+ builder::ast_type_func_ptr create_type;
+};
+
+template <typename T>
+constexpr Params ParamsFor() {
+ return Params{DataType<T>::Expr, DataType<T>::AST};
+}
+
+static constexpr Params all_param_types[] = {
+ ParamsFor<bool>(), //
+ ParamsFor<u32>(), //
+ ParamsFor<i32>(), //
+ ParamsFor<f32>(), //
+ ParamsFor<vec3<bool>>(), //
+ ParamsFor<vec3<i32>>(), //
+ ParamsFor<vec3<u32>>(), //
+ ParamsFor<vec3<f32>>(), //
+ ParamsFor<mat3x3<f32>>(), //
+ ParamsFor<mat2x3<f32>>(), //
+ ParamsFor<mat3x2<f32>>() //
+};
+
+TEST_F(ResolverCallTest, Valid) {
+ ast::VariableList params;
+ ast::ExpressionList args;
+ for (auto& p : all_param_types) {
+ params.push_back(Param(Sym(), p.create_type(*this)));
+ args.push_back(p.create_value(*this, 0));
+ }
+
+ auto* func = Func("foo", std::move(params), ty.f32(), {Return(1.23f)});
+ auto* call_expr = Call("foo", std::move(args));
+ WrapInFunction(call_expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* call = Sem().Get(call_expr);
+ EXPECT_NE(call, nullptr);
+ EXPECT_EQ(call->Target(), Sem().Get(func));
+}
+
+TEST_F(ResolverCallTest, OutOfOrder) {
+ auto* call_expr = Call("b");
+ Func("a", {}, ty.void_(), {CallStmt(call_expr)});
+ auto* b = Func("b", {}, ty.void_(), {});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* call = Sem().Get(call_expr);
+ EXPECT_NE(call, nullptr);
+ EXPECT_EQ(call->Target(), Sem().Get(b));
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/call_validation_test.cc b/src/tint/resolver/call_validation_test.cc
new file mode 100644
index 0000000..c8be5e4
--- /dev/null
+++ b/src/tint/resolver/call_validation_test.cc
@@ -0,0 +1,288 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/ast/call_statement.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ResolverCallValidationTest = ResolverTest;
+
+TEST_F(ResolverCallValidationTest, TooFewArgs) {
+ Func("foo", {Param(Sym(), ty.i32()), Param(Sym(), ty.f32())}, ty.void_(),
+ {Return()});
+ auto* call = Call(Source{{12, 34}}, "foo", 1);
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: too few arguments in call to 'foo', expected 2, got 1");
+}
+
+TEST_F(ResolverCallValidationTest, TooManyArgs) {
+ Func("foo", {Param(Sym(), ty.i32()), Param(Sym(), ty.f32())}, ty.void_(),
+ {Return()});
+ auto* call = Call(Source{{12, 34}}, "foo", 1, 1.0f, 1.0f);
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: too many arguments in call to 'foo', expected 2, got 3");
+}
+
+TEST_F(ResolverCallValidationTest, MismatchedArgs) {
+ Func("foo", {Param(Sym(), ty.i32()), Param(Sym(), ty.f32())}, ty.void_(),
+ {Return()});
+ auto* call = Call("foo", Expr(Source{{12, 34}}, true), 1.0f);
+ WrapInFunction(call);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type mismatch for argument 1 in call to 'foo', "
+ "expected 'i32', got 'bool'");
+}
+
+TEST_F(ResolverCallValidationTest, UnusedRetval) {
+ // fn func() -> f32 { return 1.0; }
+ // fn main() {func(); return; }
+
+ Func("func", {}, ty.f32(), {Return(Expr(1.0f))}, {});
+
+ Func("main", {}, ty.void_(),
+ {
+ CallStmt(Source{{12, 34}}, Call("func")),
+ Return(),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCallValidationTest, PointerArgument_VariableIdentExpr) {
+ // fn foo(p: ptr<function, i32>) {}
+ // fn main() {
+ // var z: i32 = 1;
+ // foo(&z);
+ // }
+ auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
+ Func("foo", {param}, ty.void_(), {});
+ Func("main", {}, ty.void_(),
+ {
+ Decl(Var("z", ty.i32(), Expr(1))),
+ CallStmt(Call("foo", AddressOf(Source{{12, 34}}, Expr("z")))),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCallValidationTest, PointerArgument_ConstIdentExpr) {
+ // fn foo(p: ptr<function, i32>) {}
+ // fn main() {
+ // let z: i32 = 1;
+ // foo(&z);
+ // }
+ auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
+ Func("foo", {param}, ty.void_(), {});
+ Func("main", {}, ty.void_(),
+ {
+ Decl(Const("z", ty.i32(), Expr(1))),
+ CallStmt(Call("foo", AddressOf(Expr(Source{{12, 34}}, "z")))),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
+}
+
+TEST_F(ResolverCallValidationTest, PointerArgument_NotIdentExprVar) {
+ // struct S { m: i32; };
+ // fn foo(p: ptr<function, i32>) {}
+ // fn main() {
+ // var v: S;
+ // foo(&v.m);
+ // }
+ auto* S = Structure("S", {Member("m", ty.i32())});
+ auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
+ Func("foo", {param}, ty.void_(), {});
+ Func("main", {}, ty.void_(),
+ {
+ Decl(Var("v", ty.Of(S))),
+ CallStmt(Call(
+ "foo", AddressOf(Source{{12, 34}}, MemberAccessor("v", "m")))),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: expected an address-of expression of a variable "
+ "identifier expression or a function parameter");
+}
+
+TEST_F(ResolverCallValidationTest, PointerArgument_AddressOfMemberAccessor) {
+ // struct S { m: i32; };
+ // fn foo(p: ptr<function, i32>) {}
+ // fn main() {
+ // let v: S = S();
+ // foo(&v.m);
+ // }
+ auto* S = Structure("S", {Member("m", ty.i32())});
+ auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
+ Func("foo", {param}, ty.void_(), {});
+ Func("main", {}, ty.void_(),
+ {
+ Decl(Const("v", ty.Of(S), Construct(ty.Of(S)))),
+ CallStmt(Call("foo", AddressOf(Expr(Source{{12, 34}},
+ MemberAccessor("v", "m"))))),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
+}
+
+TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam) {
+ // fn foo(p: ptr<function, i32>) {}
+ // fn bar(p: ptr<function, i32>) {
+ // foo(p);
+ // }
+ Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
+ ty.void_(), {});
+ Func("bar", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
+ ty.void_(), ast::StatementList{CallStmt(Call("foo", Expr("p")))});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParamWithMain) {
+ // fn foo(p: ptr<function, i32>) {}
+ // fn bar(p: ptr<function, i32>) {
+ // foo(p);
+ // }
+ // @stage(fragment)
+ // fn main() {
+ // var v: i32;
+ // bar(&v);
+ // }
+ Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
+ ty.void_(), {});
+ Func("bar", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
+ ty.void_(), ast::StatementList{CallStmt(Call("foo", Expr("p")))});
+ Func("main", ast::VariableList{}, ty.void_(),
+ {
+ Decl(Var("v", ty.i32(), Expr(1))),
+ CallStmt(Call("foo", AddressOf(Expr("v")))),
+ },
+ {
+ Stage(ast::PipelineStage::kFragment),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCallValidationTest, LetPointer) {
+ // fn x(p : ptr<function, i32>) -> i32 {}
+ // @stage(fragment)
+ // fn main() {
+ // var v: i32;
+ // let p: ptr<function, i32> = &v;
+ // var c: i32 = x(p);
+ // }
+ Func("x", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
+ ty.void_(), {});
+ auto* v = Var("v", ty.i32());
+ auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kFunction),
+ AddressOf(v));
+ auto* c = Var("c", ty.i32(), ast::StorageClass::kNone,
+ Call("x", Expr(Source{{12, 34}}, p)));
+ Func("main", ast::VariableList{}, ty.void_(),
+ {
+ Decl(v),
+ Decl(p),
+ Decl(c),
+ },
+ {
+ Stage(ast::PipelineStage::kFragment),
+ });
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: expected an address-of expression of a variable "
+ "identifier expression or a function parameter");
+}
+
+TEST_F(ResolverCallValidationTest, LetPointerPrivate) {
+ // let p: ptr<private, i32> = &v;
+ // fn foo(p : ptr<private, i32>) -> i32 {}
+ // var v: i32;
+ // @stage(fragment)
+ // fn main() {
+ // var c: i32 = foo(p);
+ // }
+ Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kPrivate))},
+ ty.void_(), {});
+ auto* v = Global("v", ty.i32(), ast::StorageClass::kPrivate);
+ auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kPrivate),
+ AddressOf(v));
+ auto* c = Var("c", ty.i32(), ast::StorageClass::kNone,
+ Call("foo", Expr(Source{{12, 34}}, p)));
+ Func("main", ast::VariableList{}, ty.void_(),
+ {
+ Decl(p),
+ Decl(c),
+ },
+ {
+ Stage(ast::PipelineStage::kFragment),
+ });
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: expected an address-of expression of a variable "
+ "identifier expression or a function parameter");
+}
+
+TEST_F(ResolverCallValidationTest, CallVariable) {
+ // var v : i32;
+ // fn f() {
+ // v();
+ // }
+ Global("v", ty.i32(), ast::StorageClass::kPrivate);
+ Func("f", {}, ty.void_(), {CallStmt(Call(Source{{12, 34}}, "v"))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(error: cannot call variable 'v'
+note: 'v' declared here)");
+}
+
+TEST_F(ResolverCallValidationTest, CallVariableShadowsFunction) {
+ // fn x() {}
+ // fn f() {
+ // var x : i32;
+ // x();
+ // }
+ Func("x", {}, ty.void_(), {});
+ Func("f", {}, ty.void_(),
+ {
+ Decl(Var(Source{{56, 78}}, "x", ty.i32())),
+ CallStmt(Call(Source{{12, 34}}, "x")),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(error: cannot call variable 'x'
+56:78 note: 'x' declared here)");
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/compound_statement_test.cc b/src/tint/resolver/compound_statement_test.cc
new file mode 100644
index 0000000..bf130ca
--- /dev/null
+++ b/src/tint/resolver/compound_statement_test.cc
@@ -0,0 +1,380 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/block_statement.h"
+#include "src/tint/sem/for_loop_statement.h"
+#include "src/tint/sem/if_statement.h"
+#include "src/tint/sem/loop_statement.h"
+#include "src/tint/sem/switch_statement.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ResolverCompoundStatementTest = ResolverTest;
+
+TEST_F(ResolverCompoundStatementTest, FunctionBlock) {
+ // fn F() {
+ // var x : 32;
+ // }
+ auto* stmt = Decl(Var("x", ty.i32()));
+ auto* f = Func("F", {}, ty.void_(), {stmt});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* s = Sem().Get(stmt);
+ ASSERT_NE(s, nullptr);
+ ASSERT_NE(s->Block(), nullptr);
+ ASSERT_TRUE(s->Block()->Is<sem::FunctionBlockStatement>());
+ EXPECT_EQ(s->Block(), s->FindFirstParent<sem::BlockStatement>());
+ EXPECT_EQ(s->Block(), s->FindFirstParent<sem::FunctionBlockStatement>());
+ EXPECT_EQ(s->Function()->Declaration(), f);
+ EXPECT_EQ(s->Block()->Parent(), nullptr);
+}
+
+TEST_F(ResolverCompoundStatementTest, Block) {
+ // fn F() {
+ // {
+ // var x : 32;
+ // }
+ // }
+ auto* stmt = Decl(Var("x", ty.i32()));
+ auto* block = Block(stmt);
+ auto* f = Func("F", {}, ty.void_(), {block});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ {
+ auto* s = Sem().Get(block);
+ ASSERT_NE(s, nullptr);
+ EXPECT_TRUE(s->Is<sem::BlockStatement>());
+ EXPECT_EQ(s, s->Block());
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::FunctionBlockStatement>());
+ }
+ {
+ auto* s = Sem().Get(stmt);
+ ASSERT_NE(s, nullptr);
+ ASSERT_NE(s->Block(), nullptr);
+ EXPECT_EQ(s->Block(), s->FindFirstParent<sem::BlockStatement>());
+ EXPECT_EQ(s->Block()->Parent(),
+ s->FindFirstParent<sem::FunctionBlockStatement>());
+ ASSERT_TRUE(s->Block()->Parent()->Is<sem::FunctionBlockStatement>());
+ EXPECT_EQ(s->Function()->Declaration(), f);
+ EXPECT_EQ(s->Block()->Parent()->Parent(), nullptr);
+ }
+}
+
+TEST_F(ResolverCompoundStatementTest, Loop) {
+ // fn F() {
+ // loop {
+ // break;
+ // continuing {
+ // stmt;
+ // }
+ // }
+ // }
+ auto* brk = Break();
+ auto* stmt = Ignore(1);
+ auto* loop = Loop(Block(brk), Block(stmt));
+ auto* f = Func("F", {}, ty.void_(), {loop});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ {
+ auto* s = Sem().Get(loop);
+ ASSERT_NE(s, nullptr);
+ EXPECT_TRUE(s->Is<sem::LoopStatement>());
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::FunctionBlockStatement>());
+ EXPECT_EQ(s->Parent(), s->Block());
+ }
+ {
+ auto* s = Sem().Get(brk);
+ ASSERT_NE(s, nullptr);
+ ASSERT_NE(s->Block(), nullptr);
+ EXPECT_EQ(s->Parent(), s->Block());
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::LoopBlockStatement>());
+
+ EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::LoopStatement>());
+ EXPECT_TRUE(Is<sem::LoopStatement>(s->Parent()->Parent()));
+
+ EXPECT_EQ(s->Parent()->Parent()->Parent(),
+ s->FindFirstParent<sem::FunctionBlockStatement>());
+ EXPECT_TRUE(
+ Is<sem::FunctionBlockStatement>(s->Parent()->Parent()->Parent()));
+
+ EXPECT_EQ(s->Function()->Declaration(), f);
+
+ EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(), nullptr);
+ }
+ {
+ auto* s = Sem().Get(stmt);
+ ASSERT_NE(s, nullptr);
+ ASSERT_NE(s->Block(), nullptr);
+ EXPECT_EQ(s->Parent(), s->Block());
+
+ EXPECT_EQ(s->Parent(),
+ s->FindFirstParent<sem::LoopContinuingBlockStatement>());
+ EXPECT_TRUE(Is<sem::LoopContinuingBlockStatement>(s->Parent()));
+
+ EXPECT_EQ(s->Parent()->Parent(),
+ s->FindFirstParent<sem::LoopBlockStatement>());
+ EXPECT_TRUE(Is<sem::LoopBlockStatement>(s->Parent()->Parent()));
+
+ EXPECT_EQ(s->Parent()->Parent()->Parent(),
+ s->FindFirstParent<sem::LoopStatement>());
+ EXPECT_TRUE(Is<sem::LoopStatement>(s->Parent()->Parent()->Parent()));
+
+ EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
+ s->FindFirstParent<sem::FunctionBlockStatement>());
+ EXPECT_TRUE(Is<sem::FunctionBlockStatement>(
+ s->Parent()->Parent()->Parent()->Parent()));
+ EXPECT_EQ(s->Function()->Declaration(), f);
+
+ EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent()->Parent(), nullptr);
+ }
+}
+
+TEST_F(ResolverCompoundStatementTest, ForLoop) {
+ // fn F() {
+ // for (var i : u32; true; i = i + 1u) {
+ // return;
+ // }
+ // }
+ auto* init = Decl(Var("i", ty.u32()));
+ auto* cond = Expr(true);
+ auto* cont = Assign("i", Add("i", 1u));
+ auto* stmt = Return();
+ auto* body = Block(stmt);
+ auto* for_ = For(init, cond, cont, body);
+ auto* f = Func("F", {}, ty.void_(), {for_});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ {
+ auto* s = Sem().Get(for_);
+ ASSERT_NE(s, nullptr);
+ EXPECT_TRUE(s->Is<sem::ForLoopStatement>());
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::FunctionBlockStatement>());
+ EXPECT_EQ(s->Parent(), s->Block());
+ }
+ {
+ auto* s = Sem().Get(init);
+ ASSERT_NE(s, nullptr);
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::ForLoopStatement>());
+ EXPECT_TRUE(Is<sem::ForLoopStatement>(s->Parent()));
+ EXPECT_EQ(s->Block(), s->FindFirstParent<sem::FunctionBlockStatement>());
+ EXPECT_TRUE(Is<sem::FunctionBlockStatement>(s->Parent()->Parent()));
+ }
+ { // Condition expression's statement is the for-loop itself
+ auto* e = Sem().Get(cond);
+ ASSERT_NE(e, nullptr);
+ auto* s = e->Stmt();
+ ASSERT_NE(s, nullptr);
+ ASSERT_TRUE(Is<sem::ForLoopStatement>(s));
+ ASSERT_NE(s->Parent(), nullptr);
+ EXPECT_EQ(s->Parent(), s->Block());
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::FunctionBlockStatement>());
+ EXPECT_TRUE(Is<sem::FunctionBlockStatement>(s->Block()));
+ }
+ {
+ auto* s = Sem().Get(cont);
+ ASSERT_NE(s, nullptr);
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::ForLoopStatement>());
+ EXPECT_TRUE(Is<sem::ForLoopStatement>(s->Parent()));
+ EXPECT_EQ(s->Block(), s->FindFirstParent<sem::FunctionBlockStatement>());
+ EXPECT_TRUE(Is<sem::FunctionBlockStatement>(s->Parent()->Parent()));
+ }
+ {
+ auto* s = Sem().Get(stmt);
+ ASSERT_NE(s, nullptr);
+ ASSERT_NE(s->Block(), nullptr);
+ EXPECT_EQ(s->Parent(), s->Block());
+ EXPECT_EQ(s->Block(), s->FindFirstParent<sem::LoopBlockStatement>());
+ EXPECT_TRUE(Is<sem::ForLoopStatement>(s->Parent()->Parent()));
+ EXPECT_EQ(s->Block()->Parent(),
+ s->FindFirstParent<sem::ForLoopStatement>());
+ ASSERT_TRUE(
+ Is<sem::FunctionBlockStatement>(s->Block()->Parent()->Parent()));
+ EXPECT_EQ(s->Block()->Parent()->Parent(),
+ s->FindFirstParent<sem::FunctionBlockStatement>());
+ EXPECT_EQ(s->Function()->Declaration(), f);
+ EXPECT_EQ(s->Block()->Parent()->Parent()->Parent(), nullptr);
+ }
+}
+
+TEST_F(ResolverCompoundStatementTest, If) {
+ // fn F() {
+ // if (cond_a) {
+ // stat_a;
+ // } else if (cond_b) {
+ // stat_b;
+ // } else {
+ // stat_c;
+ // }
+ // }
+
+ auto* cond_a = Expr(true);
+ auto* stmt_a = Ignore(1);
+ auto* cond_b = Expr(true);
+ auto* stmt_b = Ignore(1);
+ auto* stmt_c = Ignore(1);
+ auto* if_stmt = If(cond_a, Block(stmt_a), Else(cond_b, Block(stmt_b)),
+ Else(nullptr, Block(stmt_c)));
+ WrapInFunction(if_stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ {
+ auto* s = Sem().Get(if_stmt);
+ ASSERT_NE(s, nullptr);
+ EXPECT_TRUE(s->Is<sem::IfStatement>());
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::FunctionBlockStatement>());
+ EXPECT_EQ(s->Parent(), s->Block());
+ }
+ {
+ auto* e = Sem().Get(cond_a);
+ ASSERT_NE(e, nullptr);
+ auto* s = e->Stmt();
+ ASSERT_NE(s, nullptr);
+ EXPECT_TRUE(s->Is<sem::IfStatement>());
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::FunctionBlockStatement>());
+ EXPECT_EQ(s->Parent(), s->Block());
+ }
+ {
+ auto* s = Sem().Get(stmt_a);
+ ASSERT_NE(s, nullptr);
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
+ EXPECT_EQ(s->Parent(), s->Block());
+ EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::IfStatement>());
+ EXPECT_EQ(s->Parent()->Parent()->Parent(),
+ s->FindFirstParent<sem::FunctionBlockStatement>());
+ }
+ {
+ auto* e = Sem().Get(cond_b);
+ ASSERT_NE(e, nullptr);
+ auto* s = e->Stmt();
+ ASSERT_NE(s, nullptr);
+ EXPECT_TRUE(s->Is<sem::ElseStatement>());
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::IfStatement>());
+ EXPECT_EQ(s->Parent()->Parent(),
+ s->FindFirstParent<sem::FunctionBlockStatement>());
+ EXPECT_EQ(s->Parent()->Parent(), s->Block());
+ }
+ {
+ auto* s = Sem().Get(stmt_b);
+ ASSERT_NE(s, nullptr);
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
+ EXPECT_EQ(s->Parent(), s->Block());
+ EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::ElseStatement>());
+ EXPECT_EQ(s->Parent()->Parent()->Parent(),
+ s->FindFirstParent<sem::IfStatement>());
+ EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
+ s->FindFirstParent<sem::FunctionBlockStatement>());
+ }
+ {
+ auto* s = Sem().Get(stmt_c);
+ ASSERT_NE(s, nullptr);
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
+ EXPECT_EQ(s->Parent(), s->Block());
+ EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::ElseStatement>());
+ EXPECT_EQ(s->Parent()->Parent()->Parent(),
+ s->FindFirstParent<sem::IfStatement>());
+ EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
+ s->FindFirstParent<sem::FunctionBlockStatement>());
+ }
+}
+
+TEST_F(ResolverCompoundStatementTest, Switch) {
+ // fn F() {
+ // switch (expr) {
+ // case 1: {
+ // stmt_a;
+ // }
+ // case 2: {
+ // stmt_b;
+ // }
+ // default: {
+ // stmt_c;
+ // }
+ // }
+ // }
+
+ auto* expr = Expr(5);
+ auto* stmt_a = Ignore(1);
+ auto* stmt_b = Ignore(1);
+ auto* stmt_c = Ignore(1);
+ auto* swi = Switch(expr, Case(Expr(1), Block(stmt_a)),
+ Case(Expr(2), Block(stmt_b)), DefaultCase(Block(stmt_c)));
+ WrapInFunction(swi);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ {
+ auto* s = Sem().Get(swi);
+ ASSERT_NE(s, nullptr);
+ EXPECT_TRUE(s->Is<sem::SwitchStatement>());
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::FunctionBlockStatement>());
+ EXPECT_EQ(s->Parent(), s->Block());
+ }
+ {
+ auto* e = Sem().Get(expr);
+ ASSERT_NE(e, nullptr);
+ auto* s = e->Stmt();
+ ASSERT_NE(s, nullptr);
+ EXPECT_TRUE(s->Is<sem::SwitchStatement>());
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::FunctionBlockStatement>());
+ EXPECT_EQ(s->Parent(), s->Block());
+ }
+ {
+ auto* s = Sem().Get(stmt_a);
+ ASSERT_NE(s, nullptr);
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
+ EXPECT_EQ(s->Parent(), s->Block());
+ EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::CaseStatement>());
+ EXPECT_EQ(s->Parent()->Parent()->Parent(),
+ s->FindFirstParent<sem::SwitchStatement>());
+ EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
+ s->FindFirstParent<sem::FunctionBlockStatement>());
+ }
+ {
+ auto* s = Sem().Get(stmt_b);
+ ASSERT_NE(s, nullptr);
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
+ EXPECT_EQ(s->Parent(), s->Block());
+ EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::CaseStatement>());
+ EXPECT_EQ(s->Parent()->Parent()->Parent(),
+ s->FindFirstParent<sem::SwitchStatement>());
+ EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
+ s->FindFirstParent<sem::FunctionBlockStatement>());
+ }
+ {
+ auto* s = Sem().Get(stmt_c);
+ ASSERT_NE(s, nullptr);
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
+ EXPECT_EQ(s->Parent(), s->Block());
+ EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::CaseStatement>());
+ EXPECT_EQ(s->Parent()->Parent()->Parent(),
+ s->FindFirstParent<sem::SwitchStatement>());
+ EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
+ s->FindFirstParent<sem::FunctionBlockStatement>());
+ }
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/control_block_validation_test.cc b/src/tint/resolver/control_block_validation_test.cc
new file mode 100644
index 0000000..9406da4
--- /dev/null
+++ b/src/tint/resolver/control_block_validation_test.cc
@@ -0,0 +1,364 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/break_statement.h"
+#include "src/tint/ast/continue_statement.h"
+#include "src/tint/ast/fallthrough_statement.h"
+#include "src/tint/ast/switch_statement.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+
+namespace tint {
+namespace {
+
+class ResolverControlBlockValidationTest : public resolver::TestHelper,
+ public testing::Test {};
+
+TEST_F(ResolverControlBlockValidationTest,
+ SwitchSelectorExpressionNoneIntegerType_Fail) {
+ // var a : f32 = 3.14;
+ // switch (a) {
+ // default: {}
+ // }
+ auto* var = Var("a", ty.f32(), Expr(3.14f));
+
+ auto* block = Block(Decl(var), Switch(Expr(Source{{12, 34}}, "a"), //
+ DefaultCase()));
+
+ WrapInFunction(block);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: switch statement selector expression must be of a "
+ "scalar integer type");
+}
+
+TEST_F(ResolverControlBlockValidationTest, SwitchWithoutDefault_Fail) {
+ // var a : i32 = 2;
+ // switch (a) {
+ // case 1: {}
+ // }
+ auto* var = Var("a", ty.i32(), Expr(2));
+
+ auto* block = Block(Decl(var), //
+ Switch(Source{{12, 34}}, "a", //
+ Case(Expr(1))));
+
+ WrapInFunction(block);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: switch statement must have a default clause");
+}
+
+TEST_F(ResolverControlBlockValidationTest, SwitchWithTwoDefault_Fail) {
+ // var a : i32 = 2;
+ // switch (a) {
+ // default: {}
+ // case 1: {}
+ // default: {}
+ // }
+ auto* var = Var("a", ty.i32(), Expr(2));
+
+ auto* block = Block(Decl(var), //
+ Switch("a", //
+ DefaultCase(), //
+ Case(Expr(1)), //
+ DefaultCase(Source{{12, 34}})));
+
+ WrapInFunction(block);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: switch statement must have exactly one default clause");
+}
+
+TEST_F(ResolverControlBlockValidationTest, UnreachableCode_Loop_continue) {
+ // loop {
+ // if (false) { break; }
+ // var z: i32;
+ // continue;
+ // z = 1;
+ // }
+ auto* decl_z = Decl(Var("z", ty.i32()));
+ auto* cont = Continue();
+ auto* assign_z = Assign(Source{{12, 34}}, "z", 1);
+ WrapInFunction(
+ Loop(Block(If(false, Block(Break())), decl_z, cont, assign_z)));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(r()->error(), "12:34 warning: code is unreachable");
+ EXPECT_TRUE(Sem().Get(decl_z)->IsReachable());
+ EXPECT_TRUE(Sem().Get(cont)->IsReachable());
+ EXPECT_FALSE(Sem().Get(assign_z)->IsReachable());
+}
+
+TEST_F(ResolverControlBlockValidationTest,
+ UnreachableCode_Loop_continue_InBlocks) {
+ // loop {
+ // if (false) { break; }
+ // var z: i32;
+ // {{{continue;}}}
+ // z = 1;
+ // }
+ auto* decl_z = Decl(Var("z", ty.i32()));
+ auto* cont = Continue();
+ auto* assign_z = Assign(Source{{12, 34}}, "z", 1);
+ WrapInFunction(Loop(Block(If(false, Block(Break())), decl_z,
+ Block(Block(Block(cont))), assign_z)));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(r()->error(), "12:34 warning: code is unreachable");
+ EXPECT_TRUE(Sem().Get(decl_z)->IsReachable());
+ EXPECT_TRUE(Sem().Get(cont)->IsReachable());
+ EXPECT_FALSE(Sem().Get(assign_z)->IsReachable());
+}
+
+TEST_F(ResolverControlBlockValidationTest, UnreachableCode_ForLoop_continue) {
+ // for (;false;) {
+ // var z: i32;
+ // continue;
+ // z = 1;
+ // }
+ auto* decl_z = Decl(Var("z", ty.i32()));
+ auto* cont = Continue();
+ auto* assign_z = Assign(Source{{12, 34}}, "z", 1);
+ WrapInFunction(For(nullptr, false, nullptr, //
+ Block(decl_z, cont, assign_z)));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(r()->error(), "12:34 warning: code is unreachable");
+ EXPECT_TRUE(Sem().Get(decl_z)->IsReachable());
+ EXPECT_TRUE(Sem().Get(cont)->IsReachable());
+ EXPECT_FALSE(Sem().Get(assign_z)->IsReachable());
+}
+
+TEST_F(ResolverControlBlockValidationTest,
+ UnreachableCode_ForLoop_continue_InBlocks) {
+ // for (;false;) {
+ // var z: i32;
+ // {{{continue;}}}
+ // z = 1;
+ // }
+ auto* decl_z = Decl(Var("z", ty.i32()));
+ auto* cont = Continue();
+ auto* assign_z = Assign(Source{{12, 34}}, "z", 1);
+ WrapInFunction(For(nullptr, false, nullptr,
+ Block(decl_z, Block(Block(Block(cont))), assign_z)));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(r()->error(), "12:34 warning: code is unreachable");
+ EXPECT_TRUE(Sem().Get(decl_z)->IsReachable());
+ EXPECT_TRUE(Sem().Get(cont)->IsReachable());
+ EXPECT_FALSE(Sem().Get(assign_z)->IsReachable());
+}
+
+TEST_F(ResolverControlBlockValidationTest, UnreachableCode_break) {
+ // switch (1) {
+ // case 1: {
+ // var z: i32;
+ // break;
+ // z = 1;
+ // default: {}
+ // }
+ auto* decl_z = Decl(Var("z", ty.i32()));
+ auto* brk = Break();
+ auto* assign_z = Assign(Source{{12, 34}}, "z", 1);
+ WrapInFunction( //
+ Block(Switch(1, //
+ Case(Expr(1), Block(decl_z, brk, assign_z)), //
+ DefaultCase())));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(r()->error(), "12:34 warning: code is unreachable");
+ EXPECT_TRUE(Sem().Get(decl_z)->IsReachable());
+ EXPECT_TRUE(Sem().Get(brk)->IsReachable());
+ EXPECT_FALSE(Sem().Get(assign_z)->IsReachable());
+}
+
+TEST_F(ResolverControlBlockValidationTest, UnreachableCode_break_InBlocks) {
+ // loop {
+ // switch (1) {
+ // case 1: { {{{break;}}} var a : u32 = 2;}
+ // default: {}
+ // }
+ // break;
+ // }
+ auto* decl_z = Decl(Var("z", ty.i32()));
+ auto* brk = Break();
+ auto* assign_z = Assign(Source{{12, 34}}, "z", 1);
+ WrapInFunction(Loop(Block(
+ Switch(1, //
+ Case(Expr(1), Block(decl_z, Block(Block(Block(brk))), assign_z)),
+ DefaultCase()), //
+ Break())));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(r()->error(), "12:34 warning: code is unreachable");
+ EXPECT_TRUE(Sem().Get(decl_z)->IsReachable());
+ EXPECT_TRUE(Sem().Get(brk)->IsReachable());
+ EXPECT_FALSE(Sem().Get(assign_z)->IsReachable());
+}
+
+TEST_F(ResolverControlBlockValidationTest,
+ SwitchConditionTypeMustMatchSelectorType2_Fail) {
+ // var a : u32 = 2;
+ // switch (a) {
+ // case 1: {}
+ // default: {}
+ // }
+ auto* var = Var("a", ty.i32(), Expr(2));
+
+ auto* block = Block(Decl(var), Switch("a", //
+ Case(Source{{12, 34}}, {Expr(1u)}), //
+ DefaultCase()));
+ WrapInFunction(block);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: the case selector values must have the same type as "
+ "the selector expression.");
+}
+
+TEST_F(ResolverControlBlockValidationTest,
+ SwitchConditionTypeMustMatchSelectorType_Fail) {
+ // var a : u32 = 2;
+ // switch (a) {
+ // case -1: {}
+ // default: {}
+ // }
+ auto* var = Var("a", ty.u32(), Expr(2u));
+
+ auto* block = Block(Decl(var), //
+ Switch("a", //
+ Case(Source{{12, 34}}, {Expr(-1)}), //
+ DefaultCase()));
+ WrapInFunction(block);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: the case selector values must have the same type as "
+ "the selector expression.");
+}
+
+TEST_F(ResolverControlBlockValidationTest,
+ NonUniqueCaseSelectorValueUint_Fail) {
+ // var a : u32 = 3;
+ // switch (a) {
+ // case 0u: {}
+ // case 2u, 3u, 2u: {}
+ // default: {}
+ // }
+ auto* var = Var("a", ty.u32(), Expr(3u));
+
+ auto* block = Block(Decl(var), //
+ Switch("a", //
+ Case(Expr(0u)),
+ Case({
+ Expr(Source{{12, 34}}, 2u),
+ Expr(3u),
+ Expr(Source{{56, 78}}, 2u),
+ }),
+ DefaultCase()));
+ WrapInFunction(block);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "56:78 error: duplicate switch case '2'\n"
+ "12:34 note: previous case declared here");
+}
+
+TEST_F(ResolverControlBlockValidationTest,
+ NonUniqueCaseSelectorValueSint_Fail) {
+ // var a : i32 = 2;
+ // switch (a) {
+ // case -10: {}
+ // case 0,1,2,-10: {}
+ // default: {}
+ // }
+ auto* var = Var("a", ty.i32(), Expr(2));
+
+ auto* block = Block(Decl(var), //
+ Switch("a", //
+ Case(Expr(Source{{12, 34}}, -10)),
+ Case({
+ Expr(0),
+ Expr(1),
+ Expr(2),
+ Expr(Source{{56, 78}}, -10),
+ }),
+ DefaultCase()));
+ WrapInFunction(block);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "56:78 error: duplicate switch case '-10'\n"
+ "12:34 note: previous case declared here");
+}
+
+TEST_F(ResolverControlBlockValidationTest,
+ LastClauseLastStatementIsFallthrough_Fail) {
+ // var a : i32 = 2;
+ // switch (a) {
+ // default: { fallthrough; }
+ // }
+ auto* var = Var("a", ty.i32(), Expr(2));
+ auto* fallthrough = create<ast::FallthroughStatement>(Source{{12, 34}});
+ auto* block = Block(Decl(var), //
+ Switch("a", //
+ DefaultCase(Block(fallthrough))));
+ WrapInFunction(block);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: a fallthrough statement must not be used in the last "
+ "switch case");
+}
+
+TEST_F(ResolverControlBlockValidationTest, SwitchCase_Pass) {
+ // var a : i32 = 2;
+ // switch (a) {
+ // default: {}
+ // case 5: {}
+ // }
+ auto* var = Var("a", ty.i32(), Expr(2));
+
+ auto* block = Block(Decl(var), //
+ Switch("a", //
+ DefaultCase(Source{{12, 34}}), //
+ Case(Expr(5))));
+ WrapInFunction(block);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverControlBlockValidationTest, SwitchCaseAlias_Pass) {
+ // type MyInt = u32;
+ // var v: MyInt;
+ // switch(v){
+ // default: {}
+ // }
+
+ auto* my_int = Alias("MyInt", ty.u32());
+ auto* var = Var("a", ty.Of(my_int), Expr(2u));
+ auto* block = Block(Decl(var), //
+ Switch("a", DefaultCase(Source{{12, 34}})));
+
+ WrapInFunction(block);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+} // namespace
+} // namespace tint
diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc
new file mode 100644
index 0000000..8d58018
--- /dev/null
+++ b/src/tint/resolver/dependency_graph.cc
@@ -0,0 +1,736 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/dependency_graph.h"
+
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "src/tint/ast/continue_statement.h"
+#include "src/tint/ast/discard_statement.h"
+#include "src/tint/ast/fallthrough_statement.h"
+#include "src/tint/ast/traverse_expressions.h"
+#include "src/tint/scope_stack.h"
+#include "src/tint/sem/builtin.h"
+#include "src/tint/utils/defer.h"
+#include "src/tint/utils/map.h"
+#include "src/tint/utils/scoped_assignment.h"
+#include "src/tint/utils/unique_vector.h"
+
+#define TINT_DUMP_DEPENDENCY_GRAPH 0
+
+namespace tint {
+namespace resolver {
+namespace {
+
+// Forward declaration
+struct Global;
+
+/// Dependency describes how one global depends on another global
+struct DependencyInfo {
+ /// The source of the symbol that forms the dependency
+ Source source;
+ /// A string describing how the dependency is referenced. e.g. 'calls'
+ const char* action = nullptr;
+};
+
+/// DependencyEdge describes the two Globals used to define a dependency
+/// relationship.
+struct DependencyEdge {
+ /// The Global that depends on #to
+ const Global* from;
+ /// The Global that is depended on by #from
+ const Global* to;
+};
+
+/// DependencyEdgeCmp implements the contracts of std::equal_to<DependencyEdge>
+/// and std::hash<DependencyEdge>.
+struct DependencyEdgeCmp {
+ /// Equality operator
+ bool operator()(const DependencyEdge& lhs, const DependencyEdge& rhs) const {
+ return lhs.from == rhs.from && lhs.to == rhs.to;
+ }
+ /// Hashing operator
+ inline std::size_t operator()(const DependencyEdge& d) const {
+ return utils::Hash(d.from, d.to);
+ }
+};
+
+/// A map of DependencyEdge to DependencyInfo
+using DependencyEdges = std::unordered_map<DependencyEdge,
+ DependencyInfo,
+ DependencyEdgeCmp,
+ DependencyEdgeCmp>;
+
+/// Global describes a module-scope variable, type or function.
+struct Global {
+ explicit Global(const ast::Node* n) : node(n) {}
+
+ /// The declaration ast::Node
+ const ast::Node* node;
+ /// A list of dependencies that this global depends on
+ std::vector<Global*> deps;
+};
+
+/// A map of global name to Global
+using GlobalMap = std::unordered_map<Symbol, Global*>;
+
+/// Raises an ICE that a global ast::Node type was not handled by this system.
+void UnhandledNode(diag::List& diagnostics, const ast::Node* node) {
+ TINT_ICE(Resolver, diagnostics)
+ << "unhandled node type: " << node->TypeInfo().name;
+}
+
+/// Raises an error diagnostic with the given message and source.
+void AddError(diag::List& diagnostics,
+ const std::string& msg,
+ const Source& source) {
+ diagnostics.add_error(diag::System::Resolver, msg, source);
+}
+
+/// Raises a note diagnostic with the given message and source.
+void AddNote(diag::List& diagnostics,
+ const std::string& msg,
+ const Source& source) {
+ diagnostics.add_note(diag::System::Resolver, msg, source);
+}
+
+/// DependencyScanner is used to traverse a module to build the list of
+/// global-to-global dependencies.
+class DependencyScanner {
+ public:
+ /// Constructor
+ /// @param syms the program symbol table
+ /// @param globals_by_name map of global symbol to Global pointer
+ /// @param diagnostics diagnostic messages, appended with any errors found
+ /// @param graph the dependency graph to populate with resolved symbols
+ /// @param edges the map of globals-to-global dependency edges, which will
+ /// be populated by calls to Scan()
+ DependencyScanner(const SymbolTable& syms,
+ const GlobalMap& globals_by_name,
+ diag::List& diagnostics,
+ DependencyGraph& graph,
+ DependencyEdges& edges)
+ : symbols_(syms),
+ globals_(globals_by_name),
+ diagnostics_(diagnostics),
+ graph_(graph),
+ dependency_edges_(edges) {
+ // Register all the globals at global-scope
+ for (auto it : globals_by_name) {
+ scope_stack_.Set(it.first, it.second->node);
+ }
+ }
+
+ /// Walks the global declarations, resolving symbols, and determining the
+ /// dependencies of each global.
+ void Scan(Global* global) {
+ TINT_SCOPED_ASSIGNMENT(current_global_, global);
+ Switch(
+ global->node,
+ [&](const ast::Struct* str) {
+ Declare(str->name, str);
+ for (auto* member : str->members) {
+ TraverseType(member->type);
+ }
+ },
+ [&](const ast::Alias* alias) {
+ Declare(alias->name, alias);
+ TraverseType(alias->type);
+ },
+ [&](const ast::Function* func) {
+ Declare(func->symbol, func);
+ TraverseAttributes(func->attributes);
+ TraverseFunction(func);
+ },
+ [&](const ast::Variable* var) {
+ Declare(var->symbol, var);
+ TraverseType(var->type);
+ if (var->constructor) {
+ TraverseExpression(var->constructor);
+ }
+ },
+ [&](Default) { UnhandledNode(diagnostics_, global->node); });
+ }
+
+ private:
+ /// Traverses the function, performing symbol resolution and determining
+ /// global dependencies.
+ void TraverseFunction(const ast::Function* func) {
+ // Perform symbol resolution on all the parameter types before registering
+ // the parameters themselves. This allows the case of declaring a parameter
+ // with the same identifier as its type.
+ for (auto* param : func->params) {
+ TraverseType(param->type);
+ }
+ // Resolve the return type
+ TraverseType(func->return_type);
+
+ // Push the scope stack for the parameters and function body.
+ scope_stack_.Push();
+ TINT_DEFER(scope_stack_.Pop());
+
+ for (auto* param : func->params) {
+ if (auto* shadows = scope_stack_.Get(param->symbol)) {
+ graph_.shadows.emplace(param, shadows);
+ }
+ Declare(param->symbol, param);
+ }
+ if (func->body) {
+ TraverseStatements(func->body->statements);
+ }
+ }
+
+ /// Traverses the statements, performing symbol resolution and determining
+ /// global dependencies.
+ void TraverseStatements(const ast::StatementList& stmts) {
+ for (auto* s : stmts) {
+ TraverseStatement(s);
+ }
+ }
+
+ /// Traverses the statement, performing symbol resolution and determining
+ /// global dependencies.
+ void TraverseStatement(const ast::Statement* stmt) {
+ if (!stmt) {
+ return;
+ }
+ Switch(
+ stmt, //
+ [&](const ast::AssignmentStatement* a) {
+ TraverseExpression(a->lhs);
+ TraverseExpression(a->rhs);
+ },
+ [&](const ast::BlockStatement* b) {
+ scope_stack_.Push();
+ TINT_DEFER(scope_stack_.Pop());
+ TraverseStatements(b->statements);
+ },
+ [&](const ast::CallStatement* r) { //
+ TraverseExpression(r->expr);
+ },
+ [&](const ast::ForLoopStatement* l) {
+ scope_stack_.Push();
+ TINT_DEFER(scope_stack_.Pop());
+ TraverseStatement(l->initializer);
+ TraverseExpression(l->condition);
+ TraverseStatement(l->continuing);
+ TraverseStatement(l->body);
+ },
+ [&](const ast::LoopStatement* l) {
+ scope_stack_.Push();
+ TINT_DEFER(scope_stack_.Pop());
+ TraverseStatements(l->body->statements);
+ TraverseStatement(l->continuing);
+ },
+ [&](const ast::IfStatement* i) {
+ TraverseExpression(i->condition);
+ TraverseStatement(i->body);
+ for (auto* e : i->else_statements) {
+ TraverseExpression(e->condition);
+ TraverseStatement(e->body);
+ }
+ },
+ [&](const ast::ReturnStatement* r) { //
+ TraverseExpression(r->value);
+ },
+ [&](const ast::SwitchStatement* s) {
+ TraverseExpression(s->condition);
+ for (auto* c : s->body) {
+ for (auto* sel : c->selectors) {
+ TraverseExpression(sel);
+ }
+ TraverseStatement(c->body);
+ }
+ },
+ [&](const ast::VariableDeclStatement* v) {
+ if (auto* shadows = scope_stack_.Get(v->variable->symbol)) {
+ graph_.shadows.emplace(v->variable, shadows);
+ }
+ TraverseType(v->variable->type);
+ TraverseExpression(v->variable->constructor);
+ Declare(v->variable->symbol, v->variable);
+ },
+ [&](Default) {
+ if (!stmt->IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
+ ast::DiscardStatement,
+ ast::FallthroughStatement>()) {
+ UnhandledNode(diagnostics_, stmt);
+ }
+ });
+ }
+
+ /// Adds the symbol definition to the current scope, raising an error if two
+ /// symbols collide within the same scope.
+ void Declare(Symbol symbol, const ast::Node* node) {
+ auto* old = scope_stack_.Set(symbol, node);
+ if (old != nullptr && node != old) {
+ auto name = symbols_.NameFor(symbol);
+ AddError(diagnostics_, "redeclaration of '" + name + "'", node->source);
+ AddNote(diagnostics_, "'" + name + "' previously declared here",
+ old->source);
+ }
+ }
+
+ /// Traverses the expression, performing symbol resolution and determining
+ /// global dependencies.
+ void TraverseExpression(const ast::Expression* root) {
+ if (!root) {
+ return;
+ }
+ ast::TraverseExpressions(
+ root, diagnostics_, [&](const ast::Expression* expr) {
+ Switch(
+ expr,
+ [&](const ast::IdentifierExpression* ident) {
+ AddDependency(ident, ident->symbol, "identifier", "references");
+ },
+ [&](const ast::CallExpression* call) {
+ if (call->target.name) {
+ AddDependency(call->target.name, call->target.name->symbol,
+ "function", "calls");
+ }
+ if (call->target.type) {
+ TraverseType(call->target.type);
+ }
+ },
+ [&](const ast::BitcastExpression* cast) {
+ TraverseType(cast->type);
+ });
+ return ast::TraverseAction::Descend;
+ });
+ }
+
+ /// Traverses the type node, performing symbol resolution and determining
+ /// global dependencies.
+ void TraverseType(const ast::Type* ty) {
+ if (!ty) {
+ return;
+ }
+ Switch(
+ ty, //
+ [&](const ast::Array* arr) {
+ TraverseType(arr->type); //
+ TraverseExpression(arr->count);
+ },
+ [&](const ast::Atomic* atomic) { //
+ TraverseType(atomic->type);
+ },
+ [&](const ast::Matrix* mat) { //
+ TraverseType(mat->type);
+ },
+ [&](const ast::Pointer* ptr) { //
+ TraverseType(ptr->type);
+ },
+ [&](const ast::TypeName* tn) { //
+ AddDependency(tn, tn->name, "type", "references");
+ },
+ [&](const ast::Vector* vec) { //
+ TraverseType(vec->type);
+ },
+ [&](const ast::SampledTexture* tex) { //
+ TraverseType(tex->type);
+ },
+ [&](const ast::MultisampledTexture* tex) { //
+ TraverseType(tex->type);
+ },
+ [&](Default) {
+ if (!ty->IsAnyOf<ast::Void, ast::Bool, ast::I32, ast::U32, ast::F32,
+ ast::DepthTexture, ast::DepthMultisampledTexture,
+ ast::StorageTexture, ast::ExternalTexture,
+ ast::Sampler>()) {
+ UnhandledNode(diagnostics_, ty);
+ }
+ });
+ }
+
+ /// Traverses the attribute list, performing symbol resolution and
+ /// determining global dependencies.
+ void TraverseAttributes(const ast::AttributeList& attrs) {
+ for (auto* attr : attrs) {
+ TraverseAttribute(attr);
+ }
+ }
+
+ /// Traverses the attribute, performing symbol resolution and determining
+ /// global dependencies.
+ void TraverseAttribute(const ast::Attribute* attr) {
+ if (auto* wg = attr->As<ast::WorkgroupAttribute>()) {
+ TraverseExpression(wg->x);
+ TraverseExpression(wg->y);
+ TraverseExpression(wg->z);
+ return;
+ }
+ if (attr->IsAnyOf<
+ ast::BindingAttribute, ast::BuiltinAttribute, ast::GroupAttribute,
+ ast::IdAttribute, ast::InternalAttribute, ast::InterpolateAttribute,
+ ast::InvariantAttribute, ast::LocationAttribute,
+ ast::StageAttribute, ast::StrideAttribute,
+ ast::StructBlockAttribute, ast::StructMemberAlignAttribute,
+ ast::StructMemberOffsetAttribute,
+ ast::StructMemberSizeAttribute>()) {
+ return;
+ }
+
+ UnhandledNode(diagnostics_, attr);
+ }
+
+ /// Adds the dependency from `from` to `to`, erroring if `to` cannot be
+ /// resolved.
+ void AddDependency(const ast::Node* from,
+ Symbol to,
+ const char* use,
+ const char* action) {
+ auto* resolved = scope_stack_.Get(to);
+ if (!resolved) {
+ if (!IsBuiltin(to)) {
+ UnknownSymbol(to, from->source, use);
+ return;
+ }
+ }
+
+ if (auto* global = utils::Lookup(globals_, to);
+ global && global->node == resolved) {
+ if (dependency_edges_
+ .emplace(DependencyEdge{current_global_, global},
+ DependencyInfo{from->source, action})
+ .second) {
+ current_global_->deps.emplace_back(global);
+ }
+ }
+
+ graph_.resolved_symbols.emplace(from, resolved);
+ }
+
+ /// @returns true if `name` is the name of a builtin function
+ bool IsBuiltin(Symbol name) const {
+ return sem::ParseBuiltinType(symbols_.NameFor(name)) !=
+ sem::BuiltinType::kNone;
+ }
+
+ /// Appends an error to the diagnostics that the given symbol cannot be
+ /// resolved.
+ void UnknownSymbol(Symbol name, Source source, const char* use) {
+ AddError(
+ diagnostics_,
+ "unknown " + std::string(use) + ": '" + symbols_.NameFor(name) + "'",
+ source);
+ }
+
+ using VariableMap = std::unordered_map<Symbol, const ast::Variable*>;
+ const SymbolTable& symbols_;
+ const GlobalMap& globals_;
+ diag::List& diagnostics_;
+ DependencyGraph& graph_;
+ DependencyEdges& dependency_edges_;
+
+ ScopeStack<const ast::Node*> scope_stack_;
+ Global* current_global_ = nullptr;
+};
+
+/// The global dependency analysis system
+struct DependencyAnalysis {
+ public:
+ /// Constructor
+ DependencyAnalysis(const SymbolTable& symbols,
+ diag::List& diagnostics,
+ DependencyGraph& graph)
+ : symbols_(symbols), diagnostics_(diagnostics), graph_(graph) {}
+
+ /// Performs global dependency analysis on the module, emitting any errors to
+ /// #diagnostics.
+ /// @returns true if analysis found no errors, otherwise false.
+ bool Run(const ast::Module& module) {
+ // Collect all the named globals from the AST module
+ GatherGlobals(module);
+
+ // Traverse the named globals to build the dependency graph
+ DetermineDependencies();
+
+ // Sort the globals into dependency order
+ SortGlobals();
+
+ // Dump the dependency graph if TINT_DUMP_DEPENDENCY_GRAPH is non-zero
+ DumpDependencyGraph();
+
+ graph_.ordered_globals = std::move(sorted_);
+
+ return !diagnostics_.contains_errors();
+ }
+
+ private:
+ /// @param node the ast::Node of the global declaration
+ /// @returns the symbol of the global declaration node
+ /// @note will raise an ICE if the node is not a type, function or variable
+ /// declaration
+ Symbol SymbolOf(const ast::Node* node) const {
+ return Switch(
+ node, //
+ [&](const ast::TypeDecl* td) { return td->name; },
+ [&](const ast::Function* func) { return func->symbol; },
+ [&](const ast::Variable* var) { return var->symbol; },
+ [&](Default) {
+ UnhandledNode(diagnostics_, node);
+ return Symbol{};
+ });
+ }
+
+ /// @param node the ast::Node of the global declaration
+ /// @returns the name of the global declaration node
+ /// @note will raise an ICE if the node is not a type, function or variable
+ /// declaration
+ std::string NameOf(const ast::Node* node) const {
+ return symbols_.NameFor(SymbolOf(node));
+ }
+
+ /// @param node the ast::Node of the global declaration
+ /// @returns a string representation of the global declaration kind
+ /// @note will raise an ICE if the node is not a type, function or variable
+ /// declaration
+ std::string KindOf(const ast::Node* node) {
+ return Switch(
+ node, //
+ [&](const ast::Struct*) { return "struct"; },
+ [&](const ast::Alias*) { return "alias"; },
+ [&](const ast::Function*) { return "function"; },
+ [&](const ast::Variable* var) { return var->is_const ? "let" : "var"; },
+ [&](Default) {
+ UnhandledNode(diagnostics_, node);
+ return "<error>";
+ });
+ }
+
+ /// Traverses `module`, collecting all the global declarations and populating
+ /// the #globals and #declaration_order fields.
+ void GatherGlobals(const ast::Module& module) {
+ for (auto* node : module.GlobalDeclarations()) {
+ auto* global = allocator_.Create(node);
+ globals_.emplace(SymbolOf(node), global);
+ declaration_order_.emplace_back(global);
+ }
+ }
+
+ /// Walks the global declarations, determining the dependencies of each global
+ /// and adding these to each global's Global::deps field.
+ void DetermineDependencies() {
+ DependencyScanner scanner(symbols_, globals_, diagnostics_, graph_,
+ dependency_edges_);
+ for (auto* global : declaration_order_) {
+ scanner.Scan(global);
+ }
+ }
+
+ /// Performs a depth-first traversal of `root`'s dependencies, calling `enter`
+ /// as the function decends into each dependency and `exit` when bubbling back
+ /// up towards the root.
+ /// @param enter is a function with the signature: `bool(Global*)`. The
+ /// `enter` function returns true if TraverseDependencies() should traverse
+ /// the dependency, otherwise it will be skipped.
+ /// @param exit is a function with the signature: `void(Global*)`. The `exit`
+ /// function is only called if the corresponding `enter` call returned true.
+ template <typename ENTER, typename EXIT>
+ void TraverseDependencies(const Global* root, ENTER&& enter, EXIT&& exit) {
+ // Entry is a single entry in the traversal stack. Entry points to a
+ // dep_idx'th dependency of Entry::global.
+ struct Entry {
+ const Global* global; // The parent global
+ size_t dep_idx; // The dependency index in `global->deps`
+ };
+
+ if (!enter(root)) {
+ return;
+ }
+
+ std::vector<Entry> stack{Entry{root, 0}};
+ while (true) {
+ auto& entry = stack.back();
+ // Have we exhausted the dependencies of entry.global?
+ if (entry.dep_idx < entry.global->deps.size()) {
+ // No, there's more dependencies to traverse.
+ auto& dep = entry.global->deps[entry.dep_idx];
+ // Does the caller want to enter this dependency?
+ if (enter(dep)) { // Yes.
+ stack.push_back(Entry{dep, 0}); // Enter the dependency.
+ } else {
+ entry.dep_idx++; // No. Skip this node.
+ }
+ } else {
+ // Yes. Time to back up.
+ // Exit this global, pop the stack, and if there's another parent node,
+ // increment its dependency index, and loop again.
+ exit(entry.global);
+ stack.pop_back();
+ if (stack.empty()) {
+ return; // All done.
+ }
+ stack.back().dep_idx++;
+ }
+ }
+ }
+
+ /// SortGlobals sorts the globals into dependency order, erroring if cyclic
+ /// dependencies are found. The sorted dependencies are assigned to #sorted.
+ void SortGlobals() {
+ if (diagnostics_.contains_errors()) {
+ return; // This code assumes there are no undeclared identifiers.
+ }
+
+ std::unordered_set<const Global*> visited;
+ for (auto* global : declaration_order_) {
+ utils::UniqueVector<const Global*> stack;
+ TraverseDependencies(
+ global,
+ [&](const Global* g) { // Enter
+ if (!stack.add(g)) {
+ CyclicDependencyFound(g, stack);
+ return false;
+ }
+ if (sorted_.contains(g->node)) {
+ // Visited this global already.
+ // stack was pushed, but exit() will not be called when we return
+ // false, so pop here.
+ stack.pop_back();
+ return false;
+ }
+ return true;
+ },
+ [&](const Global* g) { // Exit. Only called if Enter returned true.
+ sorted_.add(g->node);
+ stack.pop_back();
+ });
+
+ sorted_.add(global->node);
+
+ if (!stack.empty()) {
+ // Each stack.push() must have a corresponding stack.pop_back().
+ TINT_ICE(Resolver, diagnostics_)
+ << "stack not empty after returning from TraverseDependencies()";
+ }
+ }
+ }
+
+ /// DepInfoFor() looks up the global dependency information for the dependency
+ /// of global `from` depending on `to`.
+ /// @note will raise an ICE if the edge is not found.
+ DependencyInfo DepInfoFor(const Global* from, const Global* to) const {
+ auto it = dependency_edges_.find(DependencyEdge{from, to});
+ if (it != dependency_edges_.end()) {
+ return it->second;
+ }
+ TINT_ICE(Resolver, diagnostics_)
+ << "failed to find dependency info for edge: '" << NameOf(from->node)
+ << "' -> '" << NameOf(to->node) << "'";
+ return {};
+ }
+
+ /// CyclicDependencyFound() emits an error diagnostic for a cyclic dependency.
+ /// @param root is the global that starts the cyclic dependency, which must be
+ /// found in `stack`.
+ /// @param stack is the global dependency stack that contains a loop.
+ void CyclicDependencyFound(const Global* root,
+ const std::vector<const Global*>& stack) {
+ std::stringstream msg;
+ msg << "cyclic dependency found: ";
+ constexpr size_t kLoopNotStarted = ~0u;
+ size_t loop_start = kLoopNotStarted;
+ for (size_t i = 0; i < stack.size(); i++) {
+ auto* e = stack[i];
+ if (loop_start == kLoopNotStarted && e == root) {
+ loop_start = i;
+ }
+ if (loop_start != kLoopNotStarted) {
+ msg << "'" << NameOf(e->node) << "' -> ";
+ }
+ }
+ msg << "'" << NameOf(root->node) << "'";
+ AddError(diagnostics_, msg.str(), root->node->source);
+ for (size_t i = loop_start; i < stack.size(); i++) {
+ auto* from = stack[i];
+ auto* to = (i + 1 < stack.size()) ? stack[i + 1] : stack[loop_start];
+ auto info = DepInfoFor(from, to);
+ AddNote(diagnostics_,
+ KindOf(from->node) + " '" + NameOf(from->node) + "' " +
+ info.action + " " + KindOf(to->node) + " '" +
+ NameOf(to->node) + "' here",
+ info.source);
+ }
+ }
+
+ void DumpDependencyGraph() {
+#if TINT_DUMP_DEPENDENCY_GRAPH == 0
+ if ((true)) {
+ return;
+ }
+#endif // TINT_DUMP_DEPENDENCY_GRAPH
+ printf("=========================\n");
+ printf("------ declaration ------ \n");
+ for (auto* global : declaration_order_) {
+ printf("%s\n", NameOf(global->node).c_str());
+ }
+ printf("------ dependencies ------ \n");
+ for (auto* node : sorted_) {
+ auto symbol = SymbolOf(node);
+ auto* global = globals_.at(symbol);
+ printf("%s depends on:\n", symbols_.NameFor(symbol).c_str());
+ for (auto* dep : global->deps) {
+ printf(" %s\n", NameOf(dep->node).c_str());
+ }
+ }
+ printf("=========================\n");
+ }
+
+ /// Program symbols
+ const SymbolTable& symbols_;
+
+ /// Program diagnostics
+ diag::List& diagnostics_;
+
+ /// The resulting dependency graph
+ DependencyGraph& graph_;
+
+ /// Allocator of Globals
+ BlockAllocator<Global> allocator_;
+
+ /// Global map, keyed by name. Populated by GatherGlobals().
+ GlobalMap globals_;
+
+ /// Map of DependencyEdge to DependencyInfo. Populated by
+ /// DetermineDependencies().
+ DependencyEdges dependency_edges_;
+
+ /// Globals in declaration order. Populated by GatherGlobals().
+ std::vector<Global*> declaration_order_;
+
+ /// Globals in sorted dependency order. Populated by SortGlobals().
+ utils::UniqueVector<const ast::Node*> sorted_;
+};
+
+} // namespace
+
+DependencyGraph::DependencyGraph() = default;
+DependencyGraph::DependencyGraph(DependencyGraph&&) = default;
+DependencyGraph::~DependencyGraph() = default;
+
+bool DependencyGraph::Build(const ast::Module& module,
+ const SymbolTable& symbols,
+ diag::List& diagnostics,
+ DependencyGraph& output) {
+ DependencyAnalysis da{symbols, diagnostics, output};
+ return da.Run(module);
+}
+
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/dependency_graph.h b/src/tint/resolver/dependency_graph.h
new file mode 100644
index 0000000..a943708
--- /dev/null
+++ b/src/tint/resolver/dependency_graph.h
@@ -0,0 +1,66 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_RESOLVER_DEPENDENCY_GRAPH_H_
+#define SRC_TINT_RESOLVER_DEPENDENCY_GRAPH_H_
+
+#include <unordered_map>
+#include <vector>
+
+#include "src/tint/ast/module.h"
+#include "src/tint/diagnostic/diagnostic.h"
+
+namespace tint {
+namespace resolver {
+
+/// DependencyGraph holds information about module-scope declaration dependency
+/// analysis and symbol resolutions.
+struct DependencyGraph {
+ /// Constructor
+ DependencyGraph();
+ /// Move-constructor
+ DependencyGraph(DependencyGraph&&);
+ /// Destructor
+ ~DependencyGraph();
+
+ /// Build() performs symbol resolution and dependency analysis on `module`,
+ /// populating `output` with the resulting dependency graph.
+ /// @param module the AST module to analyse
+ /// @param symbols the symbol table
+ /// @param diagnostics the diagnostic list to populate with errors / warnings
+ /// @param output the resulting DependencyGraph
+ /// @returns true on success, false on error
+ static bool Build(const ast::Module& module,
+ const SymbolTable& symbols,
+ diag::List& diagnostics,
+ DependencyGraph& output);
+
+ /// All globals in dependency-sorted order.
+ std::vector<const ast::Node*> ordered_globals;
+
+ /// Map of ast::IdentifierExpression or ast::TypeName to a type, function, or
+ /// variable that declares the symbol.
+ std::unordered_map<const ast::Node*, const ast::Node*> resolved_symbols;
+
+ /// Map of ast::Variable to a type, function, or variable that is shadowed by
+ /// the variable key. A declaration (X) shadows another (Y) if X and Y use
+ /// the same symbol, and X is declared in a sub-scope of the scope that
+ /// declares Y.
+ std::unordered_map<const ast::Variable*, const ast::Node*> shadows;
+};
+
+} // namespace resolver
+} // namespace tint
+
+#endif // SRC_TINT_RESOLVER_DEPENDENCY_GRAPH_H_
diff --git a/src/tint/resolver/dependency_graph_test.cc b/src/tint/resolver/dependency_graph_test.cc
new file mode 100644
index 0000000..094c908
--- /dev/null
+++ b/src/tint/resolver/dependency_graph_test.cc
@@ -0,0 +1,1342 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+#include <tuple>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "src/tint/resolver/dependency_graph.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ::testing::ElementsAre;
+
+template <typename T>
+class ResolverDependencyGraphTestWithParam : public ResolverTestWithParam<T> {
+ public:
+ DependencyGraph Build(std::string expected_error = "") {
+ DependencyGraph graph;
+ auto result = DependencyGraph::Build(this->AST(), this->Symbols(),
+ this->Diagnostics(), graph);
+ if (expected_error.empty()) {
+ EXPECT_TRUE(result) << this->Diagnostics().str();
+ } else {
+ EXPECT_FALSE(result);
+ EXPECT_EQ(expected_error, this->Diagnostics().str());
+ }
+ return graph;
+ }
+};
+
+using ResolverDependencyGraphTest =
+ ResolverDependencyGraphTestWithParam<::testing::Test>;
+
+////////////////////////////////////////////////////////////////////////////////
+// Parameterized test helpers
+////////////////////////////////////////////////////////////////////////////////
+
+/// SymbolDeclKind is used by parameterized tests to enumerate the different
+/// kinds of symbol declarations.
+enum class SymbolDeclKind {
+ GlobalVar,
+ GlobalLet,
+ Alias,
+ Struct,
+ Function,
+ Parameter,
+ LocalVar,
+ LocalLet,
+ NestedLocalVar,
+ NestedLocalLet,
+};
+
+static constexpr SymbolDeclKind kAllSymbolDeclKinds[] = {
+ SymbolDeclKind::GlobalVar, SymbolDeclKind::GlobalLet,
+ SymbolDeclKind::Alias, SymbolDeclKind::Struct,
+ SymbolDeclKind::Function, SymbolDeclKind::Parameter,
+ SymbolDeclKind::LocalVar, SymbolDeclKind::LocalLet,
+ SymbolDeclKind::NestedLocalVar, SymbolDeclKind::NestedLocalLet,
+};
+
+static constexpr SymbolDeclKind kTypeDeclKinds[] = {
+ SymbolDeclKind::Alias,
+ SymbolDeclKind::Struct,
+};
+
+static constexpr SymbolDeclKind kValueDeclKinds[] = {
+ SymbolDeclKind::GlobalVar, SymbolDeclKind::GlobalLet,
+ SymbolDeclKind::Parameter, SymbolDeclKind::LocalVar,
+ SymbolDeclKind::LocalLet, SymbolDeclKind::NestedLocalVar,
+ SymbolDeclKind::NestedLocalLet,
+};
+
+static constexpr SymbolDeclKind kGlobalDeclKinds[] = {
+ SymbolDeclKind::GlobalVar, SymbolDeclKind::GlobalLet, SymbolDeclKind::Alias,
+ SymbolDeclKind::Struct, SymbolDeclKind::Function,
+};
+
+static constexpr SymbolDeclKind kLocalDeclKinds[] = {
+ SymbolDeclKind::Parameter, SymbolDeclKind::LocalVar,
+ SymbolDeclKind::LocalLet, SymbolDeclKind::NestedLocalVar,
+ SymbolDeclKind::NestedLocalLet,
+};
+
+static constexpr SymbolDeclKind kGlobalValueDeclKinds[] = {
+ SymbolDeclKind::GlobalVar,
+ SymbolDeclKind::GlobalLet,
+};
+
+static constexpr SymbolDeclKind kFuncDeclKinds[] = {
+ SymbolDeclKind::Function,
+};
+
+/// SymbolUseKind is used by parameterized tests to enumerate the different
+/// kinds of symbol uses.
+enum class SymbolUseKind {
+ GlobalVarType,
+ GlobalVarArrayElemType,
+ GlobalVarArraySizeValue,
+ GlobalVarVectorElemType,
+ GlobalVarMatrixElemType,
+ GlobalVarSampledTexElemType,
+ GlobalVarMultisampledTexElemType,
+ GlobalVarValue,
+ GlobalLetType,
+ GlobalLetArrayElemType,
+ GlobalLetArraySizeValue,
+ GlobalLetVectorElemType,
+ GlobalLetMatrixElemType,
+ GlobalLetValue,
+ AliasType,
+ StructMemberType,
+ CallFunction,
+ ParameterType,
+ LocalVarType,
+ LocalVarArrayElemType,
+ LocalVarArraySizeValue,
+ LocalVarVectorElemType,
+ LocalVarMatrixElemType,
+ LocalVarValue,
+ LocalLetType,
+ LocalLetValue,
+ NestedLocalVarType,
+ NestedLocalVarValue,
+ NestedLocalLetType,
+ NestedLocalLetValue,
+ WorkgroupSizeValue,
+};
+
+static constexpr SymbolUseKind kTypeUseKinds[] = {
+ SymbolUseKind::GlobalVarType,
+ SymbolUseKind::GlobalVarArrayElemType,
+ SymbolUseKind::GlobalVarArraySizeValue,
+ SymbolUseKind::GlobalVarVectorElemType,
+ SymbolUseKind::GlobalVarMatrixElemType,
+ SymbolUseKind::GlobalVarSampledTexElemType,
+ SymbolUseKind::GlobalVarMultisampledTexElemType,
+ SymbolUseKind::GlobalLetType,
+ SymbolUseKind::GlobalLetArrayElemType,
+ SymbolUseKind::GlobalLetArraySizeValue,
+ SymbolUseKind::GlobalLetVectorElemType,
+ SymbolUseKind::GlobalLetMatrixElemType,
+ SymbolUseKind::AliasType,
+ SymbolUseKind::StructMemberType,
+ SymbolUseKind::ParameterType,
+ SymbolUseKind::LocalVarType,
+ SymbolUseKind::LocalVarArrayElemType,
+ SymbolUseKind::LocalVarArraySizeValue,
+ SymbolUseKind::LocalVarVectorElemType,
+ SymbolUseKind::LocalVarMatrixElemType,
+ SymbolUseKind::LocalLetType,
+ SymbolUseKind::NestedLocalVarType,
+ SymbolUseKind::NestedLocalLetType,
+};
+
+static constexpr SymbolUseKind kValueUseKinds[] = {
+ SymbolUseKind::GlobalVarValue, SymbolUseKind::GlobalLetValue,
+ SymbolUseKind::LocalVarValue, SymbolUseKind::LocalLetValue,
+ SymbolUseKind::NestedLocalVarValue, SymbolUseKind::NestedLocalLetValue,
+ SymbolUseKind::WorkgroupSizeValue,
+};
+
+static constexpr SymbolUseKind kFuncUseKinds[] = {
+ SymbolUseKind::CallFunction,
+};
+
+/// @returns the description of the symbol declaration kind.
+/// @note: This differs from the strings used in diagnostic messages.
+std::ostream& operator<<(std::ostream& out, SymbolDeclKind kind) {
+ switch (kind) {
+ case SymbolDeclKind::GlobalVar:
+ return out << "global var";
+ case SymbolDeclKind::GlobalLet:
+ return out << "global let";
+ case SymbolDeclKind::Alias:
+ return out << "alias";
+ case SymbolDeclKind::Struct:
+ return out << "struct";
+ case SymbolDeclKind::Function:
+ return out << "function";
+ case SymbolDeclKind::Parameter:
+ return out << "parameter";
+ case SymbolDeclKind::LocalVar:
+ return out << "local var";
+ case SymbolDeclKind::LocalLet:
+ return out << "local let";
+ case SymbolDeclKind::NestedLocalVar:
+ return out << "nested local var";
+ case SymbolDeclKind::NestedLocalLet:
+ return out << "nested local let";
+ }
+ return out << "<unknown>";
+}
+
+/// @returns the description of the symbol use kind.
+/// @note: This differs from the strings used in diagnostic messages.
+std::ostream& operator<<(std::ostream& out, SymbolUseKind kind) {
+ switch (kind) {
+ case SymbolUseKind::GlobalVarType:
+ return out << "global var type";
+ case SymbolUseKind::GlobalVarValue:
+ return out << "global var value";
+ case SymbolUseKind::GlobalVarArrayElemType:
+ return out << "global var array element type";
+ case SymbolUseKind::GlobalVarArraySizeValue:
+ return out << "global var array size value";
+ case SymbolUseKind::GlobalVarVectorElemType:
+ return out << "global var vector element type";
+ case SymbolUseKind::GlobalVarMatrixElemType:
+ return out << "global var matrix element type";
+ case SymbolUseKind::GlobalVarSampledTexElemType:
+ return out << "global var sampled_texture element type";
+ case SymbolUseKind::GlobalVarMultisampledTexElemType:
+ return out << "global var multisampled_texture element type";
+ case SymbolUseKind::GlobalLetType:
+ return out << "global let type";
+ case SymbolUseKind::GlobalLetValue:
+ return out << "global let value";
+ case SymbolUseKind::GlobalLetArrayElemType:
+ return out << "global let array element type";
+ case SymbolUseKind::GlobalLetArraySizeValue:
+ return out << "global let array size value";
+ case SymbolUseKind::GlobalLetVectorElemType:
+ return out << "global let vector element type";
+ case SymbolUseKind::GlobalLetMatrixElemType:
+ return out << "global let matrix element type";
+ case SymbolUseKind::AliasType:
+ return out << "alias type";
+ case SymbolUseKind::StructMemberType:
+ return out << "struct member type";
+ case SymbolUseKind::CallFunction:
+ return out << "call function";
+ case SymbolUseKind::ParameterType:
+ return out << "parameter type";
+ case SymbolUseKind::LocalVarType:
+ return out << "local var type";
+ case SymbolUseKind::LocalVarArrayElemType:
+ return out << "local var array element type";
+ case SymbolUseKind::LocalVarArraySizeValue:
+ return out << "local var array size value";
+ case SymbolUseKind::LocalVarVectorElemType:
+ return out << "local var vector element type";
+ case SymbolUseKind::LocalVarMatrixElemType:
+ return out << "local var matrix element type";
+ case SymbolUseKind::LocalVarValue:
+ return out << "local var value";
+ case SymbolUseKind::LocalLetType:
+ return out << "local let type";
+ case SymbolUseKind::LocalLetValue:
+ return out << "local let value";
+ case SymbolUseKind::NestedLocalVarType:
+ return out << "nested local var type";
+ case SymbolUseKind::NestedLocalVarValue:
+ return out << "nested local var value";
+ case SymbolUseKind::NestedLocalLetType:
+ return out << "nested local let type";
+ case SymbolUseKind::NestedLocalLetValue:
+ return out << "nested local let value";
+ case SymbolUseKind::WorkgroupSizeValue:
+ return out << "workgroup size value";
+ }
+ return out << "<unknown>";
+}
+
+/// @returns the the diagnostic message name used for the given use
+std::string DiagString(SymbolUseKind kind) {
+ switch (kind) {
+ case SymbolUseKind::GlobalVarType:
+ case SymbolUseKind::GlobalVarArrayElemType:
+ case SymbolUseKind::GlobalVarVectorElemType:
+ case SymbolUseKind::GlobalVarMatrixElemType:
+ case SymbolUseKind::GlobalVarSampledTexElemType:
+ case SymbolUseKind::GlobalVarMultisampledTexElemType:
+ case SymbolUseKind::GlobalLetType:
+ case SymbolUseKind::GlobalLetArrayElemType:
+ case SymbolUseKind::GlobalLetVectorElemType:
+ case SymbolUseKind::GlobalLetMatrixElemType:
+ case SymbolUseKind::AliasType:
+ case SymbolUseKind::StructMemberType:
+ case SymbolUseKind::ParameterType:
+ case SymbolUseKind::LocalVarType:
+ case SymbolUseKind::LocalVarArrayElemType:
+ case SymbolUseKind::LocalVarVectorElemType:
+ case SymbolUseKind::LocalVarMatrixElemType:
+ case SymbolUseKind::LocalLetType:
+ case SymbolUseKind::NestedLocalVarType:
+ case SymbolUseKind::NestedLocalLetType:
+ return "type";
+ case SymbolUseKind::GlobalVarValue:
+ case SymbolUseKind::GlobalVarArraySizeValue:
+ case SymbolUseKind::GlobalLetValue:
+ case SymbolUseKind::GlobalLetArraySizeValue:
+ case SymbolUseKind::LocalVarValue:
+ case SymbolUseKind::LocalVarArraySizeValue:
+ case SymbolUseKind::LocalLetValue:
+ case SymbolUseKind::NestedLocalVarValue:
+ case SymbolUseKind::NestedLocalLetValue:
+ case SymbolUseKind::WorkgroupSizeValue:
+ return "identifier";
+ case SymbolUseKind::CallFunction:
+ return "function";
+ }
+ return "<unknown>";
+}
+
+/// @returns the declaration scope depth for the symbol declaration kind.
+/// Globals are at depth 0, parameters and locals are at depth 1,
+/// nested locals are at depth 2.
+int ScopeDepth(SymbolDeclKind kind) {
+ switch (kind) {
+ case SymbolDeclKind::GlobalVar:
+ case SymbolDeclKind::GlobalLet:
+ case SymbolDeclKind::Alias:
+ case SymbolDeclKind::Struct:
+ case SymbolDeclKind::Function:
+ return 0;
+ case SymbolDeclKind::Parameter:
+ case SymbolDeclKind::LocalVar:
+ case SymbolDeclKind::LocalLet:
+ return 1;
+ case SymbolDeclKind::NestedLocalVar:
+ case SymbolDeclKind::NestedLocalLet:
+ return 2;
+ }
+ return -1;
+}
+
+/// @returns the use depth for the symbol use kind.
+/// Globals are at depth 0, parameters and locals are at depth 1,
+/// nested locals are at depth 2.
+int ScopeDepth(SymbolUseKind kind) {
+ switch (kind) {
+ case SymbolUseKind::GlobalVarType:
+ case SymbolUseKind::GlobalVarValue:
+ case SymbolUseKind::GlobalVarArrayElemType:
+ case SymbolUseKind::GlobalVarArraySizeValue:
+ case SymbolUseKind::GlobalVarVectorElemType:
+ case SymbolUseKind::GlobalVarMatrixElemType:
+ case SymbolUseKind::GlobalVarSampledTexElemType:
+ case SymbolUseKind::GlobalVarMultisampledTexElemType:
+ case SymbolUseKind::GlobalLetType:
+ case SymbolUseKind::GlobalLetValue:
+ case SymbolUseKind::GlobalLetArrayElemType:
+ case SymbolUseKind::GlobalLetArraySizeValue:
+ case SymbolUseKind::GlobalLetVectorElemType:
+ case SymbolUseKind::GlobalLetMatrixElemType:
+ case SymbolUseKind::AliasType:
+ case SymbolUseKind::StructMemberType:
+ case SymbolUseKind::WorkgroupSizeValue:
+ return 0;
+ case SymbolUseKind::CallFunction:
+ case SymbolUseKind::ParameterType:
+ case SymbolUseKind::LocalVarType:
+ case SymbolUseKind::LocalVarArrayElemType:
+ case SymbolUseKind::LocalVarArraySizeValue:
+ case SymbolUseKind::LocalVarVectorElemType:
+ case SymbolUseKind::LocalVarMatrixElemType:
+ case SymbolUseKind::LocalVarValue:
+ case SymbolUseKind::LocalLetType:
+ case SymbolUseKind::LocalLetValue:
+ return 1;
+ case SymbolUseKind::NestedLocalVarType:
+ case SymbolUseKind::NestedLocalVarValue:
+ case SymbolUseKind::NestedLocalLetType:
+ case SymbolUseKind::NestedLocalLetValue:
+ return 2;
+ }
+ return -1;
+}
+
+/// A helper for building programs that exercise symbol declaration tests.
+struct SymbolTestHelper {
+ /// The program builder
+ ProgramBuilder* const builder;
+ /// Parameters to a function that may need to be built
+ std::vector<const ast::Variable*> parameters;
+ /// Shallow function var / let declaration statements
+ std::vector<const ast::Statement*> statements;
+ /// Nested function local var / let declaration statements
+ std::vector<const ast::Statement*> nested_statements;
+ /// Function attributes
+ ast::AttributeList func_attrs;
+
+ /// Constructor
+ /// @param builder the program builder
+ explicit SymbolTestHelper(ProgramBuilder* builder);
+
+ /// Destructor.
+ ~SymbolTestHelper();
+
+ /// Declares a symbol with the given kind
+ /// @param kind the kind of symbol declaration
+ /// @param symbol the symbol to use for the declaration
+ /// @param source the source of the declaration
+ /// @returns the declaration node
+ const ast::Node* Add(SymbolDeclKind kind, Symbol symbol, Source source);
+
+ /// Declares a use of a symbol with the given kind
+ /// @param kind the kind of symbol use
+ /// @param symbol the declaration symbol to use
+ /// @param source the source of the use
+ /// @returns the use node
+ const ast::Node* Add(SymbolUseKind kind, Symbol symbol, Source source);
+
+ /// Builds a function, if any parameter or local declarations have been added
+ void Build();
+};
+
+SymbolTestHelper::SymbolTestHelper(ProgramBuilder* b) : builder(b) {}
+
+SymbolTestHelper::~SymbolTestHelper() {}
+
+const ast::Node* SymbolTestHelper::Add(SymbolDeclKind kind,
+ Symbol symbol,
+ Source source) {
+ auto& b = *builder;
+ switch (kind) {
+ case SymbolDeclKind::GlobalVar:
+ return b.Global(source, symbol, b.ty.i32(), ast::StorageClass::kPrivate);
+ case SymbolDeclKind::GlobalLet:
+ return b.GlobalConst(source, symbol, b.ty.i32(), b.Expr(1));
+ case SymbolDeclKind::Alias:
+ return b.Alias(source, symbol, b.ty.i32());
+ case SymbolDeclKind::Struct:
+ return b.Structure(source, symbol, {b.Member("m", b.ty.i32())});
+ case SymbolDeclKind::Function:
+ return b.Func(source, symbol, {}, b.ty.void_(), {});
+ case SymbolDeclKind::Parameter: {
+ auto* node = b.Param(source, symbol, b.ty.i32());
+ parameters.emplace_back(node);
+ return node;
+ }
+ case SymbolDeclKind::LocalVar: {
+ auto* node = b.Var(source, symbol, b.ty.i32());
+ statements.emplace_back(b.Decl(node));
+ return node;
+ }
+ case SymbolDeclKind::LocalLet: {
+ auto* node = b.Const(source, symbol, b.ty.i32(), b.Expr(1));
+ statements.emplace_back(b.Decl(node));
+ return node;
+ }
+ case SymbolDeclKind::NestedLocalVar: {
+ auto* node = b.Var(source, symbol, b.ty.i32());
+ nested_statements.emplace_back(b.Decl(node));
+ return node;
+ }
+ case SymbolDeclKind::NestedLocalLet: {
+ auto* node = b.Const(source, symbol, b.ty.i32(), b.Expr(1));
+ nested_statements.emplace_back(b.Decl(node));
+ return node;
+ }
+ }
+ return nullptr;
+}
+
+const ast::Node* SymbolTestHelper::Add(SymbolUseKind kind,
+ Symbol symbol,
+ Source source) {
+ auto& b = *builder;
+ switch (kind) {
+ case SymbolUseKind::GlobalVarType: {
+ auto* node = b.ty.type_name(source, symbol);
+ b.Global(b.Sym(), node, ast::StorageClass::kPrivate);
+ return node;
+ }
+ case SymbolUseKind::GlobalVarArrayElemType: {
+ auto* node = b.ty.type_name(source, symbol);
+ b.Global(b.Sym(), b.ty.array(node, 4), ast::StorageClass::kPrivate);
+ return node;
+ }
+ case SymbolUseKind::GlobalVarArraySizeValue: {
+ auto* node = b.Expr(source, symbol);
+ b.Global(b.Sym(), b.ty.array(b.ty.i32(), node),
+ ast::StorageClass::kPrivate);
+ return node;
+ }
+ case SymbolUseKind::GlobalVarVectorElemType: {
+ auto* node = b.ty.type_name(source, symbol);
+ b.Global(b.Sym(), b.ty.vec3(node), ast::StorageClass::kPrivate);
+ return node;
+ }
+ case SymbolUseKind::GlobalVarMatrixElemType: {
+ auto* node = b.ty.type_name(source, symbol);
+ b.Global(b.Sym(), b.ty.mat3x4(node), ast::StorageClass::kPrivate);
+ return node;
+ }
+ case SymbolUseKind::GlobalVarSampledTexElemType: {
+ auto* node = b.ty.type_name(source, symbol);
+ b.Global(b.Sym(), b.ty.sampled_texture(ast::TextureDimension::k2d, node));
+ return node;
+ }
+ case SymbolUseKind::GlobalVarMultisampledTexElemType: {
+ auto* node = b.ty.type_name(source, symbol);
+ b.Global(b.Sym(),
+ b.ty.multisampled_texture(ast::TextureDimension::k2d, node));
+ return node;
+ }
+ case SymbolUseKind::GlobalVarValue: {
+ auto* node = b.Expr(source, symbol);
+ b.Global(b.Sym(), b.ty.i32(), ast::StorageClass::kPrivate, node);
+ return node;
+ }
+ case SymbolUseKind::GlobalLetType: {
+ auto* node = b.ty.type_name(source, symbol);
+ b.GlobalConst(b.Sym(), node, b.Expr(1));
+ return node;
+ }
+ case SymbolUseKind::GlobalLetArrayElemType: {
+ auto* node = b.ty.type_name(source, symbol);
+ b.GlobalConst(b.Sym(), b.ty.array(node, 4), b.Expr(1));
+ return node;
+ }
+ case SymbolUseKind::GlobalLetArraySizeValue: {
+ auto* node = b.Expr(source, symbol);
+ b.GlobalConst(b.Sym(), b.ty.array(b.ty.i32(), node), b.Expr(1));
+ return node;
+ }
+ case SymbolUseKind::GlobalLetVectorElemType: {
+ auto* node = b.ty.type_name(source, symbol);
+ b.GlobalConst(b.Sym(), b.ty.vec3(node), b.Expr(1));
+ return node;
+ }
+ case SymbolUseKind::GlobalLetMatrixElemType: {
+ auto* node = b.ty.type_name(source, symbol);
+ b.GlobalConst(b.Sym(), b.ty.mat3x4(node), b.Expr(1));
+ return node;
+ }
+ case SymbolUseKind::GlobalLetValue: {
+ auto* node = b.Expr(source, symbol);
+ b.GlobalConst(b.Sym(), b.ty.i32(), node);
+ return node;
+ }
+ case SymbolUseKind::AliasType: {
+ auto* node = b.ty.type_name(source, symbol);
+ b.Alias(b.Sym(), node);
+ return node;
+ }
+ case SymbolUseKind::StructMemberType: {
+ auto* node = b.ty.type_name(source, symbol);
+ b.Structure(b.Sym(), {b.Member("m", node)});
+ return node;
+ }
+ case SymbolUseKind::CallFunction: {
+ auto* node = b.Expr(source, symbol);
+ statements.emplace_back(b.CallStmt(b.Call(node)));
+ return node;
+ }
+ case SymbolUseKind::ParameterType: {
+ auto* node = b.ty.type_name(source, symbol);
+ parameters.emplace_back(b.Param(b.Sym(), node));
+ return node;
+ }
+ case SymbolUseKind::LocalVarType: {
+ auto* node = b.ty.type_name(source, symbol);
+ statements.emplace_back(b.Decl(b.Var(b.Sym(), node)));
+ return node;
+ }
+ case SymbolUseKind::LocalVarArrayElemType: {
+ auto* node = b.ty.type_name(source, symbol);
+ statements.emplace_back(
+ b.Decl(b.Var(b.Sym(), b.ty.array(node, 4), b.Expr(1))));
+ return node;
+ }
+ case SymbolUseKind::LocalVarArraySizeValue: {
+ auto* node = b.Expr(source, symbol);
+ statements.emplace_back(
+ b.Decl(b.Var(b.Sym(), b.ty.array(b.ty.i32(), node), b.Expr(1))));
+ return node;
+ }
+ case SymbolUseKind::LocalVarVectorElemType: {
+ auto* node = b.ty.type_name(source, symbol);
+ statements.emplace_back(b.Decl(b.Var(b.Sym(), b.ty.vec3(node))));
+ return node;
+ }
+ case SymbolUseKind::LocalVarMatrixElemType: {
+ auto* node = b.ty.type_name(source, symbol);
+ statements.emplace_back(b.Decl(b.Var(b.Sym(), b.ty.mat3x4(node))));
+ return node;
+ }
+ case SymbolUseKind::LocalVarValue: {
+ auto* node = b.Expr(source, symbol);
+ statements.emplace_back(b.Decl(b.Var(b.Sym(), b.ty.i32(), node)));
+ return node;
+ }
+ case SymbolUseKind::LocalLetType: {
+ auto* node = b.ty.type_name(source, symbol);
+ statements.emplace_back(b.Decl(b.Const(b.Sym(), node, b.Expr(1))));
+ return node;
+ }
+ case SymbolUseKind::LocalLetValue: {
+ auto* node = b.Expr(source, symbol);
+ statements.emplace_back(b.Decl(b.Const(b.Sym(), b.ty.i32(), node)));
+ return node;
+ }
+ case SymbolUseKind::NestedLocalVarType: {
+ auto* node = b.ty.type_name(source, symbol);
+ nested_statements.emplace_back(b.Decl(b.Var(b.Sym(), node)));
+ return node;
+ }
+ case SymbolUseKind::NestedLocalVarValue: {
+ auto* node = b.Expr(source, symbol);
+ nested_statements.emplace_back(b.Decl(b.Var(b.Sym(), b.ty.i32(), node)));
+ return node;
+ }
+ case SymbolUseKind::NestedLocalLetType: {
+ auto* node = b.ty.type_name(source, symbol);
+ nested_statements.emplace_back(b.Decl(b.Const(b.Sym(), node, b.Expr(1))));
+ return node;
+ }
+ case SymbolUseKind::NestedLocalLetValue: {
+ auto* node = b.Expr(source, symbol);
+ nested_statements.emplace_back(
+ b.Decl(b.Const(b.Sym(), b.ty.i32(), node)));
+ return node;
+ }
+ case SymbolUseKind::WorkgroupSizeValue: {
+ auto* node = b.Expr(source, symbol);
+ func_attrs.emplace_back(b.WorkgroupSize(1, node, 2));
+ return node;
+ }
+ }
+ return nullptr;
+}
+
+void SymbolTestHelper::Build() {
+ auto& b = *builder;
+ if (!nested_statements.empty()) {
+ statements.emplace_back(b.Block(nested_statements));
+ nested_statements.clear();
+ }
+ if (!parameters.empty() || !statements.empty() || !func_attrs.empty()) {
+ b.Func("func", parameters, b.ty.void_(), statements, func_attrs);
+ parameters.clear();
+ statements.clear();
+ func_attrs.clear();
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Used-before-declarated tests
+////////////////////////////////////////////////////////////////////////////////
+namespace used_before_decl_tests {
+
+using ResolverDependencyGraphUsedBeforeDeclTest = ResolverDependencyGraphTest;
+
+TEST_F(ResolverDependencyGraphUsedBeforeDeclTest, FuncCall) {
+ // fn A() { B(); }
+ // fn B() {}
+
+ Func("A", {}, ty.void_(), {CallStmt(Call(Expr(Source{{12, 34}}, "B")))});
+ Func(Source{{56, 78}}, "B", {}, ty.void_(), {Return()});
+
+ Build();
+}
+
+TEST_F(ResolverDependencyGraphUsedBeforeDeclTest, TypeConstructed) {
+ // fn F() {
+ // { _ = T(); }
+ // }
+ // type T = i32;
+
+ Func("F", {}, ty.void_(),
+ {Block(Ignore(Construct(ty.type_name(Source{{12, 34}}, "T"))))});
+ Alias(Source{{56, 78}}, "T", ty.i32());
+
+ Build();
+}
+
+TEST_F(ResolverDependencyGraphUsedBeforeDeclTest, TypeUsedByLocal) {
+ // fn F() {
+ // { var v : T; }
+ // }
+ // type T = i32;
+
+ Func("F", {}, ty.void_(),
+ {Block(Decl(Var("v", ty.type_name(Source{{12, 34}}, "T"))))});
+ Alias(Source{{56, 78}}, "T", ty.i32());
+
+ Build();
+}
+
+TEST_F(ResolverDependencyGraphUsedBeforeDeclTest, TypeUsedByParam) {
+ // fn F(p : T) {}
+ // type T = i32;
+
+ Func("F", {Param("p", ty.type_name(Source{{12, 34}}, "T"))}, ty.void_(), {});
+ Alias(Source{{56, 78}}, "T", ty.i32());
+
+ Build();
+}
+
+TEST_F(ResolverDependencyGraphUsedBeforeDeclTest, TypeUsedAsReturnType) {
+ // fn F() -> T {}
+ // type T = i32;
+
+ Func("F", {}, ty.type_name(Source{{12, 34}}, "T"), {});
+ Alias(Source{{56, 78}}, "T", ty.i32());
+
+ Build();
+}
+
+TEST_F(ResolverDependencyGraphUsedBeforeDeclTest, TypeByStructMember) {
+ // struct S { m : T };
+ // type T = i32;
+
+ Structure("S", {Member("m", ty.type_name(Source{{12, 34}}, "T"))});
+ Alias(Source{{56, 78}}, "T", ty.i32());
+
+ Build();
+}
+
+TEST_F(ResolverDependencyGraphUsedBeforeDeclTest, VarUsed) {
+ // fn F() {
+ // { G = 3.14f; }
+ // }
+ // var G: f32 = 2.1;
+
+ Func("F", ast::VariableList{}, ty.void_(),
+ {Block(Assign(Expr(Source{{12, 34}}, "G"), 3.14f))});
+
+ Global(Source{{56, 78}}, "G", ty.f32(), ast::StorageClass::kPrivate,
+ Expr(2.1f));
+
+ Build();
+}
+
+} // namespace used_before_decl_tests
+
+////////////////////////////////////////////////////////////////////////////////
+// Undeclared symbol tests
+////////////////////////////////////////////////////////////////////////////////
+namespace undeclared_tests {
+
+using ResolverDependencyGraphUndeclaredSymbolTest =
+ ResolverDependencyGraphTestWithParam<SymbolUseKind>;
+
+TEST_P(ResolverDependencyGraphUndeclaredSymbolTest, Test) {
+ const Symbol symbol = Sym("SYMBOL");
+ const auto use_kind = GetParam();
+
+ // Build a use of a non-existent symbol
+ SymbolTestHelper helper(this);
+ helper.Add(use_kind, symbol, Source{{56, 78}});
+ helper.Build();
+
+ Build("56:78 error: unknown " + DiagString(use_kind) + ": 'SYMBOL'");
+}
+
+INSTANTIATE_TEST_SUITE_P(Types,
+ ResolverDependencyGraphUndeclaredSymbolTest,
+ testing::ValuesIn(kTypeUseKinds));
+
+INSTANTIATE_TEST_SUITE_P(Values,
+ ResolverDependencyGraphUndeclaredSymbolTest,
+ testing::ValuesIn(kValueUseKinds));
+
+INSTANTIATE_TEST_SUITE_P(Functions,
+ ResolverDependencyGraphUndeclaredSymbolTest,
+ testing::ValuesIn(kFuncUseKinds));
+
+} // namespace undeclared_tests
+
+////////////////////////////////////////////////////////////////////////////////
+// Self reference by decl
+////////////////////////////////////////////////////////////////////////////////
+namespace undeclared_tests {
+
+using ResolverDependencyGraphDeclSelfUse = ResolverDependencyGraphTest;
+
+TEST_F(ResolverDependencyGraphDeclSelfUse, GlobalVar) {
+ const Symbol symbol = Sym("SYMBOL");
+ Global(symbol, ty.i32(), Mul(Expr(Source{{12, 34}}, symbol), 123));
+ Build(R"(error: cyclic dependency found: 'SYMBOL' -> 'SYMBOL'
+12:34 note: var 'SYMBOL' references var 'SYMBOL' here)");
+}
+
+TEST_F(ResolverDependencyGraphDeclSelfUse, GlobalLet) {
+ const Symbol symbol = Sym("SYMBOL");
+ GlobalConst(symbol, ty.i32(), Mul(Expr(Source{{12, 34}}, symbol), 123));
+ Build(R"(error: cyclic dependency found: 'SYMBOL' -> 'SYMBOL'
+12:34 note: let 'SYMBOL' references let 'SYMBOL' here)");
+}
+
+TEST_F(ResolverDependencyGraphDeclSelfUse, LocalVar) {
+ const Symbol symbol = Sym("SYMBOL");
+ WrapInFunction(
+ Decl(Var(symbol, ty.i32(), Mul(Expr(Source{{12, 34}}, symbol), 123))));
+ Build("12:34 error: unknown identifier: 'SYMBOL'");
+}
+
+TEST_F(ResolverDependencyGraphDeclSelfUse, LocalLet) {
+ const Symbol symbol = Sym("SYMBOL");
+ WrapInFunction(
+ Decl(Const(symbol, ty.i32(), Mul(Expr(Source{{12, 34}}, symbol), 123))));
+ Build("12:34 error: unknown identifier: 'SYMBOL'");
+}
+
+} // namespace undeclared_tests
+
+////////////////////////////////////////////////////////////////////////////////
+// Recursive dependency tests
+////////////////////////////////////////////////////////////////////////////////
+namespace recursive_tests {
+
+using ResolverDependencyGraphCyclicRefTest = ResolverDependencyGraphTest;
+
+TEST_F(ResolverDependencyGraphCyclicRefTest, DirectCall) {
+ // fn main() { main(); }
+
+ Func(Source{{12, 34}}, "main", {}, ty.void_(),
+ {CallStmt(Call(Expr(Source{{56, 78}}, "main")))});
+
+ Build(R"(12:34 error: cyclic dependency found: 'main' -> 'main'
+56:78 note: function 'main' calls function 'main' here)");
+}
+
+TEST_F(ResolverDependencyGraphCyclicRefTest, IndirectCall) {
+ // 1: fn a() { b(); }
+ // 2: fn e() { }
+ // 3: fn d() { e(); b(); }
+ // 4: fn c() { d(); }
+ // 5: fn b() { c(); }
+
+ Func(Source{{1, 1}}, "a", {}, ty.void_(),
+ {CallStmt(Call(Expr(Source{{1, 10}}, "b")))});
+ Func(Source{{2, 1}}, "e", {}, ty.void_(), {});
+ Func(Source{{3, 1}}, "d", {}, ty.void_(),
+ {
+ CallStmt(Call(Expr(Source{{3, 10}}, "e"))),
+ CallStmt(Call(Expr(Source{{3, 10}}, "b"))),
+ });
+ Func(Source{{4, 1}}, "c", {}, ty.void_(),
+ {CallStmt(Call(Expr(Source{{4, 10}}, "d")))});
+ Func(Source{{5, 1}}, "b", {}, ty.void_(),
+ {CallStmt(Call(Expr(Source{{5, 10}}, "c")))});
+
+ Build(R"(5:1 error: cyclic dependency found: 'b' -> 'c' -> 'd' -> 'b'
+5:10 note: function 'b' calls function 'c' here
+4:10 note: function 'c' calls function 'd' here
+3:10 note: function 'd' calls function 'b' here)");
+}
+
+TEST_F(ResolverDependencyGraphCyclicRefTest, Alias_Direct) {
+ // type T = T;
+
+ Alias(Source{{12, 34}}, "T", ty.type_name(Source{{56, 78}}, "T"));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: cyclic dependency found: 'T' -> 'T'
+56:78 note: alias 'T' references alias 'T' here)");
+}
+
+TEST_F(ResolverDependencyGraphCyclicRefTest, Alias_Indirect) {
+ // 1: type Y = Z;
+ // 2: type X = Y;
+ // 3: type Z = X;
+
+ Alias(Source{{1, 1}}, "Y", ty.type_name(Source{{1, 10}}, "Z"));
+ Alias(Source{{2, 1}}, "X", ty.type_name(Source{{2, 10}}, "Y"));
+ Alias(Source{{3, 1}}, "Z", ty.type_name(Source{{3, 10}}, "X"));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(1:1 error: cyclic dependency found: 'Y' -> 'Z' -> 'X' -> 'Y'
+1:10 note: alias 'Y' references alias 'Z' here
+3:10 note: alias 'Z' references alias 'X' here
+2:10 note: alias 'X' references alias 'Y' here)");
+}
+
+TEST_F(ResolverDependencyGraphCyclicRefTest, Struct_Direct) {
+ // struct S {
+ // a: S;
+ // };
+
+ Structure(Source{{12, 34}}, "S",
+ {Member("a", ty.type_name(Source{{56, 78}}, "S"))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: cyclic dependency found: 'S' -> 'S'
+56:78 note: struct 'S' references struct 'S' here)");
+}
+
+TEST_F(ResolverDependencyGraphCyclicRefTest, Struct_Indirect) {
+ // 1: struct Y { z: Z; };
+ // 2: struct X { y: Y; };
+ // 3: struct Z { x: X; };
+
+ Structure(Source{{1, 1}}, "Y",
+ {Member("z", ty.type_name(Source{{1, 10}}, "Z"))});
+ Structure(Source{{2, 1}}, "X",
+ {Member("y", ty.type_name(Source{{2, 10}}, "Y"))});
+ Structure(Source{{3, 1}}, "Z",
+ {Member("x", ty.type_name(Source{{3, 10}}, "X"))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(1:1 error: cyclic dependency found: 'Y' -> 'Z' -> 'X' -> 'Y'
+1:10 note: struct 'Y' references struct 'Z' here
+3:10 note: struct 'Z' references struct 'X' here
+2:10 note: struct 'X' references struct 'Y' here)");
+}
+
+TEST_F(ResolverDependencyGraphCyclicRefTest, GlobalVar_Direct) {
+ // var<private> V : i32 = V;
+
+ Global(Source{{12, 34}}, "V", ty.i32(), Expr(Source{{56, 78}}, "V"));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: cyclic dependency found: 'V' -> 'V'
+56:78 note: var 'V' references var 'V' here)");
+}
+
+TEST_F(ResolverDependencyGraphCyclicRefTest, GlobalLet_Direct) {
+ // let V : i32 = V;
+
+ GlobalConst(Source{{12, 34}}, "V", ty.i32(), Expr(Source{{56, 78}}, "V"));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: cyclic dependency found: 'V' -> 'V'
+56:78 note: let 'V' references let 'V' here)");
+}
+
+TEST_F(ResolverDependencyGraphCyclicRefTest, GlobalVar_Indirect) {
+ // 1: var<private> Y : i32 = Z;
+ // 2: var<private> X : i32 = Y;
+ // 3: var<private> Z : i32 = X;
+
+ Global(Source{{1, 1}}, "Y", ty.i32(), Expr(Source{{1, 10}}, "Z"));
+ Global(Source{{2, 1}}, "X", ty.i32(), Expr(Source{{2, 10}}, "Y"));
+ Global(Source{{3, 1}}, "Z", ty.i32(), Expr(Source{{3, 10}}, "X"));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(1:1 error: cyclic dependency found: 'Y' -> 'Z' -> 'X' -> 'Y'
+1:10 note: var 'Y' references var 'Z' here
+3:10 note: var 'Z' references var 'X' here
+2:10 note: var 'X' references var 'Y' here)");
+}
+
+TEST_F(ResolverDependencyGraphCyclicRefTest, GlobalLet_Indirect) {
+ // 1: let Y : i32 = Z;
+ // 2: let X : i32 = Y;
+ // 3: let Z : i32 = X;
+
+ GlobalConst(Source{{1, 1}}, "Y", ty.i32(), Expr(Source{{1, 10}}, "Z"));
+ GlobalConst(Source{{2, 1}}, "X", ty.i32(), Expr(Source{{2, 10}}, "Y"));
+ GlobalConst(Source{{3, 1}}, "Z", ty.i32(), Expr(Source{{3, 10}}, "X"));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(1:1 error: cyclic dependency found: 'Y' -> 'Z' -> 'X' -> 'Y'
+1:10 note: let 'Y' references let 'Z' here
+3:10 note: let 'Z' references let 'X' here
+2:10 note: let 'X' references let 'Y' here)");
+}
+
+TEST_F(ResolverDependencyGraphCyclicRefTest, Mixed_RecursiveDependencies) {
+ // 1: fn F() -> R { return Z; }
+ // 2: type A = S;
+ // 3: struct S { a : A };
+ // 4: var Z = L;
+ // 5: type R = A;
+ // 6: let L : S = Z;
+
+ Func(Source{{1, 1}}, "F", {}, ty.type_name(Source{{1, 5}}, "R"),
+ {Return(Expr(Source{{1, 10}}, "Z"))});
+ Alias(Source{{2, 1}}, "A", ty.type_name(Source{{2, 10}}, "S"));
+ Structure(Source{{3, 1}}, "S",
+ {Member("a", ty.type_name(Source{{3, 10}}, "A"))});
+ Global(Source{{4, 1}}, "Z", nullptr, Expr(Source{{4, 10}}, "L"));
+ Alias(Source{{5, 1}}, "R", ty.type_name(Source{{5, 10}}, "A"));
+ GlobalConst(Source{{6, 1}}, "L", ty.type_name(Source{{5, 5}}, "S"),
+ Expr(Source{{5, 10}}, "Z"));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(2:1 error: cyclic dependency found: 'A' -> 'S' -> 'A'
+2:10 note: alias 'A' references struct 'S' here
+3:10 note: struct 'S' references alias 'A' here
+4:1 error: cyclic dependency found: 'Z' -> 'L' -> 'Z'
+4:10 note: var 'Z' references let 'L' here
+5:10 note: let 'L' references var 'Z' here)");
+}
+
+} // namespace recursive_tests
+
+////////////////////////////////////////////////////////////////////////////////
+// Symbol Redeclaration tests
+////////////////////////////////////////////////////////////////////////////////
+namespace redeclaration_tests {
+
+using ResolverDependencyGraphRedeclarationTest =
+ ResolverDependencyGraphTestWithParam<
+ std::tuple<SymbolDeclKind, SymbolDeclKind>>;
+
+TEST_P(ResolverDependencyGraphRedeclarationTest, Test) {
+ const auto symbol = Sym("SYMBOL");
+
+ auto a_kind = std::get<0>(GetParam());
+ auto b_kind = std::get<1>(GetParam());
+
+ auto a_source = Source{{12, 34}};
+ auto b_source = Source{{56, 78}};
+
+ if (a_kind != SymbolDeclKind::Parameter &&
+ b_kind == SymbolDeclKind::Parameter) {
+ std::swap(a_source, b_source); // Parameters are declared before locals
+ }
+
+ SymbolTestHelper helper(this);
+ helper.Add(a_kind, symbol, a_source);
+ helper.Add(b_kind, symbol, b_source);
+ helper.Build();
+
+ bool error = ScopeDepth(a_kind) == ScopeDepth(b_kind);
+
+ Build(error ? R"(56:78 error: redeclaration of 'SYMBOL'
+12:34 note: 'SYMBOL' previously declared here)"
+ : "");
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ ResolverTest,
+ ResolverDependencyGraphRedeclarationTest,
+ testing::Combine(testing::ValuesIn(kAllSymbolDeclKinds),
+ testing::ValuesIn(kAllSymbolDeclKinds)));
+
+} // namespace redeclaration_tests
+
+////////////////////////////////////////////////////////////////////////////////
+// Ordered global tests
+////////////////////////////////////////////////////////////////////////////////
+namespace ordered_globals {
+
+using ResolverDependencyGraphOrderedGlobalsTest =
+ ResolverDependencyGraphTestWithParam<
+ std::tuple<SymbolDeclKind, SymbolUseKind>>;
+
+TEST_P(ResolverDependencyGraphOrderedGlobalsTest, InOrder) {
+ const Symbol symbol = Sym("SYMBOL");
+ const auto decl_kind = std::get<0>(GetParam());
+ const auto use_kind = std::get<1>(GetParam());
+
+ // Declaration before use
+ SymbolTestHelper helper(this);
+ helper.Add(decl_kind, symbol, Source{{12, 34}});
+ helper.Add(use_kind, symbol, Source{{56, 78}});
+ helper.Build();
+
+ ASSERT_EQ(AST().GlobalDeclarations().size(), 2u);
+
+ auto* decl = AST().GlobalDeclarations()[0];
+ auto* use = AST().GlobalDeclarations()[1];
+ EXPECT_THAT(Build().ordered_globals, ElementsAre(decl, use));
+}
+
+TEST_P(ResolverDependencyGraphOrderedGlobalsTest, OutOfOrder) {
+ const Symbol symbol = Sym("SYMBOL");
+ const auto decl_kind = std::get<0>(GetParam());
+ const auto use_kind = std::get<1>(GetParam());
+
+ // Use before declaration
+ SymbolTestHelper helper(this);
+ helper.Add(use_kind, symbol, Source{{56, 78}});
+ helper.Build(); // If the use is in a function, then ensure this function is
+ // built before the symbol declaration
+ helper.Add(decl_kind, symbol, Source{{12, 34}});
+ helper.Build();
+
+ ASSERT_EQ(AST().GlobalDeclarations().size(), 2u);
+
+ auto* use = AST().GlobalDeclarations()[0];
+ auto* decl = AST().GlobalDeclarations()[1];
+ EXPECT_THAT(Build().ordered_globals, ElementsAre(decl, use));
+}
+
+INSTANTIATE_TEST_SUITE_P(Types,
+ ResolverDependencyGraphOrderedGlobalsTest,
+ testing::Combine(testing::ValuesIn(kTypeDeclKinds),
+ testing::ValuesIn(kTypeUseKinds)));
+
+INSTANTIATE_TEST_SUITE_P(
+ Values,
+ ResolverDependencyGraphOrderedGlobalsTest,
+ testing::Combine(testing::ValuesIn(kGlobalValueDeclKinds),
+ testing::ValuesIn(kValueUseKinds)));
+
+INSTANTIATE_TEST_SUITE_P(Functions,
+ ResolverDependencyGraphOrderedGlobalsTest,
+ testing::Combine(testing::ValuesIn(kFuncDeclKinds),
+ testing::ValuesIn(kFuncUseKinds)));
+} // namespace ordered_globals
+
+////////////////////////////////////////////////////////////////////////////////
+// Resolved symbols tests
+////////////////////////////////////////////////////////////////////////////////
+namespace resolved_symbols {
+
+using ResolverDependencyGraphResolvedSymbolTest =
+ ResolverDependencyGraphTestWithParam<
+ std::tuple<SymbolDeclKind, SymbolUseKind>>;
+
+TEST_P(ResolverDependencyGraphResolvedSymbolTest, Test) {
+ const Symbol symbol = Sym("SYMBOL");
+ const auto decl_kind = std::get<0>(GetParam());
+ const auto use_kind = std::get<1>(GetParam());
+
+ // Build a symbol declaration and a use of that symbol
+ SymbolTestHelper helper(this);
+ auto* decl = helper.Add(decl_kind, symbol, Source{{12, 34}});
+ auto* use = helper.Add(use_kind, symbol, Source{{56, 78}});
+ helper.Build();
+
+ // If the declaration is visible to the use, then we expect the analysis to
+ // succeed.
+ bool expect_pass = ScopeDepth(decl_kind) <= ScopeDepth(use_kind);
+ auto graph =
+ Build(expect_pass ? "" : "56:78 error: unknown identifier: 'SYMBOL'");
+
+ if (expect_pass) {
+ // Check that the use resolves to the declaration
+ auto* resolved_symbol = graph.resolved_symbols[use];
+ EXPECT_EQ(resolved_symbol, decl)
+ << "resolved: "
+ << (resolved_symbol ? resolved_symbol->TypeInfo().name : "<null>")
+ << "\n"
+ << "decl: " << decl->TypeInfo().name;
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(Types,
+ ResolverDependencyGraphResolvedSymbolTest,
+ testing::Combine(testing::ValuesIn(kTypeDeclKinds),
+ testing::ValuesIn(kTypeUseKinds)));
+
+INSTANTIATE_TEST_SUITE_P(Values,
+ ResolverDependencyGraphResolvedSymbolTest,
+ testing::Combine(testing::ValuesIn(kValueDeclKinds),
+ testing::ValuesIn(kValueUseKinds)));
+
+INSTANTIATE_TEST_SUITE_P(Functions,
+ ResolverDependencyGraphResolvedSymbolTest,
+ testing::Combine(testing::ValuesIn(kFuncDeclKinds),
+ testing::ValuesIn(kFuncUseKinds)));
+
+} // namespace resolved_symbols
+
+////////////////////////////////////////////////////////////////////////////////
+// Shadowing tests
+////////////////////////////////////////////////////////////////////////////////
+namespace shadowing {
+
+using ResolverDependencyShadowTest = ResolverDependencyGraphTestWithParam<
+ std::tuple<SymbolDeclKind, SymbolDeclKind>>;
+
+TEST_P(ResolverDependencyShadowTest, Test) {
+ const Symbol symbol = Sym("SYMBOL");
+ const auto outer_kind = std::get<0>(GetParam());
+ const auto inner_kind = std::get<1>(GetParam());
+
+ // Build a symbol declaration and a use of that symbol
+ SymbolTestHelper helper(this);
+ auto* outer = helper.Add(outer_kind, symbol, Source{{12, 34}});
+ helper.Add(inner_kind, symbol, Source{{56, 78}});
+ auto* inner_var = helper.nested_statements.size()
+ ? helper.nested_statements[0]
+ ->As<ast::VariableDeclStatement>()
+ ->variable
+ : helper.statements.size()
+ ? helper.statements[0]
+ ->As<ast::VariableDeclStatement>()
+ ->variable
+ : helper.parameters[0];
+ helper.Build();
+
+ EXPECT_EQ(Build().shadows[inner_var], outer);
+}
+
+INSTANTIATE_TEST_SUITE_P(LocalShadowGlobal,
+ ResolverDependencyShadowTest,
+ testing::Combine(testing::ValuesIn(kGlobalDeclKinds),
+ testing::ValuesIn(kLocalDeclKinds)));
+
+INSTANTIATE_TEST_SUITE_P(
+ NestedLocalShadowLocal,
+ ResolverDependencyShadowTest,
+ testing::Combine(testing::Values(SymbolDeclKind::Parameter,
+ SymbolDeclKind::LocalVar,
+ SymbolDeclKind::LocalLet),
+ testing::Values(SymbolDeclKind::NestedLocalVar,
+ SymbolDeclKind::NestedLocalLet)));
+
+} // namespace shadowing
+
+////////////////////////////////////////////////////////////////////////////////
+// AST traversal tests
+////////////////////////////////////////////////////////////////////////////////
+namespace ast_traversal {
+
+using ResolverDependencyGraphTraversalTest = ResolverDependencyGraphTest;
+
+TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) {
+ const auto value_sym = Sym("VALUE");
+ const auto type_sym = Sym("TYPE");
+ const auto func_sym = Sym("FUNC");
+
+ const auto* value_decl =
+ Global(value_sym, ty.i32(), ast::StorageClass::kPrivate);
+ const auto* type_decl = Alias(type_sym, ty.i32());
+ const auto* func_decl = Func(func_sym, {}, ty.void_(), {});
+
+ struct SymbolUse {
+ const ast::Node* decl = nullptr;
+ const ast::Node* use = nullptr;
+ std::string where = nullptr;
+ };
+
+ std::vector<SymbolUse> symbol_uses;
+
+ auto add_use = [&](const ast::Node* decl, auto* use, int line,
+ const char* kind) {
+ symbol_uses.emplace_back(SymbolUse{
+ decl, use,
+ std::string(__FILE__) + ":" + std::to_string(line) + ": " + kind});
+ return use;
+ };
+#define V add_use(value_decl, Expr(value_sym), __LINE__, "V()")
+#define T add_use(type_decl, ty.type_name(type_sym), __LINE__, "T()")
+#define F add_use(func_decl, Expr(func_sym), __LINE__, "F()")
+
+ Alias(Sym(), T);
+ Structure(Sym(), {Member(Sym(), T)});
+ Global(Sym(), T, V);
+ GlobalConst(Sym(), T, V);
+ Func(Sym(), //
+ {Param(Sym(), T)}, //
+ T, // Return type
+ {
+ Decl(Var(Sym(), T, V)), //
+ Decl(Const(Sym(), T, V)), //
+ CallStmt(Call(F, V)), //
+ Block( //
+ Assign(V, V)), //
+ If(V, //
+ Block(Assign(V, V)), //
+ Else(V, //
+ Block(Assign(V, V)))), //
+ Ignore(Bitcast(T, V)), //
+ For(Decl(Var(Sym(), T, V)), //
+ Equal(V, V), //
+ Assign(V, V), //
+ Block( //
+ Assign(V, V))), //
+ Loop(Block(Assign(V, V)), //
+ Block(Assign(V, V))), //
+ Switch(V, //
+ Case(Expr(1), //
+ Block(Assign(V, V))), //
+ Case(Expr(2), //
+ Block(Fallthrough())), //
+ DefaultCase(Block(Assign(V, V)))), //
+ Return(V), //
+ Break(), //
+ Discard(), //
+ }); //
+ // Exercise type traversal
+ Global(Sym(), ty.atomic(T));
+ Global(Sym(), ty.bool_());
+ Global(Sym(), ty.i32());
+ Global(Sym(), ty.u32());
+ Global(Sym(), ty.f32());
+ Global(Sym(), ty.array(T, V, 4));
+ Global(Sym(), ty.vec3(T));
+ Global(Sym(), ty.mat3x2(T));
+ Global(Sym(), ty.pointer(T, ast::StorageClass::kPrivate));
+ Global(Sym(), ty.sampled_texture(ast::TextureDimension::k2d, T));
+ Global(Sym(), ty.depth_texture(ast::TextureDimension::k2d));
+ Global(Sym(), ty.depth_multisampled_texture(ast::TextureDimension::k2d));
+ Global(Sym(), ty.external_texture());
+ Global(Sym(), ty.multisampled_texture(ast::TextureDimension::k2d, T));
+ Global(Sym(), ty.storage_texture(ast::TextureDimension::k2d,
+ ast::TexelFormat::kR32Float,
+ ast::Access::kRead)); //
+ Global(Sym(), ty.sampler(ast::SamplerKind::kSampler));
+ Func(Sym(), {}, ty.void_(), {});
+#undef V
+#undef T
+#undef F
+
+ auto graph = Build();
+ for (auto use : symbol_uses) {
+ auto* resolved_symbol = graph.resolved_symbols[use.use];
+ EXPECT_EQ(resolved_symbol, use.decl) << use.where;
+ }
+}
+
+TEST_F(ResolverDependencyGraphTraversalTest, InferredType) {
+ // Check that the nullptr of the var / let type doesn't make things explode
+ Global("a", nullptr, Expr(1));
+ GlobalConst("b", nullptr, Expr(1));
+ WrapInFunction(Var("c", nullptr, Expr(1)), //
+ Const("d", nullptr, Expr(1)));
+ Build();
+}
+
+// Reproduces an unbalanced stack push / pop bug in
+// DependencyAnalysis::SortGlobals(), found by clusterfuzz.
+// See: crbug.com/chromium/1273451
+TEST_F(ResolverDependencyGraphTraversalTest, chromium_1273451) {
+ Structure("A", {Member("a", ty.i32())});
+ Structure("B", {Member("b", ty.i32())});
+ Func("f", {Param("a", ty.type_name("A"))}, ty.type_name("B"),
+ {
+ Return(Construct(ty.type_name("B"))),
+ });
+ Build();
+}
+
+} // namespace ast_traversal
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/entry_point_validation_test.cc b/src/tint/resolver/entry_point_validation_test.cc
new file mode 100644
index 0000000..1f61452
--- /dev/null
+++ b/src/tint/resolver/entry_point_validation_test.cc
@@ -0,0 +1,804 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/builtin_attribute.h"
+#include "src/tint/ast/location_attribute.h"
+#include "src/tint/ast/return_statement.h"
+#include "src/tint/ast/stage_attribute.h"
+#include "src/tint/ast/struct_block_attribute.h"
+#include "src/tint/resolver/resolver.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+
+#include "gmock/gmock.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+// Helpers and typedefs
+template <typename T>
+using DataType = builder::DataType<T>;
+template <typename T>
+using vec2 = builder::vec2<T>;
+template <typename T>
+using vec3 = builder::vec3<T>;
+template <typename T>
+using vec4 = builder::vec4<T>;
+template <typename T>
+using mat2x2 = builder::mat2x2<T>;
+template <typename T>
+using mat3x3 = builder::mat3x3<T>;
+template <typename T>
+using mat4x4 = builder::mat4x4<T>;
+template <typename T>
+using alias = builder::alias<T>;
+using f32 = builder::f32;
+using i32 = builder::i32;
+using u32 = builder::u32;
+
+class ResolverEntryPointValidationTest : public TestHelper,
+ public testing::Test {};
+
+TEST_F(ResolverEntryPointValidationTest, ReturnTypeAttribute_Location) {
+ // @stage(fragment)
+ // fn main() -> @location(0) f32 { return 1.0; }
+ Func(Source{{12, 34}}, "main", {}, ty.f32(), {Return(1.0f)},
+ {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverEntryPointValidationTest, ReturnTypeAttribute_Builtin) {
+ // @stage(vertex)
+ // fn main() -> @builtin(position) vec4<f32> { return vec4<f32>(); }
+ Func(Source{{12, 34}}, "main", {}, ty.vec4<f32>(),
+ {Return(Construct(ty.vec4<f32>()))},
+ {Stage(ast::PipelineStage::kVertex)},
+ {Builtin(ast::Builtin::kPosition)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverEntryPointValidationTest, ReturnTypeAttribute_Missing) {
+ // @stage(vertex)
+ // fn main() -> f32 {
+ // return 1.0;
+ // }
+ Func(Source{{12, 34}}, "main", {}, ty.vec4<f32>(),
+ {Return(Construct(ty.vec4<f32>()))},
+ {Stage(ast::PipelineStage::kVertex)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: missing entry point IO attribute on return type");
+}
+
+TEST_F(ResolverEntryPointValidationTest, ReturnTypeAttribute_Multiple) {
+ // @stage(vertex)
+ // fn main() -> @location(0) @builtin(position) vec4<f32> {
+ // return vec4<f32>();
+ // }
+ Func(Source{{12, 34}}, "main", {}, ty.vec4<f32>(),
+ {Return(Construct(ty.vec4<f32>()))},
+ {Stage(ast::PipelineStage::kVertex)},
+ {Location(Source{{13, 43}}, 0),
+ Builtin(Source{{14, 52}}, ast::Builtin::kPosition)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(14:52 error: multiple entry point IO attributes
+13:43 note: previously consumed location(0))");
+}
+
+TEST_F(ResolverEntryPointValidationTest, ReturnType_Struct_Valid) {
+ // struct Output {
+ // @location(0) a : f32;
+ // @builtin(frag_depth) b : f32;
+ // };
+ // @stage(fragment)
+ // fn main() -> Output {
+ // return Output();
+ // }
+ auto* output = Structure(
+ "Output", {Member("a", ty.f32(), {Location(0)}),
+ Member("b", ty.f32(), {Builtin(ast::Builtin::kFragDepth)})});
+ Func(Source{{12, 34}}, "main", {}, ty.Of(output),
+ {Return(Construct(ty.Of(output)))},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverEntryPointValidationTest,
+ ReturnType_Struct_MemberMultipleAttributes) {
+ // struct Output {
+ // @location(0) @builtin(frag_depth) a : f32;
+ // };
+ // @stage(fragment)
+ // fn main() -> Output {
+ // return Output();
+ // }
+ auto* output = Structure(
+ "Output",
+ {Member("a", ty.f32(),
+ {Location(Source{{13, 43}}, 0),
+ Builtin(Source{{14, 52}}, ast::Builtin::kFragDepth)})});
+ Func(Source{{12, 34}}, "main", {}, ty.Of(output),
+ {Return(Construct(ty.Of(output)))},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(14:52 error: multiple entry point IO attributes
+13:43 note: previously consumed location(0)
+12:34 note: while analysing entry point 'main')");
+}
+
+TEST_F(ResolverEntryPointValidationTest,
+ ReturnType_Struct_MemberMissingAttribute) {
+ // struct Output {
+ // @location(0) a : f32;
+ // b : f32;
+ // };
+ // @stage(fragment)
+ // fn main() -> Output {
+ // return Output();
+ // }
+ auto* output = Structure(
+ "Output", {Member(Source{{13, 43}}, "a", ty.f32(), {Location(0)}),
+ Member(Source{{14, 52}}, "b", ty.f32(), {})});
+ Func(Source{{12, 34}}, "main", {}, ty.Of(output),
+ {Return(Construct(ty.Of(output)))},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(14:52 error: missing entry point IO attribute
+12:34 note: while analysing entry point 'main')");
+}
+
+TEST_F(ResolverEntryPointValidationTest, ReturnType_Struct_DuplicateBuiltins) {
+ // struct Output {
+ // @builtin(frag_depth) a : f32;
+ // @builtin(frag_depth) b : f32;
+ // };
+ // @stage(fragment)
+ // fn main() -> Output {
+ // return Output();
+ // }
+ auto* output = Structure(
+ "Output", {Member("a", ty.f32(), {Builtin(ast::Builtin::kFragDepth)}),
+ Member("b", ty.f32(), {Builtin(ast::Builtin::kFragDepth)})});
+ Func(Source{{12, 34}}, "main", {}, ty.Of(output),
+ {Return(Construct(ty.Of(output)))},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: builtin(frag_depth) attribute appears multiple times as pipeline output
+12:34 note: while analysing entry point 'main')");
+}
+
+TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Location) {
+ // @stage(fragment)
+ // fn main(@location(0) param : f32) {}
+ auto* param = Param("param", ty.f32(), {Location(0)});
+ Func(Source{{12, 34}}, "main", {param}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Missing) {
+ // @stage(fragment)
+ // fn main(param : f32) {}
+ auto* param = Param(Source{{13, 43}}, "param", ty.vec4<f32>());
+ Func(Source{{12, 34}}, "main", {param}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "13:43 error: missing entry point IO attribute on parameter");
+}
+
+TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Multiple) {
+ // @stage(fragment)
+ // fn main(@location(0) @builtin(sample_index) param : u32) {}
+ auto* param = Param("param", ty.u32(),
+ {Location(Source{{13, 43}}, 0),
+ Builtin(Source{{14, 52}}, ast::Builtin::kSampleIndex)});
+ Func(Source{{12, 34}}, "main", {param}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(14:52 error: multiple entry point IO attributes
+13:43 note: previously consumed location(0))");
+}
+
+TEST_F(ResolverEntryPointValidationTest, Parameter_Struct_Valid) {
+ // struct Input {
+ // @location(0) a : f32;
+ // @builtin(sample_index) b : u32;
+ // };
+ // @stage(fragment)
+ // fn main(param : Input) {}
+ auto* input = Structure(
+ "Input", {Member("a", ty.f32(), {Location(0)}),
+ Member("b", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)})});
+ auto* param = Param("param", ty.Of(input));
+ Func(Source{{12, 34}}, "main", {param}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverEntryPointValidationTest,
+ Parameter_Struct_MemberMultipleAttributes) {
+ // struct Input {
+ // @location(0) @builtin(sample_index) a : u32;
+ // };
+ // @stage(fragment)
+ // fn main(param : Input) {}
+ auto* input = Structure(
+ "Input",
+ {Member("a", ty.u32(),
+ {Location(Source{{13, 43}}, 0),
+ Builtin(Source{{14, 52}}, ast::Builtin::kSampleIndex)})});
+ auto* param = Param("param", ty.Of(input));
+ Func(Source{{12, 34}}, "main", {param}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(14:52 error: multiple entry point IO attributes
+13:43 note: previously consumed location(0)
+12:34 note: while analysing entry point 'main')");
+}
+
+TEST_F(ResolverEntryPointValidationTest,
+ Parameter_Struct_MemberMissingAttribute) {
+ // struct Input {
+ // @location(0) a : f32;
+ // b : f32;
+ // };
+ // @stage(fragment)
+ // fn main(param : Input) {}
+ auto* input = Structure(
+ "Input", {Member(Source{{13, 43}}, "a", ty.f32(), {Location(0)}),
+ Member(Source{{14, 52}}, "b", ty.f32(), {})});
+ auto* param = Param("param", ty.Of(input));
+ Func(Source{{12, 34}}, "main", {param}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(14:52 error: missing entry point IO attribute
+12:34 note: while analysing entry point 'main')");
+}
+
+TEST_F(ResolverEntryPointValidationTest, Parameter_DuplicateBuiltins) {
+ // @stage(fragment)
+ // fn main(@builtin(sample_index) param_a : u32,
+ // @builtin(sample_index) param_b : u32) {}
+ auto* param_a =
+ Param("param_a", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)});
+ auto* param_b =
+ Param("param_b", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)});
+ Func(Source{{12, 34}}, "main", {param_a, param_b}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: builtin(sample_index) attribute appears multiple times as "
+ "pipeline input");
+}
+
+TEST_F(ResolverEntryPointValidationTest, Parameter_Struct_DuplicateBuiltins) {
+ // struct InputA {
+ // @builtin(sample_index) a : u32;
+ // };
+ // struct InputB {
+ // @builtin(sample_index) a : u32;
+ // };
+ // @stage(fragment)
+ // fn main(param_a : InputA, param_b : InputB) {}
+ auto* input_a = Structure(
+ "InputA", {Member("a", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)})});
+ auto* input_b = Structure(
+ "InputB", {Member("a", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)})});
+ auto* param_a = Param("param_a", ty.Of(input_a));
+ auto* param_b = Param("param_b", ty.Of(input_b));
+ Func(Source{{12, 34}}, "main", {param_a, param_b}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: builtin(sample_index) attribute appears multiple times as pipeline input
+12:34 note: while analysing entry point 'main')");
+}
+
+TEST_F(ResolverEntryPointValidationTest, VertexShaderMustReturnPosition) {
+ // @stage(vertex)
+ // fn main() {}
+ Func(Source{{12, 34}}, "main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kVertex)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: a vertex shader must include the 'position' builtin "
+ "in its return type");
+}
+
+namespace TypeValidationTests {
+struct Params {
+ builder::ast_type_func_ptr create_ast_type;
+ bool is_valid;
+};
+
+template <typename T>
+constexpr Params ParamsFor(bool is_valid) {
+ return Params{DataType<T>::AST, is_valid};
+}
+
+using TypeValidationTest = resolver::ResolverTestWithParam<Params>;
+
+static constexpr Params cases[] = {
+ ParamsFor<f32>(true), //
+ ParamsFor<i32>(true), //
+ ParamsFor<u32>(true), //
+ ParamsFor<bool>(false), //
+ ParamsFor<vec2<f32>>(true), //
+ ParamsFor<vec3<f32>>(true), //
+ ParamsFor<vec4<f32>>(true), //
+ ParamsFor<mat2x2<f32>>(false), //
+ ParamsFor<mat3x3<f32>>(false), //
+ ParamsFor<mat4x4<f32>>(false), //
+ ParamsFor<alias<f32>>(true), //
+ ParamsFor<alias<i32>>(true), //
+ ParamsFor<alias<u32>>(true), //
+ ParamsFor<alias<bool>>(false), //
+};
+
+TEST_P(TypeValidationTest, BareInputs) {
+ // @stage(fragment)
+ // fn main(@location(0) @interpolate(flat) a : *) {}
+ auto params = GetParam();
+ auto* a = Param("a", params.create_ast_type(*this), {Location(0), Flat()});
+ Func(Source{{12, 34}}, "main", {a}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ if (params.is_valid) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ }
+}
+
+TEST_P(TypeValidationTest, StructInputs) {
+ // struct Input {
+ // @location(0) @interpolate(flat) a : *;
+ // };
+ // @stage(fragment)
+ // fn main(a : Input) {}
+ auto params = GetParam();
+ auto* input = Structure("Input", {Member("a", params.create_ast_type(*this),
+ {Location(0), Flat()})});
+ auto* a = Param("a", ty.Of(input), {});
+ Func(Source{{12, 34}}, "main", {a}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ if (params.is_valid) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ }
+}
+
+TEST_P(TypeValidationTest, BareOutputs) {
+ // @stage(fragment)
+ // fn main() -> @location(0) * {
+ // return *();
+ // }
+ auto params = GetParam();
+ Func(Source{{12, 34}}, "main", {}, params.create_ast_type(*this),
+ {Return(Construct(params.create_ast_type(*this)))},
+ {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+
+ if (params.is_valid) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ }
+}
+
+TEST_P(TypeValidationTest, StructOutputs) {
+ // struct Output {
+ // @location(0) a : *;
+ // };
+ // @stage(fragment)
+ // fn main() -> Output {
+ // return Output();
+ // }
+ auto params = GetParam();
+ auto* output = Structure(
+ "Output", {Member("a", params.create_ast_type(*this), {Location(0)})});
+ Func(Source{{12, 34}}, "main", {}, ty.Of(output),
+ {Return(Construct(ty.Of(output)))},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ if (params.is_valid) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ }
+}
+INSTANTIATE_TEST_SUITE_P(ResolverEntryPointValidationTest,
+ TypeValidationTest,
+ testing::ValuesIn(cases));
+
+} // namespace TypeValidationTests
+
+namespace LocationAttributeTests {
+namespace {
+using LocationAttributeTests = ResolverTest;
+
+TEST_F(LocationAttributeTests, Pass) {
+ // @stage(fragment)
+ // fn frag_main(@location(0) @interpolate(flat) a: i32) {}
+
+ auto* p = Param(Source{{12, 34}}, "a", ty.i32(), {Location(0), Flat()});
+ Func("frag_main", {p}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(LocationAttributeTests, BadType_Input_bool) {
+ // @stage(fragment)
+ // fn frag_main(@location(0) a: bool) {}
+
+ auto* p =
+ Param(Source{{12, 34}}, "a", ty.bool_(), {Location(Source{{34, 56}}, 0)});
+ Func("frag_main", {p}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cannot apply 'location' attribute to declaration of "
+ "type 'bool'\n"
+ "34:56 note: 'location' attribute must only be applied to "
+ "declarations of numeric scalar or numeric vector type");
+}
+
+TEST_F(LocationAttributeTests, BadType_Output_Array) {
+ // @stage(fragment)
+ // fn frag_main()->@location(0) array<f32, 2> { return array<f32, 2>(); }
+
+ Func(Source{{12, 34}}, "frag_main", {}, ty.array<f32, 2>(),
+ {Return(Construct(ty.array<f32, 2>()))},
+ {Stage(ast::PipelineStage::kFragment)}, {Location(Source{{34, 56}}, 0)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cannot apply 'location' attribute to declaration of "
+ "type 'array<f32, 2>'\n"
+ "34:56 note: 'location' attribute must only be applied to "
+ "declarations of numeric scalar or numeric vector type");
+}
+
+TEST_F(LocationAttributeTests, BadType_Input_Struct) {
+ // struct Input {
+ // a : f32;
+ // };
+ // @stage(fragment)
+ // fn main(@location(0) param : Input) {}
+ auto* input = Structure("Input", {Member("a", ty.f32())});
+ auto* param = Param(Source{{12, 34}}, "param", ty.Of(input),
+ {Location(Source{{13, 43}}, 0)});
+ Func(Source{{12, 34}}, "main", {param}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cannot apply 'location' attribute to declaration of "
+ "type 'Input'\n"
+ "13:43 note: 'location' attribute must only be applied to "
+ "declarations of numeric scalar or numeric vector type");
+}
+
+TEST_F(LocationAttributeTests, BadType_Input_Struct_NestedStruct) {
+ // struct Inner {
+ // @location(0) b : f32;
+ // };
+ // struct Input {
+ // a : Inner;
+ // };
+ // @stage(fragment)
+ // fn main(param : Input) {}
+ auto* inner = Structure(
+ "Inner", {Member(Source{{13, 43}}, "a", ty.f32(), {Location(0)})});
+ auto* input =
+ Structure("Input", {Member(Source{{14, 52}}, "a", ty.Of(inner))});
+ auto* param = Param("param", ty.Of(input));
+ Func(Source{{12, 34}}, "main", {param}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "14:52 error: nested structures cannot be used for entry point IO\n"
+ "12:34 note: while analysing entry point 'main'");
+}
+
+TEST_F(LocationAttributeTests, BadType_Input_Struct_RuntimeArray) {
+ // [[block]]
+ // struct Input {
+ // @location(0) a : array<f32>;
+ // };
+ // @stage(fragment)
+ // fn main(param : Input) {}
+ auto* input = Structure(
+ "Input",
+ {Member(Source{{13, 43}}, "a", ty.array<float>(), {Location(0)})},
+ {create<ast::StructBlockAttribute>()});
+ auto* param = Param("param", ty.Of(input));
+ Func(Source{{12, 34}}, "main", {param}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "13:43 error: cannot apply 'location' attribute to declaration of "
+ "type 'array<f32>'\n"
+ "note: 'location' attribute must only be applied to declarations "
+ "of numeric scalar or numeric vector type");
+}
+
+TEST_F(LocationAttributeTests, BadMemberType_Input) {
+ // [[block]]
+ // struct S { @location(0) m: array<i32>; };
+ // @stage(fragment)
+ // fn frag_main( a: S) {}
+
+ auto* m = Member(Source{{34, 56}}, "m", ty.array<i32>(),
+ ast::AttributeList{Location(Source{{12, 34}}, 0u)});
+ auto* s = Structure("S", {m}, ast::AttributeList{StructBlock()});
+ auto* p = Param("a", ty.Of(s));
+
+ Func("frag_main", {p}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "34:56 error: cannot apply 'location' attribute to declaration of "
+ "type 'array<i32>'\n"
+ "12:34 note: 'location' attribute must only be applied to "
+ "declarations of numeric scalar or numeric vector type");
+}
+
+TEST_F(LocationAttributeTests, BadMemberType_Output) {
+ // struct S { @location(0) m: atomic<i32>; };
+ // @stage(fragment)
+ // fn frag_main() -> S {}
+ auto* m = Member(Source{{34, 56}}, "m", ty.atomic<i32>(),
+ ast::AttributeList{Location(Source{{12, 34}}, 0u)});
+ auto* s = Structure("S", {m});
+
+ Func("frag_main", {}, ty.Of(s), {Return(Construct(ty.Of(s)))},
+ {Stage(ast::PipelineStage::kFragment)}, {});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "34:56 error: cannot apply 'location' attribute to declaration of "
+ "type 'atomic<i32>'\n"
+ "12:34 note: 'location' attribute must only be applied to "
+ "declarations of numeric scalar or numeric vector type");
+}
+
+TEST_F(LocationAttributeTests, BadMemberType_Unused) {
+ // struct S { @location(0) m: mat3x2<f32>; };
+
+ auto* m = Member(Source{{34, 56}}, "m", ty.mat3x2<f32>(),
+ ast::AttributeList{Location(Source{{12, 34}}, 0u)});
+ Structure("S", {m});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "34:56 error: cannot apply 'location' attribute to declaration of "
+ "type 'mat3x2<f32>'\n"
+ "12:34 note: 'location' attribute must only be applied to "
+ "declarations of numeric scalar or numeric vector type");
+}
+
+TEST_F(LocationAttributeTests, ReturnType_Struct_Valid) {
+ // struct Output {
+ // @location(0) a : f32;
+ // @builtin(frag_depth) b : f32;
+ // };
+ // @stage(fragment)
+ // fn main() -> Output {
+ // return Output();
+ // }
+ auto* output = Structure(
+ "Output", {Member("a", ty.f32(), {Location(0)}),
+ Member("b", ty.f32(), {Builtin(ast::Builtin::kFragDepth)})});
+ Func(Source{{12, 34}}, "main", {}, ty.Of(output),
+ {Return(Construct(ty.Of(output)))},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(LocationAttributeTests, ReturnType_Struct) {
+ // struct Output {
+ // a : f32;
+ // };
+ // @stage(vertex)
+ // fn main() -> @location(0) Output {
+ // return Output();
+ // }
+ auto* output = Structure("Output", {Member("a", ty.f32())});
+ Func(Source{{12, 34}}, "main", {}, ty.Of(output),
+ {Return(Construct(ty.Of(output)))}, {Stage(ast::PipelineStage::kVertex)},
+ {Location(Source{{13, 43}}, 0)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cannot apply 'location' attribute to declaration of "
+ "type 'Output'\n"
+ "13:43 note: 'location' attribute must only be applied to "
+ "declarations of numeric scalar or numeric vector type");
+}
+
+TEST_F(LocationAttributeTests, ReturnType_Struct_NestedStruct) {
+ // struct Inner {
+ // @location(0) b : f32;
+ // };
+ // struct Output {
+ // a : Inner;
+ // };
+ // @stage(fragment)
+ // fn main() -> Output { return Output(); }
+ auto* inner = Structure(
+ "Inner", {Member(Source{{13, 43}}, "a", ty.f32(), {Location(0)})});
+ auto* output =
+ Structure("Output", {Member(Source{{14, 52}}, "a", ty.Of(inner))});
+ Func(Source{{12, 34}}, "main", {}, ty.Of(output),
+ {Return(Construct(ty.Of(output)))},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "14:52 error: nested structures cannot be used for entry point IO\n"
+ "12:34 note: while analysing entry point 'main'");
+}
+
+TEST_F(LocationAttributeTests, ReturnType_Struct_RuntimeArray) {
+ // [[block]]
+ // struct Output {
+ // @location(0) a : array<f32>;
+ // };
+ // @stage(fragment)
+ // fn main() -> Output {
+ // return Output();
+ // }
+ auto* output = Structure("Output",
+ {Member(Source{{13, 43}}, "a", ty.array<float>(),
+ {Location(Source{{12, 34}}, 0)})},
+ {create<ast::StructBlockAttribute>()});
+ Func(Source{{12, 34}}, "main", {}, ty.Of(output),
+ {Return(Construct(ty.Of(output)))},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "13:43 error: cannot apply 'location' attribute to declaration of "
+ "type 'array<f32>'\n"
+ "12:34 note: 'location' attribute must only be applied to "
+ "declarations of numeric scalar or numeric vector type");
+}
+
+TEST_F(LocationAttributeTests, ComputeShaderLocation_Input) {
+ Func("main", {}, ty.i32(), {Return(Expr(1))},
+ {Stage(ast::PipelineStage::kCompute),
+ create<ast::WorkgroupAttribute>(Source{{12, 34}}, Expr(1))},
+ ast::AttributeList{Location(Source{{12, 34}}, 1)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: attribute is not valid for compute shader output");
+}
+
+TEST_F(LocationAttributeTests, ComputeShaderLocation_Output) {
+ auto* input = Param("input", ty.i32(),
+ ast::AttributeList{Location(Source{{12, 34}}, 0u)});
+ Func("main", {input}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ create<ast::WorkgroupAttribute>(Source{{12, 34}}, Expr(1))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: attribute is not valid for compute shader inputs");
+}
+
+TEST_F(LocationAttributeTests, ComputeShaderLocationStructMember_Output) {
+ auto* m =
+ Member("m", ty.i32(), ast::AttributeList{Location(Source{{12, 34}}, 0u)});
+ auto* s = Structure("S", {m});
+ Func(Source{{56, 78}}, "main", {}, ty.Of(s),
+ ast::StatementList{Return(Expr(Construct(ty.Of(s))))},
+ {Stage(ast::PipelineStage::kCompute),
+ create<ast::WorkgroupAttribute>(Source{{12, 34}}, Expr(1))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: attribute is not valid for compute shader output\n"
+ "56:78 note: while analysing entry point 'main'");
+}
+
+TEST_F(LocationAttributeTests, ComputeShaderLocationStructMember_Input) {
+ auto* m =
+ Member("m", ty.i32(), ast::AttributeList{Location(Source{{12, 34}}, 0u)});
+ auto* s = Structure("S", {m});
+ auto* input = Param("input", ty.Of(s));
+ Func(Source{{56, 78}}, "main", {input}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ create<ast::WorkgroupAttribute>(Source{{12, 34}}, Expr(1))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: attribute is not valid for compute shader inputs\n"
+ "56:78 note: while analysing entry point 'main'");
+}
+
+TEST_F(LocationAttributeTests, Duplicate_input) {
+ // @stage(fragment)
+ // fn main(@location(1) param_a : f32,
+ // @location(1) param_b : f32) {}
+ auto* param_a = Param("param_a", ty.f32(), {Location(1)});
+ auto* param_b = Param("param_b", ty.f32(), {Location(Source{{12, 34}}, 1)});
+ Func(Source{{12, 34}}, "main", {param_a, param_b}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: location(1) attribute appears multiple times");
+}
+
+TEST_F(LocationAttributeTests, Duplicate_struct) {
+ // struct InputA {
+ // @location(1) a : f32;
+ // };
+ // struct InputB {
+ // @location(1) a : f32;
+ // };
+ // @stage(fragment)
+ // fn main(param_a : InputA, param_b : InputB) {}
+ auto* input_a = Structure("InputA", {Member("a", ty.f32(), {Location(1)})});
+ auto* input_b = Structure(
+ "InputB", {Member("a", ty.f32(), {Location(Source{{34, 56}}, 1)})});
+ auto* param_a = Param("param_a", ty.Of(input_a));
+ auto* param_b = Param("param_b", ty.Of(input_b));
+ Func(Source{{12, 34}}, "main", {param_a, param_b}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "34:56 error: location(1) attribute appears multiple times\n"
+ "12:34 note: while analysing entry point 'main'");
+}
+
+} // namespace
+} // namespace LocationAttributeTests
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc
new file mode 100644
index 0000000..4efd00b
--- /dev/null
+++ b/src/tint/resolver/function_validation_test.cc
@@ -0,0 +1,830 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/discard_statement.h"
+#include "src/tint/ast/return_statement.h"
+#include "src/tint/ast/stage_attribute.h"
+#include "src/tint/resolver/resolver.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+
+#include "gmock/gmock.h"
+
+namespace tint {
+namespace {
+
+class ResolverFunctionValidationTest : public resolver::TestHelper,
+ public testing::Test {};
+
+TEST_F(ResolverFunctionValidationTest, DuplicateParameterName) {
+ // fn func_a(common_name : f32) { }
+ // fn func_b(common_name : f32) { }
+ Func("func_a", {Param("common_name", ty.f32())}, ty.void_(), {});
+ Func("func_b", {Param("common_name", ty.f32())}, ty.void_(), {});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest, ParameterMayShadowGlobal) {
+ // var<private> common_name : f32;
+ // fn func(common_name : f32) { }
+ Global("common_name", ty.f32(), ast::StorageClass::kPrivate);
+ Func("func", {Param("common_name", ty.f32())}, ty.void_(), {});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest, LocalConflictsWithParameter) {
+ // fn func(common_name : f32) {
+ // let common_name = 1;
+ // }
+ Func("func", {Param(Source{{12, 34}}, "common_name", ty.f32())}, ty.void_(),
+ {Decl(Const(Source{{56, 78}}, "common_name", nullptr, Expr(1)))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(56:78 error: redeclaration of 'common_name'
+12:34 note: 'common_name' previously declared here)");
+}
+
+TEST_F(ResolverFunctionValidationTest, NestedLocalMayShadowParameter) {
+ // fn func(common_name : f32) {
+ // {
+ // let common_name = 1;
+ // }
+ // }
+ Func("func", {Param(Source{{12, 34}}, "common_name", ty.f32())}, ty.void_(),
+ {Block(Decl(Const(Source{{56, 78}}, "common_name", nullptr, Expr(1))))});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest,
+ VoidFunctionEndWithoutReturnStatement_Pass) {
+ // fn func { var a:i32 = 2; }
+ auto* var = Var("a", ty.i32(), Expr(2));
+
+ Func(Source{{12, 34}}, "func", ast::VariableList{}, ty.void_(),
+ ast::StatementList{
+ Decl(var),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest, FunctionUsingSameVariableName_Pass) {
+ // fn func() -> i32 {
+ // var func:i32 = 0;
+ // return func;
+ // }
+
+ auto* var = Var("func", ty.i32(), Expr(0));
+ Func("func", ast::VariableList{}, ty.i32(),
+ ast::StatementList{
+ Decl(var),
+ Return(Source{{12, 34}}, Expr("func")),
+ },
+ ast::AttributeList{});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest,
+ FunctionNameSameAsFunctionScopeVariableName_Pass) {
+ // fn a() -> void { var b:i32 = 0; }
+ // fn b() -> i32 { return 2; }
+
+ auto* var = Var("b", ty.i32(), Expr(0));
+ Func("a", ast::VariableList{}, ty.void_(),
+ ast::StatementList{
+ Decl(var),
+ },
+ ast::AttributeList{});
+
+ Func(Source{{12, 34}}, "b", ast::VariableList{}, ty.i32(),
+ ast::StatementList{
+ Return(2),
+ },
+ ast::AttributeList{});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest, UnreachableCode_return) {
+ // fn func() -> {
+ // var a : i32;
+ // return;
+ // a = 2;
+ //}
+
+ auto* decl_a = Decl(Var("a", ty.i32()));
+ auto* ret = Return();
+ auto* assign_a = Assign(Source{{12, 34}}, "a", 2);
+
+ Func("func", ast::VariableList{}, ty.void_(), {decl_a, ret, assign_a});
+
+ ASSERT_TRUE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), "12:34 warning: code is unreachable");
+ EXPECT_TRUE(Sem().Get(decl_a)->IsReachable());
+ EXPECT_TRUE(Sem().Get(ret)->IsReachable());
+ EXPECT_FALSE(Sem().Get(assign_a)->IsReachable());
+}
+
+TEST_F(ResolverFunctionValidationTest, UnreachableCode_return_InBlocks) {
+ // fn func() -> {
+ // var a : i32;
+ // {{{return;}}}
+ // a = 2;
+ //}
+
+ auto* decl_a = Decl(Var("a", ty.i32()));
+ auto* ret = Return();
+ auto* assign_a = Assign(Source{{12, 34}}, "a", 2);
+
+ Func("func", ast::VariableList{}, ty.void_(),
+ {decl_a, Block(Block(Block(ret))), assign_a});
+
+ ASSERT_TRUE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 warning: code is unreachable");
+ EXPECT_TRUE(Sem().Get(decl_a)->IsReachable());
+ EXPECT_TRUE(Sem().Get(ret)->IsReachable());
+ EXPECT_FALSE(Sem().Get(assign_a)->IsReachable());
+}
+
+TEST_F(ResolverFunctionValidationTest, UnreachableCode_discard) {
+ // fn func() -> {
+ // var a : i32;
+ // discard;
+ // a = 2;
+ //}
+
+ auto* decl_a = Decl(Var("a", ty.i32()));
+ auto* discard = Discard();
+ auto* assign_a = Assign(Source{{12, 34}}, "a", 2);
+
+ Func("func", ast::VariableList{}, ty.void_(), {decl_a, discard, assign_a});
+
+ ASSERT_TRUE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 warning: code is unreachable");
+ EXPECT_TRUE(Sem().Get(decl_a)->IsReachable());
+ EXPECT_TRUE(Sem().Get(discard)->IsReachable());
+ EXPECT_FALSE(Sem().Get(assign_a)->IsReachable());
+}
+
+TEST_F(ResolverFunctionValidationTest, UnreachableCode_discard_InBlocks) {
+ // fn func() -> {
+ // var a : i32;
+ // {{{discard;}}}
+ // a = 2;
+ //}
+
+ auto* decl_a = Decl(Var("a", ty.i32()));
+ auto* discard = Discard();
+ auto* assign_a = Assign(Source{{12, 34}}, "a", 2);
+
+ Func("func", ast::VariableList{}, ty.void_(),
+ {decl_a, Block(Block(Block(discard))), assign_a});
+
+ ASSERT_TRUE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 warning: code is unreachable");
+ EXPECT_TRUE(Sem().Get(decl_a)->IsReachable());
+ EXPECT_TRUE(Sem().Get(discard)->IsReachable());
+ EXPECT_FALSE(Sem().Get(assign_a)->IsReachable());
+}
+
+TEST_F(ResolverFunctionValidationTest, FunctionEndWithoutReturnStatement_Fail) {
+ // fn func() -> int { var a:i32 = 2; }
+
+ auto* var = Var("a", ty.i32(), Expr(2));
+
+ Func(Source{{12, 34}}, "func", ast::VariableList{}, ty.i32(),
+ ast::StatementList{
+ Decl(var),
+ },
+ ast::AttributeList{});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: missing return at end of function");
+}
+
+TEST_F(ResolverFunctionValidationTest,
+ VoidFunctionEndWithoutReturnStatementEmptyBody_Pass) {
+ // fn func {}
+
+ Func(Source{{12, 34}}, "func", ast::VariableList{}, ty.void_(),
+ ast::StatementList{});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest,
+ FunctionEndWithoutReturnStatementEmptyBody_Fail) {
+ // fn func() -> int {}
+
+ Func(Source{{12, 34}}, "func", ast::VariableList{}, ty.i32(),
+ ast::StatementList{}, ast::AttributeList{});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: missing return at end of function");
+}
+
+TEST_F(ResolverFunctionValidationTest,
+ FunctionTypeMustMatchReturnStatementType_Pass) {
+ // fn func { return; }
+
+ Func("func", ast::VariableList{}, ty.void_(),
+ ast::StatementList{
+ Return(),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest,
+ FunctionTypeMustMatchReturnStatementType_fail) {
+ // fn func { return 2; }
+ Func("func", ast::VariableList{}, ty.void_(),
+ ast::StatementList{
+ Return(Source{{12, 34}}, Expr(2)),
+ },
+ ast::AttributeList{});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: return statement type must match its function return "
+ "type, returned 'i32', expected 'void'");
+}
+
+TEST_F(ResolverFunctionValidationTest,
+ FunctionTypeMustMatchReturnStatementType_void_fail) {
+ // fn v { return; }
+ // fn func { return v(); }
+ Func("v", {}, ty.void_(), {Return()});
+ Func("func", {}, ty.void_(),
+ {
+ Return(Call(Source{{12, 34}}, "v")),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: function 'v' does not return a value");
+}
+
+TEST_F(ResolverFunctionValidationTest,
+ FunctionTypeMustMatchReturnStatementTypeMissing_fail) {
+ // fn func() -> f32 { return; }
+ Func("func", ast::VariableList{}, ty.f32(),
+ ast::StatementList{
+ Return(Source{{12, 34}}, nullptr),
+ },
+ ast::AttributeList{});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: return statement type must match its function return "
+ "type, returned 'void', expected 'f32'");
+}
+
+TEST_F(ResolverFunctionValidationTest,
+ FunctionTypeMustMatchReturnStatementTypeF32_pass) {
+ // fn func() -> f32 { return 2.0; }
+ Func("func", ast::VariableList{}, ty.f32(),
+ ast::StatementList{
+ Return(Source{{12, 34}}, Expr(2.f)),
+ },
+ ast::AttributeList{});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest,
+ FunctionTypeMustMatchReturnStatementTypeF32_fail) {
+ // fn func() -> f32 { return 2; }
+ Func("func", ast::VariableList{}, ty.f32(),
+ ast::StatementList{
+ Return(Source{{12, 34}}, Expr(2)),
+ },
+ ast::AttributeList{});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: return statement type must match its function return "
+ "type, returned 'i32', expected 'f32'");
+}
+
+TEST_F(ResolverFunctionValidationTest,
+ FunctionTypeMustMatchReturnStatementTypeF32Alias_pass) {
+ // type myf32 = f32;
+ // fn func() -> myf32 { return 2.0; }
+ auto* myf32 = Alias("myf32", ty.f32());
+ Func("func", ast::VariableList{}, ty.Of(myf32),
+ ast::StatementList{
+ Return(Source{{12, 34}}, Expr(2.f)),
+ },
+ ast::AttributeList{});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest,
+ FunctionTypeMustMatchReturnStatementTypeF32Alias_fail) {
+ // type myf32 = f32;
+ // fn func() -> myf32 { return 2; }
+ auto* myf32 = Alias("myf32", ty.f32());
+ Func("func", ast::VariableList{}, ty.Of(myf32),
+ ast::StatementList{
+ Return(Source{{12, 34}}, Expr(2u)),
+ },
+ ast::AttributeList{});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: return statement type must match its function return "
+ "type, returned 'u32', expected 'f32'");
+}
+
+TEST_F(ResolverFunctionValidationTest, CannotCallEntryPoint) {
+ // @stage(compute) @workgroup_size(1) fn entrypoint() {}
+ // fn func() { return entrypoint(); }
+ Func("entrypoint", ast::VariableList{}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
+
+ Func("func", ast::VariableList{}, ty.void_(),
+ {
+ CallStmt(Call(Source{{12, 34}}, "entrypoint")),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+
+ R"(12:34 error: entry point functions cannot be the target of a function call)");
+}
+
+TEST_F(ResolverFunctionValidationTest, PipelineStage_MustBeUnique_Fail) {
+ // @stage(fragment)
+ // @stage(vertex)
+ // fn main() { return; }
+ Func(Source{{12, 34}}, "main", ast::VariableList{}, ty.void_(),
+ ast::StatementList{
+ Return(),
+ },
+ ast::AttributeList{
+ Stage(Source{{12, 34}}, ast::PipelineStage::kVertex),
+ Stage(Source{{56, 78}}, ast::PipelineStage::kFragment),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(56:78 error: duplicate stage attribute
+12:34 note: first attribute declared here)");
+}
+
+TEST_F(ResolverFunctionValidationTest, NoPipelineEntryPoints) {
+ Func("vtx_func", ast::VariableList{}, ty.void_(),
+ ast::StatementList{
+ Return(),
+ },
+ ast::AttributeList{});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest, FunctionVarInitWithParam) {
+ // fn foo(bar : f32){
+ // var baz : f32 = bar;
+ // }
+
+ auto* bar = Param("bar", ty.f32());
+ auto* baz = Var("baz", ty.f32(), Expr("bar"));
+
+ Func("foo", ast::VariableList{bar}, ty.void_(), ast::StatementList{Decl(baz)},
+ ast::AttributeList{});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest, FunctionConstInitWithParam) {
+ // fn foo(bar : f32){
+ // let baz : f32 = bar;
+ // }
+
+ auto* bar = Param("bar", ty.f32());
+ auto* baz = Const("baz", ty.f32(), Expr("bar"));
+
+ Func("foo", ast::VariableList{bar}, ty.void_(), ast::StatementList{Decl(baz)},
+ ast::AttributeList{});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest, FunctionParamsConst) {
+ Func("foo", {Param(Sym("arg"), ty.i32())}, ty.void_(),
+ {Assign(Expr(Source{{12, 34}}, "arg"), Expr(1)), Return()});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cannot assign to function parameter\nnote: 'arg' is "
+ "declared here:");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_ConstU32) {
+ // let x = 4u;
+ // let x = 8u;
+ // @stage(compute) @workgroup_size(x, y, 16u)
+ // fn main() {}
+ auto* x = GlobalConst("x", ty.u32(), Expr(4u));
+ auto* y = GlobalConst("y", ty.u32(), Expr(8u));
+ auto* func = Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr("x"), Expr("y"), Expr(16u))});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem_func = Sem().Get(func);
+ auto* sem_x = Sem().Get<sem::GlobalVariable>(x);
+ auto* sem_y = Sem().Get<sem::GlobalVariable>(y);
+
+ ASSERT_NE(sem_func, nullptr);
+ ASSERT_NE(sem_x, nullptr);
+ ASSERT_NE(sem_y, nullptr);
+
+ EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().contains(sem_x));
+ EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().contains(sem_y));
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32) {
+ // [[stage(compute), workgroup_size(1u, 2u, 3u)]
+ // fn main() {}
+
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Source{{12, 34}}, Expr(1u), Expr(2u), Expr(3u))});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchTypeU32) {
+ // [[stage(compute), workgroup_size(1u, 2u, 3)]
+ // fn main() {}
+
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(1u), Expr(2u), Expr(Source{{12, 34}}, 3))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size arguments must be of the same type, "
+ "either i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchTypeI32) {
+ // [[stage(compute), workgroup_size(1, 2u, 3)]
+ // fn main() {}
+
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(1), Expr(Source{{12, 34}}, 2u), Expr(3))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size arguments must be of the same type, "
+ "either i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch) {
+ // let x = 64u;
+ // [[stage(compute), workgroup_size(1, x)]
+ // fn main() {}
+ GlobalConst("x", ty.u32(), Expr(64u));
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(1), Expr(Source{{12, 34}}, "x"))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size arguments must be of the same type, "
+ "either i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch2) {
+ // let x = 64u;
+ // let y = 32;
+ // [[stage(compute), workgroup_size(x, y)]
+ // fn main() {}
+ GlobalConst("x", ty.u32(), Expr(64u));
+ GlobalConst("y", ty.i32(), Expr(32));
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr("x"), Expr(Source{{12, 34}}, "y"))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size arguments must be of the same type, "
+ "either i32 or u32");
+}
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Mismatch_ConstU32) {
+ // let x = 4u;
+ // let x = 8u;
+ // [[stage(compute), workgroup_size(x, y, 16]
+ // fn main() {}
+ GlobalConst("x", ty.u32(), Expr(4u));
+ GlobalConst("y", ty.u32(), Expr(8u));
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr("x"), Expr("y"), Expr(Source{{12, 34}}, 16))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size arguments must be of the same type, "
+ "either i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_BadType) {
+ // [[stage(compute), workgroup_size(64.0)]
+ // fn main() {}
+
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{{12, 34}}, 64.f))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size argument must be either literal or "
+ "module-scope constant of type i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_Negative) {
+ // [[stage(compute), workgroup_size(-2)]
+ // fn main() {}
+
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{{12, 34}}, -2))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size argument must be at least 1");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_Zero) {
+ // [[stage(compute), workgroup_size(0)]
+ // fn main() {}
+
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{{12, 34}}, 0))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size argument must be at least 1");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_BadType) {
+ // let x = 64.0;
+ // [[stage(compute), workgroup_size(x)]
+ // fn main() {}
+ GlobalConst("x", ty.f32(), Expr(64.f));
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{{12, 34}}, "x"))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size argument must be either literal or "
+ "module-scope constant of type i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_Negative) {
+ // let x = -2;
+ // [[stage(compute), workgroup_size(x)]
+ // fn main() {}
+ GlobalConst("x", ty.i32(), Expr(-2));
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{{12, 34}}, "x"))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size argument must be at least 1");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_Zero) {
+ // let x = 0;
+ // [[stage(compute), workgroup_size(x)]
+ // fn main() {}
+ GlobalConst("x", ty.i32(), Expr(0));
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{{12, 34}}, "x"))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size argument must be at least 1");
+}
+
+TEST_F(ResolverFunctionValidationTest,
+ WorkgroupSize_Const_NestedZeroValueConstructor) {
+ // let x = i32(i32(i32()));
+ // [[stage(compute), workgroup_size(x)]
+ // fn main() {}
+ GlobalConst("x", ty.i32(),
+ Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32()))));
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{{12, 34}}, "x"))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size argument must be at least 1");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_NonConst) {
+ // var<private> x = 0;
+ // [[stage(compute), workgroup_size(x)]
+ // fn main() {}
+ Global("x", ty.i32(), ast::StorageClass::kPrivate, Expr(64));
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{{12, 34}}, "x"))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size argument must be either literal or "
+ "module-scope constant of type i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr) {
+ // [[stage(compute), workgroup_size(i32(1))]
+ // fn main() {}
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), 1))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size argument must be either a literal or "
+ "a module-scope constant");
+}
+
+TEST_F(ResolverFunctionValidationTest, ReturnIsConstructible_NonPlain) {
+ auto* ret_type =
+ ty.pointer(Source{{12, 34}}, ty.i32(), ast::StorageClass::kFunction);
+ Func("f", {}, ret_type, {});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: function return type must be a constructible type");
+}
+
+TEST_F(ResolverFunctionValidationTest, ReturnIsConstructible_AtomicInt) {
+ auto* ret_type = ty.atomic(Source{{12, 34}}, ty.i32());
+ Func("f", {}, ret_type, {});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: function return type must be a constructible type");
+}
+
+TEST_F(ResolverFunctionValidationTest, ReturnIsConstructible_ArrayOfAtomic) {
+ auto* ret_type = ty.array(Source{{12, 34}}, ty.atomic(ty.i32()), 10);
+ Func("f", {}, ret_type, {});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: function return type must be a constructible type");
+}
+
+TEST_F(ResolverFunctionValidationTest, ReturnIsConstructible_StructOfAtomic) {
+ Structure("S", {Member("m", ty.atomic(ty.i32()))});
+ auto* ret_type = ty.type_name(Source{{12, 34}}, "S");
+ Func("f", {}, ret_type, {});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: function return type must be a constructible type");
+}
+
+TEST_F(ResolverFunctionValidationTest, ReturnIsConstructible_RuntimeArray) {
+ auto* ret_type = ty.array(Source{{12, 34}}, ty.i32());
+ Func("f", {}, ret_type, {});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: function return type must be a constructible type");
+}
+
+TEST_F(ResolverFunctionValidationTest, ParameterStoreType_NonAtomicFree) {
+ Structure("S", {Member("m", ty.atomic(ty.i32()))});
+ auto* ret_type = ty.type_name(Source{{12, 34}}, "S");
+ auto* bar = Param(Source{{12, 34}}, "bar", ret_type);
+ Func("f", ast::VariableList{bar}, ty.void_(), {});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: store type of function parameter must be a "
+ "constructible type");
+}
+
+TEST_F(ResolverFunctionValidationTest, ParameterSotreType_AtomicFree) {
+ Structure("S", {Member("m", ty.i32())});
+ auto* ret_type = ty.type_name(Source{{12, 34}}, "S");
+ auto* bar = Param(Source{{12, 34}}, "bar", ret_type);
+ Func("f", ast::VariableList{bar}, ty.void_(), {});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest, ParametersAtLimit) {
+ ast::VariableList params;
+ for (int i = 0; i < 255; i++) {
+ params.emplace_back(Param("param_" + std::to_string(i), ty.i32()));
+ }
+ Func(Source{{12, 34}}, "f", params, ty.void_(), {});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest, ParametersOverLimit) {
+ ast::VariableList params;
+ for (int i = 0; i < 256; i++) {
+ params.emplace_back(Param("param_" + std::to_string(i), ty.i32()));
+ }
+ Func(Source{{12, 34}}, "f", params, ty.void_(), {});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: functions may declare at most 255 parameters");
+}
+
+TEST_F(ResolverFunctionValidationTest, ParameterVectorNoType) {
+ // fn f(p : vec3) {}
+
+ Func(Source{{12, 34}}, "f",
+ {Param("p", create<ast::Vector>(Source{{12, 34}}, nullptr, 3))},
+ ty.void_(), {});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: missing vector element type");
+}
+
+TEST_F(ResolverFunctionValidationTest, ParameterMatrixNoType) {
+ // fn f(p : vec3) {}
+
+ Func(Source{{12, 34}}, "f",
+ {Param("p", create<ast::Matrix>(Source{{12, 34}}, nullptr, 3, 3))},
+ ty.void_(), {});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: missing matrix element type");
+}
+
+struct TestParams {
+ ast::StorageClass storage_class;
+ bool should_pass;
+};
+
+struct TestWithParams : resolver::ResolverTestWithParam<TestParams> {};
+
+using ResolverFunctionParameterValidationTest = TestWithParams;
+TEST_P(ResolverFunctionParameterValidationTest, StorageClass) {
+ auto& param = GetParam();
+ auto* ptr_type = ty.pointer(Source{{12, 34}}, ty.i32(), param.storage_class);
+ auto* arg = Param(Source{{12, 34}}, "p", ptr_type);
+ Func("f", ast::VariableList{arg}, ty.void_(), {});
+
+ if (param.should_pass) {
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ std::stringstream ss;
+ ss << param.storage_class;
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: function parameter of pointer type cannot be in '" +
+ ss.str() + "' storage class");
+ }
+}
+INSTANTIATE_TEST_SUITE_P(
+ ResolverTest,
+ ResolverFunctionParameterValidationTest,
+ testing::Values(TestParams{ast::StorageClass::kNone, false},
+ TestParams{ast::StorageClass::kInput, false},
+ TestParams{ast::StorageClass::kOutput, false},
+ TestParams{ast::StorageClass::kUniform, false},
+ TestParams{ast::StorageClass::kWorkgroup, true},
+ TestParams{ast::StorageClass::kUniformConstant, false},
+ TestParams{ast::StorageClass::kStorage, false},
+ TestParams{ast::StorageClass::kPrivate, true},
+ TestParams{ast::StorageClass::kFunction, true}));
+
+} // namespace
+} // namespace tint
diff --git a/src/tint/resolver/host_shareable_validation_test.cc b/src/tint/resolver/host_shareable_validation_test.cc
new file mode 100644
index 0000000..f876bfd
--- /dev/null
+++ b/src/tint/resolver/host_shareable_validation_test.cc
@@ -0,0 +1,151 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/ast/struct_block_attribute.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/struct.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ResolverHostShareableValidationTest = ResolverTest;
+
+TEST_F(ResolverHostShareableValidationTest, BoolMember) {
+ auto* s = Structure("S", {Member(Source{{12, 34}}, "x", ty.bool_())},
+ {create<ast::StructBlockAttribute>()});
+
+ Global(Source{{56, 78}}, "g", ty.Of(s), ast::StorageClass::kStorage,
+ ast::Access::kRead,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: Type 'bool' cannot be used in storage class 'storage' as it is non-host-shareable
+12:34 note: while analysing structure member S.x
+56:78 note: while instantiating variable g)");
+}
+
+TEST_F(ResolverHostShareableValidationTest, BoolVectorMember) {
+ auto* s = Structure("S", {Member(Source{{12, 34}}, "x", ty.vec3<bool>())},
+ {create<ast::StructBlockAttribute>()});
+
+ Global(Source{{56, 78}}, "g", ty.Of(s), ast::StorageClass::kStorage,
+ ast::Access::kRead,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: Type 'vec3<bool>' cannot be used in storage class 'storage' as it is non-host-shareable
+12:34 note: while analysing structure member S.x
+56:78 note: while instantiating variable g)");
+}
+
+TEST_F(ResolverHostShareableValidationTest, Aliases) {
+ auto* a1 = Alias("a1", ty.bool_());
+ auto* s = Structure("S", {Member(Source{{12, 34}}, "x", ty.Of(a1))},
+ {create<ast::StructBlockAttribute>()});
+ auto* a2 = Alias("a2", ty.Of(s));
+ Global(Source{{56, 78}}, "g", ty.Of(a2), ast::StorageClass::kStorage,
+ ast::Access::kRead,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: Type 'bool' cannot be used in storage class 'storage' as it is non-host-shareable
+12:34 note: while analysing structure member S.x
+56:78 note: while instantiating variable g)");
+}
+
+TEST_F(ResolverHostShareableValidationTest, NestedStructures) {
+ auto* i1 = Structure("I1", {Member(Source{{1, 2}}, "x", ty.bool_())});
+ auto* i2 = Structure("I2", {Member(Source{{3, 4}}, "y", ty.Of(i1))});
+ auto* i3 = Structure("I3", {Member(Source{{5, 6}}, "z", ty.Of(i2))});
+
+ auto* s = Structure("S", {Member(Source{{7, 8}}, "m", ty.Of(i3))},
+ {create<ast::StructBlockAttribute>()});
+
+ Global(Source{{9, 10}}, "g", ty.Of(s), ast::StorageClass::kStorage,
+ ast::Access::kRead,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(9:10 error: Type 'bool' cannot be used in storage class 'storage' as it is non-host-shareable
+1:2 note: while analysing structure member I1.x
+3:4 note: while analysing structure member I2.y
+5:6 note: while analysing structure member I3.z
+7:8 note: while analysing structure member S.m
+9:10 note: while instantiating variable g)");
+}
+
+TEST_F(ResolverHostShareableValidationTest, NoError) {
+ auto* i1 =
+ Structure("I1", {
+ Member(Source{{1, 1}}, "x1", ty.f32()),
+ Member(Source{{2, 1}}, "y1", ty.vec3<f32>()),
+ Member(Source{{3, 1}}, "z1", ty.array<i32, 4>()),
+ });
+ auto* a1 = Alias("a1", ty.Of(i1));
+ auto* i2 = Structure("I2", {
+ Member(Source{{4, 1}}, "x2", ty.mat2x2<f32>()),
+ Member(Source{{5, 1}}, "y2", ty.Of(i1)),
+ });
+ auto* a2 = Alias("a2", ty.Of(i2));
+ auto* i3 = Structure("I3", {
+ Member(Source{{4, 1}}, "x3", ty.Of(a1)),
+ Member(Source{{5, 1}}, "y3", ty.Of(i2)),
+ Member(Source{{6, 1}}, "z3", ty.Of(a2)),
+ });
+
+ auto* s = Structure("S", {Member(Source{{7, 8}}, "m", ty.Of(i3))},
+ {create<ast::StructBlockAttribute>()});
+
+ Global(Source{{9, 10}}, "g", ty.Of(s), ast::StorageClass::kStorage,
+ ast::Access::kRead,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+ WrapInFunction();
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/inferred_type_test.cc b/src/tint/resolver/inferred_type_test.cc
new file mode 100644
index 0000000..689f813
--- /dev/null
+++ b/src/tint/resolver/inferred_type_test.cc
@@ -0,0 +1,176 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/struct_block_attribute.h"
+#include "src/tint/resolver/resolver.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+
+#include "gmock/gmock.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+// Helpers and typedefs
+template <typename T>
+using DataType = builder::DataType<T>;
+template <typename T>
+using vec2 = builder::vec2<T>;
+template <typename T>
+using vec3 = builder::vec3<T>;
+template <typename T>
+using vec4 = builder::vec4<T>;
+template <typename T>
+using mat2x2 = builder::mat2x2<T>;
+template <typename T>
+using mat3x3 = builder::mat3x3<T>;
+template <typename T>
+using mat4x4 = builder::mat4x4<T>;
+template <typename T>
+using alias = builder::alias<T>;
+using f32 = builder::f32;
+using i32 = builder::i32;
+using u32 = builder::u32;
+
+struct ResolverInferredTypeTest : public resolver::TestHelper,
+ public testing::Test {};
+
+struct Params {
+ builder::ast_expr_func_ptr create_value;
+ builder::sem_type_func_ptr create_expected_type;
+};
+
+template <typename T>
+constexpr Params ParamsFor() {
+ return Params{DataType<T>::Expr, DataType<T>::Sem};
+}
+
+Params all_cases[] = {
+ ParamsFor<bool>(), //
+ ParamsFor<u32>(), //
+ ParamsFor<i32>(), //
+ ParamsFor<f32>(), //
+ ParamsFor<vec3<bool>>(), //
+ ParamsFor<vec3<i32>>(), //
+ ParamsFor<vec3<u32>>(), //
+ ParamsFor<vec3<f32>>(), //
+ ParamsFor<mat3x3<f32>>(), //
+ ParamsFor<alias<bool>>(), //
+ ParamsFor<alias<u32>>(), //
+ ParamsFor<alias<i32>>(), //
+ ParamsFor<alias<f32>>(), //
+ ParamsFor<alias<vec3<bool>>>(), //
+ ParamsFor<alias<vec3<i32>>>(), //
+ ParamsFor<alias<vec3<u32>>>(), //
+ ParamsFor<alias<vec3<f32>>>(), //
+ ParamsFor<alias<mat3x3<f32>>>(), //
+};
+
+using ResolverInferredTypeParamTest = ResolverTestWithParam<Params>;
+
+TEST_P(ResolverInferredTypeParamTest, GlobalLet_Pass) {
+ auto& params = GetParam();
+
+ auto* expected_type = params.create_expected_type(*this);
+
+ // let a = <type constructor>;
+ auto* ctor_expr = params.create_value(*this, 0);
+ auto* var = GlobalConst("a", nullptr, ctor_expr);
+ WrapInFunction();
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(TypeOf(var), expected_type);
+}
+
+TEST_P(ResolverInferredTypeParamTest, GlobalVar_Fail) {
+ auto& params = GetParam();
+
+ // var a = <type constructor>;
+ auto* ctor_expr = params.create_value(*this, 0);
+ Global(Source{{12, 34}}, "a", nullptr, ast::StorageClass::kPrivate,
+ ctor_expr);
+ WrapInFunction();
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: global var declaration must specify a type");
+}
+
+TEST_P(ResolverInferredTypeParamTest, LocalLet_Pass) {
+ auto& params = GetParam();
+
+ auto* expected_type = params.create_expected_type(*this);
+
+ // let a = <type constructor>;
+ auto* ctor_expr = params.create_value(*this, 0);
+ auto* var = Const("a", nullptr, ctor_expr);
+ WrapInFunction(var);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(TypeOf(var), expected_type);
+}
+
+TEST_P(ResolverInferredTypeParamTest, LocalVar_Pass) {
+ auto& params = GetParam();
+
+ auto* expected_type = params.create_expected_type(*this);
+
+ // var a = <type constructor>;
+ auto* ctor_expr = params.create_value(*this, 0);
+ auto* var = Var("a", nullptr, ast::StorageClass::kFunction, ctor_expr);
+ WrapInFunction(var);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(TypeOf(var)->UnwrapRef(), expected_type);
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverTest,
+ ResolverInferredTypeParamTest,
+ testing::ValuesIn(all_cases));
+
+TEST_F(ResolverInferredTypeTest, InferArray_Pass) {
+ auto* type = ty.array(ty.u32(), 10);
+ auto* expected_type =
+ create<sem::Array>(create<sem::U32>(), 10, 4, 4 * 10, 4, 4);
+
+ auto* ctor_expr = Construct(type);
+ auto* var = Var("a", nullptr, ast::StorageClass::kFunction, ctor_expr);
+ WrapInFunction(var);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(TypeOf(var)->UnwrapRef(), expected_type);
+}
+
+TEST_F(ResolverInferredTypeTest, InferStruct_Pass) {
+ auto* member = Member("x", ty.i32());
+ auto* str = Structure("S", {member}, {create<ast::StructBlockAttribute>()});
+
+ auto* expected_type = create<sem::Struct>(
+ str, str->name,
+ sem::StructMemberList{create<sem::StructMember>(
+ member, member->symbol, create<sem::I32>(), 0, 0, 0, 4)},
+ 0, 4, 4);
+
+ auto* ctor_expr = Construct(ty.Of(str));
+
+ auto* var = Var("a", nullptr, ast::StorageClass::kFunction, ctor_expr);
+ WrapInFunction(var);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(TypeOf(var)->UnwrapRef(), expected_type);
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/is_host_shareable_test.cc b/src/tint/resolver/is_host_shareable_test.cc
new file mode 100644
index 0000000..9c992ab
--- /dev/null
+++ b/src/tint/resolver/is_host_shareable_test.cc
@@ -0,0 +1,115 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/atomic_type.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ResolverIsHostShareable = ResolverTest;
+
+TEST_F(ResolverIsHostShareable, Void) {
+ EXPECT_FALSE(r()->IsHostShareable(create<sem::Void>()));
+}
+
+TEST_F(ResolverIsHostShareable, Bool) {
+ EXPECT_FALSE(r()->IsHostShareable(create<sem::Bool>()));
+}
+
+TEST_F(ResolverIsHostShareable, NumericScalar) {
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::I32>()));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::U32>()));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::F32>()));
+}
+
+TEST_F(ResolverIsHostShareable, NumericVector) {
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Vector>(create<sem::I32>(), 2)));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Vector>(create<sem::I32>(), 3)));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Vector>(create<sem::I32>(), 4)));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Vector>(create<sem::U32>(), 2)));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Vector>(create<sem::U32>(), 3)));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Vector>(create<sem::U32>(), 4)));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Vector>(create<sem::F32>(), 2)));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Vector>(create<sem::F32>(), 3)));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Vector>(create<sem::F32>(), 4)));
+}
+
+TEST_F(ResolverIsHostShareable, BoolVector) {
+ EXPECT_FALSE(
+ r()->IsHostShareable(create<sem::Vector>(create<sem::Bool>(), 2)));
+ EXPECT_FALSE(
+ r()->IsHostShareable(create<sem::Vector>(create<sem::Bool>(), 3)));
+ EXPECT_FALSE(
+ r()->IsHostShareable(create<sem::Vector>(create<sem::Bool>(), 4)));
+ EXPECT_FALSE(
+ r()->IsHostShareable(create<sem::Vector>(create<sem::Bool>(), 2)));
+ EXPECT_FALSE(
+ r()->IsHostShareable(create<sem::Vector>(create<sem::Bool>(), 3)));
+ EXPECT_FALSE(
+ r()->IsHostShareable(create<sem::Vector>(create<sem::Bool>(), 4)));
+ EXPECT_FALSE(
+ r()->IsHostShareable(create<sem::Vector>(create<sem::Bool>(), 2)));
+ EXPECT_FALSE(
+ r()->IsHostShareable(create<sem::Vector>(create<sem::Bool>(), 3)));
+ EXPECT_FALSE(
+ r()->IsHostShareable(create<sem::Vector>(create<sem::Bool>(), 4)));
+}
+
+TEST_F(ResolverIsHostShareable, Matrix) {
+ auto* vec2 = create<sem::Vector>(create<sem::F32>(), 2);
+ auto* vec3 = create<sem::Vector>(create<sem::F32>(), 3);
+ auto* vec4 = create<sem::Vector>(create<sem::F32>(), 4);
+
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Matrix>(vec2, 2)));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Matrix>(vec2, 3)));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Matrix>(vec2, 4)));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Matrix>(vec3, 2)));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Matrix>(vec3, 3)));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Matrix>(vec3, 4)));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Matrix>(vec4, 2)));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Matrix>(vec4, 3)));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Matrix>(vec4, 4)));
+}
+
+TEST_F(ResolverIsHostShareable, Pointer) {
+ auto* ptr = create<sem::Pointer>(
+ create<sem::I32>(), ast::StorageClass::kPrivate, ast::Access::kReadWrite);
+ EXPECT_FALSE(r()->IsHostShareable(ptr));
+}
+
+TEST_F(ResolverIsHostShareable, Atomic) {
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Atomic>(create<sem::I32>())));
+ EXPECT_TRUE(r()->IsHostShareable(create<sem::Atomic>(create<sem::U32>())));
+}
+
+TEST_F(ResolverIsHostShareable, ArraySizedOfHostShareable) {
+ auto* arr = create<sem::Array>(create<sem::I32>(), 5, 4, 20, 4, 4);
+ EXPECT_TRUE(r()->IsHostShareable(arr));
+}
+
+TEST_F(ResolverIsHostShareable, ArrayUnsizedOfHostShareable) {
+ auto* arr = create<sem::Array>(create<sem::I32>(), 0, 4, 4, 4, 4);
+ EXPECT_TRUE(r()->IsHostShareable(arr));
+}
+
+// Note: Structure tests covered in host_shareable_validation_test.cc
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/is_storeable_test.cc b/src/tint/resolver/is_storeable_test.cc
new file mode 100644
index 0000000..cc3323a
--- /dev/null
+++ b/src/tint/resolver/is_storeable_test.cc
@@ -0,0 +1,140 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/atomic_type.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ResolverIsStorableTest = ResolverTest;
+
+TEST_F(ResolverIsStorableTest, Void) {
+ EXPECT_FALSE(r()->IsStorable(create<sem::Void>()));
+}
+
+TEST_F(ResolverIsStorableTest, Scalar) {
+ EXPECT_TRUE(r()->IsStorable(create<sem::Bool>()));
+ EXPECT_TRUE(r()->IsStorable(create<sem::I32>()));
+ EXPECT_TRUE(r()->IsStorable(create<sem::U32>()));
+ EXPECT_TRUE(r()->IsStorable(create<sem::F32>()));
+}
+
+TEST_F(ResolverIsStorableTest, Vector) {
+ EXPECT_TRUE(r()->IsStorable(create<sem::Vector>(create<sem::I32>(), 2)));
+ EXPECT_TRUE(r()->IsStorable(create<sem::Vector>(create<sem::I32>(), 3)));
+ EXPECT_TRUE(r()->IsStorable(create<sem::Vector>(create<sem::I32>(), 4)));
+ EXPECT_TRUE(r()->IsStorable(create<sem::Vector>(create<sem::U32>(), 2)));
+ EXPECT_TRUE(r()->IsStorable(create<sem::Vector>(create<sem::U32>(), 3)));
+ EXPECT_TRUE(r()->IsStorable(create<sem::Vector>(create<sem::U32>(), 4)));
+ EXPECT_TRUE(r()->IsStorable(create<sem::Vector>(create<sem::F32>(), 2)));
+ EXPECT_TRUE(r()->IsStorable(create<sem::Vector>(create<sem::F32>(), 3)));
+ EXPECT_TRUE(r()->IsStorable(create<sem::Vector>(create<sem::F32>(), 4)));
+}
+
+TEST_F(ResolverIsStorableTest, Matrix) {
+ auto* vec2 = create<sem::Vector>(create<sem::F32>(), 2);
+ auto* vec3 = create<sem::Vector>(create<sem::F32>(), 3);
+ auto* vec4 = create<sem::Vector>(create<sem::F32>(), 4);
+ EXPECT_TRUE(r()->IsStorable(create<sem::Matrix>(vec2, 2)));
+ EXPECT_TRUE(r()->IsStorable(create<sem::Matrix>(vec2, 3)));
+ EXPECT_TRUE(r()->IsStorable(create<sem::Matrix>(vec2, 4)));
+ EXPECT_TRUE(r()->IsStorable(create<sem::Matrix>(vec3, 2)));
+ EXPECT_TRUE(r()->IsStorable(create<sem::Matrix>(vec3, 3)));
+ EXPECT_TRUE(r()->IsStorable(create<sem::Matrix>(vec3, 4)));
+ EXPECT_TRUE(r()->IsStorable(create<sem::Matrix>(vec4, 2)));
+ EXPECT_TRUE(r()->IsStorable(create<sem::Matrix>(vec4, 3)));
+ EXPECT_TRUE(r()->IsStorable(create<sem::Matrix>(vec4, 4)));
+}
+
+TEST_F(ResolverIsStorableTest, Pointer) {
+ auto* ptr = create<sem::Pointer>(
+ create<sem::I32>(), ast::StorageClass::kPrivate, ast::Access::kReadWrite);
+ EXPECT_FALSE(r()->IsStorable(ptr));
+}
+
+TEST_F(ResolverIsStorableTest, Atomic) {
+ EXPECT_TRUE(r()->IsStorable(create<sem::Atomic>(create<sem::I32>())));
+ EXPECT_TRUE(r()->IsStorable(create<sem::Atomic>(create<sem::U32>())));
+}
+
+TEST_F(ResolverIsStorableTest, ArraySizedOfStorable) {
+ auto* arr = create<sem::Array>(create<sem::I32>(), 5, 4, 20, 4, 4);
+ EXPECT_TRUE(r()->IsStorable(arr));
+}
+
+TEST_F(ResolverIsStorableTest, ArrayUnsizedOfStorable) {
+ auto* arr = create<sem::Array>(create<sem::I32>(), 0, 4, 4, 4, 4);
+ EXPECT_TRUE(r()->IsStorable(arr));
+}
+
+TEST_F(ResolverIsStorableTest, Struct_AllMembersStorable) {
+ Structure("S", {
+ Member("a", ty.i32()),
+ Member("b", ty.f32()),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverIsStorableTest, Struct_SomeMembersNonStorable) {
+ Structure("S", {
+ Member("a", ty.i32()),
+ Member("b", ty.pointer<i32>(ast::StorageClass::kPrivate)),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(error: ptr<private, i32, read_write> cannot be used as the type of a structure member)");
+}
+
+TEST_F(ResolverIsStorableTest, Struct_NestedStorable) {
+ auto* storable = Structure("Storable", {
+ Member("a", ty.i32()),
+ Member("b", ty.f32()),
+ });
+ Structure("S", {
+ Member("a", ty.i32()),
+ Member("b", ty.Of(storable)),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverIsStorableTest, Struct_NestedNonStorable) {
+ auto* non_storable =
+ Structure("nonstorable",
+ {
+ Member("a", ty.i32()),
+ Member("b", ty.pointer<i32>(ast::StorageClass::kPrivate)),
+ });
+ Structure("S", {
+ Member("a", ty.i32()),
+ Member("b", ty.Of(non_storable)),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(error: ptr<private, i32, read_write> cannot be used as the type of a structure member)");
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/pipeline_overridable_constant_test.cc b/src/tint/resolver/pipeline_overridable_constant_test.cc
new file mode 100644
index 0000000..9672174
--- /dev/null
+++ b/src/tint/resolver/pipeline_overridable_constant_test.cc
@@ -0,0 +1,108 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "src/tint/resolver/resolver_test_helper.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+class ResolverPipelineOverridableConstantTest : public ResolverTest {
+ protected:
+ /// Verify that the AST node `var` was resolved to an overridable constant
+ /// with an ID equal to `id`.
+ /// @param var the overridable constant AST node
+ /// @param id the expected constant ID
+ void ExpectConstantId(const ast::Variable* var, uint16_t id) {
+ auto* sem = Sem().Get<sem::GlobalVariable>(var);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->Declaration(), var);
+ EXPECT_TRUE(sem->IsOverridable());
+ EXPECT_EQ(sem->ConstantId(), id);
+ EXPECT_FALSE(sem->ConstantValue());
+ }
+};
+
+TEST_F(ResolverPipelineOverridableConstantTest, NonOverridable) {
+ auto* a = GlobalConst("a", ty.f32(), Expr(1.f));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem_a = Sem().Get<sem::GlobalVariable>(a);
+ ASSERT_NE(sem_a, nullptr);
+ EXPECT_EQ(sem_a->Declaration(), a);
+ EXPECT_FALSE(sem_a->IsOverridable());
+ EXPECT_TRUE(sem_a->ConstantValue());
+}
+
+TEST_F(ResolverPipelineOverridableConstantTest, WithId) {
+ auto* a = Override("a", ty.f32(), Expr(1.f), {Id(7u)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ExpectConstantId(a, 7u);
+}
+
+TEST_F(ResolverPipelineOverridableConstantTest, WithoutId) {
+ auto* a = Override("a", ty.f32(), Expr(1.f));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ExpectConstantId(a, 0u);
+}
+
+TEST_F(ResolverPipelineOverridableConstantTest, WithAndWithoutIds) {
+ std::vector<ast::Variable*> variables;
+ auto* a = Override("a", ty.f32(), Expr(1.f));
+ auto* b = Override("b", ty.f32(), Expr(1.f));
+ auto* c = Override("c", ty.f32(), Expr(1.f), {Id(2u)});
+ auto* d = Override("d", ty.f32(), Expr(1.f), {Id(4u)});
+ auto* e = Override("e", ty.f32(), Expr(1.f));
+ auto* f = Override("f", ty.f32(), Expr(1.f), {Id(1u)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ // Verify that constant id allocation order is deterministic.
+ ExpectConstantId(a, 0u);
+ ExpectConstantId(b, 3u);
+ ExpectConstantId(c, 2u);
+ ExpectConstantId(d, 4u);
+ ExpectConstantId(e, 5u);
+ ExpectConstantId(f, 1u);
+}
+
+TEST_F(ResolverPipelineOverridableConstantTest, DuplicateIds) {
+ Override("a", ty.f32(), Expr(1.f), {Id(Source{{12, 34}}, 7u)});
+ Override("b", ty.f32(), Expr(1.f), {Id(Source{{56, 78}}, 7u)});
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), R"(56:78 error: pipeline constant IDs must be unique
+12:34 note: a pipeline constant with an ID of 7 was previously declared here:)");
+}
+
+TEST_F(ResolverPipelineOverridableConstantTest, IdTooLarge) {
+ Override("a", ty.f32(), Expr(1.f), {Id(Source{{12, 34}}, 65536u)});
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "12:34 error: pipeline constant IDs must be between 0 and 65535");
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/ptr_ref_test.cc b/src/tint/resolver/ptr_ref_test.cc
new file mode 100644
index 0000000..fa26304
--- /dev/null
+++ b/src/tint/resolver/ptr_ref_test.cc
@@ -0,0 +1,126 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/struct_block_attribute.h"
+#include "src/tint/resolver/resolver.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/reference_type.h"
+
+#include "gmock/gmock.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+struct ResolverPtrRefTest : public resolver::TestHelper,
+ public testing::Test {};
+
+TEST_F(ResolverPtrRefTest, AddressOf) {
+ // var v : i32;
+ // &v
+
+ auto* v = Var("v", ty.i32(), ast::StorageClass::kNone);
+ auto* expr = AddressOf(v);
+
+ WrapInFunction(v, expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::Pointer>());
+ EXPECT_TRUE(TypeOf(expr)->As<sem::Pointer>()->StoreType()->Is<sem::I32>());
+ EXPECT_EQ(TypeOf(expr)->As<sem::Pointer>()->StorageClass(),
+ ast::StorageClass::kFunction);
+}
+
+TEST_F(ResolverPtrRefTest, AddressOfThenDeref) {
+ // var v : i32;
+ // *(&v)
+
+ auto* v = Var("v", ty.i32(), ast::StorageClass::kNone);
+ auto* expr = Deref(AddressOf(v));
+
+ WrapInFunction(v, expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::Reference>());
+ EXPECT_TRUE(TypeOf(expr)->As<sem::Reference>()->StoreType()->Is<sem::I32>());
+}
+
+TEST_F(ResolverPtrRefTest, DefaultPtrStorageClass) {
+ // https://gpuweb.github.io/gpuweb/wgsl/#storage-class
+
+ auto* buf = Structure("S", {Member("m", ty.i32())},
+ {create<ast::StructBlockAttribute>()});
+ auto* function = Var("f", ty.i32());
+ auto* private_ = Global("p", ty.i32(), ast::StorageClass::kPrivate);
+ auto* workgroup = Global("w", ty.i32(), ast::StorageClass::kWorkgroup);
+ auto* uniform = Global("ub", ty.Of(buf), ast::StorageClass::kUniform,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+ auto* storage = Global("sb", ty.Of(buf), ast::StorageClass::kStorage,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(1),
+ create<ast::GroupAttribute>(0),
+ });
+
+ auto* function_ptr =
+ Const("f_ptr", ty.pointer(ty.i32(), ast::StorageClass::kFunction),
+ AddressOf(function));
+ auto* private_ptr =
+ Const("p_ptr", ty.pointer(ty.i32(), ast::StorageClass::kPrivate),
+ AddressOf(private_));
+ auto* workgroup_ptr =
+ Const("w_ptr", ty.pointer(ty.i32(), ast::StorageClass::kWorkgroup),
+ AddressOf(workgroup));
+ auto* uniform_ptr =
+ Const("ub_ptr", ty.pointer(ty.Of(buf), ast::StorageClass::kUniform),
+ AddressOf(uniform));
+ auto* storage_ptr =
+ Const("sb_ptr", ty.pointer(ty.Of(buf), ast::StorageClass::kStorage),
+ AddressOf(storage));
+
+ WrapInFunction(function, function_ptr, private_ptr, workgroup_ptr,
+ uniform_ptr, storage_ptr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_TRUE(TypeOf(function_ptr)->Is<sem::Pointer>())
+ << "function_ptr is " << TypeOf(function_ptr)->TypeInfo().name;
+ ASSERT_TRUE(TypeOf(private_ptr)->Is<sem::Pointer>())
+ << "private_ptr is " << TypeOf(private_ptr)->TypeInfo().name;
+ ASSERT_TRUE(TypeOf(workgroup_ptr)->Is<sem::Pointer>())
+ << "workgroup_ptr is " << TypeOf(workgroup_ptr)->TypeInfo().name;
+ ASSERT_TRUE(TypeOf(uniform_ptr)->Is<sem::Pointer>())
+ << "uniform_ptr is " << TypeOf(uniform_ptr)->TypeInfo().name;
+ ASSERT_TRUE(TypeOf(storage_ptr)->Is<sem::Pointer>())
+ << "storage_ptr is " << TypeOf(storage_ptr)->TypeInfo().name;
+
+ EXPECT_EQ(TypeOf(function_ptr)->As<sem::Pointer>()->Access(),
+ ast::Access::kReadWrite);
+ EXPECT_EQ(TypeOf(private_ptr)->As<sem::Pointer>()->Access(),
+ ast::Access::kReadWrite);
+ EXPECT_EQ(TypeOf(workgroup_ptr)->As<sem::Pointer>()->Access(),
+ ast::Access::kReadWrite);
+ EXPECT_EQ(TypeOf(uniform_ptr)->As<sem::Pointer>()->Access(),
+ ast::Access::kRead);
+ EXPECT_EQ(TypeOf(storage_ptr)->As<sem::Pointer>()->Access(),
+ ast::Access::kRead);
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/ptr_ref_validation_test.cc b/src/tint/resolver/ptr_ref_validation_test.cc
new file mode 100644
index 0000000..2bec6ef
--- /dev/null
+++ b/src/tint/resolver/ptr_ref_validation_test.cc
@@ -0,0 +1,176 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/bitcast_expression.h"
+#include "src/tint/ast/struct_block_attribute.h"
+#include "src/tint/resolver/resolver.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/reference_type.h"
+
+#include "gmock/gmock.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+struct ResolverPtrRefValidationTest : public resolver::TestHelper,
+ public testing::Test {};
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfLiteral) {
+ // &1
+
+ auto* expr = AddressOf(Expr(Source{{12, 34}}, 1));
+
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
+}
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfLet) {
+ // let l : i32 = 1;
+ // &l
+ auto* l = Const("l", ty.i32(), Expr(1));
+ auto* expr = AddressOf(Expr(Source{{12, 34}}, "l"));
+
+ WrapInFunction(l, expr);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
+}
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfHandle) {
+ // @group(0) @binding(0) var t: texture_3d<f32>;
+ // &t
+ Global("t", ty.sampled_texture(ast::TextureDimension::k3d, ty.f32()),
+ GroupAndBinding(0u, 0u));
+ auto* expr = AddressOf(Expr(Source{{12, 34}}, "t"));
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cannot take the address of expression in handle "
+ "storage class");
+}
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfVectorComponent_MemberAccessor) {
+ // var v : vec4<i32>;
+ // &v.y
+ auto* v = Var("v", ty.vec4<i32>());
+ auto* expr = AddressOf(MemberAccessor(Source{{12, 34}}, "v", "y"));
+
+ WrapInFunction(v, expr);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cannot take the address of a vector component");
+}
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfVectorComponent_IndexAccessor) {
+ // var v : vec4<i32>;
+ // &v[2]
+ auto* v = Var("v", ty.vec4<i32>());
+ auto* expr = AddressOf(IndexAccessor(Source{{12, 34}}, "v", 2));
+
+ WrapInFunction(v, expr);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cannot take the address of a vector component");
+}
+
+TEST_F(ResolverPtrRefValidationTest, IndirectOfAddressOfHandle) {
+ // @group(0) @binding(0) var t: texture_3d<f32>;
+ // *&t
+ Global("t", ty.sampled_texture(ast::TextureDimension::k3d, ty.f32()),
+ GroupAndBinding(0u, 0u));
+ auto* expr = Deref(AddressOf(Expr(Source{{12, 34}}, "t")));
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cannot take the address of expression in handle "
+ "storage class");
+}
+
+TEST_F(ResolverPtrRefValidationTest, DerefOfLiteral) {
+ // *1
+
+ auto* expr = Deref(Expr(Source{{12, 34}}, 1));
+
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cannot dereference expression of type 'i32'");
+}
+
+TEST_F(ResolverPtrRefValidationTest, DerefOfVar) {
+ // var v : i32 = 1;
+ // *1
+ auto* v = Var("v", ty.i32());
+ auto* expr = Deref(Expr(Source{{12, 34}}, "v"));
+
+ WrapInFunction(v, expr);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cannot dereference expression of type 'i32'");
+}
+
+TEST_F(ResolverPtrRefValidationTest, InferredPtrAccessMismatch) {
+ // struct Inner {
+ // arr: array<i32, 4>;
+ // }
+ // [[block]] struct S {
+ // inner: Inner;
+ // }
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ // fn f() {
+ // let p : pointer<storage, i32> = &s.inner.arr[2];
+ // }
+ auto* inner = Structure("Inner", {Member("arr", ty.array<i32, 4>())});
+ auto* buf = Structure("S", {Member("inner", ty.Of(inner))},
+ {create<ast::StructBlockAttribute>()});
+ auto* storage = Global("s", ty.Of(buf), ast::StorageClass::kStorage,
+ ast::Access::kReadWrite,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ auto* expr =
+ IndexAccessor(MemberAccessor(MemberAccessor(storage, "inner"), "arr"), 4);
+ auto* ptr =
+ Const(Source{{12, 34}}, "p", ty.pointer<i32>(ast::StorageClass::kStorage),
+ AddressOf(expr));
+
+ WrapInFunction(ptr);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cannot initialize let of type "
+ "'ptr<storage, i32, read>' with value of type "
+ "'ptr<storage, i32, read_write>'");
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
new file mode 100644
index 0000000..42f8cb1
--- /dev/null
+++ b/src/tint/resolver/resolver.cc
@@ -0,0 +1,2917 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include <algorithm>
+#include <cmath>
+#include <iomanip>
+#include <limits>
+#include <utility>
+
+#include "src/tint/ast/alias.h"
+#include "src/tint/ast/array.h"
+#include "src/tint/ast/assignment_statement.h"
+#include "src/tint/ast/bitcast_expression.h"
+#include "src/tint/ast/break_statement.h"
+#include "src/tint/ast/call_statement.h"
+#include "src/tint/ast/continue_statement.h"
+#include "src/tint/ast/depth_texture.h"
+#include "src/tint/ast/disable_validation_attribute.h"
+#include "src/tint/ast/discard_statement.h"
+#include "src/tint/ast/fallthrough_statement.h"
+#include "src/tint/ast/for_loop_statement.h"
+#include "src/tint/ast/id_attribute.h"
+#include "src/tint/ast/if_statement.h"
+#include "src/tint/ast/internal_attribute.h"
+#include "src/tint/ast/interpolate_attribute.h"
+#include "src/tint/ast/loop_statement.h"
+#include "src/tint/ast/matrix.h"
+#include "src/tint/ast/pointer.h"
+#include "src/tint/ast/return_statement.h"
+#include "src/tint/ast/sampled_texture.h"
+#include "src/tint/ast/sampler.h"
+#include "src/tint/ast/storage_texture.h"
+#include "src/tint/ast/switch_statement.h"
+#include "src/tint/ast/traverse_expressions.h"
+#include "src/tint/ast/type_name.h"
+#include "src/tint/ast/unary_op_expression.h"
+#include "src/tint/ast/variable_decl_statement.h"
+#include "src/tint/ast/vector.h"
+#include "src/tint/ast/workgroup_attribute.h"
+#include "src/tint/sem/array.h"
+#include "src/tint/sem/atomic_type.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/depth_multisampled_texture_type.h"
+#include "src/tint/sem/depth_texture_type.h"
+#include "src/tint/sem/for_loop_statement.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/if_statement.h"
+#include "src/tint/sem/loop_statement.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/module.h"
+#include "src/tint/sem/multisampled_texture_type.h"
+#include "src/tint/sem/pointer_type.h"
+#include "src/tint/sem/reference_type.h"
+#include "src/tint/sem/sampled_texture_type.h"
+#include "src/tint/sem/sampler_type.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/storage_texture_type.h"
+#include "src/tint/sem/struct.h"
+#include "src/tint/sem/switch_statement.h"
+#include "src/tint/sem/type_constructor.h"
+#include "src/tint/sem/type_conversion.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/utils/defer.h"
+#include "src/tint/utils/math.h"
+#include "src/tint/utils/reverse.h"
+#include "src/tint/utils/scoped_assignment.h"
+#include "src/tint/utils/transform.h"
+
+namespace tint {
+namespace resolver {
+
+Resolver::Resolver(ProgramBuilder* builder)
+ : builder_(builder),
+ diagnostics_(builder->Diagnostics()),
+ builtin_table_(BuiltinTable::Create(*builder)) {}
+
+Resolver::~Resolver() = default;
+
+bool Resolver::Resolve() {
+ if (builder_->Diagnostics().contains_errors()) {
+ return false;
+ }
+
+ if (!DependencyGraph::Build(builder_->AST(), builder_->Symbols(),
+ builder_->Diagnostics(), dependencies_)) {
+ return false;
+ }
+
+ // Create the semantic module
+ builder_->Sem().SetModule(
+ builder_->create<sem::Module>(dependencies_.ordered_globals));
+
+ bool result = ResolveInternal();
+
+ if (!result && !diagnostics_.contains_errors()) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "resolving failed, but no error was raised";
+ return false;
+ }
+
+ return result;
+}
+
+bool Resolver::ResolveInternal() {
+ Mark(&builder_->AST());
+
+ // Process all module-scope declarations in dependency order.
+ for (auto* decl : dependencies_.ordered_globals) {
+ Mark(decl);
+ if (!Switch(
+ decl, //
+ [&](const ast::TypeDecl* td) { //
+ return TypeDecl(td) != nullptr;
+ },
+ [&](const ast::Function* func) {
+ return Function(func) != nullptr;
+ },
+ [&](const ast::Variable* var) {
+ return GlobalVariable(var) != nullptr;
+ },
+ [&](Default) {
+ TINT_UNREACHABLE(Resolver, diagnostics_)
+ << "unhandled global declaration: " << decl->TypeInfo().name;
+ return false;
+ })) {
+ return false;
+ }
+ }
+
+ AllocateOverridableConstantIds();
+
+ SetShadows();
+
+ if (!ValidatePipelineStages()) {
+ return false;
+ }
+
+ bool result = true;
+ for (auto* node : builder_->ASTNodes().Objects()) {
+ if (marked_.count(node) == 0) {
+ TINT_ICE(Resolver, diagnostics_) << "AST node '" << node->TypeInfo().name
+ << "' was not reached by the resolver\n"
+ << "At: " << node->source << "\n"
+ << "Pointer: " << node;
+ result = false;
+ }
+ }
+
+ return result;
+}
+
+sem::Type* Resolver::Type(const ast::Type* ty) {
+ Mark(ty);
+ auto* s = Switch(
+ ty,
+ [&](const ast::Void*) -> sem::Type* {
+ return builder_->create<sem::Void>();
+ },
+ [&](const ast::Bool*) -> sem::Type* {
+ return builder_->create<sem::Bool>();
+ },
+ [&](const ast::I32*) -> sem::Type* {
+ return builder_->create<sem::I32>();
+ },
+ [&](const ast::U32*) -> sem::Type* {
+ return builder_->create<sem::U32>();
+ },
+ [&](const ast::F32*) -> sem::Type* {
+ return builder_->create<sem::F32>();
+ },
+ [&](const ast::Vector* t) -> sem::Type* {
+ if (!t->type) {
+ AddError("missing vector element type", t->source.End());
+ return nullptr;
+ }
+ if (auto* el = Type(t->type)) {
+ if (auto* vector = builder_->create<sem::Vector>(el, t->width)) {
+ if (ValidateVector(vector, t->source)) {
+ return vector;
+ }
+ }
+ }
+ return nullptr;
+ },
+ [&](const ast::Matrix* t) -> sem::Type* {
+ if (!t->type) {
+ AddError("missing matrix element type", t->source.End());
+ return nullptr;
+ }
+ if (auto* el = Type(t->type)) {
+ if (auto* column_type = builder_->create<sem::Vector>(el, t->rows)) {
+ if (auto* matrix =
+ builder_->create<sem::Matrix>(column_type, t->columns)) {
+ if (ValidateMatrix(matrix, t->source)) {
+ return matrix;
+ }
+ }
+ }
+ }
+ return nullptr;
+ },
+ [&](const ast::Array* t) -> sem::Type* { return Array(t); },
+ [&](const ast::Atomic* t) -> sem::Type* {
+ if (auto* el = Type(t->type)) {
+ auto* a = builder_->create<sem::Atomic>(el);
+ if (!ValidateAtomic(t, a)) {
+ return nullptr;
+ }
+ return a;
+ }
+ return nullptr;
+ },
+ [&](const ast::Pointer* t) -> sem::Type* {
+ if (auto* el = Type(t->type)) {
+ auto access = t->access;
+ if (access == ast::kUndefined) {
+ access = DefaultAccessForStorageClass(t->storage_class);
+ }
+ return builder_->create<sem::Pointer>(el, t->storage_class, access);
+ }
+ return nullptr;
+ },
+ [&](const ast::Sampler* t) -> sem::Type* {
+ return builder_->create<sem::Sampler>(t->kind);
+ },
+ [&](const ast::SampledTexture* t) -> sem::Type* {
+ if (auto* el = Type(t->type)) {
+ return builder_->create<sem::SampledTexture>(t->dim, el);
+ }
+ return nullptr;
+ },
+ [&](const ast::MultisampledTexture* t) -> sem::Type* {
+ if (auto* el = Type(t->type)) {
+ return builder_->create<sem::MultisampledTexture>(t->dim, el);
+ }
+ return nullptr;
+ },
+ [&](const ast::DepthTexture* t) -> sem::Type* {
+ return builder_->create<sem::DepthTexture>(t->dim);
+ },
+ [&](const ast::DepthMultisampledTexture* t) -> sem::Type* {
+ return builder_->create<sem::DepthMultisampledTexture>(t->dim);
+ },
+ [&](const ast::StorageTexture* t) -> sem::Type* {
+ if (auto* el = Type(t->type)) {
+ if (!ValidateStorageTexture(t)) {
+ return nullptr;
+ }
+ return builder_->create<sem::StorageTexture>(t->dim, t->format,
+ t->access, el);
+ }
+ return nullptr;
+ },
+ [&](const ast::ExternalTexture*) -> sem::Type* {
+ return builder_->create<sem::ExternalTexture>();
+ },
+ [&](Default) -> sem::Type* {
+ auto* resolved = ResolvedSymbol(ty);
+ return Switch(
+ resolved, //
+ [&](sem::Type* type) { return type; },
+ [&](sem::Variable* var) {
+ auto name =
+ builder_->Symbols().NameFor(var->Declaration()->symbol);
+ AddError("cannot use variable '" + name + "' as type",
+ ty->source);
+ AddNote("'" + name + "' declared here",
+ var->Declaration()->source);
+ return nullptr;
+ },
+ [&](sem::Function* func) {
+ auto name =
+ builder_->Symbols().NameFor(func->Declaration()->symbol);
+ AddError("cannot use function '" + name + "' as type",
+ ty->source);
+ AddNote("'" + name + "' declared here",
+ func->Declaration()->source);
+ return nullptr;
+ },
+ [&](Default) {
+ TINT_UNREACHABLE(Resolver, diagnostics_)
+ << "Unhandled resolved type '"
+ << (resolved ? resolved->TypeInfo().name : "<null>")
+ << "' resolved from ast::Type '" << ty->TypeInfo().name
+ << "'";
+ return nullptr;
+ });
+ });
+
+ if (s) {
+ builder_->Sem().Add(ty, s);
+ }
+ return s;
+}
+
+sem::Variable* Resolver::Variable(const ast::Variable* var,
+ VariableKind kind,
+ uint32_t index /* = 0 */) {
+ const sem::Type* storage_ty = nullptr;
+
+ // If the variable has a declared type, resolve it.
+ if (auto* ty = var->type) {
+ storage_ty = Type(ty);
+ if (!storage_ty) {
+ return nullptr;
+ }
+ }
+
+ const sem::Expression* rhs = nullptr;
+
+ // Does the variable have a constructor?
+ if (var->constructor) {
+ rhs = Expression(var->constructor);
+ if (!rhs) {
+ return nullptr;
+ }
+
+ // If the variable has no declared type, infer it from the RHS
+ if (!storage_ty) {
+ if (!var->is_const && kind == VariableKind::kGlobal) {
+ AddError("global var declaration must specify a type", var->source);
+ return nullptr;
+ }
+
+ storage_ty = rhs->Type()->UnwrapRef(); // Implicit load of RHS
+ }
+ } else if (var->is_const && !var->is_overridable &&
+ kind != VariableKind::kParameter) {
+ AddError("let declaration must have an initializer", var->source);
+ return nullptr;
+ } else if (!var->type) {
+ AddError(
+ (kind == VariableKind::kGlobal)
+ ? "module scope var declaration requires a type and initializer"
+ : "function scope var declaration requires a type or initializer",
+ var->source);
+ return nullptr;
+ }
+
+ if (!storage_ty) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "failed to determine storage type for variable '" +
+ builder_->Symbols().NameFor(var->symbol) + "'\n"
+ << "Source: " << var->source;
+ return nullptr;
+ }
+
+ auto storage_class = var->declared_storage_class;
+ if (storage_class == ast::StorageClass::kNone && !var->is_const) {
+ // No declared storage class. Infer from usage / type.
+ if (kind == VariableKind::kLocal) {
+ storage_class = ast::StorageClass::kFunction;
+ } else if (storage_ty->UnwrapRef()->is_handle()) {
+ // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
+ // If the store type is a texture type or a sampler type, then the
+ // variable declaration must not have a storage class attribute. The
+ // storage class will always be handle.
+ storage_class = ast::StorageClass::kUniformConstant;
+ }
+ }
+
+ if (kind == VariableKind::kLocal && !var->is_const &&
+ storage_class != ast::StorageClass::kFunction &&
+ IsValidationEnabled(var->attributes,
+ ast::DisabledValidation::kIgnoreStorageClass)) {
+ AddError("function variable has a non-function storage class", var->source);
+ return nullptr;
+ }
+
+ auto access = var->declared_access;
+ if (access == ast::Access::kUndefined) {
+ access = DefaultAccessForStorageClass(storage_class);
+ }
+
+ auto* var_ty = storage_ty;
+ if (!var->is_const) {
+ // Variable declaration. Unlike `let`, `var` has storage.
+ // Variables are always of a reference type to the declared storage type.
+ var_ty =
+ builder_->create<sem::Reference>(storage_ty, storage_class, access);
+ }
+
+ if (rhs && !ValidateVariableConstructorOrCast(var, storage_class, storage_ty,
+ rhs->Type())) {
+ return nullptr;
+ }
+
+ if (!ApplyStorageClassUsageToType(
+ storage_class, const_cast<sem::Type*>(var_ty), var->source)) {
+ AddNote(
+ std::string("while instantiating ") +
+ ((kind == VariableKind::kParameter) ? "parameter " : "variable ") +
+ builder_->Symbols().NameFor(var->symbol),
+ var->source);
+ return nullptr;
+ }
+
+ if (kind == VariableKind::kParameter) {
+ if (auto* ptr = var_ty->As<sem::Pointer>()) {
+ // For MSL, we push module-scope variables into the entry point as pointer
+ // parameters, so we also need to handle their store type.
+ if (!ApplyStorageClassUsageToType(
+ ptr->StorageClass(), const_cast<sem::Type*>(ptr->StoreType()),
+ var->source)) {
+ AddNote("while instantiating parameter " +
+ builder_->Symbols().NameFor(var->symbol),
+ var->source);
+ return nullptr;
+ }
+ }
+ }
+
+ switch (kind) {
+ case VariableKind::kGlobal: {
+ sem::BindingPoint binding_point;
+ if (auto bp = var->BindingPoint()) {
+ binding_point = {bp.group->value, bp.binding->value};
+ }
+
+ bool has_const_val = rhs && var->is_const && !var->is_overridable;
+ auto* global = builder_->create<sem::GlobalVariable>(
+ var, var_ty, storage_class, access,
+ has_const_val ? rhs->ConstantValue() : sem::Constant{},
+ binding_point);
+
+ if (var->is_overridable) {
+ global->SetIsOverridable();
+ if (auto* id = ast::GetAttribute<ast::IdAttribute>(var->attributes)) {
+ global->SetConstantId(static_cast<uint16_t>(id->value));
+ }
+ }
+
+ global->SetConstructor(rhs);
+
+ builder_->Sem().Add(var, global);
+ return global;
+ }
+ case VariableKind::kLocal: {
+ auto* local = builder_->create<sem::LocalVariable>(
+ var, var_ty, storage_class, access, current_statement_,
+ (rhs && var->is_const) ? rhs->ConstantValue() : sem::Constant{});
+ builder_->Sem().Add(var, local);
+ local->SetConstructor(rhs);
+ return local;
+ }
+ case VariableKind::kParameter: {
+ auto* param = builder_->create<sem::Parameter>(var, index, var_ty,
+ storage_class, access);
+ builder_->Sem().Add(var, param);
+ return param;
+ }
+ }
+
+ TINT_UNREACHABLE(Resolver, diagnostics_)
+ << "unhandled VariableKind " << static_cast<int>(kind);
+ return nullptr;
+}
+
+ast::Access Resolver::DefaultAccessForStorageClass(
+ ast::StorageClass storage_class) {
+ // https://gpuweb.github.io/gpuweb/wgsl/#storage-class
+ switch (storage_class) {
+ case ast::StorageClass::kStorage:
+ case ast::StorageClass::kUniform:
+ case ast::StorageClass::kUniformConstant:
+ return ast::Access::kRead;
+ default:
+ break;
+ }
+ return ast::Access::kReadWrite;
+}
+
+void Resolver::AllocateOverridableConstantIds() {
+ // The next pipeline constant ID to try to allocate.
+ uint16_t next_constant_id = 0;
+
+ // Allocate constant IDs in global declaration order, so that they are
+ // deterministic.
+ // TODO(crbug.com/tint/1192): If a transform changes the order or removes an
+ // unused constant, the allocation may change on the next Resolver pass.
+ for (auto* decl : builder_->AST().GlobalDeclarations()) {
+ auto* var = decl->As<ast::Variable>();
+ if (!var || !var->is_overridable) {
+ continue;
+ }
+
+ uint16_t constant_id;
+ if (auto* id_attr = ast::GetAttribute<ast::IdAttribute>(var->attributes)) {
+ constant_id = static_cast<uint16_t>(id_attr->value);
+ } else {
+ // No ID was specified, so allocate the next available ID.
+ constant_id = next_constant_id;
+ while (constant_ids_.count(constant_id)) {
+ if (constant_id == UINT16_MAX) {
+ TINT_ICE(Resolver, builder_->Diagnostics())
+ << "no more pipeline constant IDs available";
+ return;
+ }
+ constant_id++;
+ }
+ next_constant_id = constant_id + 1;
+ }
+
+ auto* sem = Sem<sem::GlobalVariable>(var);
+ const_cast<sem::GlobalVariable*>(sem)->SetConstantId(constant_id);
+ }
+}
+
+void Resolver::SetShadows() {
+ for (auto it : dependencies_.shadows) {
+ Switch(
+ Sem(it.first), //
+ [&](sem::LocalVariable* local) { local->SetShadows(Sem(it.second)); },
+ [&](sem::Parameter* param) { param->SetShadows(Sem(it.second)); });
+ }
+} // namespace resolver
+
+sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* var) {
+ auto* sem = Variable(var, VariableKind::kGlobal);
+ if (!sem) {
+ return nullptr;
+ }
+
+ auto storage_class = sem->StorageClass();
+ if (!var->is_const && storage_class == ast::StorageClass::kNone) {
+ AddError("global variables must have a storage class", var->source);
+ return nullptr;
+ }
+ if (var->is_const && storage_class != ast::StorageClass::kNone) {
+ AddError("global constants shouldn't have a storage class", var->source);
+ return nullptr;
+ }
+
+ for (auto* attr : var->attributes) {
+ Mark(attr);
+
+ if (auto* id_attr = attr->As<ast::IdAttribute>()) {
+ // Track the constant IDs that are specified in the shader.
+ constant_ids_.emplace(id_attr->value, sem);
+ }
+ }
+
+ if (!ValidateNoDuplicateAttributes(var->attributes)) {
+ return nullptr;
+ }
+
+ if (!ValidateGlobalVariable(sem)) {
+ return nullptr;
+ }
+
+ // TODO(bclayton): Call this at the end of resolve on all uniform and storage
+ // referenced structs
+ if (!ValidateStorageClassLayout(sem)) {
+ return nullptr;
+ }
+
+ return sem->As<sem::GlobalVariable>();
+}
+
+sem::Function* Resolver::Function(const ast::Function* decl) {
+ uint32_t parameter_index = 0;
+ std::unordered_map<Symbol, Source> parameter_names;
+ std::vector<sem::Parameter*> parameters;
+
+ // Resolve all the parameters
+ for (auto* param : decl->params) {
+ Mark(param);
+
+ { // Check the parameter name is unique for the function
+ auto emplaced = parameter_names.emplace(param->symbol, param->source);
+ if (!emplaced.second) {
+ auto name = builder_->Symbols().NameFor(param->symbol);
+ AddError("redefinition of parameter '" + name + "'", param->source);
+ AddNote("previous definition is here", emplaced.first->second);
+ return nullptr;
+ }
+ }
+
+ auto* var = As<sem::Parameter>(
+ Variable(param, VariableKind::kParameter, parameter_index++));
+ if (!var) {
+ return nullptr;
+ }
+
+ for (auto* attr : param->attributes) {
+ Mark(attr);
+ }
+ if (!ValidateNoDuplicateAttributes(param->attributes)) {
+ return nullptr;
+ }
+
+ parameters.emplace_back(var);
+
+ auto* var_ty = const_cast<sem::Type*>(var->Type());
+ if (auto* str = var_ty->As<sem::Struct>()) {
+ switch (decl->PipelineStage()) {
+ case ast::PipelineStage::kVertex:
+ str->AddUsage(sem::PipelineStageUsage::kVertexInput);
+ break;
+ case ast::PipelineStage::kFragment:
+ str->AddUsage(sem::PipelineStageUsage::kFragmentInput);
+ break;
+ case ast::PipelineStage::kCompute:
+ str->AddUsage(sem::PipelineStageUsage::kComputeInput);
+ break;
+ case ast::PipelineStage::kNone:
+ break;
+ }
+ }
+ }
+
+ // Resolve the return type
+ sem::Type* return_type = nullptr;
+ if (auto* ty = decl->return_type) {
+ return_type = Type(ty);
+ if (!return_type) {
+ return nullptr;
+ }
+ } else {
+ return_type = builder_->create<sem::Void>();
+ }
+
+ if (auto* str = return_type->As<sem::Struct>()) {
+ if (!ApplyStorageClassUsageToType(ast::StorageClass::kNone, str,
+ decl->source)) {
+ AddNote("while instantiating return type for " +
+ builder_->Symbols().NameFor(decl->symbol),
+ decl->source);
+ return nullptr;
+ }
+
+ switch (decl->PipelineStage()) {
+ case ast::PipelineStage::kVertex:
+ str->AddUsage(sem::PipelineStageUsage::kVertexOutput);
+ break;
+ case ast::PipelineStage::kFragment:
+ str->AddUsage(sem::PipelineStageUsage::kFragmentOutput);
+ break;
+ case ast::PipelineStage::kCompute:
+ str->AddUsage(sem::PipelineStageUsage::kComputeOutput);
+ break;
+ case ast::PipelineStage::kNone:
+ break;
+ }
+ }
+
+ auto* func = builder_->create<sem::Function>(decl, return_type, parameters);
+ builder_->Sem().Add(decl, func);
+
+ TINT_SCOPED_ASSIGNMENT(current_function_, func);
+
+ if (!WorkgroupSize(decl)) {
+ return nullptr;
+ }
+
+ if (decl->IsEntryPoint()) {
+ entry_points_.emplace_back(func);
+ }
+
+ if (decl->body) {
+ Mark(decl->body);
+ if (current_compound_statement_) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "Resolver::Function() called with a current compound statement";
+ return nullptr;
+ }
+ auto* body = StatementScope(
+ decl->body, builder_->create<sem::FunctionBlockStatement>(func),
+ [&] { return Statements(decl->body->statements); });
+ if (!body) {
+ return nullptr;
+ }
+ func->Behaviors() = body->Behaviors();
+ if (func->Behaviors().Contains(sem::Behavior::kReturn)) {
+ // https://www.w3.org/TR/WGSL/#behaviors-rules
+ // We assign a behavior to each function: it is its body’s behavior
+ // (treating the body as a regular statement), with any "Return" replaced
+ // by "Next".
+ func->Behaviors().Remove(sem::Behavior::kReturn);
+ func->Behaviors().Add(sem::Behavior::kNext);
+ }
+ }
+
+ for (auto* attr : decl->attributes) {
+ Mark(attr);
+ }
+ if (!ValidateNoDuplicateAttributes(decl->attributes)) {
+ return nullptr;
+ }
+
+ for (auto* attr : decl->return_type_attributes) {
+ Mark(attr);
+ }
+ if (!ValidateNoDuplicateAttributes(decl->return_type_attributes)) {
+ return nullptr;
+ }
+
+ if (!ValidateFunction(func)) {
+ return nullptr;
+ }
+
+ // If this is an entry point, mark all transitively called functions as being
+ // used by this entry point.
+ if (decl->IsEntryPoint()) {
+ for (auto* f : func->TransitivelyCalledFunctions()) {
+ const_cast<sem::Function*>(f)->AddAncestorEntryPoint(func);
+ }
+ }
+
+ return func;
+}
+
+bool Resolver::WorkgroupSize(const ast::Function* func) {
+ // Set work-group size defaults.
+ sem::WorkgroupSize ws;
+ for (int i = 0; i < 3; i++) {
+ ws[i].value = 1;
+ ws[i].overridable_const = nullptr;
+ }
+
+ auto* attr = ast::GetAttribute<ast::WorkgroupAttribute>(func->attributes);
+ if (!attr) {
+ return true;
+ }
+
+ auto values = attr->Values();
+ auto any_i32 = false;
+ auto any_u32 = false;
+ for (int i = 0; i < 3; i++) {
+ // Each argument to this attribute can either be a literal, an
+ // identifier for a module-scope constants, or nullptr if not specified.
+
+ auto* expr = values[i];
+ if (!expr) {
+ // Not specified, just use the default.
+ continue;
+ }
+
+ auto* expr_sem = Expression(expr);
+ if (!expr_sem) {
+ return false;
+ }
+
+ constexpr const char* kErrBadType =
+ "workgroup_size argument must be either literal or module-scope "
+ "constant of type i32 or u32";
+ constexpr const char* kErrInconsistentType =
+ "workgroup_size arguments must be of the same type, either i32 "
+ "or u32";
+
+ auto* ty = TypeOf(expr);
+ bool is_i32 = ty->UnwrapRef()->Is<sem::I32>();
+ bool is_u32 = ty->UnwrapRef()->Is<sem::U32>();
+ if (!is_i32 && !is_u32) {
+ AddError(kErrBadType, expr->source);
+ return false;
+ }
+
+ any_i32 = any_i32 || is_i32;
+ any_u32 = any_u32 || is_u32;
+ if (any_i32 && any_u32) {
+ AddError(kErrInconsistentType, expr->source);
+ return false;
+ }
+
+ sem::Constant value;
+
+ if (auto* user = Sem(expr)->As<sem::VariableUser>()) {
+ // We have an variable of a module-scope constant.
+ auto* decl = user->Variable()->Declaration();
+ if (!decl->is_const) {
+ AddError(kErrBadType, expr->source);
+ return false;
+ }
+ // Capture the constant if it is pipeline-overridable.
+ if (decl->is_overridable) {
+ ws[i].overridable_const = decl;
+ }
+
+ if (decl->constructor) {
+ value = Sem(decl->constructor)->ConstantValue();
+ } else {
+ // No constructor means this value must be overriden by the user.
+ ws[i].value = 0;
+ continue;
+ }
+ } else if (expr->Is<ast::LiteralExpression>()) {
+ value = Sem(expr)->ConstantValue();
+ } else {
+ AddError(
+ "workgroup_size argument must be either a literal or a "
+ "module-scope constant",
+ values[i]->source);
+ return false;
+ }
+
+ if (!value) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "could not resolve constant workgroup_size constant value";
+ continue;
+ }
+ // Validate and set the default value for this dimension.
+ if (is_i32 ? value.Elements()[0].i32 < 1 : value.Elements()[0].u32 < 1) {
+ AddError("workgroup_size argument must be at least 1", values[i]->source);
+ return false;
+ }
+
+ ws[i].value = is_i32 ? static_cast<uint32_t>(value.Elements()[0].i32)
+ : value.Elements()[0].u32;
+ }
+
+ current_function_->SetWorkgroupSize(std::move(ws));
+ return true;
+}
+
+bool Resolver::Statements(const ast::StatementList& stmts) {
+ sem::Behaviors behaviors{sem::Behavior::kNext};
+
+ bool reachable = true;
+ for (auto* stmt : stmts) {
+ Mark(stmt);
+ auto* sem = Statement(stmt);
+ if (!sem) {
+ return false;
+ }
+ // s1 s2:(B1∖{Next}) ∪ B2
+ sem->SetIsReachable(reachable);
+ if (reachable) {
+ behaviors = (behaviors - sem::Behavior::kNext) + sem->Behaviors();
+ }
+ reachable = reachable && sem->Behaviors().Contains(sem::Behavior::kNext);
+ }
+
+ current_statement_->Behaviors() = behaviors;
+
+ if (!ValidateStatements(stmts)) {
+ return false;
+ }
+
+ return true;
+}
+
+sem::Statement* Resolver::Statement(const ast::Statement* stmt) {
+ return Switch(
+ stmt,
+ // Compound statements. These create their own sem::CompoundStatement
+ // bindings.
+ [&](const ast::BlockStatement* b) -> sem::Statement* {
+ return BlockStatement(b);
+ },
+ [&](const ast::ForLoopStatement* l) -> sem::Statement* {
+ return ForLoopStatement(l);
+ },
+ [&](const ast::LoopStatement* l) -> sem::Statement* {
+ return LoopStatement(l);
+ },
+ [&](const ast::IfStatement* i) -> sem::Statement* {
+ return IfStatement(i);
+ },
+ [&](const ast::SwitchStatement* s) -> sem::Statement* {
+ return SwitchStatement(s);
+ },
+
+ // Non-Compound statements
+ [&](const ast::AssignmentStatement* a) -> sem::Statement* {
+ return AssignmentStatement(a);
+ },
+ [&](const ast::BreakStatement* b) -> sem::Statement* {
+ return BreakStatement(b);
+ },
+ [&](const ast::CallStatement* c) -> sem::Statement* {
+ return CallStatement(c);
+ },
+ [&](const ast::ContinueStatement* c) -> sem::Statement* {
+ return ContinueStatement(c);
+ },
+ [&](const ast::DiscardStatement* d) -> sem::Statement* {
+ return DiscardStatement(d);
+ },
+ [&](const ast::FallthroughStatement* f) -> sem::Statement* {
+ return FallthroughStatement(f);
+ },
+ [&](const ast::ReturnStatement* r) -> sem::Statement* {
+ return ReturnStatement(r);
+ },
+ [&](const ast::VariableDeclStatement* v) -> sem::Statement* {
+ return VariableDeclStatement(v);
+ },
+
+ // Error cases
+ [&](const ast::CaseStatement*) -> sem::Statement* {
+ AddError("case statement can only be used inside a switch statement",
+ stmt->source);
+ return nullptr;
+ },
+ [&](const ast::ElseStatement*) -> sem::Statement* {
+ TINT_ICE(Resolver, diagnostics_)
+ << "Resolver::Statement() encountered an Else statement. Else "
+ "statements are embedded in If statements, so should never be "
+ "encountered as top-level statements";
+ return nullptr;
+ },
+ [&](Default) -> sem::Statement* {
+ AddError(
+ "unknown statement type: " + std::string(stmt->TypeInfo().name),
+ stmt->source);
+ return nullptr;
+ });
+}
+
+sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt) {
+ auto* sem = builder_->create<sem::CaseStatement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ for (auto* sel : stmt->selectors) {
+ Mark(sel);
+ }
+ Mark(stmt->body);
+ auto* body = BlockStatement(stmt->body);
+ if (!body) {
+ return false;
+ }
+ sem->SetBlock(body);
+ sem->Behaviors() = body->Behaviors();
+ return true;
+ });
+}
+
+sem::IfStatement* Resolver::IfStatement(const ast::IfStatement* stmt) {
+ auto* sem = builder_->create<sem::IfStatement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ auto* cond = Expression(stmt->condition);
+ if (!cond) {
+ return false;
+ }
+ sem->SetCondition(cond);
+ sem->Behaviors() = cond->Behaviors();
+ sem->Behaviors().Remove(sem::Behavior::kNext);
+
+ Mark(stmt->body);
+ auto* body = builder_->create<sem::BlockStatement>(
+ stmt->body, current_compound_statement_, current_function_);
+ if (!StatementScope(stmt->body, body,
+ [&] { return Statements(stmt->body->statements); })) {
+ return false;
+ }
+ sem->Behaviors().Add(body->Behaviors());
+
+ for (auto* else_stmt : stmt->else_statements) {
+ Mark(else_stmt);
+ auto* else_sem = ElseStatement(else_stmt);
+ if (!else_sem) {
+ return false;
+ }
+ sem->Behaviors().Add(else_sem->Behaviors());
+ }
+
+ if (stmt->else_statements.empty() ||
+ stmt->else_statements.back()->condition != nullptr) {
+ // https://www.w3.org/TR/WGSL/#behaviors-rules
+ // if statements without an else branch are treated as if they had an
+ // empty else branch (which adds Next to their behavior)
+ sem->Behaviors().Add(sem::Behavior::kNext);
+ }
+
+ return ValidateIfStatement(sem);
+ });
+}
+
+sem::ElseStatement* Resolver::ElseStatement(const ast::ElseStatement* stmt) {
+ auto* sem = builder_->create<sem::ElseStatement>(
+ stmt, current_compound_statement_->As<sem::IfStatement>(),
+ current_function_);
+ return StatementScope(stmt, sem, [&] {
+ if (auto* cond_expr = stmt->condition) {
+ auto* cond = Expression(cond_expr);
+ if (!cond) {
+ return false;
+ }
+ sem->SetCondition(cond);
+ // https://www.w3.org/TR/WGSL/#behaviors-rules
+ // if statements with else if branches are treated as if they were nested
+ // simple if/else statements
+ sem->Behaviors() = cond->Behaviors();
+ }
+ sem->Behaviors().Remove(sem::Behavior::kNext);
+
+ Mark(stmt->body);
+ auto* body = builder_->create<sem::BlockStatement>(
+ stmt->body, current_compound_statement_, current_function_);
+ if (!StatementScope(stmt->body, body,
+ [&] { return Statements(stmt->body->statements); })) {
+ return false;
+ }
+ sem->Behaviors().Add(body->Behaviors());
+
+ return ValidateElseStatement(sem);
+ });
+}
+
+sem::BlockStatement* Resolver::BlockStatement(const ast::BlockStatement* stmt) {
+ auto* sem = builder_->create<sem::BlockStatement>(
+ stmt->As<ast::BlockStatement>(), current_compound_statement_,
+ current_function_);
+ return StatementScope(stmt, sem,
+ [&] { return Statements(stmt->statements); });
+}
+
+sem::LoopStatement* Resolver::LoopStatement(const ast::LoopStatement* stmt) {
+ auto* sem = builder_->create<sem::LoopStatement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ Mark(stmt->body);
+
+ auto* body = builder_->create<sem::LoopBlockStatement>(
+ stmt->body, current_compound_statement_, current_function_);
+ return StatementScope(stmt->body, body, [&] {
+ if (!Statements(stmt->body->statements)) {
+ return false;
+ }
+ auto& behaviors = sem->Behaviors();
+ behaviors = body->Behaviors();
+
+ if (stmt->continuing) {
+ Mark(stmt->continuing);
+ if (!stmt->continuing->Empty()) {
+ auto* continuing = StatementScope(
+ stmt->continuing,
+ builder_->create<sem::LoopContinuingBlockStatement>(
+ stmt->continuing, current_compound_statement_,
+ current_function_),
+ [&] { return Statements(stmt->continuing->statements); });
+ if (!continuing) {
+ return false;
+ }
+ behaviors.Add(continuing->Behaviors());
+ }
+ }
+
+ if (behaviors.Contains(sem::Behavior::kBreak)) { // Does the loop exit?
+ behaviors.Add(sem::Behavior::kNext);
+ } else {
+ behaviors.Remove(sem::Behavior::kNext);
+ }
+ behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
+
+ return ValidateLoopStatement(sem);
+ });
+ });
+}
+
+sem::ForLoopStatement* Resolver::ForLoopStatement(
+ const ast::ForLoopStatement* stmt) {
+ auto* sem = builder_->create<sem::ForLoopStatement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ auto& behaviors = sem->Behaviors();
+ if (auto* initializer = stmt->initializer) {
+ Mark(initializer);
+ auto* init = Statement(initializer);
+ if (!init) {
+ return false;
+ }
+ behaviors.Add(init->Behaviors());
+ }
+
+ if (auto* cond_expr = stmt->condition) {
+ auto* cond = Expression(cond_expr);
+ if (!cond) {
+ return false;
+ }
+ sem->SetCondition(cond);
+ behaviors.Add(cond->Behaviors());
+ }
+
+ if (auto* continuing = stmt->continuing) {
+ Mark(continuing);
+ auto* cont = Statement(continuing);
+ if (!cont) {
+ return false;
+ }
+ behaviors.Add(cont->Behaviors());
+ }
+
+ Mark(stmt->body);
+
+ auto* body = builder_->create<sem::LoopBlockStatement>(
+ stmt->body, current_compound_statement_, current_function_);
+ if (!StatementScope(stmt->body, body,
+ [&] { return Statements(stmt->body->statements); })) {
+ return false;
+ }
+
+ behaviors.Add(body->Behaviors());
+ if (stmt->condition ||
+ behaviors.Contains(sem::Behavior::kBreak)) { // Does the loop exit?
+ behaviors.Add(sem::Behavior::kNext);
+ } else {
+ behaviors.Remove(sem::Behavior::kNext);
+ }
+ behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
+
+ return ValidateForLoopStatement(sem);
+ });
+}
+
+sem::Expression* Resolver::Expression(const ast::Expression* root) {
+ std::vector<const ast::Expression*> sorted;
+ bool mark_failed = false;
+ if (!ast::TraverseExpressions<ast::TraverseOrder::RightToLeft>(
+ root, diagnostics_, [&](const ast::Expression* expr) {
+ if (!Mark(expr)) {
+ mark_failed = true;
+ return ast::TraverseAction::Stop;
+ }
+ sorted.emplace_back(expr);
+ return ast::TraverseAction::Descend;
+ })) {
+ return nullptr;
+ }
+
+ if (mark_failed) {
+ return nullptr;
+ }
+
+ for (auto* expr : utils::Reverse(sorted)) {
+ auto* sem_expr = Switch(
+ expr,
+ [&](const ast::IndexAccessorExpression* array) -> sem::Expression* {
+ return IndexAccessor(array);
+ },
+ [&](const ast::BinaryExpression* bin_op) -> sem::Expression* {
+ return Binary(bin_op);
+ },
+ [&](const ast::BitcastExpression* bitcast) -> sem::Expression* {
+ return Bitcast(bitcast);
+ },
+ [&](const ast::CallExpression* call) -> sem::Expression* {
+ return Call(call);
+ },
+ [&](const ast::IdentifierExpression* ident) -> sem::Expression* {
+ return Identifier(ident);
+ },
+ [&](const ast::LiteralExpression* literal) -> sem::Expression* {
+ return Literal(literal);
+ },
+ [&](const ast::MemberAccessorExpression* member) -> sem::Expression* {
+ return MemberAccessor(member);
+ },
+ [&](const ast::UnaryOpExpression* unary) -> sem::Expression* {
+ return UnaryOp(unary);
+ },
+ [&](const ast::PhonyExpression*) -> sem::Expression* {
+ return builder_->create<sem::Expression>(
+ expr, builder_->create<sem::Void>(), current_statement_,
+ sem::Constant{}, /* has_side_effects */ false);
+ },
+ [&](Default) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "unhandled expression type: " << expr->TypeInfo().name;
+ return nullptr;
+ });
+ if (!sem_expr) {
+ return nullptr;
+ }
+
+ builder_->Sem().Add(expr, sem_expr);
+ if (expr == root) {
+ return sem_expr;
+ }
+ }
+
+ TINT_ICE(Resolver, diagnostics_) << "Expression() did not find root node";
+ return nullptr;
+}
+
+sem::Expression* Resolver::IndexAccessor(
+ const ast::IndexAccessorExpression* expr) {
+ auto* idx = Sem(expr->index);
+ auto* obj = Sem(expr->object);
+ auto* obj_raw_ty = obj->Type();
+ auto* obj_ty = obj_raw_ty->UnwrapRef();
+ auto* ty = Switch(
+ obj_ty, //
+ [&](const sem::Array* arr) -> const sem::Type* {
+ return arr->ElemType();
+ },
+ [&](const sem::Vector* vec) -> const sem::Type* { //
+ return vec->type();
+ },
+ [&](const sem::Matrix* mat) -> const sem::Type* {
+ return builder_->create<sem::Vector>(mat->type(), mat->rows());
+ },
+ [&](Default) -> const sem::Type* {
+ AddError("cannot index type '" + TypeNameOf(obj_ty) + "'",
+ expr->source);
+ return nullptr;
+ });
+ if (ty == nullptr) {
+ return nullptr;
+ }
+
+ auto* idx_ty = idx->Type()->UnwrapRef();
+ if (!idx_ty->IsAnyOf<sem::I32, sem::U32>()) {
+ AddError("index must be of type 'i32' or 'u32', found: '" +
+ TypeNameOf(idx_ty) + "'",
+ idx->Declaration()->source);
+ return nullptr;
+ }
+
+ // If we're extracting from a reference, we return a reference.
+ if (auto* ref = obj_raw_ty->As<sem::Reference>()) {
+ ty = builder_->create<sem::Reference>(ty, ref->StorageClass(),
+ ref->Access());
+ }
+
+ auto val = EvaluateConstantValue(expr, ty);
+ bool has_side_effects = idx->HasSideEffects() || obj->HasSideEffects();
+ auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_,
+ val, has_side_effects);
+ sem->Behaviors() = idx->Behaviors() + obj->Behaviors();
+ return sem;
+}
+
+sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
+ auto* inner = Sem(expr->expr);
+ auto* ty = Type(expr->type);
+ if (!ty) {
+ return nullptr;
+ }
+
+ auto val = EvaluateConstantValue(expr, ty);
+ auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_,
+ val, inner->HasSideEffects());
+
+ sem->Behaviors() = inner->Behaviors();
+
+ if (!ValidateBitcast(expr, ty)) {
+ return nullptr;
+ }
+
+ return sem;
+}
+
+sem::Call* Resolver::Call(const ast::CallExpression* expr) {
+ std::vector<const sem::Expression*> args(expr->args.size());
+ std::vector<const sem::Type*> arg_tys(args.size());
+ sem::Behaviors arg_behaviors;
+
+ // The element type of all the arguments. Nullptr if argument types are
+ // different.
+ const sem::Type* arg_el_ty = nullptr;
+
+ for (size_t i = 0; i < expr->args.size(); i++) {
+ auto* arg = Sem(expr->args[i]);
+ if (!arg) {
+ return nullptr;
+ }
+ args[i] = arg;
+ arg_tys[i] = args[i]->Type();
+ arg_behaviors.Add(arg->Behaviors());
+
+ // Determine the common argument element type
+ auto* el_ty = arg_tys[i]->UnwrapRef();
+ if (auto* vec = el_ty->As<sem::Vector>()) {
+ el_ty = vec->type();
+ } else if (auto* mat = el_ty->As<sem::Matrix>()) {
+ el_ty = mat->type();
+ }
+ if (i == 0) {
+ arg_el_ty = el_ty;
+ } else if (arg_el_ty != el_ty) {
+ arg_el_ty = nullptr;
+ }
+ }
+
+ arg_behaviors.Remove(sem::Behavior::kNext);
+
+ auto type_ctor_or_conv = [&](const sem::Type* ty) -> sem::Call* {
+ // The call has resolved to a type constructor or cast.
+ if (args.size() == 1) {
+ auto* target = ty;
+ auto* source = args[0]->Type()->UnwrapRef();
+ if ((source != target) && //
+ ((source->is_scalar() && target->is_scalar()) ||
+ (source->Is<sem::Vector>() && target->Is<sem::Vector>()) ||
+ (source->Is<sem::Matrix>() && target->Is<sem::Matrix>()))) {
+ // Note: Matrix types currently cannot be converted (the element type
+ // must only be f32). We implement this for the day we support other
+ // matrix element types.
+ return TypeConversion(expr, ty, args[0], arg_tys[0]);
+ }
+ }
+ return TypeConstructor(expr, ty, std::move(args), std::move(arg_tys));
+ };
+
+ // Resolve the target of the CallExpression to determine whether this is a
+ // function call, cast or type constructor expression.
+ if (expr->target.type) {
+ const sem::Type* ty = nullptr;
+
+ auto err_cannot_infer_el_ty = [&](std::string name) {
+ AddError(
+ "cannot infer " + name +
+ " element type, as constructor arguments have different types",
+ expr->source);
+ for (size_t i = 0; i < args.size(); i++) {
+ auto* arg = args[i];
+ AddNote("argument " + std::to_string(i) + " has type " +
+ arg->Type()->FriendlyName(builder_->Symbols()),
+ arg->Declaration()->source);
+ }
+ };
+
+ if (!expr->args.empty()) {
+ // vecN() without explicit element type?
+ // Try to infer element type from args
+ if (auto* vec = expr->target.type->As<ast::Vector>()) {
+ if (!vec->type) {
+ if (!arg_el_ty) {
+ err_cannot_infer_el_ty("vector");
+ return nullptr;
+ }
+
+ Mark(vec);
+ auto* v = builder_->create<sem::Vector>(
+ arg_el_ty, static_cast<uint32_t>(vec->width));
+ if (!ValidateVector(v, vec->source)) {
+ return nullptr;
+ }
+ builder_->Sem().Add(vec, v);
+ ty = v;
+ }
+ }
+
+ // matNxM() without explicit element type?
+ // Try to infer element type from args
+ if (auto* mat = expr->target.type->As<ast::Matrix>()) {
+ if (!mat->type) {
+ if (!arg_el_ty) {
+ err_cannot_infer_el_ty("matrix");
+ return nullptr;
+ }
+
+ Mark(mat);
+ auto* column_type =
+ builder_->create<sem::Vector>(arg_el_ty, mat->rows);
+ auto* m = builder_->create<sem::Matrix>(column_type, mat->columns);
+ if (!ValidateMatrix(m, mat->source)) {
+ return nullptr;
+ }
+ builder_->Sem().Add(mat, m);
+ ty = m;
+ }
+ }
+ }
+
+ if (ty == nullptr) {
+ ty = Type(expr->target.type);
+ if (!ty) {
+ return nullptr;
+ }
+ }
+
+ return type_ctor_or_conv(ty);
+ }
+
+ auto* ident = expr->target.name;
+ Mark(ident);
+
+ auto* resolved = ResolvedSymbol(ident);
+ return Switch(
+ resolved, //
+ [&](sem::Type* type) { return type_ctor_or_conv(type); },
+ [&](sem::Function* func) {
+ return FunctionCall(expr, func, std::move(args), arg_behaviors);
+ },
+ [&](sem::Variable* var) {
+ auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
+ AddError("cannot call variable '" + name + "'", ident->source);
+ AddNote("'" + name + "' declared here", var->Declaration()->source);
+ return nullptr;
+ },
+ [&](Default) -> sem::Call* {
+ auto name = builder_->Symbols().NameFor(ident->symbol);
+ auto builtin_type = sem::ParseBuiltinType(name);
+ if (builtin_type != sem::BuiltinType::kNone) {
+ return BuiltinCall(expr, builtin_type, std::move(args),
+ std::move(arg_tys));
+ }
+
+ TINT_ICE(Resolver, diagnostics_)
+ << expr->source << " unresolved CallExpression target:\n"
+ << "resolved: " << (resolved ? resolved->TypeInfo().name : "<null>")
+ << "\n"
+ << "name: " << builder_->Symbols().NameFor(ident->symbol);
+ return nullptr;
+ });
+}
+
+sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr,
+ sem::BuiltinType builtin_type,
+ const std::vector<const sem::Expression*> args,
+ const std::vector<const sem::Type*> arg_tys) {
+ auto* builtin =
+ builtin_table_->Lookup(builtin_type, std::move(arg_tys), expr->source);
+ if (!builtin) {
+ return nullptr;
+ }
+
+ if (builtin->IsDeprecated()) {
+ AddWarning("use of deprecated builtin", expr->source);
+ }
+
+ bool has_side_effects = builtin->HasSideEffects() ||
+ std::any_of(args.begin(), args.end(), [](auto* e) {
+ return e->HasSideEffects();
+ });
+ auto* call = builder_->create<sem::Call>(expr, builtin, std::move(args),
+ current_statement_, sem::Constant{},
+ has_side_effects);
+
+ current_function_->AddDirectlyCalledBuiltin(builtin);
+
+ if (IsTextureBuiltin(builtin_type)) {
+ if (!ValidateTextureBuiltinFunction(call)) {
+ return nullptr;
+ }
+ // Collect a texture/sampler pair for this builtin.
+ const auto& signature = builtin->Signature();
+ int texture_index = signature.IndexOf(sem::ParameterUsage::kTexture);
+ if (texture_index == -1) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "texture builtin without texture parameter";
+ }
+
+ auto* texture = args[texture_index]->As<sem::VariableUser>()->Variable();
+ if (!texture->Type()->UnwrapRef()->Is<sem::StorageTexture>()) {
+ int sampler_index = signature.IndexOf(sem::ParameterUsage::kSampler);
+ const sem::Variable* sampler =
+ sampler_index != -1
+ ? args[sampler_index]->As<sem::VariableUser>()->Variable()
+ : nullptr;
+ current_function_->AddTextureSamplerPair(texture, sampler);
+ }
+ }
+
+ if (!ValidateBuiltinCall(call)) {
+ return nullptr;
+ }
+
+ current_function_->AddDirectCall(call);
+
+ return call;
+}
+
+sem::Call* Resolver::FunctionCall(
+ const ast::CallExpression* expr,
+ sem::Function* target,
+ const std::vector<const sem::Expression*> args,
+ sem::Behaviors arg_behaviors) {
+ auto sym = expr->target.name->symbol;
+ auto name = builder_->Symbols().NameFor(sym);
+
+ // TODO(crbug.com/tint/1420): For now, assume all function calls have side
+ // effects.
+ bool has_side_effects = true;
+ auto* call = builder_->create<sem::Call>(expr, target, std::move(args),
+ current_statement_, sem::Constant{},
+ has_side_effects);
+
+ if (current_function_) {
+ // Note: Requires called functions to be resolved first.
+ // This is currently guaranteed as functions must be declared before
+ // use.
+ current_function_->AddTransitivelyCalledFunction(target);
+ current_function_->AddDirectCall(call);
+ for (auto* transitive_call : target->TransitivelyCalledFunctions()) {
+ current_function_->AddTransitivelyCalledFunction(transitive_call);
+ }
+
+ // We inherit any referenced variables from the callee.
+ for (auto* var : target->TransitivelyReferencedGlobals()) {
+ current_function_->AddTransitivelyReferencedGlobal(var);
+ }
+
+ // Map all texture/sampler pairs from the target function to the
+ // current function. These can only be global or parameter
+ // variables. Resolve any parameter variables to the corresponding
+ // argument passed to the current function. Leave global variables
+ // as-is. Then add the mapped pair to the current function's list of
+ // texture/sampler pairs.
+ for (sem::VariablePair pair : target->TextureSamplerPairs()) {
+ const sem::Variable* texture = pair.first;
+ const sem::Variable* sampler = pair.second;
+ if (auto* param = texture->As<sem::Parameter>()) {
+ texture = args[param->Index()]->As<sem::VariableUser>()->Variable();
+ }
+ if (sampler) {
+ if (auto* param = sampler->As<sem::Parameter>()) {
+ sampler = args[param->Index()]->As<sem::VariableUser>()->Variable();
+ }
+ }
+ current_function_->AddTextureSamplerPair(texture, sampler);
+ }
+ }
+
+ target->AddCallSite(call);
+
+ call->Behaviors() = arg_behaviors + target->Behaviors();
+
+ if (!ValidateFunctionCall(call)) {
+ return nullptr;
+ }
+
+ return call;
+}
+
+sem::Call* Resolver::TypeConversion(const ast::CallExpression* expr,
+ const sem::Type* target,
+ const sem::Expression* arg,
+ const sem::Type* source) {
+ // It is not valid to have a type-cast call expression inside a call
+ // statement.
+ if (IsCallStatement(expr)) {
+ AddError("type cast evaluated but not used", expr->source);
+ return nullptr;
+ }
+
+ auto* call_target = utils::GetOrCreate(
+ type_conversions_, TypeConversionSig{target, source},
+ [&]() -> sem::TypeConversion* {
+ // Now that the argument types have been determined, make sure that
+ // they obey the conversion rules laid out in
+ // https://gpuweb.github.io/gpuweb/wgsl/#conversion-expr.
+ bool ok = Switch(
+ target,
+ [&](const sem::Vector* vec_type) {
+ return ValidateVectorConstructorOrCast(expr, vec_type);
+ },
+ [&](const sem::Matrix* mat_type) {
+ // Note: Matrix types currently cannot be converted (the element
+ // type must only be f32). We implement this for the day we
+ // support other matrix element types.
+ return ValidateMatrixConstructorOrCast(expr, mat_type);
+ },
+ [&](const sem::Array* arr_type) {
+ return ValidateArrayConstructorOrCast(expr, arr_type);
+ },
+ [&](const sem::Struct* struct_type) {
+ return ValidateStructureConstructorOrCast(expr, struct_type);
+ },
+ [&](Default) {
+ if (target->is_scalar()) {
+ return ValidateScalarConstructorOrCast(expr, target);
+ }
+ AddError("type is not constructible", expr->source);
+ return false;
+ });
+ if (!ok) {
+ return nullptr;
+ }
+
+ auto* param = builder_->create<sem::Parameter>(
+ nullptr, // declaration
+ 0, // index
+ source->UnwrapRef(), // type
+ ast::StorageClass::kNone, // storage_class
+ ast::Access::kUndefined); // access
+ return builder_->create<sem::TypeConversion>(target, param);
+ });
+
+ if (!call_target) {
+ return nullptr;
+ }
+
+ auto val = EvaluateConstantValue(expr, target);
+ bool has_side_effects = arg->HasSideEffects();
+ return builder_->create<sem::Call>(expr, call_target,
+ std::vector<const sem::Expression*>{arg},
+ current_statement_, val, has_side_effects);
+}
+
+sem::Call* Resolver::TypeConstructor(
+ const ast::CallExpression* expr,
+ const sem::Type* ty,
+ const std::vector<const sem::Expression*> args,
+ const std::vector<const sem::Type*> arg_tys) {
+ // It is not valid to have a type-constructor call expression as a call
+ // statement.
+ if (IsCallStatement(expr)) {
+ AddError("type constructor evaluated but not used", expr->source);
+ return nullptr;
+ }
+
+ auto* call_target = utils::GetOrCreate(
+ type_ctors_, TypeConstructorSig{ty, arg_tys},
+ [&]() -> sem::TypeConstructor* {
+ // Now that the argument types have been determined, make sure that
+ // they obey the constructor type rules laid out in
+ // https://gpuweb.github.io/gpuweb/wgsl/#type-constructor-expr.
+ bool ok = Switch(
+ ty,
+ [&](const sem::Vector* vec_type) {
+ return ValidateVectorConstructorOrCast(expr, vec_type);
+ },
+ [&](const sem::Matrix* mat_type) {
+ return ValidateMatrixConstructorOrCast(expr, mat_type);
+ },
+ [&](const sem::Array* arr_type) {
+ return ValidateArrayConstructorOrCast(expr, arr_type);
+ },
+ [&](const sem::Struct* struct_type) {
+ return ValidateStructureConstructorOrCast(expr, struct_type);
+ },
+ [&](Default) {
+ if (ty->is_scalar()) {
+ return ValidateScalarConstructorOrCast(expr, ty);
+ }
+ AddError("type is not constructible", expr->source);
+ return false;
+ });
+ if (!ok) {
+ return nullptr;
+ }
+
+ return builder_->create<sem::TypeConstructor>(
+ ty, utils::Transform(
+ arg_tys,
+ [&](const sem::Type* t, size_t i) -> const sem::Parameter* {
+ return builder_->create<sem::Parameter>(
+ nullptr, // declaration
+ static_cast<uint32_t>(i), // index
+ t->UnwrapRef(), // type
+ ast::StorageClass::kNone, // storage_class
+ ast::Access::kUndefined); // access
+ }));
+ });
+
+ if (!call_target) {
+ return nullptr;
+ }
+
+ auto val = EvaluateConstantValue(expr, ty);
+ bool has_side_effects = std::any_of(
+ args.begin(), args.end(), [](auto* e) { return e->HasSideEffects(); });
+ return builder_->create<sem::Call>(expr, call_target, std::move(args),
+ current_statement_, val, has_side_effects);
+}
+
+sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
+ auto* ty = TypeOf(literal);
+ if (!ty) {
+ return nullptr;
+ }
+
+ auto val = EvaluateConstantValue(literal, ty);
+ return builder_->create<sem::Expression>(literal, ty, current_statement_, val,
+ /* has_side_effects */ false);
+}
+
+sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) {
+ auto symbol = expr->symbol;
+ auto* resolved = ResolvedSymbol(expr);
+ if (auto* var = As<sem::Variable>(resolved)) {
+ auto* user =
+ builder_->create<sem::VariableUser>(expr, current_statement_, var);
+
+ if (current_statement_) {
+ // If identifier is part of a loop continuing block, make sure it
+ // doesn't refer to a variable that is bypassed by a continue statement
+ // in the loop's body block.
+ if (auto* continuing_block =
+ current_statement_
+ ->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
+ auto* loop_block =
+ continuing_block->FindFirstParent<sem::LoopBlockStatement>();
+ if (loop_block->FirstContinue()) {
+ auto& decls = loop_block->Decls();
+ // If our identifier is in loop_block->decls, make sure its index is
+ // less than first_continue
+ auto iter =
+ std::find_if(decls.begin(), decls.end(),
+ [&symbol](auto* v) { return v->symbol == symbol; });
+ if (iter != decls.end()) {
+ auto var_decl_index =
+ static_cast<size_t>(std::distance(decls.begin(), iter));
+ if (var_decl_index >= loop_block->NumDeclsAtFirstContinue()) {
+ AddError("continue statement bypasses declaration of '" +
+ builder_->Symbols().NameFor(symbol) + "'",
+ loop_block->FirstContinue()->source);
+ AddNote("identifier '" + builder_->Symbols().NameFor(symbol) +
+ "' declared here",
+ (*iter)->source);
+ AddNote("identifier '" + builder_->Symbols().NameFor(symbol) +
+ "' referenced in continuing block here",
+ expr->source);
+ return nullptr;
+ }
+ }
+ }
+ }
+ }
+
+ if (current_function_) {
+ if (auto* global = var->As<sem::GlobalVariable>()) {
+ current_function_->AddDirectlyReferencedGlobal(global);
+ }
+ }
+
+ var->AddUser(user);
+ return user;
+ }
+
+ if (Is<sem::Function>(resolved)) {
+ AddError("missing '(' for function call", expr->source.End());
+ return nullptr;
+ }
+
+ if (IsBuiltin(symbol)) {
+ AddError("missing '(' for builtin call", expr->source.End());
+ return nullptr;
+ }
+
+ if (resolved->Is<sem::Type>()) {
+ AddError("missing '(' for type constructor or cast", expr->source.End());
+ return nullptr;
+ }
+
+ TINT_ICE(Resolver, diagnostics_)
+ << expr->source << " unresolved identifier:\n"
+ << "resolved: " << (resolved ? resolved->TypeInfo().name : "<null>")
+ << "\n"
+ << "name: " << builder_->Symbols().NameFor(symbol);
+ return nullptr;
+}
+
+sem::Expression* Resolver::MemberAccessor(
+ const ast::MemberAccessorExpression* expr) {
+ auto* structure = TypeOf(expr->structure);
+ auto* storage_ty = structure->UnwrapRef();
+
+ const sem::Type* ret = nullptr;
+ std::vector<uint32_t> swizzle;
+
+ if (auto* str = storage_ty->As<sem::Struct>()) {
+ Mark(expr->member);
+ auto symbol = expr->member->symbol;
+
+ const sem::StructMember* member = nullptr;
+ for (auto* m : str->Members()) {
+ if (m->Name() == symbol) {
+ ret = m->Type();
+ member = m;
+ break;
+ }
+ }
+
+ if (ret == nullptr) {
+ AddError(
+ "struct member " + builder_->Symbols().NameFor(symbol) + " not found",
+ expr->source);
+ return nullptr;
+ }
+
+ // If we're extracting from a reference, we return a reference.
+ if (auto* ref = structure->As<sem::Reference>()) {
+ ret = builder_->create<sem::Reference>(ret, ref->StorageClass(),
+ ref->Access());
+ }
+
+ // Structure may be a side-effecting expression (e.g. function call).
+ auto* sem_structure = Sem(expr->structure);
+ bool has_side_effects = sem_structure && sem_structure->HasSideEffects();
+
+ return builder_->create<sem::StructMemberAccess>(
+ expr, ret, current_statement_, member, has_side_effects);
+ }
+
+ if (auto* vec = storage_ty->As<sem::Vector>()) {
+ Mark(expr->member);
+ std::string s = builder_->Symbols().NameFor(expr->member->symbol);
+ auto size = s.size();
+ swizzle.reserve(s.size());
+
+ for (auto c : s) {
+ switch (c) {
+ case 'x':
+ case 'r':
+ swizzle.emplace_back(0);
+ break;
+ case 'y':
+ case 'g':
+ swizzle.emplace_back(1);
+ break;
+ case 'z':
+ case 'b':
+ swizzle.emplace_back(2);
+ break;
+ case 'w':
+ case 'a':
+ swizzle.emplace_back(3);
+ break;
+ default:
+ AddError("invalid vector swizzle character",
+ expr->member->source.Begin() + swizzle.size());
+ return nullptr;
+ }
+
+ if (swizzle.back() >= vec->Width()) {
+ AddError("invalid vector swizzle member", expr->member->source);
+ return nullptr;
+ }
+ }
+
+ if (size < 1 || size > 4) {
+ AddError("invalid vector swizzle size", expr->member->source);
+ return nullptr;
+ }
+
+ // All characters are valid, check if they're being mixed
+ auto is_rgba = [](char c) {
+ return c == 'r' || c == 'g' || c == 'b' || c == 'a';
+ };
+ auto is_xyzw = [](char c) {
+ return c == 'x' || c == 'y' || c == 'z' || c == 'w';
+ };
+ if (!std::all_of(s.begin(), s.end(), is_rgba) &&
+ !std::all_of(s.begin(), s.end(), is_xyzw)) {
+ AddError("invalid mixing of vector swizzle characters rgba with xyzw",
+ expr->member->source);
+ return nullptr;
+ }
+
+ if (size == 1) {
+ // A single element swizzle is just the type of the vector.
+ ret = vec->type();
+ // If we're extracting from a reference, we return a reference.
+ if (auto* ref = structure->As<sem::Reference>()) {
+ ret = builder_->create<sem::Reference>(ret, ref->StorageClass(),
+ ref->Access());
+ }
+ } else {
+ // The vector will have a number of components equal to the length of
+ // the swizzle.
+ ret = builder_->create<sem::Vector>(vec->type(),
+ static_cast<uint32_t>(size));
+ }
+ return builder_->create<sem::Swizzle>(expr, ret, current_statement_,
+ std::move(swizzle));
+ }
+
+ AddError(
+ "invalid member accessor expression. Expected vector or struct, got '" +
+ TypeNameOf(storage_ty) + "'",
+ expr->structure->source);
+ return nullptr;
+}
+
+sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
+ using Bool = sem::Bool;
+ using F32 = sem::F32;
+ using I32 = sem::I32;
+ using U32 = sem::U32;
+ using Matrix = sem::Matrix;
+ using Vector = sem::Vector;
+
+ auto* lhs = Sem(expr->lhs);
+ auto* rhs = Sem(expr->rhs);
+
+ auto* lhs_ty = lhs->Type()->UnwrapRef();
+ auto* rhs_ty = rhs->Type()->UnwrapRef();
+
+ auto* lhs_vec = lhs_ty->As<Vector>();
+ auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
+ auto* rhs_vec = rhs_ty->As<Vector>();
+ auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr;
+
+ const bool matching_vec_elem_types =
+ lhs_vec_elem_type && rhs_vec_elem_type &&
+ (lhs_vec_elem_type == rhs_vec_elem_type) &&
+ (lhs_vec->Width() == rhs_vec->Width());
+
+ const bool matching_types = matching_vec_elem_types || (lhs_ty == rhs_ty);
+
+ auto build = [&](const sem::Type* ty) {
+ auto val = EvaluateConstantValue(expr, ty);
+ bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
+ auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_,
+ val, has_side_effects);
+ sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
+ return sem;
+ };
+
+ // Binary logical expressions
+ if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
+ if (matching_types && lhs_ty->Is<Bool>()) {
+ return build(lhs_ty);
+ }
+ }
+ if (expr->IsOr() || expr->IsAnd()) {
+ if (matching_types && lhs_ty->Is<Bool>()) {
+ return build(lhs_ty);
+ }
+ if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
+ return build(lhs_ty);
+ }
+ }
+
+ // Arithmetic expressions
+ if (expr->IsArithmetic()) {
+ // Binary arithmetic expressions over scalars
+ if (matching_types && lhs_ty->is_numeric_scalar()) {
+ return build(lhs_ty);
+ }
+
+ // Binary arithmetic expressions over vectors
+ if (matching_types && lhs_vec_elem_type &&
+ lhs_vec_elem_type->is_numeric_scalar()) {
+ return build(lhs_ty);
+ }
+
+ // Binary arithmetic expressions with mixed scalar and vector operands
+ if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty)) {
+ if (expr->IsModulo()) {
+ if (rhs_ty->is_integer_scalar()) {
+ return build(lhs_ty);
+ }
+ } else if (rhs_ty->is_numeric_scalar()) {
+ return build(lhs_ty);
+ }
+ }
+ if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty)) {
+ if (expr->IsModulo()) {
+ if (lhs_ty->is_integer_scalar()) {
+ return build(rhs_ty);
+ }
+ } else if (lhs_ty->is_numeric_scalar()) {
+ return build(rhs_ty);
+ }
+ }
+ }
+
+ // Matrix arithmetic
+ auto* lhs_mat = lhs_ty->As<Matrix>();
+ auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
+ auto* rhs_mat = rhs_ty->As<Matrix>();
+ auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
+ // Addition and subtraction of float matrices
+ if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type &&
+ lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
+ rhs_mat_elem_type->Is<F32>() &&
+ (lhs_mat->columns() == rhs_mat->columns()) &&
+ (lhs_mat->rows() == rhs_mat->rows())) {
+ return build(rhs_ty);
+ }
+ if (expr->IsMultiply()) {
+ // Multiplication of a matrix and a scalar
+ if (lhs_ty->Is<F32>() && rhs_mat_elem_type &&
+ rhs_mat_elem_type->Is<F32>()) {
+ return build(rhs_ty);
+ }
+ if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
+ rhs_ty->Is<F32>()) {
+ return build(lhs_ty);
+ }
+
+ // Vector times matrix
+ if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
+ rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
+ (lhs_vec->Width() == rhs_mat->rows())) {
+ return build(
+ builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns()));
+ }
+
+ // Matrix times vector
+ if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
+ rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() &&
+ (lhs_mat->columns() == rhs_vec->Width())) {
+ return build(
+ builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows()));
+ }
+
+ // Matrix times matrix
+ if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
+ rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
+ (lhs_mat->columns() == rhs_mat->rows())) {
+ return build(builder_->create<sem::Matrix>(
+ builder_->create<sem::Vector>(lhs_mat_elem_type, lhs_mat->rows()),
+ rhs_mat->columns()));
+ }
+ }
+
+ // Comparison expressions
+ if (expr->IsComparison()) {
+ if (matching_types) {
+ // Special case for bools: only == and !=
+ if (lhs_ty->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) {
+ return build(builder_->create<sem::Bool>());
+ }
+
+ // For the rest, we can compare i32, u32, and f32
+ if (lhs_ty->IsAnyOf<I32, U32, F32>()) {
+ return build(builder_->create<sem::Bool>());
+ }
+ }
+
+ // Same for vectors
+ if (matching_vec_elem_types) {
+ if (lhs_vec_elem_type->Is<Bool>() &&
+ (expr->IsEqual() || expr->IsNotEqual())) {
+ return build(builder_->create<sem::Vector>(
+ builder_->create<sem::Bool>(), lhs_vec->Width()));
+ }
+
+ if (lhs_vec_elem_type->is_numeric_scalar()) {
+ return build(builder_->create<sem::Vector>(
+ builder_->create<sem::Bool>(), lhs_vec->Width()));
+ }
+ }
+ }
+
+ // Binary bitwise operations
+ if (expr->IsBitwise()) {
+ if (matching_types && lhs_ty->is_integer_scalar_or_vector()) {
+ return build(lhs_ty);
+ }
+ }
+
+ // Bit shift expressions
+ if (expr->IsBitshift()) {
+ // Type validation rules are the same for left or right shift, despite
+ // differences in computation rules (i.e. right shift can be arithmetic or
+ // logical depending on lhs type).
+
+ if (lhs_ty->IsAnyOf<I32, U32>() && rhs_ty->Is<U32>()) {
+ return build(lhs_ty);
+ }
+
+ if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() &&
+ rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) {
+ return build(lhs_ty);
+ }
+ }
+
+ AddError("Binary expression operand types are invalid for this operation: " +
+ TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " +
+ TypeNameOf(rhs_ty),
+ expr->source);
+ return nullptr;
+}
+
+sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
+ auto* expr = Sem(unary->expr);
+ auto* expr_ty = expr->Type();
+ if (!expr_ty) {
+ return nullptr;
+ }
+
+ const sem::Type* ty = nullptr;
+
+ switch (unary->op) {
+ case ast::UnaryOp::kNot:
+ // Result type matches the deref'd inner type.
+ ty = expr_ty->UnwrapRef();
+ if (!ty->Is<sem::Bool>() && !ty->is_bool_vector()) {
+ AddError(
+ "cannot logical negate expression of type '" + TypeNameOf(expr_ty),
+ unary->expr->source);
+ return nullptr;
+ }
+ break;
+
+ case ast::UnaryOp::kComplement:
+ // Result type matches the deref'd inner type.
+ ty = expr_ty->UnwrapRef();
+ if (!ty->is_integer_scalar_or_vector()) {
+ AddError("cannot bitwise complement expression of type '" +
+ TypeNameOf(expr_ty),
+ unary->expr->source);
+ return nullptr;
+ }
+ break;
+
+ case ast::UnaryOp::kNegation:
+ // Result type matches the deref'd inner type.
+ ty = expr_ty->UnwrapRef();
+ if (!(ty->IsAnyOf<sem::F32, sem::I32>() ||
+ ty->is_signed_integer_vector() || ty->is_float_vector())) {
+ AddError("cannot negate expression of type '" + TypeNameOf(expr_ty),
+ unary->expr->source);
+ return nullptr;
+ }
+ break;
+
+ case ast::UnaryOp::kAddressOf:
+ if (auto* ref = expr_ty->As<sem::Reference>()) {
+ if (ref->StoreType()->UnwrapRef()->is_handle()) {
+ AddError(
+ "cannot take the address of expression in handle storage class",
+ unary->expr->source);
+ return nullptr;
+ }
+
+ auto* array = unary->expr->As<ast::IndexAccessorExpression>();
+ auto* member = unary->expr->As<ast::MemberAccessorExpression>();
+ if ((array && TypeOf(array->object)->UnwrapRef()->Is<sem::Vector>()) ||
+ (member &&
+ TypeOf(member->structure)->UnwrapRef()->Is<sem::Vector>())) {
+ AddError("cannot take the address of a vector component",
+ unary->expr->source);
+ return nullptr;
+ }
+
+ ty = builder_->create<sem::Pointer>(ref->StoreType(),
+ ref->StorageClass(), ref->Access());
+ } else {
+ AddError("cannot take the address of expression", unary->expr->source);
+ return nullptr;
+ }
+ break;
+
+ case ast::UnaryOp::kIndirection:
+ if (auto* ptr = expr_ty->As<sem::Pointer>()) {
+ ty = builder_->create<sem::Reference>(
+ ptr->StoreType(), ptr->StorageClass(), ptr->Access());
+ } else {
+ AddError("cannot dereference expression of type '" +
+ TypeNameOf(expr_ty) + "'",
+ unary->expr->source);
+ return nullptr;
+ }
+ break;
+ }
+
+ auto val = EvaluateConstantValue(unary, ty);
+ auto* sem = builder_->create<sem::Expression>(unary, ty, current_statement_,
+ val, expr->HasSideEffects());
+ sem->Behaviors() = expr->Behaviors();
+ return sem;
+}
+
+sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) {
+ sem::Type* result = nullptr;
+ if (auto* alias = named_type->As<ast::Alias>()) {
+ result = Alias(alias);
+ } else if (auto* str = named_type->As<ast::Struct>()) {
+ result = Structure(str);
+ } else {
+ TINT_UNREACHABLE(Resolver, diagnostics_) << "Unhandled TypeDecl";
+ }
+
+ if (!result) {
+ return nullptr;
+ }
+
+ builder_->Sem().Add(named_type, result);
+ return result;
+}
+
+sem::Type* Resolver::TypeOf(const ast::Expression* expr) {
+ auto* sem = Sem(expr);
+ return sem ? const_cast<sem::Type*>(sem->Type()) : nullptr;
+}
+
+std::string Resolver::TypeNameOf(const sem::Type* ty) {
+ return RawTypeNameOf(ty->UnwrapRef());
+}
+
+std::string Resolver::RawTypeNameOf(const sem::Type* ty) {
+ return ty->FriendlyName(builder_->Symbols());
+}
+
+sem::Type* Resolver::TypeOf(const ast::LiteralExpression* lit) {
+ return Switch(
+ lit,
+ [&](const ast::SintLiteralExpression*) -> sem::Type* {
+ return builder_->create<sem::I32>();
+ },
+ [&](const ast::UintLiteralExpression*) -> sem::Type* {
+ return builder_->create<sem::U32>();
+ },
+ [&](const ast::FloatLiteralExpression*) -> sem::Type* {
+ return builder_->create<sem::F32>();
+ },
+ [&](const ast::BoolLiteralExpression*) -> sem::Type* {
+ return builder_->create<sem::Bool>();
+ },
+ [&](Default) -> sem::Type* {
+ TINT_UNREACHABLE(Resolver, diagnostics_)
+ << "Unhandled literal type: " << lit->TypeInfo().name;
+ return nullptr;
+ });
+}
+
+sem::Array* Resolver::Array(const ast::Array* arr) {
+ auto source = arr->source;
+
+ auto* elem_type = Type(arr->type);
+ if (!elem_type) {
+ return nullptr;
+ }
+
+ if (!IsPlain(elem_type)) { // Check must come before GetDefaultAlignAndSize()
+ AddError(TypeNameOf(elem_type) +
+ " cannot be used as an element type of an array",
+ source);
+ return nullptr;
+ }
+
+ uint32_t el_align = elem_type->Align();
+ uint32_t el_size = elem_type->Size();
+
+ if (!ValidateNoDuplicateAttributes(arr->attributes)) {
+ return nullptr;
+ }
+
+ // Look for explicit stride via @stride(n) attribute
+ uint32_t explicit_stride = 0;
+ for (auto* attr : arr->attributes) {
+ Mark(attr);
+ if (auto* sd = attr->As<ast::StrideAttribute>()) {
+ explicit_stride = sd->stride;
+ if (!ValidateArrayStrideAttribute(sd, el_size, el_align, source)) {
+ return nullptr;
+ }
+ continue;
+ }
+
+ AddError("attribute is not valid for array types", attr->source);
+ return nullptr;
+ }
+
+ // Calculate implicit stride
+ uint64_t implicit_stride = utils::RoundUp<uint64_t>(el_align, el_size);
+
+ uint64_t stride = explicit_stride ? explicit_stride : implicit_stride;
+
+ // Evaluate the constant array size expression.
+ // sem::Array uses a size of 0 for a runtime-sized array.
+ uint32_t count = 0;
+ if (auto* count_expr = arr->count) {
+ auto* count_sem = Expression(count_expr);
+ if (!count_sem) {
+ return nullptr;
+ }
+
+ auto size_source = count_expr->source;
+
+ auto* ty = count_sem->Type()->UnwrapRef();
+ if (!ty->is_integer_scalar()) {
+ AddError("array size must be integer scalar", size_source);
+ return nullptr;
+ }
+
+ if (auto* ident = count_expr->As<ast::IdentifierExpression>()) {
+ // Make sure the identifier is a non-overridable module-scope constant.
+ auto* var = ResolvedSymbol<sem::GlobalVariable>(ident);
+ if (!var || !var->Declaration()->is_const) {
+ AddError("array size identifier must be a module-scope constant",
+ size_source);
+ return nullptr;
+ }
+ if (var->IsOverridable()) {
+ AddError("array size expression must not be pipeline-overridable",
+ size_source);
+ return nullptr;
+ }
+
+ count_expr = var->Declaration()->constructor;
+ } else if (!count_expr->Is<ast::LiteralExpression>()) {
+ AddError(
+ "array size expression must be either a literal or a module-scope "
+ "constant",
+ size_source);
+ return nullptr;
+ }
+
+ auto count_val = count_sem->ConstantValue();
+ if (!count_val) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "could not resolve array size expression";
+ return nullptr;
+ }
+
+ if (ty->is_signed_integer_scalar() ? count_val.Elements()[0].i32 < 1
+ : count_val.Elements()[0].u32 < 1u) {
+ AddError("array size must be at least 1", size_source);
+ return nullptr;
+ }
+
+ count = count_val.Elements()[0].u32;
+ }
+
+ auto size = std::max<uint64_t>(count, 1) * stride;
+ if (size > std::numeric_limits<uint32_t>::max()) {
+ std::stringstream msg;
+ msg << "array size in bytes must not exceed 0x" << std::hex
+ << std::numeric_limits<uint32_t>::max() << ", but is 0x" << std::hex
+ << size;
+ AddError(msg.str(), arr->source);
+ return nullptr;
+ }
+ if (stride > std::numeric_limits<uint32_t>::max() ||
+ implicit_stride > std::numeric_limits<uint32_t>::max()) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "calculated array stride exceeds uint32";
+ return nullptr;
+ }
+ auto* out = builder_->create<sem::Array>(
+ elem_type, count, el_align, static_cast<uint32_t>(size),
+ static_cast<uint32_t>(stride), static_cast<uint32_t>(implicit_stride));
+
+ if (!ValidateArray(out, source)) {
+ return nullptr;
+ }
+
+ if (elem_type->Is<sem::Atomic>()) {
+ atomic_composite_info_.emplace(out, arr->type->source);
+ } else {
+ auto found = atomic_composite_info_.find(elem_type);
+ if (found != atomic_composite_info_.end()) {
+ atomic_composite_info_.emplace(out, found->second);
+ }
+ }
+
+ return out;
+}
+
+sem::Type* Resolver::Alias(const ast::Alias* alias) {
+ auto* ty = Type(alias->type);
+ if (!ty) {
+ return nullptr;
+ }
+ if (!ValidateAlias(alias)) {
+ return nullptr;
+ }
+ return ty;
+}
+
+sem::Struct* Resolver::Structure(const ast::Struct* str) {
+ if (!ValidateNoDuplicateAttributes(str->attributes)) {
+ return nullptr;
+ }
+ for (auto* attr : str->attributes) {
+ Mark(attr);
+ }
+
+ sem::StructMemberList sem_members;
+ sem_members.reserve(str->members.size());
+
+ // Calculate the effective size and alignment of each field, and the overall
+ // size of the structure.
+ // For size, use the size attribute if provided, otherwise use the default
+ // size for the type.
+ // For alignment, use the alignment attribute if provided, otherwise use the
+ // default alignment for the member type.
+ // Diagnostic errors are raised if a basic rule is violated.
+ // Validation of storage-class rules requires analysing the actual variable
+ // usage of the structure, and so is performed as part of the variable
+ // validation.
+ uint64_t struct_size = 0;
+ uint64_t struct_align = 1;
+ std::unordered_map<Symbol, const ast::StructMember*> member_map;
+
+ for (auto* member : str->members) {
+ Mark(member);
+ auto result = member_map.emplace(member->symbol, member);
+ if (!result.second) {
+ AddError("redefinition of '" +
+ builder_->Symbols().NameFor(member->symbol) + "'",
+ member->source);
+ AddNote("previous definition is here", result.first->second->source);
+ return nullptr;
+ }
+
+ // Resolve member type
+ auto* type = Type(member->type);
+ if (!type) {
+ return nullptr;
+ }
+
+ // Validate member type
+ if (!IsPlain(type)) {
+ AddError(TypeNameOf(type) +
+ " cannot be used as the type of a structure member",
+ member->source);
+ return nullptr;
+ }
+
+ uint64_t offset = struct_size;
+ uint64_t align = type->Align();
+ uint64_t size = type->Size();
+
+ if (!ValidateNoDuplicateAttributes(member->attributes)) {
+ return nullptr;
+ }
+
+ bool has_offset_attr = false;
+ bool has_align_attr = false;
+ bool has_size_attr = false;
+ for (auto* attr : member->attributes) {
+ Mark(attr);
+ if (auto* o = attr->As<ast::StructMemberOffsetAttribute>()) {
+ // Offset attributes are not part of the WGSL spec, but are emitted
+ // by the SPIR-V reader.
+ if (o->offset < struct_size) {
+ AddError("offsets must be in ascending order", o->source);
+ return nullptr;
+ }
+ offset = o->offset;
+ align = 1;
+ has_offset_attr = true;
+ } else if (auto* a = attr->As<ast::StructMemberAlignAttribute>()) {
+ if (a->align <= 0 || !utils::IsPowerOfTwo(a->align)) {
+ AddError("align value must be a positive, power-of-two integer",
+ a->source);
+ return nullptr;
+ }
+ align = a->align;
+ has_align_attr = true;
+ } else if (auto* s = attr->As<ast::StructMemberSizeAttribute>()) {
+ if (s->size < size) {
+ AddError("size must be at least as big as the type's size (" +
+ std::to_string(size) + ")",
+ s->source);
+ return nullptr;
+ }
+ size = s->size;
+ has_size_attr = true;
+ }
+ }
+
+ if (has_offset_attr && (has_align_attr || has_size_attr)) {
+ AddError("offset attributes cannot be used with align or size attributes",
+ member->source);
+ return nullptr;
+ }
+
+ offset = utils::RoundUp(align, offset);
+ if (offset > std::numeric_limits<uint32_t>::max()) {
+ std::stringstream msg;
+ msg << "struct member has byte offset 0x" << std::hex << offset
+ << ", but must not exceed 0x" << std::hex
+ << std::numeric_limits<uint32_t>::max();
+ AddError(msg.str(), member->source);
+ return nullptr;
+ }
+
+ auto* sem_member = builder_->create<sem::StructMember>(
+ member, member->symbol, type, static_cast<uint32_t>(sem_members.size()),
+ static_cast<uint32_t>(offset), static_cast<uint32_t>(align),
+ static_cast<uint32_t>(size));
+ builder_->Sem().Add(member, sem_member);
+ sem_members.emplace_back(sem_member);
+
+ struct_size = offset + size;
+ struct_align = std::max(struct_align, align);
+ }
+
+ uint64_t size_no_padding = struct_size;
+ struct_size = utils::RoundUp(struct_align, struct_size);
+
+ if (struct_size > std::numeric_limits<uint32_t>::max()) {
+ std::stringstream msg;
+ msg << "struct size in bytes must not exceed 0x" << std::hex
+ << std::numeric_limits<uint32_t>::max() << ", but is 0x" << std::hex
+ << struct_size;
+ AddError(msg.str(), str->source);
+ return nullptr;
+ }
+ if (struct_align > std::numeric_limits<uint32_t>::max()) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "calculated struct stride exceeds uint32";
+ return nullptr;
+ }
+
+ auto* out = builder_->create<sem::Struct>(
+ str, str->name, sem_members, static_cast<uint32_t>(struct_align),
+ static_cast<uint32_t>(struct_size),
+ static_cast<uint32_t>(size_no_padding));
+
+ for (size_t i = 0; i < sem_members.size(); i++) {
+ auto* mem_type = sem_members[i]->Type();
+ if (mem_type->Is<sem::Atomic>()) {
+ atomic_composite_info_.emplace(out,
+ sem_members[i]->Declaration()->source);
+ break;
+ } else {
+ auto found = atomic_composite_info_.find(mem_type);
+ if (found != atomic_composite_info_.end()) {
+ atomic_composite_info_.emplace(out, found->second);
+ break;
+ }
+ }
+ }
+
+ if (!ValidateStructure(out)) {
+ return nullptr;
+ }
+
+ return out;
+}
+
+sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) {
+ auto* sem = builder_->create<sem::Statement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ auto& behaviors = current_statement_->Behaviors();
+ behaviors = sem::Behavior::kReturn;
+
+ if (auto* value = stmt->value) {
+ auto* expr = Expression(value);
+ if (!expr) {
+ return false;
+ }
+ behaviors.Add(expr->Behaviors() - sem::Behavior::kNext);
+ }
+
+ // Validate after processing the return value expression so that its type
+ // is available for validation.
+ return ValidateReturn(stmt);
+ });
+}
+
+sem::SwitchStatement* Resolver::SwitchStatement(
+ const ast::SwitchStatement* stmt) {
+ auto* sem = builder_->create<sem::SwitchStatement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ auto& behaviors = sem->Behaviors();
+
+ auto* cond = Expression(stmt->condition);
+ if (!cond) {
+ return false;
+ }
+ behaviors = cond->Behaviors() - sem::Behavior::kNext;
+
+ for (auto* case_stmt : stmt->body) {
+ Mark(case_stmt);
+ auto* c = CaseStatement(case_stmt);
+ if (!c) {
+ return false;
+ }
+ behaviors.Add(c->Behaviors());
+ }
+
+ if (behaviors.Contains(sem::Behavior::kBreak)) {
+ behaviors.Add(sem::Behavior::kNext);
+ }
+ behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kFallthrough);
+
+ return ValidateSwitch(stmt);
+ });
+}
+
+sem::Statement* Resolver::VariableDeclStatement(
+ const ast::VariableDeclStatement* stmt) {
+ auto* sem = builder_->create<sem::Statement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ Mark(stmt->variable);
+
+ auto* var = Variable(stmt->variable, VariableKind::kLocal);
+ if (!var) {
+ return false;
+ }
+
+ for (auto* attr : stmt->variable->attributes) {
+ Mark(attr);
+ if (!attr->Is<ast::InternalAttribute>()) {
+ AddError("attributes are not valid on local variables", attr->source);
+ return false;
+ }
+ }
+
+ if (current_block_) { // Not all statements are inside a block
+ current_block_->AddDecl(stmt->variable);
+ }
+
+ if (auto* ctor = var->Constructor()) {
+ sem->Behaviors() = ctor->Behaviors();
+ }
+
+ return ValidateVariable(var);
+ });
+}
+
+sem::Statement* Resolver::AssignmentStatement(
+ const ast::AssignmentStatement* stmt) {
+ auto* sem = builder_->create<sem::Statement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ auto* lhs = Expression(stmt->lhs);
+ if (!lhs) {
+ return false;
+ }
+
+ auto* rhs = Expression(stmt->rhs);
+ if (!rhs) {
+ return false;
+ }
+
+ auto& behaviors = sem->Behaviors();
+ behaviors = rhs->Behaviors();
+ if (!stmt->lhs->Is<ast::PhonyExpression>()) {
+ behaviors.Add(lhs->Behaviors());
+ }
+
+ return ValidateAssignment(stmt);
+ });
+}
+
+sem::Statement* Resolver::BreakStatement(const ast::BreakStatement* stmt) {
+ auto* sem = builder_->create<sem::Statement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ sem->Behaviors() = sem::Behavior::kBreak;
+
+ return ValidateBreakStatement(sem);
+ });
+}
+
+sem::Statement* Resolver::CallStatement(const ast::CallStatement* stmt) {
+ auto* sem = builder_->create<sem::Statement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ if (auto* expr = Expression(stmt->expr)) {
+ sem->Behaviors() = expr->Behaviors();
+ return true;
+ }
+ return false;
+ });
+}
+
+sem::Statement* Resolver::ContinueStatement(
+ const ast::ContinueStatement* stmt) {
+ auto* sem = builder_->create<sem::Statement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ sem->Behaviors() = sem::Behavior::kContinue;
+
+ // Set if we've hit the first continue statement in our parent loop
+ if (auto* block = sem->FindFirstParent<sem::LoopBlockStatement>()) {
+ if (!block->FirstContinue()) {
+ const_cast<sem::LoopBlockStatement*>(block)->SetFirstContinue(
+ stmt, block->Decls().size());
+ }
+ }
+
+ return ValidateContinueStatement(sem);
+ });
+}
+
+sem::Statement* Resolver::DiscardStatement(const ast::DiscardStatement* stmt) {
+ auto* sem = builder_->create<sem::Statement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ sem->Behaviors() = sem::Behavior::kDiscard;
+ current_function_->SetHasDiscard();
+
+ return ValidateDiscardStatement(sem);
+ });
+}
+
+sem::Statement* Resolver::FallthroughStatement(
+ const ast::FallthroughStatement* stmt) {
+ auto* sem = builder_->create<sem::Statement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ sem->Behaviors() = sem::Behavior::kFallthrough;
+
+ return ValidateFallthroughStatement(sem);
+ });
+}
+
+bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
+ sem::Type* ty,
+ const Source& usage) {
+ ty = const_cast<sem::Type*>(ty->UnwrapRef());
+
+ if (auto* str = ty->As<sem::Struct>()) {
+ if (str->StorageClassUsage().count(sc)) {
+ return true; // Already applied
+ }
+
+ str->AddUsage(sc);
+
+ for (auto* member : str->Members()) {
+ if (!ApplyStorageClassUsageToType(sc, member->Type(), usage)) {
+ std::stringstream err;
+ err << "while analysing structure member " << TypeNameOf(str) << "."
+ << builder_->Symbols().NameFor(member->Declaration()->symbol);
+ AddNote(err.str(), member->Declaration()->source);
+ return false;
+ }
+ }
+ return true;
+ }
+
+ if (auto* arr = ty->As<sem::Array>()) {
+ if (arr->IsRuntimeSized() && sc != ast::StorageClass::kStorage) {
+ AddError(
+ "runtime-sized arrays can only be used in the <storage> storage "
+ "class",
+ usage);
+ return false;
+ }
+
+ return ApplyStorageClassUsageToType(
+ sc, const_cast<sem::Type*>(arr->ElemType()), usage);
+ }
+
+ if (ast::IsHostShareable(sc) && !IsHostShareable(ty)) {
+ std::stringstream err;
+ err << "Type '" << TypeNameOf(ty) << "' cannot be used in storage class '"
+ << sc << "' as it is non-host-shareable";
+ AddError(err.str(), usage);
+ return false;
+ }
+
+ return true;
+}
+
+template <typename SEM, typename F>
+SEM* Resolver::StatementScope(const ast::Statement* ast,
+ SEM* sem,
+ F&& callback) {
+ builder_->Sem().Add(ast, sem);
+
+ auto* as_compound =
+ As<sem::CompoundStatement, CastFlags::kDontErrorOnImpossibleCast>(sem);
+ auto* as_block =
+ As<sem::BlockStatement, CastFlags::kDontErrorOnImpossibleCast>(sem);
+
+ TINT_SCOPED_ASSIGNMENT(current_statement_, sem);
+ TINT_SCOPED_ASSIGNMENT(
+ current_compound_statement_,
+ as_compound ? as_compound : current_compound_statement_);
+ TINT_SCOPED_ASSIGNMENT(current_block_, as_block ? as_block : current_block_);
+
+ if (!callback()) {
+ return nullptr;
+ }
+
+ return sem;
+}
+
+std::string Resolver::VectorPretty(uint32_t size,
+ const sem::Type* element_type) {
+ sem::Vector vec_type(element_type, size);
+ return vec_type.FriendlyName(builder_->Symbols());
+}
+
+bool Resolver::Mark(const ast::Node* node) {
+ if (node == nullptr) {
+ TINT_ICE(Resolver, diagnostics_) << "Resolver::Mark() called with nullptr";
+ return false;
+ }
+ if (marked_.emplace(node).second) {
+ return true;
+ }
+ TINT_ICE(Resolver, diagnostics_)
+ << "AST node '" << node->TypeInfo().name
+ << "' was encountered twice in the same AST of a Program\n"
+ << "At: " << node->source << "\n"
+ << "Pointer: " << node;
+ return false;
+}
+
+void Resolver::AddError(const std::string& msg, const Source& source) const {
+ diagnostics_.add_error(diag::System::Resolver, msg, source);
+}
+
+void Resolver::AddWarning(const std::string& msg, const Source& source) const {
+ diagnostics_.add_warning(diag::System::Resolver, msg, source);
+}
+
+void Resolver::AddNote(const std::string& msg, const Source& source) const {
+ diagnostics_.add_note(diag::System::Resolver, msg, source);
+}
+
+// https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section
+bool Resolver::IsPlain(const sem::Type* type) const {
+ return type->is_scalar() ||
+ type->IsAnyOf<sem::Atomic, sem::Vector, sem::Matrix, sem::Array,
+ sem::Struct>();
+}
+
+// https://gpuweb.github.io/gpuweb/wgsl/#fixed-footprint-types
+bool Resolver::IsFixedFootprint(const sem::Type* type) const {
+ return Switch(
+ type, //
+ [&](const sem::Vector*) { return true; }, //
+ [&](const sem::Matrix*) { return true; }, //
+ [&](const sem::Atomic*) { return true; },
+ [&](const sem::Array* arr) {
+ return !arr->IsRuntimeSized() && IsFixedFootprint(arr->ElemType());
+ },
+ [&](const sem::Struct* str) {
+ for (auto* member : str->Members()) {
+ if (!IsFixedFootprint(member->Type())) {
+ return false;
+ }
+ }
+ return true;
+ },
+ [&](Default) { return type->is_scalar(); });
+}
+
+// https://gpuweb.github.io/gpuweb/wgsl.html#storable-types
+bool Resolver::IsStorable(const sem::Type* type) const {
+ return IsPlain(type) || type->IsAnyOf<sem::Texture, sem::Sampler>();
+}
+
+// https://gpuweb.github.io/gpuweb/wgsl.html#host-shareable-types
+bool Resolver::IsHostShareable(const sem::Type* type) const {
+ if (type->IsAnyOf<sem::I32, sem::U32, sem::F32>()) {
+ return true;
+ }
+ return Switch(
+ type, //
+ [&](const sem::Vector* vec) { return IsHostShareable(vec->type()); },
+ [&](const sem::Matrix* mat) { return IsHostShareable(mat->type()); },
+ [&](const sem::Array* arr) { return IsHostShareable(arr->ElemType()); },
+ [&](const sem::Struct* str) {
+ for (auto* member : str->Members()) {
+ if (!IsHostShareable(member->Type())) {
+ return false;
+ }
+ }
+ return true;
+ },
+ [&](const sem::Atomic* atomic) {
+ return IsHostShareable(atomic->Type());
+ });
+}
+
+bool Resolver::IsBuiltin(Symbol symbol) const {
+ std::string name = builder_->Symbols().NameFor(symbol);
+ return sem::ParseBuiltinType(name) != sem::BuiltinType::kNone;
+}
+
+bool Resolver::IsCallStatement(const ast::Expression* expr) const {
+ return current_statement_ &&
+ Is<ast::CallStatement>(current_statement_->Declaration(),
+ [&](auto* stmt) { return stmt->expr == expr; });
+}
+
+const ast::Statement* Resolver::ClosestContinuing(bool stop_at_loop) const {
+ for (const auto* s = current_statement_; s != nullptr; s = s->Parent()) {
+ if (stop_at_loop && s->Is<sem::LoopStatement>()) {
+ break;
+ }
+ if (s->Is<sem::LoopContinuingBlockStatement>()) {
+ return s->Declaration();
+ }
+ if (auto* f = As<sem::ForLoopStatement>(s->Parent())) {
+ if (f->Declaration()->continuing == s->Declaration()) {
+ return s->Declaration();
+ }
+ if (stop_at_loop) {
+ break;
+ }
+ }
+ }
+ return nullptr;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Resolver::TypeConversionSig
+////////////////////////////////////////////////////////////////////////////////
+bool Resolver::TypeConversionSig::operator==(
+ const TypeConversionSig& rhs) const {
+ return target == rhs.target && source == rhs.source;
+}
+std::size_t Resolver::TypeConversionSig::Hasher::operator()(
+ const TypeConversionSig& sig) const {
+ return utils::Hash(sig.target, sig.source);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Resolver::TypeConstructorSig
+////////////////////////////////////////////////////////////////////////////////
+Resolver::TypeConstructorSig::TypeConstructorSig(
+ const sem::Type* ty,
+ const std::vector<const sem::Type*> params)
+ : type(ty), parameters(params) {}
+Resolver::TypeConstructorSig::TypeConstructorSig(const TypeConstructorSig&) =
+ default;
+Resolver::TypeConstructorSig::~TypeConstructorSig() = default;
+
+bool Resolver::TypeConstructorSig::operator==(
+ const TypeConstructorSig& rhs) const {
+ return type == rhs.type && parameters == rhs.parameters;
+}
+std::size_t Resolver::TypeConstructorSig::Hasher::operator()(
+ const TypeConstructorSig& sig) const {
+ return utils::Hash(sig.type, sig.parameters);
+}
+
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
new file mode 100644
index 0000000..fe7e865
--- /dev/null
+++ b/src/tint/resolver/resolver.h
@@ -0,0 +1,545 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_RESOLVER_RESOLVER_H_
+#define SRC_TINT_RESOLVER_RESOLVER_H_
+
+#include <memory>
+#include <set>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "src/tint/builtin_table.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/resolver/dependency_graph.h"
+#include "src/tint/scope_stack.h"
+#include "src/tint/sem/binding_point.h"
+#include "src/tint/sem/block_statement.h"
+#include "src/tint/sem/constant.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/struct.h"
+#include "src/tint/utils/map.h"
+#include "src/tint/utils/unique_vector.h"
+
+namespace tint {
+
+// Forward declarations
+namespace ast {
+class IndexAccessorExpression;
+class BinaryExpression;
+class BitcastExpression;
+class CallExpression;
+class CallStatement;
+class CaseStatement;
+class ForLoopStatement;
+class Function;
+class IdentifierExpression;
+class LoopStatement;
+class MemberAccessorExpression;
+class ReturnStatement;
+class SwitchStatement;
+class UnaryOpExpression;
+class Variable;
+} // namespace ast
+namespace sem {
+class Array;
+class Atomic;
+class BlockStatement;
+class Builtin;
+class CaseStatement;
+class ElseStatement;
+class ForLoopStatement;
+class IfStatement;
+class LoopStatement;
+class Statement;
+class SwitchStatement;
+class TypeConstructor;
+} // namespace sem
+
+namespace resolver {
+
+/// Resolves types for all items in the given tint program
+class Resolver {
+ public:
+ /// Constructor
+ /// @param builder the program builder
+ explicit Resolver(ProgramBuilder* builder);
+
+ /// Destructor
+ ~Resolver();
+
+ /// @returns error messages from the resolver
+ std::string error() const { return diagnostics_.str(); }
+
+ /// @returns true if the resolver was successful
+ bool Resolve();
+
+ /// @param type the given type
+ /// @returns true if the given type is a plain type
+ bool IsPlain(const sem::Type* type) const;
+
+ /// @param type the given type
+ /// @returns true if the given type is a fixed-footprint type
+ bool IsFixedFootprint(const sem::Type* type) const;
+
+ /// @param type the given type
+ /// @returns true if the given type is storable
+ bool IsStorable(const sem::Type* type) const;
+
+ /// @param type the given type
+ /// @returns true if the given type is host-shareable
+ bool IsHostShareable(const sem::Type* type) const;
+
+ private:
+ /// Describes the context in which a variable is declared
+ enum class VariableKind { kParameter, kLocal, kGlobal };
+
+ std::set<std::pair<const sem::Type*, ast::StorageClass>>
+ valid_type_storage_layouts_;
+
+ /// Structure holding semantic information about a block (i.e. scope), such as
+ /// parent block and variables declared in the block.
+ /// Used to validate variable scoping rules.
+ struct BlockInfo {
+ enum class Type { kGeneric, kLoop, kLoopContinuing, kSwitchCase };
+
+ BlockInfo(const ast::BlockStatement* block, Type type, BlockInfo* parent);
+ ~BlockInfo();
+
+ template <typename Pred>
+ BlockInfo* FindFirstParent(Pred&& pred) {
+ BlockInfo* curr = this;
+ while (curr && !pred(curr)) {
+ curr = curr->parent;
+ }
+ return curr;
+ }
+
+ BlockInfo* FindFirstParent(BlockInfo::Type ty) {
+ return FindFirstParent(
+ [ty](auto* block_info) { return block_info->type == ty; });
+ }
+
+ ast::BlockStatement const* const block;
+ const Type type;
+ BlockInfo* const parent;
+ std::vector<const ast::Variable*> decls;
+
+ // first_continue is set to the index of the first variable in decls
+ // declared after the first continue statement in a loop block, if any.
+ constexpr static size_t kNoContinue = size_t(~0);
+ size_t first_continue = kNoContinue;
+ };
+
+ // Structure holding information for a TypeDecl
+ struct TypeDeclInfo {
+ ast::TypeDecl const* const ast;
+ sem::Type* const sem;
+ };
+
+ /// Resolves the program, without creating final the semantic nodes.
+ /// @returns true on success, false on error
+ bool ResolveInternal();
+
+ bool ValidatePipelineStages();
+
+ /// Creates the nodes and adds them to the sem::Info mappings of the
+ /// ProgramBuilder.
+ void CreateSemanticNodes() const;
+
+ /// Retrieves information for the requested import.
+ /// @param src the source of the import
+ /// @param path the import path
+ /// @param name the method name to get information on
+ /// @param params the parameters to the method call
+ /// @param id out parameter for the external call ID. Must not be a nullptr.
+ /// @returns the return type of `name` in `path` or nullptr on error.
+ sem::Type* GetImportData(const Source& src,
+ const std::string& path,
+ const std::string& name,
+ const ast::ExpressionList& params,
+ uint32_t* id);
+
+ //////////////////////////////////////////////////////////////////////////////
+ // AST and Type traversal methods
+ //////////////////////////////////////////////////////////////////////////////
+
+ // Expression resolving methods
+ // Returns the semantic node pointer on success, nullptr on failure.
+ sem::Expression* IndexAccessor(const ast::IndexAccessorExpression*);
+ sem::Expression* Binary(const ast::BinaryExpression*);
+ sem::Expression* Bitcast(const ast::BitcastExpression*);
+ sem::Call* Call(const ast::CallExpression*);
+ sem::Expression* Expression(const ast::Expression*);
+ sem::Function* Function(const ast::Function*);
+ sem::Call* FunctionCall(const ast::CallExpression*,
+ sem::Function* target,
+ const std::vector<const sem::Expression*> args,
+ sem::Behaviors arg_behaviors);
+ sem::Expression* Identifier(const ast::IdentifierExpression*);
+ sem::Call* BuiltinCall(const ast::CallExpression*,
+ sem::BuiltinType,
+ const std::vector<const sem::Expression*> args,
+ const std::vector<const sem::Type*> arg_tys);
+ sem::Expression* Literal(const ast::LiteralExpression*);
+ sem::Expression* MemberAccessor(const ast::MemberAccessorExpression*);
+ sem::Call* TypeConversion(const ast::CallExpression* expr,
+ const sem::Type* ty,
+ const sem::Expression* arg,
+ const sem::Type* arg_ty);
+ sem::Call* TypeConstructor(const ast::CallExpression* expr,
+ const sem::Type* ty,
+ const std::vector<const sem::Expression*> args,
+ const std::vector<const sem::Type*> arg_tys);
+ sem::Expression* UnaryOp(const ast::UnaryOpExpression*);
+
+ // Statement resolving methods
+ // Each return true on success, false on failure.
+ sem::Statement* AssignmentStatement(const ast::AssignmentStatement*);
+ sem::BlockStatement* BlockStatement(const ast::BlockStatement*);
+ sem::Statement* BreakStatement(const ast::BreakStatement*);
+ sem::Statement* CallStatement(const ast::CallStatement*);
+ sem::CaseStatement* CaseStatement(const ast::CaseStatement*);
+ sem::Statement* ContinueStatement(const ast::ContinueStatement*);
+ sem::Statement* DiscardStatement(const ast::DiscardStatement*);
+ sem::ElseStatement* ElseStatement(const ast::ElseStatement*);
+ sem::Statement* FallthroughStatement(const ast::FallthroughStatement*);
+ sem::ForLoopStatement* ForLoopStatement(const ast::ForLoopStatement*);
+ sem::GlobalVariable* GlobalVariable(const ast::Variable*);
+ sem::Statement* Parameter(const ast::Variable*);
+ sem::IfStatement* IfStatement(const ast::IfStatement*);
+ sem::LoopStatement* LoopStatement(const ast::LoopStatement*);
+ sem::Statement* ReturnStatement(const ast::ReturnStatement*);
+ sem::Statement* Statement(const ast::Statement*);
+ sem::SwitchStatement* SwitchStatement(const ast::SwitchStatement* s);
+ sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
+ bool Statements(const ast::StatementList&);
+
+ // AST and Type validation methods
+ // Each return true on success, false on failure.
+ bool ValidateAlias(const ast::Alias*);
+ bool ValidateArray(const sem::Array* arr, const Source& source);
+ bool ValidateArrayStrideAttribute(const ast::StrideAttribute* attr,
+ uint32_t el_size,
+ uint32_t el_align,
+ const Source& source);
+ bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s);
+ bool ValidateAtomicVariable(const sem::Variable* var);
+ bool ValidateAssignment(const ast::AssignmentStatement* a);
+ bool ValidateBitcast(const ast::BitcastExpression* cast, const sem::Type* to);
+ bool ValidateBreakStatement(const sem::Statement* stmt);
+ bool ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr,
+ const sem::Type* storage_type,
+ const bool is_input);
+ bool ValidateContinueStatement(const sem::Statement* stmt);
+ bool ValidateDiscardStatement(const sem::Statement* stmt);
+ bool ValidateElseStatement(const sem::ElseStatement* stmt);
+ bool ValidateEntryPoint(const sem::Function* func);
+ bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt);
+ bool ValidateFallthroughStatement(const sem::Statement* stmt);
+ bool ValidateFunction(const sem::Function* func);
+ bool ValidateFunctionCall(const sem::Call* call);
+ bool ValidateGlobalVariable(const sem::Variable* var);
+ bool ValidateIfStatement(const sem::IfStatement* stmt);
+ bool ValidateInterpolateAttribute(const ast::InterpolateAttribute* attr,
+ const sem::Type* storage_type);
+ bool ValidateBuiltinCall(const sem::Call* call);
+ bool ValidateLocationAttribute(const ast::LocationAttribute* location,
+ const sem::Type* type,
+ std::unordered_set<uint32_t>& locations,
+ const Source& source,
+ const bool is_input = false);
+ bool ValidateLoopStatement(const sem::LoopStatement* stmt);
+ bool ValidateMatrix(const sem::Matrix* ty, const Source& source);
+ bool ValidateFunctionParameter(const ast::Function* func,
+ const sem::Variable* var);
+ bool ValidateParameter(const ast::Function* func, const sem::Variable* var);
+ bool ValidateReturn(const ast::ReturnStatement* ret);
+ bool ValidateStatements(const ast::StatementList& stmts);
+ bool ValidateStorageTexture(const ast::StorageTexture* t);
+ bool ValidateStructure(const sem::Struct* str);
+ bool ValidateStructureConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Struct* struct_type);
+ bool ValidateSwitch(const ast::SwitchStatement* s);
+ bool ValidateVariable(const sem::Variable* var);
+ bool ValidateVariableConstructorOrCast(const ast::Variable* var,
+ ast::StorageClass storage_class,
+ const sem::Type* storage_type,
+ const sem::Type* rhs_type);
+ bool ValidateVector(const sem::Vector* ty, const Source& source);
+ bool ValidateVectorConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Vector* vec_type);
+ bool ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Matrix* matrix_type);
+ bool ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Type* type);
+ bool ValidateArrayConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Array* arr_type);
+ bool ValidateTextureBuiltinFunction(const sem::Call* call);
+ bool ValidateNoDuplicateAttributes(const ast::AttributeList& attributes);
+ bool ValidateStorageClassLayout(const sem::Type* type,
+ ast::StorageClass sc,
+ Source source);
+ bool ValidateStorageClassLayout(const sem::Variable* var);
+
+ /// @returns true if the attribute list contains a
+ /// ast::DisableValidationAttribute with the validation mode equal to
+ /// `validation`
+ bool IsValidationDisabled(const ast::AttributeList& attributes,
+ ast::DisabledValidation validation) const;
+
+ /// @returns true if the attribute list does not contains a
+ /// ast::DisableValidationAttribute with the validation mode equal to
+ /// `validation`
+ bool IsValidationEnabled(const ast::AttributeList& attributes,
+ ast::DisabledValidation validation) const;
+
+ /// Resolves the WorkgroupSize for the given function, assigning it to
+ /// current_function_
+ bool WorkgroupSize(const ast::Function*);
+
+ /// @returns the sem::Type for the ast::Type `ty`, building it if it
+ /// hasn't been constructed already. If an error is raised, nullptr is
+ /// returned.
+ /// @param ty the ast::Type
+ sem::Type* Type(const ast::Type* ty);
+
+ /// @param named_type the named type to resolve
+ /// @returns the resolved semantic type
+ sem::Type* TypeDecl(const ast::TypeDecl* named_type);
+
+ /// Builds and returns the semantic information for the array `arr`.
+ /// This method does not mark the ast::Array node, nor attach the generated
+ /// semantic information to the AST node.
+ /// @returns the semantic Array information, or nullptr if an error is
+ /// raised.
+ /// @param arr the Array to get semantic information for
+ sem::Array* Array(const ast::Array* arr);
+
+ /// Builds and returns the semantic information for the alias `alias`.
+ /// This method does not mark the ast::Alias node, nor attach the generated
+ /// semantic information to the AST node.
+ /// @returns the aliased type, or nullptr if an error is raised.
+ sem::Type* Alias(const ast::Alias* alias);
+
+ /// Builds and returns the semantic information for the structure `str`.
+ /// This method does not mark the ast::Struct node, nor attach the generated
+ /// semantic information to the AST node.
+ /// @returns the semantic Struct information, or nullptr if an error is
+ /// raised.
+ sem::Struct* Structure(const ast::Struct* str);
+
+ /// @returns the semantic info for the variable `var`. If an error is
+ /// raised, nullptr is returned.
+ /// @note this method does not resolve the attributes as these are
+ /// context-dependent (global, local, parameter)
+ /// @param var the variable to create or return the `VariableInfo` for
+ /// @param kind what kind of variable we are declaring
+ /// @param index the index of the parameter, if this variable is a parameter
+ sem::Variable* Variable(const ast::Variable* var,
+ VariableKind kind,
+ uint32_t index = 0);
+
+ /// Records the storage class usage for the given type, and any transient
+ /// dependencies of the type. Validates that the type can be used for the
+ /// given storage class, erroring if it cannot.
+ /// @param sc the storage class to apply to the type and transitent types
+ /// @param ty the type to apply the storage class on
+ /// @param usage the Source of the root variable declaration that uses the
+ /// given type and storage class. Used for generating sensible error
+ /// messages.
+ /// @returns true on success, false on error
+ bool ApplyStorageClassUsageToType(ast::StorageClass sc,
+ sem::Type* ty,
+ const Source& usage);
+
+ /// @param storage_class the storage class
+ /// @returns the default access control for the given storage class
+ ast::Access DefaultAccessForStorageClass(ast::StorageClass storage_class);
+
+ /// Allocate constant IDs for pipeline-overridable constants.
+ void AllocateOverridableConstantIds();
+
+ /// Set the shadowing information on variable declarations.
+ /// @note this method must only be called after all semantic nodes are built.
+ void SetShadows();
+
+ /// @returns the resolved type of the ast::Expression `expr`
+ /// @param expr the expression
+ sem::Type* TypeOf(const ast::Expression* expr);
+
+ /// @returns the type name of the given semantic type, unwrapping
+ /// references.
+ std::string TypeNameOf(const sem::Type* ty);
+
+ /// @returns the type name of the given semantic type, without unwrapping
+ /// references.
+ std::string RawTypeNameOf(const sem::Type* ty);
+
+ /// @returns the semantic type of the AST literal `lit`
+ /// @param lit the literal
+ sem::Type* TypeOf(const ast::LiteralExpression* lit);
+
+ /// StatementScope() does the following:
+ /// * Creates the AST -> SEM mapping.
+ /// * Assigns `sem` to #current_statement_
+ /// * Assigns `sem` to #current_compound_statement_ if `sem` derives from
+ /// sem::CompoundStatement.
+ /// * Assigns `sem` to #current_block_ if `sem` derives from
+ /// sem::BlockStatement.
+ /// * Then calls `callback`.
+ /// * Before returning #current_statement_, #current_compound_statement_, and
+ /// #current_block_ are restored to their original values.
+ /// @returns `sem` if `callback` returns true, otherwise `nullptr`.
+ template <typename SEM, typename F>
+ SEM* StatementScope(const ast::Statement* ast, SEM* sem, F&& callback);
+
+ /// Returns a human-readable string representation of the vector type name
+ /// with the given parameters.
+ /// @param size the vector dimension
+ /// @param element_type scalar vector sub-element type
+ /// @return pretty string representation
+ std::string VectorPretty(uint32_t size, const sem::Type* element_type);
+
+ /// Mark records that the given AST node has been visited, and asserts that
+ /// the given node has not already been seen. Diamonds in the AST are
+ /// illegal.
+ /// @param node the AST node.
+ /// @returns true on success, false on error
+ bool Mark(const ast::Node* node);
+
+ /// Adds the given error message to the diagnostics
+ void AddError(const std::string& msg, const Source& source) const;
+
+ /// Adds the given warning message to the diagnostics
+ void AddWarning(const std::string& msg, const Source& source) const;
+
+ /// Adds the given note message to the diagnostics
+ void AddNote(const std::string& msg, const Source& source) const;
+
+ //////////////////////////////////////////////////////////////////////////////
+ /// Constant value evaluation methods
+ //////////////////////////////////////////////////////////////////////////////
+ /// Cast `Value` to `target_type`
+ /// @return the casted value
+ sem::Constant ConstantCast(const sem::Constant& value,
+ const sem::Type* target_elem_type);
+
+ sem::Constant EvaluateConstantValue(const ast::Expression* expr,
+ const sem::Type* type);
+ sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal,
+ const sem::Type* type);
+ sem::Constant EvaluateConstantValue(const ast::CallExpression* call,
+ const sem::Type* type);
+
+ /// Sem is a helper for obtaining the semantic node for the given AST node.
+ template <typename SEM = sem::Info::InferFromAST,
+ typename AST_OR_TYPE = CastableBase>
+ auto* Sem(const AST_OR_TYPE* ast) {
+ using T = sem::Info::GetResultType<SEM, AST_OR_TYPE>;
+ auto* sem = builder_->Sem().Get(ast);
+ if (!sem) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "AST node '" << ast->TypeInfo().name << "' had no semantic info\n"
+ << "At: " << ast->source << "\n"
+ << "Pointer: " << ast;
+ }
+ return const_cast<T*>(As<T>(sem));
+ }
+
+ /// @returns true if the symbol is the name of a builtin function.
+ bool IsBuiltin(Symbol) const;
+
+ /// @returns true if `expr` is the current CallStatement's CallExpression
+ bool IsCallStatement(const ast::Expression* expr) const;
+
+ /// Searches the current statement and up through parents of the current
+ /// statement looking for a loop or for-loop continuing statement.
+ /// @returns the closest continuing statement to the current statement that
+ /// (transitively) owns the current statement.
+ /// @param stop_at_loop if true then the function will return nullptr if a
+ /// loop or for-loop was found before the continuing.
+ const ast::Statement* ClosestContinuing(bool stop_at_loop) const;
+
+ /// @returns the resolved symbol (function, type or variable) for the given
+ /// ast::Identifier or ast::TypeName cast to the given semantic type.
+ template <typename SEM = sem::Node>
+ SEM* ResolvedSymbol(const ast::Node* node) {
+ auto* resolved = utils::Lookup(dependencies_.resolved_symbols, node);
+ return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(resolved))
+ : nullptr;
+ }
+
+ struct TypeConversionSig {
+ const sem::Type* target;
+ const sem::Type* source;
+
+ bool operator==(const TypeConversionSig&) const;
+
+ /// Hasher provides a hash function for the TypeConversionSig
+ struct Hasher {
+ /// @param sig the TypeConversionSig to create a hash for
+ /// @return the hash value
+ std::size_t operator()(const TypeConversionSig& sig) const;
+ };
+ };
+
+ struct TypeConstructorSig {
+ const sem::Type* type;
+ const std::vector<const sem::Type*> parameters;
+
+ TypeConstructorSig(const sem::Type* ty,
+ const std::vector<const sem::Type*> params);
+ TypeConstructorSig(const TypeConstructorSig&);
+ ~TypeConstructorSig();
+ bool operator==(const TypeConstructorSig&) const;
+
+ /// Hasher provides a hash function for the TypeConstructorSig
+ struct Hasher {
+ /// @param sig the TypeConstructorSig to create a hash for
+ /// @return the hash value
+ std::size_t operator()(const TypeConstructorSig& sig) const;
+ };
+ };
+
+ ProgramBuilder* const builder_;
+ diag::List& diagnostics_;
+ std::unique_ptr<BuiltinTable> const builtin_table_;
+ DependencyGraph dependencies_;
+ std::vector<sem::Function*> entry_points_;
+ std::unordered_map<const sem::Type*, const Source&> atomic_composite_info_;
+ std::unordered_set<const ast::Node*> marked_;
+ std::unordered_map<uint32_t, const sem::Variable*> constant_ids_;
+ std::unordered_map<TypeConversionSig,
+ sem::CallTarget*,
+ TypeConversionSig::Hasher>
+ type_conversions_;
+ std::unordered_map<TypeConstructorSig,
+ sem::CallTarget*,
+ TypeConstructorSig::Hasher>
+ type_ctors_;
+
+ sem::Function* current_function_ = nullptr;
+ sem::Statement* current_statement_ = nullptr;
+ sem::CompoundStatement* current_compound_statement_ = nullptr;
+ sem::BlockStatement* current_block_ = nullptr;
+};
+
+} // namespace resolver
+} // namespace tint
+
+#endif // SRC_TINT_RESOLVER_RESOLVER_H_
diff --git a/src/tint/resolver/resolver_behavior_test.cc b/src/tint/resolver/resolver_behavior_test.cc
new file mode 100644
index 0000000..7cc6cb1
--- /dev/null
+++ b/src/tint/resolver/resolver_behavior_test.cc
@@ -0,0 +1,659 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gtest/gtest.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/expression.h"
+#include "src/tint/sem/for_loop_statement.h"
+#include "src/tint/sem/if_statement.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+class ResolverBehaviorTest : public ResolverTest {
+ protected:
+ void SetUp() override {
+ // Create a function called 'DiscardOrNext' which returns an i32, and has
+ // the behavior of {Discard, Return}, which when called, will have the
+ // behavior {Discard, Next}.
+ Func("DiscardOrNext", {}, ty.i32(),
+ {
+ If(true, Block(Discard())),
+ Return(1),
+ });
+ }
+};
+
+TEST_F(ResolverBehaviorTest, ExprBinaryOp_LHS) {
+ auto* stmt = Decl(Var("lhs", ty.i32(), Add(Call("DiscardOrNext"), 1)));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, ExprBinaryOp_RHS) {
+ auto* stmt = Decl(Var("lhs", ty.i32(), Add(1, Call("DiscardOrNext"))));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, ExprBitcastOp) {
+ auto* stmt = Decl(Var("lhs", ty.u32(), Bitcast<u32>(Call("DiscardOrNext"))));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, ExprIndex_Arr) {
+ Func("ArrayDiscardOrNext", {}, ty.array<i32, 4>(),
+ {
+ If(true, Block(Discard())),
+ Return(Construct(ty.array<i32, 4>())),
+ });
+
+ auto* stmt =
+ Decl(Var("lhs", ty.i32(), IndexAccessor(Call("ArrayDiscardOrNext"), 1)));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, ExprIndex_Idx) {
+ auto* stmt =
+ Decl(Var("lhs", ty.i32(), IndexAccessor("arr", Call("DiscardOrNext"))));
+ WrapInFunction(Decl(Var("arr", ty.array<i32, 4>())), //
+ stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, ExprUnaryOp) {
+ auto* stmt = Decl(Var("lhs", ty.i32(),
+ create<ast::UnaryOpExpression>(
+ ast::UnaryOp::kComplement, Call("DiscardOrNext"))));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtAssign) {
+ auto* stmt = Assign("lhs", "rhs");
+ WrapInFunction(Decl(Var("lhs", ty.i32())), //
+ Decl(Var("rhs", ty.i32())), //
+ stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtAssign_LHSDiscardOrNext) {
+ auto* stmt = Assign(IndexAccessor("lhs", Call("DiscardOrNext")), 1);
+ WrapInFunction(Decl(Var("lhs", ty.array<i32, 4>())), //
+ stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtAssign_RHSDiscardOrNext) {
+ auto* stmt = Assign("lhs", Call("DiscardOrNext"));
+ WrapInFunction(Decl(Var("lhs", ty.i32())), //
+ stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtBlockEmpty) {
+ auto* stmt = Block();
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtBlockSingleStmt) {
+ auto* stmt = Block(Discard());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtCallReturn) {
+ Func("f", {}, ty.void_(), {Return()});
+ auto* stmt = CallStmt(Call("f"));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtCallFuncDiscard) {
+ Func("f", {}, ty.void_(), {Discard()});
+ auto* stmt = CallStmt(Call("f"));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtCallFuncMayDiscard) {
+ auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr,
+ nullptr, Block(Break()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtBreak) {
+ auto* stmt = Break();
+ WrapInFunction(Loop(Block(stmt)));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kBreak);
+}
+
+TEST_F(ResolverBehaviorTest, StmtContinue) {
+ auto* stmt = Continue();
+ WrapInFunction(Loop(Block(If(true, Block(Break())), //
+ stmt)));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kContinue);
+}
+
+TEST_F(ResolverBehaviorTest, StmtDiscard) {
+ auto* stmt = Discard();
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_NoExit) {
+ auto* stmt = For(Source{{12, 34}}, nullptr, nullptr, nullptr, Block());
+ WrapInFunction(stmt);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: for-loop does not exit");
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopBreak) {
+ auto* stmt = For(nullptr, nullptr, nullptr, Block(Break()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopContinue_NoExit) {
+ auto* stmt =
+ For(Source{{12, 34}}, nullptr, nullptr, nullptr, Block(Continue()));
+ WrapInFunction(stmt);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: for-loop does not exit");
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopDiscard) {
+ auto* stmt = For(nullptr, nullptr, nullptr, Block(Discard()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopReturn) {
+ auto* stmt = For(nullptr, nullptr, nullptr, Block(Return()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn);
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopBreak_InitCallFuncMayDiscard) {
+ auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr,
+ nullptr, Block(Break()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_InitCallFuncMayDiscard) {
+ auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr,
+ nullptr, Block());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_CondTrue) {
+ auto* stmt = For(nullptr, true, nullptr, Block());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behaviors(sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_CondCallFuncMayDiscard) {
+ auto* stmt = For(nullptr, Equal(Call("DiscardOrNext"), 1), nullptr, Block());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock) {
+ auto* stmt = If(true, Block());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard) {
+ auto* stmt = If(true, Block(Discard()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock_ElseDiscard) {
+ auto* stmt = If(true, Block(), Else(Block(Discard())));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard_ElseDiscard) {
+ auto* stmt = If(true, Block(Discard()), Else(Block(Discard())));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfCallFuncMayDiscard_ThenEmptyBlock) {
+ auto* stmt = If(Equal(Call("DiscardOrNext"), 1), Block());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock_ElseCallFuncMayDiscard) {
+ auto* stmt = If(true, Block(), //
+ Else(Equal(Call("DiscardOrNext"), 1), Block()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtLetDecl) {
+ auto* stmt = Decl(Const("v", ty.i32(), Expr(1)));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtLetDecl_RHSDiscardOrNext) {
+ auto* stmt = Decl(Const("lhs", ty.i32(), Call("DiscardOrNext")));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopEmpty_NoExit) {
+ auto* stmt = Loop(Source{{12, 34}}, Block());
+ WrapInFunction(stmt);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: loop does not exit");
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopBreak) {
+ auto* stmt = Loop(Block(Break()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopContinue_NoExit) {
+ auto* stmt = Loop(Source{{12, 34}}, Block(Continue()));
+ WrapInFunction(stmt);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: loop does not exit");
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopDiscard) {
+ auto* stmt = Loop(Block(Discard()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopReturn) {
+ auto* stmt = Loop(Block(Return()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn);
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopEmpty_ContEmpty_NoExit) {
+ auto* stmt = Loop(Source{{12, 34}}, Block(), Block());
+ WrapInFunction(stmt);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: loop does not exit");
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopEmpty_ContIfTrueBreak) {
+ auto* stmt = Loop(Block(), Block(If(true, Block(Break()))));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtReturn) {
+ auto* stmt = Return();
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn);
+}
+
+TEST_F(ResolverBehaviorTest, StmtReturn_DiscardOrNext) {
+ auto* stmt = Return(Call("DiscardOrNext"));
+ Func("F", {}, ty.i32(), {stmt});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kReturn, sem::Behavior::kDiscard));
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondTrue_DefaultEmpty) {
+ auto* stmt = Switch(1, DefaultCase(Block()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultEmpty) {
+ auto* stmt = Switch(1, DefaultCase(Block()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultDiscard) {
+ auto* stmt = Switch(1, DefaultCase(Block(Discard())));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultReturn) {
+ auto* stmt = Switch(1, DefaultCase(Block(Return())));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn);
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultEmpty) {
+ auto* stmt = Switch(1, Case(Expr(0), Block()), DefaultCase(Block()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultDiscard) {
+ auto* stmt = Switch(1, Case(Expr(0), Block()), DefaultCase(Block(Discard())));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kNext, sem::Behavior::kDiscard));
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultReturn) {
+ auto* stmt = Switch(1, Case(Expr(0), Block()), DefaultCase(Block(Return())));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kNext, sem::Behavior::kReturn));
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultEmpty) {
+ auto* stmt = Switch(1, Case(Expr(0), Block(Discard())), DefaultCase(Block()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest,
+ StmtSwitch_CondLiteral_Case0Discard_DefaultDiscard) {
+ auto* stmt =
+ Switch(1, Case(Expr(0), Block(Discard())), DefaultCase(Block(Discard())));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest,
+ StmtSwitch_CondLiteral_Case0Discard_DefaultReturn) {
+ auto* stmt =
+ Switch(1, Case(Expr(0), Block(Discard())), DefaultCase(Block(Return())));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kReturn));
+}
+
+TEST_F(ResolverBehaviorTest,
+ StmtSwitch_CondLiteral_Case0Discard_Case1Return_DefaultEmpty) {
+ auto* stmt = Switch(1, //
+ Case(Expr(0), Block(Discard())), //
+ Case(Expr(1), Block(Return())), //
+ DefaultCase(Block()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext,
+ sem::Behavior::kReturn));
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondCallFuncMayDiscard_DefaultEmpty) {
+ auto* stmt = Switch(Call("DiscardOrNext"), DefaultCase(Block()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtVarDecl) {
+ auto* stmt = Decl(Var("v", ty.i32()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtVarDecl_RHSDiscardOrNext) {
+ auto* stmt = Decl(Var("lhs", ty.i32(), Call("DiscardOrNext")));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/resolver_constants.cc b/src/tint/resolver/resolver_constants.cc
new file mode 100644
index 0000000..a83ae73
--- /dev/null
+++ b/src/tint/resolver/resolver_constants.cc
@@ -0,0 +1,144 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "src/tint/sem/constant.h"
+#include "src/tint/sem/type_constructor.h"
+#include "src/tint/utils/map.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using i32 = ProgramBuilder::i32;
+using u32 = ProgramBuilder::u32;
+using f32 = ProgramBuilder::f32;
+
+} // namespace
+
+sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr,
+ const sem::Type* type) {
+ if (auto* e = expr->As<ast::LiteralExpression>()) {
+ return EvaluateConstantValue(e, type);
+ }
+ if (auto* e = expr->As<ast::CallExpression>()) {
+ return EvaluateConstantValue(e, type);
+ }
+ return {};
+}
+
+sem::Constant Resolver::EvaluateConstantValue(
+ const ast::LiteralExpression* literal,
+ const sem::Type* type) {
+ if (auto* lit = literal->As<ast::SintLiteralExpression>()) {
+ return {type, {lit->ValueAsI32()}};
+ }
+ if (auto* lit = literal->As<ast::UintLiteralExpression>()) {
+ return {type, {lit->ValueAsU32()}};
+ }
+ if (auto* lit = literal->As<ast::FloatLiteralExpression>()) {
+ return {type, {lit->value}};
+ }
+ if (auto* lit = literal->As<ast::BoolLiteralExpression>()) {
+ return {type, {lit->value}};
+ }
+ TINT_UNREACHABLE(Resolver, builder_->Diagnostics());
+ return {};
+}
+
+sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
+ const sem::Type* type) {
+ auto* vec = type->As<sem::Vector>();
+
+ // For now, only fold scalars and vectors
+ if (!type->is_scalar() && !vec) {
+ return {};
+ }
+
+ auto* elem_type = vec ? vec->type() : type;
+ int result_size = vec ? static_cast<int>(vec->Width()) : 1;
+
+ // For zero value init, return 0s
+ if (call->args.empty()) {
+ if (elem_type->Is<sem::I32>()) {
+ return sem::Constant(type, sem::Constant::Scalars(result_size, 0));
+ }
+ if (elem_type->Is<sem::U32>()) {
+ return sem::Constant(type, sem::Constant::Scalars(result_size, 0u));
+ }
+ if (elem_type->Is<sem::F32>()) {
+ return sem::Constant(type, sem::Constant::Scalars(result_size, 0.f));
+ }
+ if (elem_type->Is<sem::Bool>()) {
+ return sem::Constant(type, sem::Constant::Scalars(result_size, false));
+ }
+ }
+
+ // Build value for type_ctor from each child value by casting to
+ // type_ctor's type.
+ sem::Constant::Scalars elems;
+ for (auto* expr : call->args) {
+ auto* arg = builder_->Sem().Get(expr);
+ if (!arg || !arg->ConstantValue()) {
+ return {};
+ }
+ auto cast = ConstantCast(arg->ConstantValue(), elem_type);
+ elems.insert(elems.end(), cast.Elements().begin(), cast.Elements().end());
+ }
+
+ // Splat single-value initializers
+ if (elems.size() == 1) {
+ for (int i = 0; i < result_size - 1; ++i) {
+ elems.emplace_back(elems[0]);
+ }
+ }
+
+ return sem::Constant(type, std::move(elems));
+}
+
+sem::Constant Resolver::ConstantCast(const sem::Constant& value,
+ const sem::Type* target_elem_type) {
+ if (value.ElementType() == target_elem_type) {
+ return value;
+ }
+
+ sem::Constant::Scalars elems;
+ for (size_t i = 0; i < value.Elements().size(); ++i) {
+ if (target_elem_type->Is<sem::I32>()) {
+ elems.emplace_back(
+ value.WithScalarAt(i, [](auto&& s) { return static_cast<i32>(s); }));
+ } else if (target_elem_type->Is<sem::U32>()) {
+ elems.emplace_back(
+ value.WithScalarAt(i, [](auto&& s) { return static_cast<u32>(s); }));
+ } else if (target_elem_type->Is<sem::F32>()) {
+ elems.emplace_back(
+ value.WithScalarAt(i, [](auto&& s) { return static_cast<f32>(s); }));
+ } else if (target_elem_type->Is<sem::Bool>()) {
+ elems.emplace_back(
+ value.WithScalarAt(i, [](auto&& s) { return static_cast<bool>(s); }));
+ }
+ }
+
+ auto* target_type =
+ value.Type()->Is<sem::Vector>()
+ ? builder_->create<sem::Vector>(target_elem_type,
+ static_cast<uint32_t>(elems.size()))
+ : target_elem_type;
+
+ return sem::Constant(target_type, elems);
+}
+
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/resolver_constants_test.cc b/src/tint/resolver/resolver_constants_test.cc
new file mode 100644
index 0000000..6d06bef
--- /dev/null
+++ b/src/tint/resolver/resolver_constants_test.cc
@@ -0,0 +1,433 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gtest/gtest.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/expression.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using Scalar = sem::Constant::Scalar;
+
+using ResolverConstantsTest = ResolverTest;
+
+TEST_F(ResolverConstantsTest, Scalar_i32) {
+ auto* expr = Expr(99);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Type()->Is<sem::I32>());
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 99);
+}
+
+TEST_F(ResolverConstantsTest, Scalar_u32) {
+ auto* expr = Expr(99u);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Type()->Is<sem::U32>());
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 99u);
+}
+
+TEST_F(ResolverConstantsTest, Scalar_f32) {
+ auto* expr = Expr(9.9f);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Type()->Is<sem::F32>());
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 9.9f);
+}
+
+TEST_F(ResolverConstantsTest, Scalar_bool) {
+ auto* expr = Expr(true);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Type()->Is<sem::Bool>());
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, true);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_ZeroInit_i32) {
+ auto* expr = vec3<i32>();
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 0);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 0);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 0);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_ZeroInit_u32) {
+ auto* expr = vec3<u32>();
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 0u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].u32, 0u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].u32, 0u);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_ZeroInit_f32) {
+ auto* expr = vec3<f32>();
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 0u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 0u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 0u);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_ZeroInit_bool) {
+ auto* expr = vec3<bool>();
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, false);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].bool_, false);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].bool_, false);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_Splat_i32) {
+ auto* expr = vec3<i32>(99);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 99);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 99);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 99);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_Splat_u32) {
+ auto* expr = vec3<u32>(99u);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 99u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].u32, 99u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].u32, 99u);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_Splat_f32) {
+ auto* expr = vec3<f32>(9.9f);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 9.9f);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 9.9f);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 9.9f);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_Splat_bool) {
+ auto* expr = vec3<bool>(true);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, true);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].bool_, true);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].bool_, true);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_FullConstruct_i32) {
+ auto* expr = vec3<i32>(1, 2, 3);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 1);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 2);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 3);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_FullConstruct_u32) {
+ auto* expr = vec3<u32>(1u, 2u, 3u);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 1u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].u32, 2u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].u32, 3u);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_FullConstruct_f32) {
+ auto* expr = vec3<f32>(1.f, 2.f, 3.f);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 1.f);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 2.f);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 3.f);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_FullConstruct_bool) {
+ auto* expr = vec3<bool>(true, false, true);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, true);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].bool_, false);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].bool_, true);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_MixConstruct_i32) {
+ auto* expr = vec3<i32>(1, vec2<i32>(2, 3));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 1);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 2);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 3);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_MixConstruct_u32) {
+ auto* expr = vec3<u32>(vec2<u32>(1u, 2u), 3u);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 1u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].u32, 2u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].u32, 3u);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_MixConstruct_f32) {
+ auto* expr = vec3<f32>(1.f, vec2<f32>(2.f, 3.f));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 1.f);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 2.f);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 3.f);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_MixConstruct_bool) {
+ auto* expr = vec3<bool>(vec2<bool>(true, false), true);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, true);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].bool_, false);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].bool_, true);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_Cast_f32_to_32) {
+ auto* expr = vec3<i32>(vec3<f32>(1.1f, 2.2f, 3.3f));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 1);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 2);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 3);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_Cast_u32_to_f32) {
+ auto* expr = vec3<f32>(vec3<u32>(10u, 20u, 30u));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
+ ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 10.f);
+ EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 20.f);
+ EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 30.f);
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc
new file mode 100644
index 0000000..ec0b26d
--- /dev/null
+++ b/src/tint/resolver/resolver_test.cc
@@ -0,0 +1,2189 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include <tuple>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ast/assignment_statement.h"
+#include "src/tint/ast/bitcast_expression.h"
+#include "src/tint/ast/break_statement.h"
+#include "src/tint/ast/builtin_texture_helper_test.h"
+#include "src/tint/ast/call_statement.h"
+#include "src/tint/ast/continue_statement.h"
+#include "src/tint/ast/float_literal_expression.h"
+#include "src/tint/ast/id_attribute.h"
+#include "src/tint/ast/if_statement.h"
+#include "src/tint/ast/loop_statement.h"
+#include "src/tint/ast/return_statement.h"
+#include "src/tint/ast/stage_attribute.h"
+#include "src/tint/ast/struct_block_attribute.h"
+#include "src/tint/ast/switch_statement.h"
+#include "src/tint/ast/unary_op_expression.h"
+#include "src/tint/ast/variable_decl_statement.h"
+#include "src/tint/ast/workgroup_attribute.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/module.h"
+#include "src/tint/sem/reference_type.h"
+#include "src/tint/sem/sampled_texture_type.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/variable.h"
+
+using ::testing::ElementsAre;
+using ::testing::HasSubstr;
+
+namespace tint {
+namespace resolver {
+namespace {
+
+// Helpers and typedefs
+template <typename T>
+using DataType = builder::DataType<T>;
+template <int N, typename T>
+using vec = builder::vec<N, T>;
+template <typename T>
+using vec2 = builder::vec2<T>;
+template <typename T>
+using vec3 = builder::vec3<T>;
+template <typename T>
+using vec4 = builder::vec4<T>;
+template <int N, int M, typename T>
+using mat = builder::mat<N, M, T>;
+template <typename T>
+using mat2x2 = builder::mat2x2<T>;
+template <typename T>
+using mat2x3 = builder::mat2x3<T>;
+template <typename T>
+using mat3x2 = builder::mat3x2<T>;
+template <typename T>
+using mat3x3 = builder::mat3x3<T>;
+template <typename T>
+using mat4x4 = builder::mat4x4<T>;
+template <typename T, int ID = 0>
+using alias = builder::alias<T, ID>;
+template <typename T>
+using alias1 = builder::alias1<T>;
+template <typename T>
+using alias2 = builder::alias2<T>;
+template <typename T>
+using alias3 = builder::alias3<T>;
+using f32 = builder::f32;
+using i32 = builder::i32;
+using u32 = builder::u32;
+using Op = ast::BinaryOp;
+
+TEST_F(ResolverTest, Stmt_Assign) {
+ auto* v = Var("v", ty.f32());
+ auto* lhs = Expr("v");
+ auto* rhs = Expr(2.3f);
+
+ auto* assign = Assign(lhs, rhs);
+ WrapInFunction(v, assign);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
+
+ EXPECT_TRUE(TypeOf(lhs)->UnwrapRef()->Is<sem::F32>());
+ EXPECT_TRUE(TypeOf(rhs)->Is<sem::F32>());
+ EXPECT_EQ(StmtOf(lhs), assign);
+ EXPECT_EQ(StmtOf(rhs), assign);
+}
+
+TEST_F(ResolverTest, Stmt_Case) {
+ auto* v = Var("v", ty.f32());
+ auto* lhs = Expr("v");
+ auto* rhs = Expr(2.3f);
+
+ auto* assign = Assign(lhs, rhs);
+ auto* block = Block(assign);
+ ast::CaseSelectorList lit;
+ lit.push_back(create<ast::SintLiteralExpression>(3));
+ auto* cse = create<ast::CaseStatement>(lit, block);
+ auto* cond_var = Var("c", ty.i32());
+ auto* sw = Switch(cond_var, cse, DefaultCase());
+ WrapInFunction(v, cond_var, sw);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
+ EXPECT_TRUE(TypeOf(lhs)->UnwrapRef()->Is<sem::F32>());
+ EXPECT_TRUE(TypeOf(rhs)->Is<sem::F32>());
+ EXPECT_EQ(StmtOf(lhs), assign);
+ EXPECT_EQ(StmtOf(rhs), assign);
+ EXPECT_EQ(BlockOf(assign), block);
+}
+
+TEST_F(ResolverTest, Stmt_Block) {
+ auto* v = Var("v", ty.f32());
+ auto* lhs = Expr("v");
+ auto* rhs = Expr(2.3f);
+
+ auto* assign = Assign(lhs, rhs);
+ auto* block = Block(assign);
+ WrapInFunction(v, block);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
+ EXPECT_TRUE(TypeOf(lhs)->UnwrapRef()->Is<sem::F32>());
+ EXPECT_TRUE(TypeOf(rhs)->Is<sem::F32>());
+ EXPECT_EQ(StmtOf(lhs), assign);
+ EXPECT_EQ(StmtOf(rhs), assign);
+ EXPECT_EQ(BlockOf(lhs), block);
+ EXPECT_EQ(BlockOf(rhs), block);
+ EXPECT_EQ(BlockOf(assign), block);
+}
+
+TEST_F(ResolverTest, Stmt_If) {
+ auto* v = Var("v", ty.f32());
+ auto* else_lhs = Expr("v");
+ auto* else_rhs = Expr(2.3f);
+
+ auto* else_body = Block(Assign(else_lhs, else_rhs));
+
+ auto* else_cond = Expr(true);
+ auto* else_stmt = create<ast::ElseStatement>(else_cond, else_body);
+
+ auto* lhs = Expr("v");
+ auto* rhs = Expr(2.3f);
+
+ auto* assign = Assign(lhs, rhs);
+ auto* body = Block(assign);
+ auto* cond = Expr(true);
+ auto* stmt =
+ create<ast::IfStatement>(cond, body, ast::ElseStatementList{else_stmt});
+ WrapInFunction(v, stmt);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(stmt->condition), nullptr);
+ ASSERT_NE(TypeOf(else_lhs), nullptr);
+ ASSERT_NE(TypeOf(else_rhs), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
+ EXPECT_TRUE(TypeOf(stmt->condition)->Is<sem::Bool>());
+ EXPECT_TRUE(TypeOf(else_lhs)->UnwrapRef()->Is<sem::F32>());
+ EXPECT_TRUE(TypeOf(else_rhs)->Is<sem::F32>());
+ EXPECT_TRUE(TypeOf(lhs)->UnwrapRef()->Is<sem::F32>());
+ EXPECT_TRUE(TypeOf(rhs)->Is<sem::F32>());
+ EXPECT_EQ(StmtOf(lhs), assign);
+ EXPECT_EQ(StmtOf(rhs), assign);
+ EXPECT_EQ(StmtOf(cond), stmt);
+ EXPECT_EQ(StmtOf(else_cond), else_stmt);
+ EXPECT_EQ(BlockOf(lhs), body);
+ EXPECT_EQ(BlockOf(rhs), body);
+ EXPECT_EQ(BlockOf(else_lhs), else_body);
+ EXPECT_EQ(BlockOf(else_rhs), else_body);
+}
+
+TEST_F(ResolverTest, Stmt_Loop) {
+ auto* v = Var("v", ty.f32());
+ auto* body_lhs = Expr("v");
+ auto* body_rhs = Expr(2.3f);
+
+ auto* body = Block(Assign(body_lhs, body_rhs), Break());
+ auto* continuing_lhs = Expr("v");
+ auto* continuing_rhs = Expr(2.3f);
+
+ auto* continuing = Block(Assign(continuing_lhs, continuing_rhs));
+ auto* stmt = Loop(body, continuing);
+ WrapInFunction(v, stmt);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(body_lhs), nullptr);
+ ASSERT_NE(TypeOf(body_rhs), nullptr);
+ ASSERT_NE(TypeOf(continuing_lhs), nullptr);
+ ASSERT_NE(TypeOf(continuing_rhs), nullptr);
+ EXPECT_TRUE(TypeOf(body_lhs)->UnwrapRef()->Is<sem::F32>());
+ EXPECT_TRUE(TypeOf(body_rhs)->Is<sem::F32>());
+ EXPECT_TRUE(TypeOf(continuing_lhs)->UnwrapRef()->Is<sem::F32>());
+ EXPECT_TRUE(TypeOf(continuing_rhs)->Is<sem::F32>());
+ EXPECT_EQ(BlockOf(body_lhs), body);
+ EXPECT_EQ(BlockOf(body_rhs), body);
+ EXPECT_EQ(BlockOf(continuing_lhs), continuing);
+ EXPECT_EQ(BlockOf(continuing_rhs), continuing);
+}
+
+TEST_F(ResolverTest, Stmt_Return) {
+ auto* cond = Expr(2);
+
+ auto* ret = Return(cond);
+ Func("test", {}, ty.i32(), {ret}, {});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(cond), nullptr);
+ EXPECT_TRUE(TypeOf(cond)->Is<sem::I32>());
+}
+
+TEST_F(ResolverTest, Stmt_Return_WithoutValue) {
+ auto* ret = Return();
+ WrapInFunction(ret);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTest, Stmt_Switch) {
+ auto* v = Var("v", ty.f32());
+ auto* lhs = Expr("v");
+ auto* rhs = Expr(2.3f);
+ auto* case_block = Block(Assign(lhs, rhs));
+ auto* stmt = Switch(Expr(2), Case(Expr(3), case_block), DefaultCase());
+ WrapInFunction(v, stmt);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(stmt->condition), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
+
+ EXPECT_TRUE(TypeOf(stmt->condition)->Is<sem::I32>());
+ EXPECT_TRUE(TypeOf(lhs)->UnwrapRef()->Is<sem::F32>());
+ EXPECT_TRUE(TypeOf(rhs)->Is<sem::F32>());
+ EXPECT_EQ(BlockOf(lhs), case_block);
+ EXPECT_EQ(BlockOf(rhs), case_block);
+}
+
+TEST_F(ResolverTest, Stmt_Call) {
+ ast::VariableList params;
+ Func("my_func", params, ty.void_(), {Return()}, ast::AttributeList{});
+
+ auto* expr = Call("my_func");
+
+ auto* call = CallStmt(expr);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<sem::Void>());
+ EXPECT_EQ(StmtOf(expr), call);
+}
+
+TEST_F(ResolverTest, Stmt_VariableDecl) {
+ auto* var = Var("my_var", ty.i32(), ast::StorageClass::kNone, Expr(2));
+ auto* init = var->constructor;
+
+ auto* decl = Decl(var);
+ WrapInFunction(decl);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(init), nullptr);
+ EXPECT_TRUE(TypeOf(init)->Is<sem::I32>());
+}
+
+TEST_F(ResolverTest, Stmt_VariableDecl_Alias) {
+ auto* my_int = Alias("MyInt", ty.i32());
+ auto* var = Var("my_var", ty.Of(my_int), ast::StorageClass::kNone, Expr(2));
+ auto* init = var->constructor;
+
+ auto* decl = Decl(var);
+ WrapInFunction(decl);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(init), nullptr);
+ EXPECT_TRUE(TypeOf(init)->Is<sem::I32>());
+}
+
+TEST_F(ResolverTest, Stmt_VariableDecl_ModuleScope) {
+ auto* init = Expr(2);
+ Global("my_var", ty.i32(), ast::StorageClass::kPrivate, init);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(init), nullptr);
+ EXPECT_TRUE(TypeOf(init)->Is<sem::I32>());
+ EXPECT_EQ(StmtOf(init), nullptr);
+}
+
+TEST_F(ResolverTest, Stmt_VariableDecl_OuterScopeAfterInnerScope) {
+ // fn func_i32() {
+ // {
+ // var foo : i32 = 2;
+ // var bar : i32 = foo;
+ // }
+ // var foo : f32 = 2.0;
+ // var bar : f32 = foo;
+ // }
+
+ ast::VariableList params;
+
+ // Declare i32 "foo" inside a block
+ auto* foo_i32 = Var("foo", ty.i32(), ast::StorageClass::kNone, Expr(2));
+ auto* foo_i32_init = foo_i32->constructor;
+ auto* foo_i32_decl = Decl(foo_i32);
+
+ // Reference "foo" inside the block
+ auto* bar_i32 = Var("bar", ty.i32(), ast::StorageClass::kNone, Expr("foo"));
+ auto* bar_i32_init = bar_i32->constructor;
+ auto* bar_i32_decl = Decl(bar_i32);
+
+ auto* inner = Block(foo_i32_decl, bar_i32_decl);
+
+ // Declare f32 "foo" at function scope
+ auto* foo_f32 = Var("foo", ty.f32(), ast::StorageClass::kNone, Expr(2.f));
+ auto* foo_f32_init = foo_f32->constructor;
+ auto* foo_f32_decl = Decl(foo_f32);
+
+ // Reference "foo" at function scope
+ auto* bar_f32 = Var("bar", ty.f32(), ast::StorageClass::kNone, Expr("foo"));
+ auto* bar_f32_init = bar_f32->constructor;
+ auto* bar_f32_decl = Decl(bar_f32);
+
+ Func("func", params, ty.void_(), {inner, foo_f32_decl, bar_f32_decl},
+ ast::AttributeList{});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_NE(TypeOf(foo_i32_init), nullptr);
+ EXPECT_TRUE(TypeOf(foo_i32_init)->Is<sem::I32>());
+ ASSERT_NE(TypeOf(foo_f32_init), nullptr);
+ EXPECT_TRUE(TypeOf(foo_f32_init)->Is<sem::F32>());
+ ASSERT_NE(TypeOf(bar_i32_init), nullptr);
+ EXPECT_TRUE(TypeOf(bar_i32_init)->UnwrapRef()->Is<sem::I32>());
+ ASSERT_NE(TypeOf(bar_f32_init), nullptr);
+ EXPECT_TRUE(TypeOf(bar_f32_init)->UnwrapRef()->Is<sem::F32>());
+ EXPECT_EQ(StmtOf(foo_i32_init), foo_i32_decl);
+ EXPECT_EQ(StmtOf(bar_i32_init), bar_i32_decl);
+ EXPECT_EQ(StmtOf(foo_f32_init), foo_f32_decl);
+ EXPECT_EQ(StmtOf(bar_f32_init), bar_f32_decl);
+ EXPECT_TRUE(CheckVarUsers(foo_i32, {bar_i32->constructor}));
+ EXPECT_TRUE(CheckVarUsers(foo_f32, {bar_f32->constructor}));
+ ASSERT_NE(VarOf(bar_i32->constructor), nullptr);
+ EXPECT_EQ(VarOf(bar_i32->constructor)->Declaration(), foo_i32);
+ ASSERT_NE(VarOf(bar_f32->constructor), nullptr);
+ EXPECT_EQ(VarOf(bar_f32->constructor)->Declaration(), foo_f32);
+}
+
+TEST_F(ResolverTest, Stmt_VariableDecl_ModuleScopeAfterFunctionScope) {
+ // fn func_i32() {
+ // var foo : i32 = 2;
+ // }
+ // var foo : f32 = 2.0;
+ // fn func_f32() {
+ // var bar : f32 = foo;
+ // }
+
+ ast::VariableList params;
+
+ // Declare i32 "foo" inside a function
+ auto* fn_i32 = Var("foo", ty.i32(), ast::StorageClass::kNone, Expr(2));
+ auto* fn_i32_init = fn_i32->constructor;
+ auto* fn_i32_decl = Decl(fn_i32);
+ Func("func_i32", params, ty.void_(), {fn_i32_decl}, ast::AttributeList{});
+
+ // Declare f32 "foo" at module scope
+ auto* mod_f32 = Var("foo", ty.f32(), ast::StorageClass::kPrivate, Expr(2.f));
+ auto* mod_init = mod_f32->constructor;
+ AST().AddGlobalVariable(mod_f32);
+
+ // Reference "foo" in another function
+ auto* fn_f32 = Var("bar", ty.f32(), ast::StorageClass::kNone, Expr("foo"));
+ auto* fn_f32_init = fn_f32->constructor;
+ auto* fn_f32_decl = Decl(fn_f32);
+ Func("func_f32", params, ty.void_(), {fn_f32_decl}, ast::AttributeList{});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_NE(TypeOf(mod_init), nullptr);
+ EXPECT_TRUE(TypeOf(mod_init)->Is<sem::F32>());
+ ASSERT_NE(TypeOf(fn_i32_init), nullptr);
+ EXPECT_TRUE(TypeOf(fn_i32_init)->Is<sem::I32>());
+ ASSERT_NE(TypeOf(fn_f32_init), nullptr);
+ EXPECT_TRUE(TypeOf(fn_f32_init)->UnwrapRef()->Is<sem::F32>());
+ EXPECT_EQ(StmtOf(fn_i32_init), fn_i32_decl);
+ EXPECT_EQ(StmtOf(mod_init), nullptr);
+ EXPECT_EQ(StmtOf(fn_f32_init), fn_f32_decl);
+ EXPECT_TRUE(CheckVarUsers(fn_i32, {}));
+ EXPECT_TRUE(CheckVarUsers(mod_f32, {fn_f32->constructor}));
+ ASSERT_NE(VarOf(fn_f32->constructor), nullptr);
+ EXPECT_EQ(VarOf(fn_f32->constructor)->Declaration(), mod_f32);
+}
+
+TEST_F(ResolverTest, ArraySize_UnsignedLiteral) {
+ // var<private> a : array<f32, 10u>;
+ auto* a =
+ Global("a", ty.array(ty.f32(), Expr(10u)), ast::StorageClass::kPrivate);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(a), nullptr);
+ auto* ref = TypeOf(a)->As<sem::Reference>();
+ ASSERT_NE(ref, nullptr);
+ auto* ary = ref->StoreType()->As<sem::Array>();
+ EXPECT_EQ(ary->Count(), 10u);
+}
+
+TEST_F(ResolverTest, ArraySize_SignedLiteral) {
+ // var<private> a : array<f32, 10>;
+ auto* a =
+ Global("a", ty.array(ty.f32(), Expr(10)), ast::StorageClass::kPrivate);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(a), nullptr);
+ auto* ref = TypeOf(a)->As<sem::Reference>();
+ ASSERT_NE(ref, nullptr);
+ auto* ary = ref->StoreType()->As<sem::Array>();
+ EXPECT_EQ(ary->Count(), 10u);
+}
+
+TEST_F(ResolverTest, ArraySize_UnsignedConstant) {
+ // let size = 0u;
+ // var<private> a : array<f32, 10u>;
+ GlobalConst("size", nullptr, Expr(10u));
+ auto* a = Global("a", ty.array(ty.f32(), Expr("size")),
+ ast::StorageClass::kPrivate);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(a), nullptr);
+ auto* ref = TypeOf(a)->As<sem::Reference>();
+ ASSERT_NE(ref, nullptr);
+ auto* ary = ref->StoreType()->As<sem::Array>();
+ EXPECT_EQ(ary->Count(), 10u);
+}
+
+TEST_F(ResolverTest, ArraySize_SignedConstant) {
+ // let size = 0;
+ // var<private> a : array<f32, 10>;
+ GlobalConst("size", nullptr, Expr(10));
+ auto* a = Global("a", ty.array(ty.f32(), Expr("size")),
+ ast::StorageClass::kPrivate);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(a), nullptr);
+ auto* ref = TypeOf(a)->As<sem::Reference>();
+ ASSERT_NE(ref, nullptr);
+ auto* ary = ref->StoreType()->As<sem::Array>();
+ EXPECT_EQ(ary->Count(), 10u);
+}
+
+TEST_F(ResolverTest, Expr_Bitcast) {
+ Global("name", ty.f32(), ast::StorageClass::kPrivate);
+
+ auto* bitcast = create<ast::BitcastExpression>(ty.f32(), Expr("name"));
+ WrapInFunction(bitcast);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(bitcast), nullptr);
+ EXPECT_TRUE(TypeOf(bitcast)->Is<sem::F32>());
+}
+
+TEST_F(ResolverTest, Expr_Call) {
+ ast::VariableList params;
+ Func("my_func", params, ty.f32(), {Return(0.0f)}, ast::AttributeList{});
+
+ auto* call = Call("my_func");
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
+}
+
+TEST_F(ResolverTest, Expr_Call_InBinaryOp) {
+ ast::VariableList params;
+ Func("func", params, ty.f32(), {Return(0.0f)}, ast::AttributeList{});
+
+ auto* expr = Add(Call("func"), Call("func"));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<sem::F32>());
+}
+
+TEST_F(ResolverTest, Expr_Call_WithParams) {
+ Func("my_func", {Param(Sym(), ty.f32())}, ty.f32(),
+ {
+ Return(1.2f),
+ });
+
+ auto* param = Expr(2.4f);
+
+ auto* call = Call("my_func", param);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(param), nullptr);
+ EXPECT_TRUE(TypeOf(param)->Is<sem::F32>());
+}
+
+TEST_F(ResolverTest, Expr_Call_Builtin) {
+ auto* call = Call("round", 2.4f);
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
+}
+
+TEST_F(ResolverTest, Expr_Cast) {
+ Global("name", ty.f32(), ast::StorageClass::kPrivate);
+
+ auto* cast = Construct(ty.f32(), "name");
+ WrapInFunction(cast);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(cast), nullptr);
+ EXPECT_TRUE(TypeOf(cast)->Is<sem::F32>());
+}
+
+TEST_F(ResolverTest, Expr_Constructor_Scalar) {
+ auto* s = Expr(1.0f);
+ WrapInFunction(s);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(s), nullptr);
+ EXPECT_TRUE(TypeOf(s)->Is<sem::F32>());
+}
+
+TEST_F(ResolverTest, Expr_Constructor_Type_Vec2) {
+ auto* tc = vec2<f32>(1.0f, 1.0f);
+ WrapInFunction(tc);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
+}
+
+TEST_F(ResolverTest, Expr_Constructor_Type_Vec3) {
+ auto* tc = vec3<f32>(1.0f, 1.0f, 1.0f);
+ WrapInFunction(tc);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+}
+
+TEST_F(ResolverTest, Expr_Constructor_Type_Vec4) {
+ auto* tc = vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f);
+ WrapInFunction(tc);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 4u);
+}
+
+TEST_F(ResolverTest, Expr_Identifier_GlobalVariable) {
+ auto* my_var = Global("my_var", ty.f32(), ast::StorageClass::kPrivate);
+
+ auto* ident = Expr("my_var");
+ WrapInFunction(ident);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(ident), nullptr);
+ ASSERT_TRUE(TypeOf(ident)->Is<sem::Reference>());
+ EXPECT_TRUE(TypeOf(ident)->UnwrapRef()->Is<sem::F32>());
+ EXPECT_TRUE(CheckVarUsers(my_var, {ident}));
+ ASSERT_NE(VarOf(ident), nullptr);
+ EXPECT_EQ(VarOf(ident)->Declaration(), my_var);
+}
+
+TEST_F(ResolverTest, Expr_Identifier_GlobalConstant) {
+ auto* my_var = GlobalConst("my_var", ty.f32(), Construct(ty.f32()));
+
+ auto* ident = Expr("my_var");
+ WrapInFunction(ident);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->Is<sem::F32>());
+ EXPECT_TRUE(CheckVarUsers(my_var, {ident}));
+ ASSERT_NE(VarOf(ident), nullptr);
+ EXPECT_EQ(VarOf(ident)->Declaration(), my_var);
+}
+
+TEST_F(ResolverTest, Expr_Identifier_FunctionVariable_Const) {
+ auto* my_var_a = Expr("my_var");
+ auto* var = Const("my_var", ty.f32(), Construct(ty.f32()));
+ auto* decl = Decl(Var("b", ty.f32(), ast::StorageClass::kNone, my_var_a));
+
+ Func("my_func", ast::VariableList{}, ty.void_(),
+ {
+ Decl(var),
+ decl,
+ },
+ ast::AttributeList{});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(my_var_a), nullptr);
+ EXPECT_TRUE(TypeOf(my_var_a)->Is<sem::F32>());
+ EXPECT_EQ(StmtOf(my_var_a), decl);
+ EXPECT_TRUE(CheckVarUsers(var, {my_var_a}));
+ ASSERT_NE(VarOf(my_var_a), nullptr);
+ EXPECT_EQ(VarOf(my_var_a)->Declaration(), var);
+}
+
+TEST_F(ResolverTest, IndexAccessor_Dynamic_Ref_F32) {
+ // var a : array<bool, 10> = 0;
+ // var idx : f32 = f32();
+ // var f : f32 = a[idx];
+ auto* a = Var("a", ty.array<bool, 10>(), array<bool, 10>());
+ auto* idx = Var("idx", ty.f32(), Construct(ty.f32()));
+ auto* f = Var("f", ty.f32(), IndexAccessor("a", Expr(Source{{12, 34}}, idx)));
+ Func("my_func", ast::VariableList{}, ty.void_(),
+ {
+ Decl(a),
+ Decl(idx),
+ Decl(f),
+ },
+ ast::AttributeList{});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: index must be of type 'i32' or 'u32', found: 'f32'");
+}
+
+TEST_F(ResolverTest, Expr_Identifier_FunctionVariable) {
+ auto* my_var_a = Expr("my_var");
+ auto* my_var_b = Expr("my_var");
+ auto* assign = Assign(my_var_a, my_var_b);
+
+ auto* var = Var("my_var", ty.f32());
+
+ Func("my_func", ast::VariableList{}, ty.void_(),
+ {
+ Decl(var),
+ assign,
+ },
+ ast::AttributeList{});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(my_var_a), nullptr);
+ ASSERT_TRUE(TypeOf(my_var_a)->Is<sem::Reference>());
+ EXPECT_TRUE(TypeOf(my_var_a)->UnwrapRef()->Is<sem::F32>());
+ EXPECT_EQ(StmtOf(my_var_a), assign);
+ ASSERT_NE(TypeOf(my_var_b), nullptr);
+ ASSERT_TRUE(TypeOf(my_var_b)->Is<sem::Reference>());
+ EXPECT_TRUE(TypeOf(my_var_b)->UnwrapRef()->Is<sem::F32>());
+ EXPECT_EQ(StmtOf(my_var_b), assign);
+ EXPECT_TRUE(CheckVarUsers(var, {my_var_a, my_var_b}));
+ ASSERT_NE(VarOf(my_var_a), nullptr);
+ EXPECT_EQ(VarOf(my_var_a)->Declaration(), var);
+ ASSERT_NE(VarOf(my_var_b), nullptr);
+ EXPECT_EQ(VarOf(my_var_b)->Declaration(), var);
+}
+
+TEST_F(ResolverTest, Expr_Identifier_Function_Ptr) {
+ auto* v = Expr("v");
+ auto* p = Expr("p");
+ auto* v_decl = Decl(Var("v", ty.f32()));
+ auto* p_decl = Decl(
+ Const("p", ty.pointer<f32>(ast::StorageClass::kFunction), AddressOf(v)));
+ auto* assign = Assign(Deref(p), 1.23f);
+ Func("my_func", ast::VariableList{}, ty.void_(),
+ {
+ v_decl,
+ p_decl,
+ assign,
+ },
+ ast::AttributeList{});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(v), nullptr);
+ ASSERT_TRUE(TypeOf(v)->Is<sem::Reference>());
+ EXPECT_TRUE(TypeOf(v)->UnwrapRef()->Is<sem::F32>());
+ EXPECT_EQ(StmtOf(v), p_decl);
+ ASSERT_NE(TypeOf(p), nullptr);
+ ASSERT_TRUE(TypeOf(p)->Is<sem::Pointer>());
+ EXPECT_TRUE(TypeOf(p)->UnwrapPtr()->Is<sem::F32>());
+ EXPECT_EQ(StmtOf(p), assign);
+}
+
+TEST_F(ResolverTest, Expr_Call_Function) {
+ Func("my_func", ast::VariableList{}, ty.f32(), {Return(0.0f)},
+ ast::AttributeList{});
+
+ auto* call = Call("my_func");
+ WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
+}
+
+TEST_F(ResolverTest, Expr_Identifier_Unknown) {
+ auto* a = Expr("a");
+ WrapInFunction(a);
+
+ EXPECT_FALSE(r()->Resolve());
+}
+
+TEST_F(ResolverTest, Function_Parameters) {
+ auto* param_a = Param("a", ty.f32());
+ auto* param_b = Param("b", ty.i32());
+ auto* param_c = Param("c", ty.u32());
+
+ auto* func = Func("my_func",
+ ast::VariableList{
+ param_a,
+ param_b,
+ param_c,
+ },
+ ty.void_(), {});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+ EXPECT_EQ(func_sem->Parameters().size(), 3u);
+ EXPECT_TRUE(func_sem->Parameters()[0]->Type()->Is<sem::F32>());
+ EXPECT_TRUE(func_sem->Parameters()[1]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(func_sem->Parameters()[2]->Type()->Is<sem::U32>());
+ EXPECT_EQ(func_sem->Parameters()[0]->Declaration(), param_a);
+ EXPECT_EQ(func_sem->Parameters()[1]->Declaration(), param_b);
+ EXPECT_EQ(func_sem->Parameters()[2]->Declaration(), param_c);
+ EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
+}
+
+TEST_F(ResolverTest, Function_RegisterInputOutputVariables) {
+ auto* s = Structure("S", {Member("m", ty.u32())},
+ {create<ast::StructBlockAttribute>()});
+
+ auto* sb_var = Global("sb_var", ty.Of(s), ast::StorageClass::kStorage,
+ ast::Access::kReadWrite,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+ auto* wg_var = Global("wg_var", ty.f32(), ast::StorageClass::kWorkgroup);
+ auto* priv_var = Global("priv_var", ty.f32(), ast::StorageClass::kPrivate);
+
+ auto* func = Func("my_func", ast::VariableList{}, ty.void_(),
+ {
+ Assign("wg_var", "wg_var"),
+ Assign("sb_var", "sb_var"),
+ Assign("priv_var", "priv_var"),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+ EXPECT_EQ(func_sem->Parameters().size(), 0u);
+ EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
+
+ const auto& vars = func_sem->TransitivelyReferencedGlobals();
+ ASSERT_EQ(vars.size(), 3u);
+ EXPECT_EQ(vars[0]->Declaration(), wg_var);
+ EXPECT_EQ(vars[1]->Declaration(), sb_var);
+ EXPECT_EQ(vars[2]->Declaration(), priv_var);
+}
+
+TEST_F(ResolverTest, Function_RegisterInputOutputVariables_SubFunction) {
+ auto* s = Structure("S", {Member("m", ty.u32())},
+ {create<ast::StructBlockAttribute>()});
+
+ auto* sb_var = Global("sb_var", ty.Of(s), ast::StorageClass::kStorage,
+ ast::Access::kReadWrite,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+ auto* wg_var = Global("wg_var", ty.f32(), ast::StorageClass::kWorkgroup);
+ auto* priv_var = Global("priv_var", ty.f32(), ast::StorageClass::kPrivate);
+
+ Func("my_func", ast::VariableList{}, ty.f32(),
+ {Assign("wg_var", "wg_var"), Assign("sb_var", "sb_var"),
+ Assign("priv_var", "priv_var"), Return(0.0f)},
+ ast::AttributeList{});
+
+ auto* func2 = Func("func", ast::VariableList{}, ty.void_(),
+ {
+ WrapInStatement(Call("my_func")),
+ },
+ ast::AttributeList{});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func2_sem = Sem().Get(func2);
+ ASSERT_NE(func2_sem, nullptr);
+ EXPECT_EQ(func2_sem->Parameters().size(), 0u);
+
+ const auto& vars = func2_sem->TransitivelyReferencedGlobals();
+ ASSERT_EQ(vars.size(), 3u);
+ EXPECT_EQ(vars[0]->Declaration(), wg_var);
+ EXPECT_EQ(vars[1]->Declaration(), sb_var);
+ EXPECT_EQ(vars[2]->Declaration(), priv_var);
+}
+
+TEST_F(ResolverTest, Function_NotRegisterFunctionVariable) {
+ auto* func = Func("my_func", ast::VariableList{}, ty.void_(),
+ {
+ Decl(Var("var", ty.f32())),
+ Assign("var", 1.f),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->TransitivelyReferencedGlobals().size(), 0u);
+ EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
+}
+
+TEST_F(ResolverTest, Function_NotRegisterFunctionConstant) {
+ auto* func = Func("my_func", ast::VariableList{}, ty.void_(),
+ {
+ Decl(Const("var", ty.f32(), Construct(ty.f32()))),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->TransitivelyReferencedGlobals().size(), 0u);
+ EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
+}
+
+TEST_F(ResolverTest, Function_NotRegisterFunctionParams) {
+ auto* func = Func("my_func", {Const("var", ty.f32(), Construct(ty.f32()))},
+ ty.void_(), {});
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->TransitivelyReferencedGlobals().size(), 0u);
+ EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
+}
+
+TEST_F(ResolverTest, Function_CallSites) {
+ auto* foo = Func("foo", ast::VariableList{}, ty.void_(), {});
+
+ auto* call_1 = Call("foo");
+ auto* call_2 = Call("foo");
+ auto* bar = Func("bar", ast::VariableList{}, ty.void_(),
+ {
+ CallStmt(call_1),
+ CallStmt(call_2),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* foo_sem = Sem().Get(foo);
+ ASSERT_NE(foo_sem, nullptr);
+ ASSERT_EQ(foo_sem->CallSites().size(), 2u);
+ EXPECT_EQ(foo_sem->CallSites()[0]->Declaration(), call_1);
+ EXPECT_EQ(foo_sem->CallSites()[1]->Declaration(), call_2);
+
+ auto* bar_sem = Sem().Get(bar);
+ ASSERT_NE(bar_sem, nullptr);
+ EXPECT_EQ(bar_sem->CallSites().size(), 0u);
+}
+
+TEST_F(ResolverTest, Function_WorkgroupSize_NotSet) {
+ // @stage(compute) @workgroup_size(1)
+ // fn main() {}
+ auto* func = Func("main", ast::VariableList{}, ty.void_(), {}, {});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 1u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 1u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 1u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
+}
+
+TEST_F(ResolverTest, Function_WorkgroupSize_Literals) {
+ // @stage(compute) @workgroup_size(8, 2, 3)
+ // fn main() {}
+ auto* func =
+ Func("main", ast::VariableList{}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(8, 2, 3)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 8u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 2u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 3u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
+}
+
+TEST_F(ResolverTest, Function_WorkgroupSize_Consts) {
+ // let width = 16;
+ // let height = 8;
+ // let depth = 2;
+ // @stage(compute) @workgroup_size(width, height, depth)
+ // fn main() {}
+ GlobalConst("width", ty.i32(), Expr(16));
+ GlobalConst("height", ty.i32(), Expr(8));
+ GlobalConst("depth", ty.i32(), Expr(2));
+ auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize("width", "height", "depth")});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 16u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 8u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 2u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
+}
+
+TEST_F(ResolverTest, Function_WorkgroupSize_Consts_NestedInitializer) {
+ // let width = i32(i32(i32(8)));
+ // let height = i32(i32(i32(4)));
+ // @stage(compute) @workgroup_size(width, height)
+ // fn main() {}
+ GlobalConst("width", ty.i32(),
+ Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32(), 8))));
+ GlobalConst("height", ty.i32(),
+ Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32(), 4))));
+ auto* func = Func(
+ "main", ast::VariableList{}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize("width", "height")});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 8u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 4u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 1u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
+}
+
+TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts) {
+ // @id(0) override width = 16;
+ // @id(1) override height = 8;
+ // @id(2) override depth = 2;
+ // @stage(compute) @workgroup_size(width, height, depth)
+ // fn main() {}
+ auto* width = Override("width", ty.i32(), Expr(16), {Id(0)});
+ auto* height = Override("height", ty.i32(), Expr(8), {Id(1)});
+ auto* depth = Override("depth", ty.i32(), Expr(2), {Id(2)});
+ auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize("width", "height", "depth")});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 16u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 8u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 2u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, width);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, height);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, depth);
+}
+
+TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts_NoInit) {
+ // @id(0) override width : i32;
+ // @id(1) override height : i32;
+ // @id(2) override depth : i32;
+ // @stage(compute) @workgroup_size(width, height, depth)
+ // fn main() {}
+ auto* width = Override("width", ty.i32(), nullptr, {Id(0)});
+ auto* height = Override("height", ty.i32(), nullptr, {Id(1)});
+ auto* depth = Override("depth", ty.i32(), nullptr, {Id(2)});
+ auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize("width", "height", "depth")});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 0u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 0u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 0u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, width);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, height);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, depth);
+}
+
+TEST_F(ResolverTest, Function_WorkgroupSize_Mixed) {
+ // @id(1) override height = 2;
+ // let depth = 3;
+ // @stage(compute) @workgroup_size(8, height, depth)
+ // fn main() {}
+ auto* height = Override("height", ty.i32(), Expr(2), {Id(0)});
+ GlobalConst("depth", ty.i32(), Expr(3));
+ auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(8, "height", "depth")});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 8u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 2u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 3u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, height);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
+}
+
+TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
+ auto* st = Structure("S", {Member("first_member", ty.i32()),
+ Member("second_member", ty.f32())});
+ Global("my_struct", ty.Of(st), ast::StorageClass::kPrivate);
+
+ auto* mem = MemberAccessor("my_struct", "second_member");
+ WrapInFunction(mem);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(mem), nullptr);
+ ASSERT_TRUE(TypeOf(mem)->Is<sem::Reference>());
+
+ auto* ref = TypeOf(mem)->As<sem::Reference>();
+ EXPECT_TRUE(ref->StoreType()->Is<sem::F32>());
+ auto* sma = Sem().Get(mem)->As<sem::StructMemberAccess>();
+ ASSERT_NE(sma, nullptr);
+ EXPECT_TRUE(sma->Member()->Type()->Is<sem::F32>());
+ EXPECT_EQ(sma->Member()->Index(), 1u);
+ EXPECT_EQ(sma->Member()->Declaration()->symbol,
+ Symbols().Get("second_member"));
+}
+
+TEST_F(ResolverTest, Expr_MemberAccessor_Struct_Alias) {
+ auto* st = Structure("S", {Member("first_member", ty.i32()),
+ Member("second_member", ty.f32())});
+ auto* alias = Alias("alias", ty.Of(st));
+ Global("my_struct", ty.Of(alias), ast::StorageClass::kPrivate);
+
+ auto* mem = MemberAccessor("my_struct", "second_member");
+ WrapInFunction(mem);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(mem), nullptr);
+ ASSERT_TRUE(TypeOf(mem)->Is<sem::Reference>());
+
+ auto* ref = TypeOf(mem)->As<sem::Reference>();
+ EXPECT_TRUE(ref->StoreType()->Is<sem::F32>());
+ auto* sma = Sem().Get(mem)->As<sem::StructMemberAccess>();
+ ASSERT_NE(sma, nullptr);
+ EXPECT_TRUE(sma->Member()->Type()->Is<sem::F32>());
+ EXPECT_EQ(sma->Member()->Index(), 1u);
+}
+
+TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle) {
+ Global("my_vec", ty.vec4<f32>(), ast::StorageClass::kPrivate);
+
+ auto* mem = MemberAccessor("my_vec", "xzyw");
+ WrapInFunction(mem);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(mem), nullptr);
+ ASSERT_TRUE(TypeOf(mem)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(mem)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(mem)->As<sem::Vector>()->Width(), 4u);
+ ASSERT_TRUE(Sem().Get(mem)->Is<sem::Swizzle>());
+ EXPECT_THAT(Sem().Get(mem)->As<sem::Swizzle>()->Indices(),
+ ElementsAre(0, 2, 1, 3));
+}
+
+TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
+ Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kPrivate);
+
+ auto* mem = MemberAccessor("my_vec", "b");
+ WrapInFunction(mem);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(mem), nullptr);
+ ASSERT_TRUE(TypeOf(mem)->Is<sem::Reference>());
+
+ auto* ref = TypeOf(mem)->As<sem::Reference>();
+ ASSERT_TRUE(ref->StoreType()->Is<sem::F32>());
+ ASSERT_TRUE(Sem().Get(mem)->Is<sem::Swizzle>());
+ EXPECT_THAT(Sem().Get(mem)->As<sem::Swizzle>()->Indices(), ElementsAre(2));
+}
+
+TEST_F(ResolverTest, Expr_Accessor_MultiLevel) {
+ // struct b {
+ // vec4<f32> foo
+ // }
+ // struct A {
+ // array<b, 3> mem
+ // }
+ // var c : A
+ // c.mem[0].foo.yx
+ // -> vec2<f32>
+ //
+ // fn f() {
+ // c.mem[0].foo
+ // }
+ //
+
+ auto* stB = Structure("B", {Member("foo", ty.vec4<f32>())});
+ auto* stA = Structure("A", {Member("mem", ty.array(ty.Of(stB), 3))});
+ Global("c", ty.Of(stA), ast::StorageClass::kPrivate);
+
+ auto* mem = MemberAccessor(
+ MemberAccessor(IndexAccessor(MemberAccessor("c", "mem"), 0), "foo"),
+ "yx");
+ WrapInFunction(mem);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(mem), nullptr);
+ ASSERT_TRUE(TypeOf(mem)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(mem)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(mem)->As<sem::Vector>()->Width(), 2u);
+ ASSERT_TRUE(Sem().Get(mem)->Is<sem::Swizzle>());
+}
+
+TEST_F(ResolverTest, Expr_MemberAccessor_InBinaryOp) {
+ auto* st = Structure("S", {Member("first_member", ty.f32()),
+ Member("second_member", ty.f32())});
+ Global("my_struct", ty.Of(st), ast::StorageClass::kPrivate);
+
+ auto* expr = Add(MemberAccessor("my_struct", "first_member"),
+ MemberAccessor("my_struct", "second_member"));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<sem::F32>());
+}
+
+namespace ExprBinaryTest {
+
+template <typename T, int ID>
+struct Aliased {
+ using type = alias<T, ID>;
+};
+
+template <int N, typename T, int ID>
+struct Aliased<vec<N, T>, ID> {
+ using type = vec<N, alias<T, ID>>;
+};
+
+template <int N, int M, typename T, int ID>
+struct Aliased<mat<N, M, T>, ID> {
+ using type = mat<N, M, alias<T, ID>>;
+};
+
+struct Params {
+ ast::BinaryOp op;
+ builder::ast_type_func_ptr create_lhs_type;
+ builder::ast_type_func_ptr create_rhs_type;
+ builder::ast_type_func_ptr create_lhs_alias_type;
+ builder::ast_type_func_ptr create_rhs_alias_type;
+ builder::sem_type_func_ptr create_result_type;
+};
+
+template <typename LHS, typename RHS, typename RES>
+constexpr Params ParamsFor(ast::BinaryOp op) {
+ return Params{op,
+ DataType<LHS>::AST,
+ DataType<RHS>::AST,
+ DataType<typename Aliased<LHS, 0>::type>::AST,
+ DataType<typename Aliased<RHS, 1>::type>::AST,
+ DataType<RES>::Sem};
+}
+
+static constexpr ast::BinaryOp all_ops[] = {
+ ast::BinaryOp::kAnd,
+ ast::BinaryOp::kOr,
+ ast::BinaryOp::kXor,
+ ast::BinaryOp::kLogicalAnd,
+ ast::BinaryOp::kLogicalOr,
+ ast::BinaryOp::kEqual,
+ ast::BinaryOp::kNotEqual,
+ ast::BinaryOp::kLessThan,
+ ast::BinaryOp::kGreaterThan,
+ ast::BinaryOp::kLessThanEqual,
+ ast::BinaryOp::kGreaterThanEqual,
+ ast::BinaryOp::kShiftLeft,
+ ast::BinaryOp::kShiftRight,
+ ast::BinaryOp::kAdd,
+ ast::BinaryOp::kSubtract,
+ ast::BinaryOp::kMultiply,
+ ast::BinaryOp::kDivide,
+ ast::BinaryOp::kModulo,
+};
+
+static constexpr builder::ast_type_func_ptr all_create_type_funcs[] = {
+ DataType<bool>::AST, //
+ DataType<u32>::AST, //
+ DataType<i32>::AST, //
+ DataType<f32>::AST, //
+ DataType<vec3<bool>>::AST, //
+ DataType<vec3<i32>>::AST, //
+ DataType<vec3<u32>>::AST, //
+ DataType<vec3<f32>>::AST, //
+ DataType<mat3x3<f32>>::AST, //
+ DataType<mat2x3<f32>>::AST, //
+ DataType<mat3x2<f32>>::AST //
+};
+
+// A list of all valid test cases for 'lhs op rhs', except that for vecN and
+// matNxN, we only test N=3.
+static constexpr Params all_valid_cases[] = {
+ // Logical expressions
+ // https://gpuweb.github.io/gpuweb/wgsl.html#logical-expr
+
+ // Binary logical expressions
+ ParamsFor<bool, bool, bool>(Op::kLogicalAnd),
+ ParamsFor<bool, bool, bool>(Op::kLogicalOr),
+
+ ParamsFor<bool, bool, bool>(Op::kAnd),
+ ParamsFor<bool, bool, bool>(Op::kOr),
+ ParamsFor<vec3<bool>, vec3<bool>, vec3<bool>>(Op::kAnd),
+ ParamsFor<vec3<bool>, vec3<bool>, vec3<bool>>(Op::kOr),
+
+ // Arithmetic expressions
+ // https://gpuweb.github.io/gpuweb/wgsl.html#arithmetic-expr
+
+ // Binary arithmetic expressions over scalars
+ ParamsFor<i32, i32, i32>(Op::kAdd),
+ ParamsFor<i32, i32, i32>(Op::kSubtract),
+ ParamsFor<i32, i32, i32>(Op::kMultiply),
+ ParamsFor<i32, i32, i32>(Op::kDivide),
+ ParamsFor<i32, i32, i32>(Op::kModulo),
+
+ ParamsFor<u32, u32, u32>(Op::kAdd),
+ ParamsFor<u32, u32, u32>(Op::kSubtract),
+ ParamsFor<u32, u32, u32>(Op::kMultiply),
+ ParamsFor<u32, u32, u32>(Op::kDivide),
+ ParamsFor<u32, u32, u32>(Op::kModulo),
+
+ ParamsFor<f32, f32, f32>(Op::kAdd),
+ ParamsFor<f32, f32, f32>(Op::kSubtract),
+ ParamsFor<f32, f32, f32>(Op::kMultiply),
+ ParamsFor<f32, f32, f32>(Op::kDivide),
+ ParamsFor<f32, f32, f32>(Op::kModulo),
+
+ // Binary arithmetic expressions over vectors
+ ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kAdd),
+ ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kSubtract),
+ ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kMultiply),
+ ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kDivide),
+ ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kModulo),
+
+ ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kAdd),
+ ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kSubtract),
+ ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kMultiply),
+ ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kDivide),
+ ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kModulo),
+
+ ParamsFor<vec3<f32>, vec3<f32>, vec3<f32>>(Op::kAdd),
+ ParamsFor<vec3<f32>, vec3<f32>, vec3<f32>>(Op::kSubtract),
+ ParamsFor<vec3<f32>, vec3<f32>, vec3<f32>>(Op::kMultiply),
+ ParamsFor<vec3<f32>, vec3<f32>, vec3<f32>>(Op::kDivide),
+ ParamsFor<vec3<f32>, vec3<f32>, vec3<f32>>(Op::kModulo),
+
+ // Binary arithmetic expressions with mixed scalar and vector operands
+ ParamsFor<vec3<i32>, i32, vec3<i32>>(Op::kAdd),
+ ParamsFor<vec3<i32>, i32, vec3<i32>>(Op::kSubtract),
+ ParamsFor<vec3<i32>, i32, vec3<i32>>(Op::kMultiply),
+ ParamsFor<vec3<i32>, i32, vec3<i32>>(Op::kDivide),
+ ParamsFor<vec3<i32>, i32, vec3<i32>>(Op::kModulo),
+
+ ParamsFor<i32, vec3<i32>, vec3<i32>>(Op::kAdd),
+ ParamsFor<i32, vec3<i32>, vec3<i32>>(Op::kSubtract),
+ ParamsFor<i32, vec3<i32>, vec3<i32>>(Op::kMultiply),
+ ParamsFor<i32, vec3<i32>, vec3<i32>>(Op::kDivide),
+ ParamsFor<i32, vec3<i32>, vec3<i32>>(Op::kModulo),
+
+ ParamsFor<vec3<u32>, u32, vec3<u32>>(Op::kAdd),
+ ParamsFor<vec3<u32>, u32, vec3<u32>>(Op::kSubtract),
+ ParamsFor<vec3<u32>, u32, vec3<u32>>(Op::kMultiply),
+ ParamsFor<vec3<u32>, u32, vec3<u32>>(Op::kDivide),
+ ParamsFor<vec3<u32>, u32, vec3<u32>>(Op::kModulo),
+
+ ParamsFor<u32, vec3<u32>, vec3<u32>>(Op::kAdd),
+ ParamsFor<u32, vec3<u32>, vec3<u32>>(Op::kSubtract),
+ ParamsFor<u32, vec3<u32>, vec3<u32>>(Op::kMultiply),
+ ParamsFor<u32, vec3<u32>, vec3<u32>>(Op::kDivide),
+ ParamsFor<u32, vec3<u32>, vec3<u32>>(Op::kModulo),
+
+ ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kAdd),
+ ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kSubtract),
+ ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kMultiply),
+ ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kDivide),
+ // NOTE: no kModulo for vec3<f32>, f32
+ // ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kModulo),
+
+ ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kAdd),
+ ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kSubtract),
+ ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kMultiply),
+ ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kDivide),
+ // NOTE: no kModulo for f32, vec3<f32>
+ // ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kModulo),
+
+ // Matrix arithmetic
+ ParamsFor<mat2x3<f32>, f32, mat2x3<f32>>(Op::kMultiply),
+ ParamsFor<mat3x2<f32>, f32, mat3x2<f32>>(Op::kMultiply),
+ ParamsFor<mat3x3<f32>, f32, mat3x3<f32>>(Op::kMultiply),
+
+ ParamsFor<f32, mat2x3<f32>, mat2x3<f32>>(Op::kMultiply),
+ ParamsFor<f32, mat3x2<f32>, mat3x2<f32>>(Op::kMultiply),
+ ParamsFor<f32, mat3x3<f32>, mat3x3<f32>>(Op::kMultiply),
+
+ ParamsFor<vec3<f32>, mat2x3<f32>, vec2<f32>>(Op::kMultiply),
+ ParamsFor<vec2<f32>, mat3x2<f32>, vec3<f32>>(Op::kMultiply),
+ ParamsFor<vec3<f32>, mat3x3<f32>, vec3<f32>>(Op::kMultiply),
+
+ ParamsFor<mat3x2<f32>, vec3<f32>, vec2<f32>>(Op::kMultiply),
+ ParamsFor<mat2x3<f32>, vec2<f32>, vec3<f32>>(Op::kMultiply),
+ ParamsFor<mat3x3<f32>, vec3<f32>, vec3<f32>>(Op::kMultiply),
+
+ ParamsFor<mat2x3<f32>, mat3x2<f32>, mat3x3<f32>>(Op::kMultiply),
+ ParamsFor<mat3x2<f32>, mat2x3<f32>, mat2x2<f32>>(Op::kMultiply),
+ ParamsFor<mat3x2<f32>, mat3x3<f32>, mat3x2<f32>>(Op::kMultiply),
+ ParamsFor<mat3x3<f32>, mat3x3<f32>, mat3x3<f32>>(Op::kMultiply),
+ ParamsFor<mat3x3<f32>, mat2x3<f32>, mat2x3<f32>>(Op::kMultiply),
+
+ ParamsFor<mat2x3<f32>, mat2x3<f32>, mat2x3<f32>>(Op::kAdd),
+ ParamsFor<mat3x2<f32>, mat3x2<f32>, mat3x2<f32>>(Op::kAdd),
+ ParamsFor<mat3x3<f32>, mat3x3<f32>, mat3x3<f32>>(Op::kAdd),
+
+ ParamsFor<mat2x3<f32>, mat2x3<f32>, mat2x3<f32>>(Op::kSubtract),
+ ParamsFor<mat3x2<f32>, mat3x2<f32>, mat3x2<f32>>(Op::kSubtract),
+ ParamsFor<mat3x3<f32>, mat3x3<f32>, mat3x3<f32>>(Op::kSubtract),
+
+ // Comparison expressions
+ // https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr
+
+ // Comparisons over scalars
+ ParamsFor<bool, bool, bool>(Op::kEqual),
+ ParamsFor<bool, bool, bool>(Op::kNotEqual),
+
+ ParamsFor<i32, i32, bool>(Op::kEqual),
+ ParamsFor<i32, i32, bool>(Op::kNotEqual),
+ ParamsFor<i32, i32, bool>(Op::kLessThan),
+ ParamsFor<i32, i32, bool>(Op::kLessThanEqual),
+ ParamsFor<i32, i32, bool>(Op::kGreaterThan),
+ ParamsFor<i32, i32, bool>(Op::kGreaterThanEqual),
+
+ ParamsFor<u32, u32, bool>(Op::kEqual),
+ ParamsFor<u32, u32, bool>(Op::kNotEqual),
+ ParamsFor<u32, u32, bool>(Op::kLessThan),
+ ParamsFor<u32, u32, bool>(Op::kLessThanEqual),
+ ParamsFor<u32, u32, bool>(Op::kGreaterThan),
+ ParamsFor<u32, u32, bool>(Op::kGreaterThanEqual),
+
+ ParamsFor<f32, f32, bool>(Op::kEqual),
+ ParamsFor<f32, f32, bool>(Op::kNotEqual),
+ ParamsFor<f32, f32, bool>(Op::kLessThan),
+ ParamsFor<f32, f32, bool>(Op::kLessThanEqual),
+ ParamsFor<f32, f32, bool>(Op::kGreaterThan),
+ ParamsFor<f32, f32, bool>(Op::kGreaterThanEqual),
+
+ // Comparisons over vectors
+ ParamsFor<vec3<bool>, vec3<bool>, vec3<bool>>(Op::kEqual),
+ ParamsFor<vec3<bool>, vec3<bool>, vec3<bool>>(Op::kNotEqual),
+
+ ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kEqual),
+ ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kNotEqual),
+ ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kLessThan),
+ ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kLessThanEqual),
+ ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kGreaterThan),
+ ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kGreaterThanEqual),
+
+ ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kEqual),
+ ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kNotEqual),
+ ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kLessThan),
+ ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kLessThanEqual),
+ ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kGreaterThan),
+ ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kGreaterThanEqual),
+
+ ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kEqual),
+ ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kNotEqual),
+ ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kLessThan),
+ ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kLessThanEqual),
+ ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kGreaterThan),
+ ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kGreaterThanEqual),
+
+ // Binary bitwise operations
+ ParamsFor<i32, i32, i32>(Op::kOr),
+ ParamsFor<i32, i32, i32>(Op::kAnd),
+ ParamsFor<i32, i32, i32>(Op::kXor),
+
+ ParamsFor<u32, u32, u32>(Op::kOr),
+ ParamsFor<u32, u32, u32>(Op::kAnd),
+ ParamsFor<u32, u32, u32>(Op::kXor),
+
+ ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kOr),
+ ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kAnd),
+ ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kXor),
+
+ ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kOr),
+ ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kAnd),
+ ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kXor),
+
+ // Bit shift expressions
+ ParamsFor<i32, u32, i32>(Op::kShiftLeft),
+ ParamsFor<vec3<i32>, vec3<u32>, vec3<i32>>(Op::kShiftLeft),
+
+ ParamsFor<u32, u32, u32>(Op::kShiftLeft),
+ ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kShiftLeft),
+
+ ParamsFor<i32, u32, i32>(Op::kShiftRight),
+ ParamsFor<vec3<i32>, vec3<u32>, vec3<i32>>(Op::kShiftRight),
+
+ ParamsFor<u32, u32, u32>(Op::kShiftRight),
+ ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kShiftRight),
+};
+
+using Expr_Binary_Test_Valid = ResolverTestWithParam<Params>;
+TEST_P(Expr_Binary_Test_Valid, All) {
+ auto& params = GetParam();
+
+ auto* lhs_type = params.create_lhs_type(*this);
+ auto* rhs_type = params.create_rhs_type(*this);
+ auto* result_type = params.create_result_type(*this);
+
+ std::stringstream ss;
+ ss << FriendlyName(lhs_type) << " " << params.op << " "
+ << FriendlyName(rhs_type);
+ SCOPED_TRACE(ss.str());
+
+ Global("lhs", lhs_type, ast::StorageClass::kPrivate);
+ Global("rhs", rhs_type, ast::StorageClass::kPrivate);
+
+ auto* expr =
+ create<ast::BinaryExpression>(params.op, Expr("lhs"), Expr("rhs"));
+ WrapInFunction(expr);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr) == result_type);
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTest,
+ Expr_Binary_Test_Valid,
+ testing::ValuesIn(all_valid_cases));
+
+enum class BinaryExprSide { Left, Right, Both };
+using Expr_Binary_Test_WithAlias_Valid =
+ ResolverTestWithParam<std::tuple<Params, BinaryExprSide>>;
+TEST_P(Expr_Binary_Test_WithAlias_Valid, All) {
+ const Params& params = std::get<0>(GetParam());
+ BinaryExprSide side = std::get<1>(GetParam());
+
+ auto* create_lhs_type =
+ (side == BinaryExprSide::Left || side == BinaryExprSide::Both)
+ ? params.create_lhs_alias_type
+ : params.create_lhs_type;
+ auto* create_rhs_type =
+ (side == BinaryExprSide::Right || side == BinaryExprSide::Both)
+ ? params.create_rhs_alias_type
+ : params.create_rhs_type;
+
+ auto* lhs_type = create_lhs_type(*this);
+ auto* rhs_type = create_rhs_type(*this);
+
+ std::stringstream ss;
+ ss << FriendlyName(lhs_type) << " " << params.op << " "
+ << FriendlyName(rhs_type);
+
+ ss << ", After aliasing: " << FriendlyName(lhs_type) << " " << params.op
+ << " " << FriendlyName(rhs_type);
+ SCOPED_TRACE(ss.str());
+
+ Global("lhs", lhs_type, ast::StorageClass::kPrivate);
+ Global("rhs", rhs_type, ast::StorageClass::kPrivate);
+
+ auto* expr =
+ create<ast::BinaryExpression>(params.op, Expr("lhs"), Expr("rhs"));
+ WrapInFunction(expr);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_NE(TypeOf(expr), nullptr);
+ // TODO(amaiorano): Bring this back once we have a way to get the canonical
+ // type
+ // auto* *result_type = params.create_result_type(*this);
+ // ASSERT_TRUE(TypeOf(expr) == result_type);
+}
+INSTANTIATE_TEST_SUITE_P(
+ ResolverTest,
+ Expr_Binary_Test_WithAlias_Valid,
+ testing::Combine(testing::ValuesIn(all_valid_cases),
+ testing::Values(BinaryExprSide::Left,
+ BinaryExprSide::Right,
+ BinaryExprSide::Both)));
+
+// This test works by taking the cartesian product of all possible
+// (type * type * op), and processing only the triplets that are not found in
+// the `all_valid_cases` table.
+using Expr_Binary_Test_Invalid =
+ ResolverTestWithParam<std::tuple<builder::ast_type_func_ptr,
+ builder::ast_type_func_ptr,
+ ast::BinaryOp>>;
+TEST_P(Expr_Binary_Test_Invalid, All) {
+ const builder::ast_type_func_ptr& lhs_create_type_func =
+ std::get<0>(GetParam());
+ const builder::ast_type_func_ptr& rhs_create_type_func =
+ std::get<1>(GetParam());
+ const ast::BinaryOp op = std::get<2>(GetParam());
+
+ // Skip if valid case
+ // TODO(amaiorano): replace linear lookup with O(1) if too slow
+ for (auto& c : all_valid_cases) {
+ if (c.create_lhs_type == lhs_create_type_func &&
+ c.create_rhs_type == rhs_create_type_func && c.op == op) {
+ return;
+ }
+ }
+
+ auto* lhs_type = lhs_create_type_func(*this);
+ auto* rhs_type = rhs_create_type_func(*this);
+
+ std::stringstream ss;
+ ss << FriendlyName(lhs_type) << " " << op << " " << FriendlyName(rhs_type);
+ SCOPED_TRACE(ss.str());
+
+ Global("lhs", lhs_type, ast::StorageClass::kPrivate);
+ Global("rhs", rhs_type, ast::StorageClass::kPrivate);
+
+ auto* expr = create<ast::BinaryExpression>(Source{{12, 34}}, op, Expr("lhs"),
+ Expr("rhs"));
+ WrapInFunction(expr);
+
+ ASSERT_FALSE(r()->Resolve());
+ ASSERT_EQ(r()->error(),
+ "12:34 error: Binary expression operand types are invalid for "
+ "this operation: " +
+ FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op) +
+ " " + FriendlyName(rhs_type));
+}
+INSTANTIATE_TEST_SUITE_P(
+ ResolverTest,
+ Expr_Binary_Test_Invalid,
+ testing::Combine(testing::ValuesIn(all_create_type_funcs),
+ testing::ValuesIn(all_create_type_funcs),
+ testing::ValuesIn(all_ops)));
+
+using Expr_Binary_Test_Invalid_VectorMatrixMultiply =
+ ResolverTestWithParam<std::tuple<bool, uint32_t, uint32_t, uint32_t>>;
+TEST_P(Expr_Binary_Test_Invalid_VectorMatrixMultiply, All) {
+ bool vec_by_mat = std::get<0>(GetParam());
+ uint32_t vec_size = std::get<1>(GetParam());
+ uint32_t mat_rows = std::get<2>(GetParam());
+ uint32_t mat_cols = std::get<3>(GetParam());
+
+ const ast::Type* lhs_type = nullptr;
+ const ast::Type* rhs_type = nullptr;
+ const sem::Type* result_type = nullptr;
+ bool is_valid_expr;
+
+ if (vec_by_mat) {
+ lhs_type = ty.vec<f32>(vec_size);
+ rhs_type = ty.mat<f32>(mat_cols, mat_rows);
+ result_type = create<sem::Vector>(create<sem::F32>(), mat_cols);
+ is_valid_expr = vec_size == mat_rows;
+ } else {
+ lhs_type = ty.mat<f32>(mat_cols, mat_rows);
+ rhs_type = ty.vec<f32>(vec_size);
+ result_type = create<sem::Vector>(create<sem::F32>(), mat_rows);
+ is_valid_expr = vec_size == mat_cols;
+ }
+
+ Global("lhs", lhs_type, ast::StorageClass::kPrivate);
+ Global("rhs", rhs_type, ast::StorageClass::kPrivate);
+
+ auto* expr = Mul(Source{{12, 34}}, Expr("lhs"), Expr("rhs"));
+ WrapInFunction(expr);
+
+ if (is_valid_expr) {
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(TypeOf(expr) == result_type);
+ } else {
+ ASSERT_FALSE(r()->Resolve());
+ ASSERT_EQ(r()->error(),
+ "12:34 error: Binary expression operand types are invalid for "
+ "this operation: " +
+ FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op) +
+ " " + FriendlyName(rhs_type));
+ }
+}
+auto all_dimension_values = testing::Values(2u, 3u, 4u);
+INSTANTIATE_TEST_SUITE_P(ResolverTest,
+ Expr_Binary_Test_Invalid_VectorMatrixMultiply,
+ testing::Combine(testing::Values(true, false),
+ all_dimension_values,
+ all_dimension_values,
+ all_dimension_values));
+
+using Expr_Binary_Test_Invalid_MatrixMatrixMultiply =
+ ResolverTestWithParam<std::tuple<uint32_t, uint32_t, uint32_t, uint32_t>>;
+TEST_P(Expr_Binary_Test_Invalid_MatrixMatrixMultiply, All) {
+ uint32_t lhs_mat_rows = std::get<0>(GetParam());
+ uint32_t lhs_mat_cols = std::get<1>(GetParam());
+ uint32_t rhs_mat_rows = std::get<2>(GetParam());
+ uint32_t rhs_mat_cols = std::get<3>(GetParam());
+
+ auto* lhs_type = ty.mat<f32>(lhs_mat_cols, lhs_mat_rows);
+ auto* rhs_type = ty.mat<f32>(rhs_mat_cols, rhs_mat_rows);
+
+ auto* f32 = create<sem::F32>();
+ auto* col = create<sem::Vector>(f32, lhs_mat_rows);
+ auto* result_type = create<sem::Matrix>(col, rhs_mat_cols);
+
+ Global("lhs", lhs_type, ast::StorageClass::kPrivate);
+ Global("rhs", rhs_type, ast::StorageClass::kPrivate);
+
+ auto* expr = Mul(Source{{12, 34}}, Expr("lhs"), Expr("rhs"));
+ WrapInFunction(expr);
+
+ bool is_valid_expr = lhs_mat_cols == rhs_mat_rows;
+ if (is_valid_expr) {
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(TypeOf(expr) == result_type);
+ } else {
+ ASSERT_FALSE(r()->Resolve());
+ ASSERT_EQ(r()->error(),
+ "12:34 error: Binary expression operand types are invalid for "
+ "this operation: " +
+ FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op) +
+ " " + FriendlyName(rhs_type));
+ }
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTest,
+ Expr_Binary_Test_Invalid_MatrixMatrixMultiply,
+ testing::Combine(all_dimension_values,
+ all_dimension_values,
+ all_dimension_values,
+ all_dimension_values));
+
+} // namespace ExprBinaryTest
+
+using UnaryOpExpressionTest = ResolverTestWithParam<ast::UnaryOp>;
+TEST_P(UnaryOpExpressionTest, Expr_UnaryOp) {
+ auto op = GetParam();
+
+ if (op == ast::UnaryOp::kNot) {
+ Global("ident", ty.vec4<bool>(), ast::StorageClass::kPrivate);
+ } else if (op == ast::UnaryOp::kNegation || op == ast::UnaryOp::kComplement) {
+ Global("ident", ty.vec4<i32>(), ast::StorageClass::kPrivate);
+ } else {
+ Global("ident", ty.vec4<f32>(), ast::StorageClass::kPrivate);
+ }
+ auto* der = create<ast::UnaryOpExpression>(op, Expr("ident"));
+ WrapInFunction(der);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(der), nullptr);
+ ASSERT_TRUE(TypeOf(der)->Is<sem::Vector>());
+ if (op == ast::UnaryOp::kNot) {
+ EXPECT_TRUE(TypeOf(der)->As<sem::Vector>()->type()->Is<sem::Bool>());
+ } else if (op == ast::UnaryOp::kNegation || op == ast::UnaryOp::kComplement) {
+ EXPECT_TRUE(TypeOf(der)->As<sem::Vector>()->type()->Is<sem::I32>());
+ } else {
+ EXPECT_TRUE(TypeOf(der)->As<sem::Vector>()->type()->Is<sem::F32>());
+ }
+ EXPECT_EQ(TypeOf(der)->As<sem::Vector>()->Width(), 4u);
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTest,
+ UnaryOpExpressionTest,
+ testing::Values(ast::UnaryOp::kComplement,
+ ast::UnaryOp::kNegation,
+ ast::UnaryOp::kNot));
+
+TEST_F(ResolverTest, StorageClass_SetsIfMissing) {
+ auto* var = Var("var", ty.i32());
+
+ auto* stmt = Decl(var);
+ Func("func", ast::VariableList{}, ty.void_(), {stmt}, ast::AttributeList{});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ EXPECT_EQ(Sem().Get(var)->StorageClass(), ast::StorageClass::kFunction);
+}
+
+TEST_F(ResolverTest, StorageClass_SetForSampler) {
+ auto* t = ty.sampler(ast::SamplerKind::kSampler);
+ auto* var = Global("var", t,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ EXPECT_EQ(Sem().Get(var)->StorageClass(),
+ ast::StorageClass::kUniformConstant);
+}
+
+TEST_F(ResolverTest, StorageClass_SetForTexture) {
+ auto* t = ty.sampled_texture(ast::TextureDimension::k1d, ty.f32());
+ auto* var = Global("var", t,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ EXPECT_EQ(Sem().Get(var)->StorageClass(),
+ ast::StorageClass::kUniformConstant);
+}
+
+TEST_F(ResolverTest, StorageClass_DoesNotSetOnConst) {
+ auto* var = Const("var", ty.i32(), Construct(ty.i32()));
+ auto* stmt = Decl(var);
+ Func("func", ast::VariableList{}, ty.void_(), {stmt}, ast::AttributeList{});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ EXPECT_EQ(Sem().Get(var)->StorageClass(), ast::StorageClass::kNone);
+}
+
+TEST_F(ResolverTest, Access_SetForStorageBuffer) {
+ // [[block]] struct S { x : i32 };
+ // var<storage> g : S;
+ auto* s = Structure("S", {Member(Source{{12, 34}}, "x", ty.i32())},
+ {create<ast::StructBlockAttribute>()});
+ auto* var =
+ Global(Source{{56, 78}}, "g", ty.Of(s), ast::StorageClass::kStorage,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ EXPECT_EQ(Sem().Get(var)->Access(), ast::Access::kRead);
+}
+
+TEST_F(ResolverTest, BindingPoint_SetForResources) {
+ // @group(1) @binding(2) var s1 : sampler;
+ // @group(3) @binding(4) var s2 : sampler;
+ auto* s1 = Global(Sym(), ty.sampler(ast::SamplerKind::kSampler),
+ ast::AttributeList{create<ast::GroupAttribute>(1),
+ create<ast::BindingAttribute>(2)});
+ auto* s2 = Global(Sym(), ty.sampler(ast::SamplerKind::kSampler),
+ ast::AttributeList{create<ast::GroupAttribute>(3),
+ create<ast::BindingAttribute>(4)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ EXPECT_EQ(Sem().Get<sem::GlobalVariable>(s1)->BindingPoint(),
+ (sem::BindingPoint{1u, 2u}));
+ EXPECT_EQ(Sem().Get<sem::GlobalVariable>(s2)->BindingPoint(),
+ (sem::BindingPoint{3u, 4u}));
+}
+
+TEST_F(ResolverTest, Function_EntryPoints_StageAttribute) {
+ // fn b() {}
+ // fn c() { b(); }
+ // fn a() { c(); }
+ // fn ep_1() { a(); b(); }
+ // fn ep_2() { c();}
+ //
+ // c -> {ep_1, ep_2}
+ // a -> {ep_1}
+ // b -> {ep_1, ep_2}
+ // ep_1 -> {}
+ // ep_2 -> {}
+
+ Global("first", ty.f32(), ast::StorageClass::kPrivate);
+ Global("second", ty.f32(), ast::StorageClass::kPrivate);
+ Global("call_a", ty.f32(), ast::StorageClass::kPrivate);
+ Global("call_b", ty.f32(), ast::StorageClass::kPrivate);
+ Global("call_c", ty.f32(), ast::StorageClass::kPrivate);
+
+ ast::VariableList params;
+ auto* func_b =
+ Func("b", params, ty.f32(), {Return(0.0f)}, ast::AttributeList{});
+ auto* func_c =
+ Func("c", params, ty.f32(), {Assign("second", Call("b")), Return(0.0f)},
+ ast::AttributeList{});
+
+ auto* func_a =
+ Func("a", params, ty.f32(), {Assign("first", Call("c")), Return(0.0f)},
+ ast::AttributeList{});
+
+ auto* ep_1 = Func("ep_1", params, ty.void_(),
+ {
+ Assign("call_a", Call("a")),
+ Assign("call_b", Call("b")),
+ },
+ ast::AttributeList{Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(1)});
+
+ auto* ep_2 = Func("ep_2", params, ty.void_(),
+ {
+ Assign("call_c", Call("c")),
+ },
+ ast::AttributeList{Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(1)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_b_sem = Sem().Get(func_b);
+ auto* func_a_sem = Sem().Get(func_a);
+ auto* func_c_sem = Sem().Get(func_c);
+ auto* ep_1_sem = Sem().Get(ep_1);
+ auto* ep_2_sem = Sem().Get(ep_2);
+ ASSERT_NE(func_b_sem, nullptr);
+ ASSERT_NE(func_a_sem, nullptr);
+ ASSERT_NE(func_c_sem, nullptr);
+ ASSERT_NE(ep_1_sem, nullptr);
+ ASSERT_NE(ep_2_sem, nullptr);
+
+ EXPECT_EQ(func_b_sem->Parameters().size(), 0u);
+ EXPECT_EQ(func_a_sem->Parameters().size(), 0u);
+ EXPECT_EQ(func_c_sem->Parameters().size(), 0u);
+
+ const auto& b_eps = func_b_sem->AncestorEntryPoints();
+ ASSERT_EQ(2u, b_eps.size());
+ EXPECT_EQ(Symbols().Register("ep_1"), b_eps[0]->Declaration()->symbol);
+ EXPECT_EQ(Symbols().Register("ep_2"), b_eps[1]->Declaration()->symbol);
+
+ const auto& a_eps = func_a_sem->AncestorEntryPoints();
+ ASSERT_EQ(1u, a_eps.size());
+ EXPECT_EQ(Symbols().Register("ep_1"), a_eps[0]->Declaration()->symbol);
+
+ const auto& c_eps = func_c_sem->AncestorEntryPoints();
+ ASSERT_EQ(2u, c_eps.size());
+ EXPECT_EQ(Symbols().Register("ep_1"), c_eps[0]->Declaration()->symbol);
+ EXPECT_EQ(Symbols().Register("ep_2"), c_eps[1]->Declaration()->symbol);
+
+ EXPECT_TRUE(ep_1_sem->AncestorEntryPoints().empty());
+ EXPECT_TRUE(ep_2_sem->AncestorEntryPoints().empty());
+}
+
+// Check for linear-time traversal of functions reachable from entry points.
+// See: crbug.com/tint/245
+TEST_F(ResolverTest, Function_EntryPoints_LinearTime) {
+ // fn lNa() { }
+ // fn lNb() { }
+ // ...
+ // fn l2a() { l3a(); l3b(); }
+ // fn l2b() { l3a(); l3b(); }
+ // fn l1a() { l2a(); l2b(); }
+ // fn l1b() { l2a(); l2b(); }
+ // fn main() { l1a(); l1b(); }
+
+ static constexpr int levels = 64;
+
+ auto fn_a = [](int level) { return "l" + std::to_string(level + 1) + "a"; };
+ auto fn_b = [](int level) { return "l" + std::to_string(level + 1) + "b"; };
+
+ Func(fn_a(levels), {}, ty.void_(), {}, {});
+ Func(fn_b(levels), {}, ty.void_(), {}, {});
+
+ for (int i = levels - 1; i >= 0; i--) {
+ Func(fn_a(i), {}, ty.void_(),
+ {
+ CallStmt(Call(fn_a(i + 1))),
+ CallStmt(Call(fn_b(i + 1))),
+ },
+ {});
+ Func(fn_b(i), {}, ty.void_(),
+ {
+ CallStmt(Call(fn_a(i + 1))),
+ CallStmt(Call(fn_b(i + 1))),
+ },
+ {});
+ }
+
+ Func("main", {}, ty.void_(),
+ {
+ CallStmt(Call(fn_a(0))),
+ CallStmt(Call(fn_b(0))),
+ },
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+// Test for crbug.com/tint/728
+TEST_F(ResolverTest, ASTNodesAreReached) {
+ Structure("A", {Member("x", ty.array<f32, 4>(4))});
+ Structure("B", {Member("x", ty.array<f32, 4>(4))});
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTest, ASTNodeNotReached) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.Expr("expr");
+ Resolver(&b).Resolve();
+ },
+ "internal compiler error: AST node 'tint::ast::IdentifierExpression' was "
+ "not reached by the resolver");
+}
+
+TEST_F(ResolverTest, ASTNodeReachedTwice) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ auto* expr = b.Expr(1);
+ b.Global("a", b.ty.i32(), ast::StorageClass::kPrivate, expr);
+ b.Global("b", b.ty.i32(), ast::StorageClass::kPrivate, expr);
+ Resolver(&b).Resolve();
+ },
+ "internal compiler error: AST node 'tint::ast::SintLiteralExpression' "
+ "was encountered twice in the same AST of a Program");
+}
+
+TEST_F(ResolverTest, UnaryOp_Not) {
+ Global("ident", ty.vec4<f32>(), ast::StorageClass::kPrivate);
+ auto* der = create<ast::UnaryOpExpression>(ast::UnaryOp::kNot,
+ Expr(Source{{12, 34}}, "ident"));
+ WrapInFunction(der);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cannot logical negate expression of type 'vec4<f32>");
+}
+
+TEST_F(ResolverTest, UnaryOp_Complement) {
+ Global("ident", ty.vec4<f32>(), ast::StorageClass::kPrivate);
+ auto* der = create<ast::UnaryOpExpression>(ast::UnaryOp::kComplement,
+ Expr(Source{{12, 34}}, "ident"));
+ WrapInFunction(der);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: cannot bitwise complement expression of type 'vec4<f32>");
+}
+
+TEST_F(ResolverTest, UnaryOp_Negation) {
+ Global("ident", ty.u32(), ast::StorageClass::kPrivate);
+ auto* der = create<ast::UnaryOpExpression>(ast::UnaryOp::kNegation,
+ Expr(Source{{12, 34}}, "ident"));
+ WrapInFunction(der);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: cannot negate expression of type 'u32");
+}
+
+TEST_F(ResolverTest, TextureSampler_TextureSample) {
+ Global("t", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()),
+ GroupAndBinding(1, 1));
+ Global("s", ty.sampler(ast::SamplerKind::kSampler), GroupAndBinding(1, 2));
+
+ auto* call = CallStmt(Call("textureSample", "t", "s", vec2<f32>(1.0f, 2.0f)));
+ const ast::Function* f = Func("test_function", {}, ty.void_(), {call},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ const sem::Function* sf = Sem().Get(f);
+ auto pairs = sf->TextureSamplerPairs();
+ ASSERT_EQ(pairs.size(), 1u);
+ EXPECT_TRUE(pairs[0].first != nullptr);
+ EXPECT_TRUE(pairs[0].second != nullptr);
+}
+
+TEST_F(ResolverTest, TextureSampler_TextureSampleInFunction) {
+ Global("t", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()),
+ GroupAndBinding(1, 1));
+ Global("s", ty.sampler(ast::SamplerKind::kSampler), GroupAndBinding(1, 2));
+
+ auto* inner_call =
+ CallStmt(Call("textureSample", "t", "s", vec2<f32>(1.0f, 2.0f)));
+ const ast::Function* inner_func =
+ Func("inner_func", {}, ty.void_(), {inner_call});
+ auto* outer_call = CallStmt(Call("inner_func"));
+ const ast::Function* outer_func =
+ Func("outer_func", {}, ty.void_(), {outer_call},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto inner_pairs = Sem().Get(inner_func)->TextureSamplerPairs();
+ ASSERT_EQ(inner_pairs.size(), 1u);
+ EXPECT_TRUE(inner_pairs[0].first != nullptr);
+ EXPECT_TRUE(inner_pairs[0].second != nullptr);
+
+ auto outer_pairs = Sem().Get(outer_func)->TextureSamplerPairs();
+ ASSERT_EQ(outer_pairs.size(), 1u);
+ EXPECT_TRUE(outer_pairs[0].first != nullptr);
+ EXPECT_TRUE(outer_pairs[0].second != nullptr);
+}
+
+TEST_F(ResolverTest, TextureSampler_TextureSampleFunctionDiamondSameVariables) {
+ Global("t", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()),
+ GroupAndBinding(1, 1));
+ Global("s", ty.sampler(ast::SamplerKind::kSampler), GroupAndBinding(1, 2));
+
+ auto* inner_call_1 =
+ CallStmt(Call("textureSample", "t", "s", vec2<f32>(1.0f, 2.0f)));
+ const ast::Function* inner_func_1 =
+ Func("inner_func_1", {}, ty.void_(), {inner_call_1});
+ auto* inner_call_2 =
+ CallStmt(Call("textureSample", "t", "s", vec2<f32>(3.0f, 4.0f)));
+ const ast::Function* inner_func_2 =
+ Func("inner_func_2", {}, ty.void_(), {inner_call_2});
+ auto* outer_call_1 = CallStmt(Call("inner_func_1"));
+ auto* outer_call_2 = CallStmt(Call("inner_func_2"));
+ const ast::Function* outer_func =
+ Func("outer_func", {}, ty.void_(), {outer_call_1, outer_call_2},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto inner_pairs_1 = Sem().Get(inner_func_1)->TextureSamplerPairs();
+ ASSERT_EQ(inner_pairs_1.size(), 1u);
+ EXPECT_TRUE(inner_pairs_1[0].first != nullptr);
+ EXPECT_TRUE(inner_pairs_1[0].second != nullptr);
+
+ auto inner_pairs_2 = Sem().Get(inner_func_2)->TextureSamplerPairs();
+ ASSERT_EQ(inner_pairs_1.size(), 1u);
+ EXPECT_TRUE(inner_pairs_2[0].first != nullptr);
+ EXPECT_TRUE(inner_pairs_2[0].second != nullptr);
+
+ auto outer_pairs = Sem().Get(outer_func)->TextureSamplerPairs();
+ ASSERT_EQ(outer_pairs.size(), 1u);
+ EXPECT_TRUE(outer_pairs[0].first != nullptr);
+ EXPECT_TRUE(outer_pairs[0].second != nullptr);
+}
+
+TEST_F(ResolverTest,
+ TextureSampler_TextureSampleFunctionDiamondDifferentVariables) {
+ Global("t1", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()),
+ GroupAndBinding(1, 1));
+ Global("t2", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()),
+ GroupAndBinding(1, 2));
+ Global("s", ty.sampler(ast::SamplerKind::kSampler), GroupAndBinding(1, 3));
+
+ auto* inner_call_1 =
+ CallStmt(Call("textureSample", "t1", "s", vec2<f32>(1.0f, 2.0f)));
+ const ast::Function* inner_func_1 =
+ Func("inner_func_1", {}, ty.void_(), {inner_call_1});
+ auto* inner_call_2 =
+ CallStmt(Call("textureSample", "t2", "s", vec2<f32>(3.0f, 4.0f)));
+ const ast::Function* inner_func_2 =
+ Func("inner_func_2", {}, ty.void_(), {inner_call_2});
+ auto* outer_call_1 = CallStmt(Call("inner_func_1"));
+ auto* outer_call_2 = CallStmt(Call("inner_func_2"));
+ const ast::Function* outer_func =
+ Func("outer_func", {}, ty.void_(), {outer_call_1, outer_call_2},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto inner_pairs_1 = Sem().Get(inner_func_1)->TextureSamplerPairs();
+ ASSERT_EQ(inner_pairs_1.size(), 1u);
+ EXPECT_TRUE(inner_pairs_1[0].first != nullptr);
+ EXPECT_TRUE(inner_pairs_1[0].second != nullptr);
+
+ auto inner_pairs_2 = Sem().Get(inner_func_2)->TextureSamplerPairs();
+ ASSERT_EQ(inner_pairs_2.size(), 1u);
+ EXPECT_TRUE(inner_pairs_2[0].first != nullptr);
+ EXPECT_TRUE(inner_pairs_2[0].second != nullptr);
+
+ auto outer_pairs = Sem().Get(outer_func)->TextureSamplerPairs();
+ ASSERT_EQ(outer_pairs.size(), 2u);
+ EXPECT_TRUE(outer_pairs[0].first == inner_pairs_1[0].first);
+ EXPECT_TRUE(outer_pairs[0].second == inner_pairs_1[0].second);
+ EXPECT_TRUE(outer_pairs[1].first == inner_pairs_2[0].first);
+ EXPECT_TRUE(outer_pairs[1].second == inner_pairs_2[0].second);
+}
+
+TEST_F(ResolverTest, TextureSampler_TextureDimensions) {
+ Global("t", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()),
+ GroupAndBinding(1, 2));
+
+ auto* call = Call("textureDimensions", "t");
+ const ast::Function* f = WrapInFunction(call);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ const sem::Function* sf = Sem().Get(f);
+ auto pairs = sf->TextureSamplerPairs();
+ ASSERT_EQ(pairs.size(), 1u);
+ EXPECT_TRUE(pairs[0].first != nullptr);
+ EXPECT_TRUE(pairs[0].second == nullptr);
+}
+
+TEST_F(ResolverTest, ModuleDependencyOrderedDeclarations) {
+ auto* f0 = Func("f0", {}, ty.void_(), {});
+ auto* v0 = Global("v0", ty.i32(), ast::StorageClass::kPrivate);
+ auto* a0 = Alias("a0", ty.i32());
+ auto* s0 = Structure("s0", {Member("m", ty.i32())});
+ auto* f1 = Func("f1", {}, ty.void_(), {});
+ auto* v1 = Global("v1", ty.i32(), ast::StorageClass::kPrivate);
+ auto* a1 = Alias("a1", ty.i32());
+ auto* s1 = Structure("s1", {Member("m", ty.i32())});
+ auto* f2 = Func("f2", {}, ty.void_(), {});
+ auto* v2 = Global("v2", ty.i32(), ast::StorageClass::kPrivate);
+ auto* a2 = Alias("a2", ty.i32());
+ auto* s2 = Structure("s2", {Member("m", ty.i32())});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(Sem().Module(), nullptr);
+ EXPECT_THAT(Sem().Module()->DependencyOrderedDeclarations(),
+ ElementsAre(f0, v0, a0, s0, f1, v1, a1, s1, f2, v2, a2, s2));
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/resolver_test_helper.cc b/src/tint/resolver/resolver_test_helper.cc
new file mode 100644
index 0000000..aea14cf
--- /dev/null
+++ b/src/tint/resolver/resolver_test_helper.cc
@@ -0,0 +1,27 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver_test_helper.h"
+
+#include <memory>
+
+namespace tint {
+namespace resolver {
+
+TestHelper::TestHelper() : resolver_(std::make_unique<Resolver>(this)) {}
+
+TestHelper::~TestHelper() = default;
+
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h
new file mode 100644
index 0000000..b128a77
--- /dev/null
+++ b/src/tint/resolver/resolver_test_helper.h
@@ -0,0 +1,489 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_RESOLVER_RESOLVER_TEST_HELPER_H_
+#define SRC_TINT_RESOLVER_RESOLVER_TEST_HELPER_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/resolver/resolver.h"
+#include "src/tint/sem/expression.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/variable.h"
+
+namespace tint {
+namespace resolver {
+
+/// Helper class for testing
+class TestHelper : public ProgramBuilder {
+ public:
+ /// Constructor
+ TestHelper();
+
+ /// Destructor
+ ~TestHelper() override;
+
+ /// @return a pointer to the Resolver
+ Resolver* r() const { return resolver_.get(); }
+
+ /// Returns the statement that holds the given expression.
+ /// @param expr the ast::Expression
+ /// @return the ast::Statement of the ast::Expression, or nullptr if the
+ /// expression is not owned by a statement.
+ const ast::Statement* StmtOf(const ast::Expression* expr) {
+ auto* sem_stmt = Sem().Get(expr)->Stmt();
+ return sem_stmt ? sem_stmt->Declaration() : nullptr;
+ }
+
+ /// Returns the BlockStatement that holds the given statement.
+ /// @param stmt the ast::Statement
+ /// @return the ast::BlockStatement that holds the ast::Statement, or nullptr
+ /// if the statement is not owned by a BlockStatement.
+ const ast::BlockStatement* BlockOf(const ast::Statement* stmt) {
+ auto* sem_stmt = Sem().Get(stmt);
+ return sem_stmt ? sem_stmt->Block()->Declaration() : nullptr;
+ }
+
+ /// Returns the BlockStatement that holds the given expression.
+ /// @param expr the ast::Expression
+ /// @return the ast::Statement of the ast::Expression, or nullptr if the
+ /// expression is not indirectly owned by a BlockStatement.
+ const ast::BlockStatement* BlockOf(const ast::Expression* expr) {
+ auto* sem_stmt = Sem().Get(expr)->Stmt();
+ return sem_stmt ? sem_stmt->Block()->Declaration() : nullptr;
+ }
+
+ /// Returns the semantic variable for the given identifier expression.
+ /// @param expr the identifier expression
+ /// @return the resolved sem::Variable of the identifier, or nullptr if
+ /// the expression did not resolve to a variable.
+ const sem::Variable* VarOf(const ast::Expression* expr) {
+ auto* sem_ident = Sem().Get(expr);
+ auto* var_user = sem_ident ? sem_ident->As<sem::VariableUser>() : nullptr;
+ return var_user ? var_user->Variable() : nullptr;
+ }
+
+ /// Checks that all the users of the given variable are as expected
+ /// @param var the variable to check
+ /// @param expected_users the expected users of the variable
+ /// @return true if all users are as expected
+ bool CheckVarUsers(const ast::Variable* var,
+ std::vector<const ast::Expression*>&& expected_users) {
+ auto& var_users = Sem().Get(var)->Users();
+ if (var_users.size() != expected_users.size()) {
+ return false;
+ }
+ for (size_t i = 0; i < var_users.size(); i++) {
+ if (var_users[i]->Declaration() != expected_users[i]) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /// @param type a type
+ /// @returns the name for `type` that closely resembles how it would be
+ /// declared in WGSL.
+ std::string FriendlyName(const ast::Type* type) {
+ return type->FriendlyName(Symbols());
+ }
+
+ /// @param type a type
+ /// @returns the name for `type` that closely resembles how it would be
+ /// declared in WGSL.
+ std::string FriendlyName(const sem::Type* type) {
+ return type->FriendlyName(Symbols());
+ }
+
+ private:
+ std::unique_ptr<Resolver> resolver_;
+};
+
+class ResolverTest : public TestHelper, public testing::Test {};
+
+template <typename T>
+class ResolverTestWithParam : public TestHelper,
+ public testing::TestWithParam<T> {};
+
+namespace builder {
+
+using i32 = ProgramBuilder::i32;
+using u32 = ProgramBuilder::u32;
+using f32 = ProgramBuilder::f32;
+
+template <int N, typename T>
+struct vec {};
+
+template <typename T>
+using vec2 = vec<2, T>;
+
+template <typename T>
+using vec3 = vec<3, T>;
+
+template <typename T>
+using vec4 = vec<4, T>;
+
+template <int N, int M, typename T>
+struct mat {};
+
+template <typename T>
+using mat2x2 = mat<2, 2, T>;
+
+template <typename T>
+using mat2x3 = mat<2, 3, T>;
+
+template <typename T>
+using mat3x2 = mat<3, 2, T>;
+
+template <typename T>
+using mat3x3 = mat<3, 3, T>;
+
+template <typename T>
+using mat4x4 = mat<4, 4, T>;
+
+template <int N, typename T>
+struct array {};
+
+template <typename TO, int ID = 0>
+struct alias {};
+
+template <typename TO>
+using alias1 = alias<TO, 1>;
+
+template <typename TO>
+using alias2 = alias<TO, 2>;
+
+template <typename TO>
+using alias3 = alias<TO, 3>;
+
+template <typename TO>
+struct ptr {};
+
+using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b);
+using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b,
+ int elem_value);
+using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b);
+
+template <typename T>
+struct DataType {};
+
+/// Helper for building bool types and expressions
+template <>
+struct DataType<bool> {
+ /// false as bool is not a composite type
+ static constexpr bool is_composite = false;
+
+ /// @param b the ProgramBuilder
+ /// @return a new AST bool type
+ static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.bool_(); }
+ /// @param b the ProgramBuilder
+ /// @return the semantic bool type
+ static inline const sem::Type* Sem(ProgramBuilder& b) {
+ return b.create<sem::Bool>();
+ }
+ /// @param b the ProgramBuilder
+ /// @param elem_value the b
+ /// @return a new AST expression of the bool type
+ static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+ return b.Expr(elem_value == 0);
+ }
+};
+
+/// Helper for building i32 types and expressions
+template <>
+struct DataType<i32> {
+ /// false as i32 is not a composite type
+ static constexpr bool is_composite = false;
+
+ /// @param b the ProgramBuilder
+ /// @return a new AST i32 type
+ static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.i32(); }
+ /// @param b the ProgramBuilder
+ /// @return the semantic i32 type
+ static inline const sem::Type* Sem(ProgramBuilder& b) {
+ return b.create<sem::I32>();
+ }
+ /// @param b the ProgramBuilder
+ /// @param elem_value the value i32 will be initialized with
+ /// @return a new AST i32 literal value expression
+ static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+ return b.Expr(static_cast<i32>(elem_value));
+ }
+};
+
+/// Helper for building u32 types and expressions
+template <>
+struct DataType<u32> {
+ /// false as u32 is not a composite type
+ static constexpr bool is_composite = false;
+
+ /// @param b the ProgramBuilder
+ /// @return a new AST u32 type
+ static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.u32(); }
+ /// @param b the ProgramBuilder
+ /// @return the semantic u32 type
+ static inline const sem::Type* Sem(ProgramBuilder& b) {
+ return b.create<sem::U32>();
+ }
+ /// @param b the ProgramBuilder
+ /// @param elem_value the value u32 will be initialized with
+ /// @return a new AST u32 literal value expression
+ static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+ return b.Expr(static_cast<u32>(elem_value));
+ }
+};
+
+/// Helper for building f32 types and expressions
+template <>
+struct DataType<f32> {
+ /// false as f32 is not a composite type
+ static constexpr bool is_composite = false;
+
+ /// @param b the ProgramBuilder
+ /// @return a new AST f32 type
+ static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.f32(); }
+ /// @param b the ProgramBuilder
+ /// @return the semantic f32 type
+ static inline const sem::Type* Sem(ProgramBuilder& b) {
+ return b.create<sem::F32>();
+ }
+ /// @param b the ProgramBuilder
+ /// @param elem_value the value f32 will be initialized with
+ /// @return a new AST f32 literal value expression
+ static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+ return b.Expr(static_cast<f32>(elem_value));
+ }
+};
+
+/// Helper for building vector types and expressions
+template <int N, typename T>
+struct DataType<vec<N, T>> {
+ /// true as vectors are a composite type
+ static constexpr bool is_composite = true;
+
+ /// @param b the ProgramBuilder
+ /// @return a new AST vector type
+ static inline const ast::Type* AST(ProgramBuilder& b) {
+ return b.ty.vec(DataType<T>::AST(b), N);
+ }
+ /// @param b the ProgramBuilder
+ /// @return the semantic vector type
+ static inline const sem::Type* Sem(ProgramBuilder& b) {
+ return b.create<sem::Vector>(DataType<T>::Sem(b), N);
+ }
+ /// @param b the ProgramBuilder
+ /// @param elem_value the value each element in the vector will be initialized
+ /// with
+ /// @return a new AST vector value expression
+ static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+ return b.Construct(AST(b), ExprArgs(b, elem_value));
+ }
+
+ /// @param b the ProgramBuilder
+ /// @param elem_value the value each element will be initialized with
+ /// @return the list of expressions that are used to construct the vector
+ static inline ast::ExpressionList ExprArgs(ProgramBuilder& b,
+ int elem_value) {
+ ast::ExpressionList args;
+ for (int i = 0; i < N; i++) {
+ args.emplace_back(DataType<T>::Expr(b, elem_value));
+ }
+ return args;
+ }
+};
+
+/// Helper for building matrix types and expressions
+template <int N, int M, typename T>
+struct DataType<mat<N, M, T>> {
+ /// true as matrices are a composite type
+ static constexpr bool is_composite = true;
+
+ /// @param b the ProgramBuilder
+ /// @return a new AST matrix type
+ static inline const ast::Type* AST(ProgramBuilder& b) {
+ return b.ty.mat(DataType<T>::AST(b), N, M);
+ }
+ /// @param b the ProgramBuilder
+ /// @return the semantic matrix type
+ static inline const sem::Type* Sem(ProgramBuilder& b) {
+ auto* column_type = b.create<sem::Vector>(DataType<T>::Sem(b), M);
+ return b.create<sem::Matrix>(column_type, N);
+ }
+ /// @param b the ProgramBuilder
+ /// @param elem_value the value each element in the matrix will be initialized
+ /// with
+ /// @return a new AST matrix value expression
+ static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+ return b.Construct(AST(b), ExprArgs(b, elem_value));
+ }
+
+ /// @param b the ProgramBuilder
+ /// @param elem_value the value each element will be initialized with
+ /// @return the list of expressions that are used to construct the matrix
+ static inline ast::ExpressionList ExprArgs(ProgramBuilder& b,
+ int elem_value) {
+ ast::ExpressionList args;
+ for (int i = 0; i < N; i++) {
+ args.emplace_back(DataType<vec<M, T>>::Expr(b, elem_value));
+ }
+ return args;
+ }
+};
+
+/// Helper for building alias types and expressions
+template <typename T, int ID>
+struct DataType<alias<T, ID>> {
+ /// true if the aliased type is a composite type
+ static constexpr bool is_composite = DataType<T>::is_composite;
+
+ /// @param b the ProgramBuilder
+ /// @return a new AST alias type
+ static inline const ast::Type* AST(ProgramBuilder& b) {
+ auto name = b.Symbols().Register("alias_" + std::to_string(ID));
+ if (!b.AST().LookupType(name)) {
+ auto* type = DataType<T>::AST(b);
+ b.AST().AddTypeDecl(b.ty.alias(name, type));
+ }
+ return b.create<ast::TypeName>(name);
+ }
+ /// @param b the ProgramBuilder
+ /// @return the semantic aliased type
+ static inline const sem::Type* Sem(ProgramBuilder& b) {
+ return DataType<T>::Sem(b);
+ }
+
+ /// @param b the ProgramBuilder
+ /// @param elem_value the value nested elements will be initialized with
+ /// @return a new AST expression of the alias type
+ template <bool IS_COMPOSITE = is_composite>
+ static inline traits::EnableIf<!IS_COMPOSITE, const ast::Expression*> Expr(
+ ProgramBuilder& b,
+ int elem_value) {
+ // Cast
+ return b.Construct(AST(b), DataType<T>::Expr(b, elem_value));
+ }
+
+ /// @param b the ProgramBuilder
+ /// @param elem_value the value nested elements will be initialized with
+ /// @return a new AST expression of the alias type
+ template <bool IS_COMPOSITE = is_composite>
+ static inline traits::EnableIf<IS_COMPOSITE, const ast::Expression*> Expr(
+ ProgramBuilder& b,
+ int elem_value) {
+ // Construct
+ return b.Construct(AST(b), DataType<T>::ExprArgs(b, elem_value));
+ }
+};
+
+/// Helper for building pointer types and expressions
+template <typename T>
+struct DataType<ptr<T>> {
+ /// true if the pointer type is a composite type
+ static constexpr bool is_composite = false;
+
+ /// @param b the ProgramBuilder
+ /// @return a new AST alias type
+ static inline const ast::Type* AST(ProgramBuilder& b) {
+ return b.create<ast::Pointer>(DataType<T>::AST(b),
+ ast::StorageClass::kPrivate,
+ ast::Access::kReadWrite);
+ }
+ /// @param b the ProgramBuilder
+ /// @return the semantic aliased type
+ static inline const sem::Type* Sem(ProgramBuilder& b) {
+ return b.create<sem::Pointer>(DataType<T>::Sem(b),
+ ast::StorageClass::kPrivate,
+ ast::Access::kReadWrite);
+ }
+
+ /// @param b the ProgramBuilder
+ /// @return a new AST expression of the alias type
+ static inline const ast::Expression* Expr(ProgramBuilder& b, int /*unused*/) {
+ auto sym = b.Symbols().New("global_for_ptr");
+ b.Global(sym, DataType<T>::AST(b), ast::StorageClass::kPrivate);
+ return b.AddressOf(sym);
+ }
+};
+
+/// Helper for building array types and expressions
+template <int N, typename T>
+struct DataType<array<N, T>> {
+ /// true as arrays are a composite type
+ static constexpr bool is_composite = true;
+
+ /// @param b the ProgramBuilder
+ /// @return a new AST array type
+ static inline const ast::Type* AST(ProgramBuilder& b) {
+ return b.ty.array(DataType<T>::AST(b), N);
+ }
+ /// @param b the ProgramBuilder
+ /// @return the semantic array type
+ static inline const sem::Type* Sem(ProgramBuilder& b) {
+ auto* el = DataType<T>::Sem(b);
+ return b.create<sem::Array>(
+ /* element */ el,
+ /* count */ N,
+ /* align */ el->Align(),
+ /* size */ el->Size(),
+ /* stride */ el->Align(),
+ /* implicit_stride */ el->Align());
+ }
+ /// @param b the ProgramBuilder
+ /// @param elem_value the value each element in the array will be initialized
+ /// with
+ /// @return a new AST array value expression
+ static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+ return b.Construct(AST(b), ExprArgs(b, elem_value));
+ }
+
+ /// @param b the ProgramBuilder
+ /// @param elem_value the value each element will be initialized with
+ /// @return the list of expressions that are used to construct the array
+ static inline ast::ExpressionList ExprArgs(ProgramBuilder& b,
+ int elem_value) {
+ ast::ExpressionList args;
+ for (int i = 0; i < N; i++) {
+ args.emplace_back(DataType<T>::Expr(b, elem_value));
+ }
+ return args;
+ }
+};
+
+/// Struct of all creation pointer types
+struct CreatePtrs {
+ /// ast node type create function
+ ast_type_func_ptr ast;
+ /// ast expression type create function
+ ast_expr_func_ptr expr;
+ /// sem type create function
+ sem_type_func_ptr sem;
+};
+
+/// Returns a CreatePtrs struct instance with all creation pointer types for
+/// type `T`
+template <typename T>
+constexpr CreatePtrs CreatePtrsFor() {
+ return {DataType<T>::AST, DataType<T>::Expr, DataType<T>::Sem};
+}
+
+} // namespace builder
+
+} // namespace resolver
+} // namespace tint
+
+#endif // SRC_TINT_RESOLVER_RESOLVER_TEST_HELPER_H_
diff --git a/src/tint/resolver/resolver_validation.cc b/src/tint/resolver/resolver_validation.cc
new file mode 100644
index 0000000..a23079d
--- /dev/null
+++ b/src/tint/resolver/resolver_validation.cc
@@ -0,0 +1,2369 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include <algorithm>
+#include <limits>
+#include <utility>
+
+#include "src/tint/ast/alias.h"
+#include "src/tint/ast/array.h"
+#include "src/tint/ast/assignment_statement.h"
+#include "src/tint/ast/bitcast_expression.h"
+#include "src/tint/ast/break_statement.h"
+#include "src/tint/ast/call_statement.h"
+#include "src/tint/ast/continue_statement.h"
+#include "src/tint/ast/depth_texture.h"
+#include "src/tint/ast/disable_validation_attribute.h"
+#include "src/tint/ast/discard_statement.h"
+#include "src/tint/ast/fallthrough_statement.h"
+#include "src/tint/ast/for_loop_statement.h"
+#include "src/tint/ast/id_attribute.h"
+#include "src/tint/ast/if_statement.h"
+#include "src/tint/ast/internal_attribute.h"
+#include "src/tint/ast/interpolate_attribute.h"
+#include "src/tint/ast/loop_statement.h"
+#include "src/tint/ast/matrix.h"
+#include "src/tint/ast/pointer.h"
+#include "src/tint/ast/return_statement.h"
+#include "src/tint/ast/sampled_texture.h"
+#include "src/tint/ast/sampler.h"
+#include "src/tint/ast/storage_texture.h"
+#include "src/tint/ast/switch_statement.h"
+#include "src/tint/ast/traverse_expressions.h"
+#include "src/tint/ast/type_name.h"
+#include "src/tint/ast/unary_op_expression.h"
+#include "src/tint/ast/variable_decl_statement.h"
+#include "src/tint/ast/vector.h"
+#include "src/tint/ast/workgroup_attribute.h"
+#include "src/tint/sem/array.h"
+#include "src/tint/sem/atomic_type.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/depth_multisampled_texture_type.h"
+#include "src/tint/sem/depth_texture_type.h"
+#include "src/tint/sem/for_loop_statement.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/if_statement.h"
+#include "src/tint/sem/loop_statement.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/multisampled_texture_type.h"
+#include "src/tint/sem/pointer_type.h"
+#include "src/tint/sem/reference_type.h"
+#include "src/tint/sem/sampled_texture_type.h"
+#include "src/tint/sem/sampler_type.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/storage_texture_type.h"
+#include "src/tint/sem/struct.h"
+#include "src/tint/sem/switch_statement.h"
+#include "src/tint/sem/type_constructor.h"
+#include "src/tint/sem/type_conversion.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/utils/defer.h"
+#include "src/tint/utils/map.h"
+#include "src/tint/utils/math.h"
+#include "src/tint/utils/reverse.h"
+#include "src/tint/utils/scoped_assignment.h"
+#include "src/tint/utils/transform.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+bool IsValidStorageTextureDimension(ast::TextureDimension dim) {
+ switch (dim) {
+ case ast::TextureDimension::k1d:
+ case ast::TextureDimension::k2d:
+ case ast::TextureDimension::k2dArray:
+ case ast::TextureDimension::k3d:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool IsValidStorageTextureTexelFormat(ast::TexelFormat format) {
+ switch (format) {
+ case ast::TexelFormat::kR32Uint:
+ case ast::TexelFormat::kR32Sint:
+ case ast::TexelFormat::kR32Float:
+ case ast::TexelFormat::kRg32Uint:
+ case ast::TexelFormat::kRg32Sint:
+ case ast::TexelFormat::kRg32Float:
+ case ast::TexelFormat::kRgba8Unorm:
+ case ast::TexelFormat::kRgba8Snorm:
+ case ast::TexelFormat::kRgba8Uint:
+ case ast::TexelFormat::kRgba8Sint:
+ case ast::TexelFormat::kRgba16Uint:
+ case ast::TexelFormat::kRgba16Sint:
+ case ast::TexelFormat::kRgba16Float:
+ case ast::TexelFormat::kRgba32Uint:
+ case ast::TexelFormat::kRgba32Sint:
+ case ast::TexelFormat::kRgba32Float:
+ return true;
+ default:
+ return false;
+ }
+}
+
+// Helper to stringify a pipeline IO attribute.
+std::string attr_to_str(const ast::Attribute* attr) {
+ std::stringstream str;
+ if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
+ str << "builtin(" << builtin->builtin << ")";
+ } else if (auto* location = attr->As<ast::LocationAttribute>()) {
+ str << "location(" << location->value << ")";
+ }
+ return str.str();
+}
+
+template <typename CALLBACK>
+void TraverseCallChain(diag::List& diagnostics,
+ const sem::Function* from,
+ const sem::Function* to,
+ CALLBACK&& callback) {
+ for (auto* f : from->TransitivelyCalledFunctions()) {
+ if (f == to) {
+ callback(f);
+ return;
+ }
+ if (f->TransitivelyCalledFunctions().contains(to)) {
+ TraverseCallChain(diagnostics, f, to, callback);
+ callback(f);
+ return;
+ }
+ }
+ TINT_ICE(Resolver, diagnostics)
+ << "TraverseCallChain() 'from' does not transitively call 'to'";
+}
+
+} // namespace
+
+bool Resolver::ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s) {
+ // https://gpuweb.github.io/gpuweb/wgsl/#atomic-types
+ // T must be either u32 or i32.
+ if (!s->Type()->IsAnyOf<sem::U32, sem::I32>()) {
+ AddError("atomic only supports i32 or u32 types",
+ a->type ? a->type->source : a->source);
+ return false;
+ }
+ return true;
+}
+
+bool Resolver::ValidateStorageTexture(const ast::StorageTexture* t) {
+ switch (t->access) {
+ case ast::Access::kWrite:
+ break;
+ case ast::Access::kUndefined:
+ AddError("storage texture missing access control", t->source);
+ return false;
+ default:
+ AddError("storage textures currently only support 'write' access control",
+ t->source);
+ return false;
+ }
+
+ if (!IsValidStorageTextureDimension(t->dim)) {
+ AddError("cube dimensions for storage textures are not supported",
+ t->source);
+ return false;
+ }
+
+ if (!IsValidStorageTextureTexelFormat(t->format)) {
+ AddError(
+ "image format must be one of the texel formats specified for storage "
+ "textues in https://gpuweb.github.io/gpuweb/wgsl/#texel-formats",
+ t->source);
+ return false;
+ }
+ return true;
+}
+
+bool Resolver::ValidateVariableConstructorOrCast(
+ const ast::Variable* var,
+ ast::StorageClass storage_class,
+ const sem::Type* storage_ty,
+ const sem::Type* rhs_ty) {
+ auto* value_type = rhs_ty->UnwrapRef(); // Implicit load of RHS
+
+ // Value type has to match storage type
+ if (storage_ty != value_type) {
+ std::string decl = var->is_const ? "let" : "var";
+ AddError("cannot initialize " + decl + " of type '" +
+ TypeNameOf(storage_ty) + "' with value of type '" +
+ TypeNameOf(rhs_ty) + "'",
+ var->source);
+ return false;
+ }
+
+ if (!var->is_const) {
+ switch (storage_class) {
+ case ast::StorageClass::kPrivate:
+ case ast::StorageClass::kFunction:
+ break; // Allowed an initializer
+ default:
+ // https://gpuweb.github.io/gpuweb/wgsl/#var-and-let
+ // Optionally has an initializer expression, if the variable is in the
+ // private or function storage classes.
+ AddError("var of storage class '" +
+ std::string(ast::ToString(storage_class)) +
+ "' cannot have an initializer. var initializers are only "
+ "supported for the storage classes "
+ "'private' and 'function'",
+ var->source);
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateStorageClassLayout(const sem::Type* store_ty,
+ ast::StorageClass sc,
+ Source source) {
+ // https://gpuweb.github.io/gpuweb/wgsl/#storage-class-layout-constraints
+
+ auto is_uniform_struct_or_array = [sc](const sem::Type* ty) {
+ return sc == ast::StorageClass::kUniform &&
+ ty->IsAnyOf<sem::Array, sem::Struct>();
+ };
+
+ auto is_uniform_struct = [sc](const sem::Type* ty) {
+ return sc == ast::StorageClass::kUniform && ty->Is<sem::Struct>();
+ };
+
+ auto required_alignment_of = [&](const sem::Type* ty) {
+ uint32_t actual_align = ty->Align();
+ uint32_t required_align = actual_align;
+ if (is_uniform_struct_or_array(ty)) {
+ required_align = utils::RoundUp(16u, actual_align);
+ }
+ return required_align;
+ };
+
+ auto member_name_of = [this](const sem::StructMember* sm) {
+ return builder_->Symbols().NameFor(sm->Declaration()->symbol);
+ };
+
+ // Cache result of type + storage class pair.
+ if (!valid_type_storage_layouts_.emplace(store_ty, sc).second) {
+ return true;
+ }
+
+ if (!ast::IsHostShareable(sc)) {
+ return true;
+ }
+
+ if (auto* str = store_ty->As<sem::Struct>()) {
+ for (size_t i = 0; i < str->Members().size(); ++i) {
+ auto* const m = str->Members()[i];
+ uint32_t required_align = required_alignment_of(m->Type());
+
+ // Recurse into the member type.
+ if (!ValidateStorageClassLayout(m->Type(), sc,
+ m->Declaration()->type->source)) {
+ AddNote("see layout of struct:\n" + str->Layout(builder_->Symbols()),
+ str->Declaration()->source);
+ return false;
+ }
+
+ // Validate that member is at a valid byte offset
+ if (m->Offset() % required_align != 0) {
+ AddError("the offset of a struct member of type '" +
+ m->Type()->UnwrapRef()->FriendlyName(builder_->Symbols()) +
+ "' in storage class '" + ast::ToString(sc) +
+ "' must be a multiple of " +
+ std::to_string(required_align) + " bytes, but '" +
+ member_name_of(m) + "' is currently at offset " +
+ std::to_string(m->Offset()) +
+ ". Consider setting @align(" +
+ std::to_string(required_align) + ") on this member",
+ m->Declaration()->source);
+
+ AddNote("see layout of struct:\n" + str->Layout(builder_->Symbols()),
+ str->Declaration()->source);
+
+ if (auto* member_str = m->Type()->As<sem::Struct>()) {
+ AddNote("and layout of struct member:\n" +
+ member_str->Layout(builder_->Symbols()),
+ member_str->Declaration()->source);
+ }
+
+ return false;
+ }
+
+ // For uniform buffers, validate that the number of bytes between the
+ // previous member of type struct and the current is a multiple of 16
+ // bytes.
+ auto* const prev_member = (i == 0) ? nullptr : str->Members()[i - 1];
+ if (prev_member && is_uniform_struct(prev_member->Type())) {
+ const uint32_t prev_to_curr_offset =
+ m->Offset() - prev_member->Offset();
+ if (prev_to_curr_offset % 16 != 0) {
+ AddError(
+ "uniform storage requires that the number of bytes between the "
+ "start of the previous member of type struct and the current "
+ "member be a multiple of 16 bytes, but there are currently " +
+ std::to_string(prev_to_curr_offset) + " bytes between '" +
+ member_name_of(prev_member) + "' and '" + member_name_of(m) +
+ "'. Consider setting @align(16) on this member",
+ m->Declaration()->source);
+
+ AddNote("see layout of struct:\n" + str->Layout(builder_->Symbols()),
+ str->Declaration()->source);
+
+ auto* prev_member_str = prev_member->Type()->As<sem::Struct>();
+ AddNote("and layout of previous member struct:\n" +
+ prev_member_str->Layout(builder_->Symbols()),
+ prev_member_str->Declaration()->source);
+ return false;
+ }
+ }
+ }
+ }
+
+ // For uniform buffer array members, validate that array elements are
+ // aligned to 16 bytes
+ if (auto* arr = store_ty->As<sem::Array>()) {
+ // Recurse into the element type.
+ // TODO(crbug.com/tint/1388): Ideally we'd pass the source for nested
+ // element type here, but we can't easily get that from the semantic node.
+ // We should consider recursing through the AST type nodes instead.
+ if (!ValidateStorageClassLayout(arr->ElemType(), sc, source)) {
+ return false;
+ }
+
+ if (sc == ast::StorageClass::kUniform) {
+ // We already validated that this array member is itself aligned to 16
+ // bytes above, so we only need to validate that stride is a multiple
+ // of 16 bytes.
+ if (arr->Stride() % 16 != 0) {
+ // Since WGSL has no stride attribute, try to provide a useful hint
+ // for how the shader author can resolve the issue.
+ std::string hint;
+ if (arr->ElemType()->is_scalar()) {
+ hint =
+ "Consider using a vector or struct as the element type "
+ "instead.";
+ } else if (auto* vec = arr->ElemType()->As<sem::Vector>();
+ vec && vec->type()->Size() == 4) {
+ hint = "Consider using a vec4 instead.";
+ } else if (arr->ElemType()->Is<sem::Struct>()) {
+ hint =
+ "Consider using the @size attribute on the last struct "
+ "member.";
+ } else {
+ hint =
+ "Consider wrapping the element type in a struct and using "
+ "the "
+ "@size attribute.";
+ }
+ AddError(
+ "uniform storage requires that array elements be aligned to 16 "
+ "bytes, but array element alignment is currently " +
+ std::to_string(arr->Stride()) + ". " + hint,
+ source);
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateStorageClassLayout(const sem::Variable* var) {
+ if (auto* str = var->Type()->UnwrapRef()->As<sem::Struct>()) {
+ if (!ValidateStorageClassLayout(str, var->StorageClass(),
+ str->Declaration()->source)) {
+ AddNote("see declaration of variable", var->Declaration()->source);
+ return false;
+ }
+ } else {
+ Source source = var->Declaration()->source;
+ if (var->Declaration()->type) {
+ source = var->Declaration()->type->source;
+ }
+ if (!ValidateStorageClassLayout(var->Type()->UnwrapRef(),
+ var->StorageClass(), source)) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateGlobalVariable(const sem::Variable* var) {
+ auto* decl = var->Declaration();
+ if (!ValidateNoDuplicateAttributes(decl->attributes)) {
+ return false;
+ }
+
+ for (auto* attr : decl->attributes) {
+ if (decl->is_const) {
+ if (auto* id_attr = attr->As<ast::IdAttribute>()) {
+ uint32_t id = id_attr->value;
+ auto it = constant_ids_.find(id);
+ if (it != constant_ids_.end() && it->second != var) {
+ AddError("pipeline constant IDs must be unique", attr->source);
+ AddNote("a pipeline constant with an ID of " + std::to_string(id) +
+ " was previously declared "
+ "here:",
+ ast::GetAttribute<ast::IdAttribute>(
+ it->second->Declaration()->attributes)
+ ->source);
+ return false;
+ }
+ if (id > 65535) {
+ AddError("pipeline constant IDs must be between 0 and 65535",
+ attr->source);
+ return false;
+ }
+ } else {
+ AddError("attribute is not valid for constants", attr->source);
+ return false;
+ }
+ } else {
+ bool is_shader_io_attribute =
+ attr->IsAnyOf<ast::BuiltinAttribute, ast::InterpolateAttribute,
+ ast::InvariantAttribute, ast::LocationAttribute>();
+ bool has_io_storage_class =
+ var->StorageClass() == ast::StorageClass::kInput ||
+ var->StorageClass() == ast::StorageClass::kOutput;
+ if (!(attr->IsAnyOf<ast::BindingAttribute, ast::GroupAttribute,
+ ast::InternalAttribute>()) &&
+ (!is_shader_io_attribute || !has_io_storage_class)) {
+ AddError("attribute is not valid for variables", attr->source);
+ return false;
+ }
+ }
+ }
+
+ if (var->StorageClass() == ast::StorageClass::kFunction) {
+ AddError(
+ "variables declared at module scope must not be in the function "
+ "storage class",
+ decl->source);
+ return false;
+ }
+
+ auto binding_point = decl->BindingPoint();
+ switch (var->StorageClass()) {
+ case ast::StorageClass::kUniform:
+ case ast::StorageClass::kStorage:
+ case ast::StorageClass::kUniformConstant: {
+ // https://gpuweb.github.io/gpuweb/wgsl/#resource-interface
+ // Each resource variable must be declared with both group and binding
+ // attributes.
+ if (!binding_point) {
+ AddError(
+ "resource variables require @group and @binding "
+ "attributes",
+ decl->source);
+ return false;
+ }
+ break;
+ }
+ default:
+ if (binding_point.binding || binding_point.group) {
+ // https://gpuweb.github.io/gpuweb/wgsl/#attribute-binding
+ // Must only be applied to a resource variable
+ AddError(
+ "non-resource variables must not have @group or @binding "
+ "attributes",
+ decl->source);
+ return false;
+ }
+ }
+
+ // https://gpuweb.github.io/gpuweb/wgsl/#variable-declaration
+ // The access mode always has a default, and except for variables in the
+ // storage storage class, must not be written.
+ if (var->StorageClass() != ast::StorageClass::kStorage &&
+ decl->declared_access != ast::Access::kUndefined) {
+ AddError(
+ "only variables in <storage> storage class may declare an access mode",
+ decl->source);
+ return false;
+ }
+
+ if (!decl->is_const) {
+ if (!ValidateAtomicVariable(var)) {
+ return false;
+ }
+ }
+
+ return ValidateVariable(var);
+}
+
+// https://gpuweb.github.io/gpuweb/wgsl/#atomic-types
+// Atomic types may only be instantiated by variables in the workgroup storage
+// class or by storage buffer variables with a read_write access mode.
+bool Resolver::ValidateAtomicVariable(const sem::Variable* var) {
+ auto sc = var->StorageClass();
+ auto* decl = var->Declaration();
+ auto access = var->Access();
+ auto* type = var->Type()->UnwrapRef();
+ auto source = decl->type ? decl->type->source : decl->source;
+
+ if (type->Is<sem::Atomic>()) {
+ if (sc != ast::StorageClass::kWorkgroup) {
+ AddError(
+ "atomic variables must have <storage> or <workgroup> storage class",
+ source);
+ return false;
+ }
+ } else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
+ auto found = atomic_composite_info_.find(type);
+ if (found != atomic_composite_info_.end()) {
+ if (sc != ast::StorageClass::kStorage &&
+ sc != ast::StorageClass::kWorkgroup) {
+ AddError(
+ "atomic variables must have <storage> or <workgroup> storage class",
+ source);
+ AddNote(
+ "atomic sub-type of '" + TypeNameOf(type) + "' is declared here",
+ found->second);
+ return false;
+ } else if (sc == ast::StorageClass::kStorage &&
+ access != ast::Access::kReadWrite) {
+ AddError(
+ "atomic variables in <storage> storage class must have read_write "
+ "access mode",
+ source);
+ AddNote(
+ "atomic sub-type of '" + TypeNameOf(type) + "' is declared here",
+ found->second);
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateVariable(const sem::Variable* var) {
+ auto* decl = var->Declaration();
+ auto* storage_ty = var->Type()->UnwrapRef();
+
+ if (var->Is<sem::GlobalVariable>()) {
+ auto name = builder_->Symbols().NameFor(decl->symbol);
+ if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
+ auto* kind = var->Declaration()->is_const ? "let" : "var";
+ AddError(
+ "'" + name +
+ "' is a builtin and cannot be redeclared as a module-scope " +
+ kind,
+ decl->source);
+ return false;
+ }
+ }
+
+ if (!decl->is_const && !IsStorable(storage_ty)) {
+ AddError(TypeNameOf(storage_ty) + " cannot be used as the type of a var",
+ decl->source);
+ return false;
+ }
+
+ if (decl->is_const && !var->Is<sem::Parameter>() &&
+ !(storage_ty->IsConstructible() || storage_ty->Is<sem::Pointer>())) {
+ AddError(TypeNameOf(storage_ty) + " cannot be used as the type of a let",
+ decl->source);
+ return false;
+ }
+
+ if (auto* r = storage_ty->As<sem::MultisampledTexture>()) {
+ if (r->dim() != ast::TextureDimension::k2d) {
+ AddError("only 2d multisampled textures are supported", decl->source);
+ return false;
+ }
+
+ if (!r->type()->UnwrapRef()->is_numeric_scalar()) {
+ AddError("texture_multisampled_2d<type>: type must be f32, i32 or u32",
+ decl->source);
+ return false;
+ }
+ }
+
+ if (var->Is<sem::LocalVariable>() && !decl->is_const &&
+ IsValidationEnabled(decl->attributes,
+ ast::DisabledValidation::kIgnoreStorageClass)) {
+ if (!var->Type()->UnwrapRef()->IsConstructible()) {
+ AddError("function variable must have a constructible type",
+ decl->type ? decl->type->source : decl->source);
+ return false;
+ }
+ }
+
+ if (storage_ty->is_handle() &&
+ decl->declared_storage_class != ast::StorageClass::kNone) {
+ // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
+ // If the store type is a texture type or a sampler type, then the
+ // variable declaration must not have a storage class attribute. The
+ // storage class will always be handle.
+ AddError("variables of type '" + TypeNameOf(storage_ty) +
+ "' must not have a storage class",
+ decl->source);
+ return false;
+ }
+
+ if (IsValidationEnabled(decl->attributes,
+ ast::DisabledValidation::kIgnoreStorageClass) &&
+ (decl->declared_storage_class == ast::StorageClass::kInput ||
+ decl->declared_storage_class == ast::StorageClass::kOutput)) {
+ AddError("invalid use of input/output storage class", decl->source);
+ return false;
+ }
+ return true;
+}
+
+bool Resolver::ValidateFunctionParameter(const ast::Function* func,
+ const sem::Variable* var) {
+ if (!ValidateVariable(var)) {
+ return false;
+ }
+
+ auto* decl = var->Declaration();
+
+ for (auto* attr : decl->attributes) {
+ if (!func->IsEntryPoint() && !attr->Is<ast::InternalAttribute>()) {
+ AddError("attribute is not valid for non-entry point function parameters",
+ attr->source);
+ return false;
+ } else if (!attr->IsAnyOf<ast::BuiltinAttribute, ast::InvariantAttribute,
+ ast::LocationAttribute, ast::InterpolateAttribute,
+ ast::InternalAttribute>() &&
+ (IsValidationEnabled(
+ decl->attributes,
+ ast::DisabledValidation::kEntryPointParameter) &&
+ IsValidationEnabled(
+ decl->attributes,
+ ast::DisabledValidation::
+ kIgnoreConstructibleFunctionParameter))) {
+ AddError("attribute is not valid for function parameters", attr->source);
+ return false;
+ }
+ }
+
+ if (auto* ref = var->Type()->As<sem::Pointer>()) {
+ auto sc = ref->StorageClass();
+ if (!(sc == ast::StorageClass::kFunction ||
+ sc == ast::StorageClass::kPrivate ||
+ sc == ast::StorageClass::kWorkgroup) &&
+ IsValidationEnabled(decl->attributes,
+ ast::DisabledValidation::kIgnoreStorageClass)) {
+ std::stringstream ss;
+ ss << "function parameter of pointer type cannot be in '" << sc
+ << "' storage class";
+ AddError(ss.str(), decl->source);
+ return false;
+ }
+ }
+
+ if (IsPlain(var->Type())) {
+ if (!var->Type()->IsConstructible() &&
+ IsValidationEnabled(
+ decl->attributes,
+ ast::DisabledValidation::kIgnoreConstructibleFunctionParameter)) {
+ AddError("store type of function parameter must be a constructible type",
+ decl->source);
+ return false;
+ }
+ } else if (!var->Type()
+ ->IsAnyOf<sem::Texture, sem::Sampler, sem::Pointer>()) {
+ AddError(
+ "store type of function parameter cannot be " + TypeNameOf(var->Type()),
+ decl->source);
+ return false;
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr,
+ const sem::Type* storage_ty,
+ const bool is_input) {
+ auto* type = storage_ty->UnwrapRef();
+ const auto stage = current_function_
+ ? current_function_->Declaration()->PipelineStage()
+ : ast::PipelineStage::kNone;
+ std::stringstream stage_name;
+ stage_name << stage;
+ bool is_stage_mismatch = false;
+ bool is_output = !is_input;
+ switch (attr->builtin) {
+ case ast::Builtin::kPosition:
+ if (stage != ast::PipelineStage::kNone &&
+ !((is_input && stage == ast::PipelineStage::kFragment) ||
+ (is_output && stage == ast::PipelineStage::kVertex))) {
+ is_stage_mismatch = true;
+ }
+ if (!(type->is_float_vector() && type->As<sem::Vector>()->Width() == 4)) {
+ AddError("store type of " + attr_to_str(attr) + " must be 'vec4<f32>'",
+ attr->source);
+ return false;
+ }
+ break;
+ case ast::Builtin::kGlobalInvocationId:
+ case ast::Builtin::kLocalInvocationId:
+ case ast::Builtin::kNumWorkgroups:
+ case ast::Builtin::kWorkgroupId:
+ if (stage != ast::PipelineStage::kNone &&
+ !(stage == ast::PipelineStage::kCompute && is_input)) {
+ is_stage_mismatch = true;
+ }
+ if (!(type->is_unsigned_integer_vector() &&
+ type->As<sem::Vector>()->Width() == 3)) {
+ AddError("store type of " + attr_to_str(attr) + " must be 'vec3<u32>'",
+ attr->source);
+ return false;
+ }
+ break;
+ case ast::Builtin::kFragDepth:
+ if (stage != ast::PipelineStage::kNone &&
+ !(stage == ast::PipelineStage::kFragment && !is_input)) {
+ is_stage_mismatch = true;
+ }
+ if (!type->Is<sem::F32>()) {
+ AddError("store type of " + attr_to_str(attr) + " must be 'f32'",
+ attr->source);
+ return false;
+ }
+ break;
+ case ast::Builtin::kFrontFacing:
+ if (stage != ast::PipelineStage::kNone &&
+ !(stage == ast::PipelineStage::kFragment && is_input)) {
+ is_stage_mismatch = true;
+ }
+ if (!type->Is<sem::Bool>()) {
+ AddError("store type of " + attr_to_str(attr) + " must be 'bool'",
+ attr->source);
+ return false;
+ }
+ break;
+ case ast::Builtin::kLocalInvocationIndex:
+ if (stage != ast::PipelineStage::kNone &&
+ !(stage == ast::PipelineStage::kCompute && is_input)) {
+ is_stage_mismatch = true;
+ }
+ if (!type->Is<sem::U32>()) {
+ AddError("store type of " + attr_to_str(attr) + " must be 'u32'",
+ attr->source);
+ return false;
+ }
+ break;
+ case ast::Builtin::kVertexIndex:
+ case ast::Builtin::kInstanceIndex:
+ if (stage != ast::PipelineStage::kNone &&
+ !(stage == ast::PipelineStage::kVertex && is_input)) {
+ is_stage_mismatch = true;
+ }
+ if (!type->Is<sem::U32>()) {
+ AddError("store type of " + attr_to_str(attr) + " must be 'u32'",
+ attr->source);
+ return false;
+ }
+ break;
+ case ast::Builtin::kSampleMask:
+ if (stage != ast::PipelineStage::kNone &&
+ !(stage == ast::PipelineStage::kFragment)) {
+ is_stage_mismatch = true;
+ }
+ if (!type->Is<sem::U32>()) {
+ AddError("store type of " + attr_to_str(attr) + " must be 'u32'",
+ attr->source);
+ return false;
+ }
+ break;
+ case ast::Builtin::kSampleIndex:
+ if (stage != ast::PipelineStage::kNone &&
+ !(stage == ast::PipelineStage::kFragment && is_input)) {
+ is_stage_mismatch = true;
+ }
+ if (!type->Is<sem::U32>()) {
+ AddError("store type of " + attr_to_str(attr) + " must be 'u32'",
+ attr->source);
+ return false;
+ }
+ break;
+ default:
+ break;
+ }
+
+ if (is_stage_mismatch) {
+ AddError(attr_to_str(attr) + " cannot be used in " +
+ (is_input ? "input of " : "output of ") + stage_name.str() +
+ " pipeline stage",
+ attr->source);
+ return false;
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateInterpolateAttribute(
+ const ast::InterpolateAttribute* attr,
+ const sem::Type* storage_ty) {
+ auto* type = storage_ty->UnwrapRef();
+
+ if (type->is_integer_scalar_or_vector() &&
+ attr->type != ast::InterpolationType::kFlat) {
+ AddError(
+ "interpolation type must be 'flat' for integral user-defined IO types",
+ attr->source);
+ return false;
+ }
+
+ if (attr->type == ast::InterpolationType::kFlat &&
+ attr->sampling != ast::InterpolationSampling::kNone) {
+ AddError("flat interpolation attribute must not have a sampling parameter",
+ attr->source);
+ return false;
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateFunction(const sem::Function* func) {
+ auto* decl = func->Declaration();
+
+ auto name = builder_->Symbols().NameFor(decl->symbol);
+ if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
+ AddError(
+ "'" + name + "' is a builtin and cannot be redeclared as a function",
+ decl->source);
+ return false;
+ }
+
+ auto workgroup_attr_count = 0;
+ for (auto* attr : decl->attributes) {
+ if (attr->Is<ast::WorkgroupAttribute>()) {
+ workgroup_attr_count++;
+ if (decl->PipelineStage() != ast::PipelineStage::kCompute) {
+ AddError(
+ "the workgroup_size attribute is only valid for compute stages",
+ attr->source);
+ return false;
+ }
+ } else if (!attr->IsAnyOf<ast::StageAttribute, ast::InternalAttribute>()) {
+ AddError("attribute is not valid for functions", attr->source);
+ return false;
+ }
+ }
+
+ if (decl->params.size() > 255) {
+ AddError("functions may declare at most 255 parameters", decl->source);
+ return false;
+ }
+
+ for (size_t i = 0; i < decl->params.size(); i++) {
+ if (!ValidateFunctionParameter(decl, func->Parameters()[i])) {
+ return false;
+ }
+ }
+
+ if (!func->ReturnType()->Is<sem::Void>()) {
+ if (!func->ReturnType()->IsConstructible()) {
+ AddError("function return type must be a constructible type",
+ decl->return_type->source);
+ return false;
+ }
+
+ if (decl->body) {
+ sem::Behaviors behaviors{sem::Behavior::kNext};
+ if (auto* last = decl->body->Last()) {
+ behaviors = Sem(last)->Behaviors();
+ }
+ if (behaviors.Contains(sem::Behavior::kNext)) {
+ AddError("missing return at end of function", decl->source);
+ return false;
+ }
+ } else if (IsValidationEnabled(
+ decl->attributes,
+ ast::DisabledValidation::kFunctionHasNoBody)) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "Function " << builder_->Symbols().NameFor(decl->symbol)
+ << " has no body";
+ }
+
+ for (auto* attr : decl->return_type_attributes) {
+ if (!decl->IsEntryPoint()) {
+ AddError(
+ "attribute is not valid for non-entry point function return types",
+ attr->source);
+ return false;
+ }
+ if (!attr->IsAnyOf<ast::BuiltinAttribute, ast::InternalAttribute,
+ ast::LocationAttribute, ast::InterpolateAttribute,
+ ast::InvariantAttribute>() &&
+ (IsValidationEnabled(decl->attributes,
+ ast::DisabledValidation::kEntryPointParameter) &&
+ IsValidationEnabled(decl->attributes,
+ ast::DisabledValidation::
+ kIgnoreConstructibleFunctionParameter))) {
+ AddError("attribute is not valid for entry point return types",
+ attr->source);
+ return false;
+ }
+ }
+ }
+
+ if (decl->IsEntryPoint()) {
+ if (!ValidateEntryPoint(func)) {
+ return false;
+ }
+ }
+
+ // https://www.w3.org/TR/WGSL/#behaviors-rules
+ // a function behavior is always one of {}, {Next}, {Discard}, or
+ // {Next, Discard}.
+ if (func->Behaviors() != sem::Behaviors{} && // NOLINT: bad warning
+ func->Behaviors() != sem::Behavior::kNext &&
+ func->Behaviors() != sem::Behavior::kDiscard &&
+ func->Behaviors() != sem::Behaviors{sem::Behavior::kNext, //
+ sem::Behavior::kDiscard}) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "function '" << name << "' behaviors are: " << func->Behaviors();
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateEntryPoint(const sem::Function* func) {
+ auto* decl = func->Declaration();
+
+ // Use a lambda to validate the entry point attributes for a type.
+ // Persistent state is used to track which builtins and locations have
+ // already been seen, in order to catch conflicts.
+ // TODO(jrprice): This state could be stored in sem::Function instead, and
+ // then passed to sem::Function since it would be useful there too.
+ std::unordered_set<ast::Builtin> builtins;
+ std::unordered_set<uint32_t> locations;
+ enum class ParamOrRetType {
+ kParameter,
+ kReturnType,
+ };
+
+ // Inner lambda that is applied to a type and all of its members.
+ auto validate_entry_point_attributes_inner = [&](const ast::AttributeList&
+ attrs,
+ const sem::Type* ty,
+ Source source,
+ ParamOrRetType param_or_ret,
+ bool is_struct_member) {
+ // Scan attributes for pipeline IO attributes.
+ // Check for overlap with attributes that have been seen previously.
+ const ast::Attribute* pipeline_io_attribute = nullptr;
+ const ast::InterpolateAttribute* interpolate_attribute = nullptr;
+ const ast::InvariantAttribute* invariant_attribute = nullptr;
+ for (auto* attr : attrs) {
+ auto is_invalid_compute_shader_attribute = false;
+ if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
+ if (pipeline_io_attribute) {
+ AddError("multiple entry point IO attributes", attr->source);
+ AddNote("previously consumed " + attr_to_str(pipeline_io_attribute),
+ pipeline_io_attribute->source);
+ return false;
+ }
+ pipeline_io_attribute = attr;
+
+ if (builtins.count(builtin->builtin)) {
+ AddError(attr_to_str(builtin) +
+ " attribute appears multiple times as pipeline " +
+ (param_or_ret == ParamOrRetType::kParameter ? "input"
+ : "output"),
+ decl->source);
+ return false;
+ }
+
+ if (!ValidateBuiltinAttribute(
+ builtin, ty,
+ /* is_input */ param_or_ret == ParamOrRetType::kParameter)) {
+ return false;
+ }
+ builtins.emplace(builtin->builtin);
+ } else if (auto* location = attr->As<ast::LocationAttribute>()) {
+ if (pipeline_io_attribute) {
+ AddError("multiple entry point IO attributes", attr->source);
+ AddNote("previously consumed " + attr_to_str(pipeline_io_attribute),
+ pipeline_io_attribute->source);
+ return false;
+ }
+ pipeline_io_attribute = attr;
+
+ bool is_input = param_or_ret == ParamOrRetType::kParameter;
+ if (!ValidateLocationAttribute(location, ty, locations, source,
+ is_input)) {
+ return false;
+ }
+ } else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
+ if (decl->PipelineStage() == ast::PipelineStage::kCompute) {
+ is_invalid_compute_shader_attribute = true;
+ } else if (!ValidateInterpolateAttribute(interpolate, ty)) {
+ return false;
+ }
+ interpolate_attribute = interpolate;
+ } else if (auto* invariant = attr->As<ast::InvariantAttribute>()) {
+ if (decl->PipelineStage() == ast::PipelineStage::kCompute) {
+ is_invalid_compute_shader_attribute = true;
+ }
+ invariant_attribute = invariant;
+ }
+ if (is_invalid_compute_shader_attribute) {
+ std::string input_or_output =
+ param_or_ret == ParamOrRetType::kParameter ? "inputs" : "output";
+ AddError("attribute is not valid for compute shader " + input_or_output,
+ attr->source);
+ return false;
+ }
+ }
+
+ if (IsValidationEnabled(attrs,
+ ast::DisabledValidation::kEntryPointParameter)) {
+ if (is_struct_member && ty->Is<sem::Struct>()) {
+ AddError("nested structures cannot be used for entry point IO", source);
+ return false;
+ }
+
+ if (!ty->Is<sem::Struct>() && !pipeline_io_attribute) {
+ std::string err = "missing entry point IO attribute";
+ if (!is_struct_member) {
+ err +=
+ (param_or_ret == ParamOrRetType::kParameter ? " on parameter"
+ : " on return type");
+ }
+ AddError(err, source);
+ return false;
+ }
+
+ if (pipeline_io_attribute &&
+ pipeline_io_attribute->Is<ast::LocationAttribute>()) {
+ if (ty->is_integer_scalar_or_vector() && !interpolate_attribute) {
+ if (decl->PipelineStage() == ast::PipelineStage::kVertex &&
+ param_or_ret == ParamOrRetType::kReturnType) {
+ AddError(
+ "integral user-defined vertex outputs must have a flat "
+ "interpolation attribute",
+ source);
+ return false;
+ }
+ if (decl->PipelineStage() == ast::PipelineStage::kFragment &&
+ param_or_ret == ParamOrRetType::kParameter) {
+ AddError(
+ "integral user-defined fragment inputs must have a flat "
+ "interpolation attribute",
+ source);
+ return false;
+ }
+ }
+ }
+
+ if (interpolate_attribute) {
+ if (!pipeline_io_attribute ||
+ !pipeline_io_attribute->Is<ast::LocationAttribute>()) {
+ AddError("interpolate attribute must only be used with @location",
+ interpolate_attribute->source);
+ return false;
+ }
+ }
+
+ if (invariant_attribute) {
+ bool has_position = false;
+ if (pipeline_io_attribute) {
+ if (auto* builtin =
+ pipeline_io_attribute->As<ast::BuiltinAttribute>()) {
+ has_position = (builtin->builtin == ast::Builtin::kPosition);
+ }
+ }
+ if (!has_position) {
+ AddError(
+ "invariant attribute must only be applied to a position "
+ "builtin",
+ invariant_attribute->source);
+ return false;
+ }
+ }
+ }
+ return true;
+ };
+
+ // Outer lambda for validating the entry point attributes for a type.
+ auto validate_entry_point_attributes = [&](const ast::AttributeList& attrs,
+ const sem::Type* ty, Source source,
+ ParamOrRetType param_or_ret) {
+ if (!validate_entry_point_attributes_inner(attrs, ty, source, param_or_ret,
+ /*is_struct_member*/ false)) {
+ return false;
+ }
+
+ if (auto* str = ty->As<sem::Struct>()) {
+ for (auto* member : str->Members()) {
+ if (!validate_entry_point_attributes_inner(
+ member->Declaration()->attributes, member->Type(),
+ member->Declaration()->source, param_or_ret,
+ /*is_struct_member*/ true)) {
+ AddNote("while analysing entry point '" +
+ builder_->Symbols().NameFor(decl->symbol) + "'",
+ decl->source);
+ return false;
+ }
+ }
+ }
+
+ return true;
+ };
+
+ for (auto* param : func->Parameters()) {
+ auto* param_decl = param->Declaration();
+ if (!validate_entry_point_attributes(param_decl->attributes, param->Type(),
+ param_decl->source,
+ ParamOrRetType::kParameter)) {
+ return false;
+ }
+ }
+
+ // Clear IO sets after parameter validation. Builtin and location attributes
+ // in return types should be validated independently from those used in
+ // parameters.
+ builtins.clear();
+ locations.clear();
+
+ if (!func->ReturnType()->Is<sem::Void>()) {
+ if (!validate_entry_point_attributes(decl->return_type_attributes,
+ func->ReturnType(), decl->source,
+ ParamOrRetType::kReturnType)) {
+ return false;
+ }
+ }
+
+ if (decl->PipelineStage() == ast::PipelineStage::kVertex &&
+ builtins.count(ast::Builtin::kPosition) == 0) {
+ // Check module-scope variables, as the SPIR-V sanitizer generates these.
+ bool found = false;
+ for (auto* global : func->TransitivelyReferencedGlobals()) {
+ if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(
+ global->Declaration()->attributes)) {
+ if (builtin->builtin == ast::Builtin::kPosition) {
+ found = true;
+ break;
+ }
+ }
+ }
+ if (!found) {
+ AddError(
+ "a vertex shader must include the 'position' builtin in its return "
+ "type",
+ decl->source);
+ return false;
+ }
+ }
+
+ if (decl->PipelineStage() == ast::PipelineStage::kCompute) {
+ if (!ast::HasAttribute<ast::WorkgroupAttribute>(decl->attributes)) {
+ AddError(
+ "a compute shader must include 'workgroup_size' in its "
+ "attributes",
+ decl->source);
+ return false;
+ }
+ }
+
+ // Validate there are no resource variable binding collisions
+ std::unordered_map<sem::BindingPoint, const ast::Variable*> binding_points;
+ for (auto* var : func->TransitivelyReferencedGlobals()) {
+ auto* var_decl = var->Declaration();
+ if (!var_decl->BindingPoint()) {
+ continue;
+ }
+ auto bp = var->BindingPoint();
+ auto res = binding_points.emplace(bp, var_decl);
+ if (!res.second &&
+ IsValidationEnabled(decl->attributes,
+ ast::DisabledValidation::kBindingPointCollision) &&
+ IsValidationEnabled(res.first->second->attributes,
+ ast::DisabledValidation::kBindingPointCollision)) {
+ // https://gpuweb.github.io/gpuweb/wgsl/#resource-interface
+ // Bindings must not alias within a shader stage: two different
+ // variables in the resource interface of a given shader must not have
+ // the same group and binding values, when considered as a pair of
+ // values.
+ auto func_name = builder_->Symbols().NameFor(decl->symbol);
+ AddError("entry point '" + func_name +
+ "' references multiple variables that use the "
+ "same resource binding @group(" +
+ std::to_string(bp.group) + "), @binding(" +
+ std::to_string(bp.binding) + ")",
+ var_decl->source);
+ AddNote("first resource binding usage declared here",
+ res.first->second->source);
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateStatements(const ast::StatementList& stmts) {
+ for (auto* stmt : stmts) {
+ if (!Sem(stmt)->IsReachable()) {
+ /// TODO(https://github.com/gpuweb/gpuweb/issues/2378): This may need to
+ /// become an error.
+ AddWarning("code is unreachable", stmt->source);
+ break;
+ }
+ }
+ return true;
+}
+
+bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast,
+ const sem::Type* to) {
+ auto* from = TypeOf(cast->expr)->UnwrapRef();
+ if (!from->is_numeric_scalar_or_vector()) {
+ AddError("'" + TypeNameOf(from) + "' cannot be bitcast",
+ cast->expr->source);
+ return false;
+ }
+ if (!to->is_numeric_scalar_or_vector()) {
+ AddError("cannot bitcast to '" + TypeNameOf(to) + "'", cast->type->source);
+ return false;
+ }
+
+ auto width = [&](const sem::Type* ty) {
+ if (auto* vec = ty->As<sem::Vector>()) {
+ return vec->Width();
+ }
+ return 1u;
+ };
+
+ if (width(from) != width(to)) {
+ AddError("cannot bitcast from '" + TypeNameOf(from) + "' to '" +
+ TypeNameOf(to) + "'",
+ cast->source);
+ return false;
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) {
+ if (!stmt->FindFirstParent<sem::LoopBlockStatement, sem::CaseStatement>()) {
+ AddError("break statement must be in a loop or switch case",
+ stmt->Declaration()->source);
+ return false;
+ }
+ if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ true)) {
+ auto fail = [&](const char* note_msg, const Source& note_src) {
+ constexpr const char* kErrorMsg =
+ "break statement in a continuing block must be the single statement "
+ "of an if statement's true or false block, and that if statement "
+ "must be the last statement of the continuing block";
+ AddError(kErrorMsg, stmt->Declaration()->source);
+ AddNote(note_msg, note_src);
+ return false;
+ };
+
+ if (auto* block = stmt->Parent()->As<sem::BlockStatement>()) {
+ auto* block_parent = block->Parent();
+ auto* if_stmt = block_parent->As<sem::IfStatement>();
+ auto* el_stmt = block_parent->As<sem::ElseStatement>();
+ if (el_stmt) {
+ if_stmt = el_stmt->Parent();
+ }
+ if (!if_stmt) {
+ return fail("break statement is not directly in if statement block",
+ stmt->Declaration()->source);
+ }
+ if (block->Declaration()->statements.size() != 1) {
+ return fail("if statement block contains multiple statements",
+ block->Declaration()->source);
+ }
+ for (auto* el : if_stmt->Declaration()->else_statements) {
+ if (el->condition) {
+ return fail("else has condition", el->condition->source);
+ }
+ bool el_contains_break = el_stmt && el == el_stmt->Declaration();
+ if (el_contains_break) {
+ if (auto* true_block = if_stmt->Declaration()->body;
+ !true_block->Empty()) {
+ return fail("non-empty true block", true_block->source);
+ }
+ } else {
+ if (!el->body->Empty()) {
+ return fail("non-empty false block", el->body->source);
+ }
+ }
+ }
+ if (if_stmt->Parent()->Declaration() != continuing) {
+ return fail(
+ "if statement containing break statement is not directly in "
+ "continuing block",
+ if_stmt->Declaration()->source);
+ }
+ if (auto* cont_block = continuing->As<ast::BlockStatement>()) {
+ if (if_stmt->Declaration() != cont_block->Last()) {
+ return fail(
+ "if statement containing break statement is not the last "
+ "statement of the continuing block",
+ if_stmt->Declaration()->source);
+ }
+ }
+ }
+ }
+ return true;
+}
+
+bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) {
+ if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ true)) {
+ AddError("continuing blocks must not contain a continue statement",
+ stmt->Declaration()->source);
+ if (continuing != stmt->Declaration() &&
+ continuing != stmt->Parent()->Declaration()) {
+ AddNote("see continuing block here", continuing->source);
+ }
+ return false;
+ }
+
+ if (!stmt->FindFirstParent<sem::LoopBlockStatement>()) {
+ AddError("continue statement must be in a loop",
+ stmt->Declaration()->source);
+ return false;
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) {
+ if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) {
+ AddError("continuing blocks must not contain a discard statement",
+ stmt->Declaration()->source);
+ if (continuing != stmt->Declaration() &&
+ continuing != stmt->Parent()->Declaration()) {
+ AddNote("see continuing block here", continuing->source);
+ }
+ return false;
+ }
+ return true;
+}
+
+bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) {
+ if (auto* block = As<sem::BlockStatement>(stmt->Parent())) {
+ if (auto* c = As<sem::CaseStatement>(block->Parent())) {
+ if (block->Declaration()->Last() == stmt->Declaration()) {
+ if (auto* s = As<sem::SwitchStatement>(c->Parent())) {
+ if (c->Declaration() != s->Declaration()->body.back()) {
+ return true;
+ }
+ AddError(
+ "a fallthrough statement must not be used in the last switch "
+ "case",
+ stmt->Declaration()->source);
+ return false;
+ }
+ }
+ }
+ }
+ AddError(
+ "fallthrough must only be used as the last statement of a case block",
+ stmt->Declaration()->source);
+ return false;
+}
+
+bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) {
+ if (auto* cond = stmt->Condition()) {
+ auto* cond_ty = cond->Type()->UnwrapRef();
+ if (!cond_ty->Is<sem::Bool>()) {
+ AddError(
+ "else statement condition must be bool, got " + TypeNameOf(cond_ty),
+ stmt->Condition()->Declaration()->source);
+ return false;
+ }
+ }
+ return true;
+}
+
+bool Resolver::ValidateLoopStatement(const sem::LoopStatement* stmt) {
+ if (stmt->Behaviors().Empty()) {
+ AddError("loop does not exit", stmt->Declaration()->source.Begin());
+ return false;
+ }
+ return true;
+}
+
+bool Resolver::ValidateForLoopStatement(const sem::ForLoopStatement* stmt) {
+ if (stmt->Behaviors().Empty()) {
+ AddError("for-loop does not exit", stmt->Declaration()->source.Begin());
+ return false;
+ }
+ if (auto* cond = stmt->Condition()) {
+ auto* cond_ty = cond->Type()->UnwrapRef();
+ if (!cond_ty->Is<sem::Bool>()) {
+ AddError("for-loop condition must be bool, got " + TypeNameOf(cond_ty),
+ stmt->Condition()->Declaration()->source);
+ return false;
+ }
+ }
+ return true;
+}
+
+bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) {
+ auto* cond_ty = stmt->Condition()->Type()->UnwrapRef();
+ if (!cond_ty->Is<sem::Bool>()) {
+ AddError("if statement condition must be bool, got " + TypeNameOf(cond_ty),
+ stmt->Condition()->Declaration()->source);
+ return false;
+ }
+ return true;
+}
+
+bool Resolver::ValidateBuiltinCall(const sem::Call* call) {
+ if (call->Type()->Is<sem::Void>()) {
+ bool is_call_statement = false;
+ if (auto* call_stmt = As<ast::CallStatement>(call->Stmt()->Declaration())) {
+ if (call_stmt->expr == call->Declaration()) {
+ is_call_statement = true;
+ }
+ }
+ if (!is_call_statement) {
+ // https://gpuweb.github.io/gpuweb/wgsl/#function-call-expr
+ // If the called function does not return a value, a function call
+ // statement should be used instead.
+ auto* ident = call->Declaration()->target.name;
+ auto name = builder_->Symbols().NameFor(ident->symbol);
+ AddError("builtin '" + name + "' does not return a value",
+ call->Declaration()->source);
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateTextureBuiltinFunction(const sem::Call* call) {
+ auto* builtin = call->Target()->As<sem::Builtin>();
+ if (!builtin) {
+ return false;
+ }
+
+ std::string func_name = builtin->str();
+ auto& signature = builtin->Signature();
+
+ auto check_arg_is_constexpr = [&](sem::ParameterUsage usage, int min,
+ int max) {
+ auto index = signature.IndexOf(usage);
+ if (index < 0) {
+ return true;
+ }
+ std::string name = sem::str(usage);
+ auto* arg = call->Arguments()[index];
+ if (auto values = arg->ConstantValue()) {
+ // Assert that the constant values are of the expected type.
+ if (!values.Type()->IsAnyOf<sem::I32, sem::Vector>() ||
+ !values.ElementType()->Is<sem::I32>()) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "failed to resolve '" + func_name + "' " << name
+ << " parameter type";
+ return false;
+ }
+
+ // Currently const_expr is restricted to literals and type constructors.
+ // Check that that's all we have for the parameter.
+ bool is_const_expr = true;
+ ast::TraverseExpressions(
+ arg->Declaration(), diagnostics_, [&](const ast::Expression* e) {
+ if (e->IsAnyOf<ast::LiteralExpression, ast::CallExpression>()) {
+ return ast::TraverseAction::Descend;
+ }
+ is_const_expr = false;
+ return ast::TraverseAction::Stop;
+ });
+ if (is_const_expr) {
+ auto vector = builtin->Parameters()[index]->Type()->Is<sem::Vector>();
+ for (size_t i = 0; i < values.Elements().size(); i++) {
+ auto value = values.Elements()[i].i32;
+ if (value < min || value > max) {
+ if (vector) {
+ AddError("each component of the " + name +
+ " argument must be at least " + std::to_string(min) +
+ " and at most " + std::to_string(max) + ". " + name +
+ " component " + std::to_string(i) + " is " +
+ std::to_string(value),
+ arg->Declaration()->source);
+ } else {
+ AddError("the " + name + " argument must be at least " +
+ std::to_string(min) + " and at most " +
+ std::to_string(max) + ". " + name + " is " +
+ std::to_string(value),
+ arg->Declaration()->source);
+ }
+ return false;
+ }
+ }
+ return true;
+ }
+ }
+ AddError("the " + name + " argument must be a const_expression",
+ arg->Declaration()->source);
+ return false;
+ };
+
+ return check_arg_is_constexpr(sem::ParameterUsage::kOffset, -8, 7) &&
+ check_arg_is_constexpr(sem::ParameterUsage::kComponent, 0, 3);
+}
+
+bool Resolver::ValidateFunctionCall(const sem::Call* call) {
+ auto* decl = call->Declaration();
+ auto* target = call->Target()->As<sem::Function>();
+ auto sym = decl->target.name->symbol;
+ auto name = builder_->Symbols().NameFor(sym);
+
+ if (target->Declaration()->IsEntryPoint()) {
+ // https://www.w3.org/TR/WGSL/#function-restriction
+ // An entry point must never be the target of a function call.
+ AddError("entry point functions cannot be the target of a function call",
+ decl->source);
+ return false;
+ }
+
+ if (decl->args.size() != target->Parameters().size()) {
+ bool more = decl->args.size() > target->Parameters().size();
+ AddError("too " + (more ? std::string("many") : std::string("few")) +
+ " arguments in call to '" + name + "', expected " +
+ std::to_string(target->Parameters().size()) + ", got " +
+ std::to_string(call->Arguments().size()),
+ decl->source);
+ return false;
+ }
+
+ for (size_t i = 0; i < call->Arguments().size(); ++i) {
+ const sem::Variable* param = target->Parameters()[i];
+ const ast::Expression* arg_expr = decl->args[i];
+ auto* param_type = param->Type();
+ auto* arg_type = TypeOf(arg_expr)->UnwrapRef();
+
+ if (param_type != arg_type) {
+ AddError("type mismatch for argument " + std::to_string(i + 1) +
+ " in call to '" + name + "', expected '" +
+ TypeNameOf(param_type) + "', got '" + TypeNameOf(arg_type) +
+ "'",
+ arg_expr->source);
+ return false;
+ }
+
+ if (param_type->Is<sem::Pointer>()) {
+ auto is_valid = false;
+ if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) {
+ auto* var = ResolvedSymbol<sem::Variable>(ident_expr);
+ if (!var) {
+ TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
+ return false;
+ }
+ if (var->Is<sem::Parameter>()) {
+ is_valid = true;
+ }
+ } else if (auto* unary = arg_expr->As<ast::UnaryOpExpression>()) {
+ if (unary->op == ast::UnaryOp::kAddressOf) {
+ if (auto* ident_unary =
+ unary->expr->As<ast::IdentifierExpression>()) {
+ auto* var = ResolvedSymbol<sem::Variable>(ident_unary);
+ if (!var) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "failed to resolve identifier";
+ return false;
+ }
+ if (var->Declaration()->is_const) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "Resolver::FunctionCall() encountered an address-of "
+ "expression of a constant identifier expression";
+ return false;
+ }
+ is_valid = true;
+ }
+ }
+ }
+
+ if (!is_valid &&
+ IsValidationEnabled(
+ param->Declaration()->attributes,
+ ast::DisabledValidation::kIgnoreInvalidPointerArgument)) {
+ AddError(
+ "expected an address-of expression of a variable identifier "
+ "expression or a function parameter",
+ arg_expr->source);
+ return false;
+ }
+ }
+ }
+
+ if (call->Type()->Is<sem::Void>()) {
+ bool is_call_statement = false;
+ if (auto* call_stmt = As<ast::CallStatement>(call->Stmt()->Declaration())) {
+ if (call_stmt->expr == call->Declaration()) {
+ is_call_statement = true;
+ }
+ }
+ if (!is_call_statement) {
+ // https://gpuweb.github.io/gpuweb/wgsl/#function-call-expr
+ // If the called function does not return a value, a function call
+ // statement should be used instead.
+ AddError("function '" + name + "' does not return a value", decl->source);
+ return false;
+ }
+ }
+
+ if (call->Behaviors().Contains(sem::Behavior::kDiscard)) {
+ if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) {
+ AddError(
+ "cannot call a function that may discard inside a continuing block",
+ call->Declaration()->source);
+ if (continuing != call->Stmt()->Declaration() &&
+ continuing != call->Stmt()->Parent()->Declaration()) {
+ AddNote("see continuing block here", continuing->source);
+ }
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateStructureConstructorOrCast(
+ const ast::CallExpression* ctor,
+ const sem::Struct* struct_type) {
+ if (!struct_type->IsConstructible()) {
+ AddError("struct constructor has non-constructible type", ctor->source);
+ return false;
+ }
+
+ if (ctor->args.size() > 0) {
+ if (ctor->args.size() != struct_type->Members().size()) {
+ std::string fm =
+ ctor->args.size() < struct_type->Members().size() ? "few" : "many";
+ AddError("struct constructor has too " + fm + " inputs: expected " +
+ std::to_string(struct_type->Members().size()) + ", found " +
+ std::to_string(ctor->args.size()),
+ ctor->source);
+ return false;
+ }
+ for (auto* member : struct_type->Members()) {
+ auto* value = ctor->args[member->Index()];
+ auto* value_ty = TypeOf(value);
+ if (member->Type() != value_ty->UnwrapRef()) {
+ AddError(
+ "type in struct constructor does not match struct member type: "
+ "expected '" +
+ TypeNameOf(member->Type()) + "', found '" +
+ TypeNameOf(value_ty) + "'",
+ value->source);
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+bool Resolver::ValidateArrayConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Array* array_type) {
+ auto& values = ctor->args;
+ auto* elem_ty = array_type->ElemType();
+ for (auto* value : values) {
+ auto* value_ty = TypeOf(value)->UnwrapRef();
+ if (value_ty != elem_ty) {
+ AddError(
+ "type in array constructor does not match array type: "
+ "expected '" +
+ TypeNameOf(elem_ty) + "', found '" + TypeNameOf(value_ty) + "'",
+ value->source);
+ return false;
+ }
+ }
+
+ if (array_type->IsRuntimeSized()) {
+ AddError("cannot init a runtime-sized array", ctor->source);
+ return false;
+ } else if (!elem_ty->IsConstructible()) {
+ AddError("array constructor has non-constructible element type",
+ ctor->source);
+ return false;
+ } else if (!values.empty() && (values.size() != array_type->Count())) {
+ std::string fm = values.size() < array_type->Count() ? "few" : "many";
+ AddError("array constructor has too " + fm + " elements: expected " +
+ std::to_string(array_type->Count()) + ", found " +
+ std::to_string(values.size()),
+ ctor->source);
+ return false;
+ } else if (values.size() > array_type->Count()) {
+ AddError("array constructor has too many elements: expected " +
+ std::to_string(array_type->Count()) + ", found " +
+ std::to_string(values.size()),
+ ctor->source);
+ return false;
+ }
+ return true;
+}
+
+bool Resolver::ValidateVectorConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Vector* vec_type) {
+ auto& values = ctor->args;
+ auto* elem_ty = vec_type->type();
+ size_t value_cardinality_sum = 0;
+ for (auto* value : values) {
+ auto* value_ty = TypeOf(value)->UnwrapRef();
+ if (value_ty->is_scalar()) {
+ if (elem_ty != value_ty) {
+ AddError(
+ "type in vector constructor does not match vector type: "
+ "expected '" +
+ TypeNameOf(elem_ty) + "', found '" + TypeNameOf(value_ty) + "'",
+ value->source);
+ return false;
+ }
+
+ value_cardinality_sum++;
+ } else if (auto* value_vec = value_ty->As<sem::Vector>()) {
+ auto* value_elem_ty = value_vec->type();
+ // A mismatch of vector type parameter T is only an error if multiple
+ // arguments are present. A single argument constructor constitutes a
+ // type conversion expression.
+ if (elem_ty != value_elem_ty && values.size() > 1u) {
+ AddError(
+ "type in vector constructor does not match vector type: "
+ "expected '" +
+ TypeNameOf(elem_ty) + "', found '" + TypeNameOf(value_elem_ty) +
+ "'",
+ value->source);
+ return false;
+ }
+
+ value_cardinality_sum += value_vec->Width();
+ } else {
+ // A vector constructor can only accept vectors and scalars.
+ AddError("expected vector or scalar type in vector constructor; found: " +
+ TypeNameOf(value_ty),
+ value->source);
+ return false;
+ }
+ }
+
+ // A correct vector constructor must either be a zero-value expression,
+ // a single-value initializer (splat) expression, or the number of components
+ // of all constructor arguments must add up to the vector cardinality.
+ if (value_cardinality_sum > 1 && value_cardinality_sum != vec_type->Width()) {
+ if (values.empty()) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "constructor arguments expected to be non-empty!";
+ }
+ const Source& values_start = values[0]->source;
+ const Source& values_end = values[values.size() - 1]->source;
+ AddError("attempted to construct '" + TypeNameOf(vec_type) + "' with " +
+ std::to_string(value_cardinality_sum) + " component(s)",
+ Source::Combine(values_start, values_end));
+ return false;
+ }
+ return true;
+}
+
+bool Resolver::ValidateVector(const sem::Vector* ty, const Source& source) {
+ if (!ty->type()->is_scalar()) {
+ AddError("vector element type must be 'bool', 'f32', 'i32' or 'u32'",
+ source);
+ return false;
+ }
+ return true;
+}
+
+bool Resolver::ValidateMatrix(const sem::Matrix* ty, const Source& source) {
+ if (!ty->is_float_matrix()) {
+ AddError("matrix element type must be 'f32'", source);
+ return false;
+ }
+ return true;
+}
+
+bool Resolver::ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Matrix* matrix_ty) {
+ auto& values = ctor->args;
+ // Zero Value expression
+ if (values.empty()) {
+ return true;
+ }
+
+ if (!ValidateMatrix(matrix_ty, ctor->source)) {
+ return false;
+ }
+
+ std::vector<const sem::Type*> arg_tys;
+ arg_tys.reserve(values.size());
+ for (auto* value : values) {
+ arg_tys.emplace_back(TypeOf(value)->UnwrapRef());
+ }
+
+ auto* elem_type = matrix_ty->type();
+ auto num_elements = matrix_ty->columns() * matrix_ty->rows();
+
+ // Print a generic error for an invalid matrix constructor, showing the
+ // available overloads.
+ auto print_error = [&]() {
+ const Source& values_start = values[0]->source;
+ const Source& values_end = values[values.size() - 1]->source;
+ auto type_name = TypeNameOf(matrix_ty);
+ auto elem_type_name = TypeNameOf(elem_type);
+ std::stringstream ss;
+ ss << "no matching constructor " + type_name << "(";
+ for (size_t i = 0; i < values.size(); i++) {
+ if (i > 0) {
+ ss << ", ";
+ }
+ ss << arg_tys[i]->FriendlyName(builder_->Symbols());
+ }
+ ss << ")" << std::endl << std::endl;
+ ss << "3 candidates available:" << std::endl;
+ ss << " " << type_name << "()" << std::endl;
+ ss << " " << type_name << "(" << elem_type_name << ",...,"
+ << elem_type_name << ")"
+ << " // " << std::to_string(num_elements) << " arguments" << std::endl;
+ ss << " " << type_name << "(";
+ for (uint32_t c = 0; c < matrix_ty->columns(); c++) {
+ if (c > 0) {
+ ss << ", ";
+ }
+ ss << VectorPretty(matrix_ty->rows(), elem_type);
+ }
+ ss << ")" << std::endl;
+ AddError(ss.str(), Source::Combine(values_start, values_end));
+ };
+
+ const sem::Type* expected_arg_type = nullptr;
+ if (num_elements == values.size()) {
+ // Column-major construction from scalar elements.
+ expected_arg_type = matrix_ty->type();
+ } else if (matrix_ty->columns() == values.size()) {
+ // Column-by-column construction from vectors.
+ expected_arg_type = matrix_ty->ColumnType();
+ } else {
+ print_error();
+ return false;
+ }
+
+ for (auto* arg_ty : arg_tys) {
+ if (arg_ty != expected_arg_type) {
+ print_error();
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Type* ty) {
+ if (ctor->args.size() == 0) {
+ return true;
+ }
+ if (ctor->args.size() > 1) {
+ AddError("expected zero or one value in constructor, got " +
+ std::to_string(ctor->args.size()),
+ ctor->source);
+ return false;
+ }
+
+ // Validate constructor
+ auto* value = ctor->args[0];
+ auto* value_ty = TypeOf(value)->UnwrapRef();
+
+ using Bool = sem::Bool;
+ using I32 = sem::I32;
+ using U32 = sem::U32;
+ using F32 = sem::F32;
+
+ const bool is_valid = (ty->Is<Bool>() && value_ty->is_scalar()) ||
+ (ty->Is<I32>() && value_ty->is_scalar()) ||
+ (ty->Is<U32>() && value_ty->is_scalar()) ||
+ (ty->Is<F32>() && value_ty->is_scalar());
+ if (!is_valid) {
+ AddError("cannot construct '" + TypeNameOf(ty) +
+ "' with a value of type '" + TypeNameOf(value_ty) + "'",
+ ctor->source);
+
+ return false;
+ }
+
+ return true;
+}
+
+bool Resolver::ValidatePipelineStages() {
+ auto check_workgroup_storage = [&](const sem::Function* func,
+ const sem::Function* entry_point) {
+ auto stage = entry_point->Declaration()->PipelineStage();
+ if (stage != ast::PipelineStage::kCompute) {
+ for (auto* var : func->DirectlyReferencedGlobals()) {
+ if (var->StorageClass() == ast::StorageClass::kWorkgroup) {
+ std::stringstream stage_name;
+ stage_name << stage;
+ for (auto* user : var->Users()) {
+ if (func == user->Stmt()->Function()) {
+ AddError("workgroup memory cannot be used by " +
+ stage_name.str() + " pipeline stage",
+ user->Declaration()->source);
+ break;
+ }
+ }
+ AddNote("variable is declared here", var->Declaration()->source);
+ if (func != entry_point) {
+ TraverseCallChain(diagnostics_, entry_point, func,
+ [&](const sem::Function* f) {
+ AddNote("called by function '" +
+ builder_->Symbols().NameFor(
+ f->Declaration()->symbol) +
+ "'",
+ f->Declaration()->source);
+ });
+ AddNote("called by entry point '" +
+ builder_->Symbols().NameFor(
+ entry_point->Declaration()->symbol) +
+ "'",
+ entry_point->Declaration()->source);
+ }
+ return false;
+ }
+ }
+ }
+ return true;
+ };
+
+ for (auto* entry_point : entry_points_) {
+ if (!check_workgroup_storage(entry_point, entry_point)) {
+ return false;
+ }
+ for (auto* func : entry_point->TransitivelyCalledFunctions()) {
+ if (!check_workgroup_storage(func, entry_point)) {
+ return false;
+ }
+ }
+ }
+
+ auto check_builtin_calls = [&](const sem::Function* func,
+ const sem::Function* entry_point) {
+ auto stage = entry_point->Declaration()->PipelineStage();
+ for (auto* builtin : func->DirectlyCalledBuiltins()) {
+ if (!builtin->SupportedStages().Contains(stage)) {
+ auto* call = func->FindDirectCallTo(builtin);
+ std::stringstream err;
+ err << "built-in cannot be used by " << stage << " pipeline stage";
+ AddError(err.str(), call ? call->Declaration()->source
+ : func->Declaration()->source);
+ if (func != entry_point) {
+ TraverseCallChain(
+ diagnostics_, entry_point, func, [&](const sem::Function* f) {
+ AddNote(
+ "called by function '" +
+ builder_->Symbols().NameFor(f->Declaration()->symbol) +
+ "'",
+ f->Declaration()->source);
+ });
+ AddNote("called by entry point '" +
+ builder_->Symbols().NameFor(
+ entry_point->Declaration()->symbol) +
+ "'",
+ entry_point->Declaration()->source);
+ }
+ return false;
+ }
+ }
+ return true;
+ };
+
+ for (auto* entry_point : entry_points_) {
+ if (!check_builtin_calls(entry_point, entry_point)) {
+ return false;
+ }
+ for (auto* func : entry_point->TransitivelyCalledFunctions()) {
+ if (!check_builtin_calls(func, entry_point)) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+bool Resolver::ValidateArray(const sem::Array* arr, const Source& source) {
+ auto* el_ty = arr->ElemType();
+
+ if (!IsFixedFootprint(el_ty)) {
+ AddError("an array element type cannot contain a runtime-sized array",
+ source);
+ return false;
+ }
+ return true;
+}
+
+bool Resolver::ValidateArrayStrideAttribute(const ast::StrideAttribute* attr,
+ uint32_t el_size,
+ uint32_t el_align,
+ const Source& source) {
+ auto stride = attr->stride;
+ bool is_valid_stride =
+ (stride >= el_size) && (stride >= el_align) && (stride % el_align == 0);
+ if (!is_valid_stride) {
+ // https://gpuweb.github.io/gpuweb/wgsl/#array-layout-rules
+ // Arrays decorated with the stride attribute must have a stride that is
+ // at least the size of the element type, and be a multiple of the
+ // element type's alignment value.
+ AddError(
+ "arrays decorated with the stride attribute must have a stride "
+ "that is at least the size of the element type, and be a multiple "
+ "of the element type's alignment value.",
+ source);
+ return false;
+ }
+ return true;
+}
+
+bool Resolver::ValidateAlias(const ast::Alias* alias) {
+ auto name = builder_->Symbols().NameFor(alias->name);
+ if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
+ AddError("'" + name + "' is a builtin and cannot be redeclared as an alias",
+ alias->source);
+ return false;
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateStructure(const sem::Struct* str) {
+ auto name = builder_->Symbols().NameFor(str->Declaration()->name);
+ if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
+ AddError("'" + name + "' is a builtin and cannot be redeclared as a struct",
+ str->Declaration()->source);
+ return false;
+ }
+
+ if (str->Members().empty()) {
+ AddError("structures must have at least one member",
+ str->Declaration()->source);
+ return false;
+ }
+
+ std::unordered_set<uint32_t> locations;
+ for (auto* member : str->Members()) {
+ if (auto* r = member->Type()->As<sem::Array>()) {
+ if (r->IsRuntimeSized()) {
+ if (member != str->Members().back()) {
+ AddError(
+ "runtime arrays may only appear as the last member of a struct",
+ member->Declaration()->source);
+ return false;
+ }
+ }
+ } else if (!IsFixedFootprint(member->Type())) {
+ AddError(
+ "a struct that contains a runtime array cannot be nested inside "
+ "another struct",
+ member->Declaration()->source);
+ return false;
+ }
+
+ auto has_location = false;
+ auto has_position = false;
+ const ast::InvariantAttribute* invariant_attribute = nullptr;
+ const ast::InterpolateAttribute* interpolate_attribute = nullptr;
+ for (auto* attr : member->Declaration()->attributes) {
+ if (!attr->IsAnyOf<ast::BuiltinAttribute, //
+ ast::InternalAttribute, //
+ ast::InterpolateAttribute, //
+ ast::InvariantAttribute, //
+ ast::LocationAttribute, //
+ ast::StructMemberOffsetAttribute, //
+ ast::StructMemberSizeAttribute, //
+ ast::StructMemberAlignAttribute>()) {
+ if (attr->Is<ast::StrideAttribute>() &&
+ IsValidationDisabled(
+ member->Declaration()->attributes,
+ ast::DisabledValidation::kIgnoreStrideAttribute)) {
+ continue;
+ }
+ AddError("attribute is not valid for structure members", attr->source);
+ return false;
+ }
+
+ if (auto* invariant = attr->As<ast::InvariantAttribute>()) {
+ invariant_attribute = invariant;
+ } else if (auto* location = attr->As<ast::LocationAttribute>()) {
+ has_location = true;
+ if (!ValidateLocationAttribute(location, member->Type(), locations,
+ member->Declaration()->source)) {
+ return false;
+ }
+ } else if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
+ if (!ValidateBuiltinAttribute(builtin, member->Type(),
+ /* is_input */ false)) {
+ return false;
+ }
+ if (builtin->builtin == ast::Builtin::kPosition) {
+ has_position = true;
+ }
+ } else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
+ interpolate_attribute = interpolate;
+ if (!ValidateInterpolateAttribute(interpolate, member->Type())) {
+ return false;
+ }
+ }
+ }
+
+ if (invariant_attribute && !has_position) {
+ AddError("invariant attribute must only be applied to a position builtin",
+ invariant_attribute->source);
+ return false;
+ }
+
+ if (interpolate_attribute && !has_location) {
+ AddError("interpolate attribute must only be used with @location",
+ interpolate_attribute->source);
+ return false;
+ }
+ }
+
+ for (auto* attr : str->Declaration()->attributes) {
+ if (!(attr->IsAnyOf<ast::StructBlockAttribute, ast::InternalAttribute>())) {
+ AddError("attribute is not valid for struct declarations", attr->source);
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateLocationAttribute(
+ const ast::LocationAttribute* location,
+ const sem::Type* type,
+ std::unordered_set<uint32_t>& locations,
+ const Source& source,
+ const bool is_input) {
+ std::string inputs_or_output = is_input ? "inputs" : "output";
+ if (current_function_ && current_function_->Declaration()->PipelineStage() ==
+ ast::PipelineStage::kCompute) {
+ AddError("attribute is not valid for compute shader " + inputs_or_output,
+ location->source);
+ return false;
+ }
+
+ if (!type->is_numeric_scalar_or_vector()) {
+ std::string invalid_type = TypeNameOf(type);
+ AddError("cannot apply 'location' attribute to declaration of type '" +
+ invalid_type + "'",
+ source);
+ AddNote(
+ "'location' attribute must only be applied to declarations of "
+ "numeric scalar or numeric vector type",
+ location->source);
+ return false;
+ }
+
+ if (locations.count(location->value)) {
+ AddError(attr_to_str(location) + " attribute appears multiple times",
+ location->source);
+ return false;
+ }
+ locations.emplace(location->value);
+
+ return true;
+}
+
+bool Resolver::ValidateReturn(const ast::ReturnStatement* ret) {
+ auto* func_type = current_function_->ReturnType();
+
+ auto* ret_type = ret->value ? TypeOf(ret->value)->UnwrapRef()
+ : builder_->create<sem::Void>();
+
+ if (func_type->UnwrapRef() != ret_type) {
+ AddError(
+ "return statement type must match its function "
+ "return type, returned '" +
+ TypeNameOf(ret_type) + "', expected '" + TypeNameOf(func_type) +
+ "'",
+ ret->source);
+ return false;
+ }
+
+ auto* sem = Sem(ret);
+ if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) {
+ AddError("continuing blocks must not contain a return statement",
+ ret->source);
+ if (continuing != sem->Declaration() &&
+ continuing != sem->Parent()->Declaration()) {
+ AddNote("see continuing block here", continuing->source);
+ }
+ return false;
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) {
+ auto* cond_ty = TypeOf(s->condition)->UnwrapRef();
+ if (!cond_ty->is_integer_scalar()) {
+ AddError(
+ "switch statement selector expression must be of a "
+ "scalar integer type",
+ s->condition->source);
+ return false;
+ }
+
+ bool has_default = false;
+ std::unordered_map<uint32_t, Source> selectors;
+
+ for (auto* case_stmt : s->body) {
+ if (case_stmt->IsDefault()) {
+ if (has_default) {
+ // More than one default clause
+ AddError("switch statement must have exactly one default clause",
+ case_stmt->source);
+ return false;
+ }
+ has_default = true;
+ }
+
+ for (auto* selector : case_stmt->selectors) {
+ if (cond_ty != TypeOf(selector)) {
+ AddError(
+ "the case selector values must have the same "
+ "type as the selector expression.",
+ case_stmt->source);
+ return false;
+ }
+
+ auto v = selector->ValueAsU32();
+ auto it = selectors.find(v);
+ if (it != selectors.end()) {
+ auto val = selector->Is<ast::IntLiteralExpression>()
+ ? std::to_string(selector->ValueAsI32())
+ : std::to_string(selector->ValueAsU32());
+ AddError("duplicate switch case '" + val + "'", selector->source);
+ AddNote("previous case declared here", it->second);
+ return false;
+ }
+ selectors.emplace(v, selector->source);
+ }
+ }
+
+ if (!has_default) {
+ // No default clause
+ AddError("switch statement must have a default clause", s->source);
+ return false;
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateAssignment(const ast::AssignmentStatement* a) {
+ auto const* rhs_ty = TypeOf(a->rhs);
+
+ if (a->lhs->Is<ast::PhonyExpression>()) {
+ // https://www.w3.org/TR/WGSL/#phony-assignment-section
+ auto* ty = rhs_ty->UnwrapRef();
+ if (!ty->IsConstructible() &&
+ !ty->IsAnyOf<sem::Pointer, sem::Texture, sem::Sampler>()) {
+ AddError(
+ "cannot assign '" + TypeNameOf(rhs_ty) +
+ "' to '_'. '_' can only be assigned a constructible, pointer, "
+ "texture or sampler type",
+ a->rhs->source);
+ return false;
+ }
+ return true; // RHS can be anything.
+ }
+
+ // https://gpuweb.github.io/gpuweb/wgsl/#assignment-statement
+ auto const* lhs_ty = TypeOf(a->lhs);
+
+ if (auto* var = ResolvedSymbol<sem::Variable>(a->lhs)) {
+ auto* decl = var->Declaration();
+ if (var->Is<sem::Parameter>()) {
+ AddError("cannot assign to function parameter", a->lhs->source);
+ AddNote("'" + builder_->Symbols().NameFor(decl->symbol) +
+ "' is declared here:",
+ decl->source);
+ return false;
+ }
+ if (decl->is_const) {
+ AddError("cannot assign to const", a->lhs->source);
+ AddNote("'" + builder_->Symbols().NameFor(decl->symbol) +
+ "' is declared here:",
+ decl->source);
+ return false;
+ }
+ }
+
+ auto* lhs_ref = lhs_ty->As<sem::Reference>();
+ if (!lhs_ref) {
+ // LHS is not a reference, so it has no storage.
+ AddError("cannot assign to value of type '" + TypeNameOf(lhs_ty) + "'",
+ a->lhs->source);
+ return false;
+ }
+
+ auto* storage_ty = lhs_ref->StoreType();
+ auto* value_type = rhs_ty->UnwrapRef(); // Implicit load of RHS
+
+ // Value type has to match storage type
+ if (storage_ty != value_type) {
+ AddError("cannot assign '" + TypeNameOf(rhs_ty) + "' to '" +
+ TypeNameOf(lhs_ty) + "'",
+ a->source);
+ return false;
+ }
+ if (!storage_ty->IsConstructible()) {
+ AddError("storage type of assignment must be constructible", a->source);
+ return false;
+ }
+ if (lhs_ref->Access() == ast::Access::kRead) {
+ AddError(
+ "cannot store into a read-only type '" + RawTypeNameOf(lhs_ty) + "'",
+ a->source);
+ return false;
+ }
+ return true;
+}
+
+bool Resolver::ValidateNoDuplicateAttributes(
+ const ast::AttributeList& attributes) {
+ std::unordered_map<const TypeInfo*, Source> seen;
+ for (auto* d : attributes) {
+ auto res = seen.emplace(&d->TypeInfo(), d->source);
+ if (!res.second && !d->Is<ast::InternalAttribute>()) {
+ AddError("duplicate " + d->Name() + " attribute", d->source);
+ AddNote("first attribute declared here", res.first->second);
+ return false;
+ }
+ }
+ return true;
+}
+
+bool Resolver::IsValidationDisabled(const ast::AttributeList& attributes,
+ ast::DisabledValidation validation) const {
+ for (auto* attribute : attributes) {
+ if (auto* dv = attribute->As<ast::DisableValidationAttribute>()) {
+ if (dv->validation == validation) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+bool Resolver::IsValidationEnabled(const ast::AttributeList& attributes,
+ ast::DisabledValidation validation) const {
+ return !IsValidationDisabled(attributes, validation);
+}
+
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/side_effects_test.cc b/src/tint/resolver/side_effects_test.cc
new file mode 100644
index 0000000..944ff5d
--- /dev/null
+++ b/src/tint/resolver/side_effects_test.cc
@@ -0,0 +1,371 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gtest/gtest.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/expression.h"
+#include "src/tint/sem/member_accessor_expression.h"
+
+namespace tint::resolver {
+namespace {
+
+struct SideEffectsTest : ResolverTest {
+ template <typename T>
+ void MakeSideEffectFunc(const char* name) {
+ auto global = Sym();
+ Global(global, ty.Of<T>(), ast::StorageClass::kPrivate);
+ auto local = Sym();
+ Func(name, {}, ty.Of<T>(),
+ {
+ Decl(Var(local, ty.Of<T>())),
+ Assign(global, local),
+ Return(global),
+ });
+ }
+
+ template <typename MAKE_TYPE_FUNC>
+ void MakeSideEffectFunc(const char* name, MAKE_TYPE_FUNC make_type) {
+ auto global = Sym();
+ Global(global, make_type(), ast::StorageClass::kPrivate);
+ auto local = Sym();
+ Func(name, {}, make_type(),
+ {
+ Decl(Var(local, make_type())),
+ Assign(global, local),
+ Return(global),
+ });
+ }
+};
+
+TEST_F(SideEffectsTest, Phony) {
+ auto* expr = Phony();
+ auto* body = Assign(expr, 1);
+ WrapInFunction(body);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_FALSE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, Literal) {
+ auto* expr = Expr(1);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_FALSE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, VariableUser) {
+ auto* var = Decl(Var("a", ty.i32()));
+ auto* expr = Expr("a");
+ WrapInFunction(var, expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Is<sem::VariableUser>());
+ EXPECT_FALSE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, Call_Builtin_NoSE) {
+ Global("a", ty.f32(), ast::StorageClass::kPrivate);
+ auto* expr = Call("dpdx", "a");
+ Func("f", {}, ty.void_(), {Ignore(expr)},
+ {create<ast::StageAttribute>(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Is<sem::Call>());
+ EXPECT_FALSE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, Call_Builtin_NoSE_WithSEArg) {
+ MakeSideEffectFunc<f32>("se");
+ auto* expr = Call("dpdx", Call("se"));
+ Func("f", {}, ty.void_(), {Ignore(expr)},
+ {create<ast::StageAttribute>(ast::PipelineStage::kFragment)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Is<sem::Call>());
+ EXPECT_TRUE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, Call_Builtin_SE) {
+ Global("a", ty.atomic(ty.i32()), ast::StorageClass::kWorkgroup);
+ auto* expr = Call("atomicAdd", AddressOf("a"), 1);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Is<sem::Call>());
+ EXPECT_TRUE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, Call_Function) {
+ Func("f", {}, ty.i32(), {Return(1)});
+ auto* expr = Call("f");
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Is<sem::Call>());
+ EXPECT_TRUE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, Call_TypeConversion_NoSE) {
+ auto* var = Decl(Var("a", ty.i32()));
+ auto* expr = Construct(ty.f32(), "a");
+ WrapInFunction(var, expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Is<sem::Call>());
+ EXPECT_FALSE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, Call_TypeConversion_SE) {
+ MakeSideEffectFunc<i32>("se");
+ auto* expr = Construct(ty.f32(), Call("se"));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Is<sem::Call>());
+ EXPECT_TRUE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, Call_TypeConstructor_NoSE) {
+ auto* var = Decl(Var("a", ty.f32()));
+ auto* expr = Construct(ty.f32(), "a");
+ WrapInFunction(var, expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Is<sem::Call>());
+ EXPECT_FALSE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, Call_TypeConstructor_SE) {
+ MakeSideEffectFunc<f32>("se");
+ auto* expr = Construct(ty.f32(), Call("se"));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Is<sem::Call>());
+ EXPECT_TRUE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, MemberAccessor_Struct_NoSE) {
+ auto* s = Structure("S", {Member("m", ty.i32())});
+ auto* var = Decl(Var("a", ty.Of(s)));
+ auto* expr = MemberAccessor("a", "m");
+ WrapInFunction(var, expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_FALSE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, MemberAccessor_Struct_SE) {
+ auto* s = Structure("S", {Member("m", ty.i32())});
+ MakeSideEffectFunc("se", [&] { return ty.Of(s); });
+ auto* expr = MemberAccessor(Call("se"), "m");
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, MemberAccessor_Vector) {
+ auto* var = Decl(Var("a", ty.vec4<f32>()));
+ auto* expr = MemberAccessor("a", "x");
+ WrapInFunction(var, expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ EXPECT_TRUE(sem->Is<sem::MemberAccessorExpression>());
+ ASSERT_NE(sem, nullptr);
+ EXPECT_FALSE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, MemberAccessor_VectorSwizzle) {
+ auto* var = Decl(Var("a", ty.vec4<f32>()));
+ auto* expr = MemberAccessor("a", "xzyw");
+ WrapInFunction(var, expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ EXPECT_TRUE(sem->Is<sem::Swizzle>());
+ ASSERT_NE(sem, nullptr);
+ EXPECT_FALSE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, Binary_NoSE) {
+ auto* a = Decl(Var("a", ty.i32()));
+ auto* b = Decl(Var("b", ty.i32()));
+ auto* expr = Add("a", "b");
+ WrapInFunction(a, b, expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_FALSE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, Binary_LeftSE) {
+ MakeSideEffectFunc<i32>("se");
+ auto* b = Decl(Var("b", ty.i32()));
+ auto* expr = Add(Call("se"), "b");
+ WrapInFunction(b, expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, Binary_RightSE) {
+ MakeSideEffectFunc<i32>("se");
+ auto* a = Decl(Var("a", ty.i32()));
+ auto* expr = Add("a", Call("se"));
+ WrapInFunction(a, expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, Binary_BothSE) {
+ MakeSideEffectFunc<i32>("se1");
+ MakeSideEffectFunc<i32>("se2");
+ auto* expr = Add(Call("se1"), Call("se2"));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, Unary_NoSE) {
+ auto* var = Decl(Var("a", ty.bool_()));
+ auto* expr = Not("a");
+ WrapInFunction(var, expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_FALSE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, Unary_SE) {
+ MakeSideEffectFunc<bool>("se");
+ auto* expr = Not(Call("se"));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, IndexAccessor_NoSE) {
+ auto* var = Decl(Var("a", ty.array<i32, 10>()));
+ auto* expr = IndexAccessor("a", 0);
+ WrapInFunction(var, expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_FALSE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, IndexAccessor_ObjSE) {
+ MakeSideEffectFunc("se", [&] { return ty.array<i32, 10>(); });
+ auto* expr = IndexAccessor(Call("se"), 0);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, IndexAccessor_IndexSE) {
+ MakeSideEffectFunc<i32>("se");
+ auto* var = Decl(Var("a", ty.array<i32, 10>()));
+ auto* expr = IndexAccessor("a", Call("se"));
+ WrapInFunction(var, expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, IndexAccessor_BothSE) {
+ MakeSideEffectFunc("se1", [&] { return ty.array<i32, 10>(); });
+ MakeSideEffectFunc<i32>("se2");
+ auto* expr = IndexAccessor(Call("se1"), Call("se2"));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, Bitcast_NoSE) {
+ auto* var = Decl(Var("a", ty.i32()));
+ auto* expr = Bitcast<f32>("a");
+ WrapInFunction(var, expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_FALSE(sem->HasSideEffects());
+}
+
+TEST_F(SideEffectsTest, Bitcast_SE) {
+ MakeSideEffectFunc<i32>("se");
+ auto* expr = Bitcast<f32>(Call("se"));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->HasSideEffects());
+}
+
+} // namespace
+} // namespace tint::resolver
diff --git a/src/tint/resolver/storage_class_layout_validation_test.cc b/src/tint/resolver/storage_class_layout_validation_test.cc
new file mode 100644
index 0000000..406fdb8
--- /dev/null
+++ b/src/tint/resolver/storage_class_layout_validation_test.cc
@@ -0,0 +1,573 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ResolverStorageClassLayoutValidationTest = ResolverTest;
+
+// Detect unaligned member for storage buffers
+TEST_F(ResolverStorageClassLayoutValidationTest,
+ StorageBuffer_UnalignedMember) {
+ // [[block]]
+ // struct S {
+ // @size(5) a : f32;
+ // @align(1) b : f32;
+ // };
+ // @group(0) @binding(0)
+ // var<storage> a : S;
+
+ Structure(Source{{12, 34}}, "S",
+ {Member("a", ty.f32(), {MemberSize(5)}),
+ Member(Source{{34, 56}}, "b", ty.f32(), {MemberAlign(1)})},
+ {StructBlock()});
+
+ Global(Source{{78, 90}}, "a", ty.type_name("S"), ast::StorageClass::kStorage,
+ GroupAndBinding(0, 0));
+
+ ASSERT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(34:56 error: the offset of a struct member of type 'f32' in storage class 'storage' must be a multiple of 4 bytes, but 'b' is currently at offset 5. Consider setting @align(4) on this member
+12:34 note: see layout of struct:
+/* align(4) size(12) */ struct S {
+/* offset(0) align(4) size( 5) */ a : f32;
+/* offset(5) align(1) size( 4) */ b : f32;
+/* offset(9) align(1) size( 3) */ // -- implicit struct size padding --;
+/* */ };
+78:90 note: see declaration of variable)");
+}
+
+TEST_F(ResolverStorageClassLayoutValidationTest,
+ StorageBuffer_UnalignedMember_SuggestedFix) {
+ // [[block]]
+ // struct S {
+ // @size(5) a : f32;
+ // @align(4) b : f32;
+ // };
+ // @group(0) @binding(0)
+ // var<storage> a : S;
+
+ Structure(Source{{12, 34}}, "S",
+ {Member("a", ty.f32(), {MemberSize(5)}),
+ Member(Source{{34, 56}}, "b", ty.f32(), {MemberAlign(4)})},
+ {StructBlock()});
+
+ Global(Source{{78, 90}}, "a", ty.type_name("S"), ast::StorageClass::kStorage,
+ GroupAndBinding(0, 0));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+// Detect unaligned struct member for uniform buffers
+TEST_F(ResolverStorageClassLayoutValidationTest,
+ UniformBuffer_UnalignedMember_Struct) {
+ // struct Inner {
+ // scalar : i32;
+ // };
+ //
+ // [[block]]
+ // struct Outer {
+ // scalar : f32;
+ // inner : Inner;
+ // };
+ //
+ // @group(0) @binding(0)
+ // var<uniform> a : Outer;
+
+ Structure(Source{{12, 34}}, "Inner", {Member("scalar", ty.i32())});
+
+ Structure(Source{{34, 56}}, "Outer",
+ {
+ Member("scalar", ty.f32()),
+ Member(Source{{56, 78}}, "inner", ty.type_name("Inner")),
+ },
+ {StructBlock()});
+
+ Global(Source{{78, 90}}, "a", ty.type_name("Outer"),
+ ast::StorageClass::kUniform, GroupAndBinding(0, 0));
+
+ ASSERT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: the offset of a struct member of type 'Inner' in storage class 'uniform' must be a multiple of 16 bytes, but 'inner' is currently at offset 4. Consider setting @align(16) on this member
+34:56 note: see layout of struct:
+/* align(4) size(8) */ struct Outer {
+/* offset(0) align(4) size(4) */ scalar : f32;
+/* offset(4) align(4) size(4) */ inner : Inner;
+/* */ };
+12:34 note: and layout of struct member:
+/* align(4) size(4) */ struct Inner {
+/* offset(0) align(4) size(4) */ scalar : i32;
+/* */ };
+78:90 note: see declaration of variable)");
+}
+
+TEST_F(ResolverStorageClassLayoutValidationTest,
+ UniformBuffer_UnalignedMember_Struct_SuggestedFix) {
+ // struct Inner {
+ // scalar : i32;
+ // };
+ //
+ // [[block]]
+ // struct Outer {
+ // scalar : f32;
+ // @align(16) inner : Inner;
+ // };
+ //
+ // @group(0) @binding(0)
+ // var<uniform> a : Outer;
+
+ Structure(Source{{12, 34}}, "Inner", {Member("scalar", ty.i32())});
+
+ Structure(Source{{34, 56}}, "Outer",
+ {
+ Member("scalar", ty.f32()),
+ Member(Source{{56, 78}}, "inner", ty.type_name("Inner"),
+ {MemberAlign(16)}),
+ },
+ {StructBlock()});
+
+ Global(Source{{78, 90}}, "a", ty.type_name("Outer"),
+ ast::StorageClass::kUniform, GroupAndBinding(0, 0));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+// Detect unaligned array member for uniform buffers
+TEST_F(ResolverStorageClassLayoutValidationTest,
+ UniformBuffer_UnalignedMember_Array) {
+ // type Inner = @stride(16) array<f32, 10>;
+ //
+ // [[block]]
+ // struct Outer {
+ // scalar : f32;
+ // inner : Inner;
+ // };
+ //
+ // @group(0) @binding(0)
+ // var<uniform> a : Outer;
+ Alias("Inner", ty.array(ty.f32(), 10, 16));
+
+ Structure(Source{{12, 34}}, "Outer",
+ {
+ Member("scalar", ty.f32()),
+ Member(Source{{56, 78}}, "inner", ty.type_name("Inner")),
+ },
+ {StructBlock()});
+
+ Global(Source{{78, 90}}, "a", ty.type_name("Outer"),
+ ast::StorageClass::kUniform, GroupAndBinding(0, 0));
+
+ ASSERT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: the offset of a struct member of type '@stride(16) array<f32, 10>' in storage class 'uniform' must be a multiple of 16 bytes, but 'inner' is currently at offset 4. Consider setting @align(16) on this member
+12:34 note: see layout of struct:
+/* align(4) size(164) */ struct Outer {
+/* offset( 0) align(4) size( 4) */ scalar : f32;
+/* offset( 4) align(4) size(160) */ inner : @stride(16) array<f32, 10>;
+/* */ };
+78:90 note: see declaration of variable)");
+}
+
+TEST_F(ResolverStorageClassLayoutValidationTest,
+ UniformBuffer_UnalignedMember_Array_SuggestedFix) {
+ // type Inner = @stride(16) array<f32, 10>;
+ //
+ // [[block]]
+ // struct Outer {
+ // scalar : f32;
+ // @align(16) inner : Inner;
+ // };
+ //
+ // @group(0) @binding(0)
+ // var<uniform> a : Outer;
+ Alias("Inner", ty.array(ty.f32(), 10, 16));
+
+ Structure(Source{{12, 34}}, "Outer",
+ {
+ Member("scalar", ty.f32()),
+ Member(Source{{34, 56}}, "inner", ty.type_name("Inner"),
+ {MemberAlign(16)}),
+ },
+ {StructBlock()});
+
+ Global(Source{{78, 90}}, "a", ty.type_name("Outer"),
+ ast::StorageClass::kUniform, GroupAndBinding(0, 0));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+// Detect uniform buffers with byte offset between 2 members that is not a
+// multiple of 16 bytes
+TEST_F(ResolverStorageClassLayoutValidationTest,
+ UniformBuffer_MembersOffsetNotMultipleOf16) {
+ // struct Inner {
+ // @align(1) @size(5) scalar : i32;
+ // };
+ //
+ // [[block]]
+ // struct Outer {
+ // inner : Inner;
+ // scalar : i32;
+ // };
+ //
+ // @group(0) @binding(0)
+ // var<uniform> a : Outer;
+
+ Structure(Source{{12, 34}}, "Inner",
+ {Member("scalar", ty.i32(), {MemberAlign(1), MemberSize(5)})});
+
+ Structure(Source{{34, 56}}, "Outer",
+ {
+ Member(Source{{56, 78}}, "inner", ty.type_name("Inner")),
+ Member(Source{{78, 90}}, "scalar", ty.i32()),
+ },
+ {StructBlock()});
+
+ Global(Source{{22, 24}}, "a", ty.type_name("Outer"),
+ ast::StorageClass::kUniform, GroupAndBinding(0, 0));
+
+ ASSERT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(78:90 error: uniform storage requires that the number of bytes between the start of the previous member of type struct and the current member be a multiple of 16 bytes, but there are currently 8 bytes between 'inner' and 'scalar'. Consider setting @align(16) on this member
+34:56 note: see layout of struct:
+/* align(4) size(12) */ struct Outer {
+/* offset( 0) align(1) size( 5) */ inner : Inner;
+/* offset( 5) align(1) size( 3) */ // -- implicit field alignment padding --;
+/* offset( 8) align(4) size( 4) */ scalar : i32;
+/* */ };
+12:34 note: and layout of previous member struct:
+/* align(1) size(5) */ struct Inner {
+/* offset(0) align(1) size(5) */ scalar : i32;
+/* */ };
+22:24 note: see declaration of variable)");
+}
+
+// See https://crbug.com/tint/1344
+TEST_F(ResolverStorageClassLayoutValidationTest,
+ UniformBuffer_MembersOffsetNotMultipleOf16_InnerMoreMembersThanOuter) {
+ // struct Inner {
+ // a : i32;
+ // b : i32;
+ // c : i32;
+ // @align(1) @size(5) scalar : i32;
+ // };
+ //
+ // [[block]]
+ // struct Outer {
+ // inner : Inner;
+ // scalar : i32;
+ // };
+ //
+ // @group(0) @binding(0)
+ // var<uniform> a : Outer;
+
+ Structure(Source{{12, 34}}, "Inner",
+ {
+ Member("a", ty.i32()),
+ Member("b", ty.i32()),
+ Member("c", ty.i32()),
+ Member("scalar", ty.i32(), {MemberAlign(1), MemberSize(5)}),
+ });
+
+ Structure(Source{{34, 56}}, "Outer",
+ {
+ Member(Source{{56, 78}}, "inner", ty.type_name("Inner")),
+ Member(Source{{78, 90}}, "scalar", ty.i32()),
+ },
+ {StructBlock()});
+
+ Global(Source{{22, 24}}, "a", ty.type_name("Outer"),
+ ast::StorageClass::kUniform, GroupAndBinding(0, 0));
+
+ ASSERT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(78:90 error: uniform storage requires that the number of bytes between the start of the previous member of type struct and the current member be a multiple of 16 bytes, but there are currently 20 bytes between 'inner' and 'scalar'. Consider setting @align(16) on this member
+34:56 note: see layout of struct:
+/* align(4) size(24) */ struct Outer {
+/* offset( 0) align(4) size(20) */ inner : Inner;
+/* offset(20) align(4) size( 4) */ scalar : i32;
+/* */ };
+12:34 note: and layout of previous member struct:
+/* align(4) size(20) */ struct Inner {
+/* offset( 0) align(4) size( 4) */ a : i32;
+/* offset( 4) align(4) size( 4) */ b : i32;
+/* offset( 8) align(4) size( 4) */ c : i32;
+/* offset(12) align(1) size( 5) */ scalar : i32;
+/* offset(17) align(1) size( 3) */ // -- implicit struct size padding --;
+/* */ };
+22:24 note: see declaration of variable)");
+}
+
+TEST_F(ResolverStorageClassLayoutValidationTest,
+ UniformBuffer_MembersOffsetNotMultipleOf16_SuggestedFix) {
+ // struct Inner {
+ // @align(1) @size(5) scalar : i32;
+ // };
+ //
+ // [[block]]
+ // struct Outer {
+ // @align(16) inner : Inner;
+ // scalar : i32;
+ // };
+ //
+ // @group(0) @binding(0)
+ // var<uniform> a : Outer;
+
+ Structure(Source{{12, 34}}, "Inner",
+ {Member("scalar", ty.i32(), {MemberAlign(1), MemberSize(5)})});
+
+ Structure(Source{{34, 56}}, "Outer",
+ {
+ Member(Source{{56, 78}}, "inner", ty.type_name("Inner")),
+ Member(Source{{78, 90}}, "scalar", ty.i32(), {MemberAlign(16)}),
+ },
+ {StructBlock()});
+
+ Global(Source{{22, 34}}, "a", ty.type_name("Outer"),
+ ast::StorageClass::kUniform, GroupAndBinding(0, 0));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+// Make sure that this doesn't fail validation because vec3's align is 16, but
+// size is 12. 's' should be at offset 12, which is okay here.
+TEST_F(ResolverStorageClassLayoutValidationTest,
+ UniformBuffer_Vec3MemberOffset_NoFail) {
+ // [[block]]
+ // struct ScalarPackedAtEndOfVec3 {
+ // v : vec3<f32>;
+ // s : f32;
+ // };
+ // @group(0) @binding(0)
+ // var<uniform> a : ScalarPackedAtEndOfVec3;
+
+ Structure("ScalarPackedAtEndOfVec3",
+ {
+ Member("v", ty.vec3(ty.f32())),
+ Member("s", ty.f32()),
+ },
+ {StructBlock()});
+
+ Global(Source{{78, 90}}, "a", ty.type_name("ScalarPackedAtEndOfVec3"),
+ ast::StorageClass::kUniform, GroupAndBinding(0, 0));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+// Detect array stride must be a multiple of 16 bytes for uniform buffers
+TEST_F(ResolverStorageClassLayoutValidationTest,
+ UniformBuffer_InvalidArrayStride_Scalar) {
+ // type Inner = array<f32, 10>;
+ //
+ // [[block]]
+ // struct Outer {
+ // inner : Inner;
+ // scalar : i32;
+ // };
+ //
+ // @group(0) @binding(0)
+ // var<uniform> a : Outer;
+
+ Alias("Inner", ty.array(ty.f32(), 10));
+
+ Structure(Source{{12, 34}}, "Outer",
+ {
+ Member("inner", ty.type_name(Source{{34, 56}}, "Inner")),
+ Member("scalar", ty.i32()),
+ },
+ {StructBlock()});
+
+ Global(Source{{78, 90}}, "a", ty.type_name("Outer"),
+ ast::StorageClass::kUniform, GroupAndBinding(0, 0));
+
+ ASSERT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(34:56 error: uniform storage requires that array elements be aligned to 16 bytes, but array element alignment is currently 4. Consider using a vector or struct as the element type instead.
+12:34 note: see layout of struct:
+/* align(4) size(44) */ struct Outer {
+/* offset( 0) align(4) size(40) */ inner : array<f32, 10>;
+/* offset(40) align(4) size( 4) */ scalar : i32;
+/* */ };
+78:90 note: see declaration of variable)");
+}
+
+TEST_F(ResolverStorageClassLayoutValidationTest,
+ UniformBuffer_InvalidArrayStride_Vector) {
+ // type Inner = array<vec2<f32>, 10>;
+ //
+ // [[block]]
+ // struct Outer {
+ // inner : Inner;
+ // scalar : i32;
+ // };
+ //
+ // @group(0) @binding(0)
+ // var<uniform> a : Outer;
+
+ Alias("Inner", ty.array(ty.vec2<f32>(), 10));
+
+ Structure(Source{{12, 34}}, "Outer",
+ {
+ Member("inner", ty.type_name(Source{{34, 56}}, "Inner")),
+ Member("scalar", ty.i32()),
+ },
+ {StructBlock()});
+
+ Global(Source{{78, 90}}, "a", ty.type_name("Outer"),
+ ast::StorageClass::kUniform, GroupAndBinding(0, 0));
+
+ ASSERT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(34:56 error: uniform storage requires that array elements be aligned to 16 bytes, but array element alignment is currently 8. Consider using a vec4 instead.
+12:34 note: see layout of struct:
+/* align(8) size(88) */ struct Outer {
+/* offset( 0) align(8) size(80) */ inner : array<vec2<f32>, 10>;
+/* offset(80) align(4) size( 4) */ scalar : i32;
+/* offset(84) align(1) size( 4) */ // -- implicit struct size padding --;
+/* */ };
+78:90 note: see declaration of variable)");
+}
+
+TEST_F(ResolverStorageClassLayoutValidationTest,
+ UniformBuffer_InvalidArrayStride_Struct) {
+ // struct ArrayElem {
+ // a : f32;
+ // b : i32;
+ // }
+ // type Inner = array<ArrayElem, 10>;
+ //
+ // [[block]]
+ // struct Outer {
+ // inner : Inner;
+ // scalar : i32;
+ // };
+ //
+ // @group(0) @binding(0)
+ // var<uniform> a : Outer;
+
+ auto* array_elem = Structure("ArrayElem", {
+ Member("a", ty.f32()),
+ Member("b", ty.i32()),
+ });
+ Alias("Inner", ty.array(ty.Of(array_elem), 10));
+
+ Structure(Source{{12, 34}}, "Outer",
+ {
+ Member("inner", ty.type_name(Source{{34, 56}}, "Inner")),
+ Member("scalar", ty.i32()),
+ },
+ {StructBlock()});
+
+ Global(Source{{78, 90}}, "a", ty.type_name("Outer"),
+ ast::StorageClass::kUniform, GroupAndBinding(0, 0));
+
+ ASSERT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(34:56 error: uniform storage requires that array elements be aligned to 16 bytes, but array element alignment is currently 8. Consider using the @size attribute on the last struct member.
+12:34 note: see layout of struct:
+/* align(4) size(84) */ struct Outer {
+/* offset( 0) align(4) size(80) */ inner : array<ArrayElem, 10>;
+/* offset(80) align(4) size( 4) */ scalar : i32;
+/* */ };
+78:90 note: see declaration of variable)");
+}
+
+TEST_F(ResolverStorageClassLayoutValidationTest,
+ UniformBuffer_InvalidArrayStride_TopLevelArray) {
+ // @group(0) @binding(0)
+ // var<uniform> a : array<f32, 4>;
+ Global(Source{{78, 90}}, "a", ty.array(Source{{34, 56}}, ty.f32(), 4),
+ ast::StorageClass::kUniform, GroupAndBinding(0, 0));
+
+ ASSERT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(34:56 error: uniform storage requires that array elements be aligned to 16 bytes, but array element alignment is currently 4. Consider using a vector or struct as the element type instead.)");
+}
+
+TEST_F(ResolverStorageClassLayoutValidationTest,
+ UniformBuffer_InvalidArrayStride_NestedArray) {
+ // struct Outer {
+ // inner : array<array<f32, 4>, 4>
+ // };
+ //
+ // @group(0) @binding(0)
+ // var<uniform> a : array<Outer, 4>;
+
+ Structure(
+ Source{{12, 34}}, "Outer",
+ {
+ Member("inner", ty.array(Source{{34, 56}}, ty.array(ty.f32(), 4), 4)),
+ },
+ {StructBlock()});
+
+ Global(Source{{78, 90}}, "a", ty.type_name("Outer"),
+ ast::StorageClass::kUniform, GroupAndBinding(0, 0));
+
+ ASSERT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(34:56 error: uniform storage requires that array elements be aligned to 16 bytes, but array element alignment is currently 4. Consider using a vector or struct as the element type instead.
+12:34 note: see layout of struct:
+/* align(4) size(64) */ struct Outer {
+/* offset( 0) align(4) size(64) */ inner : array<array<f32, 4>, 4>;
+/* */ };
+78:90 note: see declaration of variable)");
+}
+
+TEST_F(ResolverStorageClassLayoutValidationTest,
+ UniformBuffer_InvalidArrayStride_SuggestedFix) {
+ // type Inner = @stride(16) array<f32, 10>;
+ //
+ // [[block]]
+ // struct Outer {
+ // inner : Inner;
+ // scalar : i32;
+ // };
+ //
+ // @group(0) @binding(0)
+ // var<uniform> a : Outer;
+
+ Alias("Inner", ty.array(ty.f32(), 10, 16));
+
+ Structure(Source{{12, 34}}, "Outer",
+ {
+ Member("inner", ty.type_name(Source{{34, 56}}, "Inner")),
+ Member("scalar", ty.i32()),
+ },
+ {StructBlock()});
+
+ Global(Source{{78, 90}}, "a", ty.type_name("Outer"),
+ ast::StorageClass::kUniform, GroupAndBinding(0, 0));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/storage_class_validation_test.cc b/src/tint/resolver/storage_class_validation_test.cc
new file mode 100644
index 0000000..0c33b8b
--- /dev/null
+++ b/src/tint/resolver/storage_class_validation_test.cc
@@ -0,0 +1,370 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/ast/struct_block_attribute.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/struct.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ResolverStorageClassValidationTest = ResolverTest;
+
+TEST_F(ResolverStorageClassValidationTest, GlobalVariableNoStorageClass_Fail) {
+ // var g : f32;
+ Global(Source{{12, 34}}, "g", ty.f32(), ast::StorageClass::kNone);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: global variables must have a storage class");
+}
+
+TEST_F(ResolverStorageClassValidationTest,
+ GlobalVariableFunctionStorageClass_Fail) {
+ // var<function> g : f32;
+ Global(Source{{12, 34}}, "g", ty.f32(), ast::StorageClass::kFunction);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: variables declared at module scope must not be in "
+ "the function storage class");
+}
+
+TEST_F(ResolverStorageClassValidationTest, Private_RuntimeArray) {
+ Global(Source{{12, 34}}, "v", ty.array(ty.i32()),
+ ast::StorageClass::kPrivate);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: runtime-sized arrays can only be used in the <storage> storage class
+12:34 note: while instantiating variable v)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, Private_RuntimeArrayInStruct) {
+ auto* s = Structure("S", {Member("m", ty.array(ty.i32()))}, {StructBlock()});
+ Global(Source{{12, 34}}, "v", ty.Of(s), ast::StorageClass::kPrivate);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: runtime-sized arrays can only be used in the <storage> storage class
+note: while analysing structure member S.m
+12:34 note: while instantiating variable v)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, Workgroup_RuntimeArray) {
+ Global(Source{{12, 34}}, "v", ty.array(ty.i32()),
+ ast::StorageClass::kWorkgroup);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: runtime-sized arrays can only be used in the <storage> storage class
+12:34 note: while instantiating variable v)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, Workgroup_RuntimeArrayInStruct) {
+ auto* s = Structure("S", {Member("m", ty.array(ty.i32()))}, {StructBlock()});
+ Global(Source{{12, 34}}, "v", ty.Of(s), ast::StorageClass::kWorkgroup);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: runtime-sized arrays can only be used in the <storage> storage class
+note: while analysing structure member S.m
+12:34 note: while instantiating variable v)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, StorageBufferBool) {
+ // var<storage> g : bool;
+ Global(Source{{56, 78}}, "g", ty.bool_(), ast::StorageClass::kStorage,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: Type 'bool' cannot be used in storage class 'storage' as it is non-host-shareable
+56:78 note: while instantiating variable g)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, StorageBufferPointer) {
+ // var<storage> g : ptr<private, f32>;
+ Global(Source{{56, 78}}, "g",
+ ty.pointer(ty.f32(), ast::StorageClass::kPrivate),
+ ast::StorageClass::kStorage,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: Type 'ptr<private, f32, read_write>' cannot be used in storage class 'storage' as it is non-host-shareable
+56:78 note: while instantiating variable g)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, StorageBufferIntScalar) {
+ // var<storage> g : i32;
+ Global(Source{{56, 78}}, "g", ty.i32(), ast::StorageClass::kStorage,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverStorageClassValidationTest, StorageBufferVector) {
+ // var<storage> g : vec4<f32>;
+ Global(Source{{56, 78}}, "g", ty.vec4<f32>(), ast::StorageClass::kStorage,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverStorageClassValidationTest, StorageBufferArray) {
+ // var<storage, read> g : array<S, 3>;
+ auto* s = Structure("S", {Member("a", ty.f32())});
+ auto* a = ty.array(ty.Of(s), 3);
+ Global(Source{{56, 78}}, "g", a, ast::StorageClass::kStorage,
+ ast::Access::kRead,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverStorageClassValidationTest, StorageBufferBoolAlias) {
+ // type a = bool;
+ // var<storage, read> g : a;
+ auto* a = Alias("a", ty.bool_());
+ Global(Source{{56, 78}}, "g", ty.Of(a), ast::StorageClass::kStorage,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: Type 'bool' cannot be used in storage class 'storage' as it is non-host-shareable
+56:78 note: while instantiating variable g)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, NotStorage_AccessMode) {
+ // var<private, read> g : a;
+ Global(Source{{56, 78}}, "g", ty.i32(), ast::StorageClass::kPrivate,
+ ast::Access::kRead);
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: only variables in <storage> storage class may declare an access mode)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, StorageBufferNoError_Basic) {
+ // [[block]] struct S { x : i32 };
+ // var<storage, read> g : S;
+ auto* s = Structure("S", {Member(Source{{12, 34}}, "x", ty.i32())},
+ {create<ast::StructBlockAttribute>()});
+ Global(Source{{56, 78}}, "g", ty.Of(s), ast::StorageClass::kStorage,
+ ast::Access::kRead,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_TRUE(r()->Resolve());
+}
+
+TEST_F(ResolverStorageClassValidationTest, StorageBufferNoError_Aliases) {
+ // [[block]] struct S { x : i32 };
+ // type a1 = S;
+ // var<storage, read> g : a1;
+ auto* s = Structure("S", {Member(Source{{12, 34}}, "x", ty.i32())},
+ {create<ast::StructBlockAttribute>()});
+ auto* a1 = Alias("a1", ty.Of(s));
+ auto* a2 = Alias("a2", ty.Of(a1));
+ Global(Source{{56, 78}}, "g", ty.Of(a2), ast::StorageClass::kStorage,
+ ast::Access::kRead,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_TRUE(r()->Resolve());
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBuffer_Struct_Runtime) {
+ // [[block]] struct S { m: array<f32>; };
+ // @group(0) @binding(0) var<uniform, > svar : S;
+
+ auto* s = Structure(Source{{12, 34}}, "S", {Member("m", ty.array<i32>())},
+ {create<ast::StructBlockAttribute>()});
+
+ Global(Source{{56, 78}}, "svar", ty.Of(s), ast::StorageClass::kUniform,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: runtime-sized arrays can only be used in the <storage> storage class
+note: while analysing structure member S.m
+56:78 note: while instantiating variable svar)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferBool) {
+ // var<uniform> g : bool;
+ Global(Source{{56, 78}}, "g", ty.bool_(), ast::StorageClass::kUniform,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: Type 'bool' cannot be used in storage class 'uniform' as it is non-host-shareable
+56:78 note: while instantiating variable g)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferPointer) {
+ // var<uniform> g : ptr<private, f32>;
+ Global(Source{{56, 78}}, "g",
+ ty.pointer(ty.f32(), ast::StorageClass::kPrivate),
+ ast::StorageClass::kUniform,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: Type 'ptr<private, f32, read_write>' cannot be used in storage class 'uniform' as it is non-host-shareable
+56:78 note: while instantiating variable g)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferIntScalar) {
+ // var<uniform> g : i32;
+ Global(Source{{56, 78}}, "g", ty.i32(), ast::StorageClass::kUniform,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferVector) {
+ // var<uniform> g : vec4<f32>;
+ Global(Source{{56, 78}}, "g", ty.vec4<f32>(), ast::StorageClass::kUniform,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferArray) {
+ // struct S {
+ // @size(16) f : f32;
+ // }
+ // var<uniform> g : array<S, 3>;
+ auto* s = Structure("S", {Member("a", ty.f32(), {MemberSize(16)})});
+ auto* a = ty.array(ty.Of(s), 3);
+ Global(Source{{56, 78}}, "g", a, ast::StorageClass::kUniform,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferBoolAlias) {
+ // type a = bool;
+ // var<uniform> g : a;
+ auto* a = Alias("a", ty.bool_());
+ Global(Source{{56, 78}}, "g", ty.Of(a), ast::StorageClass::kUniform,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: Type 'bool' cannot be used in storage class 'uniform' as it is non-host-shareable
+56:78 note: while instantiating variable g)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferNoError_Basic) {
+ // [[block]] struct S { x : i32 };
+ // var<uniform> g : S;
+ auto* s = Structure("S", {Member(Source{{12, 34}}, "x", ty.i32())},
+ {create<ast::StructBlockAttribute>()});
+ Global(Source{{56, 78}}, "g", ty.Of(s), ast::StorageClass::kUniform,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferNoError_Aliases) {
+ // [[block]] struct S { x : i32 };
+ // type a1 = S;
+ // var<uniform> g : a1;
+ auto* s = Structure("S", {Member(Source{{12, 34}}, "x", ty.i32())},
+ {create<ast::StructBlockAttribute>()});
+ auto* a1 = Alias("a1", ty.Of(s));
+ Global(Source{{56, 78}}, "g", ty.Of(a1), ast::StorageClass::kUniform,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/struct_layout_test.cc b/src/tint/resolver/struct_layout_test.cc
new file mode 100644
index 0000000..f8e76fd
--- /dev/null
+++ b/src/tint/resolver/struct_layout_test.cc
@@ -0,0 +1,410 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/ast/struct_block_attribute.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/struct.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ResolverStructLayoutTest = ResolverTest;
+
+TEST_F(ResolverStructLayoutTest, Scalars) {
+ auto* s = Structure("S", {
+ Member("a", ty.f32()),
+ Member("b", ty.u32()),
+ Member("c", ty.i32()),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->Size(), 12u);
+ EXPECT_EQ(sem->SizeNoPadding(), 12u);
+ EXPECT_EQ(sem->Align(), 4u);
+ ASSERT_EQ(sem->Members().size(), 3u);
+ EXPECT_EQ(sem->Members()[0]->Offset(), 0u);
+ EXPECT_EQ(sem->Members()[0]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[0]->Size(), 4u);
+ EXPECT_EQ(sem->Members()[1]->Offset(), 4u);
+ EXPECT_EQ(sem->Members()[1]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[1]->Size(), 4u);
+ EXPECT_EQ(sem->Members()[2]->Offset(), 8u);
+ EXPECT_EQ(sem->Members()[2]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[2]->Size(), 4u);
+}
+
+TEST_F(ResolverStructLayoutTest, Alias) {
+ auto* alias_a = Alias("a", ty.f32());
+ auto* alias_b = Alias("b", ty.f32());
+
+ auto* s = Structure("S", {
+ Member("a", ty.Of(alias_a)),
+ Member("b", ty.Of(alias_b)),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->Size(), 8u);
+ EXPECT_EQ(sem->SizeNoPadding(), 8u);
+ EXPECT_EQ(sem->Align(), 4u);
+ ASSERT_EQ(sem->Members().size(), 2u);
+ EXPECT_EQ(sem->Members()[0]->Offset(), 0u);
+ EXPECT_EQ(sem->Members()[0]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[0]->Size(), 4u);
+ EXPECT_EQ(sem->Members()[1]->Offset(), 4u);
+ EXPECT_EQ(sem->Members()[1]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[1]->Size(), 4u);
+}
+
+TEST_F(ResolverStructLayoutTest, ImplicitStrideArrayStaticSize) {
+ auto* s = Structure("S", {
+ Member("a", ty.array<i32, 3>()),
+ Member("b", ty.array<f32, 5>()),
+ Member("c", ty.array<f32, 1>()),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->Size(), 36u);
+ EXPECT_EQ(sem->SizeNoPadding(), 36u);
+ EXPECT_EQ(sem->Align(), 4u);
+ ASSERT_EQ(sem->Members().size(), 3u);
+ EXPECT_EQ(sem->Members()[0]->Offset(), 0u);
+ EXPECT_EQ(sem->Members()[0]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[0]->Size(), 12u);
+ EXPECT_EQ(sem->Members()[1]->Offset(), 12u);
+ EXPECT_EQ(sem->Members()[1]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[1]->Size(), 20u);
+ EXPECT_EQ(sem->Members()[2]->Offset(), 32u);
+ EXPECT_EQ(sem->Members()[2]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[2]->Size(), 4u);
+}
+
+TEST_F(ResolverStructLayoutTest, ExplicitStrideArrayStaticSize) {
+ auto* s = Structure("S", {
+ Member("a", ty.array<i32, 3>(/*stride*/ 8)),
+ Member("b", ty.array<f32, 5>(/*stride*/ 16)),
+ Member("c", ty.array<f32, 1>(/*stride*/ 32)),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->Size(), 136u);
+ EXPECT_EQ(sem->SizeNoPadding(), 136u);
+ EXPECT_EQ(sem->Align(), 4u);
+ ASSERT_EQ(sem->Members().size(), 3u);
+ EXPECT_EQ(sem->Members()[0]->Offset(), 0u);
+ EXPECT_EQ(sem->Members()[0]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[0]->Size(), 24u);
+ EXPECT_EQ(sem->Members()[1]->Offset(), 24u);
+ EXPECT_EQ(sem->Members()[1]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[1]->Size(), 80u);
+ EXPECT_EQ(sem->Members()[2]->Offset(), 104u);
+ EXPECT_EQ(sem->Members()[2]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[2]->Size(), 32u);
+}
+
+TEST_F(ResolverStructLayoutTest, ImplicitStrideArrayRuntimeSized) {
+ auto* s = Structure("S",
+ {
+ Member("c", ty.array<f32>()),
+ },
+ ast::AttributeList{create<ast::StructBlockAttribute>()});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->Size(), 4u);
+ EXPECT_EQ(sem->SizeNoPadding(), 4u);
+ EXPECT_EQ(sem->Align(), 4u);
+ ASSERT_EQ(sem->Members().size(), 1u);
+ EXPECT_EQ(sem->Members()[0]->Offset(), 0u);
+ EXPECT_EQ(sem->Members()[0]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[0]->Size(), 4u);
+}
+
+TEST_F(ResolverStructLayoutTest, ExplicitStrideArrayRuntimeSized) {
+ auto* s = Structure("S",
+ {
+ Member("c", ty.array<f32>(/*stride*/ 32)),
+ },
+ ast::AttributeList{create<ast::StructBlockAttribute>()});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->Size(), 32u);
+ EXPECT_EQ(sem->SizeNoPadding(), 32u);
+ EXPECT_EQ(sem->Align(), 4u);
+ ASSERT_EQ(sem->Members().size(), 1u);
+ EXPECT_EQ(sem->Members()[0]->Offset(), 0u);
+ EXPECT_EQ(sem->Members()[0]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[0]->Size(), 32u);
+}
+
+TEST_F(ResolverStructLayoutTest, ImplicitStrideArrayOfExplicitStrideArray) {
+ auto* inner = ty.array<i32, 2>(/*stride*/ 16); // size: 32
+ auto* outer = ty.array(inner, 12); // size: 12 * 32
+ auto* s = Structure("S", {
+ Member("c", outer),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->Size(), 384u);
+ EXPECT_EQ(sem->SizeNoPadding(), 384u);
+ EXPECT_EQ(sem->Align(), 4u);
+ ASSERT_EQ(sem->Members().size(), 1u);
+ EXPECT_EQ(sem->Members()[0]->Offset(), 0u);
+ EXPECT_EQ(sem->Members()[0]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[0]->Size(), 384u);
+}
+
+TEST_F(ResolverStructLayoutTest, ImplicitStrideArrayOfStructure) {
+ auto* inner = Structure("Inner", {
+ Member("a", ty.vec2<i32>()),
+ Member("b", ty.vec3<i32>()),
+ Member("c", ty.vec4<i32>()),
+ }); // size: 48
+ auto* outer = ty.array(ty.Of(inner), 12); // size: 12 * 48
+ auto* s = Structure("S", {
+ Member("c", outer),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->Size(), 576u);
+ EXPECT_EQ(sem->SizeNoPadding(), 576u);
+ EXPECT_EQ(sem->Align(), 16u);
+ ASSERT_EQ(sem->Members().size(), 1u);
+ EXPECT_EQ(sem->Members()[0]->Offset(), 0u);
+ EXPECT_EQ(sem->Members()[0]->Align(), 16u);
+ EXPECT_EQ(sem->Members()[0]->Size(), 576u);
+}
+
+TEST_F(ResolverStructLayoutTest, Vector) {
+ auto* s = Structure("S", {
+ Member("a", ty.vec2<i32>()),
+ Member("b", ty.vec3<i32>()),
+ Member("c", ty.vec4<i32>()),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->Size(), 48u);
+ EXPECT_EQ(sem->SizeNoPadding(), 48u);
+ EXPECT_EQ(sem->Align(), 16u);
+ ASSERT_EQ(sem->Members().size(), 3u);
+ EXPECT_EQ(sem->Members()[0]->Offset(), 0u); // vec2
+ EXPECT_EQ(sem->Members()[0]->Align(), 8u);
+ EXPECT_EQ(sem->Members()[0]->Size(), 8u);
+ EXPECT_EQ(sem->Members()[1]->Offset(), 16u); // vec3
+ EXPECT_EQ(sem->Members()[1]->Align(), 16u);
+ EXPECT_EQ(sem->Members()[1]->Size(), 12u);
+ EXPECT_EQ(sem->Members()[2]->Offset(), 32u); // vec4
+ EXPECT_EQ(sem->Members()[2]->Align(), 16u);
+ EXPECT_EQ(sem->Members()[2]->Size(), 16u);
+}
+
+TEST_F(ResolverStructLayoutTest, Matrix) {
+ auto* s = Structure("S", {
+ Member("a", ty.mat2x2<f32>()),
+ Member("b", ty.mat2x3<f32>()),
+ Member("c", ty.mat2x4<f32>()),
+ Member("d", ty.mat3x2<f32>()),
+ Member("e", ty.mat3x3<f32>()),
+ Member("f", ty.mat3x4<f32>()),
+ Member("g", ty.mat4x2<f32>()),
+ Member("h", ty.mat4x3<f32>()),
+ Member("i", ty.mat4x4<f32>()),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->Size(), 368u);
+ EXPECT_EQ(sem->SizeNoPadding(), 368u);
+ EXPECT_EQ(sem->Align(), 16u);
+ ASSERT_EQ(sem->Members().size(), 9u);
+ EXPECT_EQ(sem->Members()[0]->Offset(), 0u); // mat2x2
+ EXPECT_EQ(sem->Members()[0]->Align(), 8u);
+ EXPECT_EQ(sem->Members()[0]->Size(), 16u);
+ EXPECT_EQ(sem->Members()[1]->Offset(), 16u); // mat2x3
+ EXPECT_EQ(sem->Members()[1]->Align(), 16u);
+ EXPECT_EQ(sem->Members()[1]->Size(), 32u);
+ EXPECT_EQ(sem->Members()[2]->Offset(), 48u); // mat2x4
+ EXPECT_EQ(sem->Members()[2]->Align(), 16u);
+ EXPECT_EQ(sem->Members()[2]->Size(), 32u);
+ EXPECT_EQ(sem->Members()[3]->Offset(), 80u); // mat3x2
+ EXPECT_EQ(sem->Members()[3]->Align(), 8u);
+ EXPECT_EQ(sem->Members()[3]->Size(), 24u);
+ EXPECT_EQ(sem->Members()[4]->Offset(), 112u); // mat3x3
+ EXPECT_EQ(sem->Members()[4]->Align(), 16u);
+ EXPECT_EQ(sem->Members()[4]->Size(), 48u);
+ EXPECT_EQ(sem->Members()[5]->Offset(), 160u); // mat3x4
+ EXPECT_EQ(sem->Members()[5]->Align(), 16u);
+ EXPECT_EQ(sem->Members()[5]->Size(), 48u);
+ EXPECT_EQ(sem->Members()[6]->Offset(), 208u); // mat4x2
+ EXPECT_EQ(sem->Members()[6]->Align(), 8u);
+ EXPECT_EQ(sem->Members()[6]->Size(), 32u);
+ EXPECT_EQ(sem->Members()[7]->Offset(), 240u); // mat4x3
+ EXPECT_EQ(sem->Members()[7]->Align(), 16u);
+ EXPECT_EQ(sem->Members()[7]->Size(), 64u);
+ EXPECT_EQ(sem->Members()[8]->Offset(), 304u); // mat4x4
+ EXPECT_EQ(sem->Members()[8]->Align(), 16u);
+ EXPECT_EQ(sem->Members()[8]->Size(), 64u);
+}
+
+TEST_F(ResolverStructLayoutTest, NestedStruct) {
+ auto* inner = Structure("Inner", {
+ Member("a", ty.mat3x3<f32>()),
+ });
+ auto* s = Structure("S", {
+ Member("a", ty.i32()),
+ Member("b", ty.Of(inner)),
+ Member("c", ty.i32()),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->Size(), 80u);
+ EXPECT_EQ(sem->SizeNoPadding(), 68u);
+ EXPECT_EQ(sem->Align(), 16u);
+ ASSERT_EQ(sem->Members().size(), 3u);
+ EXPECT_EQ(sem->Members()[0]->Offset(), 0u);
+ EXPECT_EQ(sem->Members()[0]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[0]->Size(), 4u);
+ EXPECT_EQ(sem->Members()[1]->Offset(), 16u);
+ EXPECT_EQ(sem->Members()[1]->Align(), 16u);
+ EXPECT_EQ(sem->Members()[1]->Size(), 48u);
+ EXPECT_EQ(sem->Members()[2]->Offset(), 64u);
+ EXPECT_EQ(sem->Members()[2]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[2]->Size(), 4u);
+}
+
+TEST_F(ResolverStructLayoutTest, SizeAttributes) {
+ auto* inner = Structure("Inner", {
+ Member("a", ty.f32(), {MemberSize(8)}),
+ Member("b", ty.f32(), {MemberSize(16)}),
+ Member("c", ty.f32(), {MemberSize(8)}),
+ });
+ auto* s = Structure("S", {
+ Member("a", ty.f32(), {MemberSize(4)}),
+ Member("b", ty.u32(), {MemberSize(8)}),
+ Member("c", ty.Of(inner)),
+ Member("d", ty.i32(), {MemberSize(32)}),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->Size(), 76u);
+ EXPECT_EQ(sem->SizeNoPadding(), 76u);
+ EXPECT_EQ(sem->Align(), 4u);
+ ASSERT_EQ(sem->Members().size(), 4u);
+ EXPECT_EQ(sem->Members()[0]->Offset(), 0u);
+ EXPECT_EQ(sem->Members()[0]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[0]->Size(), 4u);
+ EXPECT_EQ(sem->Members()[1]->Offset(), 4u);
+ EXPECT_EQ(sem->Members()[1]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[1]->Size(), 8u);
+ EXPECT_EQ(sem->Members()[2]->Offset(), 12u);
+ EXPECT_EQ(sem->Members()[2]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[2]->Size(), 32u);
+ EXPECT_EQ(sem->Members()[3]->Offset(), 44u);
+ EXPECT_EQ(sem->Members()[3]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[3]->Size(), 32u);
+}
+
+TEST_F(ResolverStructLayoutTest, AlignAttributes) {
+ auto* inner = Structure("Inner", {
+ Member("a", ty.f32(), {MemberAlign(8)}),
+ Member("b", ty.f32(), {MemberAlign(16)}),
+ Member("c", ty.f32(), {MemberAlign(4)}),
+ });
+ auto* s = Structure("S", {
+ Member("a", ty.f32(), {MemberAlign(4)}),
+ Member("b", ty.u32(), {MemberAlign(8)}),
+ Member("c", ty.Of(inner)),
+ Member("d", ty.i32(), {MemberAlign(32)}),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->Size(), 96u);
+ EXPECT_EQ(sem->SizeNoPadding(), 68u);
+ EXPECT_EQ(sem->Align(), 32u);
+ ASSERT_EQ(sem->Members().size(), 4u);
+ EXPECT_EQ(sem->Members()[0]->Offset(), 0u);
+ EXPECT_EQ(sem->Members()[0]->Align(), 4u);
+ EXPECT_EQ(sem->Members()[0]->Size(), 4u);
+ EXPECT_EQ(sem->Members()[1]->Offset(), 8u);
+ EXPECT_EQ(sem->Members()[1]->Align(), 8u);
+ EXPECT_EQ(sem->Members()[1]->Size(), 4u);
+ EXPECT_EQ(sem->Members()[2]->Offset(), 16u);
+ EXPECT_EQ(sem->Members()[2]->Align(), 16u);
+ EXPECT_EQ(sem->Members()[2]->Size(), 32u);
+ EXPECT_EQ(sem->Members()[3]->Offset(), 64u);
+ EXPECT_EQ(sem->Members()[3]->Align(), 32u);
+ EXPECT_EQ(sem->Members()[3]->Size(), 4u);
+}
+
+TEST_F(ResolverStructLayoutTest, StructWithLotsOfPadding) {
+ auto* s = Structure("S", {
+ Member("a", ty.i32(), {MemberAlign(1024)}),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->Size(), 1024u);
+ EXPECT_EQ(sem->SizeNoPadding(), 4u);
+ EXPECT_EQ(sem->Align(), 1024u);
+ ASSERT_EQ(sem->Members().size(), 1u);
+ EXPECT_EQ(sem->Members()[0]->Offset(), 0u);
+ EXPECT_EQ(sem->Members()[0]->Align(), 1024u);
+ EXPECT_EQ(sem->Members()[0]->Size(), 4u);
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/struct_pipeline_stage_use_test.cc b/src/tint/resolver/struct_pipeline_stage_use_test.cc
new file mode 100644
index 0000000..e13b5f2
--- /dev/null
+++ b/src/tint/resolver/struct_pipeline_stage_use_test.cc
@@ -0,0 +1,191 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/ast/stage_attribute.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/struct.h"
+
+using ::testing::UnorderedElementsAre;
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ResolverPipelineStageUseTest = ResolverTest;
+
+TEST_F(ResolverPipelineStageUseTest, UnusedStruct) {
+ auto* s = Structure("S", {Member("a", ty.f32(), {Location(0)})});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->PipelineStageUses().empty());
+}
+
+TEST_F(ResolverPipelineStageUseTest, StructUsedAsNonEntryPointParam) {
+ auto* s = Structure("S", {Member("a", ty.f32(), {Location(0)})});
+
+ Func("foo", {Param("param", ty.Of(s))}, ty.void_(), {}, {});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->PipelineStageUses().empty());
+}
+
+TEST_F(ResolverPipelineStageUseTest, StructUsedAsNonEntryPointReturnType) {
+ auto* s = Structure("S", {Member("a", ty.f32(), {Location(0)})});
+
+ Func("foo", {}, ty.Of(s), {Return(Construct(ty.Of(s), Expr(0.f)))}, {});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->PipelineStageUses().empty());
+}
+
+TEST_F(ResolverPipelineStageUseTest, StructUsedAsVertexShaderParam) {
+ auto* s = Structure("S", {Member("a", ty.f32(), {Location(0)})});
+
+ Func("main", {Param("param", ty.Of(s))}, ty.vec4<f32>(),
+ {Return(Construct(ty.vec4<f32>()))},
+ {Stage(ast::PipelineStage::kVertex)},
+ {Builtin(ast::Builtin::kPosition)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->PipelineStageUses(),
+ UnorderedElementsAre(sem::PipelineStageUsage::kVertexInput));
+}
+
+TEST_F(ResolverPipelineStageUseTest, StructUsedAsVertexShaderReturnType) {
+ auto* s = Structure(
+ "S", {Member("a", ty.vec4<f32>(), {Builtin(ast::Builtin::kPosition)})});
+
+ Func("main", {}, ty.Of(s), {Return(Construct(ty.Of(s)))},
+ {Stage(ast::PipelineStage::kVertex)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->PipelineStageUses(),
+ UnorderedElementsAre(sem::PipelineStageUsage::kVertexOutput));
+}
+
+TEST_F(ResolverPipelineStageUseTest, StructUsedAsFragmentShaderParam) {
+ auto* s = Structure("S", {Member("a", ty.f32(), {Location(0)})});
+
+ Func("main", {Param("param", ty.Of(s))}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->PipelineStageUses(),
+ UnorderedElementsAre(sem::PipelineStageUsage::kFragmentInput));
+}
+
+TEST_F(ResolverPipelineStageUseTest, StructUsedAsFragmentShaderReturnType) {
+ auto* s = Structure("S", {Member("a", ty.f32(), {Location(0)})});
+
+ Func("main", {}, ty.Of(s), {Return(Construct(ty.Of(s), Expr(0.f)))},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->PipelineStageUses(),
+ UnorderedElementsAre(sem::PipelineStageUsage::kFragmentOutput));
+}
+
+TEST_F(ResolverPipelineStageUseTest, StructUsedAsComputeShaderParam) {
+ auto* s = Structure(
+ "S",
+ {Member("a", ty.u32(), {Builtin(ast::Builtin::kLocalInvocationIndex)})});
+
+ Func("main", {Param("param", ty.Of(s))}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->PipelineStageUses(),
+ UnorderedElementsAre(sem::PipelineStageUsage::kComputeInput));
+}
+
+TEST_F(ResolverPipelineStageUseTest, StructUsedMultipleStages) {
+ auto* s = Structure(
+ "S", {Member("a", ty.vec4<f32>(), {Builtin(ast::Builtin::kPosition)})});
+
+ Func("vert_main", {}, ty.Of(s), {Return(Construct(ty.Of(s)))},
+ {Stage(ast::PipelineStage::kVertex)});
+
+ Func("frag_main", {Param("param", ty.Of(s))}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->PipelineStageUses(),
+ UnorderedElementsAre(sem::PipelineStageUsage::kVertexOutput,
+ sem::PipelineStageUsage::kFragmentInput));
+}
+
+TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderParamViaAlias) {
+ auto* s = Structure("S", {Member("a", ty.f32(), {Location(0)})});
+ auto* s_alias = Alias("S_alias", ty.Of(s));
+
+ Func("main", {Param("param", ty.Of(s_alias))}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->PipelineStageUses(),
+ UnorderedElementsAre(sem::PipelineStageUsage::kFragmentInput));
+}
+
+TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderReturnTypeViaAlias) {
+ auto* s = Structure("S", {Member("a", ty.f32(), {Location(0)})});
+ auto* s_alias = Alias("S_alias", ty.Of(s));
+
+ Func("main", {}, ty.Of(s_alias),
+ {Return(Construct(ty.Of(s_alias), Expr(0.f)))},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->PipelineStageUses(),
+ UnorderedElementsAre(sem::PipelineStageUsage::kFragmentOutput));
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/struct_storage_class_use_test.cc b/src/tint/resolver/struct_storage_class_use_test.cc
new file mode 100644
index 0000000..2c0e9cf
--- /dev/null
+++ b/src/tint/resolver/struct_storage_class_use_test.cc
@@ -0,0 +1,197 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/ast/struct_block_attribute.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/struct.h"
+
+using ::testing::UnorderedElementsAre;
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ResolverStorageClassUseTest = ResolverTest;
+
+TEST_F(ResolverStorageClassUseTest, UnreachableStruct) {
+ auto* s = Structure("S", {Member("a", ty.f32())});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->StorageClassUsage().empty());
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableFromParameter) {
+ auto* s = Structure("S", {Member("a", ty.f32())});
+
+ Func("f", {Param("param", ty.Of(s))}, ty.void_(), {}, {});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->StorageClassUsage(),
+ UnorderedElementsAre(ast::StorageClass::kNone));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableFromReturnType) {
+ auto* s = Structure("S", {Member("a", ty.f32())});
+
+ Func("f", {}, ty.Of(s), {Return(Construct(ty.Of(s)))}, {});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->StorageClassUsage(),
+ UnorderedElementsAre(ast::StorageClass::kNone));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableFromGlobal) {
+ auto* s = Structure("S", {Member("a", ty.f32())});
+
+ Global("g", ty.Of(s), ast::StorageClass::kPrivate);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->StorageClassUsage(),
+ UnorderedElementsAre(ast::StorageClass::kPrivate));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalAlias) {
+ auto* s = Structure("S", {Member("a", ty.f32())});
+ auto* a = Alias("A", ty.Of(s));
+ Global("g", ty.Of(a), ast::StorageClass::kPrivate);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->StorageClassUsage(),
+ UnorderedElementsAre(ast::StorageClass::kPrivate));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalStruct) {
+ auto* s = Structure("S", {Member("a", ty.f32())});
+ auto* o = Structure("O", {Member("a", ty.Of(s))});
+ Global("g", ty.Of(o), ast::StorageClass::kPrivate);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->StorageClassUsage(),
+ UnorderedElementsAre(ast::StorageClass::kPrivate));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalArray) {
+ auto* s = Structure("S", {Member("a", ty.f32())});
+ auto* a = ty.array(ty.Of(s), 3);
+ Global("g", a, ast::StorageClass::kPrivate);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->StorageClassUsage(),
+ UnorderedElementsAre(ast::StorageClass::kPrivate));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableFromLocal) {
+ auto* s = Structure("S", {Member("a", ty.f32())});
+
+ WrapInFunction(Var("g", ty.Of(s)));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->StorageClassUsage(),
+ UnorderedElementsAre(ast::StorageClass::kFunction));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableViaLocalAlias) {
+ auto* s = Structure("S", {Member("a", ty.f32())});
+ auto* a = Alias("A", ty.Of(s));
+ WrapInFunction(Var("g", ty.Of(a)));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->StorageClassUsage(),
+ UnorderedElementsAre(ast::StorageClass::kFunction));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableViaLocalStruct) {
+ auto* s = Structure("S", {Member("a", ty.f32())});
+ auto* o = Structure("O", {Member("a", ty.Of(s))});
+ WrapInFunction(Var("g", ty.Of(o)));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->StorageClassUsage(),
+ UnorderedElementsAre(ast::StorageClass::kFunction));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableViaLocalArray) {
+ auto* s = Structure("S", {Member("a", ty.f32())});
+ auto* a = ty.array(ty.Of(s), 3);
+ WrapInFunction(Var("g", a));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->StorageClassUsage(),
+ UnorderedElementsAre(ast::StorageClass::kFunction));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructMultipleStorageClassUses) {
+ auto* s = Structure("S", {Member("a", ty.f32())},
+ {create<ast::StructBlockAttribute>()});
+ Global("x", ty.Of(s), ast::StorageClass::kUniform,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+ Global("y", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(1),
+ create<ast::GroupAttribute>(0),
+ });
+ WrapInFunction(Var("g", ty.Of(s)));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = TypeOf(s)->As<sem::Struct>();
+ ASSERT_NE(sem, nullptr);
+ EXPECT_THAT(sem->StorageClassUsage(),
+ UnorderedElementsAre(ast::StorageClass::kUniform,
+ ast::StorageClass::kStorage,
+ ast::StorageClass::kFunction));
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/type_constructor_validation_test.cc b/src/tint/resolver/type_constructor_validation_test.cc
new file mode 100644
index 0000000..21bc698
--- /dev/null
+++ b/src/tint/resolver/type_constructor_validation_test.cc
@@ -0,0 +1,2937 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gmock/gmock.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/reference_type.h"
+#include "src/tint/sem/type_constructor.h"
+#include "src/tint/sem/type_conversion.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ::testing::HasSubstr;
+
+// Helpers and typedefs
+using builder::alias;
+using builder::alias1;
+using builder::alias2;
+using builder::alias3;
+using builder::CreatePtrs;
+using builder::CreatePtrsFor;
+using builder::DataType;
+using builder::f32;
+using builder::i32;
+using builder::mat2x2;
+using builder::mat2x3;
+using builder::mat3x2;
+using builder::mat3x3;
+using builder::mat4x4;
+using builder::u32;
+using builder::vec2;
+using builder::vec3;
+using builder::vec4;
+
+class ResolverTypeConstructorValidationTest : public resolver::TestHelper,
+ public testing::Test {};
+
+namespace InferTypeTest {
+struct Params {
+ builder::ast_type_func_ptr create_rhs_ast_type;
+ builder::ast_expr_func_ptr create_rhs_ast_value;
+ builder::sem_type_func_ptr create_rhs_sem_type;
+};
+
+template <typename T>
+constexpr Params ParamsFor() {
+ return Params{DataType<T>::AST, DataType<T>::Expr, DataType<T>::Sem};
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, InferTypeTest_Simple) {
+ // var a = 1;
+ // var b = a;
+ auto* a = Var("a", nullptr, ast::StorageClass::kNone, Expr(1));
+ auto* b = Var("b", nullptr, ast::StorageClass::kNone, Expr("a"));
+ auto* a_ident = Expr("a");
+ auto* b_ident = Expr("b");
+
+ WrapInFunction(a, b, Assign(a_ident, "a"), Assign(b_ident, "b"));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(TypeOf(a_ident)->Is<sem::Reference>());
+ EXPECT_TRUE(
+ TypeOf(a_ident)->As<sem::Reference>()->StoreType()->Is<sem::I32>());
+ EXPECT_EQ(TypeOf(a_ident)->As<sem::Reference>()->StorageClass(),
+ ast::StorageClass::kFunction);
+ ASSERT_TRUE(TypeOf(b_ident)->Is<sem::Reference>());
+ EXPECT_TRUE(
+ TypeOf(b_ident)->As<sem::Reference>()->StoreType()->Is<sem::I32>());
+ EXPECT_EQ(TypeOf(b_ident)->As<sem::Reference>()->StorageClass(),
+ ast::StorageClass::kFunction);
+}
+
+using InferTypeTest_FromConstructorExpression = ResolverTestWithParam<Params>;
+TEST_P(InferTypeTest_FromConstructorExpression, All) {
+ // e.g. for vec3<f32>
+ // {
+ // var a = vec3<f32>(0.0, 0.0, 0.0)
+ // }
+ auto& params = GetParam();
+
+ auto* constructor_expr = params.create_rhs_ast_value(*this, 0);
+
+ auto* a = Var("a", nullptr, ast::StorageClass::kNone, constructor_expr);
+ // Self-assign 'a' to force the expression to be resolved so we can test its
+ // type below
+ auto* a_ident = Expr("a");
+ WrapInFunction(Decl(a), Assign(a_ident, "a"));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ auto* got = TypeOf(a_ident);
+ auto* expected = create<sem::Reference>(params.create_rhs_sem_type(*this),
+ ast::StorageClass::kFunction,
+ ast::Access::kReadWrite);
+ ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
+ << "expected: " << FriendlyName(expected) << "\n";
+}
+
+static constexpr Params from_constructor_expression_cases[] = {
+ ParamsFor<bool>(),
+ ParamsFor<i32>(),
+ ParamsFor<u32>(),
+ ParamsFor<f32>(),
+ ParamsFor<vec3<i32>>(),
+ ParamsFor<vec3<u32>>(),
+ ParamsFor<vec3<f32>>(),
+ ParamsFor<mat3x3<f32>>(),
+ ParamsFor<alias<bool>>(),
+ ParamsFor<alias<i32>>(),
+ ParamsFor<alias<u32>>(),
+ ParamsFor<alias<f32>>(),
+ ParamsFor<alias<vec3<i32>>>(),
+ ParamsFor<alias<vec3<u32>>>(),
+ ParamsFor<alias<vec3<f32>>>(),
+ ParamsFor<alias<mat3x3<f32>>>(),
+};
+INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
+ InferTypeTest_FromConstructorExpression,
+ testing::ValuesIn(from_constructor_expression_cases));
+
+using InferTypeTest_FromArithmeticExpression = ResolverTestWithParam<Params>;
+TEST_P(InferTypeTest_FromArithmeticExpression, All) {
+ // e.g. for vec3<f32>
+ // {
+ // var a = vec3<f32>(2.0, 2.0, 2.0) * 3.0;
+ // }
+ auto& params = GetParam();
+
+ auto* arith_lhs_expr = params.create_rhs_ast_value(*this, 2);
+ auto* arith_rhs_expr = params.create_rhs_ast_value(*this, 3);
+ auto* constructor_expr = Mul(arith_lhs_expr, arith_rhs_expr);
+
+ auto* a = Var("a", nullptr, constructor_expr);
+ // Self-assign 'a' to force the expression to be resolved so we can test its
+ // type below
+ auto* a_ident = Expr("a");
+ WrapInFunction(Decl(a), Assign(a_ident, "a"));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ auto* got = TypeOf(a_ident);
+ auto* expected = create<sem::Reference>(params.create_rhs_sem_type(*this),
+ ast::StorageClass::kFunction,
+ ast::Access::kReadWrite);
+ ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
+ << "expected: " << FriendlyName(expected) << "\n";
+}
+static constexpr Params from_arithmetic_expression_cases[] = {
+ ParamsFor<i32>(), ParamsFor<u32>(), ParamsFor<f32>(),
+ ParamsFor<vec3<f32>>(), ParamsFor<mat3x3<f32>>(),
+
+ // TODO(amaiorano): Uncomment once https://crbug.com/tint/680 is fixed
+ // ParamsFor<alias<ty_i32>>(),
+ // ParamsFor<alias<ty_u32>>(),
+ // ParamsFor<alias<ty_f32>>(),
+ // ParamsFor<alias<ty_vec3<f32>>>(),
+ // ParamsFor<alias<ty_mat3x3<f32>>>(),
+};
+INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
+ InferTypeTest_FromArithmeticExpression,
+ testing::ValuesIn(from_arithmetic_expression_cases));
+
+using InferTypeTest_FromCallExpression = ResolverTestWithParam<Params>;
+TEST_P(InferTypeTest_FromCallExpression, All) {
+ // e.g. for vec3<f32>
+ //
+ // fn foo() -> vec3<f32> {
+ // return vec3<f32>();
+ // }
+ //
+ // fn bar()
+ // {
+ // var a = foo();
+ // }
+ auto& params = GetParam();
+
+ Func("foo", {}, params.create_rhs_ast_type(*this),
+ {Return(Construct(params.create_rhs_ast_type(*this)))}, {});
+
+ auto* a = Var("a", nullptr, Call("foo"));
+ // Self-assign 'a' to force the expression to be resolved so we can test its
+ // type below
+ auto* a_ident = Expr("a");
+ WrapInFunction(Decl(a), Assign(a_ident, "a"));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ auto* got = TypeOf(a_ident);
+ auto* expected = create<sem::Reference>(params.create_rhs_sem_type(*this),
+ ast::StorageClass::kFunction,
+ ast::Access::kReadWrite);
+ ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
+ << "expected: " << FriendlyName(expected) << "\n";
+}
+static constexpr Params from_call_expression_cases[] = {
+ ParamsFor<bool>(),
+ ParamsFor<i32>(),
+ ParamsFor<u32>(),
+ ParamsFor<f32>(),
+ ParamsFor<vec3<i32>>(),
+ ParamsFor<vec3<u32>>(),
+ ParamsFor<vec3<f32>>(),
+ ParamsFor<mat3x3<f32>>(),
+ ParamsFor<alias<bool>>(),
+ ParamsFor<alias<i32>>(),
+ ParamsFor<alias<u32>>(),
+ ParamsFor<alias<f32>>(),
+ ParamsFor<alias<vec3<i32>>>(),
+ ParamsFor<alias<vec3<u32>>>(),
+ ParamsFor<alias<vec3<f32>>>(),
+ ParamsFor<alias<mat3x3<f32>>>(),
+};
+INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
+ InferTypeTest_FromCallExpression,
+ testing::ValuesIn(from_call_expression_cases));
+
+} // namespace InferTypeTest
+
+namespace ConversionConstructTest {
+enum class Kind {
+ Construct,
+ Conversion,
+};
+
+struct Params {
+ Kind kind;
+ builder::ast_type_func_ptr lhs_type;
+ builder::ast_type_func_ptr rhs_type;
+ builder::ast_expr_func_ptr rhs_value_expr;
+};
+
+template <typename LhsType, typename RhsType>
+constexpr Params ParamsFor(Kind kind) {
+ return Params{kind, DataType<LhsType>::AST, DataType<RhsType>::AST,
+ DataType<RhsType>::Expr};
+}
+
+static constexpr Params valid_cases[] = {
+ // Direct init (non-conversions)
+ ParamsFor<bool, bool>(Kind::Construct), //
+ ParamsFor<i32, i32>(Kind::Construct), //
+ ParamsFor<u32, u32>(Kind::Construct), //
+ ParamsFor<f32, f32>(Kind::Construct), //
+ ParamsFor<vec3<bool>, vec3<bool>>(Kind::Construct), //
+ ParamsFor<vec3<i32>, vec3<i32>>(Kind::Construct), //
+ ParamsFor<vec3<u32>, vec3<u32>>(Kind::Construct), //
+ ParamsFor<vec3<f32>, vec3<f32>>(Kind::Construct), //
+
+ // Splat
+ ParamsFor<vec3<bool>, bool>(Kind::Construct), //
+ ParamsFor<vec3<i32>, i32>(Kind::Construct), //
+ ParamsFor<vec3<u32>, u32>(Kind::Construct), //
+ ParamsFor<vec3<f32>, f32>(Kind::Construct), //
+
+ // Conversion
+ ParamsFor<bool, u32>(Kind::Conversion), //
+ ParamsFor<bool, i32>(Kind::Conversion), //
+ ParamsFor<bool, f32>(Kind::Conversion), //
+
+ ParamsFor<i32, bool>(Kind::Conversion), //
+ ParamsFor<i32, u32>(Kind::Conversion), //
+ ParamsFor<i32, f32>(Kind::Conversion), //
+
+ ParamsFor<u32, bool>(Kind::Conversion), //
+ ParamsFor<u32, i32>(Kind::Conversion), //
+ ParamsFor<u32, f32>(Kind::Conversion), //
+
+ ParamsFor<f32, bool>(Kind::Conversion), //
+ ParamsFor<f32, u32>(Kind::Conversion), //
+ ParamsFor<f32, i32>(Kind::Conversion), //
+
+ ParamsFor<vec3<bool>, vec3<u32>>(Kind::Conversion), //
+ ParamsFor<vec3<bool>, vec3<i32>>(Kind::Conversion), //
+ ParamsFor<vec3<bool>, vec3<f32>>(Kind::Conversion), //
+
+ ParamsFor<vec3<i32>, vec3<bool>>(Kind::Conversion), //
+ ParamsFor<vec3<i32>, vec3<u32>>(Kind::Conversion), //
+ ParamsFor<vec3<i32>, vec3<f32>>(Kind::Conversion), //
+
+ ParamsFor<vec3<u32>, vec3<bool>>(Kind::Conversion), //
+ ParamsFor<vec3<u32>, vec3<i32>>(Kind::Conversion), //
+ ParamsFor<vec3<u32>, vec3<f32>>(Kind::Conversion), //
+
+ ParamsFor<vec3<f32>, vec3<bool>>(Kind::Conversion), //
+ ParamsFor<vec3<f32>, vec3<u32>>(Kind::Conversion), //
+ ParamsFor<vec3<f32>, vec3<i32>>(Kind::Conversion), //
+};
+
+using ConversionConstructorValidTest = ResolverTestWithParam<Params>;
+TEST_P(ConversionConstructorValidTest, All) {
+ auto& params = GetParam();
+
+ // var a : <lhs_type1> = <lhs_type2>(<rhs_type>(<rhs_value_expr>));
+ auto* lhs_type1 = params.lhs_type(*this);
+ auto* lhs_type2 = params.lhs_type(*this);
+ auto* rhs_type = params.rhs_type(*this);
+ auto* rhs_value_expr = params.rhs_value_expr(*this, 0);
+
+ std::stringstream ss;
+ ss << FriendlyName(lhs_type1) << " = " << FriendlyName(lhs_type2) << "("
+ << FriendlyName(rhs_type) << "(<rhs value expr>))";
+ SCOPED_TRACE(ss.str());
+
+ auto* arg = Construct(rhs_type, rhs_value_expr);
+ auto* tc = Construct(lhs_type2, arg);
+ auto* a = Var("a", lhs_type1, ast::StorageClass::kNone, tc);
+
+ // Self-assign 'a' to force the expression to be resolved so we can test its
+ // type below
+ auto* a_ident = Expr("a");
+ WrapInFunction(Decl(a), Assign(a_ident, "a"));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ switch (params.kind) {
+ case Kind::Construct: {
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_EQ(ctor->Parameters()[0]->Type(), TypeOf(arg));
+ break;
+ }
+ case Kind::Conversion: {
+ auto* conv = call->Target()->As<sem::TypeConversion>();
+ ASSERT_NE(conv, nullptr);
+ EXPECT_EQ(call->Type(), conv->ReturnType());
+ ASSERT_EQ(conv->Parameters().size(), 1u);
+ EXPECT_EQ(conv->Parameters()[0]->Type(), TypeOf(arg));
+ break;
+ }
+ }
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
+ ConversionConstructorValidTest,
+ testing::ValuesIn(valid_cases));
+
+constexpr CreatePtrs all_types[] = {
+ CreatePtrsFor<bool>(), //
+ CreatePtrsFor<u32>(), //
+ CreatePtrsFor<i32>(), //
+ CreatePtrsFor<f32>(), //
+ CreatePtrsFor<vec3<bool>>(), //
+ CreatePtrsFor<vec3<i32>>(), //
+ CreatePtrsFor<vec3<u32>>(), //
+ CreatePtrsFor<vec3<f32>>(), //
+ CreatePtrsFor<mat3x3<i32>>(), //
+ CreatePtrsFor<mat3x3<u32>>(), //
+ CreatePtrsFor<mat3x3<f32>>(), //
+ CreatePtrsFor<mat2x3<i32>>(), //
+ CreatePtrsFor<mat2x3<u32>>(), //
+ CreatePtrsFor<mat2x3<f32>>(), //
+ CreatePtrsFor<mat3x2<i32>>(), //
+ CreatePtrsFor<mat3x2<u32>>(), //
+ CreatePtrsFor<mat3x2<f32>>() //
+};
+
+using ConversionConstructorInvalidTest =
+ ResolverTestWithParam<std::tuple<CreatePtrs, // lhs
+ CreatePtrs // rhs
+ >>;
+TEST_P(ConversionConstructorInvalidTest, All) {
+ auto& params = GetParam();
+
+ auto& lhs_params = std::get<0>(params);
+ auto& rhs_params = std::get<1>(params);
+
+ // Skip test for valid cases
+ for (auto& v : valid_cases) {
+ if (v.lhs_type == lhs_params.ast && v.rhs_type == rhs_params.ast &&
+ v.rhs_value_expr == rhs_params.expr) {
+ return;
+ }
+ }
+ // Skip non-conversions
+ if (lhs_params.ast == rhs_params.ast) {
+ return;
+ }
+
+ // var a : <lhs_type1> = <lhs_type2>(<rhs_type>(<rhs_value_expr>));
+ auto* lhs_type1 = lhs_params.ast(*this);
+ auto* lhs_type2 = lhs_params.ast(*this);
+ auto* rhs_type = rhs_params.ast(*this);
+ auto* rhs_value_expr = rhs_params.expr(*this, 0);
+
+ std::stringstream ss;
+ ss << FriendlyName(lhs_type1) << " = " << FriendlyName(lhs_type2) << "("
+ << FriendlyName(rhs_type) << "(<rhs value expr>))";
+ SCOPED_TRACE(ss.str());
+
+ auto* a = Var("a", lhs_type1, ast::StorageClass::kNone,
+ Construct(lhs_type2, Construct(rhs_type, rhs_value_expr)));
+
+ // Self-assign 'a' to force the expression to be resolved so we can test its
+ // type below
+ auto* a_ident = Expr("a");
+ WrapInFunction(Decl(a), Assign(a_ident, "a"));
+
+ ASSERT_FALSE(r()->Resolve());
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
+ ConversionConstructorInvalidTest,
+ testing::Combine(testing::ValuesIn(all_types),
+ testing::ValuesIn(all_types)));
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ ConversionConstructorInvalid_TooManyInitializers) {
+ auto* a = Var("a", ty.f32(), ast::StorageClass::kNone,
+ Construct(Source{{12, 34}}, ty.f32(), Expr(1.0f), Expr(2.0f)));
+ WrapInFunction(a);
+
+ ASSERT_FALSE(r()->Resolve());
+ ASSERT_EQ(r()->error(),
+ "12:34 error: expected zero or one value in constructor, got 2");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ ConversionConstructorInvalid_InvalidInitializer) {
+ auto* a =
+ Var("a", ty.f32(), ast::StorageClass::kNone,
+ Construct(Source{{12, 34}}, ty.f32(), Construct(ty.array<f32, 4>())));
+ WrapInFunction(a);
+
+ ASSERT_FALSE(r()->Resolve());
+ ASSERT_EQ(r()->error(),
+ "12:34 error: cannot construct 'f32' with a value of type "
+ "'array<f32, 4>'");
+}
+
+} // namespace ConversionConstructTest
+
+namespace ArrayConstructor {
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Array_ZeroValue_Pass) {
+ // array<u32, 10>();
+ auto* tc = array<u32, 10>();
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ EXPECT_TRUE(call->Type()->Is<sem::Array>());
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 0u);
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Array_type_match) {
+ // array<u32, 3>(0u, 10u. 20u);
+ auto* tc = array<u32, 3>(Expr(0u), Expr(10u), Expr(20u));
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ EXPECT_TRUE(call->Type()->Is<sem::Array>());
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 3u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::U32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::U32>());
+ EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::U32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Array_type_Mismatch_U32F32) {
+ // array<u32, 3>(0u, 1.0f, 20u);
+ auto* tc = array<u32, 3>(Expr(0u), Expr(Source{{12, 34}}, 1.0f), Expr(20u));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in array constructor does not match array type: "
+ "expected 'u32', found 'f32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Array_ScalarArgumentTypeMismatch_F32I32) {
+ // array<f32, 1>(1);
+ auto* tc = array<f32, 1>(Expr(Source{{12, 34}}, 1));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in array constructor does not match array type: "
+ "expected 'f32', found 'i32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Array_ScalarArgumentTypeMismatch_U32I32) {
+ // array<u32, 6>(1, 0u, 0u, 0u, 0u, 0u);
+ auto* tc = array<u32, 1>(Expr(Source{{12, 34}}, 1), Expr(0u), Expr(0u),
+ Expr(0u), Expr(0u));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in array constructor does not match array type: "
+ "expected 'u32', found 'i32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Array_ScalarArgumentTypeMismatch_Vec2) {
+ // array<i32, 3>(1, vec2<i32>());
+ auto* tc =
+ array<i32, 3>(Expr(1), Construct(Source{{12, 34}}, ty.vec2<i32>()));
+ WrapInFunction(tc);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in array constructor does not match array type: "
+ "expected 'i32', found 'vec2<i32>'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_ArrayOfVector_SubElemTypeMismatch_I32U32) {
+ // array<vec3<i32>, 2>(vec3<i32>(), vec3<u32>());
+ auto* e0 = vec3<i32>();
+ SetSource(Source::Location({12, 34}));
+ auto* e1 = vec3<u32>();
+ auto* t = Construct(ty.array(ty.vec3<i32>(), 2), e0, e1);
+ WrapInFunction(t);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in array constructor does not match array type: "
+ "expected 'vec3<i32>', found 'vec3<u32>'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_ArrayOfVector_SubElemTypeMismatch_I32Bool) {
+ // array<vec3<i32>, 2>(vec3<i32>(), vec3<bool>(true, true, false));
+ SetSource(Source::Location({12, 34}));
+ auto* e0 = vec3<bool>(true, true, false);
+ auto* e1 = vec3<i32>();
+ auto* t = Construct(ty.array(ty.vec3<i32>(), 2), e0, e1);
+ WrapInFunction(t);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in array constructor does not match array type: "
+ "expected 'vec3<i32>', found 'vec3<bool>'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_ArrayOfArray_SubElemSizeMismatch) {
+ // array<array<i32, 2>, 2>(array<i32, 3>(), array<i32, 2>());
+ SetSource(Source::Location({12, 34}));
+ auto* e0 = array<i32, 3>();
+ auto* e1 = array<i32, 2>();
+ auto* t = Construct(ty.array(ty.array<i32, 2>(), 2), e0, e1);
+ WrapInFunction(t);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in array constructor does not match array type: "
+ "expected 'array<i32, 2>', found 'array<i32, 3>'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_ArrayOfArray_SubElemTypeMismatch) {
+ // array<array<i32, 2>, 2>(array<i32, 2>(), array<u32, 2>());
+ auto* e0 = array<i32, 2>();
+ SetSource(Source::Location({12, 34}));
+ auto* e1 = array<u32, 2>();
+ auto* t = Construct(ty.array(ty.array<i32, 2>(), 2), e0, e1);
+ WrapInFunction(t);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in array constructor does not match array type: "
+ "expected 'array<i32, 2>', found 'array<u32, 2>'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Array_TooFewElements) {
+ // array<i32, 4>(1, 2, 3);
+ SetSource(Source::Location({12, 34}));
+ auto* tc = array<i32, 4>(Expr(1), Expr(2), Expr(3));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: array constructor has too few elements: expected 4, "
+ "found 3");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Array_TooManyElements) {
+ // array<i32, 4>(1, 2, 3, 4, 5);
+ SetSource(Source::Location({12, 34}));
+ auto* tc = array<i32, 4>(Expr(1), Expr(2), Expr(3), Expr(4), Expr(5));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: array constructor has too many "
+ "elements: expected 4, "
+ "found 5");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_Array_Runtime) {
+ // array<i32>(1);
+ auto* tc = array(ty.i32(), nullptr, Expr(Source{{12, 34}}, 1));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "error: cannot init a runtime-sized array");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Array_RuntimeZeroValue) {
+ // array<i32>();
+ auto* tc = array(ty.i32(), nullptr);
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "error: cannot init a runtime-sized array");
+}
+
+} // namespace ArrayConstructor
+
+namespace ScalarConstructor {
+
+TEST_F(ResolverTypeConstructorValidationTest, Expr_Construct_i32_Success) {
+ auto* expr = Construct<i32>(Expr(123));
+ WrapInFunction(expr);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::I32>());
+
+ auto* call = Sem().Get(expr);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, Expr_Construct_u32_Success) {
+ auto* expr = Construct<u32>(Expr(123u));
+ WrapInFunction(expr);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::U32>());
+
+ auto* call = Sem().Get(expr);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::U32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, Expr_Construct_f32_Success) {
+ auto* expr = Construct<f32>(Expr(1.23f));
+ WrapInFunction(expr);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::F32>());
+
+ auto* call = Sem().Get(expr);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::F32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, Expr_Convert_f32_to_i32_Success) {
+ auto* expr = Construct<i32>(1.23f);
+ WrapInFunction(expr);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::I32>());
+
+ auto* call = Sem().Get(expr);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConversion>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::F32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, Expr_Convert_i32_to_u32_Success) {
+ auto* expr = Construct<u32>(123);
+ WrapInFunction(expr);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::U32>());
+
+ auto* call = Sem().Get(expr);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConversion>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, Expr_Convert_u32_to_f32_Success) {
+ auto* expr = Construct<f32>(123u);
+ WrapInFunction(expr);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::F32>());
+
+ auto* call = Sem().Get(expr);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConversion>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::U32>());
+}
+
+} // namespace ScalarConstructor
+
+namespace VectorConstructor {
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec2F32_Error_ScalarArgumentTypeMismatch) {
+ auto* tc = vec2<f32>(Expr(Source{{12, 34}}, 1), 1.0f);
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in vector constructor does not match vector "
+ "type: expected 'f32', found 'i32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec2U32_Error_ScalarArgumentTypeMismatch) {
+ auto* tc = vec2<u32>(1u, Expr(Source{{12, 34}}, 1));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in vector constructor does not match vector "
+ "type: expected 'u32', found 'i32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec2I32_Error_ScalarArgumentTypeMismatch) {
+ auto* tc = vec2<i32>(Expr(Source{{12, 34}}, 1u), 1);
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in vector constructor does not match vector "
+ "type: expected 'i32', found 'u32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec2Bool_Error_ScalarArgumentTypeMismatch) {
+ auto* tc = vec2<bool>(true, Expr(Source{{12, 34}}, 1));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in vector constructor does not match vector "
+ "type: expected 'bool', found 'i32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec2_Error_Vec3ArgumentCardinalityTooLarge) {
+ auto* tc = vec2<f32>(Construct(Source{{12, 34}}, ty.vec3<f32>()));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec2<f32>' with 3 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec2_Error_Vec4ArgumentCardinalityTooLarge) {
+ auto* tc = vec2<f32>(Construct(Source{{12, 34}}, ty.vec4<f32>()));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec2<f32>' with 4 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec2_Error_TooManyArgumentsScalar) {
+ auto* tc =
+ vec2<f32>(Expr(Source{{12, 34}}, 1.0f), Expr(Source{{12, 40}}, 1.0f),
+ Expr(Source{{12, 46}}, 1.0f));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec2<f32>' with 3 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec2_Error_TooManyArgumentsVector) {
+ auto* tc = vec2<f32>(Construct(Source{{12, 34}}, ty.vec2<f32>()),
+ Construct(Source{{12, 40}}, ty.vec2<f32>()));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec2<f32>' with 4 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec2_Error_TooManyArgumentsVectorAndScalar) {
+ auto* tc = vec2<f32>(Construct(Source{{12, 34}}, ty.vec2<f32>()),
+ Expr(Source{{12, 40}}, 1.0f));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec2<f32>' with 3 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec2_Error_InvalidArgumentType) {
+ auto* tc = vec2<f32>(Construct(Source{{12, 34}}, ty.mat2x2<f32>()));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: expected vector or scalar type in vector "
+ "constructor; found: mat2x2<f32>");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec2_Success_ZeroValue) {
+ auto* tc = vec2<f32>();
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 0u);
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec2F32_Success_Scalar) {
+ auto* tc = vec2<f32>(1.0f, 1.0f);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 2u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::F32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::F32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec2U32_Success_Scalar) {
+ auto* tc = vec2<u32>(1u, 1u);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 2u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::U32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::U32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec2I32_Success_Scalar) {
+ auto* tc = vec2<i32>(1, 1);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 2u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec2Bool_Success_Scalar) {
+ auto* tc = vec2<bool>(true, false);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 2u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Bool>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::Bool>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec2_Success_Identity) {
+ auto* tc = vec2<f32>(vec2<f32>());
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec2_Success_Vec2TypeConversion) {
+ auto* tc = vec2<f32>(vec2<i32>());
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConversion>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3F32_Error_ScalarArgumentTypeMismatch) {
+ auto* tc = vec3<f32>(1.0f, 1.0f, Expr(Source{{12, 34}}, 1));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in vector constructor does not match vector "
+ "type: expected 'f32', found 'i32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3U32_Error_ScalarArgumentTypeMismatch) {
+ auto* tc = vec3<u32>(1u, Expr(Source{{12, 34}}, 1), 1u);
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in vector constructor does not match vector "
+ "type: expected 'u32', found 'i32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3I32_Error_ScalarArgumentTypeMismatch) {
+ auto* tc = vec3<i32>(1, Expr(Source{{12, 34}}, 1u), 1);
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in vector constructor does not match vector "
+ "type: expected 'i32', found 'u32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3Bool_Error_ScalarArgumentTypeMismatch) {
+ auto* tc = vec3<bool>(true, Expr(Source{{12, 34}}, 1), false);
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in vector constructor does not match vector "
+ "type: expected 'bool', found 'i32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3_Error_Vec4ArgumentCardinalityTooLarge) {
+ auto* tc = vec3<f32>(Construct(Source{{12, 34}}, ty.vec4<f32>()));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec3<f32>' with 4 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3_Error_TooFewArgumentsScalar) {
+ auto* tc =
+ vec3<f32>(Expr(Source{{12, 34}}, 1.0f), Expr(Source{{12, 40}}, 1.0f));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec3<f32>' with 2 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3_Error_TooManyArgumentsScalar) {
+ auto* tc =
+ vec3<f32>(Expr(Source{{12, 34}}, 1.0f), Expr(Source{{12, 40}}, 1.0f),
+ Expr(Source{{12, 46}}, 1.0f), Expr(Source{{12, 52}}, 1.0f));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec3<f32>' with 4 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3_Error_TooFewArgumentsVec2) {
+ auto* tc = vec3<f32>(Construct(Source{{12, 34}}, ty.vec2<f32>()));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec3<f32>' with 2 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3_Error_TooManyArgumentsVec2) {
+ auto* tc = vec3<f32>(Construct(Source{{12, 34}}, ty.vec2<f32>()),
+ Construct(Source{{12, 40}}, ty.vec2<f32>()));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec3<f32>' with 4 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3_Error_TooManyArgumentsVec2AndScalar) {
+ auto* tc =
+ vec3<f32>(Construct(Source{{12, 34}}, ty.vec2<f32>()),
+ Expr(Source{{12, 40}}, 1.0f), Expr(Source{{12, 46}}, 1.0f));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec3<f32>' with 4 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3_Error_TooManyArgumentsVec3) {
+ auto* tc = vec3<f32>(Construct(Source{{12, 34}}, ty.vec3<f32>()),
+ Expr(Source{{12, 40}}, 1.0f));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec3<f32>' with 4 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3_Error_InvalidArgumentType) {
+ auto* tc = vec3<f32>(Construct(Source{{12, 34}}, ty.mat2x2<f32>()));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: expected vector or scalar type in vector "
+ "constructor; found: mat2x2<f32>");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3_Success_ZeroValue) {
+ auto* tc = vec3<f32>();
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 0u);
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3F32_Success_Scalar) {
+ auto* tc = vec3<f32>(1.0f, 1.0f, 1.0f);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 3u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::F32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::F32>());
+ EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::F32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3U32_Success_Scalar) {
+ auto* tc = vec3<u32>(1u, 1u, 1u);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 3u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::U32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::U32>());
+ EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::U32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3I32_Success_Scalar) {
+ auto* tc = vec3<i32>(1, 1, 1);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 3u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::I32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3Bool_Success_Scalar) {
+ auto* tc = vec3<bool>(true, false, true);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 3u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Bool>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::Bool>());
+ EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::Bool>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3_Success_Vec2AndScalar) {
+ auto* tc = vec3<f32>(vec2<f32>(), 1.0f);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 2u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::F32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3_Success_ScalarAndVec2) {
+ auto* tc = vec3<f32>(1.0f, vec2<f32>());
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 2u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::F32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::Vector>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3_Success_Identity) {
+ auto* tc = vec3<f32>(vec3<f32>());
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec3_Success_Vec3TypeConversion) {
+ auto* tc = vec3<f32>(vec3<i32>());
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConversion>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4F32_Error_ScalarArgumentTypeMismatch) {
+ auto* tc = vec4<f32>(1.0f, 1.0f, Expr(Source{{12, 34}}, 1), 1.0f);
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in vector constructor does not match vector "
+ "type: expected 'f32', found 'i32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4U32_Error_ScalarArgumentTypeMismatch) {
+ auto* tc = vec4<u32>(1u, 1u, Expr(Source{{12, 34}}, 1), 1u);
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in vector constructor does not match vector "
+ "type: expected 'u32', found 'i32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4I32_Error_ScalarArgumentTypeMismatch) {
+ auto* tc = vec4<i32>(1, 1, Expr(Source{{12, 34}}, 1u), 1);
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in vector constructor does not match vector "
+ "type: expected 'i32', found 'u32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4Bool_Error_ScalarArgumentTypeMismatch) {
+ auto* tc = vec4<bool>(true, false, Expr(Source{{12, 34}}, 1), true);
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in vector constructor does not match vector "
+ "type: expected 'bool', found 'i32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Error_TooFewArgumentsScalar) {
+ auto* tc =
+ vec4<f32>(Expr(Source{{12, 34}}, 1.0f), Expr(Source{{12, 40}}, 1.0f),
+ Expr(Source{{12, 46}}, 1.0f));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec4<f32>' with 3 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Error_TooManyArgumentsScalar) {
+ auto* tc =
+ vec4<f32>(Expr(Source{{12, 34}}, 1.0f), Expr(Source{{12, 40}}, 1.0f),
+ Expr(Source{{12, 46}}, 1.0f), Expr(Source{{12, 52}}, 1.0f),
+ Expr(Source{{12, 58}}, 1.0f));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec4<f32>' with 5 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Error_TooFewArgumentsVec2AndScalar) {
+ auto* tc = vec4<f32>(Construct(Source{{12, 34}}, ty.vec2<f32>()),
+ Expr(Source{{12, 40}}, 1.0f));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec4<f32>' with 3 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Error_TooManyArgumentsVec2AndScalars) {
+ auto* tc = vec4<f32>(
+ Construct(Source{{12, 34}}, ty.vec2<f32>()), Expr(Source{{12, 40}}, 1.0f),
+ Expr(Source{{12, 46}}, 1.0f), Expr(Source{{12, 52}}, 1.0f));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec4<f32>' with 5 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Error_TooManyArgumentsVec2Vec2Scalar) {
+ auto* tc = vec4<f32>(Construct(Source{{12, 34}}, ty.vec2<f32>()),
+ Construct(Source{{12, 40}}, ty.vec2<f32>()),
+ Expr(Source{{12, 46}}, 1.0f));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec4<f32>' with 5 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Error_TooManyArgumentsVec2Vec2Vec2) {
+ auto* tc = vec4<f32>(Construct(Source{{12, 34}}, ty.vec2<f32>()),
+ Construct(Source{{12, 40}}, ty.vec2<f32>()),
+ Construct(Source{{12, 40}}, ty.vec2<f32>()));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec4<f32>' with 6 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Error_TooFewArgumentsVec3) {
+ auto* tc = vec4<f32>(Construct(Source{{12, 34}}, ty.vec3<f32>()));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec4<f32>' with 3 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Error_TooManyArgumentsVec3AndScalars) {
+ auto* tc =
+ vec4<f32>(Construct(Source{{12, 34}}, ty.vec3<f32>()),
+ Expr(Source{{12, 40}}, 1.0f), Expr(Source{{12, 46}}, 1.0f));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec4<f32>' with 5 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Error_TooManyArgumentsVec3AndVec2) {
+ auto* tc = vec4<f32>(Construct(Source{{12, 34}}, ty.vec3<f32>()),
+ Construct(Source{{12, 40}}, ty.vec2<f32>()));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec4<f32>' with 5 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Error_TooManyArgumentsVec2AndVec3) {
+ auto* tc = vec4<f32>(Construct(Source{{12, 34}}, ty.vec2<f32>()),
+ Construct(Source{{12, 40}}, ty.vec3<f32>()));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec4<f32>' with 5 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Error_TooManyArgumentsVec3AndVec3) {
+ auto* tc = vec4<f32>(Construct(Source{{12, 34}}, ty.vec3<f32>()),
+ Construct(Source{{12, 40}}, ty.vec3<f32>()));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec4<f32>' with 6 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Error_InvalidArgumentType) {
+ auto* tc = vec4<f32>(Construct(Source{{12, 34}}, ty.mat2x2<f32>()));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: expected vector or scalar type in vector "
+ "constructor; found: mat2x2<f32>");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Success_ZeroValue) {
+ auto* tc = vec4<f32>();
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 4u);
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4F32_Success_Scalar) {
+ auto* tc = vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 4u);
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4U32_Success_Scalar) {
+ auto* tc = vec4<u32>(1u, 1u, 1u, 1u);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 4u);
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4I32_Success_Scalar) {
+ auto* tc = vec4<i32>(1, 1, 1, 1);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 4u);
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4Bool_Success_Scalar) {
+ auto* tc = vec4<bool>(true, false, true, false);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 4u);
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Success_Vec2ScalarScalar) {
+ auto* tc = vec4<f32>(vec2<f32>(), 1.0f, 1.0f);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 4u);
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Success_ScalarVec2Scalar) {
+ auto* tc = vec4<f32>(1.0f, vec2<f32>(), 1.0f);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 4u);
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Success_ScalarScalarVec2) {
+ auto* tc = vec4<f32>(1.0f, 1.0f, vec2<f32>());
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 4u);
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Success_Vec2AndVec2) {
+ auto* tc = vec4<f32>(vec2<f32>(), vec2<f32>());
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 4u);
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Success_Vec3AndScalar) {
+ auto* tc = vec4<f32>(vec3<f32>(), 1.0f);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 4u);
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Success_ScalarAndVec3) {
+ auto* tc = vec4<f32>(1.0f, vec3<f32>());
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 4u);
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Success_Identity) {
+ auto* tc = vec4<f32>(vec4<f32>());
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 4u);
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vec4_Success_Vec4TypeConversion) {
+ auto* tc = vec4<f32>(vec4<i32>());
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 4u);
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_NestedVectorConstructors_InnerError) {
+ auto* tc = vec4<f32>(vec4<f32>(1.0f, 1.0f,
+ vec3<f32>(Expr(Source{{12, 34}}, 1.0f),
+ Expr(Source{{12, 34}}, 1.0f))),
+ 1.0f);
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: attempted to construct 'vec3<f32>' with 2 component(s)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_NestedVectorConstructors_Success) {
+ auto* tc = vec4<f32>(vec3<f32>(vec2<f32>(1.0f, 1.0f), 1.0f), 1.0f);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 4u);
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vector_Alias_Argument_Error) {
+ auto* alias = Alias("UnsignedInt", ty.u32());
+ Global("uint_var", ty.Of(alias), ast::StorageClass::kPrivate);
+
+ auto* tc = vec2<f32>(Expr(Source{{12, 34}}, "uint_var"));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in vector constructor does not match vector "
+ "type: expected 'f32', found 'u32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vector_Alias_Argument_Success) {
+ auto* f32_alias = Alias("Float32", ty.f32());
+ auto* vec2_alias = Alias("VectorFloat2", ty.vec2<f32>());
+ Global("my_f32", ty.Of(f32_alias), ast::StorageClass::kPrivate);
+ Global("my_vec2", ty.Of(vec2_alias), ast::StorageClass::kPrivate);
+
+ auto* tc = vec3<f32>("my_vec2", "my_f32");
+ WrapInFunction(tc);
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vector_ElementTypeAlias_Error) {
+ auto* f32_alias = Alias("Float32", ty.f32());
+
+ // vec2<Float32>(1.0f, 1u)
+ auto* vec_type = ty.vec(ty.Of(f32_alias), 2);
+ auto* tc =
+ Construct(Source{{12, 34}}, vec_type, 1.0f, Expr(Source{{12, 40}}, 1u));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:40 error: type in vector constructor does not match vector "
+ "type: expected 'f32', found 'u32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vector_ElementTypeAlias_Success) {
+ auto* f32_alias = Alias("Float32", ty.f32());
+
+ // vec2<Float32>(1.0f, 1.0f)
+ auto* vec_type = ty.vec(ty.Of(f32_alias), 2);
+ auto* tc = Construct(Source{{12, 34}}, vec_type, 1.0f, 1.0f);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vector_ArgumentElementTypeAlias_Error) {
+ auto* f32_alias = Alias("Float32", ty.f32());
+
+ // vec3<u32>(vec<Float32>(), 1.0f)
+ auto* vec_type = ty.vec(ty.Of(f32_alias), 2);
+ auto* tc = vec3<u32>(Construct(Source{{12, 34}}, vec_type), 1.0f);
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in vector constructor does not match vector "
+ "type: expected 'u32', found 'f32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_Constructor_Vector_ArgumentElementTypeAlias_Success) {
+ auto* f32_alias = Alias("Float32", ty.f32());
+
+ // vec3<f32>(vec<Float32>(), 1.0f)
+ auto* vec_type = ty.vec(ty.Of(f32_alias), 2);
+ auto* tc = vec3<f32>(Construct(Source{{12, 34}}, vec_type), 1.0f);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, InferVec2ElementTypeFromScalars) {
+ auto* vec2_bool =
+ Construct(create<ast::Vector>(nullptr, 2), Expr(true), Expr(false));
+ auto* vec2_i32 = Construct(create<ast::Vector>(nullptr, 2), Expr(1), Expr(2));
+ auto* vec2_u32 =
+ Construct(create<ast::Vector>(nullptr, 2), Expr(1u), Expr(2u));
+ auto* vec2_f32 =
+ Construct(create<ast::Vector>(nullptr, 2), Expr(1.0f), Expr(2.0f));
+ WrapInFunction(vec2_bool, vec2_i32, vec2_u32, vec2_f32);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_TRUE(TypeOf(vec2_bool)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec2_i32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec2_u32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec2_f32)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(vec2_bool)->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_TRUE(TypeOf(vec2_i32)->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_TRUE(TypeOf(vec2_u32)->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_TRUE(TypeOf(vec2_f32)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(vec2_bool)->As<sem::Vector>()->Width(), 2u);
+ EXPECT_EQ(TypeOf(vec2_i32)->As<sem::Vector>()->Width(), 2u);
+ EXPECT_EQ(TypeOf(vec2_u32)->As<sem::Vector>()->Width(), 2u);
+ EXPECT_EQ(TypeOf(vec2_f32)->As<sem::Vector>()->Width(), 2u);
+ EXPECT_EQ(TypeOf(vec2_bool), TypeOf(vec2_bool->target.type));
+ EXPECT_EQ(TypeOf(vec2_i32), TypeOf(vec2_i32->target.type));
+ EXPECT_EQ(TypeOf(vec2_u32), TypeOf(vec2_u32->target.type));
+ EXPECT_EQ(TypeOf(vec2_f32), TypeOf(vec2_f32->target.type));
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, InferVec2ElementTypeFromVec2) {
+ auto* vec2_bool =
+ Construct(create<ast::Vector>(nullptr, 2), vec2<bool>(true, false));
+ auto* vec2_i32 = Construct(create<ast::Vector>(nullptr, 2), vec2<i32>(1, 2));
+ auto* vec2_u32 =
+ Construct(create<ast::Vector>(nullptr, 2), vec2<u32>(1u, 2u));
+ auto* vec2_f32 =
+ Construct(create<ast::Vector>(nullptr, 2), vec2<f32>(1.0f, 2.0f));
+ WrapInFunction(vec2_bool, vec2_i32, vec2_u32, vec2_f32);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_TRUE(TypeOf(vec2_bool)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec2_i32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec2_u32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec2_f32)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(vec2_bool)->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_TRUE(TypeOf(vec2_i32)->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_TRUE(TypeOf(vec2_u32)->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_TRUE(TypeOf(vec2_f32)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(vec2_bool)->As<sem::Vector>()->Width(), 2u);
+ EXPECT_EQ(TypeOf(vec2_i32)->As<sem::Vector>()->Width(), 2u);
+ EXPECT_EQ(TypeOf(vec2_u32)->As<sem::Vector>()->Width(), 2u);
+ EXPECT_EQ(TypeOf(vec2_f32)->As<sem::Vector>()->Width(), 2u);
+ EXPECT_EQ(TypeOf(vec2_bool), TypeOf(vec2_bool->target.type));
+ EXPECT_EQ(TypeOf(vec2_i32), TypeOf(vec2_i32->target.type));
+ EXPECT_EQ(TypeOf(vec2_u32), TypeOf(vec2_u32->target.type));
+ EXPECT_EQ(TypeOf(vec2_f32), TypeOf(vec2_f32->target.type));
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, InferVec3ElementTypeFromScalars) {
+ auto* vec3_bool = Construct(create<ast::Vector>(nullptr, 3), Expr(true),
+ Expr(false), Expr(true));
+ auto* vec3_i32 =
+ Construct(create<ast::Vector>(nullptr, 3), Expr(1), Expr(2), Expr(3));
+ auto* vec3_u32 =
+ Construct(create<ast::Vector>(nullptr, 3), Expr(1u), Expr(2u), Expr(3u));
+ auto* vec3_f32 = Construct(create<ast::Vector>(nullptr, 3), Expr(1.0f),
+ Expr(2.0f), Expr(3.0f));
+ WrapInFunction(vec3_bool, vec3_i32, vec3_u32, vec3_f32);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_TRUE(TypeOf(vec3_bool)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec3_i32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec3_u32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec3_f32)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(vec3_bool)->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_TRUE(TypeOf(vec3_i32)->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_TRUE(TypeOf(vec3_u32)->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_TRUE(TypeOf(vec3_f32)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(vec3_bool)->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(TypeOf(vec3_i32)->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(TypeOf(vec3_u32)->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(TypeOf(vec3_f32)->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(TypeOf(vec3_bool), TypeOf(vec3_bool->target.type));
+ EXPECT_EQ(TypeOf(vec3_i32), TypeOf(vec3_i32->target.type));
+ EXPECT_EQ(TypeOf(vec3_u32), TypeOf(vec3_u32->target.type));
+ EXPECT_EQ(TypeOf(vec3_f32), TypeOf(vec3_f32->target.type));
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, InferVec3ElementTypeFromVec3) {
+ auto* vec3_bool =
+ Construct(create<ast::Vector>(nullptr, 3), vec3<bool>(true, false, true));
+ auto* vec3_i32 =
+ Construct(create<ast::Vector>(nullptr, 3), vec3<i32>(1, 2, 3));
+ auto* vec3_u32 =
+ Construct(create<ast::Vector>(nullptr, 3), vec3<u32>(1u, 2u, 3u));
+ auto* vec3_f32 =
+ Construct(create<ast::Vector>(nullptr, 3), vec3<f32>(1.0f, 2.0f, 3.0f));
+ WrapInFunction(vec3_bool, vec3_i32, vec3_u32, vec3_f32);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_TRUE(TypeOf(vec3_bool)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec3_i32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec3_u32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec3_f32)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(vec3_bool)->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_TRUE(TypeOf(vec3_i32)->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_TRUE(TypeOf(vec3_u32)->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_TRUE(TypeOf(vec3_f32)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(vec3_bool)->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(TypeOf(vec3_i32)->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(TypeOf(vec3_u32)->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(TypeOf(vec3_f32)->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(TypeOf(vec3_bool), TypeOf(vec3_bool->target.type));
+ EXPECT_EQ(TypeOf(vec3_i32), TypeOf(vec3_i32->target.type));
+ EXPECT_EQ(TypeOf(vec3_u32), TypeOf(vec3_u32->target.type));
+ EXPECT_EQ(TypeOf(vec3_f32), TypeOf(vec3_f32->target.type));
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ InferVec3ElementTypeFromScalarAndVec2) {
+ auto* vec3_bool = Construct(create<ast::Vector>(nullptr, 3), Expr(true),
+ vec2<bool>(false, true));
+ auto* vec3_i32 =
+ Construct(create<ast::Vector>(nullptr, 3), Expr(1), vec2<i32>(2, 3));
+ auto* vec3_u32 =
+ Construct(create<ast::Vector>(nullptr, 3), Expr(1u), vec2<u32>(2u, 3u));
+ auto* vec3_f32 = Construct(create<ast::Vector>(nullptr, 3), Expr(1.0f),
+ vec2<f32>(2.0f, 3.0f));
+ WrapInFunction(vec3_bool, vec3_i32, vec3_u32, vec3_f32);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_TRUE(TypeOf(vec3_bool)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec3_i32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec3_u32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec3_f32)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(vec3_bool)->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_TRUE(TypeOf(vec3_i32)->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_TRUE(TypeOf(vec3_u32)->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_TRUE(TypeOf(vec3_f32)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(vec3_bool)->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(TypeOf(vec3_i32)->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(TypeOf(vec3_u32)->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(TypeOf(vec3_f32)->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(TypeOf(vec3_bool), TypeOf(vec3_bool->target.type));
+ EXPECT_EQ(TypeOf(vec3_i32), TypeOf(vec3_i32->target.type));
+ EXPECT_EQ(TypeOf(vec3_u32), TypeOf(vec3_u32->target.type));
+ EXPECT_EQ(TypeOf(vec3_f32), TypeOf(vec3_f32->target.type));
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, InferVec4ElementTypeFromScalars) {
+ auto* vec4_bool = Construct(create<ast::Vector>(nullptr, 4), Expr(true),
+ Expr(false), Expr(true), Expr(false));
+ auto* vec4_i32 = Construct(create<ast::Vector>(nullptr, 4), Expr(1), Expr(2),
+ Expr(3), Expr(4));
+ auto* vec4_u32 = Construct(create<ast::Vector>(nullptr, 4), Expr(1u),
+ Expr(2u), Expr(3u), Expr(4u));
+ auto* vec4_f32 = Construct(create<ast::Vector>(nullptr, 4), Expr(1.0f),
+ Expr(2.0f), Expr(3.0f), Expr(4.0f));
+ WrapInFunction(vec4_bool, vec4_i32, vec4_u32, vec4_f32);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_TRUE(TypeOf(vec4_bool)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec4_i32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec4_u32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec4_f32)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(vec4_bool)->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_TRUE(TypeOf(vec4_i32)->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_TRUE(TypeOf(vec4_u32)->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_TRUE(TypeOf(vec4_f32)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(vec4_bool)->As<sem::Vector>()->Width(), 4u);
+ EXPECT_EQ(TypeOf(vec4_i32)->As<sem::Vector>()->Width(), 4u);
+ EXPECT_EQ(TypeOf(vec4_u32)->As<sem::Vector>()->Width(), 4u);
+ EXPECT_EQ(TypeOf(vec4_f32)->As<sem::Vector>()->Width(), 4u);
+ EXPECT_EQ(TypeOf(vec4_bool), TypeOf(vec4_bool->target.type));
+ EXPECT_EQ(TypeOf(vec4_i32), TypeOf(vec4_i32->target.type));
+ EXPECT_EQ(TypeOf(vec4_u32), TypeOf(vec4_u32->target.type));
+ EXPECT_EQ(TypeOf(vec4_f32), TypeOf(vec4_f32->target.type));
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, InferVec4ElementTypeFromVec4) {
+ auto* vec4_bool = Construct(create<ast::Vector>(nullptr, 4),
+ vec4<bool>(true, false, true, false));
+ auto* vec4_i32 =
+ Construct(create<ast::Vector>(nullptr, 4), vec4<i32>(1, 2, 3, 4));
+ auto* vec4_u32 =
+ Construct(create<ast::Vector>(nullptr, 4), vec4<u32>(1u, 2u, 3u, 4u));
+ auto* vec4_f32 = Construct(create<ast::Vector>(nullptr, 4),
+ vec4<f32>(1.0f, 2.0f, 3.0f, 4.0f));
+ WrapInFunction(vec4_bool, vec4_i32, vec4_u32, vec4_f32);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_TRUE(TypeOf(vec4_bool)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec4_i32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec4_u32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec4_f32)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(vec4_bool)->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_TRUE(TypeOf(vec4_i32)->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_TRUE(TypeOf(vec4_u32)->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_TRUE(TypeOf(vec4_f32)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(vec4_bool)->As<sem::Vector>()->Width(), 4u);
+ EXPECT_EQ(TypeOf(vec4_i32)->As<sem::Vector>()->Width(), 4u);
+ EXPECT_EQ(TypeOf(vec4_u32)->As<sem::Vector>()->Width(), 4u);
+ EXPECT_EQ(TypeOf(vec4_f32)->As<sem::Vector>()->Width(), 4u);
+ EXPECT_EQ(TypeOf(vec4_bool), TypeOf(vec4_bool->target.type));
+ EXPECT_EQ(TypeOf(vec4_i32), TypeOf(vec4_i32->target.type));
+ EXPECT_EQ(TypeOf(vec4_u32), TypeOf(vec4_u32->target.type));
+ EXPECT_EQ(TypeOf(vec4_f32), TypeOf(vec4_f32->target.type));
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ InferVec4ElementTypeFromScalarAndVec3) {
+ auto* vec4_bool = Construct(create<ast::Vector>(nullptr, 4), Expr(true),
+ vec3<bool>(false, true, false));
+ auto* vec4_i32 =
+ Construct(create<ast::Vector>(nullptr, 4), Expr(1), vec3<i32>(2, 3, 4));
+ auto* vec4_u32 = Construct(create<ast::Vector>(nullptr, 4), Expr(1u),
+ vec3<u32>(2u, 3u, 4u));
+ auto* vec4_f32 = Construct(create<ast::Vector>(nullptr, 4), Expr(1.0f),
+ vec3<f32>(2.0f, 3.0f, 4.0f));
+ WrapInFunction(vec4_bool, vec4_i32, vec4_u32, vec4_f32);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_TRUE(TypeOf(vec4_bool)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec4_i32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec4_u32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec4_f32)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(vec4_bool)->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_TRUE(TypeOf(vec4_i32)->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_TRUE(TypeOf(vec4_u32)->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_TRUE(TypeOf(vec4_f32)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(vec4_bool)->As<sem::Vector>()->Width(), 4u);
+ EXPECT_EQ(TypeOf(vec4_i32)->As<sem::Vector>()->Width(), 4u);
+ EXPECT_EQ(TypeOf(vec4_u32)->As<sem::Vector>()->Width(), 4u);
+ EXPECT_EQ(TypeOf(vec4_f32)->As<sem::Vector>()->Width(), 4u);
+ EXPECT_EQ(TypeOf(vec4_bool), TypeOf(vec4_bool->target.type));
+ EXPECT_EQ(TypeOf(vec4_i32), TypeOf(vec4_i32->target.type));
+ EXPECT_EQ(TypeOf(vec4_u32), TypeOf(vec4_u32->target.type));
+ EXPECT_EQ(TypeOf(vec4_f32), TypeOf(vec4_f32->target.type));
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ InferVec4ElementTypeFromVec2AndVec2) {
+ auto* vec4_bool = Construct(create<ast::Vector>(nullptr, 4),
+ vec2<bool>(true, false), vec2<bool>(true, false));
+ auto* vec4_i32 = Construct(create<ast::Vector>(nullptr, 4), vec2<i32>(1, 2),
+ vec2<i32>(3, 4));
+ auto* vec4_u32 = Construct(create<ast::Vector>(nullptr, 4), vec2<u32>(1u, 2u),
+ vec2<u32>(3u, 4u));
+ auto* vec4_f32 = Construct(create<ast::Vector>(nullptr, 4),
+ vec2<f32>(1.0f, 2.0f), vec2<f32>(3.0f, 4.0f));
+ WrapInFunction(vec4_bool, vec4_i32, vec4_u32, vec4_f32);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_TRUE(TypeOf(vec4_bool)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec4_i32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec4_u32)->Is<sem::Vector>());
+ ASSERT_TRUE(TypeOf(vec4_f32)->Is<sem::Vector>());
+ EXPECT_TRUE(TypeOf(vec4_bool)->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_TRUE(TypeOf(vec4_i32)->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_TRUE(TypeOf(vec4_u32)->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_TRUE(TypeOf(vec4_f32)->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(TypeOf(vec4_bool)->As<sem::Vector>()->Width(), 4u);
+ EXPECT_EQ(TypeOf(vec4_i32)->As<sem::Vector>()->Width(), 4u);
+ EXPECT_EQ(TypeOf(vec4_u32)->As<sem::Vector>()->Width(), 4u);
+ EXPECT_EQ(TypeOf(vec4_f32)->As<sem::Vector>()->Width(), 4u);
+ EXPECT_EQ(TypeOf(vec4_bool), TypeOf(vec4_bool->target.type));
+ EXPECT_EQ(TypeOf(vec4_i32), TypeOf(vec4_i32->target.type));
+ EXPECT_EQ(TypeOf(vec4_u32), TypeOf(vec4_u32->target.type));
+ EXPECT_EQ(TypeOf(vec4_f32), TypeOf(vec4_f32->target.type));
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ CannotInferVectorElementTypeWithoutArgs) {
+ WrapInFunction(Construct(create<ast::Vector>(Source{{12, 34}}, nullptr, 3)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: missing vector element type");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ CannotInferVec2ElementTypeFromScalarsMismatch) {
+ WrapInFunction(Construct(Source{{1, 1}}, create<ast::Vector>(nullptr, 2),
+ Expr(Source{{1, 2}}, 1), //
+ Expr(Source{{1, 3}}, 2u)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(1:1 error: cannot infer vector element type, as constructor arguments have different types
+1:2 note: argument 0 has type i32
+1:3 note: argument 1 has type u32)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ CannotInferVec3ElementTypeFromScalarsMismatch) {
+ WrapInFunction(Construct(Source{{1, 1}}, create<ast::Vector>(nullptr, 3),
+ Expr(Source{{1, 2}}, 1), //
+ Expr(Source{{1, 3}}, 2u), //
+ Expr(Source{{1, 4}}, 3)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(1:1 error: cannot infer vector element type, as constructor arguments have different types
+1:2 note: argument 0 has type i32
+1:3 note: argument 1 has type u32
+1:4 note: argument 2 has type i32)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ CannotInferVec3ElementTypeFromScalarAndVec2Mismatch) {
+ WrapInFunction(
+ Construct(Source{{1, 1}}, create<ast::Vector>(nullptr, 3),
+ Expr(Source{{1, 2}}, 1), //
+ Construct(Source{{1, 3}}, ty.vec2<f32>(), 2.0f, 3.0f)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(1:1 error: cannot infer vector element type, as constructor arguments have different types
+1:2 note: argument 0 has type i32
+1:3 note: argument 1 has type vec2<f32>)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ CannotInferVec4ElementTypeFromScalarsMismatch) {
+ WrapInFunction(Construct(Source{{1, 1}}, create<ast::Vector>(nullptr, 4),
+ Expr(Source{{1, 2}}, 1), //
+ Expr(Source{{1, 3}}, 2), //
+ Expr(Source{{1, 4}}, 3.0f), //
+ Expr(Source{{1, 5}}, 4)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(1:1 error: cannot infer vector element type, as constructor arguments have different types
+1:2 note: argument 0 has type i32
+1:3 note: argument 1 has type i32
+1:4 note: argument 2 has type f32
+1:5 note: argument 3 has type i32)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ CannotInferVec4ElementTypeFromScalarAndVec3Mismatch) {
+ WrapInFunction(
+ Construct(Source{{1, 1}}, create<ast::Vector>(nullptr, 4),
+ Expr(Source{{1, 2}}, 1), //
+ Construct(Source{{1, 3}}, ty.vec3<u32>(), 2u, 3u, 4u)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(1:1 error: cannot infer vector element type, as constructor arguments have different types
+1:2 note: argument 0 has type i32
+1:3 note: argument 1 has type vec3<u32>)");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ CannotInferVec4ElementTypeFromVec2AndVec2Mismatch) {
+ WrapInFunction(Construct(Source{{1, 1}}, create<ast::Vector>(nullptr, 4),
+ Construct(Source{{1, 2}}, ty.vec2<i32>(), 3, 4), //
+ Construct(Source{{1, 3}}, ty.vec2<u32>(), 3u, 4u)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(1:1 error: cannot infer vector element type, as constructor arguments have different types
+1:2 note: argument 0 has type vec2<i32>
+1:3 note: argument 1 has type vec2<u32>)");
+}
+
+} // namespace VectorConstructor
+
+namespace MatrixConstructor {
+struct MatrixDimensions {
+ uint32_t rows;
+ uint32_t columns;
+};
+
+static std::string MatrixStr(const MatrixDimensions& dimensions) {
+ return "mat" + std::to_string(dimensions.columns) + "x" +
+ std::to_string(dimensions.rows) + "<f32>";
+}
+
+using MatrixConstructorTest = ResolverTestWithParam<MatrixDimensions>;
+
+TEST_P(MatrixConstructorTest, Expr_ColumnConstructor_Error_TooFewArguments) {
+ // matNxM<f32>(vecM<f32>(), ...); with N - 1 arguments
+
+ const auto param = GetParam();
+
+ std::stringstream args_tys;
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns - 1; i++) {
+ auto* vec_type = ty.vec<f32>(param.rows);
+ args.push_back(Construct(Source{{12, i}}, vec_type));
+ if (i > 1) {
+ args_tys << ", ";
+ }
+ args_tys << "vec" << param.rows << "<f32>";
+ }
+
+ auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
+ auto* tc = Construct(Source{}, matrix_type, std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(r()->error(), HasSubstr("12:1 error: no matching constructor " +
+ MatrixStr(param) + "(" + args_tys.str() +
+ ")\n\n3 candidates available:"));
+}
+
+TEST_P(MatrixConstructorTest, Expr_ElementConstructor_Error_TooFewArguments) {
+ // matNxM<f32>(f32,...,f32); with N*M - 1 arguments
+
+ const auto param = GetParam();
+
+ std::stringstream args_tys;
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns * param.rows - 1; i++) {
+ args.push_back(Construct(Source{{12, i}}, ty.f32()));
+ if (i > 1) {
+ args_tys << ", ";
+ }
+ args_tys << "f32";
+ }
+
+ auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
+ auto* tc = Construct(Source{}, matrix_type, std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(r()->error(), HasSubstr("12:1 error: no matching constructor " +
+ MatrixStr(param) + "(" + args_tys.str() +
+ ")\n\n3 candidates available:"));
+}
+
+TEST_P(MatrixConstructorTest, Expr_ColumnConstructor_Error_TooManyArguments) {
+ // matNxM<f32>(vecM<f32>(), ...); with N + 1 arguments
+
+ const auto param = GetParam();
+
+ std::stringstream args_tys;
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns + 1; i++) {
+ auto* vec_type = ty.vec<f32>(param.rows);
+ args.push_back(Construct(Source{{12, i}}, vec_type));
+ if (i > 1) {
+ args_tys << ", ";
+ }
+ args_tys << "vec" << param.rows << "<f32>";
+ }
+
+ auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
+ auto* tc = Construct(Source{}, matrix_type, std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(r()->error(), HasSubstr("12:1 error: no matching constructor " +
+ MatrixStr(param) + "(" + args_tys.str() +
+ ")\n\n3 candidates available:"));
+}
+
+TEST_P(MatrixConstructorTest, Expr_ElementConstructor_Error_TooManyArguments) {
+ // matNxM<f32>(f32,...,f32); with N*M + 1 arguments
+
+ const auto param = GetParam();
+
+ std::stringstream args_tys;
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns * param.rows + 1; i++) {
+ args.push_back(Construct(Source{{12, i}}, ty.f32()));
+ if (i > 1) {
+ args_tys << ", ";
+ }
+ args_tys << "f32";
+ }
+
+ auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
+ auto* tc = Construct(Source{}, matrix_type, std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(r()->error(), HasSubstr("12:1 error: no matching constructor " +
+ MatrixStr(param) + "(" + args_tys.str() +
+ ")\n\n3 candidates available:"));
+}
+
+TEST_P(MatrixConstructorTest,
+ Expr_ColumnConstructor_Error_InvalidArgumentType) {
+ // matNxM<f32>(vec<u32>, vec<u32>, ...); N arguments
+
+ const auto param = GetParam();
+
+ std::stringstream args_tys;
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns; i++) {
+ auto* vec_type = ty.vec<u32>(param.rows);
+ args.push_back(Construct(Source{{12, i}}, vec_type));
+ if (i > 1) {
+ args_tys << ", ";
+ }
+ args_tys << "vec" << param.rows << "<u32>";
+ }
+
+ auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
+ auto* tc = Construct(Source{}, matrix_type, std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(r()->error(), HasSubstr("12:1 error: no matching constructor " +
+ MatrixStr(param) + "(" + args_tys.str() +
+ ")\n\n3 candidates available:"));
+}
+
+TEST_P(MatrixConstructorTest,
+ Expr_ElementConstructor_Error_InvalidArgumentType) {
+ // matNxM<f32>(u32, u32, ...); N*M arguments
+
+ const auto param = GetParam();
+
+ std::stringstream args_tys;
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns; i++) {
+ args.push_back(Expr(Source{{12, i}}, 1u));
+ if (i > 1) {
+ args_tys << ", ";
+ }
+ args_tys << "u32";
+ }
+
+ auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
+ auto* tc = Construct(Source{}, matrix_type, std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(r()->error(), HasSubstr("12:1 error: no matching constructor " +
+ MatrixStr(param) + "(" + args_tys.str() +
+ ")\n\n3 candidates available:"));
+}
+
+TEST_P(MatrixConstructorTest,
+ Expr_ColumnConstructor_Error_TooFewRowsInVectorArgument) {
+ // matNxM<f32>(vecM<f32>(),...,vecM-1<f32>());
+
+ const auto param = GetParam();
+
+ // Skip the test if parameters would have resulted in an invalid vec1 type.
+ if (param.rows == 2) {
+ return;
+ }
+
+ std::stringstream args_tys;
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns - 1; i++) {
+ auto* valid_vec_type = ty.vec<f32>(param.rows);
+ args.push_back(Construct(Source{{12, i}}, valid_vec_type));
+ if (i > 1) {
+ args_tys << ", ";
+ }
+ args_tys << "vec" << param.rows << "<f32>";
+ }
+ const size_t kInvalidLoc = 2 * (param.columns - 1);
+ auto* invalid_vec_type = ty.vec<f32>(param.rows - 1);
+ args.push_back(Construct(Source{{12, kInvalidLoc}}, invalid_vec_type));
+ args_tys << ", vec" << (param.rows - 1) << "<f32>";
+
+ auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
+ auto* tc = Construct(Source{}, matrix_type, std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(r()->error(), HasSubstr("12:1 error: no matching constructor " +
+ MatrixStr(param) + "(" + args_tys.str() +
+ ")\n\n3 candidates available:"));
+}
+
+TEST_P(MatrixConstructorTest,
+ Expr_ColumnConstructor_Error_TooManyRowsInVectorArgument) {
+ // matNxM<f32>(vecM<f32>(),...,vecM+1<f32>());
+
+ const auto param = GetParam();
+
+ // Skip the test if parameters would have resulted in an invalid vec5 type.
+ if (param.rows == 4) {
+ return;
+ }
+
+ std::stringstream args_tys;
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns - 1; i++) {
+ auto* valid_vec_type = ty.vec<f32>(param.rows);
+ args.push_back(Construct(Source{{12, i}}, valid_vec_type));
+ if (i > 1) {
+ args_tys << ", ";
+ }
+ args_tys << "vec" << param.rows << "<f32>";
+ }
+ const size_t kInvalidLoc = 2 * (param.columns - 1);
+ auto* invalid_vec_type = ty.vec<f32>(param.rows + 1);
+ args.push_back(Construct(Source{{12, kInvalidLoc}}, invalid_vec_type));
+ args_tys << ", vec" << (param.rows + 1) << "<f32>";
+
+ auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
+ auto* tc = Construct(Source{}, matrix_type, std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(r()->error(), HasSubstr("12:1 error: no matching constructor " +
+ MatrixStr(param) + "(" + args_tys.str() +
+ ")\n\n3 candidates available:"));
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_ZeroValue_Success) {
+ // matNxM<f32>();
+
+ const auto param = GetParam();
+ auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
+ auto* tc = Construct(Source{{12, 40}}, matrix_type);
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_WithColumns_Success) {
+ // matNxM<f32>(vecM<f32>(), ...); with N arguments
+
+ const auto param = GetParam();
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns; i++) {
+ auto* vec_type = ty.vec<f32>(param.rows);
+ args.push_back(Construct(Source{{12, i}}, vec_type));
+ }
+
+ auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
+ auto* tc = Construct(Source{}, matrix_type, std::move(args));
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_WithElements_Success) {
+ // matNxM<f32>(f32,...,f32); with N*M arguments
+
+ const auto param = GetParam();
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns * param.rows; i++) {
+ args.push_back(Construct(Source{{12, i}}, ty.f32()));
+ }
+
+ auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
+ auto* tc = Construct(Source{}, matrix_type, std::move(args));
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Error) {
+ // matNxM<Float32>(vecM<u32>(), ...); with N arguments
+
+ const auto param = GetParam();
+ auto* f32_alias = Alias("Float32", ty.f32());
+
+ std::stringstream args_tys;
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns; i++) {
+ auto* vec_type = ty.vec(ty.u32(), param.rows);
+ args.push_back(Construct(Source{{12, i}}, vec_type));
+ if (i > 1) {
+ args_tys << ", ";
+ }
+ args_tys << "vec" << param.rows << "<u32>";
+ }
+
+ auto* matrix_type = ty.mat(ty.Of(f32_alias), param.columns, param.rows);
+ auto* tc = Construct(Source{}, matrix_type, std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(r()->error(), HasSubstr("12:1 error: no matching constructor " +
+ MatrixStr(param) + "(" + args_tys.str() +
+ ")\n\n3 candidates available:"));
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Success) {
+ // matNxM<Float32>(vecM<f32>(), ...); with N arguments
+
+ const auto param = GetParam();
+ auto* f32_alias = Alias("Float32", ty.f32());
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns; i++) {
+ auto* vec_type = ty.vec<f32>(param.rows);
+ args.push_back(Construct(Source{{12, i}}, vec_type));
+ }
+
+ auto* matrix_type = ty.mat(ty.Of(f32_alias), param.columns, param.rows);
+ auto* tc = Construct(Source{}, matrix_type, std::move(args));
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ Expr_MatrixConstructor_ArgumentTypeAlias_Error) {
+ auto* alias = Alias("VectorUnsigned2", ty.vec2<u32>());
+ auto* tc =
+ mat2x2<f32>(Construct(Source{{12, 34}}, ty.Of(alias)), vec2<f32>());
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: no matching constructor mat2x2<f32>(vec2<u32>, vec2<f32>)
+
+3 candidates available:
+ mat2x2<f32>()
+ mat2x2<f32>(f32,...,f32) // 4 arguments
+ mat2x2<f32>(vec2<f32>, vec2<f32>)
+)");
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentTypeAlias_Success) {
+ const auto param = GetParam();
+ auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
+ auto* vec_type = ty.vec<f32>(param.rows);
+ auto* vec_alias = Alias("VectorFloat2", vec_type);
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns; i++) {
+ args.push_back(Construct(Source{{12, i}}, ty.Of(vec_alias)));
+ }
+
+ auto* tc = Construct(Source{}, matrix_type, std::move(args));
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentElementTypeAlias_Error) {
+ const auto param = GetParam();
+ auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
+ auto* f32_alias = Alias("UnsignedInt", ty.u32());
+
+ std::stringstream args_tys;
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns; i++) {
+ auto* vec_type = ty.vec(ty.Of(f32_alias), param.rows);
+ args.push_back(Construct(Source{{12, i}}, vec_type));
+ if (i > 1) {
+ args_tys << ", ";
+ }
+ args_tys << "vec" << param.rows << "<u32>";
+ }
+
+ auto* tc = Construct(Source{}, matrix_type, std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(r()->error(), HasSubstr("12:1 error: no matching constructor " +
+ MatrixStr(param) + "(" + args_tys.str() +
+ ")\n\n3 candidates available:"));
+}
+
+TEST_P(MatrixConstructorTest,
+ Expr_Constructor_ArgumentElementTypeAlias_Success) {
+ const auto param = GetParam();
+ auto* f32_alias = Alias("Float32", ty.f32());
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns; i++) {
+ auto* vec_type = ty.vec(ty.Of(f32_alias), param.rows);
+ args.push_back(Construct(Source{{12, i}}, vec_type));
+ }
+
+ auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
+ auto* tc = Construct(Source{}, matrix_type, std::move(args));
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_P(MatrixConstructorTest, InferElementTypeFromVectors) {
+ const auto param = GetParam();
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns; i++) {
+ args.push_back(Construct(ty.vec<f32>(param.rows)));
+ }
+
+ auto* matrix_type = create<ast::Matrix>(nullptr, param.rows, param.columns);
+ auto* tc = Construct(Source{}, matrix_type, std::move(args));
+ WrapInFunction(tc);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_P(MatrixConstructorTest, InferElementTypeFromScalars) {
+ const auto param = GetParam();
+
+ ast::ExpressionList args;
+ for (uint32_t i = 0; i < param.rows * param.columns; i++) {
+ args.push_back(Expr(static_cast<f32>(i)));
+ }
+
+ auto* matrix_type = create<ast::Matrix>(nullptr, param.rows, param.columns);
+ WrapInFunction(Construct(Source{{12, 34}}, matrix_type, std::move(args)));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_P(MatrixConstructorTest, CannotInferElementTypeFromVectors_Mismatch) {
+ const auto param = GetParam();
+
+ std::stringstream err;
+ err << "12:34 error: cannot infer matrix element type, as constructor "
+ "arguments have different types";
+
+ ast::ExpressionList args;
+ for (uint32_t i = 0; i < param.columns; i++) {
+ err << "\n";
+ auto src = Source{{1, 10 + i}};
+ if (i == 1) {
+ // Odd one out
+ args.push_back(Construct(src, ty.vec<i32>(param.rows)));
+ err << src << " note: argument " << i << " has type vec" << param.rows
+ << "<i32>";
+ } else {
+ args.push_back(Construct(src, ty.vec<f32>(param.rows)));
+ err << src << " note: argument " << i << " has type vec" << param.rows
+ << "<f32>";
+ }
+ }
+
+ auto* matrix_type = create<ast::Matrix>(nullptr, param.rows, param.columns);
+ WrapInFunction(Construct(Source{{12, 34}}, matrix_type, std::move(args)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(r()->error(), err.str());
+}
+
+TEST_P(MatrixConstructorTest, CannotInferElementTypeFromScalars_Mismatch) {
+ const auto param = GetParam();
+
+ std::stringstream err;
+ err << "12:34 error: cannot infer matrix element type, as constructor "
+ "arguments have different types";
+ ast::ExpressionList args;
+ for (uint32_t i = 0; i < param.rows * param.columns; i++) {
+ err << "\n";
+ auto src = Source{{1, 10 + i}};
+ if (i == 3) {
+ args.push_back(Expr(src, static_cast<i32>(i))); // The odd one out
+ err << src << " note: argument " << i << " has type i32";
+ } else {
+ args.push_back(Expr(src, static_cast<f32>(i)));
+ err << src << " note: argument " << i << " has type f32";
+ }
+ }
+
+ auto* matrix_type = create<ast::Matrix>(nullptr, param.rows, param.columns);
+ WrapInFunction(Construct(Source{{12, 34}}, matrix_type, std::move(args)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(r()->error(), err.str());
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
+ MatrixConstructorTest,
+ testing::Values(MatrixDimensions{2, 2},
+ MatrixDimensions{3, 2},
+ MatrixDimensions{4, 2},
+ MatrixDimensions{2, 3},
+ MatrixDimensions{3, 3},
+ MatrixDimensions{4, 3},
+ MatrixDimensions{2, 4},
+ MatrixDimensions{3, 4},
+ MatrixDimensions{4, 4}));
+} // namespace MatrixConstructor
+
+namespace StructConstructor {
+using builder::CreatePtrs;
+using builder::CreatePtrsFor;
+using builder::f32;
+using builder::i32;
+using builder::mat2x2;
+using builder::mat3x3;
+using builder::mat4x4;
+using builder::u32;
+using builder::vec2;
+using builder::vec3;
+using builder::vec4;
+
+constexpr CreatePtrs all_types[] = {
+ CreatePtrsFor<bool>(), //
+ CreatePtrsFor<u32>(), //
+ CreatePtrsFor<i32>(), //
+ CreatePtrsFor<f32>(), //
+ CreatePtrsFor<vec4<bool>>(), //
+ CreatePtrsFor<vec2<i32>>(), //
+ CreatePtrsFor<vec3<u32>>(), //
+ CreatePtrsFor<vec4<f32>>(), //
+ CreatePtrsFor<mat2x2<f32>>(), //
+ CreatePtrsFor<mat3x3<f32>>(), //
+ CreatePtrsFor<mat4x4<f32>>() //
+};
+
+auto number_of_members = testing::Values(2u, 32u, 64u);
+
+using StructConstructorInputsTest =
+ ResolverTestWithParam<std::tuple<CreatePtrs, // struct member type
+ uint32_t>>; // number of struct members
+TEST_P(StructConstructorInputsTest, TooFew) {
+ auto& param = GetParam();
+ auto& str_params = std::get<0>(param);
+ uint32_t N = std::get<1>(param);
+
+ ast::StructMemberList members;
+ ast::ExpressionList values;
+ for (uint32_t i = 0; i < N; i++) {
+ auto* struct_type = str_params.ast(*this);
+ members.push_back(Member("member_" + std::to_string(i), struct_type));
+ if (i < N - 1) {
+ auto* ctor_value_expr = str_params.expr(*this, 0);
+ values.push_back(ctor_value_expr);
+ }
+ }
+ auto* s = Structure("s", members);
+ auto* tc = Construct(Source{{12, 34}}, ty.Of(s), values);
+ WrapInFunction(tc);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: struct constructor has too few inputs: expected " +
+ std::to_string(N) + ", found " + std::to_string(N - 1));
+}
+
+TEST_P(StructConstructorInputsTest, TooMany) {
+ auto& param = GetParam();
+ auto& str_params = std::get<0>(param);
+ uint32_t N = std::get<1>(param);
+
+ ast::StructMemberList members;
+ ast::ExpressionList values;
+ for (uint32_t i = 0; i < N + 1; i++) {
+ if (i < N) {
+ auto* struct_type = str_params.ast(*this);
+ members.push_back(Member("member_" + std::to_string(i), struct_type));
+ }
+ auto* ctor_value_expr = str_params.expr(*this, 0);
+ values.push_back(ctor_value_expr);
+ }
+ auto* s = Structure("s", members);
+ auto* tc = Construct(Source{{12, 34}}, ty.Of(s), values);
+ WrapInFunction(tc);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: struct constructor has too many inputs: expected " +
+ std::to_string(N) + ", found " + std::to_string(N + 1));
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
+ StructConstructorInputsTest,
+ testing::Combine(testing::ValuesIn(all_types),
+ number_of_members));
+using StructConstructorTypeTest =
+ ResolverTestWithParam<std::tuple<CreatePtrs, // struct member type
+ CreatePtrs, // constructor value type
+ uint32_t>>; // number of struct members
+TEST_P(StructConstructorTypeTest, AllTypes) {
+ auto& param = GetParam();
+ auto& str_params = std::get<0>(param);
+ auto& ctor_params = std::get<1>(param);
+ uint32_t N = std::get<2>(param);
+
+ if (str_params.ast == ctor_params.ast) {
+ return;
+ }
+
+ ast::StructMemberList members;
+ ast::ExpressionList values;
+ // make the last value of the constructor to have a different type
+ uint32_t constructor_value_with_different_type = N - 1;
+ for (uint32_t i = 0; i < N; i++) {
+ auto* struct_type = str_params.ast(*this);
+ members.push_back(Member("member_" + std::to_string(i), struct_type));
+ auto* ctor_value_expr = (i == constructor_value_with_different_type)
+ ? ctor_params.expr(*this, 0)
+ : str_params.expr(*this, 0);
+ values.push_back(ctor_value_expr);
+ }
+ auto* s = Structure("s", members);
+ auto* tc = Construct(ty.Of(s), values);
+ WrapInFunction(tc);
+
+ std::string found = FriendlyName(ctor_params.ast(*this));
+ std::string expected = FriendlyName(str_params.ast(*this));
+ std::stringstream err;
+ err << "error: type in struct constructor does not match struct member ";
+ err << "type: expected '" << expected << "', found '" << found << "'";
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), err.str());
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
+ StructConstructorTypeTest,
+ testing::Combine(testing::ValuesIn(all_types),
+ testing::ValuesIn(all_types),
+ number_of_members));
+
+TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_Struct_Nested) {
+ auto* inner_m = Member("m", ty.i32());
+ auto* inner_s = Structure("inner_s", {inner_m});
+
+ auto* m0 = Member("m0", ty.i32());
+ auto* m1 = Member("m1", ty.Of(inner_s));
+ auto* m2 = Member("m2", ty.i32());
+ auto* s = Structure("s", {m0, m1, m2});
+
+ auto* tc = Construct(Source{{12, 34}}, ty.Of(s), 1, 1, 1);
+ WrapInFunction(tc);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "error: type in struct constructor does not match struct member "
+ "type: expected 'inner_s', found 'i32'");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_Struct) {
+ auto* m = Member("m", ty.i32());
+ auto* s = Structure("MyInputs", {m});
+ auto* tc = Construct(Source{{12, 34}}, ty.Of(s));
+ WrapInFunction(tc);
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_Struct_Empty) {
+ auto* str = Structure("S", {
+ Member("a", ty.i32()),
+ Member("b", ty.f32()),
+ Member("c", ty.vec3<i32>()),
+ });
+
+ WrapInFunction(Construct(ty.Of(str)));
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+} // namespace StructConstructor
+
+TEST_F(ResolverTypeConstructorValidationTest, NonConstructibleType_Atomic) {
+ WrapInFunction(
+ Assign(Phony(), Construct(Source{{12, 34}}, ty.atomic(ty.i32()))));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: type is not constructible");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ NonConstructibleType_AtomicArray) {
+ WrapInFunction(Assign(
+ Phony(), Construct(Source{{12, 34}}, ty.array(ty.atomic(ty.i32()), 4))));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: array constructor has non-constructible element type");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest,
+ NonConstructibleType_AtomicStructMember) {
+ auto* str = Structure("S", {Member("a", ty.atomic(ty.i32()))});
+ WrapInFunction(Assign(Phony(), Construct(Source{{12, 34}}, ty.Of(str))));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: struct constructor has non-constructible type");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, NonConstructibleType_Sampler) {
+ WrapInFunction(Assign(
+ Phony(),
+ Construct(Source{{12, 34}}, ty.sampler(ast::SamplerKind::kSampler))));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: type is not constructible");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, TypeConstructorAsStatement) {
+ WrapInFunction(
+ CallStmt(Construct(Source{{12, 34}}, ty.vec2<f32>(), 1.f, 2.f)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type constructor evaluated but not used");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, TypeConversionAsStatement) {
+ WrapInFunction(CallStmt(Construct(Source{{12, 34}}, ty.f32(), 1)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: type cast evaluated but not used");
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/type_validation_test.cc b/src/tint/resolver/type_validation_test.cc
new file mode 100644
index 0000000..1d41858
--- /dev/null
+++ b/src/tint/resolver/type_validation_test.cc
@@ -0,0 +1,1163 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/id_attribute.h"
+#include "src/tint/ast/return_statement.h"
+#include "src/tint/ast/stage_attribute.h"
+#include "src/tint/ast/struct_block_attribute.h"
+#include "src/tint/resolver/resolver.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/multisampled_texture_type.h"
+#include "src/tint/sem/storage_texture_type.h"
+
+#include "gmock/gmock.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+// Helpers and typedefs
+template <typename T>
+using DataType = builder::DataType<T>;
+template <typename T>
+using vec2 = builder::vec2<T>;
+template <typename T>
+using vec3 = builder::vec3<T>;
+template <typename T>
+using vec4 = builder::vec4<T>;
+template <typename T>
+using mat2x2 = builder::mat2x2<T>;
+template <typename T>
+using mat3x3 = builder::mat3x3<T>;
+template <typename T>
+using mat4x4 = builder::mat4x4<T>;
+template <int N, typename T>
+using array = builder::array<N, T>;
+template <typename T>
+using alias = builder::alias<T>;
+template <typename T>
+using alias1 = builder::alias1<T>;
+template <typename T>
+using alias2 = builder::alias2<T>;
+template <typename T>
+using alias3 = builder::alias3<T>;
+using f32 = builder::f32;
+using i32 = builder::i32;
+using u32 = builder::u32;
+
+class ResolverTypeValidationTest : public resolver::TestHelper,
+ public testing::Test {};
+
+TEST_F(ResolverTypeValidationTest, VariableDeclNoConstructor_Pass) {
+ // {
+ // var a :i32;
+ // a = 2;
+ // }
+ auto* var = Var("a", ty.i32(), ast::StorageClass::kNone, nullptr);
+ auto* lhs = Expr("a");
+ auto* rhs = Expr(2);
+
+ auto* body =
+ Block(Decl(var), Assign(Source{Source::Location{12, 34}}, lhs, rhs));
+
+ WrapInFunction(body);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
+}
+
+TEST_F(ResolverTypeValidationTest, GlobalConstantNoConstructor_Pass) {
+ // @id(0) override a :i32;
+ Override(Source{{12, 34}}, "a", ty.i32(), nullptr, ast::AttributeList{Id(0)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeValidationTest, GlobalVariableWithStorageClass_Pass) {
+ // var<private> global_var: f32;
+ Global(Source{{12, 34}}, "global_var", ty.f32(), ast::StorageClass::kPrivate);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeValidationTest, GlobalConstantWithStorageClass_Fail) {
+ // const<private> global_var: f32;
+ AST().AddGlobalVariable(create<ast::Variable>(
+ Source{{12, 34}}, Symbols().Register("global_var"),
+ ast::StorageClass::kPrivate, ast::Access::kUndefined, ty.f32(), true,
+ false, Expr(1.23f), ast::AttributeList{}));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: global constants shouldn't have a storage class");
+}
+
+TEST_F(ResolverTypeValidationTest, GlobalConstNoStorageClass_Pass) {
+ // let global_var: f32;
+ GlobalConst(Source{{12, 34}}, "global_var", ty.f32(), Construct(ty.f32()));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeValidationTest, GlobalVariableUnique_Pass) {
+ // var global_var0 : f32 = 0.1;
+ // var global_var1 : i32 = 0;
+
+ Global("global_var0", ty.f32(), ast::StorageClass::kPrivate, Expr(0.1f));
+
+ Global(Source{{12, 34}}, "global_var1", ty.f32(), ast::StorageClass::kPrivate,
+ Expr(1.0f));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeValidationTest,
+ GlobalVariableFunctionVariableNotUnique_Pass) {
+ // fn my_func() {
+ // var a: f32 = 2.0;
+ // }
+ // var a: f32 = 2.1;
+
+ auto* var = Var("a", ty.f32(), ast::StorageClass::kNone, Expr(2.0f));
+
+ Func("my_func", ast::VariableList{}, ty.void_(), {Decl(var)});
+
+ Global("a", ty.f32(), ast::StorageClass::kPrivate, Expr(2.1f));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeValidationTest, RedeclaredIdentifierInnerScope_Pass) {
+ // {
+ // if (true) { var a : f32 = 2.0; }
+ // var a : f32 = 3.14;
+ // }
+ auto* var = Var("a", ty.f32(), ast::StorageClass::kNone, Expr(2.0f));
+
+ auto* cond = Expr(true);
+ auto* body = Block(Decl(var));
+
+ auto* var_a_float = Var("a", ty.f32(), ast::StorageClass::kNone, Expr(3.1f));
+
+ auto* outer_body =
+ Block(create<ast::IfStatement>(cond, body, ast::ElseStatementList{}),
+ Decl(Source{{12, 34}}, var_a_float));
+
+ WrapInFunction(outer_body);
+
+ EXPECT_TRUE(r()->Resolve());
+}
+
+TEST_F(ResolverTypeValidationTest, RedeclaredIdentifierInnerScopeBlock_Pass) {
+ // {
+ // { var a : f32; }
+ // var a : f32;
+ // }
+ auto* var_inner = Var("a", ty.f32(), ast::StorageClass::kNone);
+ auto* inner = Block(Decl(Source{{12, 34}}, var_inner));
+
+ auto* var_outer = Var("a", ty.f32(), ast::StorageClass::kNone);
+ auto* outer_body = Block(inner, Decl(var_outer));
+
+ WrapInFunction(outer_body);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeValidationTest,
+ RedeclaredIdentifierDifferentFunctions_Pass) {
+ // func0 { var a : f32 = 2.0; return; }
+ // func1 { var a : f32 = 3.0; return; }
+ auto* var0 = Var("a", ty.f32(), ast::StorageClass::kNone, Expr(2.0f));
+
+ auto* var1 = Var("a", ty.f32(), ast::StorageClass::kNone, Expr(1.0f));
+
+ Func("func0", ast::VariableList{}, ty.void_(),
+ ast::StatementList{
+ Decl(Source{{12, 34}}, var0),
+ Return(),
+ },
+ ast::AttributeList{});
+
+ Func("func1", ast::VariableList{}, ty.void_(),
+ ast::StatementList{
+ Decl(Source{{13, 34}}, var1),
+ Return(),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_UnsignedLiteral_Pass) {
+ // var<private> a : array<f32, 4u>;
+ Global("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, 4u)),
+ ast::StorageClass::kPrivate);
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_SignedLiteral_Pass) {
+ // var<private> a : array<f32, 4>;
+ Global("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, 4)),
+ ast::StorageClass::kPrivate);
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_UnsignedConstant_Pass) {
+ // let size = 4u;
+ // var<private> a : array<f32, size>;
+ GlobalConst("size", nullptr, Expr(4u));
+ Global("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")),
+ ast::StorageClass::kPrivate);
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_SignedConstant_Pass) {
+ // let size = 4;
+ // var<private> a : array<f32, size>;
+ GlobalConst("size", nullptr, Expr(4));
+ Global("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")),
+ ast::StorageClass::kPrivate);
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_UnsignedLiteral_Zero) {
+ // var<private> a : array<f32, 0u>;
+ Global("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, 0u)),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: array size must be at least 1");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_SignedLiteral_Zero) {
+ // var<private> a : array<f32, 0>;
+ Global("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, 0)),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: array size must be at least 1");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_SignedLiteral_Negative) {
+ // var<private> a : array<f32, -10>;
+ Global("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, -10)),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: array size must be at least 1");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_UnsignedConstant_Zero) {
+ // let size = 0u;
+ // var<private> a : array<f32, size>;
+ GlobalConst("size", nullptr, Expr(0u));
+ Global("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: array size must be at least 1");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_SignedConstant_Zero) {
+ // let size = 0;
+ // var<private> a : array<f32, size>;
+ GlobalConst("size", nullptr, Expr(0));
+ Global("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: array size must be at least 1");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_SignedConstant_Negative) {
+ // let size = -10;
+ // var<private> a : array<f32, size>;
+ GlobalConst("size", nullptr, Expr(-10));
+ Global("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: array size must be at least 1");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_FloatLiteral) {
+ // var<private> a : array<f32, 10.0>;
+ Global("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, 10.f)),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: array size must be integer scalar");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_IVecLiteral) {
+ // var<private> a : array<f32, vec2<i32>(10, 10)>;
+ Global(
+ "a",
+ ty.array(ty.f32(), Construct(Source{{12, 34}}, ty.vec2<i32>(), 10, 10)),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: array size must be integer scalar");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_FloatConstant) {
+ // let size = 10.0;
+ // var<private> a : array<f32, size>;
+ GlobalConst("size", nullptr, Expr(10.f));
+ Global("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: array size must be integer scalar");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_IVecConstant) {
+ // let size = vec2<i32>(100, 100);
+ // var<private> a : array<f32, size>;
+ GlobalConst("size", nullptr, Construct(ty.vec2<i32>(), 100, 100));
+ Global("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: array size must be integer scalar");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ImplicitStride) {
+ // var<private> a : array<f32, 0x40000000>;
+ Global("a", ty.array(Source{{12, 34}}, ty.f32(), 0x40000000),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: array size in bytes must not exceed 0xffffffff, but "
+ "is 0x100000000");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ExplicitStride) {
+ // var<private> a : @stride(8) array<f32, 0x20000000>;
+ Global("a", ty.array(Source{{12, 34}}, ty.f32(), 0x20000000, 8),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: array size in bytes must not exceed 0xffffffff, but "
+ "is 0x100000000");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_OverridableConstant) {
+ // override size = 10;
+ // var<private> a : array<f32, size>;
+ Override("size", nullptr, Expr(10));
+ Global("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: array size expression must not be pipeline-overridable");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_ModuleVar) {
+ // var<private> size : i32 = 10;
+ // var<private> a : array<f32, size>;
+ Global("size", ty.i32(), Expr(10), ast::StorageClass::kPrivate);
+ Global("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: array size identifier must be a module-scope constant");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_FunctionConstant) {
+ // {
+ // let size = 10;
+ // var a : array<f32, size>;
+ // }
+ auto* size = Const("size", nullptr, Expr(10));
+ auto* a = Var("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")));
+ WrapInFunction(Block(Decl(size), Decl(a)));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: array size identifier must be a module-scope constant");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_InvalidExpr) {
+ // var a : array<f32, i32(4)>;
+ auto* size = Const("size", nullptr, Expr(10));
+ auto* a =
+ Var("a", ty.array(ty.f32(), Construct(Source{{12, 34}}, ty.i32(), 4)));
+ WrapInFunction(Block(Decl(size), Decl(a)));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: array size expression must be either a literal or a "
+ "module-scope constant");
+}
+
+TEST_F(ResolverTypeValidationTest, RuntimeArrayInFunction_Fail) {
+ /// @stage(vertex)
+ // fn func() { var a : array<i32>; }
+
+ auto* var =
+ Var(Source{{12, 34}}, "a", ty.array<i32>(), ast::StorageClass::kNone);
+
+ Func("func", ast::VariableList{}, ty.void_(),
+ ast::StatementList{
+ Decl(var),
+ },
+ ast::AttributeList{
+ Stage(ast::PipelineStage::kVertex),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: runtime-sized arrays can only be used in the <storage> storage class
+12:34 note: while instantiating variable a)");
+}
+
+TEST_F(ResolverTypeValidationTest, Struct_Member_VectorNoType) {
+ // struct S {
+ // a: vec3;
+ // };
+
+ Structure("S",
+ {Member("a", create<ast::Vector>(Source{{12, 34}}, nullptr, 3))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: missing vector element type");
+}
+
+TEST_F(ResolverTypeValidationTest, Struct_Member_MatrixNoType) {
+ // struct S {
+ // a: mat3x3;
+ // };
+ Structure(
+ "S", {Member("a", create<ast::Matrix>(Source{{12, 34}}, nullptr, 3, 3))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: missing matrix element type");
+}
+
+TEST_F(ResolverTypeValidationTest, Struct_TooBig) {
+ // struct Foo {
+ // a: array<f32, 0x20000000>;
+ // b: array<f32, 0x20000000>;
+ // };
+
+ Structure(Source{{12, 34}}, "Foo",
+ {
+ Member("a", ty.array<f32, 0x20000000>()),
+ Member("b", ty.array<f32, 0x20000000>()),
+ });
+
+ WrapInFunction();
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: struct size in bytes must not exceed 0xffffffff, but "
+ "is 0x100000000");
+}
+
+TEST_F(ResolverTypeValidationTest, Struct_MemberOffset_TooBig) {
+ // struct Foo {
+ // a: array<f32, 0x3fffffff>;
+ // b: f32;
+ // c: f32;
+ // };
+
+ Structure("Foo", {
+ Member("a", ty.array<f32, 0x3fffffff>()),
+ Member("b", ty.f32()),
+ Member(Source{{12, 34}}, "c", ty.f32()),
+ });
+
+ WrapInFunction();
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: struct member has byte offset 0x100000000, but must "
+ "not exceed 0xffffffff");
+}
+
+TEST_F(ResolverTypeValidationTest, RuntimeArrayIsLast_Pass) {
+ // [[block]]
+ // struct Foo {
+ // vf: f32;
+ // rt: array<f32>;
+ // };
+
+ Structure("Foo",
+ {
+ Member("vf", ty.f32()),
+ Member("rt", ty.array<f32>()),
+ },
+ {create<ast::StructBlockAttribute>()});
+
+ WrapInFunction();
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeValidationTest, RuntimeArrayInArray) {
+ // struct Foo {
+ // rt : array<array<f32>, 4>;
+ // };
+
+ Structure("Foo",
+ {Member("rt", ty.array(Source{{12, 34}}, ty.array<f32>(), 4))});
+
+ EXPECT_FALSE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(r()->error(),
+ "12:34 error: an array element type cannot contain a runtime-sized "
+ "array");
+}
+
+TEST_F(ResolverTypeValidationTest, RuntimeArrayInStructInArray) {
+ // struct Foo {
+ // rt : array<f32>;
+ // };
+ // var<private> a : array<Foo, 4>;
+
+ auto* foo = Structure("Foo", {Member("rt", ty.array<f32>())});
+ Global("v", ty.array(Source{{12, 34}}, ty.Of(foo), 4),
+ ast::StorageClass::kPrivate);
+
+ EXPECT_FALSE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(r()->error(),
+ "12:34 error: an array element type cannot contain a runtime-sized "
+ "array");
+}
+
+TEST_F(ResolverTypeValidationTest, RuntimeArrayInStructInStruct) {
+ // struct Foo {
+ // rt : array<f32>;
+ // };
+ // struct Outer {
+ // inner : Foo;
+ // };
+
+ auto* foo = Structure("Foo", {Member("rt", ty.array<f32>())});
+ Structure("Outer", {Member(Source{{12, 34}}, "inner", ty.Of(foo))});
+
+ EXPECT_FALSE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(r()->error(),
+ "12:34 error: a struct that contains a runtime array cannot be "
+ "nested inside another struct");
+}
+
+TEST_F(ResolverTypeValidationTest, RuntimeArrayIsNotLast_Fail) {
+ // [[block]]
+ // struct Foo {
+ // rt: array<f32>;
+ // vf: f32;
+ // };
+
+ Structure("Foo",
+ {
+ Member(Source{{12, 34}}, "rt", ty.array<f32>()),
+ Member("vf", ty.f32()),
+ },
+ {create<ast::StructBlockAttribute>()});
+
+ WrapInFunction();
+
+ EXPECT_FALSE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: runtime arrays may only appear as the last member of a struct)");
+}
+
+TEST_F(ResolverTypeValidationTest, RuntimeArrayAsGlobalVariable) {
+ Global(Source{{56, 78}}, "g", ty.array<i32>(), ast::StorageClass::kPrivate);
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: runtime-sized arrays can only be used in the <storage> storage class
+56:78 note: while instantiating variable g)");
+}
+
+TEST_F(ResolverTypeValidationTest, RuntimeArrayAsLocalVariable) {
+ auto* v = Var(Source{{56, 78}}, "g", ty.array<i32>());
+ WrapInFunction(v);
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: runtime-sized arrays can only be used in the <storage> storage class
+56:78 note: while instantiating variable g)");
+}
+
+TEST_F(ResolverTypeValidationTest, RuntimeArrayAsParameter_Fail) {
+ // fn func(a : array<u32>) {}
+ // @stage(vertex) fn main() {}
+
+ auto* param = Param(Source{{12, 34}}, "a", ty.array<i32>());
+
+ Func("func", ast::VariableList{param}, ty.void_(),
+ ast::StatementList{
+ Return(),
+ },
+ ast::AttributeList{});
+
+ Func("main", ast::VariableList{}, ty.void_(),
+ ast::StatementList{
+ Return(),
+ },
+ ast::AttributeList{
+ Stage(ast::PipelineStage::kVertex),
+ });
+
+ EXPECT_FALSE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: runtime-sized arrays can only be used in the <storage> storage class
+12:34 note: while instantiating parameter a)");
+}
+
+TEST_F(ResolverTypeValidationTest, PtrToRuntimeArrayAsParameter_Fail) {
+ // fn func(a : ptr<workgroup, array<u32>>) {}
+
+ auto* param =
+ Param(Source{{12, 34}}, "a",
+ ty.pointer(ty.array<i32>(), ast::StorageClass::kWorkgroup));
+
+ Func("func", ast::VariableList{param}, ty.void_(),
+ ast::StatementList{
+ Return(),
+ },
+ ast::AttributeList{});
+
+ EXPECT_FALSE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: runtime-sized arrays can only be used in the <storage> storage class
+12:34 note: while instantiating parameter a)");
+}
+
+TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsNotLast_Fail) {
+ // [[block]]
+ // type RTArr = array<u32>;
+ // struct s {
+ // b: RTArr;
+ // a: u32;
+ //}
+
+ auto* alias = Alias("RTArr", ty.array<u32>());
+ Structure("s",
+ {
+ Member(Source{{12, 34}}, "b", ty.Of(alias)),
+ Member("a", ty.u32()),
+ },
+ {create<ast::StructBlockAttribute>()});
+
+ WrapInFunction();
+
+ EXPECT_FALSE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(r()->error(),
+ "12:34 error: runtime arrays may only appear as the last member of "
+ "a struct");
+}
+
+TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsLast_Pass) {
+ // [[block]]
+ // type RTArr = array<u32>;
+ // struct s {
+ // a: u32;
+ // b: RTArr;
+ //}
+
+ auto* alias = Alias("RTArr", ty.array<u32>());
+ Structure("s",
+ {
+ Member("a", ty.u32()),
+ Member("b", ty.Of(alias)),
+ },
+ {create<ast::StructBlockAttribute>()});
+
+ WrapInFunction();
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTypeValidationTest, ArrayOfNonStorableType) {
+ auto* tex_ty = ty.sampled_texture(ast::TextureDimension::k2d, ty.f32());
+ Global("arr", ty.array(Source{{12, 34}}, tex_ty, 4),
+ ast::StorageClass::kPrivate);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: texture_2d<f32> cannot be used as an element type of "
+ "an array");
+}
+
+TEST_F(ResolverTypeValidationTest, VariableAsType) {
+ // var<private> a : i32;
+ // var<private> b : a;
+ Global("a", ty.i32(), ast::StorageClass::kPrivate);
+ Global("b", ty.type_name("a"), ast::StorageClass::kPrivate);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(error: cannot use variable 'a' as type
+note: 'a' declared here)");
+}
+
+TEST_F(ResolverTypeValidationTest, FunctionAsType) {
+ // fn f() {}
+ // var<private> v : f;
+ Func("f", {}, ty.void_(), {});
+ Global("v", ty.type_name("f"), ast::StorageClass::kPrivate);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(error: cannot use function 'f' as type
+note: 'f' declared here)");
+}
+
+namespace GetCanonicalTests {
+struct Params {
+ builder::ast_type_func_ptr create_ast_type;
+ builder::sem_type_func_ptr create_sem_type;
+};
+
+template <typename T>
+constexpr Params ParamsFor() {
+ return Params{DataType<T>::AST, DataType<T>::Sem};
+}
+
+static constexpr Params cases[] = {
+ ParamsFor<bool>(),
+ ParamsFor<alias<bool>>(),
+ ParamsFor<alias1<alias<bool>>>(),
+
+ ParamsFor<vec3<f32>>(),
+ ParamsFor<alias<vec3<f32>>>(),
+ ParamsFor<alias1<alias<vec3<f32>>>>(),
+
+ ParamsFor<vec3<alias<f32>>>(),
+ ParamsFor<alias1<vec3<alias<f32>>>>(),
+ ParamsFor<alias2<alias1<vec3<alias<f32>>>>>(),
+ ParamsFor<alias3<alias2<vec3<alias1<alias<f32>>>>>>(),
+
+ ParamsFor<mat3x3<alias<f32>>>(),
+ ParamsFor<alias1<mat3x3<alias<f32>>>>(),
+ ParamsFor<alias2<alias1<mat3x3<alias<f32>>>>>(),
+ ParamsFor<alias3<alias2<mat3x3<alias1<alias<f32>>>>>>(),
+
+ ParamsFor<alias1<alias<bool>>>(),
+ ParamsFor<alias1<alias<vec3<f32>>>>(),
+ ParamsFor<alias1<alias<mat3x3<f32>>>>(),
+};
+
+using CanonicalTest = ResolverTestWithParam<Params>;
+TEST_P(CanonicalTest, All) {
+ auto& params = GetParam();
+
+ auto* type = params.create_ast_type(*this);
+
+ auto* var = Var("v", type);
+ auto* expr = Expr("v");
+ WrapInFunction(var, expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* got = TypeOf(expr)->UnwrapRef();
+ auto* expected = params.create_sem_type(*this);
+
+ EXPECT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
+ << "expected: " << FriendlyName(expected) << "\n";
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+ CanonicalTest,
+ testing::ValuesIn(cases));
+
+} // namespace GetCanonicalTests
+
+namespace MultisampledTextureTests {
+struct DimensionParams {
+ ast::TextureDimension dim;
+ bool is_valid;
+};
+
+static constexpr DimensionParams dimension_cases[] = {
+ DimensionParams{ast::TextureDimension::k1d, false},
+ DimensionParams{ast::TextureDimension::k2d, true},
+ DimensionParams{ast::TextureDimension::k2dArray, false},
+ DimensionParams{ast::TextureDimension::k3d, false},
+ DimensionParams{ast::TextureDimension::kCube, false},
+ DimensionParams{ast::TextureDimension::kCubeArray, false}};
+
+using MultisampledTextureDimensionTest = ResolverTestWithParam<DimensionParams>;
+TEST_P(MultisampledTextureDimensionTest, All) {
+ auto& params = GetParam();
+ Global(Source{{12, 34}}, "a", ty.multisampled_texture(params.dim, ty.i32()),
+ ast::StorageClass::kNone, nullptr,
+ ast::AttributeList{GroupAndBinding(0, 0)});
+
+ if (params.is_valid) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: only 2d multisampled textures are supported");
+ }
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+ MultisampledTextureDimensionTest,
+ testing::ValuesIn(dimension_cases));
+
+struct TypeParams {
+ builder::ast_type_func_ptr type_func;
+ bool is_valid;
+};
+
+template <typename T>
+constexpr TypeParams TypeParamsFor(bool is_valid) {
+ return TypeParams{DataType<T>::AST, is_valid};
+}
+
+static constexpr TypeParams type_cases[] = {
+ TypeParamsFor<bool>(false),
+ TypeParamsFor<i32>(true),
+ TypeParamsFor<u32>(true),
+ TypeParamsFor<f32>(true),
+
+ TypeParamsFor<alias<bool>>(false),
+ TypeParamsFor<alias<i32>>(true),
+ TypeParamsFor<alias<u32>>(true),
+ TypeParamsFor<alias<f32>>(true),
+
+ TypeParamsFor<vec3<f32>>(false),
+ TypeParamsFor<mat3x3<f32>>(false),
+
+ TypeParamsFor<alias<vec3<f32>>>(false),
+ TypeParamsFor<alias<mat3x3<f32>>>(false),
+};
+
+using MultisampledTextureTypeTest = ResolverTestWithParam<TypeParams>;
+TEST_P(MultisampledTextureTypeTest, All) {
+ auto& params = GetParam();
+ Global(Source{{12, 34}}, "a",
+ ty.multisampled_texture(ast::TextureDimension::k2d,
+ params.type_func(*this)),
+ ast::StorageClass::kNone, nullptr,
+ ast::AttributeList{GroupAndBinding(0, 0)});
+
+ if (params.is_valid) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: texture_multisampled_2d<type>: type must be f32, "
+ "i32 or u32");
+ }
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+ MultisampledTextureTypeTest,
+ testing::ValuesIn(type_cases));
+
+} // namespace MultisampledTextureTests
+
+namespace StorageTextureTests {
+struct DimensionParams {
+ ast::TextureDimension dim;
+ bool is_valid;
+};
+
+static constexpr DimensionParams Dimension_cases[] = {
+ DimensionParams{ast::TextureDimension::k1d, true},
+ DimensionParams{ast::TextureDimension::k2d, true},
+ DimensionParams{ast::TextureDimension::k2dArray, true},
+ DimensionParams{ast::TextureDimension::k3d, true},
+ DimensionParams{ast::TextureDimension::kCube, false},
+ DimensionParams{ast::TextureDimension::kCubeArray, false}};
+
+using StorageTextureDimensionTest = ResolverTestWithParam<DimensionParams>;
+TEST_P(StorageTextureDimensionTest, All) {
+ // @group(0) @binding(0)
+ // var a : texture_storage_*<ru32int, write>;
+ auto& params = GetParam();
+
+ auto* st =
+ ty.storage_texture(Source{{12, 34}}, params.dim,
+ ast::TexelFormat::kR32Uint, ast::Access::kWrite);
+
+ Global("a", st, ast::StorageClass::kNone,
+ ast::AttributeList{GroupAndBinding(0, 0)});
+
+ if (params.is_valid) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cube dimensions for storage textures are not "
+ "supported");
+ }
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+ StorageTextureDimensionTest,
+ testing::ValuesIn(Dimension_cases));
+
+struct FormatParams {
+ ast::TexelFormat format;
+ bool is_valid;
+};
+
+static constexpr FormatParams format_cases[] = {
+ FormatParams{ast::TexelFormat::kR32Float, true},
+ FormatParams{ast::TexelFormat::kR32Sint, true},
+ FormatParams{ast::TexelFormat::kR32Uint, true},
+ FormatParams{ast::TexelFormat::kRg32Float, true},
+ FormatParams{ast::TexelFormat::kRg32Sint, true},
+ FormatParams{ast::TexelFormat::kRg32Uint, true},
+ FormatParams{ast::TexelFormat::kRgba16Float, true},
+ FormatParams{ast::TexelFormat::kRgba16Sint, true},
+ FormatParams{ast::TexelFormat::kRgba16Uint, true},
+ FormatParams{ast::TexelFormat::kRgba32Float, true},
+ FormatParams{ast::TexelFormat::kRgba32Sint, true},
+ FormatParams{ast::TexelFormat::kRgba32Uint, true},
+ FormatParams{ast::TexelFormat::kRgba8Sint, true},
+ FormatParams{ast::TexelFormat::kRgba8Snorm, true},
+ FormatParams{ast::TexelFormat::kRgba8Uint, true},
+ FormatParams{ast::TexelFormat::kRgba8Unorm, true}};
+
+using StorageTextureFormatTest = ResolverTestWithParam<FormatParams>;
+TEST_P(StorageTextureFormatTest, All) {
+ auto& params = GetParam();
+ // @group(0) @binding(0)
+ // var a : texture_storage_1d<*, write>;
+ // @group(0) @binding(1)
+ // var b : texture_storage_2d<*, write>;
+ // @group(0) @binding(2)
+ // var c : texture_storage_2d_array<*, write>;
+ // @group(0) @binding(3)
+ // var d : texture_storage_3d<*, write>;
+
+ auto* st_a = ty.storage_texture(Source{{12, 34}}, ast::TextureDimension::k1d,
+ params.format, ast::Access::kWrite);
+ Global("a", st_a, ast::StorageClass::kNone,
+ ast::AttributeList{GroupAndBinding(0, 0)});
+
+ auto* st_b = ty.storage_texture(ast::TextureDimension::k2d, params.format,
+ ast::Access::kWrite);
+ Global("b", st_b, ast::StorageClass::kNone,
+ ast::AttributeList{GroupAndBinding(0, 1)});
+
+ auto* st_c = ty.storage_texture(ast::TextureDimension::k2dArray,
+ params.format, ast::Access::kWrite);
+ Global("c", st_c, ast::StorageClass::kNone,
+ ast::AttributeList{GroupAndBinding(0, 2)});
+
+ auto* st_d = ty.storage_texture(ast::TextureDimension::k3d, params.format,
+ ast::Access::kWrite);
+ Global("d", st_d, ast::StorageClass::kNone,
+ ast::AttributeList{GroupAndBinding(0, 3)});
+
+ if (params.is_valid) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: image format must be one of the texel formats "
+ "specified for storage textues in "
+ "https://gpuweb.github.io/gpuweb/wgsl/#texel-formats");
+ }
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+ StorageTextureFormatTest,
+ testing::ValuesIn(format_cases));
+
+using StorageTextureAccessTest = ResolverTest;
+
+TEST_F(StorageTextureAccessTest, MissingAccess_Fail) {
+ // @group(0) @binding(0)
+ // var a : texture_storage_1d<ru32int>;
+
+ auto* st =
+ ty.storage_texture(Source{{12, 34}}, ast::TextureDimension::k1d,
+ ast::TexelFormat::kR32Uint, ast::Access::kUndefined);
+
+ Global("a", st, ast::StorageClass::kNone,
+ ast::AttributeList{GroupAndBinding(0, 0)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: storage texture missing access control");
+}
+
+TEST_F(StorageTextureAccessTest, RWAccess_Fail) {
+ // @group(0) @binding(0)
+ // var a : texture_storage_1d<ru32int, read_write>;
+
+ auto* st =
+ ty.storage_texture(Source{{12, 34}}, ast::TextureDimension::k1d,
+ ast::TexelFormat::kR32Uint, ast::Access::kReadWrite);
+
+ Global("a", st, ast::StorageClass::kNone, nullptr,
+ ast::AttributeList{GroupAndBinding(0, 0)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: storage textures currently only support 'write' "
+ "access control");
+}
+
+TEST_F(StorageTextureAccessTest, ReadOnlyAccess_Fail) {
+ // @group(0) @binding(0)
+ // var a : texture_storage_1d<ru32int, read>;
+
+ auto* st = ty.storage_texture(Source{{12, 34}}, ast::TextureDimension::k1d,
+ ast::TexelFormat::kR32Uint, ast::Access::kRead);
+
+ Global("a", st, ast::StorageClass::kNone, nullptr,
+ ast::AttributeList{GroupAndBinding(0, 0)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: storage textures currently only support 'write' "
+ "access control");
+}
+
+TEST_F(StorageTextureAccessTest, WriteOnlyAccess_Pass) {
+ // @group(0) @binding(0)
+ // var a : texture_storage_1d<ru32int, write>;
+
+ auto* st =
+ ty.storage_texture(ast::TextureDimension::k1d, ast::TexelFormat::kR32Uint,
+ ast::Access::kWrite);
+
+ Global("a", st, ast::StorageClass::kNone, nullptr,
+ ast::AttributeList{GroupAndBinding(0, 0)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+} // namespace StorageTextureTests
+
+namespace MatrixTests {
+struct Params {
+ uint32_t columns;
+ uint32_t rows;
+ builder::ast_type_func_ptr elem_ty;
+};
+
+template <typename T>
+constexpr Params ParamsFor(uint32_t columns, uint32_t rows) {
+ return Params{columns, rows, DataType<T>::AST};
+}
+
+using ValidMatrixTypes = ResolverTestWithParam<Params>;
+TEST_P(ValidMatrixTypes, Okay) {
+ // var a : matNxM<EL_TY>;
+ auto& params = GetParam();
+ Global("a", ty.mat(params.elem_ty(*this), params.columns, params.rows),
+ ast::StorageClass::kPrivate);
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+ ValidMatrixTypes,
+ testing::Values(ParamsFor<f32>(2, 2),
+ ParamsFor<f32>(2, 3),
+ ParamsFor<f32>(2, 4),
+ ParamsFor<f32>(3, 2),
+ ParamsFor<f32>(3, 3),
+ ParamsFor<f32>(3, 4),
+ ParamsFor<f32>(4, 2),
+ ParamsFor<f32>(4, 3),
+ ParamsFor<f32>(4, 4),
+ ParamsFor<alias<f32>>(4, 2),
+ ParamsFor<alias<f32>>(4, 3),
+ ParamsFor<alias<f32>>(4, 4)));
+
+using InvalidMatrixElementTypes = ResolverTestWithParam<Params>;
+TEST_P(InvalidMatrixElementTypes, InvalidElementType) {
+ // var a : matNxM<EL_TY>;
+ auto& params = GetParam();
+ Global("a",
+ ty.mat(Source{{12, 34}}, params.elem_ty(*this), params.columns,
+ params.rows),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: matrix element type must be 'f32'");
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+ InvalidMatrixElementTypes,
+ testing::Values(ParamsFor<bool>(4, 2),
+ ParamsFor<i32>(4, 3),
+ ParamsFor<u32>(4, 4),
+ ParamsFor<vec2<f32>>(2, 2),
+ ParamsFor<vec3<i32>>(2, 3),
+ ParamsFor<vec4<u32>>(2, 4),
+ ParamsFor<mat2x2<f32>>(3, 2),
+ ParamsFor<mat3x3<f32>>(3, 3),
+ ParamsFor<mat4x4<f32>>(3, 4),
+ ParamsFor<array<2, f32>>(4, 2)));
+} // namespace MatrixTests
+
+namespace VectorTests {
+struct Params {
+ uint32_t width;
+ builder::ast_type_func_ptr elem_ty;
+};
+
+template <typename T>
+constexpr Params ParamsFor(uint32_t width) {
+ return Params{width, DataType<T>::AST};
+}
+
+using ValidVectorTypes = ResolverTestWithParam<Params>;
+TEST_P(ValidVectorTypes, Okay) {
+ // var a : vecN<EL_TY>;
+ auto& params = GetParam();
+ Global("a", ty.vec(params.elem_ty(*this), params.width),
+ ast::StorageClass::kPrivate);
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+ ValidVectorTypes,
+ testing::Values(ParamsFor<bool>(2),
+ ParamsFor<f32>(2),
+ ParamsFor<i32>(2),
+ ParamsFor<u32>(2),
+ ParamsFor<bool>(3),
+ ParamsFor<f32>(3),
+ ParamsFor<i32>(3),
+ ParamsFor<u32>(3),
+ ParamsFor<bool>(4),
+ ParamsFor<f32>(4),
+ ParamsFor<i32>(4),
+ ParamsFor<u32>(4),
+ ParamsFor<alias<bool>>(4),
+ ParamsFor<alias<f32>>(4),
+ ParamsFor<alias<i32>>(4),
+ ParamsFor<alias<u32>>(4)));
+
+using InvalidVectorElementTypes = ResolverTestWithParam<Params>;
+TEST_P(InvalidVectorElementTypes, InvalidElementType) {
+ // var a : vecN<EL_TY>;
+ auto& params = GetParam();
+ Global("a", ty.vec(Source{{12, 34}}, params.elem_ty(*this), params.width),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: vector element type must be 'bool', 'f32', 'i32' "
+ "or 'u32'");
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+ InvalidVectorElementTypes,
+ testing::Values(ParamsFor<vec2<f32>>(2),
+ ParamsFor<vec3<i32>>(2),
+ ParamsFor<vec4<u32>>(2),
+ ParamsFor<mat2x2<f32>>(2),
+ ParamsFor<mat3x3<f32>>(2),
+ ParamsFor<mat4x4<f32>>(2),
+ ParamsFor<array<2, f32>>(2)));
+} // namespace VectorTests
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/validation_test.cc b/src/tint/resolver/validation_test.cc
new file mode 100644
index 0000000..9198487
--- /dev/null
+++ b/src/tint/resolver/validation_test.cc
@@ -0,0 +1,1320 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ast/assignment_statement.h"
+#include "src/tint/ast/bitcast_expression.h"
+#include "src/tint/ast/break_statement.h"
+#include "src/tint/ast/builtin_texture_helper_test.h"
+#include "src/tint/ast/call_statement.h"
+#include "src/tint/ast/continue_statement.h"
+#include "src/tint/ast/discard_statement.h"
+#include "src/tint/ast/if_statement.h"
+#include "src/tint/ast/loop_statement.h"
+#include "src/tint/ast/return_statement.h"
+#include "src/tint/ast/stage_attribute.h"
+#include "src/tint/ast/switch_statement.h"
+#include "src/tint/ast/unary_op_expression.h"
+#include "src/tint/ast/variable_decl_statement.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/sampled_texture_type.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/variable.h"
+
+using ::testing::ElementsAre;
+using ::testing::HasSubstr;
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ResolverValidationTest = ResolverTest;
+
+class FakeStmt : public Castable<FakeStmt, ast::Statement> {
+ public:
+ FakeStmt(ProgramID pid, Source src) : Base(pid, src) {}
+ FakeStmt* Clone(CloneContext*) const override { return nullptr; }
+};
+
+class FakeExpr : public Castable<FakeExpr, ast::Expression> {
+ public:
+ FakeExpr(ProgramID pid, Source src) : Base(pid, src) {}
+ FakeExpr* Clone(CloneContext*) const override { return nullptr; }
+};
+
+TEST_F(ResolverValidationTest, WorkgroupMemoryUsedInVertexStage) {
+ Global(Source{{1, 2}}, "wg", ty.vec4<f32>(), ast::StorageClass::kWorkgroup);
+ Global("dst", ty.vec4<f32>(), ast::StorageClass::kPrivate);
+ auto* stmt = Assign(Expr("dst"), Expr(Source{{3, 4}}, "wg"));
+
+ Func(Source{{9, 10}}, "f0", ast::VariableList{}, ty.vec4<f32>(),
+ {stmt, Return(Expr("dst"))},
+ ast::AttributeList{Stage(ast::PipelineStage::kVertex)},
+ ast::AttributeList{Builtin(ast::Builtin::kPosition)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "3:4 error: workgroup memory cannot be used by vertex pipeline "
+ "stage\n1:2 note: variable is declared here");
+}
+
+TEST_F(ResolverValidationTest, WorkgroupMemoryUsedInFragmentStage) {
+ // var<workgroup> wg : vec4<f32>;
+ // var<workgroup> dst : vec4<f32>;
+ // fn f2(){ dst = wg; }
+ // fn f1() { f2(); }
+ // @stage(fragment)
+ // fn f0() {
+ // f1();
+ //}
+
+ Global(Source{{1, 2}}, "wg", ty.vec4<f32>(), ast::StorageClass::kWorkgroup);
+ Global("dst", ty.vec4<f32>(), ast::StorageClass::kPrivate);
+ auto* stmt = Assign(Expr("dst"), Expr(Source{{3, 4}}, "wg"));
+
+ Func(Source{{5, 6}}, "f2", {}, ty.void_(), {stmt});
+ Func(Source{{7, 8}}, "f1", {}, ty.void_(), {CallStmt(Call("f2"))});
+ Func(Source{{9, 10}}, "f0", {}, ty.void_(), {CallStmt(Call("f1"))},
+ ast::AttributeList{Stage(ast::PipelineStage::kFragment)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(3:4 error: workgroup memory cannot be used by fragment pipeline stage
+1:2 note: variable is declared here
+5:6 note: called by function 'f2'
+7:8 note: called by function 'f1'
+9:10 note: called by entry point 'f0')");
+}
+
+TEST_F(ResolverValidationTest, UnhandledStmt) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.WrapInFunction(b.create<FakeStmt>());
+ Program(std::move(b));
+ },
+ "internal compiler error: unhandled node type: tint::resolver::FakeStmt");
+}
+
+TEST_F(ResolverValidationTest, Stmt_If_NonBool) {
+ // if (1.23f) {}
+
+ WrapInFunction(If(Expr(Source{{12, 34}}, 1.23f), Block()));
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "12:34 error: if statement condition must be bool, got f32");
+}
+
+TEST_F(ResolverValidationTest, Stmt_Else_NonBool) {
+ // else (1.23f) {}
+
+ WrapInFunction(
+ If(Expr(true), Block(), Else(Expr(Source{{12, 34}}, 1.23f), Block())));
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "12:34 error: else statement condition must be bool, got f32");
+}
+
+TEST_F(ResolverValidationTest, Expr_ErrUnknownExprType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.WrapInFunction(b.create<FakeExpr>());
+ Resolver(&b).Resolve();
+ },
+ "internal compiler error: unhandled expression type: "
+ "tint::resolver::FakeExpr");
+}
+
+TEST_F(ResolverValidationTest, Expr_DontCall_Function) {
+ Func("func", {}, ty.void_(), {}, {});
+ WrapInFunction(Expr(Source{{{3, 3}, {3, 8}}}, "func"));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "3:8 error: missing '(' for function call");
+}
+
+TEST_F(ResolverValidationTest, Expr_DontCall_Builtin) {
+ WrapInFunction(Expr(Source{{{3, 3}, {3, 8}}}, "round"));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "3:8 error: missing '(' for builtin call");
+}
+
+TEST_F(ResolverValidationTest, Expr_DontCall_Type) {
+ Alias("T", ty.u32());
+ WrapInFunction(Expr(Source{{{3, 3}, {3, 8}}}, "T"));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "3:8 error: missing '(' for type constructor or cast");
+}
+
+TEST_F(ResolverValidationTest, AssignmentStmt_InvalidLHS_BuiltinFunctionName) {
+ // normalize = 2;
+
+ auto* lhs = Expr(Source{{12, 34}}, "normalize");
+ auto* rhs = Expr(2);
+ auto* assign = Assign(lhs, rhs);
+ WrapInFunction(assign);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: missing '(' for builtin call");
+}
+
+TEST_F(ResolverValidationTest, UsingUndefinedVariable_Fail) {
+ // b = 2;
+
+ auto* lhs = Expr(Source{{12, 34}}, "b");
+ auto* rhs = Expr(2);
+ auto* assign = Assign(lhs, rhs);
+ WrapInFunction(assign);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: unknown identifier: 'b'");
+}
+
+TEST_F(ResolverValidationTest, UsingUndefinedVariableInBlockStatement_Fail) {
+ // {
+ // b = 2;
+ // }
+
+ auto* lhs = Expr(Source{{12, 34}}, "b");
+ auto* rhs = Expr(2);
+
+ auto* body = Block(Assign(lhs, rhs));
+ WrapInFunction(body);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: unknown identifier: 'b'");
+}
+
+TEST_F(ResolverValidationTest, UsingUndefinedVariableGlobalVariable_Pass) {
+ // var global_var: f32 = 2.1;
+ // fn my_func() {
+ // global_var = 3.14;
+ // return;
+ // }
+
+ Global("global_var", ty.f32(), ast::StorageClass::kPrivate, Expr(2.1f));
+
+ Func("my_func", ast::VariableList{}, ty.void_(),
+ {
+ Assign(Expr(Source{{12, 34}}, "global_var"), 3.14f),
+ Return(),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverValidationTest, UsingUndefinedVariableInnerScope_Fail) {
+ // {
+ // if (true) { var a : f32 = 2.0; }
+ // a = 3.14;
+ // }
+ auto* var = Var("a", ty.f32(), ast::StorageClass::kNone, Expr(2.0f));
+
+ auto* cond = Expr(true);
+ auto* body = Block(Decl(var));
+
+ SetSource(Source{{12, 34}});
+ auto* lhs = Expr(Source{{12, 34}}, "a");
+ auto* rhs = Expr(3.14f);
+
+ auto* outer_body =
+ Block(create<ast::IfStatement>(cond, body, ast::ElseStatementList{}),
+ Assign(lhs, rhs));
+
+ WrapInFunction(outer_body);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: unknown identifier: 'a'");
+}
+
+TEST_F(ResolverValidationTest, UsingUndefinedVariableOuterScope_Pass) {
+ // {
+ // var a : f32 = 2.0;
+ // if (true) { a = 3.14; }
+ // }
+ auto* var = Var("a", ty.f32(), ast::StorageClass::kNone, Expr(2.0f));
+
+ auto* lhs = Expr(Source{{12, 34}}, "a");
+ auto* rhs = Expr(3.14f);
+
+ auto* cond = Expr(true);
+ auto* body = Block(Assign(lhs, rhs));
+
+ auto* outer_body =
+ Block(Decl(var),
+ create<ast::IfStatement>(cond, body, ast::ElseStatementList{}));
+
+ WrapInFunction(outer_body);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverValidationTest, UsingUndefinedVariableDifferentScope_Fail) {
+ // {
+ // { var a : f32 = 2.0; }
+ // { a = 3.14; }
+ // }
+ auto* var = Var("a", ty.f32(), ast::StorageClass::kNone, Expr(2.0f));
+ auto* first_body = Block(Decl(var));
+
+ auto* lhs = Expr(Source{{12, 34}}, "a");
+ auto* rhs = Expr(3.14f);
+ auto* second_body = Block(Assign(lhs, rhs));
+
+ auto* outer_body = Block(first_body, second_body);
+
+ WrapInFunction(outer_body);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: unknown identifier: 'a'");
+}
+
+TEST_F(ResolverValidationTest, StorageClass_FunctionVariableWorkgroupClass) {
+ auto* var = Var("var", ty.i32(), ast::StorageClass::kWorkgroup);
+
+ auto* stmt = Decl(var);
+ Func("func", ast::VariableList{}, ty.void_(), {stmt}, ast::AttributeList{});
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "error: function variable has a non-function storage class");
+}
+
+TEST_F(ResolverValidationTest, StorageClass_FunctionVariableI32) {
+ auto* var = Var("s", ty.i32(), ast::StorageClass::kPrivate);
+
+ auto* stmt = Decl(var);
+ Func("func", ast::VariableList{}, ty.void_(), {stmt}, ast::AttributeList{});
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "error: function variable has a non-function storage class");
+}
+
+TEST_F(ResolverValidationTest, StorageClass_SamplerExplicitStorageClass) {
+ auto* t = ty.sampler(ast::SamplerKind::kSampler);
+ Global(Source{{12, 34}}, "var", t, ast::StorageClass::kUniformConstant,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: variables of type 'sampler' must not have a storage class)");
+}
+
+TEST_F(ResolverValidationTest, StorageClass_TextureExplicitStorageClass) {
+ auto* t = ty.sampled_texture(ast::TextureDimension::k1d, ty.f32());
+ Global(Source{{12, 34}}, "var", t, ast::StorageClass::kUniformConstant,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ EXPECT_FALSE(r()->Resolve()) << r()->error();
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: variables of type 'texture_1d<f32>' must not have a storage class)");
+}
+
+TEST_F(ResolverValidationTest, Expr_MemberAccessor_VectorSwizzle_BadChar) {
+ Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kPrivate);
+
+ auto* ident = Expr(Source{{{3, 3}, {3, 7}}}, "xyqz");
+
+ auto* mem = MemberAccessor("my_vec", ident);
+ WrapInFunction(mem);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "3:5 error: invalid vector swizzle character");
+}
+
+TEST_F(ResolverValidationTest, Expr_MemberAccessor_VectorSwizzle_MixedChars) {
+ Global("my_vec", ty.vec4<f32>(), ast::StorageClass::kPrivate);
+
+ auto* ident = Expr(Source{{{3, 3}, {3, 7}}}, "rgyw");
+
+ auto* mem = MemberAccessor("my_vec", ident);
+ WrapInFunction(mem);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "3:3 error: invalid mixing of vector swizzle characters rgba with xyzw");
+}
+
+TEST_F(ResolverValidationTest, Expr_MemberAccessor_VectorSwizzle_BadLength) {
+ Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kPrivate);
+
+ auto* ident = Expr(Source{{{3, 3}, {3, 8}}}, "zzzzz");
+ auto* mem = MemberAccessor("my_vec", ident);
+ WrapInFunction(mem);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "3:3 error: invalid vector swizzle size");
+}
+
+TEST_F(ResolverValidationTest, Expr_MemberAccessor_VectorSwizzle_BadIndex) {
+ Global("my_vec", ty.vec2<f32>(), ast::StorageClass::kPrivate);
+
+ auto* ident = Expr(Source{{3, 3}}, "z");
+ auto* mem = MemberAccessor("my_vec", ident);
+ WrapInFunction(mem);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "3:3 error: invalid vector swizzle member");
+}
+
+TEST_F(ResolverValidationTest, Expr_MemberAccessor_BadParent) {
+ // var param: vec4<f32>
+ // let ret: f32 = *(¶m).x;
+ auto* param = Var("param", ty.vec4<f32>());
+ auto* x = Expr(Source{{{3, 3}, {3, 8}}}, "x");
+
+ auto* addressOf_expr = AddressOf(Source{{12, 34}}, param);
+ auto* accessor_expr = MemberAccessor(addressOf_expr, x);
+ auto* star_p = Deref(accessor_expr);
+ auto* ret = Var("r", ty.f32(), star_p);
+ WrapInFunction(Decl(param), Decl(ret));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: invalid member accessor expression. Expected vector "
+ "or struct, got 'ptr<function, vec4<f32>, read_write>'");
+}
+
+TEST_F(ResolverValidationTest, EXpr_MemberAccessor_FuncGoodParent) {
+ // fn func(p: ptr<function, vec4<f32>>) -> f32 {
+ // let x: f32 = (*p).z;
+ // return x;
+ // }
+ auto* p =
+ Param("p", ty.pointer(ty.vec4<f32>(), ast::StorageClass::kFunction));
+ auto* star_p = Deref(p);
+ auto* z = Expr(Source{{{3, 3}, {3, 8}}}, "z");
+ auto* accessor_expr = MemberAccessor(star_p, z);
+ auto* x = Var("x", ty.f32(), accessor_expr);
+ Func("func", {p}, ty.f32(), {Decl(x), Return(x)});
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverValidationTest, EXpr_MemberAccessor_FuncBadParent) {
+ // fn func(p: ptr<function, vec4<f32>>) -> f32 {
+ // let x: f32 = *p.z;
+ // return x;
+ // }
+ auto* p =
+ Param("p", ty.pointer(ty.vec4<f32>(), ast::StorageClass::kFunction));
+ auto* z = Expr(Source{{{3, 3}, {3, 8}}}, "z");
+ auto* accessor_expr = MemberAccessor(p, z);
+ auto* star_p = Deref(accessor_expr);
+ auto* x = Var("x", ty.f32(), star_p);
+ Func("func", {p}, ty.f32(), {Decl(x), Return(x)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "error: invalid member accessor expression. "
+ "Expected vector or struct, got 'ptr<function, vec4<f32>, read_write>'");
+}
+
+TEST_F(ResolverValidationTest,
+ Stmt_Loop_ContinueInLoopBodyBeforeDeclAndAfterDecl_UsageInContinuing) {
+ // loop {
+ // continue; // Bypasses z decl
+ // var z : i32; // unreachable
+ //
+ // continuing {
+ // z = 2;
+ // }
+ // }
+
+ auto error_loc = Source{{12, 34}};
+ auto* body =
+ Block(Continue(),
+ Decl(error_loc, Var("z", ty.i32(), ast::StorageClass::kNone)));
+ auto* continuing = Block(Assign(Expr("z"), 2));
+ auto* loop_stmt = Loop(body, continuing);
+ WrapInFunction(loop_stmt);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 warning: code is unreachable
+error: continue statement bypasses declaration of 'z'
+note: identifier 'z' declared here
+note: identifier 'z' referenced in continuing block here)");
+}
+
+TEST_F(ResolverValidationTest,
+ Stmt_Loop_ContinueInLoopBodyAfterDecl_UsageInContinuing_InBlocks) {
+ // loop {
+ // if (false) { break; }
+ // var z : i32;
+ // {{{continue;}}}
+ // continue; // Ok
+ //
+ // continuing {
+ // z = 2;
+ // }
+ // }
+
+ auto* body = Block(If(false, Block(Break())), //
+ Decl(Var("z", ty.i32(), ast::StorageClass::kNone)),
+ Block(Block(Block(Continue()))));
+ auto* continuing = Block(Assign(Expr("z"), 2));
+ auto* loop_stmt = Loop(body, continuing);
+ WrapInFunction(loop_stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverValidationTest,
+ Stmt_Loop_ContinueInLoopBodySubscopeBeforeDecl_UsageInContinuing) {
+ // loop {
+ // if (true) {
+ // continue; // Still bypasses z decl (if we reach here)
+ // }
+ // var z : i32;
+ // continuing {
+ // z = 2;
+ // }
+ // }
+
+ auto cont_loc = Source{{12, 34}};
+ auto decl_loc = Source{{56, 78}};
+ auto ref_loc = Source{{90, 12}};
+ auto* body =
+ Block(If(Expr(true), Block(Continue(cont_loc))),
+ Decl(Var(decl_loc, "z", ty.i32(), ast::StorageClass::kNone)));
+ auto* continuing = Block(Assign(Expr(ref_loc, "z"), 2));
+ auto* loop_stmt = Loop(body, continuing);
+ WrapInFunction(loop_stmt);
+
+ EXPECT_FALSE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: continue statement bypasses declaration of 'z'
+56:78 note: identifier 'z' declared here
+90:12 note: identifier 'z' referenced in continuing block here)");
+}
+
+TEST_F(
+ ResolverValidationTest,
+ Stmt_Loop_ContinueInLoopBodySubscopeBeforeDecl_UsageInContinuingSubscope) {
+ // loop {
+ // if (true) {
+ // continue; // Still bypasses z decl (if we reach here)
+ // }
+ // var z : i32;
+ // continuing {
+ // if (true) {
+ // z = 2; // Must fail even if z is in a sub-scope
+ // }
+ // }
+ // }
+
+ auto cont_loc = Source{{12, 34}};
+ auto decl_loc = Source{{56, 78}};
+ auto ref_loc = Source{{90, 12}};
+ auto* body =
+ Block(If(Expr(true), Block(Continue(cont_loc))),
+ Decl(Var(decl_loc, "z", ty.i32(), ast::StorageClass::kNone)));
+
+ auto* continuing =
+ Block(If(Expr(true), Block(Assign(Expr(ref_loc, "z"), 2))));
+ auto* loop_stmt = Loop(body, continuing);
+ WrapInFunction(loop_stmt);
+
+ EXPECT_FALSE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: continue statement bypasses declaration of 'z'
+56:78 note: identifier 'z' declared here
+90:12 note: identifier 'z' referenced in continuing block here)");
+}
+
+TEST_F(ResolverValidationTest,
+ Stmt_Loop_ContinueInLoopBodySubscopeBeforeDecl_UsageOutsideBlock) {
+ // loop {
+ // if (true) {
+ // continue; // bypasses z decl (if we reach here)
+ // }
+ // var z : i32;
+ // continuing {
+ // // Must fail even if z is used in an expression that isn't
+ // // directly contained inside a block.
+ // if (z < 2) {
+ // }
+ // }
+ // }
+
+ auto cont_loc = Source{{12, 34}};
+ auto decl_loc = Source{{56, 78}};
+ auto ref_loc = Source{{90, 12}};
+ auto* body =
+ Block(If(Expr(true), Block(Continue(cont_loc))),
+ Decl(Var(decl_loc, "z", ty.i32(), ast::StorageClass::kNone)));
+ auto* compare = create<ast::BinaryExpression>(ast::BinaryOp::kLessThan,
+ Expr(ref_loc, "z"), Expr(2));
+ auto* continuing = Block(If(compare, Block()));
+ auto* loop_stmt = Loop(body, continuing);
+ WrapInFunction(loop_stmt);
+
+ EXPECT_FALSE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: continue statement bypasses declaration of 'z'
+56:78 note: identifier 'z' declared here
+90:12 note: identifier 'z' referenced in continuing block here)");
+}
+
+TEST_F(ResolverValidationTest,
+ Stmt_Loop_ContinueInLoopBodySubscopeBeforeDecl_UsageInContinuingLoop) {
+ // loop {
+ // if (true) {
+ // continue; // Still bypasses z decl (if we reach here)
+ // }
+ // var z : i32;
+ // continuing {
+ // loop {
+ // z = 2; // Must fail even if z is in a sub-scope
+ // }
+ // }
+ // }
+
+ auto cont_loc = Source{{12, 34}};
+ auto decl_loc = Source{{56, 78}};
+ auto ref_loc = Source{{90, 12}};
+ auto* body =
+ Block(If(Expr(true), Block(Continue(cont_loc))),
+ Decl(Var(decl_loc, "z", ty.i32(), ast::StorageClass::kNone)));
+
+ auto* continuing = Block(Loop(Block(Assign(Expr(ref_loc, "z"), 2))));
+ auto* loop_stmt = Loop(body, continuing);
+ WrapInFunction(loop_stmt);
+
+ EXPECT_FALSE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: continue statement bypasses declaration of 'z'
+56:78 note: identifier 'z' declared here
+90:12 note: identifier 'z' referenced in continuing block here)");
+}
+
+TEST_F(ResolverValidationTest,
+ Stmt_Loop_ContinueInNestedLoopBodyBeforeDecl_UsageInContinuing) {
+ // loop {
+ // loop {
+ // if (true) { continue; } // OK: not part of the outer loop
+ // break;
+ // }
+ // var z : i32;
+ // break;
+ // continuing {
+ // z = 2;
+ // }
+ // }
+
+ auto* inner_loop = Loop(Block( //
+ If(true, Block(Continue())), //
+ Break()));
+ auto* body = Block(inner_loop, //
+ Decl(Var("z", ty.i32(), ast::StorageClass::kNone)), //
+ Break());
+ auto* continuing = Block(Assign("z", 2));
+ auto* loop_stmt = Loop(body, continuing);
+ WrapInFunction(loop_stmt);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverValidationTest,
+ Stmt_Loop_ContinueInNestedLoopBodyBeforeDecl_UsageInContinuingSubscope) {
+ // loop {
+ // loop {
+ // if (true) { continue; } // OK: not part of the outer loop
+ // break;
+ // }
+ // var z : i32;
+ // break;
+ // continuing {
+ // if (true) {
+ // z = 2;
+ // }
+ // }
+ // }
+
+ auto* inner_loop = Loop(Block(If(true, Block(Continue())), //
+ Break()));
+ auto* body = Block(inner_loop, //
+ Decl(Var("z", ty.i32(), ast::StorageClass::kNone)), //
+ Break());
+ auto* continuing = Block(If(Expr(true), Block(Assign("z", 2))));
+ auto* loop_stmt = Loop(body, continuing);
+ WrapInFunction(loop_stmt);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverValidationTest,
+ Stmt_Loop_ContinueInNestedLoopBodyBeforeDecl_UsageInContinuingLoop) {
+ // loop {
+ // loop {
+ // if (true) { continue; } // OK: not part of the outer loop
+ // break;
+ // }
+ // var z : i32;
+ // break;
+ // continuing {
+ // loop {
+ // z = 2;
+ // break;
+ // }
+ // }
+ // }
+
+ auto* inner_loop = Loop(Block(If(true, Block(Continue())), //
+ Break()));
+ auto* body = Block(inner_loop, //
+ Decl(Var("z", ty.i32(), ast::StorageClass::kNone)), //
+ Break());
+ auto* continuing = Block(Loop(Block(Assign("z", 2), //
+ Break())));
+ auto* loop_stmt = Loop(body, continuing);
+ WrapInFunction(loop_stmt);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTest, Stmt_Loop_ContinueInLoopBodyAfterDecl_UsageInContinuing) {
+ // loop {
+ // var z : i32;
+ // if (true) { continue; }
+ // break;
+ // continuing {
+ // z = 2;
+ // }
+ // }
+
+ auto error_loc = Source{{12, 34}};
+ auto* body = Block(Decl(Var("z", ty.i32(), ast::StorageClass::kNone)),
+ If(true, Block(Continue())), //
+ Break());
+ auto* continuing = Block(Assign(Expr(error_loc, "z"), 2));
+ auto* loop_stmt = Loop(body, continuing);
+ WrapInFunction(loop_stmt);
+
+ EXPECT_TRUE(r()->Resolve());
+}
+
+TEST_F(ResolverTest, Stmt_Loop_ReturnInContinuing_Direct) {
+ // loop {
+ // continuing {
+ // return;
+ // }
+ // }
+
+ WrapInFunction(Loop( // loop
+ Block(), // loop block
+ Block( // loop continuing block
+ Return(Source{{12, 34}}))));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: continuing blocks must not contain a return statement)");
+}
+
+TEST_F(ResolverTest, Stmt_Loop_ReturnInContinuing_Indirect) {
+ // loop {
+ // if (false) { break; }
+ // continuing {
+ // loop {
+ // return;
+ // }
+ // }
+ // }
+
+ WrapInFunction(Loop( // outer loop
+ Block(If(false, Block(Break()))), // outer loop block
+ Block(Source{{56, 78}}, // outer loop continuing block
+ Loop( // inner loop
+ Block( // inner loop block
+ Return(Source{{12, 34}}))))));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: continuing blocks must not contain a return statement
+56:78 note: see continuing block here)");
+}
+
+TEST_F(ResolverTest, Stmt_Loop_DiscardInContinuing_Direct) {
+ // loop {
+ // continuing {
+ // discard;
+ // }
+ // }
+
+ WrapInFunction(Loop( // loop
+ Block(), // loop block
+ Block( // loop continuing block
+ Discard(Source{{12, 34}}))));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: continuing blocks must not contain a discard statement)");
+}
+
+TEST_F(ResolverTest, Stmt_Loop_DiscardInContinuing_Indirect) {
+ // loop {
+ // if (false) { break; }
+ // continuing {
+ // loop { discard; }
+ // }
+ // }
+
+ WrapInFunction(Loop( // outer loop
+ Block(If(false, Block(Break()))), // outer loop block
+ Block(Source{{56, 78}}, // outer loop continuing block
+ Loop( // inner loop
+ Block( // inner loop block
+ Discard(Source{{12, 34}}))))));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: continuing blocks must not contain a discard statement
+56:78 note: see continuing block here)");
+}
+
+TEST_F(ResolverTest, Stmt_Loop_DiscardInContinuing_Indirect_ViaCall) {
+ // fn MayDiscard() { if (true) { discard; } }
+ // fn F() { MayDiscard(); }
+ // loop {
+ // continuing {
+ // loop { F(); }
+ // }
+ // }
+
+ Func("MayDiscard", {}, ty.void_(), {If(true, Block(Discard()))});
+ Func("SomeFunc", {}, ty.void_(), {CallStmt(Call("MayDiscard"))});
+
+ WrapInFunction(Loop( // outer loop
+ Block(), // outer loop block
+ Block(Source{{56, 78}}, // outer loop continuing block
+ Loop( // inner loop
+ Block( // inner loop block
+ CallStmt(Call(Source{{12, 34}}, "SomeFunc")))))));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: cannot call a function that may discard inside a continuing block
+56:78 note: see continuing block here)");
+}
+
+TEST_F(ResolverTest, Stmt_Loop_ContinueInContinuing_Direct) {
+ // loop {
+ // continuing {
+ // continue;
+ // }
+ // }
+
+ WrapInFunction(Loop( // loop
+ Block(), // loop block
+ Block(Source{{56, 78}}, // loop continuing block
+ Continue(Source{{12, 34}}))));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: continuing blocks must not contain a continue statement");
+}
+
+TEST_F(ResolverTest, Stmt_Loop_ContinueInContinuing_Indirect) {
+ // loop {
+ // if (false) { break; }
+ // continuing {
+ // loop {
+ // if (false) { break; }
+ // continue;
+ // }
+ // }
+ // }
+
+ WrapInFunction(Loop( // outer loop
+ Block( // outer loop block
+ If(false, Block(Break()))), // if (false) { break; }
+ Block( // outer loop continuing block
+ Loop( // inner loop
+ Block( // inner loop block
+ If(false, Block(Break())), // if (false) { break; }
+ Continue(Source{{12, 34}})))))); // continue
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTest, Stmt_ForLoop_ReturnInContinuing_Direct) {
+ // for(;; return) {
+ // break;
+ // }
+
+ WrapInFunction(For(nullptr, nullptr, Return(Source{{12, 34}}), //
+ Block(Break())));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: continuing blocks must not contain a return statement)");
+}
+
+TEST_F(ResolverTest, Stmt_ForLoop_ReturnInContinuing_Indirect) {
+ // for(;; loop { return }) {
+ // break;
+ // }
+
+ WrapInFunction(For(nullptr, nullptr,
+ Loop(Source{{56, 78}}, //
+ Block(Return(Source{{12, 34}}))), //
+ Block(Break())));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: continuing blocks must not contain a return statement
+56:78 note: see continuing block here)");
+}
+
+TEST_F(ResolverTest, Stmt_ForLoop_DiscardInContinuing_Direct) {
+ // for(;; discard) {
+ // break;
+ // }
+
+ WrapInFunction(For(nullptr, nullptr, Discard(Source{{12, 34}}), //
+ Block(Break())));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: continuing blocks must not contain a discard statement)");
+}
+
+TEST_F(ResolverTest, Stmt_ForLoop_DiscardInContinuing_Indirect) {
+ // for(;; loop { discard }) {
+ // break;
+ // }
+
+ WrapInFunction(For(nullptr, nullptr,
+ Loop(Source{{56, 78}}, //
+ Block(Discard(Source{{12, 34}}))), //
+ Block(Break())));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: continuing blocks must not contain a discard statement
+56:78 note: see continuing block here)");
+}
+
+TEST_F(ResolverTest, Stmt_ForLoop_DiscardInContinuing_Indirect_ViaCall) {
+ // fn MayDiscard() { if (true) { discard; } }
+ // fn F() { MayDiscard(); }
+ // for(;; loop { F() }) {
+ // break;
+ // }
+
+ Func("MayDiscard", {}, ty.void_(), {If(true, Block(Discard()))});
+ Func("F", {}, ty.void_(), {CallStmt(Call("MayDiscard"))});
+
+ WrapInFunction(For(nullptr, nullptr,
+ Loop(Source{{56, 78}}, //
+ Block(CallStmt(Call(Source{{12, 34}}, "F")))), //
+ Block(Break())));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: cannot call a function that may discard inside a continuing block
+56:78 note: see continuing block here)");
+}
+
+TEST_F(ResolverTest, Stmt_ForLoop_ContinueInContinuing_Direct) {
+ // for(;; continue) {
+ // break;
+ // }
+
+ WrapInFunction(For(nullptr, nullptr, Continue(Source{{12, 34}}), //
+ Block(Break())));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: continuing blocks must not contain a continue statement");
+}
+
+TEST_F(ResolverTest, Stmt_ForLoop_ContinueInContinuing_Indirect) {
+ // for(;; loop { if (false) { break; } continue }) {
+ // break;
+ // }
+
+ WrapInFunction(For(nullptr, nullptr,
+ Loop( //
+ Block(If(false, Block(Break())), //
+ Continue(Source{{12, 34}}))), //
+ Block(Break())));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTest, Stmt_ForLoop_CondIsBoolRef) {
+ // var cond : bool = true;
+ // for (; cond; ) {
+ // }
+
+ auto* cond = Var("cond", ty.bool_(), Expr(true));
+ WrapInFunction(Decl(cond), For(nullptr, "cond", nullptr, Block()));
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTest, Stmt_ForLoop_CondIsNotBool) {
+ // for (; 1.0f; ) {
+ // }
+
+ WrapInFunction(For(nullptr, Expr(Source{{12, 34}}, 1.0f), nullptr, Block()));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: for-loop condition must be bool, got f32");
+}
+
+TEST_F(ResolverValidationTest, Stmt_ContinueInLoop) {
+ WrapInFunction(Loop(Block(If(false, Block(Break())), //
+ Continue(Source{{12, 34}}))));
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverValidationTest, Stmt_ContinueNotInLoop) {
+ WrapInFunction(Continue(Source{{12, 34}}));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: continue statement must be in a loop");
+}
+
+TEST_F(ResolverValidationTest, Stmt_BreakInLoop) {
+ WrapInFunction(Loop(Block(Break(Source{{12, 34}}))));
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverValidationTest, Stmt_BreakInSwitch) {
+ WrapInFunction(Loop(Block(Switch(Expr(1), //
+ Case(Expr(1), //
+ Block(Break())), //
+ DefaultCase()), //
+ Break()))); //
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverValidationTest, Stmt_BreakInIfTrueInContinuing) {
+ auto* cont = Block( // continuing {
+ If(true, Block( // if(true) {
+ Break(Source{{12, 34}})))); // break;
+ // }
+ // }
+ WrapInFunction(Loop(Block(), cont));
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverValidationTest, Stmt_BreakInIfElseInContinuing) {
+ auto* cont = Block( // continuing {
+ If(true, Block(), // if(true) {
+ Else(Block( // } else {
+ Break(Source{{12, 34}}))))); // break;
+ // }
+ // }
+ WrapInFunction(Loop(Block(), cont));
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverValidationTest, Stmt_BreakInContinuing) {
+ auto* cont = Block( // continuing {
+ Block(Break(Source{{12, 34}}))); // break;
+ // }
+ WrapInFunction(Loop(Block(), cont));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: break statement in a continuing block must be the single "
+ "statement of an if statement's true or false block, and that if "
+ "statement must be the last statement of the continuing block\n"
+ "12:34 note: break statement is not directly in if statement block");
+}
+
+TEST_F(ResolverValidationTest, Stmt_BreakInIfInIfInContinuing) {
+ auto* cont = Block( // continuing {
+ If(true, Block( // if(true) {
+ If(Source{{56, 78}}, true, // if(true) {
+ Block(Break(Source{{12, 34}})))))); // break;
+ // }
+ // }
+ // }
+ WrapInFunction(Loop(Block(), cont));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: break statement in a continuing block must be the single "
+ "statement of an if statement's true or false block, and that if "
+ "statement must be the last statement of the continuing block\n"
+ "56:78 note: if statement containing break statement is not directly in "
+ "continuing block");
+}
+
+TEST_F(ResolverValidationTest, Stmt_BreakInIfTrueMultipleStmtsInContinuing) {
+ auto* cont = Block( // continuing {
+ If(true, Block(Source{{56, 78}}, // if(true) {
+ Assign(Phony(), 1), // _ = 1;
+ Break(Source{{12, 34}})))); // break;
+ // }
+ // }
+ WrapInFunction(Loop(Block(), cont));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: break statement in a continuing block must be the single "
+ "statement of an if statement's true or false block, and that if "
+ "statement must be the last statement of the continuing block\n"
+ "56:78 note: if statement block contains multiple statements");
+}
+
+TEST_F(ResolverValidationTest, Stmt_BreakInIfElseMultipleStmtsInContinuing) {
+ auto* cont = Block( // continuing {
+ If(true, Block(), // if(true) {
+ Else(Block(Source{{56, 78}}, // } else {
+ Assign(Phony(), 1), // _ = 1;
+ Break(Source{{12, 34}}))))); // break;
+ // }
+ // }
+ WrapInFunction(Loop(Block(), cont));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: break statement in a continuing block must be the single "
+ "statement of an if statement's true or false block, and that if "
+ "statement must be the last statement of the continuing block\n"
+ "56:78 note: if statement block contains multiple statements");
+}
+
+TEST_F(ResolverValidationTest, Stmt_BreakInIfElseIfInContinuing) {
+ auto* cont = Block( // continuing {
+ If(true, Block(), // if(true) {
+ Else(Expr(Source{{56, 78}}, true), // } else if (true) {
+ Block(Break(Source{{12, 34}}))))); // break;
+ // }
+ // }
+ WrapInFunction(Loop(Block(), cont));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: break statement in a continuing block must be the single "
+ "statement of an if statement's true or false block, and that if "
+ "statement must be the last statement of the continuing block\n"
+ "56:78 note: else has condition");
+}
+
+TEST_F(ResolverValidationTest, Stmt_BreakInIfNonEmptyElseInContinuing) {
+ auto* cont = Block( // continuing {
+ If(true, // if(true) {
+ Block(Break(Source{{12, 34}})), // break;
+ Else(Block(Source{{56, 78}}, // } else {
+ Assign(Phony(), 1))))); // _ = 1;
+ // }
+ // }
+ WrapInFunction(Loop(Block(), cont));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: break statement in a continuing block must be the single "
+ "statement of an if statement's true or false block, and that if "
+ "statement must be the last statement of the continuing block\n"
+ "56:78 note: non-empty false block");
+}
+
+TEST_F(ResolverValidationTest, Stmt_BreakInIfElseNonEmptyTrueInContinuing) {
+ auto* cont = Block( // continuing {
+ If(true, // if(true) {
+ Block(Source{{56, 78}}, Assign(Phony(), 1)), // _ = 1;
+ Else(Block( // } else {
+ Break(Source{{12, 34}}))))); // break;
+ // }
+ // }
+ WrapInFunction(Loop(Block(), cont));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: break statement in a continuing block must be the single "
+ "statement of an if statement's true or false block, and that if "
+ "statement must be the last statement of the continuing block\n"
+ "56:78 note: non-empty true block");
+}
+
+TEST_F(ResolverValidationTest, Stmt_BreakInIfInContinuingNotLast) {
+ auto* cont = Block( // continuing {
+ If(Source{{56, 78}}, true, // if(true) {
+ Block(Break(Source{{12, 34}}))), // break;
+ // }
+ Assign(Phony(), 1)); // _ = 1;
+ // }
+ WrapInFunction(Loop(Block(), cont));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: break statement in a continuing block must be the single "
+ "statement of an if statement's true or false block, and that if "
+ "statement must be the last statement of the continuing block\n"
+ "56:78 note: if statement containing break statement is not the last "
+ "statement of the continuing block");
+}
+
+TEST_F(ResolverValidationTest, Stmt_BreakNotInLoopOrSwitch) {
+ WrapInFunction(Break(Source{{12, 34}}));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: break statement must be in a loop or switch case");
+}
+
+TEST_F(ResolverValidationTest, StructMemberDuplicateName) {
+ Structure("S", {Member(Source{{12, 34}}, "a", ty.i32()),
+ Member(Source{{56, 78}}, "a", ty.i32())});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "56:78 error: redefinition of 'a'\n12:34 note: previous definition "
+ "is here");
+}
+TEST_F(ResolverValidationTest, StructMemberDuplicateNameDifferentTypes) {
+ Structure("S", {Member(Source{{12, 34}}, "a", ty.bool_()),
+ Member(Source{{12, 34}}, "a", ty.vec3<f32>())});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: redefinition of 'a'\n12:34 note: previous definition "
+ "is here");
+}
+TEST_F(ResolverValidationTest, StructMemberDuplicateNamePass) {
+ Structure("S", {Member("a", ty.i32()), Member("b", ty.f32())});
+ Structure("S1", {Member("a", ty.i32()), Member("b", ty.f32())});
+ EXPECT_TRUE(r()->Resolve());
+}
+
+TEST_F(ResolverValidationTest, NonPOTStructMemberAlignAttribute) {
+ Structure("S", {
+ Member("a", ty.f32(), {MemberAlign(Source{{12, 34}}, 3)}),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: align value must be a positive, power-of-two integer");
+}
+
+TEST_F(ResolverValidationTest, ZeroStructMemberAlignAttribute) {
+ Structure("S", {
+ Member("a", ty.f32(), {MemberAlign(Source{{12, 34}}, 0)}),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: align value must be a positive, power-of-two integer");
+}
+
+TEST_F(ResolverValidationTest, ZeroStructMemberSizeAttribute) {
+ Structure("S", {
+ Member("a", ty.f32(), {MemberSize(Source{{12, 34}}, 0)}),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: size must be at least as big as the type's size (4)");
+}
+
+TEST_F(ResolverValidationTest, OffsetAndSizeAttribute) {
+ Structure("S", {
+ Member(Source{{12, 34}}, "a", ty.f32(),
+ {MemberOffset(0), MemberSize(4)}),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: offset attributes cannot be used with align or size "
+ "attributes");
+}
+
+TEST_F(ResolverValidationTest, OffsetAndAlignAttribute) {
+ Structure("S", {
+ Member(Source{{12, 34}}, "a", ty.f32(),
+ {MemberOffset(0), MemberAlign(4)}),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: offset attributes cannot be used with align or size "
+ "attributes");
+}
+
+TEST_F(ResolverValidationTest, OffsetAndAlignAndSizeAttribute) {
+ Structure("S", {
+ Member(Source{{12, 34}}, "a", ty.f32(),
+ {MemberOffset(0), MemberAlign(4), MemberSize(4)}),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: offset attributes cannot be used with align or size "
+ "attributes");
+}
+
+TEST_F(ResolverTest, Expr_Constructor_Cast_Pointer) {
+ auto* vf = Var("vf", ty.f32());
+ auto* c =
+ Construct(Source{{12, 34}}, ty.pointer<i32>(ast::StorageClass::kFunction),
+ ExprList(vf));
+ auto* ip = Const("ip", ty.pointer<i32>(ast::StorageClass::kFunction), c);
+ WrapInFunction(Decl(vf), Decl(ip));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: type is not constructible");
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
+
+TINT_INSTANTIATE_TYPEINFO(tint::resolver::FakeStmt);
+TINT_INSTANTIATE_TYPEINFO(tint::resolver::FakeExpr);
diff --git a/src/tint/resolver/var_let_test.cc b/src/tint/resolver/var_let_test.cc
new file mode 100644
index 0000000..79f2f0a
--- /dev/null
+++ b/src/tint/resolver/var_let_test.cc
@@ -0,0 +1,701 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/struct_block_attribute.h"
+#include "src/tint/resolver/resolver.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/reference_type.h"
+
+#include "gmock/gmock.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+struct ResolverVarLetTest : public resolver::TestHelper,
+ public testing::Test {};
+
+TEST_F(ResolverVarLetTest, VarDeclWithoutConstructor) {
+ // struct S { i : i32; }
+ // alias A = S;
+ // fn F(){
+ // var i : i32;
+ // var u : u32;
+ // var f : f32;
+ // var b : bool;
+ // var s : S;
+ // var a : A;
+ // }
+
+ auto* S = Structure("S", {Member("i", ty.i32())});
+ auto* A = Alias("A", ty.Of(S));
+
+ auto* i = Var("i", ty.i32(), ast::StorageClass::kNone);
+ auto* u = Var("u", ty.u32(), ast::StorageClass::kNone);
+ auto* f = Var("f", ty.f32(), ast::StorageClass::kNone);
+ auto* b = Var("b", ty.bool_(), ast::StorageClass::kNone);
+ auto* s = Var("s", ty.Of(S), ast::StorageClass::kNone);
+ auto* a = Var("a", ty.Of(A), ast::StorageClass::kNone);
+
+ Func("F", {}, ty.void_(),
+ {
+ Decl(i),
+ Decl(u),
+ Decl(f),
+ Decl(b),
+ Decl(s),
+ Decl(a),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ // `var` declarations are always of reference type
+ ASSERT_TRUE(TypeOf(i)->Is<sem::Reference>());
+ ASSERT_TRUE(TypeOf(u)->Is<sem::Reference>());
+ ASSERT_TRUE(TypeOf(f)->Is<sem::Reference>());
+ ASSERT_TRUE(TypeOf(b)->Is<sem::Reference>());
+ ASSERT_TRUE(TypeOf(s)->Is<sem::Reference>());
+ ASSERT_TRUE(TypeOf(a)->Is<sem::Reference>());
+
+ EXPECT_TRUE(TypeOf(i)->As<sem::Reference>()->StoreType()->Is<sem::I32>());
+ EXPECT_TRUE(TypeOf(u)->As<sem::Reference>()->StoreType()->Is<sem::U32>());
+ EXPECT_TRUE(TypeOf(f)->As<sem::Reference>()->StoreType()->Is<sem::F32>());
+ EXPECT_TRUE(TypeOf(b)->As<sem::Reference>()->StoreType()->Is<sem::Bool>());
+ EXPECT_TRUE(TypeOf(s)->As<sem::Reference>()->StoreType()->Is<sem::Struct>());
+ EXPECT_TRUE(TypeOf(a)->As<sem::Reference>()->StoreType()->Is<sem::Struct>());
+
+ EXPECT_EQ(Sem().Get(i)->Constructor(), nullptr);
+ EXPECT_EQ(Sem().Get(u)->Constructor(), nullptr);
+ EXPECT_EQ(Sem().Get(f)->Constructor(), nullptr);
+ EXPECT_EQ(Sem().Get(b)->Constructor(), nullptr);
+ EXPECT_EQ(Sem().Get(s)->Constructor(), nullptr);
+ EXPECT_EQ(Sem().Get(a)->Constructor(), nullptr);
+}
+
+TEST_F(ResolverVarLetTest, VarDeclWithConstructor) {
+ // struct S { i : i32; }
+ // alias A = S;
+ // fn F(){
+ // var i : i32 = 1;
+ // var u : u32 = 1u;
+ // var f : f32 = 1.f;
+ // var b : bool = true;
+ // var s : S = S(1);
+ // var a : A = A(1);
+ // }
+
+ auto* S = Structure("S", {Member("i", ty.i32())});
+ auto* A = Alias("A", ty.Of(S));
+
+ auto* i_c = Expr(1);
+ auto* u_c = Expr(1u);
+ auto* f_c = Expr(1.f);
+ auto* b_c = Expr(true);
+ auto* s_c = Construct(ty.Of(S), Expr(1));
+ auto* a_c = Construct(ty.Of(A), Expr(1));
+
+ auto* i = Var("i", ty.i32(), ast::StorageClass::kNone, i_c);
+ auto* u = Var("u", ty.u32(), ast::StorageClass::kNone, u_c);
+ auto* f = Var("f", ty.f32(), ast::StorageClass::kNone, f_c);
+ auto* b = Var("b", ty.bool_(), ast::StorageClass::kNone, b_c);
+ auto* s = Var("s", ty.Of(S), ast::StorageClass::kNone, s_c);
+ auto* a = Var("a", ty.Of(A), ast::StorageClass::kNone, a_c);
+
+ Func("F", {}, ty.void_(),
+ {
+ Decl(i),
+ Decl(u),
+ Decl(f),
+ Decl(b),
+ Decl(s),
+ Decl(a),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ // `var` declarations are always of reference type
+ ASSERT_TRUE(TypeOf(i)->Is<sem::Reference>());
+ ASSERT_TRUE(TypeOf(u)->Is<sem::Reference>());
+ ASSERT_TRUE(TypeOf(f)->Is<sem::Reference>());
+ ASSERT_TRUE(TypeOf(b)->Is<sem::Reference>());
+ ASSERT_TRUE(TypeOf(s)->Is<sem::Reference>());
+ ASSERT_TRUE(TypeOf(a)->Is<sem::Reference>());
+
+ EXPECT_TRUE(TypeOf(i)->As<sem::Reference>()->StoreType()->Is<sem::I32>());
+ EXPECT_TRUE(TypeOf(u)->As<sem::Reference>()->StoreType()->Is<sem::U32>());
+ EXPECT_TRUE(TypeOf(f)->As<sem::Reference>()->StoreType()->Is<sem::F32>());
+ EXPECT_TRUE(TypeOf(b)->As<sem::Reference>()->StoreType()->Is<sem::Bool>());
+ EXPECT_TRUE(TypeOf(s)->As<sem::Reference>()->StoreType()->Is<sem::Struct>());
+ EXPECT_TRUE(TypeOf(a)->As<sem::Reference>()->StoreType()->Is<sem::Struct>());
+
+ EXPECT_EQ(Sem().Get(i)->Constructor()->Declaration(), i_c);
+ EXPECT_EQ(Sem().Get(u)->Constructor()->Declaration(), u_c);
+ EXPECT_EQ(Sem().Get(f)->Constructor()->Declaration(), f_c);
+ EXPECT_EQ(Sem().Get(b)->Constructor()->Declaration(), b_c);
+ EXPECT_EQ(Sem().Get(s)->Constructor()->Declaration(), s_c);
+ EXPECT_EQ(Sem().Get(a)->Constructor()->Declaration(), a_c);
+}
+
+TEST_F(ResolverVarLetTest, LetDecl) {
+ // struct S { i : i32; }
+ // fn F(){
+ // var v : i32;
+ // let i : i32 = 1;
+ // let u : u32 = 1u;
+ // let f : f32 = 1.;
+ // let b : bool = true;
+ // let s : S = S(1);
+ // let a : A = A(1);
+ // let p : pointer<function, i32> = &v;
+ // }
+
+ auto* S = Structure("S", {Member("i", ty.i32())});
+ auto* A = Alias("A", ty.Of(S));
+ auto* v = Var("v", ty.i32(), ast::StorageClass::kNone);
+
+ auto* i_c = Expr(1);
+ auto* u_c = Expr(1u);
+ auto* f_c = Expr(1.f);
+ auto* b_c = Expr(true);
+ auto* s_c = Construct(ty.Of(S), Expr(1));
+ auto* a_c = Construct(ty.Of(A), Expr(1));
+ auto* p_c = AddressOf(v);
+
+ auto* i = Const("i", ty.i32(), i_c);
+ auto* u = Const("u", ty.u32(), u_c);
+ auto* f = Const("f", ty.f32(), f_c);
+ auto* b = Const("b", ty.bool_(), b_c);
+ auto* s = Const("s", ty.Of(S), s_c);
+ auto* a = Const("a", ty.Of(A), a_c);
+ auto* p = Const("p", ty.pointer<i32>(ast::StorageClass::kFunction), p_c);
+
+ Func("F", {}, ty.void_(),
+ {
+ Decl(v),
+ Decl(i),
+ Decl(u),
+ Decl(f),
+ Decl(b),
+ Decl(s),
+ Decl(a),
+ Decl(p),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ // `let` declarations are always of the storage type
+ ASSERT_TRUE(TypeOf(i)->Is<sem::I32>());
+ ASSERT_TRUE(TypeOf(u)->Is<sem::U32>());
+ ASSERT_TRUE(TypeOf(f)->Is<sem::F32>());
+ ASSERT_TRUE(TypeOf(b)->Is<sem::Bool>());
+ ASSERT_TRUE(TypeOf(s)->Is<sem::Struct>());
+ ASSERT_TRUE(TypeOf(a)->Is<sem::Struct>());
+ ASSERT_TRUE(TypeOf(p)->Is<sem::Pointer>());
+ ASSERT_TRUE(TypeOf(p)->As<sem::Pointer>()->StoreType()->Is<sem::I32>());
+
+ EXPECT_EQ(Sem().Get(i)->Constructor()->Declaration(), i_c);
+ EXPECT_EQ(Sem().Get(u)->Constructor()->Declaration(), u_c);
+ EXPECT_EQ(Sem().Get(f)->Constructor()->Declaration(), f_c);
+ EXPECT_EQ(Sem().Get(b)->Constructor()->Declaration(), b_c);
+ EXPECT_EQ(Sem().Get(s)->Constructor()->Declaration(), s_c);
+ EXPECT_EQ(Sem().Get(a)->Constructor()->Declaration(), a_c);
+ EXPECT_EQ(Sem().Get(p)->Constructor()->Declaration(), p_c);
+}
+
+TEST_F(ResolverVarLetTest, DefaultVarStorageClass) {
+ // https://gpuweb.github.io/gpuweb/wgsl/#storage-class
+
+ auto* buf = Structure("S", {Member("m", ty.i32())},
+ {create<ast::StructBlockAttribute>()});
+ auto* function = Var("f", ty.i32());
+ auto* private_ = Global("p", ty.i32(), ast::StorageClass::kPrivate);
+ auto* workgroup = Global("w", ty.i32(), ast::StorageClass::kWorkgroup);
+ auto* uniform = Global("ub", ty.Of(buf), ast::StorageClass::kUniform,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+ auto* storage = Global("sb", ty.Of(buf), ast::StorageClass::kStorage,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(1),
+ create<ast::GroupAttribute>(0),
+ });
+ auto* handle = Global("h", ty.depth_texture(ast::TextureDimension::k2d),
+ ast::AttributeList{
+ create<ast::BindingAttribute>(2),
+ create<ast::GroupAttribute>(0),
+ });
+
+ WrapInFunction(function);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_TRUE(TypeOf(function)->Is<sem::Reference>());
+ ASSERT_TRUE(TypeOf(private_)->Is<sem::Reference>());
+ ASSERT_TRUE(TypeOf(workgroup)->Is<sem::Reference>());
+ ASSERT_TRUE(TypeOf(uniform)->Is<sem::Reference>());
+ ASSERT_TRUE(TypeOf(storage)->Is<sem::Reference>());
+ ASSERT_TRUE(TypeOf(handle)->Is<sem::Reference>());
+
+ EXPECT_EQ(TypeOf(function)->As<sem::Reference>()->Access(),
+ ast::Access::kReadWrite);
+ EXPECT_EQ(TypeOf(private_)->As<sem::Reference>()->Access(),
+ ast::Access::kReadWrite);
+ EXPECT_EQ(TypeOf(workgroup)->As<sem::Reference>()->Access(),
+ ast::Access::kReadWrite);
+ EXPECT_EQ(TypeOf(uniform)->As<sem::Reference>()->Access(),
+ ast::Access::kRead);
+ EXPECT_EQ(TypeOf(storage)->As<sem::Reference>()->Access(),
+ ast::Access::kRead);
+ EXPECT_EQ(TypeOf(handle)->As<sem::Reference>()->Access(), ast::Access::kRead);
+}
+
+TEST_F(ResolverVarLetTest, ExplicitVarStorageClass) {
+ // https://gpuweb.github.io/gpuweb/wgsl/#storage-class
+
+ auto* buf = Structure("S", {Member("m", ty.i32())},
+ {create<ast::StructBlockAttribute>()});
+ auto* storage = Global("sb", ty.Of(buf), ast::StorageClass::kStorage,
+ ast::Access::kReadWrite,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(1),
+ create<ast::GroupAttribute>(0),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_TRUE(TypeOf(storage)->Is<sem::Reference>());
+
+ EXPECT_EQ(TypeOf(storage)->As<sem::Reference>()->Access(),
+ ast::Access::kReadWrite);
+}
+
+TEST_F(ResolverVarLetTest, LetInheritsAccessFromOriginatingVariable) {
+ // struct Inner {
+ // arr: array<i32, 4>;
+ // }
+ // [[block]] struct S {
+ // inner: Inner;
+ // }
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ // fn f() {
+ // let p = &s.inner.arr[2];
+ // }
+ auto* inner = Structure("Inner", {Member("arr", ty.array<i32, 4>())});
+ auto* buf = Structure("S", {Member("inner", ty.Of(inner))},
+ {create<ast::StructBlockAttribute>()});
+ auto* storage = Global("s", ty.Of(buf), ast::StorageClass::kStorage,
+ ast::Access::kReadWrite,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ auto* expr =
+ IndexAccessor(MemberAccessor(MemberAccessor(storage, "inner"), "arr"), 4);
+ auto* ptr = Const("p", nullptr, AddressOf(expr));
+
+ WrapInFunction(ptr);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::Reference>());
+ ASSERT_TRUE(TypeOf(ptr)->Is<sem::Pointer>());
+
+ EXPECT_EQ(TypeOf(expr)->As<sem::Reference>()->Access(),
+ ast::Access::kReadWrite);
+ EXPECT_EQ(TypeOf(ptr)->As<sem::Pointer>()->Access(), ast::Access::kReadWrite);
+}
+
+TEST_F(ResolverVarLetTest, LocalShadowsAlias) {
+ // type a = i32;
+ //
+ // fn X() {
+ // var a = false;
+ // }
+ //
+ // fn Y() {
+ // let a = true;
+ // }
+
+ auto* t = Alias("a", ty.i32());
+ auto* v = Var("a", nullptr, Expr(false));
+ auto* l = Const("a", nullptr, Expr(false));
+ Func("X", {}, ty.void_(), {Decl(v)});
+ Func("Y", {}, ty.void_(), {Decl(l)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* type_t = Sem().Get(t);
+ auto* local_v = Sem().Get<sem::LocalVariable>(v);
+ auto* local_l = Sem().Get<sem::LocalVariable>(l);
+
+ ASSERT_NE(local_v, nullptr);
+ ASSERT_NE(local_l, nullptr);
+
+ EXPECT_EQ(local_v->Shadows(), type_t);
+ EXPECT_EQ(local_l->Shadows(), type_t);
+}
+
+TEST_F(ResolverVarLetTest, LocalShadowsStruct) {
+ // struct a {
+ // m : i32;
+ // };
+ //
+ // fn X() {
+ // var a = true;
+ // }
+ //
+ // fn Y() {
+ // let a = false;
+ // }
+
+ auto* t = Structure("a", {Member("m", ty.i32())});
+ auto* v = Var("a", nullptr, Expr(false));
+ auto* l = Const("a", nullptr, Expr(false));
+ Func("X", {}, ty.void_(), {Decl(v)});
+ Func("Y", {}, ty.void_(), {Decl(l)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* type_t = Sem().Get(t);
+ auto* local_v = Sem().Get<sem::LocalVariable>(v);
+ auto* local_l = Sem().Get<sem::LocalVariable>(l);
+
+ ASSERT_NE(local_v, nullptr);
+ ASSERT_NE(local_l, nullptr);
+
+ EXPECT_EQ(local_v->Shadows(), type_t);
+ EXPECT_EQ(local_l->Shadows(), type_t);
+}
+
+TEST_F(ResolverVarLetTest, LocalShadowsFunction) {
+ // fn a() {
+ // var a = true;
+ // }
+ //
+ // fn b() {
+ // let b = false;
+ // }
+
+ auto* v = Var("a", nullptr, Expr(false));
+ auto* l = Const("b", nullptr, Expr(false));
+ auto* fa = Func("a", {}, ty.void_(), {Decl(v)});
+ auto* fb = Func("b", {}, ty.void_(), {Decl(l)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* local_v = Sem().Get<sem::LocalVariable>(v);
+ auto* local_l = Sem().Get<sem::LocalVariable>(l);
+ auto* func_a = Sem().Get(fa);
+ auto* func_b = Sem().Get(fb);
+
+ ASSERT_NE(local_v, nullptr);
+ ASSERT_NE(local_l, nullptr);
+ ASSERT_NE(func_a, nullptr);
+ ASSERT_NE(func_b, nullptr);
+
+ EXPECT_EQ(local_v->Shadows(), func_a);
+ EXPECT_EQ(local_l->Shadows(), func_b);
+}
+
+TEST_F(ResolverVarLetTest, LocalShadowsGlobalVar) {
+ // var<private> a : i32;
+ //
+ // fn X() {
+ // var a = a;
+ // }
+ //
+ // fn Y() {
+ // let a = a;
+ // }
+
+ auto* g = Global("a", ty.i32(), ast::StorageClass::kPrivate);
+ auto* v = Var("a", nullptr, Expr("a"));
+ auto* l = Const("a", nullptr, Expr("a"));
+ Func("X", {}, ty.void_(), {Decl(v)});
+ Func("Y", {}, ty.void_(), {Decl(l)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* global = Sem().Get(g);
+ auto* local_v = Sem().Get<sem::LocalVariable>(v);
+ auto* local_l = Sem().Get<sem::LocalVariable>(l);
+
+ ASSERT_NE(local_v, nullptr);
+ ASSERT_NE(local_l, nullptr);
+
+ EXPECT_EQ(local_v->Shadows(), global);
+ EXPECT_EQ(local_l->Shadows(), global);
+
+ auto* user_v =
+ Sem().Get<sem::VariableUser>(local_v->Declaration()->constructor);
+ auto* user_l =
+ Sem().Get<sem::VariableUser>(local_l->Declaration()->constructor);
+
+ ASSERT_NE(user_v, nullptr);
+ ASSERT_NE(user_l, nullptr);
+
+ EXPECT_EQ(user_v->Variable(), global);
+ EXPECT_EQ(user_l->Variable(), global);
+}
+
+TEST_F(ResolverVarLetTest, LocalShadowsGlobalLet) {
+ // let a : i32 = 1;
+ //
+ // fn X() {
+ // var a = (a == 123);
+ // }
+ //
+ // fn Y() {
+ // let a = (a == 321);
+ // }
+
+ auto* g = GlobalConst("a", ty.i32(), Expr(1));
+ auto* v = Var("a", nullptr, Expr("a"));
+ auto* l = Const("a", nullptr, Expr("a"));
+ Func("X", {}, ty.void_(), {Decl(v)});
+ Func("Y", {}, ty.void_(), {Decl(l)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* global = Sem().Get(g);
+ auto* local_v = Sem().Get<sem::LocalVariable>(v);
+ auto* local_l = Sem().Get<sem::LocalVariable>(l);
+
+ ASSERT_NE(local_v, nullptr);
+ ASSERT_NE(local_l, nullptr);
+
+ EXPECT_EQ(local_v->Shadows(), global);
+ EXPECT_EQ(local_l->Shadows(), global);
+
+ auto* user_v =
+ Sem().Get<sem::VariableUser>(local_v->Declaration()->constructor);
+ auto* user_l =
+ Sem().Get<sem::VariableUser>(local_l->Declaration()->constructor);
+
+ ASSERT_NE(user_v, nullptr);
+ ASSERT_NE(user_l, nullptr);
+
+ EXPECT_EQ(user_v->Variable(), global);
+ EXPECT_EQ(user_l->Variable(), global);
+}
+
+TEST_F(ResolverVarLetTest, LocalShadowsLocalVar) {
+ // fn X() {
+ // var a : i32;
+ // {
+ // var a = a;
+ // }
+ // {
+ // let a = a;
+ // }
+ // }
+
+ auto* s = Var("a", ty.i32(), Expr(1));
+ auto* v = Var("a", nullptr, Expr("a"));
+ auto* l = Const("a", nullptr, Expr("a"));
+ Func("X", {}, ty.void_(), {Decl(s), Block(Decl(v)), Block(Decl(l))});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* local_s = Sem().Get<sem::LocalVariable>(s);
+ auto* local_v = Sem().Get<sem::LocalVariable>(v);
+ auto* local_l = Sem().Get<sem::LocalVariable>(l);
+
+ ASSERT_NE(local_s, nullptr);
+ ASSERT_NE(local_v, nullptr);
+ ASSERT_NE(local_l, nullptr);
+
+ EXPECT_EQ(local_v->Shadows(), local_s);
+ EXPECT_EQ(local_l->Shadows(), local_s);
+
+ auto* user_v =
+ Sem().Get<sem::VariableUser>(local_v->Declaration()->constructor);
+ auto* user_l =
+ Sem().Get<sem::VariableUser>(local_l->Declaration()->constructor);
+
+ ASSERT_NE(user_v, nullptr);
+ ASSERT_NE(user_l, nullptr);
+
+ EXPECT_EQ(user_v->Variable(), local_s);
+ EXPECT_EQ(user_l->Variable(), local_s);
+}
+
+TEST_F(ResolverVarLetTest, LocalShadowsLocalLet) {
+ // fn X() {
+ // let a = 1;
+ // {
+ // var a = (a == 123);
+ // }
+ // {
+ // let a = (a == 321);
+ // }
+ // }
+
+ auto* s = Const("a", ty.i32(), Expr(1));
+ auto* v = Var("a", nullptr, Expr("a"));
+ auto* l = Const("a", nullptr, Expr("a"));
+ Func("X", {}, ty.void_(), {Decl(s), Block(Decl(v)), Block(Decl(l))});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* local_s = Sem().Get<sem::LocalVariable>(s);
+ auto* local_v = Sem().Get<sem::LocalVariable>(v);
+ auto* local_l = Sem().Get<sem::LocalVariable>(l);
+
+ ASSERT_NE(local_s, nullptr);
+ ASSERT_NE(local_v, nullptr);
+ ASSERT_NE(local_l, nullptr);
+
+ EXPECT_EQ(local_v->Shadows(), local_s);
+ EXPECT_EQ(local_l->Shadows(), local_s);
+
+ auto* user_v =
+ Sem().Get<sem::VariableUser>(local_v->Declaration()->constructor);
+ auto* user_l =
+ Sem().Get<sem::VariableUser>(local_l->Declaration()->constructor);
+
+ ASSERT_NE(user_v, nullptr);
+ ASSERT_NE(user_l, nullptr);
+
+ EXPECT_EQ(user_v->Variable(), local_s);
+ EXPECT_EQ(user_l->Variable(), local_s);
+}
+
+TEST_F(ResolverVarLetTest, LocalShadowsParam) {
+ // fn F(a : i32) {
+ // {
+ // var a = a;
+ // }
+ // {
+ // let a = a;
+ // }
+ // }
+
+ auto* p = Param("a", ty.i32());
+ auto* v = Var("a", nullptr, Expr("a"));
+ auto* l = Const("a", nullptr, Expr("a"));
+ Func("X", {p}, ty.void_(), {Block(Decl(v)), Block(Decl(l))});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* param = Sem().Get<sem::Parameter>(p);
+ auto* local_v = Sem().Get<sem::LocalVariable>(v);
+ auto* local_l = Sem().Get<sem::LocalVariable>(l);
+
+ ASSERT_NE(param, nullptr);
+ ASSERT_NE(local_v, nullptr);
+ ASSERT_NE(local_l, nullptr);
+
+ EXPECT_EQ(local_v->Shadows(), param);
+ EXPECT_EQ(local_l->Shadows(), param);
+
+ auto* user_v =
+ Sem().Get<sem::VariableUser>(local_v->Declaration()->constructor);
+ auto* user_l =
+ Sem().Get<sem::VariableUser>(local_l->Declaration()->constructor);
+
+ ASSERT_NE(user_v, nullptr);
+ ASSERT_NE(user_l, nullptr);
+
+ EXPECT_EQ(user_v->Variable(), param);
+ EXPECT_EQ(user_l->Variable(), param);
+}
+
+TEST_F(ResolverVarLetTest, ParamShadowsFunction) {
+ // fn a(a : bool) {
+ // }
+
+ auto* p = Param("a", ty.bool_());
+ auto* f = Func("a", {p}, ty.void_(), {});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func = Sem().Get(f);
+ auto* param = Sem().Get<sem::Parameter>(p);
+
+ ASSERT_NE(func, nullptr);
+ ASSERT_NE(param, nullptr);
+
+ EXPECT_EQ(param->Shadows(), func);
+}
+
+TEST_F(ResolverVarLetTest, ParamShadowsGlobalVar) {
+ // var<private> a : i32;
+ //
+ // fn F(a : bool) {
+ // }
+
+ auto* g = Global("a", ty.i32(), ast::StorageClass::kPrivate);
+ auto* p = Param("a", ty.bool_());
+ Func("F", {p}, ty.void_(), {});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* global = Sem().Get(g);
+ auto* param = Sem().Get<sem::Parameter>(p);
+
+ ASSERT_NE(global, nullptr);
+ ASSERT_NE(param, nullptr);
+
+ EXPECT_EQ(param->Shadows(), global);
+}
+
+TEST_F(ResolverVarLetTest, ParamShadowsGlobalLet) {
+ // let a : i32 = 1;
+ //
+ // fn F(a : bool) {
+ // }
+
+ auto* g = GlobalConst("a", ty.i32(), Expr(1));
+ auto* p = Param("a", ty.bool_());
+ Func("F", {p}, ty.void_(), {});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* global = Sem().Get(g);
+ auto* param = Sem().Get<sem::Parameter>(p);
+
+ ASSERT_NE(global, nullptr);
+ ASSERT_NE(param, nullptr);
+
+ EXPECT_EQ(param->Shadows(), global);
+}
+
+TEST_F(ResolverVarLetTest, ParamShadowsAlias) {
+ // type a = i32;
+ //
+ // fn F(a : a) {
+ // }
+
+ auto* a = Alias("a", ty.i32());
+ auto* p = Param("a", ty.type_name("a"));
+ Func("F", {p}, ty.void_(), {});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* alias = Sem().Get(a);
+ auto* param = Sem().Get<sem::Parameter>(p);
+
+ ASSERT_NE(alias, nullptr);
+ ASSERT_NE(param, nullptr);
+
+ EXPECT_EQ(param->Shadows(), alias);
+ EXPECT_EQ(param->Type(), alias);
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/tint/resolver/var_let_validation_test.cc b/src/tint/resolver/var_let_validation_test.cc
new file mode 100644
index 0000000..fbb570e
--- /dev/null
+++ b/src/tint/resolver/var_let_validation_test.cc
@@ -0,0 +1,352 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/struct_block_attribute.h"
+#include "src/tint/resolver/resolver.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+
+#include "gmock/gmock.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+struct ResolverVarLetValidationTest : public resolver::TestHelper,
+ public testing::Test {};
+
+TEST_F(ResolverVarLetValidationTest, LetNoInitializer) {
+ // let a : i32;
+ WrapInFunction(Const(Source{{12, 34}}, "a", ty.i32(), nullptr));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: let declaration must have an initializer");
+}
+
+TEST_F(ResolverVarLetValidationTest, GlobalLetNoInitializer) {
+ // let a : i32;
+ GlobalConst(Source{{12, 34}}, "a", ty.i32(), nullptr);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: let declaration must have an initializer");
+}
+
+TEST_F(ResolverVarLetValidationTest, VarNoInitializerNoType) {
+ // var a;
+ WrapInFunction(Var(Source{{12, 34}}, "a", nullptr));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: function scope var declaration requires a type or "
+ "initializer");
+}
+
+TEST_F(ResolverVarLetValidationTest, GlobalVarNoInitializerNoType) {
+ // var a;
+ Global(Source{{12, 34}}, "a", nullptr);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: module scope var declaration requires a type and "
+ "initializer");
+}
+
+TEST_F(ResolverVarLetValidationTest, VarTypeNotStorable) {
+ // var i : i32;
+ // var p : pointer<function, i32> = &v;
+ auto* i = Var("i", ty.i32(), ast::StorageClass::kNone);
+ auto* p =
+ Var(Source{{56, 78}}, "a", ty.pointer<i32>(ast::StorageClass::kFunction),
+ ast::StorageClass::kNone, AddressOf(Source{{12, 34}}, "i"));
+ WrapInFunction(i, p);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "56:78 error: ptr<function, i32, read_write> cannot be used as the "
+ "type of a var");
+}
+
+TEST_F(ResolverVarLetValidationTest, LetTypeNotConstructible) {
+ // @group(0) @binding(0) var t1 : texture_2d<f32>;
+ // let t2 : t1;
+ auto* t1 =
+ Global("t1", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()),
+ GroupAndBinding(0, 0));
+ auto* t2 = Const(Source{{56, 78}}, "t2", nullptr, Expr(t1));
+ WrapInFunction(t2);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "56:78 error: texture_2d<f32> cannot be used as the type of a let");
+}
+
+TEST_F(ResolverVarLetValidationTest, LetConstructorWrongType) {
+ // var v : i32 = 2u
+ WrapInFunction(Const(Source{{3, 3}}, "v", ty.i32(), Expr(2u)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(3:3 error: cannot initialize let of type 'i32' with value of type 'u32')");
+}
+
+TEST_F(ResolverVarLetValidationTest, VarConstructorWrongType) {
+ // var v : i32 = 2u
+ WrapInFunction(
+ Var(Source{{3, 3}}, "v", ty.i32(), ast::StorageClass::kNone, Expr(2u)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(3:3 error: cannot initialize var of type 'i32' with value of type 'u32')");
+}
+
+TEST_F(ResolverVarLetValidationTest, LetConstructorWrongTypeViaAlias) {
+ auto* a = Alias("I32", ty.i32());
+ WrapInFunction(Const(Source{{3, 3}}, "v", ty.Of(a), Expr(2u)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(3:3 error: cannot initialize let of type 'i32' with value of type 'u32')");
+}
+
+TEST_F(ResolverVarLetValidationTest, VarConstructorWrongTypeViaAlias) {
+ auto* a = Alias("I32", ty.i32());
+ WrapInFunction(
+ Var(Source{{3, 3}}, "v", ty.Of(a), ast::StorageClass::kNone, Expr(2u)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(3:3 error: cannot initialize var of type 'i32' with value of type 'u32')");
+}
+
+TEST_F(ResolverVarLetValidationTest, LetOfPtrConstructedWithRef) {
+ // var a : f32;
+ // let b : ptr<function,f32> = a;
+ const auto priv = ast::StorageClass::kFunction;
+ auto* var_a = Var("a", ty.f32(), priv);
+ auto* var_b =
+ Const(Source{{12, 34}}, "b", ty.pointer<float>(priv), Expr("a"), {});
+ WrapInFunction(var_a, var_b);
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: cannot initialize let of type 'ptr<function, f32, read_write>' with value of type 'f32')");
+}
+
+TEST_F(ResolverVarLetValidationTest, LocalLetRedeclared) {
+ // let l : f32 = 1.;
+ // let l : i32 = 0;
+ auto* l1 = Const("l", ty.f32(), Expr(1.f));
+ auto* l2 = Const(Source{{12, 34}}, "l", ty.i32(), Expr(0));
+ WrapInFunction(l1, l2);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: redeclaration of 'l'\nnote: 'l' previously declared here");
+}
+
+TEST_F(ResolverVarLetValidationTest, GlobalVarRedeclaredAsLocal) {
+ // var v : f32 = 2.1;
+ // fn my_func() {
+ // var v : f32 = 2.0;
+ // return 0;
+ // }
+
+ Global("v", ty.f32(), ast::StorageClass::kPrivate, Expr(2.1f));
+
+ WrapInFunction(Var(Source{{12, 34}}, "v", ty.f32(), ast::StorageClass::kNone,
+ Expr(2.0f)));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverVarLetValidationTest, VarRedeclaredInInnerBlock) {
+ // {
+ // var v : f32;
+ // { var v : f32; }
+ // }
+ auto* var_outer = Var("v", ty.f32(), ast::StorageClass::kNone);
+ auto* var_inner =
+ Var(Source{{12, 34}}, "v", ty.f32(), ast::StorageClass::kNone);
+ auto* inner = Block(Decl(var_inner));
+ auto* outer_body = Block(Decl(var_outer), inner);
+
+ WrapInFunction(outer_body);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverVarLetValidationTest, VarRedeclaredInIfBlock) {
+ // {
+ // var v : f32 = 3.14;
+ // if (true) { var v : f32 = 2.0; }
+ // }
+ auto* var_a_float = Var("v", ty.f32(), ast::StorageClass::kNone, Expr(3.1f));
+
+ auto* var = Var(Source{{12, 34}}, "v", ty.f32(), ast::StorageClass::kNone,
+ Expr(2.0f));
+
+ auto* cond = Expr(true);
+ auto* body = Block(Decl(var));
+
+ auto* outer_body =
+ Block(Decl(var_a_float),
+ create<ast::IfStatement>(cond, body, ast::ElseStatementList{}));
+
+ WrapInFunction(outer_body);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverVarLetValidationTest, InferredPtrStorageAccessMismatch) {
+ // struct Inner {
+ // arr: array<i32, 4>;
+ // }
+ // [[block]] struct S {
+ // inner: Inner;
+ // }
+ // @group(0) @binding(0) var<storage> s : S;
+ // fn f() {
+ // let p : pointer<storage, i32, read_write> = &s.inner.arr[2];
+ // }
+ auto* inner = Structure("Inner", {Member("arr", ty.array<i32, 4>())});
+ auto* buf = Structure("S", {Member("inner", ty.Of(inner))},
+ {create<ast::StructBlockAttribute>()});
+ auto* storage = Global("s", ty.Of(buf), ast::StorageClass::kStorage,
+ ast::AttributeList{
+ create<ast::BindingAttribute>(0),
+ create<ast::GroupAttribute>(0),
+ });
+
+ auto* expr =
+ IndexAccessor(MemberAccessor(MemberAccessor(storage, "inner"), "arr"), 4);
+ auto* ptr = Const(
+ Source{{12, 34}}, "p",
+ ty.pointer<i32>(ast::StorageClass::kStorage, ast::Access::kReadWrite),
+ AddressOf(expr));
+
+ WrapInFunction(ptr);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: cannot initialize let of type "
+ "'ptr<storage, i32, read_write>' with value of type "
+ "'ptr<storage, i32, read>'");
+}
+
+TEST_F(ResolverVarLetValidationTest, NonConstructibleType_Atomic) {
+ auto* v = Var("v", ty.atomic(Source{{12, 34}}, ty.i32()));
+ WrapInFunction(v);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: function variable must have a constructible type");
+}
+
+TEST_F(ResolverVarLetValidationTest, NonConstructibleType_RuntimeArray) {
+ auto* s = Structure("S", {Member(Source{{56, 78}}, "m", ty.array(ty.i32()))},
+ {StructBlock()});
+ auto* v = Var(Source{{12, 34}}, "v", ty.Of(s));
+ WrapInFunction(v);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: runtime-sized arrays can only be used in the <storage> storage class
+56:78 note: while analysing structure member S.m
+12:34 note: while instantiating variable v)");
+}
+
+TEST_F(ResolverVarLetValidationTest, NonConstructibleType_Struct_WithAtomic) {
+ auto* s = Structure("S", {Member("m", ty.atomic(ty.i32()))});
+ auto* v = Var("v", ty.Of(s));
+ WrapInFunction(v);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "error: function variable must have a constructible type");
+}
+
+TEST_F(ResolverVarLetValidationTest, NonConstructibleType_InferredType) {
+ // @group(0) @binding(0) var s : sampler;
+ // fn foo() {
+ // var v = s;
+ // }
+ Global("s", ty.sampler(ast::SamplerKind::kSampler), GroupAndBinding(0, 0));
+ auto* v = Var(Source{{12, 34}}, "v", nullptr, Expr("s"));
+ WrapInFunction(v);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: function variable must have a constructible type");
+}
+
+TEST_F(ResolverVarLetValidationTest, InvalidStorageClassForInitializer) {
+ // var<workgroup> v : f32 = 1.23;
+ Global(Source{{12, 34}}, "v", ty.f32(), ast::StorageClass::kWorkgroup,
+ Expr(1.23f));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: var of storage class 'workgroup' cannot have "
+ "an initializer. var initializers are only supported for the "
+ "storage classes 'private' and 'function'");
+}
+
+TEST_F(ResolverVarLetValidationTest, VectorLetNoType) {
+ // let a : mat3x3 = mat3x3<f32>();
+ WrapInFunction(Const("a", create<ast::Vector>(Source{{12, 34}}, nullptr, 3),
+ vec3<f32>()));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: missing vector element type");
+}
+
+TEST_F(ResolverVarLetValidationTest, VectorVarNoType) {
+ // var a : mat3x3;
+ WrapInFunction(Var("a", create<ast::Vector>(Source{{12, 34}}, nullptr, 3)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: missing vector element type");
+}
+
+TEST_F(ResolverVarLetValidationTest, MatrixLetNoType) {
+ // let a : mat3x3 = mat3x3<f32>();
+ WrapInFunction(Const("a",
+ create<ast::Matrix>(Source{{12, 34}}, nullptr, 3, 3),
+ mat3x3<f32>()));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: missing matrix element type");
+}
+
+TEST_F(ResolverVarLetValidationTest, MatrixVarNoType) {
+ // var a : mat3x3;
+ WrapInFunction(
+ Var("a", create<ast::Matrix>(Source{{12, 34}}, nullptr, 3, 3)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: missing matrix element type");
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint