writer/spirv: Fix dynamic array accessors
If the initial array accessor in a chain uses a non-literal index, use
the path that copies the source to a function variable, and then
perform a load from the OpAccessChain result if necessary.
Fixed: tint:426
Change-Id: Ie2f3f388170c02c1d6b73355f0b3bc49c3d3a4e5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49800
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 8a74ad5..7148a78 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -822,11 +822,8 @@
auto* type = TypeOf(expr->idx_expr());
idx_id = GenerateLoadIfNeeded(type, idx_id);
- // If the source is a pointer we access chain into it. We also access chain
- // into an array of non-scalar types.
- if (info->source_type->Is<sem::Pointer>() ||
- (info->source_type->Is<sem::ArrayType>() &&
- !info->source_type->As<sem::ArrayType>()->type()->is_scalar())) {
+ // If the source is a pointer, we access chain into it.
+ if (info->source_type->Is<sem::Pointer>()) {
info->access_chain_indices.push_back(idx_id);
info->source_type = TypeOf(expr);
return true;
@@ -1063,17 +1060,21 @@
}
info.source_type = TypeOf(source);
- // If our initial access is into an array of non-scalar types, and that array
- // is not a pointer, then we need to load that array into a variable in order
- // to access chain into the array.
+ // If our initial access is into a non-pointer array, and either has a
+ // non-scalar element type or the accessor uses a non-literal index, then we
+ // need to load that array into a variable in order to access chain into it.
+ // TODO(jrprice): The non-scalar part shouldn't be necessary, but is tied to
+ // how the Resolver currently determines the type of these expression. This
+ // should be fixed when proper support for ptr/ref types is implemented.
if (auto* array = accessors[0]->As<ast::ArrayAccessorExpression>()) {
- auto* ary_res_type = TypeOf(array->array());
-
- if (!ary_res_type->Is<sem::Pointer>() &&
- (ary_res_type->Is<sem::ArrayType>() &&
- !ary_res_type->As<sem::ArrayType>()->type()->is_scalar())) {
- sem::Pointer ptr(ary_res_type, ast::StorageClass::kFunction);
- auto result_type_id = GenerateTypeIfNeeded(&ptr);
+ auto* ary_res_type = TypeOf(array->array())->As<sem::ArrayType>();
+ if (ary_res_type &&
+ (!ary_res_type->type()->is_scalar() ||
+ !array->idx_expr()->Is<ast::ScalarConstructorExpression>())) {
+ // Wrap the source type in a pointer to function storage.
+ auto ptr =
+ builder_.ty.pointer(ary_res_type, ast::StorageClass::kFunction);
+ auto result_type_id = GenerateTypeIfNeeded(ptr);
if (result_type_id == 0) {
return 0;
}
@@ -1094,6 +1095,7 @@
}
info.source_id = ary_result.to_i();
+ info.source_type = ptr;
}
}
@@ -1115,7 +1117,17 @@
}
if (!info.access_chain_indices.empty()) {
- auto result_type_id = GenerateTypeIfNeeded(TypeOf(expr));
+ bool needs_load = false;
+ auto* ptr = TypeOf(expr);
+ if (!ptr->Is<sem::Pointer>()) {
+ // We are performing an access chain but the final result is not a
+ // pointer, so we need to perform a load to get it. This happens when we
+ // have to copy the source expression into a function variable.
+ ptr = builder_.ty.pointer(ptr, ast::StorageClass::kFunction);
+ needs_load = true;
+ }
+
+ auto result_type_id = GenerateTypeIfNeeded(ptr);
if (result_type_id == 0) {
return 0;
}
@@ -1133,6 +1145,11 @@
return false;
}
info.source_id = result_id;
+
+ // Load from the access chain result if required.
+ if (needs_load) {
+ info.source_id = GenerateLoadIfNeeded(ptr, result_id);
+ }
}
return info.source_id;
diff --git a/src/writer/spirv/builder_accessor_expression_test.cc b/src/writer/spirv/builder_accessor_expression_test.cc
index d0c409b..80819a0 100644
--- a/src/writer/spirv/builder_accessor_expression_test.cc
+++ b/src/writer/spirv/builder_accessor_expression_test.cc
@@ -895,13 +895,54 @@
)");
}
-TEST_F(BuilderTest, DISABLED_Accessor_Array_NonPointer_Dynamic) {
+TEST_F(BuilderTest, Accessor_Array_NonPointer_Dynamic) {
// let a : array<f32, 3>;
// idx : i32
// a[idx]
- //
- // This needs to copy the array to an OpVariable in the Function storage class
- // and then access chain into it and load the result.
+
+ auto* var = GlobalConst("a", ty.array<f32, 3>(),
+ Construct(ty.array<f32, 3>(), 0.0f, 0.5f, 1.0f));
+
+ auto* idx = Var("idx", ty.i32(), ast::StorageClass::kFunction);
+ auto* expr = IndexAccessor("a", idx);
+
+ ast::StatementList body;
+ body.push_back(WrapInStatement(idx));
+ body.push_back(WrapInStatement(expr));
+ WrapInFunction(body);
+
+ spirv::Builder& b = Build();
+
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
+ ASSERT_TRUE(b.GenerateFunctionVariable(idx)) << b.error();
+ EXPECT_EQ(b.GenerateAccessorExpression(expr), 19u) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
+%3 = OpTypeInt 32 0
+%4 = OpConstant %3 3
+%1 = OpTypeArray %2 %4
+%5 = OpConstant %2 0
+%6 = OpConstant %2 0.5
+%7 = OpConstant %2 1
+%8 = OpConstantComposite %1 %5 %6 %7
+%11 = OpTypeInt 32 1
+%10 = OpTypePointer Function %11
+%12 = OpConstantNull %11
+%13 = OpTypePointer Function %1
+%15 = OpConstantNull %1
+%17 = OpTypePointer Function %2
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
+ R"(%9 = OpVariable %10 Function %12
+%14 = OpVariable %13 Function %15
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(OpStore %14 %8
+%16 = OpLoad %11 %9
+%18 = OpAccessChain %17 %14 %16
+%19 = OpLoad %2 %18
+)");
}
} // namespace