[spirv-reader] Support access chain
Bug: tint:3
Change-Id: Ibdb6698c4a97ce66ed533a9bf007bc352a09244e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/21641
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 94aa686..8b2617c 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -23,10 +23,12 @@
#include "source/opt/function.h"
#include "source/opt/instruction.h"
#include "source/opt/module.h"
+#include "src/ast/array_accessor_expression.h"
#include "src/ast/as_expression.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/binary_expression.h"
#include "src/ast/identifier_expression.h"
+#include "src/ast/member_accessor_expression.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/storage_class.h"
#include "src/ast/uint_literal.h"
@@ -1492,32 +1494,31 @@
return Fail() << "unhandled instruction with opcode " << inst.opcode();
}
+TypedExpression FunctionEmitter::MakeOperand(
+ const spvtools::opt::Instruction& inst,
+ uint32_t operand_index) {
+ auto expr = this->MakeExpression(inst.GetSingleWordInOperand(operand_index));
+ return parser_impl_.RectifyOperandSignedness(inst.opcode(), std::move(expr));
+}
+
TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
const spvtools::opt::Instruction& inst) {
if (inst.result_id() == 0) {
return {};
}
- // TODO(dneto): Fill in the following cases.
-
- auto operand = [this, &inst](uint32_t operand_index) {
- auto expr =
- this->MakeExpression(inst.GetSingleWordInOperand(operand_index));
- return parser_impl_.RectifyOperandSignedness(inst.opcode(),
- std::move(expr));
- };
+ const auto opcode = inst.opcode();
ast::type::Type* ast_type =
inst.type_id() != 0 ? parser_impl_.ConvertType(inst.type_id()) : nullptr;
- auto binary_op = ConvertBinaryOp(inst.opcode());
+ auto binary_op = ConvertBinaryOp(opcode);
if (binary_op != ast::BinaryOp::kNone) {
- auto arg0 = operand(0);
- auto arg1 = operand(1);
+ auto arg0 = MakeOperand(inst, 0);
+ auto arg1 = MakeOperand(inst, 1);
auto binary_expr = std::make_unique<ast::BinaryExpression>(
binary_op, std::move(arg0.expr), std::move(arg1.expr));
- auto* forced_result_ty =
- parser_impl_.ForcedResultType(inst.opcode(), arg0.type);
+ auto* forced_result_ty = parser_impl_.ForcedResultType(opcode, arg0.type);
if (forced_result_ty && forced_result_ty != ast_type) {
return {ast_type, std::make_unique<ast::AsExpression>(
ast_type, std::move(binary_expr))};
@@ -1526,12 +1527,11 @@
}
auto unary_op = ast::UnaryOp::kNegation;
- if (GetUnaryOp(inst.opcode(), &unary_op)) {
- auto arg0 = operand(0);
+ if (GetUnaryOp(opcode, &unary_op)) {
+ auto arg0 = MakeOperand(inst, 0);
auto unary_expr = std::make_unique<ast::UnaryOpExpression>(
unary_op, std::move(arg0.expr));
- auto* forced_result_ty =
- parser_impl_.ForcedResultType(inst.opcode(), arg0.type);
+ auto* forced_result_ty = parser_impl_.ForcedResultType(opcode, arg0.type);
if (forced_result_ty && forced_result_ty != ast_type) {
return {ast_type, std::make_unique<ast::AsExpression>(
ast_type, std::move(unary_expr))};
@@ -1539,16 +1539,19 @@
return {ast_type, std::move(unary_expr)};
}
- if (inst.opcode() == SpvOpBitcast) {
- auto* target_ty = parser_impl_.ConvertType(inst.type_id());
- return {target_ty,
- std::make_unique<ast::AsExpression>(target_ty, operand(0).expr)};
+ if (opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain) {
+ return MakeAccessChain(inst);
}
- auto negated_op = NegatedFloatCompare(inst.opcode());
+ if (opcode == SpvOpBitcast) {
+ return {ast_type, std::make_unique<ast::AsExpression>(
+ ast_type, MakeOperand(inst, 0).expr)};
+ }
+
+ auto negated_op = NegatedFloatCompare(opcode);
if (negated_op != ast::BinaryOp::kNone) {
- auto arg0 = operand(0);
- auto arg1 = operand(1);
+ auto arg0 = MakeOperand(inst, 0);
+ auto arg1 = MakeOperand(inst, 1);
auto binary_expr = std::make_unique<ast::BinaryExpression>(
negated_op, std::move(arg0.expr), std::move(arg1.expr));
auto negated_expr = std::make_unique<ast::UnaryOpExpression>(
@@ -1578,8 +1581,6 @@
// OpGenericCastToPtr // Not in Vulkan
// OpGenericCastToPtrExplicit // Not in Vulkan
//
- // OpAccessChain
- // OpInBoundsAccessChain
// OpArrayLength
// OpVectorExtractDynamic
// OpVectorInsertDynamic
@@ -1589,6 +1590,130 @@
return {};
}
+TypedExpression FunctionEmitter::MakeAccessChain(
+ const spvtools::opt::Instruction& inst) {
+ if (inst.NumInOperands() < 1) {
+ // Binary parsing will fail on this anyway.
+ Fail() << "invalid access chain: has no input operands";
+ return {};
+ }
+
+ // A SPIR-V access chain is a single instruction with multiple indices
+ // walking down into composites. The Tint AST represents this as ever-deeper
+ // nested indexing expresions.
+ // Start off with an expression for the base, and then bury that inside
+ // nested indexing expressions.
+ TypedExpression current_expr(MakeOperand(inst, 0));
+
+ const auto constants = constant_mgr_->GetOperandConstants(&inst);
+ static const char* swizzles[] = {"x", "y", "z", "w"};
+
+ const auto base_id = inst.GetSingleWordInOperand(0);
+ const auto ptr_ty_id = def_use_mgr_->GetDef(base_id)->type_id();
+ const auto* ptr_type = type_mgr_->GetType(ptr_ty_id);
+ if (!ptr_type || !ptr_type->AsPointer()) {
+ Fail() << "Access chain %" << inst.result_id()
+ << " base pointer is not of pointer type";
+ return {};
+ }
+ const auto* pointee_type = ptr_type->AsPointer()->pointee_type();
+ const auto num_in_operands = inst.NumInOperands();
+ for (uint32_t index = 1; index < num_in_operands; ++index) {
+ const auto* index_const =
+ constants[index] ? constants[index]->AsIntConstant() : nullptr;
+ const int64_t index_const_val =
+ index_const ? index_const->GetSignExtendedValue() : 0;
+ std::unique_ptr<ast::Expression> next_expr;
+ switch (pointee_type->kind()) {
+ case spvtools::opt::analysis::Type::kVector:
+ if (index_const) {
+ // Try generating a MemberAccessor expression.
+ if (index_const_val < 0 ||
+ pointee_type->AsVector()->element_count() <= index_const_val) {
+ Fail() << "Access chain %" << inst.result_id() << " index %"
+ << inst.GetSingleWordInOperand(index) << " value "
+ << index_const_val
+ << " is out of bounds for vector of "
+ << pointee_type->AsVector()->element_count()
+ << " elements";
+ return {};
+ }
+ if (uint64_t(index_const_val) >=
+ sizeof(swizzles) / sizeof(swizzles[0])) {
+ Fail() << "internal error: swizzle index " << index_const_val
+ << " is too big. Max handled index is "
+ << ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
+ }
+ auto letter_index = std::make_unique<ast::IdentifierExpression>(
+ swizzles[index_const_val]);
+ next_expr = std::make_unique<ast::MemberAccessorExpression>(
+ std::move(current_expr.expr), std::move(letter_index));
+ } else {
+ // Non-constant index. Use array syntax
+ next_expr = std::make_unique<ast::ArrayAccessorExpression>(
+ std::move(current_expr.expr),
+ std::move(MakeOperand(inst, index).expr));
+ }
+ pointee_type = pointee_type->AsVector()->element_type();
+ break;
+ case spvtools::opt::analysis::Type::kMatrix:
+ // Use array syntax.
+ next_expr = std::make_unique<ast::ArrayAccessorExpression>(
+ std::move(current_expr.expr), std::move(MakeOperand(inst, index).expr));
+ pointee_type = pointee_type->AsMatrix()->element_type();
+ break;
+ case spvtools::opt::analysis::Type::kArray:
+ next_expr = std::make_unique<ast::ArrayAccessorExpression>(
+ std::move(current_expr.expr), std::move(MakeOperand(inst, index).expr));
+ pointee_type = pointee_type->AsArray()->element_type();
+ break;
+ case spvtools::opt::analysis::Type::kRuntimeArray:
+ next_expr = std::make_unique<ast::ArrayAccessorExpression>(
+ std::move(current_expr.expr), std::move(MakeOperand(inst, index).expr));
+ pointee_type = pointee_type->AsRuntimeArray()->element_type();
+ break;
+ case spvtools::opt::analysis::Type::kStruct: {
+ if (!index_const) {
+ Fail() << "Access chain %" << inst.result_id() << " index %"
+ << inst.GetSingleWordInOperand(index)
+ << " is a non-constant index into a structure %"
+ << type_mgr_->GetId(pointee_type);
+ return {};
+ }
+ if ((index_const_val < 0) ||
+ pointee_type->AsStruct()->element_types().size() <=
+ uint64_t(index_const_val)) {
+ Fail() << "Access chain %" << inst.result_id()
+ << " index value " << index_const_val
+ << " is out of bounds for structure %"
+ << type_mgr_->GetId(pointee_type) << " having "
+ << pointee_type->AsStruct()->element_types().size()
+ << " elements";
+ return {};
+ }
+ auto member_access =
+ std::make_unique<ast::IdentifierExpression>(namer_.GetMemberName(
+ type_mgr_->GetId(pointee_type), uint32_t(index_const_val)));
+
+ next_expr = std::make_unique<ast::MemberAccessorExpression>(
+ std::move(current_expr.expr), std::move(member_access));
+ pointee_type =
+ pointee_type->AsStruct()->element_types()[index_const_val];
+ break;
+ }
+ default:
+ Fail() << "Access chain with unknown pointee type %"
+ << type_mgr_->GetId(pointee_type) << " "
+ << pointee_type->str();
+ return {};
+ }
+ current_expr.reset(TypedExpression(
+ parser_impl_.ConvertType(type_mgr_->GetId(pointee_type)),
+ std::move(next_expr)));
+ }
+ return current_expr;
+}
+
} // namespace spirv
} // namespace reader
} // namespace tint
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 86a9a6f..36490d4 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -316,6 +316,21 @@
ast::type::Type* GetVariableStoreType(
const spvtools::opt::Instruction& var_decl_inst);
+ /// Returns an expression for an instruction operand. Signedness conversion is
+ /// performed to match the result type of the SPIR-V instruction.
+ /// @param inst the SPIR-V instruction
+ /// @param operand_index the index of the operand, counting 0 as the first
+ /// input operand
+ /// @returns a new expression node
+ TypedExpression MakeOperand(const spvtools::opt::Instruction& inst,
+ uint32_t operand_index);
+
+ /// Returns an expression for a SPIR-V OpAccessChain or OpInBoundsAccessChain
+ /// instruction.
+ /// @param inst the SPIR-V instruction
+ /// @returns an expression
+ TypedExpression MakeAccessChain(const spvtools::opt::Instruction& inst);
+
/// Finds the header block for a structured construct that we can "break"
/// out from, from deeply nested control flow, if such a block exists.
/// If the construct is:
diff --git a/src/reader/spirv/function_memory_test.cc b/src/reader/spirv/function_memory_test.cc
index aafc6cb..5bdfb2c 100644
--- a/src/reader/spirv/function_memory_test.cc
+++ b/src/reader/spirv/function_memory_test.cc
@@ -26,6 +26,7 @@
namespace spirv {
namespace {
+using ::testing::Eq;
using ::testing::HasSubstr;
TEST_F(SpvParserTest, EmitStatement_StoreBoolConst) {
@@ -279,6 +280,434 @@
})"));
}
+TEST_F(SpvParserTest, EmitStatement_AccessChain_NoOperands) {
+ auto err = test::AssembleFailure(R"(
+ %void = OpTypeVoid
+ %voidfn = OpTypeFunction %void
+ %ty = OpTypeInt 32 0
+ %val = OpConstant %ty 42
+ %ptr_ty = OpTypePointer Workgroup %ty
+ %1 = OpVariable %ptr_ty Workgroup
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+
+ %2 = OpAccessChain %ptr_ty ; Needs a base operand
+ OpStore %1 %val
+ OpReturn
+ )");
+ EXPECT_THAT(err,
+ Eq("11:5: Expected operand, found next instruction instead."));
+}
+
+TEST_F(SpvParserTest, EmitStatement_AccessChain_BaseIsNotPointer) {
+ auto* p = parser(test::Assemble(R"(
+ %void = OpTypeVoid
+ %voidfn = OpTypeFunction %void
+ %10 = OpTypeInt 32 0
+ %val = OpConstant %10 42
+ %ptr_ty = OpTypePointer Workgroup %10
+ %20 = OpVariable %10 Workgroup ; bad pointer type
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpAccessChain %ptr_ty %20
+ OpStore %1 %val
+ OpReturn
+ )"));
+ EXPECT_FALSE(p->BuildAndParseInternalModuleExceptFunctions());
+ EXPECT_THAT(p->error(), Eq("variable with ID 20 has non-pointer type 10"));
+}
+
+TEST_F(SpvParserTest, EmitStatement_AccessChain_VectorSwizzle) {
+ const std::string assembly = R"(
+ OpName %1 "myvar"
+ %void = OpTypeVoid
+ %voidfn = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %store_ty = OpTypeVector %uint 4
+ %uint_2 = OpConstant %uint 2
+ %uint_42 = OpConstant %uint 42
+ %elem_ty = OpTypePointer Workgroup %uint
+ %var_ty = OpTypePointer Workgroup %store_ty
+ %1 = OpVariable %var_ty Workgroup
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %2 = OpAccessChain %elem_ty %1 %uint_2
+ OpStore %2 %uint_42
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+ << assembly << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody());
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{
+ MemberAccessor{
+ Identifier{myvar}
+ Identifier{z}
+ }
+ ScalarConstructor{42}
+})"));
+}
+
+TEST_F(SpvParserTest, EmitStatement_AccessChain_VectorConstOutOfBounds) {
+ const std::string assembly = R"(
+ OpName %1 "myvar"
+ %void = OpTypeVoid
+ %voidfn = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %store_ty = OpTypeVector %uint 4
+ %42 = OpConstant %uint 42
+ %uint_99 = OpConstant %uint 99
+ %elem_ty = OpTypePointer Workgroup %uint
+ %var_ty = OpTypePointer Workgroup %store_ty
+ %1 = OpVariable %var_ty Workgroup
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %2 = OpAccessChain %elem_ty %1 %42
+ OpStore %2 %uint_99
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+ << assembly << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_FALSE(fe.EmitBody());
+ EXPECT_THAT(p->error(), Eq("Access chain %2 index %42 value 42 is out of "
+ "bounds for vector of 4 elements"));
+}
+
+TEST_F(SpvParserTest, EmitStatement_AccessChain_VectorNonConstIndex) {
+ const std::string assembly = R"(
+ OpName %1 "myvar"
+ %void = OpTypeVoid
+ %voidfn = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %store_ty = OpTypeVector %uint 4
+ %uint_2 = OpConstant %uint 2
+ %uint_42 = OpConstant %uint 42
+ %elem_ty = OpTypePointer Workgroup %uint
+ %var_ty = OpTypePointer Workgroup %store_ty
+ %1 = OpVariable %var_ty Workgroup
+ %10 = OpVariable %var_ty Workgroup
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %11 = OpLoad %uint %10
+ %2 = OpAccessChain %elem_ty %1 %11
+ OpStore %2 %uint_42
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+ << assembly << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody());
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{
+ ArrayAccessor{
+ Identifier{myvar}
+ Identifier{x_11}
+ }
+ ScalarConstructor{42}
+})"));
+}
+
+TEST_F(SpvParserTest, EmitStatement_AccessChain_Matrix) {
+ const std::string assembly = R"(
+ OpName %1 "myvar"
+ %void = OpTypeVoid
+ %voidfn = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v4float = OpTypeVector %float 4
+ %m3v4float = OpTypeMatrix %v4float 3
+ %elem_ty = OpTypePointer Workgroup %v4float
+ %var_ty = OpTypePointer Workgroup %m3v4float
+ %uint = OpTypeInt 32 0
+ %uint_2 = OpConstant %uint 2
+ %float_42 = OpConstant %float 42
+ %v4float_42 = OpConstantComposite %v4float %float_42 %float_42 %float_42 %float_42
+
+ %1 = OpVariable %var_ty Workgroup
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %2 = OpAccessChain %elem_ty %1 %uint_2
+ OpStore %2 %v4float_42
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+ << assembly << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody());
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{
+ ArrayAccessor{
+ Identifier{myvar}
+ ScalarConstructor{2}
+ }
+ TypeConstructor{
+ __vec_4__f32
+ ScalarConstructor{42.000000}
+ ScalarConstructor{42.000000}
+ ScalarConstructor{42.000000}
+ ScalarConstructor{42.000000}
+ }
+})"));
+}
+
+TEST_F(SpvParserTest, EmitStatement_AccessChain_Array) {
+ const std::string assembly = R"(
+ OpName %1 "myvar"
+ %void = OpTypeVoid
+ %voidfn = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v4float = OpTypeVector %float 4
+ %m3v4float = OpTypeMatrix %v4float 3
+ %elem_ty = OpTypePointer Workgroup %v4float
+ %var_ty = OpTypePointer Workgroup %m3v4float
+ %uint = OpTypeInt 32 0
+ %uint_2 = OpConstant %uint 2
+ %float_42 = OpConstant %float 42
+ %v4float_42 = OpConstantComposite %v4float %float_42 %float_42 %float_42 %float_42
+
+ %1 = OpVariable %var_ty Workgroup
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %2 = OpAccessChain %elem_ty %1 %uint_2
+ OpStore %2 %v4float_42
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+ << assembly << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody());
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{
+ ArrayAccessor{
+ Identifier{myvar}
+ ScalarConstructor{2}
+ }
+ TypeConstructor{
+ __vec_4__f32
+ ScalarConstructor{42.000000}
+ ScalarConstructor{42.000000}
+ ScalarConstructor{42.000000}
+ ScalarConstructor{42.000000}
+ }
+})"));
+}
+
+TEST_F(SpvParserTest, EmitStatement_AccessChain_Struct) {
+ const std::string assembly = R"(
+ OpName %1 "myvar"
+ OpMemberName %strct 1 "age"
+ %void = OpTypeVoid
+ %voidfn = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %float_42 = OpConstant %float 42
+ %strct = OpTypeStruct %float %float
+ %elem_ty = OpTypePointer Workgroup %float
+ %var_ty = OpTypePointer Workgroup %strct
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+
+ %1 = OpVariable %var_ty Workgroup
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %2 = OpAccessChain %elem_ty %1 %uint_1
+ OpStore %2 %float_42
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+ << assembly << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody());
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{
+ MemberAccessor{
+ Identifier{myvar}
+ Identifier{age}
+ }
+ ScalarConstructor{42.000000}
+})"));
+}
+
+TEST_F(SpvParserTest, EmitStatement_AccessChain_StructNonConstIndex) {
+ const std::string assembly = R"(
+ OpName %1 "myvar"
+ OpMemberName %55 1 "age"
+ %void = OpTypeVoid
+ %voidfn = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %float_42 = OpConstant %float 42
+ %55 = OpTypeStruct %float %float
+ %elem_ty = OpTypePointer Workgroup %float
+ %var_ty = OpTypePointer Workgroup %55
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %uint_ptr = OpTypePointer Workgroup %uint
+ %uintvar = OpVariable %uint_ptr Workgroup
+
+ %1 = OpVariable %var_ty Workgroup
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %10 = OpLoad %uint %uintvar
+ %2 = OpAccessChain %elem_ty %1 %10
+ OpStore %2 %float_42
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+ << assembly << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_FALSE(fe.EmitBody());
+ EXPECT_THAT(p->error(), Eq("Access chain %2 index %10 is a non-constant "
+ "index into a structure %55"));
+}
+
+TEST_F(SpvParserTest, EmitStatement_AccessChain_StructConstOutOfBounds) {
+ const std::string assembly = R"(
+ OpName %1 "myvar"
+ OpMemberName %55 1 "age"
+ %void = OpTypeVoid
+ %voidfn = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %float_42 = OpConstant %float 42
+ %55 = OpTypeStruct %float %float
+ %elem_ty = OpTypePointer Workgroup %float
+ %var_ty = OpTypePointer Workgroup %55
+ %uint = OpTypeInt 32 0
+ %uint_99 = OpConstant %uint 99
+
+ %1 = OpVariable %var_ty Workgroup
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %2 = OpAccessChain %elem_ty %1 %uint_99
+ OpStore %2 %float_42
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+ << assembly << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_FALSE(fe.EmitBody());
+ EXPECT_THAT(p->error(), Eq("Access chain %2 index value 99 is out of bounds "
+ "for structure %55 having 2 elements"));
+}
+
+TEST_F(SpvParserTest, EmitStatement_AccessChain_Struct_RuntimeArray) {
+ const std::string assembly = R"(
+ OpName %1 "myvar"
+ OpMemberName %strct 1 "age"
+ %void = OpTypeVoid
+ %voidfn = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %float_42 = OpConstant %float 42
+ %rtarr = OpTypeRuntimeArray %float
+ %strct = OpTypeStruct %float %rtarr
+ %elem_ty = OpTypePointer Workgroup %float
+ %var_ty = OpTypePointer Workgroup %strct
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %uint_2 = OpConstant %uint 2
+
+ %1 = OpVariable %var_ty Workgroup
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %2 = OpAccessChain %elem_ty %1 %uint_1 %uint_2
+ OpStore %2 %float_42
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+ << assembly << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody());
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{
+ ArrayAccessor{
+ MemberAccessor{
+ Identifier{myvar}
+ Identifier{age}
+ }
+ ScalarConstructor{2}
+ }
+ ScalarConstructor{42.000000}
+})"));
+}
+
+TEST_F(SpvParserTest, EmitStatement_AccessChain_Compound_Matrix_Vector) {
+ const std::string assembly = R"(
+ OpName %1 "myvar"
+ %void = OpTypeVoid
+ %voidfn = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v4float = OpTypeVector %float 4
+ %m3v4float = OpTypeMatrix %v4float 3
+ %elem_ty = OpTypePointer Workgroup %float
+ %var_ty = OpTypePointer Workgroup %m3v4float
+ %uint = OpTypeInt 32 0
+ %uint_2 = OpConstant %uint 2
+ %uint_3 = OpConstant %uint 3
+ %float_42 = OpConstant %float 42
+
+ %1 = OpVariable %var_ty Workgroup
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %2 = OpAccessChain %elem_ty %1 %uint_2 %uint_3
+ OpStore %2 %float_42
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+ << assembly << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody());
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{
+ MemberAccessor{
+ ArrayAccessor{
+ Identifier{myvar}
+ ScalarConstructor{2}
+ }
+ Identifier{w}
+ }
+ ScalarConstructor{42.000000}
+})"));
+}
+
+TEST_F(SpvParserTest, EmitStatement_AccessChain_InvalidPointeeType) {
+ const std::string assembly = R"(
+ OpName %1 "myvar"
+ %55 = OpTypeVoid
+ %voidfn = OpTypeFunction %55
+ %float = OpTypeFloat 32
+ %60 = OpTypePointer Workgroup %55
+ %var_ty = OpTypePointer Workgroup %60
+ %uint = OpTypeInt 32 0
+ %uint_2 = OpConstant %uint 2
+
+ %1 = OpVariable %var_ty Workgroup
+ %100 = OpFunction %55 None %voidfn
+ %entry = OpLabel
+ %2 = OpAccessChain %60 %1 %uint_2
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+ << assembly << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_FALSE(fe.EmitBody());
+ EXPECT_THAT(p->error(),
+ HasSubstr("Access chain with unknown pointee type %60 void"));
+}
+
} // namespace
} // namespace spirv
} // namespace reader
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index a09401c..29a952e 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -183,6 +183,11 @@
TypedExpression::~TypedExpression() {}
+void TypedExpression::reset(TypedExpression&& other) {
+ type = other.type;
+ expr = std::move(other.expr);
+}
+
ParserImpl::ParserImpl(Context* ctx, const std::vector<uint32_t>& spv_binary)
: Reader(ctx),
spv_binary_(spv_binary),
@@ -786,6 +791,10 @@
"SPIR-V type with ID: "
<< var.type_id();
}
+ if (!ast_type->IsPointer()) {
+ return Fail() << "variable with ID " << var.result_id()
+ << " has non-pointer type " << var.type_id();
+ }
auto* ast_store_type = ast_type->AsPointer()->type();
auto ast_var =
MakeVariable(var.result_id(), ast_storage_class, ast_store_type);
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h
index 554fe27..d99b73c 100644
--- a/src/reader/spirv/parser_impl.h
+++ b/src/reader/spirv/parser_impl.h
@@ -65,6 +65,9 @@
TypedExpression(TypedExpression&& other);
/// Destructor
~TypedExpression();
+ /// Takes values from another typed expression.
+ /// @param other the other typed expression
+ void reset(TypedExpression&& other);
/// The type
ast::type::Type* type;
/// The expression
diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc
index d46ab0a..ef06c26 100644
--- a/src/reader/spirv/parser_impl_module_var_test.cc
+++ b/src/reader/spirv/parser_impl_module_var_test.cc
@@ -99,6 +99,20 @@
"AST type for SPIR-V type with ID: 3"));
}
+TEST_F(SpvParserTest, ModuleScopeVar_NonPointerType) {
+ auto* p = parser(test::Assemble(R"(
+ %float = OpTypeFloat 32
+ %5 = OpTypeFunction %float
+ %3 = OpTypePointer Private %5
+ %52 = OpVariable %float Private
+ )"));
+ EXPECT_TRUE(p->BuildInternalModule());
+ EXPECT_FALSE(p->RegisterTypes());
+ EXPECT_THAT(
+ p->error(),
+ HasSubstr("SPIR-V pointer type with ID 3 has invalid pointee type 5"));
+}
+
TEST_F(SpvParserTest, ModuleScopeVar_AnonWorkgroupVar) {
auto* p = parser(test::Assemble(R"(
%float = OpTypeFloat 32
diff --git a/src/reader/spirv/spirv_tools_helpers_test.cc b/src/reader/spirv/spirv_tools_helpers_test.cc
index 2a31f7d..81dd5e7 100644
--- a/src/reader/spirv/spirv_tools_helpers_test.cc
+++ b/src/reader/spirv/spirv_tools_helpers_test.cc
@@ -49,6 +49,26 @@
return result;
}
+std::string AssembleFailure(const std::string& spirv_assembly) {
+ // TODO(dneto): Use ScopedTrace?
+
+ // (The target environment doesn't affect assembly.
+ spvtools::SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
+ std::stringstream errors;
+ std::vector<uint32_t> result;
+ tools.SetMessageConsumer([&errors](spv_message_level_t, const char*,
+ const spv_position_t& position,
+ const char* message) {
+ errors << position.line << ":" << position.column << ": " << message;
+ });
+
+ const auto success = tools.Assemble(
+ spirv_assembly, &result, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ EXPECT_FALSE(success);
+
+ return errors.str();
+}
+
} // namespace test
} // namespace spirv
} // namespace reader
diff --git a/src/reader/spirv/spirv_tools_helpers_test.h b/src/reader/spirv/spirv_tools_helpers_test.h
index d6db2a1..ba4ad68 100644
--- a/src/reader/spirv/spirv_tools_helpers_test.h
+++ b/src/reader/spirv/spirv_tools_helpers_test.h
@@ -28,6 +28,10 @@
/// are preserved.
std::vector<uint32_t> Assemble(const std::string& spirv_assembly);
+/// Attempts to assemble given SPIR-V assembly text. Expect it to fail.
+/// @returns the failure message.
+std::string AssembleFailure(const std::string& spirv_assembly);
+
} // namespace test
} // namespace spirv
} // namespace reader