spirv-reader: refactor swizzle creation
Change-Id: I6a09756026b7cbc436d5f232be9331255615e8c3
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/34040
Auto-Submit: David Neto <dneto@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 5d1e5c4..15286e8 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -172,6 +172,8 @@
namespace {
+constexpr uint32_t kMaxVectorLen = 4;
+
// Gets the AST unary opcode for the given SPIR-V opcode, if any
// @param opcode SPIR-V opcode
// @param ast_unary_op return parameter
@@ -2874,6 +2876,16 @@
return {ast_type, call};
}
+ast::IdentifierExpression* FunctionEmitter::Swizzle(uint32_t i) {
+ if (i >= kMaxVectorLen) {
+ Fail() << "vector component index is larger than " << kMaxVectorLen - 1
+ << ": " << i;
+ return nullptr;
+ }
+ const char* names[] = {"x", "y", "z", "w"};
+ return ast_module_.create<ast::IdentifierExpression>(names[i & 3]);
+}
+
TypedExpression FunctionEmitter::MakeAccessChain(
const spvtools::opt::Instruction& inst) {
if (inst.NumInOperands() < 1) {
@@ -2888,7 +2900,6 @@
// for the base, and then bury that inside nested indexing expressions.
TypedExpression current_expr(MakeOperand(inst, 0));
const auto constants = constant_mgr_->GetOperandConstants(&inst);
- static const char* swizzles[] = {"x", "y", "z", "w"};
const auto base_id = inst.GetSingleWordInOperand(0);
auto ptr_ty_id = def_use_mgr_->GetDef(base_id)->type_id();
@@ -2981,16 +2992,12 @@
<< num_elems << " elements";
return {};
}
- if (uint64_t(index_const_val) >=
- sizeof(swizzles) / sizeof(swizzles[0])) {
+ if (uint64_t(index_const_val) >= kMaxVectorLen) {
Fail() << "internal error: swizzle index " << index_const_val
- << " is too big. Max handled index is "
- << ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
+ << " is too big. Max handled index is " << kMaxVectorLen - 1;
}
- auto* letter_index =
- create<ast::IdentifierExpression>(swizzles[index_const_val]);
- next_expr = create<ast::MemberAccessorExpression>(current_expr.expr,
- letter_index);
+ next_expr = create<ast::MemberAccessorExpression>(
+ current_expr.expr, Swizzle(uint32_t(index_const_val)));
} else {
// Non-constant index. Use array syntax
next_expr = create<ast::ArrayAccessorExpression>(
@@ -3072,7 +3079,6 @@
return create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, literal));
};
- static const char* swizzles[] = {"x", "y", "z", "w"};
const auto composite = inst.GetSingleWordInOperand(0);
auto current_type_id = def_use_mgr_->GetDef(composite)->type_id();
@@ -3102,15 +3108,12 @@
<< " elements";
return {};
}
- if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) {
+ if (index_val >= kMaxVectorLen) {
Fail() << "internal error: swizzle index " << index_val
- << " is too big. Max handled index is "
- << ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
+ << " is too big. Max handled index is " << kMaxVectorLen - 1;
}
- auto* letter_index =
- create<ast::IdentifierExpression>(swizzles[index_val]);
next_expr = create<ast::MemberAccessorExpression>(current_expr.expr,
- letter_index);
+ Swizzle(index_val));
// All vector components are the same type.
current_type_id = current_type_inst->GetSingleWordInOperand(0);
break;
@@ -3124,10 +3127,9 @@
<< " elements";
return {};
}
- if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) {
+ if (index_val >= kMaxVectorLen) {
Fail() << "internal error: swizzle index " << index_val
- << " is too big. Max handled index is "
- << ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
+ << " is too big. Max handled index is " << kMaxVectorLen - 1;
}
// Use array syntax.
next_expr = create<ast::ArrayAccessorExpression>(current_expr.expr,
@@ -3197,7 +3199,6 @@
type_mgr_->GetType(vec1.type_id())->AsVector()->element_count();
// Idiomatic vector accessors.
- const char* swizzles[] = {"x", "y", "z", "w"};
// Generate an ast::TypeConstructor expression.
// Assume the literal indices are valid, and there is a valid number of them.
@@ -3207,16 +3208,13 @@
for (uint32_t i = 2; i < inst.NumInOperands(); ++i) {
const auto index = inst.GetSingleWordInOperand(i);
if (index < vec0_len) {
- assert(index < sizeof(swizzles) / sizeof(swizzles[0]));
values.emplace_back(create<ast::MemberAccessorExpression>(
- MakeExpression(vec0_id).expr,
- create<ast::IdentifierExpression>(swizzles[index])));
+ MakeExpression(vec0_id).expr, Swizzle(index)));
} else if (index < vec0_len + vec1_len) {
const auto sub_index = index - vec0_len;
- assert(sub_index < sizeof(swizzles) / sizeof(swizzles[0]));
+ assert(sub_index < kMaxVectorLen);
values.emplace_back(create<ast::MemberAccessorExpression>(
- MakeExpression(vec1_id).expr,
- create<ast::IdentifierExpression>(swizzles[sub_index])));
+ MakeExpression(vec1_id).expr, Swizzle(sub_index)));
} else if (index == 0xFFFFFFFF) {
// By rule, this maps to OpUndef. Instead, make it zero.
values.emplace_back(parser_impl_.MakeNullValue(result_type->type()));
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index cfebe80..f5c3872 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -672,6 +672,13 @@
/// @returns the associated loop construct, or nullptr
const Construct* SiblingLoopConstruct(const Construct* c) const;
+ /// Returns an identifier expression for the swizzle name of the given
+ /// index into a vector. Emits an error and returns nullptr if the
+ /// index is out of range, i.e. 4 or higher.
+ /// @param i index of the subcomponent
+ /// @returns the identifier expression for the @p i'th component
+ ast::IdentifierExpression* Swizzle(uint32_t i);
+
private:
/// @returns the store type for the OpVariable instruction, or
/// null on failure.
diff --git a/src/reader/spirv/function_misc_test.cc b/src/reader/spirv/function_misc_test.cc
index f5d1bd9..e215fdd 100644
--- a/src/reader/spirv/function_misc_test.cc
+++ b/src/reader/spirv/function_misc_test.cc
@@ -16,6 +16,7 @@
#include <vector>
#include "gmock/gmock.h"
+#include "src/ast/identifier_expression.h"
#include "src/reader/spirv/function.h"
#include "src/reader/spirv/parser_impl.h"
#include "src/reader/spirv/parser_impl_test_helper.h"
@@ -295,6 +296,53 @@
)")) << ToString(fe.ast_body());
}
+// Test swizzle generation.
+
+struct SwizzleCase {
+ uint32_t index;
+ std::string expected_expr;
+ std::string expected_error;
+};
+using SpvParserSwizzleTest =
+ SpvParserTestBase<::testing::TestWithParam<SwizzleCase>>;
+
+TEST_P(SpvParserSwizzleTest, Sample) {
+ // We need a function so we can get a FunctionEmitter.
+ const auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+
+ auto* result = fe.Swizzle(GetParam().index);
+ if (GetParam().expected_error.empty()) {
+ EXPECT_TRUE(fe.success());
+ ASSERT_NE(result, nullptr);
+ std::ostringstream ss;
+ result->to_str(ss, 0);
+ EXPECT_THAT(ss.str(), Eq(GetParam().expected_expr));
+ } else {
+ EXPECT_EQ(result, nullptr);
+ EXPECT_FALSE(fe.success());
+ EXPECT_THAT(p->error(), Eq(GetParam().expected_error));
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ ValidIndex,
+ SpvParserSwizzleTest,
+ ::testing::ValuesIn(std::vector<SwizzleCase>{
+ {0, "Identifier[not set]{x}\n", ""},
+ {1, "Identifier[not set]{y}\n", ""},
+ {2, "Identifier[not set]{z}\n", ""},
+ {3, "Identifier[not set]{w}\n", ""},
+ {4, "", "vector component index is larger than 3: 4"},
+ {99999, "", "vector component index is larger than 3: 99999"}}));
+
// TODO(dneto): OpSizeof : requires Kernel (OpenCL)
} // namespace