[spirv-writer] Handle non-pointer struct member accessors
These map to OpCompositeExtract instructions.
Fixed: tint:662
Change-Id: Ibd865bdb16326de7932157cbdfe543394415b3ff
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/45940
Auto-Submit: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index bec0606..440699a 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -861,30 +861,43 @@
// If the data_type is a structure we're accessing a member, if it's a
// vector we're accessing a swizzle.
if (data_type->Is<type::Struct>()) {
- if (!info->source_type->Is<type::Pointer>()) {
- error_ =
- "Attempting to access a struct member on a non-pointer. Something is "
- "wrong";
- return false;
- }
-
auto* strct = data_type->As<type::Struct>()->impl();
auto symbol = expr->member()->symbol();
- uint32_t i = 0;
- for (; i < strct->members().size(); ++i) {
- auto* member = strct->members()[i];
+ uint32_t idx = 0;
+ for (; idx < strct->members().size(); ++idx) {
+ auto* member = strct->members()[idx];
if (member->symbol() == symbol) {
break;
}
}
- auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(i));
- if (idx_id == 0) {
- return 0;
+ if (info->source_type->Is<type::Pointer>()) {
+ auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(idx));
+ if (idx_id == 0) {
+ return 0;
+ }
+ info->access_chain_indices.push_back(idx_id);
+ info->source_type = expr_type;
+ } else {
+ auto result_type_id = GenerateTypeIfNeeded(expr_type);
+ if (result_type_id == 0) {
+ return false;
+ }
+
+ auto extract = result_op();
+ auto extract_id = extract.to_i();
+ if (!push_function_inst(
+ spv::Op::OpCompositeExtract,
+ {Operand::Int(result_type_id), extract,
+ Operand::Int(info->source_id), Operand::Int(idx)})) {
+ return false;
+ }
+
+ info->source_id = extract_id;
+ info->source_type = expr_type;
}
- info->access_chain_indices.push_back(idx_id);
- info->source_type = expr_type;
+
return true;
}
diff --git a/src/writer/spirv/builder_accessor_expression_test.cc b/src/writer/spirv/builder_accessor_expression_test.cc
index 78ec1ad..accad83 100644
--- a/src/writer/spirv/builder_accessor_expression_test.cc
+++ b/src/writer/spirv/builder_accessor_expression_test.cc
@@ -255,6 +255,7 @@
TEST_F(BuilderTest, MemberAccessor_Nested) {
// inner_struct {
// a : f32
+ // b : f32
// }
// my_struct {
// inner : inner_struct
@@ -270,7 +271,7 @@
auto* s_type = Structure("my_struct", {Member("inner", inner_struct)});
auto* var = Global("ident", s_type, ast::StorageClass::kFunction);
- auto* expr = MemberAccessor(MemberAccessor("ident", "inner"), "a");
+ auto* expr = MemberAccessor(MemberAccessor("ident", "inner"), "b");
WrapInFunction(expr);
spirv::Builder& b = Build();
@@ -278,7 +279,7 @@
b.push_function(Function{});
ASSERT_TRUE(b.GenerateFunctionVariable(var)) << b.error();
- EXPECT_EQ(b.GenerateAccessorExpression(expr), 10u);
+ EXPECT_EQ(b.GenerateAccessorExpression(expr), 11u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
%4 = OpTypeStruct %5 %5
@@ -287,13 +288,92 @@
%6 = OpConstantNull %3
%7 = OpTypeInt 32 0
%8 = OpConstant %7 0
-%9 = OpTypePointer Function %5
+%9 = OpConstant %7 1
+%10 = OpTypePointer Function %5
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%1 = OpVariable %2 Function %6
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
- R"(%10 = OpAccessChain %9 %1 %8 %8
+ R"(%11 = OpAccessChain %10 %1 %8 %9
+)");
+}
+
+TEST_F(BuilderTest, MemberAccessor_NonPointer) {
+ // my_struct {
+ // a : f32
+ // b : f32
+ // }
+ // const ident : my_struct = my_struct();
+ // ident.b
+
+ auto* s = Structure("my_struct", {
+ Member("a", ty.f32()),
+ Member("b", ty.f32()),
+ });
+
+ auto* var = GlobalConst("ident", s, Construct(s, 0.f, 0.f));
+
+ auto* expr = MemberAccessor("ident", "b");
+ WrapInFunction(expr);
+
+ spirv::Builder& b = Build();
+
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateFunctionVariable(var)) << b.error();
+
+ EXPECT_EQ(b.GenerateAccessorExpression(expr), 5u);
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
+%1 = OpTypeStruct %2 %2
+%3 = OpConstant %2 0
+%4 = OpConstantComposite %1 %3 %3
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%5 = OpCompositeExtract %2 %4 1
+)");
+}
+
+TEST_F(BuilderTest, MemberAccessor_Nested_NonPointer) {
+ // inner_struct {
+ // a : f32
+ // b : f32
+ // }
+ // my_struct {
+ // inner : inner_struct
+ // }
+ //
+ // const ident : my_struct = my_struct();
+ // ident.inner.a
+ auto* inner_struct = Structure("Inner", {
+ Member("a", ty.f32()),
+ Member("b", ty.f32()),
+ });
+
+ auto* s_type = Structure("my_struct", {Member("inner", inner_struct)});
+
+ auto* var = GlobalConst("ident", s_type,
+ Construct(s_type, Construct(inner_struct, 0.f, 0.f)));
+ auto* expr = MemberAccessor(MemberAccessor("ident", "inner"), "b");
+ WrapInFunction(expr);
+
+ spirv::Builder& b = Build();
+
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateFunctionVariable(var)) << b.error();
+
+ EXPECT_EQ(b.GenerateAccessorExpression(expr), 8u);
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32
+%2 = OpTypeStruct %3 %3
+%1 = OpTypeStruct %2
+%4 = OpConstant %3 0
+%5 = OpConstantComposite %2 %4 %4
+%6 = OpConstantComposite %1 %5
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%7 = OpCompositeExtract %2 %6 0
+%8 = OpCompositeExtract %3 %7 1
)");
}
@@ -753,31 +833,6 @@
// out of the ScalarConstructor as extract requires integer indices.
}
-TEST_F(BuilderTest, DISABLED_Accessor_Struct_NonPointer) {
- // type A = struct {
- // a : f32;
- // b : f32;
- // };
- // const b : A;
- // b.b
- //
- // This needs to do an OpCompositeExtract on the struct.
-}
-
-TEST_F(BuilderTest, DISABLED_Accessor_NonPointer_Multi) {
- // type A = struct {
- // a : f32;
- // b : vec3<f32, 3>;
- // };
- // type B = struct {
- // c : A;
- // }
- // const b : array<B, 3>;
- // b[2].c.b.yx.x
- //
- // This needs to do an OpCompositeExtract similar to the AccessChain case
-}
-
} // namespace
} // namespace spirv
} // namespace writer