TypeDeterminer: Resolve swizzles
Have the TD resolve swizzles down to indices, erroring out if they're not valid.
Resolving these at TD time removes swizzle parsing in the HLSL writer, and is generally useful information.
If we don't sanitize in the TD, we can end up trying to construct a resulting vector of an invalid size (> 4) triggering an assert in the type::Vector constructor.
Fixed: chromium:1180634
Bug: tint:79
Change-Id: If1282c933d65eb02d26a8dc7e190f27801ef9dc5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/42221
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/program_builder.h b/src/program_builder.h
index 24a0186..d4b8ff6 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -486,7 +486,10 @@
/// @param expr the expression
/// @return expr
- ast::Expression* Expr(ast::Expression* expr) { return expr; }
+ template <typename T>
+ traits::EnableIfIsType<T, ast::Expression>* Expr(T* expr) {
+ return expr;
+ }
/// @param name the identifier name
/// @return an ast::IdentifierExpression with the given name
@@ -948,7 +951,7 @@
/// @param idx the index argument for the array accessor expression
/// @returns a `ast::MemberAccessorExpression` that indexes `obj` with `idx`
template <typename OBJ, typename IDX>
- ast::Expression* MemberAccessor(OBJ&& obj, IDX&& idx) {
+ ast::MemberAccessorExpression* MemberAccessor(OBJ&& obj, IDX&& idx) {
return create<ast::MemberAccessorExpression>(Expr(std::forward<OBJ>(obj)),
Expr(std::forward<IDX>(idx)));
}
diff --git a/src/semantic/member_accessor_expression.h b/src/semantic/member_accessor_expression.h
index 4a7d7f6..71268b4 100644
--- a/src/semantic/member_accessor_expression.h
+++ b/src/semantic/member_accessor_expression.h
@@ -15,6 +15,8 @@
#ifndef SRC_SEMANTIC_MEMBER_ACCESSOR_EXPRESSION_H_
#define SRC_SEMANTIC_MEMBER_ACCESSOR_EXPRESSION_H_
+#include <vector>
+
#include "src/semantic/expression.h"
namespace tint {
@@ -29,17 +31,24 @@
/// @param declaration the AST node
/// @param type the resolved type of the expression
/// @param statement the statement that owns this expression
- /// @param is_swizzle true if this member access is for a vector swizzle
+ /// @param swizzle if this member access is for a vector swizzle, the swizzle
+ /// indices
MemberAccessorExpression(ast::Expression* declaration,
type::Type* type,
Statement* statement,
- bool is_swizzle);
+ std::vector<uint32_t> swizzle);
+
+ /// Destructor
+ ~MemberAccessorExpression() override;
/// @return true if this member access is for a vector swizzle
- bool IsSwizzle() const { return is_swizzle_; }
+ bool IsSwizzle() const { return !swizzle_.empty(); }
+
+ /// @return the swizzle indices, if this is a vector swizzle
+ const std::vector<uint32_t>& Swizzle() const { return swizzle_; }
private:
- bool const is_swizzle_;
+ std::vector<uint32_t> const swizzle_;
};
} // namespace semantic
diff --git a/src/semantic/sem_member_accessor_expression.cc b/src/semantic/sem_member_accessor_expression.cc
index 470f059..fdbaeb3 100644
--- a/src/semantic/sem_member_accessor_expression.cc
+++ b/src/semantic/sem_member_accessor_expression.cc
@@ -19,11 +19,14 @@
namespace tint {
namespace semantic {
-MemberAccessorExpression::MemberAccessorExpression(ast::Expression* declaration,
- type::Type* type,
- Statement* statement,
- bool is_swizzle)
- : Base(declaration, type, statement), is_swizzle_(is_swizzle) {}
+MemberAccessorExpression::MemberAccessorExpression(
+ ast::Expression* declaration,
+ type::Type* type,
+ Statement* statement,
+ std::vector<uint32_t> swizzle)
+ : Base(declaration, type, statement), swizzle_(std::move(swizzle)) {}
+
+MemberAccessorExpression::~MemberAccessorExpression() = default;
} // namespace semantic
} // namespace tint
diff --git a/src/source.h b/src/source.h
index 9b6e74c..381d737 100644
--- a/src/source.h
+++ b/src/source.h
@@ -84,6 +84,13 @@
/// @param e the range end location
inline Range(const Location& b, const Location& e) : begin(b), end(e) {}
+ /// Return a column-shifted Range
+ /// @param n the number of characters to shift by
+ /// @returns a Range with a #begin and #end column shifted by `n`
+ inline Range operator+(size_t n) const {
+ return Range{{begin.line, begin.column + n}, {end.line, end.column + n}};
+ }
+
/// The location of the first character in the range.
Location begin;
/// The location of one-past the last character in the range.
@@ -127,6 +134,13 @@
return Source(Range{range.end}, file_path, file_content);
}
+ /// Return a column-shifted Source
+ /// @param n the number of characters to shift by
+ /// @returns a Source with the range's columns shifted by `n`
+ inline Source operator+(size_t n) const {
+ return Source(range + n, file_path, file_content);
+ }
+
/// range is the span of text this source refers to in #file_path
Range range;
/// file is the optional file path this source refers to
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index fd334c2..066cf0a 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -752,7 +752,7 @@
auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
type::Type* ret = nullptr;
- bool is_swizzle = false;
+ std::vector<uint32_t> swizzle;
if (auto* ty = data_type->As<type::Struct>()) {
auto* strct = ty->impl();
@@ -777,9 +777,42 @@
ret = builder_->create<type::Pointer>(ret, ptr->storage_class());
}
} else if (auto* vec = data_type->As<type::Vector>()) {
- is_swizzle = true;
+ std::string str = builder_->Symbols().NameFor(expr->member()->symbol());
+ auto size = str.size();
+ swizzle.reserve(str.size());
- auto size = builder_->Symbols().NameFor(expr->member()->symbol()).size();
+ for (auto c : str) {
+ switch (c) {
+ case 'x':
+ case 'r':
+ swizzle.emplace_back(0);
+ break;
+ case 'y':
+ case 'g':
+ swizzle.emplace_back(1);
+ break;
+ case 'z':
+ case 'b':
+ swizzle.emplace_back(2);
+ break;
+ case 'w':
+ case 'a':
+ swizzle.emplace_back(3);
+ break;
+ default:
+ diagnostics_.add_error(
+ "invalid vector swizzle character",
+ expr->member()->source().Begin() + swizzle.size());
+ return false;
+ }
+ }
+
+ if (size < 1 || size > 4) {
+ diagnostics_.add_error("invalid vector swizzle size",
+ expr->member()->source());
+ return false;
+ }
+
if (size == 1) {
// A single element swizzle is just the type of the vector.
ret = vec->type();
@@ -788,15 +821,15 @@
ret = builder_->create<type::Pointer>(ret, ptr->storage_class());
}
} else {
- // The vector will have a number of components equal to the length of the
- // swizzle. This assumes the validator will check that the swizzle
+ // The vector will have a number of components equal to the length of
+ // the swizzle. This assumes the validator will check that the swizzle
// is correct.
ret = builder_->create<type::Vector>(vec->type(),
static_cast<uint32_t>(size));
}
} else {
diagnostics_.add_error(
- "v-0007: invalid use of member accessor on a non-vector/non-struct " +
+ "invalid use of member accessor on a non-vector/non-struct " +
data_type->type_name(),
expr->source());
return false;
@@ -804,7 +837,7 @@
builder_->Sem().Add(expr,
builder_->create<semantic::MemberAccessorExpression>(
- expr, ret, current_statement_, is_swizzle));
+ expr, ret, current_statement_, std::move(swizzle)));
SetType(expr, ret);
return true;
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 801260d..6f39628 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -55,6 +55,7 @@
#include "src/semantic/call.h"
#include "src/semantic/expression.h"
#include "src/semantic/function.h"
+#include "src/semantic/member_accessor_expression.h"
#include "src/semantic/statement.h"
#include "src/semantic/variable.h"
#include "src/type/access_control_type.h"
@@ -75,6 +76,7 @@
#include "src/type/u32_type.h"
#include "src/type/vector_type.h"
+using ::testing::ElementsAre;
using ::testing::HasSubstr;
namespace tint {
@@ -1005,7 +1007,7 @@
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
- auto* mem = MemberAccessor("my_vec", "xy");
+ auto* mem = MemberAccessor("my_vec", "xzyw");
WrapInFunction(mem);
EXPECT_TRUE(td()->Determine()) << td()->error();
@@ -1013,13 +1015,14 @@
ASSERT_NE(TypeOf(mem), nullptr);
ASSERT_TRUE(TypeOf(mem)->Is<type::Vector>());
EXPECT_TRUE(TypeOf(mem)->As<type::Vector>()->type()->Is<type::F32>());
- EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 2u);
+ EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 4u);
+ EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(0, 2, 1, 3));
}
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
- auto* mem = MemberAccessor("my_vec", "x");
+ auto* mem = MemberAccessor("my_vec", "b");
WrapInFunction(mem);
EXPECT_TRUE(td()->Determine()) << td()->error();
@@ -1029,6 +1032,34 @@
auto* ptr = TypeOf(mem)->As<type::Pointer>();
ASSERT_TRUE(ptr->type()->Is<type::F32>());
+ EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(2));
+}
+
+TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_BadChar) {
+ Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
+
+ auto* ident = create<ast::IdentifierExpression>(
+ Source{{Source::Location{3, 3}, Source::Location{3, 7}}},
+ Symbols().Register("xyqz"));
+
+ auto* mem = MemberAccessor("my_vec", ident);
+ WrapInFunction(mem);
+
+ EXPECT_FALSE(td()->Determine());
+ EXPECT_EQ(td()->error(), "3:5 error: invalid vector swizzle character");
+}
+
+TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_BadLength) {
+ Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
+
+ auto* ident = create<ast::IdentifierExpression>(
+ Source{{Source::Location{3, 3}, Source::Location{3, 8}}},
+ Symbols().Register("zzzzz"));
+ auto* mem = MemberAccessor("my_vec", ident);
+ WrapInFunction(mem);
+
+ EXPECT_FALSE(td()->Determine());
+ EXPECT_EQ(td()->error(), "3:3 error: invalid vector swizzle size");
}
TEST_F(TypeDeterminerTest, Expr_Accessor_MultiLevel) {
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 0a732f3..7fb193d 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -91,22 +91,6 @@
stmts->last()->Is<ast::FallthroughStatement>();
}
-uint32_t convert_swizzle_to_index(const std::string& swizzle) {
- if (swizzle == "r" || swizzle == "x") {
- return 0;
- }
- if (swizzle == "g" || swizzle == "y") {
- return 1;
- }
- if (swizzle == "b" || swizzle == "z") {
- return 2;
- }
- if (swizzle == "a" || swizzle == "w") {
- return 3;
- }
- return 0;
-}
-
const char* image_format_to_rwtexture_type(type::ImageFormat image_format) {
switch (image_format) {
case type::ImageFormat::kRgba8Unorm:
@@ -2084,11 +2068,13 @@
out << str_member->offset();
} else if (res_type->Is<type::Vector>()) {
+ auto swizzle = builder_.Sem().Get(mem)->Swizzle();
+
// TODO(dsinclair): Swizzle stuff
//
// This must be a single element swizzle if we've got a vector at this
// point.
- if (builder_.Symbols().NameFor(mem->member()->symbol()).size() != 1) {
+ if (swizzle.size() != 1) {
diagnostics_.add_error(
"Encountered multi-element swizzle when should have only one "
"level");
@@ -2098,10 +2084,7 @@
// TODO(dsinclair): All our types are currently 4 bytes (f32, i32, u32)
// so this is assuming 4. This will need to be fixed when we get f16 or
// f64 types.
- out << "(4 * "
- << convert_swizzle_to_index(
- builder_.Symbols().NameFor(mem->member()->symbol()))
- << ")";
+ out << "(4 * " << swizzle[0] << ")";
} else {
diagnostics_.add_error("Invalid result type for member accessor: " +
res_type->type_name());