[hlsl] Emit constant structures.
This CL extends the constant struct emission to handle non-zero
structures.
Bug: 42251045
Change-Id: I2dc6e5beeb205027d6c8189ddde8399d4338bf89
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/195554
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: James Price <jrprice@google.com>
Auto-Submit: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/writer/constant_test.cc b/src/tint/lang/hlsl/writer/constant_test.cc
index de3d61b..aff0f2a 100644
--- a/src/tint/lang/hlsl/writer/constant_test.cc
+++ b/src/tint/lang/hlsl/writer/constant_test.cc
@@ -25,6 +25,9 @@
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#include <utility>
+
+#include "src/tint/lang/core/fluent_types.h"
#include "src/tint/lang/hlsl/writer/helper_test.h"
using namespace tint::core::fluent_types; // NOLINT
@@ -553,7 +556,7 @@
)");
}
-TEST_F(HlslWriterTest, ConstantTypeStructNested) {
+TEST_F(HlslWriterTest, ConstantTypeStructNestedEmpty) {
Vector members_a{
ty.Get<core::type::StructMember>(b.ir.symbols.New("d"), ty.i32(), 0u, 0u, 4u, 4u,
core::type::StructMemberAttributes{}),
@@ -599,6 +602,134 @@
)");
}
+TEST_F(HlslWriterTest, ConstantTypeStructNested) {
+ Vector members_a{
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("e"), ty.vec4<f32>(), 0u, 0u, 16u, 16u,
+ core::type::StructMemberAttributes{}),
+ };
+ auto* a_strct = ty.Struct(b.ir.symbols.New("A"), std::move(members_a));
+
+ Vector members_s{
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("c"), a_strct, 0u, 0u, 16u, 16u,
+ core::type::StructMemberAttributes{}),
+ };
+ auto* s_strct = ty.Struct(b.ir.symbols.New("S"), std::move(members_s));
+
+ auto* f = b.Function("a", s_strct);
+ b.Append(f->Block(), [&] {
+ b.Return(f, b.Construct(s_strct, b.Construct(a_strct, b.Splat(ty.vec4<f32>(), 1_f))));
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(struct A {
+ float4 e;
+};
+
+struct S {
+ A c;
+};
+
+
+S a() {
+ return {{(1.0f).xxxx}};
+}
+
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, ConstantTypeLetStructComposite) {
+ Vector members_a{
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("e"), ty.vec4<f32>(), 0u, 0u, 16u, 16u,
+ core::type::StructMemberAttributes{}),
+ };
+ auto* a_strct = ty.Struct(b.ir.symbols.New("A"), std::move(members_a));
+
+ Vector members_s{
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("c"), a_strct, 0u, 0u, 16u, 16u,
+ core::type::StructMemberAttributes{}),
+ };
+ auto* s_strct = ty.Struct(b.ir.symbols.New("S"), std::move(members_s));
+
+ auto* f = b.Function("a", ty.f32());
+ b.Append(f->Block(), [&] {
+ b.Let("z", b.Composite(s_strct, b.Composite(a_strct, b.Splat(ty.vec4<f32>(), 1_f))));
+ b.Return(f, 1_f);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(struct A {
+ float4 e;
+};
+
+struct S {
+ A c;
+};
+
+
+float a() {
+ S z = {{(1.0f).xxxx}};
+ return 1.0f;
+}
+
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+}
+
+)");
+}
+
+// TODO(dsinclair): Need support for `static const` variables
+TEST_F(HlslWriterTest, DISABLED_ConstantTypeLetStructCompositeModuleScoped) {
+ Vector members_a{
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("e"), ty.vec4<f32>(), 0u, 0u, 16u, 16u,
+ core::type::StructMemberAttributes{}),
+ };
+ auto* a_strct = ty.Struct(b.ir.symbols.New("A"), std::move(members_a));
+
+ Vector members_s{
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("c"), a_strct, 0u, 0u, 16u, 16u,
+ core::type::StructMemberAttributes{}),
+ };
+ auto* s_strct = ty.Struct(b.ir.symbols.New("S"), std::move(members_s));
+
+ b.ir.root_block->Append(b.Var<private_>(
+ "z", b.Composite(s_strct, b.Composite(a_strct, b.Splat(ty.vec4<f32>(), 1_f)))));
+
+ auto* f = b.Function("a", ty.f32());
+ b.Append(f->Block(), [&] {
+ b.Var<function>("t",
+ b.Composite(s_strct, b.Composite(a_strct, b.Splat(ty.vec4<f32>(), 1_f))));
+ b.Return(f, 1_f);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(struct A {
+ float4 e;
+};
+
+struct S {
+ A c;
+};
+
+static const A c_1 = {(1.f).xxxx};
+static const S c_2 = {c_1};
+static S z = c_2;
+float a() {
+ S t = {{(1.0f).xxxx}};
+ return 1.0f;
+}
+
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+}
+
+)");
+}
+
TEST_F(HlslWriterTest, ConstantTypeStructEmpty) {
Vector members{
ty.Get<core::type::StructMember>(b.ir.symbols.New("a"), ty.i32(), 0u, 0u, 4u, 4u,
@@ -635,7 +766,87 @@
)");
}
-TEST_F(HlslWriterTest, ConstantTypeStructStatic) {
+TEST_F(HlslWriterTest, ConstantTypeStruct) {
+ Vector members{
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("a"), ty.i32(), 0u, 0u, 4u, 4u,
+ core::type::StructMemberAttributes{}),
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("b"), ty.f32(), 1u, 4u, 4u, 4u,
+ core::type::StructMemberAttributes{}),
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("c"), ty.vec3<i32>(), 2u, 16u, 16u, 16u,
+ core::type::StructMemberAttributes{}),
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("d"), ty.vec4<f32>(), 2u, 32u, 16u, 16u,
+ core::type::StructMemberAttributes{}),
+ };
+ auto* strct = ty.Struct(b.ir.symbols.New("S"), std::move(members));
+
+ auto* f = b.Function("a", strct);
+ b.Append(f->Block(), [&] {
+ b.Return(f, b.Construct(strct, 1_i, 1_f, b.Splat(ty.vec3<i32>(), 2_i),
+ b.Splat(ty.vec4<f32>(), 3_f)));
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(struct S {
+ int a;
+ float b;
+ int3 c;
+ float4 d;
+};
+
+
+S a() {
+ return {1, 1.0f, (2).xxx, (3.0f).xxxx};
+}
+
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, ConstantTypeLetStruct) {
+ Vector members{
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("a"), ty.i32(), 0u, 0u, 4u, 4u,
+ core::type::StructMemberAttributes{}),
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("b"), ty.f32(), 1u, 4u, 4u, 4u,
+ core::type::StructMemberAttributes{}),
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("c"), ty.vec3<i32>(), 2u, 16u, 16u, 16u,
+ core::type::StructMemberAttributes{}),
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("d"), ty.vec4<f32>(), 2u, 32u, 16u, 16u,
+ core::type::StructMemberAttributes{}),
+ };
+ auto* strct = ty.Struct(b.ir.symbols.New("S"), std::move(members));
+
+ auto* f = b.Function("a", ty.f32());
+ b.Append(f->Block(), [&] {
+ b.Let("z", b.Construct(strct, 1_i, 1_f, b.Splat(ty.vec3<i32>(), 2_i),
+ b.Splat(ty.vec4<f32>(), 3_f)));
+ b.Return(f, 1_f);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(struct S {
+ int a;
+ float b;
+ int3 c;
+ float4 d;
+};
+
+
+float a() {
+ S z = {1, 1.0f, (2).xxx, (3.0f).xxxx};
+ return 1.0f;
+}
+
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, ConstantTypeStructStaticEmpty) {
Vector members{
ty.Get<core::type::StructMember>(b.ir.symbols.New("a"), ty.i32(), 0u, 0u, 4u, 4u,
core::type::StructMemberAttributes{}),
@@ -659,5 +870,30 @@
)");
}
+// TODO(dsinclair): Need suppport for `static const` variables
+TEST_F(HlslWriterTest, DISABLED_ConstantTypeStructStatic) {
+ Vector members{
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("a"), ty.i32(), 0u, 0u, 4u, 4u,
+ core::type::StructMemberAttributes{}),
+ };
+ auto* strct = ty.Struct(b.ir.symbols.New("S"), std::move(members));
+
+ b.Append(b.ir.root_block, [&] { b.Var<private_>("p", b.Construct(strct, 3_i)); });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(struct S {
+ int a;
+};
+
+
+static const
+S p = {3};
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+}
+
+)");
+}
+
} // namespace
} // namespace tint::hlsl::writer
diff --git a/src/tint/lang/hlsl/writer/printer/printer.cc b/src/tint/lang/hlsl/writer/printer/printer.cc
index 6b3fa75..24b2508e 100644
--- a/src/tint/lang/hlsl/writer/printer/printer.cc
+++ b/src/tint/lang/hlsl/writer/printer/printer.cc
@@ -1088,6 +1088,15 @@
out << "(" << StructName(s) << ")0";
return;
}
+
+ out << "{";
+ for (size_t i = 0; i < s->Members().Length(); i++) {
+ if (i > 0) {
+ out << ", ";
+ }
+ EmitConstant(out, c->Index(i));
+ }
+ out << "}";
}
void EmitTypeAndName(StringStream& out,