Record when an identifier expression is a swizzle.
This CL updates the type determiner to record if an identifier
expression is a swizzle and then uses that in the MSL and HLSL generator
to output the swizzle name directly.
Change-Id: I77c0e1e80dce9e2f09cbbd37476a146b06555ee2
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/38960
Auto-Submit: dan sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/ast/identifier_expression.h b/src/ast/identifier_expression.h
index b6b9b62..d1a2b77 100644
--- a/src/ast/identifier_expression.h
+++ b/src/ast/identifier_expression.h
@@ -58,6 +58,12 @@
/// @returns true if this identifier is for an intrinsic
bool IsIntrinsic() const { return intrinsic_ != Intrinsic::kNone; }
+ /// Sets the identifier as a swizzle
+ void SetIsSwizzle() { is_swizzle_ = true; }
+
+ /// @returns true if this is a swizzle identifier
+ bool IsSwizzle() const { return is_swizzle_; }
+
/// Clones this node and all transitive child nodes using the `CloneContext`
/// `ctx`.
/// @note Semantic information such as resolved expression type and intrinsic
@@ -81,6 +87,7 @@
Intrinsic intrinsic_ = Intrinsic::kNone; // Semantic info
std::unique_ptr<intrinsic::Signature> intrinsic_sig_; // Semantic info
+ bool is_swizzle_ = false; // Semantic info
};
} // namespace ast
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index f0f8c61..518f5c9 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -1118,7 +1118,7 @@
ret = builder_->create<type::Pointer>(ret, ptr->storage_class());
}
} else if (auto* vec = data_type->As<type::Vector>()) {
- // TODO(dsinclair): Swizzle, record into the identifier experesion
+ expr->member()->SetIsSwizzle();
auto size = builder_->Symbols().NameFor(expr->member()->symbol()).size();
if (size == 1) {
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index bfc6cd9..ac24076 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -1171,7 +1171,13 @@
out << name << ".";
}
}
- out << namer_.NameFor(program_->Symbols().NameFor(ident->symbol()));
+
+ // Swizzles output the name directly
+ if (ident->IsSwizzle()) {
+ out << program_->Symbols().NameFor(ident->symbol());
+ } else {
+ out << namer_.NameFor(program_->Symbols().NameFor(ident->symbol()));
+ }
return true;
}
diff --git a/src/writer/hlsl/generator_impl_member_accessor_test.cc b/src/writer/hlsl/generator_impl_member_accessor_test.cc
index adeaee0..fc885da 100644
--- a/src/writer/hlsl/generator_impl_member_accessor_test.cc
+++ b/src/writer/hlsl/generator_impl_member_accessor_test.cc
@@ -948,6 +948,34 @@
)");
}
+TEST_F(HlslGeneratorImplTest_MemberAccessor,
+ EmitExpression_MemberAccessor_Swizzle_xyz) {
+ auto* vec = Var("my_vec", ast::StorageClass::kPrivate, ty.vec4<f32>());
+ td.RegisterVariableForTesting(vec);
+
+ auto* expr = MemberAccessor("my_vec", "xyz");
+
+ ASSERT_TRUE(td.DetermineResultType(expr)) << td.error();
+
+ GeneratorImpl& gen = Build();
+ ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error();
+ EXPECT_EQ(result(), "my_vec.xyz");
+}
+
+TEST_F(HlslGeneratorImplTest_MemberAccessor,
+ EmitExpression_MemberAccessor_Swizzle_gbr) {
+ auto* vec = Var("my_vec", ast::StorageClass::kPrivate, ty.vec4<f32>());
+ td.RegisterVariableForTesting(vec);
+
+ auto* expr = MemberAccessor("my_vec", "gbr");
+
+ ASSERT_TRUE(td.DetermineResultType(expr)) << td.error();
+
+ GeneratorImpl& gen = Build();
+ ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error();
+ EXPECT_EQ(result(), "my_vec.gbr");
+}
+
} // namespace
} // namespace hlsl
} // namespace writer
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index 67a8843..34a1390 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -1592,7 +1592,13 @@
out_ << name << ".";
}
}
- out_ << namer_.NameFor(program_->Symbols().NameFor(ident->symbol()));
+
+ // Swizzles get written out directly
+ if (ident->IsSwizzle()) {
+ out_ << program_->Symbols().NameFor(ident->symbol());
+ } else {
+ out_ << namer_.NameFor(program_->Symbols().NameFor(ident->symbol()));
+ }
return true;
}
diff --git a/src/writer/msl/generator_impl_member_accessor_test.cc b/src/writer/msl/generator_impl_member_accessor_test.cc
index 89ac2c6..a6ff053 100644
--- a/src/writer/msl/generator_impl_member_accessor_test.cc
+++ b/src/writer/msl/generator_impl_member_accessor_test.cc
@@ -37,6 +37,31 @@
EXPECT_EQ(gen.result(), "str.mem");
}
+TEST_F(MslGeneratorImplTest, EmitExpression_MemberAccessor_Swizzle_xyz) {
+ auto* vec = Var("my_vec", ast::StorageClass::kPrivate, ty.vec4<f32>());
+
+ td.RegisterVariableForTesting(vec);
+
+ auto* expr = MemberAccessor("my_vec", "xyz");
+ ASSERT_TRUE(td.DetermineResultType(expr)) << td.error();
+
+ GeneratorImpl& gen = Build();
+ ASSERT_TRUE(gen.EmitExpression(expr)) << gen.error();
+ EXPECT_EQ(gen.result(), "my_vec.xyz");
+}
+
+TEST_F(MslGeneratorImplTest, EmitExpression_MemberAccessor_Swizzle_gbr) {
+ auto* vec = Var("my_vec", ast::StorageClass::kPrivate, ty.vec4<f32>());
+ td.RegisterVariableForTesting(vec);
+
+ auto* expr = MemberAccessor("my_vec", "gbr");
+ ASSERT_TRUE(td.DetermineResultType(expr)) << td.error();
+
+ GeneratorImpl& gen = Build();
+ ASSERT_TRUE(gen.EmitExpression(expr)) << gen.error();
+ EXPECT_EQ(gen.result(), "my_vec.gbr");
+}
+
} // namespace
} // namespace msl
} // namespace writer