[hlsl-writer] Add support for input locations and builtins.
This CL adds the beginning of support for input/output locations and
builtins in the HLSL backend.
Bug: tint:7
Change-Id: I8fb01707b50635a800b0d7317cf4a8f62f12cfca
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/26780
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index eb6edf4..6e53100 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -20,6 +20,7 @@
#include "src/ast/binary_expression.h"
#include "src/ast/bool_literal.h"
#include "src/ast/case_statement.h"
+#include "src/ast/decorated_variable.h"
#include "src/ast/else_statement.h"
#include "src/ast/float_literal.h"
#include "src/ast/identifier_expression.h"
@@ -46,6 +47,11 @@
namespace hlsl {
namespace {
+const char kInStructNameSuffix[] = "in";
+const char kOutStructNameSuffix[] = "out";
+const char kTintStructInVarPrefix[] = "tint_in";
+const char kTintStructOutVarPrefix[] = "tint_out";
+
bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
if (stmts->empty()) {
return false;
@@ -74,6 +80,11 @@
out_ << std::endl;
}
+ for (const auto& ep : module_->entry_points()) {
+ if (!EmitEntryPointData(ep.get())) {
+ return false;
+ }
+ }
for (const auto& func : module_->functions()) {
if (!EmitFunction(func.get())) {
return false;
@@ -89,6 +100,17 @@
return true;
}
+std::string GeneratorImpl::generate_name(const std::string& prefix) {
+ std::string name = prefix;
+ uint32_t i = 0;
+ while (namer_.IsMapped(name)) {
+ name = prefix + "_" + std::to_string(i);
+ ++i;
+ }
+ namer_.RegisterRemappedName(name);
+ return name;
+}
+
std::string GeneratorImpl::current_ep_var_name(VarType type) {
std::string name = "";
switch (type) {
@@ -431,8 +453,12 @@
return false;
}
-bool GeneratorImpl::global_is_in_struct(ast::Variable*) const {
- return false;
+bool GeneratorImpl::global_is_in_struct(ast::Variable* var) const {
+ return var->IsDecorated() &&
+ (var->AsDecorated()->HasLocationDecoration() ||
+ var->AsDecorated()->HasBuiltinDecoration()) &&
+ (var->storage_class() == ast::StorageClass::kInput ||
+ var->storage_class() == ast::StorageClass::kOutput);
}
bool GeneratorImpl::EmitIdentifier(ast::IdentifierExpression* expr) {
@@ -499,6 +525,25 @@
return EmitBlock(stmt->body());
}
+bool GeneratorImpl::has_referenced_var_needing_struct(ast::Function* func) {
+ for (auto data : func->referenced_location_variables()) {
+ auto* var = data.first;
+ if (var->storage_class() == ast::StorageClass::kOutput ||
+ var->storage_class() == ast::StorageClass::kInput) {
+ return true;
+ }
+ }
+
+ for (auto data : func->referenced_builtin_variables()) {
+ auto* var = data.first;
+ if (var->storage_class() == ast::StorageClass::kOutput ||
+ var->storage_class() == ast::StorageClass::kInput) {
+ return true;
+ }
+ }
+ return false;
+}
+
bool GeneratorImpl::EmitFunction(ast::Function* func) {
make_indent();
@@ -507,6 +552,33 @@
return true;
}
+ // TODO(dsinclair): This could be smarter. If the input/outputs for multiple
+ // entry points are the same we could generate a single struct and then have
+ // this determine it's the same struct and just emit once.
+ bool emit_duplicate_functions = func->ancestor_entry_points().size() > 0 &&
+ has_referenced_var_needing_struct(func);
+
+ if (emit_duplicate_functions) {
+ for (const auto& ep_name : func->ancestor_entry_points()) {
+ if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_name)) {
+ return false;
+ }
+ out_ << std::endl;
+ }
+ } else {
+ // Emit as non-duplicated
+ if (!EmitFunctionInternal(func, false, "")) {
+ return false;
+ }
+ out_ << std::endl;
+ }
+
+ return true;
+}
+
+bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
+ bool emit_duplicate_functions,
+ const std::string& ep_name) {
auto name = func->name();
if (!EmitType(func->return_type(), "")) {
@@ -516,6 +588,30 @@
out_ << " " << namer_.NameFor(name) << "(";
bool first = true;
+
+ // If we're emitting duplicate functions that means the function takes
+ // the stage_in or stage_out value from the entry point, emit them.
+ //
+ // We emit both of them if they're there regardless of if they're both used.
+ if (emit_duplicate_functions) {
+ auto in_it = ep_name_to_in_data_.find(ep_name);
+ if (in_it != ep_name_to_in_data_.end()) {
+ out_ << "in " << in_it->second.struct_name << " "
+ << in_it->second.var_name;
+ first = false;
+ }
+
+ auto out_it = ep_name_to_out_data_.find(ep_name);
+ if (out_it != ep_name_to_out_data_.end()) {
+ if (!first) {
+ out_ << ", ";
+ }
+ out_ << "out " << out_it->second.struct_name << " "
+ << out_it->second.var_name;
+ first = false;
+ }
+ }
+
for (const auto& v : func->params()) {
if (!first) {
out_ << ", ";
@@ -533,19 +629,188 @@
out_ << ") ";
+ current_ep_name_ = ep_name;
+
if (!EmitBlockAndNewline(func->body())) {
return false;
}
+ current_ep_name_ = "";
+
+ return true;
+}
+
+bool GeneratorImpl::EmitEntryPointData(ast::EntryPoint* ep) {
+ auto* func = module_->FindFunctionByName(ep->function_name());
+ if (func == nullptr) {
+ error_ = "Unable to find entry point function: " + ep->function_name();
+ return false;
+ }
+
+ std::vector<std::pair<ast::Variable*, ast::VariableDecoration*>> in_variables;
+ std::vector<std::pair<ast::Variable*, ast::VariableDecoration*>>
+ out_variables;
+ for (auto data : func->referenced_location_variables()) {
+ auto* var = data.first;
+ auto* deco = data.second;
+
+ if (var->storage_class() == ast::StorageClass::kInput) {
+ in_variables.push_back({var, deco});
+ } else if (var->storage_class() == ast::StorageClass::kOutput) {
+ out_variables.push_back({var, deco});
+ }
+ }
+
+ for (auto data : func->referenced_builtin_variables()) {
+ auto* var = data.first;
+ auto* deco = data.second;
+
+ if (var->storage_class() == ast::StorageClass::kInput) {
+ in_variables.push_back({var, deco});
+ } else if (var->storage_class() == ast::StorageClass::kOutput) {
+ out_variables.push_back({var, deco});
+ }
+ }
+
+ auto ep_name = ep->name();
+ if (ep_name.empty()) {
+ ep_name = ep->function_name();
+ }
+
+ // TODO(dsinclair): There is a potential bug here. Entry points can have the
+ // same name in WGSL if they have different pipeline stages. This does not
+ // take that into account and will emit duplicate struct names. I'm ignoring
+ // this until https://github.com/gpuweb/gpuweb/issues/662 is resolved as it
+ // may remove this issue and entry point names will need to be unique.
+ if (!in_variables.empty()) {
+ auto in_struct_name = generate_name(ep_name + "_" + kInStructNameSuffix);
+ auto in_var_name = generate_name(kTintStructInVarPrefix);
+ ep_name_to_in_data_[ep_name] = {in_struct_name, in_var_name};
+
+ make_indent();
+ out_ << "struct " << in_struct_name << " {" << std::endl;
+
+ increment_indent();
+
+ for (auto& data : in_variables) {
+ auto* var = data.first;
+ auto* deco = data.second;
+
+ make_indent();
+ if (!EmitType(var->type(), var->name())) {
+ return false;
+ }
+
+ out_ << " " << var->name() << " : ";
+ if (deco->IsLocation()) {
+ out_ << "TEXCOORD" << deco->AsLocation()->value();
+ } else if (deco->IsBuiltin()) {
+ auto attr = builtin_to_attribute(deco->AsBuiltin()->value());
+ if (attr.empty()) {
+ error_ = "unsupported builtin";
+ return false;
+ }
+ out_ << attr;
+ } else {
+ error_ = "unsupported variable decoration for entry point output";
+ return false;
+ }
+ out_ << ";" << std::endl;
+ }
+ decrement_indent();
+ make_indent();
+
+ out_ << "};" << std::endl << std::endl;
+ }
+
+ if (!out_variables.empty()) {
+ auto out_struct_name = generate_name(ep_name + "_" + kOutStructNameSuffix);
+ auto out_var_name = generate_name(kTintStructOutVarPrefix);
+ ep_name_to_out_data_[ep_name] = {out_struct_name, out_var_name};
+
+ make_indent();
+ out_ << "struct " << out_struct_name << " {" << std::endl;
+
+ increment_indent();
+ for (auto& data : out_variables) {
+ auto* var = data.first;
+ auto* deco = data.second;
+
+ make_indent();
+ if (!EmitType(var->type(), var->name())) {
+ return false;
+ }
+
+ out_ << " " << var->name() << " : ";
+
+ if (deco->IsLocation()) {
+ auto loc = deco->AsLocation()->value();
+ if (ep->stage() == ast::PipelineStage::kVertex) {
+ out_ << "TEXCOORD" << loc;
+ } else if (ep->stage() == ast::PipelineStage::kFragment) {
+ out_ << "SV_Target" << loc << "";
+ } else {
+ error_ = "invalid location variable for pipeline stage";
+ return false;
+ }
+ } else if (deco->IsBuiltin()) {
+ auto attr = builtin_to_attribute(deco->AsBuiltin()->value());
+ if (attr.empty()) {
+ error_ = "unsupported builtin";
+ return false;
+ }
+ out_ << attr;
+ } else {
+ error_ = "unsupported variable decoration for entry point output";
+ return false;
+ }
+ out_ << ";" << std::endl;
+ }
+ decrement_indent();
+ make_indent();
+ out_ << "};" << std::endl << std::endl;
+ }
+
return true;
}
+std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const {
+ switch (builtin) {
+ case ast::Builtin::kPosition:
+ return "SV_Position";
+ case ast::Builtin::kVertexIdx:
+ return "SV_VertexID";
+ case ast::Builtin::kInstanceIdx:
+ return "SV_InstanceID";
+ case ast::Builtin::kFrontFacing:
+ return "SV_IsFrontFacing";
+ case ast::Builtin::kFragCoord:
+ return "SV_Position";
+ case ast::Builtin::kFragDepth:
+ return "SV_Depth";
+ // TODO(dsinclair): Ignore for now. This has been removed as a builtin
+ // in the spec. Need to update Tint to match.
+ // https://github.com/gpuweb/gpuweb/pull/824
+ case ast::Builtin::kWorkgroupSize:
+ return "";
+ case ast::Builtin::kLocalInvocationId:
+ return "SV_GroupThreadID";
+ case ast::Builtin::kLocalInvocationIdx:
+ return "SV_GroupIndex";
+ case ast::Builtin::kGlobalInvocationId:
+ return "SV_DispatchThreadID";
+ default:
+ break;
+ }
+ return "";
+}
+
bool GeneratorImpl::EmitEntryPointFunction(ast::EntryPoint* ep) {
make_indent();
- auto current_ep_name = ep->name();
- if (current_ep_name.empty()) {
- current_ep_name = ep->function_name();
+ current_ep_name_ = ep->name();
+ if (current_ep_name_.empty()) {
+ current_ep_name_ = ep->function_name();
}
auto* func = module_->FindFunctionByName(ep->function_name());
@@ -554,19 +819,43 @@
return false;
}
- out_ << "void " << namer_.NameFor(current_ep_name) << "() {" << std::endl;
+ auto out_data = ep_name_to_out_data_.find(current_ep_name_);
+ bool has_out_data = out_data != ep_name_to_out_data_.end();
+ if (has_out_data) {
+ out_ << out_data->second.struct_name;
+ } else {
+ out_ << "void";
+ }
+ out_ << " " << namer_.NameFor(current_ep_name_) << "(";
+
+ auto in_data = ep_name_to_in_data_.find(current_ep_name_);
+ if (in_data != ep_name_to_in_data_.end()) {
+ out_ << in_data->second.struct_name << " " << in_data->second.var_name;
+ }
+ out_ << ") {" << std::endl;
+
increment_indent();
+ if (has_out_data) {
+ make_indent();
+ out_ << out_data->second.struct_name << " " << out_data->second.var_name
+ << ";" << std::endl;
+ }
+
+ generating_entry_point_ = true;
for (const auto& s : *(func->body())) {
if (!EmitStatement(s.get())) {
return false;
}
}
+ generating_entry_point_ = false;
decrement_indent();
make_indent();
out_ << "}" << std::endl;
+ current_ep_name_ = "";
+
return true;
}
diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h
index c24f977..b65b94d 100644
--- a/src/writer/hlsl/generator_impl.h
+++ b/src/writer/hlsl/generator_impl.h
@@ -110,6 +110,19 @@
/// @param func the function to generate
/// @returns true if the function was emitted
bool EmitFunction(ast::Function* func);
+ /// Internal helper for emitting functions
+ /// @param func the function to emit
+ /// @param emit_duplicate_functions set true if we need to duplicate per entry
+ /// point
+ /// @param ep_name the current entry point or blank if none set
+ /// @returns true if the function was emitted.
+ bool EmitFunctionInternal(ast::Function* func,
+ bool emit_duplicate_functions,
+ const std::string& ep_name);
+ /// Handles emitting information for an entry point
+ /// @param ep the entry point
+ /// @returns true if the entry point data was emitted
+ bool EmitEntryPointData(ast::EntryPoint* ep);
/// Handles emitting the entry point function
/// @param ep the entry point
/// @returns true if the entry point function was emitted
@@ -168,6 +181,21 @@
/// @param var the variable to check
/// @returns true if the global is in an input or output struct
bool global_is_in_struct(ast::Variable* var) const;
+ /// Generates a name for the prefix
+ /// @param prefix the prefix of the name to generate
+ /// @returns the name
+ std::string generate_name(const std::string& prefix);
+ /// Converts a builtin to an attribute name
+ /// @param builtin the builtin to convert
+ /// @returns the string name of the builtin or blank on error
+ std::string builtin_to_attribute(ast::Builtin builtin) const;
+ /// Determines if any used module variable requires an input or output struct.
+ /// @param func the function to check
+ /// @returns true if an input or output struct is required.
+ bool has_referenced_var_needing_struct(ast::Function* func);
+
+ /// @returns the namer for testing
+ Namer* namer_for_testing() { return &namer_; }
private:
enum class VarType { kIn, kOut };
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index 2216df5..0983554 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -67,6 +67,7 @@
EXPECT_EQ(g.result(), R"( void my_func() {
return;
}
+
)");
}
@@ -90,6 +91,7 @@
EXPECT_EQ(g.result(), R"( void GeometryShader_tint_0() {
return;
}
+
)");
}
@@ -121,6 +123,7 @@
EXPECT_EQ(g.result(), R"( void my_func(float a, int b) {
return;
}
+
)");
}
@@ -144,7 +147,7 @@
)");
}
-TEST_F(HlslGeneratorImplTest, DISABLED_Emit_Function_EntryPoint_WithInOutVars) {
+TEST_F(HlslGeneratorImplTest, Emit_Function_EntryPoint_WithInOutVars) {
ast::type::VoidType void_type;
ast::type::F32Type f32;
@@ -207,8 +210,7 @@
)");
}
-TEST_F(HlslGeneratorImplTest,
- DISABLED_Emit_Function_EntryPoint_WithInOut_Builtins) {
+TEST_F(HlslGeneratorImplTest, Emit_Function_EntryPoint_WithInOut_Builtins) {
ast::type::VoidType void_type;
ast::type::F32Type f32;
ast::type::VectorType vec4(&f32, 4);
@@ -262,7 +264,7 @@
GeneratorImpl g(&mod);
ASSERT_TRUE(g.Generate()) << g.error();
EXPECT_EQ(g.result(), R"(struct frag_main_in {
- float gl_FragCoord : SV_Position;
+ vector<float, 4> coord : SV_Position;
};
struct frag_main_out {
@@ -271,7 +273,7 @@
frag_main_out frag_main(frag_main_in tint_in) {
frag_main_out tint_out;
- tint_out.depth = tint_in.gl_FragCoord.x;
+ tint_out.depth = tint_in.coord.x;
return tint_out;
}
@@ -377,6 +379,7 @@
EXPECT_EQ(g.result(), R"( ... )");
}
+// TODO(dsinclair): Requires CallExpression
TEST_F(
HlslGeneratorImplTest,
DISABLED_Emit_Function_Called_By_EntryPoints_WithLocationGlobals_And_Params) {
@@ -480,6 +483,7 @@
)");
}
+// TODO(dsinclair): Requires CallExpression
TEST_F(HlslGeneratorImplTest,
DISABLED_Emit_Function_Called_By_EntryPoints_NoUsedGlobals) {
ast::type::VoidType void_type;
@@ -558,6 +562,7 @@
)");
}
+// TODO(dsinclair): Requires CallExpression
TEST_F(
HlslGeneratorImplTest,
DISABLED_Emit_Function_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) {
@@ -1098,6 +1103,7 @@
EXPECT_EQ(g.result(), R"( void my_func(float a[5]) {
return;
}
+
)");
}
diff --git a/src/writer/hlsl/generator_impl_test.cc b/src/writer/hlsl/generator_impl_test.cc
index df128ad..b1b16db 100644
--- a/src/writer/hlsl/generator_impl_test.cc
+++ b/src/writer/hlsl/generator_impl_test.cc
@@ -19,6 +19,7 @@
#include "gtest/gtest.h"
#include "src/ast/entry_point.h"
#include "src/ast/function.h"
+#include "src/ast/identifier_expression.h"
#include "src/ast/module.h"
#include "src/ast/type/void_type.h"
@@ -46,6 +47,66 @@
)");
}
+TEST_F(HlslGeneratorImplTest, InputStructName) {
+ ast::Module m;
+ GeneratorImpl g(&m);
+ ASSERT_EQ(g.generate_name("func_main_in"), "func_main_in");
+}
+
+TEST_F(HlslGeneratorImplTest, InputStructName_ConflictWithExisting) {
+ ast::Module m;
+ GeneratorImpl g(&m);
+
+ // Register the struct name as existing.
+ auto* namer = g.namer_for_testing();
+ namer->NameFor("func_main_out");
+
+ ASSERT_EQ(g.generate_name("func_main_out"), "func_main_out_0");
+}
+
+TEST_F(HlslGeneratorImplTest, NameConflictWith_InputStructName) {
+ ast::Module m;
+ GeneratorImpl g(&m);
+ ASSERT_EQ(g.generate_name("func_main_in"), "func_main_in");
+
+ ast::IdentifierExpression ident("func_main_in");
+ ASSERT_TRUE(g.EmitIdentifier(&ident));
+ EXPECT_EQ(g.result(), "func_main_in_0");
+}
+
+struct HlslBuiltinData {
+ ast::Builtin builtin;
+ const char* attribute_name;
+};
+inline std::ostream& operator<<(std::ostream& out, HlslBuiltinData data) {
+ out << data.builtin;
+ return out;
+}
+using HlslBuiltinConversionTest = testing::TestWithParam<HlslBuiltinData>;
+TEST_P(HlslBuiltinConversionTest, Emit) {
+ auto params = GetParam();
+
+ ast::Module m;
+ GeneratorImpl g(&m);
+ EXPECT_EQ(g.builtin_to_attribute(params.builtin),
+ std::string(params.attribute_name));
+}
+INSTANTIATE_TEST_SUITE_P(
+ HlslGeneratorImplTest,
+ HlslBuiltinConversionTest,
+ testing::Values(
+ HlslBuiltinData{ast::Builtin::kPosition, "SV_Position"},
+ HlslBuiltinData{ast::Builtin::kVertexIdx, "SV_VertexID"},
+ HlslBuiltinData{ast::Builtin::kInstanceIdx, "SV_InstanceID"},
+ HlslBuiltinData{ast::Builtin::kFrontFacing, "SV_IsFrontFacing"},
+ HlslBuiltinData{ast::Builtin::kFragCoord, "SV_Position"},
+ HlslBuiltinData{ast::Builtin::kFragDepth, "SV_Depth"},
+ HlslBuiltinData{ast::Builtin::kWorkgroupSize, ""},
+ HlslBuiltinData{ast::Builtin::kLocalInvocationId, "SV_GroupThreadID"},
+ HlslBuiltinData{ast::Builtin::kLocalInvocationIdx, "SV_GroupIndex"},
+ HlslBuiltinData{ast::Builtin::kGlobalInvocationId,
+ "SV_DispatchThreadID"}));
+
} // namespace
} // namespace hlsl
} // namespace writer