[tint] Check uniformity for subgroup matrix constructors
They must only be called from uniform control flow.
Fixed: 403610976
Change-Id: Iaca57415a3ae316f145df60c4ccffba9fafa42b9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/235656
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/tint/lang/wgsl/resolver/uniformity.cc b/src/tint/lang/wgsl/resolver/uniformity.cc
index a9b0bec..2bb5cdc 100644
--- a/src/tint/lang/wgsl/resolver/uniformity.cc
+++ b/src/tint/lang/wgsl/resolver/uniformity.cc
@@ -401,6 +401,9 @@
/// The function currently being analyzed.
FunctionInfo* current_function_;
+ /// A map from composite type to true/false to indicate whether it contains a subgroup matrix.
+ Hashmap<const core::type::Type*, bool, 16> composite_subgroup_matrix_info_;
+
/// Create a new node.
/// @param tag_list a string list that will be used to identify the node for debugging purposes
/// @param ast the optional AST node that this node corresponds to
@@ -1746,9 +1749,18 @@
function_tag = info->function_tag;
func_info = info.value;
},
- [&](const sem::ValueConstructor*) {
- callsite_tag = {CallSiteTag::CallSiteNoRestriction};
- function_tag = NoRestriction;
+ [&](const sem::ValueConstructor* construct) {
+ if (ContainsSubgroupMatrix(construct->ReturnType())) {
+ // Get the severity of subgroup matrix uniformity violations in this context.
+ auto severity = sem_.DiagnosticSeverity(
+ call, wgsl::ChromiumDiagnosticRule::kSubgroupMatrixUniformity);
+ if (severity != wgsl::DiagnosticSeverity::kOff) {
+ callsite_tag = {CallSiteTag::CallSiteRequiredToBeUniform, severity};
+ }
+ } else {
+ callsite_tag = {CallSiteTag::CallSiteNoRestriction};
+ function_tag = NoRestriction;
+ }
},
[&](const sem::ValueConversion*) {
callsite_tag = {CallSiteTag::CallSiteNoRestriction};
@@ -1908,8 +1920,8 @@
const ast::CallExpression* call,
wgsl::DiagnosticSeverity severity) {
auto* target = SemCall(call)->Target();
- if (target->Is<sem::BuiltinFn>()) {
- // This is a call to a builtin, so we must be done.
+ if (target->IsAnyOf<sem::BuiltinFn, sem::ValueConstructor>()) {
+ // This is a call to a builtin or constructor, so we must be done.
return call;
} else if (auto* user = target->As<sem::Function>()) {
// This is a call to a user-defined function, so inspect the functions called by that
@@ -2156,6 +2168,28 @@
const sem::Call* SemCall(const ast::CallExpression* expr) const {
return sem_.Get(expr)->UnwrapMaterialize()->As<sem::Call>();
}
+
+ /// @returns true if @p type is or contains a subgroup matrix type
+ bool ContainsSubgroupMatrix(const core::type::Type* type) {
+ if (type->Is<core::type::SubgroupMatrix>()) {
+ return true;
+ }
+ return composite_subgroup_matrix_info_.GetOrAdd(type, [&] {
+ return Switch(
+ type, //
+ [&](const core::type::Array* arr) {
+ return ContainsSubgroupMatrix(arr->ElemType());
+ },
+ [&](const core::type::Struct* str) {
+ for (auto* member : str->Members()) {
+ if (ContainsSubgroupMatrix(member->Type())) {
+ return true;
+ }
+ }
+ return false;
+ });
+ });
+ }
};
} // namespace
diff --git a/src/tint/lang/wgsl/resolver/uniformity_test.cc b/src/tint/lang/wgsl/resolver/uniformity_test.cc
index 7365faf..500dfda 100644
--- a/src/tint/lang/wgsl/resolver/uniformity_test.cc
+++ b/src/tint/lang/wgsl/resolver/uniformity_test.cc
@@ -171,6 +171,7 @@
kQuadSwapY,
kQuadSwapDiagonal,
// Subgroup matrix functions:
+ kSubgroupMatrixConstruct,
kSubgroupMatrixLoad,
kSubgroupMatrixStore,
kSubgroupMatrixMultiply,
@@ -309,6 +310,8 @@
return "_ = quadSwapY(1.0)";
case kQuadSwapDiagonal:
return "_ = quadSwapDiagonal(1.0)";
+ case kSubgroupMatrixConstruct:
+ return "_ = subgroup_matrix_result<f32, 8, 8>()";
case kSubgroupMatrixLoad:
return "_ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>("
"&subgroup_matrix_data, 0, false, 4)";
@@ -409,6 +412,7 @@
CASE(kQuadSwapX);
CASE(kQuadSwapY);
CASE(kQuadSwapDiagonal);
+ CASE(kSubgroupMatrixConstruct);
CASE(kSubgroupMatrixLoad);
CASE(kSubgroupMatrixStore);
CASE(kSubgroupMatrixMultiply);
@@ -9787,6 +9791,76 @@
RunTest(src, true);
}
+TEST_F(UniformityAnalysisTest, SubgroupMatrixConstructor_NestedInStruct) {
+ std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+@group(0) @binding(0) var<storage, read_write> non_uniform_global : i32;
+
+struct Inner {
+ u : u32,
+ m : subgroup_matrix_result<f32, 8, 8>,
+}
+
+struct S {
+ u : u32,
+ inner : Inner,
+}
+
+fn foo() {
+ if (non_uniform_global == 0) {
+ _ = S();
+ }
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:18:9 error: 'S' must only be called from uniform control flow
+ _ = S();
+ ^^^
+
+test:17:3 note: control flow depends on possibly non-uniform value
+ if (non_uniform_global == 0) {
+ ^^
+
+test:17:7 note: reading from read_write storage buffer 'non_uniform_global' may result in a non-uniform value
+ if (non_uniform_global == 0) {
+ ^^^^^^^^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixConstructor_NestedInArray) {
+ std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+@group(0) @binding(0) var<storage, read_write> non_uniform_global : i32;
+
+alias ArrayType = array<array<subgroup_matrix_result<f32, 8, 8>, 4>, 4>;
+
+fn foo() {
+ if (non_uniform_global == 0) {
+ _ = ArrayType();
+ }
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:10:9 error: 'ArrayType' must only be called from uniform control flow
+ _ = ArrayType();
+ ^^^^^^^^^^^
+
+test:9:3 note: control flow depends on possibly non-uniform value
+ if (non_uniform_global == 0) {
+ ^^
+
+test:9:7 note: reading from read_write storage buffer 'non_uniform_global' may result in a non-uniform value
+ if (non_uniform_global == 0) {
+ ^^^^^^^^^^^^^^^^^^
+)");
+}
+
TEST_F(UniformityAnalysisTest, StressGraphTraversalDepth) {
// Create a function with a very long sequence of variable declarations and assignments to
// test traversals of very deep graphs. This requires a non-recursive traversal algorithm.
@@ -9920,7 +9994,7 @@
}
}
-TEST_P(UniformityAnalysisDiagnosticFilterTest, Directive_SubgroupMatrixUniformity_Callsite) {
+TEST_P(UniformityAnalysisDiagnosticFilterTest, Directive_SubgroupMatrixUniformity_BuiltinFunction) {
auto& param = GetParam();
StringStream ss;
ss << "enable chromium_experimental_subgroup_matrix;\n"
@@ -9947,6 +10021,33 @@
}
}
+TEST_P(UniformityAnalysisDiagnosticFilterTest, Directive_SubgroupMatrixUniformity_Constructor) {
+ auto& param = GetParam();
+ StringStream ss;
+ ss << "enable chromium_experimental_subgroup_matrix;\n"
+ << "diagnostic(" << param << ", chromium.subgroup_matrix_uniformity);" << R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+@group(0) @binding(1) var<storage, read_write> data : array<f32>;
+
+fn foo() {
+ if (non_uniform == 42) {
+ _ = subgroup_matrix_left<f32, 8, 8>(1);
+ }
+}
+)";
+
+ RunTest(ss.str(), param != wgsl::DiagnosticSeverity::kError);
+
+ if (param == wgsl::DiagnosticSeverity::kOff) {
+ EXPECT_TRUE(error_.empty());
+ } else {
+ StringStream err;
+ err << ToStr(param) << ": 'subgroup_matrix_left' must only be called";
+ EXPECT_THAT(error_, ::testing::HasSubstr(err.str()));
+ }
+}
+
TEST_P(UniformityAnalysisDiagnosticFilterTest, AttributeOnFunction_DerivativeUniformity) {
auto& param = GetParam();
StringStream ss;