[hlsl] Simplify type emission in HLSL IR printer.
With the addition of the `ByteAddressBuffer` type we can simplify the
`EmitType` method. The `ByteAddressBuffer` can emit itself as part of
the regular `Switch`. The `uniform` handling is pulled out and handled
directly when we emit the global uniform variable.
Bug: 42251045
Change-Id: Ie3451db082dfbfdb604218e557b248c9b7f417ab
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/196695
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/writer/printer/printer.cc b/src/tint/lang/hlsl/writer/printer/printer.cc
index 97b3600..f8cdf56 100644
--- a/src/tint/lang/hlsl/writer/printer/printer.cc
+++ b/src/tint/lang/hlsl/writer/printer/printer.cc
@@ -262,8 +262,7 @@
out << "inout ";
}
}
- EmitTypeAndName(out, param->Type(), core::AddressSpace::kUndefined,
- core::Access::kUndefined, NameOf(param));
+ EmitTypeAndName(out, param->Type(), NameOf(param));
}
out << ") {";
@@ -281,7 +280,7 @@
auto name = UniqueIdentifier("ary_ret");
out << "typedef ";
- EmitTypeAndName(out, ty, core::AddressSpace::kUndefined, core::Access::kReadWrite, name);
+ EmitTypeAndName(out, ty, name);
out << ";\n" << name;
}
@@ -514,19 +513,15 @@
{
const ScopedIndent si(this);
- auto out = Line();
- EmitTypeAndName(out, ptr->StoreType(), core::AddressSpace::kUniform, ptr->Access(),
- NameOf(var->Result(0)));
-
- out << ";";
+ auto array_length = (ptr->StoreType()->Size() + 15) / 16;
+ Line() << "uint4 " << NameOf(var->Result(0)) << "[" << array_length << "];";
}
Line() << "};";
}
void EmitStorageVariable(const core::ir::Var* var, const hlsl::type::ByteAddressBuffer* buf) {
auto out = Line();
- EmitTypeAndName(out, var->Result(0)->Type(), core::AddressSpace::kStorage, buf->Access(),
- NameOf(var->Result(0)));
+ EmitTypeAndName(out, var->Result(0)->Type(), NameOf(var->Result(0)));
auto bp = var->BindingPoint();
TINT_ASSERT(bp.has_value());
@@ -557,8 +552,7 @@
// TODO(dsinclair): Handle PixelLocal::RasterizerOrderedView attribute
auto out = Line();
- EmitTypeAndName(out, var->Result(0)->Type(), ptr->AddressSpace(), ptr->Access(),
- NameOf(var->Result(0)));
+ EmitTypeAndName(out, var->Result(0)->Type(), NameOf(var->Result(0)));
out << RegisterAndSpace(register_space, bp.value()) << ";";
}
@@ -568,7 +562,7 @@
auto space = ptr->AddressSpace();
- EmitTypeAndName(out, var->Result(0)->Type(), space, ptr->Access(), NameOf(var->Result(0)));
+ EmitTypeAndName(out, var->Result(0)->Type(), NameOf(var->Result(0)));
if (var->Initializer()) {
out << " = ";
@@ -598,8 +592,7 @@
// TODO(dsinclair): Investigate using `const` here as well, the AST printer doesn't emit
// const with a let, but we should be able to.
- EmitTypeAndName(out, l->Result(0)->Type(), core::AddressSpace::kUndefined,
- core::Access::kUndefined, NameOf(l->Result(0)));
+ EmitTypeAndName(out, l->Result(0)->Type(), NameOf(l->Result(0)));
out << " = ";
EmitValue(out, l->Value());
out << ";";
@@ -1152,13 +1145,9 @@
out << "}";
}
- void EmitTypeAndName(StringStream& out,
- const core::type::Type* type,
- core::AddressSpace address_space,
- core::Access access,
- const std::string& name) {
+ void EmitTypeAndName(StringStream& out, const core::type::Type* type, const std::string& name) {
bool name_printed = false;
- EmitType(out, type, address_space, access, name, &name_printed);
+ EmitType(out, type, name, &name_printed);
if (!name.empty() && !name_printed) {
out << " " << name;
@@ -1167,35 +1156,20 @@
void EmitType(StringStream& out,
const core::type::Type* ty,
- core::AddressSpace address_space = core::AddressSpace::kUndefined,
- core::Access access = core::Access::kUndefined,
const std::string& name = "",
bool* name_printed = nullptr) {
if (name_printed) {
*name_printed = false;
}
- switch (address_space) {
- case core::AddressSpace::kStorage:
- if (access != core::Access::kRead) {
+ Switch(
+ ty,
+ [&](const hlsl::type::ByteAddressBuffer* buf) {
+ if (buf->Access() != core::Access::kRead) {
out << "RW";
}
out << "ByteAddressBuffer";
- return;
- case core::AddressSpace::kUniform: {
- auto array_length = (ty->Size() + 15) / 16;
- out << "uint4 " << name << "[" << array_length << "]";
- if (name_printed) {
- *name_printed = true;
- }
- return;
- }
- default:
- break;
- }
-
- Switch(
- ty, //
+ },
[&](const core::type::Bool*) { out << "bool"; }, //
[&](const core::type::F16*) { out << "float16_t"; }, //
[&](const core::type::F32*) { out << "float"; }, //
@@ -1203,21 +1177,17 @@
[&](const core::type::U32*) { out << "uint"; }, //
[&](const core::type::Void*) { out << "void"; }, //
- [&](const core::type::Atomic* atomic) {
- EmitType(out, atomic->Type(), address_space, access, name);
- },
- [&](const core::type::Array* ary) {
- EmitArrayType(out, ary, address_space, access, name, name_printed);
- },
- [&](const core::type::Vector* vec) { EmitVectorType(out, vec, address_space, access); },
- [&](const core::type::Matrix* mat) { EmitMatrixType(out, mat, address_space, access); },
+ [&](const core::type::Atomic* atomic) { EmitType(out, atomic->Type(), name); },
+ [&](const core::type::Array* ary) { EmitArrayType(out, ary, name, name_printed); },
+ [&](const core::type::Vector* vec) { EmitVectorType(out, vec); },
+ [&](const core::type::Matrix* mat) { EmitMatrixType(out, mat); },
[&](const core::type::Struct* str) {
out << StructName(str);
EmitStructType(str);
},
[&](const core::type::Pointer* p) {
- EmitType(out, p->StoreType(), p->AddressSpace(), p->Access(), name, name_printed);
+ EmitType(out, p->StoreType(), name, name_printed);
},
[&](const core::type::Sampler* sampler) { EmitSamplerType(out, sampler); },
[&](const core::type::Texture* tex) { EmitTextureType(out, tex); },
@@ -1226,8 +1196,6 @@
void EmitArrayType(StringStream& out,
const core::type::Array* ary,
- core::AddressSpace address_space,
- core::Access access,
const std::string& name,
bool* name_printed) {
const core::type::Type* base_type = ary;
@@ -1244,7 +1212,7 @@
sizes.push_back(count.value());
base_type = arr->ElemType();
}
- EmitType(out, base_type, address_space, access);
+ EmitType(out, base_type);
if (!name.empty()) {
out << " " << name;
@@ -1258,10 +1226,7 @@
}
}
- void EmitVectorType(StringStream& out,
- const core::type::Vector* vec,
- core::AddressSpace address_space,
- core::Access access) {
+ void EmitVectorType(StringStream& out, const core::type::Vector* vec) {
auto width = vec->Width();
if (vec->type()->Is<core::type::F32>()) {
out << "float" << width;
@@ -1274,24 +1239,21 @@
} else {
// For example, use "vector<float16_t, N>" for f16 vector.
out << "vector<";
- EmitType(out, vec->type(), address_space, access);
+ EmitType(out, vec->type());
out << ", " << width << ">";
}
}
- void EmitMatrixType(StringStream& out,
- const core::type::Matrix* mat,
- core::AddressSpace address_space,
- core::Access access) {
+ void EmitMatrixType(StringStream& out, const core::type::Matrix* mat) {
if (mat->type()->Is<core::type::F16>()) {
// Use matrix<type, N, M> for f16 matrix
out << "matrix<";
- EmitType(out, mat->type(), address_space, access);
+ EmitType(out, mat->type());
out << ", " << mat->columns() << ", " << mat->rows() << ">";
return;
}
- EmitType(out, mat->type(), address_space, access);
+ EmitType(out, mat->type());
// Note: HLSL's matrices are declared as <type>NxM, where N is the
// number of rows and M is the number of columns. Despite HLSL's
@@ -1441,8 +1403,7 @@
}
out << pre;
- EmitTypeAndName(out, ty, core::AddressSpace::kUndefined, core::Access::kReadWrite,
- mem_name);
+ EmitTypeAndName(out, ty, mem_name);
out << post << ";";
}
}