spirv-reader: support NumWorkgroups
Fixed: tint:1065
Change-Id: Id2a8af247e7da79933703e634478f1dec25f9145
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110220
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Auto-Submit: David Neto <dneto@google.com>
diff --git a/src/tint/reader/spirv/enum_converter.cc b/src/tint/reader/spirv/enum_converter.cc
index d590507..8c70156 100644
--- a/src/tint/reader/spirv/enum_converter.cc
+++ b/src/tint/reader/spirv/enum_converter.cc
@@ -82,6 +82,8 @@
return ast::BuiltinValue::kLocalInvocationIndex;
case spv::BuiltIn::GlobalInvocationId:
return ast::BuiltinValue::kGlobalInvocationId;
+ case spv::BuiltIn::NumWorkgroups:
+ return ast::BuiltinValue::kNumWorkgroups;
case spv::BuiltIn::WorkgroupId:
return ast::BuiltinValue::kWorkgroupId;
case spv::BuiltIn::SampleId:
diff --git a/src/tint/reader/spirv/enum_converter_test.cc b/src/tint/reader/spirv/enum_converter_test.cc
index b366ebd..bdb0247 100644
--- a/src/tint/reader/spirv/enum_converter_test.cc
+++ b/src/tint/reader/spirv/enum_converter_test.cc
@@ -192,6 +192,7 @@
BuiltinCase{spv::BuiltIn::LocalInvocationIndex, true,
ast::BuiltinValue::kLocalInvocationIndex},
BuiltinCase{spv::BuiltIn::GlobalInvocationId, true, ast::BuiltinValue::kGlobalInvocationId},
+ BuiltinCase{spv::BuiltIn::NumWorkgroups, true, ast::BuiltinValue::kNumWorkgroups},
BuiltinCase{spv::BuiltIn::WorkgroupId, true, ast::BuiltinValue::kWorkgroupId},
BuiltinCase{spv::BuiltIn::SampleId, true, ast::BuiltinValue::kSampleIndex},
BuiltinCase{spv::BuiltIn::SampleMask, true, ast::BuiltinValue::kSampleMask}));
@@ -208,8 +209,6 @@
testing::Values(BuiltinCase{static_cast<spv::BuiltIn>(9999), false,
ast::BuiltinValue::kUndefined},
BuiltinCase{static_cast<spv::BuiltIn>(9999), false,
- ast::BuiltinValue::kUndefined},
- BuiltinCase{spv::BuiltIn::NumWorkgroups, false,
ast::BuiltinValue::kUndefined}));
// Dim
diff --git a/src/tint/reader/spirv/parser_impl_module_var_test.cc b/src/tint/reader/spirv/parser_impl_module_var_test.cc
index 44fc4b3..be931a6 100644
--- a/src/tint/reader/spirv/parser_impl_module_var_test.cc
+++ b/src/tint/reader/spirv/parser_impl_module_var_test.cc
@@ -3106,6 +3106,14 @@
// Returns the start of a shader for testing LocalInvocationIndex,
// parameterized by store type of %int or %uint
std::string ComputeBuiltinInputPreamble(std::string builtin, std::string store_type) {
+ std::string ptr_component_type;
+ if (store_type == "%v3int") {
+ ptr_component_type = " %ptr_comp_ty = OpTypePointer Input %int\n";
+ }
+ if (store_type == "%v3uint") {
+ ptr_component_type = " %ptr_comp_ty = OpTypePointer Input %uint\n";
+ }
+
return R"(
OpCapability Shader
OpMemoryModel Logical Simple
@@ -3118,10 +3126,11 @@
%float = OpTypeFloat 32
%uint = OpTypeInt 32 0
%int = OpTypeInt 32 1
+ %int_1 = OpConstant %int 1
%v3uint = OpTypeVector %uint 3
%v3int = OpTypeVector %int 3
%ptr_ty = OpTypePointer Input )" +
- store_type + R"(
+ store_type + ptr_component_type + R"(
%1 = OpVariable %ptr_ty Input
)";
}
@@ -3331,14 +3340,84 @@
{"LocalInvocationId", "%v3int", "local_invocation_id"},
{"GlobalInvocationId", "%v3uint", "global_invocation_id"},
{"GlobalInvocationId", "%v3int", "global_invocation_id"},
+ {"NumWorkgroups", "%v3uint", "num_workgroups"},
+ {"NumWorkgroups", "%v3int", "num_workgroups"},
{"WorkgroupId", "%v3uint", "workgroup_id"},
{"WorkgroupId", "%v3int", "workgroup_id"}}));
-// TODO(dneto): crbug.com/tint/752
-// NumWorkgroups support is blocked by crbug.com/tint/752
-// When the AST supports NumWorkgroups, add these cases:
-// {"NumWorkgroups", "%uint", "num_workgroups"}
-// {"NumWorkgroups", "%int", "num_workgroups"}
+// For compute shader builtins that are vectors, test loading one component.
+struct ComputeBuiltinInputVectorCase {
+ std::string spirv_builtin;
+ std::string spirv_store_type;
+ std::string spirv_component_store_type;
+ std::string wgsl_builtin;
+};
+inline std::ostream& operator<<(std::ostream& o, ComputeBuiltinInputVectorCase c) {
+ return o << "ComputeBuiltinInputVectorCase(" << c.spirv_builtin << " " << c.spirv_store_type
+ << " " << c.spirv_component_store_type << " " << c.wgsl_builtin << ")";
+}
+
+using SpvModuleScopeVarParserTest_ComputeBuiltinVector =
+ SpvParserTestBase<::testing::TestWithParam<ComputeBuiltinInputVectorCase>>;
+
+TEST_P(SpvModuleScopeVarParserTest_ComputeBuiltinVector, Load_Component_Direct) {
+ const auto wgsl_type = WgslType(GetParam().spirv_store_type);
+ const auto wgsl_component_type = WgslType(GetParam().spirv_component_store_type);
+ const auto wgsl_builtin = GetParam().wgsl_builtin;
+ const auto unsigned_wgsl_type = UnsignedWgslType(wgsl_type);
+ const auto signed_wgsl_type = SignedWgslType(wgsl_type);
+ const std::string assembly =
+ ComputeBuiltinInputPreamble(GetParam().spirv_builtin, GetParam().spirv_store_type) +
+ R"(
+ %main = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %3 = OpAccessChain %ptr_comp_ty %1 %int_1
+ %2 = OpLoad )" +
+ GetParam().spirv_component_store_type + R"( %3
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly;
+ EXPECT_TRUE(p->error().empty());
+ const auto module_str = test::ToString(p->program());
+ std::string expected = R"(var<private> x_1 : ${wgsl_type};
+
+fn main_1() {
+ let x_2 : ${wgsl_component_type} = x_1.y;
+ return;
+}
+
+@compute @workgroup_size(1i, 1i, 1i)
+fn main(@builtin(${wgsl_builtin}) x_1_param : ${unsigned_wgsl_type}) {
+ x_1 = ${assignment_value};
+ main_1();
+}
+)";
+
+ expected = utils::ReplaceAll(expected, "${wgsl_type}", wgsl_type);
+ expected = utils::ReplaceAll(expected, "${wgsl_component_type}", wgsl_component_type);
+ expected = utils::ReplaceAll(expected, "${unsigned_wgsl_type}", unsigned_wgsl_type);
+ expected = utils::ReplaceAll(expected, "${wgsl_builtin}", wgsl_builtin);
+ expected = utils::ReplaceAll(expected, "${assignment_value}",
+ (wgsl_type == unsigned_wgsl_type)
+ ? "x_1_param"
+ : "bitcast<" + signed_wgsl_type + ">(x_1_param)");
+
+ EXPECT_EQ(module_str, expected) << module_str;
+}
+
+INSTANTIATE_TEST_SUITE_P(Samples,
+ SpvModuleScopeVarParserTest_ComputeBuiltinVector,
+ ::testing::ValuesIn(std::vector<ComputeBuiltinInputVectorCase>{
+ {"LocalInvocationId", "%v3uint", "%uint", "local_invocation_id"},
+ {"LocalInvocationId", "%v3int", "%int", "local_invocation_id"},
+ {"GlobalInvocationId", "%v3uint", "%uint", "global_invocation_id"},
+ {"GlobalInvocationId", "%v3int", "%int", "global_invocation_id"},
+ {"NumWorkgroups", "%v3uint", "%uint", "num_workgroups"},
+ {"NumWorkgroups", "%v3int", "%int", "num_workgroups"},
+ {"WorkgroupId", "%v3uint", "%uint", "workgroup_id"},
+ {"WorkgroupId", "%v3int", "%int", "workgroup_id"}}));
TEST_F(SpvModuleScopeVarParserTest, RegisterInputOutputVars) {
const std::string assembly =