[hlsl-writer] Use entry point interface canonicalization transform
This replaces the entry point IO component of the HLSL sanitizing
transform, and completes support for the new entry point IO syntax.
Struct emission in the HLSL writer is updated to use the correct
attributes depending on the pipeline stage usage.
Fixed: tint:511
Change-Id: I6a30ed2182ee19b2f25262a30a83685ffcb5ef25
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46521
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/samples/main.cc b/samples/main.cc
index 34ba6ee..e0d2c6b 100644
--- a/samples/main.cc
+++ b/samples/main.cc
@@ -714,6 +714,8 @@
#endif // TINT_BUILD_MSL_WRITER
#if TINT_BUILD_HLSL_WRITER
case Format::kHlsl:
+ transform_manager.append(
+ std::make_unique<tint::transform::CanonicalizeEntryPointIO>());
transform_manager.append(std::make_unique<tint::transform::Hlsl>());
break;
#endif // TINT_BUILD_HLSL_WRITER
diff --git a/src/transform/hlsl.cc b/src/transform/hlsl.cc
index 87e8c6e..21b9bf7 100644
--- a/src/transform/hlsl.cc
+++ b/src/transform/hlsl.cc
@@ -15,7 +15,6 @@
#include "src/transform/hlsl.h"
#include <utility>
-#include <vector>
#include "src/ast/variable_decl_statement.h"
#include "src/program_builder.h"
@@ -33,7 +32,6 @@
ProgramBuilder out;
CloneContext ctx(&out, in);
PromoteArrayInitializerToConstVar(ctx);
- HandleEntryPointIOTypes(ctx);
ctx.Clone();
return Output{Program(std::move(out))};
}
@@ -106,134 +104,5 @@
}
}
-void Hlsl::HandleEntryPointIOTypes(CloneContext& ctx) const {
- // Collect entry point parameters into a struct.
- // Insert function-scope const declarations to replace those parameters.
- //
- // Before:
- // ```
- // [[stage(fragment)]]
- // fn frag_main([[builtin(frag_coord)]] coord : vec4<f32>,
- // [[location(1)]] loc1 : f32,
- // [[location(2)]] loc2 : vec4<u32>) -> void {
- // var col : f32 = (coord.x * loc1);
- // }
- // ```
- //
- // After:
- // ```
- // struct frag_main_in {
- // [[builtin(frag_coord)]] coord : vec4<f32>;
- // [[location(1)]] loc1 : f32;
- // [[location(2)]] loc2 : vec4<u32>
- // };
-
- // [[stage(fragment)]]
- // fn frag_main(in : frag_main_in) -> void {
- // const coord : vec4<f32> = in.coord;
- // const loc1 : f32 = in.loc1;
- // const loc2 : vec4<u32> = in.loc2;
- // var col : f32 = (coord.x * loc1);
- // }
- // ```
-
- for (auto* func : ctx.src->AST().Functions()) {
- if (!func->IsEntryPoint()) {
- continue;
- }
-
- // Build a new structure to hold the non-struct input parameters.
- ast::StructMemberList struct_members;
- for (auto* param : func->params()) {
- auto* type = ctx.src->Sem().Get(param)->Type();
- if (type->Is<type::Struct>()) {
- // Already a struct, nothing to do.
- continue;
- }
-
- if (param->decorations().size() != 1) {
- TINT_ICE(ctx.dst->Diagnostics()) << "Unsupported entry point parameter";
- }
-
- auto name = ctx.src->Symbols().NameFor(param->symbol());
-
- auto* deco = param->decorations()[0];
- if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
- // Create a struct member with the builtin decoration.
- struct_members.push_back(ctx.dst->Member(
- name, ctx.Clone(type), ast::DecorationList{ctx.Clone(builtin)}));
- } else if (auto* loc = deco->As<ast::LocationDecoration>()) {
- // Create a struct member with the location decoration.
- struct_members.push_back(ctx.dst->Member(
- name, ctx.Clone(type), ast::DecorationList{ctx.Clone(loc)}));
- } else {
- TINT_ICE(ctx.dst->Diagnostics())
- << "Unsupported entry point parameter decoration";
- }
- }
-
- if (struct_members.empty()) {
- // Nothing to do.
- continue;
- }
-
- ast::VariableList new_parameters;
- ast::StatementList new_body;
-
- // Create a struct type to hold all of the non-struct input parameters.
- auto* in_struct = ctx.dst->create<type::Struct>(
- ctx.dst->Symbols().New(),
- ctx.dst->create<ast::Struct>(struct_members, ast::DecorationList{}));
- ctx.InsertBefore(func, in_struct);
-
- // Create a new function parameter using this struct type.
- auto struct_param_symbol = ctx.dst->Symbols().New();
- auto* struct_param =
- ctx.dst->Var(struct_param_symbol, in_struct, ast::StorageClass::kNone);
- new_parameters.push_back(struct_param);
-
- // Replace the original parameters with function-scope constants.
- for (auto* param : func->params()) {
- auto* type = ctx.src->Sem().Get(param)->Type();
- if (type->Is<type::Struct>()) {
- // Keep struct parameters unchanged.
- new_parameters.push_back(ctx.Clone(param));
- continue;
- }
-
- auto name = ctx.src->Symbols().NameFor(param->symbol());
-
- // Create a function-scope const to replace the parameter.
- // Initialize it with the value extracted from the struct parameter.
- auto func_const_symbol = ctx.dst->Symbols().Register(name);
- auto* func_const =
- ctx.dst->Const(func_const_symbol, ctx.Clone(type),
- ctx.dst->MemberAccessor(struct_param_symbol, name));
-
- new_body.push_back(ctx.dst->WrapInStatement(func_const));
-
- // Replace all uses of the function parameter with the function const.
- for (auto* user : ctx.src->Sem().Get(param)->Users()) {
- ctx.Replace<ast::Expression>(user->Declaration(),
- ctx.dst->Expr(func_const_symbol));
- }
- }
-
- // Copy over the rest of the function body unchanged.
- for (auto* stmt : func->body()->list()) {
- new_body.push_back(ctx.Clone(stmt));
- }
-
- // Rewrite the function header with the new parameters.
- auto* new_func = ctx.dst->create<ast::Function>(
- func->source(), ctx.Clone(func->symbol()), new_parameters,
- ctx.Clone(func->return_type()),
- ctx.dst->create<ast::BlockStatement>(new_body),
- ctx.Clone(func->decorations()),
- ctx.Clone(func->return_type_decorations()));
- ctx.Replace(func, new_func);
- }
-}
-
} // namespace transform
} // namespace tint
diff --git a/src/transform/hlsl.h b/src/transform/hlsl.h
index bf6a124..53ad35f 100644
--- a/src/transform/hlsl.h
+++ b/src/transform/hlsl.h
@@ -44,9 +44,6 @@
/// the array usage statement.
/// See crbug.com/tint/406 for more details
void PromoteArrayInitializerToConstVar(CloneContext& ctx) const;
-
- /// Hoist entry point parameters out to struct members.
- void HandleEntryPointIOTypes(CloneContext& ctx) const;
};
} // namespace transform
diff --git a/src/transform/hlsl_test.cc b/src/transform/hlsl_test.cc
index fb59a5d..bc6b264 100644
--- a/src/transform/hlsl_test.cc
+++ b/src/transform/hlsl_test.cc
@@ -143,133 +143,6 @@
EXPECT_EQ(expect, str(got));
}
-TEST_F(HlslTest, HandleEntryPointIOTypes_Parameters) {
- auto* src = R"(
-struct FragIn {
- [[location(2)]]
- loc2 : f32;
-};
-
-[[stage(fragment)]]
-fn frag_main([[builtin(frag_coord)]] coord : vec4<f32>,
- [[location(1)]] loc1 : f32,
- frag_in : FragIn) -> void {
- var col : f32 = (coord.x * loc1 + frag_in.loc2);
-}
-)";
-
- auto* expect = R"(
-struct FragIn {
- [[location(2)]]
- loc2 : f32;
-};
-
-struct tint_symbol_3 {
- [[builtin(frag_coord)]]
- coord : vec4<f32>;
- [[location(1)]]
- loc1 : f32;
-};
-
-[[stage(fragment)]]
-fn frag_main(tint_symbol_4 : tint_symbol_3, frag_in : FragIn) -> void {
- const coord : vec4<f32> = tint_symbol_4.coord;
- const loc1 : f32 = tint_symbol_4.loc1;
- var col : f32 = ((coord.x * loc1) + frag_in.loc2);
-}
-)";
-
- auto got = Run<Hlsl>(src);
-
- EXPECT_EQ(expect, str(got));
-}
-
-TEST_F(HlslTest, HandleEntryPointIOTypes_Parameter_TypeAlias) {
- auto* src = R"(
-type myf32 = f32;
-
-[[stage(fragment)]]
-fn frag_main([[location(1)]] loc1 : myf32) -> void {
-}
-)";
-
- auto* expect = R"(
-type myf32 = f32;
-
-struct tint_symbol_3 {
- [[location(1)]]
- loc1 : myf32;
-};
-
-[[stage(fragment)]]
-fn frag_main(tint_symbol_4 : tint_symbol_3) -> void {
- const loc1 : myf32 = tint_symbol_4.loc1;
-}
-)";
-
- auto got = Run<Hlsl>(src);
-
- EXPECT_EQ(expect, str(got));
-}
-
-TEST_F(HlslTest, HandleEntryPointIOTypes_OnlyStructParameters) {
- // Expect no change.
- auto* src = R"(
-struct FragBuiltins {
- [[builtin(frag_coord)]]
- coord : vec4<f32>;
-};
-
-struct FragInputs {
- [[location(1)]]
- loc1 : f32;
- [[location(2)]]
- loc2 : vec4<u32>;
-};
-
-[[stage(fragment)]]
-fn frag_main(builtins : FragBuiltins, inputs : FragInputs) -> void {
- var col : f32 = (builtins.coord.x * inputs.loc1);
-}
-)";
-
- auto got = Run<Hlsl>(src);
-
- EXPECT_EQ(src, str(got));
-}
-
-TEST_F(HlslTest, HandleEntryPointIOTypes_Parameters_EmptyBody) {
- auto* src = R"(
-[[stage(fragment)]]
-fn frag_main([[builtin(frag_coord)]] coord : vec4<f32>,
- [[location(1)]] loc1 : f32,
- [[location(2)]] loc2 : vec4<u32>) -> void {
-}
-)";
-
- auto* expect = R"(
-struct tint_symbol_4 {
- [[builtin(frag_coord)]]
- coord : vec4<f32>;
- [[location(1)]]
- loc1 : f32;
- [[location(2)]]
- loc2 : vec4<u32>;
-};
-
-[[stage(fragment)]]
-fn frag_main(tint_symbol_5 : tint_symbol_4) -> void {
- const coord : vec4<f32> = tint_symbol_5.coord;
- const loc1 : f32 = tint_symbol_5.loc1;
- const loc2 : vec4<u32> = tint_symbol_5.loc2;
-}
-)";
-
- auto got = Run<Hlsl>(src);
-
- EXPECT_EQ(expect, str(got));
-}
-
} // namespace
} // namespace transform
} // namespace tint
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index c019ca7..705a3d4 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -1571,8 +1571,7 @@
auto* func_sem = builder_.Sem().Get(func);
auto func_sym = func->symbol();
- // TODO(jrprice): Remove this when we remove support for entry point
- // inputs/outputs as module-scope globals.
+ // TODO(crbug.com/tint/697): Remove this.
for (auto data : func_sem->ReferencedLocationVariables()) {
auto* var = data.first;
auto* decl = var->Declaration();
@@ -1585,8 +1584,7 @@
}
}
- // TODO(jrprice): Remove this when we remove support for entry point
- // inputs/outputs as module-scope globals.
+ // TODO(crbug.com/tint/697): Remove this.
for (auto data : func_sem->ReferencedBuiltinVariables()) {
auto* var = data.first;
auto* decl = var->Declaration();
@@ -1670,8 +1668,7 @@
out << std::endl;
}
- // TODO(jrprice): Remove this when we remove support for entry point inputs as
- // module-scope globals.
+ // TODO(crbug.com/tint/697): Remove this.
if (!in_variables.empty()) {
auto in_struct_name = generate_name(builder_.Symbols().NameFor(func_sym) +
"_" + kInStructNameSuffix);
@@ -1721,8 +1718,7 @@
out << "};" << std::endl << std::endl;
}
- // TODO(jrprice): Remove this when we remove support for entry point outputs
- // as module-scope globals.
+ // TODO(crbug.com/tint/697): Remove this.
if (!outvariables.empty()) {
auto outstruct_name = generate_name(builder_.Symbols().NameFor(func_sym) +
"_" + kOutStructNameSuffix);
@@ -1882,16 +1878,21 @@
auto outdata = ep_sym_to_out_data_.find(current_ep_sym_);
bool has_outdata = outdata != ep_sym_to_out_data_.end();
if (has_outdata) {
+ // TODO(crbug.com/tint/697): Remove this.
+ if (!func->return_type()->Is<type::Void>()) {
+ TINT_ICE(diagnostics_) << "Mixing module-scope variables and return "
+ "types for shader outputs";
+ }
out << outdata->second.struct_name;
} else {
- out << "void";
+ out << func->return_type()->FriendlyName(builder_.Symbols());
}
// TODO(dsinclair): This should output the remapped name
out << " " << namer_.NameFor(builder_.Symbols().NameFor(current_ep_sym_))
<< "(";
bool first = true;
- // TODO(jrprice): Remove this when we remove support for inputs as globals.
+ // TODO(crbug.com/tint/697): Remove this.
auto in_data = ep_sym_to_in_data_.find(current_ep_sym_);
if (in_data != ep_sym_to_in_data_.end()) {
out << in_data->second.struct_name << " " << in_data->second.var_name;
@@ -2367,13 +2368,7 @@
bool GeneratorImpl::EmitReturn(std::ostream& out, ast::ReturnStatement* stmt) {
make_indent(out);
- if (generating_entry_point_) {
- out << "return";
- auto outdata = ep_sym_to_out_data_.find(current_ep_sym_);
- if (outdata != ep_sym_to_out_data_.end()) {
- out << " " << outdata->second.var_name;
- }
- } else if (stmt->has_value()) {
+ if (stmt->has_value()) {
std::ostringstream pre;
std::ostringstream ret_out;
if (!EmitExpression(pre, ret_out, stmt->value())) {
@@ -2381,6 +2376,13 @@
}
out << pre.str();
out << "return " << ret_out.str();
+ } else if (generating_entry_point_) {
+ // TODO(crbug.com/tint/697): Remove this (and generating_entry_point_)
+ out << "return";
+ auto outdata = ep_sym_to_out_data_.find(current_ep_sym_);
+ if (outdata != ep_sym_to_out_data_.end()) {
+ out << " " << outdata->second.var_name;
+ }
} else {
out << "return";
}
@@ -2650,7 +2652,27 @@
for (auto* deco : mem->decorations()) {
if (auto* location = deco->As<ast::LocationDecoration>()) {
- out << " : TEXCOORD" << location->value();
+ auto& pipeline_stage_uses =
+ builder_.Sem().Get(str)->PipelineStageUses();
+ if (pipeline_stage_uses.size() != 1) {
+ TINT_ICE(diagnostics_) << "invalid entry point IO struct uses";
+ }
+
+ if (pipeline_stage_uses.count(
+ semantic::PipelineStageUsage::kVertexInput)) {
+ out << " : TEXCOORD" + std::to_string(location->value());
+ } else if (pipeline_stage_uses.count(
+ semantic::PipelineStageUsage::kVertexOutput)) {
+ out << " : TEXCOORD" + std::to_string(location->value());
+ } else if (pipeline_stage_uses.count(
+ semantic::PipelineStageUsage::kFragmentInput)) {
+ out << " : TEXCOORD" + std::to_string(location->value());
+ } else if (pipeline_stage_uses.count(
+ semantic::PipelineStageUsage::kFragmentOutput)) {
+ out << " : SV_Target" + std::to_string(location->value());
+ } else {
+ TINT_ICE(diagnostics_) << "invalid use of location decoration";
+ }
} else if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
auto attr = builtin_to_attribute(builtin->value());
if (attr.empty()) {
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index 35dadb9..c2f364f 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -105,136 +105,205 @@
}
)");
-}
-TEST_F(HlslGeneratorImplTest_Function,
- Emit_Decoration_EntryPoint_NoReturn_InOut) {
- auto* foo_in = Var("foo", ty.f32(), ast::StorageClass::kNone, nullptr,
- ast::DecorationList{
- create<ast::LocationDecoration>(0),
- });
-
- // TODO(jrprice): Make this the return value when supported.
- Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr,
- ast::DecorationList{
- create<ast::LocationDecoration>(1),
- });
-
- Func("main", ast::VariableList{foo_in}, ty.void_(),
- ast::StatementList{
- create<ast::AssignmentStatement>(Expr("bar"), Expr("foo")),
- /* no explicit return */},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
- });
-
- GeneratorImpl& gen = SanitizeAndBuild();
-
- ASSERT_TRUE(gen.Generate(out)) << gen.error();
- EXPECT_EQ(result(), R"(struct tint_symbol_2 {
- float foo : TEXCOORD0;
-};
-
-struct main_out {
- float bar : SV_Target1;
-};
-
-main_out main(tint_symbol_2 tint_symbol_3) {
- main_out tint_out = (main_out)0;
- const float foo = tint_symbol_3.foo;
- tint_out.bar = foo;
- return tint_out;
-}
-
-)");
+ Validate();
}
TEST_F(HlslGeneratorImplTest_Function,
Emit_Decoration_EntryPoint_WithInOutVars) {
- auto* foo_in = Var("foo", ty.f32(), ast::StorageClass::kNone, nullptr,
- ast::DecorationList{
- create<ast::LocationDecoration>(0),
- });
-
- // TODO(jrprice): Make this the return value when supported.
- Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr,
- ast::DecorationList{
- create<ast::LocationDecoration>(1),
- });
-
- Func("frag_main", ast::VariableList{foo_in}, ty.void_(),
- ast::StatementList{
- create<ast::AssignmentStatement>(Expr("bar"), Expr("foo")),
- create<ast::ReturnStatement>(),
- },
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
- });
+ // fn frag_main([[location(0)]] foo : f32) -> [[location(1)]] f32 {
+ // return foo;
+ // }
+ auto* foo_in =
+ Const("foo", ty.f32(), nullptr, {create<ast::LocationDecoration>(0)});
+ Func("frag_main", ast::VariableList{foo_in}, ty.f32(),
+ {create<ast::ReturnStatement>(Expr("foo"))},
+ {create<ast::StageDecoration>(ast::PipelineStage::kFragment)},
+ {create<ast::LocationDecoration>(1)});
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
- EXPECT_EQ(result(), R"(struct tint_symbol_2 {
+ EXPECT_EQ(result(), R"(struct tint_symbol_3 {
float foo : TEXCOORD0;
};
-
-struct frag_main_out {
- float bar : SV_Target1;
+struct tint_symbol_5 {
+ float value : SV_Target1;
};
-frag_main_out frag_main(tint_symbol_2 tint_symbol_3) {
- frag_main_out tint_out = (frag_main_out)0;
- const float foo = tint_symbol_3.foo;
- tint_out.bar = foo;
- return tint_out;
+tint_symbol_5 frag_main(tint_symbol_3 tint_symbol_1) {
+ const float foo = tint_symbol_1.foo;
+ return tint_symbol_5(foo);
}
)");
+
+ Validate();
}
TEST_F(HlslGeneratorImplTest_Function,
Emit_Decoration_EntryPoint_WithInOut_Builtins) {
+ // fn frag_main([[position(0)]] coord : vec4<f32>) -> [[frag_depth]] f32 {
+ // return coord.x;
+ // }
auto* coord_in =
- Var("coord", ty.vec4<f32>(), ast::StorageClass::kNone, nullptr,
- ast::DecorationList{
- create<ast::BuiltinDecoration>(ast::Builtin::kFragCoord),
- });
-
- // TODO(jrprice): Make this the return value when supported.
- Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr,
- ast::DecorationList{
- create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth),
- });
-
- Func("frag_main", ast::VariableList{coord_in}, ty.void_(),
- ast::StatementList{
- create<ast::AssignmentStatement>(Expr("depth"),
- MemberAccessor("coord", "x")),
- create<ast::ReturnStatement>(),
- },
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
- });
+ Const("coord", ty.vec4<f32>(), nullptr,
+ {create<ast::BuiltinDecoration>(ast::Builtin::kFragCoord)});
+ Func("frag_main", ast::VariableList{coord_in}, ty.f32(),
+ {create<ast::ReturnStatement>(MemberAccessor("coord", "x"))},
+ {create<ast::StageDecoration>(ast::PipelineStage::kFragment)},
+ {create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth)});
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
- EXPECT_EQ(result(), R"(struct tint_symbol_2 {
+ EXPECT_EQ(result(), R"(struct tint_symbol_3 {
float4 coord : SV_Position;
};
-
-struct frag_main_out {
- float depth : SV_Depth;
+struct tint_symbol_5 {
+ float value : SV_Depth;
};
-frag_main_out frag_main(tint_symbol_2 tint_symbol_3) {
- frag_main_out tint_out = (frag_main_out)0;
- const float4 coord = tint_symbol_3.coord;
- tint_out.depth = coord.x;
- return tint_out;
+tint_symbol_5 frag_main(tint_symbol_3 tint_symbol_1) {
+ const float4 coord = tint_symbol_1.coord;
+ return tint_symbol_5(coord.x);
}
)");
+
+ Validate();
+}
+
+TEST_F(HlslGeneratorImplTest_Function,
+ Emit_Decoration_EntryPoint_SharedStruct_DifferentStages) {
+ // struct Interface {
+ // [[location(1)]] col1 : f32;
+ // [[location(2)]] col2 : f32;
+ // };
+ // fn vert_main() -> Interface {
+ // return Interface(0.4, 0.6);
+ // }
+ // fn frag_main(colors : Interface) -> void {
+ // const r = colors.col1;
+ // const g = colors.col2;
+ // }
+ auto* interface_struct = Structure(
+ "Interface",
+ {Member("col1", ty.f32(), {create<ast::LocationDecoration>(1)}),
+ Member("col2", ty.f32(), {create<ast::LocationDecoration>(2)})});
+
+ Func("vert_main", {}, interface_struct,
+ {create<ast::ReturnStatement>(
+ Construct(interface_struct, Expr(0.5f), Expr(0.25f)))},
+ {create<ast::StageDecoration>(ast::PipelineStage::kVertex)});
+
+ Func("frag_main", {Const("colors", interface_struct)}, ty.void_(),
+ {
+ WrapInStatement(
+ Const("r", ty.f32(), MemberAccessor(Expr("colors"), "col1"))),
+ WrapInStatement(
+ Const("g", ty.f32(), MemberAccessor(Expr("colors"), "col2"))),
+ },
+ {create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
+
+ GeneratorImpl& gen = SanitizeAndBuild();
+
+ ASSERT_TRUE(gen.Generate(out)) << gen.error();
+ EXPECT_EQ(result(), R"(struct Interface {
+ float col1;
+ float col2;
+};
+struct tint_symbol_4 {
+ float col1 : TEXCOORD1;
+ float col2 : TEXCOORD2;
+};
+struct tint_symbol_9 {
+ float col1 : TEXCOORD1;
+ float col2 : TEXCOORD2;
+};
+
+tint_symbol_4 vert_main() {
+ const Interface tint_symbol_5 = Interface(0.5f, 0.25f);
+ return tint_symbol_4(tint_symbol_5.col1, tint_symbol_5.col2);
+}
+
+void frag_main(tint_symbol_9 tint_symbol_7) {
+ const Interface colors = Interface(tint_symbol_7.col1, tint_symbol_7.col2);
+ const float r = colors.col1;
+ const float g = colors.col2;
+ return;
+}
+
+)");
+
+ Validate();
+}
+
+TEST_F(HlslGeneratorImplTest_Function,
+ Emit_Decoration_EntryPoint_SharedStruct_HelperFunction) {
+ // struct VertexOutput {
+ // [[builtin(position)]] pos : vec4<f32>;
+ // };
+ // fn foo(x : f32) -> VertexOutput {
+ // return VertexOutput(vec4<f32>(x, x, x, 1.0));
+ // }
+ // fn vert_main1() -> VertexOutput {
+ // return foo(0.5);
+ // }
+ // fn vert_main2() -> VertexOutput {
+ // return foo(0.25);
+ // }
+ auto* vertex_output_struct = Structure(
+ "VertexOutput",
+ {Member("pos", ty.vec4<f32>(),
+ {create<ast::BuiltinDecoration>(ast::Builtin::kPosition)})});
+
+ Func("foo", {Const("x", ty.f32())}, vertex_output_struct,
+ {create<ast::ReturnStatement>(Construct(
+ vertex_output_struct, Construct(ty.vec4<f32>(), Expr("x"), Expr("x"),
+ Expr("x"), Expr(1.f))))},
+ {});
+
+ Func("vert_main1", {}, vertex_output_struct,
+ {create<ast::ReturnStatement>(
+ Construct(vertex_output_struct, Expr(Call("foo", Expr(0.5f)))))},
+ {create<ast::StageDecoration>(ast::PipelineStage::kVertex)});
+
+ Func("vert_main2", {}, vertex_output_struct,
+ {create<ast::ReturnStatement>(
+ Construct(vertex_output_struct, Expr(Call("foo", Expr(0.25f)))))},
+ {create<ast::StageDecoration>(ast::PipelineStage::kVertex)});
+
+ GeneratorImpl& gen = SanitizeAndBuild();
+
+ ASSERT_TRUE(gen.Generate(out)) << gen.error();
+ EXPECT_EQ(result(), R"(struct VertexOutput {
+ float4 pos;
+};
+struct tint_symbol_3 {
+ float4 pos : SV_Position;
+};
+struct tint_symbol_7 {
+ float4 pos : SV_Position;
+};
+
+VertexOutput foo(float x) {
+ return VertexOutput(float4(x, x, x, 1.0f));
+}
+
+tint_symbol_3 vert_main1() {
+ const VertexOutput tint_symbol_5 = VertexOutput(foo(0.5f));
+ return tint_symbol_3(tint_symbol_5.pos);
+}
+
+tint_symbol_7 vert_main2() {
+ const VertexOutput tint_symbol_8 = VertexOutput(foo(0.25f));
+ return tint_symbol_7(tint_symbol_8.pos);
+}
+
+)");
+
+ Validate();
}
TEST_F(HlslGeneratorImplTest_Function,
@@ -270,6 +339,8 @@
}
)");
+
+ Validate();
}
TEST_F(HlslGeneratorImplTest_Function,
@@ -310,6 +381,8 @@
}
)");
+
+ Validate();
}
TEST_F(HlslGeneratorImplTest_Function,
@@ -349,6 +422,8 @@
float v = asfloat(coord.Load(4));
return;
})"));
+
+ Validate();
}
TEST_F(HlslGeneratorImplTest_Function,
@@ -388,6 +463,8 @@
float v = asfloat(coord.Load(4));
return;
})"));
+
+ Validate();
}
TEST_F(HlslGeneratorImplTest_Function,
@@ -425,6 +502,8 @@
coord.Store(4, asuint(2.0f));
return;
})"));
+
+ Validate();
}
TEST_F(HlslGeneratorImplTest_Function,
@@ -462,15 +541,18 @@
coord.Store(4, asuint(2.0f));
return;
})"));
+
+ Validate();
}
+// TODO(crbug.com/tint/697): Remove this test
TEST_F(
HlslGeneratorImplTest_Function,
Emit_Decoration_Called_By_EntryPoints_WithLocationGlobals_And_Params) { // NOLINT
- auto* foo_in = Var("foo", ty.f32(), ast::StorageClass::kNone, nullptr,
- ast::DecorationList{
- create<ast::LocationDecoration>(0),
- });
+ Global("foo", ty.f32(), ast::StorageClass::kInput, nullptr,
+ ast::DecorationList{
+ create<ast::LocationDecoration>(0),
+ });
Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr,
ast::DecorationList{
@@ -483,8 +565,7 @@
});
Func("sub_func",
- ast::VariableList{Var("param", ty.f32(), ast::StorageClass::kNone),
- Var("foo", ty.f32(), ast::StorageClass::kNone)},
+ ast::VariableList{Var("param", ty.f32(), ast::StorageClass::kNone)},
ty.f32(),
ast::StatementList{
create<ast::AssignmentStatement>(Expr("bar"), Expr("foo")),
@@ -493,20 +574,20 @@
},
ast::DecorationList{});
- Func("ep_1", ast::VariableList{foo_in}, ty.void_(),
- ast::StatementList{
- create<ast::AssignmentStatement>(
- Expr("bar"), Call("sub_func", 1.0f, Expr("foo"))),
- create<ast::ReturnStatement>(),
- },
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
- });
+ Func(
+ "ep_1", ast::VariableList{}, ty.void_(),
+ ast::StatementList{
+ create<ast::AssignmentStatement>(Expr("bar"), Call("sub_func", 1.0f)),
+ create<ast::ReturnStatement>(),
+ },
+ ast::DecorationList{
+ create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ });
- GeneratorImpl& gen = SanitizeAndBuild();
+ GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
- EXPECT_EQ(result(), R"(struct tint_symbol_2 {
+ EXPECT_EQ(result(), R"(struct ep_1_in {
float foo : TEXCOORD0;
};
@@ -515,20 +596,21 @@
float val : SV_Target0;
};
-float sub_func_ep_1(out ep_1_out tint_out, float param, float foo) {
- tint_out.bar = foo;
+float sub_func_ep_1(in ep_1_in tint_in, out ep_1_out tint_out, float param) {
+ tint_out.bar = tint_in.foo;
tint_out.val = param;
- return foo;
+ return tint_in.foo;
}
-ep_1_out ep_1(tint_symbol_2 tint_symbol_3) {
+ep_1_out ep_1(ep_1_in tint_in) {
ep_1_out tint_out = (ep_1_out)0;
- const float foo = tint_symbol_3.foo;
- tint_out.bar = sub_func_ep_1(tint_out, 1.0f, foo);
+ tint_out.bar = sub_func_ep_1(tint_in, tint_out, 1.0f);
return tint_out;
}
)");
+
+ Validate();
}
TEST_F(HlslGeneratorImplTest_Function,
@@ -574,39 +656,38 @@
}
)");
+
+ Validate();
}
+// TODO(crbug.com/tint/697): Remove this test
TEST_F(
HlslGeneratorImplTest_Function,
Emit_Decoration_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) { // NOLINT
- auto* coord_in =
- Var("coord", ty.vec4<f32>(), ast::StorageClass::kNone, nullptr,
- ast::DecorationList{
- create<ast::BuiltinDecoration>(ast::Builtin::kFragCoord),
- });
+ Global("coord", ty.vec4<f32>(), ast::StorageClass::kInput, nullptr,
+ ast::DecorationList{
+ create<ast::BuiltinDecoration>(ast::Builtin::kFragCoord),
+ });
- // TODO(jrprice): Make this the return value when supported.
Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr,
ast::DecorationList{
create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth),
});
- Func(
- "sub_func",
- ast::VariableList{Var("param", ty.f32(), ast::StorageClass::kNone),
- Var("coord", ty.vec4<f32>(), ast::StorageClass::kNone)},
- ty.f32(),
- ast::StatementList{
- create<ast::AssignmentStatement>(Expr("depth"),
- MemberAccessor("coord", "x")),
- create<ast::ReturnStatement>(Expr("param")),
- },
- ast::DecorationList{});
-
- Func("ep_1", ast::VariableList{coord_in}, ty.void_(),
+ Func("sub_func",
+ ast::VariableList{Var("param", ty.f32(), ast::StorageClass::kNone)},
+ ty.f32(),
ast::StatementList{
- create<ast::AssignmentStatement>(
- Expr("depth"), Call("sub_func", 1.0f, Expr("coord"))),
+ create<ast::AssignmentStatement>(Expr("depth"),
+ MemberAccessor("coord", "x")),
+ create<ast::ReturnStatement>(Expr("param")),
+ },
+ ast::DecorationList{});
+
+ Func("ep_1", ast::VariableList{}, ty.void_(),
+ ast::StatementList{
+ create<ast::AssignmentStatement>(Expr("depth"),
+ Call("sub_func", 1.0f)),
create<ast::ReturnStatement>(),
},
ast::DecorationList{
@@ -616,7 +697,7 @@
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
- EXPECT_EQ(result(), R"(struct tint_symbol_2 {
+ EXPECT_EQ(result(), R"(struct ep_1_in {
float4 coord : SV_Position;
};
@@ -624,19 +705,20 @@
float depth : SV_Depth;
};
-float sub_func_ep_1(out ep_1_out tint_out, float param, float4 coord) {
- tint_out.depth = coord.x;
+float sub_func_ep_1(in ep_1_in tint_in, out ep_1_out tint_out, float param) {
+ tint_out.depth = tint_in.coord.x;
return param;
}
-ep_1_out ep_1(tint_symbol_2 tint_symbol_3) {
+ep_1_out ep_1(ep_1_in tint_in) {
ep_1_out tint_out = (ep_1_out)0;
- const float4 coord = tint_symbol_3.coord;
- tint_out.depth = sub_func_ep_1(tint_out, 1.0f, coord);
+ tint_out.depth = sub_func_ep_1(tint_in, tint_out, 1.0f);
return tint_out;
}
)");
+
+ Validate();
}
TEST_F(HlslGeneratorImplTest_Function,
@@ -684,6 +766,8 @@
}
)");
+
+ Validate();
}
TEST_F(HlslGeneratorImplTest_Function,
@@ -729,6 +813,8 @@
float v = sub_func(1.0f);
return;
})"));
+
+ Validate();
}
TEST_F(HlslGeneratorImplTest_Function,
@@ -772,6 +858,8 @@
}
)");
+
+ Validate();
}
TEST_F(HlslGeneratorImplTest_Function,
@@ -789,6 +877,8 @@
}
)");
+
+ Validate();
}
TEST_F(HlslGeneratorImplTest_Function, Emit_Decoration_EntryPoint_Compute) {
@@ -809,6 +899,8 @@
}
)");
+
+ Validate();
}
TEST_F(HlslGeneratorImplTest_Function,
@@ -831,6 +923,8 @@
}
)");
+
+ Validate();
}
TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
@@ -853,6 +947,8 @@
}
)");
+
+ Validate();
}
// https://crbug.com/tint/297
@@ -932,6 +1028,8 @@
}
)");
+
+ Validate();
}
} // namespace
diff --git a/src/writer/hlsl/test_helper.h b/src/writer/hlsl/test_helper.h
index 3d07994..2f5f356 100644
--- a/src/writer/hlsl/test_helper.h
+++ b/src/writer/hlsl/test_helper.h
@@ -20,7 +20,9 @@
#include <utility>
#include "gtest/gtest.h"
+#include "src/transform/canonicalize_entry_point_io.h"
#include "src/transform/hlsl.h"
+#include "src/transform/manager.h"
#include "src/writer/hlsl/generator_impl.h"
namespace tint {
@@ -98,7 +100,12 @@
ASSERT_TRUE(program->IsValid())
<< formatter.format(program->Diagnostics());
}();
- auto result = transform::Hlsl().Run(program.get());
+
+ tint::transform::Manager transform_manager;
+ transform_manager.append(
+ std::make_unique<tint::transform::CanonicalizeEntryPointIO>());
+ transform_manager.append(std::make_unique<tint::transform::Hlsl>());
+ auto result = transform_manager.Run(program.get());
[&]() {
ASSERT_TRUE(result.program.IsValid())
<< formatter.format(result.program.Diagnostics());