Record when an identifier expression is a swizzle. This CL updates the type determiner to record if an identifier expression is a swizzle and then uses that in the MSL and HLSL generator to output the swizzle name directly. Change-Id: I77c0e1e80dce9e2f09cbbd37476a146b06555ee2 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/38960 Auto-Submit: dan sinclair <dsinclair@chromium.org> Commit-Queue: Ben Clayton <bclayton@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/ast/identifier_expression.h b/src/ast/identifier_expression.h index b6b9b62..d1a2b77 100644 --- a/src/ast/identifier_expression.h +++ b/src/ast/identifier_expression.h
@@ -58,6 +58,12 @@ /// @returns true if this identifier is for an intrinsic bool IsIntrinsic() const { return intrinsic_ != Intrinsic::kNone; } + /// Sets the identifier as a swizzle + void SetIsSwizzle() { is_swizzle_ = true; } + + /// @returns true if this is a swizzle identifier + bool IsSwizzle() const { return is_swizzle_; } + /// Clones this node and all transitive child nodes using the `CloneContext` /// `ctx`. /// @note Semantic information such as resolved expression type and intrinsic @@ -81,6 +87,7 @@ Intrinsic intrinsic_ = Intrinsic::kNone; // Semantic info std::unique_ptr<intrinsic::Signature> intrinsic_sig_; // Semantic info + bool is_swizzle_ = false; // Semantic info }; } // namespace ast
diff --git a/src/type_determiner.cc b/src/type_determiner.cc index f0f8c61..518f5c9 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc
@@ -1118,7 +1118,7 @@ ret = builder_->create<type::Pointer>(ret, ptr->storage_class()); } } else if (auto* vec = data_type->As<type::Vector>()) { - // TODO(dsinclair): Swizzle, record into the identifier experesion + expr->member()->SetIsSwizzle(); auto size = builder_->Symbols().NameFor(expr->member()->symbol()).size(); if (size == 1) {
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index bfc6cd9..ac24076 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc
@@ -1171,7 +1171,13 @@ out << name << "."; } } - out << namer_.NameFor(program_->Symbols().NameFor(ident->symbol())); + + // Swizzles output the name directly + if (ident->IsSwizzle()) { + out << program_->Symbols().NameFor(ident->symbol()); + } else { + out << namer_.NameFor(program_->Symbols().NameFor(ident->symbol())); + } return true; }
diff --git a/src/writer/hlsl/generator_impl_member_accessor_test.cc b/src/writer/hlsl/generator_impl_member_accessor_test.cc index adeaee0..fc885da 100644 --- a/src/writer/hlsl/generator_impl_member_accessor_test.cc +++ b/src/writer/hlsl/generator_impl_member_accessor_test.cc
@@ -948,6 +948,34 @@ )"); } +TEST_F(HlslGeneratorImplTest_MemberAccessor, + EmitExpression_MemberAccessor_Swizzle_xyz) { + auto* vec = Var("my_vec", ast::StorageClass::kPrivate, ty.vec4<f32>()); + td.RegisterVariableForTesting(vec); + + auto* expr = MemberAccessor("my_vec", "xyz"); + + ASSERT_TRUE(td.DetermineResultType(expr)) << td.error(); + + GeneratorImpl& gen = Build(); + ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); + EXPECT_EQ(result(), "my_vec.xyz"); +} + +TEST_F(HlslGeneratorImplTest_MemberAccessor, + EmitExpression_MemberAccessor_Swizzle_gbr) { + auto* vec = Var("my_vec", ast::StorageClass::kPrivate, ty.vec4<f32>()); + td.RegisterVariableForTesting(vec); + + auto* expr = MemberAccessor("my_vec", "gbr"); + + ASSERT_TRUE(td.DetermineResultType(expr)) << td.error(); + + GeneratorImpl& gen = Build(); + ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); + EXPECT_EQ(result(), "my_vec.gbr"); +} + } // namespace } // namespace hlsl } // namespace writer
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index 67a8843..34a1390 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc
@@ -1592,7 +1592,13 @@ out_ << name << "."; } } - out_ << namer_.NameFor(program_->Symbols().NameFor(ident->symbol())); + + // Swizzles get written out directly + if (ident->IsSwizzle()) { + out_ << program_->Symbols().NameFor(ident->symbol()); + } else { + out_ << namer_.NameFor(program_->Symbols().NameFor(ident->symbol())); + } return true; }
diff --git a/src/writer/msl/generator_impl_member_accessor_test.cc b/src/writer/msl/generator_impl_member_accessor_test.cc index 89ac2c6..a6ff053 100644 --- a/src/writer/msl/generator_impl_member_accessor_test.cc +++ b/src/writer/msl/generator_impl_member_accessor_test.cc
@@ -37,6 +37,31 @@ EXPECT_EQ(gen.result(), "str.mem"); } +TEST_F(MslGeneratorImplTest, EmitExpression_MemberAccessor_Swizzle_xyz) { + auto* vec = Var("my_vec", ast::StorageClass::kPrivate, ty.vec4<f32>()); + + td.RegisterVariableForTesting(vec); + + auto* expr = MemberAccessor("my_vec", "xyz"); + ASSERT_TRUE(td.DetermineResultType(expr)) << td.error(); + + GeneratorImpl& gen = Build(); + ASSERT_TRUE(gen.EmitExpression(expr)) << gen.error(); + EXPECT_EQ(gen.result(), "my_vec.xyz"); +} + +TEST_F(MslGeneratorImplTest, EmitExpression_MemberAccessor_Swizzle_gbr) { + auto* vec = Var("my_vec", ast::StorageClass::kPrivate, ty.vec4<f32>()); + td.RegisterVariableForTesting(vec); + + auto* expr = MemberAccessor("my_vec", "gbr"); + ASSERT_TRUE(td.DetermineResultType(expr)) << td.error(); + + GeneratorImpl& gen = Build(); + ASSERT_TRUE(gen.EmitExpression(expr)) << gen.error(); + EXPECT_EQ(gen.result(), "my_vec.gbr"); +} + } // namespace } // namespace msl } // namespace writer