writer/spirv: Simplify member accesses
Using semantic info.
Change-Id: Iec9a592d9d66930535ead78fab69a6085a57a941
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/50302
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index a3162c2..78f91fa 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -2535,7 +2535,8 @@
offset = utils::RoundUp(align, offset);
auto* sem_member = builder_->create<sem::StructMember>(
- member, const_cast<sem::Type*>(type), offset, align, size);
+ member, const_cast<sem::Type*>(type),
+ static_cast<uint32_t>(sem_members.size()), offset, align, size);
builder_->Sem().Add(member, sem_member);
sem_members.emplace_back(sem_member);
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index 32a900a..88cfbb6 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -901,6 +901,7 @@
auto* sma = Sem().Get(mem)->As<sem::StructMemberAccess>();
ASSERT_NE(sma, nullptr);
EXPECT_EQ(sma->Member()->Type(), ty.f32());
+ EXPECT_EQ(sma->Member()->Index(), 1u);
EXPECT_EQ(sma->Member()->Declaration()->symbol(),
Symbols().Get("second_member"));
}
@@ -925,6 +926,7 @@
auto* sma = Sem().Get(mem)->As<sem::StructMemberAccess>();
ASSERT_NE(sma, nullptr);
EXPECT_EQ(sma->Member()->Type(), ty.f32());
+ EXPECT_EQ(sma->Member()->Index(), 1u);
}
TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle) {
diff --git a/src/sem/struct.cc b/src/sem/struct.cc
index 40f3e06..f8eefc1 100644
--- a/src/sem/struct.cc
+++ b/src/sem/struct.cc
@@ -56,11 +56,13 @@
StructMember::StructMember(ast::StructMember* declaration,
sem::Type* type,
+ uint32_t index,
uint32_t offset,
uint32_t align,
uint32_t size)
: declaration_(declaration),
type_(type),
+ index_(index),
offset_(offset),
align_(align),
size_(size) {}
diff --git a/src/sem/struct.h b/src/sem/struct.h
index fda7b4f..6694d6d 100644
--- a/src/sem/struct.h
+++ b/src/sem/struct.h
@@ -166,11 +166,13 @@
/// Constructor
/// @param declaration the AST declaration node
/// @param type the type of the member
+ /// @param index the index of the member in the structure
/// @param offset the byte offset from the base of the structure
/// @param align the byte alignment of the member
/// @param size the byte size of the member
StructMember(ast::StructMember* declaration,
sem::Type* type,
+ uint32_t index,
uint32_t offset,
uint32_t align,
uint32_t size);
@@ -184,6 +186,9 @@
/// @returns the type of the member
sem::Type* Type() const { return type_; }
+ /// @returns the member index
+ uint32_t Index() const { return index_; }
+
/// @returns byte offset from base of structure
uint32_t Offset() const { return offset_; }
@@ -196,9 +201,10 @@
private:
ast::StructMember* const declaration_;
sem::Type* const type_;
- uint32_t const offset_; // Byte offset from base of structure
- uint32_t const align_; // Byte alignment of the member
- uint32_t const size_; // Byte size of the member
+ uint32_t const index_;
+ uint32_t const offset_;
+ uint32_t const align_;
+ uint32_t const size_;
};
} // namespace sem
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index dee0885..935387d 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -26,6 +26,7 @@
#include "src/sem/depth_texture_type.h"
#include "src/sem/function.h"
#include "src/sem/intrinsic.h"
+#include "src/sem/member_accessor_expression.h"
#include "src/sem/multisampled_texture_type.h"
#include "src/sem/sampled_texture_type.h"
#include "src/sem/struct.h"
@@ -89,24 +90,6 @@
last->Is<ast::FallthroughStatement>();
}
-uint32_t IndexFromName(char name) {
- switch (name) {
- case 'x':
- case 'r':
- return 0;
- case 'y':
- case 'g':
- return 1;
- case 'z':
- case 'b':
- return 2;
- case 'w':
- case 'a':
- return 3;
- }
- return std::numeric_limits<uint32_t>::max();
-}
-
/// Returns the matrix type that is `type` or that is wrapped by
/// one or more levels of an arrays inside of `type`.
/// @param type the given type, which must not be null
@@ -880,23 +863,11 @@
bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
AccessorInfo* info) {
- auto* data_type =
- TypeOf(expr->structure())->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
- auto* expr_type = TypeOf(expr);
+ auto* expr_sem = builder_.Sem().Get(expr);
+ auto* expr_type = expr_sem->Type();
- // If the data_type is a structure we're accessing a member, if it's a
- // vector we're accessing a swizzle.
- if (auto* str = data_type->As<sem::Struct>()) {
- auto* impl = str->Declaration();
- auto symbol = expr->member()->symbol();
-
- uint32_t idx = 0;
- for (; idx < impl->members().size(); ++idx) {
- auto* member = impl->members()[idx];
- if (member->symbol() == symbol) {
- break;
- }
- }
+ if (auto* access = expr_sem->As<sem::StructMemberAccess>()) {
+ uint32_t idx = access->Member()->Index();
if (info->source_type->Is<sem::Pointer>()) {
auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(idx));
@@ -927,106 +898,93 @@
return true;
}
- if (!data_type->Is<sem::Vector>()) {
- error_ = "Member accessor without a struct or vector. Something is wrong";
- return false;
- }
+ if (auto* swizzle = expr_sem->As<sem::Swizzle>()) {
+ // Single element swizzle is either an access chain or a composite extract
+ auto& indices = swizzle->Indices();
+ if (indices.size() == 1) {
+ if (info->source_type->Is<sem::Pointer>()) {
+ auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(indices[0]));
+ if (idx_id == 0) {
+ return 0;
+ }
+ info->access_chain_indices.push_back(idx_id);
+ } else {
+ auto result_type_id = GenerateTypeIfNeeded(expr_type);
+ if (result_type_id == 0) {
+ return 0;
+ }
- // TODO(dsinclair): Swizzle stuff
- auto swiz = builder_.Symbols().NameFor(expr->member()->symbol());
- // Single element swizzle is either an access chain or a composite extract
- if (swiz.size() == 1) {
- auto val = IndexFromName(swiz[0]);
- if (val == std::numeric_limits<uint32_t>::max()) {
- error_ = "invalid swizzle name: " + swiz;
- return false;
+ auto extract = result_op();
+ auto extract_id = extract.to_i();
+ if (!push_function_inst(
+ spv::Op::OpCompositeExtract,
+ {Operand::Int(result_type_id), extract,
+ Operand::Int(info->source_id), Operand::Int(indices[0])})) {
+ return false;
+ }
+
+ info->source_id = extract_id;
+ info->source_type = expr_type;
+ }
+ return true;
}
- if (info->source_type->Is<sem::Pointer>()) {
- auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(val));
- if (idx_id == 0) {
- return 0;
- }
- info->access_chain_indices.push_back(idx_id);
- } else {
- auto result_type_id = GenerateTypeIfNeeded(expr_type);
+ // Store the type away as it may change if we run the access chain
+ auto* incoming_type = info->source_type;
+
+ // Multi-item extract is a VectorShuffle. We have to emit any existing
+ // access chain data, then load the access chain and shuffle that.
+ if (!info->access_chain_indices.empty()) {
+ auto result_type_id = GenerateTypeIfNeeded(info->source_type);
if (result_type_id == 0) {
return 0;
}
-
auto extract = result_op();
auto extract_id = extract.to_i();
- if (!push_function_inst(
- spv::Op::OpCompositeExtract,
- {Operand::Int(result_type_id), extract,
- Operand::Int(info->source_id), Operand::Int(val)})) {
+
+ OperandList ops = {Operand::Int(result_type_id), extract,
+ Operand::Int(info->source_id)};
+ for (auto id : info->access_chain_indices) {
+ ops.push_back(Operand::Int(id));
+ }
+
+ if (!push_function_inst(spv::Op::OpAccessChain, ops)) {
return false;
}
- info->source_id = extract_id;
- info->source_type = expr_type;
+ info->source_id = GenerateLoadIfNeeded(expr_type, extract_id);
+ info->source_type = expr_type->UnwrapPtrIfNeeded();
+ info->access_chain_indices.clear();
}
+
+ auto result_type_id = GenerateTypeIfNeeded(expr_type);
+ if (result_type_id == 0) {
+ return false;
+ }
+
+ auto vec_id = GenerateLoadIfNeeded(incoming_type, info->source_id);
+
+ auto result = result_op();
+ auto result_id = result.to_i();
+
+ OperandList ops = {Operand::Int(result_type_id), result,
+ Operand::Int(vec_id), Operand::Int(vec_id)};
+
+ for (auto idx : indices) {
+ ops.push_back(Operand::Int(idx));
+ }
+
+ if (!push_function_inst(spv::Op::OpVectorShuffle, ops)) {
+ return false;
+ }
+ info->source_id = result_id;
+ info->source_type = expr_type;
return true;
}
- // Store the type away as it may change if we run the access chain
- auto* incoming_type = info->source_type;
-
- // Multi-item extract is a VectorShuffle. We have to emit any existing access
- // chain data, then load the access chain and shuffle that.
- if (!info->access_chain_indices.empty()) {
- auto result_type_id = GenerateTypeIfNeeded(info->source_type);
- if (result_type_id == 0) {
- return 0;
- }
- auto extract = result_op();
- auto extract_id = extract.to_i();
-
- OperandList ops = {Operand::Int(result_type_id), extract,
- Operand::Int(info->source_id)};
- for (auto id : info->access_chain_indices) {
- ops.push_back(Operand::Int(id));
- }
-
- if (!push_function_inst(spv::Op::OpAccessChain, ops)) {
- return false;
- }
-
- info->source_id = GenerateLoadIfNeeded(expr_type, extract_id);
- info->source_type = expr_type->UnwrapPtrIfNeeded();
- info->access_chain_indices.clear();
- }
-
- auto result_type_id = GenerateTypeIfNeeded(expr_type);
- if (result_type_id == 0) {
- return false;
- }
-
- auto vec_id = GenerateLoadIfNeeded(incoming_type, info->source_id);
-
- auto result = result_op();
- auto result_id = result.to_i();
-
- OperandList ops = {Operand::Int(result_type_id), result, Operand::Int(vec_id),
- Operand::Int(vec_id)};
-
- for (uint32_t i = 0; i < swiz.size(); ++i) {
- auto val = IndexFromName(swiz[i]);
- if (val == std::numeric_limits<uint32_t>::max()) {
- error_ = "invalid swizzle name: " + swiz;
- return false;
- }
-
- ops.push_back(Operand::Int(val));
- }
-
- if (!push_function_inst(spv::Op::OpVectorShuffle, ops)) {
- return false;
- }
- info->source_id = result_id;
- info->source_type = expr_type;
-
- return true;
+ TINT_ICE(builder_.Diagnostics())
+ << "unhandled member index type: " << expr_sem->TypeInfo().name;
+ return false;
}
uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {