Enable InsertBefore to work on global declarations
Use it for entry point IO sanitizing transforms to fix cases where structures were being inserted before type aliases that they reference.
Also fixes up some ordering issues with the FirstIndexOffset
transform.
Change-Id: I50d472ccb844b388f69914dcecbc0fcda1a579ed
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/45000
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: James Price <jrprice@google.com>
diff --git a/src/ast/module.cc b/src/ast/module.cc
index 892bcce..41e3629 100644
--- a/src/ast/module.cc
+++ b/src/ast/module.cc
@@ -25,7 +25,7 @@
Module::Module(const Source& source) : Base(source) {}
-Module::Module(const Source& source, std::vector<CastableBase*> global_decls)
+Module::Module(const Source& source, std::vector<Cloneable*> global_decls)
: Base(source), global_declarations_(std::move(global_decls)) {
for (auto* decl : global_declarations_) {
if (decl == nullptr) {
@@ -54,14 +54,14 @@
}
void Module::Copy(CloneContext* ctx, const Module* src) {
- for (auto* decl : src->global_declarations_) {
+ for (auto* decl : ctx->Clone(src->global_declarations_)) {
assert(decl);
if (auto* ty = decl->As<type::Type>()) {
- AddConstructedType(ctx->Clone(ty));
+ AddConstructedType(ty);
} else if (auto* func = decl->As<Function>()) {
- AddFunction(ctx->Clone(func));
+ AddFunction(func);
} else if (auto* var = decl->As<Variable>()) {
- AddGlobalVariable(ctx->Clone(var));
+ AddGlobalVariable(var);
} else {
TINT_ICE(ctx->dst->Diagnostics()) << "Unknown global declaration type";
}
diff --git a/src/ast/module.h b/src/ast/module.h
index 0f518e1..4defcc8 100644
--- a/src/ast/module.h
+++ b/src/ast/module.h
@@ -35,13 +35,13 @@
/// @param source the source of the module
/// @param global_decls the list of global types, functions, and variables, in
/// the order they were declared in the source program
- Module(const Source& source, std::vector<CastableBase*> global_decls);
+ Module(const Source& source, std::vector<Cloneable*> global_decls);
/// Destructor
~Module() override;
/// @returns the ordered global declarations for the translation unit
- const std::vector<CastableBase*>& GlobalDeclarations() const {
+ const std::vector<Cloneable*>& GlobalDeclarations() const {
return global_declarations_;
}
@@ -108,7 +108,7 @@
std::string to_str(const semantic::Info& sem) const;
private:
- std::vector<CastableBase*> global_declarations_;
+ std::vector<Cloneable*> global_declarations_;
std::vector<type::Type*> constructed_types_;
FunctionList functions_;
VariableList global_variables_;
diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc
index 6acabe0..eccd813 100644
--- a/src/transform/first_index_offset.cc
+++ b/src/transform/first_index_offset.cc
@@ -191,10 +191,11 @@
dst->create<ast::GroupDecoration>(Source{}, group),
});
- dst->AST().AddGlobalVariable(idx_var);
dst->AST().AddConstructedType(struct_type);
+ dst->AST().AddGlobalVariable(idx_var);
+
return idx_var;
}
diff --git a/src/transform/first_index_offset_test.cc b/src/transform/first_index_offset_test.cc
index ef8965b..53d8d0f 100644
--- a/src/transform/first_index_offset_test.cc
+++ b/src/transform/first_index_offset_test.cc
@@ -92,15 +92,15 @@
)";
auto* expect = R"(
-[[builtin(vertex_index)]] var<in> tint_first_index_offset_vert_idx : u32;
-
-[[binding(1), group(2)]] var<uniform> tint_first_index_data : TintFirstIndexOffsetData;
-
[[block]]
struct TintFirstIndexOffsetData {
tint_first_vertex_index : u32;
};
+[[binding(1), group(2)]] var<uniform> tint_first_index_data : TintFirstIndexOffsetData;
+
+[[builtin(vertex_index)]] var<in> tint_first_index_offset_vert_idx : u32;
+
fn test() -> u32 {
const vert_idx : u32 = (tint_first_index_offset_vert_idx + tint_first_index_data.tint_first_vertex_index);
return vert_idx;
@@ -140,15 +140,15 @@
)";
auto* expect = R"(
-[[builtin(instance_index)]] var<in> tint_first_index_offset_inst_idx : u32;
-
-[[binding(1), group(7)]] var<uniform> tint_first_index_data : TintFirstIndexOffsetData;
-
[[block]]
struct TintFirstIndexOffsetData {
tint_first_instance_index : u32;
};
+[[binding(1), group(7)]] var<uniform> tint_first_index_data : TintFirstIndexOffsetData;
+
+[[builtin(instance_index)]] var<in> tint_first_index_offset_inst_idx : u32;
+
fn test() -> u32 {
const inst_idx : u32 = (tint_first_index_offset_inst_idx + tint_first_index_data.tint_first_instance_index);
return inst_idx;
@@ -189,18 +189,18 @@
)";
auto* expect = R"(
-[[builtin(instance_index)]] var<in> tint_first_index_offset_instance_idx : u32;
-
-[[builtin(vertex_index)]] var<in> tint_first_index_offset_vert_idx : u32;
-
-[[binding(1), group(2)]] var<uniform> tint_first_index_data : TintFirstIndexOffsetData;
-
[[block]]
struct TintFirstIndexOffsetData {
tint_first_vertex_index : u32;
tint_first_instance_index : u32;
};
+[[binding(1), group(2)]] var<uniform> tint_first_index_data : TintFirstIndexOffsetData;
+
+[[builtin(instance_index)]] var<in> tint_first_index_offset_instance_idx : u32;
+
+[[builtin(vertex_index)]] var<in> tint_first_index_offset_vert_idx : u32;
+
fn test() -> u32 {
const instance_idx : u32 = (tint_first_index_offset_instance_idx + tint_first_index_data.tint_first_instance_index);
const vert_idx : u32 = (tint_first_index_offset_vert_idx + tint_first_index_data.tint_first_vertex_index);
@@ -245,15 +245,15 @@
)";
auto* expect = R"(
-[[builtin(vertex_index)]] var<in> tint_first_index_offset_vert_idx : u32;
-
-[[binding(1), group(2)]] var<uniform> tint_first_index_data : TintFirstIndexOffsetData;
-
[[block]]
struct TintFirstIndexOffsetData {
tint_first_vertex_index : u32;
};
+[[binding(1), group(2)]] var<uniform> tint_first_index_data : TintFirstIndexOffsetData;
+
+[[builtin(vertex_index)]] var<in> tint_first_index_offset_vert_idx : u32;
+
fn func1() -> u32 {
const vert_idx : u32 = (tint_first_index_offset_vert_idx + tint_first_index_data.tint_first_vertex_index);
return vert_idx;
diff --git a/src/transform/hlsl.cc b/src/transform/hlsl.cc
index ee7f3b3..494ca38 100644
--- a/src/transform/hlsl.cc
+++ b/src/transform/hlsl.cc
@@ -185,7 +185,7 @@
auto* in_struct = ctx.dst->create<type::Struct>(
ctx.dst->Symbols().New(),
ctx.dst->create<ast::Struct>(struct_members, ast::DecorationList{}));
- ctx.dst->AST().AddConstructedType(in_struct);
+ ctx.InsertBefore(func, in_struct);
// Create a new function parameter using this struct type.
auto struct_param_symbol = ctx.dst->Symbols().New();
diff --git a/src/transform/hlsl_test.cc b/src/transform/hlsl_test.cc
index 82c732d..5f6d7d6 100644
--- a/src/transform/hlsl_test.cc
+++ b/src/transform/hlsl_test.cc
@@ -159,6 +159,11 @@
)";
auto* expect = R"(
+struct FragIn {
+ [[location(2)]]
+ loc2 : f32;
+};
+
struct tint_symbol_3 {
[[builtin(frag_coord)]]
coord : vec4<f32>;
@@ -166,11 +171,6 @@
loc1 : f32;
};
-struct FragIn {
- [[location(2)]]
- loc2 : f32;
-};
-
[[stage(fragment)]]
fn frag_main(tint_symbol_4 : tint_symbol_3, frag_in : FragIn) -> void {
const coord : vec4<f32> = tint_symbol_4.coord;
@@ -184,6 +184,34 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(HlslTest, HandleEntryPointIOTypes_Parameter_TypeAlias) {
+ auto* src = R"(
+type myf32 = f32;
+
+[[stage(fragment)]]
+fn frag_main([[location(1)]] loc1 : myf32) -> void {
+}
+)";
+
+ auto* expect = R"(
+type myf32 = f32;
+
+struct tint_symbol_3 {
+ [[location(1)]]
+ loc1 : myf32;
+};
+
+[[stage(fragment)]]
+fn frag_main(tint_symbol_4 : tint_symbol_3) -> void {
+ const loc1 : myf32 = tint_symbol_4.loc1;
+}
+)";
+
+ auto got = Run<Hlsl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
TEST_F(HlslTest, HandleEntryPointIOTypes_OnlyStructParameters) {
// Expect no change.
auto* src = R"(
diff --git a/src/transform/msl.cc b/src/transform/msl.cc
index 3733d10..3110418 100644
--- a/src/transform/msl.cc
+++ b/src/transform/msl.cc
@@ -347,7 +347,7 @@
auto* in_struct = ctx.dst->create<type::Struct>(
ctx.dst->Symbols().New(),
ctx.dst->create<ast::Struct>(struct_members, ast::DecorationList{}));
- ctx.dst->AST().AddConstructedType(in_struct);
+ ctx.InsertBefore(func, in_struct);
// Create a new function parameter using this struct type.
auto struct_param_symbol = ctx.dst->Symbols().New();
diff --git a/src/transform/msl_test.cc b/src/transform/msl_test.cc
index cda0a5f..d1ddac8 100644
--- a/src/transform/msl_test.cc
+++ b/src/transform/msl_test.cc
@@ -372,6 +372,34 @@
EXPECT_EQ(src, str(got));
}
+TEST_F(MslEntryPointIOTest, HandleEntryPointIOTypes_Parameter_TypeAlias) {
+ auto* src = R"(
+type myf32 = f32;
+
+[[stage(fragment)]]
+fn frag_main([[location(1)]] loc1 : myf32) -> void {
+}
+)";
+
+ auto* expect = R"(
+type myf32 = f32;
+
+struct tint_symbol_3 {
+ [[location(1)]]
+ loc1 : myf32;
+};
+
+[[stage(fragment)]]
+fn frag_main(tint_symbol_4 : tint_symbol_3) -> void {
+ const loc1 : myf32 = tint_symbol_4.loc1;
+}
+)";
+
+ auto got = Run<Msl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
TEST_F(MslEntryPointIOTest, HandleEntryPointIOTypes_Parameters_EmptyBody) {
auto* src = R"(
[[stage(fragment)]]
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc
index 2c1bb7c..97ce585 100644
--- a/src/transform/spirv.cc
+++ b/src/transform/spirv.cc
@@ -114,9 +114,10 @@
// Create a new symbol for the global variable.
auto var_symbol = ctx.dst->Symbols().New();
// Create the global variable.
- ctx.dst->Global(var_symbol, ctx.Clone(param->type()),
- ast::StorageClass::kInput, nullptr,
- ctx.Clone(param->decorations()));
+ auto* var = ctx.dst->Var(var_symbol, ctx.Clone(param->type()),
+ ast::StorageClass::kInput, nullptr,
+ ctx.Clone(param->decorations()));
+ ctx.InsertBefore(func, var);
// Replace all uses of the function parameter with the global variable.
for (auto* user : ctx.src->Sem().Get(param)->Users()) {
diff --git a/src/transform/spirv_test.cc b/src/transform/spirv_test.cc
index f85e318..2f03ad4 100644
--- a/src/transform/spirv_test.cc
+++ b/src/transform/spirv_test.cc
@@ -42,15 +42,15 @@
[[location(1)]] var<in> tint_symbol_2 : f32;
-[[builtin(local_invocation_id)]] var<in> tint_symbol_6 : vec3<u32>;
-
-[[builtin(local_invocation_index)]] var<in> tint_symbol_7 : u32;
-
[[stage(fragment)]]
fn frag_main() -> void {
var col : f32 = (tint_symbol_1.x * tint_symbol_2);
}
+[[builtin(local_invocation_id)]] var<in> tint_symbol_6 : vec3<u32>;
+
+[[builtin(local_invocation_index)]] var<in> tint_symbol_7 : u32;
+
[[stage(compute)]]
fn compute_main() -> void {
var id_x : u32 = tint_symbol_6.x;
@@ -62,6 +62,30 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(SpirvTest, HandleEntryPointIOTypes_Parameter_TypeAlias) {
+ auto* src = R"(
+type myf32 = f32;
+
+[[stage(fragment)]]
+fn frag_main([[location(1)]] loc1 : myf32) -> void {
+}
+)";
+
+ auto* expect = R"(
+type myf32 = f32;
+
+[[location(1)]] var<in> tint_symbol_1 : myf32;
+
+[[stage(fragment)]]
+fn frag_main() -> void {
+}
+)";
+
+ auto got = Run<Spirv>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
TEST_F(SpirvTest, HandleSampleMaskBuiltins_Basic) {
auto* src = R"(
[[builtin(sample_index)]] var<in> sample_index : u32;
@@ -152,12 +176,12 @@
)";
auto* expect = R"(
+[[builtin(sample_mask_out)]] var<out> mask_out : array<u32, 1>;
+
[[builtin(sample_index)]] var<in> tint_symbol_1 : u32;
[[builtin(sample_mask_in)]] var<in> tint_symbol_2 : array<u32, 1>;
-[[builtin(sample_mask_out)]] var<out> mask_out : array<u32, 1>;
-
[[stage(fragment)]]
fn main() -> void {
mask_out[0] = tint_symbol_2[0];