[spirv-writer] Fix accessing array of non-scalars.
Currently, if we access an array of non-scalar items we'll incorrectly
emit an OpVectorExtractDynamic which will fail as the result is not
scalar.
This CL updates the array accessor code such that if the base array is
an array of non-scalars we'll do load of the array and then access chain
into the loaded variable.
Bug: tint:267
Change-Id: Ia4d7052b57d8b31b835714b7b922c7859e3dce1f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/29844
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 2500468..06e8a2a 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -352,6 +352,13 @@
if (res->IsPointer()) {
ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
ret, res->AsPointer()->storage_class()));
+ } else if (parent_type->IsArray() &&
+ !parent_type->AsArray()->type()->is_scalar()) {
+ // If we extract a non-scalar from an array then we also get a pointer. We
+ // will generate a Function storage class variable to store this
+ // into.
+ ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
+ ret, ast::StorageClass::kFunction));
}
expr->set_result_type(ret);
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 4b11bf4..e80c60f 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -528,7 +528,7 @@
std::make_unique<ast::IdentifierExpression>("my_var"), std::move(idx));
EXPECT_TRUE(td()->DetermineResultType(&acc));
ASSERT_NE(acc.result_type(), nullptr);
- EXPECT_TRUE(acc.result_type()->IsF32());
+ EXPECT_TRUE(acc.result_type()->IsF32()) << acc.result_type()->type_name();
}
TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Matrix) {
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index f720116..6a99791 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -783,8 +783,11 @@
}
idx_id = GenerateLoadIfNeeded(expr->idx_expr()->result_type(), idx_id);
- // If the source is a pointer we access chain into it.
- if (info->source_type->IsPointer()) {
+ // If the source is a pointer we access chain into it. We also access chain
+ // into an array of non-scalar types.
+ if (info->source_type->IsPointer() ||
+ (info->source_type->IsArray() &&
+ !info->source_type->AsArray()->type()->is_scalar())) {
info->access_chain_indices.push_back(idx_id);
info->source_type = expr->result_type();
return true;
@@ -967,6 +970,37 @@
}
info.source_type = source->result_type();
+ // If our initial access in into an array, and that array is not a pointer,
+ // then we need to load that array into a variable in order to be access
+ // chain into the array
+ if (accessors[0]->IsArrayAccessor()) {
+ auto* ary_res_type =
+ accessors[0]->AsArrayAccessor()->array()->result_type();
+ if (!ary_res_type->IsPointer()) {
+ ast::type::PointerType ptr(ary_res_type, ast::StorageClass::kFunction);
+ auto result_type_id = GenerateTypeIfNeeded(&ptr);
+ if (result_type_id == 0) {
+ return 0;
+ }
+
+ auto ary_result = result_op();
+
+ ast::NullLiteral nl(ary_res_type);
+ auto init = GenerateLiteralIfNeeded(nullptr, &nl);
+
+ // If we're access chaining into an array then we must be in a function
+ push_function_var(
+ {Operand::Int(result_type_id), ary_result,
+ Operand::Int(ConvertStorageClass(ast::StorageClass::kFunction)),
+ Operand::Int(init)});
+
+ push_function_inst(spv::Op::OpStore,
+ {ary_result, Operand::Int(info.source_id)});
+
+ info.source_id = ary_result.to_i();
+ }
+ }
+
std::vector<uint32_t> access_chain_indices;
for (auto* accessor : accessors) {
if (accessor->IsArrayAccessor()) {
diff --git a/src/writer/spirv/builder_accessor_expression_test.cc b/src/writer/spirv/builder_accessor_expression_test.cc
index 7bafbed..d88f8e7 100644
--- a/src/writer/spirv/builder_accessor_expression_test.cc
+++ b/src/writer/spirv/builder_accessor_expression_test.cc
@@ -28,7 +28,10 @@
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
#include "src/ast/type/struct_type.h"
+#include "src/ast/type/u32_type.h"
#include "src/ast/type/vector_type.h"
+#include "src/ast/type_constructor_expression.h"
+#include "src/ast/uint_literal.h"
#include "src/ast/variable.h"
#include "src/context.h"
#include "src/type_determiner.h"
@@ -941,6 +944,90 @@
)");
}
+TEST_F(BuilderTest, Accessor_Array_Of_Vec) {
+ // const pos : array<vec2<f32>, 3> = array<vec2<f32>, 3>(
+ // vec2<f32>(0.0, 0.5),
+ // vec2<f32>(-0.5, -0.5),
+ // vec2<f32>(0.5, -0.5));
+ // pos[1]
+
+ ast::type::F32Type f32;
+ ast::type::U32Type u32;
+ ast::type::VectorType vec(&f32, 2);
+ ast::type::ArrayType arr(&vec, 3);
+
+ ast::ExpressionList ary_params;
+
+ ast::ExpressionList vec_params;
+ vec_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 0.0)));
+ vec_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 0.5)));
+ ary_params.push_back(std::make_unique<ast::TypeConstructorExpression>(
+ &vec, std::move(vec_params)));
+
+ vec_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, -0.5)));
+ vec_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, -0.5)));
+ ary_params.push_back(std::make_unique<ast::TypeConstructorExpression>(
+ &vec, std::move(vec_params)));
+
+ vec_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 0.5)));
+ vec_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, -0.5)));
+ ary_params.push_back(std::make_unique<ast::TypeConstructorExpression>(
+ &vec, std::move(vec_params)));
+
+ ast::Variable var("pos", ast::StorageClass::kPrivate, &arr);
+ var.set_is_const(true);
+ var.set_constructor(std::make_unique<ast::TypeConstructorExpression>(
+ &arr, std::move(ary_params)));
+
+ ast::ArrayAccessorExpression expr(
+ std::make_unique<ast::IdentifierExpression>("pos"),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::UintLiteral>(&u32, 1)));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(&var);
+ ASSERT_TRUE(td.DetermineResultType(var.constructor())) << td.error();
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ Builder b(&mod);
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error();
+ EXPECT_EQ(b.GenerateAccessorExpression(&expr), 18u) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32
+%2 = OpTypeVector %3 2
+%4 = OpTypeInt 32 0
+%5 = OpConstant %4 3
+%1 = OpTypeArray %2 %5
+%6 = OpConstant %3 0
+%7 = OpConstant %3 0.5
+%8 = OpConstantComposite %2 %6 %7
+%9 = OpConstant %3 -0.5
+%10 = OpConstantComposite %2 %9 %9
+%11 = OpConstantComposite %2 %7 %9
+%12 = OpConstantComposite %1 %8 %10 %11
+%13 = OpTypePointer Function %1
+%15 = OpConstantNull %1
+%16 = OpConstant %4 1
+%17 = OpTypePointer Function %2
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
+ R"(%14 = OpVariable %13 Function %15
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(OpStore %14 %12
+%18 = OpAccessChain %17 %14 %16
+)");
+}
+
TEST_F(BuilderTest, DISABLED_Accessor_Array_NonPointer) {
// const a : array<f32, 3>;
// a[2]