spirv-reader: handle locations on structure members
Bug: tint:912
Change-Id: Ia179a3152cfcdfed812f1673aaa4dba8b565dadf
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/56341
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: David Neto <dneto@google.com>
Auto-Submit: David Neto <dneto@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 85ab480..03f45d9 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -980,11 +980,14 @@
index_prefix.push_back(0);
for (int i = 0; i < static_cast<int>(members.size()); ++i) {
index_prefix.back() = i;
+ auto* location = parser_impl_.GetMemberLocation(*struct_type, i);
+ auto* saved_location = SetLocation(decos, location);
if (!EmitPipelineInput(var_name, var_type, decos, index_prefix,
members[i], forced_param_type, params,
statements)) {
return false;
}
+ SetLocation(decos, saved_location);
}
return success();
}
@@ -1036,6 +1039,12 @@
statements->push_back(builder_.Assign(store_dest, param_value));
// Increment the location attribute, in case more parameters will follow.
+ IncrementLocation(decos);
+
+ return success();
+}
+
+void FunctionEmitter::IncrementLocation(ast::DecorationList* decos) {
for (auto*& deco : *decos) {
if (auto* loc_deco = deco->As<ast::LocationDecoration>()) {
// Replace this location decoration with a new one with one higher index.
@@ -1044,8 +1053,27 @@
deco = builder_.Location(loc_deco->source(), loc_deco->value() + 1);
}
}
+}
- return success();
+ast::Decoration* FunctionEmitter::SetLocation(
+ ast::DecorationList* decos,
+ ast::Decoration* replacement) {
+ if (!replacement) {
+ return nullptr;
+ }
+ for (auto*& deco : *decos) {
+ if (deco->Is<ast::LocationDecoration>()) {
+ // Replace this location decoration with a new one with one higher index.
+ // The old one doesn't leak because it's kept in the builder's AST node
+ // list.
+ ast::Decoration* result = deco;
+ deco = replacement;
+ return result;
+ }
+ }
+ // The list didn't have a location. Add it.
+ decos->push_back(replacement);
+ return nullptr;
}
bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
@@ -1095,11 +1123,14 @@
index_prefix.push_back(0);
for (int i = 0; i < static_cast<int>(members.size()); ++i) {
index_prefix.back() = i;
+ auto* location = parser_impl_.GetMemberLocation(*struct_type, i);
+ auto* saved_location = SetLocation(decos, location);
if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix,
members[i], forced_member_type, return_members,
return_exprs)) {
return false;
}
+ SetLocation(decos, saved_location);
}
return success();
}
@@ -1150,14 +1181,7 @@
return_exprs->push_back(load_source);
// Increment the location attribute, in case more parameters will follow.
- for (auto*& deco : *decos) {
- if (auto* loc_deco = deco->As<ast::LocationDecoration>()) {
- // Replace this location decoration with a new one with one higher index.
- // The old one doesn't leak because it's kept in the builder's AST node
- // list.
- deco = builder_.Location(loc_deco->source(), loc_deco->value() + 1);
- }
- }
+ IncrementLocation(decos);
return success();
}
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 3a287a4..06f0450 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -466,6 +466,23 @@
ast::StructMemberList* return_members,
ast::ExpressionList* return_exprs);
+ /// Updates the decoration list, replacing an existing Location decoration
+ /// with another having one higher location value. Does nothing if no
+ /// location decoration exists.
+ /// Assumes the list contains at most one Location decoration.
+ /// @param decos the decoration list to modify
+ void IncrementLocation(ast::DecorationList* decos);
+
+ /// Updates the decoration list, placing a non-null location decoration into
+ /// the list, replacing an existing one if it exists. Does nothing if the
+ /// replacement is nullptr.
+ /// Assumes the list contains at most one Location decoration.
+ /// @param decos the decoration list to modify
+ /// @param replacement the location decoration to place into the list
+ /// @returns the location decoration that was replaced, if one was replaced.
+ ast::Decoration* SetLocation(ast::DecorationList* decos,
+ ast::Decoration* replacement);
+
/// Create an ast::BlockStatement representing the body of the function.
/// This creates the statement stack, which is non-empty for the lifetime
/// of the function.
diff --git a/src/reader/spirv/function_var_test.cc b/src/reader/spirv/function_var_test.cc
index cf2e9aa..9b07da7 100644
--- a/src/reader/spirv/function_var_test.cc
+++ b/src/reader/spirv/function_var_test.cc
@@ -2784,7 +2784,6 @@
OpReturn
OpFunctionEnd
)";
- std::cout << assembly << std::endl;
auto p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly;
auto fe = p->function_emitter(100);
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index db4383a..c5830c5 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -1098,6 +1098,8 @@
// apply the ReadOnly access control to the containing struct if all
// the members are non-writable.
is_non_writable = true;
+ } else if (decoration[0] == SpvDecorationLocation) {
+ // Location decorations are handled when emitting the entry point.
} else {
auto* ast_member_decoration =
ConvertMemberDecoration(type_id, member_index, decoration);
@@ -1134,7 +1136,7 @@
// Now make the struct.
auto sym = builder_.Symbols().Register(name);
ast::DecorationList ast_struct_decorations;
- if (is_block_decorated) {
+ if (is_block_decorated && struct_types_for_buffers_.count(type_id)) {
ast_struct_decorations.emplace_back(
create<ast::StructBlockDecoration>(Source{}));
}
@@ -1208,6 +1210,37 @@
if (!success_) {
return false;
}
+
+ // First record the structure types that should have a `block` decoration
+ // in WGSL. In particular, exclude user-defined pipeline IO in a
+ // block-decorated struct.
+ for (const auto& type_or_value : module_->types_values()) {
+ if (type_or_value.opcode() != SpvOpVariable) {
+ continue;
+ }
+ const auto& var = type_or_value;
+ const auto spirv_storage_class =
+ SpvStorageClass(var.GetSingleWordInOperand(0));
+ if ((spirv_storage_class != SpvStorageClassStorageBuffer) &&
+ (spirv_storage_class != SpvStorageClassUniform)) {
+ continue;
+ }
+ const auto* ptr_type = def_use_mgr_->GetDef(var.type_id());
+ if (ptr_type->opcode() != SpvOpTypePointer) {
+ return Fail() << "OpVariable type expected to be a pointer: "
+ << var.PrettyPrint();
+ }
+ const auto* store_type =
+ def_use_mgr_->GetDef(ptr_type->GetSingleWordInOperand(1));
+ if (store_type->opcode() == SpvOpTypeStruct) {
+ struct_types_for_buffers_.insert(store_type->result_id());
+ } else {
+ Fail() << "WGSL does not support arrays of buffers: "
+ << var.PrettyPrint();
+ }
+ }
+
+ // Now convert each type.
for (auto& type_or_const : module_->types_values()) {
const auto* type = type_mgr_->GetType(type_or_const.result_id());
if (type == nullptr) {
@@ -2630,6 +2663,22 @@
return namer_.GetMemberName(where->second, member_index);
}
+ast::Decoration* ParserImpl::GetMemberLocation(const Struct& struct_type,
+ int member_index) {
+ auto where = struct_id_for_symbol_.find(struct_type.name);
+ if (where == struct_id_for_symbol_.end()) {
+ Fail() << "no structure type registered for symbol";
+ return nullptr;
+ }
+ const auto type_id = where->second;
+ for (auto& deco : GetDecorationsForMember(type_id, member_index)) {
+ if ((deco.size() == 2) && (deco[0] == SpvDecorationLocation)) {
+ return create<ast::LocationDecoration>(Source{}, deco[1]);
+ }
+ }
+ return nullptr;
+}
+
WorkgroupSizeInfo::WorkgroupSizeInfo() = default;
WorkgroupSizeInfo::~WorkgroupSizeInfo() = default;
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h
index 3321fd9..b041605 100644
--- a/src/reader/spirv/parser_impl.h
+++ b/src/reader/spirv/parser_impl.h
@@ -379,6 +379,13 @@
/// @returns the field name
std::string GetMemberName(const Struct& struct_type, int member_index);
+ /// Returns the location decoration, if any on a struct member.
+ /// @param struct_type the parser's structure type.
+ /// @param member_index the member index
+ /// @returns a newly created location node, or nullptr
+ ast::Decoration* GetMemberLocation(const Struct& struct_type,
+ int member_index);
+
/// Creates an AST Variable node for a SPIR-V ID, including any attached
/// decorations, unless it's an ignorable builtin variable.
/// @param id the SPIR-V result ID
@@ -765,6 +772,10 @@
// "NonSemanticInfo." import is ignored.
std::unordered_set<uint32_t> ignored_imports_;
+ // The SPIR-V IDs of structure types that are the store type for buffer
+ // variables, either UBO or SSBO.
+ std::unordered_set<uint32_t> struct_types_for_buffers_;
+
// Bookkeeping for the gl_Position builtin.
// In Vulkan SPIR-V, it's the 0 member of the gl_PerVertex structure.
// But in WGSL we make a module-scope variable:
diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc
index 78e1ea9..beeef85 100644
--- a/src/reader/spirv/parser_impl_module_var_test.cc
+++ b/src/reader/spirv/parser_impl_module_var_test.cc
@@ -6528,7 +6528,7 @@
EXPECT_EQ(got, expected) << got;
}
-TEST_F(SpvModuleScopeVarParserTest, Input_FlattenStruct) {
+TEST_F(SpvModuleScopeVarParserTest, Input_FlattenStruct_LocOnVariable) {
const std::string assembly = R"(
OpCapability Shader
OpMemoryModel Logical Simple
@@ -6993,7 +6993,7 @@
EXPECT_EQ(got, expected) << got;
}
-TEST_F(SpvModuleScopeVarParserTest, Output_FlattenStruct) {
+TEST_F(SpvModuleScopeVarParserTest, Output_FlattenStruct_LocOnVariable) {
const std::string assembly = R"(
OpCapability Shader
OpMemoryModel Logical Simple
@@ -7092,6 +7092,192 @@
EXPECT_EQ(got, expected) << got;
}
+TEST_F(SpvModuleScopeVarParserTest, FlattenStruct_LocOnMembers) {
+ // Block-decorated struct may have its members decorated with Location.
+ const std::string assembly = R"(
+ OpCapability Shader
+ OpMemoryModel Logical Simple
+ OpEntryPoint Vertex %main "main" %1 %2 %3
+
+ OpName %strct "Communicators"
+ OpMemberName %strct 0 "alice"
+ OpMemberName %strct 1 "bob"
+
+ OpMemberDecorate %strct 0 Location 9
+ OpMemberDecorate %strct 1 Location 11
+ OpDecorate %strct Block
+ OpDecorate %2 BuiltIn Position
+
+ %void = OpTypeVoid
+ %voidfn = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v4float = OpTypeVector %float 4
+ %strct = OpTypeStruct %float %v4float
+
+ %11 = OpTypePointer Input %strct
+ %13 = OpTypePointer Output %strct
+
+ %1 = OpVariable %11 Input
+ %3 = OpVariable %13 Output
+
+ %12 = OpTypePointer Output %v4float
+ %2 = OpVariable %12 Output
+
+ %main = OpFunction %void None %voidfn
+ %entry = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+ auto p = parser(test::Assemble(assembly));
+
+ ASSERT_TRUE(p->Parse()) << p->error() << assembly;
+ EXPECT_TRUE(p->error().empty());
+
+ const auto got = p->program().to_str();
+ const std::string expected = R"(Module{
+ Struct Communicators {
+ StructMember{alice: __f32}
+ StructMember{bob: __vec_4__f32}
+ }
+ Struct main_out {
+ StructMember{[[ BuiltinDecoration{position}
+ ]] x_2_1: __vec_4__f32}
+ StructMember{[[ LocationDecoration{9}
+ ]] x_3_1: __f32}
+ StructMember{[[ LocationDecoration{11}
+ ]] x_3_2: __vec_4__f32}
+ }
+ Variable{
+ x_1
+ private
+ undefined
+ __type_name_Communicators
+ }
+ Variable{
+ x_3
+ private
+ undefined
+ __type_name_Communicators
+ }
+ Variable{
+ x_2
+ private
+ undefined
+ __vec_4__f32
+ }
+ Function main_1 -> __void
+ ()
+ {
+ Return{}
+ }
+ Function main -> __type_name_main_out
+ StageDecoration{vertex}
+ (
+ VariableConst{
+ Decorations{
+ LocationDecoration{9}
+ }
+ x_1_param
+ none
+ undefined
+ __f32
+ }
+ VariableConst{
+ Decorations{
+ LocationDecoration{11}
+ }
+ x_1_param_1
+ none
+ undefined
+ __vec_4__f32
+ }
+ )
+ {
+ Assignment{
+ MemberAccessor[not set]{
+ Identifier[not set]{x_1}
+ Identifier[not set]{alice}
+ }
+ Identifier[not set]{x_1_param}
+ }
+ Assignment{
+ MemberAccessor[not set]{
+ Identifier[not set]{x_1}
+ Identifier[not set]{bob}
+ }
+ Identifier[not set]{x_1_param_1}
+ }
+ Call[not set]{
+ Identifier[not set]{main_1}
+ (
+ )
+ }
+ Return{
+ {
+ TypeConstructor[not set]{
+ __type_name_main_out
+ Identifier[not set]{x_2}
+ MemberAccessor[not set]{
+ Identifier[not set]{x_3}
+ Identifier[not set]{alice}
+ }
+ MemberAccessor[not set]{
+ Identifier[not set]{x_3}
+ Identifier[not set]{bob}
+ }
+ }
+ }
+ }
+ }
+}
+)";
+ EXPECT_EQ(got, expected) << got;
+}
+
+TEST_F(SpvModuleScopeVarParserTest, FlattenStruct_LocOnStruct) {
+ const std::string assembly = R"(
+ OpCapability Shader
+ OpMemoryModel Logical Simple
+ OpEntryPoint Vertex %main "main" %1 %2 %3
+
+ OpName %strct "Communicators"
+ OpMemberName %strct 0 "alice"
+ OpMemberName %strct 1 "bob"
+
+ OpDecorate %strct Location 9
+ OpDecorate %strct Block
+ OpDecorate %2 BuiltIn Position
+
+ %void = OpTypeVoid
+ %voidfn = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v4float = OpTypeVector %float 4
+ %strct = OpTypeStruct %float %v4float
+
+ %11 = OpTypePointer Input %strct
+ %13 = OpTypePointer Output %strct
+
+ %1 = OpVariable %11 Input
+ %3 = OpVariable %13 Output
+
+ %12 = OpTypePointer Output %v4float
+ %2 = OpVariable %12 Output
+
+ %main = OpFunction %void None %voidfn
+ %entry = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+ auto p = parser(test::Assemble(assembly));
+
+ // The validator rejects this because Location decorations
+ // can only go on OpVariable or members of a structure type.
+ ASSERT_FALSE(p->Parse()) << p->error() << assembly;
+ EXPECT_THAT(p->error(),
+ HasSubstr("Location decoration can only be applied to a variable "
+ "or member of a structure type"));
+}
+
} // namespace
} // namespace spirv
} // namespace reader