[spirv-reader] Add OpCompositeExtract
Bug: tint:3
Change-Id: I9d8c1cf2545e28ef0ddf89e55ce45ec19c50022a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23161
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 05a9c9e..75276f2 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -45,6 +45,7 @@
#include "src/ast/sint_literal.h"
#include "src/ast/storage_class.h"
#include "src/ast/switch_statement.h"
+#include "src/ast/type/u32_type.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/uint_literal.h"
#include "src/ast/unary_op.h"
@@ -2400,6 +2401,10 @@
ast_type, std::move(operands))};
}
+ if (opcode == SpvOpCompositeExtract) {
+ return MakeCompositeExtract(inst);
+ }
+
// builtin readonly function
// glsl.std.450 readonly function
@@ -2462,7 +2467,7 @@
// A SPIR-V access chain is a single instruction with multiple indices
// walking down into composites. The Tint AST represents this as
- // ever-deeper nested indexing expresions. Start off with an expression
+ // ever-deeper nested indexing expressions. Start off with an expression
// for the base, and then bury that inside nested indexing expressions.
TypedExpression current_expr(MakeOperand(inst, 0));
@@ -2574,6 +2579,113 @@
return current_expr;
}
+TypedExpression FunctionEmitter::MakeCompositeExtract(
+ const spvtools::opt::Instruction& inst) {
+ // This is structurally similar to creating an access chain, but
+ // the SPIR-V instruction has literal indices instead of IDs for indices.
+
+ // A SPIR-V composite extract is a single instruction with multiple
+ // literal indices walking down into composites. The Tint AST represents
+ // this as ever-deeper nested indexing expressions. Start off with an
+ // expression for the composite, and then bury that inside nested indexing
+ // expressions.
+ TypedExpression current_expr(MakeOperand(inst, 0));
+
+ auto make_index = [](uint32_t literal) {
+ ast::type::U32Type u32;
+ return std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::UintLiteral>(&u32, literal));
+ };
+ static const char* swizzles[] = {"x", "y", "z", "w"};
+
+ const auto composite = inst.GetSingleWordInOperand(0);
+ const auto composite_type_id = def_use_mgr_->GetDef(composite)->type_id();
+ const auto* current_type = type_mgr_->GetType(composite_type_id);
+ const auto num_in_operands = inst.NumInOperands();
+ for (uint32_t index = 1; index < num_in_operands; ++index) {
+ const uint32_t index_val = inst.GetSingleWordInOperand(index);
+ std::unique_ptr<ast::Expression> next_expr;
+ switch (current_type->kind()) {
+ case spvtools::opt::analysis::Type::kVector: {
+ // Try generating a MemberAccessor expression. That result in something
+ // like "foo.z", which is more idiomatic than "foo[2]".
+ if (current_type->AsVector()->element_count() <= index_val) {
+ Fail() << "CompositeExtract %" << inst.result_id() << " index value "
+ << index_val << " is out of bounds for vector of "
+ << current_type->AsVector()->element_count() << " elements";
+ return {};
+ }
+ if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) {
+ Fail() << "internal error: swizzle index " << index_val
+ << " is too big. Max handled index is "
+ << ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
+ }
+ auto letter_index =
+ std::make_unique<ast::IdentifierExpression>(swizzles[index_val]);
+ next_expr = std::make_unique<ast::MemberAccessorExpression>(
+ std::move(current_expr.expr), std::move(letter_index));
+ current_type = current_type->AsVector()->element_type();
+ break;
+ }
+ case spvtools::opt::analysis::Type::kMatrix:
+ // Check bounds
+ if (current_type->AsMatrix()->element_count() <= index_val) {
+ Fail() << "CompositeExtract %" << inst.result_id() << " index value "
+ << index_val << " is out of bounds for matrix of "
+ << current_type->AsMatrix()->element_count() << " elements";
+ return {};
+ }
+ if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) {
+ Fail() << "internal error: swizzle index " << index_val
+ << " is too big. Max handled index is "
+ << ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
+ }
+ // Use array syntax.
+ next_expr = std::make_unique<ast::ArrayAccessorExpression>(
+ std::move(current_expr.expr), make_index(index_val));
+ current_type = current_type->AsMatrix()->element_type();
+ break;
+ case spvtools::opt::analysis::Type::kArray:
+ // The array size could be a spec constant, and so it's not always
+ // statically checkable. Instead, rely on a runtime index clamp
+ // or runtime check to keep this safe.
+ next_expr = std::make_unique<ast::ArrayAccessorExpression>(
+ std::move(current_expr.expr), make_index(index_val));
+ current_type = current_type->AsArray()->element_type();
+ break;
+ case spvtools::opt::analysis::Type::kRuntimeArray:
+ Fail() << "can't do OpCompositeExtract on a runtime array";
+ return {};
+ case spvtools::opt::analysis::Type::kStruct: {
+ if (current_type->AsStruct()->element_types().size() <= index_val) {
+ Fail() << "CompositeExtract %" << inst.result_id() << " index value "
+ << index_val << " is out of bounds for structure %"
+ << type_mgr_->GetId(current_type) << " having "
+ << current_type->AsStruct()->element_types().size()
+ << " elements";
+ return {};
+ }
+ auto member_access =
+ std::make_unique<ast::IdentifierExpression>(namer_.GetMemberName(
+ type_mgr_->GetId(current_type), uint32_t(index_val)));
+
+ next_expr = std::make_unique<ast::MemberAccessorExpression>(
+ std::move(current_expr.expr), std::move(member_access));
+ current_type = current_type->AsStruct()->element_types()[index_val];
+ break;
+ }
+ default:
+ Fail() << "CompositeExtract with bad type %"
+ << type_mgr_->GetId(current_type) << " " << current_type->str();
+ return {};
+ }
+ current_expr.reset(TypedExpression(
+ parser_impl_.ConvertType(type_mgr_->GetId(current_type)),
+ std::move(next_expr)));
+ }
+ return current_expr;
+}
+
} // namespace spirv
} // namespace reader
} // namespace tint
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 0e92907..686c924 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -433,6 +433,11 @@
/// @returns an AST expression for the instruction, or nullptr.
TypedExpression EmitGlslStd450ExtInst(const spvtools::opt::Instruction& inst);
+ /// Creates an expression for OpCompositeExtract
+ /// @param inst an OpCompositeExtract instruction.
+ /// @returns an AST expression for the instruction, or nullptr.
+ TypedExpression MakeCompositeExtract(const spvtools::opt::Instruction& inst);
+
/// Gets the block info for a block ID, if any exists
/// @param id the SPIR-V ID of the OpLabel instruction starting the block
/// @returns the block info for the given ID, if it exists, or nullptr
diff --git a/src/reader/spirv/function_composite_test.cc b/src/reader/spirv/function_composite_test.cc
index 6030a90..944db61 100644
--- a/src/reader/spirv/function_composite_test.cc
+++ b/src/reader/spirv/function_composite_test.cc
@@ -27,6 +27,7 @@
namespace spirv {
namespace {
+using ::testing::Eq;
using ::testing::HasSubstr;
std::string Preamble() {
@@ -54,6 +55,7 @@
%v2float = OpTypeVector %float 2
%m3v2float = OpTypeMatrix %v2float 3
+ %m3v2float_0 = OpConstantNull %m3v2float
%s_v2f_u_i = OpTypeStruct %v2float %uint %int
%a_u_5 = OpTypeArray %uint %uint_5
@@ -229,6 +231,283 @@
<< ToString(fe.ast_body());
}
+using SpvParserTest_CompositeExtract = SpvParserTest;
+
+TEST_F(SpvParserTest_CompositeExtract, Vector) {
+ const auto assembly = Preamble() + R"(
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpCompositeExtract %float %v2float_50_60 1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+ Variable{
+ x_1
+ none
+ __f32
+ {
+ MemberAccessor{
+ TypeConstructor{
+ __vec_2__f32
+ ScalarConstructor{50.000000}
+ ScalarConstructor{60.000000}
+ }
+ Identifier{y}
+ }
+ }
+ })")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest_CompositeExtract, Vector_IndexTooBigError) {
+ const auto assembly = Preamble() + R"(
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpCompositeExtract %float %v2float_50_60 900
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_FALSE(fe.EmitBody());
+ EXPECT_THAT(p->error(), Eq("CompositeExtract %1 index value 900 is out of "
+ "bounds for vector of 2 elements"));
+}
+
+TEST_F(SpvParserTest_CompositeExtract, Matrix) {
+ const auto assembly = Preamble() + R"(
+ %ptr = OpTypePointer Function %m3v2float
+ %var = OpVariable %ptr Function
+
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpLoad %m3v2float %var
+ %2 = OpCompositeExtract %v2float %1 2
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+ Variable{
+ x_2
+ none
+ __vec_2__f32
+ {
+ ArrayAccessor{
+ Identifier{x_1}
+ ScalarConstructor{2}
+ }
+ }
+ })"))
+ << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest_CompositeExtract, Matrix_IndexTooBigError) {
+ const auto assembly = Preamble() + R"(
+ %ptr = OpTypePointer Function %m3v2float
+ %var = OpVariable %ptr Function
+
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpLoad %m3v2float %var
+ %2 = OpCompositeExtract %v2float %1 3
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_FALSE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(p->error(), Eq("CompositeExtract %2 index value 3 is out of "
+ "bounds for matrix of 3 elements"));
+}
+
+TEST_F(SpvParserTest_CompositeExtract, Matrix_Vector) {
+ const auto assembly = Preamble() + R"(
+ %ptr = OpTypePointer Function %m3v2float
+ %var = OpVariable %ptr Function
+
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpLoad %m3v2float %var
+ %2 = OpCompositeExtract %float %1 2 1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+ Variable{
+ x_2
+ none
+ __f32
+ {
+ MemberAccessor{
+ ArrayAccessor{
+ Identifier{x_1}
+ ScalarConstructor{2}
+ }
+ Identifier{y}
+ }
+ }
+ })"))
+ << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest_CompositeExtract, Array) {
+ const auto assembly = Preamble() + R"(
+ %ptr = OpTypePointer Function %a_u_5
+ %var = OpVariable %ptr Function
+
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpLoad %a_u_5 %var
+ %2 = OpCompositeExtract %uint %1 3
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+ Variable{
+ x_2
+ none
+ __u32
+ {
+ ArrayAccessor{
+ Identifier{x_1}
+ ScalarConstructor{3}
+ }
+ }
+ })"))
+ << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest_CompositeExtract, RuntimeArray_IsError) {
+ const auto assembly = Preamble() + R"(
+ %rtarr = OpTypeRuntimeArray %uint
+ %ptr = OpTypePointer Function %rtarr
+ %var = OpVariable %ptr Function
+
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpLoad %rtarr %var
+ %2 = OpCompositeExtract %uint %1 3
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_FALSE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(p->error(), Eq("can't do OpCompositeExtract on a runtime array"));
+}
+
+TEST_F(SpvParserTest_CompositeExtract, Struct) {
+ const auto assembly = Preamble() + R"(
+ %ptr = OpTypePointer Function %s_v2f_u_i
+ %var = OpVariable %ptr Function
+
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpLoad %s_v2f_u_i %var
+ %2 = OpCompositeExtract %int %1 2
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+ Variable{
+ x_2
+ none
+ __i32
+ {
+ MemberAccessor{
+ Identifier{x_1}
+ Identifier{field2}
+ }
+ }
+ })"))
+ << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest_CompositeExtract, Struct_IndexTooBigError) {
+ const auto assembly = Preamble() + R"(
+ %ptr = OpTypePointer Function %s_v2f_u_i
+ %var = OpVariable %ptr Function
+
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpLoad %s_v2f_u_i %var
+ %2 = OpCompositeExtract %int %1 40
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_FALSE(fe.EmitBody());
+ EXPECT_THAT(p->error(), Eq("CompositeExtract %2 index value 40 is out of "
+ "bounds for structure %23 having 3 elements"));
+}
+
+TEST_F(SpvParserTest_CompositeExtract, Struct_Array_Matrix_Vector) {
+ const auto assembly = Preamble() + R"(
+ %a_mat = OpTypeArray %m3v2float %uint_3
+ %s = OpTypeStruct %uint %a_mat
+ %ptr = OpTypePointer Function %s
+ %var = OpVariable %ptr Function
+
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpLoad %s %var
+ %2 = OpCompositeExtract %float %1 1 2 0 1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+ Variable{
+ x_2
+ none
+ __f32
+ {
+ MemberAccessor{
+ ArrayAccessor{
+ ArrayAccessor{
+ MemberAccessor{
+ Identifier{x_1}
+ Identifier{field1}
+ }
+ ScalarConstructor{2}
+ }
+ ScalarConstructor{0}
+ }
+ Identifier{y}
+ }
+ }
+ })"))
+ << ToString(fe.ast_body());
+}
+
} // namespace
} // namespace spirv
} // namespace reader