[hlsl] Implement more type emission.
This Cl fixes the emission of array types and adds nested struct
emission.
Bug: 42251045
Change-Id: Id2a05336e684a6c5bd4e2d0eab7fa00eaa959b67
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/194020
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/hlsl/writer/constant_test.cc b/src/tint/lang/hlsl/writer/constant_test.cc
index b0c7646..8b31b9d 100644
--- a/src/tint/lang/hlsl/writer/constant_test.cc
+++ b/src/tint/lang/hlsl/writer/constant_test.cc
@@ -481,8 +481,7 @@
)");
}
-// TODO(dsinclair): Need `load`
-TEST_F(HlslWriterTest, DISABLED_ConstantTypeMatIdentityF32) {
+TEST_F(HlslWriterTest, ConstantTypeMatIdentityF32) {
// fn f() {
// var m_1: mat4x4<f32> = mat4x4<f32>();
// var m_2: mat4x4<f32> = mat4x4<f32>(m_1);
@@ -496,8 +495,10 @@
});
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
- EXPECT_EQ(output_.hlsl, R"(void a() {
- float4x4 m_2 = float4x4(m_1);
+ EXPECT_EQ(output_.hlsl, R"(
+void a() {
+ float4x4 m_1 = float4x4((0.0f).xxxx, (0.0f).xxxx, (0.0f).xxxx, (0.0f).xxxx);
+ float4x4 m_2 = m_1;
}
[numthreads(1, 1, 1)]
@@ -507,8 +508,7 @@
)");
}
-// TODO(dsinclair): Need `load`
-TEST_F(HlslWriterTest, DISABLED_ConstantTypeMatIdentityF16) {
+TEST_F(HlslWriterTest, ConstantTypeMatIdentityF16) {
// fn f() {
// var m_1: mat4x4<f16> = mat4x4<f16>();
// var m_2: mat4x4<f16> = mat4x4<f16>(m_1);
@@ -522,8 +522,10 @@
});
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
- EXPECT_EQ(output_.hlsl, R"(void a() {
- matrix<float16_t, 4, 4> m_2 = matrix<float16_t, 4, 4>(m_1);
+ EXPECT_EQ(output_.hlsl, R"(
+void a() {
+ matrix<float16_t, 4, 4> m_1 = matrix<float16_t, 4, 4>((float16_t(0.0h)).xxxx, (float16_t(0.0h)).xxxx, (float16_t(0.0h)).xxxx, (float16_t(0.0h)).xxxx);
+ matrix<float16_t, 4, 4> m_2 = m_1;
}
[numthreads(1, 1, 1)]
@@ -533,7 +535,7 @@
)");
}
-TEST_F(HlslWriterTest, DISABLED_ConstantTypeArray) {
+TEST_F(HlslWriterTest, DISABLED_ConstantTypeArrayFunctionReturn) {
auto* f = b.Function("a", ty.array<vec3<f32>, 3>());
b.Append(f->Block(), [&] {
b.Return(f,
@@ -545,8 +547,7 @@
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
EXPECT_EQ(output_.hlsl, R"(typedef float3 a_ret[3]
a_ret a() {
- float3 tint_symbol[3] = {float3(1.0f, 2.0f, 3.0f), float3(4.0f, 5.0f, 6.0f), float3(7.0f, 8.0f, 9.0f)};
- return tint_symbol;
+ return {float3(1.0f, 2.0f, 3.0f), float3(4.0f, 5.0f, 6.0f), float3(7.0f, 8.0f, 9.0f)};
}
[numthreads(1, 1, 1)]
@@ -556,13 +557,14 @@
)");
}
-TEST_F(HlslWriterTest, ConstantType_Array_Empty) {
+TEST_F(HlslWriterTest, DISABLED_ConstantTypeArrayEmptyFunctionReturn) {
auto* f = b.Function("a", ty.array<vec3<f32>, 3>());
b.Append(f->Block(), [&] { b.Return(f, b.Zero<array<vec3<f32>, 3>>()); });
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
EXPECT_EQ(output_.hlsl, R"(
-float3[3] a() {
+typedef float3 a_ret[3];
+a_ret a() {
return (float3[3])0;
}
@@ -573,6 +575,46 @@
)");
}
+TEST_F(HlslWriterTest, ConstantTypeArray) {
+ auto* f = b.Function("a", ty.void_(), core::ir::Function::PipelineStage::kCompute);
+ f->SetWorkgroupSize(1, 1, 1);
+
+ b.Append(f->Block(), [&] {
+ b.Var("v", b.Composite(ty.array<vec3<f32>, 3>(), b.Composite(ty.vec3<f32>(), 1_f, 2_f, 3_f),
+ b.Composite(ty.vec3<f32>(), 4_f, 5_f, 6_f),
+ b.Composite(ty.vec3<f32>(), 7_f, 8_f, 9_f)));
+ b.Return(f);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+[numthreads(1, 1, 1)]
+void a() {
+ float3 v[3] = {float3(1.0f, 2.0f, 3.0f), float3(4.0f, 5.0f, 6.0f), float3(7.0f, 8.0f, 9.0f)};
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, ConstantTypeArrayEmpty) {
+ auto* f = b.Function("a", ty.void_(), core::ir::Function::PipelineStage::kCompute);
+ f->SetWorkgroupSize(1, 1, 1);
+
+ b.Append(f->Block(), [&] {
+ b.Var("v", b.Zero<array<vec3<f32>, 3>>());
+ b.Return(f);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+[numthreads(1, 1, 1)]
+void a() {
+ float3 v[3] = (float3[3])0;
+}
+
+)");
+}
+
// TODO(dsinclair): needs `construct`
TEST_F(HlslWriterTest, DISABLED_ConstantTypeStruct) {
Vector members{
@@ -608,6 +650,52 @@
)");
}
+TEST_F(HlslWriterTest, ConstantTypeStructNested) {
+ Vector members_a{
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("d"), ty.i32(), 0u, 0u, 4u, 4u,
+ core::type::StructMemberAttributes{}),
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("e"), ty.f32(), 1u, 4u, 4u, 4u,
+ core::type::StructMemberAttributes{}),
+ };
+ auto* a_strct = ty.Struct(b.ir.symbols.New("A"), std::move(members_a));
+
+ Vector members_s{
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("a"), ty.i32(), 0u, 0u, 4u, 4u,
+ core::type::StructMemberAttributes{}),
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("b"), ty.f32(), 1u, 4u, 4u, 4u,
+ core::type::StructMemberAttributes{}),
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("c"), a_strct, 2u, 8u, 8u, 8u,
+ core::type::StructMemberAttributes{}),
+ };
+ auto* s_strct = ty.Struct(b.ir.symbols.New("S"), std::move(members_s));
+
+ auto* f = b.Function("a", s_strct);
+ b.Append(f->Block(), [&] { b.Return(f, b.Zero(s_strct)); });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(struct A {
+ int d;
+ float e;
+};
+
+struct S {
+ int a;
+ float b;
+ A c;
+};
+
+
+S a() {
+ return (S)0;
+}
+
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+}
+
+)");
+}
+
TEST_F(HlslWriterTest, ConstantTypeStructEmpty) {
Vector members{
ty.Get<core::type::StructMember>(b.ir.symbols.New("a"), ty.i32(), 0u, 0u, 4u, 4u,
@@ -616,6 +704,8 @@
core::type::StructMemberAttributes{}),
ty.Get<core::type::StructMember>(b.ir.symbols.New("c"), ty.vec3<i32>(), 2u, 8u, 16u, 16u,
core::type::StructMemberAttributes{}),
+ ty.Get<core::type::StructMember>(b.ir.symbols.New("d"), ty.array<f32, 3>(), 2u, 8u, 16u,
+ 16u, core::type::StructMemberAttributes{}),
};
auto* strct = ty.Struct(b.ir.symbols.New("S"), std::move(members));
@@ -627,8 +717,10 @@
int a;
float b;
int3 c;
+ float d[3];
};
+
S a() {
return (S)0;
}
diff --git a/src/tint/lang/hlsl/writer/printer/printer.cc b/src/tint/lang/hlsl/writer/printer/printer.cc
index 813937f..9d81451 100644
--- a/src/tint/lang/hlsl/writer/printer/printer.cc
+++ b/src/tint/lang/hlsl/writer/printer/printer.cc
@@ -584,7 +584,7 @@
void EmitConstantStruct(StringStream& out,
const core::constant::Value* c,
const core::type::Struct* s) {
- EmitStructType(&preamble_buffer_, s);
+ EmitStructType(s);
if (c->AllZero()) {
out << "(" << StructName(s) << ")0";
@@ -592,6 +592,19 @@
}
}
+ void EmitTypeAndName(StringStream& out,
+ const core::type::Type* type,
+ core::AddressSpace address_space,
+ core::Access access,
+ const std::string& name) {
+ bool name_printed = false;
+ EmitType(out, type, address_space, access, name, &name_printed);
+
+ if (!name.empty() && !name_printed) {
+ out << " " << name;
+ }
+ }
+
void EmitType(StringStream& out,
const core::type::Type* ty,
core::AddressSpace address_space = core::AddressSpace::kUndefined,
@@ -633,36 +646,30 @@
[&](const core::type::Atomic* atomic) {
EmitType(out, atomic->Type(), address_space, access, name);
},
- [&](const core::type::Array* ary) { EmitArrayType(out, ary, address_space, access); },
+ [&](const core::type::Array* ary) {
+ EmitArrayType(out, ary, address_space, access, name, name_printed);
+ },
[&](const core::type::Vector* vec) { EmitVectorType(out, vec, address_space, access); },
[&](const core::type::Matrix* mat) { EmitMatrixType(out, mat, address_space, access); },
- [&](const core::type::Struct* str) { out << StructName(str); },
+ [&](const core::type::Struct* str) {
+ out << StructName(str);
+ EmitStructType(str);
+ },
[&](const core::type::Pointer* p) {
- EmitType(out, p->StoreType(), p->AddressSpace(), p->Access());
+ EmitType(out, p->StoreType(), p->AddressSpace(), p->Access(), name, name_printed);
},
[&](const core::type::Sampler* sampler) { EmitSamplerType(out, sampler); },
[&](const core::type::Texture* tex) { EmitTextureType(out, tex); },
TINT_ICE_ON_NO_MATCH);
}
- void EmitTypeAndName(StringStream& out,
- const core::type::Type* type,
- core::AddressSpace address_space,
- core::Access access,
- const std::string& name) {
- bool name_printed = false;
- EmitType(out, type, address_space, access, name, &name_printed);
-
- if (!name.empty() && !name_printed) {
- out << " " << name;
- }
- }
-
void EmitArrayType(StringStream& out,
const core::type::Array* ary,
core::AddressSpace address_space,
- core::Access access) {
+ core::Access access,
+ const std::string& name,
+ bool* name_printed) {
const core::type::Type* base_type = ary;
std::vector<uint32_t> sizes;
while (auto* arr = base_type->As<core::type::Array>()) {
@@ -679,6 +686,13 @@
}
EmitType(out, base_type, address_space, access);
+ if (!name.empty()) {
+ out << " " << name;
+ if (name_printed) {
+ *name_printed = true;
+ }
+ }
+
for (const uint32_t size : sizes) {
out << "[" << size << "]";
}
@@ -800,23 +814,25 @@
out << "State";
}
- void EmitStructType(TextBuffer* b, const core::type::Struct* str) {
+ void EmitStructType(const core::type::Struct* str) {
auto it = emitted_structs_.emplace(str);
if (!it.second) {
return;
}
- Line(b) << "struct " << StructName(str) << " {";
+ TextBuffer str_buf;
+ Line(&str_buf) << "struct " << StructName(str) << " {";
{
- const ScopedIndent si(b);
+ const ScopedIndent si(&str_buf);
for (auto* mem : str->Members()) {
auto mem_name = mem->Name().Name();
auto* ty = mem->Type();
- auto out = Line(b);
- std::string pre, post;
+ auto out = Line(&str_buf);
auto& attributes = mem->Attributes();
+ std::string pre;
+ std::string post;
if (auto location = attributes.location) {
auto& pipeline_stage_uses = str->PipelineStageUses();
if (TINT_UNLIKELY(pipeline_stage_uses.Count() != 1)) {
@@ -871,7 +887,10 @@
}
}
- Line(b) << "};";
+ Line(&str_buf) << "};";
+ Line(&str_buf) << "";
+
+ preamble_buffer_.Append(str_buf);
}
std::string builtin_to_attribute(core::BuiltinValue builtin) const {