[wgsl] Fix 8-bit subgroup matrix constructors
The signatures for subgroup matrix constructors need to use the shader
scalar types, which are 32-bit integers for the 8-bit matrices.
Fixed: 458773229
Change-Id: I710af6fd705cde43d0978f70c1a0a191d7cc175f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/272174
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 655f3f1..d740796 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -2097,10 +2097,17 @@
auto* call_target = subgroup_matrix_ctors_.GetOrAdd(
SubgroupMatrixConstructorSig{{m, args.Length()}},
[&]() -> sem::ValueConstructor* {
+ // 8-bit integer matrices use 32-bit shader scalar types.
+ auto* scalar_ty = m->Type();
+ if (m->Type()->Is<core::type::I8>()) {
+ scalar_ty = b.create<core::type::I32>();
+ } else if (m->Type()->Is<core::type::U8>()) {
+ scalar_ty = b.create<core::type::U32>();
+ }
auto params = tint::Transform(args, [&](auto, size_t i) {
return b.create<sem::Parameter>(nullptr, // declaration
static_cast<uint32_t>(i), // index
- m->Type());
+ scalar_ty);
});
return b.create<sem::ValueConstructor>(m, std::move(params),
core::EvaluationStage::kRuntime);
@@ -2110,7 +2117,7 @@
return nullptr;
}
- if (DAWN_UNLIKELY(!validator_.SubgroupMatrixConstructor(expr, m))) {
+ if (DAWN_UNLIKELY(!validator_.SubgroupMatrixConstructor(expr, m, call_target))) {
return nullptr;
}
diff --git a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
index caa327e..a4e06f4 100644
--- a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
+++ b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
@@ -292,6 +292,44 @@
EXPECT_EQ(target->Stage(), core::EvaluationStage::kRuntime);
}
+TEST_F(ResolverSubgroupMatrixTest, SingleValueConstructor_U8_Abstract) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ auto* call = Call(Ident("subgroup_matrix_result", ty.u8(), 8_a, 8_a), 1_a);
+ Func("foo", Empty, ty.void_(),
+ Vector{
+ Assign(Phony(), call),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto call_sem = Sem().Get(call)->As<sem::Call>();
+ ASSERT_NE(call_sem, nullptr);
+ auto* target = call_sem->Target()->As<sem::ValueConstructor>();
+ ASSERT_NE(target, nullptr);
+ EXPECT_TRUE(target->ReturnType()->Is<core::type::SubgroupMatrix>());
+ EXPECT_EQ(target->Parameters().Length(), 1u);
+ EXPECT_EQ(target->Stage(), core::EvaluationStage::kRuntime);
+}
+
+TEST_F(ResolverSubgroupMatrixTest, SingleValueConstructor_U8_U32) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ auto* call = Call(Ident("subgroup_matrix_result", ty.u8(), 8_a, 8_a), 1_u);
+ Func("foo", Empty, ty.void_(),
+ Vector{
+ Assign(Phony(), call),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto call_sem = Sem().Get(call)->As<sem::Call>();
+ ASSERT_NE(call_sem, nullptr);
+ auto* target = call_sem->Target()->As<sem::ValueConstructor>();
+ ASSERT_NE(target, nullptr);
+ EXPECT_TRUE(target->ReturnType()->Is<core::type::SubgroupMatrix>());
+ EXPECT_EQ(target->Parameters().Length(), 1u);
+ EXPECT_EQ(target->Stage(), core::EvaluationStage::kRuntime);
+}
+
TEST_F(ResolverSubgroupMatrixTest, ConstructorTooManyArgs) {
Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
auto* call = Call(Ident("subgroup_matrix_result", ty.f32(), 8_a, 8_a), 1_f, 2_f);
diff --git a/src/tint/lang/wgsl/resolver/validator.cc b/src/tint/lang/wgsl/resolver/validator.cc
index 59ae08b..2da058c 100644
--- a/src/tint/lang/wgsl/resolver/validator.cc
+++ b/src/tint/lang/wgsl/resolver/validator.cc
@@ -2389,14 +2389,15 @@
return true;
}
-bool Validator::SubgroupMatrixConstructor(
- const ast::CallExpression* ctor,
- const core::type::SubgroupMatrix* subgroup_matrix_type) const {
+bool Validator::SubgroupMatrixConstructor(const ast::CallExpression* ctor,
+ const core::type::SubgroupMatrix* subgroup_matrix_type,
+ const sem::CallTarget* signature) const {
auto& values = ctor->args;
if (values.Length() == 1) {
auto* elem_ty = subgroup_matrix_type->Type();
auto* value_ty = sem_.TypeOf(values[0])->UnwrapRef();
- if (core::type::Type::ConversionRank(value_ty, elem_ty) ==
+ auto* expected_ty = signature->Parameters()[0]->Type();
+ if (core::type::Type::ConversionRank(value_ty, expected_ty) ==
core::type::Type::kNoConversion) {
AddError(values[0]->source) << style::Type(sem_.TypeNameOf(value_ty))
<< " cannot be used to construct a subgroup matrix of "
diff --git a/src/tint/lang/wgsl/resolver/validator.h b/src/tint/lang/wgsl/resolver/validator.h
index de3b8b4..7142ec2 100644
--- a/src/tint/lang/wgsl/resolver/validator.h
+++ b/src/tint/lang/wgsl/resolver/validator.h
@@ -536,9 +536,11 @@
/// Validates a subgroup matrix constructor
/// @param ctor the call expression to validate
/// @param subgroup_matrix_type the type of the subgroup matrix
+ /// @param signature the construct signature to validate against
/// @returns true on success, false otherwise
bool SubgroupMatrixConstructor(const ast::CallExpression* ctor,
- const core::type::SubgroupMatrix* subgroup_matrix_type) const;
+ const core::type::SubgroupMatrix* subgroup_matrix_type,
+ const sem::CallTarget* signature) const;
/// Validates a subgroupShuffle builtin functions including Up,Down, and Xor.
/// @param fn the builtin call type