Consistent formatting for Dawn/Tint.
This CL updates the clang format files to have a single shared format
between Dawn and Tint. The major changes are tabs are 4 spaces, lines
are 100 columns and namespaces are not indented.
Bug: dawn:1339
Change-Id: I4208742c95643998d9fd14e77a9cc558071ded39
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/87603
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/transform/add_empty_entry_point.cc b/src/tint/transform/add_empty_entry_point.cc
index 0710d2b..18c5688 100644
--- a/src/tint/transform/add_empty_entry_point.cc
+++ b/src/tint/transform/add_empty_entry_point.cc
@@ -26,24 +26,19 @@
AddEmptyEntryPoint::~AddEmptyEntryPoint() = default;
-bool AddEmptyEntryPoint::ShouldRun(const Program* program,
- const DataMap&) const {
- for (auto* func : program->AST().Functions()) {
- if (func->IsEntryPoint()) {
- return false;
+bool AddEmptyEntryPoint::ShouldRun(const Program* program, const DataMap&) const {
+ for (auto* func : program->AST().Functions()) {
+ if (func->IsEntryPoint()) {
+ return false;
+ }
}
- }
- return true;
+ return true;
}
-void AddEmptyEntryPoint::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- ctx.dst->Func(ctx.dst->Symbols().New("unused_entry_point"), {},
- ctx.dst->ty.void_(), {},
- {ctx.dst->Stage(ast::PipelineStage::kCompute),
- ctx.dst->WorkgroupSize(1)});
- ctx.Clone();
+void AddEmptyEntryPoint::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ ctx.dst->Func(ctx.dst->Symbols().New("unused_entry_point"), {}, ctx.dst->ty.void_(), {},
+ {ctx.dst->Stage(ast::PipelineStage::kCompute), ctx.dst->WorkgroupSize(1)});
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/add_empty_entry_point.h b/src/tint/transform/add_empty_entry_point.h
index eb9dccd..5530355 100644
--- a/src/tint/transform/add_empty_entry_point.h
+++ b/src/tint/transform/add_empty_entry_point.h
@@ -20,30 +20,26 @@
namespace tint::transform {
/// Add an empty entry point to the module, if no other entry points exist.
-class AddEmptyEntryPoint final
- : public Castable<AddEmptyEntryPoint, Transform> {
- public:
- /// Constructor
- AddEmptyEntryPoint();
- /// Destructor
- ~AddEmptyEntryPoint() override;
+class AddEmptyEntryPoint final : public Castable<AddEmptyEntryPoint, Transform> {
+ public:
+ /// Constructor
+ AddEmptyEntryPoint();
+ /// Destructor
+ ~AddEmptyEntryPoint() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/add_empty_entry_point_test.cc b/src/tint/transform/add_empty_entry_point_test.cc
index 0854251..cbbd9c3 100644
--- a/src/tint/transform/add_empty_entry_point_test.cc
+++ b/src/tint/transform/add_empty_entry_point_test.cc
@@ -24,52 +24,52 @@
using AddEmptyEntryPointTest = TransformTest;
TEST_F(AddEmptyEntryPointTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_TRUE(ShouldRun<AddEmptyEntryPoint>(src));
+ EXPECT_TRUE(ShouldRun<AddEmptyEntryPoint>(src));
}
TEST_F(AddEmptyEntryPointTest, ShouldRunExistingEntryPoint) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn existing() {}
)";
- EXPECT_FALSE(ShouldRun<AddEmptyEntryPoint>(src));
+ EXPECT_FALSE(ShouldRun<AddEmptyEntryPoint>(src));
}
TEST_F(AddEmptyEntryPointTest, EmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn unused_entry_point() {
}
)";
- auto got = Run<AddEmptyEntryPoint>(src);
+ auto got = Run<AddEmptyEntryPoint>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(AddEmptyEntryPointTest, ExistingEntryPoint) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn main() {
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<AddEmptyEntryPoint>(src);
+ auto got = Run<AddEmptyEntryPoint>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(AddEmptyEntryPointTest, NameClash) {
- auto* src = R"(var<private> unused_entry_point : f32;)";
+ auto* src = R"(var<private> unused_entry_point : f32;)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn unused_entry_point_1() {
}
@@ -77,9 +77,9 @@
var<private> unused_entry_point : f32;
)";
- auto got = Run<AddEmptyEntryPoint>(src);
+ auto got = Run<AddEmptyEntryPoint>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/add_spirv_block_attribute.cc b/src/tint/transform/add_spirv_block_attribute.cc
index 91ab991..38e0de6 100644
--- a/src/tint/transform/add_spirv_block_attribute.cc
+++ b/src/tint/transform/add_spirv_block_attribute.cc
@@ -23,8 +23,7 @@
#include "src/tint/utils/map.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::AddSpirvBlockAttribute);
-TINT_INSTANTIATE_TYPEINFO(
- tint::transform::AddSpirvBlockAttribute::SpirvBlockAttribute);
+TINT_INSTANTIATE_TYPEINFO(tint::transform::AddSpirvBlockAttribute::SpirvBlockAttribute);
namespace tint::transform {
@@ -32,89 +31,81 @@
AddSpirvBlockAttribute::~AddSpirvBlockAttribute() = default;
-void AddSpirvBlockAttribute::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- auto& sem = ctx.src->Sem();
+void AddSpirvBlockAttribute::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ auto& sem = ctx.src->Sem();
- // Collect the set of structs that are nested in other types.
- std::unordered_set<const sem::Struct*> nested_structs;
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- if (auto* arr = sem.Get<sem::Array>(node->As<ast::Array>())) {
- if (auto* nested_str = arr->ElemType()->As<sem::Struct>()) {
- nested_structs.insert(nested_str);
- }
- } else if (auto* str = sem.Get<sem::Struct>(node->As<ast::Struct>())) {
- for (auto* member : str->Members()) {
- if (auto* nested_str = member->Type()->As<sem::Struct>()) {
- nested_structs.insert(nested_str);
+ // Collect the set of structs that are nested in other types.
+ std::unordered_set<const sem::Struct*> nested_structs;
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* arr = sem.Get<sem::Array>(node->As<ast::Array>())) {
+ if (auto* nested_str = arr->ElemType()->As<sem::Struct>()) {
+ nested_structs.insert(nested_str);
+ }
+ } else if (auto* str = sem.Get<sem::Struct>(node->As<ast::Struct>())) {
+ for (auto* member : str->Members()) {
+ if (auto* nested_str = member->Type()->As<sem::Struct>()) {
+ nested_structs.insert(nested_str);
+ }
+ }
}
- }
- }
- }
-
- // A map from a type in the source program to a block-decorated wrapper that
- // contains it in the destination program.
- std::unordered_map<const sem::Type*, const ast::Struct*> wrapper_structs;
-
- // Process global variables that are buffers.
- for (auto* var : ctx.src->AST().GlobalVariables()) {
- auto* sem_var = sem.Get<sem::GlobalVariable>(var);
- if (var->declared_storage_class != ast::StorageClass::kStorage &&
- var->declared_storage_class != ast::StorageClass::kUniform) {
- continue;
}
- auto* ty = sem.Get(var->type);
- auto* str = ty->As<sem::Struct>();
- if (!str || nested_structs.count(str)) {
- const char* kMemberName = "inner";
+ // A map from a type in the source program to a block-decorated wrapper that
+ // contains it in the destination program.
+ std::unordered_map<const sem::Type*, const ast::Struct*> wrapper_structs;
- // This is a non-struct or a struct that is nested somewhere else, so we
- // need to wrap it first.
- auto* wrapper = utils::GetOrCreate(wrapper_structs, ty, [&]() {
- auto* block =
- ctx.dst->ASTNodes().Create<SpirvBlockAttribute>(ctx.dst->ID());
- auto wrapper_name = ctx.src->Symbols().NameFor(var->symbol) + "_block";
- auto* ret = ctx.dst->create<ast::Struct>(
- ctx.dst->Symbols().New(wrapper_name),
- ast::StructMemberList{
- ctx.dst->Member(kMemberName, CreateASTTypeFor(ctx, ty))},
- ast::AttributeList{block});
- ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), var, ret);
- return ret;
- });
- ctx.Replace(var->type, ctx.dst->ty.Of(wrapper));
+ // Process global variables that are buffers.
+ for (auto* var : ctx.src->AST().GlobalVariables()) {
+ auto* sem_var = sem.Get<sem::GlobalVariable>(var);
+ if (var->declared_storage_class != ast::StorageClass::kStorage &&
+ var->declared_storage_class != ast::StorageClass::kUniform) {
+ continue;
+ }
- // Insert a member accessor to get the original type from the wrapper at
- // any usage of the original variable.
- for (auto* user : sem_var->Users()) {
- ctx.Replace(
- user->Declaration(),
- ctx.dst->MemberAccessor(ctx.Clone(var->symbol), kMemberName));
- }
- } else {
- // Add a block attribute to this struct directly.
- auto* block =
- ctx.dst->ASTNodes().Create<SpirvBlockAttribute>(ctx.dst->ID());
- ctx.InsertFront(str->Declaration()->attributes, block);
+ auto* ty = sem.Get(var->type);
+ auto* str = ty->As<sem::Struct>();
+ if (!str || nested_structs.count(str)) {
+ const char* kMemberName = "inner";
+
+ // This is a non-struct or a struct that is nested somewhere else, so we
+ // need to wrap it first.
+ auto* wrapper = utils::GetOrCreate(wrapper_structs, ty, [&]() {
+ auto* block = ctx.dst->ASTNodes().Create<SpirvBlockAttribute>(ctx.dst->ID());
+ auto wrapper_name = ctx.src->Symbols().NameFor(var->symbol) + "_block";
+ auto* ret = ctx.dst->create<ast::Struct>(
+ ctx.dst->Symbols().New(wrapper_name),
+ ast::StructMemberList{ctx.dst->Member(kMemberName, CreateASTTypeFor(ctx, ty))},
+ ast::AttributeList{block});
+ ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), var, ret);
+ return ret;
+ });
+ ctx.Replace(var->type, ctx.dst->ty.Of(wrapper));
+
+ // Insert a member accessor to get the original type from the wrapper at
+ // any usage of the original variable.
+ for (auto* user : sem_var->Users()) {
+ ctx.Replace(user->Declaration(),
+ ctx.dst->MemberAccessor(ctx.Clone(var->symbol), kMemberName));
+ }
+ } else {
+ // Add a block attribute to this struct directly.
+ auto* block = ctx.dst->ASTNodes().Create<SpirvBlockAttribute>(ctx.dst->ID());
+ ctx.InsertFront(str->Declaration()->attributes, block);
+ }
}
- }
- ctx.Clone();
+ ctx.Clone();
}
-AddSpirvBlockAttribute::SpirvBlockAttribute::SpirvBlockAttribute(ProgramID pid)
- : Base(pid) {}
+AddSpirvBlockAttribute::SpirvBlockAttribute::SpirvBlockAttribute(ProgramID pid) : Base(pid) {}
AddSpirvBlockAttribute::SpirvBlockAttribute::~SpirvBlockAttribute() = default;
std::string AddSpirvBlockAttribute::SpirvBlockAttribute::InternalName() const {
- return "spirv_block";
+ return "spirv_block";
}
const AddSpirvBlockAttribute::SpirvBlockAttribute*
AddSpirvBlockAttribute::SpirvBlockAttribute::Clone(CloneContext* ctx) const {
- return ctx->dst->ASTNodes()
- .Create<AddSpirvBlockAttribute::SpirvBlockAttribute>(ctx->dst->ID());
+ return ctx->dst->ASTNodes().Create<AddSpirvBlockAttribute::SpirvBlockAttribute>(ctx->dst->ID());
}
} // namespace tint::transform
diff --git a/src/tint/transform/add_spirv_block_attribute.h b/src/tint/transform/add_spirv_block_attribute.h
index 386a341..67faaa5 100644
--- a/src/tint/transform/add_spirv_block_attribute.h
+++ b/src/tint/transform/add_spirv_block_attribute.h
@@ -27,46 +27,42 @@
/// store type of a buffer. If that structure is nested inside another structure
/// or an array, then it is wrapped inside another structure which gets the
/// `@internal(spirv_block)` attribute instead.
-class AddSpirvBlockAttribute final
- : public Castable<AddSpirvBlockAttribute, Transform> {
- public:
- /// SpirvBlockAttribute is an InternalAttribute that is used to decorate a
- // structure that needs a SPIR-V block attribute.
- class SpirvBlockAttribute final
- : public Castable<SpirvBlockAttribute, ast::InternalAttribute> {
- public:
+class AddSpirvBlockAttribute final : public Castable<AddSpirvBlockAttribute, Transform> {
+ public:
+ /// SpirvBlockAttribute is an InternalAttribute that is used to decorate a
+ // structure that needs a SPIR-V block attribute.
+ class SpirvBlockAttribute final : public Castable<SpirvBlockAttribute, ast::InternalAttribute> {
+ public:
+ /// Constructor
+ /// @param program_id the identifier of the program that owns this node
+ explicit SpirvBlockAttribute(ProgramID program_id);
+ /// Destructor
+ ~SpirvBlockAttribute() override;
+
+ /// @return a short description of the internal attribute which will be
+ /// displayed as `@internal(<name>)`
+ std::string InternalName() const override;
+
+ /// Performs a deep clone of this object using the CloneContext `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned object
+ const SpirvBlockAttribute* Clone(CloneContext* ctx) const override;
+ };
+
/// Constructor
- /// @param program_id the identifier of the program that owns this node
- explicit SpirvBlockAttribute(ProgramID program_id);
+ AddSpirvBlockAttribute();
+
/// Destructor
- ~SpirvBlockAttribute() override;
+ ~AddSpirvBlockAttribute() override;
- /// @return a short description of the internal attribute which will be
- /// displayed as `@internal(<name>)`
- std::string InternalName() const override;
-
- /// Performs a deep clone of this object using the CloneContext `ctx`.
- /// @param ctx the clone context
- /// @return the newly cloned object
- const SpirvBlockAttribute* Clone(CloneContext* ctx) const override;
- };
-
- /// Constructor
- AddSpirvBlockAttribute();
-
- /// Destructor
- ~AddSpirvBlockAttribute() override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/add_spirv_block_attribute_test.cc b/src/tint/transform/add_spirv_block_attribute_test.cc
index b68920c..14ba929 100644
--- a/src/tint/transform/add_spirv_block_attribute_test.cc
+++ b/src/tint/transform/add_spirv_block_attribute_test.cc
@@ -25,16 +25,16 @@
using AddSpirvBlockAttributeTest = TransformTest;
TEST_F(AddSpirvBlockAttributeTest, EmptyModule) {
- auto* src = "";
- auto* expect = "";
+ auto* src = "";
+ auto* expect = "";
- auto got = Run<AddSpirvBlockAttribute>(src);
+ auto got = Run<AddSpirvBlockAttribute>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, Noop_UsedForPrivateVar) {
- auto* src = R"(
+ auto* src = R"(
struct S {
f : f32,
}
@@ -46,15 +46,15 @@
p.f = 1.0;
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<AddSpirvBlockAttribute>(src);
+ auto got = Run<AddSpirvBlockAttribute>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, Noop_UsedForShaderIO) {
- auto* src = R"(
+ auto* src = R"(
struct S {
@location(0)
f : f32,
@@ -65,15 +65,15 @@
return S();
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<AddSpirvBlockAttribute>(src);
+ auto got = Run<AddSpirvBlockAttribute>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, BasicScalar) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0)
var<uniform> u : f32;
@@ -82,7 +82,7 @@
let f = u;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(spirv_block)
struct u_block {
inner : f32,
@@ -96,13 +96,13 @@
}
)";
- auto got = Run<AddSpirvBlockAttribute>(src);
+ auto got = Run<AddSpirvBlockAttribute>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, BasicArray) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0)
var<uniform> u : array<vec4<f32>, 4u>;
@@ -111,7 +111,7 @@
let a = u;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(spirv_block)
struct u_block {
inner : array<vec4<f32>, 4u>,
@@ -125,13 +125,13 @@
}
)";
- auto got = Run<AddSpirvBlockAttribute>(src);
+ auto got = Run<AddSpirvBlockAttribute>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, BasicArray_Alias) {
- auto* src = R"(
+ auto* src = R"(
type Numbers = array<vec4<f32>, 4u>;
@group(0) @binding(0)
@@ -142,7 +142,7 @@
let a = u;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
type Numbers = array<vec4<f32>, 4u>;
@internal(spirv_block)
@@ -158,13 +158,13 @@
}
)";
- auto got = Run<AddSpirvBlockAttribute>(src);
+ auto got = Run<AddSpirvBlockAttribute>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, BasicStruct) {
- auto* src = R"(
+ auto* src = R"(
struct S {
f : f32,
};
@@ -177,7 +177,7 @@
let f = u.f;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(spirv_block)
struct S {
f : f32,
@@ -191,13 +191,13 @@
}
)";
- auto got = Run<AddSpirvBlockAttribute>(src);
+ auto got = Run<AddSpirvBlockAttribute>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, Nested_OuterBuffer_InnerNotBuffer) {
- auto* src = R"(
+ auto* src = R"(
struct Inner {
f : f32,
};
@@ -214,7 +214,7 @@
let f = u.i.f;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct Inner {
f : f32,
}
@@ -232,13 +232,13 @@
}
)";
- auto got = Run<AddSpirvBlockAttribute>(src);
+ auto got = Run<AddSpirvBlockAttribute>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, Nested_OuterBuffer_InnerBuffer) {
- auto* src = R"(
+ auto* src = R"(
struct Inner {
f : f32,
};
@@ -259,7 +259,7 @@
let f1 = u1.f;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct Inner {
f : f32,
}
@@ -285,13 +285,13 @@
}
)";
- auto got = Run<AddSpirvBlockAttribute>(src);
+ auto got = Run<AddSpirvBlockAttribute>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, Nested_OuterNotBuffer_InnerBuffer) {
- auto* src = R"(
+ auto* src = R"(
struct Inner {
f : f32,
};
@@ -311,7 +311,7 @@
let f1 = u.f;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct Inner {
f : f32,
}
@@ -336,13 +336,13 @@
}
)";
- auto got = Run<AddSpirvBlockAttribute>(src);
+ auto got = Run<AddSpirvBlockAttribute>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, Nested_InnerUsedForMultipleBuffers) {
- auto* src = R"(
+ auto* src = R"(
struct Inner {
f : f32,
};
@@ -367,7 +367,7 @@
let f2 = u2.f;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct Inner {
f : f32,
}
@@ -396,13 +396,13 @@
}
)";
- auto got = Run<AddSpirvBlockAttribute>(src);
+ auto got = Run<AddSpirvBlockAttribute>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, StructInArray) {
- auto* src = R"(
+ auto* src = R"(
struct S {
f : f32,
};
@@ -416,7 +416,7 @@
let a = array<S, 4>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
f : f32,
}
@@ -435,13 +435,13 @@
}
)";
- auto got = Run<AddSpirvBlockAttribute>(src);
+ auto got = Run<AddSpirvBlockAttribute>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, StructInArray_MultipleBuffers) {
- auto* src = R"(
+ auto* src = R"(
struct S {
f : f32,
};
@@ -459,7 +459,7 @@
let a = array<S, 4>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
f : f32,
}
@@ -481,13 +481,13 @@
}
)";
- auto got = Run<AddSpirvBlockAttribute>(src);
+ auto got = Run<AddSpirvBlockAttribute>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, Aliases_Nested_OuterBuffer_InnerBuffer) {
- auto* src = R"(
+ auto* src = R"(
struct Inner {
f : f32,
};
@@ -512,7 +512,7 @@
let f1 = u1.f;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct Inner {
f : f32,
}
@@ -542,14 +542,13 @@
}
)";
- auto got = Run<AddSpirvBlockAttribute>(src);
+ auto got = Run<AddSpirvBlockAttribute>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(AddSpirvBlockAttributeTest,
- Aliases_Nested_OuterBuffer_InnerBuffer_OutOfOrder) {
- auto* src = R"(
+TEST_F(AddSpirvBlockAttributeTest, Aliases_Nested_OuterBuffer_InnerBuffer_OutOfOrder) {
+ auto* src = R"(
@stage(fragment)
fn main() {
let f0 = u0.i.f;
@@ -574,7 +573,7 @@
f : f32,
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(fragment)
fn main() {
let f0 = u0.i.f;
@@ -604,9 +603,9 @@
}
)";
- auto got = Run<AddSpirvBlockAttribute>(src);
+ auto got = Run<AddSpirvBlockAttribute>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/array_length_from_uniform.cc b/src/tint/transform/array_length_from_uniform.cc
index 52c68f2..0741ee77 100644
--- a/src/tint/transform/array_length_from_uniform.cc
+++ b/src/tint/transform/array_length_from_uniform.cc
@@ -43,187 +43,173 @@
/// sem::GlobalVariable for the storage buffer.
template <typename F>
static void IterateArrayLengthOnStorageVar(CloneContext& ctx, F&& functor) {
- auto& sem = ctx.src->Sem();
+ auto& sem = ctx.src->Sem();
- // Find all calls to the arrayLength() builtin.
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- auto* call_expr = node->As<ast::CallExpression>();
- if (!call_expr) {
- continue;
- }
-
- auto* call = sem.Get(call_expr);
- auto* builtin = call->Target()->As<sem::Builtin>();
- if (!builtin || builtin->Type() != sem::BuiltinType::kArrayLength) {
- continue;
- }
-
- // Get the storage buffer that contains the runtime array.
- // Since we require SimplifyPointers, we can assume that the arrayLength()
- // call has one of two forms:
- // arrayLength(&struct_var.array_member)
- // arrayLength(&array_var)
- auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>();
- if (!param || param->op != ast::UnaryOp::kAddressOf) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "expected form of arrayLength argument to be &array_var or "
- "&struct_var.array_member";
- break;
- }
- auto* storage_buffer_expr = param->expr;
- if (auto* accessor = param->expr->As<ast::MemberAccessorExpression>()) {
- storage_buffer_expr = accessor->structure;
- }
- auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
- if (!storage_buffer_sem) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "expected form of arrayLength argument to be &array_var or "
- "&struct_var.array_member";
- break;
- }
-
- // Get the index to use for the buffer size array.
- auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable());
- if (!var) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "storage buffer is not a global variable";
- break;
- }
- functor(call_expr, storage_buffer_sem, var);
- }
-}
-
-bool ArrayLengthFromUniform::ShouldRun(const Program* program,
- const DataMap&) const {
- for (auto* fn : program->AST().Functions()) {
- if (auto* sem_fn = program->Sem().Get(fn)) {
- for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
- if (builtin->Type() == sem::BuiltinType::kArrayLength) {
- return true;
+ // Find all calls to the arrayLength() builtin.
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ auto* call_expr = node->As<ast::CallExpression>();
+ if (!call_expr) {
+ continue;
}
- }
+
+ auto* call = sem.Get(call_expr);
+ auto* builtin = call->Target()->As<sem::Builtin>();
+ if (!builtin || builtin->Type() != sem::BuiltinType::kArrayLength) {
+ continue;
+ }
+
+ // Get the storage buffer that contains the runtime array.
+ // Since we require SimplifyPointers, we can assume that the arrayLength()
+ // call has one of two forms:
+ // arrayLength(&struct_var.array_member)
+ // arrayLength(&array_var)
+ auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>();
+ if (!param || param->op != ast::UnaryOp::kAddressOf) {
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "expected form of arrayLength argument to be &array_var or "
+ "&struct_var.array_member";
+ break;
+ }
+ auto* storage_buffer_expr = param->expr;
+ if (auto* accessor = param->expr->As<ast::MemberAccessorExpression>()) {
+ storage_buffer_expr = accessor->structure;
+ }
+ auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
+ if (!storage_buffer_sem) {
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "expected form of arrayLength argument to be &array_var or "
+ "&struct_var.array_member";
+ break;
+ }
+
+ // Get the index to use for the buffer size array.
+ auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable());
+ if (!var) {
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "storage buffer is not a global variable";
+ break;
+ }
+ functor(call_expr, storage_buffer_sem, var);
}
- }
- return false;
}
-void ArrayLengthFromUniform::Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const {
- auto* cfg = inputs.Get<Config>();
- if (cfg == nullptr) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform,
- "missing transform data for " + std::string(TypeInfo().name));
- return;
- }
+bool ArrayLengthFromUniform::ShouldRun(const Program* program, const DataMap&) const {
+ for (auto* fn : program->AST().Functions()) {
+ if (auto* sem_fn = program->Sem().Get(fn)) {
+ for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
+ if (builtin->Type() == sem::BuiltinType::kArrayLength) {
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
- const char* kBufferSizeMemberName = "buffer_size";
+void ArrayLengthFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const {
+ auto* cfg = inputs.Get<Config>();
+ if (cfg == nullptr) {
+ ctx.dst->Diagnostics().add_error(
+ diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
+ return;
+ }
- // Determine the size of the buffer size array.
- uint32_t max_buffer_size_index = 0;
+ const char* kBufferSizeMemberName = "buffer_size";
- IterateArrayLengthOnStorageVar(
- ctx, [&](const ast::CallExpression*, const sem::VariableUser*,
- const sem::GlobalVariable* var) {
+ // Determine the size of the buffer size array.
+ uint32_t max_buffer_size_index = 0;
+
+ IterateArrayLengthOnStorageVar(ctx, [&](const ast::CallExpression*, const sem::VariableUser*,
+ const sem::GlobalVariable* var) {
auto binding = var->BindingPoint();
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
- return;
+ return;
}
if (idx_itr->second > max_buffer_size_index) {
- max_buffer_size_index = idx_itr->second;
+ max_buffer_size_index = idx_itr->second;
}
- });
+ });
- // Get (or create, on first call) the uniform buffer that will receive the
- // size of each storage buffer in the module.
- const ast::Variable* buffer_size_ubo = nullptr;
- auto get_ubo = [&]() {
- if (!buffer_size_ubo) {
- // Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
- // We do this because UBOs require an element stride that is 16-byte
- // aligned.
- auto* buffer_size_struct = ctx.dst->Structure(
- ctx.dst->Sym(),
- {ctx.dst->Member(
- kBufferSizeMemberName,
- ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()),
- (max_buffer_size_index / 4) + 1))});
- buffer_size_ubo = ctx.dst->Global(
- ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct),
- ast::StorageClass::kUniform,
- ast::AttributeList{ctx.dst->GroupAndBinding(
- cfg->ubo_binding.group, cfg->ubo_binding.binding)});
- }
- return buffer_size_ubo;
- };
-
- std::unordered_set<uint32_t> used_size_indices;
-
- IterateArrayLengthOnStorageVar(
- ctx, [&](const ast::CallExpression* call_expr,
- const sem::VariableUser* storage_buffer_sem,
- const sem::GlobalVariable* var) {
- auto binding = var->BindingPoint();
- auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
- if (idx_itr == cfg->bindpoint_to_size_index.end()) {
- return;
+ // Get (or create, on first call) the uniform buffer that will receive the
+ // size of each storage buffer in the module.
+ const ast::Variable* buffer_size_ubo = nullptr;
+ auto get_ubo = [&]() {
+ if (!buffer_size_ubo) {
+ // Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
+ // We do this because UBOs require an element stride that is 16-byte
+ // aligned.
+ auto* buffer_size_struct = ctx.dst->Structure(
+ ctx.dst->Sym(),
+ {ctx.dst->Member(kBufferSizeMemberName,
+ ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()),
+ (max_buffer_size_index / 4) + 1))});
+ buffer_size_ubo = ctx.dst->Global(
+ ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct), ast::StorageClass::kUniform,
+ ast::AttributeList{
+ ctx.dst->GroupAndBinding(cfg->ubo_binding.group, cfg->ubo_binding.binding)});
}
+ return buffer_size_ubo;
+ };
- uint32_t size_index = idx_itr->second;
- used_size_indices.insert(size_index);
+ std::unordered_set<uint32_t> used_size_indices;
- // Load the total storage buffer size from the UBO.
- uint32_t array_index = size_index / 4;
- auto* vec_expr = ctx.dst->IndexAccessor(
- ctx.dst->MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName),
- array_index);
- uint32_t vec_index = size_index % 4;
- auto* total_storage_buffer_size =
- ctx.dst->IndexAccessor(vec_expr, vec_index);
+ IterateArrayLengthOnStorageVar(
+ ctx, [&](const ast::CallExpression* call_expr, const sem::VariableUser* storage_buffer_sem,
+ const sem::GlobalVariable* var) {
+ auto binding = var->BindingPoint();
+ auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
+ if (idx_itr == cfg->bindpoint_to_size_index.end()) {
+ return;
+ }
- // Calculate actual array length
- // total_storage_buffer_size - array_offset
- // array_length = ----------------------------------------
- // array_stride
- const ast::Expression* total_size = total_storage_buffer_size;
- auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
- const sem::Array* array_type = nullptr;
- if (auto* str = storage_buffer_type->As<sem::Struct>()) {
- // The variable is a struct, so subtract the byte offset of the array
- // member.
- auto* array_member_sem = str->Members().back();
- array_type = array_member_sem->Type()->As<sem::Array>();
- total_size = ctx.dst->Sub(total_storage_buffer_size,
- array_member_sem->Offset());
- } else if (auto* arr = storage_buffer_type->As<sem::Array>()) {
- array_type = arr;
- } else {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "expected form of arrayLength argument to be &array_var or "
- "&struct_var.array_member";
- return;
- }
- auto* array_length = ctx.dst->Div(total_size, array_type->Stride());
+ uint32_t size_index = idx_itr->second;
+ used_size_indices.insert(size_index);
- ctx.Replace(call_expr, array_length);
- });
+ // Load the total storage buffer size from the UBO.
+ uint32_t array_index = size_index / 4;
+ auto* vec_expr = ctx.dst->IndexAccessor(
+ ctx.dst->MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName), array_index);
+ uint32_t vec_index = size_index % 4;
+ auto* total_storage_buffer_size = ctx.dst->IndexAccessor(vec_expr, vec_index);
- ctx.Clone();
+ // Calculate actual array length
+ // total_storage_buffer_size - array_offset
+ // array_length = ----------------------------------------
+ // array_stride
+ const ast::Expression* total_size = total_storage_buffer_size;
+ auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
+ const sem::Array* array_type = nullptr;
+ if (auto* str = storage_buffer_type->As<sem::Struct>()) {
+ // The variable is a struct, so subtract the byte offset of the array
+ // member.
+ auto* array_member_sem = str->Members().back();
+ array_type = array_member_sem->Type()->As<sem::Array>();
+ total_size = ctx.dst->Sub(total_storage_buffer_size, array_member_sem->Offset());
+ } else if (auto* arr = storage_buffer_type->As<sem::Array>()) {
+ array_type = arr;
+ } else {
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "expected form of arrayLength argument to be &array_var or "
+ "&struct_var.array_member";
+ return;
+ }
+ auto* array_length = ctx.dst->Div(total_size, array_type->Stride());
- outputs.Add<Result>(used_size_indices);
+ ctx.Replace(call_expr, array_length);
+ });
+
+ ctx.Clone();
+
+ outputs.Add<Result>(used_size_indices);
}
-ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp)
- : ubo_binding(ubo_bp) {}
+ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp) : ubo_binding(ubo_bp) {}
ArrayLengthFromUniform::Config::Config(const Config&) = default;
-ArrayLengthFromUniform::Config& ArrayLengthFromUniform::Config::operator=(
- const Config&) = default;
+ArrayLengthFromUniform::Config& ArrayLengthFromUniform::Config::operator=(const Config&) = default;
ArrayLengthFromUniform::Config::~Config() = default;
-ArrayLengthFromUniform::Result::Result(
- std::unordered_set<uint32_t> used_size_indices_in)
+ArrayLengthFromUniform::Result::Result(std::unordered_set<uint32_t> used_size_indices_in)
: used_size_indices(std::move(used_size_indices_in)) {}
ArrayLengthFromUniform::Result::Result(const Result&) = default;
ArrayLengthFromUniform::Result::~Result() = default;
diff --git a/src/tint/transform/array_length_from_uniform.h b/src/tint/transform/array_length_from_uniform.h
index 9a3a5d5..c34c529 100644
--- a/src/tint/transform/array_length_from_uniform.h
+++ b/src/tint/transform/array_length_from_uniform.h
@@ -52,71 +52,67 @@
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
-class ArrayLengthFromUniform final
- : public Castable<ArrayLengthFromUniform, Transform> {
- public:
- /// Constructor
- ArrayLengthFromUniform();
- /// Destructor
- ~ArrayLengthFromUniform() override;
-
- /// Configuration options for the ArrayLengthFromUniform transform.
- struct Config final : public Castable<Data, transform::Data> {
+class ArrayLengthFromUniform final : public Castable<ArrayLengthFromUniform, Transform> {
+ public:
/// Constructor
- /// @param ubo_bp the binding point to use for the generated uniform buffer.
- explicit Config(sem::BindingPoint ubo_bp);
-
- /// Copy constructor
- Config(const Config&);
-
- /// Copy assignment
- /// @return this Config
- Config& operator=(const Config&);
-
+ ArrayLengthFromUniform();
/// Destructor
- ~Config() override;
+ ~ArrayLengthFromUniform() override;
- /// The binding point to use for the generated uniform buffer.
- sem::BindingPoint ubo_binding;
+ /// Configuration options for the ArrayLengthFromUniform transform.
+ struct Config final : public Castable<Data, transform::Data> {
+ /// Constructor
+ /// @param ubo_bp the binding point to use for the generated uniform buffer.
+ explicit Config(sem::BindingPoint ubo_bp);
- /// The mapping from binding point to the index for the buffer size lookup.
- std::unordered_map<sem::BindingPoint, uint32_t> bindpoint_to_size_index;
- };
+ /// Copy constructor
+ Config(const Config&);
- /// Information produced about what the transform did.
- /// If there were no calls to the arrayLength() builtin, then no Result will
- /// be emitted.
- struct Result final : public Castable<Result, transform::Data> {
- /// Constructor
- /// @param used_size_indices Indices into the UBO that are statically used.
- explicit Result(std::unordered_set<uint32_t> used_size_indices);
+ /// Copy assignment
+ /// @return this Config
+ Config& operator=(const Config&);
- /// Copy constructor
- Result(const Result&);
+ /// Destructor
+ ~Config() override;
- /// Destructor
- ~Result() override;
+ /// The binding point to use for the generated uniform buffer.
+ sem::BindingPoint ubo_binding;
- /// Indices into the UBO that are statically used.
- std::unordered_set<uint32_t> used_size_indices;
- };
+ /// The mapping from binding point to the index for the buffer size lookup.
+ std::unordered_map<sem::BindingPoint, uint32_t> bindpoint_to_size_index;
+ };
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
+ /// Information produced about what the transform did.
+ /// If there were no calls to the arrayLength() builtin, then no Result will
+ /// be emitted.
+ struct Result final : public Castable<Result, transform::Data> {
+ /// Constructor
+ /// @param used_size_indices Indices into the UBO that are statically used.
+ explicit Result(std::unordered_set<uint32_t> used_size_indices);
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ /// Copy constructor
+ Result(const Result&);
+
+ /// Destructor
+ ~Result() override;
+
+ /// Indices into the UBO that are statically used.
+ std::unordered_set<uint32_t> used_size_indices;
+ };
+
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
+
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/array_length_from_uniform_test.cc b/src/tint/transform/array_length_from_uniform_test.cc
index 42a334c..ee0d4b1 100644
--- a/src/tint/transform/array_length_from_uniform_test.cc
+++ b/src/tint/transform/array_length_from_uniform_test.cc
@@ -26,13 +26,13 @@
using ArrayLengthFromUniformTest = TransformTest;
TEST_F(ArrayLengthFromUniformTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
+ EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
}
TEST_F(ArrayLengthFromUniformTest, ShouldRunNoArrayLength) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
x : i32,
arr : array<i32>,
@@ -45,11 +45,11 @@
}
)";
- EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
+ EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
}
TEST_F(ArrayLengthFromUniformTest, ShouldRunWithArrayLength) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
x : i32,
arr : array<i32>,
@@ -63,11 +63,11 @@
}
)";
- EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src));
+ EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src));
}
TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
x : i32,
arr : array<i32>,
@@ -81,17 +81,17 @@
}
)";
- auto* expect =
- "error: missing transform data for "
- "tint::transform::ArrayLengthFromUniform";
+ auto* expect =
+ "error: missing transform data for "
+ "tint::transform::ArrayLengthFromUniform";
- auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src);
+ auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ArrayLengthFromUniformTest, Basic) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var<storage, read> sb : array<i32>;
@stage(compute) @workgroup_size(1)
@@ -100,7 +100,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
buffer_size : array<vec4<u32>, 1u>,
}
@@ -115,21 +115,21 @@
}
)";
- ArrayLengthFromUniform::Config cfg({0, 30u});
- cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
- DataMap data;
- data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+ DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
- auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
+ auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
- EXPECT_EQ(expect, str(got));
- EXPECT_EQ(std::unordered_set<uint32_t>({0}),
- got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
+ EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(std::unordered_set<uint32_t>({0}),
+ got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, BasicInStruct) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
x : i32,
arr : array<i32>,
@@ -143,7 +143,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
buffer_size : array<vec4<u32>, 1u>,
}
@@ -163,21 +163,21 @@
}
)";
- ArrayLengthFromUniform::Config cfg({0, 30u});
- cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
- DataMap data;
- data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+ DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
- auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
+ auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
- EXPECT_EQ(expect, str(got));
- EXPECT_EQ(std::unordered_set<uint32_t>({0}),
- got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
+ EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(std::unordered_set<uint32_t>({0}),
+ got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, MultipleStorageBuffers) {
- auto* src = R"(
+ auto* src = R"(
struct SB1 {
x : i32,
arr1 : array<i32>,
@@ -208,7 +208,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
buffer_size : array<vec4<u32>, 2u>,
}
@@ -251,25 +251,25 @@
}
)";
- ArrayLengthFromUniform::Config cfg({0, 30u});
- cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2u}, 0);
- cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{1u, 2u}, 1);
- cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{2u, 2u}, 2);
- cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{3u, 2u}, 3);
- cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{4u, 2u}, 4);
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2u}, 0);
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{1u, 2u}, 1);
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{2u, 2u}, 2);
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{3u, 2u}, 3);
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{4u, 2u}, 4);
- DataMap data;
- data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+ DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
- auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
+ auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
- EXPECT_EQ(expect, str(got));
- EXPECT_EQ(std::unordered_set<uint32_t>({0, 1, 2, 3, 4}),
- got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
+ EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(std::unordered_set<uint32_t>({0, 1, 2, 3, 4}),
+ got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, MultipleUnusedStorageBuffers) {
- auto* src = R"(
+ auto* src = R"(
struct SB1 {
x : i32,
arr1 : array<i32>,
@@ -297,7 +297,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
buffer_size : array<vec4<u32>, 1u>,
}
@@ -337,25 +337,25 @@
}
)";
- ArrayLengthFromUniform::Config cfg({0, 30u});
- cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2u}, 0);
- cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{1u, 2u}, 1);
- cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{2u, 2u}, 2);
- cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{3u, 2u}, 3);
- cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{4u, 2u}, 4);
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2u}, 0);
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{1u, 2u}, 1);
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{2u, 2u}, 2);
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{3u, 2u}, 3);
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{4u, 2u}, 4);
- DataMap data;
- data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+ DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
- auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
+ auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
- EXPECT_EQ(expect, str(got));
- EXPECT_EQ(std::unordered_set<uint32_t>({0, 2}),
- got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
+ EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(std::unordered_set<uint32_t>({0, 2}),
+ got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, NoArrayLengthCalls) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
x : i32,
arr : array<i32>,
@@ -369,20 +369,20 @@
}
)";
- ArrayLengthFromUniform::Config cfg({0, 30u});
- cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
- DataMap data;
- data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+ DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
- auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
+ auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
- EXPECT_EQ(src, str(got));
- EXPECT_EQ(got.data.Get<ArrayLengthFromUniform::Result>(), nullptr);
+ EXPECT_EQ(src, str(got));
+ EXPECT_EQ(got.data.Get<ArrayLengthFromUniform::Result>(), nullptr);
}
TEST_F(ArrayLengthFromUniformTest, MissingBindingPointToIndexMapping) {
- auto* src = R"(
+ auto* src = R"(
struct SB1 {
x : i32,
arr1 : array<i32>,
@@ -405,7 +405,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
buffer_size : array<vec4<u32>, 1u>,
}
@@ -434,21 +434,21 @@
}
)";
- ArrayLengthFromUniform::Config cfg({0, 30u});
- cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2}, 0);
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2}, 0);
- DataMap data;
- data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+ DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
- auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
+ auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
- EXPECT_EQ(expect, str(got));
- EXPECT_EQ(std::unordered_set<uint32_t>({0}),
- got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
+ EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(std::unordered_set<uint32_t>({0}),
+ got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var len : u32 = arrayLength(&sb.arr);
@@ -462,7 +462,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
buffer_size : array<vec4<u32>, 1u>,
}
@@ -482,17 +482,17 @@
}
)";
- ArrayLengthFromUniform::Config cfg({0, 30u});
- cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
- DataMap data;
- data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+ DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
- auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
+ auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
- EXPECT_EQ(expect, str(got));
- EXPECT_EQ(std::unordered_set<uint32_t>({0}),
- got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
+ EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(std::unordered_set<uint32_t>({0}),
+ got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
} // namespace
diff --git a/src/tint/transform/binding_remapper.cc b/src/tint/transform/binding_remapper.cc
index 3934b20..e3b7afd 100644
--- a/src/tint/transform/binding_remapper.cc
+++ b/src/tint/transform/binding_remapper.cc
@@ -28,9 +28,7 @@
namespace tint::transform {
-BindingRemapper::Remappings::Remappings(BindingPoints bp,
- AccessControls ac,
- bool may_collide)
+BindingRemapper::Remappings::Remappings(BindingPoints bp, AccessControls ac, bool may_collide)
: binding_points(std::move(bp)),
access_controls(std::move(ac)),
allow_collisions(may_collide) {}
@@ -42,120 +40,112 @@
BindingRemapper::~BindingRemapper() = default;
bool BindingRemapper::ShouldRun(const Program*, const DataMap& inputs) const {
- if (auto* remappings = inputs.Get<Remappings>()) {
- return !remappings->binding_points.empty() ||
- !remappings->access_controls.empty();
- }
- return false;
+ if (auto* remappings = inputs.Get<Remappings>()) {
+ return !remappings->binding_points.empty() || !remappings->access_controls.empty();
+ }
+ return false;
}
-void BindingRemapper::Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap&) const {
- auto* remappings = inputs.Get<Remappings>();
- if (!remappings) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform,
- "missing transform data for " + std::string(TypeInfo().name));
- return;
- }
-
- // A set of post-remapped binding points that need to be decorated with a
- // DisableValidationAttribute to disable binding-point-collision validation
- std::unordered_set<sem::BindingPoint> add_collision_attr;
-
- if (remappings->allow_collisions) {
- // Scan for binding point collisions generated by this transform.
- // Populate all collisions in the `add_collision_attr` set.
- for (auto* func_ast : ctx.src->AST().Functions()) {
- if (!func_ast->IsEntryPoint()) {
- continue;
- }
- auto* func = ctx.src->Sem().Get(func_ast);
- std::unordered_map<sem::BindingPoint, int> binding_point_counts;
- for (auto* var : func->TransitivelyReferencedGlobals()) {
- if (auto binding_point = var->Declaration()->BindingPoint()) {
- BindingPoint from{binding_point.group->value,
- binding_point.binding->value};
- auto bp_it = remappings->binding_points.find(from);
- if (bp_it != remappings->binding_points.end()) {
- // Remapped
- BindingPoint to = bp_it->second;
- if (binding_point_counts[to]++) {
- add_collision_attr.emplace(to);
- }
- } else {
- // No remapping
- if (binding_point_counts[from]++) {
- add_collision_attr.emplace(from);
- }
- }
- }
- }
+void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
+ auto* remappings = inputs.Get<Remappings>();
+ if (!remappings) {
+ ctx.dst->Diagnostics().add_error(
+ diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
+ return;
}
- }
- for (auto* var : ctx.src->AST().GlobalVariables()) {
- if (auto binding_point = var->BindingPoint()) {
- // The original binding point
- BindingPoint from{binding_point.group->value,
- binding_point.binding->value};
+ // A set of post-remapped binding points that need to be decorated with a
+ // DisableValidationAttribute to disable binding-point-collision validation
+ std::unordered_set<sem::BindingPoint> add_collision_attr;
- // The binding point after remapping
- BindingPoint bp = from;
-
- // Replace any group or binding attributes.
- // Note: This has to be performed *before* remapping access controls, as
- // `ctx.Clone(var->attributes)` depend on these replacements.
- auto bp_it = remappings->binding_points.find(from);
- if (bp_it != remappings->binding_points.end()) {
- BindingPoint to = bp_it->second;
- auto* new_group = ctx.dst->create<ast::GroupAttribute>(to.group);
- auto* new_binding = ctx.dst->create<ast::BindingAttribute>(to.binding);
-
- ctx.Replace(binding_point.group, new_group);
- ctx.Replace(binding_point.binding, new_binding);
- bp = to;
- }
-
- // Replace any access controls.
- auto ac_it = remappings->access_controls.find(from);
- if (ac_it != remappings->access_controls.end()) {
- ast::Access ac = ac_it->second;
- if (ac > ast::Access::kLastValid) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform,
- "invalid access mode (" +
- std::to_string(static_cast<uint32_t>(ac)) + ")");
- return;
+ if (remappings->allow_collisions) {
+ // Scan for binding point collisions generated by this transform.
+ // Populate all collisions in the `add_collision_attr` set.
+ for (auto* func_ast : ctx.src->AST().Functions()) {
+ if (!func_ast->IsEntryPoint()) {
+ continue;
+ }
+ auto* func = ctx.src->Sem().Get(func_ast);
+ std::unordered_map<sem::BindingPoint, int> binding_point_counts;
+ for (auto* var : func->TransitivelyReferencedGlobals()) {
+ if (auto binding_point = var->Declaration()->BindingPoint()) {
+ BindingPoint from{binding_point.group->value, binding_point.binding->value};
+ auto bp_it = remappings->binding_points.find(from);
+ if (bp_it != remappings->binding_points.end()) {
+ // Remapped
+ BindingPoint to = bp_it->second;
+ if (binding_point_counts[to]++) {
+ add_collision_attr.emplace(to);
+ }
+ } else {
+ // No remapping
+ if (binding_point_counts[from]++) {
+ add_collision_attr.emplace(from);
+ }
+ }
+ }
+ }
}
- auto* sem = ctx.src->Sem().Get(var);
- if (sem->StorageClass() != ast::StorageClass::kStorage) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform,
- "cannot apply access control to variable with storage class " +
- std::string(ast::ToString(sem->StorageClass())));
- return;
- }
- auto* ty = sem->Type()->UnwrapRef();
- const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty);
- auto* new_var = ctx.dst->create<ast::Variable>(
- ctx.Clone(var->source), ctx.Clone(var->symbol),
- var->declared_storage_class, ac, inner_ty, false, false,
- ctx.Clone(var->constructor), ctx.Clone(var->attributes));
- ctx.Replace(var, new_var);
- }
-
- // Add `DisableValidationAttribute`s if required
- if (add_collision_attr.count(bp)) {
- auto* attribute =
- ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
- ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute);
- }
}
- }
- ctx.Clone();
+ for (auto* var : ctx.src->AST().GlobalVariables()) {
+ if (auto binding_point = var->BindingPoint()) {
+ // The original binding point
+ BindingPoint from{binding_point.group->value, binding_point.binding->value};
+
+ // The binding point after remapping
+ BindingPoint bp = from;
+
+ // Replace any group or binding attributes.
+ // Note: This has to be performed *before* remapping access controls, as
+ // `ctx.Clone(var->attributes)` depend on these replacements.
+ auto bp_it = remappings->binding_points.find(from);
+ if (bp_it != remappings->binding_points.end()) {
+ BindingPoint to = bp_it->second;
+ auto* new_group = ctx.dst->create<ast::GroupAttribute>(to.group);
+ auto* new_binding = ctx.dst->create<ast::BindingAttribute>(to.binding);
+
+ ctx.Replace(binding_point.group, new_group);
+ ctx.Replace(binding_point.binding, new_binding);
+ bp = to;
+ }
+
+ // Replace any access controls.
+ auto ac_it = remappings->access_controls.find(from);
+ if (ac_it != remappings->access_controls.end()) {
+ ast::Access ac = ac_it->second;
+ if (ac > ast::Access::kLastValid) {
+ ctx.dst->Diagnostics().add_error(
+ diag::System::Transform,
+ "invalid access mode (" + std::to_string(static_cast<uint32_t>(ac)) + ")");
+ return;
+ }
+ auto* sem = ctx.src->Sem().Get(var);
+ if (sem->StorageClass() != ast::StorageClass::kStorage) {
+ ctx.dst->Diagnostics().add_error(
+ diag::System::Transform,
+ "cannot apply access control to variable with storage class " +
+ std::string(ast::ToString(sem->StorageClass())));
+ return;
+ }
+ auto* ty = sem->Type()->UnwrapRef();
+ const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty);
+ auto* new_var = ctx.dst->create<ast::Variable>(
+ ctx.Clone(var->source), ctx.Clone(var->symbol), var->declared_storage_class, ac,
+ inner_ty, false, false, ctx.Clone(var->constructor),
+ ctx.Clone(var->attributes));
+ ctx.Replace(var, new_var);
+ }
+
+ // Add `DisableValidationAttribute`s if required
+ if (add_collision_attr.count(bp)) {
+ auto* attribute = ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
+ ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute);
+ }
+ }
+ }
+
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/binding_remapper.h b/src/tint/transform/binding_remapper.h
index 3e9f613..77fc5bc 100644
--- a/src/tint/transform/binding_remapper.h
+++ b/src/tint/transform/binding_remapper.h
@@ -29,60 +29,57 @@
/// BindingRemapper is a transform used to remap resource binding points and
/// access controls.
class BindingRemapper final : public Castable<BindingRemapper, Transform> {
- public:
- /// BindingPoints is a map of old binding point to new binding point
- using BindingPoints = std::unordered_map<BindingPoint, BindingPoint>;
+ public:
+ /// BindingPoints is a map of old binding point to new binding point
+ using BindingPoints = std::unordered_map<BindingPoint, BindingPoint>;
- /// AccessControls is a map of old binding point to new access control
- using AccessControls = std::unordered_map<BindingPoint, ast::Access>;
+ /// AccessControls is a map of old binding point to new access control
+ using AccessControls = std::unordered_map<BindingPoint, ast::Access>;
- /// Remappings is consumed by the BindingRemapper transform.
- /// Data holds information about shader usage and constant buffer offsets.
- struct Remappings final : public Castable<Data, transform::Data> {
+ /// Remappings is consumed by the BindingRemapper transform.
+ /// Data holds information about shader usage and constant buffer offsets.
+ struct Remappings final : public Castable<Data, transform::Data> {
+ /// Constructor
+ /// @param bp a map of new binding points
+ /// @param ac a map of new access controls
+ /// @param may_collide If true, then validation will be disabled for
+ /// binding point collisions generated by this transform
+ Remappings(BindingPoints bp, AccessControls ac, bool may_collide = true);
+
+ /// Copy constructor
+ Remappings(const Remappings&);
+
+ /// Destructor
+ ~Remappings() override;
+
+ /// A map of old binding point to new binding point
+ const BindingPoints binding_points;
+
+ /// A map of old binding point to new access controls
+ const AccessControls access_controls;
+
+ /// If true, then validation will be disabled for binding point collisions
+ /// generated by this transform
+ const bool allow_collisions;
+ };
+
/// Constructor
- /// @param bp a map of new binding points
- /// @param ac a map of new access controls
- /// @param may_collide If true, then validation will be disabled for
- /// binding point collisions generated by this transform
- Remappings(BindingPoints bp, AccessControls ac, bool may_collide = true);
+ BindingRemapper();
+ ~BindingRemapper() override;
- /// Copy constructor
- Remappings(const Remappings&);
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- /// Destructor
- ~Remappings() override;
-
- /// A map of old binding point to new binding point
- const BindingPoints binding_points;
-
- /// A map of old binding point to new access controls
- const AccessControls access_controls;
-
- /// If true, then validation will be disabled for binding point collisions
- /// generated by this transform
- const bool allow_collisions;
- };
-
- /// Constructor
- BindingRemapper();
- ~BindingRemapper() override;
-
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/binding_remapper_test.cc b/src/tint/transform/binding_remapper_test.cc
index 70c7232..29a96c3 100644
--- a/src/tint/transform/binding_remapper_test.cc
+++ b/src/tint/transform/binding_remapper_test.cc
@@ -24,48 +24,48 @@
using BindingRemapperTest = TransformTest;
TEST_F(BindingRemapperTest, ShouldRunNoRemappings) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<BindingRemapper>(src));
+ EXPECT_FALSE(ShouldRun<BindingRemapper>(src));
}
TEST_F(BindingRemapperTest, ShouldRunEmptyRemappings) {
- auto* src = R"()";
+ auto* src = R"()";
- DataMap data;
- data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
- BindingRemapper::AccessControls{});
+ DataMap data;
+ data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
+ BindingRemapper::AccessControls{});
- EXPECT_FALSE(ShouldRun<BindingRemapper>(src, data));
+ EXPECT_FALSE(ShouldRun<BindingRemapper>(src, data));
}
TEST_F(BindingRemapperTest, ShouldRunBindingPointRemappings) {
- auto* src = R"()";
+ auto* src = R"()";
- DataMap data;
- data.Add<BindingRemapper::Remappings>(
- BindingRemapper::BindingPoints{
- {{2, 1}, {1, 2}},
- },
- BindingRemapper::AccessControls{});
+ DataMap data;
+ data.Add<BindingRemapper::Remappings>(
+ BindingRemapper::BindingPoints{
+ {{2, 1}, {1, 2}},
+ },
+ BindingRemapper::AccessControls{});
- EXPECT_TRUE(ShouldRun<BindingRemapper>(src, data));
+ EXPECT_TRUE(ShouldRun<BindingRemapper>(src, data));
}
TEST_F(BindingRemapperTest, ShouldRunAccessControlRemappings) {
- auto* src = R"()";
+ auto* src = R"()";
- DataMap data;
- data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
- BindingRemapper::AccessControls{
- {{2, 1}, ast::Access::kWrite},
- });
+ DataMap data;
+ data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
+ BindingRemapper::AccessControls{
+ {{2, 1}, ast::Access::kWrite},
+ });
- EXPECT_TRUE(ShouldRun<BindingRemapper>(src, data));
+ EXPECT_TRUE(ShouldRun<BindingRemapper>(src, data));
}
TEST_F(BindingRemapperTest, NoRemappings) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : f32,
}
@@ -79,18 +79,18 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- DataMap data;
- data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
- BindingRemapper::AccessControls{});
- auto got = Run<BindingRemapper>(src, data);
+ DataMap data;
+ data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
+ BindingRemapper::AccessControls{});
+ auto got = Run<BindingRemapper>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BindingRemapperTest, RemapBindingPoints) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : f32,
};
@@ -104,7 +104,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : f32,
}
@@ -118,21 +118,21 @@
}
)";
- DataMap data;
- data.Add<BindingRemapper::Remappings>(
- BindingRemapper::BindingPoints{
- {{2, 1}, {1, 2}}, // Remap
- {{4, 5}, {6, 7}}, // Not found
- // Keep @group(3) @binding(2) as is
- },
- BindingRemapper::AccessControls{});
- auto got = Run<BindingRemapper>(src, data);
+ DataMap data;
+ data.Add<BindingRemapper::Remappings>(
+ BindingRemapper::BindingPoints{
+ {{2, 1}, {1, 2}}, // Remap
+ {{4, 5}, {6, 7}}, // Not found
+ // Keep @group(3) @binding(2) as is
+ },
+ BindingRemapper::AccessControls{});
+ auto got = Run<BindingRemapper>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BindingRemapperTest, RemapAccessControls) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : f32,
};
@@ -148,7 +148,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : f32,
}
@@ -164,21 +164,21 @@
}
)";
- DataMap data;
- data.Add<BindingRemapper::Remappings>(
- BindingRemapper::BindingPoints{},
- BindingRemapper::AccessControls{
- {{2, 1}, ast::Access::kWrite}, // Modify access control
- // Keep @group(3) @binding(2) as is
- {{4, 3}, ast::Access::kRead}, // Add access control
- });
- auto got = Run<BindingRemapper>(src, data);
+ DataMap data;
+ data.Add<BindingRemapper::Remappings>(
+ BindingRemapper::BindingPoints{},
+ BindingRemapper::AccessControls{
+ {{2, 1}, ast::Access::kWrite}, // Modify access control
+ // Keep @group(3) @binding(2) as is
+ {{4, 3}, ast::Access::kRead}, // Add access control
+ });
+ auto got = Run<BindingRemapper>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BindingRemapperTest, RemapAll) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : f32,
};
@@ -192,7 +192,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : f32,
}
@@ -206,23 +206,23 @@
}
)";
- DataMap data;
- data.Add<BindingRemapper::Remappings>(
- BindingRemapper::BindingPoints{
- {{2, 1}, {4, 5}},
- {{3, 2}, {6, 7}},
- },
- BindingRemapper::AccessControls{
- {{2, 1}, ast::Access::kWrite},
- {{3, 2}, ast::Access::kWrite},
- });
- auto got = Run<BindingRemapper>(src, data);
+ DataMap data;
+ data.Add<BindingRemapper::Remappings>(
+ BindingRemapper::BindingPoints{
+ {{2, 1}, {4, 5}},
+ {{3, 2}, {6, 7}},
+ },
+ BindingRemapper::AccessControls{
+ {{2, 1}, ast::Access::kWrite},
+ {{3, 2}, ast::Access::kWrite},
+ });
+ auto got = Run<BindingRemapper>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BindingRemapperTest, BindingCollisionsSameEntryPoint) {
- auto* src = R"(
+ auto* src = R"(
struct S {
i : i32,
};
@@ -241,7 +241,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
i : i32,
}
@@ -260,21 +260,21 @@
}
)";
- DataMap data;
- data.Add<BindingRemapper::Remappings>(
- BindingRemapper::BindingPoints{
- {{2, 1}, {1, 1}},
- {{3, 2}, {1, 1}},
- {{4, 3}, {5, 4}},
- },
- BindingRemapper::AccessControls{}, true);
- auto got = Run<BindingRemapper>(src, data);
+ DataMap data;
+ data.Add<BindingRemapper::Remappings>(
+ BindingRemapper::BindingPoints{
+ {{2, 1}, {1, 1}},
+ {{3, 2}, {1, 1}},
+ {{4, 3}, {5, 4}},
+ },
+ BindingRemapper::AccessControls{}, true);
+ auto got = Run<BindingRemapper>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BindingRemapperTest, BindingCollisionsDifferentEntryPoints) {
- auto* src = R"(
+ auto* src = R"(
struct S {
i : i32,
};
@@ -298,7 +298,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
i : i32,
}
@@ -322,21 +322,21 @@
}
)";
- DataMap data;
- data.Add<BindingRemapper::Remappings>(
- BindingRemapper::BindingPoints{
- {{2, 1}, {1, 1}},
- {{3, 2}, {1, 1}},
- {{4, 3}, {5, 4}},
- },
- BindingRemapper::AccessControls{}, true);
- auto got = Run<BindingRemapper>(src, data);
+ DataMap data;
+ data.Add<BindingRemapper::Remappings>(
+ BindingRemapper::BindingPoints{
+ {{2, 1}, {1, 1}},
+ {{3, 2}, {1, 1}},
+ {{4, 3}, {5, 4}},
+ },
+ BindingRemapper::AccessControls{}, true);
+ auto got = Run<BindingRemapper>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BindingRemapperTest, NoData) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : f32,
}
@@ -350,11 +350,11 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<BindingRemapper>(src);
+ auto got = Run<BindingRemapper>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/builtin_polyfill.cc b/src/tint/transform/builtin_polyfill.cc
index 91f5131..ba4be61 100644
--- a/src/tint/transform/builtin_polyfill.cc
+++ b/src/tint/transform/builtin_polyfill.cc
@@ -28,567 +28,529 @@
/// The PIMPL state for the BuiltinPolyfill transform
struct BuiltinPolyfill::State {
- /// Constructor
- /// @param c the CloneContext
- /// @param p the builtins to polyfill
- State(CloneContext& c, Builtins p) : ctx(c), polyfill(p) {}
+ /// Constructor
+ /// @param c the CloneContext
+ /// @param p the builtins to polyfill
+ State(CloneContext& c, Builtins p) : ctx(c), polyfill(p) {}
- /// The clone context
- CloneContext& ctx;
- /// The builtins to polyfill
- Builtins polyfill;
- /// The destination program builder
- ProgramBuilder& b = *ctx.dst;
- /// The source clone context
- const sem::Info& sem = ctx.src->Sem();
+ /// The clone context
+ CloneContext& ctx;
+ /// The builtins to polyfill
+ Builtins polyfill;
+ /// The destination program builder
+ ProgramBuilder& b = *ctx.dst;
+ /// The source clone context
+ const sem::Info& sem = ctx.src->Sem();
- /// Builds the polyfill function for the `countLeadingZeros` builtin
- /// @param ty the parameter and return type for the function
- /// @return the polyfill function name
- Symbol countLeadingZeros(const sem::Type* ty) {
- auto name = b.Symbols().New("tint_count_leading_zeros");
- uint32_t width = WidthOf(ty);
+ /// Builds the polyfill function for the `countLeadingZeros` builtin
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol countLeadingZeros(const sem::Type* ty) {
+ auto name = b.Symbols().New("tint_count_leading_zeros");
+ uint32_t width = WidthOf(ty);
- // Returns either u32 or vecN<u32>
- auto U = [&]() -> const ast::Type* {
- if (width == 1) {
- return b.ty.u32();
- }
- return b.ty.vec<u32>(width);
- };
- auto V = [&](uint32_t value) -> const ast::Expression* {
- return ScalarOrVector(width, value);
- };
- b.Func(name, {b.Param("v", T(ty))}, T(ty),
- {
- // var x = U(v);
- b.Decl(b.Var("x", nullptr, b.Construct(U(), b.Expr("v")))),
- // let b16 = select(0, 16, x <= 0x0000ffff);
- b.Decl(b.Let("b16", nullptr,
- b.Call("select", V(0), V(16),
- b.LessThanEqual("x", V(0x0000ffff))))),
- // x = x << b16;
- b.Assign("x", b.Shl("x", "b16")),
- // let b8 = select(0, 8, x <= 0x00ffffff);
- b.Decl(b.Let("b8", nullptr,
- b.Call("select", V(0), V(8),
- b.LessThanEqual("x", V(0x00ffffff))))),
- // x = x << b8;
- b.Assign("x", b.Shl("x", "b8")),
- // let b4 = select(0, 4, x <= 0x0fffffff);
- b.Decl(b.Let("b4", nullptr,
- b.Call("select", V(0), V(4),
- b.LessThanEqual("x", V(0x0fffffff))))),
- // x = x << b4;
- b.Assign("x", b.Shl("x", "b4")),
- // let b2 = select(0, 2, x <= 0x3fffffff);
- b.Decl(b.Let("b2", nullptr,
- b.Call("select", V(0), V(2),
- b.LessThanEqual("x", V(0x3fffffff))))),
- // x = x << b2;
- b.Assign("x", b.Shl("x", "b2")),
- // let b1 = select(0, 1, x <= 0x7fffffff);
- b.Decl(b.Let("b1", nullptr,
- b.Call("select", V(0), V(1),
- b.LessThanEqual("x", V(0x7fffffff))))),
- // let is_zero = select(0, 1, x == 0);
- b.Decl(b.Let("is_zero", nullptr,
- b.Call("select", V(0), V(1), b.Equal("x", V(0))))),
- // return R((b16 | b8 | b4 | b2 | b1) + zero);
- b.Return(b.Construct(
- T(ty),
- b.Add(b.Or(b.Or(b.Or(b.Or("b16", "b8"), "b4"), "b2"), "b1"),
- "is_zero"))),
- });
- return name;
- }
-
- /// Builds the polyfill function for the `countTrailingZeros` builtin
- /// @param ty the parameter and return type for the function
- /// @return the polyfill function name
- Symbol countTrailingZeros(const sem::Type* ty) {
- auto name = b.Symbols().New("tint_count_trailing_zeros");
- uint32_t width = WidthOf(ty);
-
- // Returns either u32 or vecN<u32>
- auto U = [&]() -> const ast::Type* {
- if (width == 1) {
- return b.ty.u32();
- }
- return b.ty.vec<u32>(width);
- };
- auto V = [&](uint32_t value) -> const ast::Expression* {
- return ScalarOrVector(width, value);
- };
- auto B = [&](const ast::Expression* value) -> const ast::Expression* {
- if (width == 1) {
- return b.Construct<bool>(value);
- }
- return b.Construct(b.ty.vec<bool>(width), value);
- };
- b.Func(name, {b.Param("v", T(ty))}, T(ty),
- {
- // var x = U(v);
- b.Decl(b.Var("x", nullptr, b.Construct(U(), b.Expr("v")))),
- // let b16 = select(16, 0, bool(x & 0x0000ffff));
- b.Decl(b.Let("b16", nullptr,
- b.Call("select", V(16), V(0),
- B(b.And("x", V(0x0000ffff)))))),
- // x = x >> b16;
- b.Assign("x", b.Shr("x", "b16")),
- // let b8 = select(8, 0, bool(x & 0x000000ff));
- b.Decl(b.Let(
- "b8", nullptr,
- b.Call("select", V(8), V(0), B(b.And("x", V(0x000000ff)))))),
- // x = x >> b8;
- b.Assign("x", b.Shr("x", "b8")),
- // let b4 = select(4, 0, bool(x & 0x0000000f));
- b.Decl(b.Let(
- "b4", nullptr,
- b.Call("select", V(4), V(0), B(b.And("x", V(0x0000000f)))))),
- // x = x >> b4;
- b.Assign("x", b.Shr("x", "b4")),
- // let b2 = select(2, 0, bool(x & 0x00000003));
- b.Decl(b.Let(
- "b2", nullptr,
- b.Call("select", V(2), V(0), B(b.And("x", V(0x00000003)))))),
- // x = x >> b2;
- b.Assign("x", b.Shr("x", "b2")),
- // let b1 = select(1, 0, bool(x & 0x00000001));
- b.Decl(b.Let(
- "b1", nullptr,
- b.Call("select", V(1), V(0), B(b.And("x", V(0x00000001)))))),
- // let is_zero = select(0, 1, x == 0);
- b.Decl(b.Let("is_zero", nullptr,
- b.Call("select", V(0), V(1), b.Equal("x", V(0))))),
- // return R((b16 | b8 | b4 | b2 | b1) + zero);
- b.Return(b.Construct(
- T(ty),
- b.Add(b.Or(b.Or(b.Or(b.Or("b16", "b8"), "b4"), "b2"), "b1"),
- "is_zero"))),
- });
- return name;
- }
-
- /// Builds the polyfill function for the `extractBits` builtin
- /// @param ty the parameter and return type for the function
- /// @return the polyfill function name
- Symbol extractBits(const sem::Type* ty) {
- auto name = b.Symbols().New("tint_extract_bits");
- uint32_t width = WidthOf(ty);
-
- constexpr uint32_t W = 32u; // 32-bit
-
- auto vecN_u32 =
- [&](const ast::Expression* value) -> const ast::Expression* {
- if (width == 1) {
- return value;
- }
- return b.Construct(b.ty.vec<u32>(width), value);
- };
-
- ast::StatementList body = {
- b.Decl(b.Let("s", nullptr, b.Call("min", "offset", W))),
- b.Decl(b.Let("e", nullptr, b.Call("min", W, b.Add("s", "count")))),
- };
-
- switch (polyfill.extract_bits) {
- case Level::kFull:
- body.emplace_back(b.Decl(b.Let("shl", nullptr, b.Sub(W, "e"))));
- body.emplace_back(b.Decl(b.Let("shr", nullptr, b.Add("shl", "s"))));
- body.emplace_back(b.Return(b.Shr(b.Shl("v", vecN_u32(b.Expr("shl"))),
- vecN_u32(b.Expr("shr")))));
- break;
- case Level::kClampParameters:
- body.emplace_back(
- b.Return(b.Call("extractBits", "v", "s", b.Sub("e", "s"))));
- break;
- default:
- TINT_ICE(Transform, b.Diagnostics())
- << "unhandled polyfill level: "
- << static_cast<int>(polyfill.extract_bits);
- return {};
+ // Returns either u32 or vecN<u32>
+ auto U = [&]() -> const ast::Type* {
+ if (width == 1) {
+ return b.ty.u32();
+ }
+ return b.ty.vec<u32>(width);
+ };
+ auto V = [&](uint32_t value) -> const ast::Expression* {
+ return ScalarOrVector(width, value);
+ };
+ b.Func(
+ name, {b.Param("v", T(ty))}, T(ty),
+ {
+ // var x = U(v);
+ b.Decl(b.Var("x", nullptr, b.Construct(U(), b.Expr("v")))),
+ // let b16 = select(0, 16, x <= 0x0000ffff);
+ b.Decl(b.Let("b16", nullptr,
+ b.Call("select", V(0), V(16), b.LessThanEqual("x", V(0x0000ffff))))),
+ // x = x << b16;
+ b.Assign("x", b.Shl("x", "b16")),
+ // let b8 = select(0, 8, x <= 0x00ffffff);
+ b.Decl(b.Let("b8", nullptr,
+ b.Call("select", V(0), V(8), b.LessThanEqual("x", V(0x00ffffff))))),
+ // x = x << b8;
+ b.Assign("x", b.Shl("x", "b8")),
+ // let b4 = select(0, 4, x <= 0x0fffffff);
+ b.Decl(b.Let("b4", nullptr,
+ b.Call("select", V(0), V(4), b.LessThanEqual("x", V(0x0fffffff))))),
+ // x = x << b4;
+ b.Assign("x", b.Shl("x", "b4")),
+ // let b2 = select(0, 2, x <= 0x3fffffff);
+ b.Decl(b.Let("b2", nullptr,
+ b.Call("select", V(0), V(2), b.LessThanEqual("x", V(0x3fffffff))))),
+ // x = x << b2;
+ b.Assign("x", b.Shl("x", "b2")),
+ // let b1 = select(0, 1, x <= 0x7fffffff);
+ b.Decl(b.Let("b1", nullptr,
+ b.Call("select", V(0), V(1), b.LessThanEqual("x", V(0x7fffffff))))),
+ // let is_zero = select(0, 1, x == 0);
+ b.Decl(b.Let("is_zero", nullptr, b.Call("select", V(0), V(1), b.Equal("x", V(0))))),
+ // return R((b16 | b8 | b4 | b2 | b1) + zero);
+ b.Return(b.Construct(
+ T(ty),
+ b.Add(b.Or(b.Or(b.Or(b.Or("b16", "b8"), "b4"), "b2"), "b1"), "is_zero"))),
+ });
+ return name;
}
- b.Func(name,
- {
- b.Param("v", T(ty)),
- b.Param("offset", b.ty.u32()),
- b.Param("count", b.ty.u32()),
- },
- T(ty), body);
+ /// Builds the polyfill function for the `countTrailingZeros` builtin
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol countTrailingZeros(const sem::Type* ty) {
+ auto name = b.Symbols().New("tint_count_trailing_zeros");
+ uint32_t width = WidthOf(ty);
- return name;
- }
-
- /// Builds the polyfill function for the `firstLeadingBit` builtin
- /// @param ty the parameter and return type for the function
- /// @return the polyfill function name
- Symbol firstLeadingBit(const sem::Type* ty) {
- auto name = b.Symbols().New("tint_first_leading_bit");
- uint32_t width = WidthOf(ty);
-
- // Returns either u32 or vecN<u32>
- auto U = [&]() -> const ast::Type* {
- if (width == 1) {
- return b.ty.u32();
- }
- return b.ty.vec<u32>(width);
- };
- auto V = [&](uint32_t value) -> const ast::Expression* {
- return ScalarOrVector(width, value);
- };
- auto B = [&](const ast::Expression* value) -> const ast::Expression* {
- if (width == 1) {
- return b.Construct<bool>(value);
- }
- return b.Construct(b.ty.vec<bool>(width), value);
- };
-
- const ast::Expression* x = nullptr;
- if (ty->is_unsigned_scalar_or_vector()) {
- x = b.Expr("v");
- } else {
- // If ty is signed, then the value is inverted if the sign is negative
- x = b.Call("select", //
- b.Construct(U(), "v"), //
- b.Construct(U(), b.Complement("v")), //
- b.LessThan("v", ScalarOrVector(width, 0)));
+ // Returns either u32 or vecN<u32>
+ auto U = [&]() -> const ast::Type* {
+ if (width == 1) {
+ return b.ty.u32();
+ }
+ return b.ty.vec<u32>(width);
+ };
+ auto V = [&](uint32_t value) -> const ast::Expression* {
+ return ScalarOrVector(width, value);
+ };
+ auto B = [&](const ast::Expression* value) -> const ast::Expression* {
+ if (width == 1) {
+ return b.Construct<bool>(value);
+ }
+ return b.Construct(b.ty.vec<bool>(width), value);
+ };
+ b.Func(
+ name, {b.Param("v", T(ty))}, T(ty),
+ {
+ // var x = U(v);
+ b.Decl(b.Var("x", nullptr, b.Construct(U(), b.Expr("v")))),
+ // let b16 = select(16, 0, bool(x & 0x0000ffff));
+ b.Decl(b.Let("b16", nullptr,
+ b.Call("select", V(16), V(0), B(b.And("x", V(0x0000ffff)))))),
+ // x = x >> b16;
+ b.Assign("x", b.Shr("x", "b16")),
+ // let b8 = select(8, 0, bool(x & 0x000000ff));
+ b.Decl(b.Let("b8", nullptr,
+ b.Call("select", V(8), V(0), B(b.And("x", V(0x000000ff)))))),
+ // x = x >> b8;
+ b.Assign("x", b.Shr("x", "b8")),
+ // let b4 = select(4, 0, bool(x & 0x0000000f));
+ b.Decl(b.Let("b4", nullptr,
+ b.Call("select", V(4), V(0), B(b.And("x", V(0x0000000f)))))),
+ // x = x >> b4;
+ b.Assign("x", b.Shr("x", "b4")),
+ // let b2 = select(2, 0, bool(x & 0x00000003));
+ b.Decl(b.Let("b2", nullptr,
+ b.Call("select", V(2), V(0), B(b.And("x", V(0x00000003)))))),
+ // x = x >> b2;
+ b.Assign("x", b.Shr("x", "b2")),
+ // let b1 = select(1, 0, bool(x & 0x00000001));
+ b.Decl(b.Let("b1", nullptr,
+ b.Call("select", V(1), V(0), B(b.And("x", V(0x00000001)))))),
+ // let is_zero = select(0, 1, x == 0);
+ b.Decl(b.Let("is_zero", nullptr, b.Call("select", V(0), V(1), b.Equal("x", V(0))))),
+ // return R((b16 | b8 | b4 | b2 | b1) + zero);
+ b.Return(b.Construct(
+ T(ty),
+ b.Add(b.Or(b.Or(b.Or(b.Or("b16", "b8"), "b4"), "b2"), "b1"), "is_zero"))),
+ });
+ return name;
}
- b.Func(name, {b.Param("v", T(ty))}, T(ty),
- {
- // var x = v; (unsigned)
- // var x = select(U(v), ~U(v), v < 0); (signed)
- b.Decl(b.Var("x", nullptr, x)),
- // let b16 = select(0, 16, bool(x & 0xffff0000));
- b.Decl(b.Let("b16", nullptr,
- b.Call("select", V(0), V(16),
- B(b.And("x", V(0xffff0000)))))),
- // x = x >> b16;
- b.Assign("x", b.Shr("x", "b16")),
- // let b8 = select(0, 8, bool(x & 0x0000ff00));
- b.Decl(b.Let(
- "b8", nullptr,
- b.Call("select", V(0), V(8), B(b.And("x", V(0x0000ff00)))))),
- // x = x >> b8;
- b.Assign("x", b.Shr("x", "b8")),
- // let b4 = select(0, 4, bool(x & 0x000000f0));
- b.Decl(b.Let(
- "b4", nullptr,
- b.Call("select", V(0), V(4), B(b.And("x", V(0x000000f0)))))),
- // x = x >> b4;
- b.Assign("x", b.Shr("x", "b4")),
- // let b2 = select(0, 2, bool(x & 0x0000000c));
- b.Decl(b.Let(
- "b2", nullptr,
- b.Call("select", V(0), V(2), B(b.And("x", V(0x0000000c)))))),
- // x = x >> b2;
- b.Assign("x", b.Shr("x", "b2")),
- // let b1 = select(0, 1, bool(x & 0x00000002));
- b.Decl(b.Let(
- "b1", nullptr,
- b.Call("select", V(0), V(1), B(b.And("x", V(0x00000002)))))),
- // let is_zero = select(0, 0xffffffff, x == 0);
- b.Decl(b.Let(
- "is_zero", nullptr,
- b.Call("select", V(0), V(0xffffffff), b.Equal("x", V(0))))),
- // return R(b16 | b8 | b4 | b2 | b1 | zero);
- b.Return(b.Construct(
- T(ty),
- b.Or(b.Or(b.Or(b.Or(b.Or("b16", "b8"), "b4"), "b2"), "b1"),
- "is_zero"))),
- });
- return name;
- }
+ /// Builds the polyfill function for the `extractBits` builtin
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol extractBits(const sem::Type* ty) {
+ auto name = b.Symbols().New("tint_extract_bits");
+ uint32_t width = WidthOf(ty);
- /// Builds the polyfill function for the `firstTrailingBit` builtin
- /// @param ty the parameter and return type for the function
- /// @return the polyfill function name
- Symbol firstTrailingBit(const sem::Type* ty) {
- auto name = b.Symbols().New("tint_first_trailing_bit");
- uint32_t width = WidthOf(ty);
+ constexpr uint32_t W = 32u; // 32-bit
- // Returns either u32 or vecN<u32>
- auto U = [&]() -> const ast::Type* {
- if (width == 1) {
- return b.ty.u32();
- }
- return b.ty.vec<u32>(width);
- };
- auto V = [&](uint32_t value) -> const ast::Expression* {
- return ScalarOrVector(width, value);
- };
- auto B = [&](const ast::Expression* value) -> const ast::Expression* {
- if (width == 1) {
- return b.Construct<bool>(value);
- }
- return b.Construct(b.ty.vec<bool>(width), value);
- };
- b.Func(name, {b.Param("v", T(ty))}, T(ty),
- {
- // var x = U(v);
- b.Decl(b.Var("x", nullptr, b.Construct(U(), b.Expr("v")))),
- // let b16 = select(16, 0, bool(x & 0x0000ffff));
- b.Decl(b.Let("b16", nullptr,
- b.Call("select", V(16), V(0),
- B(b.And("x", V(0x0000ffff)))))),
- // x = x >> b16;
- b.Assign("x", b.Shr("x", "b16")),
- // let b8 = select(8, 0, bool(x & 0x000000ff));
- b.Decl(b.Let(
- "b8", nullptr,
- b.Call("select", V(8), V(0), B(b.And("x", V(0x000000ff)))))),
- // x = x >> b8;
- b.Assign("x", b.Shr("x", "b8")),
- // let b4 = select(4, 0, bool(x & 0x0000000f));
- b.Decl(b.Let(
- "b4", nullptr,
- b.Call("select", V(4), V(0), B(b.And("x", V(0x0000000f)))))),
- // x = x >> b4;
- b.Assign("x", b.Shr("x", "b4")),
- // let b2 = select(2, 0, bool(x & 0x00000003));
- b.Decl(b.Let(
- "b2", nullptr,
- b.Call("select", V(2), V(0), B(b.And("x", V(0x00000003)))))),
- // x = x >> b2;
- b.Assign("x", b.Shr("x", "b2")),
- // let b1 = select(1, 0, bool(x & 0x00000001));
- b.Decl(b.Let(
- "b1", nullptr,
- b.Call("select", V(1), V(0), B(b.And("x", V(0x00000001)))))),
- // let is_zero = select(0, 0xffffffff, x == 0);
- b.Decl(b.Let(
- "is_zero", nullptr,
- b.Call("select", V(0), V(0xffffffff), b.Equal("x", V(0))))),
- // return R(b16 | b8 | b4 | b2 | b1 | is_zero);
- b.Return(b.Construct(
- T(ty),
- b.Or(b.Or(b.Or(b.Or(b.Or("b16", "b8"), "b4"), "b2"), "b1"),
- "is_zero"))),
- });
- return name;
- }
+ auto vecN_u32 = [&](const ast::Expression* value) -> const ast::Expression* {
+ if (width == 1) {
+ return value;
+ }
+ return b.Construct(b.ty.vec<u32>(width), value);
+ };
- /// Builds the polyfill function for the `insertBits` builtin
- /// @param ty the parameter and return type for the function
- /// @return the polyfill function name
- Symbol insertBits(const sem::Type* ty) {
- auto name = b.Symbols().New("tint_insert_bits");
- uint32_t width = WidthOf(ty);
+ ast::StatementList body = {
+ b.Decl(b.Let("s", nullptr, b.Call("min", "offset", W))),
+ b.Decl(b.Let("e", nullptr, b.Call("min", W, b.Add("s", "count")))),
+ };
- constexpr uint32_t W = 32u; // 32-bit
+ switch (polyfill.extract_bits) {
+ case Level::kFull:
+ body.emplace_back(b.Decl(b.Let("shl", nullptr, b.Sub(W, "e"))));
+ body.emplace_back(b.Decl(b.Let("shr", nullptr, b.Add("shl", "s"))));
+ body.emplace_back(
+ b.Return(b.Shr(b.Shl("v", vecN_u32(b.Expr("shl"))), vecN_u32(b.Expr("shr")))));
+ break;
+ case Level::kClampParameters:
+ body.emplace_back(b.Return(b.Call("extractBits", "v", "s", b.Sub("e", "s"))));
+ break;
+ default:
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled polyfill level: " << static_cast<int>(polyfill.extract_bits);
+ return {};
+ }
- auto V = [&](auto value) -> const ast::Expression* {
- const ast::Expression* expr = b.Expr(value);
- if (!ty->is_unsigned_scalar_or_vector()) {
- expr = b.Construct<i32>(expr);
- }
- if (ty->Is<sem::Vector>()) {
- expr = b.Construct(T(ty), expr);
- }
- return expr;
- };
- auto U = [&](auto value) -> const ast::Expression* {
- if (width == 1) {
- return b.Expr(value);
- }
- return b.vec(b.ty.u32(), width, value);
- };
+ b.Func(name,
+ {
+ b.Param("v", T(ty)),
+ b.Param("offset", b.ty.u32()),
+ b.Param("count", b.ty.u32()),
+ },
+ T(ty), body);
- ast::StatementList body = {
- b.Decl(b.Let("s", nullptr, b.Call("min", "offset", W))),
- b.Decl(b.Let("e", nullptr, b.Call("min", W, b.Add("s", "count")))),
- };
-
- switch (polyfill.insert_bits) {
- case Level::kFull:
- // let mask = ((1 << s) - 1) ^ ((1 << e) - 1)
- body.emplace_back(b.Decl(b.Let(
- "mask", nullptr,
- b.Xor(b.Sub(b.Shl(1u, "s"), 1u), b.Sub(b.Shl(1u, "e"), 1u)))));
- // return ((n << s) & mask) | (v & ~mask)
- body.emplace_back(b.Return(b.Or(b.And(b.Shl("n", U("s")), V("mask")),
- b.And("v", V(b.Complement("mask"))))));
- break;
- case Level::kClampParameters:
- body.emplace_back(
- b.Return(b.Call("insertBits", "v", "n", "s", b.Sub("e", "s"))));
- break;
- default:
- TINT_ICE(Transform, b.Diagnostics())
- << "unhandled polyfill level: "
- << static_cast<int>(polyfill.insert_bits);
- return {};
+ return name;
}
- b.Func(name,
- {
- b.Param("v", T(ty)),
- b.Param("n", T(ty)),
- b.Param("offset", b.ty.u32()),
- b.Param("count", b.ty.u32()),
- },
- T(ty), body);
+ /// Builds the polyfill function for the `firstLeadingBit` builtin
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol firstLeadingBit(const sem::Type* ty) {
+ auto name = b.Symbols().New("tint_first_leading_bit");
+ uint32_t width = WidthOf(ty);
- return name;
- }
+ // Returns either u32 or vecN<u32>
+ auto U = [&]() -> const ast::Type* {
+ if (width == 1) {
+ return b.ty.u32();
+ }
+ return b.ty.vec<u32>(width);
+ };
+ auto V = [&](uint32_t value) -> const ast::Expression* {
+ return ScalarOrVector(width, value);
+ };
+ auto B = [&](const ast::Expression* value) -> const ast::Expression* {
+ if (width == 1) {
+ return b.Construct<bool>(value);
+ }
+ return b.Construct(b.ty.vec<bool>(width), value);
+ };
- private:
- /// Aliases
- using u32 = ProgramBuilder::u32;
- using i32 = ProgramBuilder::i32;
+ const ast::Expression* x = nullptr;
+ if (ty->is_unsigned_scalar_or_vector()) {
+ x = b.Expr("v");
+ } else {
+ // If ty is signed, then the value is inverted if the sign is negative
+ x = b.Call("select", //
+ b.Construct(U(), "v"), //
+ b.Construct(U(), b.Complement("v")), //
+ b.LessThan("v", ScalarOrVector(width, 0)));
+ }
- /// @returns the AST type for the given sem type
- const ast::Type* T(const sem::Type* ty) const {
- return CreateASTTypeFor(ctx, ty);
- }
-
- /// @returns 1 if `ty` is not a vector, otherwise the vector width
- uint32_t WidthOf(const sem::Type* ty) const {
- if (auto* v = ty->As<sem::Vector>()) {
- return v->Width();
+ b.Func(
+ name, {b.Param("v", T(ty))}, T(ty),
+ {
+ // var x = v; (unsigned)
+ // var x = select(U(v), ~U(v), v < 0); (signed)
+ b.Decl(b.Var("x", nullptr, x)),
+ // let b16 = select(0, 16, bool(x & 0xffff0000));
+ b.Decl(b.Let("b16", nullptr,
+ b.Call("select", V(0), V(16), B(b.And("x", V(0xffff0000)))))),
+ // x = x >> b16;
+ b.Assign("x", b.Shr("x", "b16")),
+ // let b8 = select(0, 8, bool(x & 0x0000ff00));
+ b.Decl(b.Let("b8", nullptr,
+ b.Call("select", V(0), V(8), B(b.And("x", V(0x0000ff00)))))),
+ // x = x >> b8;
+ b.Assign("x", b.Shr("x", "b8")),
+ // let b4 = select(0, 4, bool(x & 0x000000f0));
+ b.Decl(b.Let("b4", nullptr,
+ b.Call("select", V(0), V(4), B(b.And("x", V(0x000000f0)))))),
+ // x = x >> b4;
+ b.Assign("x", b.Shr("x", "b4")),
+ // let b2 = select(0, 2, bool(x & 0x0000000c));
+ b.Decl(b.Let("b2", nullptr,
+ b.Call("select", V(0), V(2), B(b.And("x", V(0x0000000c)))))),
+ // x = x >> b2;
+ b.Assign("x", b.Shr("x", "b2")),
+ // let b1 = select(0, 1, bool(x & 0x00000002));
+ b.Decl(b.Let("b1", nullptr,
+ b.Call("select", V(0), V(1), B(b.And("x", V(0x00000002)))))),
+ // let is_zero = select(0, 0xffffffff, x == 0);
+ b.Decl(b.Let("is_zero", nullptr,
+ b.Call("select", V(0), V(0xffffffff), b.Equal("x", V(0))))),
+ // return R(b16 | b8 | b4 | b2 | b1 | zero);
+ b.Return(b.Construct(
+ T(ty), b.Or(b.Or(b.Or(b.Or(b.Or("b16", "b8"), "b4"), "b2"), "b1"), "is_zero"))),
+ });
+ return name;
}
- return 1;
- }
- /// @returns a scalar or vector with the given width, with each element with
- /// the given value.
- template <typename T>
- const ast::Expression* ScalarOrVector(uint32_t width, T value) const {
- if (width == 1) {
- return b.Expr(value);
+ /// Builds the polyfill function for the `firstTrailingBit` builtin
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol firstTrailingBit(const sem::Type* ty) {
+ auto name = b.Symbols().New("tint_first_trailing_bit");
+ uint32_t width = WidthOf(ty);
+
+ // Returns either u32 or vecN<u32>
+ auto U = [&]() -> const ast::Type* {
+ if (width == 1) {
+ return b.ty.u32();
+ }
+ return b.ty.vec<u32>(width);
+ };
+ auto V = [&](uint32_t value) -> const ast::Expression* {
+ return ScalarOrVector(width, value);
+ };
+ auto B = [&](const ast::Expression* value) -> const ast::Expression* {
+ if (width == 1) {
+ return b.Construct<bool>(value);
+ }
+ return b.Construct(b.ty.vec<bool>(width), value);
+ };
+ b.Func(
+ name, {b.Param("v", T(ty))}, T(ty),
+ {
+ // var x = U(v);
+ b.Decl(b.Var("x", nullptr, b.Construct(U(), b.Expr("v")))),
+ // let b16 = select(16, 0, bool(x & 0x0000ffff));
+ b.Decl(b.Let("b16", nullptr,
+ b.Call("select", V(16), V(0), B(b.And("x", V(0x0000ffff)))))),
+ // x = x >> b16;
+ b.Assign("x", b.Shr("x", "b16")),
+ // let b8 = select(8, 0, bool(x & 0x000000ff));
+ b.Decl(b.Let("b8", nullptr,
+ b.Call("select", V(8), V(0), B(b.And("x", V(0x000000ff)))))),
+ // x = x >> b8;
+ b.Assign("x", b.Shr("x", "b8")),
+ // let b4 = select(4, 0, bool(x & 0x0000000f));
+ b.Decl(b.Let("b4", nullptr,
+ b.Call("select", V(4), V(0), B(b.And("x", V(0x0000000f)))))),
+ // x = x >> b4;
+ b.Assign("x", b.Shr("x", "b4")),
+ // let b2 = select(2, 0, bool(x & 0x00000003));
+ b.Decl(b.Let("b2", nullptr,
+ b.Call("select", V(2), V(0), B(b.And("x", V(0x00000003)))))),
+ // x = x >> b2;
+ b.Assign("x", b.Shr("x", "b2")),
+ // let b1 = select(1, 0, bool(x & 0x00000001));
+ b.Decl(b.Let("b1", nullptr,
+ b.Call("select", V(1), V(0), B(b.And("x", V(0x00000001)))))),
+ // let is_zero = select(0, 0xffffffff, x == 0);
+ b.Decl(b.Let("is_zero", nullptr,
+ b.Call("select", V(0), V(0xffffffff), b.Equal("x", V(0))))),
+ // return R(b16 | b8 | b4 | b2 | b1 | is_zero);
+ b.Return(b.Construct(
+ T(ty), b.Or(b.Or(b.Or(b.Or(b.Or("b16", "b8"), "b4"), "b2"), "b1"), "is_zero"))),
+ });
+ return name;
}
- return b.Construct(b.ty.vec<T>(width), value);
- }
+
+ /// Builds the polyfill function for the `insertBits` builtin
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol insertBits(const sem::Type* ty) {
+ auto name = b.Symbols().New("tint_insert_bits");
+ uint32_t width = WidthOf(ty);
+
+ constexpr uint32_t W = 32u; // 32-bit
+
+ auto V = [&](auto value) -> const ast::Expression* {
+ const ast::Expression* expr = b.Expr(value);
+ if (!ty->is_unsigned_scalar_or_vector()) {
+ expr = b.Construct<i32>(expr);
+ }
+ if (ty->Is<sem::Vector>()) {
+ expr = b.Construct(T(ty), expr);
+ }
+ return expr;
+ };
+ auto U = [&](auto value) -> const ast::Expression* {
+ if (width == 1) {
+ return b.Expr(value);
+ }
+ return b.vec(b.ty.u32(), width, value);
+ };
+
+ ast::StatementList body = {
+ b.Decl(b.Let("s", nullptr, b.Call("min", "offset", W))),
+ b.Decl(b.Let("e", nullptr, b.Call("min", W, b.Add("s", "count")))),
+ };
+
+ switch (polyfill.insert_bits) {
+ case Level::kFull:
+ // let mask = ((1 << s) - 1) ^ ((1 << e) - 1)
+ body.emplace_back(b.Decl(b.Let(
+ "mask", nullptr, b.Xor(b.Sub(b.Shl(1u, "s"), 1u), b.Sub(b.Shl(1u, "e"), 1u)))));
+ // return ((n << s) & mask) | (v & ~mask)
+ body.emplace_back(b.Return(b.Or(b.And(b.Shl("n", U("s")), V("mask")),
+ b.And("v", V(b.Complement("mask"))))));
+ break;
+ case Level::kClampParameters:
+ body.emplace_back(b.Return(b.Call("insertBits", "v", "n", "s", b.Sub("e", "s"))));
+ break;
+ default:
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled polyfill level: " << static_cast<int>(polyfill.insert_bits);
+ return {};
+ }
+
+ b.Func(name,
+ {
+ b.Param("v", T(ty)),
+ b.Param("n", T(ty)),
+ b.Param("offset", b.ty.u32()),
+ b.Param("count", b.ty.u32()),
+ },
+ T(ty), body);
+
+ return name;
+ }
+
+ private:
+ /// Aliases
+ using u32 = ProgramBuilder::u32;
+ using i32 = ProgramBuilder::i32;
+
+ /// @returns the AST type for the given sem type
+ const ast::Type* T(const sem::Type* ty) const { return CreateASTTypeFor(ctx, ty); }
+
+ /// @returns 1 if `ty` is not a vector, otherwise the vector width
+ uint32_t WidthOf(const sem::Type* ty) const {
+ if (auto* v = ty->As<sem::Vector>()) {
+ return v->Width();
+ }
+ return 1;
+ }
+
+ /// @returns a scalar or vector with the given width, with each element with
+ /// the given value.
+ template <typename T>
+ const ast::Expression* ScalarOrVector(uint32_t width, T value) const {
+ if (width == 1) {
+ return b.Expr(value);
+ }
+ return b.Construct(b.ty.vec<T>(width), value);
+ }
};
BuiltinPolyfill::BuiltinPolyfill() = default;
BuiltinPolyfill::~BuiltinPolyfill() = default;
-bool BuiltinPolyfill::ShouldRun(const Program* program,
- const DataMap& data) const {
- if (auto* cfg = data.Get<Config>()) {
- auto builtins = cfg->builtins;
- auto& sem = program->Sem();
- for (auto* node : program->ASTNodes().Objects()) {
- if (auto* call = sem.Get<sem::Call>(node)) {
- if (auto* builtin = call->Target()->As<sem::Builtin>()) {
- switch (builtin->Type()) {
- case sem::BuiltinType::kCountLeadingZeros:
- if (builtins.count_leading_zeros) {
- return true;
- }
- break;
- case sem::BuiltinType::kCountTrailingZeros:
- if (builtins.count_trailing_zeros) {
- return true;
- }
- break;
- case sem::BuiltinType::kExtractBits:
- if (builtins.extract_bits != Level::kNone) {
- return true;
- }
- break;
- case sem::BuiltinType::kFirstLeadingBit:
- if (builtins.first_leading_bit) {
- return true;
- }
- break;
- case sem::BuiltinType::kFirstTrailingBit:
- if (builtins.first_trailing_bit) {
- return true;
- }
- break;
- case sem::BuiltinType::kInsertBits:
- if (builtins.insert_bits != Level::kNone) {
- return true;
- }
- break;
- default:
- break;
- }
+bool BuiltinPolyfill::ShouldRun(const Program* program, const DataMap& data) const {
+ if (auto* cfg = data.Get<Config>()) {
+ auto builtins = cfg->builtins;
+ auto& sem = program->Sem();
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (auto* call = sem.Get<sem::Call>(node)) {
+ if (auto* builtin = call->Target()->As<sem::Builtin>()) {
+ switch (builtin->Type()) {
+ case sem::BuiltinType::kCountLeadingZeros:
+ if (builtins.count_leading_zeros) {
+ return true;
+ }
+ break;
+ case sem::BuiltinType::kCountTrailingZeros:
+ if (builtins.count_trailing_zeros) {
+ return true;
+ }
+ break;
+ case sem::BuiltinType::kExtractBits:
+ if (builtins.extract_bits != Level::kNone) {
+ return true;
+ }
+ break;
+ case sem::BuiltinType::kFirstLeadingBit:
+ if (builtins.first_leading_bit) {
+ return true;
+ }
+ break;
+ case sem::BuiltinType::kFirstTrailingBit:
+ if (builtins.first_trailing_bit) {
+ return true;
+ }
+ break;
+ case sem::BuiltinType::kInsertBits:
+ if (builtins.insert_bits != Level::kNone) {
+ return true;
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ }
}
- }
}
- }
- return false;
+ return false;
}
-void BuiltinPolyfill::Run(CloneContext& ctx,
- const DataMap& data,
- DataMap&) const {
- auto* cfg = data.Get<Config>();
- if (!cfg) {
- ctx.Clone();
- return;
- }
+void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) const {
+ auto* cfg = data.Get<Config>();
+ if (!cfg) {
+ ctx.Clone();
+ return;
+ }
- std::unordered_map<const sem::Builtin*, Symbol> polyfills;
+ std::unordered_map<const sem::Builtin*, Symbol> polyfills;
- ctx.ReplaceAll(
- [&](const ast::CallExpression* expr) -> const ast::CallExpression* {
+ ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
auto builtins = cfg->builtins;
State s{ctx, builtins};
if (auto* call = s.sem.Get<sem::Call>(expr)) {
- if (auto* builtin = call->Target()->As<sem::Builtin>()) {
- Symbol polyfill;
- switch (builtin->Type()) {
- case sem::BuiltinType::kCountLeadingZeros:
- if (builtins.count_leading_zeros) {
- polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
- return s.countLeadingZeros(builtin->ReturnType());
- });
+ if (auto* builtin = call->Target()->As<sem::Builtin>()) {
+ Symbol polyfill;
+ switch (builtin->Type()) {
+ case sem::BuiltinType::kCountLeadingZeros:
+ if (builtins.count_leading_zeros) {
+ polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
+ return s.countLeadingZeros(builtin->ReturnType());
+ });
+ }
+ break;
+ case sem::BuiltinType::kCountTrailingZeros:
+ if (builtins.count_trailing_zeros) {
+ polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
+ return s.countTrailingZeros(builtin->ReturnType());
+ });
+ }
+ break;
+ case sem::BuiltinType::kExtractBits:
+ if (builtins.extract_bits != Level::kNone) {
+ polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
+ return s.extractBits(builtin->ReturnType());
+ });
+ }
+ break;
+ case sem::BuiltinType::kFirstLeadingBit:
+ if (builtins.first_leading_bit) {
+ polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
+ return s.firstLeadingBit(builtin->ReturnType());
+ });
+ }
+ break;
+ case sem::BuiltinType::kFirstTrailingBit:
+ if (builtins.first_trailing_bit) {
+ polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
+ return s.firstTrailingBit(builtin->ReturnType());
+ });
+ }
+ break;
+ case sem::BuiltinType::kInsertBits:
+ if (builtins.insert_bits != Level::kNone) {
+ polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
+ return s.insertBits(builtin->ReturnType());
+ });
+ }
+ break;
+ default:
+ break;
}
- break;
- case sem::BuiltinType::kCountTrailingZeros:
- if (builtins.count_trailing_zeros) {
- polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
- return s.countTrailingZeros(builtin->ReturnType());
- });
+ if (polyfill.IsValid()) {
+ return s.b.Call(polyfill, ctx.Clone(call->Declaration()->args));
}
- break;
- case sem::BuiltinType::kExtractBits:
- if (builtins.extract_bits != Level::kNone) {
- polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
- return s.extractBits(builtin->ReturnType());
- });
- }
- break;
- case sem::BuiltinType::kFirstLeadingBit:
- if (builtins.first_leading_bit) {
- polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
- return s.firstLeadingBit(builtin->ReturnType());
- });
- }
- break;
- case sem::BuiltinType::kFirstTrailingBit:
- if (builtins.first_trailing_bit) {
- polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
- return s.firstTrailingBit(builtin->ReturnType());
- });
- }
- break;
- case sem::BuiltinType::kInsertBits:
- if (builtins.insert_bits != Level::kNone) {
- polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
- return s.insertBits(builtin->ReturnType());
- });
- }
- break;
- default:
- break;
}
- if (polyfill.IsValid()) {
- return s.b.Call(polyfill, ctx.Clone(call->Declaration()->args));
- }
- }
}
return nullptr;
- });
+ });
- ctx.Clone();
+ ctx.Clone();
}
BuiltinPolyfill::Config::Config(const Builtins& b) : builtins(b) {}
diff --git a/src/tint/transform/builtin_polyfill.h b/src/tint/transform/builtin_polyfill.h
index ada1015..8453189 100644
--- a/src/tint/transform/builtin_polyfill.h
+++ b/src/tint/transform/builtin_polyfill.h
@@ -21,73 +21,70 @@
/// Implements builtins for backends that do not have a native implementation.
class BuiltinPolyfill final : public Castable<BuiltinPolyfill, Transform> {
- public:
- /// Constructor
- BuiltinPolyfill();
- /// Destructor
- ~BuiltinPolyfill() override;
-
- /// Enumerator of polyfill levels
- enum class Level {
- /// No polyfill needed, supported by the backend.
- kNone,
- /// Clamp the parameters to the inner implementation.
- kClampParameters,
- /// Polyfill the entire function
- kFull,
- };
-
- /// Specifies the builtins that should be polyfilled by the transform.
- struct Builtins {
- /// Should `countLeadingZeros()` be polyfilled?
- bool count_leading_zeros = false;
- /// Should `countTrailingZeros()` be polyfilled?
- bool count_trailing_zeros = false;
- /// What level should `extractBits()` be polyfilled?
- Level extract_bits = Level::kNone;
- /// Should `firstLeadingBit()` be polyfilled?
- bool first_leading_bit = false;
- /// Should `firstTrailingBit()` be polyfilled?
- bool first_trailing_bit = false;
- /// Should `insertBits()` be polyfilled?
- Level insert_bits = Level::kNone;
- };
-
- /// Config is consumed by the BuiltinPolyfill transform.
- /// Config specifies the builtins that should be polyfilled.
- struct Config final : public Castable<Data, transform::Data> {
+ public:
/// Constructor
- /// @param b the list of builtins to polyfill
- explicit Config(const Builtins& b);
-
- /// Copy constructor
- Config(const Config&);
-
+ BuiltinPolyfill();
/// Destructor
- ~Config() override;
+ ~BuiltinPolyfill() override;
- /// The builtins to polyfill
- const Builtins builtins;
- };
+ /// Enumerator of polyfill levels
+ enum class Level {
+ /// No polyfill needed, supported by the backend.
+ kNone,
+ /// Clamp the parameters to the inner implementation.
+ kClampParameters,
+ /// Polyfill the entire function
+ kFull,
+ };
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
+ /// Specifies the builtins that should be polyfilled by the transform.
+ struct Builtins {
+ /// Should `countLeadingZeros()` be polyfilled?
+ bool count_leading_zeros = false;
+ /// Should `countTrailingZeros()` be polyfilled?
+ bool count_trailing_zeros = false;
+ /// What level should `extractBits()` be polyfilled?
+ Level extract_bits = Level::kNone;
+ /// Should `firstLeadingBit()` be polyfilled?
+ bool first_leading_bit = false;
+ /// Should `firstTrailingBit()` be polyfilled?
+ bool first_trailing_bit = false;
+ /// Should `insertBits()` be polyfilled?
+ Level insert_bits = Level::kNone;
+ };
- protected:
- struct State;
+ /// Config is consumed by the BuiltinPolyfill transform.
+ /// Config specifies the builtins that should be polyfilled.
+ struct Config final : public Castable<Data, transform::Data> {
+ /// Constructor
+ /// @param b the list of builtins to polyfill
+ explicit Config(const Builtins& b);
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ /// Copy constructor
+ Config(const Config&);
+
+ /// Destructor
+ ~Config() override;
+
+ /// The builtins to polyfill
+ const Builtins builtins;
+ };
+
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
+
+ protected:
+ struct State;
+
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/builtin_polyfill_test.cc b/src/tint/transform/builtin_polyfill_test.cc
index c5cc2c5..e3a4eae 100644
--- a/src/tint/transform/builtin_polyfill_test.cc
+++ b/src/tint/transform/builtin_polyfill_test.cc
@@ -26,51 +26,51 @@
using BuiltinPolyfillTest = TransformTest;
TEST_F(BuiltinPolyfillTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
}
TEST_F(BuiltinPolyfillTest, EmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<BuiltinPolyfill>(src);
+ auto got = Run<BuiltinPolyfill>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
////////////////////////////////////////////////////////////////////////////////
// countLeadingZeros
////////////////////////////////////////////////////////////////////////////////
DataMap polyfillCountLeadingZeros() {
- BuiltinPolyfill::Builtins builtins;
- builtins.count_leading_zeros = true;
- DataMap data;
- data.Add<BuiltinPolyfill::Config>(builtins);
- return data;
+ BuiltinPolyfill::Builtins builtins;
+ builtins.count_leading_zeros = true;
+ DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
}
TEST_F(BuiltinPolyfillTest, ShouldRunCountLeadingZeros) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
countLeadingZeros(0xf);
}
)";
- EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
- EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillCountLeadingZeros()));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillCountLeadingZeros()));
}
TEST_F(BuiltinPolyfillTest, CountLeadingZeros_i32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : i32 = countLeadingZeros(15);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_count_leading_zeros(v : i32) -> i32 {
var x = u32(v);
let b16 = select(0u, 16u, (x <= 65535u));
@@ -91,19 +91,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillCountLeadingZeros());
+ auto got = Run<BuiltinPolyfill>(src, polyfillCountLeadingZeros());
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, CountLeadingZeros_u32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : u32 = countLeadingZeros(15u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_count_leading_zeros(v : u32) -> u32 {
var x = u32(v);
let b16 = select(0u, 16u, (x <= 65535u));
@@ -124,19 +124,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillCountLeadingZeros());
+ auto got = Run<BuiltinPolyfill>(src, polyfillCountLeadingZeros());
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, CountLeadingZeros_vec3_i32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : vec3<i32> = countLeadingZeros(vec3<i32>(15));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_count_leading_zeros(v : vec3<i32>) -> vec3<i32> {
var x = vec3<u32>(v);
let b16 = select(vec3<u32>(0u), vec3<u32>(16u), (x <= vec3<u32>(65535u)));
@@ -157,19 +157,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillCountLeadingZeros());
+ auto got = Run<BuiltinPolyfill>(src, polyfillCountLeadingZeros());
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, CountLeadingZeros_vec3_u32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : vec3<u32> = countLeadingZeros(vec3<u32>(15u));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_count_leading_zeros(v : vec3<u32>) -> vec3<u32> {
var x = vec3<u32>(v);
let b16 = select(vec3<u32>(0u), vec3<u32>(16u), (x <= vec3<u32>(65535u)));
@@ -190,41 +190,41 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillCountLeadingZeros());
+ auto got = Run<BuiltinPolyfill>(src, polyfillCountLeadingZeros());
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
////////////////////////////////////////////////////////////////////////////////
// countTrailingZeros
////////////////////////////////////////////////////////////////////////////////
DataMap polyfillCountTrailingZeros() {
- BuiltinPolyfill::Builtins builtins;
- builtins.count_trailing_zeros = true;
- DataMap data;
- data.Add<BuiltinPolyfill::Config>(builtins);
- return data;
+ BuiltinPolyfill::Builtins builtins;
+ builtins.count_trailing_zeros = true;
+ DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
}
TEST_F(BuiltinPolyfillTest, ShouldRunCountTrailingZeros) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
countTrailingZeros(0xf);
}
)";
- EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
- EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillCountTrailingZeros()));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillCountTrailingZeros()));
}
TEST_F(BuiltinPolyfillTest, CountTrailingZeros_i32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : i32 = countTrailingZeros(15);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_count_trailing_zeros(v : i32) -> i32 {
var x = u32(v);
let b16 = select(16u, 0u, bool((x & 65535u)));
@@ -245,19 +245,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillCountTrailingZeros());
+ auto got = Run<BuiltinPolyfill>(src, polyfillCountTrailingZeros());
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, CountTrailingZeros_u32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : u32 = countTrailingZeros(15u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_count_trailing_zeros(v : u32) -> u32 {
var x = u32(v);
let b16 = select(16u, 0u, bool((x & 65535u)));
@@ -278,19 +278,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillCountTrailingZeros());
+ auto got = Run<BuiltinPolyfill>(src, polyfillCountTrailingZeros());
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, CountTrailingZeros_vec3_i32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : vec3<i32> = countTrailingZeros(vec3<i32>(15));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_count_trailing_zeros(v : vec3<i32>) -> vec3<i32> {
var x = vec3<u32>(v);
let b16 = select(vec3<u32>(16u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(65535u))));
@@ -311,19 +311,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillCountTrailingZeros());
+ auto got = Run<BuiltinPolyfill>(src, polyfillCountTrailingZeros());
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, CountTrailingZeros_vec3_u32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : vec3<u32> = countTrailingZeros(vec3<u32>(15u));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_count_trailing_zeros(v : vec3<u32>) -> vec3<u32> {
var x = vec3<u32>(v);
let b16 = select(vec3<u32>(16u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(65535u))));
@@ -344,46 +344,43 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillCountTrailingZeros());
+ auto got = Run<BuiltinPolyfill>(src, polyfillCountTrailingZeros());
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
////////////////////////////////////////////////////////////////////////////////
// extractBits
////////////////////////////////////////////////////////////////////////////////
DataMap polyfillExtractBits(Level level) {
- BuiltinPolyfill::Builtins builtins;
- builtins.extract_bits = level;
- DataMap data;
- data.Add<BuiltinPolyfill::Config>(builtins);
- return data;
+ BuiltinPolyfill::Builtins builtins;
+ builtins.extract_bits = level;
+ DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
}
TEST_F(BuiltinPolyfillTest, ShouldRunExtractBits) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
extractBits(1234, 5u, 6u);
}
)";
- EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
- EXPECT_FALSE(
- ShouldRun<BuiltinPolyfill>(src, polyfillExtractBits(Level::kNone)));
- EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(
- src, polyfillExtractBits(Level::kClampParameters)));
- EXPECT_TRUE(
- ShouldRun<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull)));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillExtractBits(Level::kNone)));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillExtractBits(Level::kClampParameters)));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull)));
}
TEST_F(BuiltinPolyfillTest, ExtractBits_Full_i32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : i32 = extractBits(1234, 5u, 6u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_extract_bits(v : i32, offset : u32, count : u32) -> i32 {
let s = min(offset, 32u);
let e = min(32u, (s + count));
@@ -397,19 +394,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull));
+ auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, ExtractBits_Full_u32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : u32 = extractBits(1234u, 5u, 6u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_extract_bits(v : u32, offset : u32, count : u32) -> u32 {
let s = min(offset, 32u);
let e = min(32u, (s + count));
@@ -423,19 +420,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull));
+ auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, ExtractBits_Full_vec3_i32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : vec3<i32> = extractBits(vec3<i32>(1234), 5u, 6u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_extract_bits(v : vec3<i32>, offset : u32, count : u32) -> vec3<i32> {
let s = min(offset, 32u);
let e = min(32u, (s + count));
@@ -449,19 +446,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull));
+ auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, ExtractBits_Full_vec3_u32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : vec3<u32> = extractBits(vec3<u32>(1234u), 5u, 6u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_extract_bits(v : vec3<u32>, offset : u32, count : u32) -> vec3<u32> {
let s = min(offset, 32u);
let e = min(32u, (s + count));
@@ -475,19 +472,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull));
+ auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, ExtractBits_Clamp_i32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : i32 = extractBits(1234, 5u, 6u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_extract_bits(v : i32, offset : u32, count : u32) -> i32 {
let s = min(offset, 32u);
let e = min(32u, (s + count));
@@ -499,20 +496,19 @@
}
)";
- auto got =
- Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kClampParameters));
+ auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kClampParameters));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, ExtractBits_Clamp_u32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : u32 = extractBits(1234u, 5u, 6u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_extract_bits(v : u32, offset : u32, count : u32) -> u32 {
let s = min(offset, 32u);
let e = min(32u, (s + count));
@@ -524,20 +520,19 @@
}
)";
- auto got =
- Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kClampParameters));
+ auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kClampParameters));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, ExtractBits_Clamp_vec3_i32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : vec3<i32> = extractBits(vec3<i32>(1234), 5u, 6u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_extract_bits(v : vec3<i32>, offset : u32, count : u32) -> vec3<i32> {
let s = min(offset, 32u);
let e = min(32u, (s + count));
@@ -549,20 +544,19 @@
}
)";
- auto got =
- Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kClampParameters));
+ auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kClampParameters));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, ExtractBits_Clamp_vec3_u32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : vec3<u32> = extractBits(vec3<u32>(1234u), 5u, 6u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_extract_bits(v : vec3<u32>, offset : u32, count : u32) -> vec3<u32> {
let s = min(offset, 32u);
let e = min(32u, (s + count));
@@ -574,42 +568,41 @@
}
)";
- auto got =
- Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kClampParameters));
+ auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kClampParameters));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
////////////////////////////////////////////////////////////////////////////////
// firstLeadingBit
////////////////////////////////////////////////////////////////////////////////
DataMap polyfillFirstLeadingBit() {
- BuiltinPolyfill::Builtins builtins;
- builtins.first_leading_bit = true;
- DataMap data;
- data.Add<BuiltinPolyfill::Config>(builtins);
- return data;
+ BuiltinPolyfill::Builtins builtins;
+ builtins.first_leading_bit = true;
+ DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
}
TEST_F(BuiltinPolyfillTest, ShouldRunFirstLeadingBit) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
firstLeadingBit(0xf);
}
)";
- EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
- EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillFirstLeadingBit()));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillFirstLeadingBit()));
}
TEST_F(BuiltinPolyfillTest, FirstLeadingBit_i32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : i32 = firstLeadingBit(15);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_first_leading_bit(v : i32) -> i32 {
var x = select(u32(v), u32(~(v)), (v < 0));
let b16 = select(0u, 16u, bool((x & 4294901760u)));
@@ -630,19 +623,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillFirstLeadingBit());
+ auto got = Run<BuiltinPolyfill>(src, polyfillFirstLeadingBit());
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, FirstLeadingBit_u32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : u32 = firstLeadingBit(15u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_first_leading_bit(v : u32) -> u32 {
var x = v;
let b16 = select(0u, 16u, bool((x & 4294901760u)));
@@ -663,19 +656,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillFirstLeadingBit());
+ auto got = Run<BuiltinPolyfill>(src, polyfillFirstLeadingBit());
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, FirstLeadingBit_vec3_i32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : vec3<i32> = firstLeadingBit(vec3<i32>(15));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_first_leading_bit(v : vec3<i32>) -> vec3<i32> {
var x = select(vec3<u32>(v), vec3<u32>(~(v)), (v < vec3<i32>(0)));
let b16 = select(vec3<u32>(0u), vec3<u32>(16u), vec3<bool>((x & vec3<u32>(4294901760u))));
@@ -696,19 +689,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillFirstLeadingBit());
+ auto got = Run<BuiltinPolyfill>(src, polyfillFirstLeadingBit());
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, FirstLeadingBit_vec3_u32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : vec3<u32> = firstLeadingBit(vec3<u32>(15u));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_first_leading_bit(v : vec3<u32>) -> vec3<u32> {
var x = v;
let b16 = select(vec3<u32>(0u), vec3<u32>(16u), vec3<bool>((x & vec3<u32>(4294901760u))));
@@ -729,41 +722,41 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillFirstLeadingBit());
+ auto got = Run<BuiltinPolyfill>(src, polyfillFirstLeadingBit());
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
////////////////////////////////////////////////////////////////////////////////
// firstTrailingBit
////////////////////////////////////////////////////////////////////////////////
DataMap polyfillFirstTrailingBit() {
- BuiltinPolyfill::Builtins builtins;
- builtins.first_trailing_bit = true;
- DataMap data;
- data.Add<BuiltinPolyfill::Config>(builtins);
- return data;
+ BuiltinPolyfill::Builtins builtins;
+ builtins.first_trailing_bit = true;
+ DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
}
TEST_F(BuiltinPolyfillTest, ShouldRunFirstTrailingBit) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
firstTrailingBit(0xf);
}
)";
- EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
- EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillFirstTrailingBit()));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillFirstTrailingBit()));
}
TEST_F(BuiltinPolyfillTest, FirstTrailingBit_i32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : i32 = firstTrailingBit(15);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_first_trailing_bit(v : i32) -> i32 {
var x = u32(v);
let b16 = select(16u, 0u, bool((x & 65535u)));
@@ -784,19 +777,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillFirstTrailingBit());
+ auto got = Run<BuiltinPolyfill>(src, polyfillFirstTrailingBit());
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, FirstTrailingBit_u32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : u32 = firstTrailingBit(15u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_first_trailing_bit(v : u32) -> u32 {
var x = u32(v);
let b16 = select(16u, 0u, bool((x & 65535u)));
@@ -817,19 +810,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillFirstTrailingBit());
+ auto got = Run<BuiltinPolyfill>(src, polyfillFirstTrailingBit());
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, FirstTrailingBit_vec3_i32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : vec3<i32> = firstTrailingBit(vec3<i32>(15));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_first_trailing_bit(v : vec3<i32>) -> vec3<i32> {
var x = vec3<u32>(v);
let b16 = select(vec3<u32>(16u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(65535u))));
@@ -850,19 +843,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillFirstTrailingBit());
+ auto got = Run<BuiltinPolyfill>(src, polyfillFirstTrailingBit());
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, FirstTrailingBit_vec3_u32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : vec3<u32> = firstTrailingBit(vec3<u32>(15u));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_first_trailing_bit(v : vec3<u32>) -> vec3<u32> {
var x = vec3<u32>(v);
let b16 = select(vec3<u32>(16u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(65535u))));
@@ -883,46 +876,43 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillFirstTrailingBit());
+ auto got = Run<BuiltinPolyfill>(src, polyfillFirstTrailingBit());
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
////////////////////////////////////////////////////////////////////////////////
// insertBits
////////////////////////////////////////////////////////////////////////////////
DataMap polyfillInsertBits(Level level) {
- BuiltinPolyfill::Builtins builtins;
- builtins.insert_bits = level;
- DataMap data;
- data.Add<BuiltinPolyfill::Config>(builtins);
- return data;
+ BuiltinPolyfill::Builtins builtins;
+ builtins.insert_bits = level;
+ DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
}
TEST_F(BuiltinPolyfillTest, ShouldRunInsertBits) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
insertBits(1234, 5678, 5u, 6u);
}
)";
- EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
- EXPECT_FALSE(
- ShouldRun<BuiltinPolyfill>(src, polyfillInsertBits(Level::kNone)));
- EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(
- src, polyfillInsertBits(Level::kClampParameters)));
- EXPECT_TRUE(
- ShouldRun<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull)));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillInsertBits(Level::kNone)));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillInsertBits(Level::kClampParameters)));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull)));
}
TEST_F(BuiltinPolyfillTest, InsertBits_Full_i32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : i32 = insertBits(1234, 5678, 5u, 6u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_insert_bits(v : i32, n : i32, offset : u32, count : u32) -> i32 {
let s = min(offset, 32u);
let e = min(32u, (s + count));
@@ -935,19 +925,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull));
+ auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, InsertBits_Full_u32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : u32 = insertBits(1234u, 5678u, 5u, 6u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_insert_bits(v : u32, n : u32, offset : u32, count : u32) -> u32 {
let s = min(offset, 32u);
let e = min(32u, (s + count));
@@ -960,19 +950,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull));
+ auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, InsertBits_Full_vec3_i32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : vec3<i32> = insertBits(vec3<i32>(1234), vec3<i32>(5678), 5u, 6u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_insert_bits(v : vec3<i32>, n : vec3<i32>, offset : u32, count : u32) -> vec3<i32> {
let s = min(offset, 32u);
let e = min(32u, (s + count));
@@ -985,19 +975,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull));
+ auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, InsertBits_Full_vec3_u32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : vec3<u32> = insertBits(vec3<u32>(1234u), vec3<u32>(5678u), 5u, 6u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_insert_bits(v : vec3<u32>, n : vec3<u32>, offset : u32, count : u32) -> vec3<u32> {
let s = min(offset, 32u);
let e = min(32u, (s + count));
@@ -1010,19 +1000,19 @@
}
)";
- auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull));
+ auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, InsertBits_Clamp_i32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : i32 = insertBits(1234, 5678, 5u, 6u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_insert_bits(v : i32, n : i32, offset : u32, count : u32) -> i32 {
let s = min(offset, 32u);
let e = min(32u, (s + count));
@@ -1034,20 +1024,19 @@
}
)";
- auto got =
- Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kClampParameters));
+ auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kClampParameters));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, InsertBits_Clamp_u32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : u32 = insertBits(1234u, 5678u, 5u, 6u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_insert_bits(v : u32, n : u32, offset : u32, count : u32) -> u32 {
let s = min(offset, 32u);
let e = min(32u, (s + count));
@@ -1059,20 +1048,19 @@
}
)";
- auto got =
- Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kClampParameters));
+ auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kClampParameters));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, InsertBits_Clamp_vec3_i32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : vec3<i32> = insertBits(vec3<i32>(1234), vec3<i32>(5678), 5u, 6u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_insert_bits(v : vec3<i32>, n : vec3<i32>, offset : u32, count : u32) -> vec3<i32> {
let s = min(offset, 32u);
let e = min(32u, (s + count));
@@ -1084,20 +1072,19 @@
}
)";
- auto got =
- Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kClampParameters));
+ auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kClampParameters));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, InsertBits_Clamp_vec3_u32) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let r : vec3<u32> = insertBits(vec3<u32>(1234u), vec3<u32>(5678u), 5u, 6u);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_insert_bits(v : vec3<u32>, n : vec3<u32>, offset : u32, count : u32) -> vec3<u32> {
let s = min(offset, 32u);
let e = min(32u, (s + count));
@@ -1109,10 +1096,9 @@
}
)";
- auto got =
- Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kClampParameters));
+ auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kClampParameters));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/calculate_array_length.cc b/src/tint/transform/calculate_array_length.cc
index b2ba299..c2128c7 100644
--- a/src/tint/transform/calculate_array_length.cc
+++ b/src/tint/transform/calculate_array_length.cc
@@ -31,8 +31,7 @@
#include "src/tint/utils/map.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::CalculateArrayLength);
-TINT_INSTANTIATE_TYPEINFO(
- tint::transform::CalculateArrayLength::BufferSizeIntrinsic);
+TINT_INSTANTIATE_TYPEINFO(tint::transform::CalculateArrayLength::BufferSizeIntrinsic);
namespace tint::transform {
@@ -41,201 +40,189 @@
/// ArrayUsage describes a runtime array usage.
/// It is used as a key by the array_length_by_usage map.
struct ArrayUsage {
- ast::BlockStatement const* const block;
- sem::Variable const* const buffer;
- bool operator==(const ArrayUsage& rhs) const {
- return block == rhs.block && buffer == rhs.buffer;
- }
- struct Hasher {
- inline std::size_t operator()(const ArrayUsage& u) const {
- return utils::Hash(u.block, u.buffer);
+ ast::BlockStatement const* const block;
+ sem::Variable const* const buffer;
+ bool operator==(const ArrayUsage& rhs) const {
+ return block == rhs.block && buffer == rhs.buffer;
}
- };
+ struct Hasher {
+ inline std::size_t operator()(const ArrayUsage& u) const {
+ return utils::Hash(u.block, u.buffer);
+ }
+ };
};
} // namespace
-CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic(ProgramID pid)
- : Base(pid) {}
+CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic(ProgramID pid) : Base(pid) {}
CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default;
std::string CalculateArrayLength::BufferSizeIntrinsic::InternalName() const {
- return "intrinsic_buffer_size";
+ return "intrinsic_buffer_size";
}
-const CalculateArrayLength::BufferSizeIntrinsic*
-CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const {
- return ctx->dst->ASTNodes().Create<CalculateArrayLength::BufferSizeIntrinsic>(
- ctx->dst->ID());
+const CalculateArrayLength::BufferSizeIntrinsic* CalculateArrayLength::BufferSizeIntrinsic::Clone(
+ CloneContext* ctx) const {
+ return ctx->dst->ASTNodes().Create<CalculateArrayLength::BufferSizeIntrinsic>(ctx->dst->ID());
}
CalculateArrayLength::CalculateArrayLength() = default;
CalculateArrayLength::~CalculateArrayLength() = default;
-bool CalculateArrayLength::ShouldRun(const Program* program,
- const DataMap&) const {
- for (auto* fn : program->AST().Functions()) {
- if (auto* sem_fn = program->Sem().Get(fn)) {
- for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
- if (builtin->Type() == sem::BuiltinType::kArrayLength) {
- return true;
+bool CalculateArrayLength::ShouldRun(const Program* program, const DataMap&) const {
+ for (auto* fn : program->AST().Functions()) {
+ if (auto* sem_fn = program->Sem().Get(fn)) {
+ for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
+ if (builtin->Type() == sem::BuiltinType::kArrayLength) {
+ return true;
+ }
+ }
}
- }
}
- }
- return false;
+ return false;
}
-void CalculateArrayLength::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- auto& sem = ctx.src->Sem();
+void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ auto& sem = ctx.src->Sem();
- // get_buffer_size_intrinsic() emits the function decorated with
- // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
- // [RW]ByteAddressBuffer.GetDimensions().
- std::unordered_map<const sem::Type*, Symbol> buffer_size_intrinsics;
- auto get_buffer_size_intrinsic = [&](const sem::Type* buffer_type) {
- return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
- auto name = ctx.dst->Sym();
- auto* type = CreateASTTypeFor(ctx, buffer_type);
- auto* disable_validation = ctx.dst->Disable(
- ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
- ctx.dst->AST().AddFunction(ctx.dst->create<ast::Function>(
- name,
- ast::VariableList{
- // Note: The buffer parameter requires the kStorage StorageClass
- // in order for HLSL to emit this as a ByteAddressBuffer.
- ctx.dst->create<ast::Variable>(
- ctx.dst->Sym("buffer"), ast::StorageClass::kStorage,
- ast::Access::kUndefined, type, true, false, nullptr,
- ast::AttributeList{disable_validation}),
- ctx.dst->Param("result",
- ctx.dst->ty.pointer(ctx.dst->ty.u32(),
- ast::StorageClass::kFunction)),
- },
- ctx.dst->ty.void_(), nullptr,
- ast::AttributeList{
- ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(ctx.dst->ID()),
- },
- ast::AttributeList{}));
+ // get_buffer_size_intrinsic() emits the function decorated with
+ // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
+ // [RW]ByteAddressBuffer.GetDimensions().
+ std::unordered_map<const sem::Type*, Symbol> buffer_size_intrinsics;
+ auto get_buffer_size_intrinsic = [&](const sem::Type* buffer_type) {
+ return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
+ auto name = ctx.dst->Sym();
+ auto* type = CreateASTTypeFor(ctx, buffer_type);
+ auto* disable_validation =
+ ctx.dst->Disable(ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
+ ctx.dst->AST().AddFunction(ctx.dst->create<ast::Function>(
+ name,
+ ast::VariableList{
+ // Note: The buffer parameter requires the kStorage StorageClass
+ // in order for HLSL to emit this as a ByteAddressBuffer.
+ ctx.dst->create<ast::Variable>(ctx.dst->Sym("buffer"),
+ ast::StorageClass::kStorage,
+ ast::Access::kUndefined, type, true, false,
+ nullptr, ast::AttributeList{disable_validation}),
+ ctx.dst->Param("result", ctx.dst->ty.pointer(ctx.dst->ty.u32(),
+ ast::StorageClass::kFunction)),
+ },
+ ctx.dst->ty.void_(), nullptr,
+ ast::AttributeList{
+ ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(ctx.dst->ID()),
+ },
+ ast::AttributeList{}));
- return name;
- });
- };
+ return name;
+ });
+ };
- std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher>
- array_length_by_usage;
+ std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher> array_length_by_usage;
- // Find all the arrayLength() calls...
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- if (auto* call_expr = node->As<ast::CallExpression>()) {
- auto* call = sem.Get(call_expr);
- if (auto* builtin = call->Target()->As<sem::Builtin>()) {
- if (builtin->Type() == sem::BuiltinType::kArrayLength) {
- // We're dealing with an arrayLength() call
+ // Find all the arrayLength() calls...
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* call_expr = node->As<ast::CallExpression>()) {
+ auto* call = sem.Get(call_expr);
+ if (auto* builtin = call->Target()->As<sem::Builtin>()) {
+ if (builtin->Type() == sem::BuiltinType::kArrayLength) {
+ // We're dealing with an arrayLength() call
- // A runtime-sized array can only appear as the store type of a
- // variable, or the last element of a structure (which cannot itself
- // be nested). Given that we require SimplifyPointers, we can assume
- // that the arrayLength() call has one of two forms:
- // arrayLength(&struct_var.array_member)
- // arrayLength(&array_var)
- auto* arg = call_expr->args[0];
- auto* address_of = arg->As<ast::UnaryOpExpression>();
- if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "arrayLength() expected address-of, got "
- << arg->TypeInfo().name;
- }
- auto* storage_buffer_expr = address_of->expr;
- if (auto* accessor =
- storage_buffer_expr->As<ast::MemberAccessorExpression>()) {
- storage_buffer_expr = accessor->structure;
- }
- auto* storage_buffer_sem =
- sem.Get<sem::VariableUser>(storage_buffer_expr);
- if (!storage_buffer_sem) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "expected form of arrayLength argument to be &array_var or "
- "&struct_var.array_member";
- break;
- }
- auto* storage_buffer_var = storage_buffer_sem->Variable();
- auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
+ // A runtime-sized array can only appear as the store type of a
+ // variable, or the last element of a structure (which cannot itself
+ // be nested). Given that we require SimplifyPointers, we can assume
+ // that the arrayLength() call has one of two forms:
+ // arrayLength(&struct_var.array_member)
+ // arrayLength(&array_var)
+ auto* arg = call_expr->args[0];
+ auto* address_of = arg->As<ast::UnaryOpExpression>();
+ if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "arrayLength() expected address-of, got " << arg->TypeInfo().name;
+ }
+ auto* storage_buffer_expr = address_of->expr;
+ if (auto* accessor = storage_buffer_expr->As<ast::MemberAccessorExpression>()) {
+ storage_buffer_expr = accessor->structure;
+ }
+ auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
+ if (!storage_buffer_sem) {
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "expected form of arrayLength argument to be &array_var or "
+ "&struct_var.array_member";
+ break;
+ }
+ auto* storage_buffer_var = storage_buffer_sem->Variable();
+ auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
- // Generate BufferSizeIntrinsic for this storage type if we haven't
- // already
- auto buffer_size = get_buffer_size_intrinsic(storage_buffer_type);
+ // Generate BufferSizeIntrinsic for this storage type if we haven't
+ // already
+ auto buffer_size = get_buffer_size_intrinsic(storage_buffer_type);
- // Find the current statement block
- auto* block = call->Stmt()->Block()->Declaration();
+ // Find the current statement block
+ auto* block = call->Stmt()->Block()->Declaration();
- auto array_length = utils::GetOrCreate(
- array_length_by_usage, {block, storage_buffer_var}, [&] {
- // First time this array length is used for this block.
- // Let's calculate it.
+ auto array_length =
+ utils::GetOrCreate(array_length_by_usage, {block, storage_buffer_var}, [&] {
+ // First time this array length is used for this block.
+ // Let's calculate it.
- // Construct the variable that'll hold the result of
- // RWByteAddressBuffer.GetDimensions()
- auto* buffer_size_result = ctx.dst->Decl(
- ctx.dst->Var(ctx.dst->Sym(), ctx.dst->ty.u32(),
- ast::StorageClass::kNone, ctx.dst->Expr(0u)));
+ // Construct the variable that'll hold the result of
+ // RWByteAddressBuffer.GetDimensions()
+ auto* buffer_size_result = ctx.dst->Decl(
+ ctx.dst->Var(ctx.dst->Sym(), ctx.dst->ty.u32(),
+ ast::StorageClass::kNone, ctx.dst->Expr(0u)));
- // Call storage_buffer.GetDimensions(&buffer_size_result)
- auto* call_get_dims = ctx.dst->CallStmt(ctx.dst->Call(
- // BufferSizeIntrinsic(X, ARGS...) is
- // translated to:
- // X.GetDimensions(ARGS..) by the writer
- buffer_size, ctx.Clone(storage_buffer_expr),
- ctx.dst->AddressOf(
- ctx.dst->Expr(buffer_size_result->variable->symbol))));
+ // Call storage_buffer.GetDimensions(&buffer_size_result)
+ auto* call_get_dims = ctx.dst->CallStmt(ctx.dst->Call(
+ // BufferSizeIntrinsic(X, ARGS...) is
+ // translated to:
+ // X.GetDimensions(ARGS..) by the writer
+ buffer_size, ctx.Clone(storage_buffer_expr),
+ ctx.dst->AddressOf(
+ ctx.dst->Expr(buffer_size_result->variable->symbol))));
- // Calculate actual array length
- // total_storage_buffer_size - array_offset
- // array_length = ----------------------------------------
- // array_stride
- auto name = ctx.dst->Sym();
- const ast::Expression* total_size =
- ctx.dst->Expr(buffer_size_result->variable);
- const sem::Array* array_type = nullptr;
- if (auto* str = storage_buffer_type->As<sem::Struct>()) {
- // The variable is a struct, so subtract the byte offset of
- // the array member.
- auto* array_member_sem = str->Members().back();
- array_type = array_member_sem->Type()->As<sem::Array>();
- total_size =
- ctx.dst->Sub(total_size, array_member_sem->Offset());
- } else if (auto* arr = storage_buffer_type->As<sem::Array>()) {
- array_type = arr;
- } else {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "expected form of arrayLength argument to be "
- "&array_var or &struct_var.array_member";
- return name;
+ // Calculate actual array length
+ // total_storage_buffer_size - array_offset
+ // array_length = ----------------------------------------
+ // array_stride
+ auto name = ctx.dst->Sym();
+ const ast::Expression* total_size =
+ ctx.dst->Expr(buffer_size_result->variable);
+ const sem::Array* array_type = nullptr;
+ if (auto* str = storage_buffer_type->As<sem::Struct>()) {
+ // The variable is a struct, so subtract the byte offset of
+ // the array member.
+ auto* array_member_sem = str->Members().back();
+ array_type = array_member_sem->Type()->As<sem::Array>();
+ total_size = ctx.dst->Sub(total_size, array_member_sem->Offset());
+ } else if (auto* arr = storage_buffer_type->As<sem::Array>()) {
+ array_type = arr;
+ } else {
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "expected form of arrayLength argument to be "
+ "&array_var or &struct_var.array_member";
+ return name;
+ }
+ uint32_t array_stride = array_type->Size();
+ auto* array_length_var = ctx.dst->Decl(ctx.dst->Let(
+ name, ctx.dst->ty.u32(), ctx.dst->Div(total_size, array_stride)));
+
+ // Insert the array length calculations at the top of the block
+ ctx.InsertBefore(block->statements, block->statements[0],
+ buffer_size_result);
+ ctx.InsertBefore(block->statements, block->statements[0],
+ call_get_dims);
+ ctx.InsertBefore(block->statements, block->statements[0],
+ array_length_var);
+ return name;
+ });
+
+ // Replace the call to arrayLength() with the array length variable
+ ctx.Replace(call_expr, ctx.dst->Expr(array_length));
}
- uint32_t array_stride = array_type->Size();
- auto* array_length_var = ctx.dst->Decl(
- ctx.dst->Let(name, ctx.dst->ty.u32(),
- ctx.dst->Div(total_size, array_stride)));
-
- // Insert the array length calculations at the top of the block
- ctx.InsertBefore(block->statements, block->statements[0],
- buffer_size_result);
- ctx.InsertBefore(block->statements, block->statements[0],
- call_get_dims);
- ctx.InsertBefore(block->statements, block->statements[0],
- array_length_var);
- return name;
- });
-
- // Replace the call to arrayLength() with the array length variable
- ctx.Replace(call_expr, ctx.dst->Expr(array_length));
+ }
}
- }
}
- }
- ctx.Clone();
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/calculate_array_length.h b/src/tint/transform/calculate_array_length.h
index 344f6f0..401e081 100644
--- a/src/tint/transform/calculate_array_length.h
+++ b/src/tint/transform/calculate_array_length.h
@@ -32,50 +32,45 @@
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
-class CalculateArrayLength final
- : public Castable<CalculateArrayLength, Transform> {
- public:
- /// BufferSizeIntrinsic is an InternalAttribute that's applied to intrinsic
- /// functions used to obtain the runtime size of a storage buffer.
- class BufferSizeIntrinsic final
- : public Castable<BufferSizeIntrinsic, ast::InternalAttribute> {
- public:
+class CalculateArrayLength final : public Castable<CalculateArrayLength, Transform> {
+ public:
+ /// BufferSizeIntrinsic is an InternalAttribute that's applied to intrinsic
+ /// functions used to obtain the runtime size of a storage buffer.
+ class BufferSizeIntrinsic final : public Castable<BufferSizeIntrinsic, ast::InternalAttribute> {
+ public:
+ /// Constructor
+ /// @param program_id the identifier of the program that owns this node
+ explicit BufferSizeIntrinsic(ProgramID program_id);
+ /// Destructor
+ ~BufferSizeIntrinsic() override;
+
+ /// @return "buffer_size"
+ std::string InternalName() const override;
+
+ /// Performs a deep clone of this object using the CloneContext `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned object
+ const BufferSizeIntrinsic* Clone(CloneContext* ctx) const override;
+ };
+
/// Constructor
- /// @param program_id the identifier of the program that owns this node
- explicit BufferSizeIntrinsic(ProgramID program_id);
+ CalculateArrayLength();
/// Destructor
- ~BufferSizeIntrinsic() override;
+ ~CalculateArrayLength() override;
- /// @return "buffer_size"
- std::string InternalName() const override;
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- /// Performs a deep clone of this object using the CloneContext `ctx`.
- /// @param ctx the clone context
- /// @return the newly cloned object
- const BufferSizeIntrinsic* Clone(CloneContext* ctx) const override;
- };
-
- /// Constructor
- CalculateArrayLength();
- /// Destructor
- ~CalculateArrayLength() override;
-
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/calculate_array_length_test.cc b/src/tint/transform/calculate_array_length_test.cc
index dec1698..9c7c3ac 100644
--- a/src/tint/transform/calculate_array_length_test.cc
+++ b/src/tint/transform/calculate_array_length_test.cc
@@ -24,13 +24,13 @@
using CalculateArrayLengthTest = TransformTest;
TEST_F(CalculateArrayLengthTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<CalculateArrayLength>(src));
+ EXPECT_FALSE(ShouldRun<CalculateArrayLength>(src));
}
TEST_F(CalculateArrayLengthTest, ShouldRunNoArrayLength) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
x : i32,
arr : array<i32>,
@@ -43,11 +43,11 @@
}
)";
- EXPECT_FALSE(ShouldRun<CalculateArrayLength>(src));
+ EXPECT_FALSE(ShouldRun<CalculateArrayLength>(src));
}
TEST_F(CalculateArrayLengthTest, ShouldRunWithArrayLength) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
x : i32,
arr : array<i32>,
@@ -61,11 +61,11 @@
}
)";
- EXPECT_TRUE(ShouldRun<CalculateArrayLength>(src));
+ EXPECT_TRUE(ShouldRun<CalculateArrayLength>(src));
}
TEST_F(CalculateArrayLengthTest, BasicArray) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var<storage, read> sb : array<i32>;
@stage(compute) @workgroup_size(1)
@@ -74,7 +74,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>)
@@ -89,13 +89,13 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, BasicInStruct) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
x : i32,
arr : array<i32>,
@@ -109,7 +109,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
@@ -129,13 +129,13 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, ArrayOfStruct) {
- auto* src = R"(
+ auto* src = R"(
struct S {
f : f32,
}
@@ -147,7 +147,7 @@
let len = arrayLength(&arr);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<S>, result : ptr<function, u32>)
@@ -166,13 +166,13 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, ArrayOfArrayOfStruct) {
- auto* src = R"(
+ auto* src = R"(
struct S {
f : f32,
}
@@ -184,7 +184,7 @@
let len = arrayLength(&arr);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<array<S, 4u>>, result : ptr<function, u32>)
@@ -203,13 +203,13 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, InSameBlock) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var<storage, read> sb : array<i32>;;
@stage(compute) @workgroup_size(1)
@@ -220,7 +220,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>)
@@ -237,13 +237,13 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, InSameBlock_Struct) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
x : i32,
arr : array<i32>,
@@ -259,7 +259,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
@@ -281,13 +281,13 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, Nested) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
x : i32,
arr : array<i32>,
@@ -307,7 +307,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
@@ -336,13 +336,13 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, MultipleStorageBuffers) {
- auto* src = R"(
+ auto* src = R"(
struct SB1 {
x : i32,
arr1 : array<i32>,
@@ -368,7 +368,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB1, result : ptr<function, u32>)
@@ -412,13 +412,13 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, Shadowing) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
x : i32,
arr : array<i32>,
@@ -437,8 +437,8 @@
}
)";
- auto* expect =
- R"(
+ auto* expect =
+ R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
@@ -466,13 +466,13 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var len1 : u32 = arrayLength(&(sb1.arr1));
@@ -498,7 +498,7 @@
@group(0) @binding(2) var<storage, read> sb3 : array<i32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB1, result : ptr<function, u32>)
@@ -542,9 +542,9 @@
@group(0) @binding(2) var<storage, read> sb3 : array<i32>;
)";
- auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/canonicalize_entry_point_io.cc b/src/tint/transform/canonicalize_entry_point_io.cc
index b9aabb2..06f9f6b 100644
--- a/src/tint/transform/canonicalize_entry_point_io.cc
+++ b/src/tint/transform/canonicalize_entry_point_io.cc
@@ -38,730 +38,702 @@
// Comparison function used to reorder struct members such that all members with
// location attributes appear first (ordered by location slot), followed by
// those with builtin attributes.
-bool StructMemberComparator(const ast::StructMember* a,
- const ast::StructMember* b) {
- auto* a_loc = ast::GetAttribute<ast::LocationAttribute>(a->attributes);
- auto* b_loc = ast::GetAttribute<ast::LocationAttribute>(b->attributes);
- auto* a_blt = ast::GetAttribute<ast::BuiltinAttribute>(a->attributes);
- auto* b_blt = ast::GetAttribute<ast::BuiltinAttribute>(b->attributes);
- if (a_loc) {
- if (!b_loc) {
- // `a` has location attribute and `b` does not: `a` goes first.
- return true;
+bool StructMemberComparator(const ast::StructMember* a, const ast::StructMember* b) {
+ auto* a_loc = ast::GetAttribute<ast::LocationAttribute>(a->attributes);
+ auto* b_loc = ast::GetAttribute<ast::LocationAttribute>(b->attributes);
+ auto* a_blt = ast::GetAttribute<ast::BuiltinAttribute>(a->attributes);
+ auto* b_blt = ast::GetAttribute<ast::BuiltinAttribute>(b->attributes);
+ if (a_loc) {
+ if (!b_loc) {
+ // `a` has location attribute and `b` does not: `a` goes first.
+ return true;
+ }
+ // Both have location attributes: smallest goes first.
+ return a_loc->value < b_loc->value;
+ } else {
+ if (b_loc) {
+ // `b` has location attribute and `a` does not: `b` goes first.
+ return false;
+ }
+ // Both are builtins: order doesn't matter, just use enum value.
+ return a_blt->builtin < b_blt->builtin;
}
- // Both have location attributes: smallest goes first.
- return a_loc->value < b_loc->value;
- } else {
- if (b_loc) {
- // `b` has location attribute and `a` does not: `b` goes first.
- return false;
- }
- // Both are builtins: order doesn't matter, just use enum value.
- return a_blt->builtin < b_blt->builtin;
- }
}
// Returns true if `attr` is a shader IO attribute.
bool IsShaderIOAttribute(const ast::Attribute* attr) {
- return attr->IsAnyOf<ast::BuiltinAttribute, ast::InterpolateAttribute,
- ast::InvariantAttribute, ast::LocationAttribute>();
+ return attr->IsAnyOf<ast::BuiltinAttribute, ast::InterpolateAttribute, ast::InvariantAttribute,
+ ast::LocationAttribute>();
}
// Returns true if `attrs` contains a `sample_mask` builtin.
bool HasSampleMask(const ast::AttributeList& attrs) {
- auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(attrs);
- return builtin && builtin->builtin == ast::Builtin::kSampleMask;
+ auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(attrs);
+ return builtin && builtin->builtin == ast::Builtin::kSampleMask;
}
} // namespace
/// State holds the current transform state for a single entry point.
struct CanonicalizeEntryPointIO::State {
- /// OutputValue represents a shader result that the wrapper function produces.
- struct OutputValue {
- /// The name of the output value.
- std::string name;
- /// The type of the output value.
- const ast::Type* type;
- /// The shader IO attributes.
- ast::AttributeList attributes;
- /// The value itself.
- const ast::Expression* value;
- };
-
- /// The clone context.
- CloneContext& ctx;
- /// The transform config.
- CanonicalizeEntryPointIO::Config const cfg;
- /// The entry point function (AST).
- const ast::Function* func_ast;
- /// The entry point function (SEM).
- const sem::Function* func_sem;
-
- /// The new entry point wrapper function's parameters.
- ast::VariableList wrapper_ep_parameters;
- /// The members of the wrapper function's struct parameter.
- ast::StructMemberList wrapper_struct_param_members;
- /// The name of the wrapper function's struct parameter.
- Symbol wrapper_struct_param_name;
- /// The parameters that will be passed to the original function.
- ast::ExpressionList inner_call_parameters;
- /// The members of the wrapper function's struct return type.
- ast::StructMemberList wrapper_struct_output_members;
- /// The wrapper function output values.
- std::vector<OutputValue> wrapper_output_values;
- /// The body of the wrapper function.
- ast::StatementList wrapper_body;
- /// Input names used by the entrypoint
- std::unordered_set<std::string> input_names;
-
- /// Constructor
- /// @param context the clone context
- /// @param config the transform config
- /// @param function the entry point function
- State(CloneContext& context,
- const CanonicalizeEntryPointIO::Config& config,
- const ast::Function* function)
- : ctx(context),
- cfg(config),
- func_ast(function),
- func_sem(ctx.src->Sem().Get(function)) {}
-
- /// Clones the shader IO attributes from `src`.
- /// @param src the attributes to clone
- /// @param do_interpolate whether to clone InterpolateAttribute
- /// @return the cloned attributes
- ast::AttributeList CloneShaderIOAttributes(const ast::AttributeList& src,
- bool do_interpolate) {
- ast::AttributeList new_attributes;
- for (auto* attr : src) {
- if (IsShaderIOAttribute(attr) &&
- (do_interpolate || !attr->Is<ast::InterpolateAttribute>())) {
- new_attributes.push_back(ctx.Clone(attr));
- }
- }
- return new_attributes;
- }
-
- /// Create or return a symbol for the wrapper function's struct parameter.
- /// @returns the symbol for the struct parameter
- Symbol InputStructSymbol() {
- if (!wrapper_struct_param_name.IsValid()) {
- wrapper_struct_param_name = ctx.dst->Sym();
- }
- return wrapper_struct_param_name;
- }
-
- /// Add a shader input to the entry point.
- /// @param name the name of the shader input
- /// @param type the type of the shader input
- /// @param attributes the attributes to apply to the shader input
- /// @returns an expression which evaluates to the value of the shader input
- const ast::Expression* AddInput(std::string name,
- const sem::Type* type,
- ast::AttributeList attributes) {
- auto* ast_type = CreateASTTypeFor(ctx, type);
- if (cfg.shader_style == ShaderStyle::kSpirv ||
- cfg.shader_style == ShaderStyle::kGlsl) {
- // Vulkan requires that integer user-defined fragment inputs are
- // always decorated with `Flat`.
- // TODO(crbug.com/tint/1224): Remove this once a flat interpolation
- // attribute is required for integers.
- if (type->is_integer_scalar_or_vector() &&
- ast::HasAttribute<ast::LocationAttribute>(attributes) &&
- !ast::HasAttribute<ast::InterpolateAttribute>(attributes) &&
- func_ast->PipelineStage() == ast::PipelineStage::kFragment) {
- attributes.push_back(ctx.dst->Interpolate(
- ast::InterpolationType::kFlat, ast::InterpolationSampling::kNone));
- }
-
- // Disable validation for use of the `input` storage class.
- attributes.push_back(
- ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
-
- // In GLSL, if it's a builtin, override the name with the
- // corresponding gl_ builtin name
- auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(attributes);
- if (cfg.shader_style == ShaderStyle::kGlsl && builtin) {
- name = GLSLBuiltinToString(builtin->builtin, func_ast->PipelineStage(),
- ast::StorageClass::kInput);
- }
- auto symbol = ctx.dst->Symbols().New(name);
-
- // Create the global variable and use its value for the shader input.
- const ast::Expression* value = ctx.dst->Expr(symbol);
-
- if (builtin) {
- if (cfg.shader_style == ShaderStyle::kGlsl) {
- value = FromGLSLBuiltin(builtin->builtin, value, ast_type);
- } else if (builtin->builtin == ast::Builtin::kSampleMask) {
- // Vulkan requires the type of a SampleMask builtin to be an array.
- // Declare it as array<u32, 1> and then load the first element.
- ast_type = ctx.dst->ty.array(ast_type, 1);
- value = ctx.dst->IndexAccessor(value, 0);
- }
- }
- ctx.dst->Global(symbol, ast_type, ast::StorageClass::kInput,
- std::move(attributes));
- return value;
- } else if (cfg.shader_style == ShaderStyle::kMsl &&
- ast::HasAttribute<ast::BuiltinAttribute>(attributes)) {
- // If this input is a builtin and we are targeting MSL, then add it to the
- // parameter list and pass it directly to the inner function.
- Symbol symbol = input_names.emplace(name).second
- ? ctx.dst->Symbols().Register(name)
- : ctx.dst->Symbols().New(name);
- wrapper_ep_parameters.push_back(
- ctx.dst->Param(symbol, ast_type, std::move(attributes)));
- return ctx.dst->Expr(symbol);
- } else {
- // Otherwise, move it to the new structure member list.
- Symbol symbol = input_names.emplace(name).second
- ? ctx.dst->Symbols().Register(name)
- : ctx.dst->Symbols().New(name);
- wrapper_struct_param_members.push_back(
- ctx.dst->Member(symbol, ast_type, std::move(attributes)));
- return ctx.dst->MemberAccessor(InputStructSymbol(), symbol);
- }
- }
-
- /// Add a shader output to the entry point.
- /// @param name the name of the shader output
- /// @param type the type of the shader output
- /// @param attributes the attributes to apply to the shader output
- /// @param value the value of the shader output
- void AddOutput(std::string name,
- const sem::Type* type,
- ast::AttributeList attributes,
- const ast::Expression* value) {
- // Vulkan requires that integer user-defined vertex outputs are
- // always decorated with `Flat`.
- // TODO(crbug.com/tint/1224): Remove this once a flat interpolation
- // attribute is required for integers.
- if (cfg.shader_style == ShaderStyle::kSpirv &&
- type->is_integer_scalar_or_vector() &&
- ast::HasAttribute<ast::LocationAttribute>(attributes) &&
- !ast::HasAttribute<ast::InterpolateAttribute>(attributes) &&
- func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
- attributes.push_back(ctx.dst->Interpolate(
- ast::InterpolationType::kFlat, ast::InterpolationSampling::kNone));
- }
-
- // In GLSL, if it's a builtin, override the name with the
- // corresponding gl_ builtin name
- if (cfg.shader_style == ShaderStyle::kGlsl) {
- if (auto* b = ast::GetAttribute<ast::BuiltinAttribute>(attributes)) {
- name = GLSLBuiltinToString(b->builtin, func_ast->PipelineStage(),
- ast::StorageClass::kOutput);
- value = ToGLSLBuiltin(b->builtin, value, type);
- }
- }
-
- OutputValue output;
- output.name = name;
- output.type = CreateASTTypeFor(ctx, type);
- output.attributes = std::move(attributes);
- output.value = value;
- wrapper_output_values.push_back(output);
- }
-
- /// Process a non-struct parameter.
- /// This creates a new object for the shader input, moving the shader IO
- /// attributes to it. It also adds an expression to the list of parameters
- /// that will be passed to the original function.
- /// @param param the original function parameter
- void ProcessNonStructParameter(const sem::Parameter* param) {
- // Remove the shader IO attributes from the inner function parameter, and
- // attach them to the new object instead.
- ast::AttributeList attributes;
- for (auto* attr : param->Declaration()->attributes) {
- if (IsShaderIOAttribute(attr)) {
- ctx.Remove(param->Declaration()->attributes, attr);
- attributes.push_back(ctx.Clone(attr));
- }
- }
-
- auto name = ctx.src->Symbols().NameFor(param->Declaration()->symbol);
- auto* input_expr = AddInput(name, param->Type(), std::move(attributes));
- inner_call_parameters.push_back(input_expr);
- }
-
- /// Process a struct parameter.
- /// This creates new objects for each struct member, moving the shader IO
- /// attributes to them. It also creates the structure that will be passed to
- /// the original function.
- /// @param param the original function parameter
- void ProcessStructParameter(const sem::Parameter* param) {
- auto* str = param->Type()->As<sem::Struct>();
-
- // Recreate struct members in the outer entry point and build an initializer
- // list to pass them through to the inner function.
- ast::ExpressionList inner_struct_values;
- for (auto* member : str->Members()) {
- if (member->Type()->Is<sem::Struct>()) {
- TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
- continue;
- }
-
- auto* member_ast = member->Declaration();
- auto name = ctx.src->Symbols().NameFor(member_ast->symbol);
-
- // In GLSL, do not add interpolation attributes on vertex input
- bool do_interpolate = true;
- if (cfg.shader_style == ShaderStyle::kGlsl &&
- func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
- do_interpolate = false;
- }
- auto attributes =
- CloneShaderIOAttributes(member_ast->attributes, do_interpolate);
- auto* input_expr = AddInput(name, member->Type(), std::move(attributes));
- inner_struct_values.push_back(input_expr);
- }
-
- // Construct the original structure using the new shader input objects.
- inner_call_parameters.push_back(ctx.dst->Construct(
- ctx.Clone(param->Declaration()->type), inner_struct_values));
- }
-
- /// Process the entry point return type.
- /// This generates a list of output values that are returned by the original
- /// function.
- /// @param inner_ret_type the original function return type
- /// @param original_result the result object produced by the original function
- void ProcessReturnType(const sem::Type* inner_ret_type,
- Symbol original_result) {
- bool do_interpolate = true;
- // In GLSL, do not add interpolation attributes on fragment output
- if (cfg.shader_style == ShaderStyle::kGlsl &&
- func_ast->PipelineStage() == ast::PipelineStage::kFragment) {
- do_interpolate = false;
- }
- if (auto* str = inner_ret_type->As<sem::Struct>()) {
- for (auto* member : str->Members()) {
- if (member->Type()->Is<sem::Struct>()) {
- TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
- continue;
- }
-
- auto* member_ast = member->Declaration();
- auto name = ctx.src->Symbols().NameFor(member_ast->symbol);
- auto attributes =
- CloneShaderIOAttributes(member_ast->attributes, do_interpolate);
-
- // Extract the original structure member.
- AddOutput(name, member->Type(), std::move(attributes),
- ctx.dst->MemberAccessor(original_result, name));
- }
- } else if (!inner_ret_type->Is<sem::Void>()) {
- auto attributes = CloneShaderIOAttributes(
- func_ast->return_type_attributes, do_interpolate);
-
- // Propagate the non-struct return value as is.
- AddOutput("value", func_sem->ReturnType(), std::move(attributes),
- ctx.dst->Expr(original_result));
- }
- }
-
- /// Add a fixed sample mask to the wrapper function output.
- /// If there is already a sample mask, bitwise-and it with the fixed mask.
- /// Otherwise, create a new output value from the fixed mask.
- void AddFixedSampleMask() {
- // Check the existing output values for a sample mask builtin.
- for (auto& outval : wrapper_output_values) {
- if (HasSampleMask(outval.attributes)) {
- // Combine the authored sample mask with the fixed mask.
- outval.value = ctx.dst->And(outval.value, cfg.fixed_sample_mask);
- return;
- }
- }
-
- // No existing sample mask builtin was found, so create a new output value
- // using the fixed sample mask.
- AddOutput("fixed_sample_mask", ctx.dst->create<sem::U32>(),
- {ctx.dst->Builtin(ast::Builtin::kSampleMask)},
- ctx.dst->Expr(cfg.fixed_sample_mask));
- }
-
- /// Add a point size builtin to the wrapper function output.
- void AddVertexPointSize() {
- // Create a new output value and assign it a literal 1.0 value.
- AddOutput("vertex_point_size", ctx.dst->create<sem::F32>(),
- {ctx.dst->Builtin(ast::Builtin::kPointSize)}, ctx.dst->Expr(1.f));
- }
-
- /// Create an expression for gl_Position.[component]
- /// @param component the component of gl_Position to access
- /// @returns the new expression
- const ast::Expression* GLPosition(const char* component) {
- Symbol pos = ctx.dst->Symbols().Register("gl_Position");
- Symbol c = ctx.dst->Symbols().Register(component);
- return ctx.dst->MemberAccessor(ctx.dst->Expr(pos), ctx.dst->Expr(c));
- }
-
- /// Create the wrapper function's struct parameter and type objects.
- void CreateInputStruct() {
- // Sort the struct members to satisfy HLSL interfacing matching rules.
- std::sort(wrapper_struct_param_members.begin(),
- wrapper_struct_param_members.end(), StructMemberComparator);
-
- // Create the new struct type.
- auto struct_name = ctx.dst->Sym();
- auto* in_struct = ctx.dst->create<ast::Struct>(
- struct_name, wrapper_struct_param_members, ast::AttributeList{});
- ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct);
-
- // Create a new function parameter using this struct type.
- auto* param =
- ctx.dst->Param(InputStructSymbol(), ctx.dst->ty.type_name(struct_name));
- wrapper_ep_parameters.push_back(param);
- }
-
- /// Create and return the wrapper function's struct result object.
- /// @returns the struct type
- ast::Struct* CreateOutputStruct() {
- ast::StatementList assignments;
-
- auto wrapper_result = ctx.dst->Symbols().New("wrapper_result");
-
- // Create the struct members and their corresponding assignment statements.
- std::unordered_set<std::string> member_names;
- for (auto& outval : wrapper_output_values) {
- // Use the original output name, unless that is already taken.
- Symbol name;
- if (member_names.count(outval.name)) {
- name = ctx.dst->Symbols().New(outval.name);
- } else {
- name = ctx.dst->Symbols().Register(outval.name);
- }
- member_names.insert(ctx.dst->Symbols().NameFor(name));
-
- wrapper_struct_output_members.push_back(
- ctx.dst->Member(name, outval.type, std::move(outval.attributes)));
- assignments.push_back(ctx.dst->Assign(
- ctx.dst->MemberAccessor(wrapper_result, name), outval.value));
- }
-
- // Sort the struct members to satisfy HLSL interfacing matching rules.
- std::sort(wrapper_struct_output_members.begin(),
- wrapper_struct_output_members.end(), StructMemberComparator);
-
- // Create the new struct type.
- auto* out_struct = ctx.dst->create<ast::Struct>(
- ctx.dst->Sym(), wrapper_struct_output_members, ast::AttributeList{});
- ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct);
-
- // Create the output struct object, assign its members, and return it.
- auto* result_object =
- ctx.dst->Var(wrapper_result, ctx.dst->ty.type_name(out_struct->name));
- wrapper_body.push_back(ctx.dst->Decl(result_object));
- wrapper_body.insert(wrapper_body.end(), assignments.begin(),
- assignments.end());
- wrapper_body.push_back(ctx.dst->Return(wrapper_result));
-
- return out_struct;
- }
-
- /// Create and assign the wrapper function's output variables.
- void CreateGlobalOutputVariables() {
- for (auto& outval : wrapper_output_values) {
- // Disable validation for use of the `output` storage class.
- ast::AttributeList attributes = std::move(outval.attributes);
- attributes.push_back(
- ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
-
- // Create the global variable and assign it the output value.
- auto name = ctx.dst->Symbols().New(outval.name);
- auto* type = outval.type;
- const ast::Expression* lhs = ctx.dst->Expr(name);
- if (HasSampleMask(attributes)) {
- // Vulkan requires the type of a SampleMask builtin to be an array.
- // Declare it as array<u32, 1> and then store to the first element.
- type = ctx.dst->ty.array(type, 1);
- lhs = ctx.dst->IndexAccessor(lhs, 0);
- }
- ctx.dst->Global(name, type, ast::StorageClass::kOutput,
- std::move(attributes));
- wrapper_body.push_back(ctx.dst->Assign(lhs, outval.value));
- }
- }
-
- // Recreate the original function without entry point attributes and call it.
- /// @returns the inner function call expression
- const ast::CallExpression* CallInnerFunction() {
- Symbol inner_name;
- if (cfg.shader_style == ShaderStyle::kGlsl) {
- // In GLSL, clone the original entry point name, as the wrapper will be
- // called "main".
- inner_name = ctx.Clone(func_ast->symbol);
- } else {
- // Add a suffix to the function name, as the wrapper function will take
- // the original entry point name.
- auto ep_name = ctx.src->Symbols().NameFor(func_ast->symbol);
- inner_name = ctx.dst->Symbols().New(ep_name + "_inner");
- }
-
- // Clone everything, dropping the function and return type attributes.
- // The parameter attributes will have already been stripped during
- // processing.
- auto* inner_function = ctx.dst->create<ast::Function>(
- inner_name, ctx.Clone(func_ast->params),
- ctx.Clone(func_ast->return_type), ctx.Clone(func_ast->body),
- ast::AttributeList{}, ast::AttributeList{});
- ctx.Replace(func_ast, inner_function);
-
- // Call the function.
- return ctx.dst->Call(inner_function->symbol, inner_call_parameters);
- }
-
- /// Process the entry point function.
- void Process() {
- bool needs_fixed_sample_mask = false;
- bool needs_vertex_point_size = false;
- if (func_ast->PipelineStage() == ast::PipelineStage::kFragment &&
- cfg.fixed_sample_mask != 0xFFFFFFFF) {
- needs_fixed_sample_mask = true;
- }
- if (func_ast->PipelineStage() == ast::PipelineStage::kVertex &&
- cfg.emit_vertex_point_size) {
- needs_vertex_point_size = true;
- }
-
- // Exit early if there is no shader IO to handle.
- if (func_sem->Parameters().size() == 0 &&
- func_sem->ReturnType()->Is<sem::Void>() && !needs_fixed_sample_mask &&
- !needs_vertex_point_size && cfg.shader_style != ShaderStyle::kGlsl) {
- return;
- }
-
- // Process the entry point parameters, collecting those that need to be
- // aggregated into a single structure.
- if (!func_sem->Parameters().empty()) {
- for (auto* param : func_sem->Parameters()) {
- if (param->Type()->Is<sem::Struct>()) {
- ProcessStructParameter(param);
- } else {
- ProcessNonStructParameter(param);
- }
- }
-
- // Create a structure parameter for the outer entry point if necessary.
- if (!wrapper_struct_param_members.empty()) {
- CreateInputStruct();
- }
- }
-
- // Recreate the original function and call it.
- auto* call_inner = CallInnerFunction();
-
- // Process the return type, and start building the wrapper function body.
- std::function<const ast::Type*()> wrapper_ret_type = [&] {
- return ctx.dst->ty.void_();
+ /// OutputValue represents a shader result that the wrapper function produces.
+ struct OutputValue {
+ /// The name of the output value.
+ std::string name;
+ /// The type of the output value.
+ const ast::Type* type;
+ /// The shader IO attributes.
+ ast::AttributeList attributes;
+ /// The value itself.
+ const ast::Expression* value;
};
- if (func_sem->ReturnType()->Is<sem::Void>()) {
- // The function call is just a statement with no result.
- wrapper_body.push_back(ctx.dst->CallStmt(call_inner));
- } else {
- // Capture the result of calling the original function.
- auto* inner_result = ctx.dst->Let(ctx.dst->Symbols().New("inner_result"),
- nullptr, call_inner);
- wrapper_body.push_back(ctx.dst->Decl(inner_result));
- // Process the original return type to determine the outputs that the
- // outer function needs to produce.
- ProcessReturnType(func_sem->ReturnType(), inner_result->symbol);
- }
+ /// The clone context.
+ CloneContext& ctx;
+ /// The transform config.
+ CanonicalizeEntryPointIO::Config const cfg;
+ /// The entry point function (AST).
+ const ast::Function* func_ast;
+ /// The entry point function (SEM).
+ const sem::Function* func_sem;
- // Add a fixed sample mask, if necessary.
- if (needs_fixed_sample_mask) {
- AddFixedSampleMask();
- }
+ /// The new entry point wrapper function's parameters.
+ ast::VariableList wrapper_ep_parameters;
+ /// The members of the wrapper function's struct parameter.
+ ast::StructMemberList wrapper_struct_param_members;
+ /// The name of the wrapper function's struct parameter.
+ Symbol wrapper_struct_param_name;
+ /// The parameters that will be passed to the original function.
+ ast::ExpressionList inner_call_parameters;
+ /// The members of the wrapper function's struct return type.
+ ast::StructMemberList wrapper_struct_output_members;
+ /// The wrapper function output values.
+ std::vector<OutputValue> wrapper_output_values;
+ /// The body of the wrapper function.
+ ast::StatementList wrapper_body;
+ /// Input names used by the entrypoint
+ std::unordered_set<std::string> input_names;
- // Add the pointsize builtin, if necessary.
- if (needs_vertex_point_size) {
- AddVertexPointSize();
- }
+ /// Constructor
+ /// @param context the clone context
+ /// @param config the transform config
+ /// @param function the entry point function
+ State(CloneContext& context,
+ const CanonicalizeEntryPointIO::Config& config,
+ const ast::Function* function)
+ : ctx(context), cfg(config), func_ast(function), func_sem(ctx.src->Sem().Get(function)) {}
- // Produce the entry point outputs, if necessary.
- if (!wrapper_output_values.empty()) {
- if (cfg.shader_style == ShaderStyle::kSpirv ||
- cfg.shader_style == ShaderStyle::kGlsl) {
- CreateGlobalOutputVariables();
- } else {
- auto* output_struct = CreateOutputStruct();
- wrapper_ret_type = [&, output_struct] {
- return ctx.dst->ty.type_name(output_struct->name);
- };
- }
- }
-
- if (cfg.shader_style == ShaderStyle::kGlsl &&
- func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
- auto* pos_y = GLPosition("y");
- auto* negate_pos_y = ctx.dst->create<ast::UnaryOpExpression>(
- ast::UnaryOp::kNegation, GLPosition("y"));
- wrapper_body.push_back(ctx.dst->Assign(pos_y, negate_pos_y));
-
- auto* two_z = ctx.dst->Mul(ctx.dst->Expr(2.0f), GLPosition("z"));
- auto* fixed_z = ctx.dst->Sub(two_z, GLPosition("w"));
- wrapper_body.push_back(ctx.dst->Assign(GLPosition("z"), fixed_z));
- }
-
- // Create the wrapper entry point function.
- // For GLSL, use "main", otherwise take the name of the original
- // entry point function.
- Symbol name;
- if (cfg.shader_style == ShaderStyle::kGlsl) {
- name = ctx.dst->Symbols().New("main");
- } else {
- name = ctx.Clone(func_ast->symbol);
- }
-
- auto* wrapper_func = ctx.dst->create<ast::Function>(
- name, wrapper_ep_parameters, wrapper_ret_type(),
- ctx.dst->Block(wrapper_body), ctx.Clone(func_ast->attributes),
- ast::AttributeList{});
- ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), func_ast,
- wrapper_func);
- }
-
- /// Retrieve the gl_ string corresponding to a builtin.
- /// @param builtin the builtin
- /// @param stage the current pipeline stage
- /// @param storage_class the storage class (input or output)
- /// @returns the gl_ string corresponding to that builtin
- const char* GLSLBuiltinToString(ast::Builtin builtin,
- ast::PipelineStage stage,
- ast::StorageClass storage_class) {
- switch (builtin) {
- case ast::Builtin::kPosition:
- switch (stage) {
- case ast::PipelineStage::kVertex:
- return "gl_Position";
- case ast::PipelineStage::kFragment:
- return "gl_FragCoord";
- default:
- return "";
+ /// Clones the shader IO attributes from `src`.
+ /// @param src the attributes to clone
+ /// @param do_interpolate whether to clone InterpolateAttribute
+ /// @return the cloned attributes
+ ast::AttributeList CloneShaderIOAttributes(const ast::AttributeList& src, bool do_interpolate) {
+ ast::AttributeList new_attributes;
+ for (auto* attr : src) {
+ if (IsShaderIOAttribute(attr) &&
+ (do_interpolate || !attr->Is<ast::InterpolateAttribute>())) {
+ new_attributes.push_back(ctx.Clone(attr));
+ }
}
- case ast::Builtin::kVertexIndex:
- return "gl_VertexID";
- case ast::Builtin::kInstanceIndex:
- return "gl_InstanceID";
- case ast::Builtin::kFrontFacing:
- return "gl_FrontFacing";
- case ast::Builtin::kFragDepth:
- return "gl_FragDepth";
- case ast::Builtin::kLocalInvocationId:
- return "gl_LocalInvocationID";
- case ast::Builtin::kLocalInvocationIndex:
- return "gl_LocalInvocationIndex";
- case ast::Builtin::kGlobalInvocationId:
- return "gl_GlobalInvocationID";
- case ast::Builtin::kNumWorkgroups:
- return "gl_NumWorkGroups";
- case ast::Builtin::kWorkgroupId:
- return "gl_WorkGroupID";
- case ast::Builtin::kSampleIndex:
- return "gl_SampleID";
- case ast::Builtin::kSampleMask:
- if (storage_class == ast::StorageClass::kInput) {
- return "gl_SampleMaskIn";
+ return new_attributes;
+ }
+
+ /// Create or return a symbol for the wrapper function's struct parameter.
+ /// @returns the symbol for the struct parameter
+ Symbol InputStructSymbol() {
+ if (!wrapper_struct_param_name.IsValid()) {
+ wrapper_struct_param_name = ctx.dst->Sym();
+ }
+ return wrapper_struct_param_name;
+ }
+
+ /// Add a shader input to the entry point.
+ /// @param name the name of the shader input
+ /// @param type the type of the shader input
+ /// @param attributes the attributes to apply to the shader input
+ /// @returns an expression which evaluates to the value of the shader input
+ const ast::Expression* AddInput(std::string name,
+ const sem::Type* type,
+ ast::AttributeList attributes) {
+ auto* ast_type = CreateASTTypeFor(ctx, type);
+ if (cfg.shader_style == ShaderStyle::kSpirv || cfg.shader_style == ShaderStyle::kGlsl) {
+ // Vulkan requires that integer user-defined fragment inputs are
+ // always decorated with `Flat`.
+ // TODO(crbug.com/tint/1224): Remove this once a flat interpolation
+ // attribute is required for integers.
+ if (type->is_integer_scalar_or_vector() &&
+ ast::HasAttribute<ast::LocationAttribute>(attributes) &&
+ !ast::HasAttribute<ast::InterpolateAttribute>(attributes) &&
+ func_ast->PipelineStage() == ast::PipelineStage::kFragment) {
+ attributes.push_back(ctx.dst->Interpolate(ast::InterpolationType::kFlat,
+ ast::InterpolationSampling::kNone));
+ }
+
+ // Disable validation for use of the `input` storage class.
+ attributes.push_back(ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
+
+ // In GLSL, if it's a builtin, override the name with the
+ // corresponding gl_ builtin name
+ auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(attributes);
+ if (cfg.shader_style == ShaderStyle::kGlsl && builtin) {
+ name = GLSLBuiltinToString(builtin->builtin, func_ast->PipelineStage(),
+ ast::StorageClass::kInput);
+ }
+ auto symbol = ctx.dst->Symbols().New(name);
+
+ // Create the global variable and use its value for the shader input.
+ const ast::Expression* value = ctx.dst->Expr(symbol);
+
+ if (builtin) {
+ if (cfg.shader_style == ShaderStyle::kGlsl) {
+ value = FromGLSLBuiltin(builtin->builtin, value, ast_type);
+ } else if (builtin->builtin == ast::Builtin::kSampleMask) {
+ // Vulkan requires the type of a SampleMask builtin to be an array.
+ // Declare it as array<u32, 1> and then load the first element.
+ ast_type = ctx.dst->ty.array(ast_type, 1);
+ value = ctx.dst->IndexAccessor(value, 0);
+ }
+ }
+ ctx.dst->Global(symbol, ast_type, ast::StorageClass::kInput, std::move(attributes));
+ return value;
+ } else if (cfg.shader_style == ShaderStyle::kMsl &&
+ ast::HasAttribute<ast::BuiltinAttribute>(attributes)) {
+ // If this input is a builtin and we are targeting MSL, then add it to the
+ // parameter list and pass it directly to the inner function.
+ Symbol symbol = input_names.emplace(name).second ? ctx.dst->Symbols().Register(name)
+ : ctx.dst->Symbols().New(name);
+ wrapper_ep_parameters.push_back(
+ ctx.dst->Param(symbol, ast_type, std::move(attributes)));
+ return ctx.dst->Expr(symbol);
} else {
- return "gl_SampleMask";
+ // Otherwise, move it to the new structure member list.
+ Symbol symbol = input_names.emplace(name).second ? ctx.dst->Symbols().Register(name)
+ : ctx.dst->Symbols().New(name);
+ wrapper_struct_param_members.push_back(
+ ctx.dst->Member(symbol, ast_type, std::move(attributes)));
+ return ctx.dst->MemberAccessor(InputStructSymbol(), symbol);
}
- default:
- return "";
}
- }
- /// Convert a given GLSL builtin value to the corresponding WGSL value.
- /// @param builtin the builtin variable
- /// @param value the value to convert
- /// @param ast_type (inout) the incoming WGSL and outgoing GLSL types
- /// @returns an expression representing the GLSL builtin converted to what
- /// WGSL expects
- const ast::Expression* FromGLSLBuiltin(ast::Builtin builtin,
+ /// Add a shader output to the entry point.
+ /// @param name the name of the shader output
+ /// @param type the type of the shader output
+ /// @param attributes the attributes to apply to the shader output
+ /// @param value the value of the shader output
+ void AddOutput(std::string name,
+ const sem::Type* type,
+ ast::AttributeList attributes,
+ const ast::Expression* value) {
+ // Vulkan requires that integer user-defined vertex outputs are
+ // always decorated with `Flat`.
+ // TODO(crbug.com/tint/1224): Remove this once a flat interpolation
+ // attribute is required for integers.
+ if (cfg.shader_style == ShaderStyle::kSpirv && type->is_integer_scalar_or_vector() &&
+ ast::HasAttribute<ast::LocationAttribute>(attributes) &&
+ !ast::HasAttribute<ast::InterpolateAttribute>(attributes) &&
+ func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
+ attributes.push_back(ctx.dst->Interpolate(ast::InterpolationType::kFlat,
+ ast::InterpolationSampling::kNone));
+ }
+
+ // In GLSL, if it's a builtin, override the name with the
+ // corresponding gl_ builtin name
+ if (cfg.shader_style == ShaderStyle::kGlsl) {
+ if (auto* b = ast::GetAttribute<ast::BuiltinAttribute>(attributes)) {
+ name = GLSLBuiltinToString(b->builtin, func_ast->PipelineStage(),
+ ast::StorageClass::kOutput);
+ value = ToGLSLBuiltin(b->builtin, value, type);
+ }
+ }
+
+ OutputValue output;
+ output.name = name;
+ output.type = CreateASTTypeFor(ctx, type);
+ output.attributes = std::move(attributes);
+ output.value = value;
+ wrapper_output_values.push_back(output);
+ }
+
+ /// Process a non-struct parameter.
+ /// This creates a new object for the shader input, moving the shader IO
+ /// attributes to it. It also adds an expression to the list of parameters
+ /// that will be passed to the original function.
+ /// @param param the original function parameter
+ void ProcessNonStructParameter(const sem::Parameter* param) {
+ // Remove the shader IO attributes from the inner function parameter, and
+ // attach them to the new object instead.
+ ast::AttributeList attributes;
+ for (auto* attr : param->Declaration()->attributes) {
+ if (IsShaderIOAttribute(attr)) {
+ ctx.Remove(param->Declaration()->attributes, attr);
+ attributes.push_back(ctx.Clone(attr));
+ }
+ }
+
+ auto name = ctx.src->Symbols().NameFor(param->Declaration()->symbol);
+ auto* input_expr = AddInput(name, param->Type(), std::move(attributes));
+ inner_call_parameters.push_back(input_expr);
+ }
+
+ /// Process a struct parameter.
+ /// This creates new objects for each struct member, moving the shader IO
+ /// attributes to them. It also creates the structure that will be passed to
+ /// the original function.
+ /// @param param the original function parameter
+ void ProcessStructParameter(const sem::Parameter* param) {
+ auto* str = param->Type()->As<sem::Struct>();
+
+ // Recreate struct members in the outer entry point and build an initializer
+ // list to pass them through to the inner function.
+ ast::ExpressionList inner_struct_values;
+ for (auto* member : str->Members()) {
+ if (member->Type()->Is<sem::Struct>()) {
+ TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
+ continue;
+ }
+
+ auto* member_ast = member->Declaration();
+ auto name = ctx.src->Symbols().NameFor(member_ast->symbol);
+
+ // In GLSL, do not add interpolation attributes on vertex input
+ bool do_interpolate = true;
+ if (cfg.shader_style == ShaderStyle::kGlsl &&
+ func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
+ do_interpolate = false;
+ }
+ auto attributes = CloneShaderIOAttributes(member_ast->attributes, do_interpolate);
+ auto* input_expr = AddInput(name, member->Type(), std::move(attributes));
+ inner_struct_values.push_back(input_expr);
+ }
+
+ // Construct the original structure using the new shader input objects.
+ inner_call_parameters.push_back(
+ ctx.dst->Construct(ctx.Clone(param->Declaration()->type), inner_struct_values));
+ }
+
+ /// Process the entry point return type.
+ /// This generates a list of output values that are returned by the original
+ /// function.
+ /// @param inner_ret_type the original function return type
+ /// @param original_result the result object produced by the original function
+ void ProcessReturnType(const sem::Type* inner_ret_type, Symbol original_result) {
+ bool do_interpolate = true;
+ // In GLSL, do not add interpolation attributes on fragment output
+ if (cfg.shader_style == ShaderStyle::kGlsl &&
+ func_ast->PipelineStage() == ast::PipelineStage::kFragment) {
+ do_interpolate = false;
+ }
+ if (auto* str = inner_ret_type->As<sem::Struct>()) {
+ for (auto* member : str->Members()) {
+ if (member->Type()->Is<sem::Struct>()) {
+ TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
+ continue;
+ }
+
+ auto* member_ast = member->Declaration();
+ auto name = ctx.src->Symbols().NameFor(member_ast->symbol);
+ auto attributes = CloneShaderIOAttributes(member_ast->attributes, do_interpolate);
+
+ // Extract the original structure member.
+ AddOutput(name, member->Type(), std::move(attributes),
+ ctx.dst->MemberAccessor(original_result, name));
+ }
+ } else if (!inner_ret_type->Is<sem::Void>()) {
+ auto attributes =
+ CloneShaderIOAttributes(func_ast->return_type_attributes, do_interpolate);
+
+ // Propagate the non-struct return value as is.
+ AddOutput("value", func_sem->ReturnType(), std::move(attributes),
+ ctx.dst->Expr(original_result));
+ }
+ }
+
+ /// Add a fixed sample mask to the wrapper function output.
+ /// If there is already a sample mask, bitwise-and it with the fixed mask.
+ /// Otherwise, create a new output value from the fixed mask.
+ void AddFixedSampleMask() {
+ // Check the existing output values for a sample mask builtin.
+ for (auto& outval : wrapper_output_values) {
+ if (HasSampleMask(outval.attributes)) {
+ // Combine the authored sample mask with the fixed mask.
+ outval.value = ctx.dst->And(outval.value, cfg.fixed_sample_mask);
+ return;
+ }
+ }
+
+ // No existing sample mask builtin was found, so create a new output value
+ // using the fixed sample mask.
+ AddOutput("fixed_sample_mask", ctx.dst->create<sem::U32>(),
+ {ctx.dst->Builtin(ast::Builtin::kSampleMask)},
+ ctx.dst->Expr(cfg.fixed_sample_mask));
+ }
+
+ /// Add a point size builtin to the wrapper function output.
+ void AddVertexPointSize() {
+ // Create a new output value and assign it a literal 1.0 value.
+ AddOutput("vertex_point_size", ctx.dst->create<sem::F32>(),
+ {ctx.dst->Builtin(ast::Builtin::kPointSize)}, ctx.dst->Expr(1.f));
+ }
+
+ /// Create an expression for gl_Position.[component]
+ /// @param component the component of gl_Position to access
+ /// @returns the new expression
+ const ast::Expression* GLPosition(const char* component) {
+ Symbol pos = ctx.dst->Symbols().Register("gl_Position");
+ Symbol c = ctx.dst->Symbols().Register(component);
+ return ctx.dst->MemberAccessor(ctx.dst->Expr(pos), ctx.dst->Expr(c));
+ }
+
+ /// Create the wrapper function's struct parameter and type objects.
+ void CreateInputStruct() {
+ // Sort the struct members to satisfy HLSL interfacing matching rules.
+ std::sort(wrapper_struct_param_members.begin(), wrapper_struct_param_members.end(),
+ StructMemberComparator);
+
+ // Create the new struct type.
+ auto struct_name = ctx.dst->Sym();
+ auto* in_struct = ctx.dst->create<ast::Struct>(struct_name, wrapper_struct_param_members,
+ ast::AttributeList{});
+ ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct);
+
+ // Create a new function parameter using this struct type.
+ auto* param = ctx.dst->Param(InputStructSymbol(), ctx.dst->ty.type_name(struct_name));
+ wrapper_ep_parameters.push_back(param);
+ }
+
+ /// Create and return the wrapper function's struct result object.
+ /// @returns the struct type
+ ast::Struct* CreateOutputStruct() {
+ ast::StatementList assignments;
+
+ auto wrapper_result = ctx.dst->Symbols().New("wrapper_result");
+
+ // Create the struct members and their corresponding assignment statements.
+ std::unordered_set<std::string> member_names;
+ for (auto& outval : wrapper_output_values) {
+ // Use the original output name, unless that is already taken.
+ Symbol name;
+ if (member_names.count(outval.name)) {
+ name = ctx.dst->Symbols().New(outval.name);
+ } else {
+ name = ctx.dst->Symbols().Register(outval.name);
+ }
+ member_names.insert(ctx.dst->Symbols().NameFor(name));
+
+ wrapper_struct_output_members.push_back(
+ ctx.dst->Member(name, outval.type, std::move(outval.attributes)));
+ assignments.push_back(
+ ctx.dst->Assign(ctx.dst->MemberAccessor(wrapper_result, name), outval.value));
+ }
+
+ // Sort the struct members to satisfy HLSL interfacing matching rules.
+ std::sort(wrapper_struct_output_members.begin(), wrapper_struct_output_members.end(),
+ StructMemberComparator);
+
+ // Create the new struct type.
+ auto* out_struct = ctx.dst->create<ast::Struct>(
+ ctx.dst->Sym(), wrapper_struct_output_members, ast::AttributeList{});
+ ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct);
+
+ // Create the output struct object, assign its members, and return it.
+ auto* result_object = ctx.dst->Var(wrapper_result, ctx.dst->ty.type_name(out_struct->name));
+ wrapper_body.push_back(ctx.dst->Decl(result_object));
+ wrapper_body.insert(wrapper_body.end(), assignments.begin(), assignments.end());
+ wrapper_body.push_back(ctx.dst->Return(wrapper_result));
+
+ return out_struct;
+ }
+
+ /// Create and assign the wrapper function's output variables.
+ void CreateGlobalOutputVariables() {
+ for (auto& outval : wrapper_output_values) {
+ // Disable validation for use of the `output` storage class.
+ ast::AttributeList attributes = std::move(outval.attributes);
+ attributes.push_back(ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
+
+ // Create the global variable and assign it the output value.
+ auto name = ctx.dst->Symbols().New(outval.name);
+ auto* type = outval.type;
+ const ast::Expression* lhs = ctx.dst->Expr(name);
+ if (HasSampleMask(attributes)) {
+ // Vulkan requires the type of a SampleMask builtin to be an array.
+ // Declare it as array<u32, 1> and then store to the first element.
+ type = ctx.dst->ty.array(type, 1);
+ lhs = ctx.dst->IndexAccessor(lhs, 0);
+ }
+ ctx.dst->Global(name, type, ast::StorageClass::kOutput, std::move(attributes));
+ wrapper_body.push_back(ctx.dst->Assign(lhs, outval.value));
+ }
+ }
+
+ // Recreate the original function without entry point attributes and call it.
+ /// @returns the inner function call expression
+ const ast::CallExpression* CallInnerFunction() {
+ Symbol inner_name;
+ if (cfg.shader_style == ShaderStyle::kGlsl) {
+ // In GLSL, clone the original entry point name, as the wrapper will be
+ // called "main".
+ inner_name = ctx.Clone(func_ast->symbol);
+ } else {
+ // Add a suffix to the function name, as the wrapper function will take
+ // the original entry point name.
+ auto ep_name = ctx.src->Symbols().NameFor(func_ast->symbol);
+ inner_name = ctx.dst->Symbols().New(ep_name + "_inner");
+ }
+
+ // Clone everything, dropping the function and return type attributes.
+ // The parameter attributes will have already been stripped during
+ // processing.
+ auto* inner_function = ctx.dst->create<ast::Function>(
+ inner_name, ctx.Clone(func_ast->params), ctx.Clone(func_ast->return_type),
+ ctx.Clone(func_ast->body), ast::AttributeList{}, ast::AttributeList{});
+ ctx.Replace(func_ast, inner_function);
+
+ // Call the function.
+ return ctx.dst->Call(inner_function->symbol, inner_call_parameters);
+ }
+
+ /// Process the entry point function.
+ void Process() {
+ bool needs_fixed_sample_mask = false;
+ bool needs_vertex_point_size = false;
+ if (func_ast->PipelineStage() == ast::PipelineStage::kFragment &&
+ cfg.fixed_sample_mask != 0xFFFFFFFF) {
+ needs_fixed_sample_mask = true;
+ }
+ if (func_ast->PipelineStage() == ast::PipelineStage::kVertex &&
+ cfg.emit_vertex_point_size) {
+ needs_vertex_point_size = true;
+ }
+
+ // Exit early if there is no shader IO to handle.
+ if (func_sem->Parameters().size() == 0 && func_sem->ReturnType()->Is<sem::Void>() &&
+ !needs_fixed_sample_mask && !needs_vertex_point_size &&
+ cfg.shader_style != ShaderStyle::kGlsl) {
+ return;
+ }
+
+ // Process the entry point parameters, collecting those that need to be
+ // aggregated into a single structure.
+ if (!func_sem->Parameters().empty()) {
+ for (auto* param : func_sem->Parameters()) {
+ if (param->Type()->Is<sem::Struct>()) {
+ ProcessStructParameter(param);
+ } else {
+ ProcessNonStructParameter(param);
+ }
+ }
+
+ // Create a structure parameter for the outer entry point if necessary.
+ if (!wrapper_struct_param_members.empty()) {
+ CreateInputStruct();
+ }
+ }
+
+ // Recreate the original function and call it.
+ auto* call_inner = CallInnerFunction();
+
+ // Process the return type, and start building the wrapper function body.
+ std::function<const ast::Type*()> wrapper_ret_type = [&] { return ctx.dst->ty.void_(); };
+ if (func_sem->ReturnType()->Is<sem::Void>()) {
+ // The function call is just a statement with no result.
+ wrapper_body.push_back(ctx.dst->CallStmt(call_inner));
+ } else {
+ // Capture the result of calling the original function.
+ auto* inner_result =
+ ctx.dst->Let(ctx.dst->Symbols().New("inner_result"), nullptr, call_inner);
+ wrapper_body.push_back(ctx.dst->Decl(inner_result));
+
+ // Process the original return type to determine the outputs that the
+ // outer function needs to produce.
+ ProcessReturnType(func_sem->ReturnType(), inner_result->symbol);
+ }
+
+ // Add a fixed sample mask, if necessary.
+ if (needs_fixed_sample_mask) {
+ AddFixedSampleMask();
+ }
+
+ // Add the pointsize builtin, if necessary.
+ if (needs_vertex_point_size) {
+ AddVertexPointSize();
+ }
+
+ // Produce the entry point outputs, if necessary.
+ if (!wrapper_output_values.empty()) {
+ if (cfg.shader_style == ShaderStyle::kSpirv || cfg.shader_style == ShaderStyle::kGlsl) {
+ CreateGlobalOutputVariables();
+ } else {
+ auto* output_struct = CreateOutputStruct();
+ wrapper_ret_type = [&, output_struct] {
+ return ctx.dst->ty.type_name(output_struct->name);
+ };
+ }
+ }
+
+ if (cfg.shader_style == ShaderStyle::kGlsl &&
+ func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
+ auto* pos_y = GLPosition("y");
+ auto* negate_pos_y =
+ ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNegation, GLPosition("y"));
+ wrapper_body.push_back(ctx.dst->Assign(pos_y, negate_pos_y));
+
+ auto* two_z = ctx.dst->Mul(ctx.dst->Expr(2.0f), GLPosition("z"));
+ auto* fixed_z = ctx.dst->Sub(two_z, GLPosition("w"));
+ wrapper_body.push_back(ctx.dst->Assign(GLPosition("z"), fixed_z));
+ }
+
+ // Create the wrapper entry point function.
+ // For GLSL, use "main", otherwise take the name of the original
+ // entry point function.
+ Symbol name;
+ if (cfg.shader_style == ShaderStyle::kGlsl) {
+ name = ctx.dst->Symbols().New("main");
+ } else {
+ name = ctx.Clone(func_ast->symbol);
+ }
+
+ auto* wrapper_func = ctx.dst->create<ast::Function>(
+ name, wrapper_ep_parameters, wrapper_ret_type(), ctx.dst->Block(wrapper_body),
+ ctx.Clone(func_ast->attributes), ast::AttributeList{});
+ ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), func_ast, wrapper_func);
+ }
+
+ /// Retrieve the gl_ string corresponding to a builtin.
+ /// @param builtin the builtin
+ /// @param stage the current pipeline stage
+ /// @param storage_class the storage class (input or output)
+ /// @returns the gl_ string corresponding to that builtin
+ const char* GLSLBuiltinToString(ast::Builtin builtin,
+ ast::PipelineStage stage,
+ ast::StorageClass storage_class) {
+ switch (builtin) {
+ case ast::Builtin::kPosition:
+ switch (stage) {
+ case ast::PipelineStage::kVertex:
+ return "gl_Position";
+ case ast::PipelineStage::kFragment:
+ return "gl_FragCoord";
+ default:
+ return "";
+ }
+ case ast::Builtin::kVertexIndex:
+ return "gl_VertexID";
+ case ast::Builtin::kInstanceIndex:
+ return "gl_InstanceID";
+ case ast::Builtin::kFrontFacing:
+ return "gl_FrontFacing";
+ case ast::Builtin::kFragDepth:
+ return "gl_FragDepth";
+ case ast::Builtin::kLocalInvocationId:
+ return "gl_LocalInvocationID";
+ case ast::Builtin::kLocalInvocationIndex:
+ return "gl_LocalInvocationIndex";
+ case ast::Builtin::kGlobalInvocationId:
+ return "gl_GlobalInvocationID";
+ case ast::Builtin::kNumWorkgroups:
+ return "gl_NumWorkGroups";
+ case ast::Builtin::kWorkgroupId:
+ return "gl_WorkGroupID";
+ case ast::Builtin::kSampleIndex:
+ return "gl_SampleID";
+ case ast::Builtin::kSampleMask:
+ if (storage_class == ast::StorageClass::kInput) {
+ return "gl_SampleMaskIn";
+ } else {
+ return "gl_SampleMask";
+ }
+ default:
+ return "";
+ }
+ }
+
+ /// Convert a given GLSL builtin value to the corresponding WGSL value.
+ /// @param builtin the builtin variable
+ /// @param value the value to convert
+ /// @param ast_type (inout) the incoming WGSL and outgoing GLSL types
+ /// @returns an expression representing the GLSL builtin converted to what
+ /// WGSL expects
+ const ast::Expression* FromGLSLBuiltin(ast::Builtin builtin,
+ const ast::Expression* value,
+ const ast::Type*& ast_type) {
+ switch (builtin) {
+ case ast::Builtin::kVertexIndex:
+ case ast::Builtin::kInstanceIndex:
+ case ast::Builtin::kSampleIndex:
+ // GLSL uses i32 for these, so bitcast to u32.
+ value = ctx.dst->Bitcast(ast_type, value);
+ ast_type = ctx.dst->ty.i32();
+ break;
+ case ast::Builtin::kSampleMask:
+ // gl_SampleMask is an array of i32. Retrieve the first element and
+ // bitcast it to u32.
+ value = ctx.dst->IndexAccessor(value, 0);
+ value = ctx.dst->Bitcast(ast_type, value);
+ ast_type = ctx.dst->ty.array(ctx.dst->ty.i32(), 1);
+ break;
+ default:
+ break;
+ }
+ return value;
+ }
+
+ /// Convert a given WGSL value to the type expected when assigning to a
+ /// GLSL builtin.
+ /// @param builtin the builtin variable
+ /// @param value the value to convert
+ /// @param type (out) the type to which the value was converted
+ /// @returns the converted value which can be assigned to the GLSL builtin
+ const ast::Expression* ToGLSLBuiltin(ast::Builtin builtin,
const ast::Expression* value,
- const ast::Type*& ast_type) {
- switch (builtin) {
- case ast::Builtin::kVertexIndex:
- case ast::Builtin::kInstanceIndex:
- case ast::Builtin::kSampleIndex:
- // GLSL uses i32 for these, so bitcast to u32.
- value = ctx.dst->Bitcast(ast_type, value);
- ast_type = ctx.dst->ty.i32();
- break;
- case ast::Builtin::kSampleMask:
- // gl_SampleMask is an array of i32. Retrieve the first element and
- // bitcast it to u32.
- value = ctx.dst->IndexAccessor(value, 0);
- value = ctx.dst->Bitcast(ast_type, value);
- ast_type = ctx.dst->ty.array(ctx.dst->ty.i32(), 1);
- break;
- default:
- break;
+ const sem::Type*& type) {
+ switch (builtin) {
+ case ast::Builtin::kVertexIndex:
+ case ast::Builtin::kInstanceIndex:
+ case ast::Builtin::kSampleIndex:
+ case ast::Builtin::kSampleMask:
+ type = ctx.dst->create<sem::I32>();
+ value = ctx.dst->Bitcast(CreateASTTypeFor(ctx, type), value);
+ break;
+ default:
+ break;
+ }
+ return value;
}
- return value;
- }
-
- /// Convert a given WGSL value to the type expected when assigning to a
- /// GLSL builtin.
- /// @param builtin the builtin variable
- /// @param value the value to convert
- /// @param type (out) the type to which the value was converted
- /// @returns the converted value which can be assigned to the GLSL builtin
- const ast::Expression* ToGLSLBuiltin(ast::Builtin builtin,
- const ast::Expression* value,
- const sem::Type*& type) {
- switch (builtin) {
- case ast::Builtin::kVertexIndex:
- case ast::Builtin::kInstanceIndex:
- case ast::Builtin::kSampleIndex:
- case ast::Builtin::kSampleMask:
- type = ctx.dst->create<sem::I32>();
- value = ctx.dst->Bitcast(CreateASTTypeFor(ctx, type), value);
- break;
- default:
- break;
- }
- return value;
- }
};
-void CanonicalizeEntryPointIO::Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap&) const {
- auto* cfg = inputs.Get<Config>();
- if (cfg == nullptr) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform,
- "missing transform data for " + std::string(TypeInfo().name));
- return;
- }
+void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
+ auto* cfg = inputs.Get<Config>();
+ if (cfg == nullptr) {
+ ctx.dst->Diagnostics().add_error(
+ diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
+ return;
+ }
- // Remove entry point IO attributes from struct declarations.
- // New structures will be created for each entry point, as necessary.
- for (auto* ty : ctx.src->AST().TypeDecls()) {
- if (auto* struct_ty = ty->As<ast::Struct>()) {
- for (auto* member : struct_ty->members) {
- for (auto* attr : member->attributes) {
- if (IsShaderIOAttribute(attr)) {
- ctx.Remove(member->attributes, attr);
- }
+ // Remove entry point IO attributes from struct declarations.
+ // New structures will be created for each entry point, as necessary.
+ for (auto* ty : ctx.src->AST().TypeDecls()) {
+ if (auto* struct_ty = ty->As<ast::Struct>()) {
+ for (auto* member : struct_ty->members) {
+ for (auto* attr : member->attributes) {
+ if (IsShaderIOAttribute(attr)) {
+ ctx.Remove(member->attributes, attr);
+ }
+ }
+ }
}
- }
- }
- }
-
- for (auto* func_ast : ctx.src->AST().Functions()) {
- if (!func_ast->IsEntryPoint()) {
- continue;
}
- State state(ctx, *cfg, func_ast);
- state.Process();
- }
+ for (auto* func_ast : ctx.src->AST().Functions()) {
+ if (!func_ast->IsEntryPoint()) {
+ continue;
+ }
- ctx.Clone();
+ State state(ctx, *cfg, func_ast);
+ state.Process();
+ }
+
+ ctx.Clone();
}
CanonicalizeEntryPointIO::Config::Config(ShaderStyle style,
diff --git a/src/tint/transform/canonicalize_entry_point_io.h b/src/tint/transform/canonicalize_entry_point_io.h
index eab4128..64a10f2 100644
--- a/src/tint/transform/canonicalize_entry_point_io.h
+++ b/src/tint/transform/canonicalize_entry_point_io.h
@@ -82,64 +82,61 @@
///
/// @note Depends on the following transforms to have been run first:
/// * Unshadow
-class CanonicalizeEntryPointIO final
- : public Castable<CanonicalizeEntryPointIO, Transform> {
- public:
- /// ShaderStyle is an enumerator of different ways to emit shader IO.
- enum class ShaderStyle {
- /// Target SPIR-V (using global variables).
- kSpirv,
- /// Target GLSL (using global variables).
- kGlsl,
- /// Target MSL (using non-struct function parameters for builtins).
- kMsl,
- /// Target HLSL (using structures for all IO).
- kHlsl,
- };
+class CanonicalizeEntryPointIO final : public Castable<CanonicalizeEntryPointIO, Transform> {
+ public:
+ /// ShaderStyle is an enumerator of different ways to emit shader IO.
+ enum class ShaderStyle {
+ /// Target SPIR-V (using global variables).
+ kSpirv,
+ /// Target GLSL (using global variables).
+ kGlsl,
+ /// Target MSL (using non-struct function parameters for builtins).
+ kMsl,
+ /// Target HLSL (using structures for all IO).
+ kHlsl,
+ };
- /// Configuration options for the transform.
- struct Config final : public Castable<Config, Data> {
+ /// Configuration options for the transform.
+ struct Config final : public Castable<Config, Data> {
+ /// Constructor
+ /// @param style the approach to use for emitting shader IO.
+ /// @param sample_mask an optional sample mask to combine with shader masks
+ /// @param emit_vertex_point_size `true` to generate a pointsize builtin
+ explicit Config(ShaderStyle style,
+ uint32_t sample_mask = 0xFFFFFFFF,
+ bool emit_vertex_point_size = false);
+
+ /// Copy constructor
+ Config(const Config&);
+
+ /// Destructor
+ ~Config() override;
+
+ /// The approach to use for emitting shader IO.
+ const ShaderStyle shader_style;
+
+ /// A fixed sample mask to combine into masks produced by fragment shaders.
+ const uint32_t fixed_sample_mask;
+
+ /// Set to `true` to generate a pointsize builtin and have it set to 1.0
+ /// from all vertex shaders in the module.
+ const bool emit_vertex_point_size;
+ };
+
/// Constructor
- /// @param style the approach to use for emitting shader IO.
- /// @param sample_mask an optional sample mask to combine with shader masks
- /// @param emit_vertex_point_size `true` to generate a pointsize builtin
- explicit Config(ShaderStyle style,
- uint32_t sample_mask = 0xFFFFFFFF,
- bool emit_vertex_point_size = false);
+ CanonicalizeEntryPointIO();
+ ~CanonicalizeEntryPointIO() override;
- /// Copy constructor
- Config(const Config&);
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
- /// Destructor
- ~Config() override;
-
- /// The approach to use for emitting shader IO.
- const ShaderStyle shader_style;
-
- /// A fixed sample mask to combine into masks produced by fragment shaders.
- const uint32_t fixed_sample_mask;
-
- /// Set to `true` to generate a pointsize builtin and have it set to 1.0
- /// from all vertex shaders in the module.
- const bool emit_vertex_point_size;
- };
-
- /// Constructor
- CanonicalizeEntryPointIO();
- ~CanonicalizeEntryPointIO() override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
-
- struct State;
+ struct State;
};
} // namespace tint::transform
diff --git a/src/tint/transform/canonicalize_entry_point_io_test.cc b/src/tint/transform/canonicalize_entry_point_io_test.cc
index bf4d699..70a3ad5 100644
--- a/src/tint/transform/canonicalize_entry_point_io_test.cc
+++ b/src/tint/transform/canonicalize_entry_point_io_test.cc
@@ -23,21 +23,21 @@
using CanonicalizeEntryPointIOTest = TransformTest;
TEST_F(CanonicalizeEntryPointIOTest, Error_MissingTransformData) {
- auto* src = "";
+ auto* src = "";
- auto* expect =
- "error: missing transform data for "
- "tint::transform::CanonicalizeEntryPointIO";
+ auto* expect =
+ "error: missing transform data for "
+ "tint::transform::CanonicalizeEntryPointIO";
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, NoShaderIO) {
- // Test that we do not introduce wrapper functions when there is no shader IO
- // to process.
- auto* src = R"(
+ // Test that we do not introduce wrapper functions when there is no shader IO
+ // to process.
+ auto* src = R"(
@stage(fragment)
fn frag_main() {
}
@@ -47,18 +47,17 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Parameters_Spirv) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main(@location(1) loc1 : f32,
@location(2) @interpolate(flat) loc2 : vec4<u32>,
@@ -67,7 +66,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@location(1) @internal(disable_validation__ignore_storage_class) var<in> loc1_1 : f32;
@location(2) @interpolate(flat) @internal(disable_validation__ignore_storage_class) var<in> loc2_1 : vec4<u32>;
@@ -84,16 +83,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Parameters_Msl) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main(@location(1) loc1 : f32,
@location(2) @interpolate(flat) loc2 : vec4<u32>,
@@ -102,7 +100,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
@location(1)
loc1 : f32,
@@ -120,16 +118,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Parameters_Hlsl) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main(@location(1) loc1 : f32,
@location(2) @interpolate(flat) loc2 : vec4<u32>,
@@ -138,7 +135,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
@location(1)
loc1 : f32,
@@ -158,16 +155,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Parameter_TypeAlias) {
- auto* src = R"(
+ auto* src = R"(
type myf32 = f32;
@stage(fragment)
@@ -176,7 +172,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
type myf32 = f32;
struct tint_symbol_1 {
@@ -194,16 +190,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Parameter_TypeAlias_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main(@location(1) loc1 : myf32) {
var x : myf32 = loc1;
@@ -212,7 +207,7 @@
type myf32 = f32;
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
@location(1)
loc1 : f32,
@@ -230,16 +225,15 @@
type myf32 = f32;
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, StructParameters_Spirv) {
- auto* src = R"(
+ auto* src = R"(
struct FragBuiltins {
@builtin(position) coord : vec4<f32>,
};
@@ -256,7 +250,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@location(0) @internal(disable_validation__ignore_storage_class) var<in> loc0_1 : f32;
@location(1) @internal(disable_validation__ignore_storage_class) var<in> loc1_1 : f32;
@@ -284,16 +278,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, StructParameters_Spirv_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main(@location(0) loc0 : f32,
locations : FragLocations,
@@ -310,7 +303,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@location(0) @internal(disable_validation__ignore_storage_class) var<in> loc0_1 : f32;
@location(1) @internal(disable_validation__ignore_storage_class) var<in> loc1_1 : f32;
@@ -338,16 +331,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, StructParameters_kMsl) {
- auto* src = R"(
+ auto* src = R"(
struct FragBuiltins {
@builtin(position) coord : vec4<f32>,
};
@@ -364,7 +356,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct FragBuiltins {
coord : vec4<f32>,
}
@@ -393,16 +385,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, StructParameters_kMsl_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main(@location(0) loc0 : f32,
locations : FragLocations,
@@ -419,7 +410,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
@location(0)
loc0 : f32,
@@ -448,16 +439,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, StructParameters_Hlsl) {
- auto* src = R"(
+ auto* src = R"(
struct FragBuiltins {
@builtin(position) coord : vec4<f32>,
};
@@ -474,7 +464,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct FragBuiltins {
coord : vec4<f32>,
}
@@ -505,16 +495,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, StructParameters_Hlsl_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main(@location(0) loc0 : f32,
locations : FragLocations,
@@ -531,7 +520,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
@location(0)
loc0 : f32,
@@ -562,23 +551,22 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Return_NonStruct_Spirv) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main() -> @builtin(frag_depth) f32 {
return 1.0;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@builtin(frag_depth) @internal(disable_validation__ignore_storage_class) var<out> value : f32;
fn frag_main_inner() -> f32 {
@@ -592,23 +580,22 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Return_NonStruct_Msl) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main() -> @builtin(frag_depth) f32 {
return 1.0;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
@builtin(frag_depth)
value : f32,
@@ -627,23 +614,22 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Return_NonStruct_Hlsl) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main() -> @builtin(frag_depth) f32 {
return 1.0;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
@builtin(frag_depth)
value : f32,
@@ -662,16 +648,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Spirv) {
- auto* src = R"(
+ auto* src = R"(
struct FragOutput {
@location(0) color : vec4<f32>,
@builtin(frag_depth) depth : f32,
@@ -688,7 +673,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@location(0) @internal(disable_validation__ignore_storage_class) var<out> color_1 : vec4<f32>;
@builtin(frag_depth) @internal(disable_validation__ignore_storage_class) var<out> depth_1 : f32;
@@ -718,16 +703,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Spirv_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main() -> FragOutput {
var output : FragOutput;
@@ -744,7 +728,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@location(0) @internal(disable_validation__ignore_storage_class) var<out> color_1 : vec4<f32>;
@builtin(frag_depth) @internal(disable_validation__ignore_storage_class) var<out> depth_1 : f32;
@@ -774,16 +758,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Msl) {
- auto* src = R"(
+ auto* src = R"(
struct FragOutput {
@location(0) color : vec4<f32>,
@builtin(frag_depth) depth : f32,
@@ -800,7 +783,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct FragOutput {
color : vec4<f32>,
depth : f32,
@@ -835,16 +818,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Msl_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main() -> FragOutput {
var output : FragOutput;
@@ -861,7 +843,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
@location(0)
color : vec4<f32>,
@@ -896,16 +878,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Hlsl) {
- auto* src = R"(
+ auto* src = R"(
struct FragOutput {
@location(0) color : vec4<f32>,
@builtin(frag_depth) depth : f32,
@@ -922,7 +903,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct FragOutput {
color : vec4<f32>,
depth : f32,
@@ -957,16 +938,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Hlsl_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main() -> FragOutput {
var output : FragOutput;
@@ -983,7 +963,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
@location(0)
color : vec4<f32>,
@@ -1018,17 +998,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(CanonicalizeEntryPointIOTest,
- StructParameters_SharedDeviceFunction_Spirv) {
- auto* src = R"(
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_SharedDeviceFunction_Spirv) {
+ auto* src = R"(
struct FragmentInput {
@location(0) value : f32,
@location(1) mul : f32,
@@ -1049,7 +1027,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@location(0) @internal(disable_validation__ignore_storage_class) var<in> value_1 : f32;
@location(1) @internal(disable_validation__ignore_storage_class) var<in> mul_1 : f32;
@@ -1086,17 +1064,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(CanonicalizeEntryPointIOTest,
- StructParameters_SharedDeviceFunction_Spirv_OutOfOrder) {
- auto* src = R"(
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_SharedDeviceFunction_Spirv_OutOfOrder) {
+ auto* src = R"(
@stage(fragment)
fn frag_main1(inputs : FragmentInput) {
var x : f32 = foo(inputs);
@@ -1117,7 +1093,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@location(0) @internal(disable_validation__ignore_storage_class) var<in> value_1 : f32;
@location(1) @internal(disable_validation__ignore_storage_class) var<in> mul_1 : f32;
@@ -1154,17 +1130,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(CanonicalizeEntryPointIOTest,
- StructParameters_SharedDeviceFunction_Msl) {
- auto* src = R"(
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_SharedDeviceFunction_Msl) {
+ auto* src = R"(
struct FragmentInput {
@location(0) value : f32,
@location(1) mul : f32,
@@ -1185,7 +1159,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct FragmentInput {
value : f32,
mul : f32,
@@ -1228,17 +1202,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(CanonicalizeEntryPointIOTest,
- StructParameters_SharedDeviceFunction_Msl_OutOfOrder) {
- auto* src = R"(
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_SharedDeviceFunction_Msl_OutOfOrder) {
+ auto* src = R"(
@stage(fragment)
fn frag_main1(inputs : FragmentInput) {
var x : f32 = foo(inputs);
@@ -1259,7 +1231,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
@location(0)
value : f32,
@@ -1302,17 +1274,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(CanonicalizeEntryPointIOTest,
- StructParameters_SharedDeviceFunction_Hlsl) {
- auto* src = R"(
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_SharedDeviceFunction_Hlsl) {
+ auto* src = R"(
struct FragmentInput {
@location(0) value : f32,
@location(1) mul : f32,
@@ -1333,7 +1303,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct FragmentInput {
value : f32,
mul : f32,
@@ -1376,17 +1346,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(CanonicalizeEntryPointIOTest,
- StructParameters_SharedDeviceFunction_Hlsl_OutOfOrder) {
- auto* src = R"(
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_SharedDeviceFunction_Hlsl_OutOfOrder) {
+ auto* src = R"(
@stage(fragment)
fn frag_main1(inputs : FragmentInput) {
var x : f32 = foo(inputs);
@@ -1407,7 +1375,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
@location(0)
value : f32,
@@ -1450,16 +1418,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Struct_ModuleScopeVariable) {
- auto* src = R"(
+ auto* src = R"(
struct FragmentInput {
@location(0) col1 : f32,
@location(1) col2 : f32,
@@ -1483,7 +1450,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct FragmentInput {
col1 : f32,
col2 : f32,
@@ -1518,16 +1485,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Struct_ModuleScopeVariable_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main1(inputs : FragmentInput) {
global_inputs = inputs;
@@ -1551,7 +1517,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
@location(0)
col1 : f32,
@@ -1586,16 +1552,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Struct_TypeAliases) {
- auto* src = R"(
+ auto* src = R"(
type myf32 = f32;
struct FragmentInput {
@@ -1623,7 +1588,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
type myf32 = f32;
struct FragmentInput {
@@ -1673,16 +1638,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Struct_TypeAliases_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main(inputs : MyFragmentInput) -> MyFragmentOutput {
var x : myf32 = foo(inputs);
@@ -1710,7 +1674,7 @@
type myf32 = f32;
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
@location(0)
col1 : f32,
@@ -1760,16 +1724,15 @@
type myf32 = f32;
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, InterpolateAttributes) {
- auto* src = R"(
+ auto* src = R"(
struct VertexOut {
@builtin(position) pos : vec4<f32>,
@location(1) @interpolate(flat) loc1 : f32,
@@ -1794,7 +1757,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct VertexOut {
pos : vec4<f32>,
loc1 : f32,
@@ -1852,16 +1815,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, InterpolateAttributes_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main(inputs : FragmentIn,
@location(3) @interpolate(perspective, centroid) loc3 : f32) {
@@ -1886,7 +1848,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
@location(1) @interpolate(flat)
loc1 : f32,
@@ -1944,18 +1906,17 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, InterpolateAttributes_Integers_Spirv) {
- // Test that we add a Flat attribute to integers that are vertex outputs and
- // fragment inputs, but not vertex inputs or fragment outputs.
- auto* src = R"(
+ // Test that we add a Flat attribute to integers that are vertex outputs and
+ // fragment inputs, but not vertex inputs or fragment outputs.
+ auto* src = R"(
struct VertexIn {
@location(0) i : i32,
@location(1) u : u32,
@@ -1989,8 +1950,8 @@
}
)";
- auto* expect =
- R"(
+ auto* expect =
+ R"(
@location(0) @internal(disable_validation__ignore_storage_class) var<in> i_1 : i32;
@location(1) @internal(disable_validation__ignore_storage_class) var<in> u_1 : u32;
@@ -2075,19 +2036,17 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(CanonicalizeEntryPointIOTest,
- InterpolateAttributes_Integers_Spirv_OutOfOrder) {
- // Test that we add a Flat attribute to integers that are vertex outputs and
- // fragment inputs, but not vertex inputs or fragment outputs.
- auto* src = R"(
+TEST_F(CanonicalizeEntryPointIOTest, InterpolateAttributes_Integers_Spirv_OutOfOrder) {
+ // Test that we add a Flat attribute to integers that are vertex outputs and
+ // fragment inputs, but not vertex inputs or fragment outputs.
+ auto* src = R"(
@stage(vertex)
fn vert_main(in : VertexIn) -> VertexOut {
return VertexOut(in.i, in.u, in.vi, in.vu, vec4<f32>());
@@ -2121,8 +2080,8 @@
};
)";
- auto* expect =
- R"(
+ auto* expect =
+ R"(
@location(0) @internal(disable_validation__ignore_storage_class) var<in> i_1 : i32;
@location(1) @internal(disable_validation__ignore_storage_class) var<in> u_1 : u32;
@@ -2207,16 +2166,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, InvariantAttributes) {
- auto* src = R"(
+ auto* src = R"(
struct VertexOut {
@builtin(position) @invariant pos : vec4<f32>,
};
@@ -2232,7 +2190,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct VertexOut {
pos : vec4<f32>,
}
@@ -2272,16 +2230,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, InvariantAttributes_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn main1() -> VertexOut {
return VertexOut();
@@ -2297,7 +2254,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
@builtin(position) @invariant
pos : vec4<f32>,
@@ -2337,16 +2294,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Struct_LayoutAttributes) {
- auto* src = R"(
+ auto* src = R"(
struct FragmentInput {
@size(16) @location(1) value : f32,
@builtin(position) @align(32) coord : vec4<f32>,
@@ -2363,7 +2319,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct FragmentInput {
@size(16)
value : f32,
@@ -2405,16 +2361,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Struct_LayoutAttributes_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main(inputs : FragmentInput) -> FragmentOutput {
return FragmentOutput(inputs.coord.x * inputs.value + inputs.loc0);
@@ -2431,7 +2386,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
@location(0) @interpolate(linear, sample)
loc0 : f32,
@@ -2473,16 +2428,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, SortedMembers) {
- auto* src = R"(
+ auto* src = R"(
struct VertexOutput {
@location(1) @interpolate(flat) b : u32,
@builtin(position) pos : vec4<f32>,
@@ -2510,7 +2464,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct VertexOutput {
b : u32,
pos : vec4<f32>,
@@ -2578,16 +2532,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, SortedMembers_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn vert_main() -> VertexOutput {
return VertexOutput();
@@ -2615,7 +2568,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
@location(0)
a : f32,
@@ -2683,22 +2636,21 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, DontRenameSymbols) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn tint_symbol_1(@location(0) col : f32) {
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_2 {
@location(0)
col : f32,
@@ -2713,22 +2665,21 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_VoidNoReturn) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main() {
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
@builtin(sample_mask)
fixed_sample_mask : u32,
@@ -2746,23 +2697,22 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_VoidWithReturn) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main() {
return;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
@builtin(sample_mask)
fixed_sample_mask : u32,
@@ -2781,23 +2731,22 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_WithAuthoredMask) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main() -> @builtin(sample_mask) u32 {
return 7u;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
@builtin(sample_mask)
value : u32,
@@ -2816,23 +2765,22 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_WithoutAuthoredMask) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main() -> @location(0) f32 {
return 1.0;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
@location(0)
value : f32,
@@ -2854,16 +2802,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_StructWithAuthoredMask) {
- auto* src = R"(
+ auto* src = R"(
struct Output {
@builtin(frag_depth) depth : f32,
@builtin(sample_mask) mask : u32,
@@ -2876,7 +2823,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct Output {
depth : f32,
mask : u32,
@@ -2907,17 +2854,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(CanonicalizeEntryPointIOTest,
- FixedSampleMask_StructWithAuthoredMask_OutOfOrder) {
- auto* src = R"(
+TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_StructWithAuthoredMask_OutOfOrder) {
+ auto* src = R"(
@stage(fragment)
fn frag_main() -> Output {
return Output(0.5, 7u, 1.0);
@@ -2930,7 +2875,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
@location(0)
value : f32,
@@ -2961,17 +2906,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(CanonicalizeEntryPointIOTest,
- FixedSampleMask_StructWithoutAuthoredMask) {
- auto* src = R"(
+TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_StructWithoutAuthoredMask) {
+ auto* src = R"(
struct Output {
@builtin(frag_depth) depth : f32,
@location(0) value : f32,
@@ -2983,7 +2926,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct Output {
depth : f32,
value : f32,
@@ -3013,17 +2956,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(CanonicalizeEntryPointIOTest,
- FixedSampleMask_StructWithoutAuthoredMask_OutOfOrder) {
- auto* src = R"(
+TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_StructWithoutAuthoredMask_OutOfOrder) {
+ auto* src = R"(
@stage(fragment)
fn frag_main() -> Output {
return Output(0.5, 1.0);
@@ -3035,7 +2976,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
@location(0)
value : f32,
@@ -3065,16 +3006,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_MultipleShaders) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn frag_main1() -> @builtin(sample_mask) u32 {
return 7u;
@@ -3095,7 +3035,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
@builtin(sample_mask)
value : u32,
@@ -3155,16 +3095,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_AvoidNameClash) {
- auto* src = R"(
+ auto* src = R"(
struct FragOut {
@location(0) fixed_sample_mask : vec4<f32>,
@location(1) fixed_sample_mask_1 : vec4<f32>,
@@ -3176,7 +3115,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct FragOut {
fixed_sample_mask : vec4<f32>,
fixed_sample_mask_1 : vec4<f32>,
@@ -3206,24 +3145,22 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(CanonicalizeEntryPointIOTest,
- EmitVertexPointSize_ReturnNonStruct_Spirv) {
- auto* src = R"(
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_ReturnNonStruct_Spirv) {
+ auto* src = R"(
@stage(vertex)
fn vert_main() -> @builtin(position) vec4<f32> {
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@builtin(position) @internal(disable_validation__ignore_storage_class) var<out> value : vec4<f32>;
@builtin(pointsize) @internal(disable_validation__ignore_storage_class) var<out> vertex_point_size : f32;
@@ -3240,23 +3177,23 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kSpirv, 0xFFFFFFFF, true);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_ReturnNonStruct_Msl) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn vert_main() -> @builtin(position) vec4<f32> {
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
@builtin(position)
value : vec4<f32>,
@@ -3278,16 +3215,16 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0xFFFFFFFF, true);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_ReturnStruct_Spirv) {
- auto* src = R"(
+ auto* src = R"(
struct VertOut {
@builtin(position) pos : vec4<f32>,
};
@@ -3298,7 +3235,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@builtin(position) @internal(disable_validation__ignore_storage_class) var<out> pos_1 : vec4<f32>;
@builtin(pointsize) @internal(disable_validation__ignore_storage_class) var<out> vertex_point_size : f32;
@@ -3319,17 +3256,16 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kSpirv, 0xFFFFFFFF, true);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(CanonicalizeEntryPointIOTest,
- EmitVertexPointSize_ReturnStruct_Spirv_OutOfOrder) {
- auto* src = R"(
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_ReturnStruct_Spirv_OutOfOrder) {
+ auto* src = R"(
@stage(vertex)
fn vert_main() -> VertOut {
return VertOut();
@@ -3340,7 +3276,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@builtin(position) @internal(disable_validation__ignore_storage_class) var<out> pos_1 : vec4<f32>;
@builtin(pointsize) @internal(disable_validation__ignore_storage_class) var<out> vertex_point_size : f32;
@@ -3361,16 +3297,16 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kSpirv, 0xFFFFFFFF, true);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_ReturnStruct_Msl) {
- auto* src = R"(
+ auto* src = R"(
struct VertOut {
@builtin(position) pos : vec4<f32>,
};
@@ -3381,7 +3317,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct VertOut {
pos : vec4<f32>,
}
@@ -3407,17 +3343,16 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0xFFFFFFFF, true);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(CanonicalizeEntryPointIOTest,
- EmitVertexPointSize_ReturnStruct_Msl_OutOfOrder) {
- auto* src = R"(
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_ReturnStruct_Msl_OutOfOrder) {
+ auto* src = R"(
@stage(vertex)
fn vert_main() -> VertOut {
return VertOut();
@@ -3428,7 +3363,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
@builtin(position)
pos : vec4<f32>,
@@ -3454,16 +3389,16 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0xFFFFFFFF, true);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_AvoidNameClash_Spirv) {
- auto* src = R"(
+ auto* src = R"(
var<private> vertex_point_size : f32;
var<private> vertex_point_size_1 : f32;
var<private> vertex_point_size_2 : f32;
@@ -3488,7 +3423,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@location(0) @internal(disable_validation__ignore_storage_class) var<in> collide_2 : f32;
@location(1) @internal(disable_validation__ignore_storage_class) var<in> collide_3 : f32;
@@ -3532,17 +3467,16 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kSpirv, 0xFFFFFFFF, true);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(CanonicalizeEntryPointIOTest,
- EmitVertexPointSize_AvoidNameClash_Spirv_OutOfOrder) {
- auto* src = R"(
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_AvoidNameClash_Spirv_OutOfOrder) {
+ auto* src = R"(
@stage(vertex)
fn vert_main(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
let x = collide.collide + collide_1.collide;
@@ -3567,7 +3501,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@location(0) @internal(disable_validation__ignore_storage_class) var<in> collide_2 : f32;
@location(1) @internal(disable_validation__ignore_storage_class) var<in> collide_3 : f32;
@@ -3611,16 +3545,16 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kSpirv, 0xFFFFFFFF, true);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_AvoidNameClash_Msl) {
- auto* src = R"(
+ auto* src = R"(
struct VertIn1 {
@location(0) collide : f32,
};
@@ -3641,7 +3575,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct VertIn1 {
collide : f32,
}
@@ -3687,17 +3621,16 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0xFFFFFFFF, true);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(CanonicalizeEntryPointIOTest,
- EmitVertexPointSize_AvoidNameClash_Msl_OutOfOrder) {
- auto* src = R"(
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_AvoidNameClash_Msl_OutOfOrder) {
+ auto* src = R"(
@stage(vertex)
fn vert_main(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
let x = collide.collide + collide_1.collide;
@@ -3718,7 +3651,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
@location(0)
collide : f32,
@@ -3764,16 +3697,16 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0xFFFFFFFF, true);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_AvoidNameClash_Hlsl) {
- auto* src = R"(
+ auto* src = R"(
struct VertIn1 {
@location(0) collide : f32,
};
@@ -3794,7 +3727,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct VertIn1 {
collide : f32,
}
@@ -3840,17 +3773,16 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl, 0xFFFFFFFF, true);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(CanonicalizeEntryPointIOTest,
- EmitVertexPointSize_AvoidNameClash_Hlsl_OutOfOrder) {
- auto* src = R"(
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_AvoidNameClash_Hlsl_OutOfOrder) {
+ auto* src = R"(
@stage(vertex)
fn vert_main(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
let x = collide.collide + collide_1.collide;
@@ -3871,7 +3803,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
@location(0)
collide : f32,
@@ -3917,16 +3849,16 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl, 0xFFFFFFFF, true);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, SpirvSampleMaskBuiltins) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn main(@builtin(sample_index) sample_index : u32,
@builtin(sample_mask) mask_in : u32
@@ -3935,7 +3867,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@builtin(sample_index) @internal(disable_validation__ignore_storage_class) var<in> sample_index_1 : u32;
@builtin(sample_mask) @internal(disable_validation__ignore_storage_class) var<in> mask_in_1 : array<u32, 1>;
@@ -3953,16 +3885,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, GLSLSampleMaskBuiltins) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn fragment_main(@builtin(sample_index) sample_index : u32,
@builtin(sample_mask) mask_in : u32
@@ -3971,7 +3902,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@builtin(sample_index) @internal(disable_validation__ignore_storage_class) var<in> gl_SampleID : i32;
@builtin(sample_mask) @internal(disable_validation__ignore_storage_class) var<in> gl_SampleMaskIn : array<i32, 1>;
@@ -3989,16 +3920,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kGlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kGlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, GLSLVertexInstanceIndexBuiltins) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn vertex_main(@builtin(vertex_index) vertexID : u32,
@builtin(instance_index) instanceID : u32
@@ -4007,7 +3937,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@builtin(vertex_index) @internal(disable_validation__ignore_storage_class) var<in> gl_VertexID : i32;
@builtin(instance_index) @internal(disable_validation__ignore_storage_class) var<in> gl_InstanceID : i32;
@@ -4027,12 +3957,11 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kGlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kGlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/combine_samplers.cc b/src/tint/transform/combine_samplers.cc
index 4b1e892..66b5b937 100644
--- a/src/tint/transform/combine_samplers.cc
+++ b/src/tint/transform/combine_samplers.cc
@@ -30,8 +30,8 @@
namespace {
bool IsGlobal(const tint::sem::VariablePair& pair) {
- return pair.first->Is<tint::sem::GlobalVariable>() &&
- (!pair.second || pair.second->Is<tint::sem::GlobalVariable>());
+ return pair.first->Is<tint::sem::GlobalVariable>() &&
+ (!pair.second || pair.second->Is<tint::sem::GlobalVariable>());
}
} // namespace
@@ -46,308 +46,296 @@
/// The PIMPL state for the CombineSamplers transform
struct CombineSamplers::State {
- /// The clone context
- CloneContext& ctx;
+ /// The clone context
+ CloneContext& ctx;
- /// The binding info
- const BindingInfo* binding_info;
+ /// The binding info
+ const BindingInfo* binding_info;
- /// Map from a texture/sampler pair to the corresponding combined sampler
- /// variable
- using CombinedTextureSamplerMap =
- std::unordered_map<sem::VariablePair, const ast::Variable*>;
+ /// Map from a texture/sampler pair to the corresponding combined sampler
+ /// variable
+ using CombinedTextureSamplerMap = std::unordered_map<sem::VariablePair, const ast::Variable*>;
- /// Use sem::BindingPoint without scope.
- using BindingPoint = sem::BindingPoint;
+ /// Use sem::BindingPoint without scope.
+ using BindingPoint = sem::BindingPoint;
- /// A map of all global texture/sampler variable pairs to the global
- /// combined sampler variable that will replace it.
- CombinedTextureSamplerMap global_combined_texture_samplers_;
+ /// A map of all global texture/sampler variable pairs to the global
+ /// combined sampler variable that will replace it.
+ CombinedTextureSamplerMap global_combined_texture_samplers_;
- /// A map of all texture/sampler variable pairs that contain a function
- /// parameter to the combined sampler function paramter that will replace it.
- std::unordered_map<const sem::Function*, CombinedTextureSamplerMap>
- function_combined_texture_samplers_;
+ /// A map of all texture/sampler variable pairs that contain a function
+ /// parameter to the combined sampler function paramter that will replace it.
+ std::unordered_map<const sem::Function*, CombinedTextureSamplerMap>
+ function_combined_texture_samplers_;
- /// Placeholder global samplers used when a function contains texture-only
- /// references (one comparison sampler, one regular). These are also used as
- /// temporary sampler parameters to the texture builtins to satisfy the WGSL
- /// resolver, but are then ignored and removed by the GLSL writer.
- const ast::Variable* placeholder_samplers_[2] = {};
+ /// Placeholder global samplers used when a function contains texture-only
+ /// references (one comparison sampler, one regular). These are also used as
+ /// temporary sampler parameters to the texture builtins to satisfy the WGSL
+ /// resolver, but are then ignored and removed by the GLSL writer.
+ const ast::Variable* placeholder_samplers_[2] = {};
- /// Group and binding attributes used by all combined sampler globals.
- /// Group 0 and binding 0 are used, with collisions disabled.
- /// @returns the newly-created attribute list
- ast::AttributeList Attributes() const {
- auto attributes = ctx.dst->GroupAndBinding(0, 0);
- attributes.push_back(
- ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision));
- return attributes;
- }
-
- /// Constructor
- /// @param context the clone context
- /// @param info the binding map information
- State(CloneContext& context, const BindingInfo* info)
- : ctx(context), binding_info(info) {}
-
- /// Creates a combined sampler global variables.
- /// (Note this is actually a Texture node at the AST level, but it will be
- /// written as the corresponding sampler (eg., sampler2D) on GLSL output.)
- /// @param texture_var the texture (global) variable
- /// @param sampler_var the sampler (global) variable
- /// @param name the default name to use (may be overridden by map lookup)
- /// @returns the newly-created global variable
- const ast::Variable* CreateCombinedGlobal(const sem::Variable* texture_var,
- const sem::Variable* sampler_var,
- std::string name) {
- SamplerTexturePair bp_pair;
- bp_pair.texture_binding_point =
- texture_var->As<sem::GlobalVariable>()->BindingPoint();
- bp_pair.sampler_binding_point =
- sampler_var ? sampler_var->As<sem::GlobalVariable>()->BindingPoint()
- : binding_info->placeholder_binding_point;
- auto it = binding_info->binding_map.find(bp_pair);
- if (it != binding_info->binding_map.end()) {
- name = it->second;
- }
- const ast::Type* type = CreateCombinedASTTypeFor(texture_var, sampler_var);
- Symbol symbol = ctx.dst->Symbols().New(name);
- return ctx.dst->Global(symbol, type, Attributes());
- }
-
- /// Creates placeholder global sampler variables.
- /// @param kind the sampler kind to create for
- /// @returns the newly-created global variable
- const ast::Variable* CreatePlaceholder(ast::SamplerKind kind) {
- const ast::Type* type = ctx.dst->ty.sampler(kind);
- const char* name = kind == ast::SamplerKind::kComparisonSampler
- ? "placeholder_comparison_sampler"
- : "placeholder_sampler";
- Symbol symbol = ctx.dst->Symbols().New(name);
- return ctx.dst->Global(symbol, type, Attributes());
- }
-
- /// Creates ast::Type for a given texture and sampler variable pair.
- /// Depth textures with no samplers are turned into the corresponding
- /// f32 texture (e.g., texture_depth_2d -> texture_2d<f32>).
- /// @param texture the texture variable of interest
- /// @param sampler the texture variable of interest
- /// @returns the newly-created type
- const ast::Type* CreateCombinedASTTypeFor(const sem::Variable* texture,
- const sem::Variable* sampler) {
- const sem::Type* texture_type = texture->Type()->UnwrapRef();
- const sem::DepthTexture* depth = texture_type->As<sem::DepthTexture>();
- if (depth && !sampler) {
- return ctx.dst->create<ast::SampledTexture>(depth->dim(),
- ctx.dst->create<ast::F32>());
- } else {
- return CreateASTTypeFor(ctx, texture_type);
- }
- }
-
- /// Performs the transformation
- void Run() {
- auto& sem = ctx.src->Sem();
-
- // Remove all texture and sampler global variables. These will be replaced
- // by combined samplers.
- for (auto* var : ctx.src->AST().GlobalVariables()) {
- auto* type = sem.Get(var->type);
- if (type && type->IsAnyOf<sem::Texture, sem::Sampler>() &&
- !type->Is<sem::StorageTexture>()) {
- ctx.Remove(ctx.src->AST().GlobalDeclarations(), var);
- } else if (auto binding_point = var->BindingPoint()) {
- if (binding_point.group->value == 0 &&
- binding_point.binding->value == 0) {
- auto* attribute =
- ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
- ctx.InsertFront(var->attributes, attribute);
- }
- }
+ /// Group and binding attributes used by all combined sampler globals.
+ /// Group 0 and binding 0 are used, with collisions disabled.
+ /// @returns the newly-created attribute list
+ ast::AttributeList Attributes() const {
+ auto attributes = ctx.dst->GroupAndBinding(0, 0);
+ attributes.push_back(ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision));
+ return attributes;
}
- // Rewrite all function signatures to use combined samplers, and remove
- // separate textures & samplers. Create new combined globals where found.
- ctx.ReplaceAll([&](const ast::Function* src) -> const ast::Function* {
- if (auto* func = sem.Get(src)) {
- auto pairs = func->TextureSamplerPairs();
- if (pairs.empty()) {
- return nullptr;
- }
- ast::VariableList params;
- for (auto pair : func->TextureSamplerPairs()) {
- const sem::Variable* texture_var = pair.first;
- const sem::Variable* sampler_var = pair.second;
- std::string name =
- ctx.src->Symbols().NameFor(texture_var->Declaration()->symbol);
- if (sampler_var) {
- name += "_" + ctx.src->Symbols().NameFor(
- sampler_var->Declaration()->symbol);
- }
- if (IsGlobal(pair)) {
- // Both texture and sampler are global; add a new global variable
- // to represent the combined sampler (if not already created).
- utils::GetOrCreate(global_combined_texture_samplers_, pair, [&] {
- return CreateCombinedGlobal(texture_var, sampler_var, name);
- });
- } else {
- // Either texture or sampler (or both) is a function parameter;
- // add a new function parameter to represent the combined sampler.
- const ast::Type* type =
- CreateCombinedASTTypeFor(texture_var, sampler_var);
- const ast::Variable* var =
- ctx.dst->Param(ctx.dst->Symbols().New(name), type);
- params.push_back(var);
- function_combined_texture_samplers_[func][pair] = var;
- }
- }
- // Filter out separate textures and samplers from the original
- // function signature.
- for (auto* var : src->params) {
- if (!sem.Get(var->type)->IsAnyOf<sem::Texture, sem::Sampler>()) {
- params.push_back(ctx.Clone(var));
- }
- }
- // Create a new function signature that differs only in the parameter
- // list.
- auto symbol = ctx.Clone(src->symbol);
- auto* return_type = ctx.Clone(src->return_type);
- auto* body = ctx.Clone(src->body);
- auto attributes = ctx.Clone(src->attributes);
- auto return_type_attributes = ctx.Clone(src->return_type_attributes);
- return ctx.dst->create<ast::Function>(
- symbol, params, return_type, body, std::move(attributes),
- std::move(return_type_attributes));
- }
- return nullptr;
- });
+ /// Constructor
+ /// @param context the clone context
+ /// @param info the binding map information
+ State(CloneContext& context, const BindingInfo* info) : ctx(context), binding_info(info) {}
- // Replace all function call expressions containing texture or
- // sampler parameters to use the current function's combined samplers or
- // the combined global samplers, as appropriate.
- ctx.ReplaceAll([&](const ast::CallExpression* expr)
- -> const ast::Expression* {
- if (auto* call = sem.Get(expr)) {
- ast::ExpressionList args;
- // Replace all texture builtin calls.
- if (auto* builtin = call->Target()->As<sem::Builtin>()) {
- const auto& signature = builtin->Signature();
- int sampler_index = signature.IndexOf(sem::ParameterUsage::kSampler);
- int texture_index = signature.IndexOf(sem::ParameterUsage::kTexture);
- if (texture_index == -1) {
+ /// Creates a combined sampler global variables.
+ /// (Note this is actually a Texture node at the AST level, but it will be
+ /// written as the corresponding sampler (eg., sampler2D) on GLSL output.)
+ /// @param texture_var the texture (global) variable
+ /// @param sampler_var the sampler (global) variable
+ /// @param name the default name to use (may be overridden by map lookup)
+ /// @returns the newly-created global variable
+ const ast::Variable* CreateCombinedGlobal(const sem::Variable* texture_var,
+ const sem::Variable* sampler_var,
+ std::string name) {
+ SamplerTexturePair bp_pair;
+ bp_pair.texture_binding_point = texture_var->As<sem::GlobalVariable>()->BindingPoint();
+ bp_pair.sampler_binding_point = sampler_var
+ ? sampler_var->As<sem::GlobalVariable>()->BindingPoint()
+ : binding_info->placeholder_binding_point;
+ auto it = binding_info->binding_map.find(bp_pair);
+ if (it != binding_info->binding_map.end()) {
+ name = it->second;
+ }
+ const ast::Type* type = CreateCombinedASTTypeFor(texture_var, sampler_var);
+ Symbol symbol = ctx.dst->Symbols().New(name);
+ return ctx.dst->Global(symbol, type, Attributes());
+ }
+
+ /// Creates placeholder global sampler variables.
+ /// @param kind the sampler kind to create for
+ /// @returns the newly-created global variable
+ const ast::Variable* CreatePlaceholder(ast::SamplerKind kind) {
+ const ast::Type* type = ctx.dst->ty.sampler(kind);
+ const char* name = kind == ast::SamplerKind::kComparisonSampler
+ ? "placeholder_comparison_sampler"
+ : "placeholder_sampler";
+ Symbol symbol = ctx.dst->Symbols().New(name);
+ return ctx.dst->Global(symbol, type, Attributes());
+ }
+
+ /// Creates ast::Type for a given texture and sampler variable pair.
+ /// Depth textures with no samplers are turned into the corresponding
+ /// f32 texture (e.g., texture_depth_2d -> texture_2d<f32>).
+ /// @param texture the texture variable of interest
+ /// @param sampler the texture variable of interest
+ /// @returns the newly-created type
+ const ast::Type* CreateCombinedASTTypeFor(const sem::Variable* texture,
+ const sem::Variable* sampler) {
+ const sem::Type* texture_type = texture->Type()->UnwrapRef();
+ const sem::DepthTexture* depth = texture_type->As<sem::DepthTexture>();
+ if (depth && !sampler) {
+ return ctx.dst->create<ast::SampledTexture>(depth->dim(), ctx.dst->create<ast::F32>());
+ } else {
+ return CreateASTTypeFor(ctx, texture_type);
+ }
+ }
+
+ /// Performs the transformation
+ void Run() {
+ auto& sem = ctx.src->Sem();
+
+ // Remove all texture and sampler global variables. These will be replaced
+ // by combined samplers.
+ for (auto* var : ctx.src->AST().GlobalVariables()) {
+ auto* type = sem.Get(var->type);
+ if (type && type->IsAnyOf<sem::Texture, sem::Sampler>() &&
+ !type->Is<sem::StorageTexture>()) {
+ ctx.Remove(ctx.src->AST().GlobalDeclarations(), var);
+ } else if (auto binding_point = var->BindingPoint()) {
+ if (binding_point.group->value == 0 && binding_point.binding->value == 0) {
+ auto* attribute =
+ ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
+ ctx.InsertFront(var->attributes, attribute);
+ }
+ }
+ }
+
+ // Rewrite all function signatures to use combined samplers, and remove
+ // separate textures & samplers. Create new combined globals where found.
+ ctx.ReplaceAll([&](const ast::Function* src) -> const ast::Function* {
+ if (auto* func = sem.Get(src)) {
+ auto pairs = func->TextureSamplerPairs();
+ if (pairs.empty()) {
+ return nullptr;
+ }
+ ast::VariableList params;
+ for (auto pair : func->TextureSamplerPairs()) {
+ const sem::Variable* texture_var = pair.first;
+ const sem::Variable* sampler_var = pair.second;
+ std::string name =
+ ctx.src->Symbols().NameFor(texture_var->Declaration()->symbol);
+ if (sampler_var) {
+ name +=
+ "_" + ctx.src->Symbols().NameFor(sampler_var->Declaration()->symbol);
+ }
+ if (IsGlobal(pair)) {
+ // Both texture and sampler are global; add a new global variable
+ // to represent the combined sampler (if not already created).
+ utils::GetOrCreate(global_combined_texture_samplers_, pair, [&] {
+ return CreateCombinedGlobal(texture_var, sampler_var, name);
+ });
+ } else {
+ // Either texture or sampler (or both) is a function parameter;
+ // add a new function parameter to represent the combined sampler.
+ const ast::Type* type = CreateCombinedASTTypeFor(texture_var, sampler_var);
+ const ast::Variable* var =
+ ctx.dst->Param(ctx.dst->Symbols().New(name), type);
+ params.push_back(var);
+ function_combined_texture_samplers_[func][pair] = var;
+ }
+ }
+ // Filter out separate textures and samplers from the original
+ // function signature.
+ for (auto* var : src->params) {
+ if (!sem.Get(var->type)->IsAnyOf<sem::Texture, sem::Sampler>()) {
+ params.push_back(ctx.Clone(var));
+ }
+ }
+ // Create a new function signature that differs only in the parameter
+ // list.
+ auto symbol = ctx.Clone(src->symbol);
+ auto* return_type = ctx.Clone(src->return_type);
+ auto* body = ctx.Clone(src->body);
+ auto attributes = ctx.Clone(src->attributes);
+ auto return_type_attributes = ctx.Clone(src->return_type_attributes);
+ return ctx.dst->create<ast::Function>(symbol, params, return_type, body,
+ std::move(attributes),
+ std::move(return_type_attributes));
+ }
return nullptr;
- }
- const sem::Expression* texture = call->Arguments()[texture_index];
- // We don't want to combine storage textures with anything, since
- // they never have associated samplers in GLSL.
- if (texture->Type()->UnwrapRef()->Is<sem::StorageTexture>()) {
- return nullptr;
- }
- const sem::Expression* sampler =
- sampler_index != -1 ? call->Arguments()[sampler_index] : nullptr;
- auto* texture_var = texture->As<sem::VariableUser>()->Variable();
- auto* sampler_var =
- sampler ? sampler->As<sem::VariableUser>()->Variable() : nullptr;
- sem::VariablePair new_pair(texture_var, sampler_var);
- for (auto* arg : expr->args) {
- auto* type = ctx.src->TypeOf(arg)->UnwrapRef();
- if (type->Is<sem::Texture>()) {
- const ast::Variable* var =
- IsGlobal(new_pair)
- ? global_combined_texture_samplers_[new_pair]
- : function_combined_texture_samplers_
- [call->Stmt()->Function()][new_pair];
- args.push_back(ctx.dst->Expr(var->symbol));
- } else if (auto* sampler_type = type->As<sem::Sampler>()) {
- ast::SamplerKind kind = sampler_type->kind();
- int index = (kind == ast::SamplerKind::kSampler) ? 0 : 1;
- const ast::Variable*& p = placeholder_samplers_[index];
- if (!p) {
- p = CreatePlaceholder(kind);
- }
- args.push_back(ctx.dst->Expr(p->symbol));
- } else {
- args.push_back(ctx.Clone(arg));
- }
- }
- const ast::Expression* value =
- ctx.dst->Call(ctx.Clone(expr->target.name), args);
- if (builtin->Type() == sem::BuiltinType::kTextureLoad &&
- texture_var->Type()->UnwrapRef()->Is<sem::DepthTexture>() &&
- !call->Stmt()->Declaration()->Is<ast::CallStatement>()) {
- value = ctx.dst->MemberAccessor(value, "x");
- }
- return value;
- }
- // Replace all function calls.
- if (auto* callee = call->Target()->As<sem::Function>()) {
- for (auto pair : callee->TextureSamplerPairs()) {
- // Global pairs used by the callee do not require a function
- // parameter at the call site.
- if (IsGlobal(pair)) {
- continue;
- }
- const sem::Variable* texture_var = pair.first;
- const sem::Variable* sampler_var = pair.second;
- if (auto* param = texture_var->As<sem::Parameter>()) {
- const sem::Expression* texture =
- call->Arguments()[param->Index()];
- texture_var = texture->As<sem::VariableUser>()->Variable();
- }
- if (sampler_var) {
- if (auto* param = sampler_var->As<sem::Parameter>()) {
- const sem::Expression* sampler =
- call->Arguments()[param->Index()];
- sampler_var = sampler->As<sem::VariableUser>()->Variable();
- }
- }
- sem::VariablePair new_pair(texture_var, sampler_var);
- // If both texture and sampler are (now) global, pass that
- // global variable to the callee. Otherwise use the caller's
- // function parameter for this pair.
- const ast::Variable* var =
- IsGlobal(new_pair) ? global_combined_texture_samplers_[new_pair]
- : function_combined_texture_samplers_
- [call->Stmt()->Function()][new_pair];
- auto* arg = ctx.dst->Expr(var->symbol);
- args.push_back(arg);
- }
- // Append all of the remaining non-texture and non-sampler
- // parameters.
- for (auto* arg : expr->args) {
- if (!ctx.src->TypeOf(arg)
- ->UnwrapRef()
- ->IsAnyOf<sem::Texture, sem::Sampler>()) {
- args.push_back(ctx.Clone(arg));
- }
- }
- return ctx.dst->Call(ctx.Clone(expr->target.name), args);
- }
- }
- return nullptr;
- });
+ });
- ctx.Clone();
- }
+ // Replace all function call expressions containing texture or
+ // sampler parameters to use the current function's combined samplers or
+ // the combined global samplers, as appropriate.
+ ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::Expression* {
+ if (auto* call = sem.Get(expr)) {
+ ast::ExpressionList args;
+ // Replace all texture builtin calls.
+ if (auto* builtin = call->Target()->As<sem::Builtin>()) {
+ const auto& signature = builtin->Signature();
+ int sampler_index = signature.IndexOf(sem::ParameterUsage::kSampler);
+ int texture_index = signature.IndexOf(sem::ParameterUsage::kTexture);
+ if (texture_index == -1) {
+ return nullptr;
+ }
+ const sem::Expression* texture = call->Arguments()[texture_index];
+ // We don't want to combine storage textures with anything, since
+ // they never have associated samplers in GLSL.
+ if (texture->Type()->UnwrapRef()->Is<sem::StorageTexture>()) {
+ return nullptr;
+ }
+ const sem::Expression* sampler =
+ sampler_index != -1 ? call->Arguments()[sampler_index] : nullptr;
+ auto* texture_var = texture->As<sem::VariableUser>()->Variable();
+ auto* sampler_var =
+ sampler ? sampler->As<sem::VariableUser>()->Variable() : nullptr;
+ sem::VariablePair new_pair(texture_var, sampler_var);
+ for (auto* arg : expr->args) {
+ auto* type = ctx.src->TypeOf(arg)->UnwrapRef();
+ if (type->Is<sem::Texture>()) {
+ const ast::Variable* var =
+ IsGlobal(new_pair)
+ ? global_combined_texture_samplers_[new_pair]
+ : function_combined_texture_samplers_[call->Stmt()->Function()]
+ [new_pair];
+ args.push_back(ctx.dst->Expr(var->symbol));
+ } else if (auto* sampler_type = type->As<sem::Sampler>()) {
+ ast::SamplerKind kind = sampler_type->kind();
+ int index = (kind == ast::SamplerKind::kSampler) ? 0 : 1;
+ const ast::Variable*& p = placeholder_samplers_[index];
+ if (!p) {
+ p = CreatePlaceholder(kind);
+ }
+ args.push_back(ctx.dst->Expr(p->symbol));
+ } else {
+ args.push_back(ctx.Clone(arg));
+ }
+ }
+ const ast::Expression* value =
+ ctx.dst->Call(ctx.Clone(expr->target.name), args);
+ if (builtin->Type() == sem::BuiltinType::kTextureLoad &&
+ texture_var->Type()->UnwrapRef()->Is<sem::DepthTexture>() &&
+ !call->Stmt()->Declaration()->Is<ast::CallStatement>()) {
+ value = ctx.dst->MemberAccessor(value, "x");
+ }
+ return value;
+ }
+ // Replace all function calls.
+ if (auto* callee = call->Target()->As<sem::Function>()) {
+ for (auto pair : callee->TextureSamplerPairs()) {
+ // Global pairs used by the callee do not require a function
+ // parameter at the call site.
+ if (IsGlobal(pair)) {
+ continue;
+ }
+ const sem::Variable* texture_var = pair.first;
+ const sem::Variable* sampler_var = pair.second;
+ if (auto* param = texture_var->As<sem::Parameter>()) {
+ const sem::Expression* texture = call->Arguments()[param->Index()];
+ texture_var = texture->As<sem::VariableUser>()->Variable();
+ }
+ if (sampler_var) {
+ if (auto* param = sampler_var->As<sem::Parameter>()) {
+ const sem::Expression* sampler = call->Arguments()[param->Index()];
+ sampler_var = sampler->As<sem::VariableUser>()->Variable();
+ }
+ }
+ sem::VariablePair new_pair(texture_var, sampler_var);
+ // If both texture and sampler are (now) global, pass that
+ // global variable to the callee. Otherwise use the caller's
+ // function parameter for this pair.
+ const ast::Variable* var =
+ IsGlobal(new_pair)
+ ? global_combined_texture_samplers_[new_pair]
+ : function_combined_texture_samplers_[call->Stmt()->Function()]
+ [new_pair];
+ auto* arg = ctx.dst->Expr(var->symbol);
+ args.push_back(arg);
+ }
+ // Append all of the remaining non-texture and non-sampler
+ // parameters.
+ for (auto* arg : expr->args) {
+ if (!ctx.src->TypeOf(arg)
+ ->UnwrapRef()
+ ->IsAnyOf<sem::Texture, sem::Sampler>()) {
+ args.push_back(ctx.Clone(arg));
+ }
+ }
+ return ctx.dst->Call(ctx.Clone(expr->target.name), args);
+ }
+ }
+ return nullptr;
+ });
+
+ ctx.Clone();
+ }
};
CombineSamplers::CombineSamplers() = default;
CombineSamplers::~CombineSamplers() = default;
-void CombineSamplers::Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap&) const {
- auto* binding_info = inputs.Get<BindingInfo>();
- if (!binding_info) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform,
- "missing transform data for " + std::string(TypeInfo().name));
- return;
- }
+void CombineSamplers::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
+ auto* binding_info = inputs.Get<BindingInfo>();
+ if (!binding_info) {
+ ctx.dst->Diagnostics().add_error(
+ diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
+ return;
+ }
- State(ctx, binding_info).Run();
+ State(ctx, binding_info).Run();
}
} // namespace tint::transform
diff --git a/src/tint/transform/combine_samplers.h b/src/tint/transform/combine_samplers.h
index d15d2ab..8dfc098 100644
--- a/src/tint/transform/combine_samplers.h
+++ b/src/tint/transform/combine_samplers.h
@@ -53,54 +53,52 @@
/// (dimensionality, component type, etc). The GLSL writer outputs such
/// (Tint) Textures as (GLSL) Samplers.
class CombineSamplers final : public Castable<CombineSamplers, Transform> {
- public:
- /// A pair of binding points.
- using SamplerTexturePair = sem::SamplerTexturePair;
+ public:
+ /// A pair of binding points.
+ using SamplerTexturePair = sem::SamplerTexturePair;
- /// A map from a sampler/texture pair to a named global.
- using BindingMap = std::unordered_map<SamplerTexturePair, std::string>;
+ /// A map from a sampler/texture pair to a named global.
+ using BindingMap = std::unordered_map<SamplerTexturePair, std::string>;
- /// The client-provided mapping from separate texture and sampler binding
- /// points to combined sampler binding point.
- struct BindingInfo final : public Castable<Data, transform::Data> {
+ /// The client-provided mapping from separate texture and sampler binding
+ /// points to combined sampler binding point.
+ struct BindingInfo final : public Castable<Data, transform::Data> {
+ /// Constructor
+ /// @param map the map of all (texture, sampler) -> (combined) pairs
+ /// @param placeholder the binding point to use for placeholder samplers.
+ BindingInfo(const BindingMap& map, const sem::BindingPoint& placeholder);
+
+ /// Copy constructor
+ /// @param other the other BindingInfo to copy
+ BindingInfo(const BindingInfo& other);
+
+ /// Destructor
+ ~BindingInfo() override;
+
+ /// A map of bindings from (texture, sampler) -> combined sampler.
+ BindingMap binding_map;
+
+ /// The binding point to use for placeholder samplers.
+ sem::BindingPoint placeholder_binding_point;
+ };
+
/// Constructor
- /// @param map the map of all (texture, sampler) -> (combined) pairs
- /// @param placeholder the binding point to use for placeholder samplers.
- BindingInfo(const BindingMap& map, const sem::BindingPoint& placeholder);
-
- /// Copy constructor
- /// @param other the other BindingInfo to copy
- BindingInfo(const BindingInfo& other);
+ CombineSamplers();
/// Destructor
- ~BindingInfo() override;
+ ~CombineSamplers() override;
- /// A map of bindings from (texture, sampler) -> combined sampler.
- BindingMap binding_map;
+ protected:
+ /// The PIMPL state for this transform
+ struct State;
- /// The binding point to use for placeholder samplers.
- sem::BindingPoint placeholder_binding_point;
- };
-
- /// Constructor
- CombineSamplers();
-
- /// Destructor
- ~CombineSamplers() override;
-
- protected:
- /// The PIMPL state for this transform
- struct State;
-
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/combine_samplers_test.cc b/src/tint/transform/combine_samplers_test.cc
index cb60103..1d84859 100644
--- a/src/tint/transform/combine_samplers_test.cc
+++ b/src/tint/transform/combine_samplers_test.cc
@@ -25,19 +25,18 @@
using CombineSamplersTest = TransformTest;
TEST_F(CombineSamplersTest, EmptyModule) {
- auto* src = "";
- auto* expect = "";
+ auto* src = "";
+ auto* expect = "";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, SimplePair) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var t : texture_2d<f32>;
@group(0) @binding(1) var s : sampler;
@@ -46,7 +45,7 @@
return textureSample(t, s, vec2<f32>(1.0, 2.0));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var t_s : texture_2d<f32>;
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
@@ -56,16 +55,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, SimplePair_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn main() -> vec4<f32> {
return textureSample(t, s, vec2<f32>(1.0, 2.0));
}
@@ -74,7 +72,7 @@
@group(0) @binding(1) var s : sampler;
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var t_s : texture_2d<f32>;
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
@@ -84,16 +82,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, SimplePairInAFunction) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var t : texture_2d<f32>;
@group(0) @binding(1) var s : sampler;
@@ -106,7 +103,7 @@
return sample(t, s, vec2<f32>(1.0, 2.0));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
fn sample(t_s : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
@@ -120,16 +117,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, SimplePairInAFunction_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn main() -> vec4<f32> {
return sample(t, s, vec2<f32>(1.0, 2.0));
}
@@ -142,7 +138,7 @@
@group(0) @binding(0) var t : texture_2d<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var t_s : texture_2d<f32>;
fn main() -> vec4<f32> {
@@ -156,16 +152,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, SimplePairRename) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(1) var t : texture_2d<f32>;
@group(2) @binding(3) var s : sampler;
@@ -174,7 +169,7 @@
return textureSample(t, s, vec2<f32>(1.0, 2.0));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var fuzzy : texture_2d<f32>;
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
@@ -184,23 +179,23 @@
}
)";
- DataMap data;
- CombineSamplers::BindingMap map;
- sem::SamplerTexturePair pair;
- pair.texture_binding_point.group = 0;
- pair.texture_binding_point.binding = 1;
- pair.sampler_binding_point.group = 2;
- pair.sampler_binding_point.binding = 3;
- map[pair] = "fuzzy";
- sem::BindingPoint placeholder{1024, 0};
- data.Add<CombineSamplers::BindingInfo>(map, placeholder);
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ CombineSamplers::BindingMap map;
+ sem::SamplerTexturePair pair;
+ pair.texture_binding_point.group = 0;
+ pair.texture_binding_point.binding = 1;
+ pair.sampler_binding_point.group = 2;
+ pair.sampler_binding_point.binding = 3;
+ map[pair] = "fuzzy";
+ sem::BindingPoint placeholder{1024, 0};
+ data.Add<CombineSamplers::BindingInfo>(map, placeholder);
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, SimplePairRenameMiss) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(1) var t : texture_2d<f32>;
@group(2) @binding(3) var s : sampler;
@@ -209,7 +204,7 @@
return textureSample(t, s, vec2<f32>(1.0, 2.0));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var t_s : texture_2d<f32>;
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
@@ -219,23 +214,23 @@
}
)";
- DataMap data;
- CombineSamplers::BindingMap map;
- sem::SamplerTexturePair pair;
- pair.texture_binding_point.group = 3;
- pair.texture_binding_point.binding = 2;
- pair.sampler_binding_point.group = 1;
- pair.sampler_binding_point.binding = 0;
- map[pair] = "fuzzy";
- sem::BindingPoint placeholder{1024, 0};
- data.Add<CombineSamplers::BindingInfo>(map, placeholder);
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ CombineSamplers::BindingMap map;
+ sem::SamplerTexturePair pair;
+ pair.texture_binding_point.group = 3;
+ pair.texture_binding_point.binding = 2;
+ pair.sampler_binding_point.group = 1;
+ pair.sampler_binding_point.binding = 0;
+ map[pair] = "fuzzy";
+ sem::BindingPoint placeholder{1024, 0};
+ data.Add<CombineSamplers::BindingInfo>(map, placeholder);
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, AliasedTypes) {
- auto* src = R"(
+ auto* src = R"(
type Tex2d = texture_2d<f32>;
@@ -251,7 +246,7 @@
return sample(t, s, vec2<f32>(1.0, 2.0));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
type Tex2d = texture_2d<f32>;
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
@@ -267,16 +262,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, AliasedTypes_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn main() -> vec4<f32> {
return sample(t, s, vec2<f32>(1.0, 2.0));
}
@@ -290,7 +284,7 @@
type Tex2d = texture_2d<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var t_s : texture_2d<f32>;
fn main() -> vec4<f32> {
@@ -306,16 +300,15 @@
type Tex2d = texture_2d<f32>;
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, SimplePairInTwoFunctions) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var t : texture_2d<f32>;
@group(0) @binding(1) var s : sampler;
@@ -332,7 +325,7 @@
return f(t, s, vec2<f32>(1.0, 2.0));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
fn g(t_s : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
@@ -350,16 +343,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, SimplePairInTwoFunctions_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn main() -> vec4<f32> {
return f(t, s, vec2<f32>(1.0, 2.0));
}
@@ -375,7 +367,7 @@
@group(0) @binding(1) var s : sampler;
@group(0) @binding(0) var t : texture_2d<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var t_s : texture_2d<f32>;
fn main() -> vec4<f32> {
@@ -393,16 +385,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, TwoFunctionsGenerateSamePair) {
- auto* src = R"(
+ auto* src = R"(
@group(1) @binding(0) var tex : texture_2d<f32>;
@group(1) @binding(1) var samp : sampler;
@@ -419,7 +410,7 @@
return f() + g();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp : texture_2d<f32>;
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
@@ -437,16 +428,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, ThreeTexturesThreeSamplers) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var tex1 : texture_2d<f32>;
@group(0) @binding(1) var tex2 : texture_2d<f32>;
@group(0) @binding(2) var tex3 : texture_2d<f32>;
@@ -471,7 +461,7 @@
+ sample(tex3, samp3);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
fn sample(t_s : texture_2d<f32>) -> vec4<f32> {
@@ -501,16 +491,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, TwoFunctionsTwoTexturesDiamond) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var tex1 : texture_2d<f32>;
@group(0) @binding(1) var tex2 : texture_2d<f32>;
@@ -529,7 +518,7 @@
return f(tex1, tex2, samp, vec2<f32>(1.0, 2.0));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
fn sample(t_s : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
@@ -549,16 +538,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, TwoFunctionsTwoSamplersDiamond) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var tex : texture_2d<f32>;
@group(0) @binding(1) var samp1 : sampler;
@@ -577,7 +565,7 @@
return f(tex, samp1, samp2, vec2<f32>(1.0, 2.0));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
fn sample(t_s : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
@@ -597,16 +585,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, GlobalTextureLocalSampler) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var tex : texture_2d<f32>;
@group(0) @binding(1) var samp1 : sampler;
@@ -621,7 +608,7 @@
return f(samp1, samp2, vec2<f32>(1.0, 2.0));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
fn f(tex_s1 : texture_2d<f32>, tex_s2 : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
@@ -637,16 +624,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, GlobalTextureLocalSampler_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn main() -> vec4<f32> {
return f(samp1, samp2, vec2<f32>(1.0, 2.0));
}
@@ -659,7 +645,7 @@
@group(0) @binding(2) var samp2 : sampler;
@group(0) @binding(0) var tex : texture_2d<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp1 : texture_2d<f32>;
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp2 : texture_2d<f32>;
@@ -675,16 +661,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, LocalTextureGlobalSampler) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var tex1 : texture_2d<f32>;
@group(0) @binding(1) var tex2 : texture_2d<f32>;
@@ -699,7 +684,7 @@
return f(tex1, tex2, vec2<f32>(1.0, 2.0));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
fn f(t1_samp : texture_2d<f32>, t2_samp : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
@@ -715,16 +700,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, LocalTextureGlobalSampler_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn main() -> vec4<f32> {
return f(tex1, tex2, vec2<f32>(1.0, 2.0));
}
@@ -737,7 +721,7 @@
@group(0) @binding(0) var tex1 : texture_2d<f32>;
@group(0) @binding(1) var tex2 : texture_2d<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex1_samp : texture_2d<f32>;
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex2_samp : texture_2d<f32>;
@@ -753,16 +737,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, TextureLoadNoSampler) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var tex : texture_2d<f32>;
fn f(t : texture_2d<f32>, coords : vec2<i32>) -> vec4<f32> {
@@ -773,7 +756,7 @@
return f(tex, vec2<i32>(1, 2));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f(t_1 : texture_2d<f32>, coords : vec2<i32>) -> vec4<f32> {
return textureLoad(t_1, coords, 0);
}
@@ -785,23 +768,23 @@
}
)";
- sem::BindingPoint placeholder{1024, 0};
- sem::SamplerTexturePair pair;
- pair.texture_binding_point.group = 0;
- pair.texture_binding_point.binding = 0;
- pair.sampler_binding_point.group = placeholder.group;
- pair.sampler_binding_point.binding = placeholder.binding;
- CombineSamplers::BindingMap map;
- map[pair] = "fred";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(map, placeholder);
- auto got = Run<CombineSamplers>(src, data);
+ sem::BindingPoint placeholder{1024, 0};
+ sem::SamplerTexturePair pair;
+ pair.texture_binding_point.group = 0;
+ pair.texture_binding_point.binding = 0;
+ pair.sampler_binding_point.group = placeholder.group;
+ pair.sampler_binding_point.binding = placeholder.binding;
+ CombineSamplers::BindingMap map;
+ map[pair] = "fred";
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(map, placeholder);
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, TextureWithAndWithoutSampler) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var tex : texture_2d<f32>;
@group(0) @binding(1) var samp : sampler;
@@ -810,7 +793,7 @@
textureSample(tex, samp, vec2<f32>());
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var fred : texture_2d<f32>;
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var barney : texture_2d<f32>;
@@ -822,30 +805,30 @@
}
)";
- sem::BindingPoint placeholder{1024, 0};
- sem::BindingPoint tex{0, 0};
- sem::BindingPoint samp{0, 1};
- sem::SamplerTexturePair pair, placeholder_pair;
- pair.texture_binding_point.group = tex.group;
- pair.texture_binding_point.binding = tex.binding;
- pair.sampler_binding_point.group = samp.group;
- pair.sampler_binding_point.binding = samp.binding;
- placeholder_pair.texture_binding_point.group = tex.group;
- placeholder_pair.texture_binding_point.binding = tex.binding;
- placeholder_pair.sampler_binding_point.group = placeholder.group;
- placeholder_pair.sampler_binding_point.binding = placeholder.binding;
- CombineSamplers::BindingMap map;
- map[pair] = "barney";
- map[placeholder_pair] = "fred";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(map, placeholder);
- auto got = Run<CombineSamplers>(src, data);
+ sem::BindingPoint placeholder{1024, 0};
+ sem::BindingPoint tex{0, 0};
+ sem::BindingPoint samp{0, 1};
+ sem::SamplerTexturePair pair, placeholder_pair;
+ pair.texture_binding_point.group = tex.group;
+ pair.texture_binding_point.binding = tex.binding;
+ pair.sampler_binding_point.group = samp.group;
+ pair.sampler_binding_point.binding = samp.binding;
+ placeholder_pair.texture_binding_point.group = tex.group;
+ placeholder_pair.texture_binding_point.binding = tex.binding;
+ placeholder_pair.sampler_binding_point.group = placeholder.group;
+ placeholder_pair.sampler_binding_point.binding = placeholder.binding;
+ CombineSamplers::BindingMap map;
+ map[pair] = "barney";
+ map[placeholder_pair] = "fred";
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(map, placeholder);
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, TextureSampleCompare) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var tex : texture_depth_2d;
@group(0) @binding(1) var samp : sampler_comparison;
@@ -854,7 +837,7 @@
return vec4<f32>(textureSampleCompare(tex, samp, vec2<f32>(1.0, 2.0), 0.5));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp : texture_depth_2d;
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_comparison_sampler : sampler_comparison;
@@ -864,16 +847,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, TextureSampleCompareInAFunction) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var tex : texture_depth_2d;
@group(0) @binding(1) var samp : sampler_comparison;
@@ -886,7 +868,7 @@
return vec4<f32>(f(tex, samp, vec2<f32>(1.0, 2.0)));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_comparison_sampler : sampler_comparison;
fn f(t_s : texture_depth_2d, coords : vec2<f32>) -> f32 {
@@ -900,16 +882,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, TextureSampleCompareInAFunction_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn main() -> vec4<f32> {
return vec4<f32>(f(tex, samp, vec2<f32>(1.0, 2.0)));
}
@@ -921,7 +902,7 @@
@group(0) @binding(0) var tex : texture_depth_2d;
@group(0) @binding(1) var samp : sampler_comparison;
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp : texture_depth_2d;
fn main() -> vec4<f32> {
@@ -935,16 +916,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, BindingPointCollision) {
- auto* src = R"(
+ auto* src = R"(
@group(1) @binding(0) var tex : texture_2d<f32>;
@group(1) @binding(1) var samp : sampler;
@@ -955,7 +935,7 @@
return textureSample(tex, samp, gcoords);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(disable_validation__binding_point_collision) @group(0) @binding(0) var<uniform> gcoords : vec2<f32>;
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp : texture_2d<f32>;
@@ -967,16 +947,15 @@
}
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(CombineSamplersTest, BindingPointCollision_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn main() -> vec4<f32> {
return textureSample(tex, samp, gcoords);
}
@@ -986,7 +965,7 @@
@group(1) @binding(0) var tex : texture_2d<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp : texture_2d<f32>;
@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
@@ -998,12 +977,11 @@
@internal(disable_validation__binding_point_collision) @group(0) @binding(0) var<uniform> gcoords : vec2<f32>;
)";
- DataMap data;
- data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
- sem::BindingPoint());
- auto got = Run<CombineSamplers>(src, data);
+ DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/decompose_memory_access.cc b/src/tint/transform/decompose_memory_access.cc
index a357c93..4b21e73 100644
--- a/src/tint/transform/decompose_memory_access.cc
+++ b/src/tint/transform/decompose_memory_access.cc
@@ -48,176 +48,169 @@
/// Offset is a simple ast::Expression builder interface, used to build byte
/// offsets for storage and uniform buffer accesses.
struct Offset : Castable<Offset> {
- /// @returns builds and returns the ast::Expression in `ctx.dst`
- virtual const ast::Expression* Build(CloneContext& ctx) const = 0;
+ /// @returns builds and returns the ast::Expression in `ctx.dst`
+ virtual const ast::Expression* Build(CloneContext& ctx) const = 0;
};
/// OffsetExpr is an implementation of Offset that clones and casts the given
/// expression to `u32`.
struct OffsetExpr : Offset {
- const ast::Expression* const expr = nullptr;
+ const ast::Expression* const expr = nullptr;
- explicit OffsetExpr(const ast::Expression* e) : expr(e) {}
+ explicit OffsetExpr(const ast::Expression* e) : expr(e) {}
- const ast::Expression* Build(CloneContext& ctx) const override {
- auto* type = ctx.src->Sem().Get(expr)->Type()->UnwrapRef();
- auto* res = ctx.Clone(expr);
- if (!type->Is<sem::U32>()) {
- res = ctx.dst->Construct<ProgramBuilder::u32>(res);
+ const ast::Expression* Build(CloneContext& ctx) const override {
+ auto* type = ctx.src->Sem().Get(expr)->Type()->UnwrapRef();
+ auto* res = ctx.Clone(expr);
+ if (!type->Is<sem::U32>()) {
+ res = ctx.dst->Construct<ProgramBuilder::u32>(res);
+ }
+ return res;
}
- return res;
- }
};
/// OffsetLiteral is an implementation of Offset that constructs a u32 literal
/// value.
struct OffsetLiteral : Castable<OffsetLiteral, Offset> {
- uint32_t const literal = 0;
+ uint32_t const literal = 0;
- explicit OffsetLiteral(uint32_t lit) : literal(lit) {}
+ explicit OffsetLiteral(uint32_t lit) : literal(lit) {}
- const ast::Expression* Build(CloneContext& ctx) const override {
- return ctx.dst->Expr(literal);
- }
+ const ast::Expression* Build(CloneContext& ctx) const override {
+ return ctx.dst->Expr(literal);
+ }
};
/// OffsetBinOp is an implementation of Offset that constructs a binary-op of
/// two Offsets.
struct OffsetBinOp : Offset {
- ast::BinaryOp op;
- Offset const* lhs = nullptr;
- Offset const* rhs = nullptr;
+ ast::BinaryOp op;
+ Offset const* lhs = nullptr;
+ Offset const* rhs = nullptr;
- const ast::Expression* Build(CloneContext& ctx) const override {
- return ctx.dst->create<ast::BinaryExpression>(op, lhs->Build(ctx),
- rhs->Build(ctx));
- }
+ const ast::Expression* Build(CloneContext& ctx) const override {
+ return ctx.dst->create<ast::BinaryExpression>(op, lhs->Build(ctx), rhs->Build(ctx));
+ }
};
/// LoadStoreKey is the unordered map key to a load or store intrinsic.
struct LoadStoreKey {
- ast::StorageClass const storage_class; // buffer storage class
- sem::Type const* buf_ty = nullptr; // buffer type
- sem::Type const* el_ty = nullptr; // element type
- bool operator==(const LoadStoreKey& rhs) const {
- return storage_class == rhs.storage_class && buf_ty == rhs.buf_ty &&
- el_ty == rhs.el_ty;
- }
- struct Hasher {
- inline std::size_t operator()(const LoadStoreKey& u) const {
- return utils::Hash(u.storage_class, u.buf_ty, u.el_ty);
+ ast::StorageClass const storage_class; // buffer storage class
+ sem::Type const* buf_ty = nullptr; // buffer type
+ sem::Type const* el_ty = nullptr; // element type
+ bool operator==(const LoadStoreKey& rhs) const {
+ return storage_class == rhs.storage_class && buf_ty == rhs.buf_ty && el_ty == rhs.el_ty;
}
- };
+ struct Hasher {
+ inline std::size_t operator()(const LoadStoreKey& u) const {
+ return utils::Hash(u.storage_class, u.buf_ty, u.el_ty);
+ }
+ };
};
/// AtomicKey is the unordered map key to an atomic intrinsic.
struct AtomicKey {
- sem::Type const* buf_ty = nullptr; // buffer type
- sem::Type const* el_ty = nullptr; // element type
- sem::BuiltinType const op; // atomic op
- bool operator==(const AtomicKey& rhs) const {
- return buf_ty == rhs.buf_ty && el_ty == rhs.el_ty && op == rhs.op;
- }
- struct Hasher {
- inline std::size_t operator()(const AtomicKey& u) const {
- return utils::Hash(u.buf_ty, u.el_ty, u.op);
+ sem::Type const* buf_ty = nullptr; // buffer type
+ sem::Type const* el_ty = nullptr; // element type
+ sem::BuiltinType const op; // atomic op
+ bool operator==(const AtomicKey& rhs) const {
+ return buf_ty == rhs.buf_ty && el_ty == rhs.el_ty && op == rhs.op;
}
- };
+ struct Hasher {
+ inline std::size_t operator()(const AtomicKey& u) const {
+ return utils::Hash(u.buf_ty, u.el_ty, u.op);
+ }
+ };
};
-bool IntrinsicDataTypeFor(const sem::Type* ty,
- DecomposeMemoryAccess::Intrinsic::DataType& out) {
- if (ty->Is<sem::I32>()) {
- out = DecomposeMemoryAccess::Intrinsic::DataType::kI32;
- return true;
- }
- if (ty->Is<sem::U32>()) {
- out = DecomposeMemoryAccess::Intrinsic::DataType::kU32;
- return true;
- }
- if (ty->Is<sem::F32>()) {
- out = DecomposeMemoryAccess::Intrinsic::DataType::kF32;
- return true;
- }
- if (auto* vec = ty->As<sem::Vector>()) {
- switch (vec->Width()) {
- case 2:
- if (vec->type()->Is<sem::I32>()) {
- out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2I32;
- return true;
- }
- if (vec->type()->Is<sem::U32>()) {
- out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2U32;
- return true;
- }
- if (vec->type()->Is<sem::F32>()) {
- out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2F32;
- return true;
- }
- break;
- case 3:
- if (vec->type()->Is<sem::I32>()) {
- out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3I32;
- return true;
- }
- if (vec->type()->Is<sem::U32>()) {
- out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3U32;
- return true;
- }
- if (vec->type()->Is<sem::F32>()) {
- out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3F32;
- return true;
- }
- break;
- case 4:
- if (vec->type()->Is<sem::I32>()) {
- out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4I32;
- return true;
- }
- if (vec->type()->Is<sem::U32>()) {
- out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4U32;
- return true;
- }
- if (vec->type()->Is<sem::F32>()) {
- out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4F32;
- return true;
- }
- break;
+bool IntrinsicDataTypeFor(const sem::Type* ty, DecomposeMemoryAccess::Intrinsic::DataType& out) {
+ if (ty->Is<sem::I32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kI32;
+ return true;
}
- return false;
- }
+ if (ty->Is<sem::U32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kU32;
+ return true;
+ }
+ if (ty->Is<sem::F32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kF32;
+ return true;
+ }
+ if (auto* vec = ty->As<sem::Vector>()) {
+ switch (vec->Width()) {
+ case 2:
+ if (vec->type()->Is<sem::I32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2I32;
+ return true;
+ }
+ if (vec->type()->Is<sem::U32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2U32;
+ return true;
+ }
+ if (vec->type()->Is<sem::F32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2F32;
+ return true;
+ }
+ break;
+ case 3:
+ if (vec->type()->Is<sem::I32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3I32;
+ return true;
+ }
+ if (vec->type()->Is<sem::U32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3U32;
+ return true;
+ }
+ if (vec->type()->Is<sem::F32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3F32;
+ return true;
+ }
+ break;
+ case 4:
+ if (vec->type()->Is<sem::I32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4I32;
+ return true;
+ }
+ if (vec->type()->Is<sem::U32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4U32;
+ return true;
+ }
+ if (vec->type()->Is<sem::F32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4F32;
+ return true;
+ }
+ break;
+ }
+ return false;
+ }
- return false;
+ return false;
}
/// @returns a DecomposeMemoryAccess::Intrinsic attribute that can be applied
/// to a stub function to load the type `ty`.
-DecomposeMemoryAccess::Intrinsic* IntrinsicLoadFor(
- ProgramBuilder* builder,
- ast::StorageClass storage_class,
- const sem::Type* ty) {
- DecomposeMemoryAccess::Intrinsic::DataType type;
- if (!IntrinsicDataTypeFor(ty, type)) {
- return nullptr;
- }
- return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
- builder->ID(), DecomposeMemoryAccess::Intrinsic::Op::kLoad, storage_class,
- type);
+DecomposeMemoryAccess::Intrinsic* IntrinsicLoadFor(ProgramBuilder* builder,
+ ast::StorageClass storage_class,
+ const sem::Type* ty) {
+ DecomposeMemoryAccess::Intrinsic::DataType type;
+ if (!IntrinsicDataTypeFor(ty, type)) {
+ return nullptr;
+ }
+ return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
+ builder->ID(), DecomposeMemoryAccess::Intrinsic::Op::kLoad, storage_class, type);
}
/// @returns a DecomposeMemoryAccess::Intrinsic attribute that can be applied
/// to a stub function to store the type `ty`.
-DecomposeMemoryAccess::Intrinsic* IntrinsicStoreFor(
- ProgramBuilder* builder,
- ast::StorageClass storage_class,
- const sem::Type* ty) {
- DecomposeMemoryAccess::Intrinsic::DataType type;
- if (!IntrinsicDataTypeFor(ty, type)) {
- return nullptr;
- }
- return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
- builder->ID(), DecomposeMemoryAccess::Intrinsic::Op::kStore,
- storage_class, type);
+DecomposeMemoryAccess::Intrinsic* IntrinsicStoreFor(ProgramBuilder* builder,
+ ast::StorageClass storage_class,
+ const sem::Type* ty) {
+ DecomposeMemoryAccess::Intrinsic::DataType type;
+ if (!IntrinsicDataTypeFor(ty, type)) {
+ return nullptr;
+ }
+ return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
+ builder->ID(), DecomposeMemoryAccess::Intrinsic::Op::kStore, storage_class, type);
}
/// @returns a DecomposeMemoryAccess::Intrinsic attribute that can be applied
@@ -225,769 +218,737 @@
DecomposeMemoryAccess::Intrinsic* IntrinsicAtomicFor(ProgramBuilder* builder,
sem::BuiltinType ity,
const sem::Type* ty) {
- auto op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
- switch (ity) {
- case sem::BuiltinType::kAtomicLoad:
- op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
- break;
- case sem::BuiltinType::kAtomicStore:
- op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicStore;
- break;
- case sem::BuiltinType::kAtomicAdd:
- op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAdd;
- break;
- case sem::BuiltinType::kAtomicSub:
- op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicSub;
- break;
- case sem::BuiltinType::kAtomicMax:
- op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMax;
- break;
- case sem::BuiltinType::kAtomicMin:
- op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMin;
- break;
- case sem::BuiltinType::kAtomicAnd:
- op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAnd;
- break;
- case sem::BuiltinType::kAtomicOr:
- op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicOr;
- break;
- case sem::BuiltinType::kAtomicXor:
- op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicXor;
- break;
- case sem::BuiltinType::kAtomicExchange:
- op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicExchange;
- break;
- case sem::BuiltinType::kAtomicCompareExchangeWeak:
- op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicCompareExchangeWeak;
- break;
- default:
- TINT_ICE(Transform, builder->Diagnostics())
- << "invalid IntrinsicType for DecomposeMemoryAccess::Intrinsic: "
- << ty->TypeInfo().name;
- break;
- }
+ auto op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
+ switch (ity) {
+ case sem::BuiltinType::kAtomicLoad:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
+ break;
+ case sem::BuiltinType::kAtomicStore:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicStore;
+ break;
+ case sem::BuiltinType::kAtomicAdd:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAdd;
+ break;
+ case sem::BuiltinType::kAtomicSub:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicSub;
+ break;
+ case sem::BuiltinType::kAtomicMax:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMax;
+ break;
+ case sem::BuiltinType::kAtomicMin:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMin;
+ break;
+ case sem::BuiltinType::kAtomicAnd:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAnd;
+ break;
+ case sem::BuiltinType::kAtomicOr:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicOr;
+ break;
+ case sem::BuiltinType::kAtomicXor:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicXor;
+ break;
+ case sem::BuiltinType::kAtomicExchange:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicExchange;
+ break;
+ case sem::BuiltinType::kAtomicCompareExchangeWeak:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicCompareExchangeWeak;
+ break;
+ default:
+ TINT_ICE(Transform, builder->Diagnostics())
+ << "invalid IntrinsicType for DecomposeMemoryAccess::Intrinsic: "
+ << ty->TypeInfo().name;
+ break;
+ }
- DecomposeMemoryAccess::Intrinsic::DataType type;
- if (!IntrinsicDataTypeFor(ty, type)) {
- return nullptr;
- }
- return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
- builder->ID(), op, ast::StorageClass::kStorage, type);
+ DecomposeMemoryAccess::Intrinsic::DataType type;
+ if (!IntrinsicDataTypeFor(ty, type)) {
+ return nullptr;
+ }
+ return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
+ builder->ID(), op, ast::StorageClass::kStorage, type);
}
/// BufferAccess describes a single storage or uniform buffer access
struct BufferAccess {
- sem::Expression const* var = nullptr; // Storage buffer variable
- Offset const* offset = nullptr; // The byte offset on var
- sem::Type const* type = nullptr; // The type of the access
- operator bool() const { return var; } // Returns true if valid
+ sem::Expression const* var = nullptr; // Storage buffer variable
+ Offset const* offset = nullptr; // The byte offset on var
+ sem::Type const* type = nullptr; // The type of the access
+ operator bool() const { return var; } // Returns true if valid
};
/// Store describes a single storage or uniform buffer write
struct Store {
- const ast::AssignmentStatement* assignment; // The AST assignment statement
- BufferAccess target; // The target for the write
+ const ast::AssignmentStatement* assignment; // The AST assignment statement
+ BufferAccess target; // The target for the write
};
} // namespace
/// State holds the current transform state
struct DecomposeMemoryAccess::State {
- /// The clone context
- CloneContext& ctx;
- /// Alias to `*ctx.dst`
- ProgramBuilder& b;
- /// Map of AST expression to storage or uniform buffer access
- /// This map has entries added when encountered, and removed when outer
- /// expressions chain the access.
- /// Subset of #expression_order, as expressions are not removed from
- /// #expression_order.
- std::unordered_map<const ast::Expression*, BufferAccess> accesses;
- /// The visited order of AST expressions (superset of #accesses)
- std::vector<const ast::Expression*> expression_order;
- /// [buffer-type, element-type] -> load function name
- std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> load_funcs;
- /// [buffer-type, element-type] -> store function name
- std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> store_funcs;
- /// [buffer-type, element-type, atomic-op] -> load function name
- std::unordered_map<AtomicKey, Symbol, AtomicKey::Hasher> atomic_funcs;
- /// List of storage or uniform buffer writes
- std::vector<Store> stores;
- /// Allocations for offsets
- utils::BlockAllocator<Offset> offsets_;
+ /// The clone context
+ CloneContext& ctx;
+ /// Alias to `*ctx.dst`
+ ProgramBuilder& b;
+ /// Map of AST expression to storage or uniform buffer access
+ /// This map has entries added when encountered, and removed when outer
+ /// expressions chain the access.
+ /// Subset of #expression_order, as expressions are not removed from
+ /// #expression_order.
+ std::unordered_map<const ast::Expression*, BufferAccess> accesses;
+ /// The visited order of AST expressions (superset of #accesses)
+ std::vector<const ast::Expression*> expression_order;
+ /// [buffer-type, element-type] -> load function name
+ std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> load_funcs;
+ /// [buffer-type, element-type] -> store function name
+ std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> store_funcs;
+ /// [buffer-type, element-type, atomic-op] -> load function name
+ std::unordered_map<AtomicKey, Symbol, AtomicKey::Hasher> atomic_funcs;
+ /// List of storage or uniform buffer writes
+ std::vector<Store> stores;
+ /// Allocations for offsets
+ utils::BlockAllocator<Offset> offsets_;
- /// Constructor
- /// @param context the CloneContext
- explicit State(CloneContext& context) : ctx(context), b(*ctx.dst) {}
+ /// Constructor
+ /// @param context the CloneContext
+ explicit State(CloneContext& context) : ctx(context), b(*ctx.dst) {}
- /// @param offset the offset value to wrap in an Offset
- /// @returns an Offset for the given literal value
- const Offset* ToOffset(uint32_t offset) {
- return offsets_.Create<OffsetLiteral>(offset);
- }
+ /// @param offset the offset value to wrap in an Offset
+ /// @returns an Offset for the given literal value
+ const Offset* ToOffset(uint32_t offset) { return offsets_.Create<OffsetLiteral>(offset); }
- /// @param expr the expression to convert to an Offset
- /// @returns an Offset for the given ast::Expression
- const Offset* ToOffset(const ast::Expression* expr) {
- if (auto* u32 = expr->As<ast::UintLiteralExpression>()) {
- return offsets_.Create<OffsetLiteral>(u32->value);
- } else if (auto* i32 = expr->As<ast::SintLiteralExpression>()) {
- if (i32->value > 0) {
- return offsets_.Create<OffsetLiteral>(i32->value);
- }
- }
- return offsets_.Create<OffsetExpr>(expr);
- }
-
- /// @param offset the Offset that is returned
- /// @returns the given offset (pass-through)
- const Offset* ToOffset(const Offset* offset) { return offset; }
-
- /// @param lhs_ the left-hand side of the add expression
- /// @param rhs_ the right-hand side of the add expression
- /// @return an Offset that is a sum of lhs and rhs, performing basic constant
- /// folding if possible
- template <typename LHS, typename RHS>
- const Offset* Add(LHS&& lhs_, RHS&& rhs_) {
- auto* lhs = ToOffset(std::forward<LHS>(lhs_));
- auto* rhs = ToOffset(std::forward<RHS>(rhs_));
- auto* lhs_lit = tint::As<OffsetLiteral>(lhs);
- auto* rhs_lit = tint::As<OffsetLiteral>(rhs);
- if (lhs_lit && lhs_lit->literal == 0) {
- return rhs;
- }
- if (rhs_lit && rhs_lit->literal == 0) {
- return lhs;
- }
- if (lhs_lit && rhs_lit) {
- if (static_cast<uint64_t>(lhs_lit->literal) +
- static_cast<uint64_t>(rhs_lit->literal) <=
- 0xffffffff) {
- return offsets_.Create<OffsetLiteral>(lhs_lit->literal +
- rhs_lit->literal);
- }
- }
- auto* out = offsets_.Create<OffsetBinOp>();
- out->op = ast::BinaryOp::kAdd;
- out->lhs = lhs;
- out->rhs = rhs;
- return out;
- }
-
- /// @param lhs_ the left-hand side of the multiply expression
- /// @param rhs_ the right-hand side of the multiply expression
- /// @return an Offset that is the multiplication of lhs and rhs, performing
- /// basic constant folding if possible
- template <typename LHS, typename RHS>
- const Offset* Mul(LHS&& lhs_, RHS&& rhs_) {
- auto* lhs = ToOffset(std::forward<LHS>(lhs_));
- auto* rhs = ToOffset(std::forward<RHS>(rhs_));
- auto* lhs_lit = tint::As<OffsetLiteral>(lhs);
- auto* rhs_lit = tint::As<OffsetLiteral>(rhs);
- if (lhs_lit && lhs_lit->literal == 0) {
- return offsets_.Create<OffsetLiteral>(0);
- }
- if (rhs_lit && rhs_lit->literal == 0) {
- return offsets_.Create<OffsetLiteral>(0);
- }
- if (lhs_lit && lhs_lit->literal == 1) {
- return rhs;
- }
- if (rhs_lit && rhs_lit->literal == 1) {
- return lhs;
- }
- if (lhs_lit && rhs_lit) {
- return offsets_.Create<OffsetLiteral>(lhs_lit->literal *
- rhs_lit->literal);
- }
- auto* out = offsets_.Create<OffsetBinOp>();
- out->op = ast::BinaryOp::kMultiply;
- out->lhs = lhs;
- out->rhs = rhs;
- return out;
- }
-
- /// AddAccess() adds the `expr -> access` map item to #accesses, and `expr`
- /// to #expression_order.
- /// @param expr the expression that performs the access
- /// @param access the access
- void AddAccess(const ast::Expression* expr, const BufferAccess& access) {
- TINT_ASSERT(Transform, access.type);
- accesses.emplace(expr, access);
- expression_order.emplace_back(expr);
- }
-
- /// TakeAccess() removes the `node` item from #accesses (if it exists),
- /// returning the BufferAccess. If #accesses does not hold an item for
- /// `node`, an invalid BufferAccess is returned.
- /// @param node the expression that performed an access
- /// @return the BufferAccess for the given expression
- BufferAccess TakeAccess(const ast::Expression* node) {
- auto lhs_it = accesses.find(node);
- if (lhs_it == accesses.end()) {
- return {};
- }
- auto access = lhs_it->second;
- accesses.erase(node);
- return access;
- }
-
- /// LoadFunc() returns a symbol to an intrinsic function that loads an element
- /// of type `el_ty` from a storage or uniform buffer of type `buf_ty`.
- /// The emitted function has the signature:
- /// `fn load(buf : buf_ty, offset : u32) -> el_ty`
- /// @param buf_ty the storage or uniform buffer type
- /// @param el_ty the storage or uniform buffer element type
- /// @param var_user the variable user
- /// @return the name of the function that performs the load
- Symbol LoadFunc(const sem::Type* buf_ty,
- const sem::Type* el_ty,
- const sem::VariableUser* var_user) {
- auto storage_class = var_user->Variable()->StorageClass();
- return utils::GetOrCreate(
- load_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
- auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
- auto* disable_validation = b.Disable(
- ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
-
- ast::VariableList params = {
- // Note: The buffer parameter requires the StorageClass in
- // order for HLSL to emit this as a ByteAddressBuffer or cbuffer
- // array.
- b.create<ast::Variable>(b.Sym("buffer"), storage_class,
- var_user->Variable()->Access(),
- buf_ast_ty, true, false, nullptr,
- ast::AttributeList{disable_validation}),
- b.Param("offset", b.ty.u32()),
- };
-
- auto name = b.Sym();
-
- if (auto* intrinsic =
- IntrinsicLoadFor(ctx.dst, storage_class, el_ty)) {
- auto* el_ast_ty = CreateASTTypeFor(ctx, el_ty);
- auto* func = b.create<ast::Function>(
- name, params, el_ast_ty, nullptr,
- ast::AttributeList{
- intrinsic,
- b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
- },
- ast::AttributeList{});
- b.AST().AddFunction(func);
- } else if (auto* arr_ty = el_ty->As<sem::Array>()) {
- // fn load_func(buf : buf_ty, offset : u32) -> array<T, N> {
- // var arr : array<T, N>;
- // for (var i = 0u; i < array_count; i = i + 1) {
- // arr[i] = el_load_func(buf, offset + i * array_stride)
- // }
- // return arr;
- // }
- auto load =
- LoadFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
- auto* arr =
- b.Var(b.Symbols().New("arr"), CreateASTTypeFor(ctx, arr_ty));
- auto* i = b.Var(b.Symbols().New("i"), nullptr, b.Expr(0u));
- auto* for_init = b.Decl(i);
- auto* for_cond = b.create<ast::BinaryExpression>(
- ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(arr_ty->Count()));
- auto* for_cont = b.Assign(i, b.Add(i, 1u));
- auto* arr_el = b.IndexAccessor(arr, i);
- auto* el_offset =
- b.Add(b.Expr("offset"), b.Mul(i, arr_ty->Stride()));
- auto* el_val = b.Call(load, "buffer", el_offset);
- auto* for_loop = b.For(for_init, for_cond, for_cont,
- b.Block(b.Assign(arr_el, el_val)));
-
- b.Func(name, params, CreateASTTypeFor(ctx, arr_ty),
- {
- b.Decl(arr),
- for_loop,
- b.Return(arr),
- });
- } else {
- ast::ExpressionList values;
- if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
- auto* vec_ty = mat_ty->ColumnType();
- Symbol load = LoadFunc(buf_ty, vec_ty, var_user);
- for (uint32_t i = 0; i < mat_ty->columns(); i++) {
- auto* offset = b.Add("offset", i * mat_ty->ColumnStride());
- values.emplace_back(b.Call(load, "buffer", offset));
- }
- } else if (auto* str = el_ty->As<sem::Struct>()) {
- for (auto* member : str->Members()) {
- auto* offset = b.Add("offset", member->Offset());
- Symbol load =
- LoadFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
- values.emplace_back(b.Call(load, "buffer", offset));
- }
+ /// @param expr the expression to convert to an Offset
+ /// @returns an Offset for the given ast::Expression
+ const Offset* ToOffset(const ast::Expression* expr) {
+ if (auto* u32 = expr->As<ast::UintLiteralExpression>()) {
+ return offsets_.Create<OffsetLiteral>(u32->value);
+ } else if (auto* i32 = expr->As<ast::SintLiteralExpression>()) {
+ if (i32->value > 0) {
+ return offsets_.Create<OffsetLiteral>(i32->value);
}
- b.Func(
- name, params, CreateASTTypeFor(ctx, el_ty),
- {
- b.Return(b.Construct(CreateASTTypeFor(ctx, el_ty), values)),
- });
- }
- return name;
- });
- }
+ }
+ return offsets_.Create<OffsetExpr>(expr);
+ }
- /// StoreFunc() returns a symbol to an intrinsic function that stores an
- /// element of type `el_ty` to a storage buffer of type `buf_ty`.
- /// The function has the signature:
- /// `fn store(buf : buf_ty, offset : u32, value : el_ty)`
- /// @param buf_ty the storage buffer type
- /// @param el_ty the storage buffer element type
- /// @param var_user the variable user
- /// @return the name of the function that performs the store
- Symbol StoreFunc(const sem::Type* buf_ty,
- const sem::Type* el_ty,
- const sem::VariableUser* var_user) {
- auto storage_class = var_user->Variable()->StorageClass();
- return utils::GetOrCreate(
- store_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
- auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
- auto* el_ast_ty = CreateASTTypeFor(ctx, el_ty);
- auto* disable_validation = b.Disable(
- ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
- ast::VariableList params{
- // Note: The buffer parameter requires the StorageClass in
- // order for HLSL to emit this as a ByteAddressBuffer.
+ /// @param offset the Offset that is returned
+ /// @returns the given offset (pass-through)
+ const Offset* ToOffset(const Offset* offset) { return offset; }
- b.create<ast::Variable>(b.Sym("buffer"), storage_class,
- var_user->Variable()->Access(),
- buf_ast_ty, true, false, nullptr,
- ast::AttributeList{disable_validation}),
- b.Param("offset", b.ty.u32()),
- b.Param("value", el_ast_ty),
- };
-
- auto name = b.Sym();
-
- if (auto* intrinsic =
- IntrinsicStoreFor(ctx.dst, storage_class, el_ty)) {
- auto* func = b.create<ast::Function>(
- name, params, b.ty.void_(), nullptr,
- ast::AttributeList{
- intrinsic,
- b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
- },
- ast::AttributeList{});
- b.AST().AddFunction(func);
- } else {
- ast::StatementList body;
- if (auto* arr_ty = el_ty->As<sem::Array>()) {
- // fn store_func(buf : buf_ty, offset : u32, value : el_ty) {
- // var array = value; // No dynamic indexing on constant arrays
- // for (var i = 0u; i < array_count; i = i + 1) {
- // arr[i] = el_store_func(buf, offset + i * array_stride,
- // value[i])
- // }
- // return arr;
- // }
- auto* array =
- b.Var(b.Symbols().New("array"), nullptr, b.Expr("value"));
- auto store =
- StoreFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
- auto* i = b.Var(b.Symbols().New("i"), nullptr, b.Expr(0u));
- auto* for_init = b.Decl(i);
- auto* for_cond = b.create<ast::BinaryExpression>(
- ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(arr_ty->Count()));
- auto* for_cont = b.Assign(i, b.Add(i, 1u));
- auto* arr_el = b.IndexAccessor(array, i);
- auto* el_offset =
- b.Add(b.Expr("offset"), b.Mul(i, arr_ty->Stride()));
- auto* store_stmt =
- b.CallStmt(b.Call(store, "buffer", el_offset, arr_el));
- auto* for_loop =
- b.For(for_init, for_cond, for_cont, b.Block(store_stmt));
-
- body = {b.Decl(array), for_loop};
- } else if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
- auto* vec_ty = mat_ty->ColumnType();
- Symbol store = StoreFunc(buf_ty, vec_ty, var_user);
- for (uint32_t i = 0; i < mat_ty->columns(); i++) {
- auto* offset = b.Add("offset", i * mat_ty->ColumnStride());
- auto* access = b.IndexAccessor("value", i);
- auto* call = b.Call(store, "buffer", offset, access);
- body.emplace_back(b.CallStmt(call));
- }
- } else if (auto* str = el_ty->As<sem::Struct>()) {
- for (auto* member : str->Members()) {
- auto* offset = b.Add("offset", member->Offset());
- auto* access = b.MemberAccessor(
- "value", ctx.Clone(member->Declaration()->symbol));
- Symbol store =
- StoreFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
- auto* call = b.Call(store, "buffer", offset, access);
- body.emplace_back(b.CallStmt(call));
- }
+ /// @param lhs_ the left-hand side of the add expression
+ /// @param rhs_ the right-hand side of the add expression
+ /// @return an Offset that is a sum of lhs and rhs, performing basic constant
+ /// folding if possible
+ template <typename LHS, typename RHS>
+ const Offset* Add(LHS&& lhs_, RHS&& rhs_) {
+ auto* lhs = ToOffset(std::forward<LHS>(lhs_));
+ auto* rhs = ToOffset(std::forward<RHS>(rhs_));
+ auto* lhs_lit = tint::As<OffsetLiteral>(lhs);
+ auto* rhs_lit = tint::As<OffsetLiteral>(rhs);
+ if (lhs_lit && lhs_lit->literal == 0) {
+ return rhs;
+ }
+ if (rhs_lit && rhs_lit->literal == 0) {
+ return lhs;
+ }
+ if (lhs_lit && rhs_lit) {
+ if (static_cast<uint64_t>(lhs_lit->literal) + static_cast<uint64_t>(rhs_lit->literal) <=
+ 0xffffffff) {
+ return offsets_.Create<OffsetLiteral>(lhs_lit->literal + rhs_lit->literal);
}
- b.Func(name, params, b.ty.void_(), body);
- }
+ }
+ auto* out = offsets_.Create<OffsetBinOp>();
+ out->op = ast::BinaryOp::kAdd;
+ out->lhs = lhs;
+ out->rhs = rhs;
+ return out;
+ }
- return name;
- });
- }
+ /// @param lhs_ the left-hand side of the multiply expression
+ /// @param rhs_ the right-hand side of the multiply expression
+ /// @return an Offset that is the multiplication of lhs and rhs, performing
+ /// basic constant folding if possible
+ template <typename LHS, typename RHS>
+ const Offset* Mul(LHS&& lhs_, RHS&& rhs_) {
+ auto* lhs = ToOffset(std::forward<LHS>(lhs_));
+ auto* rhs = ToOffset(std::forward<RHS>(rhs_));
+ auto* lhs_lit = tint::As<OffsetLiteral>(lhs);
+ auto* rhs_lit = tint::As<OffsetLiteral>(rhs);
+ if (lhs_lit && lhs_lit->literal == 0) {
+ return offsets_.Create<OffsetLiteral>(0);
+ }
+ if (rhs_lit && rhs_lit->literal == 0) {
+ return offsets_.Create<OffsetLiteral>(0);
+ }
+ if (lhs_lit && lhs_lit->literal == 1) {
+ return rhs;
+ }
+ if (rhs_lit && rhs_lit->literal == 1) {
+ return lhs;
+ }
+ if (lhs_lit && rhs_lit) {
+ return offsets_.Create<OffsetLiteral>(lhs_lit->literal * rhs_lit->literal);
+ }
+ auto* out = offsets_.Create<OffsetBinOp>();
+ out->op = ast::BinaryOp::kMultiply;
+ out->lhs = lhs;
+ out->rhs = rhs;
+ return out;
+ }
- /// AtomicFunc() returns a symbol to an intrinsic function that performs an
- /// atomic operation from a storage buffer of type `buf_ty`. The function has
- /// the signature:
- // `fn atomic_op(buf : buf_ty, offset : u32, ...) -> T`
- /// @param buf_ty the storage buffer type
- /// @param el_ty the storage buffer element type
- /// @param intrinsic the atomic intrinsic
- /// @param var_user the variable user
- /// @return the name of the function that performs the load
- Symbol AtomicFunc(const sem::Type* buf_ty,
+ /// AddAccess() adds the `expr -> access` map item to #accesses, and `expr`
+ /// to #expression_order.
+ /// @param expr the expression that performs the access
+ /// @param access the access
+ void AddAccess(const ast::Expression* expr, const BufferAccess& access) {
+ TINT_ASSERT(Transform, access.type);
+ accesses.emplace(expr, access);
+ expression_order.emplace_back(expr);
+ }
+
+ /// TakeAccess() removes the `node` item from #accesses (if it exists),
+ /// returning the BufferAccess. If #accesses does not hold an item for
+ /// `node`, an invalid BufferAccess is returned.
+ /// @param node the expression that performed an access
+ /// @return the BufferAccess for the given expression
+ BufferAccess TakeAccess(const ast::Expression* node) {
+ auto lhs_it = accesses.find(node);
+ if (lhs_it == accesses.end()) {
+ return {};
+ }
+ auto access = lhs_it->second;
+ accesses.erase(node);
+ return access;
+ }
+
+ /// LoadFunc() returns a symbol to an intrinsic function that loads an element
+ /// of type `el_ty` from a storage or uniform buffer of type `buf_ty`.
+ /// The emitted function has the signature:
+ /// `fn load(buf : buf_ty, offset : u32) -> el_ty`
+ /// @param buf_ty the storage or uniform buffer type
+ /// @param el_ty the storage or uniform buffer element type
+ /// @param var_user the variable user
+ /// @return the name of the function that performs the load
+ Symbol LoadFunc(const sem::Type* buf_ty,
const sem::Type* el_ty,
- const sem::Builtin* intrinsic,
const sem::VariableUser* var_user) {
- auto op = intrinsic->Type();
- return utils::GetOrCreate(atomic_funcs, AtomicKey{buf_ty, el_ty, op}, [&] {
- auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
- auto* disable_validation = b.Disable(
- ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
- // The first parameter to all WGSL atomics is the expression to the
- // atomic. This is replaced with two parameters: the buffer and offset.
+ auto storage_class = var_user->Variable()->StorageClass();
+ return utils::GetOrCreate(load_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
+ auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
+ auto* disable_validation =
+ b.Disable(ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
- ast::VariableList params = {
- // Note: The buffer parameter requires the kStorage StorageClass in
- // order for HLSL to emit this as a ByteAddressBuffer.
- b.create<ast::Variable>(b.Sym("buffer"), ast::StorageClass::kStorage,
- var_user->Variable()->Access(), buf_ast_ty,
- true, false, nullptr,
- ast::AttributeList{disable_validation}),
- b.Param("offset", b.ty.u32()),
- };
+ ast::VariableList params = {
+ // Note: The buffer parameter requires the StorageClass in
+ // order for HLSL to emit this as a ByteAddressBuffer or cbuffer
+ // array.
+ b.create<ast::Variable>(b.Sym("buffer"), storage_class,
+ var_user->Variable()->Access(), buf_ast_ty, true, false,
+ nullptr, ast::AttributeList{disable_validation}),
+ b.Param("offset", b.ty.u32()),
+ };
- // Other parameters are copied as-is:
- for (size_t i = 1; i < intrinsic->Parameters().size(); i++) {
- auto* param = intrinsic->Parameters()[i];
- auto* ty = CreateASTTypeFor(ctx, param->Type());
- params.emplace_back(b.Param("param_" + std::to_string(i), ty));
- }
+ auto name = b.Sym();
- auto* atomic = IntrinsicAtomicFor(ctx.dst, op, el_ty);
- if (atomic == nullptr) {
- TINT_ICE(Transform, b.Diagnostics())
- << "IntrinsicAtomicFor() returned nullptr for op " << op
- << " and type " << el_ty->TypeInfo().name;
- }
+ if (auto* intrinsic = IntrinsicLoadFor(ctx.dst, storage_class, el_ty)) {
+ auto* el_ast_ty = CreateASTTypeFor(ctx, el_ty);
+ auto* func = b.create<ast::Function>(
+ name, params, el_ast_ty, nullptr,
+ ast::AttributeList{
+ intrinsic,
+ b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
+ },
+ ast::AttributeList{});
+ b.AST().AddFunction(func);
+ } else if (auto* arr_ty = el_ty->As<sem::Array>()) {
+ // fn load_func(buf : buf_ty, offset : u32) -> array<T, N> {
+ // var arr : array<T, N>;
+ // for (var i = 0u; i < array_count; i = i + 1) {
+ // arr[i] = el_load_func(buf, offset + i * array_stride)
+ // }
+ // return arr;
+ // }
+ auto load = LoadFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
+ auto* arr = b.Var(b.Symbols().New("arr"), CreateASTTypeFor(ctx, arr_ty));
+ auto* i = b.Var(b.Symbols().New("i"), nullptr, b.Expr(0u));
+ auto* for_init = b.Decl(i);
+ auto* for_cond = b.create<ast::BinaryExpression>(
+ ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(arr_ty->Count()));
+ auto* for_cont = b.Assign(i, b.Add(i, 1u));
+ auto* arr_el = b.IndexAccessor(arr, i);
+ auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, arr_ty->Stride()));
+ auto* el_val = b.Call(load, "buffer", el_offset);
+ auto* for_loop =
+ b.For(for_init, for_cond, for_cont, b.Block(b.Assign(arr_el, el_val)));
- auto* ret_ty = CreateASTTypeFor(ctx, intrinsic->ReturnType());
- auto* func = b.create<ast::Function>(
- b.Sym(), params, ret_ty, nullptr,
- ast::AttributeList{
- atomic,
- b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
- },
- ast::AttributeList{});
+ b.Func(name, params, CreateASTTypeFor(ctx, arr_ty),
+ {
+ b.Decl(arr),
+ for_loop,
+ b.Return(arr),
+ });
+ } else {
+ ast::ExpressionList values;
+ if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
+ auto* vec_ty = mat_ty->ColumnType();
+ Symbol load = LoadFunc(buf_ty, vec_ty, var_user);
+ for (uint32_t i = 0; i < mat_ty->columns(); i++) {
+ auto* offset = b.Add("offset", i * mat_ty->ColumnStride());
+ values.emplace_back(b.Call(load, "buffer", offset));
+ }
+ } else if (auto* str = el_ty->As<sem::Struct>()) {
+ for (auto* member : str->Members()) {
+ auto* offset = b.Add("offset", member->Offset());
+ Symbol load = LoadFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
+ values.emplace_back(b.Call(load, "buffer", offset));
+ }
+ }
+ b.Func(name, params, CreateASTTypeFor(ctx, el_ty),
+ {
+ b.Return(b.Construct(CreateASTTypeFor(ctx, el_ty), values)),
+ });
+ }
+ return name;
+ });
+ }
- b.AST().AddFunction(func);
- return func->symbol;
- });
- }
+ /// StoreFunc() returns a symbol to an intrinsic function that stores an
+ /// element of type `el_ty` to a storage buffer of type `buf_ty`.
+ /// The function has the signature:
+ /// `fn store(buf : buf_ty, offset : u32, value : el_ty)`
+ /// @param buf_ty the storage buffer type
+ /// @param el_ty the storage buffer element type
+ /// @param var_user the variable user
+ /// @return the name of the function that performs the store
+ Symbol StoreFunc(const sem::Type* buf_ty,
+ const sem::Type* el_ty,
+ const sem::VariableUser* var_user) {
+ auto storage_class = var_user->Variable()->StorageClass();
+ return utils::GetOrCreate(store_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
+ auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
+ auto* el_ast_ty = CreateASTTypeFor(ctx, el_ty);
+ auto* disable_validation =
+ b.Disable(ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
+ ast::VariableList params{
+ // Note: The buffer parameter requires the StorageClass in
+ // order for HLSL to emit this as a ByteAddressBuffer.
+
+ b.create<ast::Variable>(b.Sym("buffer"), storage_class,
+ var_user->Variable()->Access(), buf_ast_ty, true, false,
+ nullptr, ast::AttributeList{disable_validation}),
+ b.Param("offset", b.ty.u32()),
+ b.Param("value", el_ast_ty),
+ };
+
+ auto name = b.Sym();
+
+ if (auto* intrinsic = IntrinsicStoreFor(ctx.dst, storage_class, el_ty)) {
+ auto* func = b.create<ast::Function>(
+ name, params, b.ty.void_(), nullptr,
+ ast::AttributeList{
+ intrinsic,
+ b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
+ },
+ ast::AttributeList{});
+ b.AST().AddFunction(func);
+ } else {
+ ast::StatementList body;
+ if (auto* arr_ty = el_ty->As<sem::Array>()) {
+ // fn store_func(buf : buf_ty, offset : u32, value : el_ty) {
+ // var array = value; // No dynamic indexing on constant arrays
+ // for (var i = 0u; i < array_count; i = i + 1) {
+ // arr[i] = el_store_func(buf, offset + i * array_stride,
+ // value[i])
+ // }
+ // return arr;
+ // }
+ auto* array = b.Var(b.Symbols().New("array"), nullptr, b.Expr("value"));
+ auto store = StoreFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
+ auto* i = b.Var(b.Symbols().New("i"), nullptr, b.Expr(0u));
+ auto* for_init = b.Decl(i);
+ auto* for_cond = b.create<ast::BinaryExpression>(
+ ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(arr_ty->Count()));
+ auto* for_cont = b.Assign(i, b.Add(i, 1u));
+ auto* arr_el = b.IndexAccessor(array, i);
+ auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, arr_ty->Stride()));
+ auto* store_stmt = b.CallStmt(b.Call(store, "buffer", el_offset, arr_el));
+ auto* for_loop = b.For(for_init, for_cond, for_cont, b.Block(store_stmt));
+
+ body = {b.Decl(array), for_loop};
+ } else if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
+ auto* vec_ty = mat_ty->ColumnType();
+ Symbol store = StoreFunc(buf_ty, vec_ty, var_user);
+ for (uint32_t i = 0; i < mat_ty->columns(); i++) {
+ auto* offset = b.Add("offset", i * mat_ty->ColumnStride());
+ auto* access = b.IndexAccessor("value", i);
+ auto* call = b.Call(store, "buffer", offset, access);
+ body.emplace_back(b.CallStmt(call));
+ }
+ } else if (auto* str = el_ty->As<sem::Struct>()) {
+ for (auto* member : str->Members()) {
+ auto* offset = b.Add("offset", member->Offset());
+ auto* access =
+ b.MemberAccessor("value", ctx.Clone(member->Declaration()->symbol));
+ Symbol store = StoreFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
+ auto* call = b.Call(store, "buffer", offset, access);
+ body.emplace_back(b.CallStmt(call));
+ }
+ }
+ b.Func(name, params, b.ty.void_(), body);
+ }
+
+ return name;
+ });
+ }
+
+ /// AtomicFunc() returns a symbol to an intrinsic function that performs an
+ /// atomic operation from a storage buffer of type `buf_ty`. The function has
+ /// the signature:
+ // `fn atomic_op(buf : buf_ty, offset : u32, ...) -> T`
+ /// @param buf_ty the storage buffer type
+ /// @param el_ty the storage buffer element type
+ /// @param intrinsic the atomic intrinsic
+ /// @param var_user the variable user
+ /// @return the name of the function that performs the load
+ Symbol AtomicFunc(const sem::Type* buf_ty,
+ const sem::Type* el_ty,
+ const sem::Builtin* intrinsic,
+ const sem::VariableUser* var_user) {
+ auto op = intrinsic->Type();
+ return utils::GetOrCreate(atomic_funcs, AtomicKey{buf_ty, el_ty, op}, [&] {
+ auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
+ auto* disable_validation =
+ b.Disable(ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
+ // The first parameter to all WGSL atomics is the expression to the
+ // atomic. This is replaced with two parameters: the buffer and offset.
+
+ ast::VariableList params = {
+ // Note: The buffer parameter requires the kStorage StorageClass in
+ // order for HLSL to emit this as a ByteAddressBuffer.
+ b.create<ast::Variable>(b.Sym("buffer"), ast::StorageClass::kStorage,
+ var_user->Variable()->Access(), buf_ast_ty, true, false,
+ nullptr, ast::AttributeList{disable_validation}),
+ b.Param("offset", b.ty.u32()),
+ };
+
+ // Other parameters are copied as-is:
+ for (size_t i = 1; i < intrinsic->Parameters().size(); i++) {
+ auto* param = intrinsic->Parameters()[i];
+ auto* ty = CreateASTTypeFor(ctx, param->Type());
+ params.emplace_back(b.Param("param_" + std::to_string(i), ty));
+ }
+
+ auto* atomic = IntrinsicAtomicFor(ctx.dst, op, el_ty);
+ if (atomic == nullptr) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "IntrinsicAtomicFor() returned nullptr for op " << op << " and type "
+ << el_ty->TypeInfo().name;
+ }
+
+ auto* ret_ty = CreateASTTypeFor(ctx, intrinsic->ReturnType());
+ auto* func =
+ b.create<ast::Function>(b.Sym(), params, ret_ty, nullptr,
+ ast::AttributeList{
+ atomic,
+ b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
+ },
+ ast::AttributeList{});
+
+ b.AST().AddFunction(func);
+ return func->symbol;
+ });
+ }
};
-DecomposeMemoryAccess::Intrinsic::Intrinsic(ProgramID pid,
- Op o,
- ast::StorageClass sc,
- DataType ty)
+DecomposeMemoryAccess::Intrinsic::Intrinsic(ProgramID pid, Op o, ast::StorageClass sc, DataType ty)
: Base(pid), op(o), storage_class(sc), type(ty) {}
DecomposeMemoryAccess::Intrinsic::~Intrinsic() = default;
std::string DecomposeMemoryAccess::Intrinsic::InternalName() const {
- std::stringstream ss;
- switch (op) {
- case Op::kLoad:
- ss << "intrinsic_load_";
- break;
- case Op::kStore:
- ss << "intrinsic_store_";
- break;
- case Op::kAtomicLoad:
- ss << "intrinsic_atomic_load_";
- break;
- case Op::kAtomicStore:
- ss << "intrinsic_atomic_store_";
- break;
- case Op::kAtomicAdd:
- ss << "intrinsic_atomic_add_";
- break;
- case Op::kAtomicSub:
- ss << "intrinsic_atomic_sub_";
- break;
- case Op::kAtomicMax:
- ss << "intrinsic_atomic_max_";
- break;
- case Op::kAtomicMin:
- ss << "intrinsic_atomic_min_";
- break;
- case Op::kAtomicAnd:
- ss << "intrinsic_atomic_and_";
- break;
- case Op::kAtomicOr:
- ss << "intrinsic_atomic_or_";
- break;
- case Op::kAtomicXor:
- ss << "intrinsic_atomic_xor_";
- break;
- case Op::kAtomicExchange:
- ss << "intrinsic_atomic_exchange_";
- break;
- case Op::kAtomicCompareExchangeWeak:
- ss << "intrinsic_atomic_compare_exchange_weak_";
- break;
- }
- ss << storage_class << "_";
- switch (type) {
- case DataType::kU32:
- ss << "u32";
- break;
- case DataType::kF32:
- ss << "f32";
- break;
- case DataType::kI32:
- ss << "i32";
- break;
- case DataType::kVec2U32:
- ss << "vec2_u32";
- break;
- case DataType::kVec2F32:
- ss << "vec2_f32";
- break;
- case DataType::kVec2I32:
- ss << "vec2_i32";
- break;
- case DataType::kVec3U32:
- ss << "vec3_u32";
- break;
- case DataType::kVec3F32:
- ss << "vec3_f32";
- break;
- case DataType::kVec3I32:
- ss << "vec3_i32";
- break;
- case DataType::kVec4U32:
- ss << "vec4_u32";
- break;
- case DataType::kVec4F32:
- ss << "vec4_f32";
- break;
- case DataType::kVec4I32:
- ss << "vec4_i32";
- break;
- }
- return ss.str();
+ std::stringstream ss;
+ switch (op) {
+ case Op::kLoad:
+ ss << "intrinsic_load_";
+ break;
+ case Op::kStore:
+ ss << "intrinsic_store_";
+ break;
+ case Op::kAtomicLoad:
+ ss << "intrinsic_atomic_load_";
+ break;
+ case Op::kAtomicStore:
+ ss << "intrinsic_atomic_store_";
+ break;
+ case Op::kAtomicAdd:
+ ss << "intrinsic_atomic_add_";
+ break;
+ case Op::kAtomicSub:
+ ss << "intrinsic_atomic_sub_";
+ break;
+ case Op::kAtomicMax:
+ ss << "intrinsic_atomic_max_";
+ break;
+ case Op::kAtomicMin:
+ ss << "intrinsic_atomic_min_";
+ break;
+ case Op::kAtomicAnd:
+ ss << "intrinsic_atomic_and_";
+ break;
+ case Op::kAtomicOr:
+ ss << "intrinsic_atomic_or_";
+ break;
+ case Op::kAtomicXor:
+ ss << "intrinsic_atomic_xor_";
+ break;
+ case Op::kAtomicExchange:
+ ss << "intrinsic_atomic_exchange_";
+ break;
+ case Op::kAtomicCompareExchangeWeak:
+ ss << "intrinsic_atomic_compare_exchange_weak_";
+ break;
+ }
+ ss << storage_class << "_";
+ switch (type) {
+ case DataType::kU32:
+ ss << "u32";
+ break;
+ case DataType::kF32:
+ ss << "f32";
+ break;
+ case DataType::kI32:
+ ss << "i32";
+ break;
+ case DataType::kVec2U32:
+ ss << "vec2_u32";
+ break;
+ case DataType::kVec2F32:
+ ss << "vec2_f32";
+ break;
+ case DataType::kVec2I32:
+ ss << "vec2_i32";
+ break;
+ case DataType::kVec3U32:
+ ss << "vec3_u32";
+ break;
+ case DataType::kVec3F32:
+ ss << "vec3_f32";
+ break;
+ case DataType::kVec3I32:
+ ss << "vec3_i32";
+ break;
+ case DataType::kVec4U32:
+ ss << "vec4_u32";
+ break;
+ case DataType::kVec4F32:
+ ss << "vec4_f32";
+ break;
+ case DataType::kVec4I32:
+ ss << "vec4_i32";
+ break;
+ }
+ return ss.str();
}
const DecomposeMemoryAccess::Intrinsic* DecomposeMemoryAccess::Intrinsic::Clone(
CloneContext* ctx) const {
- return ctx->dst->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
- ctx->dst->ID(), op, storage_class, type);
+ return ctx->dst->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(ctx->dst->ID(), op,
+ storage_class, type);
}
DecomposeMemoryAccess::DecomposeMemoryAccess() = default;
DecomposeMemoryAccess::~DecomposeMemoryAccess() = default;
-bool DecomposeMemoryAccess::ShouldRun(const Program* program,
- const DataMap&) const {
- for (auto* decl : program->AST().GlobalDeclarations()) {
- if (auto* var = program->Sem().Get<sem::Variable>(decl)) {
- if (var->StorageClass() == ast::StorageClass::kStorage ||
- var->StorageClass() == ast::StorageClass::kUniform) {
- return true;
- }
+bool DecomposeMemoryAccess::ShouldRun(const Program* program, const DataMap&) const {
+ for (auto* decl : program->AST().GlobalDeclarations()) {
+ if (auto* var = program->Sem().Get<sem::Variable>(decl)) {
+ if (var->StorageClass() == ast::StorageClass::kStorage ||
+ var->StorageClass() == ast::StorageClass::kUniform) {
+ return true;
+ }
+ }
}
- }
- return false;
+ return false;
}
-void DecomposeMemoryAccess::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- auto& sem = ctx.src->Sem();
+void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ auto& sem = ctx.src->Sem();
- State state(ctx);
+ State state(ctx);
- // Scan the AST nodes for storage and uniform buffer accesses. Complex
- // expression chains (e.g. `storage_buffer.foo.bar[20].x`) are handled by
- // maintaining an offset chain via the `state.TakeAccess()`,
- // `state.AddAccess()` methods.
- //
- // Inner-most expression nodes are guaranteed to be visited first because AST
- // nodes are fully immutable and require their children to be constructed
- // first so their pointer can be passed to the parent's constructor.
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- if (auto* ident = node->As<ast::IdentifierExpression>()) {
- // X
- if (auto* var = sem.Get<sem::VariableUser>(ident)) {
- if (var->Variable()->StorageClass() == ast::StorageClass::kStorage ||
- var->Variable()->StorageClass() == ast::StorageClass::kUniform) {
- // Variable to a storage or uniform buffer
- state.AddAccess(ident, {
- var,
- state.ToOffset(0u),
- var->Type()->UnwrapRef(),
- });
+ // Scan the AST nodes for storage and uniform buffer accesses. Complex
+ // expression chains (e.g. `storage_buffer.foo.bar[20].x`) are handled by
+ // maintaining an offset chain via the `state.TakeAccess()`,
+ // `state.AddAccess()` methods.
+ //
+ // Inner-most expression nodes are guaranteed to be visited first because AST
+ // nodes are fully immutable and require their children to be constructed
+ // first so their pointer can be passed to the parent's constructor.
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* ident = node->As<ast::IdentifierExpression>()) {
+ // X
+ if (auto* var = sem.Get<sem::VariableUser>(ident)) {
+ if (var->Variable()->StorageClass() == ast::StorageClass::kStorage ||
+ var->Variable()->StorageClass() == ast::StorageClass::kUniform) {
+ // Variable to a storage or uniform buffer
+ state.AddAccess(ident, {
+ var,
+ state.ToOffset(0u),
+ var->Type()->UnwrapRef(),
+ });
+ }
+ }
+ continue;
}
- }
- continue;
+
+ if (auto* accessor = node->As<ast::MemberAccessorExpression>()) {
+ // X.Y
+ auto* accessor_sem = sem.Get(accessor);
+ if (auto* swizzle = accessor_sem->As<sem::Swizzle>()) {
+ if (swizzle->Indices().size() == 1) {
+ if (auto access = state.TakeAccess(accessor->structure)) {
+ auto* vec_ty = access.type->As<sem::Vector>();
+ auto* offset = state.Mul(vec_ty->type()->Size(), swizzle->Indices()[0]);
+ state.AddAccess(accessor, {
+ access.var,
+ state.Add(access.offset, offset),
+ vec_ty->type()->UnwrapRef(),
+ });
+ }
+ }
+ } else {
+ if (auto access = state.TakeAccess(accessor->structure)) {
+ auto* str_ty = access.type->As<sem::Struct>();
+ auto* member = str_ty->FindMember(accessor->member->symbol);
+ auto offset = member->Offset();
+ state.AddAccess(accessor, {
+ access.var,
+ state.Add(access.offset, offset),
+ member->Type()->UnwrapRef(),
+ });
+ }
+ }
+ continue;
+ }
+
+ if (auto* accessor = node->As<ast::IndexAccessorExpression>()) {
+ if (auto access = state.TakeAccess(accessor->object)) {
+ // X[Y]
+ if (auto* arr = access.type->As<sem::Array>()) {
+ auto* offset = state.Mul(arr->Stride(), accessor->index);
+ state.AddAccess(accessor, {
+ access.var,
+ state.Add(access.offset, offset),
+ arr->ElemType()->UnwrapRef(),
+ });
+ continue;
+ }
+ if (auto* vec_ty = access.type->As<sem::Vector>()) {
+ auto* offset = state.Mul(vec_ty->type()->Size(), accessor->index);
+ state.AddAccess(accessor, {
+ access.var,
+ state.Add(access.offset, offset),
+ vec_ty->type()->UnwrapRef(),
+ });
+ continue;
+ }
+ if (auto* mat_ty = access.type->As<sem::Matrix>()) {
+ auto* offset = state.Mul(mat_ty->ColumnStride(), accessor->index);
+ state.AddAccess(accessor, {
+ access.var,
+ state.Add(access.offset, offset),
+ mat_ty->ColumnType(),
+ });
+ continue;
+ }
+ }
+ }
+
+ if (auto* op = node->As<ast::UnaryOpExpression>()) {
+ if (op->op == ast::UnaryOp::kAddressOf) {
+ // &X
+ if (auto access = state.TakeAccess(op->expr)) {
+ // HLSL does not support pointers, so just take the access from the
+ // reference and place it on the pointer.
+ state.AddAccess(op, access);
+ continue;
+ }
+ }
+ }
+
+ if (auto* assign = node->As<ast::AssignmentStatement>()) {
+ // X = Y
+ // Move the LHS access to a store.
+ if (auto lhs = state.TakeAccess(assign->lhs)) {
+ state.stores.emplace_back(Store{assign, lhs});
+ }
+ }
+
+ if (auto* call_expr = node->As<ast::CallExpression>()) {
+ auto* call = sem.Get(call_expr);
+ if (auto* builtin = call->Target()->As<sem::Builtin>()) {
+ if (builtin->Type() == sem::BuiltinType::kArrayLength) {
+ // arrayLength(X)
+ // Don't convert X into a load, this builtin actually requires the
+ // real pointer.
+ state.TakeAccess(call_expr->args[0]);
+ continue;
+ }
+ if (builtin->IsAtomic()) {
+ if (auto access = state.TakeAccess(call_expr->args[0])) {
+ // atomic___(X)
+ ctx.Replace(call_expr, [=, &ctx, &state] {
+ auto* buf = access.var->Declaration();
+ auto* offset = access.offset->Build(ctx);
+ auto* buf_ty = access.var->Type()->UnwrapRef();
+ auto* el_ty = access.type->UnwrapRef()->As<sem::Atomic>()->Type();
+ Symbol func = state.AtomicFunc(buf_ty, el_ty, builtin,
+ access.var->As<sem::VariableUser>());
+
+ ast::ExpressionList args{ctx.Clone(buf), offset};
+ for (size_t i = 1; i < call_expr->args.size(); i++) {
+ auto* arg = call_expr->args[i];
+ args.emplace_back(ctx.Clone(arg));
+ }
+ return ctx.dst->Call(func, args);
+ });
+ }
+ }
+ }
+ }
}
- if (auto* accessor = node->As<ast::MemberAccessorExpression>()) {
- // X.Y
- auto* accessor_sem = sem.Get(accessor);
- if (auto* swizzle = accessor_sem->As<sem::Swizzle>()) {
- if (swizzle->Indices().size() == 1) {
- if (auto access = state.TakeAccess(accessor->structure)) {
- auto* vec_ty = access.type->As<sem::Vector>();
- auto* offset =
- state.Mul(vec_ty->type()->Size(), swizzle->Indices()[0]);
- state.AddAccess(accessor, {
- access.var,
- state.Add(access.offset, offset),
- vec_ty->type()->UnwrapRef(),
- });
- }
+ // All remaining accesses are loads, transform these into calls to the
+ // corresponding load function
+ for (auto* expr : state.expression_order) {
+ auto access_it = state.accesses.find(expr);
+ if (access_it == state.accesses.end()) {
+ continue;
}
- } else {
- if (auto access = state.TakeAccess(accessor->structure)) {
- auto* str_ty = access.type->As<sem::Struct>();
- auto* member = str_ty->FindMember(accessor->member->symbol);
- auto offset = member->Offset();
- state.AddAccess(accessor, {
- access.var,
- state.Add(access.offset, offset),
- member->Type()->UnwrapRef(),
- });
- }
- }
- continue;
+ BufferAccess access = access_it->second;
+ ctx.Replace(expr, [=, &ctx, &state] {
+ auto* buf = access.var->Declaration();
+ auto* offset = access.offset->Build(ctx);
+ auto* buf_ty = access.var->Type()->UnwrapRef();
+ auto* el_ty = access.type->UnwrapRef();
+ Symbol func = state.LoadFunc(buf_ty, el_ty, access.var->As<sem::VariableUser>());
+ return ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset);
+ });
}
- if (auto* accessor = node->As<ast::IndexAccessorExpression>()) {
- if (auto access = state.TakeAccess(accessor->object)) {
- // X[Y]
- if (auto* arr = access.type->As<sem::Array>()) {
- auto* offset = state.Mul(arr->Stride(), accessor->index);
- state.AddAccess(accessor, {
- access.var,
- state.Add(access.offset, offset),
- arr->ElemType()->UnwrapRef(),
- });
- continue;
- }
- if (auto* vec_ty = access.type->As<sem::Vector>()) {
- auto* offset = state.Mul(vec_ty->type()->Size(), accessor->index);
- state.AddAccess(accessor, {
- access.var,
- state.Add(access.offset, offset),
- vec_ty->type()->UnwrapRef(),
- });
- continue;
- }
- if (auto* mat_ty = access.type->As<sem::Matrix>()) {
- auto* offset = state.Mul(mat_ty->ColumnStride(), accessor->index);
- state.AddAccess(accessor, {
- access.var,
- state.Add(access.offset, offset),
- mat_ty->ColumnType(),
- });
- continue;
- }
- }
+ // And replace all storage and uniform buffer assignments with stores
+ for (auto store : state.stores) {
+ ctx.Replace(store.assignment, [=, &ctx, &state] {
+ auto* buf = store.target.var->Declaration();
+ auto* offset = store.target.offset->Build(ctx);
+ auto* buf_ty = store.target.var->Type()->UnwrapRef();
+ auto* el_ty = store.target.type->UnwrapRef();
+ auto* value = store.assignment->rhs;
+ Symbol func = state.StoreFunc(buf_ty, el_ty, store.target.var->As<sem::VariableUser>());
+ auto* call =
+ ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset, ctx.Clone(value));
+ return ctx.dst->CallStmt(call);
+ });
}
- if (auto* op = node->As<ast::UnaryOpExpression>()) {
- if (op->op == ast::UnaryOp::kAddressOf) {
- // &X
- if (auto access = state.TakeAccess(op->expr)) {
- // HLSL does not support pointers, so just take the access from the
- // reference and place it on the pointer.
- state.AddAccess(op, access);
- continue;
- }
- }
- }
-
- if (auto* assign = node->As<ast::AssignmentStatement>()) {
- // X = Y
- // Move the LHS access to a store.
- if (auto lhs = state.TakeAccess(assign->lhs)) {
- state.stores.emplace_back(Store{assign, lhs});
- }
- }
-
- if (auto* call_expr = node->As<ast::CallExpression>()) {
- auto* call = sem.Get(call_expr);
- if (auto* builtin = call->Target()->As<sem::Builtin>()) {
- if (builtin->Type() == sem::BuiltinType::kArrayLength) {
- // arrayLength(X)
- // Don't convert X into a load, this builtin actually requires the
- // real pointer.
- state.TakeAccess(call_expr->args[0]);
- continue;
- }
- if (builtin->IsAtomic()) {
- if (auto access = state.TakeAccess(call_expr->args[0])) {
- // atomic___(X)
- ctx.Replace(call_expr, [=, &ctx, &state] {
- auto* buf = access.var->Declaration();
- auto* offset = access.offset->Build(ctx);
- auto* buf_ty = access.var->Type()->UnwrapRef();
- auto* el_ty = access.type->UnwrapRef()->As<sem::Atomic>()->Type();
- Symbol func = state.AtomicFunc(
- buf_ty, el_ty, builtin, access.var->As<sem::VariableUser>());
-
- ast::ExpressionList args{ctx.Clone(buf), offset};
- for (size_t i = 1; i < call_expr->args.size(); i++) {
- auto* arg = call_expr->args[i];
- args.emplace_back(ctx.Clone(arg));
- }
- return ctx.dst->Call(func, args);
- });
- }
- }
- }
- }
- }
-
- // All remaining accesses are loads, transform these into calls to the
- // corresponding load function
- for (auto* expr : state.expression_order) {
- auto access_it = state.accesses.find(expr);
- if (access_it == state.accesses.end()) {
- continue;
- }
- BufferAccess access = access_it->second;
- ctx.Replace(expr, [=, &ctx, &state] {
- auto* buf = access.var->Declaration();
- auto* offset = access.offset->Build(ctx);
- auto* buf_ty = access.var->Type()->UnwrapRef();
- auto* el_ty = access.type->UnwrapRef();
- Symbol func =
- state.LoadFunc(buf_ty, el_ty, access.var->As<sem::VariableUser>());
- return ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset);
- });
- }
-
- // And replace all storage and uniform buffer assignments with stores
- for (auto store : state.stores) {
- ctx.Replace(store.assignment, [=, &ctx, &state] {
- auto* buf = store.target.var->Declaration();
- auto* offset = store.target.offset->Build(ctx);
- auto* buf_ty = store.target.var->Type()->UnwrapRef();
- auto* el_ty = store.target.type->UnwrapRef();
- auto* value = store.assignment->rhs;
- Symbol func = state.StoreFunc(buf_ty, el_ty,
- store.target.var->As<sem::VariableUser>());
- auto* call = ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset,
- ctx.Clone(value));
- return ctx.dst->CallStmt(call);
- });
- }
-
- ctx.Clone();
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/decompose_memory_access.h b/src/tint/transform/decompose_memory_access.h
index 9aa0eb5..7a7b783 100644
--- a/src/tint/transform/decompose_memory_access.h
+++ b/src/tint/transform/decompose_memory_access.h
@@ -30,99 +30,95 @@
/// DecomposeMemoryAccess is a transform used to replace storage and uniform
/// buffer accesses with a combination of load, store or atomic functions on
/// primitive types.
-class DecomposeMemoryAccess final
- : public Castable<DecomposeMemoryAccess, Transform> {
- public:
- /// Intrinsic is an InternalAttribute that's used to decorate a stub function
- /// so that the HLSL transforms this into calls to
- /// `[RW]ByteAddressBuffer.Load[N]()` or `[RW]ByteAddressBuffer.Store[N]()`,
- /// with a possible cast.
- class Intrinsic final : public Castable<Intrinsic, ast::InternalAttribute> {
- public:
- /// Intrinsic op
- enum class Op {
- kLoad,
- kStore,
- kAtomicLoad,
- kAtomicStore,
- kAtomicAdd,
- kAtomicSub,
- kAtomicMax,
- kAtomicMin,
- kAtomicAnd,
- kAtomicOr,
- kAtomicXor,
- kAtomicExchange,
- kAtomicCompareExchangeWeak,
- };
+class DecomposeMemoryAccess final : public Castable<DecomposeMemoryAccess, Transform> {
+ public:
+ /// Intrinsic is an InternalAttribute that's used to decorate a stub function
+ /// so that the HLSL transforms this into calls to
+ /// `[RW]ByteAddressBuffer.Load[N]()` or `[RW]ByteAddressBuffer.Store[N]()`,
+ /// with a possible cast.
+ class Intrinsic final : public Castable<Intrinsic, ast::InternalAttribute> {
+ public:
+ /// Intrinsic op
+ enum class Op {
+ kLoad,
+ kStore,
+ kAtomicLoad,
+ kAtomicStore,
+ kAtomicAdd,
+ kAtomicSub,
+ kAtomicMax,
+ kAtomicMin,
+ kAtomicAnd,
+ kAtomicOr,
+ kAtomicXor,
+ kAtomicExchange,
+ kAtomicCompareExchangeWeak,
+ };
- /// Intrinsic data type
- enum class DataType {
- kU32,
- kF32,
- kI32,
- kVec2U32,
- kVec2F32,
- kVec2I32,
- kVec3U32,
- kVec3F32,
- kVec3I32,
- kVec4U32,
- kVec4F32,
- kVec4I32,
+ /// Intrinsic data type
+ enum class DataType {
+ kU32,
+ kF32,
+ kI32,
+ kVec2U32,
+ kVec2F32,
+ kVec2I32,
+ kVec3U32,
+ kVec3F32,
+ kVec3I32,
+ kVec4U32,
+ kVec4F32,
+ kVec4I32,
+ };
+
+ /// Constructor
+ /// @param program_id the identifier of the program that owns this node
+ /// @param o the op of the intrinsic
+ /// @param sc the storage class of the buffer
+ /// @param ty the data type of the intrinsic
+ Intrinsic(ProgramID program_id, Op o, ast::StorageClass sc, DataType ty);
+ /// Destructor
+ ~Intrinsic() override;
+
+ /// @return a short description of the internal attribute which will be
+ /// displayed as `@internal(<name>)`
+ std::string InternalName() const override;
+
+ /// Performs a deep clone of this object using the CloneContext `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned object
+ const Intrinsic* Clone(CloneContext* ctx) const override;
+
+ /// The op of the intrinsic
+ const Op op;
+
+ /// The storage class of the buffer this intrinsic operates on
+ ast::StorageClass const storage_class;
+
+ /// The type of the intrinsic
+ const DataType type;
};
/// Constructor
- /// @param program_id the identifier of the program that owns this node
- /// @param o the op of the intrinsic
- /// @param sc the storage class of the buffer
- /// @param ty the data type of the intrinsic
- Intrinsic(ProgramID program_id, Op o, ast::StorageClass sc, DataType ty);
+ DecomposeMemoryAccess();
/// Destructor
- ~Intrinsic() override;
+ ~DecomposeMemoryAccess() override;
- /// @return a short description of the internal attribute which will be
- /// displayed as `@internal(<name>)`
- std::string InternalName() const override;
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- /// Performs a deep clone of this object using the CloneContext `ctx`.
- /// @param ctx the clone context
- /// @return the newly cloned object
- const Intrinsic* Clone(CloneContext* ctx) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
- /// The op of the intrinsic
- const Op op;
-
- /// The storage class of the buffer this intrinsic operates on
- ast::StorageClass const storage_class;
-
- /// The type of the intrinsic
- const DataType type;
- };
-
- /// Constructor
- DecomposeMemoryAccess();
- /// Destructor
- ~DecomposeMemoryAccess() override;
-
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
-
- struct State;
+ struct State;
};
} // namespace tint::transform
diff --git a/src/tint/transform/decompose_memory_access_test.cc b/src/tint/transform/decompose_memory_access_test.cc
index dadb422..22b5da4 100644
--- a/src/tint/transform/decompose_memory_access_test.cc
+++ b/src/tint/transform/decompose_memory_access_test.cc
@@ -22,35 +22,35 @@
using DecomposeMemoryAccessTest = TransformTest;
TEST_F(DecomposeMemoryAccessTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<DecomposeMemoryAccess>(src));
+ EXPECT_FALSE(ShouldRun<DecomposeMemoryAccess>(src));
}
TEST_F(DecomposeMemoryAccessTest, ShouldRunStorageBuffer) {
- auto* src = R"(
+ auto* src = R"(
struct Buffer {
i : i32,
};
@group(0) @binding(0) var<storage, read_write> sb : Buffer;
)";
- EXPECT_TRUE(ShouldRun<DecomposeMemoryAccess>(src));
+ EXPECT_TRUE(ShouldRun<DecomposeMemoryAccess>(src));
}
TEST_F(DecomposeMemoryAccessTest, ShouldRunUniformBuffer) {
- auto* src = R"(
+ auto* src = R"(
struct Buffer {
i : i32,
};
@group(0) @binding(0) var<uniform> ub : Buffer;
)";
- EXPECT_TRUE(ShouldRun<DecomposeMemoryAccess>(src));
+ EXPECT_TRUE(ShouldRun<DecomposeMemoryAccess>(src));
}
TEST_F(DecomposeMemoryAccessTest, SB_BasicLoad) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
a : i32,
b : u32,
@@ -105,7 +105,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct SB {
a : i32,
b : u32,
@@ -240,13 +240,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, SB_BasicLoad_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var a : i32 = sb.a;
@@ -301,7 +301,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_load_storage_i32) @internal(disable_validation__function_has_no_body)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, offset : u32) -> i32
@@ -436,13 +436,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, UB_BasicLoad) {
- auto* src = R"(
+ auto* src = R"(
struct UB {
a : i32,
b : u32,
@@ -497,7 +497,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct UB {
a : i32,
b : u32,
@@ -632,13 +632,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, UB_BasicLoad_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var a : i32 = ub.a;
@@ -693,7 +693,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_load_uniform_i32) @internal(disable_validation__function_has_no_body)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : UB, offset : u32) -> i32
@@ -828,13 +828,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, SB_BasicStore) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
a : i32,
b : u32,
@@ -889,7 +889,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct SB {
a : i32,
b : u32,
@@ -1041,13 +1041,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, SB_BasicStore_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
sb.a = i32();
@@ -1102,7 +1102,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_store_storage_i32) @internal(disable_validation__function_has_no_body)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, offset : u32, value : i32)
@@ -1254,13 +1254,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, LoadStructure) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
a : i32,
b : u32,
@@ -1294,7 +1294,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct SB {
a : i32,
b : u32,
@@ -1412,13 +1412,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, LoadStructure_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var x : SB = sb;
@@ -1452,7 +1452,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_load_storage_i32) @internal(disable_validation__function_has_no_body)
fn tint_symbol_1(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, offset : u32) -> i32
@@ -1570,13 +1570,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, StoreStructure) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
a : i32,
b : u32,
@@ -1610,7 +1610,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct SB {
a : i32,
b : u32,
@@ -1766,13 +1766,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, StoreStructure_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
sb = SB();
@@ -1806,7 +1806,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_store_storage_i32) @internal(disable_validation__function_has_no_body)
fn tint_symbol_1(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, offset : u32, value : i32)
@@ -1962,13 +1962,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, ComplexStaticAccessChain) {
- auto* src = R"(
+ auto* src = R"(
// sizeof(S1) == 32
// alignof(S1) == 16
struct S1 {
@@ -1999,14 +1999,14 @@
}
)";
- // sb.b[4].b[1].b.z
- // ^ ^ ^ ^ ^ ^
- // | | | | | |
- // 128 | |688 | 712
- // | | |
- // 640 656 704
+ // sb.b[4].b[1].b.z
+ // ^ ^ ^ ^ ^ ^
+ // | | | | | |
+ // 128 | |688 | 712
+ // | | |
+ // 640 656 704
- auto* expect = R"(
+ auto* expect = R"(
struct S1 {
a : i32,
b : vec3<f32>,
@@ -2036,13 +2036,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, ComplexStaticAccessChain_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var x : f32 = sb.b[4].b[1].b.z;
@@ -2069,14 +2069,14 @@
};
)";
- // sb.b[4].b[1].b.z
- // ^ ^ ^ ^ ^ ^
- // | | | | | |
- // 128 | |688 | 712
- // | | |
- // 640 656 704
+ // sb.b[4].b[1].b.z
+ // ^ ^ ^ ^ ^ ^
+ // | | | | | |
+ // 128 | |688 | 712
+ // | | |
+ // 640 656 704
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, offset : u32) -> f32
@@ -2106,13 +2106,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, ComplexDynamicAccessChain) {
- auto* src = R"(
+ auto* src = R"(
struct S1 {
a : i32,
b : vec3<f32>,
@@ -2142,7 +2142,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S1 {
a : i32,
b : vec3<f32>,
@@ -2175,13 +2175,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, ComplexDynamicAccessChain_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var i : i32 = 4;
@@ -2211,7 +2211,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, offset : u32) -> f32
@@ -2244,13 +2244,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, ComplexDynamicAccessChainWithAliases) {
- auto* src = R"(
+ auto* src = R"(
struct S1 {
a : i32,
b : vec3<f32>,
@@ -2288,7 +2288,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S1 {
a : i32,
b : vec3<f32>,
@@ -2329,14 +2329,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(DecomposeMemoryAccessTest,
- ComplexDynamicAccessChainWithAliases_OutOfOrder) {
- auto* src = R"(
+TEST_F(DecomposeMemoryAccessTest, ComplexDynamicAccessChainWithAliases_OutOfOrder) {
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var i : i32 = 4;
@@ -2374,7 +2373,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, offset : u32) -> f32
@@ -2415,13 +2414,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, StorageBufferAtomics) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
padding : vec4<f32>,
a : atomic<i32>,
@@ -2458,7 +2457,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct SB {
padding : vec4<f32>,
a : atomic<i32>,
@@ -2560,13 +2559,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, StorageBufferAtomics_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
atomicStore(&sb.a, 123);
@@ -2603,7 +2602,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@internal(intrinsic_atomic_store_storage_i32) @internal(disable_validation__function_has_no_body)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, offset : u32, param_1 : i32)
@@ -2705,13 +2704,13 @@
}
)";
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, WorkgroupBufferAtomics) {
- auto* src = R"(
+ auto* src = R"(
struct S {
padding : vec4<f32>,
a : atomic<i32>,
@@ -2747,15 +2746,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, WorkgroupBufferAtomics_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
atomicStore(&(w.a), 123);
@@ -2791,11 +2790,11 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<DecomposeMemoryAccess>(src);
+ auto got = Run<DecomposeMemoryAccess>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/decompose_strided_array.cc b/src/tint/transform/decompose_strided_array.cc
index 74b6903..bf36c06 100644
--- a/src/tint/transform/decompose_strided_array.cc
+++ b/src/tint/transform/decompose_strided_array.cc
@@ -40,121 +40,115 @@
DecomposeStridedArray::~DecomposeStridedArray() = default;
-bool DecomposeStridedArray::ShouldRun(const Program* program,
- const DataMap&) const {
- for (auto* node : program->ASTNodes().Objects()) {
- if (auto* ast = node->As<ast::Array>()) {
- if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
- return true;
- }
+bool DecomposeStridedArray::ShouldRun(const Program* program, const DataMap&) const {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (auto* ast = node->As<ast::Array>()) {
+ if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
+ return true;
+ }
+ }
}
- }
- return false;
+ return false;
}
-void DecomposeStridedArray::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- const auto& sem = ctx.src->Sem();
+void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ const auto& sem = ctx.src->Sem();
- static constexpr const char* kMemberName = "el";
+ static constexpr const char* kMemberName = "el";
- // Maps an array type in the source program to the name of the struct wrapper
- // type in the target program.
- std::unordered_map<const sem::Array*, Symbol> decomposed;
+ // Maps an array type in the source program to the name of the struct wrapper
+ // type in the target program.
+ std::unordered_map<const sem::Array*, Symbol> decomposed;
- // Find and replace all arrays with a @stride attribute with a array that has
- // the @stride removed. If the source array stride does not match the natural
- // stride for the array element type, then replace the array element type with
- // a structure, holding a single field with a @size attribute equal to the
- // array stride.
- ctx.ReplaceAll([&](const ast::Array* ast) -> const ast::Array* {
- if (auto* arr = sem.Get(ast)) {
- if (!arr->IsStrideImplicit()) {
- auto el_ty = utils::GetOrCreate(decomposed, arr, [&] {
- auto name = ctx.dst->Symbols().New("strided_arr");
- auto* member_ty = ctx.Clone(ast->type);
- auto* member = ctx.dst->Member(kMemberName, member_ty,
- {ctx.dst->MemberSize(arr->Stride())});
- ctx.dst->Structure(name, {member});
- return name;
- });
- auto* count = ctx.Clone(ast->count);
- return ctx.dst->ty.array(ctx.dst->ty.type_name(el_ty), count);
- }
- if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
- // Strip the @stride attribute
- auto* ty = ctx.Clone(ast->type);
- auto* count = ctx.Clone(ast->count);
- return ctx.dst->ty.array(ty, count);
- }
- }
- return nullptr;
- });
-
- // Find all array index-accessors expressions for arrays that have had their
- // element changed to a single field structure. These expressions are adjusted
- // to insert an additional member accessor for the single structure field.
- // Example: `arr[i]` -> `arr[i].el`
- ctx.ReplaceAll(
- [&](const ast::IndexAccessorExpression* idx) -> const ast::Expression* {
- if (auto* ty = ctx.src->TypeOf(idx->object)) {
- if (auto* arr = ty->UnwrapRef()->As<sem::Array>()) {
+ // Find and replace all arrays with a @stride attribute with a array that has
+ // the @stride removed. If the source array stride does not match the natural
+ // stride for the array element type, then replace the array element type with
+ // a structure, holding a single field with a @size attribute equal to the
+ // array stride.
+ ctx.ReplaceAll([&](const ast::Array* ast) -> const ast::Array* {
+ if (auto* arr = sem.Get(ast)) {
if (!arr->IsStrideImplicit()) {
- auto* expr = ctx.CloneWithoutTransform(idx);
- return ctx.dst->MemberAccessor(expr, kMemberName);
+ auto el_ty = utils::GetOrCreate(decomposed, arr, [&] {
+ auto name = ctx.dst->Symbols().New("strided_arr");
+ auto* member_ty = ctx.Clone(ast->type);
+ auto* member = ctx.dst->Member(kMemberName, member_ty,
+ {ctx.dst->MemberSize(arr->Stride())});
+ ctx.dst->Structure(name, {member});
+ return name;
+ });
+ auto* count = ctx.Clone(ast->count);
+ return ctx.dst->ty.array(ctx.dst->ty.type_name(el_ty), count);
}
- }
+ if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
+ // Strip the @stride attribute
+ auto* ty = ctx.Clone(ast->type);
+ auto* count = ctx.Clone(ast->count);
+ return ctx.dst->ty.array(ty, count);
+ }
}
return nullptr;
- });
+ });
- // Find all array type constructor expressions for array types that have had
- // their element changed to a single field structure. These constructors are
- // adjusted to wrap each of the arguments with an additional constructor for
- // the new element structure type.
- // Example:
- // `@stride(32) array<i32, 3>(1, 2, 3)`
- // ->
- // `array<strided_arr, 3>(strided_arr(1), strided_arr(2), strided_arr(3))`
- ctx.ReplaceAll(
- [&](const ast::CallExpression* expr) -> const ast::Expression* {
+ // Find all array index-accessors expressions for arrays that have had their
+ // element changed to a single field structure. These expressions are adjusted
+ // to insert an additional member accessor for the single structure field.
+ // Example: `arr[i]` -> `arr[i].el`
+ ctx.ReplaceAll([&](const ast::IndexAccessorExpression* idx) -> const ast::Expression* {
+ if (auto* ty = ctx.src->TypeOf(idx->object)) {
+ if (auto* arr = ty->UnwrapRef()->As<sem::Array>()) {
+ if (!arr->IsStrideImplicit()) {
+ auto* expr = ctx.CloneWithoutTransform(idx);
+ return ctx.dst->MemberAccessor(expr, kMemberName);
+ }
+ }
+ }
+ return nullptr;
+ });
+
+ // Find all array type constructor expressions for array types that have had
+ // their element changed to a single field structure. These constructors are
+ // adjusted to wrap each of the arguments with an additional constructor for
+ // the new element structure type.
+ // Example:
+ // `@stride(32) array<i32, 3>(1, 2, 3)`
+ // ->
+ // `array<strided_arr, 3>(strided_arr(1), strided_arr(2), strided_arr(3))`
+ ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::Expression* {
if (!expr->args.empty()) {
- if (auto* call = sem.Get(expr)) {
- if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
- if (auto* arr = ctor->ReturnType()->As<sem::Array>()) {
- // Begin by cloning the array constructor type or name
- // If this is an unaliased array, this may add a new entry to
- // decomposed.
- // If this is an aliased array, decomposed should already be
- // populated with any strided aliases.
- ast::CallExpression::Target target;
- if (expr->target.type) {
- target.type = ctx.Clone(expr->target.type);
- } else {
- target.name = ctx.Clone(expr->target.name);
- }
+ if (auto* call = sem.Get(expr)) {
+ if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
+ if (auto* arr = ctor->ReturnType()->As<sem::Array>()) {
+ // Begin by cloning the array constructor type or name
+ // If this is an unaliased array, this may add a new entry to
+ // decomposed.
+ // If this is an aliased array, decomposed should already be
+ // populated with any strided aliases.
+ ast::CallExpression::Target target;
+ if (expr->target.type) {
+ target.type = ctx.Clone(expr->target.type);
+ } else {
+ target.name = ctx.Clone(expr->target.name);
+ }
- ast::ExpressionList args;
- if (auto it = decomposed.find(arr); it != decomposed.end()) {
- args.reserve(expr->args.size());
- for (auto* arg : expr->args) {
- args.emplace_back(
- ctx.dst->Call(it->second, ctx.Clone(arg)));
- }
- } else {
- args = ctx.Clone(expr->args);
- }
+ ast::ExpressionList args;
+ if (auto it = decomposed.find(arr); it != decomposed.end()) {
+ args.reserve(expr->args.size());
+ for (auto* arg : expr->args) {
+ args.emplace_back(ctx.dst->Call(it->second, ctx.Clone(arg)));
+ }
+ } else {
+ args = ctx.Clone(expr->args);
+ }
- return target.type ? ctx.dst->Construct(target.type, args)
- : ctx.dst->Call(target.name, args);
- }
+ return target.type ? ctx.dst->Construct(target.type, args)
+ : ctx.dst->Call(target.name, args);
+ }
+ }
}
- }
}
return nullptr;
- });
- ctx.Clone();
+ });
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/decompose_strided_array.h b/src/tint/transform/decompose_strided_array.h
index 505f5cb..5dbaaa5 100644
--- a/src/tint/transform/decompose_strided_array.h
+++ b/src/tint/transform/decompose_strided_array.h
@@ -27,31 +27,27 @@
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
-class DecomposeStridedArray final
- : public Castable<DecomposeStridedArray, Transform> {
- public:
- /// Constructor
- DecomposeStridedArray();
+class DecomposeStridedArray final : public Castable<DecomposeStridedArray, Transform> {
+ public:
+ /// Constructor
+ DecomposeStridedArray();
- /// Destructor
- ~DecomposeStridedArray() override;
+ /// Destructor
+ ~DecomposeStridedArray() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/decompose_strided_array_test.cc b/src/tint/transform/decompose_strided_array_test.cc
index 16e08fb..1891527 100644
--- a/src/tint/transform/decompose_strided_array_test.cc
+++ b/src/tint/transform/decompose_strided_array_test.cc
@@ -30,65 +30,65 @@
using f32 = ProgramBuilder::f32;
TEST_F(DecomposeStridedArrayTest, ShouldRunEmptyModule) {
- ProgramBuilder b;
- EXPECT_FALSE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
+ ProgramBuilder b;
+ EXPECT_FALSE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
}
TEST_F(DecomposeStridedArrayTest, ShouldRunNonStridedArray) {
- // var<private> arr : array<f32, 4>
+ // var<private> arr : array<f32, 4>
- ProgramBuilder b;
- b.Global("arr", b.ty.array<f32, 4>(), ast::StorageClass::kPrivate);
- EXPECT_FALSE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
+ ProgramBuilder b;
+ b.Global("arr", b.ty.array<f32, 4>(), ast::StorageClass::kPrivate);
+ EXPECT_FALSE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
}
TEST_F(DecomposeStridedArrayTest, ShouldRunDefaultStridedArray) {
- // var<private> arr : @stride(4) array<f32, 4>
+ // var<private> arr : @stride(4) array<f32, 4>
- ProgramBuilder b;
- b.Global("arr", b.ty.array<f32, 4>(4), ast::StorageClass::kPrivate);
- EXPECT_TRUE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
+ ProgramBuilder b;
+ b.Global("arr", b.ty.array<f32, 4>(4), ast::StorageClass::kPrivate);
+ EXPECT_TRUE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
}
TEST_F(DecomposeStridedArrayTest, ShouldRunExplicitStridedArray) {
- // var<private> arr : @stride(16) array<f32, 4>
+ // var<private> arr : @stride(16) array<f32, 4>
- ProgramBuilder b;
- b.Global("arr", b.ty.array<f32, 4>(16), ast::StorageClass::kPrivate);
- EXPECT_TRUE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
+ ProgramBuilder b;
+ b.Global("arr", b.ty.array<f32, 4>(16), ast::StorageClass::kPrivate);
+ EXPECT_TRUE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
}
TEST_F(DecomposeStridedArrayTest, Empty) {
- auto* src = R"()";
- auto* expect = src;
+ auto* src = R"()";
+ auto* expect = src;
- auto got = Run<DecomposeStridedArray>(src);
+ auto got = Run<DecomposeStridedArray>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, PrivateDefaultStridedArray) {
- // var<private> arr : @stride(4) array<f32, 4>
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // let a : @stride(4) array<f32, 4> = a;
- // let b : f32 = arr[1];
- // }
+ // var<private> arr : @stride(4) array<f32, 4>
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // let a : @stride(4) array<f32, 4> = a;
+ // let b : f32 = arr[1];
+ // }
- ProgramBuilder b;
- b.Global("arr", b.ty.array<f32, 4>(4), ast::StorageClass::kPrivate);
- b.Func("f", {}, b.ty.void_(),
- {
- b.Decl(b.Let("a", b.ty.array<f32, 4>(4), b.Expr("arr"))),
- b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1))),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ ProgramBuilder b;
+ b.Global("arr", b.ty.array<f32, 4>(4), ast::StorageClass::kPrivate);
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Let("a", b.ty.array<f32, 4>(4), b.Expr("arr"))),
+ b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect = R"(
+ auto* expect = R"(
var<private> arr : array<f32, 4>;
@stage(compute) @workgroup_size(1)
@@ -98,34 +98,33 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, PrivateStridedArray) {
- // var<private> arr : @stride(32) array<f32, 4>
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // let a : @stride(32) array<f32, 4> = a;
- // let b : f32 = arr[1];
- // }
+ // var<private> arr : @stride(32) array<f32, 4>
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // let a : @stride(32) array<f32, 4> = a;
+ // let b : f32 = arr[1];
+ // }
- ProgramBuilder b;
- b.Global("arr", b.ty.array<f32, 4>(32), ast::StorageClass::kPrivate);
- b.Func("f", {}, b.ty.void_(),
- {
- b.Decl(b.Let("a", b.ty.array<f32, 4>(32), b.Expr("arr"))),
- b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1))),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ ProgramBuilder b;
+ b.Global("arr", b.ty.array<f32, 4>(32), ast::StorageClass::kPrivate);
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Let("a", b.ty.array<f32, 4>(32), b.Expr("arr"))),
+ b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect = R"(
+ auto* expect = R"(
struct strided_arr {
@size(32)
el : f32,
@@ -140,40 +139,36 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, ReadUniformStridedArray) {
- // struct S {
- // a : @stride(32) array<f32, 4>,
- // };
- // @group(0) @binding(0) var<uniform> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // let a : @stride(32) array<f32, 4> = s.a;
- // let b : f32 = s.a[1];
- // }
- ProgramBuilder b;
- auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
- b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
- b.GroupAndBinding(0, 0));
- b.Func("f", {}, b.ty.void_(),
- {
- b.Decl(b.Let("a", b.ty.array<f32, 4>(32),
- b.MemberAccessor("s", "a"))),
- b.Decl(b.Let("b", b.ty.f32(),
- b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // struct S {
+ // a : @stride(32) array<f32, 4>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // let a : @stride(32) array<f32, 4> = s.a;
+ // let b : f32 = s.a[1];
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform, b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Let("a", b.ty.array<f32, 4>(32), b.MemberAccessor("s", "a"))),
+ b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect = R"(
+ auto* expect = R"(
struct strided_arr {
@size(32)
el : f32,
@@ -192,44 +187,38 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, ReadUniformDefaultStridedArray) {
- // struct S {
- // a : @stride(16) array<vec4<f32>, 4>,
- // };
- // @group(0) @binding(0) var<uniform> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // let a : @stride(16) array<vec4<f32>, 4> = s.a;
- // let b : f32 = s.a[1][2];
- // }
- ProgramBuilder b;
- auto* S =
- b.Structure("S", {b.Member("a", b.ty.array(b.ty.vec4<f32>(), 4, 16))});
- b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
- b.GroupAndBinding(0, 0));
- b.Func(
- "f", {}, b.ty.void_(),
- {
- b.Decl(b.Let("a", b.ty.array(b.ty.vec4<f32>(), 4, 16),
- b.MemberAccessor("s", "a"))),
- b.Decl(b.Let("b", b.ty.f32(),
- b.IndexAccessor(
- b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 2))),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // struct S {
+ // a : @stride(16) array<vec4<f32>, 4>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // let a : @stride(16) array<vec4<f32>, 4> = s.a;
+ // let b : f32 = s.a[1][2];
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", {b.Member("a", b.ty.array(b.ty.vec4<f32>(), 4, 16))});
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform, b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Let("a", b.ty.array(b.ty.vec4<f32>(), 4, 16), b.MemberAccessor("s", "a"))),
+ b.Decl(b.Let("b", b.ty.f32(),
+ b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 2))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect =
- R"(
+ auto* expect =
+ R"(
struct S {
a : array<vec4<f32>, 4>,
}
@@ -243,40 +232,36 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, ReadStorageStridedArray) {
- // struct S {
- // a : @stride(32) array<f32, 4>,
- // };
- // @group(0) @binding(0) var<storage> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // let a : @stride(32) array<f32, 4> = s.a;
- // let b : f32 = s.a[1];
- // }
- ProgramBuilder b;
- auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
- b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
- b.GroupAndBinding(0, 0));
- b.Func("f", {}, b.ty.void_(),
- {
- b.Decl(b.Let("a", b.ty.array<f32, 4>(32),
- b.MemberAccessor("s", "a"))),
- b.Decl(b.Let("b", b.ty.f32(),
- b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // struct S {
+ // a : @stride(32) array<f32, 4>,
+ // };
+ // @group(0) @binding(0) var<storage> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // let a : @stride(32) array<f32, 4> = s.a;
+ // let b : f32 = s.a[1];
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Let("a", b.ty.array<f32, 4>(32), b.MemberAccessor("s", "a"))),
+ b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect = R"(
+ auto* expect = R"(
struct strided_arr {
@size(32)
el : f32,
@@ -295,40 +280,36 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, ReadStorageDefaultStridedArray) {
- // struct S {
- // a : @stride(4) array<f32, 4>,
- // };
- // @group(0) @binding(0) var<storage> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // let a : @stride(4) array<f32, 4> = s.a;
- // let b : f32 = s.a[1];
- // }
- ProgramBuilder b;
- auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(4))});
- b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
- b.GroupAndBinding(0, 0));
- b.Func(
- "f", {}, b.ty.void_(),
- {
- b.Decl(b.Let("a", b.ty.array<f32, 4>(4), b.MemberAccessor("s", "a"))),
- b.Decl(b.Let("b", b.ty.f32(),
- b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // struct S {
+ // a : @stride(4) array<f32, 4>,
+ // };
+ // @group(0) @binding(0) var<storage> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // let a : @stride(4) array<f32, 4> = s.a;
+ // let b : f32 = s.a[1];
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(4))});
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Let("a", b.ty.array<f32, 4>(4), b.MemberAccessor("s", "a"))),
+ b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : array<f32, 4>,
}
@@ -342,44 +323,41 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, WriteStorageStridedArray) {
- // struct S {
- // a : @stride(32) array<f32, 4>,
- // };
- // @group(0) @binding(0) var<storage, read_write> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // s.a = @stride(32) array<f32, 4>();
- // s.a = @stride(32) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
- // s.a[1] = 5.0;
- // }
- ProgramBuilder b;
- auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
- b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
- ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
- b.Func(
- "f", {}, b.ty.void_(),
- {
- b.Assign(b.MemberAccessor("s", "a"),
- b.Construct(b.ty.array<f32, 4>(32))),
- b.Assign(b.MemberAccessor("s", "a"),
- b.Construct(b.ty.array<f32, 4>(32), 1.0f, 2.0f, 3.0f, 4.0f)),
- b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // struct S {
+ // a : @stride(32) array<f32, 4>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // s.a = @stride(32) array<f32, 4>();
+ // s.a = @stride(32) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
+ // s.a[1] = 5.0;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite,
+ b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Assign(b.MemberAccessor("s", "a"), b.Construct(b.ty.array<f32, 4>(32))),
+ b.Assign(b.MemberAccessor("s", "a"),
+ b.Construct(b.ty.array<f32, 4>(32), 1.0f, 2.0f, 3.0f, 4.0f)),
+ b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect =
- R"(
+ auto* expect =
+ R"(
struct strided_arr {
@size(32)
el : f32,
@@ -399,44 +377,41 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, WriteStorageDefaultStridedArray) {
- // struct S {
- // a : @stride(4) array<f32, 4>,
- // };
- // @group(0) @binding(0) var<storage, read_write> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // s.a = @stride(4) array<f32, 4>();
- // s.a = @stride(4) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
- // s.a[1] = 5.0;
- // }
- ProgramBuilder b;
- auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(4))});
- b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
- ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
- b.Func(
- "f", {}, b.ty.void_(),
- {
- b.Assign(b.MemberAccessor("s", "a"),
- b.Construct(b.ty.array<f32, 4>(4))),
- b.Assign(b.MemberAccessor("s", "a"),
- b.Construct(b.ty.array<f32, 4>(4), 1.0f, 2.0f, 3.0f, 4.0f)),
- b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // struct S {
+ // a : @stride(4) array<f32, 4>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // s.a = @stride(4) array<f32, 4>();
+ // s.a = @stride(4) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
+ // s.a[1] = 5.0;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(4))});
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite,
+ b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Assign(b.MemberAccessor("s", "a"), b.Construct(b.ty.array<f32, 4>(4))),
+ b.Assign(b.MemberAccessor("s", "a"),
+ b.Construct(b.ty.array<f32, 4>(4), 1.0f, 2.0f, 3.0f, 4.0f)),
+ b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect =
- R"(
+ auto* expect =
+ R"(
struct S {
a : array<f32, 4>,
}
@@ -451,50 +426,46 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, ReadWriteViaPointerLets) {
- // struct S {
- // a : @stride(32) array<f32, 4>,
- // };
- // @group(0) @binding(0) var<storage, read_write> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // let a = &s.a;
- // let b = &*&*(a);
- // let c = *b;
- // let d = (*b)[1];
- // (*b) = @stride(32) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
- // (*b)[1] = 5.0;
- // }
- ProgramBuilder b;
- auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
- b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
- ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
- b.Func(
- "f", {}, b.ty.void_(),
- {
- b.Decl(b.Let("a", nullptr, b.AddressOf(b.MemberAccessor("s", "a")))),
- b.Decl(b.Let("b", nullptr,
- b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))),
- b.Decl(b.Let("c", nullptr, b.Deref("b"))),
- b.Decl(b.Let("d", nullptr, b.IndexAccessor(b.Deref("b"), 1))),
- b.Assign(b.Deref("b"),
- b.Construct(b.ty.array<f32, 4>(32), 1.0f, 2.0f, 3.0f, 4.0f)),
- b.Assign(b.IndexAccessor(b.Deref("b"), 1), 5.0f),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // struct S {
+ // a : @stride(32) array<f32, 4>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // let a = &s.a;
+ // let b = &*&*(a);
+ // let c = *b;
+ // let d = (*b)[1];
+ // (*b) = @stride(32) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
+ // (*b)[1] = 5.0;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite,
+ b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Let("a", nullptr, b.AddressOf(b.MemberAccessor("s", "a")))),
+ b.Decl(b.Let("b", nullptr, b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))),
+ b.Decl(b.Let("c", nullptr, b.Deref("b"))),
+ b.Decl(b.Let("d", nullptr, b.IndexAccessor(b.Deref("b"), 1))),
+ b.Assign(b.Deref("b"), b.Construct(b.ty.array<f32, 4>(32), 1.0f, 2.0f, 3.0f, 4.0f)),
+ b.Assign(b.IndexAccessor(b.Deref("b"), 1), 5.0f),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect =
- R"(
+ auto* expect =
+ R"(
struct strided_arr {
@size(32)
el : f32,
@@ -515,50 +486,46 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, PrivateAliasedStridedArray) {
- // type ARR = @stride(32) array<f32, 4>;
- // struct S {
- // a : ARR,
- // };
- // @group(0) @binding(0) var<storage, read_write> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // let a : ARR = s.a;
- // let b : f32 = s.a[1];
- // s.a = ARR();
- // s.a = ARR(1.0, 2.0, 3.0, 4.0);
- // s.a[1] = 5.0;
- // }
- ProgramBuilder b;
- b.Alias("ARR", b.ty.array<f32, 4>(32));
- auto* S = b.Structure("S", {b.Member("a", b.ty.type_name("ARR"))});
- b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
- ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
- b.Func(
- "f", {}, b.ty.void_(),
- {
- b.Decl(b.Let("a", b.ty.type_name("ARR"), b.MemberAccessor("s", "a"))),
- b.Decl(b.Let("b", b.ty.f32(),
- b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
- b.Assign(b.MemberAccessor("s", "a"),
- b.Construct(b.ty.type_name("ARR"))),
- b.Assign(b.MemberAccessor("s", "a"),
- b.Construct(b.ty.type_name("ARR"), 1.0f, 2.0f, 3.0f, 4.0f)),
- b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // type ARR = @stride(32) array<f32, 4>;
+ // struct S {
+ // a : ARR,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // let a : ARR = s.a;
+ // let b : f32 = s.a[1];
+ // s.a = ARR();
+ // s.a = ARR(1.0, 2.0, 3.0, 4.0);
+ // s.a[1] = 5.0;
+ // }
+ ProgramBuilder b;
+ b.Alias("ARR", b.ty.array<f32, 4>(32));
+ auto* S = b.Structure("S", {b.Member("a", b.ty.type_name("ARR"))});
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite,
+ b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Let("a", b.ty.type_name("ARR"), b.MemberAccessor("s", "a"))),
+ b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
+ b.Assign(b.MemberAccessor("s", "a"), b.Construct(b.ty.type_name("ARR"))),
+ b.Assign(b.MemberAccessor("s", "a"),
+ b.Construct(b.ty.type_name("ARR"), 1.0f, 2.0f, 3.0f, 4.0f)),
+ b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect = R"(
+ auto* expect = R"(
struct strided_arr {
@size(32)
el : f32,
@@ -582,79 +549,76 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, PrivateNestedStridedArray) {
- // type ARR_A = @stride(8) array<f32, 2>;
- // type ARR_B = @stride(128) array<@stride(16) array<ARR_A, 3>, 4>;
- // struct S {
- // a : ARR_B,
- // };
- // @group(0) @binding(0) var<storage, read_write> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // let a : ARR_B = s.a;
- // let b : array<@stride(8) array<f32, 2>, 3> = s.a[3];
- // let c = s.a[3][2];
- // let d = s.a[3][2][1];
- // s.a = ARR_B();
- // s.a[3][2][1] = 5.0;
- // }
+ // type ARR_A = @stride(8) array<f32, 2>;
+ // type ARR_B = @stride(128) array<@stride(16) array<ARR_A, 3>, 4>;
+ // struct S {
+ // a : ARR_B,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // let a : ARR_B = s.a;
+ // let b : array<@stride(8) array<f32, 2>, 3> = s.a[3];
+ // let c = s.a[3][2];
+ // let d = s.a[3][2][1];
+ // s.a = ARR_B();
+ // s.a[3][2][1] = 5.0;
+ // }
- ProgramBuilder b;
- b.Alias("ARR_A", b.ty.array<f32, 2>(8));
- b.Alias("ARR_B",
- b.ty.array( //
- b.ty.array(b.ty.type_name("ARR_A"), 3, 16), //
- 4, 128));
- auto* S = b.Structure("S", {b.Member("a", b.ty.type_name("ARR_B"))});
- b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
- ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
- b.Func("f", {}, b.ty.void_(),
- {
- b.Decl(b.Let("a", b.ty.type_name("ARR_B"),
- b.MemberAccessor("s", "a"))),
- b.Decl(b.Let("b", b.ty.array(b.ty.type_name("ARR_A"), 3, 16),
- b.IndexAccessor( //
- b.MemberAccessor("s", "a"), //
- 3))),
- b.Decl(b.Let("c", b.ty.type_name("ARR_A"),
- b.IndexAccessor( //
- b.IndexAccessor( //
- b.MemberAccessor("s", "a"), //
- 3),
- 2))),
- b.Decl(b.Let("d", b.ty.f32(),
- b.IndexAccessor( //
- b.IndexAccessor( //
- b.IndexAccessor( //
- b.MemberAccessor("s", "a"), //
- 3),
- 2),
- 1))),
- b.Assign(b.MemberAccessor("s", "a"),
- b.Construct(b.ty.type_name("ARR_B"))),
- b.Assign(b.IndexAccessor( //
- b.IndexAccessor( //
- b.IndexAccessor( //
- b.MemberAccessor("s", "a"), //
- 3),
- 2),
- 1),
- 5.0f),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ ProgramBuilder b;
+ b.Alias("ARR_A", b.ty.array<f32, 2>(8));
+ b.Alias("ARR_B",
+ b.ty.array( //
+ b.ty.array(b.ty.type_name("ARR_A"), 3, 16), //
+ 4, 128));
+ auto* S = b.Structure("S", {b.Member("a", b.ty.type_name("ARR_B"))});
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite,
+ b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Let("a", b.ty.type_name("ARR_B"), b.MemberAccessor("s", "a"))),
+ b.Decl(b.Let("b", b.ty.array(b.ty.type_name("ARR_A"), 3, 16),
+ b.IndexAccessor( //
+ b.MemberAccessor("s", "a"), //
+ 3))),
+ b.Decl(b.Let("c", b.ty.type_name("ARR_A"),
+ b.IndexAccessor( //
+ b.IndexAccessor( //
+ b.MemberAccessor("s", "a"), //
+ 3),
+ 2))),
+ b.Decl(b.Let("d", b.ty.f32(),
+ b.IndexAccessor( //
+ b.IndexAccessor( //
+ b.IndexAccessor( //
+ b.MemberAccessor("s", "a"), //
+ 3),
+ 2),
+ 1))),
+ b.Assign(b.MemberAccessor("s", "a"), b.Construct(b.ty.type_name("ARR_B"))),
+ b.Assign(b.IndexAccessor( //
+ b.IndexAccessor( //
+ b.IndexAccessor( //
+ b.MemberAccessor("s", "a"), //
+ 3),
+ 2),
+ 1),
+ 5.0f),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect =
- R"(
+ auto* expect =
+ R"(
struct strided_arr {
@size(8)
el : f32,
@@ -686,10 +650,9 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace tint::transform
diff --git a/src/tint/transform/decompose_strided_matrix.cc b/src/tint/transform/decompose_strided_matrix.cc
index fd7194d..85e8d9e 100644
--- a/src/tint/transform/decompose_strided_matrix.cc
+++ b/src/tint/transform/decompose_strided_matrix.cc
@@ -32,28 +32,26 @@
/// MatrixInfo describes a matrix member with a custom stride
struct MatrixInfo {
- /// The stride in bytes between columns of the matrix
- uint32_t stride = 0;
- /// The type of the matrix
- const sem::Matrix* matrix = nullptr;
+ /// The stride in bytes between columns of the matrix
+ uint32_t stride = 0;
+ /// The type of the matrix
+ const sem::Matrix* matrix = nullptr;
- /// @returns a new ast::Array that holds an vector column for each row of the
- /// matrix.
- const ast::Array* array(ProgramBuilder* b) const {
- return b->ty.array(b->ty.vec<ProgramBuilder::f32>(matrix->rows()),
- matrix->columns(), stride);
- }
-
- /// Equality operator
- bool operator==(const MatrixInfo& info) const {
- return stride == info.stride && matrix == info.matrix;
- }
- /// Hash function
- struct Hasher {
- size_t operator()(const MatrixInfo& t) const {
- return utils::Hash(t.stride, t.matrix);
+ /// @returns a new ast::Array that holds an vector column for each row of the
+ /// matrix.
+ const ast::Array* array(ProgramBuilder* b) const {
+ return b->ty.array(b->ty.vec<ProgramBuilder::f32>(matrix->rows()), matrix->columns(),
+ stride);
}
- };
+
+ /// Equality operator
+ bool operator==(const MatrixInfo& info) const {
+ return stride == info.stride && matrix == info.matrix;
+ }
+ /// Hash function
+ struct Hasher {
+ size_t operator()(const MatrixInfo& t) const { return utils::Hash(t.stride, t.matrix); }
+ };
};
/// Return type of the callback function of GatherCustomStrideMatrixMembers
@@ -71,33 +69,33 @@
/// scanning will continue.
template <typename F>
void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) {
- for (auto* node : program->ASTNodes().Objects()) {
- if (auto* str = node->As<ast::Struct>()) {
- auto* str_ty = program->Sem().Get(str);
- if (!str_ty->UsedAs(ast::StorageClass::kUniform) &&
- !str_ty->UsedAs(ast::StorageClass::kStorage)) {
- continue;
- }
- for (auto* member : str_ty->Members()) {
- auto* matrix = member->Type()->As<sem::Matrix>();
- if (!matrix) {
- continue;
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (auto* str = node->As<ast::Struct>()) {
+ auto* str_ty = program->Sem().Get(str);
+ if (!str_ty->UsedAs(ast::StorageClass::kUniform) &&
+ !str_ty->UsedAs(ast::StorageClass::kStorage)) {
+ continue;
+ }
+ for (auto* member : str_ty->Members()) {
+ auto* matrix = member->Type()->As<sem::Matrix>();
+ if (!matrix) {
+ continue;
+ }
+ auto* attr =
+ ast::GetAttribute<ast::StrideAttribute>(member->Declaration()->attributes);
+ if (!attr) {
+ continue;
+ }
+ uint32_t stride = attr->stride;
+ if (matrix->ColumnStride() == stride) {
+ continue;
+ }
+ if (callback(member, matrix, stride) == GatherResult::kStop) {
+ return;
+ }
+ }
}
- auto* attr = ast::GetAttribute<ast::StrideAttribute>(
- member->Declaration()->attributes);
- if (!attr) {
- continue;
- }
- uint32_t stride = attr->stride;
- if (matrix->ColumnStride() == stride) {
- continue;
- }
- if (callback(member, matrix, stride) == GatherResult::kStop) {
- return;
- }
- }
}
- }
}
} // namespace
@@ -106,144 +104,133 @@
DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
-bool DecomposeStridedMatrix::ShouldRun(const Program* program,
- const DataMap&) const {
- bool should_run = false;
- GatherCustomStrideMatrixMembers(
- program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) {
+bool DecomposeStridedMatrix::ShouldRun(const Program* program, const DataMap&) const {
+ bool should_run = false;
+ GatherCustomStrideMatrixMembers(program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) {
should_run = true;
return GatherResult::kStop;
- });
- return should_run;
+ });
+ return should_run;
}
-void DecomposeStridedMatrix::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- // Scan the program for all storage and uniform structure matrix members with
- // a custom stride attribute. Replace these matrices with an equivalent array,
- // and populate the `decomposed` map with the members that have been replaced.
- std::unordered_map<const ast::StructMember*, MatrixInfo> decomposed;
- GatherCustomStrideMatrixMembers(
- ctx.src, [&](const sem::StructMember* member, sem::Matrix* matrix,
- uint32_t stride) {
- // We've got ourselves a struct member of a matrix type with a custom
- // stride. Replace this with an array of column vectors.
- MatrixInfo info{stride, matrix};
- auto* replacement = ctx.dst->Member(
- member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
- ctx.Replace(member->Declaration(), replacement);
- decomposed.emplace(member->Declaration(), info);
- return GatherResult::kContinue;
- });
+void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ // Scan the program for all storage and uniform structure matrix members with
+ // a custom stride attribute. Replace these matrices with an equivalent array,
+ // and populate the `decomposed` map with the members that have been replaced.
+ std::unordered_map<const ast::StructMember*, MatrixInfo> decomposed;
+ GatherCustomStrideMatrixMembers(
+ ctx.src, [&](const sem::StructMember* member, sem::Matrix* matrix, uint32_t stride) {
+ // We've got ourselves a struct member of a matrix type with a custom
+ // stride. Replace this with an array of column vectors.
+ MatrixInfo info{stride, matrix};
+ auto* replacement =
+ ctx.dst->Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
+ ctx.Replace(member->Declaration(), replacement);
+ decomposed.emplace(member->Declaration(), info);
+ return GatherResult::kContinue;
+ });
- // For all expressions where a single matrix column vector was indexed, we can
- // preserve these without calling conversion functions.
- // Example:
- // ssbo.mat[2] -> ssbo.mat[2]
- ctx.ReplaceAll([&](const ast::IndexAccessorExpression* expr)
- -> const ast::IndexAccessorExpression* {
- if (auto* access =
- ctx.src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
- auto it = decomposed.find(access->Member()->Declaration());
- if (it != decomposed.end()) {
- auto* obj = ctx.CloneWithoutTransform(expr->object);
- auto* idx = ctx.Clone(expr->index);
- return ctx.dst->IndexAccessor(obj, idx);
- }
- }
- return nullptr;
- });
-
- // For all struct member accesses to the matrix on the LHS of an assignment,
- // we need to convert the matrix to the array before assigning to the
- // structure.
- // Example:
- // ssbo.mat = mat_to_arr(m)
- std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
- ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt)
- -> const ast::Statement* {
- if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
- auto it = decomposed.find(access->Member()->Declaration());
- if (it == decomposed.end()) {
- return nullptr;
- }
- MatrixInfo info = it->second;
- auto fn = utils::GetOrCreate(mat_to_arr, info, [&] {
- auto name = ctx.dst->Symbols().New(
- "mat" + std::to_string(info.matrix->columns()) + "x" +
- std::to_string(info.matrix->rows()) + "_stride_" +
- std::to_string(info.stride) + "_to_arr");
-
- auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
- auto array = [&] { return info.array(ctx.dst); };
-
- auto mat = ctx.dst->Sym("m");
- ast::ExpressionList columns(info.matrix->columns());
- for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size()); i++) {
- columns[i] = ctx.dst->IndexAccessor(mat, i);
- }
- ctx.dst->Func(name,
- {
- ctx.dst->Param(mat, matrix()),
- },
- array(),
- {
- ctx.dst->Return(ctx.dst->Construct(array(), columns)),
- });
- return name;
- });
- auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
- auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs));
- return ctx.dst->Assign(lhs, rhs);
- }
- return nullptr;
- });
-
- // For all other struct member accesses, we need to convert the array to the
- // matrix type. Example:
- // m = arr_to_mat(ssbo.mat)
- std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
- ctx.ReplaceAll(
- [&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
- if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) {
- auto it = decomposed.find(access->Member()->Declaration());
- if (it == decomposed.end()) {
- return nullptr;
- }
- MatrixInfo info = it->second;
- auto fn = utils::GetOrCreate(arr_to_mat, info, [&] {
- auto name = ctx.dst->Symbols().New(
- "arr_to_mat" + std::to_string(info.matrix->columns()) + "x" +
- std::to_string(info.matrix->rows()) + "_stride_" +
- std::to_string(info.stride));
-
- auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
- auto array = [&] { return info.array(ctx.dst); };
-
- auto arr = ctx.dst->Sym("arr");
- ast::ExpressionList columns(info.matrix->columns());
- for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size());
- i++) {
- columns[i] = ctx.dst->IndexAccessor(arr, i);
+ // For all expressions where a single matrix column vector was indexed, we can
+ // preserve these without calling conversion functions.
+ // Example:
+ // ssbo.mat[2] -> ssbo.mat[2]
+ ctx.ReplaceAll(
+ [&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* {
+ if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
+ auto it = decomposed.find(access->Member()->Declaration());
+ if (it != decomposed.end()) {
+ auto* obj = ctx.CloneWithoutTransform(expr->object);
+ auto* idx = ctx.Clone(expr->index);
+ return ctx.dst->IndexAccessor(obj, idx);
+ }
}
- ctx.dst->Func(
- name,
- {
- ctx.dst->Param(arr, array()),
- },
- matrix(),
- {
- ctx.dst->Return(ctx.dst->Construct(matrix(), columns)),
- });
- return name;
- });
- return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr));
+ return nullptr;
+ });
+
+ // For all struct member accesses to the matrix on the LHS of an assignment,
+ // we need to convert the matrix to the array before assigning to the
+ // structure.
+ // Example:
+ // ssbo.mat = mat_to_arr(m)
+ std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
+ ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* {
+ if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
+ auto it = decomposed.find(access->Member()->Declaration());
+ if (it == decomposed.end()) {
+ return nullptr;
+ }
+ MatrixInfo info = it->second;
+ auto fn = utils::GetOrCreate(mat_to_arr, info, [&] {
+ auto name =
+ ctx.dst->Symbols().New("mat" + std::to_string(info.matrix->columns()) + "x" +
+ std::to_string(info.matrix->rows()) + "_stride_" +
+ std::to_string(info.stride) + "_to_arr");
+
+ auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
+ auto array = [&] { return info.array(ctx.dst); };
+
+ auto mat = ctx.dst->Sym("m");
+ ast::ExpressionList columns(info.matrix->columns());
+ for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size()); i++) {
+ columns[i] = ctx.dst->IndexAccessor(mat, i);
+ }
+ ctx.dst->Func(name,
+ {
+ ctx.dst->Param(mat, matrix()),
+ },
+ array(),
+ {
+ ctx.dst->Return(ctx.dst->Construct(array(), columns)),
+ });
+ return name;
+ });
+ auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
+ auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs));
+ return ctx.dst->Assign(lhs, rhs);
}
return nullptr;
- });
+ });
- ctx.Clone();
+ // For all other struct member accesses, we need to convert the array to the
+ // matrix type. Example:
+ // m = arr_to_mat(ssbo.mat)
+ std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
+ ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
+ if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) {
+ auto it = decomposed.find(access->Member()->Declaration());
+ if (it == decomposed.end()) {
+ return nullptr;
+ }
+ MatrixInfo info = it->second;
+ auto fn = utils::GetOrCreate(arr_to_mat, info, [&] {
+ auto name = ctx.dst->Symbols().New(
+ "arr_to_mat" + std::to_string(info.matrix->columns()) + "x" +
+ std::to_string(info.matrix->rows()) + "_stride_" + std::to_string(info.stride));
+
+ auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
+ auto array = [&] { return info.array(ctx.dst); };
+
+ auto arr = ctx.dst->Sym("arr");
+ ast::ExpressionList columns(info.matrix->columns());
+ for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size()); i++) {
+ columns[i] = ctx.dst->IndexAccessor(arr, i);
+ }
+ ctx.dst->Func(name,
+ {
+ ctx.dst->Param(arr, array()),
+ },
+ matrix(),
+ {
+ ctx.dst->Return(ctx.dst->Construct(matrix(), columns)),
+ });
+ return name;
+ });
+ return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr));
+ }
+ return nullptr;
+ });
+
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/decompose_strided_matrix.h b/src/tint/transform/decompose_strided_matrix.h
index bcde5aa..40e9c3e 100644
--- a/src/tint/transform/decompose_strided_matrix.h
+++ b/src/tint/transform/decompose_strided_matrix.h
@@ -27,31 +27,27 @@
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
-class DecomposeStridedMatrix final
- : public Castable<DecomposeStridedMatrix, Transform> {
- public:
- /// Constructor
- DecomposeStridedMatrix();
+class DecomposeStridedMatrix final : public Castable<DecomposeStridedMatrix, Transform> {
+ public:
+ /// Constructor
+ DecomposeStridedMatrix();
- /// Destructor
- ~DecomposeStridedMatrix() override;
+ /// Destructor
+ ~DecomposeStridedMatrix() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/decompose_strided_matrix_test.cc b/src/tint/transform/decompose_strided_matrix_test.cc
index 4246c6e..8784839 100644
--- a/src/tint/transform/decompose_strided_matrix_test.cc
+++ b/src/tint/transform/decompose_strided_matrix_test.cc
@@ -31,64 +31,61 @@
using f32 = ProgramBuilder::f32;
TEST_F(DecomposeStridedMatrixTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
+ EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
}
TEST_F(DecomposeStridedMatrixTest, ShouldRunNonStridedMatrox) {
- auto* src = R"(
+ auto* src = R"(
var<private> m : mat3x2<f32>;
)";
- EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
+ EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
}
TEST_F(DecomposeStridedMatrixTest, Empty) {
- auto* src = R"()";
- auto* expect = src;
+ auto* src = R"()";
+ auto* expect = src;
- auto got = Run<DecomposeStridedMatrix>(src);
+ auto got = Run<DecomposeStridedMatrix>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix) {
- // struct S {
- // @offset(16) @stride(32)
- // @internal(ignore_stride_attribute)
- // m : mat2x2<f32>,
- // };
- // @group(0) @binding(0) var<uniform> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // let x : mat2x2<f32> = s.m;
- // }
- ProgramBuilder b;
- auto* S = b.Structure(
- "S",
- {
- b.Member(
- "m", b.ty.mat2x2<f32>(),
- {
- b.create<ast::StructMemberOffsetAttribute>(16),
- b.create<ast::StrideAttribute>(32),
- b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
- }),
- });
- b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
- b.GroupAndBinding(0, 0));
- b.Func("f", {}, b.ty.void_(),
- {
- b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // struct S {
+ // @offset(16) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // let x : mat2x2<f32> = s.m;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S", {
+ b.Member("m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetAttribute>(16),
+ b.create<ast::StrideAttribute>(32),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform, b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect = R"(
+ auto* expect = R"(
struct S {
@size(16)
padding : u32,
@@ -107,49 +104,44 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn) {
- // struct S {
- // @offset(16) @stride(32)
- // @internal(ignore_stride_attribute)
- // m : mat2x2<f32>,
- // };
- // @group(0) @binding(0) var<uniform> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // let x : vec2<f32> = s.m[1];
- // }
- ProgramBuilder b;
- auto* S = b.Structure(
- "S",
- {
- b.Member(
- "m", b.ty.mat2x2<f32>(),
- {
- b.create<ast::StructMemberOffsetAttribute>(16),
- b.create<ast::StrideAttribute>(32),
- b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
- }),
- });
- b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
- b.GroupAndBinding(0, 0));
- b.Func("f", {}, b.ty.void_(),
- {
- b.Decl(b.Let("x", b.ty.vec2<f32>(),
- b.IndexAccessor(b.MemberAccessor("s", "m"), 1))),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // struct S {
+ // @offset(16) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // let x : vec2<f32> = s.m[1];
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S", {
+ b.Member("m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetAttribute>(16),
+ b.create<ast::StrideAttribute>(32),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform, b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Let("x", b.ty.vec2<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect = R"(
+ auto* expect = R"(
struct S {
@size(16)
padding : u32,
@@ -164,48 +156,44 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride) {
- // struct S {
- // @offset(16) @stride(8)
- // @internal(ignore_stride_attribute)
- // m : mat2x2<f32>,
- // };
- // @group(0) @binding(0) var<uniform> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // let x : mat2x2<f32> = s.m;
- // }
- ProgramBuilder b;
- auto* S = b.Structure(
- "S",
- {
- b.Member(
- "m", b.ty.mat2x2<f32>(),
- {
- b.create<ast::StructMemberOffsetAttribute>(16),
- b.create<ast::StrideAttribute>(8),
- b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
- }),
- });
- b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
- b.GroupAndBinding(0, 0));
- b.Func("f", {}, b.ty.void_(),
- {
- b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // struct S {
+ // @offset(16) @stride(8)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // let x : mat2x2<f32> = s.m;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S", {
+ b.Member("m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetAttribute>(16),
+ b.create<ast::StrideAttribute>(8),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform, b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect = R"(
+ auto* expect = R"(
struct S {
@size(16)
padding : u32,
@@ -221,48 +209,45 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) {
- // struct S {
- // @offset(8) @stride(32)
- // @internal(ignore_stride_attribute)
- // m : mat2x2<f32>,
- // };
- // @group(0) @binding(0) var<storage, read_write> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // let x : mat2x2<f32> = s.m;
- // }
- ProgramBuilder b;
- auto* S = b.Structure(
- "S",
- {
- b.Member(
- "m", b.ty.mat2x2<f32>(),
- {
- b.create<ast::StructMemberOffsetAttribute>(8),
- b.create<ast::StrideAttribute>(32),
- b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
- }),
- });
- b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
- ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
- b.Func("f", {}, b.ty.void_(),
- {
- b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // struct S {
+ // @offset(8) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // let x : mat2x2<f32> = s.m;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S", {
+ b.Member("m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetAttribute>(8),
+ b.create<ast::StrideAttribute>(32),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite,
+ b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect = R"(
+ auto* expect = R"(
struct S {
@size(8)
padding : u32,
@@ -281,49 +266,45 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadStorageColumn) {
- // struct S {
- // @offset(16) @stride(32)
- // @internal(ignore_stride_attribute)
- // m : mat2x2<f32>,
- // };
- // @group(0) @binding(0) var<storage, read_write> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // let x : vec2<f32> = s.m[1];
- // }
- ProgramBuilder b;
- auto* S = b.Structure(
- "S",
- {
- b.Member(
- "m", b.ty.mat2x2<f32>(),
- {
- b.create<ast::StructMemberOffsetAttribute>(16),
- b.create<ast::StrideAttribute>(32),
- b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
- }),
- });
- b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
- ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
- b.Func("f", {}, b.ty.void_(),
- {
- b.Decl(b.Let("x", b.ty.vec2<f32>(),
- b.IndexAccessor(b.MemberAccessor("s", "m"), 1))),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // struct S {
+ // @offset(16) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // let x : vec2<f32> = s.m[1];
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S", {
+ b.Member("m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetAttribute>(16),
+ b.create<ast::StrideAttribute>(32),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite,
+ b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Let("x", b.ty.vec2<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect = R"(
+ auto* expect = R"(
struct S {
@size(16)
padding : u32,
@@ -338,50 +319,46 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix) {
- // struct S {
- // @offset(8) @stride(32)
- // @internal(ignore_stride_attribute)
- // m : mat2x2<f32>,
- // };
- // @group(0) @binding(0) var<storage, read_write> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
- // }
- ProgramBuilder b;
- auto* S = b.Structure(
- "S",
- {
- b.Member(
- "m", b.ty.mat2x2<f32>(),
- {
- b.create<ast::StructMemberOffsetAttribute>(8),
- b.create<ast::StrideAttribute>(32),
- b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
- }),
- });
- b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
- ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
- b.Func("f", {}, b.ty.void_(),
- {
- b.Assign(b.MemberAccessor("s", "m"),
- b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
- b.vec2<f32>(3.0f, 4.0f))),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // struct S {
+ // @offset(8) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S", {
+ b.Member("m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetAttribute>(8),
+ b.create<ast::StrideAttribute>(32),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite,
+ b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Assign(b.MemberAccessor("s", "m"),
+ b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f), b.vec2<f32>(3.0f, 4.0f))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect = R"(
+ auto* expect = R"(
struct S {
@size(8)
padding : u32,
@@ -400,49 +377,45 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn) {
- // struct S {
- // @offset(8) @stride(32)
- // @internal(ignore_stride_attribute)
- // m : mat2x2<f32>,
- // };
- // @group(0) @binding(0) var<storage, read_write> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // s.m[1] = vec2<f32>(1.0, 2.0);
- // }
- ProgramBuilder b;
- auto* S = b.Structure(
- "S",
- {
- b.Member(
- "m", b.ty.mat2x2<f32>(),
- {
- b.create<ast::StructMemberOffsetAttribute>(8),
- b.create<ast::StrideAttribute>(32),
- b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
- }),
- });
- b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
- ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
- b.Func("f", {}, b.ty.void_(),
- {
- b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1),
- b.vec2<f32>(1.0f, 2.0f)),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // struct S {
+ // @offset(8) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // s.m[1] = vec2<f32>(1.0, 2.0);
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S", {
+ b.Member("m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetAttribute>(8),
+ b.create<ast::StrideAttribute>(32),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite,
+ b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1), b.vec2<f32>(1.0f, 2.0f)),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect = R"(
+ auto* expect = R"(
struct S {
@size(8)
padding : u32,
@@ -457,63 +430,58 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadWriteViaPointerLets) {
- // struct S {
- // @offset(8) @stride(32)
- // @internal(ignore_stride_attribute)
- // m : mat2x2<f32>,
- // };
- // @group(0) @binding(0) var<storage, read_write> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // let a = &s.m;
- // let b = &*&*(a);
- // let x = *b;
- // let y = (*b)[1];
- // let z = x[1];
- // (*b) = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
- // (*b)[1] = vec2<f32>(5.0, 6.0);
- // }
- ProgramBuilder b;
- auto* S = b.Structure(
- "S",
- {
- b.Member(
- "m", b.ty.mat2x2<f32>(),
- {
- b.create<ast::StructMemberOffsetAttribute>(8),
- b.create<ast::StrideAttribute>(32),
- b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
- }),
- });
- b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
- ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
- b.Func(
- "f", {}, b.ty.void_(),
- {
- b.Decl(b.Let("a", nullptr, b.AddressOf(b.MemberAccessor("s", "m")))),
- b.Decl(b.Let("b", nullptr,
- b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))),
- b.Decl(b.Let("x", nullptr, b.Deref("b"))),
- b.Decl(b.Let("y", nullptr, b.IndexAccessor(b.Deref("b"), 1))),
- b.Decl(b.Let("z", nullptr, b.IndexAccessor("x", 1))),
- b.Assign(b.Deref("b"), b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
- b.vec2<f32>(3.0f, 4.0f))),
- b.Assign(b.IndexAccessor(b.Deref("b"), 1), b.vec2<f32>(5.0f, 6.0f)),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // struct S {
+ // @offset(8) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // let a = &s.m;
+ // let b = &*&*(a);
+ // let x = *b;
+ // let y = (*b)[1];
+ // let z = x[1];
+ // (*b) = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
+ // (*b)[1] = vec2<f32>(5.0, 6.0);
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S", {
+ b.Member("m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetAttribute>(8),
+ b.create<ast::StrideAttribute>(32),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite,
+ b.GroupAndBinding(0, 0));
+ b.Func(
+ "f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Let("a", nullptr, b.AddressOf(b.MemberAccessor("s", "m")))),
+ b.Decl(b.Let("b", nullptr, b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))),
+ b.Decl(b.Let("x", nullptr, b.Deref("b"))),
+ b.Decl(b.Let("y", nullptr, b.IndexAccessor(b.Deref("b"), 1))),
+ b.Decl(b.Let("z", nullptr, b.IndexAccessor("x", 1))),
+ b.Assign(b.Deref("b"), b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f), b.vec2<f32>(3.0f, 4.0f))),
+ b.Assign(b.IndexAccessor(b.Deref("b"), 1), b.vec2<f32>(5.0f, 6.0f)),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect = R"(
+ auto* expect = R"(
struct S {
@size(8)
padding : u32,
@@ -540,47 +508,44 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) {
- // struct S {
- // @offset(8) @stride(32)
- // @internal(ignore_stride_attribute)
- // m : mat2x2<f32>,
- // };
- // var<private> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // let x : mat2x2<f32> = s.m;
- // }
- ProgramBuilder b;
- auto* S = b.Structure(
- "S",
- {
- b.Member(
- "m", b.ty.mat2x2<f32>(),
- {
- b.create<ast::StructMemberOffsetAttribute>(8),
- b.create<ast::StrideAttribute>(32),
- b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
- }),
- });
- b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
- b.Func("f", {}, b.ty.void_(),
- {
- b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // struct S {
+ // @offset(8) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // var<private> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // let x : mat2x2<f32> = s.m;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S", {
+ b.Member("m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetAttribute>(8),
+ b.create<ast::StrideAttribute>(32),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect = R"(
+ auto* expect = R"(
struct S {
@size(8)
padding : u32,
@@ -596,49 +561,45 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, WritePrivateMatrix) {
- // struct S {
- // @offset(8) @stride(32)
- // @internal(ignore_stride_attribute)
- // m : mat2x2<f32>,
- // };
- // var<private> s : S;
- //
- // @stage(compute) @workgroup_size(1)
- // fn f() {
- // s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
- // }
- ProgramBuilder b;
- auto* S = b.Structure(
- "S",
- {
- b.Member(
- "m", b.ty.mat2x2<f32>(),
- {
- b.create<ast::StructMemberOffsetAttribute>(8),
- b.create<ast::StrideAttribute>(32),
- b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
- }),
- });
- b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
- b.Func("f", {}, b.ty.void_(),
- {
- b.Assign(b.MemberAccessor("s", "m"),
- b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
- b.vec2<f32>(3.0f, 4.0f))),
- },
- {
- b.Stage(ast::PipelineStage::kCompute),
- b.WorkgroupSize(1),
- });
+ // struct S {
+ // @offset(8) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // var<private> s : S;
+ //
+ // @stage(compute) @workgroup_size(1)
+ // fn f() {
+ // s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S", {
+ b.Member("m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetAttribute>(8),
+ b.create<ast::StrideAttribute>(32),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Assign(b.MemberAccessor("s", "m"),
+ b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f), b.vec2<f32>(3.0f, 4.0f))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
- auto* expect = R"(
+ auto* expect = R"(
struct S {
@size(8)
padding : u32,
@@ -654,10 +615,9 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
- Program(std::move(b)));
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/expand_compound_assignment.cc b/src/tint/transform/expand_compound_assignment.cc
index eda4164..ccc4f92 100644
--- a/src/tint/transform/expand_compound_assignment.cc
+++ b/src/tint/transform/expand_compound_assignment.cc
@@ -33,162 +33,152 @@
ExpandCompoundAssignment::~ExpandCompoundAssignment() = default;
-bool ExpandCompoundAssignment::ShouldRun(const Program* program,
- const DataMap&) const {
- for (auto* node : program->ASTNodes().Objects()) {
- if (node->IsAnyOf<ast::CompoundAssignmentStatement,
- ast::IncrementDecrementStatement>()) {
- return true;
+bool ExpandCompoundAssignment::ShouldRun(const Program* program, const DataMap&) const {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (node->IsAnyOf<ast::CompoundAssignmentStatement, ast::IncrementDecrementStatement>()) {
+ return true;
+ }
}
- }
- return false;
+ return false;
}
/// Internal class used to collect statement expansions during the transform.
class State {
- private:
- /// The clone context.
- CloneContext& ctx;
+ private:
+ /// The clone context.
+ CloneContext& ctx;
- /// The program builder.
- ProgramBuilder& b;
+ /// The program builder.
+ ProgramBuilder& b;
- /// The HoistToDeclBefore helper instance.
- HoistToDeclBefore hoist_to_decl_before;
+ /// The HoistToDeclBefore helper instance.
+ HoistToDeclBefore hoist_to_decl_before;
- public:
- /// Constructor
- /// @param context the clone context
- explicit State(CloneContext& context)
- : ctx(context), b(*ctx.dst), hoist_to_decl_before(ctx) {}
+ public:
+ /// Constructor
+ /// @param context the clone context
+ explicit State(CloneContext& context) : ctx(context), b(*ctx.dst), hoist_to_decl_before(ctx) {}
- /// Replace `stmt` with a regular assignment statement of the form:
- /// lhs = lhs op rhs
- /// The LHS expression will only be evaluated once, and any side effects will
- /// be hoisted to `let` declarations above the assignment statement.
- /// @param stmt the statement to replace
- /// @param lhs the lhs expression from the source statement
- /// @param rhs the rhs expression in the destination module
- /// @param op the binary operator
- void Expand(const ast::Statement* stmt,
- const ast::Expression* lhs,
- const ast::Expression* rhs,
- ast::BinaryOp op) {
- // Helper function to create the new LHS expression. This will be called
- // twice when building the non-compound assignment statement, so must
- // not produce expressions that cause side effects.
- std::function<const ast::Expression*()> new_lhs;
+ /// Replace `stmt` with a regular assignment statement of the form:
+ /// lhs = lhs op rhs
+ /// The LHS expression will only be evaluated once, and any side effects will
+ /// be hoisted to `let` declarations above the assignment statement.
+ /// @param stmt the statement to replace
+ /// @param lhs the lhs expression from the source statement
+ /// @param rhs the rhs expression in the destination module
+ /// @param op the binary operator
+ void Expand(const ast::Statement* stmt,
+ const ast::Expression* lhs,
+ const ast::Expression* rhs,
+ ast::BinaryOp op) {
+ // Helper function to create the new LHS expression. This will be called
+ // twice when building the non-compound assignment statement, so must
+ // not produce expressions that cause side effects.
+ std::function<const ast::Expression*()> new_lhs;
- // Helper function to create a variable that is a pointer to `expr`.
- auto hoist_pointer_to = [&](const ast::Expression* expr) {
- auto name = b.Sym();
- auto* ptr = b.AddressOf(ctx.Clone(expr));
- auto* decl = b.Decl(b.Let(name, nullptr, ptr));
- hoist_to_decl_before.InsertBefore(ctx.src->Sem().Get(stmt), decl);
- return name;
- };
+ // Helper function to create a variable that is a pointer to `expr`.
+ auto hoist_pointer_to = [&](const ast::Expression* expr) {
+ auto name = b.Sym();
+ auto* ptr = b.AddressOf(ctx.Clone(expr));
+ auto* decl = b.Decl(b.Let(name, nullptr, ptr));
+ hoist_to_decl_before.InsertBefore(ctx.src->Sem().Get(stmt), decl);
+ return name;
+ };
- // Helper function to hoist `expr` to a let declaration.
- auto hoist_expr_to_let = [&](const ast::Expression* expr) {
- auto name = b.Sym();
- auto* decl = b.Decl(b.Let(name, nullptr, ctx.Clone(expr)));
- hoist_to_decl_before.InsertBefore(ctx.src->Sem().Get(stmt), decl);
- return name;
- };
+ // Helper function to hoist `expr` to a let declaration.
+ auto hoist_expr_to_let = [&](const ast::Expression* expr) {
+ auto name = b.Sym();
+ auto* decl = b.Decl(b.Let(name, nullptr, ctx.Clone(expr)));
+ hoist_to_decl_before.InsertBefore(ctx.src->Sem().Get(stmt), decl);
+ return name;
+ };
- // Helper function that returns `true` if the type of `expr` is a vector.
- auto is_vec = [&](const ast::Expression* expr) {
- return ctx.src->Sem().Get(expr)->Type()->UnwrapRef()->Is<sem::Vector>();
- };
+ // Helper function that returns `true` if the type of `expr` is a vector.
+ auto is_vec = [&](const ast::Expression* expr) {
+ return ctx.src->Sem().Get(expr)->Type()->UnwrapRef()->Is<sem::Vector>();
+ };
- // Hoist the LHS expression subtree into local constants to produce a new
- // LHS that we can evaluate twice.
- // We need to special case compound assignments to vector components since
- // we cannot take the address of a vector component.
- auto* index_accessor = lhs->As<ast::IndexAccessorExpression>();
- auto* member_accessor = lhs->As<ast::MemberAccessorExpression>();
- if (lhs->Is<ast::IdentifierExpression>() ||
- (member_accessor &&
- member_accessor->structure->Is<ast::IdentifierExpression>())) {
- // This is the simple case with no side effects, so we can just use the
- // original LHS expression directly.
- // Before:
- // foo.bar += rhs;
- // After:
- // foo.bar = foo.bar + rhs;
- new_lhs = [&]() { return ctx.Clone(lhs); };
- } else if (index_accessor && is_vec(index_accessor->object)) {
- // This is the case for vector component via an array accessor. We need
- // to capture a pointer to the vector and also the index value.
- // Before:
- // v[idx()] += rhs;
- // After:
- // let vec_ptr = &v;
- // let index = idx();
- // (*vec_ptr)[index] = (*vec_ptr)[index] + rhs;
- auto lhs_ptr = hoist_pointer_to(index_accessor->object);
- auto index = hoist_expr_to_let(index_accessor->index);
- new_lhs = [&, lhs_ptr, index]() {
- return b.IndexAccessor(b.Deref(lhs_ptr), index);
- };
- } else if (member_accessor && is_vec(member_accessor->structure)) {
- // This is the case for vector component via a member accessor. We just
- // need to capture a pointer to the vector.
- // Before:
- // a[idx()].y += rhs;
- // After:
- // let vec_ptr = &a[idx()];
- // (*vec_ptr).y = (*vec_ptr).y + rhs;
- auto lhs_ptr = hoist_pointer_to(member_accessor->structure);
- new_lhs = [&, lhs_ptr]() {
- return b.MemberAccessor(b.Deref(lhs_ptr),
- ctx.Clone(member_accessor->member));
- };
- } else {
- // For all other statements that may have side-effecting expressions, we
- // just need to capture a pointer to the whole LHS.
- // Before:
- // a[idx()] += rhs;
- // After:
- // let lhs_ptr = &a[idx()];
- // (*lhs_ptr) = (*lhs_ptr) + rhs;
- auto lhs_ptr = hoist_pointer_to(lhs);
- new_lhs = [&, lhs_ptr]() { return b.Deref(lhs_ptr); };
+ // Hoist the LHS expression subtree into local constants to produce a new
+ // LHS that we can evaluate twice.
+ // We need to special case compound assignments to vector components since
+ // we cannot take the address of a vector component.
+ auto* index_accessor = lhs->As<ast::IndexAccessorExpression>();
+ auto* member_accessor = lhs->As<ast::MemberAccessorExpression>();
+ if (lhs->Is<ast::IdentifierExpression>() ||
+ (member_accessor && member_accessor->structure->Is<ast::IdentifierExpression>())) {
+ // This is the simple case with no side effects, so we can just use the
+ // original LHS expression directly.
+ // Before:
+ // foo.bar += rhs;
+ // After:
+ // foo.bar = foo.bar + rhs;
+ new_lhs = [&]() { return ctx.Clone(lhs); };
+ } else if (index_accessor && is_vec(index_accessor->object)) {
+ // This is the case for vector component via an array accessor. We need
+ // to capture a pointer to the vector and also the index value.
+ // Before:
+ // v[idx()] += rhs;
+ // After:
+ // let vec_ptr = &v;
+ // let index = idx();
+ // (*vec_ptr)[index] = (*vec_ptr)[index] + rhs;
+ auto lhs_ptr = hoist_pointer_to(index_accessor->object);
+ auto index = hoist_expr_to_let(index_accessor->index);
+ new_lhs = [&, lhs_ptr, index]() { return b.IndexAccessor(b.Deref(lhs_ptr), index); };
+ } else if (member_accessor && is_vec(member_accessor->structure)) {
+ // This is the case for vector component via a member accessor. We just
+ // need to capture a pointer to the vector.
+ // Before:
+ // a[idx()].y += rhs;
+ // After:
+ // let vec_ptr = &a[idx()];
+ // (*vec_ptr).y = (*vec_ptr).y + rhs;
+ auto lhs_ptr = hoist_pointer_to(member_accessor->structure);
+ new_lhs = [&, lhs_ptr]() {
+ return b.MemberAccessor(b.Deref(lhs_ptr), ctx.Clone(member_accessor->member));
+ };
+ } else {
+ // For all other statements that may have side-effecting expressions, we
+ // just need to capture a pointer to the whole LHS.
+ // Before:
+ // a[idx()] += rhs;
+ // After:
+ // let lhs_ptr = &a[idx()];
+ // (*lhs_ptr) = (*lhs_ptr) + rhs;
+ auto lhs_ptr = hoist_pointer_to(lhs);
+ new_lhs = [&, lhs_ptr]() { return b.Deref(lhs_ptr); };
+ }
+
+ // Replace the statement with a regular assignment statement.
+ auto* value = b.create<ast::BinaryExpression>(op, new_lhs(), rhs);
+ ctx.Replace(stmt, b.Assign(new_lhs(), value));
}
- // Replace the statement with a regular assignment statement.
- auto* value = b.create<ast::BinaryExpression>(op, new_lhs(), rhs);
- ctx.Replace(stmt, b.Assign(new_lhs(), value));
- }
-
- /// Finalize the transformation and clone the module.
- void Finalize() {
- hoist_to_decl_before.Apply();
- ctx.Clone();
- }
+ /// Finalize the transformation and clone the module.
+ void Finalize() {
+ hoist_to_decl_before.Apply();
+ ctx.Clone();
+ }
};
-void ExpandCompoundAssignment::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- State state(ctx);
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- if (auto* assign = node->As<ast::CompoundAssignmentStatement>()) {
- state.Expand(assign, assign->lhs, ctx.Clone(assign->rhs), assign->op);
- } else if (auto* inc_dec = node->As<ast::IncrementDecrementStatement>()) {
- // For increment/decrement statements, `i++` becomes `i = i + 1`.
- // TODO(jrprice): Simplify this when we have untyped literals.
- auto* sem_lhs = ctx.src->Sem().Get(inc_dec->lhs);
- const ast::IntLiteralExpression* one =
- sem_lhs->Type()->UnwrapRef()->is_signed_integer_scalar()
- ? ctx.dst->Expr(1)->As<ast::IntLiteralExpression>()
- : ctx.dst->Expr(1u)->As<ast::IntLiteralExpression>();
- auto op =
- inc_dec->increment ? ast::BinaryOp::kAdd : ast::BinaryOp::kSubtract;
- state.Expand(inc_dec, inc_dec->lhs, one, op);
+void ExpandCompoundAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ State state(ctx);
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* assign = node->As<ast::CompoundAssignmentStatement>()) {
+ state.Expand(assign, assign->lhs, ctx.Clone(assign->rhs), assign->op);
+ } else if (auto* inc_dec = node->As<ast::IncrementDecrementStatement>()) {
+ // For increment/decrement statements, `i++` becomes `i = i + 1`.
+ // TODO(jrprice): Simplify this when we have untyped literals.
+ auto* sem_lhs = ctx.src->Sem().Get(inc_dec->lhs);
+ const ast::IntLiteralExpression* one =
+ sem_lhs->Type()->UnwrapRef()->is_signed_integer_scalar()
+ ? ctx.dst->Expr(1)->As<ast::IntLiteralExpression>()
+ : ctx.dst->Expr(1u)->As<ast::IntLiteralExpression>();
+ auto op = inc_dec->increment ? ast::BinaryOp::kAdd : ast::BinaryOp::kSubtract;
+ state.Expand(inc_dec, inc_dec->lhs, one, op);
+ }
}
- }
- state.Finalize();
+ state.Finalize();
}
} // namespace tint::transform
diff --git a/src/tint/transform/expand_compound_assignment.h b/src/tint/transform/expand_compound_assignment.h
index b461bed..d38d297 100644
--- a/src/tint/transform/expand_compound_assignment.h
+++ b/src/tint/transform/expand_compound_assignment.h
@@ -38,30 +38,26 @@
///
/// This transform also handles increment and decrement statements in the same
/// manner, by replacing `i++` with `i = i + 1`.
-class ExpandCompoundAssignment
- : public Castable<ExpandCompoundAssignment, Transform> {
- public:
- /// Constructor
- ExpandCompoundAssignment();
- /// Destructor
- ~ExpandCompoundAssignment() override;
+class ExpandCompoundAssignment : public Castable<ExpandCompoundAssignment, Transform> {
+ public:
+ /// Constructor
+ ExpandCompoundAssignment();
+ /// Destructor
+ ~ExpandCompoundAssignment() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/expand_compound_assignment_test.cc b/src/tint/transform/expand_compound_assignment_test.cc
index d3fa510..2d343b9 100644
--- a/src/tint/transform/expand_compound_assignment_test.cc
+++ b/src/tint/transform/expand_compound_assignment_test.cc
@@ -24,55 +24,55 @@
using ExpandCompoundAssignmentTest = TransformTest;
TEST_F(ExpandCompoundAssignmentTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<ExpandCompoundAssignment>(src));
+ EXPECT_FALSE(ShouldRun<ExpandCompoundAssignment>(src));
}
TEST_F(ExpandCompoundAssignmentTest, ShouldRunHasCompoundAssignment) {
- auto* src = R"(
+ auto* src = R"(
fn foo() {
var v : i32;
v += 1;
}
)";
- EXPECT_TRUE(ShouldRun<ExpandCompoundAssignment>(src));
+ EXPECT_TRUE(ShouldRun<ExpandCompoundAssignment>(src));
}
TEST_F(ExpandCompoundAssignmentTest, ShouldRunHasIncrementDecrement) {
- auto* src = R"(
+ auto* src = R"(
fn foo() {
var v : i32;
v++;
}
)";
- EXPECT_TRUE(ShouldRun<ExpandCompoundAssignment>(src));
+ EXPECT_TRUE(ShouldRun<ExpandCompoundAssignment>(src));
}
TEST_F(ExpandCompoundAssignmentTest, Basic) {
- auto* src = R"(
+ auto* src = R"(
fn main() {
var v : i32;
v += 1;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn main() {
var v : i32;
v = (v + 1);
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, LhsPointer) {
- auto* src = R"(
+ auto* src = R"(
fn main() {
var v : i32;
let p = &v;
@@ -80,7 +80,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn main() {
var v : i32;
let p = &(v);
@@ -89,13 +89,13 @@
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, LhsStructMember) {
- auto* src = R"(
+ auto* src = R"(
struct S {
m : f32,
}
@@ -106,7 +106,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
m : f32,
}
@@ -117,13 +117,13 @@
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, LhsArrayElement) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : array<i32, 4>;
fn idx() -> i32 {
@@ -136,7 +136,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : array<i32, 4>;
fn idx() -> i32 {
@@ -150,13 +150,13 @@
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, LhsVectorComponent_ArrayAccessor) {
- auto* src = R"(
+ auto* src = R"(
var<private> v : vec4<i32>;
fn idx() -> i32 {
@@ -169,7 +169,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> v : vec4<i32>;
fn idx() -> i32 {
@@ -184,33 +184,33 @@
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, LhsVectorComponent_MemberAccessor) {
- auto* src = R"(
+ auto* src = R"(
fn main() {
var v : vec4<i32>;
v.y += 1;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn main() {
var v : vec4<i32>;
v.y = (v.y + 1);
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, LhsMatrixColumn) {
- auto* src = R"(
+ auto* src = R"(
var<private> m : mat4x4<f32>;
fn idx() -> i32 {
@@ -223,7 +223,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> m : mat4x4<f32>;
fn idx() -> i32 {
@@ -237,13 +237,13 @@
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, LhsMatrixElement) {
- auto* src = R"(
+ auto* src = R"(
var<private> m : mat4x4<f32>;
fn idx1() -> i32 {
@@ -261,7 +261,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> m : mat4x4<f32>;
fn idx1() -> i32 {
@@ -281,13 +281,13 @@
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, LhsMultipleSideEffects) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : array<vec4<f32>, 3>,
}
@@ -316,7 +316,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : array<vec4<f32>, 3>,
}
@@ -347,13 +347,13 @@
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, ForLoopInit) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : array<vec4<i32>, 4>;
var<private> p : i32;
@@ -375,7 +375,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : array<vec4<i32>, 4>;
var<private> p : i32;
@@ -399,13 +399,13 @@
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, ForLoopCont) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : array<vec4<i32>, 4>;
var<private> p : i32;
@@ -427,7 +427,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : array<vec4<i32>, 4>;
var<private> p : i32;
@@ -457,93 +457,93 @@
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, Increment_I32) {
- auto* src = R"(
+ auto* src = R"(
fn main() {
var v : i32;
v++;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn main() {
var v : i32;
v = (v + 1);
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, Increment_U32) {
- auto* src = R"(
+ auto* src = R"(
fn main() {
var v : u32;
v++;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn main() {
var v : u32;
v = (v + 1u);
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, Decrement_I32) {
- auto* src = R"(
+ auto* src = R"(
fn main() {
var v : i32;
v--;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn main() {
var v : i32;
v = (v - 1);
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, Decrement_U32) {
- auto* src = R"(
+ auto* src = R"(
fn main() {
var v : u32;
v--;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn main() {
var v : u32;
v = (v - 1u);
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, Increment_LhsPointer) {
- auto* src = R"(
+ auto* src = R"(
fn main() {
var v : i32;
let p = &v;
@@ -551,7 +551,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn main() {
var v : i32;
let p = &(v);
@@ -560,13 +560,13 @@
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, Increment_LhsStructMember) {
- auto* src = R"(
+ auto* src = R"(
struct S {
m : i32,
}
@@ -577,7 +577,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
m : i32,
}
@@ -588,13 +588,13 @@
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, Increment_LhsArrayElement) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : array<i32, 4>;
fn idx() -> i32 {
@@ -607,7 +607,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : array<i32, 4>;
fn idx() -> i32 {
@@ -621,14 +621,13 @@
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(ExpandCompoundAssignmentTest,
- Increment_LhsVectorComponent_ArrayAccessor) {
- auto* src = R"(
+TEST_F(ExpandCompoundAssignmentTest, Increment_LhsVectorComponent_ArrayAccessor) {
+ auto* src = R"(
var<private> v : vec4<i32>;
fn idx() -> i32 {
@@ -641,7 +640,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> v : vec4<i32>;
fn idx() -> i32 {
@@ -656,34 +655,33 @@
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(ExpandCompoundAssignmentTest,
- Increment_LhsVectorComponent_MemberAccessor) {
- auto* src = R"(
+TEST_F(ExpandCompoundAssignmentTest, Increment_LhsVectorComponent_MemberAccessor) {
+ auto* src = R"(
fn main() {
var v : vec4<i32>;
v.y++;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn main() {
var v : vec4<i32>;
v.y = (v.y + 1);
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ExpandCompoundAssignmentTest, Increment_ForLoopCont) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : array<vec4<i32>, 4>;
var<private> p : i32;
@@ -705,7 +703,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : array<vec4<i32>, 4>;
var<private> p : i32;
@@ -735,9 +733,9 @@
}
)";
- auto got = Run<ExpandCompoundAssignment>(src);
+ auto got = Run<ExpandCompoundAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/first_index_offset.cc b/src/tint/transform/first_index_offset.cc
index e9d2629..9d89b83 100644
--- a/src/tint/transform/first_index_offset.cc
+++ b/src/tint/transform/first_index_offset.cc
@@ -38,8 +38,7 @@
} // namespace
FirstIndexOffset::BindingPoint::BindingPoint() = default;
-FirstIndexOffset::BindingPoint::BindingPoint(uint32_t b, uint32_t g)
- : binding(b), group(g) {}
+FirstIndexOffset::BindingPoint::BindingPoint(uint32_t b, uint32_t g) : binding(b), group(g) {}
FirstIndexOffset::BindingPoint::~BindingPoint() = default;
FirstIndexOffset::Data::Data(bool has_vtx_or_inst_index)
@@ -51,115 +50,109 @@
FirstIndexOffset::~FirstIndexOffset() = default;
bool FirstIndexOffset::ShouldRun(const Program* program, const DataMap&) const {
- for (auto* fn : program->AST().Functions()) {
- if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
- return true;
+ for (auto* fn : program->AST().Functions()) {
+ if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
+ return true;
+ }
}
- }
- return false;
+ return false;
}
-void FirstIndexOffset::Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const {
- // Get the uniform buffer binding point
- uint32_t ub_binding = binding_;
- uint32_t ub_group = group_;
- if (auto* binding_point = inputs.Get<BindingPoint>()) {
- ub_binding = binding_point->binding;
- ub_group = binding_point->group;
- }
-
- // Map of builtin usages
- std::unordered_map<const sem::Variable*, const char*> builtin_vars;
- std::unordered_map<const sem::StructMember*, const char*> builtin_members;
-
- bool has_vertex_or_instance_index = false;
-
- // Traverse the AST scanning for builtin accesses via variables (includes
- // parameters) or structure member accesses.
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- if (auto* var = node->As<ast::Variable>()) {
- for (auto* attr : var->attributes) {
- if (auto* builtin_attr = attr->As<ast::BuiltinAttribute>()) {
- ast::Builtin builtin = builtin_attr->builtin;
- if (builtin == ast::Builtin::kVertexIndex) {
- auto* sem_var = ctx.src->Sem().Get(var);
- builtin_vars.emplace(sem_var, kFirstVertexName);
- has_vertex_or_instance_index = true;
- }
- if (builtin == ast::Builtin::kInstanceIndex) {
- auto* sem_var = ctx.src->Sem().Get(var);
- builtin_vars.emplace(sem_var, kFirstInstanceName);
- has_vertex_or_instance_index = true;
- }
- }
- }
+void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const {
+ // Get the uniform buffer binding point
+ uint32_t ub_binding = binding_;
+ uint32_t ub_group = group_;
+ if (auto* binding_point = inputs.Get<BindingPoint>()) {
+ ub_binding = binding_point->binding;
+ ub_group = binding_point->group;
}
- if (auto* member = node->As<ast::StructMember>()) {
- for (auto* attr : member->attributes) {
- if (auto* builtin_attr = attr->As<ast::BuiltinAttribute>()) {
- ast::Builtin builtin = builtin_attr->builtin;
- if (builtin == ast::Builtin::kVertexIndex) {
- auto* sem_mem = ctx.src->Sem().Get(member);
- builtin_members.emplace(sem_mem, kFirstVertexName);
- has_vertex_or_instance_index = true;
- }
- if (builtin == ast::Builtin::kInstanceIndex) {
- auto* sem_mem = ctx.src->Sem().Get(member);
- builtin_members.emplace(sem_mem, kFirstInstanceName);
- has_vertex_or_instance_index = true;
- }
+
+ // Map of builtin usages
+ std::unordered_map<const sem::Variable*, const char*> builtin_vars;
+ std::unordered_map<const sem::StructMember*, const char*> builtin_members;
+
+ bool has_vertex_or_instance_index = false;
+
+ // Traverse the AST scanning for builtin accesses via variables (includes
+ // parameters) or structure member accesses.
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* var = node->As<ast::Variable>()) {
+ for (auto* attr : var->attributes) {
+ if (auto* builtin_attr = attr->As<ast::BuiltinAttribute>()) {
+ ast::Builtin builtin = builtin_attr->builtin;
+ if (builtin == ast::Builtin::kVertexIndex) {
+ auto* sem_var = ctx.src->Sem().Get(var);
+ builtin_vars.emplace(sem_var, kFirstVertexName);
+ has_vertex_or_instance_index = true;
+ }
+ if (builtin == ast::Builtin::kInstanceIndex) {
+ auto* sem_var = ctx.src->Sem().Get(var);
+ builtin_vars.emplace(sem_var, kFirstInstanceName);
+ has_vertex_or_instance_index = true;
+ }
+ }
+ }
}
- }
+ if (auto* member = node->As<ast::StructMember>()) {
+ for (auto* attr : member->attributes) {
+ if (auto* builtin_attr = attr->As<ast::BuiltinAttribute>()) {
+ ast::Builtin builtin = builtin_attr->builtin;
+ if (builtin == ast::Builtin::kVertexIndex) {
+ auto* sem_mem = ctx.src->Sem().Get(member);
+ builtin_members.emplace(sem_mem, kFirstVertexName);
+ has_vertex_or_instance_index = true;
+ }
+ if (builtin == ast::Builtin::kInstanceIndex) {
+ auto* sem_mem = ctx.src->Sem().Get(member);
+ builtin_members.emplace(sem_mem, kFirstInstanceName);
+ has_vertex_or_instance_index = true;
+ }
+ }
+ }
+ }
}
- }
- if (has_vertex_or_instance_index) {
- // Add uniform buffer members and calculate byte offsets
- ast::StructMemberList members;
- members.push_back(ctx.dst->Member(kFirstVertexName, ctx.dst->ty.u32()));
- members.push_back(ctx.dst->Member(kFirstInstanceName, ctx.dst->ty.u32()));
- auto* struct_ = ctx.dst->Structure(ctx.dst->Sym(), std::move(members));
+ if (has_vertex_or_instance_index) {
+ // Add uniform buffer members and calculate byte offsets
+ ast::StructMemberList members;
+ members.push_back(ctx.dst->Member(kFirstVertexName, ctx.dst->ty.u32()));
+ members.push_back(ctx.dst->Member(kFirstInstanceName, ctx.dst->ty.u32()));
+ auto* struct_ = ctx.dst->Structure(ctx.dst->Sym(), std::move(members));
- // Create a global to hold the uniform buffer
- Symbol buffer_name = ctx.dst->Sym();
- ctx.dst->Global(buffer_name, ctx.dst->ty.Of(struct_),
- ast::StorageClass::kUniform, nullptr,
- ast::AttributeList{
- ctx.dst->create<ast::BindingAttribute>(ub_binding),
- ctx.dst->create<ast::GroupAttribute>(ub_group),
- });
+ // Create a global to hold the uniform buffer
+ Symbol buffer_name = ctx.dst->Sym();
+ ctx.dst->Global(buffer_name, ctx.dst->ty.Of(struct_), ast::StorageClass::kUniform, nullptr,
+ ast::AttributeList{
+ ctx.dst->create<ast::BindingAttribute>(ub_binding),
+ ctx.dst->create<ast::GroupAttribute>(ub_group),
+ });
- // Fix up all references to the builtins with the offsets
- ctx.ReplaceAll(
- [=, &ctx](const ast::Expression* expr) -> const ast::Expression* {
- if (auto* sem = ctx.src->Sem().Get(expr)) {
- if (auto* user = sem->As<sem::VariableUser>()) {
- auto it = builtin_vars.find(user->Variable());
- if (it != builtin_vars.end()) {
- return ctx.dst->Add(
- ctx.CloneWithoutTransform(expr),
- ctx.dst->MemberAccessor(buffer_name, it->second));
- }
+ // Fix up all references to the builtins with the offsets
+ ctx.ReplaceAll([=, &ctx](const ast::Expression* expr) -> const ast::Expression* {
+ if (auto* sem = ctx.src->Sem().Get(expr)) {
+ if (auto* user = sem->As<sem::VariableUser>()) {
+ auto it = builtin_vars.find(user->Variable());
+ if (it != builtin_vars.end()) {
+ return ctx.dst->Add(ctx.CloneWithoutTransform(expr),
+ ctx.dst->MemberAccessor(buffer_name, it->second));
+ }
+ }
+ if (auto* access = sem->As<sem::StructMemberAccess>()) {
+ auto it = builtin_members.find(access->Member());
+ if (it != builtin_members.end()) {
+ return ctx.dst->Add(ctx.CloneWithoutTransform(expr),
+ ctx.dst->MemberAccessor(buffer_name, it->second));
+ }
+ }
}
- if (auto* access = sem->As<sem::StructMemberAccess>()) {
- auto it = builtin_members.find(access->Member());
- if (it != builtin_members.end()) {
- return ctx.dst->Add(
- ctx.CloneWithoutTransform(expr),
- ctx.dst->MemberAccessor(buffer_name, it->second));
- }
- }
- }
- // Not interested in this experssion. Just clone.
- return nullptr;
+ // Not interested in this experssion. Just clone.
+ return nullptr;
});
- }
+ }
- ctx.Clone();
+ ctx.Clone();
- outputs.Add<Data>(has_vertex_or_instance_index);
+ outputs.Add<Data>(has_vertex_or_instance_index);
}
} // namespace tint::transform
diff --git a/src/tint/transform/first_index_offset.h b/src/tint/transform/first_index_offset.h
index f0ec791..04758cd 100644
--- a/src/tint/transform/first_index_offset.h
+++ b/src/tint/transform/first_index_offset.h
@@ -58,71 +58,68 @@
/// ```
///
class FirstIndexOffset final : public Castable<FirstIndexOffset, Transform> {
- public:
- /// BindingPoint is consumed by the FirstIndexOffset transform.
- /// BindingPoint specifies the binding point of the first index uniform
- /// buffer.
- struct BindingPoint final : public Castable<BindingPoint, transform::Data> {
- /// Constructor
- BindingPoint();
+ public:
+ /// BindingPoint is consumed by the FirstIndexOffset transform.
+ /// BindingPoint specifies the binding point of the first index uniform
+ /// buffer.
+ struct BindingPoint final : public Castable<BindingPoint, transform::Data> {
+ /// Constructor
+ BindingPoint();
+
+ /// Constructor
+ /// @param b the binding index
+ /// @param g the binding group
+ BindingPoint(uint32_t b, uint32_t g);
+
+ /// Destructor
+ ~BindingPoint() override;
+
+ /// `@binding()` for the first vertex / first instance uniform buffer
+ uint32_t binding = 0;
+ /// `@group()` for the first vertex / first instance uniform buffer
+ uint32_t group = 0;
+ };
+
+ /// Data is outputted by the FirstIndexOffset transform.
+ /// Data holds information about shader usage and constant buffer offsets.
+ struct Data final : public Castable<Data, transform::Data> {
+ /// Constructor
+ /// @param has_vtx_or_inst_index True if the shader uses vertex_index or
+ /// instance_index
+ explicit Data(bool has_vtx_or_inst_index);
+
+ /// Copy constructor
+ Data(const Data&);
+
+ /// Destructor
+ ~Data() override;
+
+ /// True if the shader uses vertex_index
+ const bool has_vertex_or_instance_index;
+ };
/// Constructor
- /// @param b the binding index
- /// @param g the binding group
- BindingPoint(uint32_t b, uint32_t g);
-
+ FirstIndexOffset();
/// Destructor
- ~BindingPoint() override;
+ ~FirstIndexOffset() override;
- /// `@binding()` for the first vertex / first instance uniform buffer
- uint32_t binding = 0;
- /// `@group()` for the first vertex / first instance uniform buffer
- uint32_t group = 0;
- };
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- /// Data is outputted by the FirstIndexOffset transform.
- /// Data holds information about shader usage and constant buffer offsets.
- struct Data final : public Castable<Data, transform::Data> {
- /// Constructor
- /// @param has_vtx_or_inst_index True if the shader uses vertex_index or
- /// instance_index
- explicit Data(bool has_vtx_or_inst_index);
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
- /// Copy constructor
- Data(const Data&);
-
- /// Destructor
- ~Data() override;
-
- /// True if the shader uses vertex_index
- const bool has_vertex_or_instance_index;
- };
-
- /// Constructor
- FirstIndexOffset();
- /// Destructor
- ~FirstIndexOffset() override;
-
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
-
- private:
- uint32_t binding_ = 0;
- uint32_t group_ = 0;
+ private:
+ uint32_t binding_ = 0;
+ uint32_t group_ = 0;
};
} // namespace tint::transform
diff --git a/src/tint/transform/first_index_offset_test.cc b/src/tint/transform/first_index_offset_test.cc
index 3dc4c71..a467c17 100644
--- a/src/tint/transform/first_index_offset_test.cc
+++ b/src/tint/transform/first_index_offset_test.cc
@@ -26,71 +26,71 @@
using FirstIndexOffsetTest = TransformTest;
TEST_F(FirstIndexOffsetTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<FirstIndexOffset>(src));
+ EXPECT_FALSE(ShouldRun<FirstIndexOffset>(src));
}
TEST_F(FirstIndexOffsetTest, ShouldRunFragmentStage) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn entry() {
return;
}
)";
- EXPECT_FALSE(ShouldRun<FirstIndexOffset>(src));
+ EXPECT_FALSE(ShouldRun<FirstIndexOffset>(src));
}
TEST_F(FirstIndexOffsetTest, ShouldRunVertexStage) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn entry() -> @builtin(position) vec4<f32> {
return vec4<f32>();
}
)";
- EXPECT_TRUE(ShouldRun<FirstIndexOffset>(src));
+ EXPECT_TRUE(ShouldRun<FirstIndexOffset>(src));
}
TEST_F(FirstIndexOffsetTest, EmptyModule) {
- auto* src = "";
- auto* expect = "";
+ auto* src = "";
+ auto* expect = "";
- DataMap config;
- config.Add<FirstIndexOffset::BindingPoint>(0, 0);
- auto got = Run<FirstIndexOffset>(src, std::move(config));
+ DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(0, 0);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<FirstIndexOffset::Data>();
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
- EXPECT_EQ(data, nullptr);
+ EXPECT_EQ(data, nullptr);
}
TEST_F(FirstIndexOffsetTest, BasicVertexShader) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn entry() -> @builtin(position) vec4<f32> {
return vec4<f32>();
}
)";
- auto* expect = src;
+ auto* expect = src;
- DataMap config;
- config.Add<FirstIndexOffset::BindingPoint>(0, 0);
- auto got = Run<FirstIndexOffset>(src, std::move(config));
+ DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(0, 0);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<FirstIndexOffset::Data>();
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
- ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, false);
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_or_instance_index, false);
}
TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex) {
- auto* src = R"(
+ auto* src = R"(
fn test(vert_idx : u32) -> u32 {
return vert_idx;
}
@@ -102,7 +102,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32,
first_instance_index : u32,
@@ -121,20 +121,20 @@
}
)";
- DataMap config;
- config.Add<FirstIndexOffset::BindingPoint>(1, 2);
- auto got = Run<FirstIndexOffset>(src, std::move(config));
+ DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 2);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<FirstIndexOffset::Data>();
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
- ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_or_instance_index, true);
}
TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
test(vert_idx);
@@ -146,7 +146,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32,
first_instance_index : u32,
@@ -165,20 +165,20 @@
}
)";
- DataMap config;
- config.Add<FirstIndexOffset::BindingPoint>(1, 2);
- auto got = Run<FirstIndexOffset>(src, std::move(config));
+ DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 2);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<FirstIndexOffset::Data>();
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
- ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_or_instance_index, true);
}
TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex) {
- auto* src = R"(
+ auto* src = R"(
fn test(inst_idx : u32) -> u32 {
return inst_idx;
}
@@ -190,7 +190,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32,
first_instance_index : u32,
@@ -209,20 +209,20 @@
}
)";
- DataMap config;
- config.Add<FirstIndexOffset::BindingPoint>(1, 7);
- auto got = Run<FirstIndexOffset>(src, std::move(config));
+ DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 7);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<FirstIndexOffset::Data>();
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
- ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_or_instance_index, true);
}
TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn entry(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
test(inst_idx);
@@ -234,7 +234,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32,
first_instance_index : u32,
@@ -253,20 +253,20 @@
}
)";
- DataMap config;
- config.Add<FirstIndexOffset::BindingPoint>(1, 7);
- auto got = Run<FirstIndexOffset>(src, std::move(config));
+ DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 7);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<FirstIndexOffset::Data>();
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
- ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_or_instance_index, true);
}
TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex) {
- auto* src = R"(
+ auto* src = R"(
fn test(instance_idx : u32, vert_idx : u32) -> u32 {
return instance_idx + vert_idx;
}
@@ -283,7 +283,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32,
first_instance_index : u32,
@@ -309,20 +309,20 @@
}
)";
- DataMap config;
- config.Add<FirstIndexOffset::BindingPoint>(1, 2);
- auto got = Run<FirstIndexOffset>(src, std::move(config));
+ DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 2);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<FirstIndexOffset::Data>();
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
- ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_or_instance_index, true);
}
TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn entry(inputs : Inputs) -> @builtin(position) vec4<f32> {
test(inputs.instance_idx, inputs.vert_idx);
@@ -339,7 +339,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32,
first_instance_index : u32,
@@ -365,20 +365,20 @@
}
)";
- DataMap config;
- config.Add<FirstIndexOffset::BindingPoint>(1, 2);
- auto got = Run<FirstIndexOffset>(src, std::move(config));
+ DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 2);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<FirstIndexOffset::Data>();
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
- ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_or_instance_index, true);
}
TEST_F(FirstIndexOffsetTest, NestedCalls) {
- auto* src = R"(
+ auto* src = R"(
fn func1(vert_idx : u32) -> u32 {
return vert_idx;
}
@@ -394,7 +394,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32,
first_instance_index : u32,
@@ -417,20 +417,20 @@
}
)";
- DataMap config;
- config.Add<FirstIndexOffset::BindingPoint>(1, 2);
- auto got = Run<FirstIndexOffset>(src, std::move(config));
+ DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 2);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<FirstIndexOffset::Data>();
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
- ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_or_instance_index, true);
}
TEST_F(FirstIndexOffsetTest, NestedCalls_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
func2(vert_idx);
@@ -446,7 +446,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32,
first_instance_index : u32,
@@ -469,20 +469,20 @@
}
)";
- DataMap config;
- config.Add<FirstIndexOffset::BindingPoint>(1, 2);
- auto got = Run<FirstIndexOffset>(src, std::move(config));
+ DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 2);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<FirstIndexOffset::Data>();
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
- ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_or_instance_index, true);
}
TEST_F(FirstIndexOffsetTest, MultipleEntryPoints) {
- auto* src = R"(
+ auto* src = R"(
fn func(i : u32) -> u32 {
return i;
}
@@ -506,7 +506,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32,
first_instance_index : u32,
@@ -537,20 +537,20 @@
}
)";
- DataMap config;
- config.Add<FirstIndexOffset::BindingPoint>(1, 2);
- auto got = Run<FirstIndexOffset>(src, std::move(config));
+ DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 2);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<FirstIndexOffset::Data>();
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
- ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_or_instance_index, true);
}
TEST_F(FirstIndexOffsetTest, MultipleEntryPoints_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn entry_a(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
func(vert_idx);
@@ -574,7 +574,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32,
first_instance_index : u32,
@@ -605,16 +605,16 @@
}
)";
- DataMap config;
- config.Add<FirstIndexOffset::BindingPoint>(1, 2);
- auto got = Run<FirstIndexOffset>(src, std::move(config));
+ DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 2);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<FirstIndexOffset::Data>();
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
- ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_or_instance_index, true);
}
} // namespace
diff --git a/src/tint/transform/fold_constants.cc b/src/tint/transform/fold_constants.cc
index b814c5c..be547b1 100644
--- a/src/tint/transform/fold_constants.cc
+++ b/src/tint/transform/fold_constants.cc
@@ -33,65 +33,62 @@
FoldConstants::~FoldConstants() = default;
void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
- auto* call = ctx.src->Sem().Get<sem::Call>(expr);
- if (!call) {
- return nullptr;
- }
+ ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
+ auto* call = ctx.src->Sem().Get<sem::Call>(expr);
+ if (!call) {
+ return nullptr;
+ }
- auto value = call->ConstantValue();
- if (!value.IsValid()) {
- return nullptr;
- }
+ auto value = call->ConstantValue();
+ if (!value.IsValid()) {
+ return nullptr;
+ }
- auto* ty = call->Type();
+ auto* ty = call->Type();
- if (!call->Target()->IsAnyOf<sem::TypeConversion, sem::TypeConstructor>()) {
- return nullptr;
- }
+ if (!call->Target()->IsAnyOf<sem::TypeConversion, sem::TypeConstructor>()) {
+ return nullptr;
+ }
- // If original ctor expression had no init values, don't replace the
- // expression
- if (call->Arguments().empty()) {
- return nullptr;
- }
+ // If original ctor expression had no init values, don't replace the
+ // expression
+ if (call->Arguments().empty()) {
+ return nullptr;
+ }
- if (auto* vec = ty->As<sem::Vector>()) {
- uint32_t vec_size = static_cast<uint32_t>(vec->Width());
+ if (auto* vec = ty->As<sem::Vector>()) {
+ uint32_t vec_size = static_cast<uint32_t>(vec->Width());
- // We'd like to construct the new vector with the same number of
- // constructor args that the original node had, but after folding
- // constants, cases like the following are problematic:
- //
- // vec3<f32> = vec3<f32>(vec2<f32>, 1.0) // vec_size=3, ctor_size=2
- //
- // In this case, creating a vec3 with 2 args is invalid, so we should
- // create it with 3. So what we do is construct with vec_size args,
- // except if the original vector was single-value initialized, in
- // which case, we only construct with one arg again.
- uint32_t ctor_size = (call->Arguments().size() == 1) ? 1 : vec_size;
+ // We'd like to construct the new vector with the same number of
+ // constructor args that the original node had, but after folding
+ // constants, cases like the following are problematic:
+ //
+ // vec3<f32> = vec3<f32>(vec2<f32>, 1.0) // vec_size=3, ctor_size=2
+ //
+ // In this case, creating a vec3 with 2 args is invalid, so we should
+ // create it with 3. So what we do is construct with vec_size args,
+ // except if the original vector was single-value initialized, in
+ // which case, we only construct with one arg again.
+ uint32_t ctor_size = (call->Arguments().size() == 1) ? 1 : vec_size;
- ast::ExpressionList ctors;
- for (uint32_t i = 0; i < ctor_size; ++i) {
- value.WithScalarAt(
- i, [&](auto&& s) { ctors.emplace_back(ctx.dst->Expr(s)); });
- }
+ ast::ExpressionList ctors;
+ for (uint32_t i = 0; i < ctor_size; ++i) {
+ value.WithScalarAt(i, [&](auto&& s) { ctors.emplace_back(ctx.dst->Expr(s)); });
+ }
- auto* el_ty = CreateASTTypeFor(ctx, vec->type());
- return ctx.dst->vec(el_ty, vec_size, ctors);
- }
+ auto* el_ty = CreateASTTypeFor(ctx, vec->type());
+ return ctx.dst->vec(el_ty, vec_size, ctors);
+ }
- if (ty->is_scalar()) {
- return value.WithScalarAt(0,
- [&](auto&& s) -> const ast::LiteralExpression* {
- return ctx.dst->Expr(s);
- });
- }
+ if (ty->is_scalar()) {
+ return value.WithScalarAt(
+ 0, [&](auto&& s) -> const ast::LiteralExpression* { return ctx.dst->Expr(s); });
+ }
- return nullptr;
- });
+ return nullptr;
+ });
- ctx.Clone();
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/fold_constants.h b/src/tint/transform/fold_constants.h
index fbe7abd..ed3e205 100644
--- a/src/tint/transform/fold_constants.h
+++ b/src/tint/transform/fold_constants.h
@@ -21,23 +21,21 @@
/// FoldConstants transforms the AST by folding constant expressions
class FoldConstants final : public Castable<FoldConstants, Transform> {
- public:
- /// Constructor
- FoldConstants();
+ public:
+ /// Constructor
+ FoldConstants();
- /// Destructor
- ~FoldConstants() override;
+ /// Destructor
+ ~FoldConstants() override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/fold_constants_test.cc b/src/tint/transform/fold_constants_test.cc
index d8121bc..5c81fb8 100644
--- a/src/tint/transform/fold_constants_test.cc
+++ b/src/tint/transform/fold_constants_test.cc
@@ -26,7 +26,7 @@
using FoldConstantsTest = TransformTest;
TEST_F(FoldConstantsTest, Module_Scalar_NoConversion) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : i32 = i32(123);
var<private> b : u32 = u32(123u);
var<private> c : f32 = f32(123.0);
@@ -36,7 +36,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : i32 = 123;
var<private> b : u32 = 123u;
@@ -49,13 +49,13 @@
}
)";
- auto got = Run<FoldConstants>(src);
+ auto got = Run<FoldConstants>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Module_Scalar_Conversion) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : i32 = i32(123.0);
var<private> b : u32 = u32(123);
var<private> c : f32 = f32(123u);
@@ -65,7 +65,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : i32 = 123;
var<private> b : u32 = 123u;
@@ -78,13 +78,13 @@
}
)";
- auto got = Run<FoldConstants>(src);
+ auto got = Run<FoldConstants>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Module_Scalar_MultipleConversions) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : i32 = i32(u32(f32(u32(i32(123.0)))));
var<private> b : u32 = u32(i32(f32(i32(u32(123)))));
var<private> c : f32 = f32(u32(i32(u32(f32(123u)))));
@@ -94,7 +94,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : i32 = 123;
var<private> b : u32 = 123u;
@@ -107,13 +107,13 @@
}
)";
- auto got = Run<FoldConstants>(src);
+ auto got = Run<FoldConstants>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Module_Vector_NoConversion) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : vec3<i32> = vec3<i32>(123);
var<private> b : vec3<u32> = vec3<u32>(123u);
var<private> c : vec3<f32> = vec3<f32>(123.0);
@@ -123,7 +123,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : vec3<i32> = vec3<i32>(123);
var<private> b : vec3<u32> = vec3<u32>(123u);
@@ -136,13 +136,13 @@
}
)";
- auto got = Run<FoldConstants>(src);
+ auto got = Run<FoldConstants>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Module_Vector_Conversion) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : vec3<i32> = vec3<i32>(vec3<f32>(123.0));
var<private> b : vec3<u32> = vec3<u32>(vec3<i32>(123));
var<private> c : vec3<f32> = vec3<f32>(vec3<u32>(123u));
@@ -152,7 +152,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : vec3<i32> = vec3<i32>(123);
var<private> b : vec3<u32> = vec3<u32>(123u);
@@ -165,13 +165,13 @@
}
)";
- auto got = Run<FoldConstants>(src);
+ auto got = Run<FoldConstants>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Module_Vector_MultipleConversions) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : vec3<i32> = vec3<i32>(vec3<u32>(vec3<f32>(vec3<u32>(u32(123.0)))));
var<private> b : vec3<u32> = vec3<u32>(vec3<i32>(vec3<f32>(vec3<i32>(i32(123)))));
var<private> c : vec3<f32> = vec3<f32>(vec3<u32>(vec3<i32>(vec3<u32>(u32(123u)))));
@@ -181,7 +181,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : vec3<i32> = vec3<i32>(123);
var<private> b : vec3<u32> = vec3<u32>(123u);
@@ -194,13 +194,13 @@
}
)";
- auto got = Run<FoldConstants>(src);
+ auto got = Run<FoldConstants>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Module_Vector_MixedSizeConversions) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : vec4<i32> = vec4<i32>(vec3<i32>(vec3<u32>(1u, 2u, 3u)), 4);
var<private> b : vec4<i32> = vec4<i32>(vec2<i32>(vec2<u32>(1u, 2u)), vec2<i32>(4, 5));
var<private> c : vec4<i32> = vec4<i32>(1, vec2<i32>(vec2<f32>(2.0, 3.0)), 4);
@@ -211,7 +211,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : vec4<i32> = vec4<i32>(1, 2, 3, 4);
var<private> b : vec4<i32> = vec4<i32>(1, 2, 4, 5);
@@ -226,13 +226,13 @@
}
)";
- auto got = Run<FoldConstants>(src);
+ auto got = Run<FoldConstants>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Function_Scalar_NoConversion) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var a : i32 = i32(123);
var b : u32 = u32(123u);
@@ -241,7 +241,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var a : i32 = 123;
var b : u32 = 123u;
@@ -250,13 +250,13 @@
}
)";
- auto got = Run<FoldConstants>(src);
+ auto got = Run<FoldConstants>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Function_Scalar_Conversion) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var a : i32 = i32(123.0);
var b : u32 = u32(123);
@@ -265,7 +265,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var a : i32 = 123;
var b : u32 = 123u;
@@ -274,13 +274,13 @@
}
)";
- auto got = Run<FoldConstants>(src);
+ auto got = Run<FoldConstants>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Function_Scalar_MultipleConversions) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var a : i32 = i32(u32(f32(u32(i32(123.0)))));
var b : u32 = u32(i32(f32(i32(u32(123)))));
@@ -289,7 +289,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var a : i32 = 123;
var b : u32 = 123u;
@@ -298,13 +298,13 @@
}
)";
- auto got = Run<FoldConstants>(src);
+ auto got = Run<FoldConstants>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Function_Vector_NoConversion) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var a : vec3<i32> = vec3<i32>(123);
var b : vec3<u32> = vec3<u32>(123u);
@@ -313,7 +313,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var a : vec3<i32> = vec3<i32>(123);
var b : vec3<u32> = vec3<u32>(123u);
@@ -322,13 +322,13 @@
}
)";
- auto got = Run<FoldConstants>(src);
+ auto got = Run<FoldConstants>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Function_Vector_Conversion) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var a : vec3<i32> = vec3<i32>(vec3<f32>(123.0));
var b : vec3<u32> = vec3<u32>(vec3<i32>(123));
@@ -337,7 +337,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var a : vec3<i32> = vec3<i32>(123);
var b : vec3<u32> = vec3<u32>(123u);
@@ -346,13 +346,13 @@
}
)";
- auto got = Run<FoldConstants>(src);
+ auto got = Run<FoldConstants>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Function_Vector_MultipleConversions) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var a : vec3<i32> = vec3<i32>(vec3<u32>(vec3<f32>(vec3<u32>(u32(123.0)))));
var b : vec3<u32> = vec3<u32>(vec3<i32>(vec3<f32>(vec3<i32>(i32(123)))));
@@ -361,7 +361,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var a : vec3<i32> = vec3<i32>(123);
var b : vec3<u32> = vec3<u32>(123u);
@@ -370,13 +370,13 @@
}
)";
- auto got = Run<FoldConstants>(src);
+ auto got = Run<FoldConstants>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Function_Vector_MixedSizeConversions) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var a : vec4<i32> = vec4<i32>(vec3<i32>(vec3<u32>(1u, 2u, 3u)), 4);
var b : vec4<i32> = vec4<i32>(vec2<i32>(vec2<u32>(1u, 2u)), vec2<i32>(4, 5));
@@ -386,7 +386,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var a : vec4<i32> = vec4<i32>(1, 2, 3, 4);
var b : vec4<i32> = vec4<i32>(1, 2, 4, 5);
@@ -396,29 +396,29 @@
}
)";
- auto got = Run<FoldConstants>(src);
+ auto got = Run<FoldConstants>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Function_Vector_ConstantWithNonConstant) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var a : f32 = f32();
var b : vec2<f32> = vec2<f32>(f32(i32(1)), a);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var a : f32 = f32();
var b : vec2<f32> = vec2<f32>(1.0, a);
}
)";
- auto got = Run<FoldConstants>(src);
+ auto got = Run<FoldConstants>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/fold_trivial_single_use_lets.cc b/src/tint/transform/fold_trivial_single_use_lets.cc
index a0f02a8..5bcdaa4 100644
--- a/src/tint/transform/fold_trivial_single_use_lets.cc
+++ b/src/tint/transform/fold_trivial_single_use_lets.cc
@@ -27,19 +27,19 @@
namespace {
const ast::VariableDeclStatement* AsTrivialLetDecl(const ast::Statement* stmt) {
- auto* var_decl = stmt->As<ast::VariableDeclStatement>();
- if (!var_decl) {
- return nullptr;
- }
- auto* var = var_decl->variable;
- if (!var->is_const) {
- return nullptr;
- }
- auto* ctor = var->constructor;
- if (!IsAnyOf<ast::IdentifierExpression, ast::LiteralExpression>(ctor)) {
- return nullptr;
- }
- return var_decl;
+ auto* var_decl = stmt->As<ast::VariableDeclStatement>();
+ if (!var_decl) {
+ return nullptr;
+ }
+ auto* var = var_decl->variable;
+ if (!var->is_const) {
+ return nullptr;
+ }
+ auto* ctor = var->constructor;
+ if (!IsAnyOf<ast::IdentifierExpression, ast::LiteralExpression>(ctor)) {
+ return nullptr;
+ }
+ return var_decl;
}
} // namespace
@@ -48,43 +48,41 @@
FoldTrivialSingleUseLets::~FoldTrivialSingleUseLets() = default;
-void FoldTrivialSingleUseLets::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- if (auto* block = node->As<ast::BlockStatement>()) {
- auto& stmts = block->statements;
- for (size_t stmt_idx = 0; stmt_idx < stmts.size(); stmt_idx++) {
- auto* stmt = stmts[stmt_idx];
- if (auto* let_decl = AsTrivialLetDecl(stmt)) {
- auto* let = let_decl->variable;
- auto* sem_let = ctx.src->Sem().Get(let);
- auto& users = sem_let->Users();
- if (users.size() != 1) {
- continue; // Does not have a single user.
- }
+void FoldTrivialSingleUseLets::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* block = node->As<ast::BlockStatement>()) {
+ auto& stmts = block->statements;
+ for (size_t stmt_idx = 0; stmt_idx < stmts.size(); stmt_idx++) {
+ auto* stmt = stmts[stmt_idx];
+ if (auto* let_decl = AsTrivialLetDecl(stmt)) {
+ auto* let = let_decl->variable;
+ auto* sem_let = ctx.src->Sem().Get(let);
+ auto& users = sem_let->Users();
+ if (users.size() != 1) {
+ continue; // Does not have a single user.
+ }
- auto* user = users[0];
- auto* user_stmt = user->Stmt()->Declaration();
+ auto* user = users[0];
+ auto* user_stmt = user->Stmt()->Declaration();
- for (size_t i = stmt_idx; i < stmts.size(); i++) {
- if (user_stmt == stmts[i]) {
- auto* user_expr = user->Declaration();
- ctx.Remove(stmts, let_decl);
- ctx.Replace(user_expr, ctx.Clone(let->constructor));
+ for (size_t i = stmt_idx; i < stmts.size(); i++) {
+ if (user_stmt == stmts[i]) {
+ auto* user_expr = user->Declaration();
+ ctx.Remove(stmts, let_decl);
+ ctx.Replace(user_expr, ctx.Clone(let->constructor));
+ }
+ if (!AsTrivialLetDecl(stmts[i])) {
+ // Stop if we hit a statement that isn't the single use of the
+ // let, and isn't a let itself.
+ break;
+ }
+ }
+ }
}
- if (!AsTrivialLetDecl(stmts[i])) {
- // Stop if we hit a statement that isn't the single use of the
- // let, and isn't a let itself.
- break;
- }
- }
}
- }
}
- }
- ctx.Clone();
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/fold_trivial_single_use_lets.h b/src/tint/transform/fold_trivial_single_use_lets.h
index d343b76..4036f02 100644
--- a/src/tint/transform/fold_trivial_single_use_lets.h
+++ b/src/tint/transform/fold_trivial_single_use_lets.h
@@ -33,25 +33,22 @@
/// single usage.
/// These rules prevent any hoisting of the let that may affect execution
/// behaviour.
-class FoldTrivialSingleUseLets final
- : public Castable<FoldTrivialSingleUseLets, Transform> {
- public:
- /// Constructor
- FoldTrivialSingleUseLets();
+class FoldTrivialSingleUseLets final : public Castable<FoldTrivialSingleUseLets, Transform> {
+ public:
+ /// Constructor
+ FoldTrivialSingleUseLets();
- /// Destructor
- ~FoldTrivialSingleUseLets() override;
+ /// Destructor
+ ~FoldTrivialSingleUseLets() override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/fold_trivial_single_use_lets_test.cc b/src/tint/transform/fold_trivial_single_use_lets_test.cc
index e08c191..00159e9 100644
--- a/src/tint/transform/fold_trivial_single_use_lets_test.cc
+++ b/src/tint/transform/fold_trivial_single_use_lets_test.cc
@@ -22,35 +22,35 @@
using FoldTrivialSingleUseLetsTest = TransformTest;
TEST_F(FoldTrivialSingleUseLetsTest, EmptyModule) {
- auto* src = "";
- auto* expect = "";
+ auto* src = "";
+ auto* expect = "";
- auto got = Run<FoldTrivialSingleUseLets>(src);
+ auto got = Run<FoldTrivialSingleUseLets>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldTrivialSingleUseLetsTest, Single) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let x = 1;
_ = x;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
_ = 1;
}
)";
- auto got = Run<FoldTrivialSingleUseLets>(src);
+ auto got = Run<FoldTrivialSingleUseLets>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldTrivialSingleUseLetsTest, Multiple) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let x = 1;
let y = 2;
@@ -59,19 +59,19 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
_ = ((1 + 2) + 3);
}
)";
- auto got = Run<FoldTrivialSingleUseLets>(src);
+ auto got = Run<FoldTrivialSingleUseLets>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldTrivialSingleUseLetsTest, Chained) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let x = 1;
let y = x;
@@ -80,19 +80,19 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
_ = 1;
}
)";
- auto got = Run<FoldTrivialSingleUseLets>(src);
+ auto got = Run<FoldTrivialSingleUseLets>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldTrivialSingleUseLetsTest, NoFold_NonTrivialLet) {
- auto* src = R"(
+ auto* src = R"(
fn function_with_posssible_side_effect() -> i32 {
return 1;
}
@@ -104,15 +104,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<FoldTrivialSingleUseLets>(src);
+ auto got = Run<FoldTrivialSingleUseLets>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldTrivialSingleUseLetsTest, NoFold_NonTrivialLet_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let x = 1;
let y = function_with_posssible_side_effect();
@@ -124,15 +124,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<FoldTrivialSingleUseLets>(src);
+ auto got = Run<FoldTrivialSingleUseLets>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldTrivialSingleUseLetsTest, NoFold_UseInSubBlock) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let x = 1;
{
@@ -141,30 +141,30 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<FoldTrivialSingleUseLets>(src);
+ auto got = Run<FoldTrivialSingleUseLets>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldTrivialSingleUseLetsTest, NoFold_MultipleUses) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let x = 1;
_ = (x + x);
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<FoldTrivialSingleUseLets>(src);
+ auto got = Run<FoldTrivialSingleUseLets>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(FoldTrivialSingleUseLetsTest, NoFold_Shadowing) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var y = 1;
let x = y;
@@ -175,11 +175,11 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<FoldTrivialSingleUseLets>(src);
+ auto got = Run<FoldTrivialSingleUseLets>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/for_loop_to_loop.cc b/src/tint/transform/for_loop_to_loop.cc
index 14d5edb..8fff0a8 100644
--- a/src/tint/transform/for_loop_to_loop.cc
+++ b/src/tint/transform/for_loop_to_loop.cc
@@ -25,50 +25,48 @@
ForLoopToLoop::~ForLoopToLoop() = default;
bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const {
- for (auto* node : program->ASTNodes().Objects()) {
- if (node->Is<ast::ForLoopStatement>()) {
- return true;
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (node->Is<ast::ForLoopStatement>()) {
+ return true;
+ }
}
- }
- return false;
+ return false;
}
void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- ctx.ReplaceAll(
- [&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* {
+ ctx.ReplaceAll([&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* {
ast::StatementList stmts;
if (auto* cond = for_loop->condition) {
- // !condition
- auto* not_cond = ctx.dst->create<ast::UnaryOpExpression>(
- ast::UnaryOp::kNot, ctx.Clone(cond));
+ // !condition
+ auto* not_cond =
+ ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond));
- // { break; }
- auto* break_body =
- ctx.dst->Block(ctx.dst->create<ast::BreakStatement>());
+ // { break; }
+ auto* break_body = ctx.dst->Block(ctx.dst->create<ast::BreakStatement>());
- // if (!condition) { break; }
- stmts.emplace_back(ctx.dst->If(not_cond, break_body));
+ // if (!condition) { break; }
+ stmts.emplace_back(ctx.dst->If(not_cond, break_body));
}
for (auto* stmt : for_loop->body->statements) {
- stmts.emplace_back(ctx.Clone(stmt));
+ stmts.emplace_back(ctx.Clone(stmt));
}
const ast::BlockStatement* continuing = nullptr;
if (auto* cont = for_loop->continuing) {
- continuing = ctx.dst->Block(ctx.Clone(cont));
+ continuing = ctx.dst->Block(ctx.Clone(cont));
}
auto* body = ctx.dst->Block(stmts);
auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing);
if (auto* init = for_loop->initializer) {
- return ctx.dst->Block(ctx.Clone(init), loop);
+ return ctx.dst->Block(ctx.Clone(init), loop);
}
return loop;
- });
+ });
- ctx.Clone();
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/for_loop_to_loop.h b/src/tint/transform/for_loop_to_loop.h
index 54286d4..5ab690a 100644
--- a/src/tint/transform/for_loop_to_loop.h
+++ b/src/tint/transform/for_loop_to_loop.h
@@ -22,29 +22,26 @@
/// ForLoopToLoop is a Transform that converts a for-loop statement into a loop
/// statement. This is required by the SPIR-V writer.
class ForLoopToLoop final : public Castable<ForLoopToLoop, Transform> {
- public:
- /// Constructor
- ForLoopToLoop();
+ public:
+ /// Constructor
+ ForLoopToLoop();
- /// Destructor
- ~ForLoopToLoop() override;
+ /// Destructor
+ ~ForLoopToLoop() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/for_loop_to_loop_test.cc b/src/tint/transform/for_loop_to_loop_test.cc
index 84ffa98..172e1fc 100644
--- a/src/tint/transform/for_loop_to_loop_test.cc
+++ b/src/tint/transform/for_loop_to_loop_test.cc
@@ -22,13 +22,13 @@
using ForLoopToLoopTest = TransformTest;
TEST_F(ForLoopToLoopTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<ForLoopToLoop>(src));
+ EXPECT_FALSE(ShouldRun<ForLoopToLoop>(src));
}
TEST_F(ForLoopToLoopTest, ShouldRunHasForLoop) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
for (;;) {
break;
@@ -36,21 +36,21 @@
}
)";
- EXPECT_TRUE(ShouldRun<ForLoopToLoop>(src));
+ EXPECT_TRUE(ShouldRun<ForLoopToLoop>(src));
}
TEST_F(ForLoopToLoopTest, EmptyModule) {
- auto* src = "";
- auto* expect = src;
+ auto* src = "";
+ auto* expect = src;
- auto got = Run<ForLoopToLoop>(src);
+ auto got = Run<ForLoopToLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Test an empty for loop.
TEST_F(ForLoopToLoopTest, Empty) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
for (;;) {
break;
@@ -58,7 +58,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
loop {
break;
@@ -66,14 +66,14 @@
}
)";
- auto got = Run<ForLoopToLoop>(src);
+ auto got = Run<ForLoopToLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Test a for loop with non-empty body.
TEST_F(ForLoopToLoopTest, Body) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
for (;;) {
discard;
@@ -81,7 +81,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
loop {
discard;
@@ -89,14 +89,14 @@
}
)";
- auto got = Run<ForLoopToLoop>(src);
+ auto got = Run<ForLoopToLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Test a for loop declaring a variable in the initializer statement.
TEST_F(ForLoopToLoopTest, InitializerStatementDecl) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
for (var i: i32;;) {
break;
@@ -104,7 +104,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
{
var i : i32;
@@ -115,15 +115,15 @@
}
)";
- auto got = Run<ForLoopToLoop>(src);
+ auto got = Run<ForLoopToLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Test a for loop declaring and initializing a variable in the initializer
// statement.
TEST_F(ForLoopToLoopTest, InitializerStatementDeclEqual) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
for (var i: i32 = 0;;) {
break;
@@ -131,7 +131,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
{
var i : i32 = 0;
@@ -142,14 +142,14 @@
}
)";
- auto got = Run<ForLoopToLoop>(src);
+ auto got = Run<ForLoopToLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Test a for loop declaring a const variable in the initializer statement.
TEST_F(ForLoopToLoopTest, InitializerStatementConstDecl) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
for (let i: i32 = 0;;) {
break;
@@ -157,7 +157,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
{
let i : i32 = 0;
@@ -168,14 +168,14 @@
}
)";
- auto got = Run<ForLoopToLoop>(src);
+ auto got = Run<ForLoopToLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Test a for loop assigning a variable in the initializer statement.
TEST_F(ForLoopToLoopTest, InitializerStatementAssignment) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i: i32;
for (i = 0;;) {
@@ -184,7 +184,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i : i32;
{
@@ -196,14 +196,14 @@
}
)";
- auto got = Run<ForLoopToLoop>(src);
+ auto got = Run<ForLoopToLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Test a for loop calling a function in the initializer statement.
TEST_F(ForLoopToLoopTest, InitializerStatementFuncCall) {
- auto* src = R"(
+ auto* src = R"(
fn a(x : i32, y : i32) {
}
@@ -216,7 +216,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(x : i32, y : i32) {
}
@@ -232,21 +232,21 @@
}
)";
- auto got = Run<ForLoopToLoop>(src);
+ auto got = Run<ForLoopToLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Test a for loop with a break condition
TEST_F(ForLoopToLoopTest, BreakCondition) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
for (; 0 == 1;) {
}
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
loop {
if (!((0 == 1))) {
@@ -256,14 +256,14 @@
}
)";
- auto got = Run<ForLoopToLoop>(src);
+ auto got = Run<ForLoopToLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Test a for loop assigning a variable in the continuing statement.
TEST_F(ForLoopToLoopTest, ContinuingAssignment) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var x: i32;
for (;;x = 2) {
@@ -272,7 +272,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var x : i32;
loop {
@@ -285,14 +285,14 @@
}
)";
- auto got = Run<ForLoopToLoop>(src);
+ auto got = Run<ForLoopToLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Test a for loop calling a function in the continuing statement.
TEST_F(ForLoopToLoopTest, ContinuingFuncCall) {
- auto* src = R"(
+ auto* src = R"(
fn a(x : i32, y : i32) {
}
@@ -305,7 +305,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(x : i32, y : i32) {
}
@@ -322,14 +322,14 @@
}
)";
- auto got = Run<ForLoopToLoop>(src);
+ auto got = Run<ForLoopToLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Test a for loop with all statements non-empty.
TEST_F(ForLoopToLoopTest, All) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var a : i32;
for(var i : i32 = 0; i < 4; i = i + 1) {
@@ -341,7 +341,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var a : i32;
{
@@ -363,9 +363,9 @@
}
)";
- auto got = Run<ForLoopToLoop>(src);
+ auto got = Run<ForLoopToLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/localize_struct_array_assignment.cc b/src/tint/transform/localize_struct_array_assignment.cc
index fda3faa..d6cdded 100644
--- a/src/tint/transform/localize_struct_array_assignment.cc
+++ b/src/tint/transform/localize_struct_array_assignment.cc
@@ -34,174 +34,166 @@
/// Private implementation of LocalizeStructArrayAssignment transform
class LocalizeStructArrayAssignment::State {
- private:
- CloneContext& ctx;
- ProgramBuilder& b;
+ private:
+ CloneContext& ctx;
+ ProgramBuilder& b;
- /// Returns true if `expr` contains an index accessor expression to a
- /// structure member of array type.
- bool ContainsStructArrayIndex(const ast::Expression* expr) {
- bool result = false;
- ast::TraverseExpressions(
- expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) {
- // Indexing using a runtime value?
- auto* idx_sem = ctx.src->Sem().Get(ia->index);
- if (!idx_sem->ConstantValue().IsValid()) {
- // Indexing a member access expr?
- if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) {
- // That accesses an array?
- if (ctx.src->TypeOf(ma)->UnwrapRef()->Is<sem::Array>()) {
- result = true;
- return ast::TraverseAction::Stop;
- }
+ /// Returns true if `expr` contains an index accessor expression to a
+ /// structure member of array type.
+ bool ContainsStructArrayIndex(const ast::Expression* expr) {
+ bool result = false;
+ ast::TraverseExpressions(
+ expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) {
+ // Indexing using a runtime value?
+ auto* idx_sem = ctx.src->Sem().Get(ia->index);
+ if (!idx_sem->ConstantValue().IsValid()) {
+ // Indexing a member access expr?
+ if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) {
+ // That accesses an array?
+ if (ctx.src->TypeOf(ma)->UnwrapRef()->Is<sem::Array>()) {
+ result = true;
+ return ast::TraverseAction::Stop;
+ }
+ }
+ }
+ return ast::TraverseAction::Descend;
+ });
+
+ return result;
+ }
+
+ // Returns the type and storage class of the originating variable of the lhs
+ // of the assignment statement.
+ // See https://www.w3.org/TR/WGSL/#originating-variable-section
+ std::pair<const sem::Type*, ast::StorageClass> GetOriginatingTypeAndStorageClass(
+ const ast::AssignmentStatement* assign_stmt) {
+ auto* source_var = ctx.src->Sem().Get(assign_stmt->lhs)->SourceVariable();
+ if (!source_var) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "Unable to determine originating variable for lhs of assignment "
+ "statement";
+ return {};
+ }
+
+ auto* type = source_var->Type();
+ if (auto* ref = type->As<sem::Reference>()) {
+ return {ref->StoreType(), ref->StorageClass()};
+ } else if (auto* ptr = type->As<sem::Pointer>()) {
+ return {ptr->StoreType(), ptr->StorageClass()};
+ }
+
+ TINT_ICE(Transform, b.Diagnostics())
+ << "Expecting to find variable of type pointer or reference on lhs "
+ "of assignment statement";
+ return {};
+ }
+
+ public:
+ /// Constructor
+ /// @param ctx_in the CloneContext primed with the input program and
+ /// ProgramBuilder
+ explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
+
+ /// Runs the transform
+ void Run() {
+ struct Shared {
+ bool process_nested_nodes = false;
+ ast::StatementList insert_before_stmts;
+ ast::StatementList insert_after_stmts;
+ } s;
+
+ ctx.ReplaceAll([&](const ast::AssignmentStatement* assign_stmt) -> const ast::Statement* {
+ // Process if it's an assignment statement to a dynamically indexed array
+ // within a struct on a function or private storage variable. This
+ // specific use-case is what FXC fails to compile with:
+ // error X3500: array reference cannot be used as an l-value; not natively
+ // addressable
+ if (!ContainsStructArrayIndex(assign_stmt->lhs)) {
+ return nullptr;
}
- }
- return ast::TraverseAction::Descend;
+ auto og = GetOriginatingTypeAndStorageClass(assign_stmt);
+ if (!(og.first->Is<sem::Struct>() && (og.second == ast::StorageClass::kFunction ||
+ og.second == ast::StorageClass::kPrivate))) {
+ return nullptr;
+ }
+
+ // Reset shared state for this assignment statement
+ s = Shared{};
+
+ const ast::Expression* new_lhs = nullptr;
+ {
+ TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, true);
+ new_lhs = ctx.Clone(assign_stmt->lhs);
+ }
+
+ auto* new_assign_stmt = b.Assign(new_lhs, ctx.Clone(assign_stmt->rhs));
+
+ // Combine insert_before_stmts + new_assign_stmt + insert_after_stmts into
+ // a block and return it
+ ast::StatementList stmts = std::move(s.insert_before_stmts);
+ stmts.reserve(1 + s.insert_after_stmts.size());
+ stmts.emplace_back(new_assign_stmt);
+ stmts.insert(stmts.end(), s.insert_after_stmts.begin(), s.insert_after_stmts.end());
+
+ return b.Block(std::move(stmts));
});
- return result;
- }
+ ctx.ReplaceAll(
+ [&](const ast::IndexAccessorExpression* index_access) -> const ast::Expression* {
+ if (!s.process_nested_nodes) {
+ return nullptr;
+ }
- // Returns the type and storage class of the originating variable of the lhs
- // of the assignment statement.
- // See https://www.w3.org/TR/WGSL/#originating-variable-section
- std::pair<const sem::Type*, ast::StorageClass>
- GetOriginatingTypeAndStorageClass(
- const ast::AssignmentStatement* assign_stmt) {
- auto* source_var = ctx.src->Sem().Get(assign_stmt->lhs)->SourceVariable();
- if (!source_var) {
- TINT_ICE(Transform, b.Diagnostics())
- << "Unable to determine originating variable for lhs of assignment "
- "statement";
- return {};
+ // Indexing a member access expr?
+ auto* mem_access = index_access->object->As<ast::MemberAccessorExpression>();
+ if (!mem_access) {
+ return nullptr;
+ }
+
+ // Process any nested IndexAccessorExpressions
+ mem_access = ctx.Clone(mem_access);
+
+ // Store the address of the member access into a let as we need to read
+ // the value twice e.g. let tint_symbol = &(s.a1);
+ auto mem_access_ptr = b.Sym();
+ s.insert_before_stmts.push_back(
+ b.Decl(b.Let(mem_access_ptr, nullptr, b.AddressOf(mem_access))));
+
+ // Disable further transforms when cloning
+ TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, false);
+
+ // Copy entire array out of struct into local temp var
+ // e.g. var tint_symbol_1 = *(tint_symbol);
+ auto tmp_var = b.Sym();
+ s.insert_before_stmts.push_back(
+ b.Decl(b.Var(tmp_var, nullptr, b.Deref(mem_access_ptr))));
+
+ // Replace input index_access with a clone of itself, but with its
+ // .object replaced by the new temp var. This is returned from this
+ // function to modify the original assignment statement. e.g.
+ // tint_symbol_1[uniforms.i]
+ auto* new_index_access = b.IndexAccessor(tmp_var, ctx.Clone(index_access->index));
+
+ // Assign temp var back to array
+ // e.g. *(tint_symbol) = tint_symbol_1;
+ auto* assign_rhs_to_temp = b.Assign(b.Deref(mem_access_ptr), tmp_var);
+ s.insert_after_stmts.insert(s.insert_after_stmts.begin(),
+ assign_rhs_to_temp); // push_front
+
+ return new_index_access;
+ });
+
+ ctx.Clone();
}
-
- auto* type = source_var->Type();
- if (auto* ref = type->As<sem::Reference>()) {
- return {ref->StoreType(), ref->StorageClass()};
- } else if (auto* ptr = type->As<sem::Pointer>()) {
- return {ptr->StoreType(), ptr->StorageClass()};
- }
-
- TINT_ICE(Transform, b.Diagnostics())
- << "Expecting to find variable of type pointer or reference on lhs "
- "of assignment statement";
- return {};
- }
-
- public:
- /// Constructor
- /// @param ctx_in the CloneContext primed with the input program and
- /// ProgramBuilder
- explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
-
- /// Runs the transform
- void Run() {
- struct Shared {
- bool process_nested_nodes = false;
- ast::StatementList insert_before_stmts;
- ast::StatementList insert_after_stmts;
- } s;
-
- ctx.ReplaceAll([&](const ast::AssignmentStatement* assign_stmt)
- -> const ast::Statement* {
- // Process if it's an assignment statement to a dynamically indexed array
- // within a struct on a function or private storage variable. This
- // specific use-case is what FXC fails to compile with:
- // error X3500: array reference cannot be used as an l-value; not natively
- // addressable
- if (!ContainsStructArrayIndex(assign_stmt->lhs)) {
- return nullptr;
- }
- auto og = GetOriginatingTypeAndStorageClass(assign_stmt);
- if (!(og.first->Is<sem::Struct>() &&
- (og.second == ast::StorageClass::kFunction ||
- og.second == ast::StorageClass::kPrivate))) {
- return nullptr;
- }
-
- // Reset shared state for this assignment statement
- s = Shared{};
-
- const ast::Expression* new_lhs = nullptr;
- {
- TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, true);
- new_lhs = ctx.Clone(assign_stmt->lhs);
- }
-
- auto* new_assign_stmt = b.Assign(new_lhs, ctx.Clone(assign_stmt->rhs));
-
- // Combine insert_before_stmts + new_assign_stmt + insert_after_stmts into
- // a block and return it
- ast::StatementList stmts = std::move(s.insert_before_stmts);
- stmts.reserve(1 + s.insert_after_stmts.size());
- stmts.emplace_back(new_assign_stmt);
- stmts.insert(stmts.end(), s.insert_after_stmts.begin(),
- s.insert_after_stmts.end());
-
- return b.Block(std::move(stmts));
- });
-
- ctx.ReplaceAll([&](const ast::IndexAccessorExpression* index_access)
- -> const ast::Expression* {
- if (!s.process_nested_nodes) {
- return nullptr;
- }
-
- // Indexing a member access expr?
- auto* mem_access =
- index_access->object->As<ast::MemberAccessorExpression>();
- if (!mem_access) {
- return nullptr;
- }
-
- // Process any nested IndexAccessorExpressions
- mem_access = ctx.Clone(mem_access);
-
- // Store the address of the member access into a let as we need to read
- // the value twice e.g. let tint_symbol = &(s.a1);
- auto mem_access_ptr = b.Sym();
- s.insert_before_stmts.push_back(
- b.Decl(b.Let(mem_access_ptr, nullptr, b.AddressOf(mem_access))));
-
- // Disable further transforms when cloning
- TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, false);
-
- // Copy entire array out of struct into local temp var
- // e.g. var tint_symbol_1 = *(tint_symbol);
- auto tmp_var = b.Sym();
- s.insert_before_stmts.push_back(
- b.Decl(b.Var(tmp_var, nullptr, b.Deref(mem_access_ptr))));
-
- // Replace input index_access with a clone of itself, but with its
- // .object replaced by the new temp var. This is returned from this
- // function to modify the original assignment statement. e.g.
- // tint_symbol_1[uniforms.i]
- auto* new_index_access =
- b.IndexAccessor(tmp_var, ctx.Clone(index_access->index));
-
- // Assign temp var back to array
- // e.g. *(tint_symbol) = tint_symbol_1;
- auto* assign_rhs_to_temp = b.Assign(b.Deref(mem_access_ptr), tmp_var);
- s.insert_after_stmts.insert(s.insert_after_stmts.begin(),
- assign_rhs_to_temp); // push_front
-
- return new_index_access;
- });
-
- ctx.Clone();
- }
};
LocalizeStructArrayAssignment::LocalizeStructArrayAssignment() = default;
LocalizeStructArrayAssignment::~LocalizeStructArrayAssignment() = default;
-void LocalizeStructArrayAssignment::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- State state(ctx);
- state.Run();
+void LocalizeStructArrayAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ State state(ctx);
+ state.Run();
}
} // namespace tint::transform
diff --git a/src/tint/transform/localize_struct_array_assignment.h b/src/tint/transform/localize_struct_array_assignment.h
index 2c45203..129c849 100644
--- a/src/tint/transform/localize_struct_array_assignment.h
+++ b/src/tint/transform/localize_struct_array_assignment.h
@@ -27,28 +27,25 @@
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
-class LocalizeStructArrayAssignment
- : public Castable<LocalizeStructArrayAssignment, Transform> {
- public:
- /// Constructor
- LocalizeStructArrayAssignment();
+class LocalizeStructArrayAssignment : public Castable<LocalizeStructArrayAssignment, Transform> {
+ public:
+ /// Constructor
+ LocalizeStructArrayAssignment();
- /// Destructor
- ~LocalizeStructArrayAssignment() override;
+ /// Destructor
+ ~LocalizeStructArrayAssignment() override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
- private:
- class State;
+ private:
+ class State;
};
} // namespace tint::transform
diff --git a/src/tint/transform/localize_struct_array_assignment_test.cc b/src/tint/transform/localize_struct_array_assignment_test.cc
index d202785..ee6df9f 100644
--- a/src/tint/transform/localize_struct_array_assignment_test.cc
+++ b/src/tint/transform/localize_struct_array_assignment_test.cc
@@ -24,15 +24,14 @@
using LocalizeStructArrayAssignmentTest = TransformTest;
TEST_F(LocalizeStructArrayAssignmentTest, EmptyModule) {
- auto* src = R"()";
- auto* expect = src;
- auto got =
- Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ auto* src = R"()";
+ auto* expect = src;
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, StructArray) {
- auto* src = R"(
+ auto* src = R"(
struct Uniforms {
i : u32,
};
@@ -55,7 +54,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct Uniforms {
i : u32,
}
@@ -83,13 +82,12 @@
}
)";
- auto got =
- Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, StructArray_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
@@ -112,7 +110,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
@@ -140,13 +138,12 @@
}
)";
- auto got =
- Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, StructStructArray) {
- auto* src = R"(
+ auto* src = R"(
struct Uniforms {
i : u32,
};
@@ -173,7 +170,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct Uniforms {
i : u32,
}
@@ -205,13 +202,12 @@
}
)";
- auto got =
- Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, StructStructArray_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
@@ -238,7 +234,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
@@ -270,13 +266,12 @@
}
)";
- auto got =
- Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, StructArrayArray) {
- auto* src = R"(
+ auto* src = R"(
struct Uniforms {
i : u32,
j : u32,
@@ -300,7 +295,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct Uniforms {
i : u32,
j : u32,
@@ -329,13 +324,12 @@
}
)";
- auto got =
- Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, StructArrayStruct) {
- auto* src = R"(
+ auto* src = R"(
struct Uniforms {
i : u32,
};
@@ -362,7 +356,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct Uniforms {
i : u32,
}
@@ -394,13 +388,12 @@
}
)";
- auto got =
- Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, StructArrayStructArray) {
- auto* src = R"(
+ auto* src = R"(
struct Uniforms {
i : u32,
j : u32,
@@ -428,7 +421,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct Uniforms {
i : u32,
j : u32,
@@ -464,13 +457,12 @@
}
)";
- auto got =
- Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, IndexingWithSideEffectFunc) {
- auto* src = R"(
+ auto* src = R"(
struct Uniforms {
i : u32,
j : u32,
@@ -504,7 +496,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct Uniforms {
i : u32,
j : u32,
@@ -547,14 +539,12 @@
}
)";
- auto got =
- Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(LocalizeStructArrayAssignmentTest,
- IndexingWithSideEffectFunc_OutOfOrder) {
- auto* src = R"(
+TEST_F(LocalizeStructArrayAssignmentTest, IndexingWithSideEffectFunc_OutOfOrder) {
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
@@ -588,7 +578,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
@@ -631,13 +621,12 @@
}
)";
- auto got =
- Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, ViaPointerArg) {
- auto* src = R"(
+ auto* src = R"(
struct Uniforms {
i : u32,
};
@@ -661,7 +650,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct Uniforms {
i : u32,
}
@@ -693,13 +682,12 @@
}
)";
- auto got =
- Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, ViaPointerArg_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var s1 : OuterS;
@@ -725,7 +713,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var s1 : OuterS;
@@ -757,13 +745,12 @@
}
)";
- auto got =
- Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, ViaPointerVar) {
- auto* src = R"(
+ auto* src = R"(
struct Uniforms {
i : u32,
};
@@ -791,7 +778,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct Uniforms {
i : u32,
}
@@ -824,13 +811,12 @@
}
)";
- auto got =
- Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, VectorAssignment) {
- auto* src = R"(
+ auto* src = R"(
struct Uniforms {
i : u32,
}
@@ -854,13 +840,12 @@
}
)";
- // Transform does nothing here as we're not actually assigning to the array in
- // the struct.
- auto* expect = src;
+ // Transform does nothing here as we're not actually assigning to the array in
+ // the struct.
+ auto* expect = src;
- auto got =
- Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/loop_to_for_loop.cc b/src/tint/transform/loop_to_for_loop.cc
index f5bc05d..3e0a4b5 100644
--- a/src/tint/transform/loop_to_for_loop.cc
+++ b/src/tint/transform/loop_to_for_loop.cc
@@ -29,24 +29,22 @@
namespace {
bool IsBlockWithSingleBreak(const ast::BlockStatement* block) {
- if (block->statements.size() != 1) {
- return false;
- }
- return block->statements[0]->Is<ast::BreakStatement>();
+ if (block->statements.size() != 1) {
+ return false;
+ }
+ return block->statements[0]->Is<ast::BreakStatement>();
}
-bool IsVarUsedByStmt(const sem::Info& sem,
- const ast::Variable* var,
- const ast::Statement* stmt) {
- auto* var_sem = sem.Get(var);
- for (auto* user : var_sem->Users()) {
- if (auto* s = user->Stmt()) {
- if (s->Declaration() == stmt) {
- return true;
- }
+bool IsVarUsedByStmt(const sem::Info& sem, const ast::Variable* var, const ast::Statement* stmt) {
+ auto* var_sem = sem.Get(var);
+ for (auto* user : var_sem->Users()) {
+ if (auto* s = user->Stmt()) {
+ if (s->Declaration() == stmt) {
+ return true;
+ }
+ }
}
- }
- return false;
+ return false;
}
} // namespace
@@ -56,88 +54,83 @@
LoopToForLoop::~LoopToForLoop() = default;
bool LoopToForLoop::ShouldRun(const Program* program, const DataMap&) const {
- for (auto* node : program->ASTNodes().Objects()) {
- if (node->Is<ast::LoopStatement>()) {
- return true;
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (node->Is<ast::LoopStatement>()) {
+ return true;
+ }
}
- }
- return false;
+ return false;
}
void LoopToForLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- ctx.ReplaceAll([&](const ast::LoopStatement* loop) -> const ast::Statement* {
- // For loop condition is taken from the first statement in the loop.
- // This requires an if-statement with either:
- // * A true block with no else statements, and the true block contains a
- // single 'break' statement.
- // * An empty true block with a single, no-condition else statement
- // containing a single 'break' statement.
- // Examples:
- // loop { if (condition) { break; } ... }
- // loop { if (condition) {} else { break; } ... }
- auto& stmts = loop->body->statements;
- if (stmts.empty()) {
- return nullptr;
- }
- auto* if_stmt = stmts[0]->As<ast::IfStatement>();
- if (!if_stmt) {
- return nullptr;
- }
- auto* else_stmt = tint::As<ast::BlockStatement>(if_stmt->else_statement);
-
- bool negate_condition = false;
- if (IsBlockWithSingleBreak(if_stmt->body) &&
- if_stmt->else_statement == nullptr) {
- negate_condition = true;
- } else if (if_stmt->body->Empty() && else_stmt &&
- IsBlockWithSingleBreak(else_stmt)) {
- negate_condition = false;
- } else {
- return nullptr;
- }
-
- // The continuing block must be empty or contain a single, assignment or
- // function call statement.
- const ast::Statement* continuing = nullptr;
- if (auto* loop_cont = loop->continuing) {
- if (loop_cont->statements.size() != 1) {
- return nullptr;
- }
-
- continuing = loop_cont->statements[0];
- if (!continuing
- ->IsAnyOf<ast::AssignmentStatement, ast::CallStatement>()) {
- return nullptr;
- }
-
- // And the continuing statement must not use any of the variables declared
- // in the loop body.
- for (auto* stmt : loop->body->statements) {
- if (auto* var_decl = stmt->As<ast::VariableDeclStatement>()) {
- if (IsVarUsedByStmt(ctx.src->Sem(), var_decl->variable, continuing)) {
+ ctx.ReplaceAll([&](const ast::LoopStatement* loop) -> const ast::Statement* {
+ // For loop condition is taken from the first statement in the loop.
+ // This requires an if-statement with either:
+ // * A true block with no else statements, and the true block contains a
+ // single 'break' statement.
+ // * An empty true block with a single, no-condition else statement
+ // containing a single 'break' statement.
+ // Examples:
+ // loop { if (condition) { break; } ... }
+ // loop { if (condition) {} else { break; } ... }
+ auto& stmts = loop->body->statements;
+ if (stmts.empty()) {
return nullptr;
- }
}
- }
+ auto* if_stmt = stmts[0]->As<ast::IfStatement>();
+ if (!if_stmt) {
+ return nullptr;
+ }
+ auto* else_stmt = tint::As<ast::BlockStatement>(if_stmt->else_statement);
- continuing = ctx.Clone(continuing);
- }
+ bool negate_condition = false;
+ if (IsBlockWithSingleBreak(if_stmt->body) && if_stmt->else_statement == nullptr) {
+ negate_condition = true;
+ } else if (if_stmt->body->Empty() && else_stmt && IsBlockWithSingleBreak(else_stmt)) {
+ negate_condition = false;
+ } else {
+ return nullptr;
+ }
- auto* condition = ctx.Clone(if_stmt->condition);
- if (negate_condition) {
- condition = ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot,
- condition);
- }
+ // The continuing block must be empty or contain a single, assignment or
+ // function call statement.
+ const ast::Statement* continuing = nullptr;
+ if (auto* loop_cont = loop->continuing) {
+ if (loop_cont->statements.size() != 1) {
+ return nullptr;
+ }
- ast::Statement* initializer = nullptr;
+ continuing = loop_cont->statements[0];
+ if (!continuing->IsAnyOf<ast::AssignmentStatement, ast::CallStatement>()) {
+ return nullptr;
+ }
- ctx.Remove(loop->body->statements, if_stmt);
- auto* body = ctx.Clone(loop->body);
- return ctx.dst->create<ast::ForLoopStatement>(initializer, condition,
- continuing, body);
- });
+ // And the continuing statement must not use any of the variables declared
+ // in the loop body.
+ for (auto* stmt : loop->body->statements) {
+ if (auto* var_decl = stmt->As<ast::VariableDeclStatement>()) {
+ if (IsVarUsedByStmt(ctx.src->Sem(), var_decl->variable, continuing)) {
+ return nullptr;
+ }
+ }
+ }
- ctx.Clone();
+ continuing = ctx.Clone(continuing);
+ }
+
+ auto* condition = ctx.Clone(if_stmt->condition);
+ if (negate_condition) {
+ condition = ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, condition);
+ }
+
+ ast::Statement* initializer = nullptr;
+
+ ctx.Remove(loop->body->statements, if_stmt);
+ auto* body = ctx.Clone(loop->body);
+ return ctx.dst->create<ast::ForLoopStatement>(initializer, condition, continuing, body);
+ });
+
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/loop_to_for_loop.h b/src/tint/transform/loop_to_for_loop.h
index b6482ae..0623d79 100644
--- a/src/tint/transform/loop_to_for_loop.h
+++ b/src/tint/transform/loop_to_for_loop.h
@@ -22,29 +22,26 @@
/// LoopToForLoop is a Transform that attempts to convert WGSL `loop {}`
/// statements into a for-loop statement.
class LoopToForLoop : public Castable<LoopToForLoop, Transform> {
- public:
- /// Constructor
- LoopToForLoop();
+ public:
+ /// Constructor
+ LoopToForLoop();
- /// Destructor
- ~LoopToForLoop() override;
+ /// Destructor
+ ~LoopToForLoop() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/loop_to_for_loop_test.cc b/src/tint/transform/loop_to_for_loop_test.cc
index d4a7693..e3d7ecc 100644
--- a/src/tint/transform/loop_to_for_loop_test.cc
+++ b/src/tint/transform/loop_to_for_loop_test.cc
@@ -22,13 +22,13 @@
using LoopToForLoopTest = TransformTest;
TEST_F(LoopToForLoopTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<LoopToForLoop>(src));
+ EXPECT_FALSE(ShouldRun<LoopToForLoop>(src));
}
TEST_F(LoopToForLoopTest, ShouldRunHasForLoop) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
loop {
break;
@@ -36,20 +36,20 @@
}
)";
- EXPECT_TRUE(ShouldRun<LoopToForLoop>(src));
+ EXPECT_TRUE(ShouldRun<LoopToForLoop>(src));
}
TEST_F(LoopToForLoopTest, EmptyModule) {
- auto* src = "";
- auto* expect = "";
+ auto* src = "";
+ auto* expect = "";
- auto got = Run<LoopToForLoop>(src);
+ auto got = Run<LoopToForLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, IfBreak) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
i = 0;
@@ -67,7 +67,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i : i32;
i = 0;
@@ -77,13 +77,13 @@
}
)";
- auto got = Run<LoopToForLoop>(src);
+ auto got = Run<LoopToForLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, IfElseBreak) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
i = 0;
@@ -102,7 +102,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i : i32;
i = 0;
@@ -112,13 +112,13 @@
}
)";
- auto got = Run<LoopToForLoop>(src);
+ auto got = Run<LoopToForLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, Nested) {
- auto* src = R"(
+ auto* src = R"(
let N = 16u;
fn f() {
@@ -150,7 +150,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
let N = 16u;
fn f() {
@@ -167,13 +167,13 @@
}
)";
- auto got = Run<LoopToForLoop>(src);
+ auto got = Run<LoopToForLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, NoTransform_IfMultipleStmts) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
i = 0;
@@ -191,15 +191,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<LoopToForLoop>(src);
+ auto got = Run<LoopToForLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, NoTransform_IfElseMultipleStmts) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
i = 0;
@@ -218,15 +218,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<LoopToForLoop>(src);
+ auto got = Run<LoopToForLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, NoTransform_ContinuingIsCompound) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
i = 0;
@@ -244,15 +244,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<LoopToForLoop>(src);
+ auto got = Run<LoopToForLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, NoTransform_ContinuingMultipleStmts) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
i = 0;
@@ -270,15 +270,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<LoopToForLoop>(src);
+ auto got = Run<LoopToForLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, NoTransform_ContinuingUsesVarDeclInLoopBody) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
i = 0;
@@ -295,15 +295,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<LoopToForLoop>(src);
+ auto got = Run<LoopToForLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, NoTransform_IfBreakWithElse) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
i = 0;
@@ -321,15 +321,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<LoopToForLoop>(src);
+ auto got = Run<LoopToForLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, NoTransform_IfBreakWithElseIf) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
i = 0;
@@ -347,11 +347,11 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<LoopToForLoop>(src);
+ auto got = Run<LoopToForLoop>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/manager.cc b/src/tint/transform/manager.cc
index a52f175..823474c 100644
--- a/src/tint/transform/manager.cc
+++ b/src/tint/transform/manager.cc
@@ -32,53 +32,47 @@
Manager::~Manager() = default;
Output Manager::Run(const Program* program, const DataMap& data) const {
- const Program* in = program;
+ const Program* in = program;
#if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
- auto print_program = [&](const char* msg, const Transform* transform) {
- auto wgsl = Program::printer(in);
- std::cout << "---------------------------------------------------------"
- << std::endl;
- std::cout << "-- " << msg << " " << transform->TypeInfo().name << ":"
- << std::endl;
- std::cout << "---------------------------------------------------------"
- << std::endl;
- std::cout << wgsl << std::endl;
- std::cout << "---------------------------------------------------------"
- << std::endl
- << std::endl;
- };
+ auto print_program = [&](const char* msg, const Transform* transform) {
+ auto wgsl = Program::printer(in);
+ std::cout << "---------------------------------------------------------" << std::endl;
+ std::cout << "-- " << msg << " " << transform->TypeInfo().name << ":" << std::endl;
+ std::cout << "---------------------------------------------------------" << std::endl;
+ std::cout << wgsl << std::endl;
+ std::cout << "---------------------------------------------------------" << std::endl
+ << std::endl;
+ };
#endif
- Output out;
- for (const auto& transform : transforms_) {
- if (!transform->ShouldRun(in, data)) {
- TINT_IF_PRINT_PROGRAM(std::cout << "Skipping "
- << transform->TypeInfo().name);
- continue;
- }
- TINT_IF_PRINT_PROGRAM(print_program("Input to", transform.get()));
+ Output out;
+ for (const auto& transform : transforms_) {
+ if (!transform->ShouldRun(in, data)) {
+ TINT_IF_PRINT_PROGRAM(std::cout << "Skipping " << transform->TypeInfo().name);
+ continue;
+ }
+ TINT_IF_PRINT_PROGRAM(print_program("Input to", transform.get()));
- auto res = transform->Run(in, data);
- out.program = std::move(res.program);
- out.data.Add(std::move(res.data));
- in = &out.program;
- if (!in->IsValid()) {
- TINT_IF_PRINT_PROGRAM(
- print_program("Invalid output of", transform.get()));
- return out;
+ auto res = transform->Run(in, data);
+ out.program = std::move(res.program);
+ out.data.Add(std::move(res.data));
+ in = &out.program;
+ if (!in->IsValid()) {
+ TINT_IF_PRINT_PROGRAM(print_program("Invalid output of", transform.get()));
+ return out;
+ }
+
+ if (transform == transforms_.back()) {
+ TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get()));
+ }
}
- if (transform == transforms_.back()) {
- TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get()));
+ if (program == in) {
+ out.program = program->Clone();
}
- }
- if (program == in) {
- out.program = program->Clone();
- }
-
- return out;
+ return out;
}
} // namespace tint::transform
diff --git a/src/tint/transform/manager.h b/src/tint/transform/manager.h
index fb614d3..9f5c6bc 100644
--- a/src/tint/transform/manager.h
+++ b/src/tint/transform/manager.h
@@ -28,33 +28,33 @@
/// If any inner transform fails the manager will return immediately and
/// the error can be retrieved with the Output's diagnostics.
class Manager : public Castable<Manager, Transform> {
- public:
- /// Constructor
- Manager();
- ~Manager() override;
+ public:
+ /// Constructor
+ Manager();
+ ~Manager() override;
- /// Add pass to the manager
- /// @param transform the transform to append
- void append(std::unique_ptr<Transform> transform) {
- transforms_.push_back(std::move(transform));
- }
+ /// Add pass to the manager
+ /// @param transform the transform to append
+ void append(std::unique_ptr<Transform> transform) {
+ transforms_.push_back(std::move(transform));
+ }
- /// Add pass to the manager of type `T`, constructed with the provided
- /// arguments.
- /// @param args the arguments to forward to the `T` constructor
- template <typename T, typename... ARGS>
- void Add(ARGS&&... args) {
- transforms_.emplace_back(std::make_unique<T>(std::forward<ARGS>(args)...));
- }
+ /// Add pass to the manager of type `T`, constructed with the provided
+ /// arguments.
+ /// @param args the arguments to forward to the `T` constructor
+ template <typename T, typename... ARGS>
+ void Add(ARGS&&... args) {
+ transforms_.emplace_back(std::make_unique<T>(std::forward<ARGS>(args)...));
+ }
- /// Runs the transforms on `program`, returning the transformation result.
- /// @param program the source program to transform
- /// @param data optional extra transform-specific input data
- /// @returns the transformed program and diagnostics
- Output Run(const Program* program, const DataMap& data = {}) const override;
+ /// Runs the transforms on `program`, returning the transformation result.
+ /// @param program the source program to transform
+ /// @param data optional extra transform-specific input data
+ /// @returns the transformed program and diagnostics
+ Output Run(const Program* program, const DataMap& data = {}) const override;
- private:
- std::vector<std::unique_ptr<Transform>> transforms_;
+ private:
+ std::vector<std::unique_ptr<Transform>> transforms_;
};
} // namespace tint::transform
diff --git a/src/tint/transform/module_scope_var_to_entry_point_param.cc b/src/tint/transform/module_scope_var_to_entry_point_param.cc
index a1789f4..22bcd5c 100644
--- a/src/tint/transform/module_scope_var_to_entry_point_param.cc
+++ b/src/tint/transform/module_scope_var_to_entry_point_param.cc
@@ -32,366 +32,350 @@
namespace {
// Returns `true` if `type` is or contains a matrix type.
bool ContainsMatrix(const sem::Type* type) {
- type = type->UnwrapRef();
- if (type->Is<sem::Matrix>()) {
- return true;
- } else if (auto* ary = type->As<sem::Array>()) {
- return ContainsMatrix(ary->ElemType());
- } else if (auto* str = type->As<sem::Struct>()) {
- for (auto* member : str->Members()) {
- if (ContainsMatrix(member->Type())) {
+ type = type->UnwrapRef();
+ if (type->Is<sem::Matrix>()) {
return true;
- }
+ } else if (auto* ary = type->As<sem::Array>()) {
+ return ContainsMatrix(ary->ElemType());
+ } else if (auto* str = type->As<sem::Struct>()) {
+ for (auto* member : str->Members()) {
+ if (ContainsMatrix(member->Type())) {
+ return true;
+ }
+ }
}
- }
- return false;
+ return false;
}
} // namespace
/// State holds the current transform state.
struct ModuleScopeVarToEntryPointParam::State {
- /// The clone context.
- CloneContext& ctx;
+ /// The clone context.
+ CloneContext& ctx;
- /// Constructor
- /// @param context the clone context
- explicit State(CloneContext& context) : ctx(context) {}
+ /// Constructor
+ /// @param context the clone context
+ explicit State(CloneContext& context) : ctx(context) {}
- /// Clone any struct types that are contained in `ty` (including `ty` itself),
- /// and add it to the global declarations now, so that they precede new global
- /// declarations that need to reference them.
- /// @param ty the type to clone
- void CloneStructTypes(const sem::Type* ty) {
- if (auto* str = ty->As<sem::Struct>()) {
- if (!cloned_structs_.emplace(str).second) {
- // The struct has already been cloned.
- return;
- }
-
- // Recurse into members.
- for (auto* member : str->Members()) {
- CloneStructTypes(member->Type());
- }
-
- // Clone the struct and add it to the global declaration list.
- // Remove the old declaration.
- auto* ast_str = str->Declaration();
- ctx.dst->AST().AddTypeDecl(ctx.Clone(ast_str));
- ctx.Remove(ctx.src->AST().GlobalDeclarations(), ast_str);
- } else if (auto* arr = ty->As<sem::Array>()) {
- CloneStructTypes(arr->ElemType());
- }
- }
-
- /// Process the module.
- void Process() {
- // Predetermine the list of function calls that need to be replaced.
- using CallList = std::vector<const ast::CallExpression*>;
- std::unordered_map<const ast::Function*, CallList> calls_to_replace;
-
- std::vector<const ast::Function*> functions_to_process;
-
- // Build a list of functions that transitively reference any module-scope
- // variables.
- for (auto* func_ast : ctx.src->AST().Functions()) {
- auto* func_sem = ctx.src->Sem().Get(func_ast);
-
- bool needs_processing = false;
- for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
- if (var->StorageClass() != ast::StorageClass::kNone) {
- needs_processing = true;
- break;
- }
- }
- if (needs_processing) {
- functions_to_process.push_back(func_ast);
-
- // Find all of the calls to this function that will need to be replaced.
- for (auto* call : func_sem->CallSites()) {
- calls_to_replace[call->Stmt()->Function()->Declaration()].push_back(
- call->Declaration());
- }
- }
- }
-
- // Build a list of `&ident` expressions. We'll use this later to avoid
- // generating expressions of the form `&*ident`, which break WGSL validation
- // rules when this expression is passed to a function.
- // TODO(jrprice): We should add support for bidirectional SEM tree traversal
- // so that we can do this on the fly instead.
- std::unordered_map<const ast::IdentifierExpression*,
- const ast::UnaryOpExpression*>
- ident_to_address_of;
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- auto* address_of = node->As<ast::UnaryOpExpression>();
- if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
- continue;
- }
- if (auto* ident = address_of->expr->As<ast::IdentifierExpression>()) {
- ident_to_address_of[ident] = address_of;
- }
- }
-
- for (auto* func_ast : functions_to_process) {
- auto* func_sem = ctx.src->Sem().Get(func_ast);
- bool is_entry_point = func_ast->IsEntryPoint();
-
- // Map module-scope variables onto their replacement.
- struct NewVar {
- Symbol symbol;
- bool is_pointer;
- bool is_wrapped;
- };
- const char* kWrappedArrayMemberName = "arr";
- std::unordered_map<const sem::Variable*, NewVar> var_to_newvar;
-
- // We aggregate all workgroup variables into a struct to avoid hitting
- // MSL's limit for threadgroup memory arguments.
- Symbol workgroup_parameter_symbol;
- ast::StructMemberList workgroup_parameter_members;
- auto workgroup_param = [&]() {
- if (!workgroup_parameter_symbol.IsValid()) {
- workgroup_parameter_symbol = ctx.dst->Sym();
- }
- return workgroup_parameter_symbol;
- };
-
- for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
- auto sc = var->StorageClass();
- auto* ty = var->Type()->UnwrapRef();
- if (sc == ast::StorageClass::kNone) {
- continue;
- }
- if (sc != ast::StorageClass::kPrivate &&
- sc != ast::StorageClass::kStorage &&
- sc != ast::StorageClass::kUniform &&
- sc != ast::StorageClass::kHandle &&
- sc != ast::StorageClass::kWorkgroup) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "unhandled module-scope storage class (" << sc << ")";
- }
-
- // This is the symbol for the variable that replaces the module-scope
- // var.
- auto new_var_symbol = ctx.dst->Sym();
-
- // Helper to create an AST node for the store type of the variable.
- auto store_type = [&]() { return CreateASTTypeFor(ctx, ty); };
-
- // Track whether the new variable is a pointer or not.
- bool is_pointer = false;
-
- // Track whether the new variable was wrapped in a struct or not.
- bool is_wrapped = false;
-
- if (is_entry_point) {
- if (var->Type()->UnwrapRef()->is_handle()) {
- // For a texture or sampler variable, redeclare it as an entry point
- // parameter. Disable entry point parameter validation.
- auto* disable_validation =
- ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
- auto attrs = ctx.Clone(var->Declaration()->attributes);
- attrs.push_back(disable_validation);
- auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs);
- ctx.InsertFront(func_ast->params, param);
- } else if (sc == ast::StorageClass::kStorage ||
- sc == ast::StorageClass::kUniform) {
- // Variables into the Storage and Uniform storage classes are
- // redeclared as entry point parameters with a pointer type.
- auto attributes = ctx.Clone(var->Declaration()->attributes);
- attributes.push_back(ctx.dst->Disable(
- ast::DisabledValidation::kEntryPointParameter));
- attributes.push_back(
- ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
-
- auto* param_type = store_type();
- if (auto* arr = ty->As<sem::Array>();
- arr && arr->IsRuntimeSized()) {
- // Wrap runtime-sized arrays in structures, so that we can declare
- // pointers to them. Ideally we'd just emit the array itself as a
- // pointer, but this is not representable in Tint's AST.
- CloneStructTypes(ty);
- auto* wrapper = ctx.dst->Structure(
- ctx.dst->Sym(),
- {ctx.dst->Member(kWrappedArrayMemberName, param_type)});
- param_type = ctx.dst->ty.Of(wrapper);
- is_wrapped = true;
+ /// Clone any struct types that are contained in `ty` (including `ty` itself),
+ /// and add it to the global declarations now, so that they precede new global
+ /// declarations that need to reference them.
+ /// @param ty the type to clone
+ void CloneStructTypes(const sem::Type* ty) {
+ if (auto* str = ty->As<sem::Struct>()) {
+ if (!cloned_structs_.emplace(str).second) {
+ // The struct has already been cloned.
+ return;
}
- param_type = ctx.dst->ty.pointer(
- param_type, sc, var->Declaration()->declared_access);
- auto* param =
- ctx.dst->Param(new_var_symbol, param_type, attributes);
- ctx.InsertFront(func_ast->params, param);
- is_pointer = true;
- } else if (sc == ast::StorageClass::kWorkgroup &&
- ContainsMatrix(var->Type())) {
- // Due to a bug in the MSL compiler, we use a threadgroup memory
- // argument for any workgroup allocation that contains a matrix.
- // See crbug.com/tint/938.
- // TODO(jrprice): Do this for all other workgroup variables too.
+ // Recurse into members.
+ for (auto* member : str->Members()) {
+ CloneStructTypes(member->Type());
+ }
- // Create a member in the workgroup parameter struct.
- auto member = ctx.Clone(var->Declaration()->symbol);
- workgroup_parameter_members.push_back(
- ctx.dst->Member(member, store_type()));
- CloneStructTypes(var->Type()->UnwrapRef());
+ // Clone the struct and add it to the global declaration list.
+ // Remove the old declaration.
+ auto* ast_str = str->Declaration();
+ ctx.dst->AST().AddTypeDecl(ctx.Clone(ast_str));
+ ctx.Remove(ctx.src->AST().GlobalDeclarations(), ast_str);
+ } else if (auto* arr = ty->As<sem::Array>()) {
+ CloneStructTypes(arr->ElemType());
+ }
+ }
- // Create a function-scope variable that is a pointer to the member.
- auto* member_ptr = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
- ctx.dst->Deref(workgroup_param()), member));
- auto* local_var =
- ctx.dst->Let(new_var_symbol,
- ctx.dst->ty.pointer(store_type(),
- ast::StorageClass::kWorkgroup),
- member_ptr);
- ctx.InsertFront(func_ast->body->statements,
- ctx.dst->Decl(local_var));
- is_pointer = true;
- } else {
- // Variables in the Private and Workgroup storage classes are
- // redeclared at function scope. Disable storage class validation on
- // this variable.
- auto* disable_validation =
- ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass);
- auto* constructor = ctx.Clone(var->Declaration()->constructor);
- auto* local_var =
- ctx.dst->Var(new_var_symbol, store_type(), sc, constructor,
- ast::AttributeList{disable_validation});
- ctx.InsertFront(func_ast->body->statements,
- ctx.dst->Decl(local_var));
- }
- } else {
- // For a regular function, redeclare the variable as a parameter.
- // Use a pointer for non-handle types.
- auto* param_type = store_type();
- ast::AttributeList attributes;
- if (!var->Type()->UnwrapRef()->is_handle()) {
- param_type = ctx.dst->ty.pointer(
- param_type, sc, var->Declaration()->declared_access);
- is_pointer = true;
+ /// Process the module.
+ void Process() {
+ // Predetermine the list of function calls that need to be replaced.
+ using CallList = std::vector<const ast::CallExpression*>;
+ std::unordered_map<const ast::Function*, CallList> calls_to_replace;
- // Disable validation of the parameter's storage class and of
- // arguments passed it.
- attributes.push_back(
- ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
- attributes.push_back(ctx.dst->Disable(
- ast::DisabledValidation::kIgnoreInvalidPointerArgument));
- }
- ctx.InsertBack(
- func_ast->params,
- ctx.dst->Param(new_var_symbol, param_type, attributes));
+ std::vector<const ast::Function*> functions_to_process;
+
+ // Build a list of functions that transitively reference any module-scope
+ // variables.
+ for (auto* func_ast : ctx.src->AST().Functions()) {
+ auto* func_sem = ctx.src->Sem().Get(func_ast);
+
+ bool needs_processing = false;
+ for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
+ if (var->StorageClass() != ast::StorageClass::kNone) {
+ needs_processing = true;
+ break;
+ }
+ }
+ if (needs_processing) {
+ functions_to_process.push_back(func_ast);
+
+ // Find all of the calls to this function that will need to be replaced.
+ for (auto* call : func_sem->CallSites()) {
+ calls_to_replace[call->Stmt()->Function()->Declaration()].push_back(
+ call->Declaration());
+ }
+ }
}
- // Replace all uses of the module-scope variable.
- // For non-entry points, dereference non-handle pointer parameters.
- for (auto* user : var->Users()) {
- if (user->Stmt()->Function()->Declaration() == func_ast) {
- const ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
- if (is_pointer) {
- // If this identifier is used by an address-of operator, just
- // remove the address-of instead of adding a deref, since we
- // already have a pointer.
- auto* ident =
- user->Declaration()->As<ast::IdentifierExpression>();
- if (ident_to_address_of.count(ident)) {
- ctx.Replace(ident_to_address_of[ident], expr);
+ // Build a list of `&ident` expressions. We'll use this later to avoid
+ // generating expressions of the form `&*ident`, which break WGSL validation
+ // rules when this expression is passed to a function.
+ // TODO(jrprice): We should add support for bidirectional SEM tree traversal
+ // so that we can do this on the fly instead.
+ std::unordered_map<const ast::IdentifierExpression*, const ast::UnaryOpExpression*>
+ ident_to_address_of;
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ auto* address_of = node->As<ast::UnaryOpExpression>();
+ if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
continue;
- }
-
- expr = ctx.dst->Deref(expr);
}
- if (is_wrapped) {
- // Get the member from the wrapper structure.
- expr = ctx.dst->MemberAccessor(expr, kWrappedArrayMemberName);
+ if (auto* ident = address_of->expr->As<ast::IdentifierExpression>()) {
+ ident_to_address_of[ident] = address_of;
}
- ctx.Replace(user->Declaration(), expr);
- }
}
- var_to_newvar[var] = {new_var_symbol, is_pointer, is_wrapped};
- }
+ for (auto* func_ast : functions_to_process) {
+ auto* func_sem = ctx.src->Sem().Get(func_ast);
+ bool is_entry_point = func_ast->IsEntryPoint();
- if (!workgroup_parameter_members.empty()) {
- // Create the workgroup memory parameter.
- // The parameter is a struct that contains members for each workgroup
- // variable.
- auto* str = ctx.dst->Structure(ctx.dst->Sym(),
- std::move(workgroup_parameter_members));
- auto* param_type = ctx.dst->ty.pointer(ctx.dst->ty.Of(str),
- ast::StorageClass::kWorkgroup);
- auto* disable_validation =
- ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
- auto* param =
- ctx.dst->Param(workgroup_param(), param_type, {disable_validation});
- ctx.InsertFront(func_ast->params, param);
- }
+ // Map module-scope variables onto their replacement.
+ struct NewVar {
+ Symbol symbol;
+ bool is_pointer;
+ bool is_wrapped;
+ };
+ const char* kWrappedArrayMemberName = "arr";
+ std::unordered_map<const sem::Variable*, NewVar> var_to_newvar;
- // Pass the variables as pointers to any functions that need them.
- for (auto* call : calls_to_replace[func_ast]) {
- auto* target =
- ctx.src->AST().Functions().Find(call->target.name->symbol);
- auto* target_sem = ctx.src->Sem().Get(target);
+ // We aggregate all workgroup variables into a struct to avoid hitting
+ // MSL's limit for threadgroup memory arguments.
+ Symbol workgroup_parameter_symbol;
+ ast::StructMemberList workgroup_parameter_members;
+ auto workgroup_param = [&]() {
+ if (!workgroup_parameter_symbol.IsValid()) {
+ workgroup_parameter_symbol = ctx.dst->Sym();
+ }
+ return workgroup_parameter_symbol;
+ };
- // Add new arguments for any variables that are needed by the callee.
- // For entry points, pass non-handle types as pointers.
- for (auto* target_var : target_sem->TransitivelyReferencedGlobals()) {
- auto sc = target_var->StorageClass();
- if (sc == ast::StorageClass::kNone) {
- continue;
- }
+ for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
+ auto sc = var->StorageClass();
+ auto* ty = var->Type()->UnwrapRef();
+ if (sc == ast::StorageClass::kNone) {
+ continue;
+ }
+ if (sc != ast::StorageClass::kPrivate && sc != ast::StorageClass::kStorage &&
+ sc != ast::StorageClass::kUniform && sc != ast::StorageClass::kHandle &&
+ sc != ast::StorageClass::kWorkgroup) {
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "unhandled module-scope storage class (" << sc << ")";
+ }
- auto new_var = var_to_newvar[target_var];
- bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
- const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
- if (new_var.is_wrapped) {
- // The variable is wrapped in a struct, so we need to pass a pointer
- // to the struct member instead.
- arg = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
- ctx.dst->Deref(arg), kWrappedArrayMemberName));
- } else if (is_entry_point && !is_handle && !new_var.is_pointer) {
- // We need to pass a pointer and we don't already have one, so take
- // the address of the new variable.
- arg = ctx.dst->AddressOf(arg);
- }
- ctx.InsertBack(call->args, arg);
+ // This is the symbol for the variable that replaces the module-scope
+ // var.
+ auto new_var_symbol = ctx.dst->Sym();
+
+ // Helper to create an AST node for the store type of the variable.
+ auto store_type = [&]() { return CreateASTTypeFor(ctx, ty); };
+
+ // Track whether the new variable is a pointer or not.
+ bool is_pointer = false;
+
+ // Track whether the new variable was wrapped in a struct or not.
+ bool is_wrapped = false;
+
+ if (is_entry_point) {
+ if (var->Type()->UnwrapRef()->is_handle()) {
+ // For a texture or sampler variable, redeclare it as an entry point
+ // parameter. Disable entry point parameter validation.
+ auto* disable_validation =
+ ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
+ auto attrs = ctx.Clone(var->Declaration()->attributes);
+ attrs.push_back(disable_validation);
+ auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs);
+ ctx.InsertFront(func_ast->params, param);
+ } else if (sc == ast::StorageClass::kStorage ||
+ sc == ast::StorageClass::kUniform) {
+ // Variables into the Storage and Uniform storage classes are
+ // redeclared as entry point parameters with a pointer type.
+ auto attributes = ctx.Clone(var->Declaration()->attributes);
+ attributes.push_back(
+ ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter));
+ attributes.push_back(
+ ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
+
+ auto* param_type = store_type();
+ if (auto* arr = ty->As<sem::Array>(); arr && arr->IsRuntimeSized()) {
+ // Wrap runtime-sized arrays in structures, so that we can declare
+ // pointers to them. Ideally we'd just emit the array itself as a
+ // pointer, but this is not representable in Tint's AST.
+ CloneStructTypes(ty);
+ auto* wrapper = ctx.dst->Structure(
+ ctx.dst->Sym(),
+ {ctx.dst->Member(kWrappedArrayMemberName, param_type)});
+ param_type = ctx.dst->ty.Of(wrapper);
+ is_wrapped = true;
+ }
+
+ param_type = ctx.dst->ty.pointer(param_type, sc,
+ var->Declaration()->declared_access);
+ auto* param = ctx.dst->Param(new_var_symbol, param_type, attributes);
+ ctx.InsertFront(func_ast->params, param);
+ is_pointer = true;
+ } else if (sc == ast::StorageClass::kWorkgroup && ContainsMatrix(var->Type())) {
+ // Due to a bug in the MSL compiler, we use a threadgroup memory
+ // argument for any workgroup allocation that contains a matrix.
+ // See crbug.com/tint/938.
+ // TODO(jrprice): Do this for all other workgroup variables too.
+
+ // Create a member in the workgroup parameter struct.
+ auto member = ctx.Clone(var->Declaration()->symbol);
+ workgroup_parameter_members.push_back(
+ ctx.dst->Member(member, store_type()));
+ CloneStructTypes(var->Type()->UnwrapRef());
+
+ // Create a function-scope variable that is a pointer to the member.
+ auto* member_ptr = ctx.dst->AddressOf(
+ ctx.dst->MemberAccessor(ctx.dst->Deref(workgroup_param()), member));
+ auto* local_var = ctx.dst->Let(
+ new_var_symbol,
+ ctx.dst->ty.pointer(store_type(), ast::StorageClass::kWorkgroup),
+ member_ptr);
+ ctx.InsertFront(func_ast->body->statements, ctx.dst->Decl(local_var));
+ is_pointer = true;
+ } else {
+ // Variables in the Private and Workgroup storage classes are
+ // redeclared at function scope. Disable storage class validation on
+ // this variable.
+ auto* disable_validation =
+ ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass);
+ auto* constructor = ctx.Clone(var->Declaration()->constructor);
+ auto* local_var =
+ ctx.dst->Var(new_var_symbol, store_type(), sc, constructor,
+ ast::AttributeList{disable_validation});
+ ctx.InsertFront(func_ast->body->statements, ctx.dst->Decl(local_var));
+ }
+ } else {
+ // For a regular function, redeclare the variable as a parameter.
+ // Use a pointer for non-handle types.
+ auto* param_type = store_type();
+ ast::AttributeList attributes;
+ if (!var->Type()->UnwrapRef()->is_handle()) {
+ param_type = ctx.dst->ty.pointer(param_type, sc,
+ var->Declaration()->declared_access);
+ is_pointer = true;
+
+ // Disable validation of the parameter's storage class and of
+ // arguments passed it.
+ attributes.push_back(
+ ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
+ attributes.push_back(ctx.dst->Disable(
+ ast::DisabledValidation::kIgnoreInvalidPointerArgument));
+ }
+ ctx.InsertBack(func_ast->params,
+ ctx.dst->Param(new_var_symbol, param_type, attributes));
+ }
+
+ // Replace all uses of the module-scope variable.
+ // For non-entry points, dereference non-handle pointer parameters.
+ for (auto* user : var->Users()) {
+ if (user->Stmt()->Function()->Declaration() == func_ast) {
+ const ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
+ if (is_pointer) {
+ // If this identifier is used by an address-of operator, just
+ // remove the address-of instead of adding a deref, since we
+ // already have a pointer.
+ auto* ident = user->Declaration()->As<ast::IdentifierExpression>();
+ if (ident_to_address_of.count(ident)) {
+ ctx.Replace(ident_to_address_of[ident], expr);
+ continue;
+ }
+
+ expr = ctx.dst->Deref(expr);
+ }
+ if (is_wrapped) {
+ // Get the member from the wrapper structure.
+ expr = ctx.dst->MemberAccessor(expr, kWrappedArrayMemberName);
+ }
+ ctx.Replace(user->Declaration(), expr);
+ }
+ }
+
+ var_to_newvar[var] = {new_var_symbol, is_pointer, is_wrapped};
+ }
+
+ if (!workgroup_parameter_members.empty()) {
+ // Create the workgroup memory parameter.
+ // The parameter is a struct that contains members for each workgroup
+ // variable.
+ auto* str =
+ ctx.dst->Structure(ctx.dst->Sym(), std::move(workgroup_parameter_members));
+ auto* param_type =
+ ctx.dst->ty.pointer(ctx.dst->ty.Of(str), ast::StorageClass::kWorkgroup);
+ auto* disable_validation =
+ ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
+ auto* param = ctx.dst->Param(workgroup_param(), param_type, {disable_validation});
+ ctx.InsertFront(func_ast->params, param);
+ }
+
+ // Pass the variables as pointers to any functions that need them.
+ for (auto* call : calls_to_replace[func_ast]) {
+ auto* target = ctx.src->AST().Functions().Find(call->target.name->symbol);
+ auto* target_sem = ctx.src->Sem().Get(target);
+
+ // Add new arguments for any variables that are needed by the callee.
+ // For entry points, pass non-handle types as pointers.
+ for (auto* target_var : target_sem->TransitivelyReferencedGlobals()) {
+ auto sc = target_var->StorageClass();
+ if (sc == ast::StorageClass::kNone) {
+ continue;
+ }
+
+ auto new_var = var_to_newvar[target_var];
+ bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
+ const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
+ if (new_var.is_wrapped) {
+ // The variable is wrapped in a struct, so we need to pass a pointer
+ // to the struct member instead.
+ arg = ctx.dst->AddressOf(
+ ctx.dst->MemberAccessor(ctx.dst->Deref(arg), kWrappedArrayMemberName));
+ } else if (is_entry_point && !is_handle && !new_var.is_pointer) {
+ // We need to pass a pointer and we don't already have one, so take
+ // the address of the new variable.
+ arg = ctx.dst->AddressOf(arg);
+ }
+ ctx.InsertBack(call->args, arg);
+ }
+ }
}
- }
+
+ // Now remove all module-scope variables with these storage classes.
+ for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
+ auto* var_sem = ctx.src->Sem().Get(var_ast);
+ if (var_sem->StorageClass() != ast::StorageClass::kNone) {
+ ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
+ }
+ }
}
- // Now remove all module-scope variables with these storage classes.
- for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
- auto* var_sem = ctx.src->Sem().Get(var_ast);
- if (var_sem->StorageClass() != ast::StorageClass::kNone) {
- ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
- }
- }
- }
-
- private:
- std::unordered_set<const sem::Struct*> cloned_structs_;
+ private:
+ std::unordered_set<const sem::Struct*> cloned_structs_;
};
ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
-bool ModuleScopeVarToEntryPointParam::ShouldRun(const Program* program,
- const DataMap&) const {
- for (auto* decl : program->AST().GlobalDeclarations()) {
- if (decl->Is<ast::Variable>()) {
- return true;
+bool ModuleScopeVarToEntryPointParam::ShouldRun(const Program* program, const DataMap&) const {
+ for (auto* decl : program->AST().GlobalDeclarations()) {
+ if (decl->Is<ast::Variable>()) {
+ return true;
+ }
}
- }
- return false;
+ return false;
}
-void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- State state{ctx};
- state.Process();
- ctx.Clone();
+void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ State state{ctx};
+ state.Process();
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/module_scope_var_to_entry_point_param.h b/src/tint/transform/module_scope_var_to_entry_point_param.h
index 8297057..f268197 100644
--- a/src/tint/transform/module_scope_var_to_entry_point_param.h
+++ b/src/tint/transform/module_scope_var_to_entry_point_param.h
@@ -63,30 +63,27 @@
/// ```
class ModuleScopeVarToEntryPointParam
: public Castable<ModuleScopeVarToEntryPointParam, Transform> {
- public:
- /// Constructor
- ModuleScopeVarToEntryPointParam();
- /// Destructor
- ~ModuleScopeVarToEntryPointParam() override;
+ public:
+ /// Constructor
+ ModuleScopeVarToEntryPointParam();
+ /// Destructor
+ ~ModuleScopeVarToEntryPointParam() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
- struct State;
+ struct State;
};
} // namespace tint::transform
diff --git a/src/tint/transform/module_scope_var_to_entry_point_param_test.cc b/src/tint/transform/module_scope_var_to_entry_point_param_test.cc
index 3089355..580e695 100644
--- a/src/tint/transform/module_scope_var_to_entry_point_param_test.cc
+++ b/src/tint/transform/module_scope_var_to_entry_point_param_test.cc
@@ -24,21 +24,21 @@
using ModuleScopeVarToEntryPointParamTest = TransformTest;
TEST_F(ModuleScopeVarToEntryPointParamTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<ModuleScopeVarToEntryPointParam>(src));
+ EXPECT_FALSE(ShouldRun<ModuleScopeVarToEntryPointParam>(src));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, ShouldRunHasGlobal) {
- auto* src = R"(
+ auto* src = R"(
var<private> v : i32;
)";
- EXPECT_TRUE(ShouldRun<ModuleScopeVarToEntryPointParam>(src));
+ EXPECT_TRUE(ShouldRun<ModuleScopeVarToEntryPointParam>(src));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Basic) {
- auto* src = R"(
+ auto* src = R"(
var<private> p : f32;
var<workgroup> w : f32;
@@ -48,7 +48,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main() {
@internal(disable_validation__ignore_storage_class) var<workgroup> tint_symbol : f32;
@@ -57,13 +57,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Basic_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
w = p;
@@ -73,7 +73,7 @@
var<private> p : f32;
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main() {
@internal(disable_validation__ignore_storage_class) var<workgroup> tint_symbol : f32;
@@ -82,13 +82,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, FunctionCalls) {
- auto* src = R"(
+ auto* src = R"(
var<private> p : f32;
var<workgroup> w : f32;
@@ -112,7 +112,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn no_uses() {
}
@@ -135,13 +135,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, FunctionCalls_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
foo(1.0);
@@ -165,7 +165,7 @@
var<workgroup> w : f32;
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main() {
@internal(disable_validation__ignore_storage_class) var<private> tint_symbol : f32;
@@ -188,13 +188,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Constructors) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : f32 = 1.0;
var<private> b : f32 = f32();
@@ -204,7 +204,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main() {
@internal(disable_validation__ignore_storage_class) var<private> tint_symbol : f32 = 1.0;
@@ -213,13 +213,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Constructors_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
let x : f32 = a + b;
@@ -229,7 +229,7 @@
var<private> a : f32 = 1.0;
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main() {
@internal(disable_validation__ignore_storage_class) var<private> tint_symbol : f32 = 1.0;
@@ -238,13 +238,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Pointers) {
- auto* src = R"(
+ auto* src = R"(
var<private> p : f32;
var<workgroup> w : f32;
@@ -257,7 +257,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main() {
@internal(disable_validation__ignore_storage_class) var<private> tint_symbol : f32;
@@ -269,13 +269,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Pointers_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
let p_ptr : ptr<private, f32> = &p;
@@ -288,7 +288,7 @@
var<private> p : f32;
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main() {
@internal(disable_validation__ignore_storage_class) var<private> tint_symbol : f32;
@@ -300,13 +300,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, FoldAddressOfDeref) {
- auto* src = R"(
+ auto* src = R"(
var<private> v : f32;
fn bar(p : ptr<private, f32>) {
@@ -323,7 +323,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn bar(p : ptr<private, f32>) {
*(p) = 0.0;
}
@@ -339,13 +339,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, FoldAddressOfDeref_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
foo();
@@ -362,7 +362,7 @@
var<private> v : f32;
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main() {
@internal(disable_validation__ignore_storage_class) var<private> tint_symbol : f32;
@@ -378,13 +378,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffers_Basic) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : f32,
};
@@ -401,7 +401,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : f32,
}
@@ -413,13 +413,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffers_Basic_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
_ = u;
@@ -435,7 +435,7 @@
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol : ptr<uniform, S>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol_1 : ptr<storage, S>) {
_ = *(tint_symbol);
@@ -447,13 +447,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0)
var<storage> buffer : array<f32>;
@@ -463,7 +463,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
arr : array<f32>,
}
@@ -474,13 +474,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
_ = buffer[0];
@@ -490,7 +490,7 @@
var<storage> buffer : array<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
arr : array<f32>,
}
@@ -501,13 +501,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArrayInsideFunction) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0)
var<storage> buffer : array<f32>;
@@ -521,7 +521,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_2 {
arr : array<f32>,
}
@@ -536,14 +536,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(ModuleScopeVarToEntryPointParamTest,
- Buffer_RuntimeArrayInsideFunction_OutOfOrder) {
- auto* src = R"(
+TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArrayInsideFunction_OutOfOrder) {
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
foo();
@@ -556,7 +555,7 @@
@group(0) @binding(0) var<storage> buffer : array<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
arr : array<f32>,
}
@@ -571,13 +570,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray_Alias) {
- auto* src = R"(
+ auto* src = R"(
type myarray = array<f32>;
@group(0) @binding(0)
@@ -589,7 +588,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
arr : array<f32>,
}
@@ -602,14 +601,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(ModuleScopeVarToEntryPointParamTest,
- Buffer_RuntimeArray_Alias_OutOfOrder) {
- auto* src = R"(
+TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray_Alias_OutOfOrder) {
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
_ = buffer[0];
@@ -620,7 +618,7 @@
type myarray = array<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_1 {
arr : array<f32>,
}
@@ -633,13 +631,13 @@
type myarray = array<f32>;
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_ArrayOfStruct) {
- auto* src = R"(
+ auto* src = R"(
struct S {
f : f32,
};
@@ -653,7 +651,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
f : f32,
}
@@ -668,13 +666,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_ArrayOfStruct_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
_ = buffer[0];
@@ -687,7 +685,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
f : f32,
}
@@ -702,13 +700,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffers_FunctionCalls) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : f32,
};
@@ -739,7 +737,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : f32,
}
@@ -765,13 +763,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffers_FunctionCalls_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
foo(1.0);
@@ -802,7 +800,7 @@
var<storage> s : S;
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol : ptr<uniform, S>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol_1 : ptr<storage, S>) {
foo(1.0, tint_symbol, tint_symbol_1);
@@ -828,13 +826,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, HandleTypes_Basic) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var t : texture_2d<f32>;
@group(0) @binding(1) var s : sampler;
@@ -845,7 +843,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) tint_symbol : texture_2d<f32>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) tint_symbol_1 : sampler) {
_ = tint_symbol;
@@ -853,13 +851,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, HandleTypes_FunctionCalls) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var t : texture_2d<f32>;
@group(0) @binding(1) var s : sampler;
@@ -884,7 +882,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn no_uses() {
}
@@ -906,14 +904,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(ModuleScopeVarToEntryPointParamTest,
- HandleTypes_FunctionCalls_OutOfOrder) {
- auto* src = R"(
+TEST_F(ModuleScopeVarToEntryPointParamTest, HandleTypes_FunctionCalls_OutOfOrder) {
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
foo(1.0);
@@ -938,7 +935,7 @@
@group(0) @binding(1) var s : sampler;
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) tint_symbol : texture_2d<f32>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) tint_symbol_1 : sampler) {
foo(1.0, tint_symbol, tint_symbol_1);
@@ -960,13 +957,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Matrix) {
- auto* src = R"(
+ auto* src = R"(
var<workgroup> m : mat2x2<f32>;
@stage(compute) @workgroup_size(1)
@@ -975,7 +972,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_2 {
m : mat2x2<f32>,
}
@@ -987,13 +984,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, NestedMatrix) {
- auto* src = R"(
+ auto* src = R"(
struct S1 {
m : mat2x2<f32>,
};
@@ -1008,7 +1005,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S1 {
m : mat2x2<f32>,
}
@@ -1028,15 +1025,15 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Test that we do not duplicate a struct type used by multiple workgroup
// variables that are promoted to threadgroup memory arguments.
TEST_F(ModuleScopeVarToEntryPointParamTest, DuplicateThreadgroupArgumentTypes) {
- auto* src = R"(
+ auto* src = R"(
struct S {
m : mat2x2<f32>,
};
@@ -1052,7 +1049,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
m : mat2x2<f32>,
}
@@ -1071,16 +1068,15 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Test that we do not duplicate a struct type used by multiple workgroup
// variables that are promoted to threadgroup memory arguments.
-TEST_F(ModuleScopeVarToEntryPointParamTest,
- DuplicateThreadgroupArgumentTypes_OutOfOrder) {
- auto* src = R"(
+TEST_F(ModuleScopeVarToEntryPointParamTest, DuplicateThreadgroupArgumentTypes_OutOfOrder) {
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
let x = a;
@@ -1095,7 +1091,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
m : mat2x2<f32>,
}
@@ -1114,13 +1110,13 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, UnusedVariables) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : f32,
};
@@ -1141,7 +1137,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : f32,
}
@@ -1151,17 +1147,17 @@
}
)";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, EmtpyModule) {
- auto* src = "";
+ auto* src = "";
- auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
- EXPECT_EQ(src, str(got));
+ EXPECT_EQ(src, str(got));
}
} // namespace
diff --git a/src/tint/transform/multiplanar_external_texture.cc b/src/tint/transform/multiplanar_external_texture.cc
index 34df013..93af4ca 100644
--- a/src/tint/transform/multiplanar_external_texture.cc
+++ b/src/tint/transform/multiplanar_external_texture.cc
@@ -24,8 +24,7 @@
#include "src/tint/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::MultiplanarExternalTexture);
-TINT_INSTANTIATE_TYPEINFO(
- tint::transform::MultiplanarExternalTexture::NewBindingPoints);
+TINT_INSTANTIATE_TYPEINFO(tint::transform::MultiplanarExternalTexture::NewBindingPoints);
namespace tint::transform {
namespace {
@@ -33,462 +32,429 @@
/// This struct stores symbols for new bindings created as a result of
/// transforming a texture_external instance.
struct NewBindingSymbols {
- Symbol params;
- Symbol plane_0;
- Symbol plane_1;
+ Symbol params;
+ Symbol plane_0;
+ Symbol plane_1;
};
} // namespace
/// State holds the current transform state
struct MultiplanarExternalTexture::State {
- /// The clone context.
- CloneContext& ctx;
+ /// The clone context.
+ CloneContext& ctx;
- /// ProgramBuilder for the context
- ProgramBuilder& b;
+ /// ProgramBuilder for the context
+ ProgramBuilder& b;
- /// Destination binding locations for the expanded texture_external provided
- /// as input into the transform.
- const NewBindingPoints* new_binding_points;
+ /// Destination binding locations for the expanded texture_external provided
+ /// as input into the transform.
+ const NewBindingPoints* new_binding_points;
- /// Symbol for the GammaTransferParams
- Symbol gamma_transfer_struct_sym;
+ /// Symbol for the GammaTransferParams
+ Symbol gamma_transfer_struct_sym;
- /// Symbol for the ExternalTextureParams struct
- Symbol params_struct_sym;
+ /// Symbol for the ExternalTextureParams struct
+ Symbol params_struct_sym;
- /// Symbol for the textureLoadExternal function
- Symbol texture_load_external_sym;
+ /// Symbol for the textureLoadExternal function
+ Symbol texture_load_external_sym;
- /// Symbol for the textureSampleExternal function
- Symbol texture_sample_external_sym;
+ /// Symbol for the textureSampleExternal function
+ Symbol texture_sample_external_sym;
- /// Symbol for the gammaCorrection function
- Symbol gamma_correction_sym;
+ /// Symbol for the gammaCorrection function
+ Symbol gamma_correction_sym;
- /// Storage for new bindings that have been created corresponding to an
- /// original texture_external binding.
- std::unordered_map<const sem::Variable*, NewBindingSymbols>
- new_binding_symbols;
+ /// Storage for new bindings that have been created corresponding to an
+ /// original texture_external binding.
+ std::unordered_map<const sem::Variable*, NewBindingSymbols> new_binding_symbols;
- /// Constructor
- /// @param context the clone
- /// @param newBindingPoints the input destination binding locations for the
- /// expanded texture_external
- State(CloneContext& context, const NewBindingPoints* newBindingPoints)
- : ctx(context), b(*context.dst), new_binding_points(newBindingPoints) {}
+ /// Constructor
+ /// @param context the clone
+ /// @param newBindingPoints the input destination binding locations for the
+ /// expanded texture_external
+ State(CloneContext& context, const NewBindingPoints* newBindingPoints)
+ : ctx(context), b(*context.dst), new_binding_points(newBindingPoints) {}
- /// Processes the module
- void Process() {
- auto& sem = ctx.src->Sem();
+ /// Processes the module
+ void Process() {
+ auto& sem = ctx.src->Sem();
- // For each texture_external binding, we replace it with a texture_2d<f32>
- // binding and create two additional bindings (one texture_2d<f32> to
- // represent the secondary plane and one uniform buffer for the
- // ExternalTextureParams struct).
- for (auto* var : ctx.src->AST().GlobalVariables()) {
- auto* sem_var = sem.Get(var);
- if (!sem_var->Type()->UnwrapRef()->Is<sem::ExternalTexture>()) {
- continue;
- }
+ // For each texture_external binding, we replace it with a texture_2d<f32>
+ // binding and create two additional bindings (one texture_2d<f32> to
+ // represent the secondary plane and one uniform buffer for the
+ // ExternalTextureParams struct).
+ for (auto* var : ctx.src->AST().GlobalVariables()) {
+ auto* sem_var = sem.Get(var);
+ if (!sem_var->Type()->UnwrapRef()->Is<sem::ExternalTexture>()) {
+ continue;
+ }
- // If the attributes are empty, then this must be a texture_external
- // passed as a function parameter. These variables are transformed
- // elsewhere.
- if (var->attributes.empty()) {
- continue;
- }
+ // If the attributes are empty, then this must be a texture_external
+ // passed as a function parameter. These variables are transformed
+ // elsewhere.
+ if (var->attributes.empty()) {
+ continue;
+ }
- // If we find a texture_external binding, we know we must emit the
- // ExternalTextureParams struct.
- if (!params_struct_sym.IsValid()) {
- createExtTexParamsStructs();
- }
+ // If we find a texture_external binding, we know we must emit the
+ // ExternalTextureParams struct.
+ if (!params_struct_sym.IsValid()) {
+ createExtTexParamsStructs();
+ }
- // The binding points for the newly introduced bindings must have been
- // provided to this transform. We fetch the new binding points by
- // providing the original texture_external binding points into the
- // passed map.
- BindingPoint bp = {var->BindingPoint().group->value,
- var->BindingPoint().binding->value};
+ // The binding points for the newly introduced bindings must have been
+ // provided to this transform. We fetch the new binding points by
+ // providing the original texture_external binding points into the
+ // passed map.
+ BindingPoint bp = {var->BindingPoint().group->value,
+ var->BindingPoint().binding->value};
- BindingsMap::const_iterator it =
- new_binding_points->bindings_map.find(bp);
- if (it == new_binding_points->bindings_map.end()) {
- b.Diagnostics().add_error(
- diag::System::Transform,
- "missing new binding points for texture_external at binding {" +
- std::to_string(bp.group) + "," + std::to_string(bp.binding) +
- "}");
- continue;
- }
+ BindingsMap::const_iterator it = new_binding_points->bindings_map.find(bp);
+ if (it == new_binding_points->bindings_map.end()) {
+ b.Diagnostics().add_error(
+ diag::System::Transform,
+ "missing new binding points for texture_external at binding {" +
+ std::to_string(bp.group) + "," + std::to_string(bp.binding) + "}");
+ continue;
+ }
- BindingPoints bps = it->second;
+ BindingPoints bps = it->second;
- // Symbols for the newly created bindings must be saved so they can be
- // passed as parameters later. These are placed in a map and keyed by
- // the source symbol associated with the texture_external binding that
- // corresponds with the new destination bindings.
- // NewBindingSymbols new_binding_syms;
- auto& syms = new_binding_symbols[sem_var];
- syms.plane_0 = ctx.Clone(var->symbol);
- syms.plane_1 = b.Symbols().New("ext_tex_plane_1");
- b.Global(syms.plane_1,
- b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32()),
- b.GroupAndBinding(bps.plane_1.group, bps.plane_1.binding));
- syms.params = b.Symbols().New("ext_tex_params");
- b.Global(syms.params, b.ty.type_name("ExternalTextureParams"),
- ast::StorageClass::kUniform,
- b.GroupAndBinding(bps.params.group, bps.params.binding));
+ // Symbols for the newly created bindings must be saved so they can be
+ // passed as parameters later. These are placed in a map and keyed by
+ // the source symbol associated with the texture_external binding that
+ // corresponds with the new destination bindings.
+ // NewBindingSymbols new_binding_syms;
+ auto& syms = new_binding_symbols[sem_var];
+ syms.plane_0 = ctx.Clone(var->symbol);
+ syms.plane_1 = b.Symbols().New("ext_tex_plane_1");
+ b.Global(syms.plane_1, b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32()),
+ b.GroupAndBinding(bps.plane_1.group, bps.plane_1.binding));
+ syms.params = b.Symbols().New("ext_tex_params");
+ b.Global(syms.params, b.ty.type_name("ExternalTextureParams"),
+ ast::StorageClass::kUniform,
+ b.GroupAndBinding(bps.params.group, bps.params.binding));
- // Replace the original texture_external binding with a texture_2d<f32>
- // binding.
- ast::AttributeList cloned_attributes = ctx.Clone(var->attributes);
- const ast::Expression* cloned_constructor = ctx.Clone(var->constructor);
+ // Replace the original texture_external binding with a texture_2d<f32>
+ // binding.
+ ast::AttributeList cloned_attributes = ctx.Clone(var->attributes);
+ const ast::Expression* cloned_constructor = ctx.Clone(var->constructor);
- auto* replacement =
- b.Var(syms.plane_0,
- b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32()),
- cloned_constructor, cloned_attributes);
- ctx.Replace(var, replacement);
- }
-
- // We must update all the texture_external parameters for user declared
- // functions.
- for (auto* fn : ctx.src->AST().Functions()) {
- for (const ast::Variable* param : fn->params) {
- if (auto* sem_var = sem.Get(param)) {
- if (!sem_var->Type()->UnwrapRef()->Is<sem::ExternalTexture>()) {
- continue;
- }
- // If we find a texture_external, we must ensure the
- // ExternalTextureParams struct exists.
- if (!params_struct_sym.IsValid()) {
- createExtTexParamsStructs();
- }
- // When a texture_external is found, we insert all components
- // the texture_external into the parameter list. We must also place
- // the new symbols into the transform state so they can be used when
- // transforming function calls.
- auto& syms = new_binding_symbols[sem_var];
- syms.plane_0 = ctx.Clone(param->symbol);
- syms.plane_1 = b.Symbols().New("ext_tex_plane_1");
- syms.params = b.Symbols().New("ext_tex_params");
- auto tex2d_f32 = [&] {
- return b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32());
- };
- ctx.Replace(param, b.Param(syms.plane_0, tex2d_f32()));
- ctx.InsertAfter(fn->params, param,
- b.Param(syms.plane_1, tex2d_f32()));
- ctx.InsertAfter(
- fn->params, param,
- b.Param(syms.params, b.ty.type_name(params_struct_sym)));
+ auto* replacement =
+ b.Var(syms.plane_0, b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32()),
+ cloned_constructor, cloned_attributes);
+ ctx.Replace(var, replacement);
}
- }
- }
- // Transform the original textureLoad and textureSampleLevel calls into
- // textureLoadExternal and textureSampleExternal calls.
- ctx.ReplaceAll(
- [&](const ast::CallExpression* expr) -> const ast::CallExpression* {
- auto* builtin = sem.Get(expr)->Target()->As<sem::Builtin>();
-
- if (builtin && !builtin->Parameters().empty() &&
- builtin->Parameters()[0]->Type()->Is<sem::ExternalTexture>() &&
- builtin->Type() != sem::BuiltinType::kTextureDimensions) {
- if (auto* var_user = sem.Get<sem::VariableUser>(expr->args[0])) {
- auto it = new_binding_symbols.find(var_user->Variable());
- if (it == new_binding_symbols.end()) {
- // If valid new binding locations were not provided earlier, we
- // would have been unable to create these symbols. An error
- // message was emitted earlier, so just return early to avoid
- // internal compiler errors and retain a clean error message.
- return nullptr;
- }
- auto& syms = it->second;
-
- if (builtin->Type() == sem::BuiltinType::kTextureLoad) {
- return createTexLdExt(expr, syms);
- }
-
- if (builtin->Type() == sem::BuiltinType::kTextureSampleLevel) {
- return createTexSmpExt(expr, syms);
- }
- }
-
- } else if (sem.Get(expr)->Target()->Is<sem::Function>()) {
- // The call expression may be to a user-defined function that
- // contains a texture_external parameter. These need to be expanded
- // out to multiple plane textures and the texture parameters
- // structure.
- for (auto* arg : expr->args) {
- if (auto* var_user = sem.Get<sem::VariableUser>(arg)) {
- // Check if a parameter is a texture_external by trying to find
- // it in the transform state.
- auto it = new_binding_symbols.find(var_user->Variable());
- if (it != new_binding_symbols.end()) {
- auto& syms = it->second;
- // When we find a texture_external, we must unpack it into its
- // components.
- ctx.Replace(arg, b.Expr(syms.plane_0));
- ctx.InsertAfter(expr->args, arg, b.Expr(syms.plane_1));
- ctx.InsertAfter(expr->args, arg, b.Expr(syms.params));
+ // We must update all the texture_external parameters for user declared
+ // functions.
+ for (auto* fn : ctx.src->AST().Functions()) {
+ for (const ast::Variable* param : fn->params) {
+ if (auto* sem_var = sem.Get(param)) {
+ if (!sem_var->Type()->UnwrapRef()->Is<sem::ExternalTexture>()) {
+ continue;
+ }
+ // If we find a texture_external, we must ensure the
+ // ExternalTextureParams struct exists.
+ if (!params_struct_sym.IsValid()) {
+ createExtTexParamsStructs();
+ }
+ // When a texture_external is found, we insert all components
+ // the texture_external into the parameter list. We must also place
+ // the new symbols into the transform state so they can be used when
+ // transforming function calls.
+ auto& syms = new_binding_symbols[sem_var];
+ syms.plane_0 = ctx.Clone(param->symbol);
+ syms.plane_1 = b.Symbols().New("ext_tex_plane_1");
+ syms.params = b.Symbols().New("ext_tex_params");
+ auto tex2d_f32 = [&] {
+ return b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32());
+ };
+ ctx.Replace(param, b.Param(syms.plane_0, tex2d_f32()));
+ ctx.InsertAfter(fn->params, param, b.Param(syms.plane_1, tex2d_f32()));
+ ctx.InsertAfter(fn->params, param,
+ b.Param(syms.params, b.ty.type_name(params_struct_sym)));
}
- }
}
- }
+ }
- return nullptr;
+ // Transform the original textureLoad and textureSampleLevel calls into
+ // textureLoadExternal and textureSampleExternal calls.
+ ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
+ auto* builtin = sem.Get(expr)->Target()->As<sem::Builtin>();
+
+ if (builtin && !builtin->Parameters().empty() &&
+ builtin->Parameters()[0]->Type()->Is<sem::ExternalTexture>() &&
+ builtin->Type() != sem::BuiltinType::kTextureDimensions) {
+ if (auto* var_user = sem.Get<sem::VariableUser>(expr->args[0])) {
+ auto it = new_binding_symbols.find(var_user->Variable());
+ if (it == new_binding_symbols.end()) {
+ // If valid new binding locations were not provided earlier, we
+ // would have been unable to create these symbols. An error
+ // message was emitted earlier, so just return early to avoid
+ // internal compiler errors and retain a clean error message.
+ return nullptr;
+ }
+ auto& syms = it->second;
+
+ if (builtin->Type() == sem::BuiltinType::kTextureLoad) {
+ return createTexLdExt(expr, syms);
+ }
+
+ if (builtin->Type() == sem::BuiltinType::kTextureSampleLevel) {
+ return createTexSmpExt(expr, syms);
+ }
+ }
+
+ } else if (sem.Get(expr)->Target()->Is<sem::Function>()) {
+ // The call expression may be to a user-defined function that
+ // contains a texture_external parameter. These need to be expanded
+ // out to multiple plane textures and the texture parameters
+ // structure.
+ for (auto* arg : expr->args) {
+ if (auto* var_user = sem.Get<sem::VariableUser>(arg)) {
+ // Check if a parameter is a texture_external by trying to find
+ // it in the transform state.
+ auto it = new_binding_symbols.find(var_user->Variable());
+ if (it != new_binding_symbols.end()) {
+ auto& syms = it->second;
+ // When we find a texture_external, we must unpack it into its
+ // components.
+ ctx.Replace(arg, b.Expr(syms.plane_0));
+ ctx.InsertAfter(expr->args, arg, b.Expr(syms.plane_1));
+ ctx.InsertAfter(expr->args, arg, b.Expr(syms.params));
+ }
+ }
+ }
+ }
+
+ return nullptr;
});
- }
-
- /// Creates the parameter structs associated with the transform.
- void createExtTexParamsStructs() {
- // Create GammaTransferParams struct.
- ast::StructMemberList gamma_transfer_member_list = {
- b.Member("G", b.ty.f32()), b.Member("A", b.ty.f32()),
- b.Member("B", b.ty.f32()), b.Member("C", b.ty.f32()),
- b.Member("D", b.ty.f32()), b.Member("E", b.ty.f32()),
- b.Member("F", b.ty.f32()), b.Member("padding", b.ty.u32())};
-
- gamma_transfer_struct_sym = b.Symbols().New("GammaTransferParams");
-
- b.Structure(gamma_transfer_struct_sym, gamma_transfer_member_list);
-
- // Create ExternalTextureParams struct.
- ast::StructMemberList ext_tex_params_member_list = {
- b.Member("numPlanes", b.ty.u32()),
- b.Member("yuvToRgbConversionMatrix", b.ty.mat3x4(b.ty.f32())),
- b.Member("gammaDecodeParams", b.ty.type_name("GammaTransferParams")),
- b.Member("gammaEncodeParams", b.ty.type_name("GammaTransferParams")),
- b.Member("gamutConversionMatrix", b.ty.mat3x3(b.ty.f32()))};
-
- params_struct_sym = b.Symbols().New("ExternalTextureParams");
-
- b.Structure(params_struct_sym, ext_tex_params_member_list);
- }
-
- /// Creates the gammaCorrection function if needed and returns a call
- /// expression to it.
- void createGammaCorrectionFn() {
- using f32 = ProgramBuilder::f32;
- ast::VariableList varList = {
- b.Param("v", b.ty.vec3<f32>()),
- b.Param("params", b.ty.type_name(gamma_transfer_struct_sym))};
-
- ast::StatementList statementList = {
- // let cond = abs(v) < vec3(params.D);
- b.Decl(b.Let("cond", nullptr,
- b.LessThan(b.Call("abs", "v"),
- b.vec3<f32>(b.MemberAccessor("params", "D"))))),
- // let t = sign(v) * ((params.C * abs(v)) + params.F);
- b.Decl(b.Let("t", nullptr,
- b.Mul(b.Call("sign", "v"),
- b.Add(b.Mul(b.MemberAccessor("params", "C"),
- b.Call("abs", "v")),
- b.MemberAccessor("params", "F"))))),
- // let f = (sign(v) * pow(((params.A * abs(v)) + params.B),
- // vec3(params.G))) + params.E;
- b.Decl(b.Let(
- "f", nullptr,
- b.Mul(b.Call("sign", "v"),
- b.Add(b.Call("pow",
- b.Add(b.Mul(b.MemberAccessor("params", "A"),
- b.Call("abs", "v")),
- b.MemberAccessor("params", "B")),
- b.vec3<f32>(b.MemberAccessor("params", "G"))),
- b.MemberAccessor("params", "E"))))),
- // return select(f, t, cond);
- b.Return(b.Call("select", "f", "t", "cond"))};
-
- gamma_correction_sym = b.Symbols().New("gammaCorrection");
-
- b.Func(gamma_correction_sym, varList, b.ty.vec3<f32>(), statementList, {});
- }
-
- /// Constructs a StatementList containing all the statements making up the
- /// bodies of the textureSampleExternal and textureLoadExternal functions.
- /// @param call_type determines which function body to generate
- /// @returns a statement list that makes of the body of the chosen function
- ast::StatementList createTexFnExtStatementList(sem::BuiltinType call_type) {
- using f32 = ProgramBuilder::f32;
- const ast::CallExpression* single_plane_call = nullptr;
- const ast::CallExpression* plane_0_call = nullptr;
- const ast::CallExpression* plane_1_call = nullptr;
- if (call_type == sem::BuiltinType::kTextureSampleLevel) {
- // textureSampleLevel(plane0, smp, coord.xy, 0.0);
- single_plane_call =
- b.Call("textureSampleLevel", "plane0", "smp", "coord", 0.0f);
- // textureSampleLevel(plane0, smp, coord.xy, 0.0);
- plane_0_call =
- b.Call("textureSampleLevel", "plane0", "smp", "coord", 0.0f);
- // textureSampleLevel(plane1, smp, coord.xy, 0.0);
- plane_1_call =
- b.Call("textureSampleLevel", "plane1", "smp", "coord", 0.0f);
- } else if (call_type == sem::BuiltinType::kTextureLoad) {
- // textureLoad(plane0, coords.xy, 0);
- single_plane_call = b.Call("textureLoad", "plane0", "coord", 0);
- // textureLoad(plane0, coords.xy, 0);
- plane_0_call = b.Call("textureLoad", "plane0", "coord", 0);
- // textureLoad(plane1, coords.xy, 0);
- plane_1_call = b.Call("textureLoad", "plane1", "coord", 0);
- } else {
- TINT_ICE(Transform, b.Diagnostics())
- << "unhandled builtin: " << call_type;
}
- return {
- // var color: vec3<f32>;
- b.Decl(b.Var("color", b.ty.vec3(b.ty.f32()))),
- // if ((params.numPlanes == 1u))
- b.If(b.create<ast::BinaryExpression>(
- ast::BinaryOp::kEqual, b.MemberAccessor("params", "numPlanes"),
- b.Expr(1u)),
- b.Block(
- // color = textureLoad(plane0, coord, 0).rgb;
- b.Assign("color", b.MemberAccessor(single_plane_call, "rgb"))),
- b.Block(
- // color = vec4<f32>(plane_0_call.r, plane_1_call.rg, 1.0) *
- // params.yuvToRgbConversionMatrix;
- b.Assign("color",
- b.Mul(b.vec4<f32>(
- b.MemberAccessor(plane_0_call, "r"),
- b.MemberAccessor(plane_1_call, "rg"), 1.0f),
- b.MemberAccessor(
- "params", "yuvToRgbConversionMatrix"))))),
- // color = gammaConversion(color, gammaDecodeParams);
- b.Assign("color",
- b.Call("gammaCorrection", "color",
- b.MemberAccessor("params", "gammaDecodeParams"))),
- // color = (params.gamutConversionMatrix * color);
- b.Assign("color",
- b.Mul(b.MemberAccessor("params", "gamutConversionMatrix"),
- "color")),
- // color = gammaConversion(color, gammaEncodeParams);
- b.Assign("color",
- b.Call("gammaCorrection", "color",
- b.MemberAccessor("params", "gammaEncodeParams"))),
- // return vec4<f32>(color, 1.0f);
- b.Return(b.vec4<f32>("color", 1.0f))};
- }
+ /// Creates the parameter structs associated with the transform.
+ void createExtTexParamsStructs() {
+ // Create GammaTransferParams struct.
+ ast::StructMemberList gamma_transfer_member_list = {
+ b.Member("G", b.ty.f32()), b.Member("A", b.ty.f32()), b.Member("B", b.ty.f32()),
+ b.Member("C", b.ty.f32()), b.Member("D", b.ty.f32()), b.Member("E", b.ty.f32()),
+ b.Member("F", b.ty.f32()), b.Member("padding", b.ty.u32())};
- /// Creates the textureSampleExternal function if needed and returns a call
- /// expression to it.
- /// @param expr the call expression being transformed
- /// @param syms the expanded symbols to be used in the new call
- /// @returns a call expression to textureSampleExternal
- const ast::CallExpression* createTexSmpExt(const ast::CallExpression* expr,
- NewBindingSymbols syms) {
- ast::ExpressionList params;
- const ast::Expression* plane_0_binding_param = ctx.Clone(expr->args[0]);
+ gamma_transfer_struct_sym = b.Symbols().New("GammaTransferParams");
- if (expr->args.size() != 3) {
- TINT_ICE(Transform, b.Diagnostics())
- << "expected textureSampleLevel call with a "
- "texture_external to have 3 parameters, found "
- << expr->args.size() << " parameters";
+ b.Structure(gamma_transfer_struct_sym, gamma_transfer_member_list);
+
+ // Create ExternalTextureParams struct.
+ ast::StructMemberList ext_tex_params_member_list = {
+ b.Member("numPlanes", b.ty.u32()),
+ b.Member("yuvToRgbConversionMatrix", b.ty.mat3x4(b.ty.f32())),
+ b.Member("gammaDecodeParams", b.ty.type_name("GammaTransferParams")),
+ b.Member("gammaEncodeParams", b.ty.type_name("GammaTransferParams")),
+ b.Member("gamutConversionMatrix", b.ty.mat3x3(b.ty.f32()))};
+
+ params_struct_sym = b.Symbols().New("ExternalTextureParams");
+
+ b.Structure(params_struct_sym, ext_tex_params_member_list);
}
- // TextureSampleExternal calls the gammaCorrection function, so ensure it
- // exists.
- if (!gamma_correction_sym.IsValid()) {
- createGammaCorrectionFn();
+ /// Creates the gammaCorrection function if needed and returns a call
+ /// expression to it.
+ void createGammaCorrectionFn() {
+ using f32 = ProgramBuilder::f32;
+ ast::VariableList varList = {b.Param("v", b.ty.vec3<f32>()),
+ b.Param("params", b.ty.type_name(gamma_transfer_struct_sym))};
+
+ ast::StatementList statementList = {
+ // let cond = abs(v) < vec3(params.D);
+ b.Decl(b.Let(
+ "cond", nullptr,
+ b.LessThan(b.Call("abs", "v"), b.vec3<f32>(b.MemberAccessor("params", "D"))))),
+ // let t = sign(v) * ((params.C * abs(v)) + params.F);
+ b.Decl(b.Let("t", nullptr,
+ b.Mul(b.Call("sign", "v"),
+ b.Add(b.Mul(b.MemberAccessor("params", "C"), b.Call("abs", "v")),
+ b.MemberAccessor("params", "F"))))),
+ // let f = (sign(v) * pow(((params.A * abs(v)) + params.B),
+ // vec3(params.G))) + params.E;
+ b.Decl(b.Let(
+ "f", nullptr,
+ b.Mul(b.Call("sign", "v"),
+ b.Add(b.Call("pow",
+ b.Add(b.Mul(b.MemberAccessor("params", "A"), b.Call("abs", "v")),
+ b.MemberAccessor("params", "B")),
+ b.vec3<f32>(b.MemberAccessor("params", "G"))),
+ b.MemberAccessor("params", "E"))))),
+ // return select(f, t, cond);
+ b.Return(b.Call("select", "f", "t", "cond"))};
+
+ gamma_correction_sym = b.Symbols().New("gammaCorrection");
+
+ b.Func(gamma_correction_sym, varList, b.ty.vec3<f32>(), statementList, {});
}
- if (!texture_sample_external_sym.IsValid()) {
- texture_sample_external_sym = b.Symbols().New("textureSampleExternal");
+ /// Constructs a StatementList containing all the statements making up the
+ /// bodies of the textureSampleExternal and textureLoadExternal functions.
+ /// @param call_type determines which function body to generate
+ /// @returns a statement list that makes of the body of the chosen function
+ ast::StatementList createTexFnExtStatementList(sem::BuiltinType call_type) {
+ using f32 = ProgramBuilder::f32;
+ const ast::CallExpression* single_plane_call = nullptr;
+ const ast::CallExpression* plane_0_call = nullptr;
+ const ast::CallExpression* plane_1_call = nullptr;
+ if (call_type == sem::BuiltinType::kTextureSampleLevel) {
+ // textureSampleLevel(plane0, smp, coord.xy, 0.0);
+ single_plane_call = b.Call("textureSampleLevel", "plane0", "smp", "coord", 0.0f);
+ // textureSampleLevel(plane0, smp, coord.xy, 0.0);
+ plane_0_call = b.Call("textureSampleLevel", "plane0", "smp", "coord", 0.0f);
+ // textureSampleLevel(plane1, smp, coord.xy, 0.0);
+ plane_1_call = b.Call("textureSampleLevel", "plane1", "smp", "coord", 0.0f);
+ } else if (call_type == sem::BuiltinType::kTextureLoad) {
+ // textureLoad(plane0, coords.xy, 0);
+ single_plane_call = b.Call("textureLoad", "plane0", "coord", 0);
+ // textureLoad(plane0, coords.xy, 0);
+ plane_0_call = b.Call("textureLoad", "plane0", "coord", 0);
+ // textureLoad(plane1, coords.xy, 0);
+ plane_1_call = b.Call("textureLoad", "plane1", "coord", 0);
+ } else {
+ TINT_ICE(Transform, b.Diagnostics()) << "unhandled builtin: " << call_type;
+ }
- // Emit the textureSampleExternal function.
- ast::VariableList varList = {
- b.Param("plane0",
- b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32())),
- b.Param("plane1",
- b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32())),
- b.Param("smp", b.ty.sampler(ast::SamplerKind::kSampler)),
- b.Param("coord", b.ty.vec2(b.ty.f32())),
- b.Param("params", b.ty.type_name(params_struct_sym))};
-
- ast::StatementList statementList =
- createTexFnExtStatementList(sem::BuiltinType::kTextureSampleLevel);
-
- b.Func(texture_sample_external_sym, varList, b.ty.vec4(b.ty.f32()),
- statementList, {});
+ return {
+ // var color: vec3<f32>;
+ b.Decl(b.Var("color", b.ty.vec3(b.ty.f32()))),
+ // if ((params.numPlanes == 1u))
+ b.If(b.create<ast::BinaryExpression>(
+ ast::BinaryOp::kEqual, b.MemberAccessor("params", "numPlanes"), b.Expr(1u)),
+ b.Block(
+ // color = textureLoad(plane0, coord, 0).rgb;
+ b.Assign("color", b.MemberAccessor(single_plane_call, "rgb"))),
+ b.Block(
+ // color = vec4<f32>(plane_0_call.r, plane_1_call.rg, 1.0) *
+ // params.yuvToRgbConversionMatrix;
+ b.Assign("color",
+ b.Mul(b.vec4<f32>(b.MemberAccessor(plane_0_call, "r"),
+ b.MemberAccessor(plane_1_call, "rg"), 1.0f),
+ b.MemberAccessor("params", "yuvToRgbConversionMatrix"))))),
+ // color = gammaConversion(color, gammaDecodeParams);
+ b.Assign("color", b.Call("gammaCorrection", "color",
+ b.MemberAccessor("params", "gammaDecodeParams"))),
+ // color = (params.gamutConversionMatrix * color);
+ b.Assign("color", b.Mul(b.MemberAccessor("params", "gamutConversionMatrix"), "color")),
+ // color = gammaConversion(color, gammaEncodeParams);
+ b.Assign("color", b.Call("gammaCorrection", "color",
+ b.MemberAccessor("params", "gammaEncodeParams"))),
+ // return vec4<f32>(color, 1.0f);
+ b.Return(b.vec4<f32>("color", 1.0f))};
}
- const ast::IdentifierExpression* exp = b.Expr(texture_sample_external_sym);
- params = {plane_0_binding_param, b.Expr(syms.plane_1),
- ctx.Clone(expr->args[1]), ctx.Clone(expr->args[2]),
- b.Expr(syms.params)};
- return b.Call(exp, params);
- }
+ /// Creates the textureSampleExternal function if needed and returns a call
+ /// expression to it.
+ /// @param expr the call expression being transformed
+ /// @param syms the expanded symbols to be used in the new call
+ /// @returns a call expression to textureSampleExternal
+ const ast::CallExpression* createTexSmpExt(const ast::CallExpression* expr,
+ NewBindingSymbols syms) {
+ ast::ExpressionList params;
+ const ast::Expression* plane_0_binding_param = ctx.Clone(expr->args[0]);
- /// Creates the textureLoadExternal function if needed and returns a call
- /// expression to it.
- /// @param expr the call expression being transformed
- /// @param syms the expanded symbols to be used in the new call
- /// @returns a call expression to textureLoadExternal
- const ast::CallExpression* createTexLdExt(const ast::CallExpression* expr,
- NewBindingSymbols syms) {
- ast::ExpressionList params;
- const ast::Expression* plane_0_binding_param = ctx.Clone(expr->args[0]);
+ if (expr->args.size() != 3) {
+ TINT_ICE(Transform, b.Diagnostics()) << "expected textureSampleLevel call with a "
+ "texture_external to have 3 parameters, found "
+ << expr->args.size() << " parameters";
+ }
- if (expr->args.size() != 2) {
- TINT_ICE(Transform, b.Diagnostics())
- << "expected textureLoad call with a texture_external "
- "to have 2 parameters, found "
- << expr->args.size() << " parameters";
+ // TextureSampleExternal calls the gammaCorrection function, so ensure it
+ // exists.
+ if (!gamma_correction_sym.IsValid()) {
+ createGammaCorrectionFn();
+ }
+
+ if (!texture_sample_external_sym.IsValid()) {
+ texture_sample_external_sym = b.Symbols().New("textureSampleExternal");
+
+ // Emit the textureSampleExternal function.
+ ast::VariableList varList = {
+ b.Param("plane0", b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32())),
+ b.Param("plane1", b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32())),
+ b.Param("smp", b.ty.sampler(ast::SamplerKind::kSampler)),
+ b.Param("coord", b.ty.vec2(b.ty.f32())),
+ b.Param("params", b.ty.type_name(params_struct_sym))};
+
+ ast::StatementList statementList =
+ createTexFnExtStatementList(sem::BuiltinType::kTextureSampleLevel);
+
+ b.Func(texture_sample_external_sym, varList, b.ty.vec4(b.ty.f32()), statementList, {});
+ }
+
+ const ast::IdentifierExpression* exp = b.Expr(texture_sample_external_sym);
+ params = {plane_0_binding_param, b.Expr(syms.plane_1), ctx.Clone(expr->args[1]),
+ ctx.Clone(expr->args[2]), b.Expr(syms.params)};
+ return b.Call(exp, params);
}
- // TextureLoadExternal calls the gammaCorrection function, so ensure it
- // exists.
- if (!gamma_correction_sym.IsValid()) {
- createGammaCorrectionFn();
+ /// Creates the textureLoadExternal function if needed and returns a call
+ /// expression to it.
+ /// @param expr the call expression being transformed
+ /// @param syms the expanded symbols to be used in the new call
+ /// @returns a call expression to textureLoadExternal
+ const ast::CallExpression* createTexLdExt(const ast::CallExpression* expr,
+ NewBindingSymbols syms) {
+ ast::ExpressionList params;
+ const ast::Expression* plane_0_binding_param = ctx.Clone(expr->args[0]);
+
+ if (expr->args.size() != 2) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "expected textureLoad call with a texture_external "
+ "to have 2 parameters, found "
+ << expr->args.size() << " parameters";
+ }
+
+ // TextureLoadExternal calls the gammaCorrection function, so ensure it
+ // exists.
+ if (!gamma_correction_sym.IsValid()) {
+ createGammaCorrectionFn();
+ }
+
+ if (!texture_load_external_sym.IsValid()) {
+ texture_load_external_sym = b.Symbols().New("textureLoadExternal");
+
+ // Emit the textureLoadExternal function.
+ ast::VariableList var_list = {
+ b.Param("plane0", b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32())),
+ b.Param("plane1", b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32())),
+ b.Param("coord", b.ty.vec2(b.ty.i32())),
+ b.Param("params", b.ty.type_name(params_struct_sym))};
+
+ ast::StatementList statement_list =
+ createTexFnExtStatementList(sem::BuiltinType::kTextureLoad);
+
+ b.Func(texture_load_external_sym, var_list, b.ty.vec4(b.ty.f32()), statement_list, {});
+ }
+
+ const ast::IdentifierExpression* exp = b.Expr(texture_load_external_sym);
+ params = {plane_0_binding_param, b.Expr(syms.plane_1), ctx.Clone(expr->args[1]),
+ b.Expr(syms.params)};
+ return b.Call(exp, params);
}
-
- if (!texture_load_external_sym.IsValid()) {
- texture_load_external_sym = b.Symbols().New("textureLoadExternal");
-
- // Emit the textureLoadExternal function.
- ast::VariableList var_list = {
- b.Param("plane0",
- b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32())),
- b.Param("plane1",
- b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32())),
- b.Param("coord", b.ty.vec2(b.ty.i32())),
- b.Param("params", b.ty.type_name(params_struct_sym))};
-
- ast::StatementList statement_list =
- createTexFnExtStatementList(sem::BuiltinType::kTextureLoad);
-
- b.Func(texture_load_external_sym, var_list, b.ty.vec4(b.ty.f32()),
- statement_list, {});
- }
-
- const ast::IdentifierExpression* exp = b.Expr(texture_load_external_sym);
- params = {plane_0_binding_param, b.Expr(syms.plane_1),
- ctx.Clone(expr->args[1]), b.Expr(syms.params)};
- return b.Call(exp, params);
- }
};
-MultiplanarExternalTexture::NewBindingPoints::NewBindingPoints(
- BindingsMap inputBindingsMap)
+MultiplanarExternalTexture::NewBindingPoints::NewBindingPoints(BindingsMap inputBindingsMap)
: bindings_map(std::move(inputBindingsMap)) {}
MultiplanarExternalTexture::NewBindingPoints::~NewBindingPoints() = default;
MultiplanarExternalTexture::MultiplanarExternalTexture() = default;
MultiplanarExternalTexture::~MultiplanarExternalTexture() = default;
-bool MultiplanarExternalTexture::ShouldRun(const Program* program,
- const DataMap&) const {
- for (auto* node : program->ASTNodes().Objects()) {
- if (auto* ty = node->As<ast::Type>()) {
- if (program->Sem().Get<sem::ExternalTexture>(ty)) {
- return true;
- }
+bool MultiplanarExternalTexture::ShouldRun(const Program* program, const DataMap&) const {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (auto* ty = node->As<ast::Type>()) {
+ if (program->Sem().Get<sem::ExternalTexture>(ty)) {
+ return true;
+ }
+ }
}
- }
- return false;
+ return false;
}
// Within this transform, an instance of a texture_external binding is unpacked
@@ -498,23 +464,21 @@
// texture_external parameter will be transformed into a newly generated version
// of the function, which can perform the desired operation on a single RGBA
// plane or on separate Y and UV planes.
-void MultiplanarExternalTexture::Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap&) const {
- auto* new_binding_points = inputs.Get<NewBindingPoints>();
+void MultiplanarExternalTexture::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
+ auto* new_binding_points = inputs.Get<NewBindingPoints>();
- if (!new_binding_points) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform,
- "missing new binding point data for " + std::string(TypeInfo().name));
- return;
- }
+ if (!new_binding_points) {
+ ctx.dst->Diagnostics().add_error(
+ diag::System::Transform,
+ "missing new binding point data for " + std::string(TypeInfo().name));
+ return;
+ }
- State state(ctx, new_binding_points);
+ State state(ctx, new_binding_points);
- state.Process();
+ state.Process();
- ctx.Clone();
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/multiplanar_external_texture.h b/src/tint/transform/multiplanar_external_texture.h
index ab2298c..88cbc981 100644
--- a/src/tint/transform/multiplanar_external_texture.h
+++ b/src/tint/transform/multiplanar_external_texture.h
@@ -31,12 +31,12 @@
/// This struct identifies the binding groups and locations for new bindings to
/// use when transforming a texture_external instance.
struct BindingPoints {
- /// The desired binding location of the texture_2d representing plane #1 when
- /// a texture_external binding is expanded.
- BindingPoint plane_1;
- /// The desired binding location of the ExternalTextureParams uniform when a
- /// texture_external binding is expanded.
- BindingPoint params;
+ /// The desired binding location of the texture_2d representing plane #1 when
+ /// a texture_external binding is expanded.
+ BindingPoint plane_1;
+ /// The desired binding location of the ExternalTextureParams uniform when a
+ /// texture_external binding is expanded.
+ BindingPoint params;
};
/// Within the MultiplanarExternalTexture transform, each instance of a
@@ -47,52 +47,48 @@
/// transformed into a newly generated version of the function, which can
/// perform the desired operation on a single RGBA plane or on seperate Y and UV
/// planes.
-class MultiplanarExternalTexture
- : public Castable<MultiplanarExternalTexture, Transform> {
- public:
- /// BindingsMap is a map where the key is the binding location of a
- /// texture_external and the value is a struct containing the desired
- /// locations for new bindings expanded from the texture_external instance.
- using BindingsMap = std::unordered_map<BindingPoint, BindingPoints>;
+class MultiplanarExternalTexture : public Castable<MultiplanarExternalTexture, Transform> {
+ public:
+ /// BindingsMap is a map where the key is the binding location of a
+ /// texture_external and the value is a struct containing the desired
+ /// locations for new bindings expanded from the texture_external instance.
+ using BindingsMap = std::unordered_map<BindingPoint, BindingPoints>;
- /// NewBindingPoints is consumed by the MultiplanarExternalTexture transform.
- /// Data holds information about location of each texture_external binding and
- /// which binding slots it should expand into.
- struct NewBindingPoints : public Castable<Data, transform::Data> {
+ /// NewBindingPoints is consumed by the MultiplanarExternalTexture transform.
+ /// Data holds information about location of each texture_external binding and
+ /// which binding slots it should expand into.
+ struct NewBindingPoints : public Castable<Data, transform::Data> {
+ /// Constructor
+ /// @param bm a map to the new binding slots to use.
+ explicit NewBindingPoints(BindingsMap bm);
+
+ /// Destructor
+ ~NewBindingPoints() override;
+
+ /// A map of new binding points to use.
+ const BindingsMap bindings_map;
+ };
+
/// Constructor
- /// @param bm a map to the new binding slots to use.
- explicit NewBindingPoints(BindingsMap bm);
-
+ MultiplanarExternalTexture();
/// Destructor
- ~NewBindingPoints() override;
+ ~MultiplanarExternalTexture() override;
- /// A map of new binding points to use.
- const BindingsMap bindings_map;
- };
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- /// Constructor
- MultiplanarExternalTexture();
- /// Destructor
- ~MultiplanarExternalTexture() override;
+ protected:
+ struct State;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
-
- protected:
- struct State;
-
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/multiplanar_external_texture_test.cc b/src/tint/transform/multiplanar_external_texture_test.cc
index 77448a7..ac05545 100644
--- a/src/tint/transform/multiplanar_external_texture_test.cc
+++ b/src/tint/transform/multiplanar_external_texture_test.cc
@@ -21,38 +21,38 @@
using MultiplanarExternalTextureTest = TransformTest;
TEST_F(MultiplanarExternalTextureTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<MultiplanarExternalTexture>(src));
+ EXPECT_FALSE(ShouldRun<MultiplanarExternalTexture>(src));
}
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) {
- auto* src = R"(
+ auto* src = R"(
type ET = texture_external;
)";
- EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
+ EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
}
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureGlobal) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var ext_tex : texture_external;
)";
- EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
+ EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
}
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) {
- auto* src = R"(
+ auto* src = R"(
fn f(ext_tex : texture_external) {}
)";
- EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
+ EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
}
// Running the transform without passing in data for the new bindings should
// result in an error.
TEST_F(MultiplanarExternalTextureTest, ErrorNoPassedData) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var s : sampler;
@group(0) @binding(1) var ext_tex : texture_external;
@@ -61,16 +61,16 @@
return textureSampleLevel(ext_tex, s, coord.xy);
}
)";
- auto* expect =
- R"(error: missing new binding point data for tint::transform::MultiplanarExternalTexture)";
+ auto* expect =
+ R"(error: missing new binding point data for tint::transform::MultiplanarExternalTexture)";
- auto got = Run<MultiplanarExternalTexture>(src);
- EXPECT_EQ(expect, str(got));
+ auto got = Run<MultiplanarExternalTexture>(src);
+ EXPECT_EQ(expect, str(got));
}
// Running the transform with incorrect binding data should result in an error.
TEST_F(MultiplanarExternalTextureTest, ErrorIncorrectBindingPont) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var s : sampler;
@group(0) @binding(1) var ext_tex : texture_external;
@@ -80,21 +80,20 @@
}
)";
- auto* expect =
- R"(error: missing new binding points for texture_external at binding {0,1})";
+ auto* expect = R"(error: missing new binding points for texture_external at binding {0,1})";
- DataMap data;
- // This bindings map specifies 0,0 as the location of the texture_external,
- // which is incorrect.
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ // This bindings map specifies 0,0 as the location of the texture_external,
+ // which is incorrect.
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Tests that the transform works with a textureDimensions call.
TEST_F(MultiplanarExternalTextureTest, Dimensions) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var ext_tex : texture_external;
@stage(fragment)
@@ -105,7 +104,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -139,16 +138,16 @@
}
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Tests that the transform works with a textureDimensions call.
TEST_F(MultiplanarExternalTextureTest, Dimensions_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
var dim : vec2<i32>;
@@ -159,7 +158,7 @@
@group(0) @binding(0) var ext_tex : texture_external;
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -193,16 +192,16 @@
@group(0) @binding(0) var ext_tex : texture_2d<f32>;
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Test that the transform works with a textureSampleLevel call.
TEST_F(MultiplanarExternalTextureTest, BasicTextureSampleLevel) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var s : sampler;
@group(0) @binding(1) var ext_tex : texture_external;
@@ -212,7 +211,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -266,16 +265,16 @@
}
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{{{0, 1}, {{0, 2}, {0, 3}}}});
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 1}, {{0, 2}, {0, 3}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Test that the transform works with a textureSampleLevel call.
TEST_F(MultiplanarExternalTextureTest, BasicTextureSampleLevel_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
return textureSampleLevel(ext_tex, s, coord.xy);
@@ -285,7 +284,7 @@
@group(0) @binding(0) var s : sampler;
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -339,16 +338,16 @@
@group(0) @binding(0) var s : sampler;
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{{{0, 1}, {{0, 2}, {0, 3}}}});
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 1}, {{0, 2}, {0, 3}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Tests that the transform works with a textureLoad call.
TEST_F(MultiplanarExternalTextureTest, BasicTextureLoad) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var ext_tex : texture_external;
@stage(fragment)
@@ -357,7 +356,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -409,16 +408,16 @@
}
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Tests that the transform works with a textureLoad call.
TEST_F(MultiplanarExternalTextureTest, BasicTextureLoad_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
return textureLoad(ext_tex, vec2<i32>(1, 1));
@@ -427,7 +426,7 @@
@group(0) @binding(0) var ext_tex : texture_external;
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -479,17 +478,17 @@
@group(0) @binding(0) var ext_tex : texture_2d<f32>;
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Tests that the transform works with both a textureSampleLevel and textureLoad
// call.
TEST_F(MultiplanarExternalTextureTest, TextureSampleAndTextureLoad) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var s : sampler;
@group(0) @binding(1) var ext_tex : texture_external;
@@ -499,7 +498,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -566,17 +565,17 @@
}
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{{{0, 1}, {{0, 2}, {0, 3}}}});
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 1}, {{0, 2}, {0, 3}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Tests that the transform works with both a textureSampleLevel and textureLoad
// call.
TEST_F(MultiplanarExternalTextureTest, TextureSampleAndTextureLoad_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
return textureSampleLevel(ext_tex, s, coord.xy) + textureLoad(ext_tex, vec2<i32>(1, 1));
@@ -586,7 +585,7 @@
@group(0) @binding(1) var ext_tex : texture_external;
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -653,16 +652,16 @@
@group(0) @binding(1) var ext_tex : texture_2d<f32>;
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{{{0, 1}, {{0, 2}, {0, 3}}}});
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 1}, {{0, 2}, {0, 3}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Tests that the transform works with many instances of texture_external.
TEST_F(MultiplanarExternalTextureTest, ManyTextureSampleLevel) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var s : sampler;
@group(0) @binding(1) var ext_tex : texture_external;
@group(0) @binding(2) var ext_tex_1 : texture_external;
@@ -675,7 +674,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -747,22 +746,21 @@
}
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{
- {{0, 1}, {{0, 4}, {0, 5}}},
- {{0, 2}, {{0, 6}, {0, 7}}},
- {{0, 3}, {{0, 8}, {0, 9}}},
- {{1, 0}, {{1, 1}, {1, 2}}},
- });
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 1}, {{0, 4}, {0, 5}}},
+ {{0, 2}, {{0, 6}, {0, 7}}},
+ {{0, 3}, {{0, 8}, {0, 9}}},
+ {{1, 0}, {{1, 1}, {1, 2}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Tests that the texture_external passed as a function parameter produces the
// correct output.
TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsParam) {
- auto* src = R"(
+ auto* src = R"(
fn f(t : texture_external, s : sampler) {
textureSampleLevel(t, s, vec2<f32>(1.0, 2.0));
}
@@ -776,7 +774,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -833,20 +831,18 @@
f(ext_tex, ext_tex_plane_1, ext_tex_params, smp);
}
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{
- {{0, 0}, {{0, 2}, {0, 3}}},
- });
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 2}, {0, 3}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Tests that the texture_external passed as a function parameter produces the
// correct output.
-TEST_F(MultiplanarExternalTextureTest,
- ExternalTexturePassedAsParam_OutOfOrder) {
- auto* src = R"(
+TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsParam_OutOfOrder) {
+ auto* src = R"(
@stage(fragment)
fn main() {
f(ext_tex, smp);
@@ -860,7 +856,7 @@
@group(0) @binding(1) var smp : sampler;
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -917,19 +913,18 @@
@group(0) @binding(1) var smp : sampler;
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{
- {{0, 0}, {{0, 2}, {0, 3}}},
- });
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 2}, {0, 3}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Tests that the texture_external passed as a parameter not in the first
// position produces the correct output.
TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsSecondParam) {
- auto* src = R"(
+ auto* src = R"(
fn f(s : sampler, t : texture_external) {
textureSampleLevel(t, s, vec2<f32>(1.0, 2.0));
}
@@ -943,7 +938,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -1000,19 +995,18 @@
f(smp, ext_tex, ext_tex_plane_1, ext_tex_params);
}
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{
- {{0, 0}, {{0, 2}, {0, 3}}},
- });
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 2}, {0, 3}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Tests that multiple texture_external params passed to a function produces the
// correct output.
TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsParamMultiple) {
- auto* src = R"(
+ auto* src = R"(
fn f(t : texture_external, s : sampler, t2 : texture_external) {
textureSampleLevel(t, s, vec2<f32>(1.0, 2.0));
textureSampleLevel(t2, s, vec2<f32>(1.0, 2.0));
@@ -1028,7 +1022,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -1092,21 +1086,19 @@
f(ext_tex, ext_tex_plane_1, ext_tex_params, smp, ext_tex2, ext_tex_plane_1_1, ext_tex_params_1);
}
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{
- {{0, 0}, {{0, 3}, {0, 4}}},
- {{0, 2}, {{0, 5}, {0, 6}}},
- });
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 3}, {0, 4}}},
+ {{0, 2}, {{0, 5}, {0, 6}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Tests that multiple texture_external params passed to a function produces the
// correct output.
-TEST_F(MultiplanarExternalTextureTest,
- ExternalTexturePassedAsParamMultiple_OutOfOrder) {
- auto* src = R"(
+TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsParamMultiple_OutOfOrder) {
+ auto* src = R"(
@stage(fragment)
fn main() {
f(ext_tex, smp, ext_tex2);
@@ -1123,7 +1115,7 @@
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -1187,20 +1179,19 @@
@group(0) @binding(2) var ext_tex2 : texture_2d<f32>;
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{
- {{0, 0}, {{0, 3}, {0, 4}}},
- {{0, 2}, {{0, 5}, {0, 6}}},
- });
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 3}, {0, 4}}},
+ {{0, 2}, {{0, 5}, {0, 6}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Tests that the texture_external passed to as a parameter to multiple
// functions produces the correct output.
TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsParamNested) {
- auto* src = R"(
+ auto* src = R"(
fn nested(t : texture_external, s : sampler) {
textureSampleLevel(t, s, vec2<f32>(1.0, 2.0));
}
@@ -1218,7 +1209,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -1279,20 +1270,18 @@
f(ext_tex, ext_tex_plane_1, ext_tex_params, smp);
}
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{
- {{0, 0}, {{0, 2}, {0, 3}}},
- });
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 2}, {0, 3}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Tests that the texture_external passed to as a parameter to multiple
// functions produces the correct output.
-TEST_F(MultiplanarExternalTextureTest,
- ExternalTexturePassedAsParamNested_OutOfOrder) {
- auto* src = R"(
+TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsParamNested_OutOfOrder) {
+ auto* src = R"(
fn nested(t : texture_external, s : sampler) {
textureSampleLevel(t, s, vec2<f32>(1.0, 2.0));
}
@@ -1310,7 +1299,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -1371,26 +1360,24 @@
f(ext_tex, ext_tex_plane_1, ext_tex_params, smp);
}
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{
- {{0, 0}, {{0, 2}, {0, 3}}},
- });
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 2}, {0, 3}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Tests that the transform works with a function using an external texture,
// even if there's no external texture declared at module scope.
-TEST_F(MultiplanarExternalTextureTest,
- ExternalTexturePassedAsParamWithoutGlobalDecl) {
- auto* src = R"(
+TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsParamWithoutGlobalDecl) {
+ auto* src = R"(
fn f(ext_tex : texture_external) -> vec2<i32> {
return textureDimensions(ext_tex);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -1415,16 +1402,16 @@
}
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Tests that the the transform handles aliases to external textures
TEST_F(MultiplanarExternalTextureTest, ExternalTextureAlias) {
- auto* src = R"(
+ auto* src = R"(
type ET = texture_external;
fn f(t : ET, s : sampler) {
@@ -1440,7 +1427,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -1499,18 +1486,17 @@
f(ext_tex, ext_tex_plane_1, ext_tex_params, smp);
}
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{
- {{0, 0}, {{0, 2}, {0, 3}}},
- });
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 2}, {0, 3}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
// Tests that the the transform handles aliases to external textures
TEST_F(MultiplanarExternalTextureTest, ExternalTextureAlias_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn main() {
f(ext_tex, smp);
@@ -1526,7 +1512,7 @@
type ET = texture_external;
)";
- auto* expect = R"(
+ auto* expect = R"(
struct GammaTransferParams {
G : f32,
A : f32,
@@ -1585,13 +1571,12 @@
type ET = texture_external;
)";
- DataMap data;
- data.Add<MultiplanarExternalTexture::NewBindingPoints>(
- MultiplanarExternalTexture::BindingsMap{
- {{0, 0}, {{0, 2}, {0, 3}}},
- });
- auto got = Run<MultiplanarExternalTexture>(src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 2}, {0, 3}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/num_workgroups_from_uniform.cc b/src/tint/transform/num_workgroups_from_uniform.cc
index 72f7fb8..17814df 100644
--- a/src/tint/transform/num_workgroups_from_uniform.cc
+++ b/src/tint/transform/num_workgroups_from_uniform.cc
@@ -32,136 +32,126 @@
/// Accessor describes the identifiers used in a member accessor that is being
/// used to retrieve the num_workgroups builtin from a parameter.
struct Accessor {
- Symbol param;
- Symbol member;
+ Symbol param;
+ Symbol member;
- /// Equality operator
- bool operator==(const Accessor& other) const {
- return param == other.param && member == other.member;
- }
- /// Hash function
- struct Hasher {
- size_t operator()(const Accessor& a) const {
- return utils::Hash(a.param, a.member);
+ /// Equality operator
+ bool operator==(const Accessor& other) const {
+ return param == other.param && member == other.member;
}
- };
+ /// Hash function
+ struct Hasher {
+ size_t operator()(const Accessor& a) const { return utils::Hash(a.param, a.member); }
+ };
};
} // namespace
NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default;
NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default;
-bool NumWorkgroupsFromUniform::ShouldRun(const Program* program,
- const DataMap&) const {
- for (auto* node : program->ASTNodes().Objects()) {
- if (auto* attr = node->As<ast::BuiltinAttribute>()) {
- if (attr->builtin == ast::Builtin::kNumWorkgroups) {
- return true;
- }
+bool NumWorkgroupsFromUniform::ShouldRun(const Program* program, const DataMap&) const {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (auto* attr = node->As<ast::BuiltinAttribute>()) {
+ if (attr->builtin == ast::Builtin::kNumWorkgroups) {
+ return true;
+ }
+ }
}
- }
- return false;
+ return false;
}
-void NumWorkgroupsFromUniform::Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap&) const {
- auto* cfg = inputs.Get<Config>();
- if (cfg == nullptr) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform,
- "missing transform data for " + std::string(TypeInfo().name));
- return;
- }
-
- const char* kNumWorkgroupsMemberName = "num_workgroups";
-
- // Find all entry point parameters that declare the num_workgroups builtin.
- std::unordered_set<Accessor, Accessor::Hasher> to_replace;
- for (auto* func : ctx.src->AST().Functions()) {
- // num_workgroups is only valid for compute stages.
- if (func->PipelineStage() != ast::PipelineStage::kCompute) {
- continue;
+void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
+ auto* cfg = inputs.Get<Config>();
+ if (cfg == nullptr) {
+ ctx.dst->Diagnostics().add_error(
+ diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
+ return;
}
- for (auto* param : ctx.src->Sem().Get(func)->Parameters()) {
- // Because the CanonicalizeEntryPointIO transform has been run, builtins
- // will only appear as struct members.
- auto* str = param->Type()->As<sem::Struct>();
- if (!str) {
- continue;
- }
+ const char* kNumWorkgroupsMemberName = "num_workgroups";
- for (auto* member : str->Members()) {
- auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(
- member->Declaration()->attributes);
- if (!builtin || builtin->builtin != ast::Builtin::kNumWorkgroups) {
- continue;
+ // Find all entry point parameters that declare the num_workgroups builtin.
+ std::unordered_set<Accessor, Accessor::Hasher> to_replace;
+ for (auto* func : ctx.src->AST().Functions()) {
+ // num_workgroups is only valid for compute stages.
+ if (func->PipelineStage() != ast::PipelineStage::kCompute) {
+ continue;
}
- // Capture the symbols that would be used to access this member, which
- // we will replace later. We currently have no way to get from the
- // parameter directly to the member accessor expressions that use it.
- to_replace.insert(
- {param->Declaration()->symbol, member->Declaration()->symbol});
+ for (auto* param : ctx.src->Sem().Get(func)->Parameters()) {
+ // Because the CanonicalizeEntryPointIO transform has been run, builtins
+ // will only appear as struct members.
+ auto* str = param->Type()->As<sem::Struct>();
+ if (!str) {
+ continue;
+ }
- // Remove the struct member.
- // The CanonicalizeEntryPointIO transform will have generated this
- // struct uniquely for this particular entry point, so we know that
- // there will be no other uses of this struct in the module and that we
- // can safely modify it here.
- ctx.Remove(str->Declaration()->members, member->Declaration());
+ for (auto* member : str->Members()) {
+ auto* builtin =
+ ast::GetAttribute<ast::BuiltinAttribute>(member->Declaration()->attributes);
+ if (!builtin || builtin->builtin != ast::Builtin::kNumWorkgroups) {
+ continue;
+ }
- // If this is the only member, remove the struct and parameter too.
- if (str->Members().size() == 1) {
- ctx.Remove(func->params, param->Declaration());
- ctx.Remove(ctx.src->AST().GlobalDeclarations(), str->Declaration());
+ // Capture the symbols that would be used to access this member, which
+ // we will replace later. We currently have no way to get from the
+ // parameter directly to the member accessor expressions that use it.
+ to_replace.insert({param->Declaration()->symbol, member->Declaration()->symbol});
+
+ // Remove the struct member.
+ // The CanonicalizeEntryPointIO transform will have generated this
+ // struct uniquely for this particular entry point, so we know that
+ // there will be no other uses of this struct in the module and that we
+ // can safely modify it here.
+ ctx.Remove(str->Declaration()->members, member->Declaration());
+
+ // If this is the only member, remove the struct and parameter too.
+ if (str->Members().size() == 1) {
+ ctx.Remove(func->params, param->Declaration());
+ ctx.Remove(ctx.src->AST().GlobalDeclarations(), str->Declaration());
+ }
+ }
}
- }
- }
- }
-
- // Get (or create, on first call) the uniform buffer that will receive the
- // number of workgroups.
- const ast::Variable* num_workgroups_ubo = nullptr;
- auto get_ubo = [&]() {
- if (!num_workgroups_ubo) {
- auto* num_workgroups_struct = ctx.dst->Structure(
- ctx.dst->Sym(),
- {ctx.dst->Member(kNumWorkgroupsMemberName,
- ctx.dst->ty.vec3(ctx.dst->ty.u32()))});
- num_workgroups_ubo = ctx.dst->Global(
- ctx.dst->Sym(), ctx.dst->ty.Of(num_workgroups_struct),
- ast::StorageClass::kUniform,
- ast::AttributeList{ctx.dst->GroupAndBinding(
- cfg->ubo_binding.group, cfg->ubo_binding.binding)});
- }
- return num_workgroups_ubo;
- };
-
- // Now replace all the places where the builtins are accessed with the value
- // loaded from the uniform buffer.
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- auto* accessor = node->As<ast::MemberAccessorExpression>();
- if (!accessor) {
- continue;
- }
- auto* ident = accessor->structure->As<ast::IdentifierExpression>();
- if (!ident) {
- continue;
}
- if (to_replace.count({ident->symbol, accessor->member->symbol})) {
- ctx.Replace(accessor, ctx.dst->MemberAccessor(get_ubo()->symbol,
- kNumWorkgroupsMemberName));
- }
- }
+ // Get (or create, on first call) the uniform buffer that will receive the
+ // number of workgroups.
+ const ast::Variable* num_workgroups_ubo = nullptr;
+ auto get_ubo = [&]() {
+ if (!num_workgroups_ubo) {
+ auto* num_workgroups_struct = ctx.dst->Structure(
+ ctx.dst->Sym(),
+ {ctx.dst->Member(kNumWorkgroupsMemberName, ctx.dst->ty.vec3(ctx.dst->ty.u32()))});
+ num_workgroups_ubo = ctx.dst->Global(
+ ctx.dst->Sym(), ctx.dst->ty.Of(num_workgroups_struct), ast::StorageClass::kUniform,
+ ast::AttributeList{
+ ctx.dst->GroupAndBinding(cfg->ubo_binding.group, cfg->ubo_binding.binding)});
+ }
+ return num_workgroups_ubo;
+ };
- ctx.Clone();
+ // Now replace all the places where the builtins are accessed with the value
+ // loaded from the uniform buffer.
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ auto* accessor = node->As<ast::MemberAccessorExpression>();
+ if (!accessor) {
+ continue;
+ }
+ auto* ident = accessor->structure->As<ast::IdentifierExpression>();
+ if (!ident) {
+ continue;
+ }
+
+ if (to_replace.count({ident->symbol, accessor->member->symbol})) {
+ ctx.Replace(accessor,
+ ctx.dst->MemberAccessor(get_ubo()->symbol, kNumWorkgroupsMemberName));
+ }
+ }
+
+ ctx.Clone();
}
-NumWorkgroupsFromUniform::Config::Config(sem::BindingPoint ubo_bp)
- : ubo_binding(ubo_bp) {}
+NumWorkgroupsFromUniform::Config::Config(sem::BindingPoint ubo_bp) : ubo_binding(ubo_bp) {}
NumWorkgroupsFromUniform::Config::Config(const Config&) = default;
NumWorkgroupsFromUniform::Config::~Config() = default;
diff --git a/src/tint/transform/num_workgroups_from_uniform.h b/src/tint/transform/num_workgroups_from_uniform.h
index e4cf20e..93c4f15 100644
--- a/src/tint/transform/num_workgroups_from_uniform.h
+++ b/src/tint/transform/num_workgroups_from_uniform.h
@@ -42,46 +42,42 @@
///
/// @note Depends on the following transforms to have been run first:
/// * CanonicalizeEntryPointIO
-class NumWorkgroupsFromUniform
- : public Castable<NumWorkgroupsFromUniform, Transform> {
- public:
- /// Constructor
- NumWorkgroupsFromUniform();
- /// Destructor
- ~NumWorkgroupsFromUniform() override;
-
- /// Configuration options for the NumWorkgroupsFromUniform transform.
- struct Config : public Castable<Data, transform::Data> {
+class NumWorkgroupsFromUniform : public Castable<NumWorkgroupsFromUniform, Transform> {
+ public:
/// Constructor
- /// @param ubo_bp the binding point to use for the generated uniform buffer.
- explicit Config(sem::BindingPoint ubo_bp);
-
- /// Copy constructor
- Config(const Config&);
-
+ NumWorkgroupsFromUniform();
/// Destructor
- ~Config() override;
+ ~NumWorkgroupsFromUniform() override;
- /// The binding point to use for the generated uniform buffer.
- sem::BindingPoint ubo_binding;
- };
+ /// Configuration options for the NumWorkgroupsFromUniform transform.
+ struct Config : public Castable<Data, transform::Data> {
+ /// Constructor
+ /// @param ubo_bp the binding point to use for the generated uniform buffer.
+ explicit Config(sem::BindingPoint ubo_bp);
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
+ /// Copy constructor
+ Config(const Config&);
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ /// Destructor
+ ~Config() override;
+
+ /// The binding point to use for the generated uniform buffer.
+ sem::BindingPoint ubo_binding;
+ };
+
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
+
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/num_workgroups_from_uniform_test.cc b/src/tint/transform/num_workgroups_from_uniform_test.cc
index 734d11b..de6c665 100644
--- a/src/tint/transform/num_workgroups_from_uniform_test.cc
+++ b/src/tint/transform/num_workgroups_from_uniform_test.cc
@@ -26,43 +26,41 @@
using NumWorkgroupsFromUniformTest = TransformTest;
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src));
+ EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src));
}
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunHasNumWorkgroups) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) {
}
)";
- EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src));
+ EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src));
}
TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) {
}
)";
- auto* expect =
- "error: missing transform data for "
- "tint::transform::NumWorkgroupsFromUniform";
+ auto* expect =
+ "error: missing transform data for "
+ "tint::transform::NumWorkgroupsFromUniform";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
- src, data);
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(NumWorkgroupsFromUniformTest, Basic) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) {
let groups_x = num_wgs.x;
@@ -71,7 +69,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_2 {
num_workgroups : vec3<u32>,
}
@@ -90,17 +88,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
- auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
- src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(NumWorkgroupsFromUniformTest, StructOnlyMember) {
- auto* src = R"(
+ auto* src = R"(
struct Builtins {
@builtin(num_workgroups) num_wgs : vec3<u32>,
};
@@ -113,7 +109,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_2 {
num_workgroups : vec3<u32>,
}
@@ -136,17 +132,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
- auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
- src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(NumWorkgroupsFromUniformTest, StructOnlyMember_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main(in : Builtins) {
let groups_x = in.num_wgs.x;
@@ -159,7 +153,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_2 {
num_workgroups : vec3<u32>,
}
@@ -182,17 +176,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
- auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
- src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(NumWorkgroupsFromUniformTest, StructMultipleMembers) {
- auto* src = R"(
+ auto* src = R"(
struct Builtins {
@builtin(global_invocation_id) gid : vec3<u32>,
@builtin(num_workgroups) num_wgs : vec3<u32>,
@@ -207,7 +199,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_2 {
num_workgroups : vec3<u32>,
}
@@ -239,17 +231,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
- auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
- src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(NumWorkgroupsFromUniformTest, StructMultipleMembers_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main(in : Builtins) {
let groups_x = in.num_wgs.x;
@@ -265,7 +255,7 @@
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_2 {
num_workgroups : vec3<u32>,
}
@@ -297,17 +287,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
- auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
- src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(NumWorkgroupsFromUniformTest, MultipleEntryPoints) {
- auto* src = R"(
+ auto* src = R"(
struct Builtins1 {
@builtin(num_workgroups) num_wgs : vec3<u32>,
};
@@ -340,7 +328,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_symbol_6 {
num_workgroups : vec3<u32>,
}
@@ -398,17 +386,15 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
- auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
- src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
+ EXPECT_EQ(expect, str(got));
}
TEST_F(NumWorkgroupsFromUniformTest, NoUsages) {
- auto* src = R"(
+ auto* src = R"(
struct Builtins {
@builtin(global_invocation_id) gid : vec3<u32>,
@builtin(workgroup_id) wgid : vec3<u32>,
@@ -419,7 +405,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct Builtins {
gid : vec3<u32>,
wgid : vec3<u32>,
@@ -441,13 +427,11 @@
}
)";
- DataMap data;
- data.Add<CanonicalizeEntryPointIO::Config>(
- CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
- data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
- auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
- src, data);
- EXPECT_EQ(expect, str(got));
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/promote_initializers_to_const_var.cc b/src/tint/transform/promote_initializers_to_const_var.cc
index a60dd6b..81b5603 100644
--- a/src/tint/transform/promote_initializers_to_const_var.cc
+++ b/src/tint/transform/promote_initializers_to_const_var.cc
@@ -27,57 +27,55 @@
PromoteInitializersToConstVar::~PromoteInitializersToConstVar() = default;
-void PromoteInitializersToConstVar::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- HoistToDeclBefore hoist_to_decl_before(ctx);
+void PromoteInitializersToConstVar::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ HoistToDeclBefore hoist_to_decl_before(ctx);
- // Hoists array and structure initializers to a constant variable, declared
- // just before the statement of usage.
- auto type_ctor_to_let = [&](const ast::CallExpression* expr) {
- auto* ctor = ctx.src->Sem().Get(expr);
- if (!ctor->Target()->Is<sem::TypeConstructor>()) {
- return true;
- }
- auto* sem_stmt = ctor->Stmt();
- if (!sem_stmt) {
- // Expression is outside of a statement. This usually means the
- // expression is part of a global (module-scope) constant declaration.
- // These must be constexpr, and so cannot contain the type of
- // expressions that must be sanitized.
- return true;
+ // Hoists array and structure initializers to a constant variable, declared
+ // just before the statement of usage.
+ auto type_ctor_to_let = [&](const ast::CallExpression* expr) {
+ auto* ctor = ctx.src->Sem().Get(expr);
+ if (!ctor->Target()->Is<sem::TypeConstructor>()) {
+ return true;
+ }
+ auto* sem_stmt = ctor->Stmt();
+ if (!sem_stmt) {
+ // Expression is outside of a statement. This usually means the
+ // expression is part of a global (module-scope) constant declaration.
+ // These must be constexpr, and so cannot contain the type of
+ // expressions that must be sanitized.
+ return true;
+ }
+
+ auto* stmt = sem_stmt->Declaration();
+
+ if (auto* src_var_decl = stmt->As<ast::VariableDeclStatement>()) {
+ if (src_var_decl->variable->constructor == expr) {
+ // This statement is just a variable declaration with the
+ // initializer as the constructor value. This is what we're
+ // attempting to transform to, and so ignore.
+ return true;
+ }
+ }
+
+ auto* src_ty = ctor->Type();
+ if (!src_ty->IsAnyOf<sem::Array, sem::Struct>()) {
+ // We only care about array and struct initializers
+ return true;
+ }
+
+ return hoist_to_decl_before.Add(ctor, expr, true);
+ };
+
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* call_expr = node->As<ast::CallExpression>()) {
+ if (!type_ctor_to_let(call_expr)) {
+ return;
+ }
+ }
}
- auto* stmt = sem_stmt->Declaration();
-
- if (auto* src_var_decl = stmt->As<ast::VariableDeclStatement>()) {
- if (src_var_decl->variable->constructor == expr) {
- // This statement is just a variable declaration with the
- // initializer as the constructor value. This is what we're
- // attempting to transform to, and so ignore.
- return true;
- }
- }
-
- auto* src_ty = ctor->Type();
- if (!src_ty->IsAnyOf<sem::Array, sem::Struct>()) {
- // We only care about array and struct initializers
- return true;
- }
-
- return hoist_to_decl_before.Add(ctor, expr, true);
- };
-
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- if (auto* call_expr = node->As<ast::CallExpression>()) {
- if (!type_ctor_to_let(call_expr)) {
- return;
- }
- }
- }
-
- hoist_to_decl_before.Apply();
- ctx.Clone();
+ hoist_to_decl_before.Apply();
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/promote_initializers_to_const_var.h b/src/tint/transform/promote_initializers_to_const_var.h
index 586e27d..67a32c4 100644
--- a/src/tint/transform/promote_initializers_to_const_var.h
+++ b/src/tint/transform/promote_initializers_to_const_var.h
@@ -22,25 +22,22 @@
/// A transform that hoists the array and structure initializers to a constant
/// variable, declared just before the statement of usage.
/// @see crbug.com/tint/406
-class PromoteInitializersToConstVar
- : public Castable<PromoteInitializersToConstVar, Transform> {
- public:
- /// Constructor
- PromoteInitializersToConstVar();
+class PromoteInitializersToConstVar : public Castable<PromoteInitializersToConstVar, Transform> {
+ public:
+ /// Constructor
+ PromoteInitializersToConstVar();
- /// Destructor
- ~PromoteInitializersToConstVar() override;
+ /// Destructor
+ ~PromoteInitializersToConstVar() override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/promote_initializers_to_const_var_test.cc b/src/tint/transform/promote_initializers_to_const_var_test.cc
index 87b9edc..f322478 100644
--- a/src/tint/transform/promote_initializers_to_const_var_test.cc
+++ b/src/tint/transform/promote_initializers_to_const_var_test.cc
@@ -22,16 +22,16 @@
using PromoteInitializersToConstVarTest = TransformTest;
TEST_F(PromoteInitializersToConstVarTest, EmptyModule) {
- auto* src = "";
- auto* expect = "";
+ auto* src = "";
+ auto* expect = "";
- auto got = Run<PromoteInitializersToConstVar>(src);
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, BasicArray) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var f0 = 1.0;
var f1 = 2.0;
@@ -41,7 +41,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var f0 = 1.0;
var f1 = 2.0;
@@ -52,14 +52,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteInitializersToConstVar>(src);
+ DataMap data;
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, BasicStruct) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : i32,
b : f32,
@@ -71,7 +71,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : i32,
b : f32,
@@ -84,14 +84,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteInitializersToConstVar>(src);
+ DataMap data;
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, BasicStruct_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var x = S(1, 2.0, vec3<f32>()).b;
}
@@ -103,7 +103,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
let tint_symbol = S(1, 2.0, vec3<f32>());
var x = tint_symbol.b;
@@ -116,14 +116,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteInitializersToConstVar>(src);
+ DataMap data;
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, ArrayInForLoopInit) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var insert_after = 1;
for(var i = array<f32, 4u>(0.0, 1.0, 2.0, 3.0)[2]; ; ) {
@@ -132,7 +132,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var insert_after = 1;
let tint_symbol = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
@@ -142,14 +142,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteInitializersToConstVar>(src);
+ DataMap data;
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, StructInForLoopInit) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : i32,
b : f32,
@@ -164,7 +164,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : i32,
b : f32,
@@ -180,14 +180,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteInitializersToConstVar>(src);
+ DataMap data;
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, StructInForLoopInit_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var insert_after = 1;
for(var x = S(1, 2.0, vec3<f32>()).b; ; ) {
@@ -202,7 +202,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var insert_after = 1;
let tint_symbol = S(1, 2.0, vec3<f32>());
@@ -218,14 +218,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteInitializersToConstVar>(src);
+ DataMap data;
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, ArrayInForLoopCond) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var f = 1.0;
for(; f == array<f32, 1u>(f)[0]; f = f + 1.0) {
@@ -234,7 +234,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var f = 1.0;
loop {
@@ -253,14 +253,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteInitializersToConstVar>(src);
+ DataMap data;
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, ArrayInForLoopCont) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var f = 0.0;
for(; f < 10.0; f = f + array<f32, 1u>(1.0)[0]) {
@@ -269,7 +269,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var f = 0.0;
loop {
@@ -288,14 +288,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteInitializersToConstVar>(src);
+ DataMap data;
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, ArrayInForLoopInitCondCont) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
for(var f = array<f32, 1u>(0.0)[0];
f < array<f32, 1u>(1.0)[0];
@@ -305,7 +305,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
let tint_symbol = array<f32, 1u>(0.0);
{
@@ -328,14 +328,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteInitializersToConstVar>(src);
+ DataMap data;
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, ArrayInElseIf) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var f = 1.0;
if (true) {
@@ -346,7 +346,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var f = 1.0;
if (true) {
@@ -360,14 +360,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteInitializersToConstVar>(src);
+ DataMap data;
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, ArrayInElseIfChain) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var f = 1.0;
if (true) {
@@ -386,7 +386,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var f = 1.0;
if (true) {
@@ -411,20 +411,20 @@
}
)";
- DataMap data;
- auto got = Run<PromoteInitializersToConstVar>(src);
+ DataMap data;
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, ArrayInArrayArray) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i = array<array<f32, 2u>, 2u>(array<f32, 2u>(1.0, 2.0), array<f32, 2u>(3.0, 4.0))[0][1];
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
let tint_symbol = array<f32, 2u>(1.0, 2.0);
let tint_symbol_1 = array<f32, 2u>(3.0, 4.0);
@@ -433,14 +433,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteInitializersToConstVar>(src);
+ DataMap data;
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, StructNested) {
- auto* src = R"(
+ auto* src = R"(
struct S1 {
a : i32,
};
@@ -460,7 +460,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S1 {
a : i32,
}
@@ -483,14 +483,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteInitializersToConstVar>(src);
+ DataMap data;
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, Mixed) {
- auto* src = R"(
+ auto* src = R"(
struct S1 {
a : i32,
};
@@ -504,7 +504,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S1 {
a : i32,
}
@@ -523,14 +523,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteInitializersToConstVar>(src);
+ DataMap data;
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, Mixed_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var x = S2(array<S1, 3u>(S1(1), S1(2), S1(3))).a[1].a;
}
@@ -544,7 +544,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
let tint_symbol = S1(1);
let tint_symbol_1 = S1(2);
@@ -563,14 +563,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteInitializersToConstVar>(src);
+ DataMap data;
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, NoChangeOnVarDecl) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : i32,
b : f32,
@@ -587,16 +587,16 @@
let module_str : S = S(1, 2.0, 3);
)";
- auto* expect = src;
+ auto* expect = src;
- DataMap data;
- auto got = Run<PromoteInitializersToConstVar>(src);
+ DataMap data;
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, NoChangeOnVarDecl_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var local_arr = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
var local_str = S(1, 2.0, 3);
@@ -613,12 +613,12 @@
let module_arr : array<f32, 4u> = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
)";
- auto* expect = src;
+ auto* expect = src;
- DataMap data;
- auto got = Run<PromoteInitializersToConstVar>(src);
+ DataMap data;
+ auto got = Run<PromoteInitializersToConstVar>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/promote_side_effects_to_decl.cc b/src/tint/transform/promote_side_effects_to_decl.cc
index 146d9a7..6f1cc4c 100644
--- a/src/tint/transform/promote_side_effects_to_decl.cc
+++ b/src/tint/transform/promote_side_effects_to_decl.cc
@@ -39,62 +39,58 @@
// Base state class for common members
class StateBase {
- protected:
- CloneContext& ctx;
- ProgramBuilder& b;
- const sem::Info& sem;
+ protected:
+ CloneContext& ctx;
+ ProgramBuilder& b;
+ const sem::Info& sem;
- explicit StateBase(CloneContext& ctx_in)
- : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}
+ explicit StateBase(CloneContext& ctx_in)
+ : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}
};
// This first transform converts side-effecting for-loops to loops and else-ifs
// to else {if}s so that the next transform, DecomposeSideEffects, can insert
// hoisted expressions above their current location.
-struct SimplifySideEffectStatements
- : Castable<PromoteSideEffectsToDecl, Transform> {
- class State;
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const override;
+struct SimplifySideEffectStatements : Castable<PromoteSideEffectsToDecl, Transform> {
+ class State;
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const override;
};
class SimplifySideEffectStatements::State : public StateBase {
- HoistToDeclBefore hoist_to_decl_before;
+ HoistToDeclBefore hoist_to_decl_before;
- public:
- explicit State(CloneContext& ctx_in)
- : StateBase(ctx_in), hoist_to_decl_before(ctx_in) {}
+ public:
+ explicit State(CloneContext& ctx_in) : StateBase(ctx_in), hoist_to_decl_before(ctx_in) {}
- void Run() {
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- if (auto* expr = node->As<ast::Expression>()) {
- auto* sem_expr = sem.Get(expr);
- if (!sem_expr || !sem_expr->HasSideEffects()) {
- continue;
+ void Run() {
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* expr = node->As<ast::Expression>()) {
+ auto* sem_expr = sem.Get(expr);
+ if (!sem_expr || !sem_expr->HasSideEffects()) {
+ continue;
+ }
+
+ hoist_to_decl_before.Prepare(sem_expr);
+ }
}
- hoist_to_decl_before.Prepare(sem_expr);
- }
+ hoist_to_decl_before.Apply();
+ ctx.Clone();
}
-
- hoist_to_decl_before.Apply();
- ctx.Clone();
- }
};
-void SimplifySideEffectStatements::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- State state(ctx);
- state.Run();
+void SimplifySideEffectStatements::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ State state(ctx);
+ state.Run();
}
// Decomposes side-effecting expressions to ensure order of evaluation. This
// handles both breaking down logical binary expressions for short-circuit
// evaluation, as well as hoisting expressions to ensure order of evaluation.
struct DecomposeSideEffects : Castable<PromoteSideEffectsToDecl, Transform> {
- class CollectHoistsState;
- class DecomposeState;
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const override;
+ class CollectHoistsState;
+ class DecomposeState;
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const override;
};
// CollectHoistsState traverses the AST top-down, identifying which expressions
@@ -103,589 +99,567 @@
// expressions.
using ToHoistSet = std::unordered_set<const ast::Expression*>;
class DecomposeSideEffects::CollectHoistsState : public StateBase {
- // Expressions to hoist because they either cause or receive side-effects.
- ToHoistSet to_hoist;
+ // Expressions to hoist because they either cause or receive side-effects.
+ ToHoistSet to_hoist;
- // Used to mark expressions as not or no longer having side-effects.
- std::unordered_set<const ast::Expression*> no_side_effects;
+ // Used to mark expressions as not or no longer having side-effects.
+ std::unordered_set<const ast::Expression*> no_side_effects;
- // Returns true if `expr` has side-effects. Unlike invoking
- // sem::Expression::HasSideEffects(), this function takes into account whether
- // `expr` has been hoisted, returning false in that case. Furthermore, it
- // returns the correct result on parent expression nodes by traversing the
- // expression tree, memoizing the results to ensure O(1) amortized lookup.
- bool HasSideEffects(const ast::Expression* expr) {
- if (no_side_effects.count(expr)) {
- return false;
- }
-
- return Switch(
- expr,
- [&](const ast::CallExpression* e) -> bool {
- return sem.Get(e)->HasSideEffects();
- },
- [&](const ast::BinaryExpression* e) {
- if (HasSideEffects(e->lhs) || HasSideEffects(e->rhs)) {
- return true;
- }
- no_side_effects.insert(e);
- return false;
- },
- [&](const ast::IndexAccessorExpression* e) {
- if (HasSideEffects(e->object) || HasSideEffects(e->index)) {
- return true;
- }
- no_side_effects.insert(e);
- return false;
- },
- [&](const ast::MemberAccessorExpression* e) {
- if (HasSideEffects(e->structure) || HasSideEffects(e->member)) {
- return true;
- }
- no_side_effects.insert(e);
- return false;
- },
- [&](const ast::BitcastExpression* e) { //
- if (HasSideEffects(e->expr)) {
- return true;
- }
- no_side_effects.insert(e);
- return false;
- },
-
- [&](const ast::UnaryOpExpression* e) { //
- if (HasSideEffects(e->expr)) {
- return true;
- }
- no_side_effects.insert(e);
- return false;
- },
- [&](const ast::IdentifierExpression* e) {
- no_side_effects.insert(e);
- return false;
- },
- [&](const ast::LiteralExpression* e) {
- no_side_effects.insert(e);
- return false;
- },
- [&](const ast::PhonyExpression* e) {
- no_side_effects.insert(e);
- return false;
- },
- [&](Default) {
- TINT_ICE(Transform, b.Diagnostics()) << "Unhandled expression type";
- return false;
- });
- }
-
- // Adds `e` to `to_hoist` for hoisting to a let later on.
- void Hoist(const ast::Expression* e) {
- no_side_effects.insert(e);
- to_hoist.emplace(e);
- }
-
- // Hoists any expressions in `maybe_hoist` and clears it
- void Flush(ast::ExpressionList& maybe_hoist) {
- for (auto* m : maybe_hoist) {
- Hoist(m);
- }
- maybe_hoist.clear();
- }
-
- // Recursive function that processes expressions for side-effects. It
- // traverses the expression tree child before parent, left-to-right. Each call
- // returns whether the input expression should maybe be hoisted, allowing the
- // parent node to decide whether to hoist or not. Generally:
- // * When 'true' is returned, the expression is added to the maybe_hoist list.
- // * When a side-effecting expression is met, we flush the expressions in the
- // maybe_hoist list, as they are potentially receivers of the side-effects.
- // * For index and member accessor expressions, special care is taken to not
- // over-hoist the lhs expressions, as these may be be chained to refer to a
- // single memory location.
- bool ProcessExpression(const ast::Expression* expr,
- ast::ExpressionList& maybe_hoist) {
- auto process = [&](const ast::Expression* e) -> bool {
- return ProcessExpression(e, maybe_hoist);
- };
-
- auto default_process = [&](const ast::Expression* e) {
- auto maybe = process(e);
- if (maybe) {
- maybe_hoist.emplace_back(e);
- }
- if (HasSideEffects(e)) {
- Flush(maybe_hoist);
- }
- return false;
- };
-
- auto binary_process = [&](auto* lhs, auto* rhs) {
- // If neither side causes side-effects, but at least one receives them,
- // let parent node hoist. This avoids over-hoisting side-effect receivers
- // of compound binary expressions (e.g. for "((a && b) && c) && f()", we
- // don't want to hoist each of "a", "b", and "c" separately, but want to
- // hoist "((a && b) && c)".
- if (!HasSideEffects(lhs) && !HasSideEffects(rhs)) {
- auto lhs_maybe = process(lhs);
- auto rhs_maybe = process(rhs);
- if (lhs_maybe || rhs_maybe) {
- return true;
+ // Returns true if `expr` has side-effects. Unlike invoking
+ // sem::Expression::HasSideEffects(), this function takes into account whether
+ // `expr` has been hoisted, returning false in that case. Furthermore, it
+ // returns the correct result on parent expression nodes by traversing the
+ // expression tree, memoizing the results to ensure O(1) amortized lookup.
+ bool HasSideEffects(const ast::Expression* expr) {
+ if (no_side_effects.count(expr)) {
+ return false;
}
- return false;
- }
- default_process(lhs);
- default_process(rhs);
- return false;
- };
-
- auto accessor_process = [&](auto* lhs, auto* rhs) {
- auto maybe = process(lhs);
- // If lhs is a variable, let parent node hoist otherwise flush it right
- // away. This is to avoid over-hoisting the lhs of accessor chains (e.g.
- // for "v[a][b][c] + g()" we want to hoist all of "v[a][b][c]", not "t1 =
- // v[a]", then "t2 = t1[b]" then "t3 = t2[c]").
- if (maybe && HasSideEffects(lhs)) {
- maybe_hoist.emplace_back(lhs);
- Flush(maybe_hoist);
- maybe = false;
- }
- default_process(rhs);
- return maybe;
- };
-
- return Switch(
- expr,
- [&](const ast::CallExpression* e) -> bool {
- // We eagerly flush any variables in maybe_hoist for the current
- // call expression. Then we scope maybe_hoist to the processing of
- // the call args. This ensures that given: g(c, a(0), d) we hoist
- // 'c' because of 'a(0)', but not 'd' because there's no need, since
- // the call to g() will be hoisted if necessary.
- if (HasSideEffects(e)) {
- Flush(maybe_hoist);
- }
-
- TINT_SCOPED_ASSIGNMENT(maybe_hoist, {});
- for (auto* a : e->args) {
- default_process(a);
- }
-
- // Always hoist this call, even if it has no side-effects to ensure
- // left-to-right order of evaluation.
- // E.g. for "no_side_effects() + side_effects()", we want to hoist
- // no_side_effects() first.
- return true;
- },
- [&](const ast::IdentifierExpression* e) {
- if (auto* sem_e = sem.Get(e)) {
- if (auto* var_user = sem_e->As<sem::VariableUser>()) {
- // Don't hoist constants.
- if (var_user->ConstantValue().IsValid()) {
+ return Switch(
+ expr,
+ [&](const ast::CallExpression* e) -> bool { return sem.Get(e)->HasSideEffects(); },
+ [&](const ast::BinaryExpression* e) {
+ if (HasSideEffects(e->lhs) || HasSideEffects(e->rhs)) {
+ return true;
+ }
+ no_side_effects.insert(e);
return false;
- }
- // Don't hoist read-only variables as they cannot receive
- // side-effects.
- if (var_user->Variable()->Access() == ast::Access::kRead) {
+ },
+ [&](const ast::IndexAccessorExpression* e) {
+ if (HasSideEffects(e->object) || HasSideEffects(e->index)) {
+ return true;
+ }
+ no_side_effects.insert(e);
return false;
- }
- return true;
+ },
+ [&](const ast::MemberAccessorExpression* e) {
+ if (HasSideEffects(e->structure) || HasSideEffects(e->member)) {
+ return true;
+ }
+ no_side_effects.insert(e);
+ return false;
+ },
+ [&](const ast::BitcastExpression* e) { //
+ if (HasSideEffects(e->expr)) {
+ return true;
+ }
+ no_side_effects.insert(e);
+ return false;
+ },
+
+ [&](const ast::UnaryOpExpression* e) { //
+ if (HasSideEffects(e->expr)) {
+ return true;
+ }
+ no_side_effects.insert(e);
+ return false;
+ },
+ [&](const ast::IdentifierExpression* e) {
+ no_side_effects.insert(e);
+ return false;
+ },
+ [&](const ast::LiteralExpression* e) {
+ no_side_effects.insert(e);
+ return false;
+ },
+ [&](const ast::PhonyExpression* e) {
+ no_side_effects.insert(e);
+ return false;
+ },
+ [&](Default) {
+ TINT_ICE(Transform, b.Diagnostics()) << "Unhandled expression type";
+ return false;
+ });
+ }
+
+ // Adds `e` to `to_hoist` for hoisting to a let later on.
+ void Hoist(const ast::Expression* e) {
+ no_side_effects.insert(e);
+ to_hoist.emplace(e);
+ }
+
+ // Hoists any expressions in `maybe_hoist` and clears it
+ void Flush(ast::ExpressionList& maybe_hoist) {
+ for (auto* m : maybe_hoist) {
+ Hoist(m);
+ }
+ maybe_hoist.clear();
+ }
+
+ // Recursive function that processes expressions for side-effects. It
+ // traverses the expression tree child before parent, left-to-right. Each call
+ // returns whether the input expression should maybe be hoisted, allowing the
+ // parent node to decide whether to hoist or not. Generally:
+ // * When 'true' is returned, the expression is added to the maybe_hoist list.
+ // * When a side-effecting expression is met, we flush the expressions in the
+ // maybe_hoist list, as they are potentially receivers of the side-effects.
+ // * For index and member accessor expressions, special care is taken to not
+ // over-hoist the lhs expressions, as these may be be chained to refer to a
+ // single memory location.
+ bool ProcessExpression(const ast::Expression* expr, ast::ExpressionList& maybe_hoist) {
+ auto process = [&](const ast::Expression* e) -> bool {
+ return ProcessExpression(e, maybe_hoist);
+ };
+
+ auto default_process = [&](const ast::Expression* e) {
+ auto maybe = process(e);
+ if (maybe) {
+ maybe_hoist.emplace_back(e);
}
- }
- return false;
- },
- [&](const ast::BinaryExpression* e) {
- if (e->IsLogical() && HasSideEffects(e)) {
- // Don't hoist children of logical binary expressions with
- // side-effects. These will be handled by DecomposeState.
- process(e->lhs);
- process(e->rhs);
+ if (HasSideEffects(e)) {
+ Flush(maybe_hoist);
+ }
return false;
- }
- return binary_process(e->lhs, e->rhs);
- },
- [&](const ast::BitcastExpression* e) { //
- return process(e->expr);
- },
- [&](const ast::UnaryOpExpression* e) { //
- auto r = process(e->expr);
- // Don't hoist address-of expressions.
- // E.g. for "g(&b, a(0))", we hoist "a(0)" only.
- if (e->op == ast::UnaryOp::kAddressOf) {
+ };
+
+ auto binary_process = [&](auto* lhs, auto* rhs) {
+ // If neither side causes side-effects, but at least one receives them,
+ // let parent node hoist. This avoids over-hoisting side-effect receivers
+ // of compound binary expressions (e.g. for "((a && b) && c) && f()", we
+ // don't want to hoist each of "a", "b", and "c" separately, but want to
+ // hoist "((a && b) && c)".
+ if (!HasSideEffects(lhs) && !HasSideEffects(rhs)) {
+ auto lhs_maybe = process(lhs);
+ auto rhs_maybe = process(rhs);
+ if (lhs_maybe || rhs_maybe) {
+ return true;
+ }
+ return false;
+ }
+
+ default_process(lhs);
+ default_process(rhs);
return false;
- }
- return r;
- },
- [&](const ast::IndexAccessorExpression* e) {
- return accessor_process(e->object, e->index);
- },
- [&](const ast::MemberAccessorExpression* e) {
- return accessor_process(e->structure, e->member);
- },
- [&](const ast::LiteralExpression*) {
- // Leaf
- return false;
- },
- [&](const ast::PhonyExpression*) {
- // Leaf
- return false;
- },
- [&](Default) {
- TINT_ICE(Transform, b.Diagnostics()) << "Unhandled expression type";
- return false;
- });
- }
+ };
- // Starts the recursive processing of a statement's expression(s) to hoist
- // side-effects to lets.
- void ProcessStatement(const ast::Expression* expr) {
- if (!expr) {
- return;
+ auto accessor_process = [&](auto* lhs, auto* rhs) {
+ auto maybe = process(lhs);
+ // If lhs is a variable, let parent node hoist otherwise flush it right
+ // away. This is to avoid over-hoisting the lhs of accessor chains (e.g.
+ // for "v[a][b][c] + g()" we want to hoist all of "v[a][b][c]", not "t1 =
+ // v[a]", then "t2 = t1[b]" then "t3 = t2[c]").
+ if (maybe && HasSideEffects(lhs)) {
+ maybe_hoist.emplace_back(lhs);
+ Flush(maybe_hoist);
+ maybe = false;
+ }
+ default_process(rhs);
+ return maybe;
+ };
+
+ return Switch(
+ expr,
+ [&](const ast::CallExpression* e) -> bool {
+ // We eagerly flush any variables in maybe_hoist for the current
+ // call expression. Then we scope maybe_hoist to the processing of
+ // the call args. This ensures that given: g(c, a(0), d) we hoist
+ // 'c' because of 'a(0)', but not 'd' because there's no need, since
+ // the call to g() will be hoisted if necessary.
+ if (HasSideEffects(e)) {
+ Flush(maybe_hoist);
+ }
+
+ TINT_SCOPED_ASSIGNMENT(maybe_hoist, {});
+ for (auto* a : e->args) {
+ default_process(a);
+ }
+
+ // Always hoist this call, even if it has no side-effects to ensure
+ // left-to-right order of evaluation.
+ // E.g. for "no_side_effects() + side_effects()", we want to hoist
+ // no_side_effects() first.
+ return true;
+ },
+ [&](const ast::IdentifierExpression* e) {
+ if (auto* sem_e = sem.Get(e)) {
+ if (auto* var_user = sem_e->As<sem::VariableUser>()) {
+ // Don't hoist constants.
+ if (var_user->ConstantValue().IsValid()) {
+ return false;
+ }
+ // Don't hoist read-only variables as they cannot receive
+ // side-effects.
+ if (var_user->Variable()->Access() == ast::Access::kRead) {
+ return false;
+ }
+ return true;
+ }
+ }
+ return false;
+ },
+ [&](const ast::BinaryExpression* e) {
+ if (e->IsLogical() && HasSideEffects(e)) {
+ // Don't hoist children of logical binary expressions with
+ // side-effects. These will be handled by DecomposeState.
+ process(e->lhs);
+ process(e->rhs);
+ return false;
+ }
+ return binary_process(e->lhs, e->rhs);
+ },
+ [&](const ast::BitcastExpression* e) { //
+ return process(e->expr);
+ },
+ [&](const ast::UnaryOpExpression* e) { //
+ auto r = process(e->expr);
+ // Don't hoist address-of expressions.
+ // E.g. for "g(&b, a(0))", we hoist "a(0)" only.
+ if (e->op == ast::UnaryOp::kAddressOf) {
+ return false;
+ }
+ return r;
+ },
+ [&](const ast::IndexAccessorExpression* e) {
+ return accessor_process(e->object, e->index);
+ },
+ [&](const ast::MemberAccessorExpression* e) {
+ return accessor_process(e->structure, e->member);
+ },
+ [&](const ast::LiteralExpression*) {
+ // Leaf
+ return false;
+ },
+ [&](const ast::PhonyExpression*) {
+ // Leaf
+ return false;
+ },
+ [&](Default) {
+ TINT_ICE(Transform, b.Diagnostics()) << "Unhandled expression type";
+ return false;
+ });
}
- ast::ExpressionList maybe_hoist;
- ProcessExpression(expr, maybe_hoist);
- }
+ // Starts the recursive processing of a statement's expression(s) to hoist
+ // side-effects to lets.
+ void ProcessStatement(const ast::Expression* expr) {
+ if (!expr) {
+ return;
+ }
- // Special case for processing assignment statement expressions, as we must
- // evaluate the rhs before the lhs, and possibly hoist the rhs expression.
- void ProcessAssignment(const ast::Expression* lhs,
- const ast::Expression* rhs) {
- // Evaluate rhs before lhs
- ast::ExpressionList maybe_hoist;
- if (ProcessExpression(rhs, maybe_hoist)) {
- maybe_hoist.emplace_back(rhs);
+ ast::ExpressionList maybe_hoist;
+ ProcessExpression(expr, maybe_hoist);
}
- // If the rhs has side-effects, it may affect the lhs, so hoist it right
- // away. e.g. "b[c] = a(0);"
- if (HasSideEffects(rhs)) {
- // Technically, we can always hoist rhs, but don't bother doing so when
- // the lhs is just a variable or phony.
- if (!lhs->IsAnyOf<ast::IdentifierExpression, ast::PhonyExpression>()) {
- Flush(maybe_hoist);
- }
+ // Special case for processing assignment statement expressions, as we must
+ // evaluate the rhs before the lhs, and possibly hoist the rhs expression.
+ void ProcessAssignment(const ast::Expression* lhs, const ast::Expression* rhs) {
+ // Evaluate rhs before lhs
+ ast::ExpressionList maybe_hoist;
+ if (ProcessExpression(rhs, maybe_hoist)) {
+ maybe_hoist.emplace_back(rhs);
+ }
+
+ // If the rhs has side-effects, it may affect the lhs, so hoist it right
+ // away. e.g. "b[c] = a(0);"
+ if (HasSideEffects(rhs)) {
+ // Technically, we can always hoist rhs, but don't bother doing so when
+ // the lhs is just a variable or phony.
+ if (!lhs->IsAnyOf<ast::IdentifierExpression, ast::PhonyExpression>()) {
+ Flush(maybe_hoist);
+ }
+ }
+
+ // If maybe_hoist still has values, it means they are potential side-effect
+ // receivers. We pass this in while processing the lhs, in which case they
+ // may get hoisted if the lhs has side-effects. E.g. "b[a(0)] = c;".
+ ProcessExpression(lhs, maybe_hoist);
}
- // If maybe_hoist still has values, it means they are potential side-effect
- // receivers. We pass this in while processing the lhs, in which case they
- // may get hoisted if the lhs has side-effects. E.g. "b[a(0)] = c;".
- ProcessExpression(lhs, maybe_hoist);
- }
+ public:
+ explicit CollectHoistsState(CloneContext& ctx_in) : StateBase(ctx_in) {}
- public:
- explicit CollectHoistsState(CloneContext& ctx_in) : StateBase(ctx_in) {}
+ ToHoistSet Run() {
+ // Traverse all statements, recursively processing their expression tree(s)
+ // to hoist side-effects to lets.
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ auto* stmt = node->As<ast::Statement>();
+ if (!stmt) {
+ continue;
+ }
- ToHoistSet Run() {
- // Traverse all statements, recursively processing their expression tree(s)
- // to hoist side-effects to lets.
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- auto* stmt = node->As<ast::Statement>();
- if (!stmt) {
- continue;
- }
+ Switch(
+ stmt, [&](const ast::AssignmentStatement* s) { ProcessAssignment(s->lhs, s->rhs); },
+ [&](const ast::CallStatement* s) { //
+ ProcessStatement(s->expr);
+ },
+ [&](const ast::ForLoopStatement* s) { ProcessStatement(s->condition); },
+ [&](const ast::IfStatement* s) { //
+ ProcessStatement(s->condition);
+ },
+ [&](const ast::ReturnStatement* s) { //
+ ProcessStatement(s->value);
+ },
+ [&](const ast::SwitchStatement* s) { ProcessStatement(s->condition); },
+ [&](const ast::VariableDeclStatement* s) {
+ ProcessStatement(s->variable->constructor);
+ });
+ }
- Switch(
- stmt,
- [&](const ast::AssignmentStatement* s) {
- ProcessAssignment(s->lhs, s->rhs);
- },
- [&](const ast::CallStatement* s) { //
- ProcessStatement(s->expr);
- },
- [&](const ast::ForLoopStatement* s) {
- ProcessStatement(s->condition);
- },
- [&](const ast::IfStatement* s) { //
- ProcessStatement(s->condition);
- },
- [&](const ast::ReturnStatement* s) { //
- ProcessStatement(s->value);
- },
- [&](const ast::SwitchStatement* s) {
- ProcessStatement(s->condition);
- },
- [&](const ast::VariableDeclStatement* s) {
- ProcessStatement(s->variable->constructor);
- });
+ return std::move(to_hoist);
}
-
- return std::move(to_hoist);
- }
};
// DecomposeState performs the actual transforming of the AST to ensure order of
// evaluation, using the set of expressions to hoist collected by
// CollectHoistsState.
class DecomposeSideEffects::DecomposeState : public StateBase {
- ToHoistSet to_hoist;
+ ToHoistSet to_hoist;
- // Returns true if `binary_expr` should be decomposed for short-circuit eval.
- bool IsLogicalWithSideEffects(const ast::BinaryExpression* binary_expr) {
- return binary_expr->IsLogical() &&
- (sem.Get(binary_expr->lhs)->HasSideEffects() ||
- sem.Get(binary_expr->rhs)->HasSideEffects());
- }
-
- // Recursive function used to decompose an expression for short-circuit eval.
- const ast::Expression* Decompose(const ast::Expression* expr,
- ast::StatementList* curr_stmts) {
- // Helper to avoid passing in same args.
- auto decompose = [&](auto& e) { return Decompose(e, curr_stmts); };
-
- // Clones `expr`, possibly hoisting it to a let.
- auto clone_maybe_hoisted =
- [&](const ast::Expression* e) -> const ast::Expression* {
- if (to_hoist.count(e)) {
- auto name = b.Symbols().New();
- auto* v = b.Let(name, nullptr, ctx.Clone(e));
- auto* decl = b.Decl(v);
- curr_stmts->push_back(decl);
- return b.Expr(name);
- }
- return ctx.Clone(e);
- };
-
- return Switch(
- expr,
- [&](const ast::BinaryExpression* bin_expr) -> const ast::Expression* {
- if (!IsLogicalWithSideEffects(bin_expr)) {
- // No short-circuit, emit usual binary expr
- ctx.Replace(bin_expr->lhs, decompose(bin_expr->lhs));
- ctx.Replace(bin_expr->rhs, decompose(bin_expr->rhs));
- return clone_maybe_hoisted(bin_expr);
- }
-
- // Decompose into ifs to implement short-circuiting
- // For example, 'let r = a && b' becomes:
- //
- // var temp = a;
- // if (temp) {
- // temp = b;
- // }
- // let r = temp;
- //
- // and similarly, 'let r = a || b' becomes:
- //
- // var temp = a;
- // if (!temp) {
- // temp = b;
- // }
- // let r = temp;
- //
- // Further, compound logical binary expressions are also handled
- // recursively, for example, 'let r = (a && (b && c))' becomes:
- //
- // var temp = a;
- // if (temp) {
- // var temp2 = b;
- // if (temp2) {
- // temp2 = c;
- // }
- // temp = temp2;
- // }
- // let r = temp;
-
- auto name = b.Sym();
- curr_stmts->push_back(
- b.Decl(b.Var(name, nullptr, decompose(bin_expr->lhs))));
-
- const ast::Expression* if_cond = nullptr;
- if (bin_expr->IsLogicalOr()) {
- if_cond = b.Not(name);
- } else {
- if_cond = b.Expr(name);
- }
-
- const ast::BlockStatement* if_body = nullptr;
- {
- ast::StatementList stmts;
- TINT_SCOPED_ASSIGNMENT(curr_stmts, &stmts);
- auto* new_rhs = decompose(bin_expr->rhs);
- curr_stmts->push_back(b.Assign(name, new_rhs));
- if_body = b.Block(std::move(*curr_stmts));
- }
-
- curr_stmts->push_back(b.If(if_cond, if_body));
-
- return b.Expr(name);
- },
- [&](const ast::IndexAccessorExpression* idx) {
- ctx.Replace(idx->object, decompose(idx->object));
- ctx.Replace(idx->index, decompose(idx->index));
- return clone_maybe_hoisted(idx);
- },
- [&](const ast::BitcastExpression* bitcast) {
- ctx.Replace(bitcast->expr, decompose(bitcast->expr));
- return clone_maybe_hoisted(bitcast);
- },
- [&](const ast::CallExpression* call) {
- if (call->target.name) {
- ctx.Replace(call->target.name, decompose(call->target.name));
- }
- for (auto* a : call->args) {
- ctx.Replace(a, decompose(a));
- }
- return clone_maybe_hoisted(call);
- },
- [&](const ast::MemberAccessorExpression* member) {
- ctx.Replace(member->structure, decompose(member->structure));
- ctx.Replace(member->member, decompose(member->member));
- return clone_maybe_hoisted(member);
- },
- [&](const ast::UnaryOpExpression* unary) {
- ctx.Replace(unary->expr, decompose(unary->expr));
- return clone_maybe_hoisted(unary);
- },
- [&](const ast::LiteralExpression* lit) {
- return clone_maybe_hoisted(lit); // Leaf expression, just clone as is
- },
- [&](const ast::IdentifierExpression* id) {
- return clone_maybe_hoisted(id); // Leaf expression, just clone as is
- },
- [&](const ast::PhonyExpression* phony) {
- return clone_maybe_hoisted(
- phony); // Leaf expression, just clone as is
- },
- [&](Default) {
- TINT_ICE(AST, b.Diagnostics())
- << "unhandled expression type: " << expr->TypeInfo().name;
- return nullptr;
- });
- }
-
- // Inserts statements in `stmts` before `stmt`
- void InsertBefore(const ast::StatementList& stmts,
- const ast::Statement* stmt) {
- if (!stmts.empty()) {
- auto ip = utils::GetInsertionPoint(ctx, stmt);
- for (auto* s : stmts) {
- ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, s);
- }
+ // Returns true if `binary_expr` should be decomposed for short-circuit eval.
+ bool IsLogicalWithSideEffects(const ast::BinaryExpression* binary_expr) {
+ return binary_expr->IsLogical() && (sem.Get(binary_expr->lhs)->HasSideEffects() ||
+ sem.Get(binary_expr->rhs)->HasSideEffects());
}
- }
- // Decomposes expressions of `stmt`, returning a replacement statement or
- // nullptr if not replacing it.
- const ast::Statement* DecomposeStatement(const ast::Statement* stmt) {
- return Switch(
- stmt,
- [&](const ast::AssignmentStatement* s) -> const ast::Statement* {
- if (!sem.Get(s->lhs)->HasSideEffects() &&
- !sem.Get(s->rhs)->HasSideEffects()) {
- return nullptr;
- }
- // rhs before lhs
- ast::StatementList stmts;
- ctx.Replace(s->rhs, Decompose(s->rhs, &stmts));
- ctx.Replace(s->lhs, Decompose(s->lhs, &stmts));
- InsertBefore(stmts, s);
- return ctx.CloneWithoutTransform(s);
- },
- [&](const ast::CallStatement* s) -> const ast::Statement* {
- if (!sem.Get(s->expr)->HasSideEffects()) {
- return nullptr;
- }
- ast::StatementList stmts;
- ctx.Replace(s->expr, Decompose(s->expr, &stmts));
- InsertBefore(stmts, s);
- return ctx.CloneWithoutTransform(s);
- },
- [&](const ast::ForLoopStatement* s) -> const ast::Statement* {
- if (!s->condition || !sem.Get(s->condition)->HasSideEffects()) {
- return nullptr;
- }
- ast::StatementList stmts;
- ctx.Replace(s->condition, Decompose(s->condition, &stmts));
- InsertBefore(stmts, s);
- return ctx.CloneWithoutTransform(s);
- },
- [&](const ast::IfStatement* s) -> const ast::Statement* {
- if (!sem.Get(s->condition)->HasSideEffects()) {
- return nullptr;
- }
- ast::StatementList stmts;
- ctx.Replace(s->condition, Decompose(s->condition, &stmts));
- InsertBefore(stmts, s);
- return ctx.CloneWithoutTransform(s);
- },
- [&](const ast::ReturnStatement* s) -> const ast::Statement* {
- if (!s->value || !sem.Get(s->value)->HasSideEffects()) {
- return nullptr;
- }
- ast::StatementList stmts;
- ctx.Replace(s->value, Decompose(s->value, &stmts));
- InsertBefore(stmts, s);
- return ctx.CloneWithoutTransform(s);
- },
- [&](const ast::SwitchStatement* s) -> const ast::Statement* {
- if (!sem.Get(s->condition)) {
- return nullptr;
- }
- ast::StatementList stmts;
- ctx.Replace(s->condition, Decompose(s->condition, &stmts));
- InsertBefore(stmts, s);
- return ctx.CloneWithoutTransform(s);
- },
- [&](const ast::VariableDeclStatement* s) -> const ast::Statement* {
- auto* var = s->variable;
- if (!var->constructor ||
- !sem.Get(var->constructor)->HasSideEffects()) {
- return nullptr;
- }
- ast::StatementList stmts;
- ctx.Replace(var->constructor, Decompose(var->constructor, &stmts));
- InsertBefore(stmts, s);
- return b.Decl(ctx.CloneWithoutTransform(var));
- },
- [](Default) -> const ast::Statement* {
- // Other statement types don't have expressions
- return nullptr;
- });
- }
+ // Recursive function used to decompose an expression for short-circuit eval.
+ const ast::Expression* Decompose(const ast::Expression* expr, ast::StatementList* curr_stmts) {
+ // Helper to avoid passing in same args.
+ auto decompose = [&](auto& e) { return Decompose(e, curr_stmts); };
- public:
- explicit DecomposeState(CloneContext& ctx_in, ToHoistSet to_hoist_in)
- : StateBase(ctx_in), to_hoist(std::move(to_hoist_in)) {}
-
- void Run() {
- // We replace all BlockStatements as this allows us to iterate over the
- // block statements and ctx.InsertBefore hoisted declarations on them.
- ctx.ReplaceAll(
- [&](const ast::BlockStatement* block) -> const ast::Statement* {
- for (auto* stmt : block->statements) {
- if (auto* new_stmt = DecomposeStatement(stmt)) {
- ctx.Replace(stmt, new_stmt);
+ // Clones `expr`, possibly hoisting it to a let.
+ auto clone_maybe_hoisted = [&](const ast::Expression* e) -> const ast::Expression* {
+ if (to_hoist.count(e)) {
+ auto name = b.Symbols().New();
+ auto* v = b.Let(name, nullptr, ctx.Clone(e));
+ auto* decl = b.Decl(v);
+ curr_stmts->push_back(decl);
+ return b.Expr(name);
}
+ return ctx.Clone(e);
+ };
- // Handle for loops, as they are the only other AST node that
- // contains statements outside of BlockStatements.
- if (auto* fl = stmt->As<ast::ForLoopStatement>()) {
- if (auto* new_stmt = DecomposeStatement(fl->initializer)) {
- ctx.Replace(fl->initializer, new_stmt);
- }
- if (auto* new_stmt = DecomposeStatement(fl->continuing)) {
- ctx.Replace(fl->continuing, new_stmt);
- }
+ return Switch(
+ expr,
+ [&](const ast::BinaryExpression* bin_expr) -> const ast::Expression* {
+ if (!IsLogicalWithSideEffects(bin_expr)) {
+ // No short-circuit, emit usual binary expr
+ ctx.Replace(bin_expr->lhs, decompose(bin_expr->lhs));
+ ctx.Replace(bin_expr->rhs, decompose(bin_expr->rhs));
+ return clone_maybe_hoisted(bin_expr);
+ }
+
+ // Decompose into ifs to implement short-circuiting
+ // For example, 'let r = a && b' becomes:
+ //
+ // var temp = a;
+ // if (temp) {
+ // temp = b;
+ // }
+ // let r = temp;
+ //
+ // and similarly, 'let r = a || b' becomes:
+ //
+ // var temp = a;
+ // if (!temp) {
+ // temp = b;
+ // }
+ // let r = temp;
+ //
+ // Further, compound logical binary expressions are also handled
+ // recursively, for example, 'let r = (a && (b && c))' becomes:
+ //
+ // var temp = a;
+ // if (temp) {
+ // var temp2 = b;
+ // if (temp2) {
+ // temp2 = c;
+ // }
+ // temp = temp2;
+ // }
+ // let r = temp;
+
+ auto name = b.Sym();
+ curr_stmts->push_back(b.Decl(b.Var(name, nullptr, decompose(bin_expr->lhs))));
+
+ const ast::Expression* if_cond = nullptr;
+ if (bin_expr->IsLogicalOr()) {
+ if_cond = b.Not(name);
+ } else {
+ if_cond = b.Expr(name);
+ }
+
+ const ast::BlockStatement* if_body = nullptr;
+ {
+ ast::StatementList stmts;
+ TINT_SCOPED_ASSIGNMENT(curr_stmts, &stmts);
+ auto* new_rhs = decompose(bin_expr->rhs);
+ curr_stmts->push_back(b.Assign(name, new_rhs));
+ if_body = b.Block(std::move(*curr_stmts));
+ }
+
+ curr_stmts->push_back(b.If(if_cond, if_body));
+
+ return b.Expr(name);
+ },
+ [&](const ast::IndexAccessorExpression* idx) {
+ ctx.Replace(idx->object, decompose(idx->object));
+ ctx.Replace(idx->index, decompose(idx->index));
+ return clone_maybe_hoisted(idx);
+ },
+ [&](const ast::BitcastExpression* bitcast) {
+ ctx.Replace(bitcast->expr, decompose(bitcast->expr));
+ return clone_maybe_hoisted(bitcast);
+ },
+ [&](const ast::CallExpression* call) {
+ if (call->target.name) {
+ ctx.Replace(call->target.name, decompose(call->target.name));
+ }
+ for (auto* a : call->args) {
+ ctx.Replace(a, decompose(a));
+ }
+ return clone_maybe_hoisted(call);
+ },
+ [&](const ast::MemberAccessorExpression* member) {
+ ctx.Replace(member->structure, decompose(member->structure));
+ ctx.Replace(member->member, decompose(member->member));
+ return clone_maybe_hoisted(member);
+ },
+ [&](const ast::UnaryOpExpression* unary) {
+ ctx.Replace(unary->expr, decompose(unary->expr));
+ return clone_maybe_hoisted(unary);
+ },
+ [&](const ast::LiteralExpression* lit) {
+ return clone_maybe_hoisted(lit); // Leaf expression, just clone as is
+ },
+ [&](const ast::IdentifierExpression* id) {
+ return clone_maybe_hoisted(id); // Leaf expression, just clone as is
+ },
+ [&](const ast::PhonyExpression* phony) {
+ return clone_maybe_hoisted(phony); // Leaf expression, just clone as is
+ },
+ [&](Default) {
+ TINT_ICE(AST, b.Diagnostics())
+ << "unhandled expression type: " << expr->TypeInfo().name;
+ return nullptr;
+ });
+ }
+
+ // Inserts statements in `stmts` before `stmt`
+ void InsertBefore(const ast::StatementList& stmts, const ast::Statement* stmt) {
+ if (!stmts.empty()) {
+ auto ip = utils::GetInsertionPoint(ctx, stmt);
+ for (auto* s : stmts) {
+ ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, s);
}
- }
- return nullptr;
+ }
+ }
+
+ // Decomposes expressions of `stmt`, returning a replacement statement or
+ // nullptr if not replacing it.
+ const ast::Statement* DecomposeStatement(const ast::Statement* stmt) {
+ return Switch(
+ stmt,
+ [&](const ast::AssignmentStatement* s) -> const ast::Statement* {
+ if (!sem.Get(s->lhs)->HasSideEffects() && !sem.Get(s->rhs)->HasSideEffects()) {
+ return nullptr;
+ }
+ // rhs before lhs
+ ast::StatementList stmts;
+ ctx.Replace(s->rhs, Decompose(s->rhs, &stmts));
+ ctx.Replace(s->lhs, Decompose(s->lhs, &stmts));
+ InsertBefore(stmts, s);
+ return ctx.CloneWithoutTransform(s);
+ },
+ [&](const ast::CallStatement* s) -> const ast::Statement* {
+ if (!sem.Get(s->expr)->HasSideEffects()) {
+ return nullptr;
+ }
+ ast::StatementList stmts;
+ ctx.Replace(s->expr, Decompose(s->expr, &stmts));
+ InsertBefore(stmts, s);
+ return ctx.CloneWithoutTransform(s);
+ },
+ [&](const ast::ForLoopStatement* s) -> const ast::Statement* {
+ if (!s->condition || !sem.Get(s->condition)->HasSideEffects()) {
+ return nullptr;
+ }
+ ast::StatementList stmts;
+ ctx.Replace(s->condition, Decompose(s->condition, &stmts));
+ InsertBefore(stmts, s);
+ return ctx.CloneWithoutTransform(s);
+ },
+ [&](const ast::IfStatement* s) -> const ast::Statement* {
+ if (!sem.Get(s->condition)->HasSideEffects()) {
+ return nullptr;
+ }
+ ast::StatementList stmts;
+ ctx.Replace(s->condition, Decompose(s->condition, &stmts));
+ InsertBefore(stmts, s);
+ return ctx.CloneWithoutTransform(s);
+ },
+ [&](const ast::ReturnStatement* s) -> const ast::Statement* {
+ if (!s->value || !sem.Get(s->value)->HasSideEffects()) {
+ return nullptr;
+ }
+ ast::StatementList stmts;
+ ctx.Replace(s->value, Decompose(s->value, &stmts));
+ InsertBefore(stmts, s);
+ return ctx.CloneWithoutTransform(s);
+ },
+ [&](const ast::SwitchStatement* s) -> const ast::Statement* {
+ if (!sem.Get(s->condition)) {
+ return nullptr;
+ }
+ ast::StatementList stmts;
+ ctx.Replace(s->condition, Decompose(s->condition, &stmts));
+ InsertBefore(stmts, s);
+ return ctx.CloneWithoutTransform(s);
+ },
+ [&](const ast::VariableDeclStatement* s) -> const ast::Statement* {
+ auto* var = s->variable;
+ if (!var->constructor || !sem.Get(var->constructor)->HasSideEffects()) {
+ return nullptr;
+ }
+ ast::StatementList stmts;
+ ctx.Replace(var->constructor, Decompose(var->constructor, &stmts));
+ InsertBefore(stmts, s);
+ return b.Decl(ctx.CloneWithoutTransform(var));
+ },
+ [](Default) -> const ast::Statement* {
+ // Other statement types don't have expressions
+ return nullptr;
+ });
+ }
+
+ public:
+ explicit DecomposeState(CloneContext& ctx_in, ToHoistSet to_hoist_in)
+ : StateBase(ctx_in), to_hoist(std::move(to_hoist_in)) {}
+
+ void Run() {
+ // We replace all BlockStatements as this allows us to iterate over the
+ // block statements and ctx.InsertBefore hoisted declarations on them.
+ ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* {
+ for (auto* stmt : block->statements) {
+ if (auto* new_stmt = DecomposeStatement(stmt)) {
+ ctx.Replace(stmt, new_stmt);
+ }
+
+ // Handle for loops, as they are the only other AST node that
+ // contains statements outside of BlockStatements.
+ if (auto* fl = stmt->As<ast::ForLoopStatement>()) {
+ if (auto* new_stmt = DecomposeStatement(fl->initializer)) {
+ ctx.Replace(fl->initializer, new_stmt);
+ }
+ if (auto* new_stmt = DecomposeStatement(fl->continuing)) {
+ ctx.Replace(fl->continuing, new_stmt);
+ }
+ }
+ }
+ return nullptr;
});
- ctx.Clone();
- }
+ ctx.Clone();
+ }
};
-void DecomposeSideEffects::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- // First collect side-effecting expressions to hoist
- CollectHoistsState collect_hoists_state{ctx};
- auto to_hoist = collect_hoists_state.Run();
+void DecomposeSideEffects::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ // First collect side-effecting expressions to hoist
+ CollectHoistsState collect_hoists_state{ctx};
+ auto to_hoist = collect_hoists_state.Run();
- // Now decompose these expressions
- DecomposeState decompose_state{ctx, std::move(to_hoist)};
- decompose_state.Run();
+ // Now decompose these expressions
+ DecomposeState decompose_state{ctx, std::move(to_hoist)};
+ decompose_state.Run();
}
} // namespace
@@ -693,14 +667,13 @@
PromoteSideEffectsToDecl::PromoteSideEffectsToDecl() = default;
PromoteSideEffectsToDecl::~PromoteSideEffectsToDecl() = default;
-Output PromoteSideEffectsToDecl::Run(const Program* program,
- const DataMap& data) const {
- transform::Manager manager;
- manager.Add<SimplifySideEffectStatements>();
- manager.Add<DecomposeSideEffects>();
+Output PromoteSideEffectsToDecl::Run(const Program* program, const DataMap& data) const {
+ transform::Manager manager;
+ manager.Add<SimplifySideEffectStatements>();
+ manager.Add<DecomposeSideEffects>();
- auto output = manager.Run(program, data);
- return output;
+ auto output = manager.Run(program, data);
+ return output;
}
} // namespace tint::transform
diff --git a/src/tint/transform/promote_side_effects_to_decl.h b/src/tint/transform/promote_side_effects_to_decl.h
index cdc9241..1e629b3 100644
--- a/src/tint/transform/promote_side_effects_to_decl.h
+++ b/src/tint/transform/promote_side_effects_to_decl.h
@@ -23,21 +23,20 @@
/// declarations before the statement of usage with the goal of ensuring
/// left-to-right order of evaluation, while respecting short-circuit
/// evaluation.
-class PromoteSideEffectsToDecl
- : public Castable<PromoteSideEffectsToDecl, Transform> {
- public:
- /// Constructor
- PromoteSideEffectsToDecl();
+class PromoteSideEffectsToDecl : public Castable<PromoteSideEffectsToDecl, Transform> {
+ public:
+ /// Constructor
+ PromoteSideEffectsToDecl();
- /// Destructor
- ~PromoteSideEffectsToDecl() override;
+ /// Destructor
+ ~PromoteSideEffectsToDecl() override;
- protected:
- /// Runs the transform on `program`, returning the transformation result.
- /// @param program the source program to transform
- /// @param data optional extra transform-specific data
- /// @returns the transformation result
- Output Run(const Program* program, const DataMap& data = {}) const override;
+ protected:
+ /// Runs the transform on `program`, returning the transformation result.
+ /// @param program the source program to transform
+ /// @param data optional extra transform-specific data
+ /// @returns the transformation result
+ Output Run(const Program* program, const DataMap& data = {}) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/promote_side_effects_to_decl_test.cc b/src/tint/transform/promote_side_effects_to_decl_test.cc
index 299d706..9d9115f 100644
--- a/src/tint/transform/promote_side_effects_to_decl_test.cc
+++ b/src/tint/transform/promote_side_effects_to_decl_test.cc
@@ -22,17 +22,17 @@
using PromoteSideEffectsToDeclTest = TransformTest;
TEST_F(PromoteSideEffectsToDeclTest, EmptyModule) {
- auto* src = "";
- auto* expect = "";
+ auto* src = "";
+ auto* expect = "";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Unary_Arith_SE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -42,16 +42,16 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_BothSE) {
- auto* src = R"(
+ auto* src = R"(
fn a() -> i32 {
return 1;
}
@@ -65,7 +65,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a() -> i32 {
return 1;
}
@@ -81,14 +81,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_LeftSE) {
- auto* src = R"(
+ auto* src = R"(
fn a() -> i32 {
return 1;
}
@@ -99,7 +99,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a() -> i32 {
return 1;
}
@@ -111,14 +111,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_RightSE) {
- auto* src = R"(
+ auto* src = R"(
fn a() -> i32 {
return 1;
}
@@ -129,7 +129,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a() -> i32 {
return 1;
}
@@ -142,14 +142,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_LeftmostSE) {
- auto* src = R"(
+ auto* src = R"(
fn a() -> i32 {
return 1;
}
@@ -162,7 +162,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a() -> i32 {
return 1;
}
@@ -176,14 +176,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_RightmostSE) {
- auto* src = R"(
+ auto* src = R"(
fn a() -> i32 {
return 1;
}
@@ -196,7 +196,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a() -> i32 {
return 1;
}
@@ -211,14 +211,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_MiddleSE) {
- auto* src = R"(
+ auto* src = R"(
fn a() -> i32 {
return 1;
}
@@ -232,7 +232,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a() -> i32 {
return 1;
}
@@ -248,14 +248,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_ThreeSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(v : i32) -> i32 {
return v;
}
@@ -265,7 +265,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(v : i32) -> i32 {
return v;
}
@@ -278,14 +278,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Constants_NoRecvSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -295,7 +295,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -306,14 +306,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Constants_RecvSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -324,7 +324,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -338,14 +338,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Constants_ConstAndSEAndVar) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -357,7 +357,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -370,14 +370,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Constants_VarAndSEAndConst) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -388,7 +388,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -401,15 +401,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(PromoteSideEffectsToDeclTest,
- Binary_Arith_Constants_SEAndVarAndConstAndVar) {
- auto* src = R"(
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Constants_SEAndVarAndConstAndVar) {
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -421,7 +420,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -434,14 +433,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Builtins_WithSE) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
a : atomic<i32>,
}
@@ -454,7 +453,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct SB {
a : atomic<i32>,
}
@@ -468,14 +467,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Builtins_NoSEAndVar) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
a : atomic<i32>,
}
@@ -488,16 +487,16 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Builtins_NoSEAndSE) {
- auto* src = R"(
+ auto* src = R"(
struct SB {
a : atomic<i32>,
}
@@ -514,7 +513,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct SB {
a : atomic<i32>,
}
@@ -533,14 +532,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Vector_RightSE) {
- auto* src = R"(
+ auto* src = R"(
fn a() -> i32 {
return 1;
}
@@ -552,7 +551,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a() -> i32 {
return 1;
}
@@ -567,14 +566,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InCall) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -589,7 +588,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -609,14 +608,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InTypeCtor) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
@@ -628,7 +627,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -646,14 +645,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InTypeConversion) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
@@ -665,7 +664,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -680,14 +679,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InIntrinsic) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
@@ -699,7 +698,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -717,14 +716,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InMemberAccessor) {
- auto* src = R"(
+ auto* src = R"(
struct S {
v : i32,
@@ -740,7 +739,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
v : i32,
}
@@ -758,14 +757,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InUnary) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -776,7 +775,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -790,14 +789,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InBitcast) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -808,7 +807,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -821,14 +820,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InForLoopInit) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -842,7 +841,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -857,14 +856,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InForLoopCond) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -877,7 +876,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -896,14 +895,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InForLoopCont) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -918,7 +917,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -940,14 +939,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InForLoopInitCondCont) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -963,7 +962,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -994,14 +993,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InElseIf) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -1016,7 +1015,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -1034,14 +1033,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InElseIfChain) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -1064,7 +1063,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -1094,14 +1093,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InReturn) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -1112,7 +1111,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -1125,14 +1124,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InSwitch) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -1146,7 +1145,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -1162,14 +1161,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_LeftSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1180,7 +1179,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1195,14 +1194,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_RightSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1213,7 +1212,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1228,14 +1227,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_BothSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1245,7 +1244,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1259,14 +1258,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_LeftmostSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1279,7 +1278,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1304,14 +1303,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_RightmostSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1324,7 +1323,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1341,14 +1340,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_MiddleSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1361,7 +1360,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1386,14 +1385,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_Constants_NoRecvSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1403,7 +1402,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1425,14 +1424,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_Constants_RecvSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1443,7 +1442,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1466,15 +1465,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(PromoteSideEffectsToDeclTest,
- Binary_Logical_Constants_ConstAndSEAndVar) {
- auto* src = R"(
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_Constants_ConstAndSEAndVar) {
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1486,7 +1484,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1506,15 +1504,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(PromoteSideEffectsToDeclTest,
- Binary_Logical_Constants_VarAndSEAndConst) {
- auto* src = R"(
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_Constants_VarAndSEAndConst) {
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1525,7 +1522,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1544,15 +1541,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(PromoteSideEffectsToDeclTest,
- Binary_Logical_Constants_SEAndVarAndConstAndVar) {
- auto* src = R"(
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_Constants_SEAndVarAndConstAndVar) {
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1564,7 +1560,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1588,14 +1584,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_MixedSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1608,7 +1604,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1645,14 +1641,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_NestedAnds) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1662,7 +1658,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1692,14 +1688,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_NestedOrs) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1709,7 +1705,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1739,14 +1735,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_MultipleStatements) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1758,7 +1754,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1778,14 +1774,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InCall) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1800,7 +1796,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1819,14 +1815,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InTypeCtor) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
@@ -1838,7 +1834,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1866,14 +1862,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InTypeConversion) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
@@ -1885,7 +1881,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -1906,16 +1902,16 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Make sure we process logical binary expressions of non-logical binary
// expressions.
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_OfNonLogical) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
@@ -1927,7 +1923,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -1944,14 +1940,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InIntrinsic) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
@@ -1963,7 +1959,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -1991,14 +1987,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InMemberAccessor) {
- auto* src = R"(
+ auto* src = R"(
struct S {
v : bool,
@@ -2014,7 +2010,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
v : bool,
}
@@ -2039,14 +2035,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InUnary) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
@@ -2058,7 +2054,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -2073,14 +2069,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InBitcast) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -2091,7 +2087,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -2106,14 +2102,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InForLoopInit) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -2127,7 +2123,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -2145,14 +2141,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InForLoopCond) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -2165,7 +2161,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -2187,14 +2183,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InForLoopCont) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -2209,7 +2205,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -2234,14 +2230,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InForLoopInitCondCont) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -2257,7 +2253,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -2297,14 +2293,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InElseIf) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -2319,7 +2315,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -2340,14 +2336,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InElseIfChain) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -2370,7 +2366,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
@@ -2405,14 +2401,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Call_NoSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2428,16 +2424,16 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Call_OneSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2452,7 +2448,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2468,14 +2464,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Call_AllSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2489,7 +2485,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2506,14 +2502,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Call_MiddleNotSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2529,7 +2525,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2547,14 +2543,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Call_InBinary) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2571,7 +2567,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2593,14 +2589,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor_2D_LeftSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2612,7 +2608,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2625,14 +2621,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor_2D_RightSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2646,7 +2642,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2660,14 +2656,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor_2D_BothSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2678,7 +2674,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2691,14 +2687,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Assignment_ToPhony) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2708,7 +2704,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2718,14 +2714,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Assignment_ToArray1D) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2736,7 +2732,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2749,14 +2745,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Assignment_ToArray2D) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2767,7 +2763,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2781,14 +2777,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Assignment_ToArray3D) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2799,7 +2795,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2814,14 +2810,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Assignment_ToArray_FromArray) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2834,7 +2830,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2851,14 +2847,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Assignment_ToVec_BothSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2869,7 +2865,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2882,14 +2878,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Assignment_ToVec_LeftSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2901,7 +2897,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2915,14 +2911,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Assignment_ToVec_RightSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2934,7 +2930,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2947,14 +2943,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, TypeConstructor_Struct) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2970,7 +2966,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -2989,14 +2985,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, TypeConstructor_Array1D) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -3006,7 +3002,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -3019,14 +3015,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, TypeConstructor_Array2D) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -3036,7 +3032,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 1;
}
@@ -3052,14 +3048,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, MemberAccessor_Vec) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> vec3<i32> {
return vec3<i32>();
}
@@ -3069,7 +3065,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> vec3<i32> {
return vec3<i32>();
}
@@ -3081,14 +3077,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, MemberAccessor_Struct) {
- auto* src = R"(
+ auto* src = R"(
struct S {
x : i32,
y : i32,
@@ -3103,7 +3099,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
x : i32,
y : i32,
@@ -3120,14 +3116,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, MemberAccessor_Struct_Mixed) {
- auto* src = R"(
+ auto* src = R"(
struct S {
x : i32,
y : i32,
@@ -3152,7 +3148,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
x : i32,
y : i32,
@@ -3184,14 +3180,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor_Plus_SE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3202,7 +3198,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3215,14 +3211,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor_Of_SE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3233,7 +3229,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3245,14 +3241,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor2_Of_LeftSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3263,7 +3259,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3275,14 +3271,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor2_Of_RightSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3293,7 +3289,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3305,14 +3301,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor2_Of_SEAndVar) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3324,7 +3320,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3337,14 +3333,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor2_Of_VarAndSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3356,7 +3352,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3370,14 +3366,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, IndexAccessorOfVar_Plus_SE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3389,7 +3385,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3404,14 +3400,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor_Plus_IndexAccessorOfSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3422,7 +3418,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3435,15 +3431,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(PromoteSideEffectsToDeclTest,
- AssignTo_IndexAccessorOfIndexAccessorOfSE) {
- auto* src = R"(
+TEST_F(PromoteSideEffectsToDeclTest, AssignTo_IndexAccessorOfIndexAccessorOfSE) {
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3455,7 +3450,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3468,15 +3463,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(PromoteSideEffectsToDeclTest,
- AssignTo_IndexAccessorOfIndexAccessorOfLiteralPlusSE) {
- auto* src = R"(
+TEST_F(PromoteSideEffectsToDeclTest, AssignTo_IndexAccessorOfIndexAccessorOfLiteralPlusSE) {
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3488,7 +3482,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3502,15 +3496,15 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest,
AssignTo_IndexAccessorOfIndexAccessorOfLiteralPlusIndexAccessorOfSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3522,7 +3516,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3536,14 +3530,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, IndexAccessorOfLhsSERhsSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3557,7 +3551,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3573,14 +3567,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, BinaryIndexAccessorOfLhsSERhsSE) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3594,7 +3588,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
@@ -3610,16 +3604,16 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, BinaryMemberAccessorPlusSE) {
- // bclayton@'s example:
- // https://dawn-review.googlesource.com/c/tint/+/78620/6..8/src/transform/promote_side_effects_to_decl.cc#b490
- auto* src = R"(
+ // bclayton@'s example:
+ // https://dawn-review.googlesource.com/c/tint/+/78620/6..8/src/transform/promote_side_effects_to_decl.cc#b490
+ auto* src = R"(
fn modify_vec(p : ptr<function, vec4<i32>>) -> i32 {
(*p).x = 42;
return 0;
@@ -3632,7 +3626,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn modify_vec(p : ptr<function, vec4<i32>>) -> i32 {
(*(p)).x = 42;
return 0;
@@ -3646,15 +3640,15 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Call_ReadOnlyArgAndSE) {
- // Make sure that read-only args don't get hoisted (tex and samp)
- auto* src = R"(
+ // Make sure that read-only args don't get hoisted (tex and samp)
+ auto* src = R"(
@group(1) @binding(1) var tex: texture_2d_array<u32>;
@group(1) @binding(2) var samp: sampler;
@@ -3667,7 +3661,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(1) @binding(1) var tex : texture_2d_array<u32>;
@group(1) @binding(2) var samp : sampler;
@@ -3682,15 +3676,15 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Call_PtrArgAndSE) {
- // Make sure that read-only args don't get hoisted (tex and samp)
- auto* src = R"(
+ // Make sure that read-only args don't get hoisted (tex and samp)
+ auto* src = R"(
var<private> b : i32 = 0;
@@ -3710,7 +3704,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> b : i32 = 0;
fn a(i : i32) -> i32 {
@@ -3728,14 +3722,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, TypeCtor_VarPlusI32CtorPlusVar) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b = 0;
var c = 0;
@@ -3744,16 +3738,16 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Mixed_ArithPlusLogical) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 0;
}
@@ -3771,7 +3765,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 0;
}
@@ -3795,14 +3789,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Mixed_LogicalPlusArith) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 0;
}
@@ -3820,7 +3814,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 0;
}
@@ -3844,14 +3838,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Mixed_ArithAndLogicalArgs) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 0;
}
@@ -3869,7 +3863,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 0;
}
@@ -3892,14 +3886,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Mixed_LogicalAndArithArgs) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 0;
}
@@ -3917,7 +3911,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 0;
}
@@ -3940,14 +3934,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Mixed_Complex) {
- auto* src = R"(
+ auto* src = R"(
fn a(i : i32) -> i32 {
return 0;
}
@@ -3969,7 +3963,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(i : i32) -> i32 {
return 0;
}
@@ -4001,10 +3995,10 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace tint::transform
diff --git a/src/tint/transform/remove_continue_in_switch.cc b/src/tint/transform/remove_continue_in_switch.cc
index 9ee05c0..5c2413e 100644
--- a/src/tint/transform/remove_continue_in_switch.cc
+++ b/src/tint/transform/remove_continue_in_switch.cc
@@ -34,95 +34,89 @@
namespace {
class State {
- private:
- CloneContext& ctx;
- ProgramBuilder& b;
- const sem::Info& sem;
+ private:
+ CloneContext& ctx;
+ ProgramBuilder& b;
+ const sem::Info& sem;
- // Map of switch statement to 'tint_continue' variable.
- std::unordered_map<const ast::SwitchStatement*, Symbol>
- switch_to_cont_var_name;
+ // Map of switch statement to 'tint_continue' variable.
+ std::unordered_map<const ast::SwitchStatement*, Symbol> switch_to_cont_var_name;
- // If `cont` is within a switch statement within a loop, returns a pointer to
- // that switch statement.
- static const ast::SwitchStatement* GetParentSwitchInLoop(
- const sem::Info& sem,
- const ast::ContinueStatement* cont) {
- // Find whether first parent is a switch or a loop
- auto* sem_stmt = sem.Get(cont);
- auto* sem_parent =
- sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement,
- sem::ForLoopStatement>();
- if (!sem_parent) {
- return nullptr;
- }
- return sem_parent->Declaration()->As<ast::SwitchStatement>();
- }
-
- public:
- /// Constructor
- /// @param ctx_in the context
- explicit State(CloneContext& ctx_in)
- : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}
-
- /// Returns true if this transform should be run for the given program
- static bool ShouldRun(const Program* program) {
- for (auto* node : program->ASTNodes().Objects()) {
- auto* stmt = node->As<ast::ContinueStatement>();
- if (!stmt) {
- continue;
- }
- if (GetParentSwitchInLoop(program->Sem(), stmt)) {
- return true;
- }
- }
- return false;
- }
-
- /// Runs the transform
- void Run() {
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- auto* cont = node->As<ast::ContinueStatement>();
- if (!cont) {
- continue;
- }
-
- // If first parent is not a switch within a loop, skip
- auto* switch_stmt = GetParentSwitchInLoop(sem, cont);
- if (!switch_stmt) {
- continue;
- }
-
- auto cont_var_name =
- tint::utils::GetOrCreate(switch_to_cont_var_name, switch_stmt, [&]() {
- // Create and insert 'var tint_continue : bool = false;' before the
- // switch.
- auto var_name = b.Symbols().New("tint_continue");
- auto* decl = b.Decl(b.Var(var_name, b.ty.bool_(), b.Expr(false)));
- auto ip = utils::GetInsertionPoint(ctx, switch_stmt);
- ctx.InsertBefore(ip.first->Declaration()->statements, ip.second,
- decl);
-
- // Create and insert 'if (tint_continue) { continue; }' after
- // switch.
- auto* if_stmt = b.If(b.Expr(var_name), b.Block(b.Continue()));
- ctx.InsertAfter(ip.first->Declaration()->statements, ip.second,
- if_stmt);
-
- // Return the new var name
- return var_name;
- });
-
- // Replace 'continue;' with '{ tint_continue = true; break; }'
- auto* new_stmt = b.Block( //
- b.Assign(b.Expr(cont_var_name), true), //
- b.Break());
-
- ctx.Replace(cont, new_stmt);
+ // If `cont` is within a switch statement within a loop, returns a pointer to
+ // that switch statement.
+ static const ast::SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem,
+ const ast::ContinueStatement* cont) {
+ // Find whether first parent is a switch or a loop
+ auto* sem_stmt = sem.Get(cont);
+ auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement,
+ sem::ForLoopStatement>();
+ if (!sem_parent) {
+ return nullptr;
+ }
+ return sem_parent->Declaration()->As<ast::SwitchStatement>();
}
- ctx.Clone();
- }
+ public:
+ /// Constructor
+ /// @param ctx_in the context
+ explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}
+
+ /// Returns true if this transform should be run for the given program
+ static bool ShouldRun(const Program* program) {
+ for (auto* node : program->ASTNodes().Objects()) {
+ auto* stmt = node->As<ast::ContinueStatement>();
+ if (!stmt) {
+ continue;
+ }
+ if (GetParentSwitchInLoop(program->Sem(), stmt)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /// Runs the transform
+ void Run() {
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ auto* cont = node->As<ast::ContinueStatement>();
+ if (!cont) {
+ continue;
+ }
+
+ // If first parent is not a switch within a loop, skip
+ auto* switch_stmt = GetParentSwitchInLoop(sem, cont);
+ if (!switch_stmt) {
+ continue;
+ }
+
+ auto cont_var_name =
+ tint::utils::GetOrCreate(switch_to_cont_var_name, switch_stmt, [&]() {
+ // Create and insert 'var tint_continue : bool = false;' before the
+ // switch.
+ auto var_name = b.Symbols().New("tint_continue");
+ auto* decl = b.Decl(b.Var(var_name, b.ty.bool_(), b.Expr(false)));
+ auto ip = utils::GetInsertionPoint(ctx, switch_stmt);
+ ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, decl);
+
+ // Create and insert 'if (tint_continue) { continue; }' after
+ // switch.
+ auto* if_stmt = b.If(b.Expr(var_name), b.Block(b.Continue()));
+ ctx.InsertAfter(ip.first->Declaration()->statements, ip.second, if_stmt);
+
+ // Return the new var name
+ return var_name;
+ });
+
+ // Replace 'continue;' with '{ tint_continue = true; break; }'
+ auto* new_stmt = b.Block( //
+ b.Assign(b.Expr(cont_var_name), true), //
+ b.Break());
+
+ ctx.Replace(cont, new_stmt);
+ }
+
+ ctx.Clone();
+ }
};
} // namespace
@@ -130,16 +124,13 @@
RemoveContinueInSwitch::RemoveContinueInSwitch() = default;
RemoveContinueInSwitch::~RemoveContinueInSwitch() = default;
-bool RemoveContinueInSwitch::ShouldRun(const Program* program,
- const DataMap& /*data*/) const {
- return State::ShouldRun(program);
+bool RemoveContinueInSwitch::ShouldRun(const Program* program, const DataMap& /*data*/) const {
+ return State::ShouldRun(program);
}
-void RemoveContinueInSwitch::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- State state(ctx);
- state.Run();
+void RemoveContinueInSwitch::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ State state(ctx);
+ state.Run();
}
} // namespace tint::transform
diff --git a/src/tint/transform/remove_continue_in_switch.h b/src/tint/transform/remove_continue_in_switch.h
index f875660..e706225 100644
--- a/src/tint/transform/remove_continue_in_switch.h
+++ b/src/tint/transform/remove_continue_in_switch.h
@@ -23,31 +23,27 @@
/// bool variable, and checking if the variable is set after the switch to
/// continue. It is necessary to work around FXC "error X3708: continue cannot
/// be used in a switch". See crbug.com/tint/1080.
-class RemoveContinueInSwitch
- : public Castable<RemoveContinueInSwitch, Transform> {
- public:
- /// Constructor
- RemoveContinueInSwitch();
+class RemoveContinueInSwitch : public Castable<RemoveContinueInSwitch, Transform> {
+ public:
+ /// Constructor
+ RemoveContinueInSwitch();
- /// Destructor
- ~RemoveContinueInSwitch() override;
+ /// Destructor
+ ~RemoveContinueInSwitch() override;
- protected:
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
+ protected:
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/remove_continue_in_switch_test.cc b/src/tint/transform/remove_continue_in_switch_test.cc
index 70f167d..a1e7b6e 100644
--- a/src/tint/transform/remove_continue_in_switch_test.cc
+++ b/src/tint/transform/remove_continue_in_switch_test.cc
@@ -21,7 +21,7 @@
using RemoveContinueInSwitchTest = TransformTest;
TEST_F(RemoveContinueInSwitchTest, ShouldRun_True) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i = 0;
loop {
@@ -39,17 +39,17 @@
}
)";
- EXPECT_TRUE(ShouldRun<RemoveContinueInSwitch>(src));
+ EXPECT_TRUE(ShouldRun<RemoveContinueInSwitch>(src));
}
TEST_F(RemoveContinueInSwitchTest, ShouldRunEmptyModule_False) {
- auto* src = "";
+ auto* src = "";
- EXPECT_FALSE(ShouldRun<RemoveContinueInSwitch>(src));
+ EXPECT_FALSE(ShouldRun<RemoveContinueInSwitch>(src));
}
TEST_F(RemoveContinueInSwitchTest, ShouldRunContinueNotInSwitch_False) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i = 0;
loop {
@@ -70,11 +70,11 @@
}
)";
- EXPECT_FALSE(ShouldRun<RemoveContinueInSwitch>(src));
+ EXPECT_FALSE(ShouldRun<RemoveContinueInSwitch>(src));
}
TEST_F(RemoveContinueInSwitchTest, ShouldRunContinueInLoopInSwitch_False) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i = 0;
switch(i) {
@@ -94,21 +94,21 @@
}
)";
- EXPECT_FALSE(ShouldRun<RemoveContinueInSwitch>(src));
+ EXPECT_FALSE(ShouldRun<RemoveContinueInSwitch>(src));
}
TEST_F(RemoveContinueInSwitchTest, EmptyModule) {
- auto* src = "";
- auto* expect = src;
+ auto* src = "";
+ auto* expect = src;
- DataMap data;
- auto got = Run<RemoveContinueInSwitch>(src, data);
+ DataMap data;
+ auto got = Run<RemoveContinueInSwitch>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveContinueInSwitchTest, SingleContinue) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i = 0;
loop {
@@ -132,7 +132,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i = 0;
loop {
@@ -163,14 +163,14 @@
}
)";
- DataMap data;
- auto got = Run<RemoveContinueInSwitch>(src, data);
+ DataMap data;
+ auto got = Run<RemoveContinueInSwitch>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveContinueInSwitchTest, MultipleContinues) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i = 0;
loop {
@@ -202,7 +202,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i = 0;
loop {
@@ -247,14 +247,14 @@
}
)";
- DataMap data;
- auto got = Run<RemoveContinueInSwitch>(src, data);
+ DataMap data;
+ auto got = Run<RemoveContinueInSwitch>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveContinueInSwitchTest, MultipleSwitch) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i = 0;
loop {
@@ -287,7 +287,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i = 0;
loop {
@@ -332,14 +332,14 @@
}
)";
- DataMap data;
- auto got = Run<RemoveContinueInSwitch>(src, data);
+ DataMap data;
+ auto got = Run<RemoveContinueInSwitch>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveContinueInSwitchTest, NestedLoopSwitch) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i = 0;
loop {
@@ -374,7 +374,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i = 0;
loop {
@@ -423,14 +423,14 @@
}
)";
- DataMap data;
- auto got = Run<RemoveContinueInSwitch>(src, data);
+ DataMap data;
+ auto got = Run<RemoveContinueInSwitch>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveContinueInSwitchTest, ExtraScopes) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i = 0;
var a = true;
@@ -462,7 +462,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i = 0;
var a = true;
@@ -501,14 +501,14 @@
}
)";
- DataMap data;
- auto got = Run<RemoveContinueInSwitch>(src, data);
+ DataMap data;
+ auto got = Run<RemoveContinueInSwitch>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveContinueInSwitchTest, ForLoop) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
for (var i = 0; i < 4; i = i + 1) {
let marker1 = 0;
@@ -527,7 +527,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
for(var i = 0; (i < 4); i = (i + 1)) {
let marker1 = 0;
@@ -553,10 +553,10 @@
}
)";
- DataMap data;
- auto got = Run<RemoveContinueInSwitch>(src, data);
+ DataMap data;
+ auto got = Run<RemoveContinueInSwitch>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/remove_phonies.cc b/src/tint/transform/remove_phonies.cc
index 9904e34..dc5c092 100644
--- a/src/tint/transform/remove_phonies.cc
+++ b/src/tint/transform/remove_phonies.cc
@@ -34,31 +34,31 @@
namespace {
struct SinkSignature {
- std::vector<const sem::Type*> types;
+ std::vector<const sem::Type*> types;
- bool operator==(const SinkSignature& other) const {
- if (types.size() != other.types.size()) {
- return false;
+ bool operator==(const SinkSignature& other) const {
+ if (types.size() != other.types.size()) {
+ return false;
+ }
+ for (size_t i = 0; i < types.size(); i++) {
+ if (types[i] != other.types[i]) {
+ return false;
+ }
+ }
+ return true;
}
- for (size_t i = 0; i < types.size(); i++) {
- if (types[i] != other.types[i]) {
- return false;
- }
- }
- return true;
- }
- struct Hasher {
- /// @param sig the CallTargetSignature to hash
- /// @return the hash value
- std::size_t operator()(const SinkSignature& sig) const {
- size_t hash = tint::utils::Hash(sig.types.size());
- for (auto* ty : sig.types) {
- tint::utils::HashCombine(&hash, ty);
- }
- return hash;
- }
- };
+ struct Hasher {
+ /// @param sig the CallTargetSignature to hash
+ /// @return the hash value
+ std::size_t operator()(const SinkSignature& sig) const {
+ size_t hash = tint::utils::Hash(sig.types.size());
+ for (auto* ty : sig.types) {
+ tint::utils::HashCombine(&hash, ty);
+ }
+ return hash;
+ }
+ };
};
} // namespace
@@ -68,87 +68,83 @@
RemovePhonies::~RemovePhonies() = default;
bool RemovePhonies::ShouldRun(const Program* program, const DataMap&) const {
- for (auto* node : program->ASTNodes().Objects()) {
- if (node->Is<ast::PhonyExpression>()) {
- return true;
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (node->Is<ast::PhonyExpression>()) {
+ return true;
+ }
}
- }
- return false;
+ return false;
}
void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- auto& sem = ctx.src->Sem();
+ auto& sem = ctx.src->Sem();
- std::unordered_map<SinkSignature, Symbol, SinkSignature::Hasher> sinks;
+ std::unordered_map<SinkSignature, Symbol, SinkSignature::Hasher> sinks;
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- if (auto* stmt = node->As<ast::AssignmentStatement>()) {
- if (stmt->lhs->Is<ast::PhonyExpression>()) {
- std::vector<const ast::Expression*> side_effects;
- if (!ast::TraverseExpressions(
- stmt->rhs, ctx.dst->Diagnostics(),
- [&](const ast::CallExpression* call) {
- // ast::CallExpression may map to a function or builtin call
- // (both may have side-effects), or a type constructor or
- // type conversion (both do not have side effects).
- if (sem.Get(call)
- ->Target()
- ->IsAnyOf<sem::Function, sem::Builtin>()) {
- side_effects.push_back(call);
- return ast::TraverseAction::Skip;
- }
- return ast::TraverseAction::Descend;
- })) {
- return;
- }
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* stmt = node->As<ast::AssignmentStatement>()) {
+ if (stmt->lhs->Is<ast::PhonyExpression>()) {
+ std::vector<const ast::Expression*> side_effects;
+ if (!ast::TraverseExpressions(
+ stmt->rhs, ctx.dst->Diagnostics(), [&](const ast::CallExpression* call) {
+ // ast::CallExpression may map to a function or builtin call
+ // (both may have side-effects), or a type constructor or
+ // type conversion (both do not have side effects).
+ if (sem.Get(call)->Target()->IsAnyOf<sem::Function, sem::Builtin>()) {
+ side_effects.push_back(call);
+ return ast::TraverseAction::Skip;
+ }
+ return ast::TraverseAction::Descend;
+ })) {
+ return;
+ }
- if (side_effects.empty()) {
- // Phony assignment with no side effects.
- // Just remove it.
- RemoveStatement(ctx, stmt);
- continue;
- }
+ if (side_effects.empty()) {
+ // Phony assignment with no side effects.
+ // Just remove it.
+ RemoveStatement(ctx, stmt);
+ continue;
+ }
- if (side_effects.size() == 1) {
- if (auto* call = side_effects[0]->As<ast::CallExpression>()) {
- // Phony assignment with single call side effect.
- // Replace phony assignment with call.
- ctx.Replace(
- stmt, [&, call] { return ctx.dst->CallStmt(ctx.Clone(call)); });
- continue;
- }
- }
+ if (side_effects.size() == 1) {
+ if (auto* call = side_effects[0]->As<ast::CallExpression>()) {
+ // Phony assignment with single call side effect.
+ // Replace phony assignment with call.
+ ctx.Replace(stmt, [&, call] { return ctx.dst->CallStmt(ctx.Clone(call)); });
+ continue;
+ }
+ }
- // Phony assignment with multiple side effects.
- // Generate a call to a placeholder function with the side
- // effects as arguments.
- ctx.Replace(stmt, [&, side_effects] {
- SinkSignature sig;
- for (auto* arg : side_effects) {
- sig.types.push_back(sem.Get(arg)->Type()->UnwrapRef());
- }
- auto sink = utils::GetOrCreate(sinks, sig, [&] {
- auto name = ctx.dst->Symbols().New("phony_sink");
- ast::VariableList params;
- for (auto* ty : sig.types) {
- auto* ast_ty = CreateASTTypeFor(ctx, ty);
- params.push_back(
- ctx.dst->Param("p" + std::to_string(params.size()), ast_ty));
+ // Phony assignment with multiple side effects.
+ // Generate a call to a placeholder function with the side
+ // effects as arguments.
+ ctx.Replace(stmt, [&, side_effects] {
+ SinkSignature sig;
+ for (auto* arg : side_effects) {
+ sig.types.push_back(sem.Get(arg)->Type()->UnwrapRef());
+ }
+ auto sink = utils::GetOrCreate(sinks, sig, [&] {
+ auto name = ctx.dst->Symbols().New("phony_sink");
+ ast::VariableList params;
+ for (auto* ty : sig.types) {
+ auto* ast_ty = CreateASTTypeFor(ctx, ty);
+ params.push_back(
+ ctx.dst->Param("p" + std::to_string(params.size()), ast_ty));
+ }
+ ctx.dst->Func(name, params, ctx.dst->ty.void_(), {});
+ return name;
+ });
+ ast::ExpressionList args;
+ for (auto* arg : side_effects) {
+ args.push_back(ctx.Clone(arg));
+ }
+ return ctx.dst->CallStmt(ctx.dst->Call(sink, args));
+ });
}
- ctx.dst->Func(name, params, ctx.dst->ty.void_(), {});
- return name;
- });
- ast::ExpressionList args;
- for (auto* arg : side_effects) {
- args.push_back(ctx.Clone(arg));
- }
- return ctx.dst->CallStmt(ctx.dst->Call(sink, args));
- });
- }
+ }
}
- }
- ctx.Clone();
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/remove_phonies.h b/src/tint/transform/remove_phonies.h
index 6e355f1..20128a0 100644
--- a/src/tint/transform/remove_phonies.h
+++ b/src/tint/transform/remove_phonies.h
@@ -26,29 +26,26 @@
/// while preserving function call expressions in the RHS of the assignment that
/// may have side-effects.
class RemovePhonies : public Castable<RemovePhonies, Transform> {
- public:
- /// Constructor
- RemovePhonies();
+ public:
+ /// Constructor
+ RemovePhonies();
- /// Destructor
- ~RemovePhonies() override;
+ /// Destructor
+ ~RemovePhonies() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/remove_phonies_test.cc b/src/tint/transform/remove_phonies_test.cc
index e6faa3e..220f1db 100644
--- a/src/tint/transform/remove_phonies_test.cc
+++ b/src/tint/transform/remove_phonies_test.cc
@@ -26,32 +26,32 @@
using RemovePhoniesTest = TransformTest;
TEST_F(RemovePhoniesTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<RemovePhonies>(src));
+ EXPECT_FALSE(ShouldRun<RemovePhonies>(src));
}
TEST_F(RemovePhoniesTest, ShouldRunHasPhony) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
_ = 1;
}
)";
- EXPECT_TRUE(ShouldRun<RemovePhonies>(src));
+ EXPECT_TRUE(ShouldRun<RemovePhonies>(src));
}
TEST_F(RemovePhoniesTest, EmptyModule) {
- auto* src = "";
- auto* expect = "";
+ auto* src = "";
+ auto* expect = "";
- auto got = Run<RemovePhonies>(src);
+ auto got = Run<RemovePhonies>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemovePhoniesTest, NoSideEffects) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var t : texture_2d<f32>;
fn f() {
@@ -68,7 +68,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) var t : texture_2d<f32>;
fn f() {
@@ -76,13 +76,13 @@
}
)";
- auto got = Run<RemovePhonies>(src);
+ auto got = Run<RemovePhonies>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemovePhoniesTest, SingleSideEffects) {
- auto* src = R"(
+ auto* src = R"(
fn neg(a : i32) -> i32 {
return -(a);
}
@@ -103,7 +103,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn neg(a : i32) -> i32 {
return -(a);
}
@@ -124,13 +124,13 @@
}
)";
- auto got = Run<RemovePhonies>(src);
+ auto got = Run<RemovePhonies>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemovePhoniesTest, SingleSideEffects_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
_ = neg(1);
_ = add(2, 3);
@@ -151,7 +151,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
neg(1);
add(2, 3);
@@ -172,13 +172,13 @@
}
)";
- auto got = Run<RemovePhonies>(src);
+ auto got = Run<RemovePhonies>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemovePhoniesTest, MultipleSideEffects) {
- auto* src = R"(
+ auto* src = R"(
fn neg(a : i32) -> i32 {
return -(a);
}
@@ -199,7 +199,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn neg(a : i32) -> i32 {
return -(a);
}
@@ -229,13 +229,13 @@
}
)";
- auto got = Run<RemovePhonies>(src);
+ auto got = Run<RemovePhonies>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemovePhoniesTest, MultipleSideEffects_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
_ = (1 + add(2 + add(3, 4), 5)) * add(6, 7) * neg(8);
_ = add(9, neg(10)) + neg(11);
@@ -256,7 +256,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn phony_sink(p0 : i32, p1 : i32, p2 : i32) {
}
@@ -286,13 +286,13 @@
}
)";
- auto got = Run<RemovePhonies>(src);
+ auto got = Run<RemovePhonies>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemovePhoniesTest, ForLoop) {
- auto* src = R"(
+ auto* src = R"(
struct S {
arr : array<i32>,
};
@@ -321,7 +321,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
arr : array<i32>,
}
@@ -353,13 +353,13 @@
}
)";
- auto got = Run<RemovePhonies>(src);
+ auto got = Run<RemovePhonies>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemovePhoniesTest, ForLoop_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
for (_ = &s.arr; ;_ = &s.arr) {
break;
@@ -388,7 +388,7 @@
@group(0) @binding(0) var<storage, read_write> s : S;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn phony_sink(p0 : i32, p1 : i32) {
}
@@ -420,9 +420,9 @@
@group(0) @binding(0) var<storage, read_write> s : S;
)";
- auto got = Run<RemovePhonies>(src);
+ auto got = Run<RemovePhonies>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/remove_unreachable_statements.cc b/src/tint/transform/remove_unreachable_statements.cc
index 3e13ad7..964d767 100644
--- a/src/tint/transform/remove_unreachable_statements.cc
+++ b/src/tint/transform/remove_unreachable_statements.cc
@@ -36,30 +36,27 @@
RemoveUnreachableStatements::~RemoveUnreachableStatements() = default;
-bool RemoveUnreachableStatements::ShouldRun(const Program* program,
- const DataMap&) const {
- for (auto* node : program->ASTNodes().Objects()) {
- if (auto* stmt = program->Sem().Get<sem::Statement>(node)) {
- if (!stmt->IsReachable()) {
- return true;
- }
+bool RemoveUnreachableStatements::ShouldRun(const Program* program, const DataMap&) const {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (auto* stmt = program->Sem().Get<sem::Statement>(node)) {
+ if (!stmt->IsReachable()) {
+ return true;
+ }
+ }
}
- }
- return false;
+ return false;
}
-void RemoveUnreachableStatements::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- if (auto* stmt = ctx.src->Sem().Get<sem::Statement>(node)) {
- if (!stmt->IsReachable()) {
- RemoveStatement(ctx, stmt->Declaration());
- }
+void RemoveUnreachableStatements::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* stmt = ctx.src->Sem().Get<sem::Statement>(node)) {
+ if (!stmt->IsReachable()) {
+ RemoveStatement(ctx, stmt->Declaration());
+ }
+ }
}
- }
- ctx.Clone();
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/remove_unreachable_statements.h b/src/tint/transform/remove_unreachable_statements.h
index a474efb..c75da3d 100644
--- a/src/tint/transform/remove_unreachable_statements.h
+++ b/src/tint/transform/remove_unreachable_statements.h
@@ -24,31 +24,27 @@
/// RemoveUnreachableStatements is a Transform that removes all statements
/// marked as unreachable.
-class RemoveUnreachableStatements
- : public Castable<RemoveUnreachableStatements, Transform> {
- public:
- /// Constructor
- RemoveUnreachableStatements();
+class RemoveUnreachableStatements : public Castable<RemoveUnreachableStatements, Transform> {
+ public:
+ /// Constructor
+ RemoveUnreachableStatements();
- /// Destructor
- ~RemoveUnreachableStatements() override;
+ /// Destructor
+ ~RemoveUnreachableStatements() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/remove_unreachable_statements_test.cc b/src/tint/transform/remove_unreachable_statements_test.cc
index 43c1950..4b0a265 100644
--- a/src/tint/transform/remove_unreachable_statements_test.cc
+++ b/src/tint/transform/remove_unreachable_statements_test.cc
@@ -22,13 +22,13 @@
using RemoveUnreachableStatementsTest = TransformTest;
TEST_F(RemoveUnreachableStatementsTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<RemoveUnreachableStatements>(src));
+ EXPECT_FALSE(ShouldRun<RemoveUnreachableStatements>(src));
}
TEST_F(RemoveUnreachableStatementsTest, ShouldRunHasNoUnreachable) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
if (true) {
var x = 1;
@@ -36,11 +36,11 @@
}
)";
- EXPECT_FALSE(ShouldRun<RemoveUnreachableStatements>(src));
+ EXPECT_FALSE(ShouldRun<RemoveUnreachableStatements>(src));
}
TEST_F(RemoveUnreachableStatementsTest, ShouldRunHasUnreachable) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
return;
if (true) {
@@ -49,20 +49,20 @@
}
)";
- EXPECT_TRUE(ShouldRun<RemoveUnreachableStatements>(src));
+ EXPECT_TRUE(ShouldRun<RemoveUnreachableStatements>(src));
}
TEST_F(RemoveUnreachableStatementsTest, EmptyModule) {
- auto* src = "";
- auto* expect = "";
+ auto* src = "";
+ auto* expect = "";
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, Return) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
return;
var remove_me = 1;
@@ -72,19 +72,19 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
return;
}
)";
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, NestedReturn) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
{
{
@@ -98,7 +98,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
{
{
@@ -108,13 +108,13 @@
}
)";
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, Discard) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
discard;
var remove_me = 1;
@@ -124,19 +124,19 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
discard;
}
)";
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, NestedDiscard) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
{
{
@@ -150,7 +150,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
{
{
@@ -160,13 +160,13 @@
}
)";
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, CallToFuncWithDiscard) {
- auto* src = R"(
+ auto* src = R"(
fn DISCARD() {
discard;
}
@@ -180,7 +180,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn DISCARD() {
discard;
}
@@ -190,13 +190,13 @@
}
)";
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, CallToFuncWithIfDiscard) {
- auto* src = R"(
+ auto* src = R"(
fn DISCARD() {
if (true) {
discard;
@@ -212,15 +212,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, IfDiscardElseDiscard) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
if (true) {
discard;
@@ -234,7 +234,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
if (true) {
discard;
@@ -244,13 +244,13 @@
}
)";
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, IfDiscardElseReturn) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
if (true) {
discard;
@@ -264,7 +264,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
if (true) {
discard;
@@ -274,13 +274,13 @@
}
)";
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, IfDiscard) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
if (true) {
discard;
@@ -292,15 +292,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, IfReturn) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
if (true) {
return;
@@ -312,15 +312,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, IfElseDiscard) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
if (true) {
} else {
@@ -333,15 +333,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, IfElseReturn) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
if (true) {
} else {
@@ -354,15 +354,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, LoopWithDiscard) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
loop {
var a = 1;
@@ -379,7 +379,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
loop {
var a = 1;
@@ -392,13 +392,13 @@
}
)";
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, LoopWithConditionalBreak) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
loop {
var a = 1;
@@ -417,15 +417,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, LoopWithConditionalBreakInContinuing) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
loop {
@@ -442,15 +442,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, SwitchDefaultDiscard) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
switch(1) {
default: {
@@ -464,7 +464,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
switch(1) {
default: {
@@ -474,13 +474,13 @@
}
)";
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, SwitchCaseReturnDefaultDiscard) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
switch(1) {
case 0: {
@@ -497,7 +497,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
switch(1) {
case 0: {
@@ -510,13 +510,13 @@
}
)";
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, SwitchCaseBreakDefaultDiscard) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
switch(1) {
case 0: {
@@ -533,15 +533,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, SwitchCaseReturnDefaultBreak) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
switch(1) {
case 0: {
@@ -558,11 +558,11 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<RemoveUnreachableStatements>(src);
+ auto got = Run<RemoveUnreachableStatements>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/renamer.cc b/src/tint/transform/renamer.cc
index 9627911..50cd781 100644
--- a/src/tint/transform/renamer.cc
+++ b/src/tint/transform/renamer.cc
@@ -1253,114 +1253,107 @@
Renamer::~Renamer() = default;
Output Renamer::Run(const Program* in, const DataMap& inputs) const {
- ProgramBuilder out;
- // Disable auto-cloning of symbols, since we want to rename them.
- CloneContext ctx(&out, in, false);
+ ProgramBuilder out;
+ // Disable auto-cloning of symbols, since we want to rename them.
+ CloneContext ctx(&out, in, false);
- // Swizzles, builtin calls and builtin structure members need to keep their
- // symbols preserved.
- std::unordered_set<const ast::IdentifierExpression*> preserve;
- for (auto* node : in->ASTNodes().Objects()) {
- if (auto* member = node->As<ast::MemberAccessorExpression>()) {
- auto* sem = in->Sem().Get(member);
- if (!sem) {
- TINT_ICE(Transform, out.Diagnostics())
- << "MemberAccessorExpression has no semantic info";
- continue;
- }
- if (sem->Is<sem::Swizzle>()) {
- preserve.emplace(member->member);
- } else if (auto* str_expr = in->Sem().Get(member->structure)) {
- if (auto* ty = str_expr->Type()->UnwrapRef()->As<sem::Struct>()) {
- if (ty->Declaration() == nullptr) { // Builtin structure
- preserve.emplace(member->member);
- }
+ // Swizzles, builtin calls and builtin structure members need to keep their
+ // symbols preserved.
+ std::unordered_set<const ast::IdentifierExpression*> preserve;
+ for (auto* node : in->ASTNodes().Objects()) {
+ if (auto* member = node->As<ast::MemberAccessorExpression>()) {
+ auto* sem = in->Sem().Get(member);
+ if (!sem) {
+ TINT_ICE(Transform, out.Diagnostics())
+ << "MemberAccessorExpression has no semantic info";
+ continue;
+ }
+ if (sem->Is<sem::Swizzle>()) {
+ preserve.emplace(member->member);
+ } else if (auto* str_expr = in->Sem().Get(member->structure)) {
+ if (auto* ty = str_expr->Type()->UnwrapRef()->As<sem::Struct>()) {
+ if (ty->Declaration() == nullptr) { // Builtin structure
+ preserve.emplace(member->member);
+ }
+ }
+ }
+ } else if (auto* call = node->As<ast::CallExpression>()) {
+ auto* sem = in->Sem().Get(call);
+ if (!sem) {
+ TINT_ICE(Transform, out.Diagnostics()) << "CallExpression has no semantic info";
+ continue;
+ }
+ if (sem->Target()->Is<sem::Builtin>()) {
+ preserve.emplace(call->target.name);
+ }
}
- }
- } else if (auto* call = node->As<ast::CallExpression>()) {
- auto* sem = in->Sem().Get(call);
- if (!sem) {
- TINT_ICE(Transform, out.Diagnostics())
- << "CallExpression has no semantic info";
- continue;
- }
- if (sem->Target()->Is<sem::Builtin>()) {
- preserve.emplace(call->target.name);
- }
- }
- }
-
- Data::Remappings remappings;
-
- Target target = Target::kAll;
- bool preserve_unicode = false;
-
- if (auto* cfg = inputs.Get<Config>()) {
- target = cfg->target;
- preserve_unicode = cfg->preserve_unicode;
- }
-
- ctx.ReplaceAll([&](Symbol sym_in) {
- auto name_in = ctx.src->Symbols().NameFor(sym_in);
- if (preserve_unicode || text::utf8::IsASCII(name_in)) {
- switch (target) {
- case Target::kAll:
- // Always rename.
- break;
- case Target::kGlslKeywords:
- if (!std::binary_search(
- kReservedKeywordsGLSL,
- kReservedKeywordsGLSL +
- sizeof(kReservedKeywordsGLSL) / sizeof(const char*),
- name_in) &&
- name_in.compare(0, 3, "gl_")) {
- // No match, just reuse the original name.
- return ctx.dst->Symbols().New(name_in);
- }
- break;
- case Target::kHlslKeywords:
- if (!std::binary_search(
- kReservedKeywordsHLSL,
- kReservedKeywordsHLSL +
- sizeof(kReservedKeywordsHLSL) / sizeof(const char*),
- name_in)) {
- // No match, just reuse the original name.
- return ctx.dst->Symbols().New(name_in);
- }
- break;
- case Target::kMslKeywords:
- if (!std::binary_search(
- kReservedKeywordsMSL,
- kReservedKeywordsMSL +
- sizeof(kReservedKeywordsMSL) / sizeof(const char*),
- name_in)) {
- // No match, just reuse the original name.
- return ctx.dst->Symbols().New(name_in);
- }
- break;
- }
}
- auto sym_out = ctx.dst->Sym();
- remappings.emplace(name_in, ctx.dst->Symbols().NameFor(sym_out));
- return sym_out;
- });
+ Data::Remappings remappings;
- ctx.ReplaceAll([&](const ast::IdentifierExpression* ident)
- -> const ast::IdentifierExpression* {
- if (preserve.count(ident)) {
- auto sym_in = ident->symbol;
- auto str = in->Symbols().NameFor(sym_in);
- auto sym_out = out.Symbols().Register(str);
- return ctx.dst->create<ast::IdentifierExpression>(
- ctx.Clone(ident->source), sym_out);
+ Target target = Target::kAll;
+ bool preserve_unicode = false;
+
+ if (auto* cfg = inputs.Get<Config>()) {
+ target = cfg->target;
+ preserve_unicode = cfg->preserve_unicode;
}
- return nullptr; // Clone ident. Uses the symbol remapping above.
- });
- ctx.Clone();
- return Output(Program(std::move(out)),
- std::make_unique<Data>(std::move(remappings)));
+ ctx.ReplaceAll([&](Symbol sym_in) {
+ auto name_in = ctx.src->Symbols().NameFor(sym_in);
+ if (preserve_unicode || text::utf8::IsASCII(name_in)) {
+ switch (target) {
+ case Target::kAll:
+ // Always rename.
+ break;
+ case Target::kGlslKeywords:
+ if (!std::binary_search(kReservedKeywordsGLSL,
+ kReservedKeywordsGLSL +
+ sizeof(kReservedKeywordsGLSL) / sizeof(const char*),
+ name_in) &&
+ name_in.compare(0, 3, "gl_")) {
+ // No match, just reuse the original name.
+ return ctx.dst->Symbols().New(name_in);
+ }
+ break;
+ case Target::kHlslKeywords:
+ if (!std::binary_search(kReservedKeywordsHLSL,
+ kReservedKeywordsHLSL +
+ sizeof(kReservedKeywordsHLSL) / sizeof(const char*),
+ name_in)) {
+ // No match, just reuse the original name.
+ return ctx.dst->Symbols().New(name_in);
+ }
+ break;
+ case Target::kMslKeywords:
+ if (!std::binary_search(kReservedKeywordsMSL,
+ kReservedKeywordsMSL +
+ sizeof(kReservedKeywordsMSL) / sizeof(const char*),
+ name_in)) {
+ // No match, just reuse the original name.
+ return ctx.dst->Symbols().New(name_in);
+ }
+ break;
+ }
+ }
+
+ auto sym_out = ctx.dst->Sym();
+ remappings.emplace(name_in, ctx.dst->Symbols().NameFor(sym_out));
+ return sym_out;
+ });
+
+ ctx.ReplaceAll([&](const ast::IdentifierExpression* ident) -> const ast::IdentifierExpression* {
+ if (preserve.count(ident)) {
+ auto sym_in = ident->symbol;
+ auto str = in->Symbols().NameFor(sym_in);
+ auto sym_out = out.Symbols().Register(str);
+ return ctx.dst->create<ast::IdentifierExpression>(ctx.Clone(ident->source), sym_out);
+ }
+ return nullptr; // Clone ident. Uses the symbol remapping above.
+ });
+ ctx.Clone();
+
+ return Output(Program(std::move(out)), std::make_unique<Data>(std::move(remappings)));
}
} // namespace tint::transform
diff --git a/src/tint/transform/renamer.h b/src/tint/transform/renamer.h
index ad37b0c..354acda 100644
--- a/src/tint/transform/renamer.h
+++ b/src/tint/transform/renamer.h
@@ -24,72 +24,72 @@
/// Renamer is a Transform that renames all the symbols in a program.
class Renamer : public Castable<Renamer, Transform> {
- public:
- /// Data is outputted by the Renamer transform.
- /// Data holds information about shader usage and constant buffer offsets.
- struct Data : public Castable<Data, transform::Data> {
- /// Remappings is a map of old symbol name to new symbol name
- using Remappings = std::unordered_map<std::string, std::string>;
+ public:
+ /// Data is outputted by the Renamer transform.
+ /// Data holds information about shader usage and constant buffer offsets.
+ struct Data : public Castable<Data, transform::Data> {
+ /// Remappings is a map of old symbol name to new symbol name
+ using Remappings = std::unordered_map<std::string, std::string>;
- /// Constructor
- /// @param remappings the symbol remappings
- explicit Data(Remappings&& remappings);
+ /// Constructor
+ /// @param remappings the symbol remappings
+ explicit Data(Remappings&& remappings);
- /// Copy constructor
- Data(const Data&);
+ /// Copy constructor
+ Data(const Data&);
+
+ /// Destructor
+ ~Data() override;
+
+ /// A map of old symbol name to new symbol name
+ const Remappings remappings;
+ };
+
+ /// Target is an enumerator of rename targets that can be used
+ enum class Target {
+ /// Rename every symbol.
+ kAll,
+ /// Only rename symbols that are reserved keywords in GLSL.
+ kGlslKeywords,
+ /// Only rename symbols that are reserved keywords in HLSL.
+ kHlslKeywords,
+ /// Only rename symbols that are reserved keywords in MSL.
+ kMslKeywords,
+ };
+
+ /// Optional configuration options for the transform.
+ /// If omitted, then the renamer will use Target::kAll.
+ struct Config : public Castable<Config, transform::Data> {
+ /// Constructor
+ /// @param tgt the targets to rename
+ /// @param keep_unicode if false, symbols with non-ascii code-points are
+ /// renamed
+ explicit Config(Target tgt, bool keep_unicode = false);
+
+ /// Copy constructor
+ Config(const Config&);
+
+ /// Destructor
+ ~Config() override;
+
+ /// The targets to rename
+ Target const target = Target::kAll;
+
+ /// If false, symbols with non-ascii code-points are renamed.
+ bool preserve_unicode = false;
+ };
+
+ /// Constructor using a the configuration provided in the input Data
+ Renamer();
/// Destructor
- ~Data() override;
+ ~Renamer() override;
- /// A map of old symbol name to new symbol name
- const Remappings remappings;
- };
-
- /// Target is an enumerator of rename targets that can be used
- enum class Target {
- /// Rename every symbol.
- kAll,
- /// Only rename symbols that are reserved keywords in GLSL.
- kGlslKeywords,
- /// Only rename symbols that are reserved keywords in HLSL.
- kHlslKeywords,
- /// Only rename symbols that are reserved keywords in MSL.
- kMslKeywords,
- };
-
- /// Optional configuration options for the transform.
- /// If omitted, then the renamer will use Target::kAll.
- struct Config : public Castable<Config, transform::Data> {
- /// Constructor
- /// @param tgt the targets to rename
- /// @param keep_unicode if false, symbols with non-ascii code-points are
- /// renamed
- explicit Config(Target tgt, bool keep_unicode = false);
-
- /// Copy constructor
- Config(const Config&);
-
- /// Destructor
- ~Config() override;
-
- /// The targets to rename
- Target const target = Target::kAll;
-
- /// If false, symbols with non-ascii code-points are renamed.
- bool preserve_unicode = false;
- };
-
- /// Constructor using a the configuration provided in the input Data
- Renamer();
-
- /// Destructor
- ~Renamer() override;
-
- /// Runs the transform on `program`, returning the transformation result.
- /// @param program the source program to transform
- /// @param data optional extra transform-specific input data
- /// @returns the transformation result
- Output Run(const Program* program, const DataMap& data = {}) const override;
+ /// Runs the transform on `program`, returning the transformation result.
+ /// @param program the source program to transform
+ /// @param data optional extra transform-specific input data
+ /// @returns the transformation result
+ Output Run(const Program* program, const DataMap& data = {}) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/renamer_test.cc b/src/tint/transform/renamer_test.cc
index e3f9458..f56971e 100644
--- a/src/tint/transform/renamer_test.cc
+++ b/src/tint/transform/renamer_test.cc
@@ -32,20 +32,20 @@
using RenamerTest = TransformTest;
TEST_F(RenamerTest, EmptyModule) {
- auto* src = "";
- auto* expect = "";
+ auto* src = "";
+ auto* expect = "";
- auto got = Run<Renamer>(src);
+ auto got = Run<Renamer>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<Renamer::Data>();
+ auto* data = got.data.Get<Renamer::Data>();
- ASSERT_EQ(data->remappings.size(), 0u);
+ ASSERT_EQ(data->remappings.size(), 0u);
}
TEST_F(RenamerTest, BasicModuleVertexIndex) {
- auto* src = R"(
+ auto* src = R"(
fn test(vert_idx : u32) -> u32 {
return vert_idx;
}
@@ -58,7 +58,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_symbol(tint_symbol_1 : u32) -> u32 {
return tint_symbol_1;
}
@@ -70,23 +70,23 @@
}
)";
- auto got = Run<Renamer>(src);
+ auto got = Run<Renamer>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<Renamer::Data>();
+ auto* data = got.data.Get<Renamer::Data>();
- ASSERT_NE(data, nullptr);
- Renamer::Data::Remappings expected_remappings = {
- {"vert_idx", "tint_symbol_1"},
- {"test", "tint_symbol"},
- {"entry", "tint_symbol_2"},
- };
- EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
+ ASSERT_NE(data, nullptr);
+ Renamer::Data::Remappings expected_remappings = {
+ {"vert_idx", "tint_symbol_1"},
+ {"test", "tint_symbol"},
+ {"entry", "tint_symbol_2"},
+ };
+ EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
}
TEST_F(RenamerTest, PreserveSwizzles) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn entry() -> @builtin(position) vec4<f32> {
var v : vec4<f32>;
@@ -96,7 +96,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(vertex)
fn tint_symbol() -> @builtin(position) vec4<f32> {
var tint_symbol_1 : vec4<f32>;
@@ -106,24 +106,24 @@
}
)";
- auto got = Run<Renamer>(src);
+ auto got = Run<Renamer>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<Renamer::Data>();
+ auto* data = got.data.Get<Renamer::Data>();
- ASSERT_NE(data, nullptr);
- Renamer::Data::Remappings expected_remappings = {
- {"entry", "tint_symbol"},
- {"v", "tint_symbol_1"},
- {"rgba", "tint_symbol_2"},
- {"xyzw", "tint_symbol_3"},
- };
- EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
+ ASSERT_NE(data, nullptr);
+ Renamer::Data::Remappings expected_remappings = {
+ {"entry", "tint_symbol"},
+ {"v", "tint_symbol_1"},
+ {"rgba", "tint_symbol_2"},
+ {"xyzw", "tint_symbol_3"},
+ };
+ EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
}
TEST_F(RenamerTest, PreserveBuiltins) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn entry() -> @builtin(position) vec4<f32> {
var blah : vec4<f32>;
@@ -131,7 +131,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(vertex)
fn tint_symbol() -> @builtin(position) vec4<f32> {
var tint_symbol_1 : vec4<f32>;
@@ -139,22 +139,22 @@
}
)";
- auto got = Run<Renamer>(src);
+ auto got = Run<Renamer>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<Renamer::Data>();
+ auto* data = got.data.Get<Renamer::Data>();
- ASSERT_NE(data, nullptr);
- Renamer::Data::Remappings expected_remappings = {
- {"entry", "tint_symbol"},
- {"blah", "tint_symbol_1"},
- };
- EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
+ ASSERT_NE(data, nullptr);
+ Renamer::Data::Remappings expected_remappings = {
+ {"entry", "tint_symbol"},
+ {"blah", "tint_symbol_1"},
+ };
+ EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
}
TEST_F(RenamerTest, PreserveBuiltinTypes) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn entry() {
var a = modf(1.0).whole;
@@ -164,7 +164,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn tint_symbol() {
var tint_symbol_1 = modf(1.0).whole;
@@ -174,41 +174,41 @@
}
)";
- auto got = Run<Renamer>(src);
+ auto got = Run<Renamer>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<Renamer::Data>();
+ auto* data = got.data.Get<Renamer::Data>();
- ASSERT_NE(data, nullptr);
- Renamer::Data::Remappings expected_remappings = {
- {"entry", "tint_symbol"}, {"a", "tint_symbol_1"}, {"b", "tint_symbol_2"},
- {"c", "tint_symbol_3"}, {"d", "tint_symbol_4"},
- };
- EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
+ ASSERT_NE(data, nullptr);
+ Renamer::Data::Remappings expected_remappings = {
+ {"entry", "tint_symbol"}, {"a", "tint_symbol_1"}, {"b", "tint_symbol_2"},
+ {"c", "tint_symbol_3"}, {"d", "tint_symbol_4"},
+ };
+ EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
}
TEST_F(RenamerTest, PreserveUnicode) {
- auto src = R"(
+ auto src = R"(
@stage(fragment)
fn frag_main() {
var )" + std::string(kUnicodeIdentifier) +
- R"( : i32;
+ R"( : i32;
}
)";
- auto expect = src;
+ auto expect = src;
- DataMap inputs;
- inputs.Add<Renamer::Config>(Renamer::Target::kMslKeywords,
- /* preserve_unicode */ true);
- auto got = Run<Renamer>(src, inputs);
+ DataMap inputs;
+ inputs.Add<Renamer::Config>(Renamer::Target::kMslKeywords,
+ /* preserve_unicode */ true);
+ auto got = Run<Renamer>(src, inputs);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RenamerTest, AttemptSymbolCollision) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn entry() -> @builtin(position) vec4<f32> {
var tint_symbol : vec4<f32>;
@@ -218,7 +218,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(vertex)
fn tint_symbol() -> @builtin(position) vec4<f32> {
var tint_symbol_1 : vec4<f32>;
@@ -228,20 +228,20 @@
}
)";
- auto got = Run<Renamer>(src);
+ auto got = Run<Renamer>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
- auto* data = got.data.Get<Renamer::Data>();
+ auto* data = got.data.Get<Renamer::Data>();
- ASSERT_NE(data, nullptr);
- Renamer::Data::Remappings expected_remappings = {
- {"entry", "tint_symbol"},
- {"tint_symbol", "tint_symbol_1"},
- {"tint_symbol_2", "tint_symbol_2"},
- {"tint_symbol_4", "tint_symbol_3"},
- };
- EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
+ ASSERT_NE(data, nullptr);
+ Renamer::Data::Remappings expected_remappings = {
+ {"entry", "tint_symbol"},
+ {"tint_symbol", "tint_symbol_1"},
+ {"tint_symbol_2", "tint_symbol_2"},
+ {"tint_symbol_4", "tint_symbol_3"},
+ };
+ EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
}
using RenamerTestGlsl = TransformTestWithParam<std::string>;
@@ -249,81 +249,81 @@
using RenamerTestMsl = TransformTestWithParam<std::string>;
TEST_P(RenamerTestGlsl, Keywords) {
- auto keyword = GetParam();
+ auto keyword = GetParam();
- auto src = R"(
+ auto src = R"(
@stage(fragment)
fn frag_main() {
var )" + keyword +
- R"( : i32;
+ R"( : i32;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(fragment)
fn frag_main() {
var tint_symbol : i32;
}
)";
- DataMap inputs;
- inputs.Add<Renamer::Config>(Renamer::Target::kGlslKeywords,
- /* preserve_unicode */ false);
- auto got = Run<Renamer>(src, inputs);
+ DataMap inputs;
+ inputs.Add<Renamer::Config>(Renamer::Target::kGlslKeywords,
+ /* preserve_unicode */ false);
+ auto got = Run<Renamer>(src, inputs);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_P(RenamerTestHlsl, Keywords) {
- auto keyword = GetParam();
+ auto keyword = GetParam();
- auto src = R"(
+ auto src = R"(
@stage(fragment)
fn frag_main() {
var )" + keyword +
- R"( : i32;
+ R"( : i32;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(fragment)
fn frag_main() {
var tint_symbol : i32;
}
)";
- DataMap inputs;
- inputs.Add<Renamer::Config>(Renamer::Target::kHlslKeywords,
- /* preserve_unicode */ false);
- auto got = Run<Renamer>(src, inputs);
+ DataMap inputs;
+ inputs.Add<Renamer::Config>(Renamer::Target::kHlslKeywords,
+ /* preserve_unicode */ false);
+ auto got = Run<Renamer>(src, inputs);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_P(RenamerTestMsl, Keywords) {
- auto keyword = GetParam();
+ auto keyword = GetParam();
- auto src = R"(
+ auto src = R"(
@stage(fragment)
fn frag_main() {
var )" + keyword +
- R"( : i32;
+ R"( : i32;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(fragment)
fn frag_main() {
var tint_symbol : i32;
}
)";
- DataMap inputs;
- inputs.Add<Renamer::Config>(Renamer::Target::kMslKeywords,
- /* preserve_unicode */ false);
- auto got = Run<Renamer>(src, inputs);
+ DataMap inputs;
+ inputs.Add<Renamer::Config>(Renamer::Target::kMslKeywords,
+ /* preserve_unicode */ false);
+ auto got = Run<Renamer>(src, inputs);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
INSTANTIATE_TEST_SUITE_P(RenamerTestGlsl,
diff --git a/src/tint/transform/robustness.cc b/src/tint/transform/robustness.cc
index 082c4ac..d946d23 100644
--- a/src/tint/transform/robustness.cc
+++ b/src/tint/transform/robustness.cc
@@ -32,255 +32,247 @@
/// State holds the current transform state
struct Robustness::State {
- /// The clone context
- CloneContext& ctx;
+ /// The clone context
+ CloneContext& ctx;
- /// Set of storage classes to not apply the transform to
- std::unordered_set<ast::StorageClass> omitted_classes;
+ /// Set of storage classes to not apply the transform to
+ std::unordered_set<ast::StorageClass> omitted_classes;
- /// Applies the transformation state to `ctx`.
- void Transform() {
- ctx.ReplaceAll([&](const ast::IndexAccessorExpression* expr) {
- return Transform(expr);
- });
- ctx.ReplaceAll(
- [&](const ast::CallExpression* expr) { return Transform(expr); });
- }
-
- /// Apply bounds clamping to array, vector and matrix indexing
- /// @param expr the array, vector or matrix index expression
- /// @return the clamped replacement expression, or nullptr if `expr` should be
- /// cloned without changes.
- const ast::IndexAccessorExpression* Transform(
- const ast::IndexAccessorExpression* expr) {
- auto* ret_type = ctx.src->Sem().Get(expr->object)->Type();
-
- auto* ref = ret_type->As<sem::Reference>();
- if (ref && omitted_classes.count(ref->StorageClass()) != 0) {
- return nullptr;
+ /// Applies the transformation state to `ctx`.
+ void Transform() {
+ ctx.ReplaceAll([&](const ast::IndexAccessorExpression* expr) { return Transform(expr); });
+ ctx.ReplaceAll([&](const ast::CallExpression* expr) { return Transform(expr); });
}
- auto* ret_unwrapped = ret_type->UnwrapRef();
+ /// Apply bounds clamping to array, vector and matrix indexing
+ /// @param expr the array, vector or matrix index expression
+ /// @return the clamped replacement expression, or nullptr if `expr` should be
+ /// cloned without changes.
+ const ast::IndexAccessorExpression* Transform(const ast::IndexAccessorExpression* expr) {
+ auto* ret_type = ctx.src->Sem().Get(expr->object)->Type();
- ProgramBuilder& b = *ctx.dst;
- using u32 = ProgramBuilder::u32;
-
- struct Value {
- const ast::Expression* expr = nullptr; // If null, then is a constant
- union {
- uint32_t u32 = 0; // use if is_signed == false
- int32_t i32; // use if is_signed == true
- };
- bool is_signed = false;
- };
-
- Value size; // size of the array, vector or matrix
- size.is_signed = false; // size is always unsigned
- if (auto* vec = ret_unwrapped->As<sem::Vector>()) {
- size.u32 = vec->Width();
-
- } else if (auto* arr = ret_unwrapped->As<sem::Array>()) {
- size.u32 = arr->Count();
- } else if (auto* mat = ret_unwrapped->As<sem::Matrix>()) {
- // The row accessor would have been an embedded index accessor and already
- // handled, so we just need to do columns here.
- size.u32 = mat->columns();
- } else {
- return nullptr;
- }
-
- if (size.u32 == 0) {
- if (!ret_unwrapped->Is<sem::Array>()) {
- b.Diagnostics().add_error(diag::System::Transform,
- "invalid 0 sized non-array", expr->source);
- return nullptr;
- }
- // Runtime sized array
- auto* arr = ctx.Clone(expr->object);
- size.expr = b.Call("arrayLength", b.AddressOf(arr));
- }
-
- // Calculate the maximum possible index value (size-1u)
- // Size must be positive (non-zero), so we can safely subtract 1 here
- // without underflow.
- Value limit;
- limit.is_signed = false; // Like size, limit is always unsigned.
- if (size.expr) {
- // Dynamic size
- limit.expr = b.Sub(size.expr, 1u);
- } else {
- // Constant size
- limit.u32 = size.u32 - 1u;
- }
-
- Value idx; // index value
-
- auto* idx_sem = ctx.src->Sem().Get(expr->index);
- auto* idx_ty = idx_sem->Type()->UnwrapRef();
- if (!idx_ty->IsAnyOf<sem::I32, sem::U32>()) {
- TINT_ICE(Transform, b.Diagnostics()) << "index must be u32 or i32, got "
- << idx_sem->Type()->TypeInfo().name;
- return nullptr;
- }
-
- if (auto idx_constant = idx_sem->ConstantValue()) {
- // Constant value index
- if (idx_constant.Type()->Is<sem::I32>()) {
- idx.i32 = idx_constant.Elements()[0].i32;
- idx.is_signed = true;
- } else if (idx_constant.Type()->Is<sem::U32>()) {
- idx.u32 = idx_constant.Elements()[0].u32;
- idx.is_signed = false;
- } else {
- TINT_ICE(Transform, b.Diagnostics())
- << "unsupported constant value for accessor "
- << idx_constant.Type()->TypeInfo().name;
- return nullptr;
- }
- } else {
- // Dynamic value index
- idx.expr = ctx.Clone(expr->index);
- idx.is_signed = idx_ty->Is<sem::I32>();
- }
-
- // Clamp the index so that it cannot exceed limit.
- if (idx.expr || limit.expr) {
- // One of, or both of idx and limit are non-constant.
-
- // If the index is signed, cast it to a u32 (with clamping if constant).
- if (idx.is_signed) {
- if (idx.expr) {
- // We don't use a max(idx, 0) here, as that incurs a runtime
- // performance cost, and if the unsigned value will be clamped by
- // limit, resulting in a value between [0..limit)
- idx.expr = b.Construct<u32>(idx.expr);
- idx.is_signed = false;
- } else {
- idx.u32 = static_cast<uint32_t>(std::max(idx.i32, 0));
- idx.is_signed = false;
+ auto* ref = ret_type->As<sem::Reference>();
+ if (ref && omitted_classes.count(ref->StorageClass()) != 0) {
+ return nullptr;
}
- }
- // Convert idx and limit to expressions, so we can emit `min(idx, limit)`.
- if (!idx.expr) {
- idx.expr = b.Expr(idx.u32);
- }
- if (!limit.expr) {
- limit.expr = b.Expr(limit.u32);
- }
+ auto* ret_unwrapped = ret_type->UnwrapRef();
- // Perform the clamp with `min(idx, limit)`
- idx.expr = b.Call("min", idx.expr, limit.expr);
- } else {
- // Both idx and max are constant.
- if (idx.is_signed) {
- // The index is signed. Calculate limit as signed.
- int32_t signed_limit = static_cast<int32_t>(
- std::min<uint32_t>(limit.u32, std::numeric_limits<int32_t>::max()));
- idx.i32 = std::max(idx.i32, 0);
- idx.i32 = std::min(idx.i32, signed_limit);
- } else {
- // The index is unsigned.
- idx.u32 = std::min(idx.u32, limit.u32);
- }
+ ProgramBuilder& b = *ctx.dst;
+ using u32 = ProgramBuilder::u32;
+
+ struct Value {
+ const ast::Expression* expr = nullptr; // If null, then is a constant
+ union {
+ uint32_t u32 = 0; // use if is_signed == false
+ int32_t i32; // use if is_signed == true
+ };
+ bool is_signed = false;
+ };
+
+ Value size; // size of the array, vector or matrix
+ size.is_signed = false; // size is always unsigned
+ if (auto* vec = ret_unwrapped->As<sem::Vector>()) {
+ size.u32 = vec->Width();
+
+ } else if (auto* arr = ret_unwrapped->As<sem::Array>()) {
+ size.u32 = arr->Count();
+ } else if (auto* mat = ret_unwrapped->As<sem::Matrix>()) {
+ // The row accessor would have been an embedded index accessor and already
+ // handled, so we just need to do columns here.
+ size.u32 = mat->columns();
+ } else {
+ return nullptr;
+ }
+
+ if (size.u32 == 0) {
+ if (!ret_unwrapped->Is<sem::Array>()) {
+ b.Diagnostics().add_error(diag::System::Transform, "invalid 0 sized non-array",
+ expr->source);
+ return nullptr;
+ }
+ // Runtime sized array
+ auto* arr = ctx.Clone(expr->object);
+ size.expr = b.Call("arrayLength", b.AddressOf(arr));
+ }
+
+ // Calculate the maximum possible index value (size-1u)
+ // Size must be positive (non-zero), so we can safely subtract 1 here
+ // without underflow.
+ Value limit;
+ limit.is_signed = false; // Like size, limit is always unsigned.
+ if (size.expr) {
+ // Dynamic size
+ limit.expr = b.Sub(size.expr, 1u);
+ } else {
+ // Constant size
+ limit.u32 = size.u32 - 1u;
+ }
+
+ Value idx; // index value
+
+ auto* idx_sem = ctx.src->Sem().Get(expr->index);
+ auto* idx_ty = idx_sem->Type()->UnwrapRef();
+ if (!idx_ty->IsAnyOf<sem::I32, sem::U32>()) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "index must be u32 or i32, got " << idx_sem->Type()->TypeInfo().name;
+ return nullptr;
+ }
+
+ if (auto idx_constant = idx_sem->ConstantValue()) {
+ // Constant value index
+ if (idx_constant.Type()->Is<sem::I32>()) {
+ idx.i32 = idx_constant.Elements()[0].i32;
+ idx.is_signed = true;
+ } else if (idx_constant.Type()->Is<sem::U32>()) {
+ idx.u32 = idx_constant.Elements()[0].u32;
+ idx.is_signed = false;
+ } else {
+ TINT_ICE(Transform, b.Diagnostics()) << "unsupported constant value for accessor "
+ << idx_constant.Type()->TypeInfo().name;
+ return nullptr;
+ }
+ } else {
+ // Dynamic value index
+ idx.expr = ctx.Clone(expr->index);
+ idx.is_signed = idx_ty->Is<sem::I32>();
+ }
+
+ // Clamp the index so that it cannot exceed limit.
+ if (idx.expr || limit.expr) {
+ // One of, or both of idx and limit are non-constant.
+
+ // If the index is signed, cast it to a u32 (with clamping if constant).
+ if (idx.is_signed) {
+ if (idx.expr) {
+ // We don't use a max(idx, 0) here, as that incurs a runtime
+ // performance cost, and if the unsigned value will be clamped by
+ // limit, resulting in a value between [0..limit)
+ idx.expr = b.Construct<u32>(idx.expr);
+ idx.is_signed = false;
+ } else {
+ idx.u32 = static_cast<uint32_t>(std::max(idx.i32, 0));
+ idx.is_signed = false;
+ }
+ }
+
+ // Convert idx and limit to expressions, so we can emit `min(idx, limit)`.
+ if (!idx.expr) {
+ idx.expr = b.Expr(idx.u32);
+ }
+ if (!limit.expr) {
+ limit.expr = b.Expr(limit.u32);
+ }
+
+ // Perform the clamp with `min(idx, limit)`
+ idx.expr = b.Call("min", idx.expr, limit.expr);
+ } else {
+ // Both idx and max are constant.
+ if (idx.is_signed) {
+ // The index is signed. Calculate limit as signed.
+ int32_t signed_limit = static_cast<int32_t>(
+ std::min<uint32_t>(limit.u32, std::numeric_limits<int32_t>::max()));
+ idx.i32 = std::max(idx.i32, 0);
+ idx.i32 = std::min(idx.i32, signed_limit);
+ } else {
+ // The index is unsigned.
+ idx.u32 = std::min(idx.u32, limit.u32);
+ }
+ }
+
+ // Convert idx to an expression, so we can emit the new accessor.
+ if (!idx.expr) {
+ idx.expr = idx.is_signed ? static_cast<const ast::Expression*>(b.Expr(idx.i32))
+ : static_cast<const ast::Expression*>(b.Expr(idx.u32));
+ }
+
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx.Clone(expr->source);
+ auto* obj = ctx.Clone(expr->object);
+ return b.IndexAccessor(src, obj, idx.expr);
}
- // Convert idx to an expression, so we can emit the new accessor.
- if (!idx.expr) {
- idx.expr = idx.is_signed
- ? static_cast<const ast::Expression*>(b.Expr(idx.i32))
- : static_cast<const ast::Expression*>(b.Expr(idx.u32));
+ /// @param type builtin type
+ /// @returns true if the given builtin is a texture function that requires
+ /// argument clamping,
+ bool TextureBuiltinNeedsClamping(sem::BuiltinType type) {
+ return type == sem::BuiltinType::kTextureLoad || type == sem::BuiltinType::kTextureStore;
}
- // Clone arguments outside of create() call to have deterministic ordering
- auto src = ctx.Clone(expr->source);
- auto* obj = ctx.Clone(expr->object);
- return b.IndexAccessor(src, obj, idx.expr);
- }
+ /// Apply bounds clamping to the coordinates, array index and level arguments
+ /// of the `textureLoad()` and `textureStore()` builtins.
+ /// @param expr the builtin call expression
+ /// @return the clamped replacement call expression, or nullptr if `expr`
+ /// should be cloned without changes.
+ const ast::CallExpression* Transform(const ast::CallExpression* expr) {
+ auto* call = ctx.src->Sem().Get(expr);
+ auto* call_target = call->Target();
+ auto* builtin = call_target->As<sem::Builtin>();
+ if (!builtin || !TextureBuiltinNeedsClamping(builtin->Type())) {
+ return nullptr; // No transform, just clone.
+ }
- /// @param type builtin type
- /// @returns true if the given builtin is a texture function that requires
- /// argument clamping,
- bool TextureBuiltinNeedsClamping(sem::BuiltinType type) {
- return type == sem::BuiltinType::kTextureLoad ||
- type == sem::BuiltinType::kTextureStore;
- }
+ ProgramBuilder& b = *ctx.dst;
- /// Apply bounds clamping to the coordinates, array index and level arguments
- /// of the `textureLoad()` and `textureStore()` builtins.
- /// @param expr the builtin call expression
- /// @return the clamped replacement call expression, or nullptr if `expr`
- /// should be cloned without changes.
- const ast::CallExpression* Transform(const ast::CallExpression* expr) {
- auto* call = ctx.src->Sem().Get(expr);
- auto* call_target = call->Target();
- auto* builtin = call_target->As<sem::Builtin>();
- if (!builtin || !TextureBuiltinNeedsClamping(builtin->Type())) {
- return nullptr; // No transform, just clone.
+ // Indices of the mandatory texture and coords parameters, and the optional
+ // array and level parameters.
+ auto& signature = builtin->Signature();
+ auto texture_idx = signature.IndexOf(sem::ParameterUsage::kTexture);
+ auto coords_idx = signature.IndexOf(sem::ParameterUsage::kCoords);
+ auto array_idx = signature.IndexOf(sem::ParameterUsage::kArrayIndex);
+ auto level_idx = signature.IndexOf(sem::ParameterUsage::kLevel);
+
+ auto* texture_arg = expr->args[texture_idx];
+ auto* coords_arg = expr->args[coords_idx];
+ auto* coords_ty = builtin->Parameters()[coords_idx]->Type();
+
+ // If the level is provided, then we need to clamp this. As the level is
+ // used by textureDimensions() and the texture[Load|Store]() calls, we need
+ // to clamp both usages.
+ // TODO(bclayton): We probably want to place this into a let so that the
+ // calculation can be reused. This is fiddly to get right.
+ std::function<const ast::Expression*()> level_arg;
+ if (level_idx >= 0) {
+ level_arg = [&] {
+ auto* arg = expr->args[level_idx];
+ auto* num_levels = b.Call("textureNumLevels", ctx.Clone(texture_arg));
+ auto* zero = b.Expr(0);
+ auto* max = ctx.dst->Sub(num_levels, 1);
+ auto* clamped = b.Call("clamp", ctx.Clone(arg), zero, max);
+ return clamped;
+ };
+ }
+
+ // Clamp the coordinates argument
+ {
+ auto* texture_dims =
+ level_arg ? b.Call("textureDimensions", ctx.Clone(texture_arg), level_arg())
+ : b.Call("textureDimensions", ctx.Clone(texture_arg));
+ auto* zero = b.Construct(CreateASTTypeFor(ctx, coords_ty));
+ auto* max =
+ ctx.dst->Sub(texture_dims, b.Construct(CreateASTTypeFor(ctx, coords_ty), 1));
+ auto* clamped_coords = b.Call("clamp", ctx.Clone(coords_arg), zero, max);
+ ctx.Replace(coords_arg, clamped_coords);
+ }
+
+ // Clamp the array_index argument, if provided
+ if (array_idx >= 0) {
+ auto* arg = expr->args[array_idx];
+ auto* num_layers = b.Call("textureNumLayers", ctx.Clone(texture_arg));
+ auto* zero = b.Expr(0);
+ auto* max = ctx.dst->Sub(num_layers, 1);
+ auto* clamped = b.Call("clamp", ctx.Clone(arg), zero, max);
+ ctx.Replace(arg, clamped);
+ }
+
+ // Clamp the level argument, if provided
+ if (level_idx >= 0) {
+ auto* arg = expr->args[level_idx];
+ ctx.Replace(arg, level_arg ? level_arg() : ctx.dst->Expr(0));
+ }
+
+ return nullptr; // Clone, which will use the argument replacements above.
}
-
- ProgramBuilder& b = *ctx.dst;
-
- // Indices of the mandatory texture and coords parameters, and the optional
- // array and level parameters.
- auto& signature = builtin->Signature();
- auto texture_idx = signature.IndexOf(sem::ParameterUsage::kTexture);
- auto coords_idx = signature.IndexOf(sem::ParameterUsage::kCoords);
- auto array_idx = signature.IndexOf(sem::ParameterUsage::kArrayIndex);
- auto level_idx = signature.IndexOf(sem::ParameterUsage::kLevel);
-
- auto* texture_arg = expr->args[texture_idx];
- auto* coords_arg = expr->args[coords_idx];
- auto* coords_ty = builtin->Parameters()[coords_idx]->Type();
-
- // If the level is provided, then we need to clamp this. As the level is
- // used by textureDimensions() and the texture[Load|Store]() calls, we need
- // to clamp both usages.
- // TODO(bclayton): We probably want to place this into a let so that the
- // calculation can be reused. This is fiddly to get right.
- std::function<const ast::Expression*()> level_arg;
- if (level_idx >= 0) {
- level_arg = [&] {
- auto* arg = expr->args[level_idx];
- auto* num_levels = b.Call("textureNumLevels", ctx.Clone(texture_arg));
- auto* zero = b.Expr(0);
- auto* max = ctx.dst->Sub(num_levels, 1);
- auto* clamped = b.Call("clamp", ctx.Clone(arg), zero, max);
- return clamped;
- };
- }
-
- // Clamp the coordinates argument
- {
- auto* texture_dims =
- level_arg
- ? b.Call("textureDimensions", ctx.Clone(texture_arg), level_arg())
- : b.Call("textureDimensions", ctx.Clone(texture_arg));
- auto* zero = b.Construct(CreateASTTypeFor(ctx, coords_ty));
- auto* max = ctx.dst->Sub(
- texture_dims, b.Construct(CreateASTTypeFor(ctx, coords_ty), 1));
- auto* clamped_coords = b.Call("clamp", ctx.Clone(coords_arg), zero, max);
- ctx.Replace(coords_arg, clamped_coords);
- }
-
- // Clamp the array_index argument, if provided
- if (array_idx >= 0) {
- auto* arg = expr->args[array_idx];
- auto* num_layers = b.Call("textureNumLayers", ctx.Clone(texture_arg));
- auto* zero = b.Expr(0);
- auto* max = ctx.dst->Sub(num_layers, 1);
- auto* clamped = b.Call("clamp", ctx.Clone(arg), zero, max);
- ctx.Replace(arg, clamped);
- }
-
- // Clamp the level argument, if provided
- if (level_idx >= 0) {
- auto* arg = expr->args[level_idx];
- ctx.Replace(arg, level_arg ? level_arg() : ctx.dst->Expr(0));
- }
-
- return nullptr; // Clone, which will use the argument replacements above.
- }
};
Robustness::Config::Config() = default;
@@ -292,27 +284,27 @@
Robustness::~Robustness() = default;
void Robustness::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
- Config cfg;
- if (auto* cfg_data = inputs.Get<Config>()) {
- cfg = *cfg_data;
- }
-
- std::unordered_set<ast::StorageClass> omitted_classes;
- for (auto sc : cfg.omitted_classes) {
- switch (sc) {
- case StorageClass::kUniform:
- omitted_classes.insert(ast::StorageClass::kUniform);
- break;
- case StorageClass::kStorage:
- omitted_classes.insert(ast::StorageClass::kStorage);
- break;
+ Config cfg;
+ if (auto* cfg_data = inputs.Get<Config>()) {
+ cfg = *cfg_data;
}
- }
- State state{ctx, std::move(omitted_classes)};
+ std::unordered_set<ast::StorageClass> omitted_classes;
+ for (auto sc : cfg.omitted_classes) {
+ switch (sc) {
+ case StorageClass::kUniform:
+ omitted_classes.insert(ast::StorageClass::kUniform);
+ break;
+ case StorageClass::kStorage:
+ omitted_classes.insert(ast::StorageClass::kStorage);
+ break;
+ }
+ }
- state.Transform();
- ctx.Clone();
+ State state{ctx, std::move(omitted_classes)};
+
+ state.Transform();
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/robustness.h b/src/tint/transform/robustness.h
index 79ddf08..138b48c 100644
--- a/src/tint/transform/robustness.h
+++ b/src/tint/transform/robustness.h
@@ -32,51 +32,49 @@
/// to zero and any access past the end of the array will clamp to
/// (array length - 1).
class Robustness : public Castable<Robustness, Transform> {
- public:
- /// Storage class to be skipped in the transform
- enum class StorageClass {
- kUniform,
- kStorage,
- };
+ public:
+ /// Storage class to be skipped in the transform
+ enum class StorageClass {
+ kUniform,
+ kStorage,
+ };
- /// Configuration options for the transform
- struct Config : public Castable<Config, Data> {
+ /// Configuration options for the transform
+ struct Config : public Castable<Config, Data> {
+ /// Constructor
+ Config();
+
+ /// Copy constructor
+ Config(const Config&);
+
+ /// Destructor
+ ~Config() override;
+
+ /// Assignment operator
+ /// @returns this Config
+ Config& operator=(const Config&);
+
+ /// Storage classes to omit from apply the transform to.
+ /// This allows for optimizing on hardware that provide safe accesses.
+ std::unordered_set<StorageClass> omitted_classes;
+ };
+
/// Constructor
- Config();
-
- /// Copy constructor
- Config(const Config&);
-
+ Robustness();
/// Destructor
- ~Config() override;
+ ~Robustness() override;
- /// Assignment operator
- /// @returns this Config
- Config& operator=(const Config&);
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
- /// Storage classes to omit from apply the transform to.
- /// This allows for optimizing on hardware that provide safe accesses.
- std::unordered_set<StorageClass> omitted_classes;
- };
-
- /// Constructor
- Robustness();
- /// Destructor
- ~Robustness() override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
-
- private:
- struct State;
+ private:
+ struct State;
};
} // namespace tint::transform
diff --git a/src/tint/transform/robustness_test.cc b/src/tint/transform/robustness_test.cc
index db113b2..481be31 100644
--- a/src/tint/transform/robustness_test.cc
+++ b/src/tint/transform/robustness_test.cc
@@ -22,7 +22,7 @@
using RobustnessTest = TransformTest;
TEST_F(RobustnessTest, Array_Idx_Clamp) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : array<f32, 3>;
let c : u32 = 1u;
@@ -32,7 +32,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : array<f32, 3>;
let c : u32 = 1u;
@@ -42,13 +42,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Array_Idx_Clamp_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let b : f32 = a[c];
}
@@ -58,7 +58,7 @@
var<private> a : array<f32, 3>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
let b : f32 = a[1u];
}
@@ -68,13 +68,13 @@
var<private> a : array<f32, 3>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Array_Idx_Nested_Scalar) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : array<f32, 3>;
var<private> b : array<i32, 5>;
@@ -86,7 +86,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : array<f32, 3>;
var<private> b : array<i32, 5>;
@@ -98,13 +98,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Array_Idx_Nested_Scalar_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var c : f32 = a[ b[i] ];
}
@@ -116,7 +116,7 @@
var<private> a : array<f32, 3>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var c : f32 = a[min(u32(b[min(i, 4u)]), 2u)];
}
@@ -128,13 +128,13 @@
var<private> a : array<f32, 3>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Array_Idx_Scalar) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : array<f32, 3>;
fn f() {
@@ -142,7 +142,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : array<f32, 3>;
fn f() {
@@ -150,13 +150,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Array_Idx_Scalar_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a[1];
}
@@ -164,7 +164,7 @@
var<private> a : array<f32, 3>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a[1];
}
@@ -172,13 +172,13 @@
var<private> a : array<f32, 3>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Array_Idx_Expr) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : array<f32, 3>;
var<private> c : i32;
@@ -188,7 +188,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : array<f32, 3>;
var<private> c : i32;
@@ -198,13 +198,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Array_Idx_Expr_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a[c + 2 - 3];
}
@@ -214,7 +214,7 @@
var<private> a : array<f32, 3>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a[min(u32(((c + 2) - 3)), 2u)];
}
@@ -224,13 +224,13 @@
var<private> a : array<f32, 3>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Array_Idx_Negative) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : array<f32, 3>;
fn f() {
@@ -238,7 +238,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : array<f32, 3>;
fn f() {
@@ -246,13 +246,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Array_Idx_Negative_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a[-1];
}
@@ -260,7 +260,7 @@
var<private> a : array<f32, 3>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a[0];
}
@@ -268,13 +268,13 @@
var<private> a : array<f32, 3>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Array_Idx_OutOfBounds) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : array<f32, 3>;
fn f() {
@@ -282,7 +282,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : array<f32, 3>;
fn f() {
@@ -290,13 +290,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Array_Idx_OutOfBounds_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a[3];
}
@@ -304,7 +304,7 @@
var<private> a : array<f32, 3>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a[2];
}
@@ -312,15 +312,15 @@
var<private> a : array<f32, 3>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// TODO(crbug.com/tint/1177) - Validation currently forbids arrays larger than
// 0xffffffff. If WGSL supports 64-bit indexing, re-enable this test.
TEST_F(RobustnessTest, DISABLED_LargeArrays_Idx) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : array<f32, 0x7fffffff>,
b : array<f32>,
@@ -358,7 +358,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : array<f32, 2147483647>,
b : array<f32>,
@@ -392,13 +392,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Vector_Idx_Scalar) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : vec3<f32>;
fn f() {
@@ -406,7 +406,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : vec3<f32>;
fn f() {
@@ -414,13 +414,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Vector_Idx_Scalar_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a[1];
}
@@ -428,7 +428,7 @@
var<private> a : vec3<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a[1];
}
@@ -436,13 +436,13 @@
var<private> a : vec3<f32>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Vector_Idx_Expr) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : vec3<f32>;
var<private> c : i32;
@@ -452,7 +452,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : vec3<f32>;
var<private> c : i32;
@@ -462,13 +462,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Vector_Idx_Expr_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a[c + 2 - 3];
}
@@ -478,7 +478,7 @@
var<private> a : vec3<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a[min(u32(((c + 2) - 3)), 2u)];
}
@@ -488,13 +488,13 @@
var<private> a : vec3<f32>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Vector_Swizzle_Idx_Scalar) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : vec3<f32>;
fn f() {
@@ -502,7 +502,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : vec3<f32>;
fn f() {
@@ -510,13 +510,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Vector_Swizzle_Idx_Scalar_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a.xy[2];
}
@@ -524,7 +524,7 @@
var<private> a : vec3<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a.xy[1];
}
@@ -532,13 +532,13 @@
var<private> a : vec3<f32>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Vector_Swizzle_Idx_Var) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : vec3<f32>;
var<private> c : i32;
@@ -548,7 +548,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : vec3<f32>;
var<private> c : i32;
@@ -558,13 +558,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Vector_Swizzle_Idx_Var_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a.xy[c];
}
@@ -574,7 +574,7 @@
var<private> a : vec3<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a.xy[min(u32(c), 1u)];
}
@@ -584,13 +584,13 @@
var<private> a : vec3<f32>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Vector_Swizzle_Idx_Expr) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : vec3<f32>;
var<private> c : i32;
@@ -600,7 +600,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : vec3<f32>;
var<private> c : i32;
@@ -610,13 +610,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Vector_Swizzle_Idx_Expr_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a.xy[c + 2 - 3];
}
@@ -626,7 +626,7 @@
var<private> a : vec3<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a.xy[min(u32(((c + 2) - 3)), 1u)];
}
@@ -636,13 +636,13 @@
var<private> a : vec3<f32>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Vector_Idx_Negative) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : vec3<f32>;
fn f() {
@@ -650,7 +650,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : vec3<f32>;
fn f() {
@@ -658,13 +658,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Vector_Idx_Negative_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a[-1];
}
@@ -672,7 +672,7 @@
var<private> a : vec3<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a[0];
}
@@ -680,13 +680,13 @@
var<private> a : vec3<f32>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Vector_Idx_OutOfBounds) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : vec3<f32>;
fn f() {
@@ -694,7 +694,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : vec3<f32>;
fn f() {
@@ -702,13 +702,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Vector_Idx_OutOfBounds_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a[3];
}
@@ -716,7 +716,7 @@
var<private> a : vec3<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a[2];
}
@@ -724,13 +724,13 @@
var<private> a : vec3<f32>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Matrix_Idx_Scalar) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : mat3x2<f32>;
fn f() {
@@ -738,7 +738,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : mat3x2<f32>;
fn f() {
@@ -746,13 +746,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Matrix_Idx_Scalar_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a[2][1];
}
@@ -760,7 +760,7 @@
var<private> a : mat3x2<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a[2][1];
}
@@ -768,13 +768,13 @@
var<private> a : mat3x2<f32>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Matrix_Idx_Expr_Column) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : mat3x2<f32>;
var<private> c : i32;
@@ -784,7 +784,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : mat3x2<f32>;
var<private> c : i32;
@@ -794,13 +794,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Matrix_Idx_Expr_Column_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a[c + 2 - 3][1];
}
@@ -810,7 +810,7 @@
var<private> a : mat3x2<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a[min(u32(((c + 2) - 3)), 2u)][1];
}
@@ -820,13 +820,13 @@
var<private> a : mat3x2<f32>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Matrix_Idx_Expr_Row) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : mat3x2<f32>;
var<private> c : i32;
@@ -836,7 +836,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : mat3x2<f32>;
var<private> c : i32;
@@ -846,13 +846,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Matrix_Idx_Expr_Row_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a[1][c + 2 - 3];
}
@@ -862,7 +862,7 @@
var<private> a : mat3x2<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a[1][min(u32(((c + 2) - 3)), 1u)];
}
@@ -872,13 +872,13 @@
var<private> a : mat3x2<f32>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Matrix_Idx_Negative_Column) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : mat3x2<f32>;
fn f() {
@@ -886,7 +886,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : mat3x2<f32>;
fn f() {
@@ -894,13 +894,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Matrix_Idx_Negative_Column_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a[-1][1];
}
@@ -908,7 +908,7 @@
var<private> a : mat3x2<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a[0][1];
}
@@ -916,13 +916,13 @@
var<private> a : mat3x2<f32>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Matrix_Idx_Negative_Row) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : mat3x2<f32>;
fn f() {
@@ -930,7 +930,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : mat3x2<f32>;
fn f() {
@@ -938,13 +938,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Matrix_Idx_Negative_Row_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a[2][-1];
}
@@ -952,7 +952,7 @@
var<private> a : mat3x2<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a[2][0];
}
@@ -960,13 +960,13 @@
var<private> a : mat3x2<f32>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Matrix_Idx_OutOfBounds_Column) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : mat3x2<f32>;
fn f() {
@@ -974,7 +974,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : mat3x2<f32>;
fn f() {
@@ -982,13 +982,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Matrix_Idx_OutOfBounds_Column_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a[5][1];
}
@@ -996,7 +996,7 @@
var<private> a : mat3x2<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a[2][1];
}
@@ -1004,13 +1004,13 @@
var<private> a : mat3x2<f32>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Matrix_Idx_OutOfBounds_Row) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : mat3x2<f32>;
fn f() {
@@ -1018,7 +1018,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : mat3x2<f32>;
fn f() {
@@ -1026,13 +1026,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, Matrix_Idx_OutOfBounds_Row_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var b : f32 = a[2][5];
}
@@ -1040,7 +1040,7 @@
var<private> a : mat3x2<f32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var b : f32 = a[2][1];
}
@@ -1048,49 +1048,49 @@
var<private> a : mat3x2<f32>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// TODO(dsinclair): Implement when constant_id exists
TEST_F(RobustnessTest, DISABLED_Vector_Constant_Id_Clamps) {
- // @id(1300) override idx : i32;
- // var a : vec3<f32>
- // var b : f32 = a[idx]
- //
- // ->var b : f32 = a[min(u32(idx), 2)]
+ // @id(1300) override idx : i32;
+ // var a : vec3<f32>
+ // var b : f32 = a[idx]
+ //
+ // ->var b : f32 = a[min(u32(idx), 2)]
}
// TODO(dsinclair): Implement when constant_id exists
TEST_F(RobustnessTest, DISABLED_Array_Constant_Id_Clamps) {
- // @id(1300) override idx : i32;
- // var a : array<f32, 4>
- // var b : f32 = a[idx]
- //
- // -> var b : f32 = a[min(u32(idx), 3)]
+ // @id(1300) override idx : i32;
+ // var a : array<f32, 4>
+ // var b : f32 = a[idx]
+ //
+ // -> var b : f32 = a[min(u32(idx), 3)]
}
// TODO(dsinclair): Implement when constant_id exists
TEST_F(RobustnessTest, DISABLED_Matrix_Column_Constant_Id_Clamps) {
- // @id(1300) override idx : i32;
- // var a : mat3x2<f32>
- // var b : f32 = a[idx][1]
- //
- // -> var b : f32 = a[min(u32(idx), 2)][1]
+ // @id(1300) override idx : i32;
+ // var a : mat3x2<f32>
+ // var b : f32 = a[idx][1]
+ //
+ // -> var b : f32 = a[min(u32(idx), 2)][1]
}
// TODO(dsinclair): Implement when constant_id exists
TEST_F(RobustnessTest, DISABLED_Matrix_Row_Constant_Id_Clamps) {
- // @id(1300) override idx : i32;
- // var a : mat3x2<f32>
- // var b : f32 = a[1][idx]
- //
- // -> var b : f32 = a[1][min(u32(idx), 0, 1)]
+ // @id(1300) override idx : i32;
+ // var a : mat3x2<f32>
+ // var b : f32 = a[1][idx]
+ //
+ // -> var b : f32 = a[1][min(u32(idx), 0, 1)]
}
TEST_F(RobustnessTest, RuntimeArray_Clamps) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : f32,
b : array<f32>,
@@ -1102,7 +1102,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : f32,
b : array<f32>,
@@ -1115,13 +1115,13 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, RuntimeArray_Clamps_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var d : f32 = s.b[25];
}
@@ -1134,7 +1134,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var d : f32 = s.b[min(25u, (arrayLength(&(s.b)) - 1u))];
}
@@ -1147,14 +1147,14 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Clamp textureLoad() coord, array_index and level values
TEST_F(RobustnessTest, TextureLoad_Clamp) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var tex_1d : texture_1d<f32>;
@group(0) @binding(0) var tex_2d : texture_2d<f32>;
@group(0) @binding(0) var tex_2d_arr : texture_2d_array<f32>;
@@ -1180,8 +1180,8 @@
}
)";
- auto* expect =
- R"(
+ auto* expect =
+ R"(
@group(0) @binding(0) var tex_1d : texture_1d<f32>;
@group(0) @binding(0) var tex_2d : texture_2d<f32>;
@@ -1213,14 +1213,14 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Clamp textureLoad() coord, array_index and level values
TEST_F(RobustnessTest, TextureLoad_Clamp_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var array_idx : i32;
var level_idx : i32;
@@ -1246,8 +1246,8 @@
@group(0) @binding(0) var tex_external : texture_external;
)";
- auto* expect =
- R"(
+ auto* expect =
+ R"(
fn f() {
var array_idx : i32;
var level_idx : i32;
@@ -1279,14 +1279,14 @@
@group(0) @binding(0) var tex_external : texture_external;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Clamp textureStore() coord, array_index and level values
TEST_F(RobustnessTest, TextureStore_Clamp) {
- auto* src = R"(
+ auto* src = R"(
@group(0) @binding(0) var tex1d : texture_storage_1d<rgba8sint, write>;
@group(0) @binding(1) var tex2d : texture_storage_2d<rgba8sint, write>;
@@ -1303,7 +1303,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@group(0) @binding(0) var tex1d : texture_storage_1d<rgba8sint, write>;
@group(0) @binding(1) var tex2d : texture_storage_2d<rgba8sint, write>;
@@ -1320,14 +1320,14 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// Clamp textureStore() coord, array_index and level values
TEST_F(RobustnessTest, TextureStore_Clamp_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
textureStore(tex1d, 10, vec4<i32>());
textureStore(tex2d, vec2<i32>(10, 20), vec4<i32>());
@@ -1345,7 +1345,7 @@
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
textureStore(tex1d, clamp(10, i32(), (textureDimensions(tex1d) - i32(1))), vec4<i32>());
textureStore(tex2d, clamp(vec2<i32>(10, 20), vec2<i32>(), (textureDimensions(tex2d) - vec2<i32>(1))), vec4<i32>());
@@ -1362,29 +1362,29 @@
@group(0) @binding(3) var tex3d : texture_storage_3d<rgba8sint, write>;
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// TODO(dsinclair): Test for scoped variables when shadowing is implemented
TEST_F(RobustnessTest, DISABLED_Shadowed_Variable) {
- // var a : array<f32, 3>;
- // var i : u32;
- // {
- // var a : array<f32, 5>;
- // var b : f32 = a[i];
- // }
- // var c : f32 = a[i];
- //
- // -> var b : f32 = a[min(u32(i), 4)];
- // var c : f32 = a[min(u32(i), 2)];
- FAIL();
+ // var a : array<f32, 3>;
+ // var i : u32;
+ // {
+ // var a : array<f32, 5>;
+ // var b : f32 = a[i];
+ // }
+ // var c : f32 = a[i];
+ //
+ // -> var b : f32 = a[min(u32(i), 4)];
+ // var c : f32 = a[min(u32(i), 2)];
+ FAIL();
}
// Check that existing use of min() and arrayLength() do not get renamed.
TEST_F(RobustnessTest, DontRenameSymbols) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : f32,
b : array<f32>,
@@ -1401,7 +1401,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : f32,
b : array<f32>,
@@ -1418,9 +1418,9 @@
}
)";
- auto got = Run<Robustness>(src);
+ auto got = Run<Robustness>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
const char* kOmitSourceShader = R"(
@@ -1481,7 +1481,7 @@
)";
TEST_F(RobustnessTest, OmitNone) {
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : array<f32, 4>,
b : array<f32>,
@@ -1534,17 +1534,17 @@
}
)";
- Robustness::Config cfg;
- DataMap data;
- data.Add<Robustness::Config>(cfg);
+ Robustness::Config cfg;
+ DataMap data;
+ data.Add<Robustness::Config>(cfg);
- auto got = Run<Robustness>(kOmitSourceShader, data);
+ auto got = Run<Robustness>(kOmitSourceShader, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, OmitStorage) {
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : array<f32, 4>,
b : array<f32>,
@@ -1597,19 +1597,19 @@
}
)";
- Robustness::Config cfg;
- cfg.omitted_classes.insert(Robustness::StorageClass::kStorage);
+ Robustness::Config cfg;
+ cfg.omitted_classes.insert(Robustness::StorageClass::kStorage);
- DataMap data;
- data.Add<Robustness::Config>(cfg);
+ DataMap data;
+ data.Add<Robustness::Config>(cfg);
- auto got = Run<Robustness>(kOmitSourceShader, data);
+ auto got = Run<Robustness>(kOmitSourceShader, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, OmitUniform) {
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : array<f32, 4>,
b : array<f32>,
@@ -1662,19 +1662,19 @@
}
)";
- Robustness::Config cfg;
- cfg.omitted_classes.insert(Robustness::StorageClass::kUniform);
+ Robustness::Config cfg;
+ cfg.omitted_classes.insert(Robustness::StorageClass::kUniform);
- DataMap data;
- data.Add<Robustness::Config>(cfg);
+ DataMap data;
+ data.Add<Robustness::Config>(cfg);
- auto got = Run<Robustness>(kOmitSourceShader, data);
+ auto got = Run<Robustness>(kOmitSourceShader, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, OmitBoth) {
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : array<f32, 4>,
b : array<f32>,
@@ -1727,16 +1727,16 @@
}
)";
- Robustness::Config cfg;
- cfg.omitted_classes.insert(Robustness::StorageClass::kStorage);
- cfg.omitted_classes.insert(Robustness::StorageClass::kUniform);
+ Robustness::Config cfg;
+ cfg.omitted_classes.insert(Robustness::StorageClass::kStorage);
+ cfg.omitted_classes.insert(Robustness::StorageClass::kUniform);
- DataMap data;
- data.Add<Robustness::Config>(cfg);
+ DataMap data;
+ data.Add<Robustness::Config>(cfg);
- auto got = Run<Robustness>(kOmitSourceShader, data);
+ auto got = Run<Robustness>(kOmitSourceShader, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/simplify_pointers.cc b/src/tint/transform/simplify_pointers.cc
index d102ca4..b8f82bc 100644
--- a/src/tint/transform/simplify_pointers.cc
+++ b/src/tint/transform/simplify_pointers.cc
@@ -35,195 +35,193 @@
/// PointerOp describes either possible indirection or address-of action on an
/// expression.
struct PointerOp {
- /// Positive: Number of times the `expr` was dereferenced (*expr)
- /// Negative: Number of times the `expr` was 'addressed-of' (&expr)
- /// Zero: no pointer op on `expr`
- int indirections = 0;
- /// The expression being operated on
- const ast::Expression* expr = nullptr;
+ /// Positive: Number of times the `expr` was dereferenced (*expr)
+ /// Negative: Number of times the `expr` was 'addressed-of' (&expr)
+ /// Zero: no pointer op on `expr`
+ int indirections = 0;
+ /// The expression being operated on
+ const ast::Expression* expr = nullptr;
};
} // namespace
/// The PIMPL state for the SimplifyPointers transform
struct SimplifyPointers::State {
- /// The clone context
- CloneContext& ctx;
+ /// The clone context
+ CloneContext& ctx;
- /// Constructor
- /// @param context the clone context
- explicit State(CloneContext& context) : ctx(context) {}
+ /// Constructor
+ /// @param context the clone context
+ explicit State(CloneContext& context) : ctx(context) {}
- /// Traverses the expression `expr` looking for non-literal array indexing
- /// expressions that would affect the computed address of a pointer
- /// expression. The function-like argument `cb` is called for each found.
- /// @param expr the expression to traverse
- /// @param cb a function-like object with the signature
- /// `void(const ast::Expression*)`, which is called for each array index
- /// expression
- template <typename F>
- static void CollectSavedArrayIndices(const ast::Expression* expr, F&& cb) {
- if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
- CollectSavedArrayIndices(a->object, cb);
- if (!a->index->Is<ast::LiteralExpression>()) {
- cb(a->index);
- }
- return;
- }
-
- if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
- CollectSavedArrayIndices(m->structure, cb);
- return;
- }
-
- if (auto* u = expr->As<ast::UnaryOpExpression>()) {
- CollectSavedArrayIndices(u->expr, cb);
- return;
- }
-
- // Note: Other ast::Expression types can be safely ignored as they cannot be
- // used to generate a reference or pointer.
- // See https://gpuweb.github.io/gpuweb/wgsl/#forming-references-and-pointers
- }
-
- /// Reduce walks the expression chain, collapsing all address-of and
- /// indirection ops into a PointerOp.
- /// @param in the expression to walk
- /// @returns the reduced PointerOp
- PointerOp Reduce(const ast::Expression* in) const {
- PointerOp op{0, in};
- while (true) {
- if (auto* unary = op.expr->As<ast::UnaryOpExpression>()) {
- switch (unary->op) {
- case ast::UnaryOp::kIndirection:
- op.indirections++;
- op.expr = unary->expr;
- continue;
- case ast::UnaryOp::kAddressOf:
- op.indirections--;
- op.expr = unary->expr;
- continue;
- default:
- break;
- }
- }
- if (auto* user = ctx.src->Sem().Get<sem::VariableUser>(op.expr)) {
- auto* var = user->Variable();
- if (var->Is<sem::LocalVariable>() && //
- var->Declaration()->is_const && //
- var->Type()->Is<sem::Pointer>()) {
- op.expr = var->Declaration()->constructor;
- continue;
- }
- }
- return op;
- }
- }
-
- /// Performs the transformation
- void Run() {
- // A map of saved expressions to their saved variable name
- std::unordered_map<const ast::Expression*, Symbol> saved_vars;
-
- // Register the ast::Expression transform handler.
- // This performs two different transformations:
- // * Identifiers that resolve to the pointer-typed `let` declarations are
- // replaced with the recursively inlined initializer expression for the
- // `let` declaration.
- // * Sub-expressions inside the pointer-typed `let` initializer expression
- // that have been hoisted to a saved variable are replaced with the saved
- // variable identifier.
- ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
- // Look to see if we need to swap this Expression with a saved variable.
- auto it = saved_vars.find(expr);
- if (it != saved_vars.end()) {
- return ctx.dst->Expr(it->second);
- }
-
- // Reduce the expression, folding away chains of address-of / indirections
- auto op = Reduce(expr);
-
- // Clone the reduced root expression
- expr = ctx.CloneWithoutTransform(op.expr);
-
- // And reapply the minimum number of address-of / indirections
- for (int i = 0; i < op.indirections; i++) {
- expr = ctx.dst->Deref(expr);
- }
- for (int i = 0; i > op.indirections; i--) {
- expr = ctx.dst->AddressOf(expr);
- }
- return expr;
- });
-
- // Find all the pointer-typed `let` declarations.
- // Note that these must be function-scoped, as module-scoped `let`s are not
- // permitted.
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- if (auto* let = node->As<ast::VariableDeclStatement>()) {
- if (!let->variable->is_const) {
- continue; // Not a `let` declaration. Ignore.
+ /// Traverses the expression `expr` looking for non-literal array indexing
+ /// expressions that would affect the computed address of a pointer
+ /// expression. The function-like argument `cb` is called for each found.
+ /// @param expr the expression to traverse
+ /// @param cb a function-like object with the signature
+ /// `void(const ast::Expression*)`, which is called for each array index
+ /// expression
+ template <typename F>
+ static void CollectSavedArrayIndices(const ast::Expression* expr, F&& cb) {
+ if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
+ CollectSavedArrayIndices(a->object, cb);
+ if (!a->index->Is<ast::LiteralExpression>()) {
+ cb(a->index);
+ }
+ return;
}
- auto* var = ctx.src->Sem().Get(let->variable);
- if (!var->Type()->Is<sem::Pointer>()) {
- continue; // Not a pointer type. Ignore.
+ if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
+ CollectSavedArrayIndices(m->structure, cb);
+ return;
}
- // We're dealing with a pointer-typed `let` declaration.
-
- // Scan the initializer expression for array index expressions that need
- // to be hoist to temporary "saved" variables.
- std::vector<const ast::VariableDeclStatement*> saved;
- CollectSavedArrayIndices(
- var->Declaration()->constructor,
- [&](const ast::Expression* idx_expr) {
- // We have a sub-expression that needs to be saved.
- // Create a new variable
- auto saved_name = ctx.dst->Symbols().New(
- ctx.src->Symbols().NameFor(var->Declaration()->symbol) +
- "_save");
- auto* decl = ctx.dst->Decl(
- ctx.dst->Let(saved_name, nullptr, ctx.Clone(idx_expr)));
- saved.emplace_back(decl);
- // Record the substitution of `idx_expr` to the saved variable
- // with the symbol `saved_name`. This will be used by the
- // ReplaceAll() handler above.
- saved_vars.emplace(idx_expr, saved_name);
- });
-
- // Find the place to insert the saved declarations.
- // Special care needs to be made for lets declared as the initializer
- // part of for-loops. In this case the block will hold the for-loop
- // statement, not the let.
- if (!saved.empty()) {
- auto* stmt = ctx.src->Sem().Get(let);
- auto* block = stmt->Block();
- // Find the statement owned by the block (either the let decl or a
- // for-loop)
- while (block != stmt->Parent()) {
- stmt = stmt->Parent();
- }
- // Declare the stored variables just before stmt. Order here is
- // important as order-of-operations needs to be preserved.
- // CollectSavedArrayIndices() visits the LHS of an index accessor
- // before the index expression.
- for (auto* decl : saved) {
- // Note that repeated calls to InsertBefore() with the same `before`
- // argument will result in nodes to inserted in the order the
- // calls are made (last call is inserted last).
- ctx.InsertBefore(block->Declaration()->statements,
- stmt->Declaration(), decl);
- }
+ if (auto* u = expr->As<ast::UnaryOpExpression>()) {
+ CollectSavedArrayIndices(u->expr, cb);
+ return;
}
- // As the original `let` declaration will be fully inlined, there's no
- // need for the original declaration to exist. Remove it.
- RemoveStatement(ctx, let);
- }
+ // Note: Other ast::Expression types can be safely ignored as they cannot be
+ // used to generate a reference or pointer.
+ // See https://gpuweb.github.io/gpuweb/wgsl/#forming-references-and-pointers
}
- ctx.Clone();
- }
+
+ /// Reduce walks the expression chain, collapsing all address-of and
+ /// indirection ops into a PointerOp.
+ /// @param in the expression to walk
+ /// @returns the reduced PointerOp
+ PointerOp Reduce(const ast::Expression* in) const {
+ PointerOp op{0, in};
+ while (true) {
+ if (auto* unary = op.expr->As<ast::UnaryOpExpression>()) {
+ switch (unary->op) {
+ case ast::UnaryOp::kIndirection:
+ op.indirections++;
+ op.expr = unary->expr;
+ continue;
+ case ast::UnaryOp::kAddressOf:
+ op.indirections--;
+ op.expr = unary->expr;
+ continue;
+ default:
+ break;
+ }
+ }
+ if (auto* user = ctx.src->Sem().Get<sem::VariableUser>(op.expr)) {
+ auto* var = user->Variable();
+ if (var->Is<sem::LocalVariable>() && //
+ var->Declaration()->is_const && //
+ var->Type()->Is<sem::Pointer>()) {
+ op.expr = var->Declaration()->constructor;
+ continue;
+ }
+ }
+ return op;
+ }
+ }
+
+ /// Performs the transformation
+ void Run() {
+ // A map of saved expressions to their saved variable name
+ std::unordered_map<const ast::Expression*, Symbol> saved_vars;
+
+ // Register the ast::Expression transform handler.
+ // This performs two different transformations:
+ // * Identifiers that resolve to the pointer-typed `let` declarations are
+ // replaced with the recursively inlined initializer expression for the
+ // `let` declaration.
+ // * Sub-expressions inside the pointer-typed `let` initializer expression
+ // that have been hoisted to a saved variable are replaced with the saved
+ // variable identifier.
+ ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
+ // Look to see if we need to swap this Expression with a saved variable.
+ auto it = saved_vars.find(expr);
+ if (it != saved_vars.end()) {
+ return ctx.dst->Expr(it->second);
+ }
+
+ // Reduce the expression, folding away chains of address-of / indirections
+ auto op = Reduce(expr);
+
+ // Clone the reduced root expression
+ expr = ctx.CloneWithoutTransform(op.expr);
+
+ // And reapply the minimum number of address-of / indirections
+ for (int i = 0; i < op.indirections; i++) {
+ expr = ctx.dst->Deref(expr);
+ }
+ for (int i = 0; i > op.indirections; i--) {
+ expr = ctx.dst->AddressOf(expr);
+ }
+ return expr;
+ });
+
+ // Find all the pointer-typed `let` declarations.
+ // Note that these must be function-scoped, as module-scoped `let`s are not
+ // permitted.
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* let = node->As<ast::VariableDeclStatement>()) {
+ if (!let->variable->is_const) {
+ continue; // Not a `let` declaration. Ignore.
+ }
+
+ auto* var = ctx.src->Sem().Get(let->variable);
+ if (!var->Type()->Is<sem::Pointer>()) {
+ continue; // Not a pointer type. Ignore.
+ }
+
+ // We're dealing with a pointer-typed `let` declaration.
+
+ // Scan the initializer expression for array index expressions that need
+ // to be hoist to temporary "saved" variables.
+ std::vector<const ast::VariableDeclStatement*> saved;
+ CollectSavedArrayIndices(
+ var->Declaration()->constructor, [&](const ast::Expression* idx_expr) {
+ // We have a sub-expression that needs to be saved.
+ // Create a new variable
+ auto saved_name = ctx.dst->Symbols().New(
+ ctx.src->Symbols().NameFor(var->Declaration()->symbol) + "_save");
+ auto* decl =
+ ctx.dst->Decl(ctx.dst->Let(saved_name, nullptr, ctx.Clone(idx_expr)));
+ saved.emplace_back(decl);
+ // Record the substitution of `idx_expr` to the saved variable
+ // with the symbol `saved_name`. This will be used by the
+ // ReplaceAll() handler above.
+ saved_vars.emplace(idx_expr, saved_name);
+ });
+
+ // Find the place to insert the saved declarations.
+ // Special care needs to be made for lets declared as the initializer
+ // part of for-loops. In this case the block will hold the for-loop
+ // statement, not the let.
+ if (!saved.empty()) {
+ auto* stmt = ctx.src->Sem().Get(let);
+ auto* block = stmt->Block();
+ // Find the statement owned by the block (either the let decl or a
+ // for-loop)
+ while (block != stmt->Parent()) {
+ stmt = stmt->Parent();
+ }
+ // Declare the stored variables just before stmt. Order here is
+ // important as order-of-operations needs to be preserved.
+ // CollectSavedArrayIndices() visits the LHS of an index accessor
+ // before the index expression.
+ for (auto* decl : saved) {
+ // Note that repeated calls to InsertBefore() with the same `before`
+ // argument will result in nodes to inserted in the order the
+ // calls are made (last call is inserted last).
+ ctx.InsertBefore(block->Declaration()->statements, stmt->Declaration(),
+ decl);
+ }
+ }
+
+ // As the original `let` declaration will be fully inlined, there's no
+ // need for the original declaration to exist. Remove it.
+ RemoveStatement(ctx, let);
+ }
+ }
+ ctx.Clone();
+ }
};
SimplifyPointers::SimplifyPointers() = default;
@@ -231,7 +229,7 @@
SimplifyPointers::~SimplifyPointers() = default;
void SimplifyPointers::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- State(ctx).Run();
+ State(ctx).Run();
}
} // namespace tint::transform
diff --git a/src/tint/transform/simplify_pointers.h b/src/tint/transform/simplify_pointers.h
index 3bd4950..267b7b2 100644
--- a/src/tint/transform/simplify_pointers.h
+++ b/src/tint/transform/simplify_pointers.h
@@ -32,25 +32,23 @@
/// @note Depends on the following transforms to have been run first:
/// * Unshadow
class SimplifyPointers : public Castable<SimplifyPointers, Transform> {
- public:
- /// Constructor
- SimplifyPointers();
+ public:
+ /// Constructor
+ SimplifyPointers();
- /// Destructor
- ~SimplifyPointers() override;
+ /// Destructor
+ ~SimplifyPointers() override;
- protected:
- struct State;
+ protected:
+ struct State;
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/simplify_pointers_test.cc b/src/tint/transform/simplify_pointers_test.cc
index 6266b6f..f5658de 100644
--- a/src/tint/transform/simplify_pointers_test.cc
+++ b/src/tint/transform/simplify_pointers_test.cc
@@ -23,16 +23,16 @@
using SimplifyPointersTest = TransformTest;
TEST_F(SimplifyPointersTest, EmptyModule) {
- auto* src = "";
- auto* expect = "";
+ auto* src = "";
+ auto* expect = "";
- auto got = Run<Unshadow, SimplifyPointers>(src);
+ auto got = Run<Unshadow, SimplifyPointers>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, FoldPointer) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var v : i32;
let p : ptr<function, i32> = &v;
@@ -40,20 +40,20 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var v : i32;
let x : i32 = v;
}
)";
- auto got = Run<Unshadow, SimplifyPointers>(src);
+ auto got = Run<Unshadow, SimplifyPointers>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, AddressOfDeref) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var v : i32;
let p : ptr<function, i32> = &(v);
@@ -66,7 +66,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var v : i32;
var a = v;
@@ -75,13 +75,13 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers>(src);
+ auto got = Run<Unshadow, SimplifyPointers>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, DerefAddressOf) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var v : i32;
let x : i32 = *(&(v));
@@ -90,7 +90,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var v : i32;
let x : i32 = v;
@@ -99,13 +99,13 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers>(src);
+ auto got = Run<Unshadow, SimplifyPointers>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, ComplexChain) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var a : array<mat4x4<f32>, 4>;
let ap : ptr<function, array<mat4x4<f32>, 4>> = &a;
@@ -115,20 +115,20 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var a : array<mat4x4<f32>, 4>;
let v : vec4<f32> = a[3][2];
}
)";
- auto got = Run<Unshadow, SimplifyPointers>(src);
+ auto got = Run<Unshadow, SimplifyPointers>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, SavedVars) {
- auto* src = R"(
+ auto* src = R"(
struct S {
i : i32,
};
@@ -152,7 +152,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
i : i32,
}
@@ -176,13 +176,13 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers>(src);
+ auto got = Run<Unshadow, SimplifyPointers>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, DontSaveLiterals) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var arr : array<i32, 2>;
let p1 : ptr<function, i32> = &arr[1];
@@ -190,20 +190,20 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var arr : array<i32, 2>;
arr[1] = 4;
}
)";
- auto got = Run<Unshadow, SimplifyPointers>(src);
+ auto got = Run<Unshadow, SimplifyPointers>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, SavedVarsChain) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var arr : array<array<i32, 2>, 2>;
let i : i32 = 0;
@@ -214,7 +214,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var arr : array<array<i32, 2>, 2>;
let i : i32 = 0;
@@ -225,13 +225,13 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers>(src);
+ auto got = Run<Unshadow, SimplifyPointers>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, ForLoopInit) {
- auto* src = R"(
+ auto* src = R"(
fn foo() -> i32 {
return 1;
}
@@ -246,7 +246,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn foo() -> i32 {
return 1;
}
@@ -262,13 +262,13 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers>(src);
+ auto got = Run<Unshadow, SimplifyPointers>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, MultiSavedVarsInSinglePtrLetExpr) {
- auto* src = R"(
+ auto* src = R"(
fn x() -> i32 {
return 1;
}
@@ -297,7 +297,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn x() -> i32 {
return 1;
}
@@ -328,13 +328,13 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers>(src);
+ auto got = Run<Unshadow, SimplifyPointers>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, ShadowPointer) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : array<i32, 2>;
@stage(compute) @workgroup_size(1)
@@ -347,7 +347,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : array<i32, 2>;
@stage(compute) @workgroup_size(1)
@@ -359,9 +359,9 @@
}
)";
- auto got = Run<Unshadow, SimplifyPointers>(src);
+ auto got = Run<Unshadow, SimplifyPointers>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/single_entry_point.cc b/src/tint/transform/single_entry_point.cc
index 5fd21d1..82324c7 100644
--- a/src/tint/transform/single_entry_point.cc
+++ b/src/tint/transform/single_entry_point.cc
@@ -30,88 +30,82 @@
SingleEntryPoint::~SingleEntryPoint() = default;
-void SingleEntryPoint::Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap&) const {
- auto* cfg = inputs.Get<Config>();
- if (cfg == nullptr) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform,
- "missing transform data for " + std::string(TypeInfo().name));
+void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
+ auto* cfg = inputs.Get<Config>();
+ if (cfg == nullptr) {
+ ctx.dst->Diagnostics().add_error(
+ diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
- return;
- }
-
- // Find the target entry point.
- const ast::Function* entry_point = nullptr;
- for (auto* f : ctx.src->AST().Functions()) {
- if (!f->IsEntryPoint()) {
- continue;
+ return;
}
- if (ctx.src->Symbols().NameFor(f->symbol) == cfg->entry_point_name) {
- entry_point = f;
- break;
- }
- }
- if (entry_point == nullptr) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform,
- "entry point '" + cfg->entry_point_name + "' not found");
- return;
- }
- auto& sem = ctx.src->Sem();
-
- // Build set of referenced module-scope variables for faster lookups later.
- std::unordered_set<const ast::Variable*> referenced_vars;
- for (auto* var : sem.Get(entry_point)->TransitivelyReferencedGlobals()) {
- referenced_vars.emplace(var->Declaration());
- }
-
- // Clone any module-scope variables, types, and functions that are statically
- // referenced by the target entry point.
- for (auto* decl : ctx.src->AST().GlobalDeclarations()) {
- if (auto* ty = decl->As<ast::TypeDecl>()) {
- // TODO(jrprice): Strip unused types.
- ctx.dst->AST().AddTypeDecl(ctx.Clone(ty));
- } else if (auto* var = decl->As<ast::Variable>()) {
- if (referenced_vars.count(var)) {
- if (var->is_overridable) {
- // It is an overridable constant
- if (!ast::HasAttribute<ast::IdAttribute>(var->attributes)) {
- // If the constant doesn't already have an @id() attribute, add one
- // so that its allocated ID so that it won't be affected by other
- // stripped away constants
- auto* global = sem.Get(var)->As<sem::GlobalVariable>();
- const auto* id = ctx.dst->Id(global->ConstantId());
- ctx.InsertFront(var->attributes, id);
- }
+ // Find the target entry point.
+ const ast::Function* entry_point = nullptr;
+ for (auto* f : ctx.src->AST().Functions()) {
+ if (!f->IsEntryPoint()) {
+ continue;
}
- ctx.dst->AST().AddGlobalVariable(ctx.Clone(var));
- }
- } else if (auto* func = decl->As<ast::Function>()) {
- if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol)) {
- ctx.dst->AST().AddFunction(ctx.Clone(func));
- }
- } else if (auto* ext = decl->As<ast::Enable>()) {
- ctx.dst->AST().AddEnable(ctx.Clone(ext));
- } else {
- TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
- << "unhandled global declaration: " << decl->TypeInfo().name;
- return;
+ if (ctx.src->Symbols().NameFor(f->symbol) == cfg->entry_point_name) {
+ entry_point = f;
+ break;
+ }
}
- }
+ if (entry_point == nullptr) {
+ ctx.dst->Diagnostics().add_error(diag::System::Transform,
+ "entry point '" + cfg->entry_point_name + "' not found");
+ return;
+ }
- // Clone the entry point.
- ctx.dst->AST().AddFunction(ctx.Clone(entry_point));
+ auto& sem = ctx.src->Sem();
+
+ // Build set of referenced module-scope variables for faster lookups later.
+ std::unordered_set<const ast::Variable*> referenced_vars;
+ for (auto* var : sem.Get(entry_point)->TransitivelyReferencedGlobals()) {
+ referenced_vars.emplace(var->Declaration());
+ }
+
+ // Clone any module-scope variables, types, and functions that are statically
+ // referenced by the target entry point.
+ for (auto* decl : ctx.src->AST().GlobalDeclarations()) {
+ if (auto* ty = decl->As<ast::TypeDecl>()) {
+ // TODO(jrprice): Strip unused types.
+ ctx.dst->AST().AddTypeDecl(ctx.Clone(ty));
+ } else if (auto* var = decl->As<ast::Variable>()) {
+ if (referenced_vars.count(var)) {
+ if (var->is_overridable) {
+ // It is an overridable constant
+ if (!ast::HasAttribute<ast::IdAttribute>(var->attributes)) {
+ // If the constant doesn't already have an @id() attribute, add one
+ // so that its allocated ID so that it won't be affected by other
+ // stripped away constants
+ auto* global = sem.Get(var)->As<sem::GlobalVariable>();
+ const auto* id = ctx.dst->Id(global->ConstantId());
+ ctx.InsertFront(var->attributes, id);
+ }
+ }
+ ctx.dst->AST().AddGlobalVariable(ctx.Clone(var));
+ }
+ } else if (auto* func = decl->As<ast::Function>()) {
+ if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol)) {
+ ctx.dst->AST().AddFunction(ctx.Clone(func));
+ }
+ } else if (auto* ext = decl->As<ast::Enable>()) {
+ ctx.dst->AST().AddEnable(ctx.Clone(ext));
+ } else {
+ TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
+ << "unhandled global declaration: " << decl->TypeInfo().name;
+ return;
+ }
+ }
+
+ // Clone the entry point.
+ ctx.dst->AST().AddFunction(ctx.Clone(entry_point));
}
-SingleEntryPoint::Config::Config(std::string entry_point)
- : entry_point_name(entry_point) {}
+SingleEntryPoint::Config::Config(std::string entry_point) : entry_point_name(entry_point) {}
SingleEntryPoint::Config::Config(const Config&) = default;
SingleEntryPoint::Config::~Config() = default;
-SingleEntryPoint::Config& SingleEntryPoint::Config::operator=(const Config&) =
- default;
+SingleEntryPoint::Config& SingleEntryPoint::Config::operator=(const Config&) = default;
} // namespace tint::transform
diff --git a/src/tint/transform/single_entry_point.h b/src/tint/transform/single_entry_point.h
index b5aed68..0a922a7 100644
--- a/src/tint/transform/single_entry_point.h
+++ b/src/tint/transform/single_entry_point.h
@@ -26,43 +26,41 @@
/// All module-scope variables, types, and functions that are not used by the
/// target entry point will also be removed.
class SingleEntryPoint : public Castable<SingleEntryPoint, Transform> {
- public:
- /// Configuration options for the transform
- struct Config : public Castable<Config, Data> {
- /// Constructor
- /// @param entry_point the name of the entry point to keep
- explicit Config(std::string entry_point = "");
+ public:
+ /// Configuration options for the transform
+ struct Config : public Castable<Config, Data> {
+ /// Constructor
+ /// @param entry_point the name of the entry point to keep
+ explicit Config(std::string entry_point = "");
- /// Copy constructor
- Config(const Config&);
+ /// Copy constructor
+ Config(const Config&);
+
+ /// Destructor
+ ~Config() override;
+
+ /// Assignment operator
+ /// @returns this Config
+ Config& operator=(const Config&);
+
+ /// The name of the entry point to keep.
+ std::string entry_point_name;
+ };
+
+ /// Constructor
+ SingleEntryPoint();
/// Destructor
- ~Config() override;
+ ~SingleEntryPoint() override;
- /// Assignment operator
- /// @returns this Config
- Config& operator=(const Config&);
-
- /// The name of the entry point to keep.
- std::string entry_point_name;
- };
-
- /// Constructor
- SingleEntryPoint();
-
- /// Destructor
- ~SingleEntryPoint() override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/single_entry_point_test.cc b/src/tint/transform/single_entry_point_test.cc
index 750f5c3..8044621 100644
--- a/src/tint/transform/single_entry_point_test.cc
+++ b/src/tint/transform/single_entry_point_test.cc
@@ -24,84 +24,83 @@
using SingleEntryPointTest = TransformTest;
TEST_F(SingleEntryPointTest, Error_MissingTransformData) {
- auto* src = "";
+ auto* src = "";
- auto* expect =
- "error: missing transform data for tint::transform::SingleEntryPoint";
+ auto* expect = "error: missing transform data for tint::transform::SingleEntryPoint";
- auto got = Run<SingleEntryPoint>(src);
+ auto got = Run<SingleEntryPoint>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, Error_NoEntryPoints) {
- auto* src = "";
+ auto* src = "";
- auto* expect = "error: entry point 'main' not found";
+ auto* expect = "error: entry point 'main' not found";
- DataMap data;
- data.Add<SingleEntryPoint::Config>("main");
- auto got = Run<SingleEntryPoint>(src, data);
+ DataMap data;
+ data.Add<SingleEntryPoint::Config>("main");
+ auto got = Run<SingleEntryPoint>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, Error_InvalidEntryPoint) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn main() -> @builtin(position) vec4<f32> {
return vec4<f32>();
}
)";
- auto* expect = "error: entry point '_' not found";
+ auto* expect = "error: entry point '_' not found";
- SingleEntryPoint::Config cfg("_");
+ SingleEntryPoint::Config cfg("_");
- DataMap data;
- data.Add<SingleEntryPoint::Config>(cfg);
- auto got = Run<SingleEntryPoint>(src, data);
+ DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, Error_NotAnEntryPoint) {
- auto* src = R"(
+ auto* src = R"(
fn foo() {}
@stage(fragment)
fn main() {}
)";
- auto* expect = "error: entry point 'foo' not found";
+ auto* expect = "error: entry point 'foo' not found";
- SingleEntryPoint::Config cfg("foo");
+ SingleEntryPoint::Config cfg("foo");
- DataMap data;
- data.Add<SingleEntryPoint::Config>(cfg);
- auto got = Run<SingleEntryPoint>(src, data);
+ DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, SingleEntryPoint) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
}
)";
- SingleEntryPoint::Config cfg("main");
+ SingleEntryPoint::Config cfg("main");
- DataMap data;
- data.Add<SingleEntryPoint::Config>(cfg);
- auto got = Run<SingleEntryPoint>(src, data);
+ DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
- EXPECT_EQ(src, str(got));
+ EXPECT_EQ(src, str(got));
}
TEST_F(SingleEntryPointTest, MultipleEntryPoints) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn vert_main() -> @builtin(position) vec4<f32> {
return vec4<f32>();
@@ -120,23 +119,23 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn comp_main1() {
}
)";
- SingleEntryPoint::Config cfg("comp_main1");
+ SingleEntryPoint::Config cfg("comp_main1");
- DataMap data;
- data.Add<SingleEntryPoint::Config>(cfg);
- auto got = Run<SingleEntryPoint>(src, data);
+ DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, GlobalVariables) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : f32;
var<private> b : f32;
@@ -167,7 +166,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> c : f32;
@stage(compute) @workgroup_size(1)
@@ -176,17 +175,17 @@
}
)";
- SingleEntryPoint::Config cfg("comp_main1");
+ SingleEntryPoint::Config cfg("comp_main1");
- DataMap data;
- data.Add<SingleEntryPoint::Config>(cfg);
- auto got = Run<SingleEntryPoint>(src, data);
+ DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, GlobalConstants) {
- auto* src = R"(
+ auto* src = R"(
let a : f32 = 1.0;
let b : f32 = 1.0;
@@ -217,7 +216,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
let c : f32 = 1.0;
@stage(compute) @workgroup_size(1)
@@ -226,17 +225,17 @@
}
)";
- SingleEntryPoint::Config cfg("comp_main1");
+ SingleEntryPoint::Config cfg("comp_main1");
- DataMap data;
- data.Add<SingleEntryPoint::Config>(cfg);
- auto got = Run<SingleEntryPoint>(src, data);
+ DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, WorkgroupSizeLetPreserved) {
- auto* src = R"(
+ auto* src = R"(
let size : i32 = 1;
@stage(compute) @workgroup_size(size)
@@ -244,19 +243,19 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- SingleEntryPoint::Config cfg("main");
+ SingleEntryPoint::Config cfg("main");
- DataMap data;
- data.Add<SingleEntryPoint::Config>(cfg);
- auto got = Run<SingleEntryPoint>(src, data);
+ DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, OverridableConstants) {
- auto* src = R"(
+ auto* src = R"(
@id(1001) override c1 : u32 = 1u;
override c2 : u32 = 1u;
@id(0) override c3 : u32 = 1u;
@@ -288,9 +287,9 @@
}
)";
- {
- SingleEntryPoint::Config cfg("comp_main1");
- auto* expect = R"(
+ {
+ SingleEntryPoint::Config cfg("comp_main1");
+ auto* expect = R"(
@id(1001) override c1 : u32 = 1u;
@stage(compute) @workgroup_size(1)
@@ -298,17 +297,17 @@
let local_d = c1;
}
)";
- DataMap data;
- data.Add<SingleEntryPoint::Config>(cfg);
- auto got = Run<SingleEntryPoint>(src, data);
- EXPECT_EQ(expect, str(got));
- }
+ DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+ EXPECT_EQ(expect, str(got));
+ }
- {
- SingleEntryPoint::Config cfg("comp_main2");
- // The decorator is replaced with the one with explicit id
- // And should not be affected by other constants stripped away
- auto* expect = R"(
+ {
+ SingleEntryPoint::Config cfg("comp_main2");
+ // The decorator is replaced with the one with explicit id
+ // And should not be affected by other constants stripped away
+ auto* expect = R"(
@id(1) override c2 : u32 = 1u;
@stage(compute) @workgroup_size(1)
@@ -316,15 +315,15 @@
let local_d = c2;
}
)";
- DataMap data;
- data.Add<SingleEntryPoint::Config>(cfg);
- auto got = Run<SingleEntryPoint>(src, data);
- EXPECT_EQ(expect, str(got));
- }
+ DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+ EXPECT_EQ(expect, str(got));
+ }
- {
- SingleEntryPoint::Config cfg("comp_main3");
- auto* expect = R"(
+ {
+ SingleEntryPoint::Config cfg("comp_main3");
+ auto* expect = R"(
@id(0) override c3 : u32 = 1u;
@stage(compute) @workgroup_size(1)
@@ -332,15 +331,15 @@
let local_d = c3;
}
)";
- DataMap data;
- data.Add<SingleEntryPoint::Config>(cfg);
- auto got = Run<SingleEntryPoint>(src, data);
- EXPECT_EQ(expect, str(got));
- }
+ DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+ EXPECT_EQ(expect, str(got));
+ }
- {
- SingleEntryPoint::Config cfg("comp_main4");
- auto* expect = R"(
+ {
+ SingleEntryPoint::Config cfg("comp_main4");
+ auto* expect = R"(
@id(9999) override c4 : u32 = 1u;
@stage(compute) @workgroup_size(1)
@@ -348,29 +347,29 @@
let local_d = c4;
}
)";
- DataMap data;
- data.Add<SingleEntryPoint::Config>(cfg);
- auto got = Run<SingleEntryPoint>(src, data);
- EXPECT_EQ(expect, str(got));
- }
+ DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+ EXPECT_EQ(expect, str(got));
+ }
- {
- SingleEntryPoint::Config cfg("comp_main5");
- auto* expect = R"(
+ {
+ SingleEntryPoint::Config cfg("comp_main5");
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn comp_main5() {
let local_d = 1u;
}
)";
- DataMap data;
- data.Add<SingleEntryPoint::Config>(cfg);
- auto got = Run<SingleEntryPoint>(src, data);
- EXPECT_EQ(expect, str(got));
- }
+ DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+ EXPECT_EQ(expect, str(got));
+ }
}
TEST_F(SingleEntryPointTest, CalledFunctions) {
- auto* src = R"(
+ auto* src = R"(
fn inner1() {
}
@@ -401,7 +400,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn inner1() {
}
@@ -419,17 +418,17 @@
}
)";
- SingleEntryPoint::Config cfg("comp_main1");
+ SingleEntryPoint::Config cfg("comp_main1");
- DataMap data;
- data.Add<SingleEntryPoint::Config>(cfg);
- auto got = Run<SingleEntryPoint>(src, data);
+ DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, GlobalsReferencedByCalledFunctions) {
- auto* src = R"(
+ auto* src = R"(
var<private> inner1_var : f32;
var<private> inner2_var : f32;
@@ -475,7 +474,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> inner1_var : f32;
var<private> inner_shared_var : f32;
@@ -502,13 +501,13 @@
}
)";
- SingleEntryPoint::Config cfg("comp_main1");
+ SingleEntryPoint::Config cfg("comp_main1");
- DataMap data;
- data.Add<SingleEntryPoint::Config>(cfg);
- auto got = Run<SingleEntryPoint>(src, data);
+ DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/test_helper.h b/src/tint/transform/test_helper.h
index 7d7015c..42218a7 100644
--- a/src/tint/transform/test_helper.h
+++ b/src/tint/transform/test_helper.h
@@ -31,115 +31,112 @@
/// @returns the output program as a WGSL string, or an error string if the
/// program is not valid.
inline std::string str(const Program& program) {
- diag::Formatter::Style style;
- style.print_newline_at_end = false;
+ diag::Formatter::Style style;
+ style.print_newline_at_end = false;
- if (!program.IsValid()) {
- return diag::Formatter(style).format(program.Diagnostics());
- }
+ if (!program.IsValid()) {
+ return diag::Formatter(style).format(program.Diagnostics());
+ }
- writer::wgsl::Options options;
- auto result = writer::wgsl::Generate(&program, options);
- if (!result.success) {
- return "WGSL writer failed:\n" + result.error;
- }
+ writer::wgsl::Options options;
+ auto result = writer::wgsl::Generate(&program, options);
+ if (!result.success) {
+ return "WGSL writer failed:\n" + result.error;
+ }
- auto res = result.wgsl;
- if (res.empty()) {
- return res;
- }
- // The WGSL sometimes has two trailing newlines. Strip them
- while (res.back() == '\n') {
- res.pop_back();
- }
- if (res.empty()) {
- return res;
- }
- return "\n" + res + "\n";
+ auto res = result.wgsl;
+ if (res.empty()) {
+ return res;
+ }
+ // The WGSL sometimes has two trailing newlines. Strip them
+ while (res.back() == '\n') {
+ res.pop_back();
+ }
+ if (res.empty()) {
+ return res;
+ }
+ return "\n" + res + "\n";
}
/// Helper class for testing transforms
template <typename BASE>
class TransformTestBase : public BASE {
- public:
- /// Transforms and returns the WGSL source `in`, transformed using
- /// `transform`.
- /// @param transform the transform to apply
- /// @param in the input WGSL source
- /// @param data the optional DataMap to pass to Transform::Run()
- /// @return the transformed output
- Output Run(std::string in,
- std::unique_ptr<transform::Transform> transform,
- const DataMap& data = {}) {
- std::vector<std::unique_ptr<transform::Transform>> transforms;
- transforms.emplace_back(std::move(transform));
- return Run(std::move(in), std::move(transforms), data);
- }
-
- /// Transforms and returns the WGSL source `in`, transformed using
- /// a transform of type `TRANSFORM`.
- /// @param in the input WGSL source
- /// @param data the optional DataMap to pass to Transform::Run()
- /// @return the transformed output
- template <typename... TRANSFORMS>
- Output Run(std::string in, const DataMap& data = {}) {
- auto file = std::make_unique<Source::File>("test", in);
- auto program = reader::wgsl::Parse(file.get());
-
- // Keep this pointer alive after Transform() returns
- files_.emplace_back(std::move(file));
-
- return Run<TRANSFORMS...>(std::move(program), data);
- }
-
- /// Transforms and returns program `program`, transformed using a transform of
- /// type `TRANSFORM`.
- /// @param program the input Program
- /// @param data the optional DataMap to pass to Transform::Run()
- /// @return the transformed output
- template <typename... TRANSFORMS>
- Output Run(Program&& program, const DataMap& data = {}) {
- if (!program.IsValid()) {
- return Output(std::move(program));
+ public:
+ /// Transforms and returns the WGSL source `in`, transformed using
+ /// `transform`.
+ /// @param transform the transform to apply
+ /// @param in the input WGSL source
+ /// @param data the optional DataMap to pass to Transform::Run()
+ /// @return the transformed output
+ Output Run(std::string in,
+ std::unique_ptr<transform::Transform> transform,
+ const DataMap& data = {}) {
+ std::vector<std::unique_ptr<transform::Transform>> transforms;
+ transforms.emplace_back(std::move(transform));
+ return Run(std::move(in), std::move(transforms), data);
}
- Manager manager;
- for (auto* transform_ptr :
- std::initializer_list<Transform*>{new TRANSFORMS()...}) {
- manager.append(std::unique_ptr<Transform>(transform_ptr));
+ /// Transforms and returns the WGSL source `in`, transformed using
+ /// a transform of type `TRANSFORM`.
+ /// @param in the input WGSL source
+ /// @param data the optional DataMap to pass to Transform::Run()
+ /// @return the transformed output
+ template <typename... TRANSFORMS>
+ Output Run(std::string in, const DataMap& data = {}) {
+ auto file = std::make_unique<Source::File>("test", in);
+ auto program = reader::wgsl::Parse(file.get());
+
+ // Keep this pointer alive after Transform() returns
+ files_.emplace_back(std::move(file));
+
+ return Run<TRANSFORMS...>(std::move(program), data);
}
- return manager.Run(&program, data);
- }
- /// @param program the input program
- /// @param data the optional DataMap to pass to Transform::Run()
- /// @return true if the transform should be run for the given input.
- template <typename TRANSFORM>
- bool ShouldRun(Program&& program, const DataMap& data = {}) {
- EXPECT_TRUE(program.IsValid()) << program.Diagnostics().str();
- const Transform& t = TRANSFORM();
- return t.ShouldRun(&program, data);
- }
+ /// Transforms and returns program `program`, transformed using a transform of
+ /// type `TRANSFORM`.
+ /// @param program the input Program
+ /// @param data the optional DataMap to pass to Transform::Run()
+ /// @return the transformed output
+ template <typename... TRANSFORMS>
+ Output Run(Program&& program, const DataMap& data = {}) {
+ if (!program.IsValid()) {
+ return Output(std::move(program));
+ }
- /// @param in the input WGSL source
- /// @param data the optional DataMap to pass to Transform::Run()
- /// @return true if the transform should be run for the given input.
- template <typename TRANSFORM>
- bool ShouldRun(std::string in, const DataMap& data = {}) {
- auto file = std::make_unique<Source::File>("test", in);
- auto program = reader::wgsl::Parse(file.get());
- return ShouldRun<TRANSFORM>(std::move(program), data);
- }
+ Manager manager;
+ for (auto* transform_ptr : std::initializer_list<Transform*>{new TRANSFORMS()...}) {
+ manager.append(std::unique_ptr<Transform>(transform_ptr));
+ }
+ return manager.Run(&program, data);
+ }
- /// @param output the output of the transform
- /// @returns the output program as a WGSL string, or an error string if the
- /// program is not valid.
- std::string str(const Output& output) {
- return transform::str(output.program);
- }
+ /// @param program the input program
+ /// @param data the optional DataMap to pass to Transform::Run()
+ /// @return true if the transform should be run for the given input.
+ template <typename TRANSFORM>
+ bool ShouldRun(Program&& program, const DataMap& data = {}) {
+ EXPECT_TRUE(program.IsValid()) << program.Diagnostics().str();
+ const Transform& t = TRANSFORM();
+ return t.ShouldRun(&program, data);
+ }
- private:
- std::vector<std::unique_ptr<Source::File>> files_;
+ /// @param in the input WGSL source
+ /// @param data the optional DataMap to pass to Transform::Run()
+ /// @return true if the transform should be run for the given input.
+ template <typename TRANSFORM>
+ bool ShouldRun(std::string in, const DataMap& data = {}) {
+ auto file = std::make_unique<Source::File>("test", in);
+ auto program = reader::wgsl::Parse(file.get());
+ return ShouldRun<TRANSFORM>(std::move(program), data);
+ }
+
+ /// @param output the output of the transform
+ /// @returns the output program as a WGSL string, or an error string if the
+ /// program is not valid.
+ std::string str(const Output& output) { return transform::str(output.program); }
+
+ private:
+ std::vector<std::unique_ptr<Source::File>> files_;
};
using TransformTest = TransformTestBase<testing::Test>;
diff --git a/src/tint/transform/transform.cc b/src/tint/transform/transform.cc
index adb709b..f1873f6 100644
--- a/src/tint/transform/transform.cc
+++ b/src/tint/transform/transform.cc
@@ -45,114 +45,109 @@
Transform::Transform() = default;
Transform::~Transform() = default;
-Output Transform::Run(const Program* program,
- const DataMap& data /* = {} */) const {
- ProgramBuilder builder;
- CloneContext ctx(&builder, program);
- Output output;
- Run(ctx, data, output.data);
- output.program = Program(std::move(builder));
- return output;
+Output Transform::Run(const Program* program, const DataMap& data /* = {} */) const {
+ ProgramBuilder builder;
+ CloneContext ctx(&builder, program);
+ Output output;
+ Run(ctx, data, output.data);
+ output.program = Program(std::move(builder));
+ return output;
}
void Transform::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- TINT_UNIMPLEMENTED(Transform, ctx.dst->Diagnostics())
- << "Transform::Run() unimplemented for " << TypeInfo().name;
+ TINT_UNIMPLEMENTED(Transform, ctx.dst->Diagnostics())
+ << "Transform::Run() unimplemented for " << TypeInfo().name;
}
bool Transform::ShouldRun(const Program*, const DataMap&) const {
- return true;
+ return true;
}
void Transform::RemoveStatement(CloneContext& ctx, const ast::Statement* stmt) {
- auto* sem = ctx.src->Sem().Get(stmt);
- if (auto* block = tint::As<sem::BlockStatement>(sem->Parent())) {
- ctx.Remove(block->Declaration()->statements, stmt);
- return;
- }
- if (tint::Is<sem::ForLoopStatement>(sem->Parent())) {
- ctx.Replace(stmt, static_cast<ast::Expression*>(nullptr));
- return;
- }
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "unable to remove statement from parent of type "
- << sem->TypeInfo().name;
+ auto* sem = ctx.src->Sem().Get(stmt);
+ if (auto* block = tint::As<sem::BlockStatement>(sem->Parent())) {
+ ctx.Remove(block->Declaration()->statements, stmt);
+ return;
+ }
+ if (tint::Is<sem::ForLoopStatement>(sem->Parent())) {
+ ctx.Replace(stmt, static_cast<ast::Expression*>(nullptr));
+ return;
+ }
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "unable to remove statement from parent of type " << sem->TypeInfo().name;
}
-const ast::Type* Transform::CreateASTTypeFor(CloneContext& ctx,
- const sem::Type* ty) {
- if (ty->Is<sem::Void>()) {
- return ctx.dst->create<ast::Void>();
- }
- if (ty->Is<sem::I32>()) {
- return ctx.dst->create<ast::I32>();
- }
- if (ty->Is<sem::U32>()) {
- return ctx.dst->create<ast::U32>();
- }
- if (ty->Is<sem::F32>()) {
- return ctx.dst->create<ast::F32>();
- }
- if (ty->Is<sem::Bool>()) {
- return ctx.dst->create<ast::Bool>();
- }
- if (auto* m = ty->As<sem::Matrix>()) {
- auto* el = CreateASTTypeFor(ctx, m->type());
- return ctx.dst->create<ast::Matrix>(el, m->rows(), m->columns());
- }
- if (auto* v = ty->As<sem::Vector>()) {
- auto* el = CreateASTTypeFor(ctx, v->type());
- return ctx.dst->create<ast::Vector>(el, v->Width());
- }
- if (auto* a = ty->As<sem::Array>()) {
- auto* el = CreateASTTypeFor(ctx, a->ElemType());
- ast::AttributeList attrs;
- if (!a->IsStrideImplicit()) {
- attrs.emplace_back(ctx.dst->create<ast::StrideAttribute>(a->Stride()));
+const ast::Type* Transform::CreateASTTypeFor(CloneContext& ctx, const sem::Type* ty) {
+ if (ty->Is<sem::Void>()) {
+ return ctx.dst->create<ast::Void>();
}
- if (a->IsRuntimeSized()) {
- return ctx.dst->ty.array(el, nullptr, std::move(attrs));
- } else {
- return ctx.dst->ty.array(el, a->Count(), std::move(attrs));
+ if (ty->Is<sem::I32>()) {
+ return ctx.dst->create<ast::I32>();
}
- }
- if (auto* s = ty->As<sem::Struct>()) {
- return ctx.dst->create<ast::TypeName>(ctx.Clone(s->Declaration()->name));
- }
- if (auto* s = ty->As<sem::Reference>()) {
- return CreateASTTypeFor(ctx, s->StoreType());
- }
- if (auto* a = ty->As<sem::Atomic>()) {
- return ctx.dst->create<ast::Atomic>(CreateASTTypeFor(ctx, a->Type()));
- }
- if (auto* t = ty->As<sem::DepthTexture>()) {
- return ctx.dst->create<ast::DepthTexture>(t->dim());
- }
- if (auto* t = ty->As<sem::DepthMultisampledTexture>()) {
- return ctx.dst->create<ast::DepthMultisampledTexture>(t->dim());
- }
- if (ty->Is<sem::ExternalTexture>()) {
- return ctx.dst->create<ast::ExternalTexture>();
- }
- if (auto* t = ty->As<sem::MultisampledTexture>()) {
- return ctx.dst->create<ast::MultisampledTexture>(
- t->dim(), CreateASTTypeFor(ctx, t->type()));
- }
- if (auto* t = ty->As<sem::SampledTexture>()) {
- return ctx.dst->create<ast::SampledTexture>(
- t->dim(), CreateASTTypeFor(ctx, t->type()));
- }
- if (auto* t = ty->As<sem::StorageTexture>()) {
- return ctx.dst->create<ast::StorageTexture>(
- t->dim(), t->texel_format(), CreateASTTypeFor(ctx, t->type()),
- t->access());
- }
- if (auto* s = ty->As<sem::Sampler>()) {
- return ctx.dst->create<ast::Sampler>(s->kind());
- }
- TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
- << "Unhandled type: " << ty->TypeInfo().name;
- return nullptr;
+ if (ty->Is<sem::U32>()) {
+ return ctx.dst->create<ast::U32>();
+ }
+ if (ty->Is<sem::F32>()) {
+ return ctx.dst->create<ast::F32>();
+ }
+ if (ty->Is<sem::Bool>()) {
+ return ctx.dst->create<ast::Bool>();
+ }
+ if (auto* m = ty->As<sem::Matrix>()) {
+ auto* el = CreateASTTypeFor(ctx, m->type());
+ return ctx.dst->create<ast::Matrix>(el, m->rows(), m->columns());
+ }
+ if (auto* v = ty->As<sem::Vector>()) {
+ auto* el = CreateASTTypeFor(ctx, v->type());
+ return ctx.dst->create<ast::Vector>(el, v->Width());
+ }
+ if (auto* a = ty->As<sem::Array>()) {
+ auto* el = CreateASTTypeFor(ctx, a->ElemType());
+ ast::AttributeList attrs;
+ if (!a->IsStrideImplicit()) {
+ attrs.emplace_back(ctx.dst->create<ast::StrideAttribute>(a->Stride()));
+ }
+ if (a->IsRuntimeSized()) {
+ return ctx.dst->ty.array(el, nullptr, std::move(attrs));
+ } else {
+ return ctx.dst->ty.array(el, a->Count(), std::move(attrs));
+ }
+ }
+ if (auto* s = ty->As<sem::Struct>()) {
+ return ctx.dst->create<ast::TypeName>(ctx.Clone(s->Declaration()->name));
+ }
+ if (auto* s = ty->As<sem::Reference>()) {
+ return CreateASTTypeFor(ctx, s->StoreType());
+ }
+ if (auto* a = ty->As<sem::Atomic>()) {
+ return ctx.dst->create<ast::Atomic>(CreateASTTypeFor(ctx, a->Type()));
+ }
+ if (auto* t = ty->As<sem::DepthTexture>()) {
+ return ctx.dst->create<ast::DepthTexture>(t->dim());
+ }
+ if (auto* t = ty->As<sem::DepthMultisampledTexture>()) {
+ return ctx.dst->create<ast::DepthMultisampledTexture>(t->dim());
+ }
+ if (ty->Is<sem::ExternalTexture>()) {
+ return ctx.dst->create<ast::ExternalTexture>();
+ }
+ if (auto* t = ty->As<sem::MultisampledTexture>()) {
+ return ctx.dst->create<ast::MultisampledTexture>(t->dim(),
+ CreateASTTypeFor(ctx, t->type()));
+ }
+ if (auto* t = ty->As<sem::SampledTexture>()) {
+ return ctx.dst->create<ast::SampledTexture>(t->dim(), CreateASTTypeFor(ctx, t->type()));
+ }
+ if (auto* t = ty->As<sem::StorageTexture>()) {
+ return ctx.dst->create<ast::StorageTexture>(t->dim(), t->texel_format(),
+ CreateASTTypeFor(ctx, t->type()), t->access());
+ }
+ if (auto* s = ty->As<sem::Sampler>()) {
+ return ctx.dst->create<ast::Sampler>(s->kind());
+ }
+ TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
+ << "Unhandled type: " << ty->TypeInfo().name;
+ return nullptr;
}
} // namespace tint::transform
diff --git a/src/tint/transform/transform.h b/src/tint/transform/transform.h
index 0a3b4f0..de57617 100644
--- a/src/tint/transform/transform.h
+++ b/src/tint/transform/transform.h
@@ -27,176 +27,171 @@
/// Data is the base class for transforms that accept extra input or emit extra
/// output information along with a Program.
class Data : public Castable<Data> {
- public:
- /// Constructor
- Data();
+ public:
+ /// Constructor
+ Data();
- /// Copy constructor
- Data(const Data&);
+ /// Copy constructor
+ Data(const Data&);
- /// Destructor
- ~Data() override;
+ /// Destructor
+ ~Data() override;
- /// Assignment operator
- /// @returns this Data
- Data& operator=(const Data&);
+ /// Assignment operator
+ /// @returns this Data
+ Data& operator=(const Data&);
};
/// DataMap is a map of Data unique pointers keyed by the Data's ClassID.
class DataMap {
- public:
- /// Constructor
- DataMap();
+ public:
+ /// Constructor
+ DataMap();
- /// Move constructor
- DataMap(DataMap&&);
+ /// Move constructor
+ DataMap(DataMap&&);
- /// Constructor
- /// @param data_unique_ptrs a variadic list of additional data unique_ptrs
- /// produced by the transform
- template <typename... DATA>
- explicit DataMap(DATA... data_unique_ptrs) {
- PutAll(std::forward<DATA>(data_unique_ptrs)...);
- }
-
- /// Destructor
- ~DataMap();
-
- /// Move assignment operator
- /// @param rhs the DataMap to move into this DataMap
- /// @return this DataMap
- DataMap& operator=(DataMap&& rhs);
-
- /// Adds the data into DataMap keyed by the ClassID of type T.
- /// @param data the data to add to the DataMap
- template <typename T>
- void Put(std::unique_ptr<T>&& data) {
- static_assert(std::is_base_of<Data, T>::value,
- "T does not derive from Data");
- map_[&TypeInfo::Of<T>()] = std::move(data);
- }
-
- /// Creates the data of type `T` with the provided arguments and adds it into
- /// DataMap keyed by the ClassID of type T.
- /// @param args the arguments forwarded to the constructor for type T
- template <typename T, typename... ARGS>
- void Add(ARGS&&... args) {
- Put(std::make_unique<T>(std::forward<ARGS>(args)...));
- }
-
- /// @returns a pointer to the Data placed into the DataMap with a call to
- /// Put()
- template <typename T>
- T const* Get() const {
- return const_cast<DataMap*>(this)->Get<T>();
- }
-
- /// @returns a pointer to the Data placed into the DataMap with a call to
- /// Put()
- template <typename T>
- T* Get() {
- auto it = map_.find(&TypeInfo::Of<T>());
- if (it == map_.end()) {
- return nullptr;
+ /// Constructor
+ /// @param data_unique_ptrs a variadic list of additional data unique_ptrs
+ /// produced by the transform
+ template <typename... DATA>
+ explicit DataMap(DATA... data_unique_ptrs) {
+ PutAll(std::forward<DATA>(data_unique_ptrs)...);
}
- return static_cast<T*>(it->second.get());
- }
- /// Add moves all the data from other into this DataMap
- /// @param other the DataMap to move into this DataMap
- void Add(DataMap&& other) {
- for (auto& it : other.map_) {
- map_.emplace(it.first, std::move(it.second));
+ /// Destructor
+ ~DataMap();
+
+ /// Move assignment operator
+ /// @param rhs the DataMap to move into this DataMap
+ /// @return this DataMap
+ DataMap& operator=(DataMap&& rhs);
+
+ /// Adds the data into DataMap keyed by the ClassID of type T.
+ /// @param data the data to add to the DataMap
+ template <typename T>
+ void Put(std::unique_ptr<T>&& data) {
+ static_assert(std::is_base_of<Data, T>::value, "T does not derive from Data");
+ map_[&TypeInfo::Of<T>()] = std::move(data);
}
- other.map_.clear();
- }
- private:
- template <typename T0>
- void PutAll(T0&& first) {
- Put(std::forward<T0>(first));
- }
+ /// Creates the data of type `T` with the provided arguments and adds it into
+ /// DataMap keyed by the ClassID of type T.
+ /// @param args the arguments forwarded to the constructor for type T
+ template <typename T, typename... ARGS>
+ void Add(ARGS&&... args) {
+ Put(std::make_unique<T>(std::forward<ARGS>(args)...));
+ }
- template <typename T0, typename... Tn>
- void PutAll(T0&& first, Tn&&... remainder) {
- Put(std::forward<T0>(first));
- PutAll(std::forward<Tn>(remainder)...);
- }
+ /// @returns a pointer to the Data placed into the DataMap with a call to
+ /// Put()
+ template <typename T>
+ T const* Get() const {
+ return const_cast<DataMap*>(this)->Get<T>();
+ }
- std::unordered_map<const TypeInfo*, std::unique_ptr<Data>> map_;
+ /// @returns a pointer to the Data placed into the DataMap with a call to
+ /// Put()
+ template <typename T>
+ T* Get() {
+ auto it = map_.find(&TypeInfo::Of<T>());
+ if (it == map_.end()) {
+ return nullptr;
+ }
+ return static_cast<T*>(it->second.get());
+ }
+
+ /// Add moves all the data from other into this DataMap
+ /// @param other the DataMap to move into this DataMap
+ void Add(DataMap&& other) {
+ for (auto& it : other.map_) {
+ map_.emplace(it.first, std::move(it.second));
+ }
+ other.map_.clear();
+ }
+
+ private:
+ template <typename T0>
+ void PutAll(T0&& first) {
+ Put(std::forward<T0>(first));
+ }
+
+ template <typename T0, typename... Tn>
+ void PutAll(T0&& first, Tn&&... remainder) {
+ Put(std::forward<T0>(first));
+ PutAll(std::forward<Tn>(remainder)...);
+ }
+
+ std::unordered_map<const TypeInfo*, std::unique_ptr<Data>> map_;
};
/// The return type of Run()
class Output {
- public:
- /// Constructor
- Output();
+ public:
+ /// Constructor
+ Output();
- /// Constructor
- /// @param program the program to move into this Output
- explicit Output(Program&& program);
+ /// Constructor
+ /// @param program the program to move into this Output
+ explicit Output(Program&& program);
- /// Constructor
- /// @param program_ the program to move into this Output
- /// @param data_ a variadic list of additional data unique_ptrs produced by
- /// the transform
- template <typename... DATA>
- Output(Program&& program_, DATA... data_)
- : program(std::move(program_)), data(std::forward<DATA>(data_)...) {}
+ /// Constructor
+ /// @param program_ the program to move into this Output
+ /// @param data_ a variadic list of additional data unique_ptrs produced by
+ /// the transform
+ template <typename... DATA>
+ Output(Program&& program_, DATA... data_)
+ : program(std::move(program_)), data(std::forward<DATA>(data_)...) {}
- /// The transformed program. May be empty on error.
- Program program;
+ /// The transformed program. May be empty on error.
+ Program program;
- /// Extra output generated by the transforms.
- DataMap data;
+ /// Extra output generated by the transforms.
+ DataMap data;
};
/// Interface for Program transforms
class Transform : public Castable<Transform> {
- public:
- /// Constructor
- Transform();
- /// Destructor
- ~Transform() override;
+ public:
+ /// Constructor
+ Transform();
+ /// Destructor
+ ~Transform() override;
- /// Runs the transform on `program`, returning the transformation result.
- /// @param program the source program to transform
- /// @param data optional extra transform-specific input data
- /// @returns the transformation result
- virtual Output Run(const Program* program, const DataMap& data = {}) const;
+ /// Runs the transform on `program`, returning the transformation result.
+ /// @param program the source program to transform
+ /// @param data optional extra transform-specific input data
+ /// @returns the transformation result
+ virtual Output Run(const Program* program, const DataMap& data = {}) const;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- virtual bool ShouldRun(const Program* program,
- const DataMap& data = {}) const;
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ virtual bool ShouldRun(const Program* program, const DataMap& data = {}) const;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- virtual void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ virtual void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const;
- /// Removes the statement `stmt` from the transformed program.
- /// RemoveStatement handles edge cases, like statements in the initializer and
- /// continuing of for-loops.
- /// @param ctx the clone context
- /// @param stmt the statement to remove when the program is cloned
- static void RemoveStatement(CloneContext& ctx, const ast::Statement* stmt);
+ /// Removes the statement `stmt` from the transformed program.
+ /// RemoveStatement handles edge cases, like statements in the initializer and
+ /// continuing of for-loops.
+ /// @param ctx the clone context
+ /// @param stmt the statement to remove when the program is cloned
+ static void RemoveStatement(CloneContext& ctx, const ast::Statement* stmt);
- /// CreateASTTypeFor constructs new ast::Type nodes that reconstructs the
- /// semantic type `ty`.
- /// @param ctx the clone context
- /// @param ty the semantic type to reconstruct
- /// @returns a ast::Type that when resolved, will produce the semantic type
- /// `ty`.
- static const ast::Type* CreateASTTypeFor(CloneContext& ctx,
- const sem::Type* ty);
+ /// CreateASTTypeFor constructs new ast::Type nodes that reconstructs the
+ /// semantic type `ty`.
+ /// @param ctx the clone context
+ /// @param ty the semantic type to reconstruct
+ /// @returns a ast::Type that when resolved, will produce the semantic type
+ /// `ty`.
+ static const ast::Type* CreateASTTypeFor(CloneContext& ctx, const sem::Type* ty);
};
} // namespace tint::transform
diff --git a/src/tint/transform/transform_test.cc b/src/tint/transform/transform_test.cc
index c100c09..cefe18f 100644
--- a/src/tint/transform/transform_test.cc
+++ b/src/tint/transform/transform_test.cc
@@ -23,98 +23,82 @@
// Inherit from Transform so we have access to protected methods
struct CreateASTTypeForTest : public testing::Test, public Transform {
- Output Run(const Program*, const DataMap&) const override { return {}; }
+ Output Run(const Program*, const DataMap&) const override { return {}; }
- const ast::Type* create(
- std::function<sem::Type*(ProgramBuilder&)> create_sem_type) {
- ProgramBuilder sem_type_builder;
- auto* sem_type = create_sem_type(sem_type_builder);
- Program program(std::move(sem_type_builder));
- CloneContext ctx(&ast_type_builder, &program, false);
- return CreateASTTypeFor(ctx, sem_type);
- }
+ const ast::Type* create(std::function<sem::Type*(ProgramBuilder&)> create_sem_type) {
+ ProgramBuilder sem_type_builder;
+ auto* sem_type = create_sem_type(sem_type_builder);
+ Program program(std::move(sem_type_builder));
+ CloneContext ctx(&ast_type_builder, &program, false);
+ return CreateASTTypeFor(ctx, sem_type);
+ }
- ProgramBuilder ast_type_builder;
+ ProgramBuilder ast_type_builder;
};
TEST_F(CreateASTTypeForTest, Basic) {
- EXPECT_TRUE(create([](ProgramBuilder& b) {
- return b.create<sem::I32>();
- })->Is<ast::I32>());
- EXPECT_TRUE(create([](ProgramBuilder& b) {
- return b.create<sem::U32>();
- })->Is<ast::U32>());
- EXPECT_TRUE(create([](ProgramBuilder& b) {
- return b.create<sem::F32>();
- })->Is<ast::F32>());
- EXPECT_TRUE(create([](ProgramBuilder& b) {
- return b.create<sem::Bool>();
- })->Is<ast::Bool>());
- EXPECT_TRUE(create([](ProgramBuilder& b) {
- return b.create<sem::Void>();
- })->Is<ast::Void>());
+ EXPECT_TRUE(create([](ProgramBuilder& b) { return b.create<sem::I32>(); })->Is<ast::I32>());
+ EXPECT_TRUE(create([](ProgramBuilder& b) { return b.create<sem::U32>(); })->Is<ast::U32>());
+ EXPECT_TRUE(create([](ProgramBuilder& b) { return b.create<sem::F32>(); })->Is<ast::F32>());
+ EXPECT_TRUE(create([](ProgramBuilder& b) { return b.create<sem::Bool>(); })->Is<ast::Bool>());
+ EXPECT_TRUE(create([](ProgramBuilder& b) { return b.create<sem::Void>(); })->Is<ast::Void>());
}
TEST_F(CreateASTTypeForTest, Matrix) {
- auto* mat = create([](ProgramBuilder& b) {
- auto* column_type = b.create<sem::Vector>(b.create<sem::F32>(), 2u);
- return b.create<sem::Matrix>(column_type, 3u);
- });
- ASSERT_TRUE(mat->Is<ast::Matrix>());
- ASSERT_TRUE(mat->As<ast::Matrix>()->type->Is<ast::F32>());
- ASSERT_EQ(mat->As<ast::Matrix>()->columns, 3u);
- ASSERT_EQ(mat->As<ast::Matrix>()->rows, 2u);
+ auto* mat = create([](ProgramBuilder& b) {
+ auto* column_type = b.create<sem::Vector>(b.create<sem::F32>(), 2u);
+ return b.create<sem::Matrix>(column_type, 3u);
+ });
+ ASSERT_TRUE(mat->Is<ast::Matrix>());
+ ASSERT_TRUE(mat->As<ast::Matrix>()->type->Is<ast::F32>());
+ ASSERT_EQ(mat->As<ast::Matrix>()->columns, 3u);
+ ASSERT_EQ(mat->As<ast::Matrix>()->rows, 2u);
}
TEST_F(CreateASTTypeForTest, Vector) {
- auto* vec = create([](ProgramBuilder& b) {
- return b.create<sem::Vector>(b.create<sem::F32>(), 2u);
- });
- ASSERT_TRUE(vec->Is<ast::Vector>());
- ASSERT_TRUE(vec->As<ast::Vector>()->type->Is<ast::F32>());
- ASSERT_EQ(vec->As<ast::Vector>()->width, 2u);
+ auto* vec =
+ create([](ProgramBuilder& b) { return b.create<sem::Vector>(b.create<sem::F32>(), 2u); });
+ ASSERT_TRUE(vec->Is<ast::Vector>());
+ ASSERT_TRUE(vec->As<ast::Vector>()->type->Is<ast::F32>());
+ ASSERT_EQ(vec->As<ast::Vector>()->width, 2u);
}
TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
- auto* arr = create([](ProgramBuilder& b) {
- return b.create<sem::Array>(b.create<sem::F32>(), 2u, 4u, 4u, 32u, 32u);
- });
- ASSERT_TRUE(arr->Is<ast::Array>());
- ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());
- ASSERT_EQ(arr->As<ast::Array>()->attributes.size(), 0u);
+ auto* arr = create([](ProgramBuilder& b) {
+ return b.create<sem::Array>(b.create<sem::F32>(), 2u, 4u, 4u, 32u, 32u);
+ });
+ ASSERT_TRUE(arr->Is<ast::Array>());
+ ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());
+ ASSERT_EQ(arr->As<ast::Array>()->attributes.size(), 0u);
- auto* size = arr->As<ast::Array>()->count->As<ast::IntLiteralExpression>();
- ASSERT_NE(size, nullptr);
- EXPECT_EQ(size->ValueAsI32(), 2);
+ auto* size = arr->As<ast::Array>()->count->As<ast::IntLiteralExpression>();
+ ASSERT_NE(size, nullptr);
+ EXPECT_EQ(size->ValueAsI32(), 2);
}
TEST_F(CreateASTTypeForTest, ArrayNonImplicitStride) {
- auto* arr = create([](ProgramBuilder& b) {
- return b.create<sem::Array>(b.create<sem::F32>(), 2u, 4u, 4u, 64u, 32u);
- });
- ASSERT_TRUE(arr->Is<ast::Array>());
- ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());
- ASSERT_EQ(arr->As<ast::Array>()->attributes.size(), 1u);
- ASSERT_TRUE(arr->As<ast::Array>()->attributes[0]->Is<ast::StrideAttribute>());
- ASSERT_EQ(
- arr->As<ast::Array>()->attributes[0]->As<ast::StrideAttribute>()->stride,
- 64u);
+ auto* arr = create([](ProgramBuilder& b) {
+ return b.create<sem::Array>(b.create<sem::F32>(), 2u, 4u, 4u, 64u, 32u);
+ });
+ ASSERT_TRUE(arr->Is<ast::Array>());
+ ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());
+ ASSERT_EQ(arr->As<ast::Array>()->attributes.size(), 1u);
+ ASSERT_TRUE(arr->As<ast::Array>()->attributes[0]->Is<ast::StrideAttribute>());
+ ASSERT_EQ(arr->As<ast::Array>()->attributes[0]->As<ast::StrideAttribute>()->stride, 64u);
- auto* size = arr->As<ast::Array>()->count->As<ast::IntLiteralExpression>();
- ASSERT_NE(size, nullptr);
- EXPECT_EQ(size->ValueAsI32(), 2);
+ auto* size = arr->As<ast::Array>()->count->As<ast::IntLiteralExpression>();
+ ASSERT_NE(size, nullptr);
+ EXPECT_EQ(size->ValueAsI32(), 2);
}
TEST_F(CreateASTTypeForTest, Struct) {
- auto* str = create([](ProgramBuilder& b) {
- auto* decl = b.Structure("S", {});
- return b.create<sem::Struct>(decl, decl->name, sem::StructMemberList{},
- 4u /* align */, 4u /* size */,
- 4u /* size_no_padding */);
- });
- ASSERT_TRUE(str->Is<ast::TypeName>());
- EXPECT_EQ(ast_type_builder.Symbols().NameFor(str->As<ast::TypeName>()->name),
- "S");
+ auto* str = create([](ProgramBuilder& b) {
+ auto* decl = b.Structure("S", {});
+ return b.create<sem::Struct>(decl, decl->name, sem::StructMemberList{}, 4u /* align */,
+ 4u /* size */, 4u /* size_no_padding */);
+ });
+ ASSERT_TRUE(str->Is<ast::TypeName>());
+ EXPECT_EQ(ast_type_builder.Symbols().NameFor(str->As<ast::TypeName>()->name), "S");
}
} // namespace
diff --git a/src/tint/transform/unshadow.cc b/src/tint/transform/unshadow.cc
index 9c28675..dcf90da 100644
--- a/src/tint/transform/unshadow.cc
+++ b/src/tint/transform/unshadow.cc
@@ -30,60 +30,60 @@
/// The PIMPL state for the Unshadow transform
struct Unshadow::State {
- /// The clone context
- CloneContext& ctx;
+ /// The clone context
+ CloneContext& ctx;
- /// Constructor
- /// @param context the clone context
- explicit State(CloneContext& context) : ctx(context) {}
+ /// Constructor
+ /// @param context the clone context
+ explicit State(CloneContext& context) : ctx(context) {}
- /// Performs the transformation
- void Run() {
- auto& sem = ctx.src->Sem();
+ /// Performs the transformation
+ void Run() {
+ auto& sem = ctx.src->Sem();
- // Maps a variable to its new name.
- std::unordered_map<const sem::Variable*, Symbol> renamed_to;
+ // Maps a variable to its new name.
+ std::unordered_map<const sem::Variable*, Symbol> renamed_to;
- auto rename = [&](const sem::Variable* var) -> const ast::Variable* {
- auto* decl = var->Declaration();
- auto name = ctx.src->Symbols().NameFor(decl->symbol);
- auto symbol = ctx.dst->Symbols().New(name);
- renamed_to.emplace(var, symbol);
+ auto rename = [&](const sem::Variable* var) -> const ast::Variable* {
+ auto* decl = var->Declaration();
+ auto name = ctx.src->Symbols().NameFor(decl->symbol);
+ auto symbol = ctx.dst->Symbols().New(name);
+ renamed_to.emplace(var, symbol);
- auto source = ctx.Clone(decl->source);
- auto* type = ctx.Clone(decl->type);
- auto* constructor = ctx.Clone(decl->constructor);
- auto attributes = ctx.Clone(decl->attributes);
- return ctx.dst->create<ast::Variable>(
- source, symbol, decl->declared_storage_class, decl->declared_access,
- type, decl->is_const, decl->is_overridable, constructor, attributes);
- };
+ auto source = ctx.Clone(decl->source);
+ auto* type = ctx.Clone(decl->type);
+ auto* constructor = ctx.Clone(decl->constructor);
+ auto attributes = ctx.Clone(decl->attributes);
+ return ctx.dst->create<ast::Variable>(source, symbol, decl->declared_storage_class,
+ decl->declared_access, type, decl->is_const,
+ decl->is_overridable, constructor, attributes);
+ };
- ctx.ReplaceAll([&](const ast::Variable* var) -> const ast::Variable* {
- if (auto* local = sem.Get<sem::LocalVariable>(var)) {
- if (local->Shadows()) {
- return rename(local);
- }
- }
- if (auto* param = sem.Get<sem::Parameter>(var)) {
- if (param->Shadows()) {
- return rename(param);
- }
- }
- return nullptr;
- });
- ctx.ReplaceAll([&](const ast::IdentifierExpression* ident)
- -> const tint::ast::IdentifierExpression* {
- if (auto* user = sem.Get<sem::VariableUser>(ident)) {
- auto it = renamed_to.find(user->Variable());
- if (it != renamed_to.end()) {
- return ctx.dst->Expr(it->second);
- }
- }
- return nullptr;
- });
- ctx.Clone();
- }
+ ctx.ReplaceAll([&](const ast::Variable* var) -> const ast::Variable* {
+ if (auto* local = sem.Get<sem::LocalVariable>(var)) {
+ if (local->Shadows()) {
+ return rename(local);
+ }
+ }
+ if (auto* param = sem.Get<sem::Parameter>(var)) {
+ if (param->Shadows()) {
+ return rename(param);
+ }
+ }
+ return nullptr;
+ });
+ ctx.ReplaceAll(
+ [&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* {
+ if (auto* user = sem.Get<sem::VariableUser>(ident)) {
+ auto it = renamed_to.find(user->Variable());
+ if (it != renamed_to.end()) {
+ return ctx.dst->Expr(it->second);
+ }
+ }
+ return nullptr;
+ });
+ ctx.Clone();
+ }
};
Unshadow::Unshadow() = default;
@@ -91,7 +91,7 @@
Unshadow::~Unshadow() = default;
void Unshadow::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- State(ctx).Run();
+ State(ctx).Run();
}
} // namespace tint::transform
diff --git a/src/tint/transform/unshadow.h b/src/tint/transform/unshadow.h
index bfea677..ce5e975 100644
--- a/src/tint/transform/unshadow.h
+++ b/src/tint/transform/unshadow.h
@@ -22,25 +22,23 @@
/// Unshadow is a Transform that renames any variables that shadow another
/// variable.
class Unshadow : public Castable<Unshadow, Transform> {
- public:
- /// Constructor
- Unshadow();
+ public:
+ /// Constructor
+ Unshadow();
- /// Destructor
- ~Unshadow() override;
+ /// Destructor
+ ~Unshadow() override;
- protected:
- struct State;
+ protected:
+ struct State;
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/unshadow_test.cc b/src/tint/transform/unshadow_test.cc
index ccb9fba..30e1db5 100644
--- a/src/tint/transform/unshadow_test.cc
+++ b/src/tint/transform/unshadow_test.cc
@@ -22,16 +22,16 @@
using UnshadowTest = TransformTest;
TEST_F(UnshadowTest, EmptyModule) {
- auto* src = "";
- auto* expect = "";
+ auto* src = "";
+ auto* expect = "";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, Noop) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : i32;
let b : i32 = 1;
@@ -46,15 +46,15 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsAlias) {
- auto* src = R"(
+ auto* src = R"(
type a = i32;
fn X() {
@@ -66,7 +66,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
type a = i32;
fn X() {
@@ -78,13 +78,13 @@
}
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsAlias_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn X() {
var a = false;
}
@@ -96,7 +96,7 @@
type a = i32;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn X() {
var a_1 = false;
}
@@ -108,13 +108,13 @@
type a = i32;
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsStruct) {
- auto* src = R"(
+ auto* src = R"(
struct a {
m : i32,
};
@@ -128,7 +128,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct a {
m : i32,
}
@@ -142,13 +142,13 @@
}
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsStruct_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn X() {
var a = true;
}
@@ -163,7 +163,7 @@
)";
- auto* expect = R"(
+ auto* expect = R"(
fn X() {
var a_1 = true;
}
@@ -177,13 +177,13 @@
}
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsFunction) {
- auto* src = R"(
+ auto* src = R"(
fn a() {
var a = true;
var b = false;
@@ -195,7 +195,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a() {
var a_1 = true;
var b_1 = false;
@@ -207,13 +207,13 @@
}
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsFunction_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn b() {
let a = true;
let b = false;
@@ -226,7 +226,7 @@
)";
- auto* expect = R"(
+ auto* expect = R"(
fn b() {
let a_1 = true;
let b_1 = false;
@@ -238,13 +238,13 @@
}
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsGlobalVar) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : i32;
fn X() {
@@ -256,7 +256,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : i32;
fn X() {
@@ -268,13 +268,13 @@
}
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsGlobalVar_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn X() {
var a = (a == 123);
}
@@ -286,7 +286,7 @@
var<private> a : i32;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn X() {
var a_1 = (a == 123);
}
@@ -298,13 +298,13 @@
var<private> a : i32;
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsGlobalLet) {
- auto* src = R"(
+ auto* src = R"(
let a : i32 = 1;
fn X() {
@@ -316,7 +316,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
let a : i32 = 1;
fn X() {
@@ -328,13 +328,13 @@
}
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsGlobalLet_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn X() {
var a = (a == 123);
}
@@ -346,7 +346,7 @@
let a : i32 = 1;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn X() {
var a_1 = (a == 123);
}
@@ -358,13 +358,13 @@
let a : i32 = 1;
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsLocalVar) {
- auto* src = R"(
+ auto* src = R"(
fn X() {
var a : i32;
{
@@ -376,7 +376,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn X() {
var a : i32;
{
@@ -388,13 +388,13 @@
}
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsLocalLet) {
- auto* src = R"(
+ auto* src = R"(
fn X() {
let a = 1;
{
@@ -406,7 +406,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn X() {
let a = 1;
{
@@ -418,13 +418,13 @@
}
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsParam) {
- auto* src = R"(
+ auto* src = R"(
fn F(a : i32) {
{
var a = (a == 123);
@@ -435,7 +435,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn F(a : i32) {
{
var a_1 = (a == 123);
@@ -446,13 +446,13 @@
}
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, ParamShadowsFunction) {
- auto* src = R"(
+ auto* src = R"(
fn a(a : i32) {
{
var a = (a == 123);
@@ -463,7 +463,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn a(a_1 : i32) {
{
var a_2 = (a_1 == 123);
@@ -474,73 +474,73 @@
}
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, ParamShadowsGlobalVar) {
- auto* src = R"(
+ auto* src = R"(
var<private> a : i32;
fn F(a : bool) {
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> a : i32;
fn F(a_1 : bool) {
}
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, ParamShadowsGlobalLet) {
- auto* src = R"(
+ auto* src = R"(
let a : i32 = 1;
fn F(a : bool) {
}
)";
- auto* expect = R"(
+ auto* expect = R"(
let a : i32 = 1;
fn F(a_1 : bool) {
}
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, ParamShadowsGlobalLet_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn F(a : bool) {
}
let a : i32 = 1;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn F(a_1 : bool) {
}
let a : i32 = 1;
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, ParamShadowsAlias) {
- auto* src = R"(
+ auto* src = R"(
type a = i32;
fn F(a : a) {
@@ -553,7 +553,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
type a = i32;
fn F(a_1 : a) {
@@ -566,13 +566,13 @@
}
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, ParamShadowsAlias_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn F(a : a) {
{
var a = (a == 123);
@@ -585,7 +585,7 @@
type a = i32;
)";
- auto* expect = R"(
+ auto* expect = R"(
fn F(a_1 : a) {
{
var a_2 = (a_1 == 123);
@@ -598,9 +598,9 @@
type a = i32;
)";
- auto got = Run<Unshadow>(src);
+ auto got = Run<Unshadow>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/unwind_discard_functions.cc b/src/tint/transform/unwind_discard_functions.cc
index 81b1d6d..e1ba74c 100644
--- a/src/tint/transform/unwind_discard_functions.cc
+++ b/src/tint/transform/unwind_discard_functions.cc
@@ -36,304 +36,299 @@
namespace {
class State {
- private:
- CloneContext& ctx;
- ProgramBuilder& b;
- const sem::Info& sem;
- Symbol module_discard_var_name; // Use ModuleDiscardVarName() to read
- Symbol module_discard_func_name; // Use ModuleDiscardFuncName() to read
+ private:
+ CloneContext& ctx;
+ ProgramBuilder& b;
+ const sem::Info& sem;
+ Symbol module_discard_var_name; // Use ModuleDiscardVarName() to read
+ Symbol module_discard_func_name; // Use ModuleDiscardFuncName() to read
- // Returns true if `sem_expr` contains a call expression that may
- // (transitively) execute a discard statement.
- bool MayDiscard(const sem::Expression* sem_expr) {
- return sem_expr && sem_expr->Behaviors().Contains(sem::Behavior::kDiscard);
- }
-
- // Lazily creates and returns the name of the module bool variable for whether
- // to discard: "tint_discard".
- Symbol ModuleDiscardVarName() {
- if (!module_discard_var_name.IsValid()) {
- module_discard_var_name = b.Symbols().New("tint_discard");
- ctx.dst->Global(module_discard_var_name, b.ty.bool_(), b.Expr(false),
- ast::StorageClass::kPrivate);
- }
- return module_discard_var_name;
- }
-
- // Lazily creates and returns the name of the function that contains a single
- // discard statement: "tint_discard_func".
- // We do this to avoid having multiple discard statements in a single program,
- // which causes problems in certain backends (see crbug.com/1118).
- Symbol ModuleDiscardFuncName() {
- if (!module_discard_func_name.IsValid()) {
- module_discard_func_name = b.Symbols().New("tint_discard_func");
- b.Func(module_discard_func_name, {}, b.ty.void_(), {b.Discard()});
- }
- return module_discard_func_name;
- }
-
- // Creates "return <default return value>;" based on the return type of
- // `stmt`'s owning function.
- const ast::ReturnStatement* Return(const ast::Statement* stmt) {
- const ast::Expression* ret_val = nullptr;
- auto* ret_type = sem.Get(stmt)->Function()->Declaration()->return_type;
- if (!ret_type->Is<ast::Void>()) {
- ret_val = b.Construct(ctx.Clone(ret_type));
- }
- return b.Return(ret_val);
- }
-
- // Returns true if the function `stmt` is in is an entry point
- bool IsInEntryPointFunc(const ast::Statement* stmt) {
- return sem.Get(stmt)->Function()->Declaration()->IsEntryPoint();
- }
-
- // Creates "tint_discard_func();"
- const ast::CallStatement* CallDiscardFunc() {
- auto func_name = ModuleDiscardFuncName();
- return b.CallStmt(b.Call(func_name));
- }
-
- // Creates and returns a new if-statement of the form:
- //
- // if (tint_discard) {
- // return <default value>;
- // }
- //
- // or if `stmt` is in a entry point function:
- //
- // if (tint_discard) {
- // tint_discard_func();
- // return <default value>;
- // }
- //
- const ast::IfStatement* IfDiscardReturn(const ast::Statement* stmt) {
- ast::StatementList stmts;
-
- // For entry point functions, also emit the discard statement
- if (IsInEntryPointFunc(stmt)) {
- stmts.emplace_back(CallDiscardFunc());
+ // Returns true if `sem_expr` contains a call expression that may
+ // (transitively) execute a discard statement.
+ bool MayDiscard(const sem::Expression* sem_expr) {
+ return sem_expr && sem_expr->Behaviors().Contains(sem::Behavior::kDiscard);
}
- stmts.emplace_back(Return(stmt));
-
- auto var_name = ModuleDiscardVarName();
- return b.If(var_name, b.Block(stmts));
- }
-
- // Hoists `sem_expr` to a let followed by an `IfDiscardReturn` before `stmt`.
- // For example, if `stmt` is:
- //
- // return f();
- //
- // This function will transform this to:
- //
- // let t1 = f();
- // if (tint_discard) {
- // return;
- // }
- // return t1;
- //
- const ast::Statement* HoistAndInsertBefore(const ast::Statement* stmt,
- const sem::Expression* sem_expr) {
- auto* expr = sem_expr->Declaration();
-
- auto ip = utils::GetInsertionPoint(ctx, stmt);
- auto var_name = b.Sym();
- auto* decl = b.Decl(b.Var(var_name, nullptr, ctx.Clone(expr)));
- ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, decl);
-
- ctx.InsertBefore(ip.first->Declaration()->statements, ip.second,
- IfDiscardReturn(stmt));
-
- auto* var_expr = b.Expr(var_name);
-
- // Special handling for CallStatement as we can only replace its expression
- // with a CallExpression.
- if (stmt->Is<ast::CallStatement>()) {
- // We could replace the call statement with no statement, but we can't do
- // that with transforms (yet), so just return a phony assignment.
- return b.Assign(b.Phony(), var_expr);
+ // Lazily creates and returns the name of the module bool variable for whether
+ // to discard: "tint_discard".
+ Symbol ModuleDiscardVarName() {
+ if (!module_discard_var_name.IsValid()) {
+ module_discard_var_name = b.Symbols().New("tint_discard");
+ ctx.dst->Global(module_discard_var_name, b.ty.bool_(), b.Expr(false),
+ ast::StorageClass::kPrivate);
+ }
+ return module_discard_var_name;
}
- ctx.Replace(expr, var_expr);
- return ctx.CloneWithoutTransform(stmt);
- }
-
- // Returns true if `stmt` is a for-loop initializer statement.
- bool IsForLoopInitStatement(const ast::Statement* stmt) {
- if (auto* sem_stmt = sem.Get(stmt)) {
- if (auto* sem_fl = As<sem::ForLoopStatement>(sem_stmt->Parent())) {
- return sem_fl->Declaration()->initializer == stmt;
- }
- }
- return false;
- }
-
- // Inserts an `IfDiscardReturn` after `stmt` if possible (i.e. `stmt` is not
- // in a for-loop init), otherwise falls back to HoistAndInsertBefore, hoisting
- // `sem_expr` to a let followed by an `IfDiscardReturn` before `stmt`.
- //
- // For example, if `stmt` is:
- //
- // let r = f();
- //
- // This function will transform this to:
- //
- // let r = f();
- // if (tint_discard) {
- // return;
- // }
- const ast::Statement* TryInsertAfter(const ast::Statement* stmt,
- const sem::Expression* sem_expr) {
- // If `stmt` is the init of a for-loop, hoist and insert before instead.
- if (IsForLoopInitStatement(stmt)) {
- return HoistAndInsertBefore(stmt, sem_expr);
+ // Lazily creates and returns the name of the function that contains a single
+ // discard statement: "tint_discard_func".
+ // We do this to avoid having multiple discard statements in a single program,
+ // which causes problems in certain backends (see crbug.com/1118).
+ Symbol ModuleDiscardFuncName() {
+ if (!module_discard_func_name.IsValid()) {
+ module_discard_func_name = b.Symbols().New("tint_discard_func");
+ b.Func(module_discard_func_name, {}, b.ty.void_(), {b.Discard()});
+ }
+ return module_discard_func_name;
}
- auto ip = utils::GetInsertionPoint(ctx, stmt);
- ctx.InsertAfter(ip.first->Declaration()->statements, ip.second,
- IfDiscardReturn(stmt));
- return nullptr; // Don't replace current statement
- }
-
- // Replaces the input discard statement with either setting the module level
- // discard bool ("tint_discard = true"), or calling the discard function
- // ("tint_discard_func()"), followed by a default return statement.
- //
- // Replaces "discard;" with:
- //
- // tint_discard = true;
- // return;
- //
- // Or if `stmt` is a entry point function, replaces with:
- //
- // tint_discard_func();
- // return;
- //
- const ast::Statement* ReplaceDiscardStatement(
- const ast::DiscardStatement* stmt) {
- const ast::Statement* to_insert = nullptr;
- if (IsInEntryPointFunc(stmt)) {
- to_insert = CallDiscardFunc();
- } else {
- auto var_name = ModuleDiscardVarName();
- to_insert = b.Assign(var_name, true);
+ // Creates "return <default return value>;" based on the return type of
+ // `stmt`'s owning function.
+ const ast::ReturnStatement* Return(const ast::Statement* stmt) {
+ const ast::Expression* ret_val = nullptr;
+ auto* ret_type = sem.Get(stmt)->Function()->Declaration()->return_type;
+ if (!ret_type->Is<ast::Void>()) {
+ ret_val = b.Construct(ctx.Clone(ret_type));
+ }
+ return b.Return(ret_val);
}
- auto ip = utils::GetInsertionPoint(ctx, stmt);
- ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, to_insert);
- return Return(stmt);
- }
+ // Returns true if the function `stmt` is in is an entry point
+ bool IsInEntryPointFunc(const ast::Statement* stmt) {
+ return sem.Get(stmt)->Function()->Declaration()->IsEntryPoint();
+ }
- // Handle statement
- const ast::Statement* Statement(const ast::Statement* stmt) {
- return Switch(
- stmt,
- [&](const ast::DiscardStatement* s) -> const ast::Statement* {
- return ReplaceDiscardStatement(s);
- },
- [&](const ast::AssignmentStatement* s) -> const ast::Statement* {
- auto* sem_lhs = sem.Get(s->lhs);
- auto* sem_rhs = sem.Get(s->rhs);
- if (MayDiscard(sem_lhs)) {
- if (MayDiscard(sem_rhs)) {
- TINT_ICE(Transform, b.Diagnostics())
- << "Unexpected: both sides of assignment statement may "
- "discard. Make sure transform::PromoteSideEffectsToDecl "
- "was run first.";
+ // Creates "tint_discard_func();"
+ const ast::CallStatement* CallDiscardFunc() {
+ auto func_name = ModuleDiscardFuncName();
+ return b.CallStmt(b.Call(func_name));
+ }
+
+ // Creates and returns a new if-statement of the form:
+ //
+ // if (tint_discard) {
+ // return <default value>;
+ // }
+ //
+ // or if `stmt` is in a entry point function:
+ //
+ // if (tint_discard) {
+ // tint_discard_func();
+ // return <default value>;
+ // }
+ //
+ const ast::IfStatement* IfDiscardReturn(const ast::Statement* stmt) {
+ ast::StatementList stmts;
+
+ // For entry point functions, also emit the discard statement
+ if (IsInEntryPointFunc(stmt)) {
+ stmts.emplace_back(CallDiscardFunc());
+ }
+
+ stmts.emplace_back(Return(stmt));
+
+ auto var_name = ModuleDiscardVarName();
+ return b.If(var_name, b.Block(stmts));
+ }
+
+ // Hoists `sem_expr` to a let followed by an `IfDiscardReturn` before `stmt`.
+ // For example, if `stmt` is:
+ //
+ // return f();
+ //
+ // This function will transform this to:
+ //
+ // let t1 = f();
+ // if (tint_discard) {
+ // return;
+ // }
+ // return t1;
+ //
+ const ast::Statement* HoistAndInsertBefore(const ast::Statement* stmt,
+ const sem::Expression* sem_expr) {
+ auto* expr = sem_expr->Declaration();
+
+ auto ip = utils::GetInsertionPoint(ctx, stmt);
+ auto var_name = b.Sym();
+ auto* decl = b.Decl(b.Var(var_name, nullptr, ctx.Clone(expr)));
+ ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, decl);
+
+ ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, IfDiscardReturn(stmt));
+
+ auto* var_expr = b.Expr(var_name);
+
+ // Special handling for CallStatement as we can only replace its expression
+ // with a CallExpression.
+ if (stmt->Is<ast::CallStatement>()) {
+ // We could replace the call statement with no statement, but we can't do
+ // that with transforms (yet), so just return a phony assignment.
+ return b.Assign(b.Phony(), var_expr);
+ }
+
+ ctx.Replace(expr, var_expr);
+ return ctx.CloneWithoutTransform(stmt);
+ }
+
+ // Returns true if `stmt` is a for-loop initializer statement.
+ bool IsForLoopInitStatement(const ast::Statement* stmt) {
+ if (auto* sem_stmt = sem.Get(stmt)) {
+ if (auto* sem_fl = As<sem::ForLoopStatement>(sem_stmt->Parent())) {
+ return sem_fl->Declaration()->initializer == stmt;
}
- return TryInsertAfter(s, sem_lhs);
- } else if (MayDiscard(sem_rhs)) {
- return TryInsertAfter(s, sem_rhs);
- }
- return nullptr;
- },
- [&](const ast::CallStatement* s) -> const ast::Statement* {
- auto* sem_expr = sem.Get(s->expr);
- if (!MayDiscard(sem_expr)) {
- return nullptr;
- }
- return TryInsertAfter(s, sem_expr);
- },
- [&](const ast::ForLoopStatement* s) -> const ast::Statement* {
- if (MayDiscard(sem.Get(s->condition))) {
- TINT_ICE(Transform, b.Diagnostics())
- << "Unexpected ForLoopStatement condition that may discard. "
- "Make sure transform::PromoteSideEffectsToDecl was run "
- "first.";
- }
- return nullptr;
- },
- [&](const ast::IfStatement* s) -> const ast::Statement* {
- auto* sem_expr = sem.Get(s->condition);
- if (!MayDiscard(sem_expr)) {
- return nullptr;
- }
- return HoistAndInsertBefore(s, sem_expr);
- },
- [&](const ast::ReturnStatement* s) -> const ast::Statement* {
- auto* sem_expr = sem.Get(s->value);
- if (!MayDiscard(sem_expr)) {
- return nullptr;
- }
- return HoistAndInsertBefore(s, sem_expr);
- },
- [&](const ast::SwitchStatement* s) -> const ast::Statement* {
- auto* sem_expr = sem.Get(s->condition);
- if (!MayDiscard(sem_expr)) {
- return nullptr;
- }
- return HoistAndInsertBefore(s, sem_expr);
- },
- [&](const ast::VariableDeclStatement* s) -> const ast::Statement* {
- auto* var = s->variable;
- if (!var->constructor) {
- return nullptr;
- }
- auto* sem_expr = sem.Get(var->constructor);
- if (!MayDiscard(sem_expr)) {
- return nullptr;
- }
- return TryInsertAfter(s, sem_expr);
- });
- }
+ }
+ return false;
+ }
- public:
- /// Constructor
- /// @param ctx_in the context
- explicit State(CloneContext& ctx_in)
- : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}
+ // Inserts an `IfDiscardReturn` after `stmt` if possible (i.e. `stmt` is not
+ // in a for-loop init), otherwise falls back to HoistAndInsertBefore, hoisting
+ // `sem_expr` to a let followed by an `IfDiscardReturn` before `stmt`.
+ //
+ // For example, if `stmt` is:
+ //
+ // let r = f();
+ //
+ // This function will transform this to:
+ //
+ // let r = f();
+ // if (tint_discard) {
+ // return;
+ // }
+ const ast::Statement* TryInsertAfter(const ast::Statement* stmt,
+ const sem::Expression* sem_expr) {
+ // If `stmt` is the init of a for-loop, hoist and insert before instead.
+ if (IsForLoopInitStatement(stmt)) {
+ return HoistAndInsertBefore(stmt, sem_expr);
+ }
- /// Runs the transform
- void Run() {
- ctx.ReplaceAll(
- [&](const ast::BlockStatement* block) -> const ast::Statement* {
- // Iterate block statements and replace them as needed.
- for (auto* stmt : block->statements) {
- if (auto* new_stmt = Statement(stmt)) {
- ctx.Replace(stmt, new_stmt);
+ auto ip = utils::GetInsertionPoint(ctx, stmt);
+ ctx.InsertAfter(ip.first->Declaration()->statements, ip.second, IfDiscardReturn(stmt));
+ return nullptr; // Don't replace current statement
+ }
+
+ // Replaces the input discard statement with either setting the module level
+ // discard bool ("tint_discard = true"), or calling the discard function
+ // ("tint_discard_func()"), followed by a default return statement.
+ //
+ // Replaces "discard;" with:
+ //
+ // tint_discard = true;
+ // return;
+ //
+ // Or if `stmt` is a entry point function, replaces with:
+ //
+ // tint_discard_func();
+ // return;
+ //
+ const ast::Statement* ReplaceDiscardStatement(const ast::DiscardStatement* stmt) {
+ const ast::Statement* to_insert = nullptr;
+ if (IsInEntryPointFunc(stmt)) {
+ to_insert = CallDiscardFunc();
+ } else {
+ auto var_name = ModuleDiscardVarName();
+ to_insert = b.Assign(var_name, true);
+ }
+
+ auto ip = utils::GetInsertionPoint(ctx, stmt);
+ ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, to_insert);
+ return Return(stmt);
+ }
+
+ // Handle statement
+ const ast::Statement* Statement(const ast::Statement* stmt) {
+ return Switch(
+ stmt,
+ [&](const ast::DiscardStatement* s) -> const ast::Statement* {
+ return ReplaceDiscardStatement(s);
+ },
+ [&](const ast::AssignmentStatement* s) -> const ast::Statement* {
+ auto* sem_lhs = sem.Get(s->lhs);
+ auto* sem_rhs = sem.Get(s->rhs);
+ if (MayDiscard(sem_lhs)) {
+ if (MayDiscard(sem_rhs)) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "Unexpected: both sides of assignment statement may "
+ "discard. Make sure transform::PromoteSideEffectsToDecl "
+ "was run first.";
+ }
+ return TryInsertAfter(s, sem_lhs);
+ } else if (MayDiscard(sem_rhs)) {
+ return TryInsertAfter(s, sem_rhs);
+ }
+ return nullptr;
+ },
+ [&](const ast::CallStatement* s) -> const ast::Statement* {
+ auto* sem_expr = sem.Get(s->expr);
+ if (!MayDiscard(sem_expr)) {
+ return nullptr;
+ }
+ return TryInsertAfter(s, sem_expr);
+ },
+ [&](const ast::ForLoopStatement* s) -> const ast::Statement* {
+ if (MayDiscard(sem.Get(s->condition))) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "Unexpected ForLoopStatement condition that may discard. "
+ "Make sure transform::PromoteSideEffectsToDecl was run "
+ "first.";
+ }
+ return nullptr;
+ },
+ [&](const ast::IfStatement* s) -> const ast::Statement* {
+ auto* sem_expr = sem.Get(s->condition);
+ if (!MayDiscard(sem_expr)) {
+ return nullptr;
+ }
+ return HoistAndInsertBefore(s, sem_expr);
+ },
+ [&](const ast::ReturnStatement* s) -> const ast::Statement* {
+ auto* sem_expr = sem.Get(s->value);
+ if (!MayDiscard(sem_expr)) {
+ return nullptr;
+ }
+ return HoistAndInsertBefore(s, sem_expr);
+ },
+ [&](const ast::SwitchStatement* s) -> const ast::Statement* {
+ auto* sem_expr = sem.Get(s->condition);
+ if (!MayDiscard(sem_expr)) {
+ return nullptr;
+ }
+ return HoistAndInsertBefore(s, sem_expr);
+ },
+ [&](const ast::VariableDeclStatement* s) -> const ast::Statement* {
+ auto* var = s->variable;
+ if (!var->constructor) {
+ return nullptr;
+ }
+ auto* sem_expr = sem.Get(var->constructor);
+ if (!MayDiscard(sem_expr)) {
+ return nullptr;
+ }
+ return TryInsertAfter(s, sem_expr);
+ });
+ }
+
+ public:
+ /// Constructor
+ /// @param ctx_in the context
+ explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}
+
+ /// Runs the transform
+ void Run() {
+ ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* {
+ // Iterate block statements and replace them as needed.
+ for (auto* stmt : block->statements) {
+ if (auto* new_stmt = Statement(stmt)) {
+ ctx.Replace(stmt, new_stmt);
+ }
+
+ // Handle for loops, as they are the only other AST node that
+ // contains statements outside of BlockStatements.
+ if (auto* fl = stmt->As<ast::ForLoopStatement>()) {
+ if (auto* new_stmt = Statement(fl->initializer)) {
+ ctx.Replace(fl->initializer, new_stmt);
+ }
+ if (auto* new_stmt = Statement(fl->continuing)) {
+ // NOTE: Should never reach here as we cannot discard in a
+ // continuing block.
+ ctx.Replace(fl->continuing, new_stmt);
+ }
+ }
}
- // Handle for loops, as they are the only other AST node that
- // contains statements outside of BlockStatements.
- if (auto* fl = stmt->As<ast::ForLoopStatement>()) {
- if (auto* new_stmt = Statement(fl->initializer)) {
- ctx.Replace(fl->initializer, new_stmt);
- }
- if (auto* new_stmt = Statement(fl->continuing)) {
- // NOTE: Should never reach here as we cannot discard in a
- // continuing block.
- ctx.Replace(fl->continuing, new_stmt);
- }
- }
- }
-
- return nullptr;
+ return nullptr;
});
- ctx.Clone();
- }
+ ctx.Clone();
+ }
};
} // namespace
@@ -341,22 +336,19 @@
UnwindDiscardFunctions::UnwindDiscardFunctions() = default;
UnwindDiscardFunctions::~UnwindDiscardFunctions() = default;
-void UnwindDiscardFunctions::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- State state(ctx);
- state.Run();
+void UnwindDiscardFunctions::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ State state(ctx);
+ state.Run();
}
-bool UnwindDiscardFunctions::ShouldRun(const Program* program,
- const DataMap& /*data*/) const {
- auto& sem = program->Sem();
- for (auto* f : program->AST().Functions()) {
- if (sem.Get(f)->Behaviors().Contains(sem::Behavior::kDiscard)) {
- return true;
+bool UnwindDiscardFunctions::ShouldRun(const Program* program, const DataMap& /*data*/) const {
+ auto& sem = program->Sem();
+ for (auto* f : program->AST().Functions()) {
+ if (sem.Get(f)->Behaviors().Contains(sem::Behavior::kDiscard)) {
+ return true;
+ }
}
- }
- return false;
+ return false;
}
} // namespace tint::transform
diff --git a/src/tint/transform/unwind_discard_functions.h b/src/tint/transform/unwind_discard_functions.h
index 42bbecb..3b1d838 100644
--- a/src/tint/transform/unwind_discard_functions.h
+++ b/src/tint/transform/unwind_discard_functions.h
@@ -36,31 +36,27 @@
///
/// @note Depends on the following transforms to have been run first:
/// * PromoteSideEffectsToDecl
-class UnwindDiscardFunctions
- : public Castable<UnwindDiscardFunctions, Transform> {
- public:
- /// Constructor
- UnwindDiscardFunctions();
+class UnwindDiscardFunctions : public Castable<UnwindDiscardFunctions, Transform> {
+ public:
+ /// Constructor
+ UnwindDiscardFunctions();
- /// Destructor
- ~UnwindDiscardFunctions() override;
+ /// Destructor
+ ~UnwindDiscardFunctions() override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/unwind_discard_functions_test.cc b/src/tint/transform/unwind_discard_functions_test.cc
index 0b8c0fc..70b4218 100644
--- a/src/tint/transform/unwind_discard_functions_test.cc
+++ b/src/tint/transform/unwind_discard_functions_test.cc
@@ -22,31 +22,31 @@
using UnwindDiscardFunctionsTest = TransformTest;
TEST_F(UnwindDiscardFunctionsTest, EmptyModule) {
- auto* src = "";
- auto* expect = src;
+ auto* src = "";
+ auto* expect = src;
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, ShouldRun_NoDiscardFunc) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
}
)";
- EXPECT_FALSE(ShouldRun<UnwindDiscardFunctions>(src));
+ EXPECT_FALSE(ShouldRun<UnwindDiscardFunctions>(src));
}
TEST_F(UnwindDiscardFunctionsTest, SingleDiscardFunc_NoCall) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
discard;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() {
@@ -55,14 +55,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, MultipleDiscardFuncs_NoCall) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
discard;
let marker1 = 0;
@@ -73,7 +73,7 @@
let marker1 = 0;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() {
@@ -89,14 +89,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, Call_VoidReturn) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
discard;
let marker1 = 0;
@@ -109,7 +109,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() {
@@ -134,14 +134,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, Call_NonVoidReturn) {
- auto* src = R"(
+ auto* src = R"(
struct S {
x : i32,
y : i32,
@@ -164,7 +164,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
x : i32,
y : i32,
@@ -199,14 +199,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, Call_Nested) {
- auto* src = R"(
+ auto* src = R"(
fn f() -> i32 {
let marker1 = 0;
if (true) {
@@ -238,7 +238,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() -> i32 {
@@ -288,14 +288,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, Call_Multiple) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
discard;
let marker1 = 0;
@@ -323,7 +323,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() {
@@ -373,14 +373,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, Call_DiscardFuncDeclaredBelow) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
f();
@@ -393,7 +393,7 @@
let marker1 = 0;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn tint_discard_func() {
discard;
}
@@ -418,14 +418,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, If) {
- auto* src = R"(
+ auto* src = R"(
fn f() -> i32 {
if (true) {
discard;
@@ -441,7 +441,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() -> i32 {
@@ -470,14 +470,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, ElseIf) {
- auto* src = R"(
+ auto* src = R"(
fn f() -> i32 {
if (true) {
discard;
@@ -497,7 +497,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() -> i32 {
@@ -532,14 +532,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, ForLoop_Init_Assignment) {
- auto* src = R"(
+ auto* src = R"(
fn f() -> i32 {
if (true) {
discard;
@@ -558,7 +558,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() -> i32 {
@@ -590,14 +590,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, ForLoop_Init_Call) {
- auto* src = R"(
+ auto* src = R"(
fn f() -> i32 {
if (true) {
discard;
@@ -615,7 +615,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() -> i32 {
@@ -646,14 +646,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, ForLoop_Init_VariableDecl) {
- auto* src = R"(
+ auto* src = R"(
fn f() -> i32 {
if (true) {
discard;
@@ -671,7 +671,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() -> i32 {
@@ -702,14 +702,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, ForLoop_Cond) {
- auto* src = R"(
+ auto* src = R"(
fn f() -> i32 {
if (true) {
discard;
@@ -727,7 +727,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() -> i32 {
@@ -763,14 +763,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, ForLoop_Cont) {
- auto* src = R"(
+ auto* src = R"(
fn f() -> i32 {
if (true) {
discard;
@@ -788,20 +788,20 @@
return vec4<f32>();
}
)";
- auto* expect =
- R"(test:12:12 error: cannot call a function that may discard inside a continuing block
+ auto* expect =
+ R"(test:12:12 error: cannot call a function that may discard inside a continuing block
for (; ; f()) {
^
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, Switch) {
- auto* src = R"(
+ auto* src = R"(
fn f() -> i32 {
if (true) {
discard;
@@ -828,7 +828,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() -> i32 {
@@ -868,14 +868,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, Return) {
- auto* src = R"(
+ auto* src = R"(
struct S {
x : i32,
y : i32,
@@ -900,7 +900,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
x : i32,
y : i32,
@@ -941,14 +941,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, VariableDecl) {
- auto* src = R"(
+ auto* src = R"(
fn f() -> i32 {
if (true) {
discard;
@@ -963,7 +963,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() -> i32 {
@@ -990,14 +990,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, Assignment_RightDiscard) {
- auto* src = R"(
+ auto* src = R"(
fn f() -> i32 {
if (true) {
discard;
@@ -1013,7 +1013,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() -> i32 {
@@ -1041,14 +1041,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, Assignment_LeftDiscard) {
- auto* src = R"(
+ auto* src = R"(
fn f() -> i32 {
if (true) {
discard;
@@ -1064,7 +1064,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() -> i32 {
@@ -1093,14 +1093,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, Assignment_BothDiscard) {
- auto* src = R"(
+ auto* src = R"(
fn f() -> i32 {
if (true) {
discard;
@@ -1123,7 +1123,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() -> i32 {
@@ -1165,14 +1165,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, Binary_Arith_MultipleDiscardFuncs) {
- auto* src = R"(
+ auto* src = R"(
fn f() -> i32 {
if (true) {
discard;
@@ -1202,7 +1202,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() -> i32 {
@@ -1257,14 +1257,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, Binary_Logical_MultipleDiscardFuncs) {
- auto* src = R"(
+ auto* src = R"(
fn f() -> i32 {
if (true) {
discard;
@@ -1294,7 +1294,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() -> i32 {
@@ -1357,14 +1357,14 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, EnsureNoSymbolCollision) {
- auto* src = R"(
+ auto* src = R"(
var<private> tint_discard_func : i32;
var<private> tint_discard : i32;
@@ -1380,7 +1380,7 @@
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<private> tint_discard_func : i32;
var<private> tint_discard : i32;
@@ -1409,10 +1409,10 @@
}
)";
- DataMap data;
- auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/utils/get_insertion_point.cc b/src/tint/transform/utils/get_insertion_point.cc
index 0f00e0c..d10d134 100644
--- a/src/tint/transform/utils/get_insertion_point.cc
+++ b/src/tint/transform/utils/get_insertion_point.cc
@@ -19,40 +19,39 @@
namespace tint::transform::utils {
-InsertionPoint GetInsertionPoint(CloneContext& ctx,
- const ast::Statement* stmt) {
- auto& sem = ctx.src->Sem();
- auto& diag = ctx.dst->Diagnostics();
- using RetType = std::pair<const sem::BlockStatement*, const ast::Statement*>;
+InsertionPoint GetInsertionPoint(CloneContext& ctx, const ast::Statement* stmt) {
+ auto& sem = ctx.src->Sem();
+ auto& diag = ctx.dst->Diagnostics();
+ using RetType = std::pair<const sem::BlockStatement*, const ast::Statement*>;
- if (auto* sem_stmt = sem.Get(stmt)) {
- auto* parent = sem_stmt->Parent();
- return Switch(
- parent,
- [&](const sem::BlockStatement* block) -> RetType {
- // Common case, can insert in the current block above/below the input
- // statement.
- return {block, stmt};
- },
- [&](const sem::ForLoopStatement* fl) -> RetType {
- // `stmt` is either the for loop initializer or the continuing
- // statement of a for-loop.
- if (fl->Declaration()->initializer == stmt) {
- // For loop init, can insert above the for loop itself.
- return {fl->Block(), fl->Declaration()};
- }
+ if (auto* sem_stmt = sem.Get(stmt)) {
+ auto* parent = sem_stmt->Parent();
+ return Switch(
+ parent,
+ [&](const sem::BlockStatement* block) -> RetType {
+ // Common case, can insert in the current block above/below the input
+ // statement.
+ return {block, stmt};
+ },
+ [&](const sem::ForLoopStatement* fl) -> RetType {
+ // `stmt` is either the for loop initializer or the continuing
+ // statement of a for-loop.
+ if (fl->Declaration()->initializer == stmt) {
+ // For loop init, can insert above the for loop itself.
+ return {fl->Block(), fl->Declaration()};
+ }
- // Cannot insert before or after continuing statement of a for-loop
- return {};
- },
- [&](Default) -> RetType {
- TINT_ICE(Transform, diag) << "expected parent of statement to be "
- "either a block or for loop";
- return {};
- });
- }
+ // Cannot insert before or after continuing statement of a for-loop
+ return {};
+ },
+ [&](Default) -> RetType {
+ TINT_ICE(Transform, diag) << "expected parent of statement to be "
+ "either a block or for loop";
+ return {};
+ });
+ }
- return {};
+ return {};
}
} // namespace tint::transform::utils
diff --git a/src/tint/transform/utils/get_insertion_point.h b/src/tint/transform/utils/get_insertion_point.h
index 85abcea..14e867c 100644
--- a/src/tint/transform/utils/get_insertion_point.h
+++ b/src/tint/transform/utils/get_insertion_point.h
@@ -24,8 +24,7 @@
/// InsertionPoint is a pair of the block (`first`) within which, and the
/// statement (`second`) before or after which to insert.
-using InsertionPoint =
- std::pair<const sem::BlockStatement*, const ast::Statement*>;
+using InsertionPoint = std::pair<const sem::BlockStatement*, const ast::Statement*>;
/// For the input statement, returns the block and statement within that
/// block to insert before/after. If `stmt` is a for-loop continue statement,
diff --git a/src/tint/transform/utils/get_insertion_point_test.cc b/src/tint/transform/utils/get_insertion_point_test.cc
index 48e358e..83eb76d 100644
--- a/src/tint/transform/utils/get_insertion_point_test.cc
+++ b/src/tint/transform/utils/get_insertion_point_test.cc
@@ -26,68 +26,68 @@
using GetInsertionPointTest = ::testing::Test;
TEST_F(GetInsertionPointTest, Block) {
- // fn f() {
- // var a = 1;
- // }
- ProgramBuilder b;
- auto* expr = b.Expr(1);
- auto* var = b.Decl(b.Var("a", nullptr, expr));
- auto* block = b.Block(var);
- b.Func("f", {}, b.ty.void_(), {block});
+ // fn f() {
+ // var a = 1;
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Expr(1);
+ auto* var = b.Decl(b.Var("a", nullptr, expr));
+ auto* block = b.Block(var);
+ b.Func("f", {}, b.ty.void_(), {block});
- Program original(std::move(b));
- ProgramBuilder cloned_b;
- CloneContext ctx(&cloned_b, &original);
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
- // Can insert in block containing the variable, above or below the input
- // statement.
- auto ip = utils::GetInsertionPoint(ctx, var);
- ASSERT_EQ(ip.first->Declaration(), block);
- ASSERT_EQ(ip.second, var);
+ // Can insert in block containing the variable, above or below the input
+ // statement.
+ auto ip = utils::GetInsertionPoint(ctx, var);
+ ASSERT_EQ(ip.first->Declaration(), block);
+ ASSERT_EQ(ip.second, var);
}
TEST_F(GetInsertionPointTest, ForLoopInit) {
- // fn f() {
- // for(var a = 1; true; ) {
- // }
- // }
- ProgramBuilder b;
- auto* expr = b.Expr(1);
- auto* var = b.Decl(b.Var("a", nullptr, expr));
- auto* fl = b.For(var, b.Expr(true), {}, b.Block());
- auto* func_block = b.Block(fl);
- b.Func("f", {}, b.ty.void_(), {func_block});
+ // fn f() {
+ // for(var a = 1; true; ) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Expr(1);
+ auto* var = b.Decl(b.Var("a", nullptr, expr));
+ auto* fl = b.For(var, b.Expr(true), {}, b.Block());
+ auto* func_block = b.Block(fl);
+ b.Func("f", {}, b.ty.void_(), {func_block});
- Program original(std::move(b));
- ProgramBuilder cloned_b;
- CloneContext ctx(&cloned_b, &original);
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
- // Can insert in block containing for-loop above the for-loop itself.
- auto ip = utils::GetInsertionPoint(ctx, var);
- ASSERT_EQ(ip.first->Declaration(), func_block);
- ASSERT_EQ(ip.second, fl);
+ // Can insert in block containing for-loop above the for-loop itself.
+ auto ip = utils::GetInsertionPoint(ctx, var);
+ ASSERT_EQ(ip.first->Declaration(), func_block);
+ ASSERT_EQ(ip.second, fl);
}
TEST_F(GetInsertionPointTest, ForLoopCont_Invalid) {
- // fn f() {
- // for(; true; var a = 1) {
- // }
- // }
- ProgramBuilder b;
- auto* expr = b.Expr(1);
- auto* var = b.Decl(b.Var("a", nullptr, expr));
- auto* s = b.For({}, b.Expr(true), var, b.Block());
- b.Func("f", {}, b.ty.void_(), {s});
+ // fn f() {
+ // for(; true; var a = 1) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Expr(1);
+ auto* var = b.Decl(b.Var("a", nullptr, expr));
+ auto* s = b.For({}, b.Expr(true), var, b.Block());
+ b.Func("f", {}, b.ty.void_(), {s});
- Program original(std::move(b));
- ProgramBuilder cloned_b;
- CloneContext ctx(&cloned_b, &original);
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
- // Can't insert before/after for loop continue statement (would ned to be
- // converted to loop).
- auto ip = utils::GetInsertionPoint(ctx, var);
- ASSERT_EQ(ip.first, nullptr);
- ASSERT_EQ(ip.second, nullptr);
+ // Can't insert before/after for loop continue statement (would ned to be
+ // converted to loop).
+ auto ip = utils::GetInsertionPoint(ctx, var);
+ ASSERT_EQ(ip.first, nullptr);
+ ASSERT_EQ(ip.second, nullptr);
}
} // namespace
diff --git a/src/tint/transform/utils/hoist_to_decl_before.cc b/src/tint/transform/utils/hoist_to_decl_before.cc
index 78ba992..9917412 100644
--- a/src/tint/transform/utils/hoist_to_decl_before.cc
+++ b/src/tint/transform/utils/hoist_to_decl_before.cc
@@ -28,247 +28,241 @@
/// Private implementation of HoistToDeclBefore transform
class HoistToDeclBefore::State {
- CloneContext& ctx;
- ProgramBuilder& b;
+ CloneContext& ctx;
+ ProgramBuilder& b;
- /// Holds information about a for-loop that needs to be decomposed into a
- /// loop, so that declaration statements can be inserted before the
- /// condition expression or continuing statement.
- struct LoopInfo {
- ast::StatementList cond_decls;
- ast::StatementList cont_decls;
- };
+ /// Holds information about a for-loop that needs to be decomposed into a
+ /// loop, so that declaration statements can be inserted before the
+ /// condition expression or continuing statement.
+ struct LoopInfo {
+ ast::StatementList cond_decls;
+ ast::StatementList cont_decls;
+ };
- /// Info for each else-if that needs decomposing
- struct ElseIfInfo {
- /// Decls to insert before condition
- ast::StatementList cond_decls;
- };
+ /// Info for each else-if that needs decomposing
+ struct ElseIfInfo {
+ /// Decls to insert before condition
+ ast::StatementList cond_decls;
+ };
- /// For-loops that need to be decomposed to loops.
- std::unordered_map<const sem::ForLoopStatement*, LoopInfo> loops;
+ /// For-loops that need to be decomposed to loops.
+ std::unordered_map<const sem::ForLoopStatement*, LoopInfo> loops;
- /// 'else if' statements that need to be decomposed to 'else {if}'
- std::unordered_map<const ast::IfStatement*, ElseIfInfo> else_ifs;
+ /// 'else if' statements that need to be decomposed to 'else {if}'
+ std::unordered_map<const ast::IfStatement*, ElseIfInfo> else_ifs;
- // Converts any for-loops marked for conversion to loops, inserting
- // registered declaration statements before the condition or continuing
- // statement.
- void ForLoopsToLoops() {
- if (loops.empty()) {
- return;
- }
+ // Converts any for-loops marked for conversion to loops, inserting
+ // registered declaration statements before the condition or continuing
+ // statement.
+ void ForLoopsToLoops() {
+ if (loops.empty()) {
+ return;
+ }
- // At least one for-loop needs to be transformed into a loop.
- ctx.ReplaceAll(
- [&](const ast::ForLoopStatement* stmt) -> const ast::Statement* {
- auto& sem = ctx.src->Sem();
+ // At least one for-loop needs to be transformed into a loop.
+ ctx.ReplaceAll([&](const ast::ForLoopStatement* stmt) -> const ast::Statement* {
+ auto& sem = ctx.src->Sem();
- if (auto* fl = sem.Get(stmt)) {
- if (auto it = loops.find(fl); it != loops.end()) {
- auto& info = it->second;
- auto* for_loop = fl->Declaration();
- // For-loop needs to be decomposed to a loop.
- // Build the loop body's statements.
- // Start with any let declarations for the conditional
- // expression.
- auto body_stmts = info.cond_decls;
- // If the for-loop has a condition, emit this next as:
- // if (!cond) { break; }
- if (auto* cond = for_loop->condition) {
- // !condition
- auto* not_cond = b.create<ast::UnaryOpExpression>(
- ast::UnaryOp::kNot, ctx.Clone(cond));
- // { break; }
- auto* break_body = b.Block(b.create<ast::BreakStatement>());
- // if (!condition) { break; }
- body_stmts.emplace_back(b.If(not_cond, break_body));
- }
- // Next emit the for-loop body
- body_stmts.emplace_back(ctx.Clone(for_loop->body));
+ if (auto* fl = sem.Get(stmt)) {
+ if (auto it = loops.find(fl); it != loops.end()) {
+ auto& info = it->second;
+ auto* for_loop = fl->Declaration();
+ // For-loop needs to be decomposed to a loop.
+ // Build the loop body's statements.
+ // Start with any let declarations for the conditional
+ // expression.
+ auto body_stmts = info.cond_decls;
+ // If the for-loop has a condition, emit this next as:
+ // if (!cond) { break; }
+ if (auto* cond = for_loop->condition) {
+ // !condition
+ auto* not_cond =
+ b.create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond));
+ // { break; }
+ auto* break_body = b.Block(b.create<ast::BreakStatement>());
+ // if (!condition) { break; }
+ body_stmts.emplace_back(b.If(not_cond, break_body));
+ }
+ // Next emit the for-loop body
+ body_stmts.emplace_back(ctx.Clone(for_loop->body));
- // Finally create the continuing block if there was one.
- const ast::BlockStatement* continuing = nullptr;
- if (auto* cont = for_loop->continuing) {
- // Continuing block starts with any let declarations used by
- // the continuing.
- auto cont_stmts = info.cont_decls;
- cont_stmts.emplace_back(ctx.Clone(cont));
- continuing = b.Block(cont_stmts);
- }
+ // Finally create the continuing block if there was one.
+ const ast::BlockStatement* continuing = nullptr;
+ if (auto* cont = for_loop->continuing) {
+ // Continuing block starts with any let declarations used by
+ // the continuing.
+ auto cont_stmts = info.cont_decls;
+ cont_stmts.emplace_back(ctx.Clone(cont));
+ continuing = b.Block(cont_stmts);
+ }
- auto* body = b.Block(body_stmts);
- auto* loop = b.Loop(body, continuing);
- if (auto* init = for_loop->initializer) {
- return b.Block(ctx.Clone(init), loop);
- }
- return loop;
+ auto* body = b.Block(body_stmts);
+ auto* loop = b.Loop(body, continuing);
+ if (auto* init = for_loop->initializer) {
+ return b.Block(ctx.Clone(init), loop);
+ }
+ return loop;
+ }
}
- }
- return nullptr;
- });
- }
-
- void ElseIfsToElseWithNestedIfs() {
- // Decompose 'else-if' statements into 'else { if }' blocks.
- ctx.ReplaceAll(
- [&](const ast::IfStatement* else_if) -> const ast::Statement* {
- if (!else_ifs.count(else_if)) {
return nullptr;
- }
- auto& else_if_info = else_ifs[else_if];
-
- // Build the else block's body statements, starting with let decls for
- // the conditional expression.
- auto& body_stmts = else_if_info.cond_decls;
-
- // Move the 'else-if' into the new `else` block as a plain 'if'.
- auto* cond = ctx.Clone(else_if->condition);
- auto* body = ctx.Clone(else_if->body);
- auto* new_if = b.If(cond, body, ctx.Clone(else_if->else_statement));
- body_stmts.emplace_back(new_if);
-
- // Replace the 'else-if' with the new 'else' block.
- return b.Block(body_stmts);
});
- }
-
- public:
- /// Constructor
- /// @param ctx_in the clone context
- explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
-
- /// Hoists `expr` to a `let` or `var` with optional `decl_name`, inserting it
- /// before `before_expr`.
- /// @param before_expr expression to insert `expr` before
- /// @param expr expression to hoist
- /// @param as_const hoist to `let` if true, otherwise to `var`
- /// @param decl_name optional name to use for the variable/constant name
- /// @return true on success
- bool Add(const sem::Expression* before_expr,
- const ast::Expression* expr,
- bool as_const,
- const char* decl_name) {
- auto name = b.Symbols().New(decl_name);
-
- // Construct the let/var that holds the hoisted expr
- auto* v = as_const ? b.Let(name, nullptr, ctx.Clone(expr))
- : b.Var(name, nullptr, ctx.Clone(expr));
- auto* decl = b.Decl(v);
-
- if (!InsertBefore(before_expr->Stmt(), decl)) {
- return false;
}
- // Replace the initializer expression with a reference to the let
- ctx.Replace(expr, b.Expr(name));
- return true;
- }
+ void ElseIfsToElseWithNestedIfs() {
+ // Decompose 'else-if' statements into 'else { if }' blocks.
+ ctx.ReplaceAll([&](const ast::IfStatement* else_if) -> const ast::Statement* {
+ if (!else_ifs.count(else_if)) {
+ return nullptr;
+ }
+ auto& else_if_info = else_ifs[else_if];
- /// Inserts `stmt` before `before_stmt`, possibly marking a for-loop to be
- /// converted to a loop, or an else-if to an else { if }. If `decl` is
- /// nullptr, for-loop and else-if conversions are marked, but no hoisting
- /// takes place.
- /// @param before_stmt statement to insert `stmt` before
- /// @param stmt statement to insert
- /// @return true on success
- bool InsertBefore(const sem::Statement* before_stmt,
- const ast::Statement* stmt) {
- auto* ip = before_stmt->Declaration();
+ // Build the else block's body statements, starting with let decls for
+ // the conditional expression.
+ auto& body_stmts = else_if_info.cond_decls;
- auto* else_if = before_stmt->As<sem::IfStatement>();
- if (else_if && else_if->Parent()->Is<sem::IfStatement>()) {
- // Insertion point is an 'else if' condition.
- // Need to convert 'else if' to 'else { if }'.
- auto& else_if_info = else_ifs[else_if->Declaration()];
+ // Move the 'else-if' into the new `else` block as a plain 'if'.
+ auto* cond = ctx.Clone(else_if->condition);
+ auto* body = ctx.Clone(else_if->body);
+ auto* new_if = b.If(cond, body, ctx.Clone(else_if->else_statement));
+ body_stmts.emplace_back(new_if);
- // Index the map to convert this else if, even if `stmt` is nullptr.
- auto& decls = else_if_info.cond_decls;
- if (stmt) {
- decls.emplace_back(stmt);
- }
- return true;
+ // Replace the 'else-if' with the new 'else' block.
+ return b.Block(body_stmts);
+ });
}
- if (auto* fl = before_stmt->As<sem::ForLoopStatement>()) {
- // Insertion point is a for-loop condition.
- // For-loop needs to be decomposed to a loop.
+ public:
+ /// Constructor
+ /// @param ctx_in the clone context
+ explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
- // Index the map to convert this for-loop, even if `stmt` is nullptr.
- auto& decls = loops[fl].cond_decls;
- if (stmt) {
- decls.emplace_back(stmt);
- }
- return true;
- }
+ /// Hoists `expr` to a `let` or `var` with optional `decl_name`, inserting it
+ /// before `before_expr`.
+ /// @param before_expr expression to insert `expr` before
+ /// @param expr expression to hoist
+ /// @param as_const hoist to `let` if true, otherwise to `var`
+ /// @param decl_name optional name to use for the variable/constant name
+ /// @return true on success
+ bool Add(const sem::Expression* before_expr,
+ const ast::Expression* expr,
+ bool as_const,
+ const char* decl_name) {
+ auto name = b.Symbols().New(decl_name);
- auto* parent = before_stmt->Parent(); // The statement's parent
- if (auto* block = parent->As<sem::BlockStatement>()) {
- // Insert point sits in a block. Simple case.
- // Insert the stmt before the parent statement.
- if (stmt) {
- ctx.InsertBefore(block->Declaration()->statements, ip, stmt);
- }
- return true;
- }
+ // Construct the let/var that holds the hoisted expr
+ auto* v = as_const ? b.Let(name, nullptr, ctx.Clone(expr))
+ : b.Var(name, nullptr, ctx.Clone(expr));
+ auto* decl = b.Decl(v);
- if (auto* fl = parent->As<sem::ForLoopStatement>()) {
- // Insertion point is a for-loop initializer or continuing statement.
- // These require special care.
- if (fl->Declaration()->initializer == ip) {
- // Insertion point is a for-loop initializer.
- // Insert the new statement above the for-loop.
- if (stmt) {
- ctx.InsertBefore(fl->Block()->Declaration()->statements,
- fl->Declaration(), stmt);
+ if (!InsertBefore(before_expr->Stmt(), decl)) {
+ return false;
}
+
+ // Replace the initializer expression with a reference to the let
+ ctx.Replace(expr, b.Expr(name));
return true;
- }
-
- if (fl->Declaration()->continuing == ip) {
- // Insertion point is a for-loop continuing statement.
- // For-loop needs to be decomposed to a loop.
-
- // Index the map to convert this for-loop, even if `stmt` is nullptr.
- auto& decls = loops[fl].cont_decls;
- if (stmt) {
- decls.emplace_back(stmt);
- }
- return true;
- }
-
- TINT_ICE(Transform, b.Diagnostics())
- << "unhandled use of expression in for-loop";
- return false;
}
- TINT_ICE(Transform, b.Diagnostics())
- << "unhandled expression parent statement type: "
- << parent->TypeInfo().name;
- return false;
- }
+ /// Inserts `stmt` before `before_stmt`, possibly marking a for-loop to be
+ /// converted to a loop, or an else-if to an else { if }. If `decl` is
+ /// nullptr, for-loop and else-if conversions are marked, but no hoisting
+ /// takes place.
+ /// @param before_stmt statement to insert `stmt` before
+ /// @param stmt statement to insert
+ /// @return true on success
+ bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt) {
+ auto* ip = before_stmt->Declaration();
- /// Use to signal that we plan on hoisting a decl before `before_expr`. This
- /// will convert 'for-loop's to 'loop's and 'else-if's to 'else {if}'s if
- /// needed.
- /// @param before_expr expression we would hoist a decl before
- /// @return true on success
- bool Prepare(const sem::Expression* before_expr) {
- return InsertBefore(before_expr->Stmt(), nullptr);
- }
+ auto* else_if = before_stmt->As<sem::IfStatement>();
+ if (else_if && else_if->Parent()->Is<sem::IfStatement>()) {
+ // Insertion point is an 'else if' condition.
+ // Need to convert 'else if' to 'else { if }'.
+ auto& else_if_info = else_ifs[else_if->Declaration()];
- /// Applies any scheduled insertions from previous calls to Add() to
- /// CloneContext. Call this once before ctx.Clone().
- /// @return true on success
- bool Apply() {
- ForLoopsToLoops();
- ElseIfsToElseWithNestedIfs();
- return true;
- }
+ // Index the map to convert this else if, even if `stmt` is nullptr.
+ auto& decls = else_if_info.cond_decls;
+ if (stmt) {
+ decls.emplace_back(stmt);
+ }
+ return true;
+ }
+
+ if (auto* fl = before_stmt->As<sem::ForLoopStatement>()) {
+ // Insertion point is a for-loop condition.
+ // For-loop needs to be decomposed to a loop.
+
+ // Index the map to convert this for-loop, even if `stmt` is nullptr.
+ auto& decls = loops[fl].cond_decls;
+ if (stmt) {
+ decls.emplace_back(stmt);
+ }
+ return true;
+ }
+
+ auto* parent = before_stmt->Parent(); // The statement's parent
+ if (auto* block = parent->As<sem::BlockStatement>()) {
+ // Insert point sits in a block. Simple case.
+ // Insert the stmt before the parent statement.
+ if (stmt) {
+ ctx.InsertBefore(block->Declaration()->statements, ip, stmt);
+ }
+ return true;
+ }
+
+ if (auto* fl = parent->As<sem::ForLoopStatement>()) {
+ // Insertion point is a for-loop initializer or continuing statement.
+ // These require special care.
+ if (fl->Declaration()->initializer == ip) {
+ // Insertion point is a for-loop initializer.
+ // Insert the new statement above the for-loop.
+ if (stmt) {
+ ctx.InsertBefore(fl->Block()->Declaration()->statements, fl->Declaration(),
+ stmt);
+ }
+ return true;
+ }
+
+ if (fl->Declaration()->continuing == ip) {
+ // Insertion point is a for-loop continuing statement.
+ // For-loop needs to be decomposed to a loop.
+
+ // Index the map to convert this for-loop, even if `stmt` is nullptr.
+ auto& decls = loops[fl].cont_decls;
+ if (stmt) {
+ decls.emplace_back(stmt);
+ }
+ return true;
+ }
+
+ TINT_ICE(Transform, b.Diagnostics()) << "unhandled use of expression in for-loop";
+ return false;
+ }
+
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled expression parent statement type: " << parent->TypeInfo().name;
+ return false;
+ }
+
+ /// Use to signal that we plan on hoisting a decl before `before_expr`. This
+ /// will convert 'for-loop's to 'loop's and 'else-if's to 'else {if}'s if
+ /// needed.
+ /// @param before_expr expression we would hoist a decl before
+ /// @return true on success
+ bool Prepare(const sem::Expression* before_expr) {
+ return InsertBefore(before_expr->Stmt(), nullptr);
+ }
+
+ /// Applies any scheduled insertions from previous calls to Add() to
+ /// CloneContext. Call this once before ctx.Clone().
+ /// @return true on success
+ bool Apply() {
+ ForLoopsToLoops();
+ ElseIfsToElseWithNestedIfs();
+ return true;
+ }
};
-HoistToDeclBefore::HoistToDeclBefore(CloneContext& ctx)
- : state_(std::make_unique<State>(ctx)) {}
+HoistToDeclBefore::HoistToDeclBefore(CloneContext& ctx) : state_(std::make_unique<State>(ctx)) {}
HoistToDeclBefore::~HoistToDeclBefore() {}
@@ -276,20 +270,20 @@
const ast::Expression* expr,
bool as_const,
const char* decl_name) {
- return state_->Add(before_expr, expr, as_const, decl_name);
+ return state_->Add(before_expr, expr, as_const, decl_name);
}
bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt,
const ast::Statement* stmt) {
- return state_->InsertBefore(before_stmt, stmt);
+ return state_->InsertBefore(before_stmt, stmt);
}
bool HoistToDeclBefore::Prepare(const sem::Expression* before_expr) {
- return state_->Prepare(before_expr);
+ return state_->Prepare(before_expr);
}
bool HoistToDeclBefore::Apply() {
- return state_->Apply();
+ return state_->Apply();
}
} // namespace tint::transform
diff --git a/src/tint/transform/utils/hoist_to_decl_before.h b/src/tint/transform/utils/hoist_to_decl_before.h
index 2d94f52..d0b96e0 100644
--- a/src/tint/transform/utils/hoist_to_decl_before.h
+++ b/src/tint/transform/utils/hoist_to_decl_before.h
@@ -26,49 +26,48 @@
/// expressions, possibly converting 'for-loop's to 'loop's and 'else-if's to
// 'else {if}'s.
class HoistToDeclBefore {
- public:
- /// Constructor
- /// @param ctx the clone context
- explicit HoistToDeclBefore(CloneContext& ctx);
+ public:
+ /// Constructor
+ /// @param ctx the clone context
+ explicit HoistToDeclBefore(CloneContext& ctx);
- /// Destructor
- ~HoistToDeclBefore();
+ /// Destructor
+ ~HoistToDeclBefore();
- /// Hoists `expr` to a `let` or `var` with optional `decl_name`, inserting it
- /// before `before_expr`.
- /// @param before_expr expression to insert `expr` before
- /// @param expr expression to hoist
- /// @param as_const hoist to `let` if true, otherwise to `var`
- /// @param decl_name optional name to use for the variable/constant name
- /// @return true on success
- bool Add(const sem::Expression* before_expr,
- const ast::Expression* expr,
- bool as_const,
- const char* decl_name = "");
+ /// Hoists `expr` to a `let` or `var` with optional `decl_name`, inserting it
+ /// before `before_expr`.
+ /// @param before_expr expression to insert `expr` before
+ /// @param expr expression to hoist
+ /// @param as_const hoist to `let` if true, otherwise to `var`
+ /// @param decl_name optional name to use for the variable/constant name
+ /// @return true on success
+ bool Add(const sem::Expression* before_expr,
+ const ast::Expression* expr,
+ bool as_const,
+ const char* decl_name = "");
- /// Inserts `stmt` before `before_stmt`, possibly converting 'for-loop's to
- /// 'loop's if necessary.
- /// @param before_stmt statement to insert `stmt` before
- /// @param stmt statement to insert
- /// @return true on success
- bool InsertBefore(const sem::Statement* before_stmt,
- const ast::Statement* stmt);
+ /// Inserts `stmt` before `before_stmt`, possibly converting 'for-loop's to
+ /// 'loop's if necessary.
+ /// @param before_stmt statement to insert `stmt` before
+ /// @param stmt statement to insert
+ /// @return true on success
+ bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt);
- /// Use to signal that we plan on hoisting a decl before `before_expr`. This
- /// will convert 'for-loop's to 'loop's and 'else-if's to 'else {if}'s if
- /// needed.
- /// @param before_expr expression we would hoist a decl before
- /// @return true on success
- bool Prepare(const sem::Expression* before_expr);
+ /// Use to signal that we plan on hoisting a decl before `before_expr`. This
+ /// will convert 'for-loop's to 'loop's and 'else-if's to 'else {if}'s if
+ /// needed.
+ /// @param before_expr expression we would hoist a decl before
+ /// @return true on success
+ bool Prepare(const sem::Expression* before_expr);
- /// Applies any scheduled insertions from previous calls to Add() to
- /// CloneContext. Call this once before ctx.Clone().
- /// @return true on success
- bool Apply();
+ /// Applies any scheduled insertions from previous calls to Add() to
+ /// CloneContext. Call this once before ctx.Clone().
+ /// @return true on success
+ bool Apply();
- private:
- class State;
- std::unique_ptr<State> state_;
+ private:
+ class State;
+ std::unique_ptr<State> state_;
};
} // namespace tint::transform
diff --git a/src/tint/transform/utils/hoist_to_decl_before_test.cc b/src/tint/transform/utils/hoist_to_decl_before_test.cc
index e313d91..bce27df 100644
--- a/src/tint/transform/utils/hoist_to_decl_before_test.cc
+++ b/src/tint/transform/utils/hoist_to_decl_before_test.cc
@@ -27,60 +27,59 @@
using HoistToDeclBeforeTest = ::testing::Test;
TEST_F(HoistToDeclBeforeTest, VarInit) {
- // fn f() {
- // var a = 1;
- // }
- ProgramBuilder b;
- auto* expr = b.Expr(1);
- auto* var = b.Decl(b.Var("a", nullptr, expr));
- b.Func("f", {}, b.ty.void_(), {var});
+ // fn f() {
+ // var a = 1;
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Expr(1);
+ auto* var = b.Decl(b.Var("a", nullptr, expr));
+ b.Func("f", {}, b.ty.void_(), {var});
- Program original(std::move(b));
- ProgramBuilder cloned_b;
- CloneContext ctx(&cloned_b, &original);
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
- HoistToDeclBefore hoistToDeclBefore(ctx);
- auto* sem_expr = ctx.src->Sem().Get(expr);
- hoistToDeclBefore.Add(sem_expr, expr, true);
- hoistToDeclBefore.Apply();
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, true);
+ hoistToDeclBefore.Apply();
- ctx.Clone();
- Program cloned(std::move(cloned_b));
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
let tint_symbol = 1;
var a = tint_symbol;
}
)";
- EXPECT_EQ(expect, str(cloned));
+ EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, ForLoopInit) {
- // fn f() {
- // for(var a = 1; true; ) {
- // }
- // }
- ProgramBuilder b;
- auto* expr = b.Expr(1);
- auto* s =
- b.For(b.Decl(b.Var("a", nullptr, expr)), b.Expr(true), {}, b.Block());
- b.Func("f", {}, b.ty.void_(), {s});
+ // fn f() {
+ // for(var a = 1; true; ) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Expr(1);
+ auto* s = b.For(b.Decl(b.Var("a", nullptr, expr)), b.Expr(true), {}, b.Block());
+ b.Func("f", {}, b.ty.void_(), {s});
- Program original(std::move(b));
- ProgramBuilder cloned_b;
- CloneContext ctx(&cloned_b, &original);
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
- HoistToDeclBefore hoistToDeclBefore(ctx);
- auto* sem_expr = ctx.src->Sem().Get(expr);
- hoistToDeclBefore.Add(sem_expr, expr, true);
- hoistToDeclBefore.Apply();
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, true);
+ hoistToDeclBefore.Apply();
- ctx.Clone();
- Program cloned(std::move(cloned_b));
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
let tint_symbol = 1;
for(var a = tint_symbol; true; ) {
@@ -88,34 +87,34 @@
}
)";
- EXPECT_EQ(expect, str(cloned));
+ EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, ForLoopCond) {
- // fn f() {
- // var a : bool;
- // for(; a; ) {
- // }
- // }
- ProgramBuilder b;
- auto* var = b.Decl(b.Var("a", b.ty.bool_()));
- auto* expr = b.Expr("a");
- auto* s = b.For({}, expr, {}, b.Block());
- b.Func("f", {}, b.ty.void_(), {var, s});
+ // fn f() {
+ // var a : bool;
+ // for(; a; ) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* var = b.Decl(b.Var("a", b.ty.bool_()));
+ auto* expr = b.Expr("a");
+ auto* s = b.For({}, expr, {}, b.Block());
+ b.Func("f", {}, b.ty.void_(), {var, s});
- Program original(std::move(b));
- ProgramBuilder cloned_b;
- CloneContext ctx(&cloned_b, &original);
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
- HoistToDeclBefore hoistToDeclBefore(ctx);
- auto* sem_expr = ctx.src->Sem().Get(expr);
- hoistToDeclBefore.Add(sem_expr, expr, true);
- hoistToDeclBefore.Apply();
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, true);
+ hoistToDeclBefore.Apply();
- ctx.Clone();
- Program cloned(std::move(cloned_b));
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var a : bool;
loop {
@@ -129,33 +128,32 @@
}
)";
- EXPECT_EQ(expect, str(cloned));
+ EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, ForLoopCont) {
- // fn f() {
- // for(; true; var a = 1) {
- // }
- // }
- ProgramBuilder b;
- auto* expr = b.Expr(1);
- auto* s =
- b.For({}, b.Expr(true), b.Decl(b.Var("a", nullptr, expr)), b.Block());
- b.Func("f", {}, b.ty.void_(), {s});
+ // fn f() {
+ // for(; true; var a = 1) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Expr(1);
+ auto* s = b.For({}, b.Expr(true), b.Decl(b.Var("a", nullptr, expr)), b.Block());
+ b.Func("f", {}, b.ty.void_(), {s});
- Program original(std::move(b));
- ProgramBuilder cloned_b;
- CloneContext ctx(&cloned_b, &original);
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
- HoistToDeclBefore hoistToDeclBefore(ctx);
- auto* sem_expr = ctx.src->Sem().Get(expr);
- hoistToDeclBefore.Add(sem_expr, expr, true);
- hoistToDeclBefore.Apply();
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, true);
+ hoistToDeclBefore.Apply();
- ctx.Clone();
- Program cloned(std::move(cloned_b));
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
loop {
if (!(true)) {
@@ -172,38 +170,38 @@
}
)";
- EXPECT_EQ(expect, str(cloned));
+ EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, ElseIf) {
- // fn f() {
- // var a : bool;
- // if (true) {
- // } else if (a) {
- // } else {
- // }
- // }
- ProgramBuilder b;
- auto* var = b.Decl(b.Var("a", b.ty.bool_()));
- auto* expr = b.Expr("a");
- auto* s = b.If(b.Expr(true), b.Block(), //
- b.If(expr, b.Block(), //
- b.Block()));
- b.Func("f", {}, b.ty.void_(), {var, s});
+ // fn f() {
+ // var a : bool;
+ // if (true) {
+ // } else if (a) {
+ // } else {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* var = b.Decl(b.Var("a", b.ty.bool_()));
+ auto* expr = b.Expr("a");
+ auto* s = b.If(b.Expr(true), b.Block(), //
+ b.If(expr, b.Block(), //
+ b.Block()));
+ b.Func("f", {}, b.ty.void_(), {var, s});
- Program original(std::move(b));
- ProgramBuilder cloned_b;
- CloneContext ctx(&cloned_b, &original);
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
- HoistToDeclBefore hoistToDeclBefore(ctx);
- auto* sem_expr = ctx.src->Sem().Get(expr);
- hoistToDeclBefore.Add(sem_expr, expr, true);
- hoistToDeclBefore.Apply();
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, true);
+ hoistToDeclBefore.Apply();
- ctx.Clone();
- Program cloned(std::move(cloned_b));
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var a : bool;
if (true) {
@@ -216,33 +214,33 @@
}
)";
- EXPECT_EQ(expect, str(cloned));
+ EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, Array1D) {
- // fn f() {
- // var a : array<i32, 10>;
- // var b = a[0];
- // }
- ProgramBuilder b;
- auto* var1 = b.Decl(b.Var("a", b.ty.array<ProgramBuilder::i32, 10>()));
- auto* expr = b.IndexAccessor("a", 0);
- auto* var2 = b.Decl(b.Var("b", nullptr, expr));
- b.Func("f", {}, b.ty.void_(), {var1, var2});
+ // fn f() {
+ // var a : array<i32, 10>;
+ // var b = a[0];
+ // }
+ ProgramBuilder b;
+ auto* var1 = b.Decl(b.Var("a", b.ty.array<ProgramBuilder::i32, 10>()));
+ auto* expr = b.IndexAccessor("a", 0);
+ auto* var2 = b.Decl(b.Var("b", nullptr, expr));
+ b.Func("f", {}, b.ty.void_(), {var1, var2});
- Program original(std::move(b));
- ProgramBuilder cloned_b;
- CloneContext ctx(&cloned_b, &original);
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
- HoistToDeclBefore hoistToDeclBefore(ctx);
- auto* sem_expr = ctx.src->Sem().Get(expr);
- hoistToDeclBefore.Add(sem_expr, expr, true);
- hoistToDeclBefore.Apply();
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, true);
+ hoistToDeclBefore.Apply();
- ctx.Clone();
- Program cloned(std::move(cloned_b));
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var a : array<i32, 10>;
let tint_symbol = a[0];
@@ -250,35 +248,34 @@
}
)";
- EXPECT_EQ(expect, str(cloned));
+ EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, Array2D) {
- // fn f() {
- // var a : array<array<i32, 10>, 10>;
- // var b = a[0][0];
- // }
- ProgramBuilder b;
+ // fn f() {
+ // var a : array<array<i32, 10>, 10>;
+ // var b = a[0][0];
+ // }
+ ProgramBuilder b;
- auto* var1 =
- b.Decl(b.Var("a", b.ty.array(b.ty.array<ProgramBuilder::i32, 10>(), 10)));
- auto* expr = b.IndexAccessor(b.IndexAccessor("a", 0), 0);
- auto* var2 = b.Decl(b.Var("b", nullptr, expr));
- b.Func("f", {}, b.ty.void_(), {var1, var2});
+ auto* var1 = b.Decl(b.Var("a", b.ty.array(b.ty.array<ProgramBuilder::i32, 10>(), 10)));
+ auto* expr = b.IndexAccessor(b.IndexAccessor("a", 0), 0);
+ auto* var2 = b.Decl(b.Var("b", nullptr, expr));
+ b.Func("f", {}, b.ty.void_(), {var1, var2});
- Program original(std::move(b));
- ProgramBuilder cloned_b;
- CloneContext ctx(&cloned_b, &original);
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
- HoistToDeclBefore hoistToDeclBefore(ctx);
- auto* sem_expr = ctx.src->Sem().Get(expr);
- hoistToDeclBefore.Add(sem_expr, expr, true);
- hoistToDeclBefore.Apply();
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, true);
+ hoistToDeclBefore.Apply();
- ctx.Clone();
- Program cloned(std::move(cloned_b));
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var a : array<array<i32, 10>, 10>;
let tint_symbol = a[0][0];
@@ -286,34 +283,34 @@
}
)";
- EXPECT_EQ(expect, str(cloned));
+ EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, Prepare_ForLoopCond) {
- // fn f() {
- // var a : bool;
- // for(; a; ) {
- // }
- // }
- ProgramBuilder b;
- auto* var = b.Decl(b.Var("a", b.ty.bool_()));
- auto* expr = b.Expr("a");
- auto* s = b.For({}, expr, {}, b.Block());
- b.Func("f", {}, b.ty.void_(), {var, s});
+ // fn f() {
+ // var a : bool;
+ // for(; a; ) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* var = b.Decl(b.Var("a", b.ty.bool_()));
+ auto* expr = b.Expr("a");
+ auto* s = b.For({}, expr, {}, b.Block());
+ b.Func("f", {}, b.ty.void_(), {var, s});
- Program original(std::move(b));
- ProgramBuilder cloned_b;
- CloneContext ctx(&cloned_b, &original);
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
- HoistToDeclBefore hoistToDeclBefore(ctx);
- auto* sem_expr = ctx.src->Sem().Get(expr);
- hoistToDeclBefore.Prepare(sem_expr);
- hoistToDeclBefore.Apply();
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Prepare(sem_expr);
+ hoistToDeclBefore.Apply();
- ctx.Clone();
- Program cloned(std::move(cloned_b));
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var a : bool;
loop {
@@ -326,33 +323,32 @@
}
)";
- EXPECT_EQ(expect, str(cloned));
+ EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, Prepare_ForLoopCont) {
- // fn f() {
- // for(; true; var a = 1) {
- // }
- // }
- ProgramBuilder b;
- auto* expr = b.Expr(1);
- auto* s =
- b.For({}, b.Expr(true), b.Decl(b.Var("a", nullptr, expr)), b.Block());
- b.Func("f", {}, b.ty.void_(), {s});
+ // fn f() {
+ // for(; true; var a = 1) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Expr(1);
+ auto* s = b.For({}, b.Expr(true), b.Decl(b.Var("a", nullptr, expr)), b.Block());
+ b.Func("f", {}, b.ty.void_(), {s});
- Program original(std::move(b));
- ProgramBuilder cloned_b;
- CloneContext ctx(&cloned_b, &original);
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
- HoistToDeclBefore hoistToDeclBefore(ctx);
- auto* sem_expr = ctx.src->Sem().Get(expr);
- hoistToDeclBefore.Prepare(sem_expr);
- hoistToDeclBefore.Apply();
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Prepare(sem_expr);
+ hoistToDeclBefore.Apply();
- ctx.Clone();
- Program cloned(std::move(cloned_b));
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
loop {
if (!(true)) {
@@ -368,38 +364,38 @@
}
)";
- EXPECT_EQ(expect, str(cloned));
+ EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, Prepare_ElseIf) {
- // fn f() {
- // var a : bool;
- // if (true) {
- // } else if (a) {
- // } else {
- // }
- // }
- ProgramBuilder b;
- auto* var = b.Decl(b.Var("a", b.ty.bool_()));
- auto* expr = b.Expr("a");
- auto* s = b.If(b.Expr(true), b.Block(), //
- b.If(expr, b.Block(), //
- b.Block()));
- b.Func("f", {}, b.ty.void_(), {var, s});
+ // fn f() {
+ // var a : bool;
+ // if (true) {
+ // } else if (a) {
+ // } else {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* var = b.Decl(b.Var("a", b.ty.bool_()));
+ auto* expr = b.Expr("a");
+ auto* s = b.If(b.Expr(true), b.Block(), //
+ b.If(expr, b.Block(), //
+ b.Block()));
+ b.Func("f", {}, b.ty.void_(), {var, s});
- Program original(std::move(b));
- ProgramBuilder cloned_b;
- CloneContext ctx(&cloned_b, &original);
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
- HoistToDeclBefore hoistToDeclBefore(ctx);
- auto* sem_expr = ctx.src->Sem().Get(expr);
- hoistToDeclBefore.Prepare(sem_expr);
- hoistToDeclBefore.Apply();
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Prepare(sem_expr);
+ hoistToDeclBefore.Apply();
- ctx.Clone();
- Program cloned(std::move(cloned_b));
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var a : bool;
if (true) {
@@ -411,34 +407,34 @@
}
)";
- EXPECT_EQ(expect, str(cloned));
+ EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, InsertBefore_Block) {
- // fn foo() {
- // }
- // fn f() {
- // var a = 1;
- // }
- ProgramBuilder b;
- b.Func("foo", {}, b.ty.void_(), {});
- auto* var = b.Decl(b.Var("a", nullptr, b.Expr(1)));
- b.Func("f", {}, b.ty.void_(), {var});
+ // fn foo() {
+ // }
+ // fn f() {
+ // var a = 1;
+ // }
+ ProgramBuilder b;
+ b.Func("foo", {}, b.ty.void_(), {});
+ auto* var = b.Decl(b.Var("a", nullptr, b.Expr(1)));
+ b.Func("f", {}, b.ty.void_(), {var});
- Program original(std::move(b));
- ProgramBuilder cloned_b;
- CloneContext ctx(&cloned_b, &original);
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
- HoistToDeclBefore hoistToDeclBefore(ctx);
- auto* before_stmt = ctx.src->Sem().Get(var);
- auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
- hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
- hoistToDeclBefore.Apply();
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* before_stmt = ctx.src->Sem().Get(var);
+ auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+ hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
+ hoistToDeclBefore.Apply();
- ctx.Clone();
- Program cloned(std::move(cloned_b));
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
- auto* expect = R"(
+ auto* expect = R"(
fn foo() {
}
@@ -448,36 +444,36 @@
}
)";
- EXPECT_EQ(expect, str(cloned));
+ EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopInit) {
- // fn foo() {
- // }
- // fn f() {
- // for(var a = 1; true;) {
- // }
- // }
- ProgramBuilder b;
- b.Func("foo", {}, b.ty.void_(), {});
- auto* var = b.Decl(b.Var("a", nullptr, b.Expr(1)));
- auto* s = b.For(var, b.Expr(true), {}, b.Block());
- b.Func("f", {}, b.ty.void_(), {s});
+ // fn foo() {
+ // }
+ // fn f() {
+ // for(var a = 1; true;) {
+ // }
+ // }
+ ProgramBuilder b;
+ b.Func("foo", {}, b.ty.void_(), {});
+ auto* var = b.Decl(b.Var("a", nullptr, b.Expr(1)));
+ auto* s = b.For(var, b.Expr(true), {}, b.Block());
+ b.Func("f", {}, b.ty.void_(), {s});
- Program original(std::move(b));
- ProgramBuilder cloned_b;
- CloneContext ctx(&cloned_b, &original);
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
- HoistToDeclBefore hoistToDeclBefore(ctx);
- auto* before_stmt = ctx.src->Sem().Get(var);
- auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
- hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
- hoistToDeclBefore.Apply();
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* before_stmt = ctx.src->Sem().Get(var);
+ auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+ hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
+ hoistToDeclBefore.Apply();
- ctx.Clone();
- Program cloned(std::move(cloned_b));
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
- auto* expect = R"(
+ auto* expect = R"(
fn foo() {
}
@@ -488,38 +484,38 @@
}
)";
- EXPECT_EQ(expect, str(cloned));
+ EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopCont) {
- // fn foo() {
- // }
- // fn f() {
- // var a = 1;
- // for(; true; a+=1) {
- // }
- // }
- ProgramBuilder b;
- b.Func("foo", {}, b.ty.void_(), {});
- auto* var = b.Decl(b.Var("a", nullptr, b.Expr(1)));
- auto* cont = b.CompoundAssign("a", b.Expr(1), ast::BinaryOp::kAdd);
- auto* s = b.For({}, b.Expr(true), cont, b.Block());
- b.Func("f", {}, b.ty.void_(), {var, s});
+ // fn foo() {
+ // }
+ // fn f() {
+ // var a = 1;
+ // for(; true; a+=1) {
+ // }
+ // }
+ ProgramBuilder b;
+ b.Func("foo", {}, b.ty.void_(), {});
+ auto* var = b.Decl(b.Var("a", nullptr, b.Expr(1)));
+ auto* cont = b.CompoundAssign("a", b.Expr(1), ast::BinaryOp::kAdd);
+ auto* s = b.For({}, b.Expr(true), cont, b.Block());
+ b.Func("f", {}, b.ty.void_(), {var, s});
- Program original(std::move(b));
- ProgramBuilder cloned_b;
- CloneContext ctx(&cloned_b, &original);
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
- HoistToDeclBefore hoistToDeclBefore(ctx);
- auto* before_stmt = ctx.src->Sem().Get(cont->As<ast::Statement>());
- auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
- hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
- hoistToDeclBefore.Apply();
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* before_stmt = ctx.src->Sem().Get(cont->As<ast::Statement>());
+ auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+ hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
+ hoistToDeclBefore.Apply();
- ctx.Clone();
- Program cloned(std::move(cloned_b));
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
- auto* expect = R"(
+ auto* expect = R"(
fn foo() {
}
@@ -540,41 +536,41 @@
}
)";
- EXPECT_EQ(expect, str(cloned));
+ EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, InsertBefore_ElseIf) {
- // fn foo() {
- // }
- // fn f() {
- // var a : bool;
- // if (true) {
- // } else if (a) {
- // } else {
- // }
- // }
- ProgramBuilder b;
- b.Func("foo", {}, b.ty.void_(), {});
- auto* var = b.Decl(b.Var("a", b.ty.bool_()));
- auto* elseif = b.If(b.Expr("a"), b.Block(), b.Block());
- auto* s = b.If(b.Expr(true), b.Block(), //
- elseif);
- b.Func("f", {}, b.ty.void_(), {var, s});
+ // fn foo() {
+ // }
+ // fn f() {
+ // var a : bool;
+ // if (true) {
+ // } else if (a) {
+ // } else {
+ // }
+ // }
+ ProgramBuilder b;
+ b.Func("foo", {}, b.ty.void_(), {});
+ auto* var = b.Decl(b.Var("a", b.ty.bool_()));
+ auto* elseif = b.If(b.Expr("a"), b.Block(), b.Block());
+ auto* s = b.If(b.Expr(true), b.Block(), //
+ elseif);
+ b.Func("f", {}, b.ty.void_(), {var, s});
- Program original(std::move(b));
- ProgramBuilder cloned_b;
- CloneContext ctx(&cloned_b, &original);
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
- HoistToDeclBefore hoistToDeclBefore(ctx);
- auto* before_stmt = ctx.src->Sem().Get(elseif);
- auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
- hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
- hoistToDeclBefore.Apply();
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* before_stmt = ctx.src->Sem().Get(elseif);
+ auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+ hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
+ hoistToDeclBefore.Apply();
- ctx.Clone();
- Program cloned(std::move(cloned_b));
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
- auto* expect = R"(
+ auto* expect = R"(
fn foo() {
}
@@ -590,7 +586,7 @@
}
)";
- EXPECT_EQ(expect, str(cloned));
+ EXPECT_EQ(expect, str(cloned));
}
} // namespace
diff --git a/src/tint/transform/var_for_dynamic_index.cc b/src/tint/transform/var_for_dynamic_index.cc
index ccd1215..aaebdc7 100644
--- a/src/tint/transform/var_for_dynamic_index.cc
+++ b/src/tint/transform/var_for_dynamic_index.cc
@@ -22,47 +22,43 @@
VarForDynamicIndex::~VarForDynamicIndex() = default;
-void VarForDynamicIndex::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- HoistToDeclBefore hoist_to_decl_before(ctx);
+void VarForDynamicIndex::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ HoistToDeclBefore hoist_to_decl_before(ctx);
- // Extracts array and matrix values that are dynamically indexed to a
- // temporary `var` local that is then indexed.
- auto dynamic_index_to_var =
- [&](const ast::IndexAccessorExpression* access_expr) {
+ // Extracts array and matrix values that are dynamically indexed to a
+ // temporary `var` local that is then indexed.
+ auto dynamic_index_to_var = [&](const ast::IndexAccessorExpression* access_expr) {
auto* index_expr = access_expr->index;
auto* object_expr = access_expr->object;
auto& sem = ctx.src->Sem();
if (sem.Get(index_expr)->ConstantValue()) {
- // Index expression resolves to a compile time value.
- // As this isn't a dynamic index, we can ignore this.
- return true;
+ // Index expression resolves to a compile time value.
+ // As this isn't a dynamic index, we can ignore this.
+ return true;
}
auto* indexed = sem.Get(object_expr);
if (!indexed->Type()->IsAnyOf<sem::Array, sem::Matrix>()) {
- // We only care about array and matrices.
- return true;
+ // We only care about array and matrices.
+ return true;
}
// TODO(bclayton): group multiple accesses in the same object.
// e.g. arr[i] + arr[i+1] // Don't create two vars for this
- return hoist_to_decl_before.Add(indexed, object_expr, false,
- "var_for_index");
- };
+ return hoist_to_decl_before.Add(indexed, object_expr, false, "var_for_index");
+ };
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- if (auto* access_expr = node->As<ast::IndexAccessorExpression>()) {
- if (!dynamic_index_to_var(access_expr)) {
- return;
- }
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* access_expr = node->As<ast::IndexAccessorExpression>()) {
+ if (!dynamic_index_to_var(access_expr)) {
+ return;
+ }
+ }
}
- }
- hoist_to_decl_before.Apply();
- ctx.Clone();
+ hoist_to_decl_before.Apply();
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/var_for_dynamic_index.h b/src/tint/transform/var_for_dynamic_index.h
index e7e2815..39ef2f2 100644
--- a/src/tint/transform/var_for_dynamic_index.h
+++ b/src/tint/transform/var_for_dynamic_index.h
@@ -24,23 +24,21 @@
/// transform is used by the SPIR-V writer as there is no SPIR-V instruction
/// that can dynamically index a non-pointer composite.
class VarForDynamicIndex : public Transform {
- public:
- /// Constructor
- VarForDynamicIndex();
+ public:
+ /// Constructor
+ VarForDynamicIndex();
- /// Destructor
- ~VarForDynamicIndex() override;
+ /// Destructor
+ ~VarForDynamicIndex() override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/var_for_dynamic_index_test.cc b/src/tint/transform/var_for_dynamic_index_test.cc
index a222ad7..ca767c9 100644
--- a/src/tint/transform/var_for_dynamic_index_test.cc
+++ b/src/tint/transform/var_for_dynamic_index_test.cc
@@ -23,16 +23,16 @@
using VarForDynamicIndexTest = TransformTest;
TEST_F(VarForDynamicIndexTest, EmptyModule) {
- auto* src = "";
- auto* expect = "";
+ auto* src = "";
+ auto* expect = "";
- auto got = Run<ForLoopToLoop, VarForDynamicIndex>(src);
+ auto got = Run<ForLoopToLoop, VarForDynamicIndex>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, ArrayIndexDynamic) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
let p = array<i32, 4>(1, 2, 3, 4);
@@ -40,7 +40,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i : i32;
let p = array<i32, 4>(1, 2, 3, 4);
@@ -49,14 +49,14 @@
}
)";
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, MatrixIndexDynamic) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
@@ -64,7 +64,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i : i32;
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
@@ -73,14 +73,14 @@
}
)";
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, ArrayIndexDynamicChain) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
var j : i32;
@@ -89,12 +89,12 @@
}
)";
- // TODO(bclayton): Optimize this case:
- // This output is not as efficient as it could be.
- // We only actually need to hoist the inner-most array to a `var`
- // (`var_for_index`), as later indexing operations will be working with
- // references, not values.
- auto* expect = R"(
+ // TODO(bclayton): Optimize this case:
+ // This output is not as efficient as it could be.
+ // We only actually need to hoist the inner-most array to a `var`
+ // (`var_for_index`), as later indexing operations will be working with
+ // references, not values.
+ auto* expect = R"(
fn f() {
var i : i32;
var j : i32;
@@ -105,14 +105,14 @@
}
)";
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, ArrayIndexInForLoopInit) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
let p = array<array<i32, 2>, 2>(array<i32, 2>(1, 2), array<i32, 2>(3, 4));
@@ -122,7 +122,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i : i32;
let p = array<array<i32, 2>, 2>(array<i32, 2>(1, 2), array<i32, 2>(3, 4));
@@ -133,14 +133,14 @@
}
)";
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, MatrixIndexInForLoopInit) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
@@ -150,7 +150,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i : i32;
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
@@ -161,14 +161,14 @@
}
)";
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, ArrayIndexInForLoopCond) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
let p = array<i32, 2>(1, 2);
@@ -178,7 +178,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i : i32;
let p = array<i32, 2>(1, 2);
@@ -194,14 +194,14 @@
}
)";
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, MatrixIndexInForLoopCond) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
@@ -211,7 +211,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i : i32;
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
@@ -227,14 +227,14 @@
}
)";
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, MatrixIndexInForLoopCondWithNestedIndex) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
@@ -247,7 +247,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i : i32;
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
@@ -267,14 +267,14 @@
}
)";
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, ArrayIndexInElseIf) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
let p = array<i32, 2>(1, 2);
@@ -286,7 +286,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i : i32;
let p = array<i32, 2>(1, 2);
@@ -301,14 +301,14 @@
}
)";
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, ArrayIndexInElseIfChain) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
let p = array<i32, 2>(1, 2);
@@ -328,7 +328,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i : i32;
let p = array<i32, 2>(1, 2);
@@ -354,14 +354,14 @@
}
)";
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, MatrixIndexInElseIf) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
@@ -373,7 +373,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i : i32;
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
@@ -388,14 +388,14 @@
}
)";
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, MatrixIndexInElseIfChain) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var i : i32;
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
@@ -415,7 +415,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
fn f() {
var i : i32;
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
@@ -441,46 +441,46 @@
}
)";
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, ArrayIndexLiteral) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let p = array<i32, 4>(1, 2, 3, 4);
let x = p[1];
}
)";
- auto* expect = src;
+ auto* expect = src;
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, MatrixIndexLiteral) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
let x = p[1];
}
)";
- auto* expect = src;
+ auto* expect = src;
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, ArrayIndexConstantLet) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let p = array<i32, 4>(1, 2, 3, 4);
let c = 1;
@@ -488,16 +488,16 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, MatrixIndexConstantLet) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
let c = 1;
@@ -505,16 +505,16 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, ArrayIndexLiteralChain) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let a = array<i32, 2>(1, 2);
let b = array<i32, 2>(3, 4);
@@ -523,28 +523,28 @@
}
)";
- auto* expect = src;
+ auto* expect = src;
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VarForDynamicIndexTest, MatrixIndexLiteralChain) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
let x = p[0][1];
}
)";
- auto* expect = src;
+ auto* expect = src;
- DataMap data;
- auto got = Run<VarForDynamicIndex>(src, data);
+ DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/vectorize_scalar_matrix_constructors.cc b/src/tint/transform/vectorize_scalar_matrix_constructors.cc
index 8af0272..80e8b1e 100644
--- a/src/tint/transform/vectorize_scalar_matrix_constructors.cc
+++ b/src/tint/transform/vectorize_scalar_matrix_constructors.cc
@@ -25,71 +25,63 @@
namespace tint::transform {
-VectorizeScalarMatrixConstructors::VectorizeScalarMatrixConstructors() =
- default;
+VectorizeScalarMatrixConstructors::VectorizeScalarMatrixConstructors() = default;
-VectorizeScalarMatrixConstructors::~VectorizeScalarMatrixConstructors() =
- default;
+VectorizeScalarMatrixConstructors::~VectorizeScalarMatrixConstructors() = default;
-bool VectorizeScalarMatrixConstructors::ShouldRun(const Program* program,
- const DataMap&) const {
- for (auto* node : program->ASTNodes().Objects()) {
- if (auto* call = program->Sem().Get<sem::Call>(node)) {
- if (call->Target()->Is<sem::TypeConstructor>() &&
- call->Type()->Is<sem::Matrix>()) {
- auto& args = call->Arguments();
- if (args.size() > 0 && args[0]->Type()->is_scalar()) {
- return true;
+bool VectorizeScalarMatrixConstructors::ShouldRun(const Program* program, const DataMap&) const {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (auto* call = program->Sem().Get<sem::Call>(node)) {
+ if (call->Target()->Is<sem::TypeConstructor>() && call->Type()->Is<sem::Matrix>()) {
+ auto& args = call->Arguments();
+ if (args.size() > 0 && args[0]->Type()->is_scalar()) {
+ return true;
+ }
+ }
}
- }
}
- }
- return false;
+ return false;
}
-void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- ctx.ReplaceAll(
- [&](const ast::CallExpression* expr) -> const ast::CallExpression* {
+void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
auto* call = ctx.src->Sem().Get(expr);
auto* ty_ctor = call->Target()->As<sem::TypeConstructor>();
if (!ty_ctor) {
- return nullptr;
+ return nullptr;
}
// Check if this is a matrix constructor with scalar arguments.
auto* mat_type = call->Type()->As<sem::Matrix>();
if (!mat_type) {
- return nullptr;
+ return nullptr;
}
auto& args = call->Arguments();
if (args.size() == 0) {
- return nullptr;
+ return nullptr;
}
if (!args[0]->Type()->is_scalar()) {
- return nullptr;
+ return nullptr;
}
// Build a list of vector expressions for each column.
ast::ExpressionList columns;
for (uint32_t c = 0; c < mat_type->columns(); c++) {
- // Build a list of scalar expressions for each value in the column.
- ast::ExpressionList row_values;
- for (uint32_t r = 0; r < mat_type->rows(); r++) {
- row_values.push_back(
- ctx.Clone(args[c * mat_type->rows() + r]->Declaration()));
- }
+ // Build a list of scalar expressions for each value in the column.
+ ast::ExpressionList row_values;
+ for (uint32_t r = 0; r < mat_type->rows(); r++) {
+ row_values.push_back(ctx.Clone(args[c * mat_type->rows() + r]->Declaration()));
+ }
- // Construct the column vector.
- auto* col = ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()),
- mat_type->rows(), row_values);
- columns.push_back(col);
+ // Construct the column vector.
+ auto* col =
+ ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(), row_values);
+ columns.push_back(col);
}
return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns);
- });
+ });
- ctx.Clone();
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/vectorize_scalar_matrix_constructors.h b/src/tint/transform/vectorize_scalar_matrix_constructors.h
index f1abb4f..83c4ce1 100644
--- a/src/tint/transform/vectorize_scalar_matrix_constructors.h
+++ b/src/tint/transform/vectorize_scalar_matrix_constructors.h
@@ -22,29 +22,26 @@
/// A transform that converts scalar matrix constructors to the vector form.
class VectorizeScalarMatrixConstructors
: public Castable<VectorizeScalarMatrixConstructors, Transform> {
- public:
- /// Constructor
- VectorizeScalarMatrixConstructors();
+ public:
+ /// Constructor
+ VectorizeScalarMatrixConstructors();
- /// Destructor
- ~VectorizeScalarMatrixConstructors() override;
+ /// Destructor
+ ~VectorizeScalarMatrixConstructors() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/vectorize_scalar_matrix_constructors_test.cc b/src/tint/transform/vectorize_scalar_matrix_constructors_test.cc
index edd68e2..242151a 100644
--- a/src/tint/transform/vectorize_scalar_matrix_constructors_test.cc
+++ b/src/tint/transform/vectorize_scalar_matrix_constructors_test.cc
@@ -23,87 +23,84 @@
namespace tint::transform {
namespace {
-using VectorizeScalarMatrixConstructorsTest =
- TransformTestWithParam<std::pair<uint32_t, uint32_t>>;
+using VectorizeScalarMatrixConstructorsTest = TransformTestWithParam<std::pair<uint32_t, uint32_t>>;
TEST_F(VectorizeScalarMatrixConstructorsTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
+ EXPECT_FALSE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
}
TEST_P(VectorizeScalarMatrixConstructorsTest, Basic) {
- uint32_t cols = GetParam().first;
- uint32_t rows = GetParam().second;
- std::string mat_type =
- "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
- std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
- std::string scalar_values;
- std::string vector_values;
- for (uint32_t c = 0; c < cols; c++) {
- if (c > 0) {
- vector_values += ", ";
- scalar_values += ", ";
+ uint32_t cols = GetParam().first;
+ uint32_t rows = GetParam().second;
+ std::string mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
+ std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
+ std::string scalar_values;
+ std::string vector_values;
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ vector_values += ", ";
+ scalar_values += ", ";
+ }
+ vector_values += vec_type + "(";
+ for (uint32_t r = 0; r < rows; r++) {
+ if (r > 0) {
+ scalar_values += ", ";
+ vector_values += ", ";
+ }
+ auto value = std::to_string(c * rows + r) + ".0";
+ scalar_values += value;
+ vector_values += value;
+ }
+ vector_values += ")";
}
- vector_values += vec_type + "(";
- for (uint32_t r = 0; r < rows; r++) {
- if (r > 0) {
- scalar_values += ", ";
- vector_values += ", ";
- }
- auto value = std::to_string(c * rows + r) + ".0";
- scalar_values += value;
- vector_values += value;
- }
- vector_values += ")";
- }
- std::string tmpl = R"(
+ std::string tmpl = R"(
@stage(fragment)
fn main() {
let m = ${matrix}(${values});
}
)";
- tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
- auto src = utils::ReplaceAll(tmpl, "${values}", scalar_values);
- auto expect = utils::ReplaceAll(tmpl, "${values}", vector_values);
+ tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
+ auto src = utils::ReplaceAll(tmpl, "${values}", scalar_values);
+ auto expect = utils::ReplaceAll(tmpl, "${values}", vector_values);
- EXPECT_TRUE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
+ EXPECT_TRUE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
- auto got = Run<VectorizeScalarMatrixConstructors>(src);
+ auto got = Run<VectorizeScalarMatrixConstructors>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_P(VectorizeScalarMatrixConstructorsTest, NonScalarConstructors) {
- uint32_t cols = GetParam().first;
- uint32_t rows = GetParam().second;
- std::string mat_type =
- "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
- std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
- std::string columns;
- for (uint32_t c = 0; c < cols; c++) {
- if (c > 0) {
- columns += ", ";
+ uint32_t cols = GetParam().first;
+ uint32_t rows = GetParam().second;
+ std::string mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
+ std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
+ std::string columns;
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ columns += ", ";
+ }
+ columns += vec_type + "()";
}
- columns += vec_type + "()";
- }
- std::string tmpl = R"(
+ std::string tmpl = R"(
@stage(fragment)
fn main() {
let m = ${matrix}(${columns});
}
)";
- tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
- auto src = utils::ReplaceAll(tmpl, "${columns}", columns);
- auto expect = src;
+ tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
+ auto src = utils::ReplaceAll(tmpl, "${columns}", columns);
+ auto expect = src;
- EXPECT_FALSE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
+ EXPECT_FALSE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
- auto got = Run<VectorizeScalarMatrixConstructors>(src);
+ auto got = Run<VectorizeScalarMatrixConstructors>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
INSTANTIATE_TEST_SUITE_P(VectorizeScalarMatrixConstructorsTest,
diff --git a/src/tint/transform/vertex_pulling.cc b/src/tint/transform/vertex_pulling.cc
index a30ea2a..d383860 100644
--- a/src/tint/transform/vertex_pulling.cc
+++ b/src/tint/transform/vertex_pulling.cc
@@ -35,10 +35,10 @@
/// The base type of a component.
/// The format type is either this type or a vector of this type.
enum class BaseType {
- kInvalid,
- kU32,
- kI32,
- kF32,
+ kInvalid,
+ kU32,
+ kI32,
+ kF32,
};
/// Writes the BaseType to the std::ostream.
@@ -46,17 +46,17 @@
/// @param format the BaseType to write
/// @returns out so calls can be chained
std::ostream& operator<<(std::ostream& out, BaseType format) {
- switch (format) {
- case BaseType::kInvalid:
- return out << "invalid";
- case BaseType::kU32:
- return out << "u32";
- case BaseType::kI32:
- return out << "i32";
- case BaseType::kF32:
- return out << "f32";
- }
- return out << "<unknown>";
+ switch (format) {
+ case BaseType::kInvalid:
+ return out << "invalid";
+ case BaseType::kU32:
+ return out << "u32";
+ case BaseType::kI32:
+ return out << "i32";
+ case BaseType::kF32:
+ return out << "f32";
+ }
+ return out << "<unknown>";
}
/// Writes the VertexFormat to the std::ostream.
@@ -64,837 +64,804 @@
/// @param format the VertexFormat to write
/// @returns out so calls can be chained
std::ostream& operator<<(std::ostream& out, VertexFormat format) {
- switch (format) {
- case VertexFormat::kUint8x2:
- return out << "uint8x2";
- case VertexFormat::kUint8x4:
- return out << "uint8x4";
- case VertexFormat::kSint8x2:
- return out << "sint8x2";
- case VertexFormat::kSint8x4:
- return out << "sint8x4";
- case VertexFormat::kUnorm8x2:
- return out << "unorm8x2";
- case VertexFormat::kUnorm8x4:
- return out << "unorm8x4";
- case VertexFormat::kSnorm8x2:
- return out << "snorm8x2";
- case VertexFormat::kSnorm8x4:
- return out << "snorm8x4";
- case VertexFormat::kUint16x2:
- return out << "uint16x2";
- case VertexFormat::kUint16x4:
- return out << "uint16x4";
- case VertexFormat::kSint16x2:
- return out << "sint16x2";
- case VertexFormat::kSint16x4:
- return out << "sint16x4";
- case VertexFormat::kUnorm16x2:
- return out << "unorm16x2";
- case VertexFormat::kUnorm16x4:
- return out << "unorm16x4";
- case VertexFormat::kSnorm16x2:
- return out << "snorm16x2";
- case VertexFormat::kSnorm16x4:
- return out << "snorm16x4";
- case VertexFormat::kFloat16x2:
- return out << "float16x2";
- case VertexFormat::kFloat16x4:
- return out << "float16x4";
- case VertexFormat::kFloat32:
- return out << "float32";
- case VertexFormat::kFloat32x2:
- return out << "float32x2";
- case VertexFormat::kFloat32x3:
- return out << "float32x3";
- case VertexFormat::kFloat32x4:
- return out << "float32x4";
- case VertexFormat::kUint32:
- return out << "uint32";
- case VertexFormat::kUint32x2:
- return out << "uint32x2";
- case VertexFormat::kUint32x3:
- return out << "uint32x3";
- case VertexFormat::kUint32x4:
- return out << "uint32x4";
- case VertexFormat::kSint32:
- return out << "sint32";
- case VertexFormat::kSint32x2:
- return out << "sint32x2";
- case VertexFormat::kSint32x3:
- return out << "sint32x3";
- case VertexFormat::kSint32x4:
- return out << "sint32x4";
- }
- return out << "<unknown>";
+ switch (format) {
+ case VertexFormat::kUint8x2:
+ return out << "uint8x2";
+ case VertexFormat::kUint8x4:
+ return out << "uint8x4";
+ case VertexFormat::kSint8x2:
+ return out << "sint8x2";
+ case VertexFormat::kSint8x4:
+ return out << "sint8x4";
+ case VertexFormat::kUnorm8x2:
+ return out << "unorm8x2";
+ case VertexFormat::kUnorm8x4:
+ return out << "unorm8x4";
+ case VertexFormat::kSnorm8x2:
+ return out << "snorm8x2";
+ case VertexFormat::kSnorm8x4:
+ return out << "snorm8x4";
+ case VertexFormat::kUint16x2:
+ return out << "uint16x2";
+ case VertexFormat::kUint16x4:
+ return out << "uint16x4";
+ case VertexFormat::kSint16x2:
+ return out << "sint16x2";
+ case VertexFormat::kSint16x4:
+ return out << "sint16x4";
+ case VertexFormat::kUnorm16x2:
+ return out << "unorm16x2";
+ case VertexFormat::kUnorm16x4:
+ return out << "unorm16x4";
+ case VertexFormat::kSnorm16x2:
+ return out << "snorm16x2";
+ case VertexFormat::kSnorm16x4:
+ return out << "snorm16x4";
+ case VertexFormat::kFloat16x2:
+ return out << "float16x2";
+ case VertexFormat::kFloat16x4:
+ return out << "float16x4";
+ case VertexFormat::kFloat32:
+ return out << "float32";
+ case VertexFormat::kFloat32x2:
+ return out << "float32x2";
+ case VertexFormat::kFloat32x3:
+ return out << "float32x3";
+ case VertexFormat::kFloat32x4:
+ return out << "float32x4";
+ case VertexFormat::kUint32:
+ return out << "uint32";
+ case VertexFormat::kUint32x2:
+ return out << "uint32x2";
+ case VertexFormat::kUint32x3:
+ return out << "uint32x3";
+ case VertexFormat::kUint32x4:
+ return out << "uint32x4";
+ case VertexFormat::kSint32:
+ return out << "sint32";
+ case VertexFormat::kSint32x2:
+ return out << "sint32x2";
+ case VertexFormat::kSint32x3:
+ return out << "sint32x3";
+ case VertexFormat::kSint32x4:
+ return out << "sint32x4";
+ }
+ return out << "<unknown>";
}
/// A vertex attribute data format.
struct DataType {
- BaseType base_type;
- uint32_t width; // 1 for scalar, 2+ for a vector
+ BaseType base_type;
+ uint32_t width; // 1 for scalar, 2+ for a vector
};
DataType DataTypeOf(const sem::Type* ty) {
- if (ty->Is<sem::I32>()) {
- return {BaseType::kI32, 1};
- }
- if (ty->Is<sem::U32>()) {
- return {BaseType::kU32, 1};
- }
- if (ty->Is<sem::F32>()) {
- return {BaseType::kF32, 1};
- }
- if (auto* vec = ty->As<sem::Vector>()) {
- return {DataTypeOf(vec->type()).base_type, vec->Width()};
- }
- return {BaseType::kInvalid, 0};
+ if (ty->Is<sem::I32>()) {
+ return {BaseType::kI32, 1};
+ }
+ if (ty->Is<sem::U32>()) {
+ return {BaseType::kU32, 1};
+ }
+ if (ty->Is<sem::F32>()) {
+ return {BaseType::kF32, 1};
+ }
+ if (auto* vec = ty->As<sem::Vector>()) {
+ return {DataTypeOf(vec->type()).base_type, vec->Width()};
+ }
+ return {BaseType::kInvalid, 0};
}
DataType DataTypeOf(VertexFormat format) {
- switch (format) {
- case VertexFormat::kUint32:
- return {BaseType::kU32, 1};
- case VertexFormat::kUint8x2:
- case VertexFormat::kUint16x2:
- case VertexFormat::kUint32x2:
- return {BaseType::kU32, 2};
- case VertexFormat::kUint32x3:
- return {BaseType::kU32, 3};
- case VertexFormat::kUint8x4:
- case VertexFormat::kUint16x4:
- case VertexFormat::kUint32x4:
- return {BaseType::kU32, 4};
- case VertexFormat::kSint32:
- return {BaseType::kI32, 1};
- case VertexFormat::kSint8x2:
- case VertexFormat::kSint16x2:
- case VertexFormat::kSint32x2:
- return {BaseType::kI32, 2};
- case VertexFormat::kSint32x3:
- return {BaseType::kI32, 3};
- case VertexFormat::kSint8x4:
- case VertexFormat::kSint16x4:
- case VertexFormat::kSint32x4:
- return {BaseType::kI32, 4};
- case VertexFormat::kFloat32:
- return {BaseType::kF32, 1};
- case VertexFormat::kUnorm8x2:
- case VertexFormat::kSnorm8x2:
- case VertexFormat::kUnorm16x2:
- case VertexFormat::kSnorm16x2:
- case VertexFormat::kFloat16x2:
- case VertexFormat::kFloat32x2:
- return {BaseType::kF32, 2};
- case VertexFormat::kFloat32x3:
- return {BaseType::kF32, 3};
- case VertexFormat::kUnorm8x4:
- case VertexFormat::kSnorm8x4:
- case VertexFormat::kUnorm16x4:
- case VertexFormat::kSnorm16x4:
- case VertexFormat::kFloat16x4:
- case VertexFormat::kFloat32x4:
- return {BaseType::kF32, 4};
- }
- return {BaseType::kInvalid, 0};
+ switch (format) {
+ case VertexFormat::kUint32:
+ return {BaseType::kU32, 1};
+ case VertexFormat::kUint8x2:
+ case VertexFormat::kUint16x2:
+ case VertexFormat::kUint32x2:
+ return {BaseType::kU32, 2};
+ case VertexFormat::kUint32x3:
+ return {BaseType::kU32, 3};
+ case VertexFormat::kUint8x4:
+ case VertexFormat::kUint16x4:
+ case VertexFormat::kUint32x4:
+ return {BaseType::kU32, 4};
+ case VertexFormat::kSint32:
+ return {BaseType::kI32, 1};
+ case VertexFormat::kSint8x2:
+ case VertexFormat::kSint16x2:
+ case VertexFormat::kSint32x2:
+ return {BaseType::kI32, 2};
+ case VertexFormat::kSint32x3:
+ return {BaseType::kI32, 3};
+ case VertexFormat::kSint8x4:
+ case VertexFormat::kSint16x4:
+ case VertexFormat::kSint32x4:
+ return {BaseType::kI32, 4};
+ case VertexFormat::kFloat32:
+ return {BaseType::kF32, 1};
+ case VertexFormat::kUnorm8x2:
+ case VertexFormat::kSnorm8x2:
+ case VertexFormat::kUnorm16x2:
+ case VertexFormat::kSnorm16x2:
+ case VertexFormat::kFloat16x2:
+ case VertexFormat::kFloat32x2:
+ return {BaseType::kF32, 2};
+ case VertexFormat::kFloat32x3:
+ return {BaseType::kF32, 3};
+ case VertexFormat::kUnorm8x4:
+ case VertexFormat::kSnorm8x4:
+ case VertexFormat::kUnorm16x4:
+ case VertexFormat::kSnorm16x4:
+ case VertexFormat::kFloat16x4:
+ case VertexFormat::kFloat32x4:
+ return {BaseType::kF32, 4};
+ }
+ return {BaseType::kInvalid, 0};
}
struct State {
- State(CloneContext& context, const VertexPulling::Config& c)
- : ctx(context), cfg(c) {}
- State(const State&) = default;
- ~State() = default;
+ State(CloneContext& context, const VertexPulling::Config& c) : ctx(context), cfg(c) {}
+ State(const State&) = default;
+ ~State() = default;
- /// LocationReplacement describes an ast::Variable replacement for a
- /// location input.
- struct LocationReplacement {
- /// The variable to replace in the source Program
- ast::Variable* from;
- /// The replacement to use in the target ProgramBuilder
- ast::Variable* to;
- };
+ /// LocationReplacement describes an ast::Variable replacement for a
+ /// location input.
+ struct LocationReplacement {
+ /// The variable to replace in the source Program
+ ast::Variable* from;
+ /// The replacement to use in the target ProgramBuilder
+ ast::Variable* to;
+ };
- struct LocationInfo {
- std::function<const ast::Expression*()> expr;
- const sem::Type* type;
- };
+ struct LocationInfo {
+ std::function<const ast::Expression*()> expr;
+ const sem::Type* type;
+ };
- CloneContext& ctx;
- VertexPulling::Config const cfg;
- std::unordered_map<uint32_t, LocationInfo> location_info;
- std::function<const ast::Expression*()> vertex_index_expr = nullptr;
- std::function<const ast::Expression*()> instance_index_expr = nullptr;
- Symbol pulling_position_name;
- Symbol struct_buffer_name;
- std::unordered_map<uint32_t, Symbol> vertex_buffer_names;
- ast::VariableList new_function_parameters;
+ CloneContext& ctx;
+ VertexPulling::Config const cfg;
+ std::unordered_map<uint32_t, LocationInfo> location_info;
+ std::function<const ast::Expression*()> vertex_index_expr = nullptr;
+ std::function<const ast::Expression*()> instance_index_expr = nullptr;
+ Symbol pulling_position_name;
+ Symbol struct_buffer_name;
+ std::unordered_map<uint32_t, Symbol> vertex_buffer_names;
+ ast::VariableList new_function_parameters;
- /// Generate the vertex buffer binding name
- /// @param index index to append to buffer name
- Symbol GetVertexBufferName(uint32_t index) {
- return utils::GetOrCreate(vertex_buffer_names, index, [&] {
- static const char kVertexBufferNamePrefix[] =
- "tint_pulling_vertex_buffer_";
- return ctx.dst->Symbols().New(kVertexBufferNamePrefix +
- std::to_string(index));
- });
- }
-
- /// Lazily generates the structure buffer symbol
- Symbol GetStructBufferName() {
- if (!struct_buffer_name.IsValid()) {
- static const char kStructBufferName[] = "tint_vertex_data";
- struct_buffer_name = ctx.dst->Symbols().New(kStructBufferName);
- }
- return struct_buffer_name;
- }
-
- /// Adds storage buffer decorated variables for the vertex buffers
- void AddVertexStorageBuffers() {
- // Creating the struct type
- static const char kStructName[] = "TintVertexData";
- auto* struct_type = ctx.dst->Structure(
- ctx.dst->Symbols().New(kStructName),
- {
- ctx.dst->Member(GetStructBufferName(),
- ctx.dst->ty.array<ProgramBuilder::u32>()),
+ /// Generate the vertex buffer binding name
+ /// @param index index to append to buffer name
+ Symbol GetVertexBufferName(uint32_t index) {
+ return utils::GetOrCreate(vertex_buffer_names, index, [&] {
+ static const char kVertexBufferNamePrefix[] = "tint_pulling_vertex_buffer_";
+ return ctx.dst->Symbols().New(kVertexBufferNamePrefix + std::to_string(index));
});
- for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
- // The decorated variable with struct type
- ctx.dst->Global(
- GetVertexBufferName(i), ctx.dst->ty.Of(struct_type),
- ast::StorageClass::kStorage, ast::Access::kRead,
- ast::AttributeList{
- ctx.dst->create<ast::BindingAttribute>(i),
- ctx.dst->create<ast::GroupAttribute>(cfg.pulling_group),
- });
}
- }
- /// Creates and returns the assignment to the variables from the buffers
- ast::BlockStatement* CreateVertexPullingPreamble() {
- // Assign by looking at the vertex descriptor to find attributes with
- // matching location.
-
- ast::StatementList stmts;
-
- for (uint32_t buffer_idx = 0; buffer_idx < cfg.vertex_state.size();
- ++buffer_idx) {
- const VertexBufferLayoutDescriptor& buffer_layout =
- cfg.vertex_state[buffer_idx];
-
- if ((buffer_layout.array_stride & 3) != 0) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform,
- "WebGPU requires that vertex stride must be a multiple of 4 bytes, "
- "but VertexPulling array stride for buffer " +
- std::to_string(buffer_idx) + " was " +
- std::to_string(buffer_layout.array_stride) + " bytes");
- return nullptr;
- }
-
- auto* index_expr = buffer_layout.step_mode == VertexStepMode::kVertex
- ? vertex_index_expr()
- : instance_index_expr();
-
- // buffer_array_base is the base array offset for all the vertex
- // attributes. These are units of uint (4 bytes).
- auto buffer_array_base = ctx.dst->Symbols().New(
- "buffer_array_base_" + std::to_string(buffer_idx));
-
- auto* attribute_offset = index_expr;
- if (buffer_layout.array_stride != 4) {
- attribute_offset =
- ctx.dst->Mul(index_expr, buffer_layout.array_stride / 4u);
- }
-
- // let pulling_offset_n = <attribute_offset>
- stmts.emplace_back(ctx.dst->Decl(
- ctx.dst->Let(buffer_array_base, nullptr, attribute_offset)));
-
- for (const VertexAttributeDescriptor& attribute_desc :
- buffer_layout.attributes) {
- auto it = location_info.find(attribute_desc.shader_location);
- if (it == location_info.end()) {
- continue;
+ /// Lazily generates the structure buffer symbol
+ Symbol GetStructBufferName() {
+ if (!struct_buffer_name.IsValid()) {
+ static const char kStructBufferName[] = "tint_vertex_data";
+ struct_buffer_name = ctx.dst->Symbols().New(kStructBufferName);
}
- auto& var = it->second;
+ return struct_buffer_name;
+ }
- // Data type of the target WGSL variable
- auto var_dt = DataTypeOf(var.type);
- // Data type of the vertex stream attribute
- auto fmt_dt = DataTypeOf(attribute_desc.format);
+ /// Adds storage buffer decorated variables for the vertex buffers
+ void AddVertexStorageBuffers() {
+ // Creating the struct type
+ static const char kStructName[] = "TintVertexData";
+ auto* struct_type = ctx.dst->Structure(
+ ctx.dst->Symbols().New(kStructName),
+ {
+ ctx.dst->Member(GetStructBufferName(), ctx.dst->ty.array<ProgramBuilder::u32>()),
+ });
+ for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
+ // The decorated variable with struct type
+ ctx.dst->Global(GetVertexBufferName(i), ctx.dst->ty.Of(struct_type),
+ ast::StorageClass::kStorage, ast::Access::kRead,
+ ast::AttributeList{
+ ctx.dst->create<ast::BindingAttribute>(i),
+ ctx.dst->create<ast::GroupAttribute>(cfg.pulling_group),
+ });
+ }
+ }
- // Base types must match between the vertex stream and the WGSL variable
- if (var_dt.base_type != fmt_dt.base_type) {
- std::stringstream err;
- err << "VertexAttributeDescriptor for location "
- << std::to_string(attribute_desc.shader_location)
- << " has format " << attribute_desc.format
- << " but shader expects "
- << var.type->FriendlyName(ctx.src->Symbols());
- ctx.dst->Diagnostics().add_error(diag::System::Transform, err.str());
- return nullptr;
+ /// Creates and returns the assignment to the variables from the buffers
+ ast::BlockStatement* CreateVertexPullingPreamble() {
+ // Assign by looking at the vertex descriptor to find attributes with
+ // matching location.
+
+ ast::StatementList stmts;
+
+ for (uint32_t buffer_idx = 0; buffer_idx < cfg.vertex_state.size(); ++buffer_idx) {
+ const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[buffer_idx];
+
+ if ((buffer_layout.array_stride & 3) != 0) {
+ ctx.dst->Diagnostics().add_error(
+ diag::System::Transform,
+ "WebGPU requires that vertex stride must be a multiple of 4 bytes, "
+ "but VertexPulling array stride for buffer " +
+ std::to_string(buffer_idx) + " was " +
+ std::to_string(buffer_layout.array_stride) + " bytes");
+ return nullptr;
+ }
+
+ auto* index_expr = buffer_layout.step_mode == VertexStepMode::kVertex
+ ? vertex_index_expr()
+ : instance_index_expr();
+
+ // buffer_array_base is the base array offset for all the vertex
+ // attributes. These are units of uint (4 bytes).
+ auto buffer_array_base =
+ ctx.dst->Symbols().New("buffer_array_base_" + std::to_string(buffer_idx));
+
+ auto* attribute_offset = index_expr;
+ if (buffer_layout.array_stride != 4) {
+ attribute_offset = ctx.dst->Mul(index_expr, buffer_layout.array_stride / 4u);
+ }
+
+ // let pulling_offset_n = <attribute_offset>
+ stmts.emplace_back(
+ ctx.dst->Decl(ctx.dst->Let(buffer_array_base, nullptr, attribute_offset)));
+
+ for (const VertexAttributeDescriptor& attribute_desc : buffer_layout.attributes) {
+ auto it = location_info.find(attribute_desc.shader_location);
+ if (it == location_info.end()) {
+ continue;
+ }
+ auto& var = it->second;
+
+ // Data type of the target WGSL variable
+ auto var_dt = DataTypeOf(var.type);
+ // Data type of the vertex stream attribute
+ auto fmt_dt = DataTypeOf(attribute_desc.format);
+
+ // Base types must match between the vertex stream and the WGSL variable
+ if (var_dt.base_type != fmt_dt.base_type) {
+ std::stringstream err;
+ err << "VertexAttributeDescriptor for location "
+ << std::to_string(attribute_desc.shader_location) << " has format "
+ << attribute_desc.format << " but shader expects "
+ << var.type->FriendlyName(ctx.src->Symbols());
+ ctx.dst->Diagnostics().add_error(diag::System::Transform, err.str());
+ return nullptr;
+ }
+
+ // Load the attribute value
+ auto* fetch = Fetch(buffer_array_base, attribute_desc.offset, buffer_idx,
+ attribute_desc.format);
+
+ // The attribute value may not be of the desired vector width. If it is
+ // not, we'll need to either reduce the width with a swizzle, or append
+ // 0's and / or a 1.
+ auto* value = fetch;
+ if (var_dt.width < fmt_dt.width) {
+ // WGSL variable vector width is smaller than the loaded vector width
+ switch (var_dt.width) {
+ case 1:
+ value = ctx.dst->MemberAccessor(fetch, "x");
+ break;
+ case 2:
+ value = ctx.dst->MemberAccessor(fetch, "xy");
+ break;
+ case 3:
+ value = ctx.dst->MemberAccessor(fetch, "xyz");
+ break;
+ default:
+ TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) << var_dt.width;
+ return nullptr;
+ }
+ } else if (var_dt.width > fmt_dt.width) {
+ // WGSL variable vector width is wider than the loaded vector width
+ const ast::Type* ty = nullptr;
+ ast::ExpressionList values{fetch};
+ switch (var_dt.base_type) {
+ case BaseType::kI32:
+ ty = ctx.dst->ty.i32();
+ for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
+ values.emplace_back(ctx.dst->Expr((i == 3) ? 1 : 0));
+ }
+ break;
+ case BaseType::kU32:
+ ty = ctx.dst->ty.u32();
+ for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
+ values.emplace_back(ctx.dst->Expr((i == 3) ? 1u : 0u));
+ }
+ break;
+ case BaseType::kF32:
+ ty = ctx.dst->ty.f32();
+ for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
+ values.emplace_back(ctx.dst->Expr((i == 3) ? 1.f : 0.f));
+ }
+ break;
+ default:
+ TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) << var_dt.base_type;
+ return nullptr;
+ }
+ value = ctx.dst->Construct(ctx.dst->ty.vec(ty, var_dt.width), values);
+ }
+
+ // Assign the value to the WGSL variable
+ stmts.emplace_back(ctx.dst->Assign(var.expr(), value));
+ }
}
- // Load the attribute value
- auto* fetch = Fetch(buffer_array_base, attribute_desc.offset,
- buffer_idx, attribute_desc.format);
-
- // The attribute value may not be of the desired vector width. If it is
- // not, we'll need to either reduce the width with a swizzle, or append
- // 0's and / or a 1.
- auto* value = fetch;
- if (var_dt.width < fmt_dt.width) {
- // WGSL variable vector width is smaller than the loaded vector width
- switch (var_dt.width) {
- case 1:
- value = ctx.dst->MemberAccessor(fetch, "x");
- break;
- case 2:
- value = ctx.dst->MemberAccessor(fetch, "xy");
- break;
- case 3:
- value = ctx.dst->MemberAccessor(fetch, "xyz");
- break;
- default:
- TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
- << var_dt.width;
- return nullptr;
- }
- } else if (var_dt.width > fmt_dt.width) {
- // WGSL variable vector width is wider than the loaded vector width
- const ast::Type* ty = nullptr;
- ast::ExpressionList values{fetch};
- switch (var_dt.base_type) {
- case BaseType::kI32:
- ty = ctx.dst->ty.i32();
- for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
- values.emplace_back(ctx.dst->Expr((i == 3) ? 1 : 0));
- }
- break;
- case BaseType::kU32:
- ty = ctx.dst->ty.u32();
- for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
- values.emplace_back(ctx.dst->Expr((i == 3) ? 1u : 0u));
- }
- break;
- case BaseType::kF32:
- ty = ctx.dst->ty.f32();
- for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
- values.emplace_back(ctx.dst->Expr((i == 3) ? 1.f : 0.f));
- }
- break;
- default:
- TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
- << var_dt.base_type;
- return nullptr;
- }
- value = ctx.dst->Construct(ctx.dst->ty.vec(ty, var_dt.width), values);
+ if (stmts.empty()) {
+ return nullptr;
}
- // Assign the value to the WGSL variable
- stmts.emplace_back(ctx.dst->Assign(var.expr(), value));
- }
+ return ctx.dst->create<ast::BlockStatement>(stmts);
}
- if (stmts.empty()) {
- return nullptr;
- }
-
- return ctx.dst->create<ast::BlockStatement>(stmts);
- }
-
- /// Generates an expression reading from a buffer a specific format.
- /// @param array_base the symbol of the variable holding the base array offset
- /// of the vertex array (each index is 4-bytes).
- /// @param offset the byte offset of the data from `buffer_base`
- /// @param buffer the index of the vertex buffer
- /// @param format the format to read
- const ast::Expression* Fetch(Symbol array_base,
- uint32_t offset,
- uint32_t buffer,
- VertexFormat format) {
- using u32 = ProgramBuilder::u32;
- using i32 = ProgramBuilder::i32;
- using f32 = ProgramBuilder::f32;
-
- // Returns a u32 loaded from buffer_base + offset.
- auto load_u32 = [&] {
- return LoadPrimitive(array_base, offset, buffer, VertexFormat::kUint32);
- };
-
- // Returns a i32 loaded from buffer_base + offset.
- auto load_i32 = [&] { return ctx.dst->Bitcast<i32>(load_u32()); };
-
- // Returns a u32 loaded from buffer_base + offset + 4.
- auto load_next_u32 = [&] {
- return LoadPrimitive(array_base, offset + 4, buffer,
- VertexFormat::kUint32);
- };
-
- // Returns a i32 loaded from buffer_base + offset + 4.
- auto load_next_i32 = [&] { return ctx.dst->Bitcast<i32>(load_next_u32()); };
-
- // Returns a u16 loaded from offset, packed in the high 16 bits of a u32.
- // The low 16 bits are 0.
- // `min_alignment` must be a power of two.
- // `offset` must be `min_alignment` bytes aligned.
- auto load_u16_h = [&] {
- auto low_u32_offset = offset & ~3u;
- auto* low_u32 = LoadPrimitive(array_base, low_u32_offset, buffer,
- VertexFormat::kUint32);
- switch (offset & 3) {
- case 0:
- return ctx.dst->Shl(low_u32, 16u);
- case 1:
- return ctx.dst->And(ctx.dst->Shl(low_u32, 8u), 0xffff0000u);
- case 2:
- return ctx.dst->And(low_u32, 0xffff0000u);
- default: { // 3:
- auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
- VertexFormat::kUint32);
- auto* shr = ctx.dst->Shr(low_u32, 8u);
- auto* shl = ctx.dst->Shl(high_u32, 24u);
- return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff0000u);
- }
- }
- };
-
- // Returns a u16 loaded from offset, packed in the low 16 bits of a u32.
- // The high 16 bits are 0.
- auto load_u16_l = [&] {
- auto low_u32_offset = offset & ~3u;
- auto* low_u32 = LoadPrimitive(array_base, low_u32_offset, buffer,
- VertexFormat::kUint32);
- switch (offset & 3) {
- case 0:
- return ctx.dst->And(low_u32, 0xffffu);
- case 1:
- return ctx.dst->And(ctx.dst->Shr(low_u32, 8u), 0xffffu);
- case 2:
- return ctx.dst->Shr(low_u32, 16u);
- default: { // 3:
- auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
- VertexFormat::kUint32);
- auto* shr = ctx.dst->Shr(low_u32, 24u);
- auto* shl = ctx.dst->Shl(high_u32, 8u);
- return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffffu);
- }
- }
- };
-
- // Returns a i16 loaded from offset, packed in the high 16 bits of a u32.
- // The low 16 bits are 0.
- auto load_i16_h = [&] { return ctx.dst->Bitcast<i32>(load_u16_h()); };
-
- // Assumptions are made that alignment must be at least as large as the size
- // of a single component.
- switch (format) {
- // Basic primitives
- case VertexFormat::kUint32:
- case VertexFormat::kSint32:
- case VertexFormat::kFloat32:
- return LoadPrimitive(array_base, offset, buffer, format);
-
- // Vectors of basic primitives
- case VertexFormat::kUint32x2:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
- VertexFormat::kUint32, 2);
- case VertexFormat::kUint32x3:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
- VertexFormat::kUint32, 3);
- case VertexFormat::kUint32x4:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
- VertexFormat::kUint32, 4);
- case VertexFormat::kSint32x2:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
- VertexFormat::kSint32, 2);
- case VertexFormat::kSint32x3:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
- VertexFormat::kSint32, 3);
- case VertexFormat::kSint32x4:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
- VertexFormat::kSint32, 4);
- case VertexFormat::kFloat32x2:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
- VertexFormat::kFloat32, 2);
- case VertexFormat::kFloat32x3:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
- VertexFormat::kFloat32, 3);
- case VertexFormat::kFloat32x4:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
- VertexFormat::kFloat32, 4);
-
- case VertexFormat::kUint8x2: {
- // yyxx0000, yyxx0000
- auto* u16s = ctx.dst->vec2<u32>(load_u16_h());
- // xx000000, yyxx0000
- auto* shl = ctx.dst->Shl(u16s, ctx.dst->vec2<u32>(8u, 0u));
- // 000000xx, 000000yy
- return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24u));
- }
- case VertexFormat::kUint8x4: {
- // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
- auto* u32s = ctx.dst->vec4<u32>(load_u32());
- // xx000000, yyxx0000, zzyyxx00, wwzzyyxx
- auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec4<u32>(24u, 16u, 8u, 0u));
- // 000000xx, 000000yy, 000000zz, 000000ww
- return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24u));
- }
- case VertexFormat::kUint16x2: {
- // yyyyxxxx, yyyyxxxx
- auto* u32s = ctx.dst->vec2<u32>(load_u32());
- // xxxx0000, yyyyxxxx
- auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec2<u32>(16u, 0u));
- // 0000xxxx, 0000yyyy
- return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16u));
- }
- case VertexFormat::kUint16x4: {
- // yyyyxxxx, wwwwzzzz
- auto* u32s = ctx.dst->vec2<u32>(load_u32(), load_next_u32());
- // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
- auto* xxyy = ctx.dst->MemberAccessor(u32s, "xxyy");
- // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
- auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16u, 0u, 16u, 0u));
- // 0000xxxx, 0000yyyy, 0000zzzz, 0000wwww
- return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16u));
- }
- case VertexFormat::kSint8x2: {
- // yyxx0000, yyxx0000
- auto* i16s = ctx.dst->vec2<i32>(load_i16_h());
- // xx000000, yyxx0000
- auto* shl = ctx.dst->Shl(i16s, ctx.dst->vec2<u32>(8u, 0u));
- // ssssssxx, ssssssyy
- return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24u));
- }
- case VertexFormat::kSint8x4: {
- // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
- auto* i32s = ctx.dst->vec4<i32>(load_i32());
- // xx000000, yyxx0000, zzyyxx00, wwzzyyxx
- auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec4<u32>(24u, 16u, 8u, 0u));
- // ssssssxx, ssssssyy, sssssszz, ssssssww
- return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24u));
- }
- case VertexFormat::kSint16x2: {
- // yyyyxxxx, yyyyxxxx
- auto* i32s = ctx.dst->vec2<i32>(load_i32());
- // xxxx0000, yyyyxxxx
- auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec2<u32>(16u, 0u));
- // ssssxxxx, ssssyyyy
- return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16u));
- }
- case VertexFormat::kSint16x4: {
- // yyyyxxxx, wwwwzzzz
- auto* i32s = ctx.dst->vec2<i32>(load_i32(), load_next_i32());
- // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
- auto* xxyy = ctx.dst->MemberAccessor(i32s, "xxyy");
- // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
- auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16u, 0u, 16u, 0u));
- // ssssxxxx, ssssyyyy, sssszzzz, sssswwww
- return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16u));
- }
- case VertexFormat::kUnorm8x2:
- return ctx.dst->MemberAccessor(
- ctx.dst->Call("unpack4x8unorm", load_u16_l()), "xy");
- case VertexFormat::kSnorm8x2:
- return ctx.dst->MemberAccessor(
- ctx.dst->Call("unpack4x8snorm", load_u16_l()), "xy");
- case VertexFormat::kUnorm8x4:
- return ctx.dst->Call("unpack4x8unorm", load_u32());
- case VertexFormat::kSnorm8x4:
- return ctx.dst->Call("unpack4x8snorm", load_u32());
- case VertexFormat::kUnorm16x2:
- return ctx.dst->Call("unpack2x16unorm", load_u32());
- case VertexFormat::kSnorm16x2:
- return ctx.dst->Call("unpack2x16snorm", load_u32());
- case VertexFormat::kFloat16x2:
- return ctx.dst->Call("unpack2x16float", load_u32());
- case VertexFormat::kUnorm16x4:
- return ctx.dst->vec4<f32>(
- ctx.dst->Call("unpack2x16unorm", load_u32()),
- ctx.dst->Call("unpack2x16unorm", load_next_u32()));
- case VertexFormat::kSnorm16x4:
- return ctx.dst->vec4<f32>(
- ctx.dst->Call("unpack2x16snorm", load_u32()),
- ctx.dst->Call("unpack2x16snorm", load_next_u32()));
- case VertexFormat::kFloat16x4:
- return ctx.dst->vec4<f32>(
- ctx.dst->Call("unpack2x16float", load_u32()),
- ctx.dst->Call("unpack2x16float", load_next_u32()));
- }
-
- TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
- << "format " << static_cast<int>(format);
- return nullptr;
- }
-
- /// Generates an expression reading an aligned basic type (u32, i32, f32) from
- /// a vertex buffer.
- /// @param array_base the symbol of the variable holding the base array offset
- /// of the vertex array (each index is 4-bytes).
- /// @param offset the byte offset of the data from `buffer_base`
- /// @param buffer the index of the vertex buffer
- /// @param format VertexFormat::kUint32, VertexFormat::kSint32 or
- /// VertexFormat::kFloat32
- const ast::Expression* LoadPrimitive(Symbol array_base,
- uint32_t offset,
- uint32_t buffer,
- VertexFormat format) {
- const ast::Expression* u32 = nullptr;
- if ((offset & 3) == 0) {
- // Aligned load.
-
- const ast ::Expression* index = nullptr;
- if (offset > 0) {
- index = ctx.dst->Add(array_base, offset / 4);
- } else {
- index = ctx.dst->Expr(array_base);
- }
- u32 = ctx.dst->IndexAccessor(
- ctx.dst->MemberAccessor(GetVertexBufferName(buffer),
- GetStructBufferName()),
- index);
-
- } else {
- // Unaligned load
- uint32_t offset_aligned = offset & ~3u;
- auto* low = LoadPrimitive(array_base, offset_aligned, buffer,
- VertexFormat::kUint32);
- auto* high = LoadPrimitive(array_base, offset_aligned + 4u, buffer,
- VertexFormat::kUint32);
-
- uint32_t shift = 8u * (offset & 3u);
-
- auto* low_shr = ctx.dst->Shr(low, shift);
- auto* high_shl = ctx.dst->Shl(high, 32u - shift);
- u32 = ctx.dst->Or(low_shr, high_shl);
- }
-
- switch (format) {
- case VertexFormat::kUint32:
- return u32;
- case VertexFormat::kSint32:
- return ctx.dst->Bitcast(ctx.dst->ty.i32(), u32);
- case VertexFormat::kFloat32:
- return ctx.dst->Bitcast(ctx.dst->ty.f32(), u32);
- default:
- break;
- }
- TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
- << "invalid format for LoadPrimitive" << static_cast<int>(format);
- return nullptr;
- }
-
- /// Generates an expression reading a vec2/3/4 from a vertex buffer.
- /// @param array_base the symbol of the variable holding the base array offset
- /// of the vertex array (each index is 4-bytes).
- /// @param offset the byte offset of the data from `buffer_base`
- /// @param buffer the index of the vertex buffer
- /// @param element_stride stride between elements, in bytes
- /// @param base_type underlying AST type
- /// @param base_format underlying vertex format
- /// @param count how many elements the vector has
- const ast::Expression* LoadVec(Symbol array_base,
+ /// Generates an expression reading from a buffer a specific format.
+ /// @param array_base the symbol of the variable holding the base array offset
+ /// of the vertex array (each index is 4-bytes).
+ /// @param offset the byte offset of the data from `buffer_base`
+ /// @param buffer the index of the vertex buffer
+ /// @param format the format to read
+ const ast::Expression* Fetch(Symbol array_base,
uint32_t offset,
uint32_t buffer,
- uint32_t element_stride,
- const ast::Type* base_type,
- VertexFormat base_format,
- uint32_t count) {
- ast::ExpressionList expr_list;
- for (uint32_t i = 0; i < count; ++i) {
- // Offset read position by element_stride for each component
- uint32_t primitive_offset = offset + element_stride * i;
- expr_list.push_back(
- LoadPrimitive(array_base, primitive_offset, buffer, base_format));
- }
+ VertexFormat format) {
+ using u32 = ProgramBuilder::u32;
+ using i32 = ProgramBuilder::i32;
+ using f32 = ProgramBuilder::f32;
- return ctx.dst->Construct(ctx.dst->create<ast::Vector>(base_type, count),
- std::move(expr_list));
- }
-
- /// Process a non-struct entry point parameter.
- /// Generate function-scope variables for location parameters, and record
- /// vertex_index and instance_index builtins if present.
- /// @param func the entry point function
- /// @param param the parameter to process
- void ProcessNonStructParameter(const ast::Function* func,
- const ast::Variable* param) {
- if (auto* location =
- ast::GetAttribute<ast::LocationAttribute>(param->attributes)) {
- // Create a function-scope variable to replace the parameter.
- auto func_var_sym = ctx.Clone(param->symbol);
- auto* func_var_type = ctx.Clone(param->type);
- auto* func_var = ctx.dst->Var(func_var_sym, func_var_type);
- ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var));
- // Capture mapping from location to the new variable.
- LocationInfo info;
- info.expr = [this, func_var]() { return ctx.dst->Expr(func_var); };
- info.type = ctx.src->Sem().Get(param)->Type();
- location_info[location->value] = info;
- } else if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(
- param->attributes)) {
- // Check for existing vertex_index and instance_index builtins.
- if (builtin->builtin == ast::Builtin::kVertexIndex) {
- vertex_index_expr = [this, param]() {
- return ctx.dst->Expr(ctx.Clone(param->symbol));
+ // Returns a u32 loaded from buffer_base + offset.
+ auto load_u32 = [&] {
+ return LoadPrimitive(array_base, offset, buffer, VertexFormat::kUint32);
};
- } else if (builtin->builtin == ast::Builtin::kInstanceIndex) {
- instance_index_expr = [this, param]() {
- return ctx.dst->Expr(ctx.Clone(param->symbol));
+
+ // Returns a i32 loaded from buffer_base + offset.
+ auto load_i32 = [&] { return ctx.dst->Bitcast<i32>(load_u32()); };
+
+ // Returns a u32 loaded from buffer_base + offset + 4.
+ auto load_next_u32 = [&] {
+ return LoadPrimitive(array_base, offset + 4, buffer, VertexFormat::kUint32);
};
- }
- new_function_parameters.push_back(ctx.Clone(param));
- } else {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "Invalid entry point parameter";
- }
- }
- /// Process a struct entry point parameter.
- /// If the struct has members with location attributes, push the parameter to
- /// a function-scope variable and create a new struct parameter without those
- /// attributes. Record expressions for members that are vertex_index and
- /// instance_index builtins.
- /// @param func the entry point function
- /// @param param the parameter to process
- /// @param struct_ty the structure type
- void ProcessStructParameter(const ast::Function* func,
- const ast::Variable* param,
- const ast::Struct* struct_ty) {
- auto param_sym = ctx.Clone(param->symbol);
+ // Returns a i32 loaded from buffer_base + offset + 4.
+ auto load_next_i32 = [&] { return ctx.dst->Bitcast<i32>(load_next_u32()); };
- // Process the struct members.
- bool has_locations = false;
- ast::StructMemberList members_to_clone;
- for (auto* member : struct_ty->members) {
- auto member_sym = ctx.Clone(member->symbol);
- std::function<const ast::Expression*()> member_expr = [this, param_sym,
- member_sym]() {
- return ctx.dst->MemberAccessor(param_sym, member_sym);
- };
+ // Returns a u16 loaded from offset, packed in the high 16 bits of a u32.
+ // The low 16 bits are 0.
+ // `min_alignment` must be a power of two.
+ // `offset` must be `min_alignment` bytes aligned.
+ auto load_u16_h = [&] {
+ auto low_u32_offset = offset & ~3u;
+ auto* low_u32 =
+ LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32);
+ switch (offset & 3) {
+ case 0:
+ return ctx.dst->Shl(low_u32, 16u);
+ case 1:
+ return ctx.dst->And(ctx.dst->Shl(low_u32, 8u), 0xffff0000u);
+ case 2:
+ return ctx.dst->And(low_u32, 0xffff0000u);
+ default: { // 3:
+ auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
+ VertexFormat::kUint32);
+ auto* shr = ctx.dst->Shr(low_u32, 8u);
+ auto* shl = ctx.dst->Shl(high_u32, 24u);
+ return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff0000u);
+ }
+ }
+ };
- if (auto* location =
- ast::GetAttribute<ast::LocationAttribute>(member->attributes)) {
- // Capture mapping from location to struct member.
- LocationInfo info;
- info.expr = member_expr;
- info.type = ctx.src->Sem().Get(member)->Type();
- location_info[location->value] = info;
- has_locations = true;
- } else if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(
- member->attributes)) {
- // Check for existing vertex_index and instance_index builtins.
- if (builtin->builtin == ast::Builtin::kVertexIndex) {
- vertex_index_expr = member_expr;
- } else if (builtin->builtin == ast::Builtin::kInstanceIndex) {
- instance_index_expr = member_expr;
+ // Returns a u16 loaded from offset, packed in the low 16 bits of a u32.
+ // The high 16 bits are 0.
+ auto load_u16_l = [&] {
+ auto low_u32_offset = offset & ~3u;
+ auto* low_u32 =
+ LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32);
+ switch (offset & 3) {
+ case 0:
+ return ctx.dst->And(low_u32, 0xffffu);
+ case 1:
+ return ctx.dst->And(ctx.dst->Shr(low_u32, 8u), 0xffffu);
+ case 2:
+ return ctx.dst->Shr(low_u32, 16u);
+ default: { // 3:
+ auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
+ VertexFormat::kUint32);
+ auto* shr = ctx.dst->Shr(low_u32, 24u);
+ auto* shl = ctx.dst->Shl(high_u32, 8u);
+ return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffffu);
+ }
+ }
+ };
+
+ // Returns a i16 loaded from offset, packed in the high 16 bits of a u32.
+ // The low 16 bits are 0.
+ auto load_i16_h = [&] { return ctx.dst->Bitcast<i32>(load_u16_h()); };
+
+ // Assumptions are made that alignment must be at least as large as the size
+ // of a single component.
+ switch (format) {
+ // Basic primitives
+ case VertexFormat::kUint32:
+ case VertexFormat::kSint32:
+ case VertexFormat::kFloat32:
+ return LoadPrimitive(array_base, offset, buffer, format);
+
+ // Vectors of basic primitives
+ case VertexFormat::kUint32x2:
+ return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
+ VertexFormat::kUint32, 2);
+ case VertexFormat::kUint32x3:
+ return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
+ VertexFormat::kUint32, 3);
+ case VertexFormat::kUint32x4:
+ return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
+ VertexFormat::kUint32, 4);
+ case VertexFormat::kSint32x2:
+ return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
+ VertexFormat::kSint32, 2);
+ case VertexFormat::kSint32x3:
+ return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
+ VertexFormat::kSint32, 3);
+ case VertexFormat::kSint32x4:
+ return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
+ VertexFormat::kSint32, 4);
+ case VertexFormat::kFloat32x2:
+ return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
+ VertexFormat::kFloat32, 2);
+ case VertexFormat::kFloat32x3:
+ return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
+ VertexFormat::kFloat32, 3);
+ case VertexFormat::kFloat32x4:
+ return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
+ VertexFormat::kFloat32, 4);
+
+ case VertexFormat::kUint8x2: {
+ // yyxx0000, yyxx0000
+ auto* u16s = ctx.dst->vec2<u32>(load_u16_h());
+ // xx000000, yyxx0000
+ auto* shl = ctx.dst->Shl(u16s, ctx.dst->vec2<u32>(8u, 0u));
+ // 000000xx, 000000yy
+ return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24u));
+ }
+ case VertexFormat::kUint8x4: {
+ // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
+ auto* u32s = ctx.dst->vec4<u32>(load_u32());
+ // xx000000, yyxx0000, zzyyxx00, wwzzyyxx
+ auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec4<u32>(24u, 16u, 8u, 0u));
+ // 000000xx, 000000yy, 000000zz, 000000ww
+ return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24u));
+ }
+ case VertexFormat::kUint16x2: {
+ // yyyyxxxx, yyyyxxxx
+ auto* u32s = ctx.dst->vec2<u32>(load_u32());
+ // xxxx0000, yyyyxxxx
+ auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec2<u32>(16u, 0u));
+ // 0000xxxx, 0000yyyy
+ return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16u));
+ }
+ case VertexFormat::kUint16x4: {
+ // yyyyxxxx, wwwwzzzz
+ auto* u32s = ctx.dst->vec2<u32>(load_u32(), load_next_u32());
+ // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
+ auto* xxyy = ctx.dst->MemberAccessor(u32s, "xxyy");
+ // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
+ auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16u, 0u, 16u, 0u));
+ // 0000xxxx, 0000yyyy, 0000zzzz, 0000wwww
+ return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16u));
+ }
+ case VertexFormat::kSint8x2: {
+ // yyxx0000, yyxx0000
+ auto* i16s = ctx.dst->vec2<i32>(load_i16_h());
+ // xx000000, yyxx0000
+ auto* shl = ctx.dst->Shl(i16s, ctx.dst->vec2<u32>(8u, 0u));
+ // ssssssxx, ssssssyy
+ return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24u));
+ }
+ case VertexFormat::kSint8x4: {
+ // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
+ auto* i32s = ctx.dst->vec4<i32>(load_i32());
+ // xx000000, yyxx0000, zzyyxx00, wwzzyyxx
+ auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec4<u32>(24u, 16u, 8u, 0u));
+ // ssssssxx, ssssssyy, sssssszz, ssssssww
+ return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24u));
+ }
+ case VertexFormat::kSint16x2: {
+ // yyyyxxxx, yyyyxxxx
+ auto* i32s = ctx.dst->vec2<i32>(load_i32());
+ // xxxx0000, yyyyxxxx
+ auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec2<u32>(16u, 0u));
+ // ssssxxxx, ssssyyyy
+ return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16u));
+ }
+ case VertexFormat::kSint16x4: {
+ // yyyyxxxx, wwwwzzzz
+ auto* i32s = ctx.dst->vec2<i32>(load_i32(), load_next_i32());
+ // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
+ auto* xxyy = ctx.dst->MemberAccessor(i32s, "xxyy");
+ // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
+ auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16u, 0u, 16u, 0u));
+ // ssssxxxx, ssssyyyy, sssszzzz, sssswwww
+ return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16u));
+ }
+ case VertexFormat::kUnorm8x2:
+ return ctx.dst->MemberAccessor(ctx.dst->Call("unpack4x8unorm", load_u16_l()), "xy");
+ case VertexFormat::kSnorm8x2:
+ return ctx.dst->MemberAccessor(ctx.dst->Call("unpack4x8snorm", load_u16_l()), "xy");
+ case VertexFormat::kUnorm8x4:
+ return ctx.dst->Call("unpack4x8unorm", load_u32());
+ case VertexFormat::kSnorm8x4:
+ return ctx.dst->Call("unpack4x8snorm", load_u32());
+ case VertexFormat::kUnorm16x2:
+ return ctx.dst->Call("unpack2x16unorm", load_u32());
+ case VertexFormat::kSnorm16x2:
+ return ctx.dst->Call("unpack2x16snorm", load_u32());
+ case VertexFormat::kFloat16x2:
+ return ctx.dst->Call("unpack2x16float", load_u32());
+ case VertexFormat::kUnorm16x4:
+ return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16unorm", load_u32()),
+ ctx.dst->Call("unpack2x16unorm", load_next_u32()));
+ case VertexFormat::kSnorm16x4:
+ return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16snorm", load_u32()),
+ ctx.dst->Call("unpack2x16snorm", load_next_u32()));
+ case VertexFormat::kFloat16x4:
+ return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16float", load_u32()),
+ ctx.dst->Call("unpack2x16float", load_next_u32()));
}
- members_to_clone.push_back(member);
- } else {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "Invalid entry point parameter";
- }
+
+ TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
+ << "format " << static_cast<int>(format);
+ return nullptr;
}
- if (!has_locations) {
- // Nothing to do.
- new_function_parameters.push_back(ctx.Clone(param));
- return;
- }
+ /// Generates an expression reading an aligned basic type (u32, i32, f32) from
+ /// a vertex buffer.
+ /// @param array_base the symbol of the variable holding the base array offset
+ /// of the vertex array (each index is 4-bytes).
+ /// @param offset the byte offset of the data from `buffer_base`
+ /// @param buffer the index of the vertex buffer
+ /// @param format VertexFormat::kUint32, VertexFormat::kSint32 or
+ /// VertexFormat::kFloat32
+ const ast::Expression* LoadPrimitive(Symbol array_base,
+ uint32_t offset,
+ uint32_t buffer,
+ VertexFormat format) {
+ const ast::Expression* u32 = nullptr;
+ if ((offset & 3) == 0) {
+ // Aligned load.
- // Create a function-scope variable to replace the parameter.
- auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->type));
- ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var));
+ const ast ::Expression* index = nullptr;
+ if (offset > 0) {
+ index = ctx.dst->Add(array_base, offset / 4);
+ } else {
+ index = ctx.dst->Expr(array_base);
+ }
+ u32 = ctx.dst->IndexAccessor(
+ ctx.dst->MemberAccessor(GetVertexBufferName(buffer), GetStructBufferName()), index);
- if (!members_to_clone.empty()) {
- // Create a new struct without the location attributes.
- ast::StructMemberList new_members;
- for (auto* member : members_to_clone) {
- auto member_sym = ctx.Clone(member->symbol);
- auto* member_type = ctx.Clone(member->type);
- auto member_attrs = ctx.Clone(member->attributes);
- new_members.push_back(
- ctx.dst->Member(member_sym, member_type, std::move(member_attrs)));
- }
- auto* new_struct = ctx.dst->Structure(ctx.dst->Sym(), new_members);
+ } else {
+ // Unaligned load
+ uint32_t offset_aligned = offset & ~3u;
+ auto* low = LoadPrimitive(array_base, offset_aligned, buffer, VertexFormat::kUint32);
+ auto* high =
+ LoadPrimitive(array_base, offset_aligned + 4u, buffer, VertexFormat::kUint32);
- // Create a new function parameter with this struct.
- auto* new_param =
- ctx.dst->Param(ctx.dst->Sym(), ctx.dst->ty.Of(new_struct));
- new_function_parameters.push_back(new_param);
+ uint32_t shift = 8u * (offset & 3u);
- // Copy values from the new parameter to the function-scope variable.
- for (auto* member : members_to_clone) {
- auto member_name = ctx.Clone(member->symbol);
- ctx.InsertFront(
- func->body->statements,
- ctx.dst->Assign(ctx.dst->MemberAccessor(func_var, member_name),
- ctx.dst->MemberAccessor(new_param, member_name)));
- }
- }
- }
-
- /// Process an entry point function.
- /// @param func the entry point function
- void Process(const ast::Function* func) {
- if (func->body->Empty()) {
- return;
- }
-
- // Process entry point parameters.
- for (auto* param : func->params) {
- auto* sem = ctx.src->Sem().Get(param);
- if (auto* str = sem->Type()->As<sem::Struct>()) {
- ProcessStructParameter(func, param, str->Declaration());
- } else {
- ProcessNonStructParameter(func, param);
- }
- }
-
- // Insert new parameters for vertex_index and instance_index if needed.
- if (!vertex_index_expr) {
- for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
- if (layout.step_mode == VertexStepMode::kVertex) {
- auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index");
- new_function_parameters.push_back(
- ctx.dst->Param(name, ctx.dst->ty.u32(),
- {ctx.dst->Builtin(ast::Builtin::kVertexIndex)}));
- vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); };
- break;
+ auto* low_shr = ctx.dst->Shr(low, shift);
+ auto* high_shl = ctx.dst->Shl(high, 32u - shift);
+ u32 = ctx.dst->Or(low_shr, high_shl);
}
- }
- }
- if (!instance_index_expr) {
- for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
- if (layout.step_mode == VertexStepMode::kInstance) {
- auto name = ctx.dst->Symbols().New("tint_pulling_instance_index");
- new_function_parameters.push_back(
- ctx.dst->Param(name, ctx.dst->ty.u32(),
- {ctx.dst->Builtin(ast::Builtin::kInstanceIndex)}));
- instance_index_expr = [this, name]() { return ctx.dst->Expr(name); };
- break;
+
+ switch (format) {
+ case VertexFormat::kUint32:
+ return u32;
+ case VertexFormat::kSint32:
+ return ctx.dst->Bitcast(ctx.dst->ty.i32(), u32);
+ case VertexFormat::kFloat32:
+ return ctx.dst->Bitcast(ctx.dst->ty.f32(), u32);
+ default:
+ break;
}
- }
+ TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
+ << "invalid format for LoadPrimitive" << static_cast<int>(format);
+ return nullptr;
}
- // Generate vertex pulling preamble.
- if (auto* block = CreateVertexPullingPreamble()) {
- ctx.InsertFront(func->body->statements, block);
+ /// Generates an expression reading a vec2/3/4 from a vertex buffer.
+ /// @param array_base the symbol of the variable holding the base array offset
+ /// of the vertex array (each index is 4-bytes).
+ /// @param offset the byte offset of the data from `buffer_base`
+ /// @param buffer the index of the vertex buffer
+ /// @param element_stride stride between elements, in bytes
+ /// @param base_type underlying AST type
+ /// @param base_format underlying vertex format
+ /// @param count how many elements the vector has
+ const ast::Expression* LoadVec(Symbol array_base,
+ uint32_t offset,
+ uint32_t buffer,
+ uint32_t element_stride,
+ const ast::Type* base_type,
+ VertexFormat base_format,
+ uint32_t count) {
+ ast::ExpressionList expr_list;
+ for (uint32_t i = 0; i < count; ++i) {
+ // Offset read position by element_stride for each component
+ uint32_t primitive_offset = offset + element_stride * i;
+ expr_list.push_back(LoadPrimitive(array_base, primitive_offset, buffer, base_format));
+ }
+
+ return ctx.dst->Construct(ctx.dst->create<ast::Vector>(base_type, count),
+ std::move(expr_list));
}
- // Rewrite the function header with the new parameters.
- auto func_sym = ctx.Clone(func->symbol);
- auto* ret_type = ctx.Clone(func->return_type);
- auto* body = ctx.Clone(func->body);
- auto attrs = ctx.Clone(func->attributes);
- auto ret_attrs = ctx.Clone(func->return_type_attributes);
- auto* new_func = ctx.dst->create<ast::Function>(
- func->source, func_sym, new_function_parameters, ret_type, body,
- std::move(attrs), std::move(ret_attrs));
- ctx.Replace(func, new_func);
- }
+ /// Process a non-struct entry point parameter.
+ /// Generate function-scope variables for location parameters, and record
+ /// vertex_index and instance_index builtins if present.
+ /// @param func the entry point function
+ /// @param param the parameter to process
+ void ProcessNonStructParameter(const ast::Function* func, const ast::Variable* param) {
+ if (auto* location = ast::GetAttribute<ast::LocationAttribute>(param->attributes)) {
+ // Create a function-scope variable to replace the parameter.
+ auto func_var_sym = ctx.Clone(param->symbol);
+ auto* func_var_type = ctx.Clone(param->type);
+ auto* func_var = ctx.dst->Var(func_var_sym, func_var_type);
+ ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var));
+ // Capture mapping from location to the new variable.
+ LocationInfo info;
+ info.expr = [this, func_var]() { return ctx.dst->Expr(func_var); };
+ info.type = ctx.src->Sem().Get(param)->Type();
+ location_info[location->value] = info;
+ } else if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) {
+ // Check for existing vertex_index and instance_index builtins.
+ if (builtin->builtin == ast::Builtin::kVertexIndex) {
+ vertex_index_expr = [this, param]() {
+ return ctx.dst->Expr(ctx.Clone(param->symbol));
+ };
+ } else if (builtin->builtin == ast::Builtin::kInstanceIndex) {
+ instance_index_expr = [this, param]() {
+ return ctx.dst->Expr(ctx.Clone(param->symbol));
+ };
+ }
+ new_function_parameters.push_back(ctx.Clone(param));
+ } else {
+ TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Invalid entry point parameter";
+ }
+ }
+
+ /// Process a struct entry point parameter.
+ /// If the struct has members with location attributes, push the parameter to
+ /// a function-scope variable and create a new struct parameter without those
+ /// attributes. Record expressions for members that are vertex_index and
+ /// instance_index builtins.
+ /// @param func the entry point function
+ /// @param param the parameter to process
+ /// @param struct_ty the structure type
+ void ProcessStructParameter(const ast::Function* func,
+ const ast::Variable* param,
+ const ast::Struct* struct_ty) {
+ auto param_sym = ctx.Clone(param->symbol);
+
+ // Process the struct members.
+ bool has_locations = false;
+ ast::StructMemberList members_to_clone;
+ for (auto* member : struct_ty->members) {
+ auto member_sym = ctx.Clone(member->symbol);
+ std::function<const ast::Expression*()> member_expr = [this, param_sym, member_sym]() {
+ return ctx.dst->MemberAccessor(param_sym, member_sym);
+ };
+
+ if (auto* location = ast::GetAttribute<ast::LocationAttribute>(member->attributes)) {
+ // Capture mapping from location to struct member.
+ LocationInfo info;
+ info.expr = member_expr;
+ info.type = ctx.src->Sem().Get(member)->Type();
+ location_info[location->value] = info;
+ has_locations = true;
+ } else if (auto* builtin =
+ ast::GetAttribute<ast::BuiltinAttribute>(member->attributes)) {
+ // Check for existing vertex_index and instance_index builtins.
+ if (builtin->builtin == ast::Builtin::kVertexIndex) {
+ vertex_index_expr = member_expr;
+ } else if (builtin->builtin == ast::Builtin::kInstanceIndex) {
+ instance_index_expr = member_expr;
+ }
+ members_to_clone.push_back(member);
+ } else {
+ TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Invalid entry point parameter";
+ }
+ }
+
+ if (!has_locations) {
+ // Nothing to do.
+ new_function_parameters.push_back(ctx.Clone(param));
+ return;
+ }
+
+ // Create a function-scope variable to replace the parameter.
+ auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->type));
+ ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var));
+
+ if (!members_to_clone.empty()) {
+ // Create a new struct without the location attributes.
+ ast::StructMemberList new_members;
+ for (auto* member : members_to_clone) {
+ auto member_sym = ctx.Clone(member->symbol);
+ auto* member_type = ctx.Clone(member->type);
+ auto member_attrs = ctx.Clone(member->attributes);
+ new_members.push_back(
+ ctx.dst->Member(member_sym, member_type, std::move(member_attrs)));
+ }
+ auto* new_struct = ctx.dst->Structure(ctx.dst->Sym(), new_members);
+
+ // Create a new function parameter with this struct.
+ auto* new_param = ctx.dst->Param(ctx.dst->Sym(), ctx.dst->ty.Of(new_struct));
+ new_function_parameters.push_back(new_param);
+
+ // Copy values from the new parameter to the function-scope variable.
+ for (auto* member : members_to_clone) {
+ auto member_name = ctx.Clone(member->symbol);
+ ctx.InsertFront(func->body->statements,
+ ctx.dst->Assign(ctx.dst->MemberAccessor(func_var, member_name),
+ ctx.dst->MemberAccessor(new_param, member_name)));
+ }
+ }
+ }
+
+ /// Process an entry point function.
+ /// @param func the entry point function
+ void Process(const ast::Function* func) {
+ if (func->body->Empty()) {
+ return;
+ }
+
+ // Process entry point parameters.
+ for (auto* param : func->params) {
+ auto* sem = ctx.src->Sem().Get(param);
+ if (auto* str = sem->Type()->As<sem::Struct>()) {
+ ProcessStructParameter(func, param, str->Declaration());
+ } else {
+ ProcessNonStructParameter(func, param);
+ }
+ }
+
+ // Insert new parameters for vertex_index and instance_index if needed.
+ if (!vertex_index_expr) {
+ for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
+ if (layout.step_mode == VertexStepMode::kVertex) {
+ auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index");
+ new_function_parameters.push_back(ctx.dst->Param(
+ name, ctx.dst->ty.u32(), {ctx.dst->Builtin(ast::Builtin::kVertexIndex)}));
+ vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); };
+ break;
+ }
+ }
+ }
+ if (!instance_index_expr) {
+ for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
+ if (layout.step_mode == VertexStepMode::kInstance) {
+ auto name = ctx.dst->Symbols().New("tint_pulling_instance_index");
+ new_function_parameters.push_back(ctx.dst->Param(
+ name, ctx.dst->ty.u32(), {ctx.dst->Builtin(ast::Builtin::kInstanceIndex)}));
+ instance_index_expr = [this, name]() { return ctx.dst->Expr(name); };
+ break;
+ }
+ }
+ }
+
+ // Generate vertex pulling preamble.
+ if (auto* block = CreateVertexPullingPreamble()) {
+ ctx.InsertFront(func->body->statements, block);
+ }
+
+ // Rewrite the function header with the new parameters.
+ auto func_sym = ctx.Clone(func->symbol);
+ auto* ret_type = ctx.Clone(func->return_type);
+ auto* body = ctx.Clone(func->body);
+ auto attrs = ctx.Clone(func->attributes);
+ auto ret_attrs = ctx.Clone(func->return_type_attributes);
+ auto* new_func =
+ ctx.dst->create<ast::Function>(func->source, func_sym, new_function_parameters,
+ ret_type, body, std::move(attrs), std::move(ret_attrs));
+ ctx.Replace(func, new_func);
+ }
};
} // namespace
@@ -902,42 +869,38 @@
VertexPulling::VertexPulling() = default;
VertexPulling::~VertexPulling() = default;
-void VertexPulling::Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap&) const {
- auto cfg = cfg_;
- if (auto* cfg_data = inputs.Get<Config>()) {
- cfg = *cfg_data;
- }
+void VertexPulling::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
+ auto cfg = cfg_;
+ if (auto* cfg_data = inputs.Get<Config>()) {
+ cfg = *cfg_data;
+ }
- // Find entry point
- auto* func = ctx.src->AST().Functions().Find(
- ctx.src->Symbols().Get(cfg.entry_point_name),
- ast::PipelineStage::kVertex);
- if (func == nullptr) {
- ctx.dst->Diagnostics().add_error(diag::System::Transform,
- "Vertex stage entry point not found");
- return;
- }
+ // Find entry point
+ auto* func = ctx.src->AST().Functions().Find(ctx.src->Symbols().Get(cfg.entry_point_name),
+ ast::PipelineStage::kVertex);
+ if (func == nullptr) {
+ ctx.dst->Diagnostics().add_error(diag::System::Transform,
+ "Vertex stage entry point not found");
+ return;
+ }
- // TODO(idanr): Need to check shader locations in descriptor cover all
- // attributes
+ // TODO(idanr): Need to check shader locations in descriptor cover all
+ // attributes
- // TODO(idanr): Make sure we covered all error cases, to guarantee the
- // following stages will pass
+ // TODO(idanr): Make sure we covered all error cases, to guarantee the
+ // following stages will pass
- State state{ctx, cfg};
- state.AddVertexStorageBuffers();
- state.Process(func);
+ State state{ctx, cfg};
+ state.AddVertexStorageBuffers();
+ state.Process(func);
- ctx.Clone();
+ ctx.Clone();
}
VertexPulling::Config::Config() = default;
VertexPulling::Config::Config(const Config&) = default;
VertexPulling::Config::~Config() = default;
-VertexPulling::Config& VertexPulling::Config::operator=(const Config&) =
- default;
+VertexPulling::Config& VertexPulling::Config::operator=(const Config&) = default;
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;
diff --git a/src/tint/transform/vertex_pulling.h b/src/tint/transform/vertex_pulling.h
index ec0769b..7875600 100644
--- a/src/tint/transform/vertex_pulling.h
+++ b/src/tint/transform/vertex_pulling.h
@@ -26,38 +26,38 @@
/// Describes the format of data in a vertex buffer
enum class VertexFormat {
- kUint8x2, // uint8x2
- kUint8x4, // uint8x4
- kSint8x2, // sint8x2
- kSint8x4, // sint8x4
- kUnorm8x2, // unorm8x2
- kUnorm8x4, // unorm8x4
- kSnorm8x2, // snorm8x2
- kSnorm8x4, // snorm8x4
- kUint16x2, // uint16x2
- kUint16x4, // uint16x4
- kSint16x2, // sint16x2
- kSint16x4, // sint16x4
- kUnorm16x2, // unorm16x2
- kUnorm16x4, // unorm16x4
- kSnorm16x2, // snorm16x2
- kSnorm16x4, // snorm16x4
- kFloat16x2, // float16x2
- kFloat16x4, // float16x4
- kFloat32, // float32
- kFloat32x2, // float32x2
- kFloat32x3, // float32x3
- kFloat32x4, // float32x4
- kUint32, // uint32
- kUint32x2, // uint32x2
- kUint32x3, // uint32x3
- kUint32x4, // uint32x4
- kSint32, // sint32
- kSint32x2, // sint32x2
- kSint32x3, // sint32x3
- kSint32x4, // sint32x4
+ kUint8x2, // uint8x2
+ kUint8x4, // uint8x4
+ kSint8x2, // sint8x2
+ kSint8x4, // sint8x4
+ kUnorm8x2, // unorm8x2
+ kUnorm8x4, // unorm8x4
+ kSnorm8x2, // snorm8x2
+ kSnorm8x4, // snorm8x4
+ kUint16x2, // uint16x2
+ kUint16x4, // uint16x4
+ kSint16x2, // sint16x2
+ kSint16x4, // sint16x4
+ kUnorm16x2, // unorm16x2
+ kUnorm16x4, // unorm16x4
+ kSnorm16x2, // snorm16x2
+ kSnorm16x4, // snorm16x4
+ kFloat16x2, // float16x2
+ kFloat16x4, // float16x4
+ kFloat32, // float32
+ kFloat32x2, // float32x2
+ kFloat32x3, // float32x3
+ kFloat32x4, // float32x4
+ kUint32, // uint32
+ kUint32x2, // uint32x2
+ kUint32x3, // uint32x3
+ kUint32x4, // uint32x4
+ kSint32, // sint32
+ kSint32x2, // sint32x2
+ kSint32x3, // sint32x3
+ kSint32x4, // sint32x4
- kLastEntry = kSint32x4,
+ kLastEntry = kSint32x4,
};
/// Describes if a vertex attributes increments with vertex index or instance
@@ -66,44 +66,42 @@
/// Describes a vertex attribute within a buffer
struct VertexAttributeDescriptor {
- /// The format of the attribute
- VertexFormat format;
- /// The byte offset of the attribute in the buffer
- uint32_t offset;
- /// The shader location used for the attribute
- uint32_t shader_location;
+ /// The format of the attribute
+ VertexFormat format;
+ /// The byte offset of the attribute in the buffer
+ uint32_t offset;
+ /// The shader location used for the attribute
+ uint32_t shader_location;
};
/// Describes a buffer containing multiple vertex attributes
struct VertexBufferLayoutDescriptor {
- /// Constructor
- VertexBufferLayoutDescriptor();
- /// Constructor
- /// @param in_array_stride the array stride of the in buffer
- /// @param in_step_mode the step mode of the in buffer
- /// @param in_attributes the in attributes
- VertexBufferLayoutDescriptor(
- uint32_t in_array_stride,
- VertexStepMode in_step_mode,
- std::vector<VertexAttributeDescriptor> in_attributes);
- /// Copy constructor
- /// @param other the struct to copy
- VertexBufferLayoutDescriptor(const VertexBufferLayoutDescriptor& other);
+ /// Constructor
+ VertexBufferLayoutDescriptor();
+ /// Constructor
+ /// @param in_array_stride the array stride of the in buffer
+ /// @param in_step_mode the step mode of the in buffer
+ /// @param in_attributes the in attributes
+ VertexBufferLayoutDescriptor(uint32_t in_array_stride,
+ VertexStepMode in_step_mode,
+ std::vector<VertexAttributeDescriptor> in_attributes);
+ /// Copy constructor
+ /// @param other the struct to copy
+ VertexBufferLayoutDescriptor(const VertexBufferLayoutDescriptor& other);
- /// Assignment operator
- /// @param other the struct to copy
- /// @returns this struct
- VertexBufferLayoutDescriptor& operator=(
- const VertexBufferLayoutDescriptor& other);
+ /// Assignment operator
+ /// @param other the struct to copy
+ /// @returns this struct
+ VertexBufferLayoutDescriptor& operator=(const VertexBufferLayoutDescriptor& other);
- ~VertexBufferLayoutDescriptor();
+ ~VertexBufferLayoutDescriptor();
- /// The array stride used in the in buffer
- uint32_t array_stride = 0u;
- /// The input step mode used
- VertexStepMode step_mode = VertexStepMode::kVertex;
- /// The vertex attributes
- std::vector<VertexAttributeDescriptor> attributes;
+ /// The array stride used in the in buffer
+ uint32_t array_stride = 0u;
+ /// The input step mode used
+ VertexStepMode step_mode = VertexStepMode::kVertex;
+ /// The vertex attributes
+ std::vector<VertexAttributeDescriptor> attributes;
};
/// Describes vertex state, which consists of many buffers containing vertex
@@ -131,52 +129,50 @@
/// these smaller types into the base types such as `f32` and `u32` for the
/// shader to use.
class VertexPulling : public Castable<VertexPulling, Transform> {
- public:
- /// Configuration options for the transform
- struct Config : public Castable<Config, Data> {
- /// Constructor
- Config();
+ public:
+ /// Configuration options for the transform
+ struct Config : public Castable<Config, Data> {
+ /// Constructor
+ Config();
- /// Copy constructor
- Config(const Config&);
+ /// Copy constructor
+ Config(const Config&);
+
+ /// Destructor
+ ~Config() override;
+
+ /// Assignment operator
+ /// @returns this Config
+ Config& operator=(const Config&);
+
+ /// The entry point to add assignments into
+ std::string entry_point_name;
+
+ /// The vertex state descriptor, containing info about attributes
+ VertexStateDescriptor vertex_state;
+
+ /// The "group" we will put all our vertex buffers into (as storage buffers)
+ /// Default to 4 as it is past the limits of user-accessible groups
+ uint32_t pulling_group = 4u;
+ };
+
+ /// Constructor
+ VertexPulling();
/// Destructor
- ~Config() override;
+ ~VertexPulling() override;
- /// Assignment operator
- /// @returns this Config
- Config& operator=(const Config&);
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
- /// The entry point to add assignments into
- std::string entry_point_name;
-
- /// The vertex state descriptor, containing info about attributes
- VertexStateDescriptor vertex_state;
-
- /// The "group" we will put all our vertex buffers into (as storage buffers)
- /// Default to 4 as it is past the limits of user-accessible groups
- uint32_t pulling_group = 4u;
- };
-
- /// Constructor
- VertexPulling();
-
- /// Destructor
- ~VertexPulling() override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
-
- private:
- Config cfg_;
+ private:
+ Config cfg_;
};
} // namespace tint::transform
diff --git a/src/tint/transform/vertex_pulling_test.cc b/src/tint/transform/vertex_pulling_test.cc
index 3c19aa6..82e28b3 100644
--- a/src/tint/transform/vertex_pulling_test.cc
+++ b/src/tint/transform/vertex_pulling_test.cc
@@ -24,88 +24,87 @@
using VertexPullingTest = TransformTest;
TEST_F(VertexPullingTest, Error_NoEntryPoint) {
- auto* src = "";
+ auto* src = "";
- auto* expect = "error: Vertex stage entry point not found";
+ auto* expect = "error: Vertex stage entry point not found";
- DataMap data;
- data.Add<VertexPulling::Config>();
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>();
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, Error_InvalidEntryPoint) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn main() -> @builtin(position) vec4<f32> {
return vec4<f32>();
}
)";
- auto* expect = "error: Vertex stage entry point not found";
+ auto* expect = "error: Vertex stage entry point not found";
- VertexPulling::Config cfg;
- cfg.entry_point_name = "_";
+ VertexPulling::Config cfg;
+ cfg.entry_point_name = "_";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, Error_EntryPointWrongStage) {
- auto* src = R"(
+ auto* src = R"(
@stage(fragment)
fn main() {}
)";
- auto* expect = "error: Vertex stage entry point not found";
+ auto* expect = "error: Vertex stage entry point not found";
- VertexPulling::Config cfg;
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, Error_BadStride) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn main(@location(0) var_a : f32) -> @builtin(position) vec4<f32> {
return vec4<f32>(var_a, 0.0, 0.0, 1.0);
}
)";
- auto* expect =
- "error: WebGPU requires that vertex stride must be a multiple of 4 "
- "bytes, but VertexPulling array stride for buffer 0 was 15 bytes";
+ auto* expect =
+ "error: WebGPU requires that vertex stride must be a multiple of 4 "
+ "bytes, but VertexPulling array stride for buffer 0 was 15 bytes";
- VertexPulling::Config cfg;
- cfg.vertex_state = {
- {{15, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}};
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{15, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}};
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, BasicModule) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn main() -> @builtin(position) vec4<f32> {
return vec4<f32>();
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct TintVertexData {
tint_vertex_data : array<u32>,
}
@@ -116,25 +115,25 @@
}
)";
- VertexPulling::Config cfg;
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, OneAttribute) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn main(@location(0) var_a : f32) -> @builtin(position) vec4<f32> {
return vec4<f32>(var_a, 0.0, 0.0, 1.0);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct TintVertexData {
tint_vertex_data : array<u32>,
}
@@ -152,27 +151,26 @@
}
)";
- VertexPulling::Config cfg;
- cfg.vertex_state = {
- {{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}};
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}};
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, OneInstancedAttribute) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn main(@location(0) var_a : f32) -> @builtin(position) vec4<f32> {
return vec4<f32>(var_a, 0.0, 0.0, 1.0);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct TintVertexData {
tint_vertex_data : array<u32>,
}
@@ -190,27 +188,26 @@
}
)";
- VertexPulling::Config cfg;
- cfg.vertex_state = {
- {{4, VertexStepMode::kInstance, {{VertexFormat::kFloat32, 0, 0}}}}};
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{4, VertexStepMode::kInstance, {{VertexFormat::kFloat32, 0, 0}}}}};
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn main(@location(0) var_a : f32) -> @builtin(position) vec4<f32> {
return vec4<f32>(var_a, 0.0, 0.0, 1.0);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct TintVertexData {
tint_vertex_data : array<u32>,
}
@@ -228,21 +225,20 @@
}
)";
- VertexPulling::Config cfg;
- cfg.vertex_state = {
- {{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}};
- cfg.pulling_group = 5;
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}};
+ cfg.pulling_group = 5;
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, OneAttribute_Struct) {
- auto* src = R"(
+ auto* src = R"(
struct Inputs {
@location(0) var_a : f32,
};
@@ -253,7 +249,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct TintVertexData {
tint_vertex_data : array<u32>,
}
@@ -276,21 +272,20 @@
}
)";
- VertexPulling::Config cfg;
- cfg.vertex_state = {
- {{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}};
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}};
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
// We expect the transform to use an existing builtin variables if it finds them
TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn main(@location(0) var_a : f32,
@location(1) var_b : f32,
@@ -301,7 +296,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct TintVertexData {
tint_vertex_data : array<u32>,
}
@@ -324,30 +319,30 @@
}
)";
- VertexPulling::Config cfg;
- cfg.vertex_state = {{
- {
- 4,
- VertexStepMode::kVertex,
- {{VertexFormat::kFloat32, 0, 0}},
- },
- {
- 4,
- VertexStepMode::kInstance,
- {{VertexFormat::kFloat32, 0, 1}},
- },
- }};
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {
+ 4,
+ VertexStepMode::kVertex,
+ {{VertexFormat::kFloat32, 0, 0}},
+ },
+ {
+ 4,
+ VertexStepMode::kInstance,
+ {{VertexFormat::kFloat32, 0, 1}},
+ },
+ }};
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_Struct) {
- auto* src = R"(
+ auto* src = R"(
struct Inputs {
@location(0) var_a : f32,
@location(1) var_b : f32,
@@ -361,7 +356,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct TintVertexData {
tint_vertex_data : array<u32>,
}
@@ -403,31 +398,30 @@
}
)";
- VertexPulling::Config cfg;
- cfg.vertex_state = {{
- {
- 4,
- VertexStepMode::kVertex,
- {{VertexFormat::kFloat32, 0, 0}},
- },
- {
- 4,
- VertexStepMode::kInstance,
- {{VertexFormat::kFloat32, 0, 1}},
- },
- }};
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {
+ 4,
+ VertexStepMode::kVertex,
+ {{VertexFormat::kFloat32, 0, 0}},
+ },
+ {
+ 4,
+ VertexStepMode::kInstance,
+ {{VertexFormat::kFloat32, 0, 1}},
+ },
+ }};
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(VertexPullingTest,
- ExistingVertexIndexAndInstanceIndex_Struct_OutOfOrder) {
- auto* src = R"(
+TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_Struct_OutOfOrder) {
+ auto* src = R"(
@stage(vertex)
fn main(inputs : Inputs) -> @builtin(position) vec4<f32> {
return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
@@ -441,7 +435,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct TintVertexData {
tint_vertex_data : array<u32>,
}
@@ -483,30 +477,30 @@
}
)";
- VertexPulling::Config cfg;
- cfg.vertex_state = {{
- {
- 4,
- VertexStepMode::kVertex,
- {{VertexFormat::kFloat32, 0, 0}},
- },
- {
- 4,
- VertexStepMode::kInstance,
- {{VertexFormat::kFloat32, 0, 1}},
- },
- }};
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {
+ 4,
+ VertexStepMode::kVertex,
+ {{VertexFormat::kFloat32, 0, 0}},
+ },
+ {
+ 4,
+ VertexStepMode::kInstance,
+ {{VertexFormat::kFloat32, 0, 1}},
+ },
+ }};
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_SeparateStruct) {
- auto* src = R"(
+ auto* src = R"(
struct Inputs {
@location(0) var_a : f32,
@location(1) var_b : f32,
@@ -523,7 +517,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct TintVertexData {
tint_vertex_data : array<u32>,
}
@@ -559,31 +553,30 @@
}
)";
- VertexPulling::Config cfg;
- cfg.vertex_state = {{
- {
- 4,
- VertexStepMode::kVertex,
- {{VertexFormat::kFloat32, 0, 0}},
- },
- {
- 4,
- VertexStepMode::kInstance,
- {{VertexFormat::kFloat32, 0, 1}},
- },
- }};
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {
+ 4,
+ VertexStepMode::kVertex,
+ {{VertexFormat::kFloat32, 0, 0}},
+ },
+ {
+ 4,
+ VertexStepMode::kInstance,
+ {{VertexFormat::kFloat32, 0, 1}},
+ },
+ }};
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(VertexPullingTest,
- ExistingVertexIndexAndInstanceIndex_SeparateStruct_OutOfOrder) {
- auto* src = R"(
+TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_SeparateStruct_OutOfOrder) {
+ auto* src = R"(
@stage(vertex)
fn main(inputs : Inputs, indices : Indices) -> @builtin(position) vec4<f32> {
return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
@@ -600,7 +593,7 @@
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct TintVertexData {
tint_vertex_data : array<u32>,
}
@@ -636,30 +629,30 @@
}
)";
- VertexPulling::Config cfg;
- cfg.vertex_state = {{
- {
- 4,
- VertexStepMode::kVertex,
- {{VertexFormat::kFloat32, 0, 0}},
- },
- {
- 4,
- VertexStepMode::kInstance,
- {{VertexFormat::kFloat32, 0, 1}},
- },
- }};
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {
+ 4,
+ VertexStepMode::kVertex,
+ {{VertexFormat::kFloat32, 0, 0}},
+ },
+ {
+ 4,
+ VertexStepMode::kInstance,
+ {{VertexFormat::kFloat32, 0, 1}},
+ },
+ }};
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, TwoAttributesSameBuffer) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn main(@location(0) var_a : f32,
@location(1) var_b : vec4<f32>) -> @builtin(position) vec4<f32> {
@@ -667,7 +660,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct TintVertexData {
tint_vertex_data : array<u32>,
}
@@ -687,22 +680,21 @@
}
)";
- VertexPulling::Config cfg;
- cfg.vertex_state = {
- {{16,
- VertexStepMode::kVertex,
- {{VertexFormat::kFloat32, 0, 0}, {VertexFormat::kFloat32x4, 0, 1}}}}};
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{16,
+ VertexStepMode::kVertex,
+ {{VertexFormat::kFloat32, 0, 0}, {VertexFormat::kFloat32x4, 0, 1}}}}};
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, FloatVectorAttributes) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn main(@location(0) var_a : vec2<f32>,
@location(1) var_b : vec3<f32>,
@@ -712,7 +704,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct TintVertexData {
tint_vertex_data : array<u32>,
}
@@ -740,23 +732,23 @@
}
)";
- VertexPulling::Config cfg;
- cfg.vertex_state = {{
- {8, VertexStepMode::kVertex, {{VertexFormat::kFloat32x2, 0, 0}}},
- {12, VertexStepMode::kVertex, {{VertexFormat::kFloat32x3, 0, 1}}},
- {16, VertexStepMode::kVertex, {{VertexFormat::kFloat32x4, 0, 2}}},
- }};
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {8, VertexStepMode::kVertex, {{VertexFormat::kFloat32x2, 0, 0}}},
+ {12, VertexStepMode::kVertex, {{VertexFormat::kFloat32x3, 0, 1}}},
+ {16, VertexStepMode::kVertex, {{VertexFormat::kFloat32x4, 0, 2}}},
+ }};
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, AttemptSymbolCollision) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn main(@location(0) var_a : f32,
@location(1) var_b : vec4<f32>) -> @builtin(position) vec4<f32> {
@@ -768,7 +760,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct TintVertexData {
tint_vertex_data_1 : array<u32>,
}
@@ -792,22 +784,21 @@
}
)";
- VertexPulling::Config cfg;
- cfg.vertex_state = {
- {{16,
- VertexStepMode::kVertex,
- {{VertexFormat::kFloat32, 0, 0}, {VertexFormat::kFloat32x4, 0, 1}}}}};
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{16,
+ VertexStepMode::kVertex,
+ {{VertexFormat::kFloat32, 0, 0}, {VertexFormat::kFloat32x4, 0, 1}}}}};
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, std::move(data));
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, std::move(data));
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, FormatsAligned) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn main(
@location(0) uint8x2 : vec2<u32>,
@@ -845,7 +836,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct TintVertexData {
tint_vertex_data : array<u32>,
}
@@ -921,52 +912,38 @@
}
)";
- VertexPulling::Config cfg;
- cfg.vertex_state = {{{256,
- VertexStepMode::kVertex,
- {
- {VertexFormat::kUint8x2, 64, 0},
- {VertexFormat::kUint8x4, 64, 1},
- {VertexFormat::kSint8x2, 64, 2},
- {VertexFormat::kSint8x4, 64, 3},
- {VertexFormat::kUnorm8x2, 64, 4},
- {VertexFormat::kUnorm8x4, 64, 5},
- {VertexFormat::kSnorm8x2, 64, 6},
- {VertexFormat::kSnorm8x4, 64, 7},
- {VertexFormat::kUint16x2, 64, 8},
- {VertexFormat::kUint16x4, 64, 9},
- {VertexFormat::kSint16x2, 64, 10},
- {VertexFormat::kSint16x4, 64, 11},
- {VertexFormat::kUnorm16x2, 64, 12},
- {VertexFormat::kUnorm16x4, 64, 13},
- {VertexFormat::kSnorm16x2, 64, 14},
- {VertexFormat::kSnorm16x4, 64, 15},
- {VertexFormat::kFloat16x2, 64, 16},
- {VertexFormat::kFloat16x4, 64, 17},
- {VertexFormat::kFloat32, 64, 18},
- {VertexFormat::kFloat32x2, 64, 19},
- {VertexFormat::kFloat32x3, 64, 20},
- {VertexFormat::kFloat32x4, 64, 21},
- {VertexFormat::kUint32, 64, 22},
- {VertexFormat::kUint32x2, 64, 23},
- {VertexFormat::kUint32x3, 64, 24},
- {VertexFormat::kUint32x4, 64, 25},
- {VertexFormat::kSint32, 64, 26},
- {VertexFormat::kSint32x2, 64, 27},
- {VertexFormat::kSint32x3, 64, 28},
- {VertexFormat::kSint32x4, 64, 29},
- }}}};
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {
+ {{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUint8x2, 64, 0}, {VertexFormat::kUint8x4, 64, 1},
+ {VertexFormat::kSint8x2, 64, 2}, {VertexFormat::kSint8x4, 64, 3},
+ {VertexFormat::kUnorm8x2, 64, 4}, {VertexFormat::kUnorm8x4, 64, 5},
+ {VertexFormat::kSnorm8x2, 64, 6}, {VertexFormat::kSnorm8x4, 64, 7},
+ {VertexFormat::kUint16x2, 64, 8}, {VertexFormat::kUint16x4, 64, 9},
+ {VertexFormat::kSint16x2, 64, 10}, {VertexFormat::kSint16x4, 64, 11},
+ {VertexFormat::kUnorm16x2, 64, 12}, {VertexFormat::kUnorm16x4, 64, 13},
+ {VertexFormat::kSnorm16x2, 64, 14}, {VertexFormat::kSnorm16x4, 64, 15},
+ {VertexFormat::kFloat16x2, 64, 16}, {VertexFormat::kFloat16x4, 64, 17},
+ {VertexFormat::kFloat32, 64, 18}, {VertexFormat::kFloat32x2, 64, 19},
+ {VertexFormat::kFloat32x3, 64, 20}, {VertexFormat::kFloat32x4, 64, 21},
+ {VertexFormat::kUint32, 64, 22}, {VertexFormat::kUint32x2, 64, 23},
+ {VertexFormat::kUint32x3, 64, 24}, {VertexFormat::kUint32x4, 64, 25},
+ {VertexFormat::kSint32, 64, 26}, {VertexFormat::kSint32x2, 64, 27},
+ {VertexFormat::kSint32x3, 64, 28}, {VertexFormat::kSint32x4, 64, 29},
+ }}}};
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, FormatsStrideUnaligned) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn main(
@location(0) uint8x2 : vec2<u32>,
@@ -1004,8 +981,8 @@
}
)";
- auto* expect =
- R"(
+ auto* expect =
+ R"(
struct TintVertexData {
tint_vertex_data : array<u32>,
}
@@ -1081,52 +1058,38 @@
}
)";
- VertexPulling::Config cfg;
- cfg.vertex_state = {{{256,
- VertexStepMode::kVertex,
- {
- {VertexFormat::kUint8x2, 63, 0},
- {VertexFormat::kUint8x4, 63, 1},
- {VertexFormat::kSint8x2, 63, 2},
- {VertexFormat::kSint8x4, 63, 3},
- {VertexFormat::kUnorm8x2, 63, 4},
- {VertexFormat::kUnorm8x4, 63, 5},
- {VertexFormat::kSnorm8x2, 63, 6},
- {VertexFormat::kSnorm8x4, 63, 7},
- {VertexFormat::kUint16x2, 63, 8},
- {VertexFormat::kUint16x4, 63, 9},
- {VertexFormat::kSint16x2, 63, 10},
- {VertexFormat::kSint16x4, 63, 11},
- {VertexFormat::kUnorm16x2, 63, 12},
- {VertexFormat::kUnorm16x4, 63, 13},
- {VertexFormat::kSnorm16x2, 63, 14},
- {VertexFormat::kSnorm16x4, 63, 15},
- {VertexFormat::kFloat16x2, 63, 16},
- {VertexFormat::kFloat16x4, 63, 17},
- {VertexFormat::kFloat32, 63, 18},
- {VertexFormat::kFloat32x2, 63, 19},
- {VertexFormat::kFloat32x3, 63, 20},
- {VertexFormat::kFloat32x4, 63, 21},
- {VertexFormat::kUint32, 63, 22},
- {VertexFormat::kUint32x2, 63, 23},
- {VertexFormat::kUint32x3, 63, 24},
- {VertexFormat::kUint32x4, 63, 25},
- {VertexFormat::kSint32, 63, 26},
- {VertexFormat::kSint32x2, 63, 27},
- {VertexFormat::kSint32x3, 63, 28},
- {VertexFormat::kSint32x4, 63, 29},
- }}}};
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {
+ {{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUint8x2, 63, 0}, {VertexFormat::kUint8x4, 63, 1},
+ {VertexFormat::kSint8x2, 63, 2}, {VertexFormat::kSint8x4, 63, 3},
+ {VertexFormat::kUnorm8x2, 63, 4}, {VertexFormat::kUnorm8x4, 63, 5},
+ {VertexFormat::kSnorm8x2, 63, 6}, {VertexFormat::kSnorm8x4, 63, 7},
+ {VertexFormat::kUint16x2, 63, 8}, {VertexFormat::kUint16x4, 63, 9},
+ {VertexFormat::kSint16x2, 63, 10}, {VertexFormat::kSint16x4, 63, 11},
+ {VertexFormat::kUnorm16x2, 63, 12}, {VertexFormat::kUnorm16x4, 63, 13},
+ {VertexFormat::kSnorm16x2, 63, 14}, {VertexFormat::kSnorm16x4, 63, 15},
+ {VertexFormat::kFloat16x2, 63, 16}, {VertexFormat::kFloat16x4, 63, 17},
+ {VertexFormat::kFloat32, 63, 18}, {VertexFormat::kFloat32x2, 63, 19},
+ {VertexFormat::kFloat32x3, 63, 20}, {VertexFormat::kFloat32x4, 63, 21},
+ {VertexFormat::kUint32, 63, 22}, {VertexFormat::kUint32x2, 63, 23},
+ {VertexFormat::kUint32x3, 63, 24}, {VertexFormat::kUint32x4, 63, 25},
+ {VertexFormat::kSint32, 63, 26}, {VertexFormat::kSint32x2, 63, 27},
+ {VertexFormat::kSint32x3, 63, 28}, {VertexFormat::kSint32x4, 63, 29},
+ }}}};
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, FormatsWithVectorsResized) {
- auto* src = R"(
+ auto* src = R"(
@stage(vertex)
fn main(
@location(0) uint8x2 : vec3<u32>,
@@ -1164,7 +1127,7 @@
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct TintVertexData {
tint_vertex_data : array<u32>,
}
@@ -1240,48 +1203,34 @@
}
)";
- VertexPulling::Config cfg;
- cfg.vertex_state = {{{256,
- VertexStepMode::kVertex,
- {
- {VertexFormat::kUint8x2, 64, 0},
- {VertexFormat::kUint8x4, 64, 1},
- {VertexFormat::kSint8x2, 64, 2},
- {VertexFormat::kSint8x4, 64, 3},
- {VertexFormat::kUnorm8x2, 64, 4},
- {VertexFormat::kUnorm8x4, 64, 5},
- {VertexFormat::kSnorm8x2, 64, 6},
- {VertexFormat::kSnorm8x4, 64, 7},
- {VertexFormat::kUint16x2, 64, 8},
- {VertexFormat::kUint16x4, 64, 9},
- {VertexFormat::kSint16x2, 64, 10},
- {VertexFormat::kSint16x4, 64, 11},
- {VertexFormat::kUnorm16x2, 64, 12},
- {VertexFormat::kUnorm16x4, 64, 13},
- {VertexFormat::kSnorm16x2, 64, 14},
- {VertexFormat::kSnorm16x4, 64, 15},
- {VertexFormat::kFloat16x2, 64, 16},
- {VertexFormat::kFloat16x4, 64, 17},
- {VertexFormat::kFloat32, 64, 18},
- {VertexFormat::kFloat32x2, 64, 19},
- {VertexFormat::kFloat32x3, 64, 20},
- {VertexFormat::kFloat32x4, 64, 21},
- {VertexFormat::kUint32, 64, 22},
- {VertexFormat::kUint32x2, 64, 23},
- {VertexFormat::kUint32x3, 64, 24},
- {VertexFormat::kUint32x4, 64, 25},
- {VertexFormat::kSint32, 64, 26},
- {VertexFormat::kSint32x2, 64, 27},
- {VertexFormat::kSint32x3, 64, 28},
- {VertexFormat::kSint32x4, 64, 29},
- }}}};
- cfg.entry_point_name = "main";
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {
+ {{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUint8x2, 64, 0}, {VertexFormat::kUint8x4, 64, 1},
+ {VertexFormat::kSint8x2, 64, 2}, {VertexFormat::kSint8x4, 64, 3},
+ {VertexFormat::kUnorm8x2, 64, 4}, {VertexFormat::kUnorm8x4, 64, 5},
+ {VertexFormat::kSnorm8x2, 64, 6}, {VertexFormat::kSnorm8x4, 64, 7},
+ {VertexFormat::kUint16x2, 64, 8}, {VertexFormat::kUint16x4, 64, 9},
+ {VertexFormat::kSint16x2, 64, 10}, {VertexFormat::kSint16x4, 64, 11},
+ {VertexFormat::kUnorm16x2, 64, 12}, {VertexFormat::kUnorm16x4, 64, 13},
+ {VertexFormat::kSnorm16x2, 64, 14}, {VertexFormat::kSnorm16x4, 64, 15},
+ {VertexFormat::kFloat16x2, 64, 16}, {VertexFormat::kFloat16x4, 64, 17},
+ {VertexFormat::kFloat32, 64, 18}, {VertexFormat::kFloat32x2, 64, 19},
+ {VertexFormat::kFloat32x3, 64, 20}, {VertexFormat::kFloat32x4, 64, 21},
+ {VertexFormat::kUint32, 64, 22}, {VertexFormat::kUint32x2, 64, 23},
+ {VertexFormat::kUint32x3, 64, 24}, {VertexFormat::kUint32x4, 64, 25},
+ {VertexFormat::kSint32, 64, 26}, {VertexFormat::kSint32x2, 64, 27},
+ {VertexFormat::kSint32x3, 64, 28}, {VertexFormat::kSint32x4, 64, 29},
+ }}}};
+ cfg.entry_point_name = "main";
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/wrap_arrays_in_structs.cc b/src/tint/transform/wrap_arrays_in_structs.cc
index 7cf3fcb..b47fa63 100644
--- a/src/tint/transform/wrap_arrays_in_structs.cc
+++ b/src/tint/transform/wrap_arrays_in_structs.cc
@@ -29,141 +29,130 @@
namespace tint::transform {
WrapArraysInStructs::WrappedArrayInfo::WrappedArrayInfo() = default;
-WrapArraysInStructs::WrappedArrayInfo::WrappedArrayInfo(
- const WrappedArrayInfo&) = default;
+WrapArraysInStructs::WrappedArrayInfo::WrappedArrayInfo(const WrappedArrayInfo&) = default;
WrapArraysInStructs::WrappedArrayInfo::~WrappedArrayInfo() = default;
WrapArraysInStructs::WrapArraysInStructs() = default;
WrapArraysInStructs::~WrapArraysInStructs() = default;
-bool WrapArraysInStructs::ShouldRun(const Program* program,
- const DataMap&) const {
- for (auto* node : program->ASTNodes().Objects()) {
- if (program->Sem().Get<sem::Array>(node->As<ast::Type>())) {
- return true;
+bool WrapArraysInStructs::ShouldRun(const Program* program, const DataMap&) const {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (program->Sem().Get<sem::Array>(node->As<ast::Type>())) {
+ return true;
+ }
}
- }
- return false;
+ return false;
}
-void WrapArraysInStructs::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- auto& sem = ctx.src->Sem();
+void WrapArraysInStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ auto& sem = ctx.src->Sem();
- std::unordered_map<const sem::Array*, WrappedArrayInfo> wrapped_arrays;
- auto wrapper = [&](const sem::Array* array) {
- return WrapArray(ctx, wrapped_arrays, array);
- };
- auto wrapper_typename = [&](const sem::Array* arr) -> ast::TypeName* {
- auto info = wrapper(arr);
- return info ? ctx.dst->create<ast::TypeName>(info.wrapper_name) : nullptr;
- };
+ std::unordered_map<const sem::Array*, WrappedArrayInfo> wrapped_arrays;
+ auto wrapper = [&](const sem::Array* array) { return WrapArray(ctx, wrapped_arrays, array); };
+ auto wrapper_typename = [&](const sem::Array* arr) -> ast::TypeName* {
+ auto info = wrapper(arr);
+ return info ? ctx.dst->create<ast::TypeName>(info.wrapper_name) : nullptr;
+ };
- // Replace all array types with their corresponding wrapper
- ctx.ReplaceAll([&](const ast::Type* ast_type) -> const ast::Type* {
- auto* type = ctx.src->TypeOf(ast_type);
- if (auto* array = type->UnwrapRef()->As<sem::Array>()) {
- return wrapper_typename(array);
- }
- return nullptr;
- });
-
- // Fix up index accessors so `a[1]` becomes `a.arr[1]`
- ctx.ReplaceAll([&](const ast::IndexAccessorExpression* accessor)
- -> const ast::IndexAccessorExpression* {
- if (auto* array = ::tint::As<sem::Array>(
- sem.Get(accessor->object)->Type()->UnwrapRef())) {
- if (wrapper(array)) {
- // Array is wrapped in a structure. Emit a member accessor to get
- // to the actual array.
- auto* arr = ctx.Clone(accessor->object);
- auto* idx = ctx.Clone(accessor->index);
- auto* unwrapped = ctx.dst->MemberAccessor(arr, "arr");
- return ctx.dst->IndexAccessor(accessor->source, unwrapped, idx);
- }
- }
- return nullptr;
- });
-
- // Fix up array constructors so `A(1,2)` becomes `tint_array_wrapper(A(1,2))`
- ctx.ReplaceAll(
- [&](const ast::CallExpression* expr) -> const ast::Expression* {
- if (auto* call = sem.Get(expr)) {
- if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
- if (auto* array = ctor->ReturnType()->As<sem::Array>()) {
- if (auto w = wrapper(array)) {
- // Wrap the array type constructor with another constructor for
- // the wrapper
- auto* wrapped_array_ty = ctx.dst->ty.type_name(w.wrapper_name);
- auto* array_ty = w.array_type(ctx);
- auto args = utils::Transform(
- call->Arguments(), [&](const tint::sem::Expression* s) {
- return ctx.Clone(s->Declaration());
- });
- auto* arr_ctor = ctx.dst->Construct(array_ty, args);
- return ctx.dst->Construct(wrapped_array_ty, arr_ctor);
- }
- }
- }
+ // Replace all array types with their corresponding wrapper
+ ctx.ReplaceAll([&](const ast::Type* ast_type) -> const ast::Type* {
+ auto* type = ctx.src->TypeOf(ast_type);
+ if (auto* array = type->UnwrapRef()->As<sem::Array>()) {
+ return wrapper_typename(array);
}
return nullptr;
- });
+ });
- ctx.Clone();
+ // Fix up index accessors so `a[1]` becomes `a.arr[1]`
+ ctx.ReplaceAll(
+ [&](const ast::IndexAccessorExpression* accessor) -> const ast::IndexAccessorExpression* {
+ if (auto* array =
+ ::tint::As<sem::Array>(sem.Get(accessor->object)->Type()->UnwrapRef())) {
+ if (wrapper(array)) {
+ // Array is wrapped in a structure. Emit a member accessor to get
+ // to the actual array.
+ auto* arr = ctx.Clone(accessor->object);
+ auto* idx = ctx.Clone(accessor->index);
+ auto* unwrapped = ctx.dst->MemberAccessor(arr, "arr");
+ return ctx.dst->IndexAccessor(accessor->source, unwrapped, idx);
+ }
+ }
+ return nullptr;
+ });
+
+ // Fix up array constructors so `A(1,2)` becomes `tint_array_wrapper(A(1,2))`
+ ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::Expression* {
+ if (auto* call = sem.Get(expr)) {
+ if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
+ if (auto* array = ctor->ReturnType()->As<sem::Array>()) {
+ if (auto w = wrapper(array)) {
+ // Wrap the array type constructor with another constructor for
+ // the wrapper
+ auto* wrapped_array_ty = ctx.dst->ty.type_name(w.wrapper_name);
+ auto* array_ty = w.array_type(ctx);
+ auto args = utils::Transform(call->Arguments(),
+ [&](const tint::sem::Expression* s) {
+ return ctx.Clone(s->Declaration());
+ });
+ auto* arr_ctor = ctx.dst->Construct(array_ty, args);
+ return ctx.dst->Construct(wrapped_array_ty, arr_ctor);
+ }
+ }
+ }
+ }
+ return nullptr;
+ });
+
+ ctx.Clone();
}
WrapArraysInStructs::WrappedArrayInfo WrapArraysInStructs::WrapArray(
CloneContext& ctx,
std::unordered_map<const sem::Array*, WrappedArrayInfo>& wrapped_arrays,
const sem::Array* array) const {
- if (array->IsRuntimeSized()) {
- return {}; // We don't want to wrap runtime sized arrays
- }
+ if (array->IsRuntimeSized()) {
+ return {}; // We don't want to wrap runtime sized arrays
+ }
- return utils::GetOrCreate(wrapped_arrays, array, [&] {
- WrappedArrayInfo info;
+ return utils::GetOrCreate(wrapped_arrays, array, [&] {
+ WrappedArrayInfo info;
- // Generate a unique name for the array wrapper
- info.wrapper_name = ctx.dst->Symbols().New("tint_array_wrapper");
+ // Generate a unique name for the array wrapper
+ info.wrapper_name = ctx.dst->Symbols().New("tint_array_wrapper");
- // Examine the element type. Is it also an array?
- std::function<const ast::Type*(CloneContext&)> el_type;
- if (auto* el_array = array->ElemType()->As<sem::Array>()) {
- // Array of array - call WrapArray() on the element type
- if (auto el = WrapArray(ctx, wrapped_arrays, el_array)) {
- el_type = [=](CloneContext& c) {
- return c.dst->create<ast::TypeName>(el.wrapper_name);
+ // Examine the element type. Is it also an array?
+ std::function<const ast::Type*(CloneContext&)> el_type;
+ if (auto* el_array = array->ElemType()->As<sem::Array>()) {
+ // Array of array - call WrapArray() on the element type
+ if (auto el = WrapArray(ctx, wrapped_arrays, el_array)) {
+ el_type = [=](CloneContext& c) {
+ return c.dst->create<ast::TypeName>(el.wrapper_name);
+ };
+ }
+ }
+
+ // If the element wasn't an array, just create the typical AST type for it
+ if (!el_type) {
+ el_type = [=](CloneContext& c) { return CreateASTTypeFor(c, array->ElemType()); };
+ }
+
+ // Construct the single structure field type
+ info.array_type = [=](CloneContext& c) {
+ ast::AttributeList attrs;
+ if (!array->IsStrideImplicit()) {
+ attrs.emplace_back(c.dst->create<ast::StrideAttribute>(array->Stride()));
+ }
+ return c.dst->ty.array(el_type(c), array->Count(), std::move(attrs));
};
- }
- }
- // If the element wasn't an array, just create the typical AST type for it
- if (!el_type) {
- el_type = [=](CloneContext& c) {
- return CreateASTTypeFor(c, array->ElemType());
- };
- }
-
- // Construct the single structure field type
- info.array_type = [=](CloneContext& c) {
- ast::AttributeList attrs;
- if (!array->IsStrideImplicit()) {
- attrs.emplace_back(
- c.dst->create<ast::StrideAttribute>(array->Stride()));
- }
- return c.dst->ty.array(el_type(c), array->Count(), std::move(attrs));
- };
-
- // Structure() will create and append the ast::Struct to the
- // global declarations of `ctx.dst`. As we haven't finished building the
- // current module-scope statement or function, this will be placed
- // immediately before the usage.
- ctx.dst->Structure(info.wrapper_name,
- {ctx.dst->Member("arr", info.array_type(ctx))});
- return info;
- });
+ // Structure() will create and append the ast::Struct to the
+ // global declarations of `ctx.dst`. As we haven't finished building the
+ // current module-scope statement or function, this will be placed
+ // immediately before the usage.
+ ctx.dst->Structure(info.wrapper_name, {ctx.dst->Member("arr", info.array_type(ctx))});
+ return info;
+ });
}
} // namespace tint::transform
diff --git a/src/tint/transform/wrap_arrays_in_structs.h b/src/tint/transform/wrap_arrays_in_structs.h
index a256ff8..4653c6b 100644
--- a/src/tint/transform/wrap_arrays_in_structs.h
+++ b/src/tint/transform/wrap_arrays_in_structs.h
@@ -34,56 +34,53 @@
/// This transform helps with backends that cannot directly return arrays or use
/// them as parameters.
class WrapArraysInStructs : public Castable<WrapArraysInStructs, Transform> {
- public:
- /// Constructor
- WrapArraysInStructs();
+ public:
+ /// Constructor
+ WrapArraysInStructs();
- /// Destructor
- ~WrapArraysInStructs() override;
+ /// Destructor
+ ~WrapArraysInStructs() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
- private:
- struct WrappedArrayInfo {
- WrappedArrayInfo();
- WrappedArrayInfo(const WrappedArrayInfo&);
- ~WrappedArrayInfo();
+ private:
+ struct WrappedArrayInfo {
+ WrappedArrayInfo();
+ WrappedArrayInfo(const WrappedArrayInfo&);
+ ~WrappedArrayInfo();
- Symbol wrapper_name;
- std::function<const ast::Type*(CloneContext&)> array_type;
+ Symbol wrapper_name;
+ std::function<const ast::Type*(CloneContext&)> array_type;
- operator bool() { return wrapper_name.IsValid(); }
- };
+ operator bool() { return wrapper_name.IsValid(); }
+ };
- /// WrapArray wraps the fixed-size array type in a new structure (if it hasn't
- /// already been wrapped). WrapArray will recursively wrap arrays-of-arrays.
- /// The new structure will be added to module-scope type declarations of
- /// `ctx.dst`.
- /// @param ctx the CloneContext
- /// @param wrapped_arrays a map of src array type to the wrapped structure
- /// name
- /// @param array the array type
- /// @return the name of the structure that wraps the array, or an invalid
- /// Symbol if this array should not be wrapped
- WrappedArrayInfo WrapArray(
- CloneContext& ctx,
- std::unordered_map<const sem::Array*, WrappedArrayInfo>& wrapped_arrays,
- const sem::Array* array) const;
+ /// WrapArray wraps the fixed-size array type in a new structure (if it hasn't
+ /// already been wrapped). WrapArray will recursively wrap arrays-of-arrays.
+ /// The new structure will be added to module-scope type declarations of
+ /// `ctx.dst`.
+ /// @param ctx the CloneContext
+ /// @param wrapped_arrays a map of src array type to the wrapped structure
+ /// name
+ /// @param array the array type
+ /// @return the name of the structure that wraps the array, or an invalid
+ /// Symbol if this array should not be wrapped
+ WrappedArrayInfo WrapArray(
+ CloneContext& ctx,
+ std::unordered_map<const sem::Array*, WrappedArrayInfo>& wrapped_arrays,
+ const sem::Array* array) const;
};
} // namespace tint::transform
diff --git a/src/tint/transform/wrap_arrays_in_structs_test.cc b/src/tint/transform/wrap_arrays_in_structs_test.cc
index 7ba884c..7a7a6b3 100644
--- a/src/tint/transform/wrap_arrays_in_structs_test.cc
+++ b/src/tint/transform/wrap_arrays_in_structs_test.cc
@@ -25,33 +25,33 @@
using WrapArraysInStructsTest = TransformTest;
TEST_F(WrapArraysInStructsTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<WrapArraysInStructs>(src));
+ EXPECT_FALSE(ShouldRun<WrapArraysInStructs>(src));
}
TEST_F(WrapArraysInStructsTest, ShouldRunHasArray) {
- auto* src = R"(
+ auto* src = R"(
var<private> arr : array<i32, 4>;
)";
- EXPECT_TRUE(ShouldRun<WrapArraysInStructs>(src));
+ EXPECT_TRUE(ShouldRun<WrapArraysInStructs>(src));
}
TEST_F(WrapArraysInStructsTest, EmptyModule) {
- auto* src = R"()";
- auto* expect = src;
+ auto* src = R"()";
+ auto* expect = src;
- auto got = Run<WrapArraysInStructs>(src);
+ auto got = Run<WrapArraysInStructs>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, ArrayAsGlobal) {
- auto* src = R"(
+ auto* src = R"(
var<private> arr : array<i32, 4>;
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 4u>,
}
@@ -59,19 +59,19 @@
var<private> arr : tint_array_wrapper;
)";
- auto got = Run<WrapArraysInStructs>(src);
+ auto got = Run<WrapArraysInStructs>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, ArrayAsFunctionVar) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var arr : array<i32, 4>;
let x = arr[3];
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 4u>,
}
@@ -82,18 +82,18 @@
}
)";
- auto got = Run<WrapArraysInStructs>(src);
+ auto got = Run<WrapArraysInStructs>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, ArrayAsParam) {
- auto* src = R"(
+ auto* src = R"(
fn f(a : array<i32, 4>) -> i32 {
return a[2];
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 4u>,
}
@@ -103,18 +103,18 @@
}
)";
- auto got = Run<WrapArraysInStructs>(src);
+ auto got = Run<WrapArraysInStructs>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, ArrayAsReturn) {
- auto* src = R"(
+ auto* src = R"(
fn f() -> array<i32, 4> {
return array<i32, 4>(1, 2, 3, 4);
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 4u>,
}
@@ -124,13 +124,13 @@
}
)";
- auto got = Run<WrapArraysInStructs>(src);
+ auto got = Run<WrapArraysInStructs>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, ArrayAlias) {
- auto* src = R"(
+ auto* src = R"(
type Inner = array<i32, 2>;
type Array = array<Inner, 2>;
@@ -143,7 +143,7 @@
let x = arr[3];
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 2u>,
}
@@ -166,13 +166,13 @@
}
)";
- auto got = Run<WrapArraysInStructs>(src);
+ auto got = Run<WrapArraysInStructs>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, ArrayAlias_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f() {
var arr : Array;
arr = Array();
@@ -185,7 +185,7 @@
type Array = array<Inner, 2>;
type Inner = array<i32, 2>;
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_array_wrapper_1 {
arr : array<i32, 2u>,
}
@@ -208,20 +208,20 @@
type Inner = tint_array_wrapper_1;
)";
- auto got = Run<WrapArraysInStructs>(src);
+ auto got = Run<WrapArraysInStructs>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, ArraysInStruct) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : array<i32, 4>,
b : array<i32, 8>,
c : array<i32, 4>,
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 4u>,
}
@@ -237,20 +237,20 @@
}
)";
- auto got = Run<WrapArraysInStructs>(src);
+ auto got = Run<WrapArraysInStructs>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, ArraysOfArraysInStruct) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : array<i32, 4>,
b : array<array<i32, 4>, 4>,
c : array<array<array<i32, 4>, 4>, 4>,
};
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 4u>,
}
@@ -270,13 +270,13 @@
}
)";
- auto got = Run<WrapArraysInStructs>(src);
+ auto got = Run<WrapArraysInStructs>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, AccessArraysOfArraysInStruct) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : array<i32, 4>,
b : array<array<i32, 4>, 4>,
@@ -287,7 +287,7 @@
return s.a[2] + s.b[1][2] + s.c[3][1][2];
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 4u>,
}
@@ -311,13 +311,13 @@
}
)";
- auto got = Run<WrapArraysInStructs>(src);
+ auto got = Run<WrapArraysInStructs>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, DeclarationOrder) {
- auto* src = R"(
+ auto* src = R"(
type T0 = i32;
type T1 = array<i32, 1>;
@@ -333,7 +333,7 @@
var v : array<i32, 3>;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
type T0 = i32;
struct tint_array_wrapper {
@@ -362,13 +362,13 @@
}
)";
- auto got = Run<WrapArraysInStructs>(src);
+ auto got = Run<WrapArraysInStructs>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, DeclarationOrder_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
fn f2() {
var v : array<i32, 3>;
}
@@ -384,7 +384,7 @@
type T0 = i32;
)";
- auto* expect = R"(
+ auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 3u>,
}
@@ -413,9 +413,9 @@
type T0 = i32;
)";
- auto got = Run<WrapArraysInStructs>(src);
+ auto got = Run<WrapArraysInStructs>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace
diff --git a/src/tint/transform/zero_init_workgroup_memory.cc b/src/tint/transform/zero_init_workgroup_memory.cc
index 13f2f79..b299c59 100644
--- a/src/tint/transform/zero_init_workgroup_memory.cc
+++ b/src/tint/transform/zero_init_workgroup_memory.cc
@@ -34,426 +34,410 @@
/// PIMPL state for the ZeroInitWorkgroupMemory transform
struct ZeroInitWorkgroupMemory::State {
- /// The clone context
- CloneContext& ctx;
+ /// The clone context
+ CloneContext& ctx;
- /// An alias to *ctx.dst
- ProgramBuilder& b = *ctx.dst;
+ /// An alias to *ctx.dst
+ ProgramBuilder& b = *ctx.dst;
- /// The constant size of the workgroup. If 0, then #workgroup_size_expr should
- /// be used instead.
- uint32_t workgroup_size_const = 0;
- /// The size of the workgroup as an expression generator. Use if
- /// #workgroup_size_const is 0.
- std::function<const ast::Expression*()> workgroup_size_expr;
+ /// The constant size of the workgroup. If 0, then #workgroup_size_expr should
+ /// be used instead.
+ uint32_t workgroup_size_const = 0;
+ /// The size of the workgroup as an expression generator. Use if
+ /// #workgroup_size_const is 0.
+ std::function<const ast::Expression*()> workgroup_size_expr;
- /// ArrayIndex represents a function on the local invocation index, of
- /// the form: `array_index = (local_invocation_index % modulo) / division`
- struct ArrayIndex {
- /// The RHS of the modulus part of the expression
- uint32_t modulo = 1;
- /// The RHS of the division part of the expression
- uint32_t division = 1;
+ /// ArrayIndex represents a function on the local invocation index, of
+ /// the form: `array_index = (local_invocation_index % modulo) / division`
+ struct ArrayIndex {
+ /// The RHS of the modulus part of the expression
+ uint32_t modulo = 1;
+ /// The RHS of the division part of the expression
+ uint32_t division = 1;
- /// Equality operator
- /// @param i the ArrayIndex to compare to this ArrayIndex
- /// @returns true if `i` and this ArrayIndex are equal
- bool operator==(const ArrayIndex& i) const {
- return modulo == i.modulo && division == i.division;
- }
-
- /// Hash function for the ArrayIndex type
- struct Hasher {
- /// @param i the ArrayIndex to calculate a hash for
- /// @returns the hash value for the ArrayIndex `i`
- size_t operator()(const ArrayIndex& i) const {
- return utils::Hash(i.modulo, i.division);
- }
- };
- };
-
- /// A list of unique ArrayIndex
- using ArrayIndices = utils::UniqueVector<ArrayIndex, ArrayIndex::Hasher>;
-
- /// Expression holds information about an expression that is being built for a
- /// statement will zero workgroup values.
- struct Expression {
- /// The AST expression node
- const ast::Expression* expr = nullptr;
- /// The number of iterations required to zero the value
- uint32_t num_iterations = 0;
- /// All array indices used by this expression
- ArrayIndices array_indices;
- };
-
- /// Statement holds information about a statement that will zero workgroup
- /// values.
- struct Statement {
- /// The AST statement node
- const ast::Statement* stmt;
- /// The number of iterations required to zero the value
- uint32_t num_iterations;
- /// All array indices used by this statement
- ArrayIndices array_indices;
- };
-
- /// All statements that zero workgroup memory
- std::vector<Statement> statements;
-
- /// A map of ArrayIndex to the name reserved for the `let` declaration of that
- /// index.
- std::unordered_map<ArrayIndex, Symbol, ArrayIndex::Hasher> array_index_names;
-
- /// Constructor
- /// @param c the CloneContext used for the transform
- explicit State(CloneContext& c) : ctx(c) {}
-
- /// Run inserts the workgroup memory zero-initialization logic at the top of
- /// the given function
- /// @param fn a compute shader entry point function
- void Run(const ast::Function* fn) {
- auto& sem = ctx.src->Sem();
-
- CalculateWorkgroupSize(
- ast::GetAttribute<ast::WorkgroupAttribute>(fn->attributes));
-
- // Generate a list of statements to zero initialize each of the
- // workgroup storage variables used by `fn`. This will populate #statements.
- auto* func = sem.Get(fn);
- for (auto* var : func->TransitivelyReferencedGlobals()) {
- if (var->StorageClass() == ast::StorageClass::kWorkgroup) {
- BuildZeroingStatements(
- var->Type()->UnwrapRef(), [&](uint32_t num_values) {
- auto var_name = ctx.Clone(var->Declaration()->symbol);
- return Expression{b.Expr(var_name), num_values, ArrayIndices{}};
- });
- }
- }
-
- if (statements.empty()) {
- return; // No workgroup variables to initialize.
- }
-
- // Scan the entry point for an existing local_invocation_index builtin
- // parameter
- std::function<const ast::Expression*()> local_index;
- for (auto* param : fn->params) {
- if (auto* builtin =
- ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) {
- if (builtin->builtin == ast::Builtin::kLocalInvocationIndex) {
- local_index = [=] { return b.Expr(ctx.Clone(param->symbol)); };
- break;
+ /// Equality operator
+ /// @param i the ArrayIndex to compare to this ArrayIndex
+ /// @returns true if `i` and this ArrayIndex are equal
+ bool operator==(const ArrayIndex& i) const {
+ return modulo == i.modulo && division == i.division;
}
- }
- if (auto* str = sem.Get(param)->Type()->As<sem::Struct>()) {
- for (auto* member : str->Members()) {
- if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(
- member->Declaration()->attributes)) {
- if (builtin->builtin == ast::Builtin::kLocalInvocationIndex) {
- local_index = [=] {
- auto* param_expr = b.Expr(ctx.Clone(param->symbol));
- auto member_name = ctx.Clone(member->Declaration()->symbol);
- return b.MemberAccessor(param_expr, member_name);
- };
- break;
+ /// Hash function for the ArrayIndex type
+ struct Hasher {
+ /// @param i the ArrayIndex to calculate a hash for
+ /// @returns the hash value for the ArrayIndex `i`
+ size_t operator()(const ArrayIndex& i) const {
+ return utils::Hash(i.modulo, i.division);
}
- }
- }
- }
- }
- if (!local_index) {
- // No existing local index parameter. Append one to the entry point.
- auto* param =
- b.Param(b.Symbols().New("local_invocation_index"), b.ty.u32(),
- {b.Builtin(ast::Builtin::kLocalInvocationIndex)});
- ctx.InsertBack(fn->params, param);
- local_index = [=] { return b.Expr(param->symbol); };
- }
-
- // Take the zeroing statements and bin them by the number of iterations
- // required to zero the workgroup data. We then emit these in blocks,
- // possibly wrapped in if-statements or for-loops.
- std::unordered_map<uint32_t, std::vector<Statement>>
- stmts_by_num_iterations;
- std::vector<uint32_t> num_sorted_iterations;
- for (auto& s : statements) {
- auto& stmts = stmts_by_num_iterations[s.num_iterations];
- if (stmts.empty()) {
- num_sorted_iterations.emplace_back(s.num_iterations);
- }
- stmts.emplace_back(s);
- }
- std::sort(num_sorted_iterations.begin(), num_sorted_iterations.end());
-
- // Loop over the statements, grouped by num_iterations.
- for (auto num_iterations : num_sorted_iterations) {
- auto& stmts = stmts_by_num_iterations[num_iterations];
-
- // Gather all the array indices used by all the statements in the block.
- ArrayIndices array_indices;
- for (auto& s : stmts) {
- for (auto& idx : s.array_indices) {
- array_indices.add(idx);
- }
- }
-
- // Determine the block type used to emit these statements.
-
- if (workgroup_size_const == 0 || num_iterations > workgroup_size_const) {
- // Either the workgroup size is dynamic, or smaller than num_iterations.
- // In either case, we need to generate a for loop to ensure we
- // initialize all the array elements.
- //
- // for (var idx : u32 = local_index;
- // idx < num_iterations;
- // idx += workgroup_size) {
- // ...
- // }
- auto idx = b.Symbols().New("idx");
- auto* init = b.Decl(b.Var(idx, b.ty.u32(), local_index()));
- auto* cond = b.create<ast::BinaryExpression>(
- ast::BinaryOp::kLessThan, b.Expr(idx), b.Expr(num_iterations));
- auto* cont = b.Assign(
- idx, b.Add(idx, workgroup_size_const ? b.Expr(workgroup_size_const)
- : workgroup_size_expr()));
-
- auto block = DeclareArrayIndices(num_iterations, array_indices,
- [&] { return b.Expr(idx); });
- for (auto& s : stmts) {
- block.emplace_back(s.stmt);
- }
- auto* for_loop = b.For(init, cond, cont, b.Block(block));
- ctx.InsertFront(fn->body->statements, for_loop);
- } else if (num_iterations < workgroup_size_const) {
- // Workgroup size is a known constant, but is greater than
- // num_iterations. Emit an if statement:
- //
- // if (local_index < num_iterations) {
- // ...
- // }
- auto* cond = b.create<ast::BinaryExpression>(
- ast::BinaryOp::kLessThan, local_index(), b.Expr(num_iterations));
- auto block = DeclareArrayIndices(num_iterations, array_indices,
- [&] { return b.Expr(local_index()); });
- for (auto& s : stmts) {
- block.emplace_back(s.stmt);
- }
- auto* if_stmt = b.If(cond, b.Block(block));
- ctx.InsertFront(fn->body->statements, if_stmt);
- } else {
- // Workgroup size exactly equals num_iterations.
- // No need for any conditionals. Just emit a basic block:
- //
- // {
- // ...
- // }
- auto block = DeclareArrayIndices(num_iterations, array_indices,
- [&] { return b.Expr(local_index()); });
- for (auto& s : stmts) {
- block.emplace_back(s.stmt);
- }
- ctx.InsertFront(fn->body->statements, b.Block(block));
- }
- }
-
- // Append a single workgroup barrier after the zero initialization.
- ctx.InsertFront(fn->body->statements,
- b.CallStmt(b.Call("workgroupBarrier")));
- }
-
- /// BuildZeroingExpr is a function that builds a sub-expression used to zero
- /// workgroup values. `num_values` is the number of elements that the
- /// expression will be used to zero. Returns the expression.
- using BuildZeroingExpr = std::function<Expression(uint32_t num_values)>;
-
- /// BuildZeroingStatements() generates the statements required to zero
- /// initialize the workgroup storage expression of type `ty`.
- /// @param ty the expression type
- /// @param get_expr a function that builds the AST nodes for the expression.
- void BuildZeroingStatements(const sem::Type* ty,
- const BuildZeroingExpr& get_expr) {
- if (CanTriviallyZero(ty)) {
- auto var = get_expr(1u);
- auto* zero_init = b.Construct(CreateASTTypeFor(ctx, ty));
- statements.emplace_back(Statement{b.Assign(var.expr, zero_init),
- var.num_iterations, var.array_indices});
- return;
- }
-
- if (auto* atomic = ty->As<sem::Atomic>()) {
- auto* zero_init = b.Construct(CreateASTTypeFor(ctx, atomic->Type()));
- auto expr = get_expr(1u);
- auto* store = b.Call("atomicStore", b.AddressOf(expr.expr), zero_init);
- statements.emplace_back(Statement{b.CallStmt(store), expr.num_iterations,
- expr.array_indices});
- return;
- }
-
- if (auto* str = ty->As<sem::Struct>()) {
- for (auto* member : str->Members()) {
- auto name = ctx.Clone(member->Declaration()->symbol);
- BuildZeroingStatements(member->Type(), [&](uint32_t num_values) {
- auto s = get_expr(num_values);
- return Expression{b.MemberAccessor(s.expr, name), s.num_iterations,
- s.array_indices};
- });
- }
- return;
- }
-
- if (auto* arr = ty->As<sem::Array>()) {
- BuildZeroingStatements(arr->ElemType(), [&](uint32_t num_values) {
- // num_values is the number of values to zero for the element type.
- // The number of iterations required to zero the array and its elements
- // is:
- // `num_values * arr->Count()`
- // The index for this array is:
- // `(idx % modulo) / division`
- auto modulo = num_values * arr->Count();
- auto division = num_values;
- auto a = get_expr(modulo);
- auto array_indices = a.array_indices;
- array_indices.add(ArrayIndex{modulo, division});
- auto index =
- utils::GetOrCreate(array_index_names, ArrayIndex{modulo, division},
- [&] { return b.Symbols().New("i"); });
- return Expression{b.IndexAccessor(a.expr, index), a.num_iterations,
- array_indices};
- });
- return;
- }
-
- TINT_UNREACHABLE(Transform, b.Diagnostics())
- << "could not zero workgroup type: "
- << ty->FriendlyName(ctx.src->Symbols());
- }
-
- /// DeclareArrayIndices returns a list of statements that contain the `let`
- /// declarations for all of the ArrayIndices.
- /// @param num_iterations the number of iterations for the block
- /// @param array_indices the list of array indices to generate `let`
- /// declarations for
- /// @param iteration a function that returns the index of the current
- /// iteration.
- /// @returns the list of `let` statements that declare the array indices
- ast::StatementList DeclareArrayIndices(
- uint32_t num_iterations,
- const ArrayIndices& array_indices,
- const std::function<const ast::Expression*()>& iteration) {
- ast::StatementList stmts;
- std::map<Symbol, ArrayIndex> indices_by_name;
- for (auto index : array_indices) {
- auto name = array_index_names.at(index);
- auto* mod =
- (num_iterations > index.modulo)
- ? b.create<ast::BinaryExpression>(
- ast::BinaryOp::kModulo, iteration(), b.Expr(index.modulo))
- : iteration();
- auto* div = (index.division != 1u) ? b.Div(mod, index.division) : mod;
- auto* decl = b.Decl(b.Let(name, b.ty.u32(), div));
- stmts.emplace_back(decl);
- }
- return stmts;
- }
-
- /// CalculateWorkgroupSize initializes the members #workgroup_size_const and
- /// #workgroup_size_expr with the linear workgroup size.
- /// @param attr the workgroup attribute applied to the entry point function
- void CalculateWorkgroupSize(const ast::WorkgroupAttribute* attr) {
- bool is_signed = false;
- workgroup_size_const = 1u;
- workgroup_size_expr = nullptr;
- for (auto* expr : attr->Values()) {
- if (!expr) {
- continue;
- }
- auto* sem = ctx.src->Sem().Get(expr);
- if (auto c = sem->ConstantValue()) {
- if (c.ElementType()->Is<sem::I32>()) {
- workgroup_size_const *= static_cast<uint32_t>(c.Elements()[0].i32);
- continue;
- } else if (c.ElementType()->Is<sem::U32>()) {
- workgroup_size_const *= c.Elements()[0].u32;
- continue;
- }
- }
- // Constant value could not be found. Build expression instead.
- workgroup_size_expr = [this, expr, size = workgroup_size_expr] {
- auto* e = ctx.Clone(expr);
- if (ctx.src->TypeOf(expr)->UnwrapRef()->Is<sem::I32>()) {
- e = b.Construct<ProgramBuilder::u32>(e);
- }
- return size ? b.Mul(size(), e) : e;
- };
- }
- if (workgroup_size_expr) {
- if (workgroup_size_const != 1) {
- // Fold workgroup_size_const in to workgroup_size_expr
- workgroup_size_expr = [this, is_signed,
- const_size = workgroup_size_const,
- expr_size = workgroup_size_expr] {
- return is_signed
- ? b.Mul(expr_size(), static_cast<int32_t>(const_size))
- : b.Mul(expr_size(), const_size);
};
- }
- // Indicate that workgroup_size_expr should be used instead of the
- // constant.
- workgroup_size_const = 0;
- }
- }
+ };
- /// @returns true if a variable with store type `ty` can be efficiently zeroed
- /// by assignment of a type constructor without operands. If
- /// CanTriviallyZero() returns false, then the type needs to be
- /// initialized by decomposing the initialization into multiple
- /// sub-initializations.
- /// @param ty the type to inspect
- bool CanTriviallyZero(const sem::Type* ty) {
- if (ty->Is<sem::Atomic>()) {
- return false;
- }
- if (auto* str = ty->As<sem::Struct>()) {
- for (auto* member : str->Members()) {
- if (!CanTriviallyZero(member->Type())) {
- return false;
+ /// A list of unique ArrayIndex
+ using ArrayIndices = utils::UniqueVector<ArrayIndex, ArrayIndex::Hasher>;
+
+ /// Expression holds information about an expression that is being built for a
+ /// statement will zero workgroup values.
+ struct Expression {
+ /// The AST expression node
+ const ast::Expression* expr = nullptr;
+ /// The number of iterations required to zero the value
+ uint32_t num_iterations = 0;
+ /// All array indices used by this expression
+ ArrayIndices array_indices;
+ };
+
+ /// Statement holds information about a statement that will zero workgroup
+ /// values.
+ struct Statement {
+ /// The AST statement node
+ const ast::Statement* stmt;
+ /// The number of iterations required to zero the value
+ uint32_t num_iterations;
+ /// All array indices used by this statement
+ ArrayIndices array_indices;
+ };
+
+ /// All statements that zero workgroup memory
+ std::vector<Statement> statements;
+
+ /// A map of ArrayIndex to the name reserved for the `let` declaration of that
+ /// index.
+ std::unordered_map<ArrayIndex, Symbol, ArrayIndex::Hasher> array_index_names;
+
+ /// Constructor
+ /// @param c the CloneContext used for the transform
+ explicit State(CloneContext& c) : ctx(c) {}
+
+ /// Run inserts the workgroup memory zero-initialization logic at the top of
+ /// the given function
+ /// @param fn a compute shader entry point function
+ void Run(const ast::Function* fn) {
+ auto& sem = ctx.src->Sem();
+
+ CalculateWorkgroupSize(ast::GetAttribute<ast::WorkgroupAttribute>(fn->attributes));
+
+ // Generate a list of statements to zero initialize each of the
+ // workgroup storage variables used by `fn`. This will populate #statements.
+ auto* func = sem.Get(fn);
+ for (auto* var : func->TransitivelyReferencedGlobals()) {
+ if (var->StorageClass() == ast::StorageClass::kWorkgroup) {
+ BuildZeroingStatements(var->Type()->UnwrapRef(), [&](uint32_t num_values) {
+ auto var_name = ctx.Clone(var->Declaration()->symbol);
+ return Expression{b.Expr(var_name), num_values, ArrayIndices{}};
+ });
+ }
}
- }
+
+ if (statements.empty()) {
+ return; // No workgroup variables to initialize.
+ }
+
+ // Scan the entry point for an existing local_invocation_index builtin
+ // parameter
+ std::function<const ast::Expression*()> local_index;
+ for (auto* param : fn->params) {
+ if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) {
+ if (builtin->builtin == ast::Builtin::kLocalInvocationIndex) {
+ local_index = [=] { return b.Expr(ctx.Clone(param->symbol)); };
+ break;
+ }
+ }
+
+ if (auto* str = sem.Get(param)->Type()->As<sem::Struct>()) {
+ for (auto* member : str->Members()) {
+ if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(
+ member->Declaration()->attributes)) {
+ if (builtin->builtin == ast::Builtin::kLocalInvocationIndex) {
+ local_index = [=] {
+ auto* param_expr = b.Expr(ctx.Clone(param->symbol));
+ auto member_name = ctx.Clone(member->Declaration()->symbol);
+ return b.MemberAccessor(param_expr, member_name);
+ };
+ break;
+ }
+ }
+ }
+ }
+ }
+ if (!local_index) {
+ // No existing local index parameter. Append one to the entry point.
+ auto* param = b.Param(b.Symbols().New("local_invocation_index"), b.ty.u32(),
+ {b.Builtin(ast::Builtin::kLocalInvocationIndex)});
+ ctx.InsertBack(fn->params, param);
+ local_index = [=] { return b.Expr(param->symbol); };
+ }
+
+ // Take the zeroing statements and bin them by the number of iterations
+ // required to zero the workgroup data. We then emit these in blocks,
+ // possibly wrapped in if-statements or for-loops.
+ std::unordered_map<uint32_t, std::vector<Statement>> stmts_by_num_iterations;
+ std::vector<uint32_t> num_sorted_iterations;
+ for (auto& s : statements) {
+ auto& stmts = stmts_by_num_iterations[s.num_iterations];
+ if (stmts.empty()) {
+ num_sorted_iterations.emplace_back(s.num_iterations);
+ }
+ stmts.emplace_back(s);
+ }
+ std::sort(num_sorted_iterations.begin(), num_sorted_iterations.end());
+
+ // Loop over the statements, grouped by num_iterations.
+ for (auto num_iterations : num_sorted_iterations) {
+ auto& stmts = stmts_by_num_iterations[num_iterations];
+
+ // Gather all the array indices used by all the statements in the block.
+ ArrayIndices array_indices;
+ for (auto& s : stmts) {
+ for (auto& idx : s.array_indices) {
+ array_indices.add(idx);
+ }
+ }
+
+ // Determine the block type used to emit these statements.
+
+ if (workgroup_size_const == 0 || num_iterations > workgroup_size_const) {
+ // Either the workgroup size is dynamic, or smaller than num_iterations.
+ // In either case, we need to generate a for loop to ensure we
+ // initialize all the array elements.
+ //
+ // for (var idx : u32 = local_index;
+ // idx < num_iterations;
+ // idx += workgroup_size) {
+ // ...
+ // }
+ auto idx = b.Symbols().New("idx");
+ auto* init = b.Decl(b.Var(idx, b.ty.u32(), local_index()));
+ auto* cond = b.create<ast::BinaryExpression>(ast::BinaryOp::kLessThan, b.Expr(idx),
+ b.Expr(num_iterations));
+ auto* cont =
+ b.Assign(idx, b.Add(idx, workgroup_size_const ? b.Expr(workgroup_size_const)
+ : workgroup_size_expr()));
+
+ auto block =
+ DeclareArrayIndices(num_iterations, array_indices, [&] { return b.Expr(idx); });
+ for (auto& s : stmts) {
+ block.emplace_back(s.stmt);
+ }
+ auto* for_loop = b.For(init, cond, cont, b.Block(block));
+ ctx.InsertFront(fn->body->statements, for_loop);
+ } else if (num_iterations < workgroup_size_const) {
+ // Workgroup size is a known constant, but is greater than
+ // num_iterations. Emit an if statement:
+ //
+ // if (local_index < num_iterations) {
+ // ...
+ // }
+ auto* cond = b.create<ast::BinaryExpression>(ast::BinaryOp::kLessThan,
+ local_index(), b.Expr(num_iterations));
+ auto block = DeclareArrayIndices(num_iterations, array_indices,
+ [&] { return b.Expr(local_index()); });
+ for (auto& s : stmts) {
+ block.emplace_back(s.stmt);
+ }
+ auto* if_stmt = b.If(cond, b.Block(block));
+ ctx.InsertFront(fn->body->statements, if_stmt);
+ } else {
+ // Workgroup size exactly equals num_iterations.
+ // No need for any conditionals. Just emit a basic block:
+ //
+ // {
+ // ...
+ // }
+ auto block = DeclareArrayIndices(num_iterations, array_indices,
+ [&] { return b.Expr(local_index()); });
+ for (auto& s : stmts) {
+ block.emplace_back(s.stmt);
+ }
+ ctx.InsertFront(fn->body->statements, b.Block(block));
+ }
+ }
+
+ // Append a single workgroup barrier after the zero initialization.
+ ctx.InsertFront(fn->body->statements, b.CallStmt(b.Call("workgroupBarrier")));
}
- if (ty->Is<sem::Array>()) {
- return false;
+
+ /// BuildZeroingExpr is a function that builds a sub-expression used to zero
+ /// workgroup values. `num_values` is the number of elements that the
+ /// expression will be used to zero. Returns the expression.
+ using BuildZeroingExpr = std::function<Expression(uint32_t num_values)>;
+
+ /// BuildZeroingStatements() generates the statements required to zero
+ /// initialize the workgroup storage expression of type `ty`.
+ /// @param ty the expression type
+ /// @param get_expr a function that builds the AST nodes for the expression.
+ void BuildZeroingStatements(const sem::Type* ty, const BuildZeroingExpr& get_expr) {
+ if (CanTriviallyZero(ty)) {
+ auto var = get_expr(1u);
+ auto* zero_init = b.Construct(CreateASTTypeFor(ctx, ty));
+ statements.emplace_back(
+ Statement{b.Assign(var.expr, zero_init), var.num_iterations, var.array_indices});
+ return;
+ }
+
+ if (auto* atomic = ty->As<sem::Atomic>()) {
+ auto* zero_init = b.Construct(CreateASTTypeFor(ctx, atomic->Type()));
+ auto expr = get_expr(1u);
+ auto* store = b.Call("atomicStore", b.AddressOf(expr.expr), zero_init);
+ statements.emplace_back(
+ Statement{b.CallStmt(store), expr.num_iterations, expr.array_indices});
+ return;
+ }
+
+ if (auto* str = ty->As<sem::Struct>()) {
+ for (auto* member : str->Members()) {
+ auto name = ctx.Clone(member->Declaration()->symbol);
+ BuildZeroingStatements(member->Type(), [&](uint32_t num_values) {
+ auto s = get_expr(num_values);
+ return Expression{b.MemberAccessor(s.expr, name), s.num_iterations,
+ s.array_indices};
+ });
+ }
+ return;
+ }
+
+ if (auto* arr = ty->As<sem::Array>()) {
+ BuildZeroingStatements(arr->ElemType(), [&](uint32_t num_values) {
+ // num_values is the number of values to zero for the element type.
+ // The number of iterations required to zero the array and its elements
+ // is:
+ // `num_values * arr->Count()`
+ // The index for this array is:
+ // `(idx % modulo) / division`
+ auto modulo = num_values * arr->Count();
+ auto division = num_values;
+ auto a = get_expr(modulo);
+ auto array_indices = a.array_indices;
+ array_indices.add(ArrayIndex{modulo, division});
+ auto index = utils::GetOrCreate(array_index_names, ArrayIndex{modulo, division},
+ [&] { return b.Symbols().New("i"); });
+ return Expression{b.IndexAccessor(a.expr, index), a.num_iterations, array_indices};
+ });
+ return;
+ }
+
+ TINT_UNREACHABLE(Transform, b.Diagnostics())
+ << "could not zero workgroup type: " << ty->FriendlyName(ctx.src->Symbols());
}
- // True for all other storable types
- return true;
- }
+
+ /// DeclareArrayIndices returns a list of statements that contain the `let`
+ /// declarations for all of the ArrayIndices.
+ /// @param num_iterations the number of iterations for the block
+ /// @param array_indices the list of array indices to generate `let`
+ /// declarations for
+ /// @param iteration a function that returns the index of the current
+ /// iteration.
+ /// @returns the list of `let` statements that declare the array indices
+ ast::StatementList DeclareArrayIndices(
+ uint32_t num_iterations,
+ const ArrayIndices& array_indices,
+ const std::function<const ast::Expression*()>& iteration) {
+ ast::StatementList stmts;
+ std::map<Symbol, ArrayIndex> indices_by_name;
+ for (auto index : array_indices) {
+ auto name = array_index_names.at(index);
+ auto* mod = (num_iterations > index.modulo)
+ ? b.create<ast::BinaryExpression>(ast::BinaryOp::kModulo, iteration(),
+ b.Expr(index.modulo))
+ : iteration();
+ auto* div = (index.division != 1u) ? b.Div(mod, index.division) : mod;
+ auto* decl = b.Decl(b.Let(name, b.ty.u32(), div));
+ stmts.emplace_back(decl);
+ }
+ return stmts;
+ }
+
+ /// CalculateWorkgroupSize initializes the members #workgroup_size_const and
+ /// #workgroup_size_expr with the linear workgroup size.
+ /// @param attr the workgroup attribute applied to the entry point function
+ void CalculateWorkgroupSize(const ast::WorkgroupAttribute* attr) {
+ bool is_signed = false;
+ workgroup_size_const = 1u;
+ workgroup_size_expr = nullptr;
+ for (auto* expr : attr->Values()) {
+ if (!expr) {
+ continue;
+ }
+ auto* sem = ctx.src->Sem().Get(expr);
+ if (auto c = sem->ConstantValue()) {
+ if (c.ElementType()->Is<sem::I32>()) {
+ workgroup_size_const *= static_cast<uint32_t>(c.Elements()[0].i32);
+ continue;
+ } else if (c.ElementType()->Is<sem::U32>()) {
+ workgroup_size_const *= c.Elements()[0].u32;
+ continue;
+ }
+ }
+ // Constant value could not be found. Build expression instead.
+ workgroup_size_expr = [this, expr, size = workgroup_size_expr] {
+ auto* e = ctx.Clone(expr);
+ if (ctx.src->TypeOf(expr)->UnwrapRef()->Is<sem::I32>()) {
+ e = b.Construct<ProgramBuilder::u32>(e);
+ }
+ return size ? b.Mul(size(), e) : e;
+ };
+ }
+ if (workgroup_size_expr) {
+ if (workgroup_size_const != 1) {
+ // Fold workgroup_size_const in to workgroup_size_expr
+ workgroup_size_expr = [this, is_signed, const_size = workgroup_size_const,
+ expr_size = workgroup_size_expr] {
+ return is_signed ? b.Mul(expr_size(), static_cast<int32_t>(const_size))
+ : b.Mul(expr_size(), const_size);
+ };
+ }
+ // Indicate that workgroup_size_expr should be used instead of the
+ // constant.
+ workgroup_size_const = 0;
+ }
+ }
+
+ /// @returns true if a variable with store type `ty` can be efficiently zeroed
+ /// by assignment of a type constructor without operands. If
+ /// CanTriviallyZero() returns false, then the type needs to be
+ /// initialized by decomposing the initialization into multiple
+ /// sub-initializations.
+ /// @param ty the type to inspect
+ bool CanTriviallyZero(const sem::Type* ty) {
+ if (ty->Is<sem::Atomic>()) {
+ return false;
+ }
+ if (auto* str = ty->As<sem::Struct>()) {
+ for (auto* member : str->Members()) {
+ if (!CanTriviallyZero(member->Type())) {
+ return false;
+ }
+ }
+ }
+ if (ty->Is<sem::Array>()) {
+ return false;
+ }
+ // True for all other storable types
+ return true;
+ }
};
ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default;
ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default;
-bool ZeroInitWorkgroupMemory::ShouldRun(const Program* program,
- const DataMap&) const {
- for (auto* decl : program->AST().GlobalDeclarations()) {
- if (auto* var = decl->As<ast::Variable>()) {
- if (var->declared_storage_class == ast::StorageClass::kWorkgroup) {
- return true;
- }
+bool ZeroInitWorkgroupMemory::ShouldRun(const Program* program, const DataMap&) const {
+ for (auto* decl : program->AST().GlobalDeclarations()) {
+ if (auto* var = decl->As<ast::Variable>()) {
+ if (var->declared_storage_class == ast::StorageClass::kWorkgroup) {
+ return true;
+ }
+ }
}
- }
- return false;
+ return false;
}
-void ZeroInitWorkgroupMemory::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) const {
- for (auto* fn : ctx.src->AST().Functions()) {
- if (fn->PipelineStage() == ast::PipelineStage::kCompute) {
- State{ctx}.Run(fn);
+void ZeroInitWorkgroupMemory::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ for (auto* fn : ctx.src->AST().Functions()) {
+ if (fn->PipelineStage() == ast::PipelineStage::kCompute) {
+ State{ctx}.Run(fn);
+ }
}
- }
- ctx.Clone();
+ ctx.Clone();
}
} // namespace tint::transform
diff --git a/src/tint/transform/zero_init_workgroup_memory.h b/src/tint/transform/zero_init_workgroup_memory.h
index 33ae52c..c757725 100644
--- a/src/tint/transform/zero_init_workgroup_memory.h
+++ b/src/tint/transform/zero_init_workgroup_memory.h
@@ -22,34 +22,30 @@
/// ZeroInitWorkgroupMemory is a transform that injects code at the top of entry
/// points to zero-initialize workgroup memory used by that entry point (and all
/// transitive functions called by that entry point)
-class ZeroInitWorkgroupMemory
- : public Castable<ZeroInitWorkgroupMemory, Transform> {
- public:
- /// Constructor
- ZeroInitWorkgroupMemory();
+class ZeroInitWorkgroupMemory : public Castable<ZeroInitWorkgroupMemory, Transform> {
+ public:
+ /// Constructor
+ ZeroInitWorkgroupMemory();
- /// Destructor
- ~ZeroInitWorkgroupMemory() override;
+ /// Destructor
+ ~ZeroInitWorkgroupMemory() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program,
- const DataMap& data = {}) const override;
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx,
- const DataMap& inputs,
- DataMap& outputs) const override;
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
- private:
- struct State;
+ private:
+ struct State;
};
} // namespace tint::transform
diff --git a/src/tint/transform/zero_init_workgroup_memory_test.cc b/src/tint/transform/zero_init_workgroup_memory_test.cc
index 32c73db..c846d55 100644
--- a/src/tint/transform/zero_init_workgroup_memory_test.cc
+++ b/src/tint/transform/zero_init_workgroup_memory_test.cc
@@ -24,53 +24,53 @@
using ZeroInitWorkgroupMemoryTest = TransformTest;
TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunEmptyModule) {
- auto* src = R"()";
+ auto* src = R"()";
- EXPECT_FALSE(ShouldRun<ZeroInitWorkgroupMemory>(src));
+ EXPECT_FALSE(ShouldRun<ZeroInitWorkgroupMemory>(src));
}
TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunHasNoWorkgroupVars) {
- auto* src = R"(
+ auto* src = R"(
var<private> v : i32;
)";
- EXPECT_FALSE(ShouldRun<ZeroInitWorkgroupMemory>(src));
+ EXPECT_FALSE(ShouldRun<ZeroInitWorkgroupMemory>(src));
}
TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunHasWorkgroupVars) {
- auto* src = R"(
+ auto* src = R"(
var<workgroup> a : i32;
)";
- EXPECT_TRUE(ShouldRun<ZeroInitWorkgroupMemory>(src));
+ EXPECT_TRUE(ShouldRun<ZeroInitWorkgroupMemory>(src));
}
TEST_F(ZeroInitWorkgroupMemoryTest, EmptyModule) {
- auto* src = "";
- auto* expect = src;
+ auto* src = "";
+ auto* expect = src;
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, NoWorkgroupVars) {
- auto* src = R"(
+ auto* src = R"(
var<private> v : i32;
fn f() {
v = 1;
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, UnreferencedWorkgroupVars) {
- auto* src = R"(
+ auto* src = R"(
var<workgroup> a : i32;
var<workgroup> b : i32;
@@ -85,15 +85,15 @@
fn f() {
}
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, UnreferencedWorkgroupVars_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn f() {
}
@@ -108,15 +108,15 @@
var<workgroup> c : i32;
)";
- auto* expect = src;
+ auto* expect = src;
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_ExistingLocalIndex) {
- auto* src = R"(
+ auto* src = R"(
var<workgroup> v : i32;
@stage(compute) @workgroup_size(1)
@@ -124,7 +124,7 @@
_ = v; // Initialization should be inserted above this statement
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<workgroup> v : i32;
@stage(compute) @workgroup_size(1)
@@ -137,14 +137,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(ZeroInitWorkgroupMemoryTest,
- SingleWorkgroupVar_ExistingLocalIndex_OutOfOrder) {
- auto* src = R"(
+TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_ExistingLocalIndex_OutOfOrder) {
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn f(@builtin(local_invocation_index) local_idx : u32) {
_ = v; // Initialization should be inserted above this statement
@@ -152,7 +151,7 @@
var<workgroup> v : i32;
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn f(@builtin(local_invocation_index) local_idx : u32) {
{
@@ -165,14 +164,13 @@
var<workgroup> v : i32;
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(ZeroInitWorkgroupMemoryTest,
- SingleWorkgroupVar_ExistingLocalIndexInStruct) {
- auto* src = R"(
+TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_ExistingLocalIndexInStruct) {
+ auto* src = R"(
var<workgroup> v : i32;
struct Params {
@@ -184,7 +182,7 @@
_ = v; // Initialization should be inserted above this statement
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<workgroup> v : i32;
struct Params {
@@ -202,14 +200,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(ZeroInitWorkgroupMemoryTest,
- SingleWorkgroupVar_ExistingLocalIndexInStruct_OutOfOrder) {
- auto* src = R"(
+TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_ExistingLocalIndexInStruct_OutOfOrder) {
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn f(params : Params) {
_ = v; // Initialization should be inserted above this statement
@@ -221,7 +218,7 @@
var<workgroup> v : i32;
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn f(params : Params) {
{
@@ -239,13 +236,13 @@
var<workgroup> v : i32;
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_InjectedLocalIndex) {
- auto* src = R"(
+ auto* src = R"(
var<workgroup> v : i32;
@stage(compute) @workgroup_size(1)
@@ -253,7 +250,7 @@
_ = v; // Initialization should be inserted above this statement
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<workgroup> v : i32;
@stage(compute) @workgroup_size(1)
@@ -266,14 +263,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(ZeroInitWorkgroupMemoryTest,
- SingleWorkgroupVar_InjectedLocalIndex_OutOfOrder) {
- auto* src = R"(
+TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_InjectedLocalIndex_OutOfOrder) {
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn f() {
_ = v; // Initialization should be inserted above this statement
@@ -281,7 +277,7 @@
var<workgroup> v : i32;
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
{
@@ -294,14 +290,13 @@
var<workgroup> v : i32;
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(ZeroInitWorkgroupMemoryTest,
- MultipleWorkgroupVar_ExistingLocalIndex_Size1) {
- auto* src = R"(
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_ExistingLocalIndex_Size1) {
+ auto* src = R"(
struct S {
x : i32,
y : array<i32, 8>,
@@ -320,7 +315,7 @@
_ = c;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
x : i32,
y : array<i32, 8>,
@@ -358,14 +353,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(ZeroInitWorkgroupMemoryTest,
- MultipleWorkgroupVar_ExistingLocalIndex_Size1_OutOfOrder) {
- auto* src = R"(
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_ExistingLocalIndex_Size1_OutOfOrder) {
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn f(@builtin(local_invocation_index) local_idx : u32) {
_ = a; // Initialization should be inserted above this statement
@@ -384,7 +378,7 @@
y : array<i32, 8>,
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn f(@builtin(local_invocation_index) local_idx : u32) {
{
@@ -422,14 +416,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(ZeroInitWorkgroupMemoryTest,
- MultipleWorkgroupVar_ExistingLocalIndex_Size_2_3) {
- auto* src = R"(
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_ExistingLocalIndex_Size_2_3) {
+ auto* src = R"(
struct S {
x : i32,
y : array<i32, 8>,
@@ -448,7 +441,7 @@
_ = c;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
x : i32,
y : array<i32, 8>,
@@ -486,14 +479,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(ZeroInitWorkgroupMemoryTest,
- MultipleWorkgroupVar_ExistingLocalIndex_Size_2_3_X) {
- auto* src = R"(
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_ExistingLocalIndex_Size_2_3_X) {
+ auto* src = R"(
struct S {
x : i32,
y : array<i32, 8>,
@@ -514,8 +506,8 @@
_ = c;
}
)";
- auto* expect =
- R"(
+ auto* expect =
+ R"(
struct S {
x : i32,
y : array<i32, 8>,
@@ -555,14 +547,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(ZeroInitWorkgroupMemoryTest,
- MultipleWorkgroupVar_ExistingLocalIndex_Size_5u_X_10u) {
- auto* src = R"(
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_ExistingLocalIndex_Size_5u_X_10u) {
+ auto* src = R"(
struct S {
x : array<array<i32, 8>, 10>,
y : array<i32, 8>,
@@ -584,8 +575,8 @@
_ = c;
}
)";
- auto* expect =
- R"(
+ auto* expect =
+ R"(
struct S {
x : array<array<i32, 8>, 10>,
y : array<i32, 8>,
@@ -645,13 +636,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_InjectedLocalIndex) {
- auto* src = R"(
+ auto* src = R"(
struct S {
x : i32,
y : array<i32, 8>,
@@ -670,7 +661,7 @@
_ = c;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
x : i32,
y : array<i32, 8>,
@@ -708,14 +699,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(ZeroInitWorkgroupMemoryTest,
- MultipleWorkgroupVar_InjectedLocalIndex_OutOfOrder) {
- auto* src = R"(
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_InjectedLocalIndex_OutOfOrder) {
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn f(@builtin(local_invocation_id) local_invocation_id : vec3<u32>) {
_ = a; // Initialization should be inserted above this statement
@@ -734,7 +724,7 @@
y : array<i32, 8>,
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn f(@builtin(local_invocation_id) local_invocation_id : vec3<u32>, @builtin(local_invocation_index) local_invocation_index : u32) {
{
@@ -772,13 +762,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_MultipleEntryPoints) {
- auto* src = R"(
+ auto* src = R"(
struct S {
x : i32,
y : array<i32, 8>,
@@ -807,7 +797,7 @@
_ = a;
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
x : i32,
y : array<i32, 8>,
@@ -871,14 +861,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(ZeroInitWorkgroupMemoryTest,
- MultipleWorkgroupVar_MultipleEntryPoints_OutOfOrder) {
- auto* src = R"(
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_MultipleEntryPoints_OutOfOrder) {
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn f1() {
_ = a; // Initialization should be inserted above this statement
@@ -907,7 +896,7 @@
y : array<i32, 8>,
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn f1(@builtin(local_invocation_index) local_invocation_index : u32) {
{
@@ -971,13 +960,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, TransitiveUsage) {
- auto* src = R"(
+ auto* src = R"(
var<workgroup> v : i32;
fn use_v() {
@@ -993,7 +982,7 @@
call_use_v(); // Initialization should be inserted above this statement
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<workgroup> v : i32;
fn use_v() {
@@ -1014,13 +1003,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, TransitiveUsage_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn f(@builtin(local_invocation_index) local_idx : u32) {
call_use_v(); // Initialization should be inserted above this statement
@@ -1036,7 +1025,7 @@
var<workgroup> v : i32;
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn f(@builtin(local_invocation_index) local_idx : u32) {
{
@@ -1057,13 +1046,13 @@
var<workgroup> v : i32;
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupAtomics) {
- auto* src = R"(
+ auto* src = R"(
var<workgroup> i : atomic<i32>;
var<workgroup> u : atomic<u32>;
@@ -1073,7 +1062,7 @@
atomicLoad(&(u));
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<workgroup> i : atomic<i32>;
var<workgroup> u : atomic<u32>;
@@ -1090,13 +1079,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupAtomics_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn f() {
atomicLoad(&(i)); // Initialization should be inserted above this statement
@@ -1106,7 +1095,7 @@
var<workgroup> i : atomic<i32>;
var<workgroup> u : atomic<u32>;
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
{
@@ -1123,13 +1112,13 @@
var<workgroup> u : atomic<u32>;
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupStructOfAtomics) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : i32,
i : atomic<i32>,
@@ -1145,7 +1134,7 @@
_ = w.a; // Initialization should be inserted above this statement
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : i32,
i : atomic<i32>,
@@ -1170,13 +1159,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupStructOfAtomics_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn f() {
_ = w.a; // Initialization should be inserted above this statement
@@ -1192,7 +1181,7 @@
c : u32,
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
{
@@ -1217,13 +1206,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfAtomics) {
- auto* src = R"(
+ auto* src = R"(
var<workgroup> w : array<atomic<u32>, 4>;
@stage(compute) @workgroup_size(1)
@@ -1231,7 +1220,7 @@
atomicLoad(&w[0]); // Initialization should be inserted above this statement
}
)";
- auto* expect = R"(
+ auto* expect = R"(
var<workgroup> w : array<atomic<u32>, 4>;
@stage(compute) @workgroup_size(1)
@@ -1245,13 +1234,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfAtomics_OutOfOrder) {
- auto* src = R"(
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn f() {
atomicLoad(&w[0]); // Initialization should be inserted above this statement
@@ -1259,7 +1248,7 @@
var<workgroup> w : array<atomic<u32>, 4>;
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
for(var idx : u32 = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
@@ -1273,13 +1262,13 @@
var<workgroup> w : array<atomic<u32>, 4>;
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfStructOfAtomics) {
- auto* src = R"(
+ auto* src = R"(
struct S {
a : i32,
i : atomic<i32>,
@@ -1295,7 +1284,7 @@
_ = w[0].a; // Initialization should be inserted above this statement
}
)";
- auto* expect = R"(
+ auto* expect = R"(
struct S {
a : i32,
i : atomic<i32>,
@@ -1321,14 +1310,13 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
-TEST_F(ZeroInitWorkgroupMemoryTest,
- WorkgroupArrayOfStructOfAtomics_OutOfOrder) {
- auto* src = R"(
+TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfStructOfAtomics_OutOfOrder) {
+ auto* src = R"(
@stage(compute) @workgroup_size(1)
fn f() {
_ = w[0].a; // Initialization should be inserted above this statement
@@ -1344,7 +1332,7 @@
c : u32,
};
)";
- auto* expect = R"(
+ auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
for(var idx : u32 = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
@@ -1370,9 +1358,9 @@
}
)";
- auto got = Run<ZeroInitWorkgroupMemory>(src);
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(expect, str(got));
}
} // namespace