spirv-reader: flatten struct pipeline inputs
Bug: tint:912
Change-Id: I01002f4996d3205af06edda092a1c18dcf6213e3
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/56220
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: David Neto <dneto@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 43466e5..55c0a27 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -974,6 +974,18 @@
}
}
return success();
+ } else if (auto* struct_type = tip_type->As<Struct>()) {
+ const auto& members = struct_type->members;
+ index_prefix.push_back(0);
+ for (int i = 0; i < static_cast<int>(members.size()); ++i) {
+ index_prefix.back() = i;
+ if (!EmitInputParameter(var_name, var_type, decos, index_prefix,
+ members[i], forced_param_type, params,
+ statements)) {
+ return false;
+ }
+ }
+ return success();
}
const bool is_builtin = ast::HasDecoration<ast::BuiltinDecoration>(*decos);
@@ -1005,6 +1017,11 @@
} else if (auto* array_type = current_type->As<Array>()) {
store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index));
current_type = array_type->type->UnwrapAlias();
+ } else if (auto* struct_type = current_type->As<Struct>()) {
+ store_dest = builder_.MemberAccessor(
+ store_dest,
+ builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
+ current_type = struct_type->members[index];
}
}
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index 5272aa1..db4383a 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -1144,7 +1144,9 @@
read_only_struct_types_.insert(ast_struct->name());
}
AddTypeDecl(sym, ast_struct);
- return ty_.Struct(sym, std::move(ast_member_types));
+ const auto* result = ty_.Struct(sym, std::move(ast_member_types));
+ struct_id_for_symbol_[sym] = type_id;
+ return result;
}
void ParserImpl::AddTypeDecl(Symbol name, ast::TypeDecl* decl) {
@@ -2618,6 +2620,16 @@
return def_use_mgr_ ? def_use_mgr_->GetDef(id) : nullptr;
}
+std::string ParserImpl::GetMemberName(const Struct& struct_type,
+ int member_index) {
+ auto where = struct_id_for_symbol_.find(struct_type.name);
+ if (where == struct_id_for_symbol_.end()) {
+ Fail() << "no structure type registered for symbol";
+ return "";
+ }
+ return namer_.GetMemberName(where->second, member_index);
+}
+
WorkgroupSizeInfo::WorkgroupSizeInfo() = default;
WorkgroupSizeInfo::~WorkgroupSizeInfo() = default;
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h
index 2b818dd..3321fd9 100644
--- a/src/reader/spirv/parser_impl.h
+++ b/src/reader/spirv/parser_impl.h
@@ -373,6 +373,12 @@
/// @returns the integer constant for its array size, or nullptr.
const spvtools::opt::analysis::IntConstant* GetArraySize(uint32_t var_id);
+ /// Returns the member name for the struct member.
+ /// @param struct_type the parser's structure type.
+ /// @param member_index the member index
+ /// @returns the field name
+ std::string GetMemberName(const Struct& struct_type, int member_index);
+
/// Creates an AST Variable node for a SPIR-V ID, including any attached
/// decorations, unless it's an ignorable builtin variable.
/// @param id the SPIR-V result ID
@@ -809,6 +815,9 @@
// adding duplicates.
std::unordered_set<Symbol> declared_types_;
+ // Maps a struct type name to the SPIR-V ID for the structure type.
+ std::unordered_map<Symbol, uint32_t> struct_id_for_symbol_;
+
/// Maps the SPIR-V ID of a module-scope builtin variable that should be
/// ignored or type-converted, to its builtin kind.
/// See also BuiltInPositionInfo which is a separate mechanism for a more
diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc
index 684ee82..4243eb5 100644
--- a/src/reader/spirv/parser_impl_module_var_test.cc
+++ b/src/reader/spirv/parser_impl_module_var_test.cc
@@ -6528,6 +6528,126 @@
EXPECT_EQ(got, expected) << got;
}
+TEST_F(SpvModuleScopeVarParserTest, Input_FlattenStruct) {
+ const std::string assembly = R"(
+ OpCapability Shader
+ OpMemoryModel Logical Simple
+ OpEntryPoint Vertex %main "main" %1 %2
+
+ OpName %strct "Communicators"
+ OpMemberName %strct 0 "alice"
+ OpMemberName %strct 1 "bob"
+
+ OpDecorate %1 Location 9
+ OpDecorate %2 BuiltIn Position
+
+
+ %void = OpTypeVoid
+ %voidfn = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v4float = OpTypeVector %float 4
+ %strct = OpTypeStruct %float %v4float
+
+ %11 = OpTypePointer Input %strct
+
+ %1 = OpVariable %11 Input
+
+ %12 = OpTypePointer Output %v4float
+ %2 = OpVariable %12 Output
+
+ %main = OpFunction %void None %voidfn
+ %entry = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+ auto p = parser(test::Assemble(assembly));
+
+ ASSERT_TRUE(p->Parse()) << p->error() << assembly;
+ EXPECT_TRUE(p->error().empty());
+
+ const auto got = p->program().to_str();
+ const std::string expected = R"(Module{
+ Struct Communicators {
+ StructMember{alice: __f32}
+ StructMember{bob: __vec_4__f32}
+ }
+ Struct main_out {
+ StructMember{[[ BuiltinDecoration{position}
+ ]] x_2: __vec_4__f32}
+ }
+ Variable{
+ x_1
+ private
+ undefined
+ __type_name_Communicators
+ }
+ Variable{
+ x_2
+ private
+ undefined
+ __vec_4__f32
+ }
+ Function main_1 -> __void
+ ()
+ {
+ Return{}
+ }
+ Function main -> __type_name_main_out
+ StageDecoration{vertex}
+ (
+ VariableConst{
+ Decorations{
+ LocationDecoration{9}
+ }
+ x_1_param
+ none
+ undefined
+ __f32
+ }
+ VariableConst{
+ Decorations{
+ LocationDecoration{10}
+ }
+ x_1_param_1
+ none
+ undefined
+ __vec_4__f32
+ }
+ )
+ {
+ Assignment{
+ MemberAccessor[not set]{
+ Identifier[not set]{x_1}
+ Identifier[not set]{alice}
+ }
+ Identifier[not set]{x_1_param}
+ }
+ Assignment{
+ MemberAccessor[not set]{
+ Identifier[not set]{x_1}
+ Identifier[not set]{bob}
+ }
+ Identifier[not set]{x_1_param_1}
+ }
+ Call[not set]{
+ Identifier[not set]{main_1}
+ (
+ )
+ }
+ Return{
+ {
+ TypeConstructor[not set]{
+ __type_name_main_out
+ Identifier[not set]{x_2}
+ }
+ }
+ }
+ }
+}
+)";
+ EXPECT_EQ(got, expected) << got;
+}
+
TEST_F(SpvModuleScopeVarParserTest, Input_FlattenNested) {
const std::string assembly = R"(
OpCapability Shader