[wgsl] Handle subgroup matrix constructors
These have to be resolved manually due to the arbitrary column and row
counts.
Support zero value and single value constructors only.
Bug: 348702031
Change-Id: I550adb23c27411fdddb09aa54bac026967a91c14
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/224055
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 7c19a50..a181969 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -2270,6 +2270,36 @@
return arr_or_str_init(str, call_target);
},
+ [&](const core::type::SubgroupMatrix* m) -> sem::Call* {
+ auto* call_target = subgroup_matrix_ctors_.GetOrAdd(
+ SubgroupMatrixConstructorSig{{m, args.Length()}},
+ [&]() -> sem::ValueConstructor* {
+ auto params = tint::Transform(args, [&](auto, size_t i) {
+ return b.create<sem::Parameter>(nullptr, // declaration
+ static_cast<uint32_t>(i), // index
+ m->Type());
+ });
+ return b.create<sem::ValueConstructor>(m, std::move(params),
+ core::EvaluationStage::kRuntime);
+ });
+
+ if (DAWN_UNLIKELY(!MaybeMaterializeAndLoadArguments(args, call_target))) {
+ return nullptr;
+ }
+
+ if (DAWN_UNLIKELY(!validator_.SubgroupMatrixConstructor(expr, m))) {
+ return nullptr;
+ }
+
+ // Subgroup matrix constructors are never const-evaluated.
+ auto stage = core::EvaluationStage::kRuntime;
+ if (not_evaluated_.Contains(expr)) {
+ stage = core::EvaluationStage::kNotEvaluated;
+ }
+
+ return b.create<sem::Call>(expr, call_target, stage, std::move(args),
+ current_statement_, nullptr, has_side_effects);
+ },
[&](Default) {
AddError(expr->source) << "type is not constructible";
return nullptr;
diff --git a/src/tint/lang/wgsl/resolver/resolver.h b/src/tint/lang/wgsl/resolver/resolver.h
index eb811eb..12e73e2 100644
--- a/src/tint/lang/wgsl/resolver/resolver.h
+++ b/src/tint/lang/wgsl/resolver/resolver.h
@@ -662,6 +662,11 @@
using StructConstructorSig = tint::UnorderedKeyWrapper<
std::tuple<const core::type::Struct*, size_t, core::EvaluationStage>>;
+ // SubgroupMatrixConstructorSig represents a unique subgroup matrix constructor signature.
+ // It is a tuple of the subgroup matrix type and the number of arguments provided.
+ using SubgroupMatrixConstructorSig =
+ tint::UnorderedKeyWrapper<std::tuple<const core::type::SubgroupMatrix*, size_t>>;
+
/// ExprEvalStageConstraint describes a constraint on when expressions can be evaluated.
struct ExprEvalStageConstraint {
/// The latest stage that the expression can be evaluated
@@ -701,6 +706,7 @@
Hashmap<OverrideId, const sem::Variable*, 8> override_ids_;
Hashmap<ArrayConstructorSig, sem::CallTarget*, 8> array_ctors_;
Hashmap<StructConstructorSig, sem::CallTarget*, 8> struct_ctors_;
+ Hashmap<SubgroupMatrixConstructorSig, sem::CallTarget*, 8> subgroup_matrix_ctors_;
sem::Function* current_function_ = nullptr;
sem::Statement* current_statement_ = nullptr;
sem::CompoundStatement* current_compound_statement_ = nullptr;
diff --git a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
index daf419e..31b9f2c 100644
--- a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
+++ b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
@@ -27,6 +27,7 @@
#include "src/tint/lang/wgsl/resolver/resolver.h"
#include "src/tint/lang/wgsl/resolver/resolver_helper_test.h"
+#include "src/tint/lang/wgsl/sem/value_constructor.h"
#include "gmock/gmock.h"
@@ -222,5 +223,69 @@
R"(error: subgroup matrix row count must be a constant positive integer)");
}
+TEST_F(ResolverSubgroupMatrixTest, ZeroValueConstructor) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ auto* call = Call(Ident("subgroup_matrix_result", ty.f32(), 8_a, 8_a));
+ Func("foo", Empty, ty.void_(),
+ Vector{
+ Assign(Phony(), call),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto call_sem = Sem().Get(call)->As<sem::Call>();
+ ASSERT_NE(call_sem, nullptr);
+ auto* target = call_sem->Target()->As<sem::ValueConstructor>();
+ ASSERT_NE(target, nullptr);
+ EXPECT_TRUE(target->ReturnType()->Is<core::type::SubgroupMatrix>());
+ EXPECT_EQ(target->Parameters().Length(), 0u);
+ EXPECT_EQ(target->Stage(), core::EvaluationStage::kRuntime);
+}
+
+TEST_F(ResolverSubgroupMatrixTest, SingleValueConstructor) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ auto* call = Call(Ident("subgroup_matrix_result", ty.f32(), 8_a, 8_a), 1_a);
+ Func("foo", Empty, ty.void_(),
+ Vector{
+ Assign(Phony(), call),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto call_sem = Sem().Get(call)->As<sem::Call>();
+ ASSERT_NE(call_sem, nullptr);
+ auto* target = call_sem->Target()->As<sem::ValueConstructor>();
+ ASSERT_NE(target, nullptr);
+ EXPECT_TRUE(target->ReturnType()->Is<core::type::SubgroupMatrix>());
+ EXPECT_EQ(target->Parameters().Length(), 1u);
+ EXPECT_EQ(target->Stage(), core::EvaluationStage::kRuntime);
+}
+
+TEST_F(ResolverSubgroupMatrixTest, ConstructorTooManyArgs) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ auto* call = Call(Ident("subgroup_matrix_result", ty.f32(), 8_a, 8_a), 1_f, 2_f);
+ Func("foo", Empty, ty.void_(),
+ Vector{
+ Assign(Phony(), call),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(error: subgroup_matrix constructor can only have zero or one elements)");
+}
+
+TEST_F(ResolverSubgroupMatrixTest, ConstructorWrongType) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ auto* call = Call(Ident("subgroup_matrix_result", ty.u32(), 8_a, 8_a), 1_f);
+ Func("foo", Empty, ty.void_(),
+ Vector{
+ Assign(Phony(), call),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(error: 'f32' cannot be used to construct a subgroup matrix of 'u32')");
+}
+
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/lang/wgsl/resolver/validator.cc b/src/tint/lang/wgsl/resolver/validator.cc
index a3f559f..bc900a9 100644
--- a/src/tint/lang/wgsl/resolver/validator.cc
+++ b/src/tint/lang/wgsl/resolver/validator.cc
@@ -2323,6 +2323,29 @@
return true;
}
+bool Validator::SubgroupMatrixConstructor(
+ const ast::CallExpression* ctor,
+ const core::type::SubgroupMatrix* subgroup_matrix_type) const {
+ auto& values = ctor->args;
+ if (values.Length() == 1) {
+ auto* elem_ty = subgroup_matrix_type->Type();
+ auto* value_ty = sem_.TypeOf(values[0])->UnwrapRef();
+ if (core::type::Type::ConversionRank(value_ty, elem_ty) ==
+ core::type::Type::kNoConversion) {
+ AddError(values[0]->source) << style::Type(sem_.TypeNameOf(value_ty))
+ << " cannot be used to construct a subgroup matrix of "
+ << style::Type(sem_.TypeNameOf(elem_ty));
+ return false;
+ }
+ } else if (values.Length() > 1) {
+ AddError(ctor->target->source)
+ << "subgroup_matrix constructor can only have zero or one elements";
+ return false;
+ }
+
+ return true;
+}
+
bool Validator::Vector(const core::type::Type* el_ty, const Source& source) const {
if (!el_ty->Is<core::type::Scalar>()) {
AddError(source) << "vector element type must be " << style::Type("bool") << ", "
diff --git a/src/tint/lang/wgsl/resolver/validator.h b/src/tint/lang/wgsl/resolver/validator.h
index 4d533f6..7907f9d 100644
--- a/src/tint/lang/wgsl/resolver/validator.h
+++ b/src/tint/lang/wgsl/resolver/validator.h
@@ -529,11 +529,18 @@
bool Vector(const core::type::Type* el_ty, const Source& source) const;
/// Validates an array constructor
- /// @param ctor the call expresion to validate
+ /// @param ctor the call expression to validate
/// @param arr_type the type of the array
/// @returns true on success, false otherwise
bool ArrayConstructor(const ast::CallExpression* ctor, const sem::Array* arr_type) const;
+ /// Validates a subgroup matrix constructor
+ /// @param ctor the call expression to validate
+ /// @param subgroup_matrix_type the type of the subgroup matrix
+ /// @returns true on success, false otherwise
+ bool SubgroupMatrixConstructor(const ast::CallExpression* ctor,
+ const core::type::SubgroupMatrix* subgroup_matrix_type) const;
+
/// Validates a subgroupShuffle builtin functions including Up,Down, and Xor.
/// @param fn the builtin call type
/// @param call the builtin call to validate