[wgsl] Fix 8-bit subgroup matrix constructors The signatures for subgroup matrix constructors need to use the shader scalar types, which are 32-bit integers for the 8-bit matrices. Fixed: 458773229 Change-Id: I710af6fd705cde43d0978f70c1a0a191d7cc175f Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/272174 Reviewed-by: dan sinclair <dsinclair@chromium.org> Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc index 655f3f1..d740796 100644 --- a/src/tint/lang/wgsl/resolver/resolver.cc +++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -2097,10 +2097,17 @@ auto* call_target = subgroup_matrix_ctors_.GetOrAdd( SubgroupMatrixConstructorSig{{m, args.Length()}}, [&]() -> sem::ValueConstructor* { + // 8-bit integer matrices use 32-bit shader scalar types. + auto* scalar_ty = m->Type(); + if (m->Type()->Is<core::type::I8>()) { + scalar_ty = b.create<core::type::I32>(); + } else if (m->Type()->Is<core::type::U8>()) { + scalar_ty = b.create<core::type::U32>(); + } auto params = tint::Transform(args, [&](auto, size_t i) { return b.create<sem::Parameter>(nullptr, // declaration static_cast<uint32_t>(i), // index - m->Type()); + scalar_ty); }); return b.create<sem::ValueConstructor>(m, std::move(params), core::EvaluationStage::kRuntime); @@ -2110,7 +2117,7 @@ return nullptr; } - if (DAWN_UNLIKELY(!validator_.SubgroupMatrixConstructor(expr, m))) { + if (DAWN_UNLIKELY(!validator_.SubgroupMatrixConstructor(expr, m, call_target))) { return nullptr; }
diff --git a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc index caa327e..a4e06f4 100644 --- a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc +++ b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
@@ -292,6 +292,44 @@ EXPECT_EQ(target->Stage(), core::EvaluationStage::kRuntime); } +TEST_F(ResolverSubgroupMatrixTest, SingleValueConstructor_U8_Abstract) { + Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix); + auto* call = Call(Ident("subgroup_matrix_result", ty.u8(), 8_a, 8_a), 1_a); + Func("foo", Empty, ty.void_(), + Vector{ + Assign(Phony(), call), + }); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto call_sem = Sem().Get(call)->As<sem::Call>(); + ASSERT_NE(call_sem, nullptr); + auto* target = call_sem->Target()->As<sem::ValueConstructor>(); + ASSERT_NE(target, nullptr); + EXPECT_TRUE(target->ReturnType()->Is<core::type::SubgroupMatrix>()); + EXPECT_EQ(target->Parameters().Length(), 1u); + EXPECT_EQ(target->Stage(), core::EvaluationStage::kRuntime); +} + +TEST_F(ResolverSubgroupMatrixTest, SingleValueConstructor_U8_U32) { + Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix); + auto* call = Call(Ident("subgroup_matrix_result", ty.u8(), 8_a, 8_a), 1_u); + Func("foo", Empty, ty.void_(), + Vector{ + Assign(Phony(), call), + }); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto call_sem = Sem().Get(call)->As<sem::Call>(); + ASSERT_NE(call_sem, nullptr); + auto* target = call_sem->Target()->As<sem::ValueConstructor>(); + ASSERT_NE(target, nullptr); + EXPECT_TRUE(target->ReturnType()->Is<core::type::SubgroupMatrix>()); + EXPECT_EQ(target->Parameters().Length(), 1u); + EXPECT_EQ(target->Stage(), core::EvaluationStage::kRuntime); +} + TEST_F(ResolverSubgroupMatrixTest, ConstructorTooManyArgs) { Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix); auto* call = Call(Ident("subgroup_matrix_result", ty.f32(), 8_a, 8_a), 1_f, 2_f);
diff --git a/src/tint/lang/wgsl/resolver/validator.cc b/src/tint/lang/wgsl/resolver/validator.cc index 59ae08b..2da058c 100644 --- a/src/tint/lang/wgsl/resolver/validator.cc +++ b/src/tint/lang/wgsl/resolver/validator.cc
@@ -2389,14 +2389,15 @@ return true; } -bool Validator::SubgroupMatrixConstructor( - const ast::CallExpression* ctor, - const core::type::SubgroupMatrix* subgroup_matrix_type) const { +bool Validator::SubgroupMatrixConstructor(const ast::CallExpression* ctor, + const core::type::SubgroupMatrix* subgroup_matrix_type, + const sem::CallTarget* signature) const { auto& values = ctor->args; if (values.Length() == 1) { auto* elem_ty = subgroup_matrix_type->Type(); auto* value_ty = sem_.TypeOf(values[0])->UnwrapRef(); - if (core::type::Type::ConversionRank(value_ty, elem_ty) == + auto* expected_ty = signature->Parameters()[0]->Type(); + if (core::type::Type::ConversionRank(value_ty, expected_ty) == core::type::Type::kNoConversion) { AddError(values[0]->source) << style::Type(sem_.TypeNameOf(value_ty)) << " cannot be used to construct a subgroup matrix of "
diff --git a/src/tint/lang/wgsl/resolver/validator.h b/src/tint/lang/wgsl/resolver/validator.h index de3b8b4..7142ec2 100644 --- a/src/tint/lang/wgsl/resolver/validator.h +++ b/src/tint/lang/wgsl/resolver/validator.h
@@ -536,9 +536,11 @@ /// Validates a subgroup matrix constructor /// @param ctor the call expression to validate /// @param subgroup_matrix_type the type of the subgroup matrix + /// @param signature the construct signature to validate against /// @returns true on success, false otherwise bool SubgroupMatrixConstructor(const ast::CallExpression* ctor, - const core::type::SubgroupMatrix* subgroup_matrix_type) const; + const core::type::SubgroupMatrix* subgroup_matrix_type, + const sem::CallTarget* signature) const; /// Validates a subgroupShuffle builtin functions including Up,Down, and Xor. /// @param fn the builtin call type