tint: Implement constant expression structures
Bug: tint:1611
Change-Id: Id04c31ade297a68e7e2941efafbd812ba631fc41
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/95946
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 9891f29..12fca5a 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -1943,13 +1943,13 @@
sem::Expression* Resolver::MemberAccessor(const ast::MemberAccessorExpression* expr) {
auto* structure = sem_.TypeOf(expr->structure);
auto* storage_ty = structure->UnwrapRef();
- auto* source_var = sem_.Get(expr->structure)->SourceVariable();
+ auto* object = sem_.Get(expr->structure);
+ auto* source_var = object->SourceVariable();
const sem::Type* ret = nullptr;
std::vector<uint32_t> swizzle;
// Object may be a side-effecting expression (e.g. function call).
- auto* object = sem_.Get(expr->structure);
bool has_side_effects = object && object->HasSideEffects();
if (auto* str = storage_ty->As<sem::Struct>()) {
@@ -1976,7 +1976,7 @@
ret = builder_->create<sem::Reference>(ret, ref->StorageClass(), ref->Access());
}
- sem::Constant* val = nullptr; // TODO(crbug.com/tint/1611): Add structure support.
+ auto* val = EvaluateMemberAccessValue(object, member);
return builder_->create<sem::StructMemberAccess>(expr, ret, current_statement_, val, object,
member, has_side_effects, source_var);
}
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index 9a9811f..20f487c 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -66,6 +66,7 @@
class IfStatement;
class LoopStatement;
class Statement;
+class StructMember;
class SwitchStatement;
class TypeConstructor;
class WhileStatement;
@@ -218,6 +219,8 @@
const sem::Type* ty); // Note: ty is not an array or structure
const sem::Constant* EvaluateIndexValue(const sem::Expression* obj, const sem::Expression* idx);
const sem::Constant* EvaluateLiteralValue(const ast::LiteralExpression*, const sem::Type*);
+ const sem::Constant* EvaluateMemberAccessValue(const sem::Expression* obj,
+ const sem::StructMember* member);
const sem::Constant* EvaluateSwizzleValue(const sem::Expression* vector,
const sem::Type* type,
const std::vector<uint32_t>& indices);
diff --git a/src/tint/resolver/resolver_constants.cc b/src/tint/resolver/resolver_constants.cc
index 798d116..c5798be 100644
--- a/src/tint/resolver/resolver_constants.cc
+++ b/src/tint/resolver/resolver_constants.cc
@@ -22,6 +22,7 @@
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/utils/compiler_macros.h"
+#include "src/tint/utils/map.h"
#include "src/tint/utils/transform.h"
using namespace tint::number_suffixes; // NOLINT
@@ -277,6 +278,24 @@
}
return nullptr;
},
+ [&](const sem::Struct* s) -> const Constant* {
+ std::unordered_map<sem::Type*, const Constant*> zero_by_type;
+ std::vector<const Constant*> zeros;
+ zeros.reserve(s->Members().size());
+ for (auto* member : s->Members()) {
+ auto* zero = utils::GetOrCreate(zero_by_type, member->Type(),
+ [&] { return ZeroValue(builder, member->Type()); });
+ if (!zero) {
+ return nullptr;
+ }
+ zeros.emplace_back(zero);
+ }
+ if (zero_by_type.size() == 1) {
+ // All members were of the same type, so the zero value is the same for all members.
+ return builder.create<Splat>(type, zeros[0], s->Members().size());
+ }
+ return CreateComposite(builder, s, std::move(zeros));
+ },
[&](Default) -> const Constant* {
return TypeDispatch(type, [&](auto zero) -> const Constant* {
return CreateElement(builder, type, zero);
@@ -335,6 +354,9 @@
bool all_equal = true;
auto* first = elements.front();
for (auto* el : elements) {
+ if (!el) {
+ return nullptr;
+ }
if (!any_zero && el->AnyZero()) {
any_zero = true;
}
@@ -395,13 +417,7 @@
return ZeroValue(*builder_, ty);
}
- uint32_t el_count = 0;
- auto* el_ty = sem::Type::ElementOf(ty, &el_count);
- if (!el_ty) {
- return nullptr; // Target type does not support constant values
- }
-
- if (args.size() == 1) {
+ if (auto* el_ty = sem::Type::ElementOf(ty); el_ty && args.size() == 1) {
// Type constructor or conversion that takes a single argument.
auto& src = args[0]->Declaration()->source;
auto* arg = static_cast<const Constant*>(args[0]->ConstantValue());
@@ -431,33 +447,25 @@
return nullptr;
}
- // Multiple arguments. Must be a type constructor.
-
- std::vector<const Constant*> els; // The constant elements for the composite constant.
- els.reserve(std::min<uint32_t>(el_count, 256u)); // min() as el_count is unbounded input
-
// Helper for pushing all the argument constants to `els`.
- auto push_all_args = [&] {
- for (auto* expr : args) {
- auto* arg = static_cast<const Constant*>(expr->ConstantValue());
- if (!arg) {
- return;
- }
- els.emplace_back(arg);
- }
+ auto args_as_constants = [&] {
+ return utils::Transform(
+ args, [&](auto* expr) { return static_cast<const Constant*>(expr->ConstantValue()); });
};
- // TODO(crbug.com/tint/1611): Add structure support.
+ // Multiple arguments. Must be a type constructor.
- Switch(
+ return Switch(
ty, // What's the target type being constructed?
- [&](const sem::Vector*) {
+ [&](const sem::Vector*) -> const Constant* {
// Vector can be constructed with a mix of scalars / abstract numerics and smaller
// vectors.
+ std::vector<const Constant*> els;
+ els.reserve(args.size());
for (auto* expr : args) {
auto* arg = static_cast<const Constant*>(expr->ConstantValue());
if (!arg) {
- return;
+ return nullptr;
}
auto* arg_ty = arg->Type();
if (auto* arg_vec = arg_ty->As<sem::Vector>()) {
@@ -465,7 +473,7 @@
for (uint32_t i = 0; i < arg_vec->Width(); i++) {
auto* el = static_cast<const Constant*>(arg->Index(i));
if (!el) {
- return;
+ return nullptr;
}
els.emplace_back(el);
}
@@ -473,12 +481,15 @@
els.emplace_back(arg);
}
}
+ return CreateComposite(*builder_, ty, std::move(els));
},
- [&](const sem::Matrix* m) {
+ [&](const sem::Matrix* m) -> const Constant* {
// Matrix can be constructed with a set of scalars / abstract numerics, or column
// vectors.
if (args.size() == m->columns() * m->rows()) {
// Matrix built from scalars / abstract numerics
+ std::vector<const Constant*> els;
+ els.reserve(args.size());
for (uint32_t c = 0; c < m->columns(); c++) {
std::vector<const Constant*> column;
column.reserve(m->rows());
@@ -486,28 +497,25 @@
auto* arg =
static_cast<const Constant*>(args[r + c * m->rows()]->ConstantValue());
if (!arg) {
- return;
+ return nullptr;
}
column.emplace_back(arg);
}
els.push_back(CreateComposite(*builder_, m->ColumnType(), std::move(column)));
}
- } else if (args.size() == m->columns()) {
- // Matrix built from column vectors
- push_all_args();
+ return CreateComposite(*builder_, ty, std::move(els));
}
+ // Matrix built from column vectors
+ return CreateComposite(*builder_, ty, args_as_constants());
},
[&](const sem::Array*) {
// Arrays must be constructed using a list of elements
- push_all_args();
+ return CreateComposite(*builder_, ty, args_as_constants());
+ },
+ [&](const sem::Struct*) {
+ // Structures must be constructed using a list of elements
+ return CreateComposite(*builder_, ty, args_as_constants());
});
-
- if (els.size() != el_count) {
- // If the number of constant elements doesn't match the type, then something went wrong.
- return nullptr;
- }
- // Construct and return either a Composite or Splat.
- return CreateComposite(*builder_, ty, std::move(els));
}
const sem::Constant* Resolver::EvaluateIndexValue(const sem::Expression* obj_expr,
@@ -538,6 +546,15 @@
return obj_val->Index(static_cast<size_t>(idx));
}
+const sem::Constant* Resolver::EvaluateMemberAccessValue(const sem::Expression* obj_expr,
+ const sem::StructMember* member) {
+ auto obj_val = obj_expr->ConstantValue();
+ if (!obj_val) {
+ return {};
+ }
+ return obj_val->Index(static_cast<size_t>(member->Index()));
+}
+
const sem::Constant* Resolver::EvaluateSwizzleValue(const sem::Expression* vec_expr,
const sem::Type* type,
const std::vector<uint32_t>& indices) {
@@ -546,7 +563,7 @@
return nullptr;
}
if (indices.size() == 1) {
- return static_cast<const Constant*>(vec_val->Index(indices[0]));
+ return static_cast<const Constant*>(vec_val->Index(static_cast<size_t>(indices[0])));
} else {
auto values = utils::Transform(
indices, [&](uint32_t i) { return static_cast<const Constant*>(vec_val->Index(i)); });
diff --git a/src/tint/resolver/resolver_constants_test.cc b/src/tint/resolver/resolver_constants_test.cc
index b8ef1b1..1774a16 100644
--- a/src/tint/resolver/resolver_constants_test.cc
+++ b/src/tint/resolver/resolver_constants_test.cc
@@ -84,6 +84,7 @@
TEST_F(ResolverConstantsTest, Scalar_f16) {
Enable(ast::Extension::kF16);
+
auto* expr = Expr(9.9_h);
WrapInFunction(expr);
@@ -217,6 +218,7 @@
TEST_F(ResolverConstantsTest, Vec3_ZeroInit_f16) {
Enable(ast::Extension::kF16);
+
auto* expr = vec3<f16>();
WrapInFunction(expr);
@@ -383,6 +385,7 @@
TEST_F(ResolverConstantsTest, Vec3_Splat_f16) {
Enable(ast::Extension::kF16);
+
auto* expr = vec3<f16>(9.9_h);
WrapInFunction(expr);
@@ -550,6 +553,7 @@
TEST_F(ResolverConstantsTest, Vec3_FullConstruct_f16) {
Enable(ast::Extension::kF16);
+
auto* expr = vec3<f16>(1_h, 2_h, 3_h);
WrapInFunction(expr);
@@ -848,6 +852,7 @@
TEST_F(ResolverConstantsTest, Vec3_MixConstruct_f16) {
Enable(ast::Extension::kF16);
+
auto* expr = vec3<f16>(1_h, vec2<f16>(2_h, 3_h));
WrapInFunction(expr);
@@ -882,6 +887,7 @@
TEST_F(ResolverConstantsTest, Vec3_MixConstruct_f16_all_10) {
Enable(ast::Extension::kF16);
+
auto* expr = vec3<f16>(10_h, vec2<f16>(10_h, 10_h));
WrapInFunction(expr);
@@ -916,6 +922,7 @@
TEST_F(ResolverConstantsTest, Vec3_MixConstruct_f16_all_positive_0) {
Enable(ast::Extension::kF16);
+
auto* expr = vec3<f16>(0_h, vec2<f16>(0_h, 0_h));
WrapInFunction(expr);
@@ -950,6 +957,7 @@
TEST_F(ResolverConstantsTest, Vec3_MixConstruct_f16_all_negative_0) {
Enable(ast::Extension::kF16);
+
auto* expr = vec3<f16>(vec2<f16>(-0_h, -0_h), -0_h);
WrapInFunction(expr);
@@ -984,6 +992,7 @@
TEST_F(ResolverConstantsTest, Vec3_MixConstruct_f16_mixed_sign_0) {
Enable(ast::Extension::kF16);
+
auto* expr = vec3<f16>(0_h, vec2<f16>(-0_h, 0_h));
WrapInFunction(expr);
@@ -1183,6 +1192,7 @@
TEST_F(ResolverConstantsTest, Vec3_Convert_f16_to_i32) {
Enable(ast::Extension::kF16);
+
auto* expr = vec3<i32>(vec3<f16>(1.1_h, 2.2_h, 3.3_h));
WrapInFunction(expr);
@@ -1217,6 +1227,7 @@
TEST_F(ResolverConstantsTest, Vec3_Convert_u32_to_f16) {
Enable(ast::Extension::kF16);
+
auto* expr = vec3<f16>(vec3<u32>(10_u, 20_u, 30_u));
WrapInFunction(expr);
@@ -1715,6 +1726,48 @@
EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(2)->As<f32>(), 0_f);
}
+TEST_F(ResolverConstantsTest, Array_Struct_f32_Zero) {
+ Structure("S", {
+ Member("m1", ty.f32()),
+ Member("m2", ty.f32()),
+ });
+ auto* expr = Construct(ty.array(ty.type_name("S"), 2_u));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* arr = sem->Type()->As<sem::Array>();
+ ASSERT_NE(arr, nullptr);
+ EXPECT_TRUE(arr->ElemType()->Is<sem::Struct>());
+ EXPECT_EQ(arr->Count(), 2u);
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue()->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->AllZero());
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Index(0)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Index(0)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Index(0)->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(0)->As<f32>(), 0_f);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Index(1)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Index(1)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Index(1)->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(1)->As<f32>(), 0_f);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Index(0)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Index(0)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Index(0)->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(0)->As<f32>(), 0_f);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Index(1)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Index(1)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Index(1)->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(1)->As<f32>(), 0_f);
+}
+
TEST_F(ResolverConstantsTest, Array_i32_Elements) {
auto* expr = Construct(ty.array<i32, 4>(), 10_i, 20_i, 30_i, 40_i);
WrapInFunction(expr);
@@ -1816,6 +1869,523 @@
EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(2)->As<f32>(), 6_f);
}
+TEST_F(ResolverConstantsTest, Array_Struct_f32_Elements) {
+ Structure("S", {
+ Member("m1", ty.f32()),
+ Member("m2", ty.f32()),
+ });
+ auto* expr = Construct(ty.array(ty.type_name("S"), 2_u), //
+ Construct(ty.type_name("S"), 1_f, 2_f), //
+ Construct(ty.type_name("S"), 3_f, 4_f));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* arr = sem->Type()->As<sem::Array>();
+ ASSERT_NE(arr, nullptr);
+ EXPECT_TRUE(arr->ElemType()->Is<sem::Struct>());
+ EXPECT_EQ(arr->Count(), 2u);
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ EXPECT_FALSE(sem->ConstantValue()->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->AllZero());
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Index(0)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->Index(0)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->Index(0)->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(0)->As<f32>(), 1_f);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Index(1)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->Index(1)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->Index(1)->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(1)->As<f32>(), 2_f);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Index(0)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->Index(0)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->Index(0)->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(0)->As<f32>(), 3_f);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Index(1)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->Index(1)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->Index(1)->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(1)->As<f32>(), 4_f);
+}
+
+TEST_F(ResolverConstantsTest, Struct_I32s_ZeroInit) {
+ Structure("S", {Member("m1", ty.i32()), Member("m2", ty.i32()), Member("m3", ty.i32())});
+ auto* expr = Construct(ty.type_name("S"));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* str = sem->Type()->As<sem::Struct>();
+ ASSERT_NE(str, nullptr);
+ EXPECT_EQ(str->Members().size(), 3u);
+ ASSERT_NE(sem->ConstantValue(), nullptr);
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue()->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->AllZero());
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Type()->Is<sem::I32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->As<i32>(), 0_i);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Type()->Is<sem::I32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->As<i32>(), 0_i);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->Type()->Is<sem::I32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(2)->As<i32>(), 0_i);
+}
+
+TEST_F(ResolverConstantsTest, Struct_MixedScalars_ZeroInit) {
+ Enable(ast::Extension::kF16);
+
+ Structure("S", {
+ Member("m1", ty.i32()),
+ Member("m2", ty.u32()),
+ Member("m3", ty.f32()),
+ Member("m4", ty.f16()),
+ Member("m5", ty.bool_()),
+ });
+ auto* expr = Construct(ty.type_name("S"));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* str = sem->Type()->As<sem::Struct>();
+ ASSERT_NE(str, nullptr);
+ EXPECT_EQ(str->Members().size(), 5u);
+ ASSERT_NE(sem->ConstantValue(), nullptr);
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ EXPECT_FALSE(sem->ConstantValue()->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->AllZero());
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Type()->Is<sem::I32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->As<i32>(), 0_i);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Type()->Is<sem::U32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->As<u32>(), 0_u);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->Type()->Is<sem::F32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(2)->As<f32>(), 0._f);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(3)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(3)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(3)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(3)->Type()->Is<sem::F16>());
+ EXPECT_EQ(sem->ConstantValue()->Index(3)->As<f16>(), 0._h);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->Type()->Is<sem::Bool>());
+ EXPECT_EQ(sem->ConstantValue()->Index(4)->As<bool>(), false);
+}
+
+TEST_F(ResolverConstantsTest, Struct_VectorF32s_ZeroInit) {
+ Structure("S", {
+ Member("m1", ty.vec3<f32>()),
+ Member("m2", ty.vec3<f32>()),
+ Member("m3", ty.vec3<f32>()),
+ });
+ auto* expr = Construct(ty.type_name("S"));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* str = sem->Type()->As<sem::Struct>();
+ ASSERT_NE(str, nullptr);
+ EXPECT_EQ(str->Members().size(), 3u);
+ ASSERT_NE(sem->ConstantValue(), nullptr);
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue()->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->AllZero());
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(0)->As<f32>(), 0._f);
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(1)->As<f32>(), 0._f);
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(2)->As<f32>(), 0._f);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(0)->As<f32>(), 0._f);
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(1)->As<f32>(), 0._f);
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(2)->As<f32>(), 0._f);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(2)->Index(0)->As<f32>(), 0._f);
+ EXPECT_EQ(sem->ConstantValue()->Index(2)->Index(1)->As<f32>(), 0._f);
+ EXPECT_EQ(sem->ConstantValue()->Index(2)->Index(2)->As<f32>(), 0._f);
+}
+
+TEST_F(ResolverConstantsTest, Struct_MixedVectors_ZeroInit) {
+ Enable(ast::Extension::kF16);
+
+ Structure("S", {
+ Member("m1", ty.vec2<i32>()),
+ Member("m2", ty.vec3<u32>()),
+ Member("m3", ty.vec4<f32>()),
+ Member("m4", ty.vec3<f16>()),
+ Member("m5", ty.vec2<bool>()),
+ });
+ auto* expr = Construct(ty.type_name("S"));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* str = sem->Type()->As<sem::Struct>();
+ ASSERT_NE(str, nullptr);
+ EXPECT_EQ(str->Members().size(), 5u);
+ ASSERT_NE(sem->ConstantValue(), nullptr);
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ EXPECT_FALSE(sem->ConstantValue()->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->AllZero());
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(0)->As<i32>(), 0_i);
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(1)->As<i32>(), 0_i);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Type()->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(0)->As<u32>(), 0_u);
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(1)->As<u32>(), 0_u);
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(2)->As<u32>(), 0_u);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(2)->Index(0)->As<f32>(), 0._f);
+ EXPECT_EQ(sem->ConstantValue()->Index(2)->Index(1)->As<f32>(), 0._f);
+ EXPECT_EQ(sem->ConstantValue()->Index(2)->Index(2)->As<f32>(), 0._f);
+ EXPECT_EQ(sem->ConstantValue()->Index(2)->Index(3)->As<f32>(), 0._f);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(3)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(3)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(3)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(3)->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->ConstantValue()->Index(3)->Type()->As<sem::Vector>()->type()->Is<sem::F16>());
+ EXPECT_EQ(sem->ConstantValue()->Index(3)->Index(0)->As<f16>(), 0._h);
+ EXPECT_EQ(sem->ConstantValue()->Index(3)->Index(1)->As<f16>(), 0._h);
+ EXPECT_EQ(sem->ConstantValue()->Index(3)->Index(2)->As<f16>(), 0._h);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->Type()->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_EQ(sem->ConstantValue()->Index(4)->Index(0)->As<bool>(), false);
+ EXPECT_EQ(sem->ConstantValue()->Index(4)->Index(1)->As<bool>(), false);
+}
+
+TEST_F(ResolverConstantsTest, Struct_Struct_ZeroInit) {
+ Structure("Inner", {
+ Member("m1", ty.i32()),
+ Member("m2", ty.u32()),
+ Member("m3", ty.f32()),
+ });
+
+ Structure("Outer", {
+ Member("m1", ty.type_name("Inner")),
+ Member("m2", ty.type_name("Inner")),
+ });
+ auto* expr = Construct(ty.type_name("Outer"));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* str = sem->Type()->As<sem::Struct>();
+ ASSERT_NE(str, nullptr);
+ EXPECT_EQ(str->Members().size(), 2u);
+ ASSERT_NE(sem->ConstantValue(), nullptr);
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue()->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->AllZero());
+
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Type()->Is<sem::Struct>());
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(0)->As<i32>(), 0_i);
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(1)->As<u32>(), 0_u);
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(2)->As<f32>(), 0_f);
+
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Type()->Is<sem::Struct>());
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(0)->As<i32>(), 0_i);
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(1)->As<u32>(), 0_u);
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(2)->As<f32>(), 0_f);
+}
+
+TEST_F(ResolverConstantsTest, Struct_MixedScalars_Construct) {
+ Enable(ast::Extension::kF16);
+
+ Structure("S", {
+ Member("m1", ty.i32()),
+ Member("m2", ty.u32()),
+ Member("m3", ty.f32()),
+ Member("m4", ty.f16()),
+ Member("m5", ty.bool_()),
+ });
+ auto* expr = Construct(ty.type_name("S"), 1_i, 2_u, 3_f, 4_h, false);
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* str = sem->Type()->As<sem::Struct>();
+ ASSERT_NE(str, nullptr);
+ EXPECT_EQ(str->Members().size(), 5u);
+ ASSERT_NE(sem->ConstantValue(), nullptr);
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ EXPECT_FALSE(sem->ConstantValue()->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->AllZero());
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Type()->Is<sem::I32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->As<i32>(), 1_i);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Type()->Is<sem::U32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->As<u32>(), 2_u);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->Type()->Is<sem::F32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(2)->As<f32>(), 3._f);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(3)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(3)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(3)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(3)->Type()->Is<sem::F16>());
+ EXPECT_EQ(sem->ConstantValue()->Index(3)->As<f16>(), 4._h);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->Type()->Is<sem::Bool>());
+ EXPECT_EQ(sem->ConstantValue()->Index(4)->As<bool>(), false);
+}
+
+TEST_F(ResolverConstantsTest, Struct_MixedVectors_Construct) {
+ Enable(ast::Extension::kF16);
+
+ Structure("S", {
+ Member("m1", ty.vec2<i32>()),
+ Member("m2", ty.vec3<u32>()),
+ Member("m3", ty.vec4<f32>()),
+ Member("m4", ty.vec3<f16>()),
+ Member("m5", ty.vec2<bool>()),
+ });
+ auto* expr = Construct(ty.type_name("S"), vec2<i32>(1_i), vec3<u32>(2_u), vec4<f32>(3_f),
+ vec3<f16>(4_h), vec2<bool>(false));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* str = sem->Type()->As<sem::Struct>();
+ ASSERT_NE(str, nullptr);
+ EXPECT_EQ(str->Members().size(), 5u);
+ ASSERT_NE(sem->ConstantValue(), nullptr);
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ EXPECT_FALSE(sem->ConstantValue()->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->AllZero());
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(0)->As<i32>(), 1_i);
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(1)->As<i32>(), 1_i);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Type()->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(0)->As<u32>(), 2_u);
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(1)->As<u32>(), 2_u);
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(2)->As<u32>(), 2_u);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(2)->Index(0)->As<f32>(), 3._f);
+ EXPECT_EQ(sem->ConstantValue()->Index(2)->Index(1)->As<f32>(), 3._f);
+ EXPECT_EQ(sem->ConstantValue()->Index(2)->Index(2)->As<f32>(), 3._f);
+ EXPECT_EQ(sem->ConstantValue()->Index(2)->Index(3)->As<f32>(), 3._f);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(3)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(3)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(3)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(3)->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->ConstantValue()->Index(3)->Type()->As<sem::Vector>()->type()->Is<sem::F16>());
+ EXPECT_EQ(sem->ConstantValue()->Index(3)->Index(0)->As<f16>(), 4._h);
+ EXPECT_EQ(sem->ConstantValue()->Index(3)->Index(1)->As<f16>(), 4._h);
+ EXPECT_EQ(sem->ConstantValue()->Index(3)->Index(2)->As<f16>(), 4._h);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->Type()->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_EQ(sem->ConstantValue()->Index(4)->Index(0)->As<bool>(), false);
+ EXPECT_EQ(sem->ConstantValue()->Index(4)->Index(1)->As<bool>(), false);
+}
+
+TEST_F(ResolverConstantsTest, Struct_Struct_Construct) {
+ Structure("Inner", {
+ Member("m1", ty.i32()),
+ Member("m2", ty.u32()),
+ Member("m3", ty.f32()),
+ });
+
+ Structure("Outer", {
+ Member("m1", ty.type_name("Inner")),
+ Member("m2", ty.type_name("Inner")),
+ });
+ auto* expr = Construct(ty.type_name("Outer"), //
+ Construct(ty.type_name("Inner"), 1_i, 2_u, 3_f),
+ Construct(ty.type_name("Inner"), 4_i, 0_u, 6_f));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* str = sem->Type()->As<sem::Struct>();
+ ASSERT_NE(str, nullptr);
+ EXPECT_EQ(str->Members().size(), 2u);
+ ASSERT_NE(sem->ConstantValue(), nullptr);
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ EXPECT_FALSE(sem->ConstantValue()->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->AllZero());
+
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Type()->Is<sem::Struct>());
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(0)->As<i32>(), 1_i);
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(1)->As<u32>(), 2_u);
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(2)->As<f32>(), 3_f);
+
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Type()->Is<sem::Struct>());
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(0)->As<i32>(), 4_i);
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(1)->As<u32>(), 0_u);
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(2)->As<f32>(), 6_f);
+}
+
+TEST_F(ResolverConstantsTest, Struct_Array_Construct) {
+ Structure("S", {
+ Member("m1", ty.array<i32, 2>()),
+ Member("m2", ty.array<f32, 3>()),
+ });
+ auto* expr = Construct(ty.type_name("S"), //
+ Construct(ty.array<i32, 2>(), 1_i, 2_i),
+ Construct(ty.array<f32, 3>(), 1_f, 2_f, 3_f));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* str = sem->Type()->As<sem::Struct>();
+ ASSERT_NE(str, nullptr);
+ EXPECT_EQ(str->Members().size(), 2u);
+ ASSERT_NE(sem->ConstantValue(), nullptr);
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ EXPECT_FALSE(sem->ConstantValue()->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->AllZero());
+
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Type()->Is<sem::Array>());
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(0)->As<i32>(), 1_i);
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->Index(1)->As<u32>(), 2_i);
+
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Type()->Is<sem::Array>());
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(0)->As<i32>(), 1_f);
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(1)->As<u32>(), 2_f);
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(2)->As<f32>(), 3_f);
+}
+
////////////////////////////////////////////////////////////////////////////////////////////////////
// Indexing
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -2302,5 +2872,59 @@
}
}
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// Member accessing
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+TEST_F(ResolverConstantsTest, MemberAccess) {
+ Structure("Inner", {
+ Member("i1", ty.i32()),
+ Member("i2", ty.u32()),
+ Member("i3", ty.f32()),
+ });
+
+ Structure("Outer", {
+ Member("o1", ty.type_name("Inner")),
+ Member("o2", ty.type_name("Inner")),
+ });
+ auto* outer_expr = Construct(ty.type_name("Outer"), //
+ Construct(ty.type_name("Inner"), 1_i, 2_u, 3_f),
+ Construct(ty.type_name("Inner")));
+ auto* o1_expr = MemberAccessor(outer_expr, "o1");
+ auto* i2_expr = MemberAccessor(o1_expr, "i2");
+ WrapInFunction(i2_expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* outer = Sem().Get(outer_expr);
+ ASSERT_NE(outer, nullptr);
+ auto* str = outer->Type()->As<sem::Struct>();
+ ASSERT_NE(str, nullptr);
+ EXPECT_EQ(str->Members().size(), 2u);
+ ASSERT_NE(outer->ConstantValue(), nullptr);
+ EXPECT_TYPE(outer->ConstantValue()->Type(), outer->Type());
+ EXPECT_FALSE(outer->ConstantValue()->AllEqual());
+ EXPECT_TRUE(outer->ConstantValue()->AnyZero());
+ EXPECT_FALSE(outer->ConstantValue()->AllZero());
+
+ auto* o1 = Sem().Get(o1_expr);
+ ASSERT_NE(o1->ConstantValue(), nullptr);
+ EXPECT_FALSE(o1->ConstantValue()->AllEqual());
+ EXPECT_FALSE(o1->ConstantValue()->AnyZero());
+ EXPECT_FALSE(o1->ConstantValue()->AllZero());
+ EXPECT_TRUE(o1->ConstantValue()->Type()->Is<sem::Struct>());
+ EXPECT_EQ(o1->ConstantValue()->Index(0)->As<i32>(), 1_i);
+ EXPECT_EQ(o1->ConstantValue()->Index(1)->As<u32>(), 2_u);
+ EXPECT_EQ(o1->ConstantValue()->Index(2)->As<f32>(), 3_f);
+
+ auto* i2 = Sem().Get(i2_expr);
+ ASSERT_NE(i2->ConstantValue(), nullptr);
+ EXPECT_TRUE(i2->ConstantValue()->AllEqual());
+ EXPECT_FALSE(i2->ConstantValue()->AnyZero());
+ EXPECT_FALSE(i2->ConstantValue()->AllZero());
+ EXPECT_TRUE(i2->ConstantValue()->Type()->Is<sem::U32>());
+ EXPECT_EQ(i2->ConstantValue()->As<u32>(), 2_u);
+}
+
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/resolver/variable_test.cc b/src/tint/resolver/variable_test.cc
index 7eb0dea..49c1f8b 100644
--- a/src/tint/resolver/variable_test.cc
+++ b/src/tint/resolver/variable_test.cc
@@ -892,6 +892,8 @@
}
TEST_F(ResolverVariableTest, LocalConst_ExplicitType_Decls) {
+ Structure("S", {Member("m", ty.u32())});
+
auto* c_i32 = Const("a", ty.i32(), Expr(0_i));
auto* c_u32 = Const("b", ty.u32(), Expr(0_u));
auto* c_f32 = Const("c", ty.f32(), Expr(0_f));
@@ -899,8 +901,9 @@
auto* c_vu32 = Const("e", ty.vec3<u32>(), vec3<u32>());
auto* c_vf32 = Const("f", ty.vec3<f32>(), vec3<f32>());
auto* c_mf32 = Const("g", ty.mat3x3<f32>(), mat3x3<f32>());
+ auto* c_s = Const("h", ty.type_name("S"), Construct(ty.type_name("S")));
- WrapInFunction(c_i32, c_u32, c_f32, c_vi32, c_vu32, c_vf32, c_mf32);
+ WrapInFunction(c_i32, c_u32, c_f32, c_vi32, c_vu32, c_vf32, c_mf32, c_s);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -911,6 +914,7 @@
EXPECT_EQ(Sem().Get(c_vu32)->Declaration(), c_vu32);
EXPECT_EQ(Sem().Get(c_vf32)->Declaration(), c_vf32);
EXPECT_EQ(Sem().Get(c_mf32)->Declaration(), c_mf32);
+ EXPECT_EQ(Sem().Get(c_s)->Declaration(), c_s);
ASSERT_TRUE(TypeOf(c_i32)->Is<sem::I32>());
ASSERT_TRUE(TypeOf(c_u32)->Is<sem::U32>());
@@ -919,6 +923,7 @@
ASSERT_TRUE(TypeOf(c_vu32)->Is<sem::Vector>());
ASSERT_TRUE(TypeOf(c_vf32)->Is<sem::Vector>());
ASSERT_TRUE(TypeOf(c_mf32)->Is<sem::Matrix>());
+ ASSERT_TRUE(TypeOf(c_s)->Is<sem::Struct>());
EXPECT_TRUE(Sem().Get(c_i32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_u32)->ConstantValue()->AllZero());
@@ -927,9 +932,12 @@
EXPECT_TRUE(Sem().Get(c_vu32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_vf32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_mf32)->ConstantValue()->AllZero());
+ EXPECT_TRUE(Sem().Get(c_s)->ConstantValue()->AllZero());
}
TEST_F(ResolverVariableTest, LocalConst_ImplicitType_Decls) {
+ Structure("S", {Member("m", ty.u32())});
+
auto* c_i32 = Const("a", nullptr, Expr(0_i));
auto* c_u32 = Const("b", nullptr, Expr(0_u));
auto* c_f32 = Const("c", nullptr, Expr(0_f));
@@ -946,9 +954,10 @@
Construct(ty.vec(nullptr, 3), Expr(0._a)),
Construct(ty.vec(nullptr, 3), Expr(0._a)),
Construct(ty.vec(nullptr, 3), Expr(0._a))));
+ auto* c_s = Const("m", nullptr, Construct(ty.type_name("S")));
WrapInFunction(c_i32, c_u32, c_f32, c_ai, c_af, c_vi32, c_vu32, c_vf32, c_vai, c_vaf, c_mf32,
- c_maf32);
+ c_maf32, c_s);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -964,6 +973,7 @@
EXPECT_EQ(Sem().Get(c_vaf)->Declaration(), c_vaf);
EXPECT_EQ(Sem().Get(c_mf32)->Declaration(), c_mf32);
EXPECT_EQ(Sem().Get(c_maf32)->Declaration(), c_maf32);
+ EXPECT_EQ(Sem().Get(c_s)->Declaration(), c_s);
ASSERT_TRUE(TypeOf(c_i32)->Is<sem::I32>());
ASSERT_TRUE(TypeOf(c_u32)->Is<sem::U32>());
@@ -977,6 +987,7 @@
ASSERT_TRUE(TypeOf(c_vaf)->Is<sem::Vector>());
ASSERT_TRUE(TypeOf(c_mf32)->Is<sem::Matrix>());
ASSERT_TRUE(TypeOf(c_maf32)->Is<sem::Matrix>());
+ ASSERT_TRUE(TypeOf(c_s)->Is<sem::Struct>());
EXPECT_TRUE(Sem().Get(c_i32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_u32)->ConstantValue()->AllZero());
@@ -990,6 +1001,7 @@
EXPECT_TRUE(Sem().Get(c_vaf)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_mf32)->ConstantValue()->AllZero());
EXPECT_TRUE(Sem().Get(c_maf32)->ConstantValue()->AllZero());
+ EXPECT_TRUE(Sem().Get(c_s)->ConstantValue()->AllZero());
}
TEST_F(ResolverVariableTest, LocalConst_PropagateConstValue) {
diff --git a/src/tint/resolver/variable_validation_test.cc b/src/tint/resolver/variable_validation_test.cc
index 79b269c..b538db0 100644
--- a/src/tint/resolver/variable_validation_test.cc
+++ b/src/tint/resolver/variable_validation_test.cc
@@ -365,23 +365,6 @@
EXPECT_EQ(r()->error(), "12:34 error: missing matrix element type");
}
-TEST_F(ResolverVariableValidationTest, ConstStructure) {
- auto* s = Structure("S", {Member("m", ty.i32())});
- auto* c = Const("c", ty.Of(s), Construct(Source{{12, 34}}, ty.Of(s)));
- WrapInFunction(c);
-
- EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), R"(12:34 error: 'const' initializer must be constant expression)");
-}
-
-TEST_F(ResolverVariableValidationTest, GlobalConstStructure) {
- auto* s = Structure("S", {Member("m", ty.i32())});
- GlobalConst("c", ty.Of(s), Construct(Source{{12, 34}}, ty.Of(s)));
-
- EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), R"(12:34 error: 'const' initializer must be constant expression)");
-}
-
TEST_F(ResolverVariableValidationTest, ConstInitWithVar) {
auto* v = Var("v", nullptr, Expr(1_i));
auto* c = Const("c", nullptr, Expr(Source{{12, 34}}, v));
diff --git a/src/tint/sem/constant.h b/src/tint/sem/constant.h
index b46127c..875864f 100644
--- a/src/tint/sem/constant.h
+++ b/src/tint/sem/constant.h
@@ -39,11 +39,15 @@
virtual const sem::Type* Type() const = 0;
/// @returns the value of this Constant, if this constant is of a scalar value or abstract
- /// numeric, otherwsie std::monostate.
+ /// numeric, otherwise std::monostate.
virtual std::variant<std::monostate, AInt, AFloat> Value() const = 0;
/// @returns the child constant element with the given index, or nullptr if the constant has no
/// children, or the index is out of bounds.
+ /// For arrays, this returns the i'th element of the array.
+ /// For vectors, this returns the i'th element of the vector.
+ /// For matrices, this returns the i'th column vector of the matrix.
+ /// For structures, this returns the i'th member field of the structure.
virtual const Constant* Index(size_t) const = 0;
/// @returns true if child elements of this constant are positive-zero valued.
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index 01cbba0..767fecb 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -862,15 +862,10 @@
return EmitZeroValue(out, type);
}
- auto it = structure_builders_.find(As<sem::Struct>(type));
- if (it != structure_builders_.end()) {
- out << it->second << "(";
- } else {
- if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite, "")) {
- return false;
- }
- out << "(";
+ if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite, "")) {
+ return false;
}
+ ScopedParen sp(out);
bool first = true;
for (auto* arg : call->Arguments()) {
@@ -884,7 +879,6 @@
}
}
- out << ")";
return true;
}
@@ -2300,6 +2294,24 @@
return true;
},
+ [&](const sem::Struct* s) {
+ if (!EmitType(out, s, ast::StorageClass::kNone, ast::Access::kUndefined, "")) {
+ return false;
+ }
+
+ ScopedParen sp(out);
+
+ for (size_t i = 0; i < s->Members().size(); i++) {
+ if (i > 0) {
+ out << ", ";
+ }
+ if (!EmitConstant(out, constant->Index(i))) {
+ return false;
+ }
+ }
+
+ return true;
+ },
[&](Default) {
diagnostics_.add_error(
diag::System::Writer,
diff --git a/src/tint/writer/glsl/generator_impl.h b/src/tint/writer/glsl/generator_impl.h
index 153e316..58a0e50 100644
--- a/src/tint/writer/glsl/generator_impl.h
+++ b/src/tint/writer/glsl/generator_impl.h
@@ -520,7 +520,6 @@
std::function<bool()> emit_continuing_;
std::unordered_map<DMAIntrinsic, std::string, DMAIntrinsic::Hasher> dma_intrinsics_;
std::unordered_map<const sem::Builtin*, std::string> builtins_;
- std::unordered_map<const sem::Struct*, std::string> structure_builders_;
std::unordered_map<const sem::Vector*, std::string> dynamic_vector_write_;
std::unordered_map<const sem::Vector*, std::string> int_dot_funcs_;
std::unordered_map<const sem::Type*, std::string> float_modulo_funcs_;
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index c9861a1..0a6a196 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -1142,11 +1142,7 @@
call->Arguments().size() == 1 &&
ctor->Parameters()[0]->Type()->is_scalar();
- auto it = structure_builders_.find(As<sem::Struct>(type));
- if (it != structure_builders_.end()) {
- out << it->second << "(";
- brackets = false;
- } else if (brackets) {
+ if (brackets) {
out << "{";
} else {
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite, "")) {
@@ -3219,6 +3215,30 @@
return true;
},
+ [&](const sem::Struct* s) {
+ if (constant->AllZero()) {
+ out << "(";
+ if (!EmitType(out, s, ast::StorageClass::kNone, ast::Access::kUndefined, "")) {
+ return false;
+ }
+ out << ")0";
+ return true;
+ }
+
+ out << "{";
+ TINT_DEFER(out << "}");
+
+ for (size_t i = 0; i < s->Members().size(); i++) {
+ if (i > 0) {
+ out << ", ";
+ }
+ if (!EmitConstant(out, constant->Index(i))) {
+ return false;
+ }
+ }
+
+ return true;
+ },
[&](Default) {
diagnostics_.add_error(
diag::System::Writer,
diff --git a/src/tint/writer/hlsl/generator_impl.h b/src/tint/writer/hlsl/generator_impl.h
index bf2debe..74ce70c 100644
--- a/src/tint/writer/hlsl/generator_impl.h
+++ b/src/tint/writer/hlsl/generator_impl.h
@@ -543,7 +543,6 @@
std::function<bool()> emit_continuing_;
std::unordered_map<const sem::Matrix*, std::string> matrix_scalar_ctors_;
std::unordered_map<const sem::Builtin*, std::string> builtins_;
- std::unordered_map<const sem::Struct*, std::string> structure_builders_;
std::unordered_map<const sem::Vector*, std::string> dynamic_vector_write_;
std::unordered_map<const sem::Matrix*, std::string> dynamic_matrix_vector_write_;
std::unordered_map<const sem::Matrix*, std::string> dynamic_matrix_scalar_write_;
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index e2b14a6..2af36c9 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -1676,14 +1676,13 @@
return false;
}
- if (constant->AllZero()) {
- out << "{}";
- return true;
- }
-
out << "{";
TINT_DEFER(out << "}");
+ if (constant->AllZero()) {
+ return true;
+ }
+
for (size_t i = 0; i < a->Count(); i++) {
if (i > 0) {
out << ", ";
@@ -1695,6 +1694,27 @@
return true;
},
+ [&](const sem::Struct* s) {
+ out << "{";
+ TINT_DEFER(out << "}");
+
+ if (constant->AllZero()) {
+ return true;
+ }
+
+ auto& members = s->Members();
+ for (size_t i = 0; i < members.size(); i++) {
+ if (i > 0) {
+ out << ", ";
+ }
+ out << "." << program_->Symbols().NameFor(members[i]->Name()) << "=";
+ if (!EmitConstant(out, constant->Index(i))) {
+ return false;
+ }
+ }
+
+ return true;
+ },
[&](Default) {
diagnostics_.add_error(
diag::System::Writer,
diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc
index b34fec3..92fcf78 100644
--- a/src/tint/writer/spirv/builder.cc
+++ b/src/tint/writer/spirv/builder.cc
@@ -1770,6 +1770,7 @@
[&](const sem::Vector* v) { return composite(v->Width()); },
[&](const sem::Matrix* m) { return composite(m->columns()); },
[&](const sem::Array* a) { return composite(a->Count()); },
+ [&](const sem::Struct* s) { return composite(s->Members().size()); },
[&](Default) {
error_ = "unhandled constant type: " + builder_.FriendlyName(ty);
return false;
diff --git a/src/tint/writer/spirv/builder_accessor_expression_test.cc b/src/tint/writer/spirv/builder_accessor_expression_test.cc
index 8996b8d..23dc783 100644
--- a/src/tint/writer/spirv/builder_accessor_expression_test.cc
+++ b/src/tint/writer/spirv/builder_accessor_expression_test.cc
@@ -47,7 +47,8 @@
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), R"(%12 = OpVariable %13 Function %14
)");
- EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(%11 = OpCompositeExtract %6 %10 1
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%11 = OpCompositeExtract %6 %10 1
OpStore %12 %11
OpReturn
)");
@@ -773,7 +774,8 @@
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), R"(%18 = OpVariable %19 Function %20
)");
- EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(%17 = OpCompositeExtract %6 %14 1
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%17 = OpCompositeExtract %6 %14 1
OpStore %18 %17
OpReturn
)");
@@ -1009,11 +1011,10 @@
%1 = OpTypeFunction %2
%6 = OpTypeFloat 32
%5 = OpTypeStruct %6 %6
-%7 = OpConstantNull %6
-%8 = OpConstantComposite %5 %7 %7
+%7 = OpConstantNull %5
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), R"()");
- EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(%9 = OpCompositeExtract %6 %8 1
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(%8 = OpCompositeExtract %6 %7 1
OpReturn
)");
@@ -1052,14 +1053,12 @@
%7 = OpTypeFloat 32
%6 = OpTypeStruct %7 %7
%5 = OpTypeStruct %6
-%8 = OpConstantNull %7
-%9 = OpConstantComposite %6 %8 %8
-%10 = OpConstantComposite %5 %9
+%8 = OpConstantNull %5
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), R"()");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
- R"(%11 = OpCompositeExtract %6 %10 0
-%12 = OpCompositeExtract %7 %11 1
+ R"(%9 = OpCompositeExtract %6 %8 0
+%10 = OpCompositeExtract %7 %9 1
OpReturn
)");