[reader-spirv] Convert array, runtime array types
Bug: tint:3
Change-Id: If0d7d38cc777bce3d86dfd83669c1572331d4ed6
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/17800
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index 1f6b6d4..f0fe50f 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -15,6 +15,7 @@
#include "src/reader/spirv/parser_impl.h"
#include <cstring>
+#include <limits>
#include <memory>
#include <string>
#include <utility>
@@ -24,6 +25,7 @@
#include "source/opt/module.h"
#include "source/opt/type_manager.h"
#include "spirv-tools/libspirv.hpp"
+#include "src/ast/type/array_type.h"
#include "src/ast/type/bool_type.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
@@ -46,6 +48,7 @@
ParserImpl::ParserImpl(Context* ctx, const std::vector<uint32_t>& spv_binary)
: Reader(ctx),
+ ctx_(ctx),
spv_binary_(spv_binary),
fail_stream_(&success_, &errors_),
namer_(fail_stream_),
@@ -180,6 +183,56 @@
// In the error case, we'll already have emitted a diagnostic.
break;
}
+ case spvtools::opt::analysis::Type::kRuntimeArray: {
+ const auto* rtarr_ty = spirv_type->AsRuntimeArray();
+ auto* ast_elem_ty =
+ ConvertType(type_mgr_->GetId(rtarr_ty->element_type()));
+ if (ast_elem_ty != nullptr) {
+ result = ctx_.type_mgr->Get(
+ std::make_unique<ast::type::ArrayType>(ast_elem_ty));
+ }
+ // In the error case, we'll already have emitted a diagnostic.
+ break;
+ }
+ case spvtools::opt::analysis::Type::kArray: {
+ const auto* arr_ty = spirv_type->AsArray();
+ auto* ast_elem_ty = ConvertType(type_mgr_->GetId(arr_ty->element_type()));
+ if (ast_elem_ty == nullptr) {
+ // In the error case, we'll already have emitted a diagnostic.
+ break;
+ }
+ const auto& length_info = arr_ty->length_info();
+ if (length_info.words.empty()) {
+ // The internal representation is invalid. The discriminant vector
+ // is mal-formed.
+ Fail() << "internal error: Array length info is invalid";
+ return nullptr;
+ }
+ if (length_info.words[0] !=
+ spvtools::opt::analysis::Array::LengthInfo::kConstant) {
+ Fail() << "Array type " << type_id
+ << " length is a specialization constant";
+ return nullptr;
+ }
+ const auto* constant =
+ constant_mgr_->FindDeclaredConstant(length_info.id);
+ if (constant == nullptr) {
+ Fail() << "Array type " << type_id << " length ID " << length_info.id
+ << " does not name an OpConstant";
+ return nullptr;
+ }
+ const uint64_t num_elem = constant->GetZeroExtendedValue();
+ // For now, limit to only 32bits.
+ if (num_elem > std::numeric_limits<uint32_t>::max()) {
+ Fail() << "Array type " << type_id
+ << " has too many elements (more than can fit in 32 bits): "
+ << num_elem;
+ return nullptr;
+ }
+ result = ctx_.type_mgr->Get(std::make_unique<ast::type::ArrayType>(
+ ast_elem_ty, static_cast<uint32_t>(num_elem)));
+ break;
+ }
default:
// The error diagnostic will be generated below because result is still
// nullptr.
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h
index a422cb6..9c4fa58 100644
--- a/src/reader/spirv/parser_impl.h
+++ b/src/reader/spirv/parser_impl.h
@@ -122,6 +122,8 @@
/// Emit entry point AST nodes.
bool EmitEntryPoints();
+ // The Tint context
+ Context ctx_;
// The SPIR-V binary we're parsing
std::vector<uint32_t> spv_binary_;
diff --git a/src/reader/spirv/parser_impl_convert_type_test.cc b/src/reader/spirv/parser_impl_convert_type_test.cc
index 4a0dfdf..72c6fc3 100644
--- a/src/reader/spirv/parser_impl_convert_type_test.cc
+++ b/src/reader/spirv/parser_impl_convert_type_test.cc
@@ -17,6 +17,7 @@
#include <vector>
#include "gmock/gmock.h"
+#include "src/ast/type/array_type.h"
#include "src/ast/type/matrix_type.h"
#include "src/ast/type/vector_type.h"
#include "src/reader/spirv/parser_impl.h"
@@ -324,6 +325,96 @@
EXPECT_TRUE(p->error().empty());
}
+TEST_F(SpvParserTest, ConvertType_RuntimeArray) {
+ auto p = parser(test::Assemble(R"(
+ %uint = OpTypeInt 32 0
+ %10 = OpTypeRuntimeArray %uint
+ )"));
+ EXPECT_TRUE(p->BuildAndParseInternalModule());
+
+ auto* type = p->ConvertType(10);
+ ASSERT_NE(type, nullptr);
+ EXPECT_TRUE(type->IsArray());
+ auto* arr_type = type->AsArray();
+ EXPECT_TRUE(arr_type->IsRuntimeArray());
+ ASSERT_NE(arr_type, nullptr);
+ ASSERT_EQ(arr_type->size(), 0u);
+ auto* elem_type = arr_type->type();
+ ASSERT_NE(elem_type, nullptr);
+ EXPECT_TRUE(elem_type->IsU32());
+ EXPECT_TRUE(p->error().empty());
+}
+
+TEST_F(SpvParserTest, ConvertType_Array) {
+ auto p = parser(test::Assemble(R"(
+ %uint = OpTypeInt 32 0
+ %uint_42 = OpConstant %uint 42
+ %10 = OpTypeArray %uint %uint_42
+ )"));
+ EXPECT_TRUE(p->BuildAndParseInternalModule());
+
+ auto* type = p->ConvertType(10);
+ ASSERT_NE(type, nullptr);
+ EXPECT_TRUE(type->IsArray());
+ auto* arr_type = type->AsArray();
+ EXPECT_FALSE(arr_type->IsRuntimeArray());
+ ASSERT_NE(arr_type, nullptr);
+ ASSERT_EQ(arr_type->size(), 42u);
+ auto* elem_type = arr_type->type();
+ ASSERT_NE(elem_type, nullptr);
+ EXPECT_TRUE(elem_type->IsU32());
+ EXPECT_TRUE(p->error().empty());
+}
+
+TEST_F(SpvParserTest, ConvertType_ArrayBadLengthIsSpecConstantValue) {
+ auto p = parser(test::Assemble(R"(
+ OpDecorate %uint_42 SpecId 12
+ %uint = OpTypeInt 32 0
+ %uint_42 = OpSpecConstant %uint 42
+ %10 = OpTypeArray %uint %uint_42
+ )"));
+ EXPECT_TRUE(p->BuildAndParseInternalModule());
+
+ auto* type = p->ConvertType(10);
+ ASSERT_EQ(type, nullptr);
+ EXPECT_THAT(p->error(),
+ Eq("Array type 10 length is a specialization constant"));
+}
+
+TEST_F(SpvParserTest, ConvertType_ArrayBadLengthIsSpecConstantExpr) {
+ auto p = parser(test::Assemble(R"(
+ %uint = OpTypeInt 32 0
+ %uint_42 = OpConstant %uint 42
+ %sum = OpSpecConstantOp %uint IAdd %uint_42 %uint_42
+ %10 = OpTypeArray %uint %sum
+ )"));
+ EXPECT_TRUE(p->BuildAndParseInternalModule());
+
+ auto* type = p->ConvertType(10);
+ ASSERT_EQ(type, nullptr);
+ EXPECT_THAT(p->error(),
+ Eq("Array type 10 length is a specialization constant"));
+}
+
+// TODO(dneto): Maybe add a test where the length operand is not a constant.
+// E.g. it's the ID of a type. That won't validate, and the SPIRV-Tools
+// optimizer representation doesn't handle it and asserts out instead.
+
+TEST_F(SpvParserTest, ConvertType_ArrayBadTooBig) {
+ auto p = parser(test::Assemble(R"(
+ %uint64 = OpTypeInt 64 0
+ %uint64_big = OpConstant %uint64 5000000000
+ %10 = OpTypeArray %uint64 %uint64_big
+ )"));
+ EXPECT_TRUE(p->BuildAndParseInternalModule());
+
+ auto* type = p->ConvertType(10);
+ ASSERT_EQ(type, nullptr);
+ // TODO(dneto): Right now it's rejected earlier in the flow because
+ // we can't even utter the uint64 type.
+ EXPECT_THAT(p->error(), Eq("unhandled integer width: 64"));
+}
+
} // namespace
} // namespace spirv
} // namespace reader
diff --git a/src/reader/spirv/parser_impl_test_helper.h b/src/reader/spirv/parser_impl_test_helper.h
index 97ac096..8c40a2b 100644
--- a/src/reader/spirv/parser_impl_test_helper.h
+++ b/src/reader/spirv/parser_impl_test_helper.h
@@ -42,8 +42,8 @@
}
/// Retrieves the parser from the helper
- /// @param input the string to parse
- /// @returns the parser implementation
+ /// @param input the SPIR-V binary to parse
+ /// @returns a parser for the given binary
ParserImpl* parser(const std::vector<uint32_t>& input) {
impl_ = std::make_unique<ParserImpl>(&ctx_, input);
return impl_.get();