Fixup single element swizzle type determination.
For a swizzle with one element (eg vec.x) the result type is just the
type of the vector, instead of a new vector.
Change-Id: I04ddb22da61db1c3553d465e4e5f9d6b32beae83
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/20062
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index d0dd7ff..b732370 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -400,12 +400,17 @@
if (data_type->IsVector()) {
auto* vec = data_type->AsVector();
- // The vector will have a number of components equal to the length of the
- // swizzle. This assumes the validator will check that the swizzle
- // is correct.
- expr->set_result_type(
- ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
- vec->type(), expr->member()->name().size())));
+ auto size = expr->member()->name().size();
+ if (size == 1) {
+ // A single element swizzle is just the type of the vector.
+ expr->set_result_type(vec->type());
+ } else {
+ // The vector will have a number of components equal to the length of the
+ // swizzle. This assumes the validator will check that the swizzle
+ // is correct.
+ expr->set_result_type(ctx_.type_mgr().Get(
+ std::make_unique<ast::type::VectorType>(vec->type(), size)));
+ }
return true;
}
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index b1c6cd8..d6ebdbf 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -745,6 +745,26 @@
EXPECT_EQ(mem.result_type()->AsVector()->size(), 2u);
}
+TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+
+ auto var = std::make_unique<ast::Variable>("my_vec", ast::StorageClass::kNone,
+ &vec3);
+ mod()->AddGlobalVariable(std::move(var));
+
+ // Register the global
+ EXPECT_TRUE(td()->Determine());
+
+ auto ident = std::make_unique<ast::IdentifierExpression>("my_vec");
+ auto swizzle = std::make_unique<ast::IdentifierExpression>("x");
+
+ ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle));
+ EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
+ ASSERT_NE(mem.result_type(), nullptr);
+ ASSERT_TRUE(mem.result_type()->IsF32());
+}
+
TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
// struct b {
// vec4<f32> foo