tint: add const eval array constructor tests
Bug: tint:1581
Change-Id: Ia6c4ba974b40cdff8dc28ddbd510189355ed27cb
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/115400
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/resolver/const_eval_construction_test.cc b/src/tint/resolver/const_eval_construction_test.cc
index bb23121..1462f7b 100644
--- a/src/tint/resolver/const_eval_construction_test.cc
+++ b/src/tint/resolver/const_eval_construction_test.cc
@@ -1623,6 +1623,102 @@
EXPECT_EQ(sem->ConstantValue()->Index(3)->ValueAs<i32>(), 40_i);
}
+namespace ArrayInit {
+struct Case {
+ Value input;
+};
+static Case C(Value input) {
+ return Case{std::move(input)};
+}
+static std::ostream& operator<<(std::ostream& o, const Case& c) {
+ return o << "input: " << c.input;
+}
+
+using ResolverConstEvalArrayInitTest = ResolverTestWithParam<Case>;
+TEST_P(ResolverConstEvalArrayInitTest, Test) {
+ Enable(ast::Extension::kF16);
+ auto& param = GetParam();
+ auto* expr = param.input.Expr(*this);
+ auto* a = Const("a", expr);
+ WrapInFunction(a);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* arr = sem->Type()->As<type::Array>();
+ ASSERT_NE(arr, nullptr);
+
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ // Constant values should match input values
+ CheckConstant(sem->ConstantValue(), param.input);
+}
+template <typename T>
+std::vector<Case> ArrayInitCases() {
+ return {
+ C(Array(T(0))), //
+ C(Array(T(0))), //
+ C(Array(T(0), T(1))), //
+ C(Array(T(0), T(1), T(2))), //
+ C(Array(T(2), T(1), T(0))), //
+ C(Array(T(2), T(0), T(1))), //
+ };
+}
+INSTANTIATE_TEST_SUITE_P( //
+ ArrayInit,
+ ResolverConstEvalArrayInitTest,
+ testing::ValuesIn(Concat(ArrayInitCases<AInt>(), //
+ ArrayInitCases<AFloat>(), //
+ ArrayInitCases<i32>(), //
+ ArrayInitCases<u32>(), //
+ ArrayInitCases<f32>(), //
+ ArrayInitCases<f16>(), //
+ ArrayInitCases<bool>())));
+} // namespace ArrayInit
+
+TEST_F(ResolverConstEvalTest, ArrayInit_Nested_f32) {
+ auto inner_ty = [&] { return ty.array<f32, 2>(); };
+ auto outer_ty = ty.array(inner_ty(), Expr(3_i));
+
+ auto* expr = Construct(outer_ty, //
+ Construct(inner_ty(), 1_f, 2_f), //
+ Construct(inner_ty(), 3_f, 4_f), //
+ Construct(inner_ty(), 5_f, 6_f));
+
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* outer_arr = sem->Type()->As<type::Array>();
+ ASSERT_NE(outer_arr, nullptr);
+ EXPECT_TRUE(outer_arr->ElemType()->Is<type::Array>());
+ EXPECT_TRUE(outer_arr->ElemType()->As<type::Array>()->ElemType()->Is<type::F32>());
+
+ auto* arr = sem->ConstantValue();
+ EXPECT_FALSE(arr->AllEqual());
+ EXPECT_FALSE(arr->AnyZero());
+ EXPECT_FALSE(arr->AllZero());
+
+ EXPECT_FALSE(arr->Index(0)->AllEqual());
+ EXPECT_FALSE(arr->Index(0)->AnyZero());
+ EXPECT_FALSE(arr->Index(0)->AllZero());
+ EXPECT_FALSE(arr->Index(1)->AllEqual());
+ EXPECT_FALSE(arr->Index(1)->AnyZero());
+ EXPECT_FALSE(arr->Index(1)->AllZero());
+ EXPECT_FALSE(arr->Index(2)->AllEqual());
+ EXPECT_FALSE(arr->Index(2)->AnyZero());
+ EXPECT_FALSE(arr->Index(2)->AllZero());
+
+ EXPECT_EQ(arr->Index(0)->Index(0)->ValueAs<f32>(), 1.0f);
+ EXPECT_EQ(arr->Index(0)->Index(1)->ValueAs<f32>(), 2.0f);
+ EXPECT_EQ(arr->Index(1)->Index(0)->ValueAs<f32>(), 3.0f);
+ EXPECT_EQ(arr->Index(1)->Index(1)->ValueAs<f32>(), 4.0f);
+ EXPECT_EQ(arr->Index(2)->Index(0)->ValueAs<f32>(), 5.0f);
+ EXPECT_EQ(arr->Index(2)->Index(1)->ValueAs<f32>(), 6.0f);
+}
+
TEST_F(ResolverConstEvalTest, Array_f32_Elements) {
auto* expr = Construct(ty.array<f32, 4>(), 10_f, 20_f, 30_f, 40_f);
WrapInFunction(expr);
diff --git a/src/tint/resolver/const_eval_test.h b/src/tint/resolver/const_eval_test.h
index c31d89f..001d3d1 100644
--- a/src/tint/resolver/const_eval_test.h
+++ b/src/tint/resolver/const_eval_test.h
@@ -245,6 +245,7 @@
return ss.str();
}
+using builder::Array;
using builder::IsValue;
using builder::Mat;
using builder::Val;
diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h
index 33818f2..24aea11 100644
--- a/src/tint/resolver/resolver_test_helper.h
+++ b/src/tint/resolver/resolver_test_helper.h
@@ -822,6 +822,17 @@
return Value::Create<vec<N, FirstT>>(std::move(v));
}
+/// Creates a Value of DataType<array<N, T>> from N scalar `args`
+template <typename... Ts>
+Value Array(Ts... args) {
+ using FirstT = std::tuple_element_t<0, std::tuple<Ts...>>;
+ static_assert(std::conjunction_v<std::is_same<FirstT, Ts>...>,
+ "Array args must all be the same type");
+ constexpr size_t N = sizeof...(args);
+ utils::Vector<Scalar, sizeof...(args)> v{args...};
+ return Value::Create<array<N, FirstT>>(std::move(v));
+}
+
/// Creates a Value of DataType<mat<C,R,T> from C*R scalar `args`
template <size_t C, size_t R, typename T>
Value Mat(const T (&m_in)[C][R]) {
@@ -884,7 +895,6 @@
}
return Value::Create<mat<C, R, T>>(std::move(m));
}
-
} // namespace builder
} // namespace tint::resolver