Tint truncate interstage variable transform Add a transform to truncate unused user interstage variables by adding a new truncated shader io struct wrapper of the original one, with a truncate function to do the assignments called at the return statement. This transform is meant to be run after CanonicalizeEntryPointIO, and will only be run under hlsl/generator_impl.cc to workaround the extra register limitation for interstage variables on D3D FXC. Bug: dawn:1493 Change-Id: I69081189ad7d4b76f2371fcc079f67dced2e9944 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/104620 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: Shrek Shao <shrekshao@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 2ced868..491fd7d 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn
@@ -550,6 +550,8 @@ "transform/substitute_override.h", "transform/transform.cc", "transform/transform.h", + "transform/truncate_interstage_variables.cc", + "transform/truncate_interstage_variables.h", "transform/unshadow.cc", "transform/unshadow.h", "transform/utils/get_insertion_point.cc", @@ -1248,6 +1250,7 @@ "transform/substitute_override_test.cc", "transform/test_helper.h", "transform/transform_test.cc", + "transform/truncate_interstage_variables_test.cc", "transform/unshadow_test.cc", "transform/utils/get_insertion_point_test.cc", "transform/utils/hoist_to_decl_before_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 2521260..16ba8cf 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt
@@ -475,6 +475,8 @@ transform/substitute_override.h transform/transform.cc transform/transform.h + transform/truncate_interstage_variables.cc + transform/truncate_interstage_variables.h transform/unshadow.cc transform/unshadow.h transform/utils/get_insertion_point.cc @@ -1209,6 +1211,7 @@ transform/std140_test.cc transform/substitute_override_test.cc transform/test_helper.h + transform/truncate_interstage_variables_test.cc transform/unshadow_test.cc transform/var_for_dynamic_index_test.cc transform/vectorize_matrix_conversions_test.cc
diff --git a/src/tint/transform/truncate_interstage_variables.cc b/src/tint/transform/truncate_interstage_variables.cc new file mode 100644 index 0000000..a5e7256 --- /dev/null +++ b/src/tint/transform/truncate_interstage_variables.cc
@@ -0,0 +1,199 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/transform/truncate_interstage_variables.h" + +#include <memory> +#include <string> +#include <utility> + +#include "src/tint/program_builder.h" +#include "src/tint/sem/call.h" +#include "src/tint/sem/function.h" +#include "src/tint/sem/member_accessor_expression.h" +#include "src/tint/sem/statement.h" +#include "src/tint/sem/variable.h" +#include "src/tint/text/unicode.h" + +TINT_INSTANTIATE_TYPEINFO(tint::transform::TruncateInterstageVariables); +TINT_INSTANTIATE_TYPEINFO(tint::transform::TruncateInterstageVariables::Config); + +namespace tint::transform { + +namespace { + +struct TruncatedStructAndConverter { + /// The symbol of the truncated structure. + Symbol truncated_struct; + /// The symbol of the helper function that takes the original structure as a single argument and + /// returns the truncated structure type. + Symbol truncate_fn; +}; + +} // anonymous namespace + +TruncateInterstageVariables::TruncateInterstageVariables() = default; +TruncateInterstageVariables::~TruncateInterstageVariables() = default; + +Transform::ApplyResult TruncateInterstageVariables::Apply(const Program* src, + const DataMap& config, + DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + + const auto* data = config.Get<Config>(); + if (data == nullptr) { + b.Diagnostics().add_error( + diag::System::Transform, + "missing transform data for " + + std::string(TypeInfo::Of<TruncateInterstageVariables>().name)); + return Program(std::move(b)); + } + + auto& sem = ctx.src->Sem(); + + bool should_run = false; + + utils::Hashmap<const sem::Function*, Symbol, 4u> entry_point_functions_to_truncate_functions; + utils::Hashmap<const sem::Struct*, TruncatedStructAndConverter, 4u> + old_shader_io_structs_to_new_struct_and_truncate_functions; + + for (auto* func_ast : ctx.src->AST().Functions()) { + if (!func_ast->IsEntryPoint()) { + continue; + } + + if (func_ast->PipelineStage() != ast::PipelineStage::kVertex) { + // Currently only vertex stage could have interstage output variables that need + // truncated. + continue; + } + + auto* func_sem = sem.Get(func_ast); + auto* str = func_sem->ReturnType()->As<sem::Struct>(); + + if (!str) { + TINT_ICE(Transform, ctx.dst->Diagnostics()) + << "Entrypoint function return type is non-struct.\n" + << "TruncateInterstageVariables transform needs to run after " + "CanonicalizeEntryPointIO transform."; + continue; + } + + // This transform is run after CanonicalizeEntryPointIO transform, + // So it is guaranteed that entry point inputs are already grouped in a struct. + const ast::Struct* struct_ty = str->Declaration(); + + // A prepass to check if any interstage variable locations in the entry point needs + // truncating. If not we don't really need to handle this entry point. + utils::Hashset<const sem::StructMember*, 16u> omit_members; + + for (auto* member : struct_ty->members) { + if (ast::GetAttribute<ast::LocationAttribute>(member->attributes)) { + auto* m = sem.Get(member); + uint32_t location = m->Location().value(); + if (!data->interstage_locations.test(location)) { + omit_members.Add(m); + } + } + } + + if (omit_members.IsEmpty()) { + continue; + } + + // Now we are sure the transform needs to be run. + should_run = true; + + // Get or create a new truncated struct/truncate function for the interstage inputs & + // outputs. + auto entry = + old_shader_io_structs_to_new_struct_and_truncate_functions.GetOrCreate(str, [&] { + auto new_struct_sym = b.Symbols().New(); + + utils::Vector<const ast::StructMember*, 20> truncated_members; + utils::Vector<const ast::Expression*, 20> initializer_exprs; + + for (auto* member : str->Members()) { + if (omit_members.Contains(member)) { + continue; + } + + truncated_members.Push(ctx.Clone(member->Declaration())); + initializer_exprs.Push( + b.MemberAccessor("io", ctx.Clone(member->Declaration()->symbol))); + } + + // Create the new shader io struct. + b.Structure(new_struct_sym, std::move(truncated_members)); + + // Create the mapping function to truncate the shader io. + auto mapping_fn_sym = b.Symbols().New("truncate_shader_output"); + b.Func(mapping_fn_sym, + utils::Vector{b.Param("io", ctx.Clone(func_ast->return_type))}, + b.ty.type_name(new_struct_sym), + utils::Vector{b.Return(b.Construct(b.ty.type_name(new_struct_sym), + std::move(initializer_exprs)))}); + return TruncatedStructAndConverter{new_struct_sym, mapping_fn_sym}; + }); + + ctx.Replace(func_ast->return_type, b.ty.type_name(entry.truncated_struct)); + + entry_point_functions_to_truncate_functions.Add(func_sem, entry.truncate_fn); + } + + if (!should_run) { + return SkipTransform; + } + + // Replace return statements with new truncated shader IO struct + ctx.ReplaceAll( + [&](const ast::ReturnStatement* return_statement) -> const ast::ReturnStatement* { + auto* return_sem = sem.Get(return_statement); + if (auto* mapping_fn_sym = + entry_point_functions_to_truncate_functions.Find(return_sem->Function())) { + return b.Return(return_statement->source, + b.Call(*mapping_fn_sym, ctx.Clone(return_statement->value))); + } + return nullptr; + }); + + // Remove IO attributes from old shader IO struct which is not used as entry point output + // anymore. + for (auto it : old_shader_io_structs_to_new_struct_and_truncate_functions) { + const ast::Struct* struct_ty = it.key->Declaration(); + for (auto* member : struct_ty->members) { + for (auto* attr : member->attributes) { + if (attr->IsAnyOf<ast::BuiltinAttribute, ast::LocationAttribute, + ast::InterpolateAttribute, ast::InvariantAttribute>()) { + ctx.Remove(member->attributes, attr); + } + } + } + } + + ctx.Clone(); + return Program(std::move(b)); +} + +TruncateInterstageVariables::Config::Config() = default; + +TruncateInterstageVariables::Config::Config(const Config&) = default; + +TruncateInterstageVariables::Config::~Config() = default; + +TruncateInterstageVariables::Config& TruncateInterstageVariables::Config::operator=(const Config&) = + default; + +} // namespace tint::transform
diff --git a/src/tint/transform/truncate_interstage_variables.h b/src/tint/transform/truncate_interstage_variables.h new file mode 100644 index 0000000..bed226c --- /dev/null +++ b/src/tint/transform/truncate_interstage_variables.h
@@ -0,0 +1,130 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_TRANSFORM_TRUNCATE_INTERSTAGE_VARIABLES_H_ +#define SRC_TINT_TRANSFORM_TRUNCATE_INTERSTAGE_VARIABLES_H_ + +#include <bitset> + +#include "src/tint/sem/binding_point.h" +#include "src/tint/transform/transform.h" + +namespace tint::transform { + +/// TruncateInterstageVariables is a transform that truncate interstage variables. +/// It must be run after CanonicalizeEntryPointIO which guarantees all interstage variables of +/// a given entry point are grouped into one shader IO struct. +/// It replaces `original shader IO struct` with a `new wrapper struct` containing builtin IOs +/// and user-defined IO whose locations are marked in the interstage_locations bitset from the +/// config. The return statements of `original shader IO struct` are wrapped by a mapping function +/// that initializes the members of `new wrapper struct` with values from `original shader IO +/// struct`. IO attributes of members in `original shader IO struct` are removed, other attributes +/// still preserve. +/// +/// For example: +/// +/// ``` +/// struct ShaderIO { +/// @builtin(position) @invariant pos: vec4<f32>, +/// @location(1) f_1: f32, +/// @location(3) @align(16) f_3: f32, +/// @location(5) @interpolate(flat) @align(16) @size(16) f_5: u32, +/// } +/// @vertex +/// fn f() -> ShaderIO { +/// var io: ShaderIO; +/// io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); +/// io.f_1 = 1.0; +/// io.f_3 = io.f_1 + 3.0; +/// io.f_5 = 1u; +/// return io; +/// } +/// ``` +/// +/// With config.interstage_locations[3] and [5] set to true, is transformed to: +/// +/// ``` +/// struct tint_symbol { +/// @builtin(position) @invariant +/// pos : vec4<f32>, +/// @location(3) @align(16) +/// f_3 : f32, +/// @location(5) @interpolate(flat) @align(16) @size(16) +/// f_5 : u32, +/// } +/// +/// fn truncate_shader_output(io : ShaderIO) -> tint_symbol { +/// return tint_symbol(io.pos, io.f_3, io.f_5); +/// } +/// +/// struct ShaderIO { +/// pos : vec4<f32>, +/// f_1 : f32, +/// @align(16) +/// f_3 : f32, +/// @align(16) @size(16) +/// f_5 : u32, +/// } +/// +/// @vertex +/// fn f() -> tint_symbol { +/// var io : ShaderIO; +/// io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); +/// io.f_1 = 1.0; +/// io.f_3 = (io.f_1 + 3.0); +/// io.f_5 = 1u; +/// return truncate_shader_output(io); +/// } +/// ``` +/// +class TruncateInterstageVariables final : public Castable<TruncateInterstageVariables, Transform> { + public: + /// Configuration options for the transform + struct Config final : public Castable<Config, Data> { + /// Constructor + Config(); + + /// Copy constructor + Config(const Config&); + + /// Destructor + ~Config() override; + + /// Assignment operator + /// @returns this Config + Config& operator=(const Config&); + + /// Indicate which interstage io locations are actually used by the later stage. + /// There can be at most 16 user defined interstage variables with locations. + std::bitset<16> interstage_locations; + + /// Reflect the fields of this class so that it can be used by tint::ForeachField() + TINT_REFLECT(interstage_variables); + }; + + /// Constructor using a the configuration provided in the input Data + TruncateInterstageVariables(); + + /// Destructor + ~TruncateInterstageVariables() override; + + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; +}; + +} // namespace tint::transform + +#endif // SRC_TINT_TRANSFORM_TRUNCATE_INTERSTAGE_VARIABLES_H_
diff --git a/src/tint/transform/truncate_interstage_variables_test.cc b/src/tint/transform/truncate_interstage_variables_test.cc new file mode 100644 index 0000000..9ab8fa2 --- /dev/null +++ b/src/tint/transform/truncate_interstage_variables_test.cc
@@ -0,0 +1,599 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/transform/truncate_interstage_variables.h" +#include "src/tint/transform/canonicalize_entry_point_io.h" + +#include "gmock/gmock.h" +#include "src/tint/transform/test_helper.h" + +namespace tint::transform { +namespace { + +using ::testing::ContainerEq; + +using TruncateInterstageVariablesTest = TransformTest; + +TEST_F(TruncateInterstageVariablesTest, ShouldRunVertex) { + auto* src = R"( +struct ShaderIO { + @builtin(position) pos: vec4<f32>, + @location(0) f_0: f32, + @location(2) f_2: f32, +} +@vertex +fn f() -> ShaderIO { + var io: ShaderIO; + io.f_0 = 1.0; + io.f_2 = io.f_2 + 3.0; + return io; +} +)"; + + { + auto* expect = + "error: missing transform data for " + "tint::transform::TruncateInterstageVariables"; + auto got = Run<TruncateInterstageVariables>(src); + EXPECT_EQ(expect, str(got)); + } + + { + // Empty interstage_locations: truncate all interstage variables, should run + TruncateInterstageVariables::Config cfg; + DataMap data; + data.Add<TruncateInterstageVariables::Config>(cfg); + EXPECT_TRUE(ShouldRun<TruncateInterstageVariables>(src, data)); + } + + { + // All existing interstage_locations are marked: should not run + TruncateInterstageVariables::Config cfg; + cfg.interstage_locations[0] = true; + cfg.interstage_locations[2] = true; + DataMap data; + data.Add<TruncateInterstageVariables::Config>(cfg); + EXPECT_FALSE(ShouldRun<TruncateInterstageVariables>(src, data)); + } + + { + // Partial interstage_locations are marked: should run + TruncateInterstageVariables::Config cfg; + cfg.interstage_locations[2] = true; + DataMap data; + data.Add<TruncateInterstageVariables::Config>(cfg); + EXPECT_TRUE(ShouldRun<TruncateInterstageVariables>(src, data)); + } +} + +TEST_F(TruncateInterstageVariablesTest, ShouldRunFragment) { + auto* src = R"( +struct ShaderIO { + @location(0) f_0: f32, + @location(2) f_2: f32, +} +@fragment +fn f(io: ShaderIO) -> @location(1) vec4<f32> { + return vec4<f32>(io.f_0, io.f_2, 0.0, 1.0); +} +)"; + + TruncateInterstageVariables::Config cfg; + cfg.interstage_locations[2] = true; + + DataMap data; + data.Add<TruncateInterstageVariables::Config>(cfg); + + EXPECT_FALSE(ShouldRun<TruncateInterstageVariables>(src, data)); +} + +// Test that this transform should run after canoicalize entry point io, where shader io is already +// grouped into a struct. +TEST_F(TruncateInterstageVariablesTest, ShouldRunAfterCanonicalizeEntryPointIO) { + auto* src = R"( +@vertex +fn f() -> @builtin(position) vec4<f32> { + return vec4<f32>(1.0, 1.0, 1.0, 1.0); +} +)"; + + TruncateInterstageVariables::Config cfg; + cfg.interstage_locations[0] = true; + DataMap data; + data.Add<TruncateInterstageVariables::Config>(cfg); + + data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl); + auto got = Run<CanonicalizeEntryPointIO>(src, data); + + // Inevitably entry point can write only one variable if not using struct + // So the truncate won't run. + EXPECT_FALSE(ShouldRun<TruncateInterstageVariables>(str(got), data)); +} + +TEST_F(TruncateInterstageVariablesTest, BasicVertexTrimLocationInMid) { + auto* src = R"( +struct ShaderIO { + @builtin(position) pos: vec4<f32>, + @location(1) f_1: f32, + @location(3) f_3: f32, +} +@vertex +fn f() -> ShaderIO { + var io: ShaderIO; + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_1 = 1.0; + io.f_3 = io.f_1 + 3.0; + return io; +} +)"; + + auto* expect = R"( +struct tint_symbol { + @builtin(position) + pos : vec4<f32>, + @location(1) + f_1 : f32, +} + +fn truncate_shader_output(io : ShaderIO) -> tint_symbol { + return tint_symbol(io.pos, io.f_1); +} + +struct ShaderIO { + pos : vec4<f32>, + f_1 : f32, + f_3 : f32, +} + +@vertex +fn f() -> tint_symbol { + var io : ShaderIO; + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_1 = 1.0; + io.f_3 = (io.f_1 + 3.0); + return truncate_shader_output(io); +} +)"; + + TruncateInterstageVariables::Config cfg; + // fragment has input at @location(1) + cfg.interstage_locations[1] = true; + + DataMap data; + data.Add<TruncateInterstageVariables::Config>(cfg); + + auto got = Run<TruncateInterstageVariables>(src, data); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(TruncateInterstageVariablesTest, BasicVertexTrimLocationAtEnd) { + auto* src = R"( +struct ShaderIO { + @builtin(position) pos: vec4<f32>, + @location(1) f_1: f32, + @location(3) f_3: f32, +} +@vertex +fn f() -> ShaderIO { + var io: ShaderIO; + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_1 = 1.0; + io.f_3 = io.f_1 + 3.0; + return io; +} +)"; + + auto* expect = R"( +struct tint_symbol { + @builtin(position) + pos : vec4<f32>, + @location(3) + f_3 : f32, +} + +fn truncate_shader_output(io : ShaderIO) -> tint_symbol { + return tint_symbol(io.pos, io.f_3); +} + +struct ShaderIO { + pos : vec4<f32>, + f_1 : f32, + f_3 : f32, +} + +@vertex +fn f() -> tint_symbol { + var io : ShaderIO; + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_1 = 1.0; + io.f_3 = (io.f_1 + 3.0); + return truncate_shader_output(io); +} +)"; + + TruncateInterstageVariables::Config cfg; + // fragment has input at @location(3) + cfg.interstage_locations[3] = true; + + DataMap data; + data.Add<TruncateInterstageVariables::Config>(cfg); + + auto got = Run<TruncateInterstageVariables>(src, data); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(TruncateInterstageVariablesTest, TruncateAllLocations) { + auto* src = R"( +struct ShaderIO { + @builtin(position) pos: vec4<f32>, + @location(1) f_1: f32, + @location(3) f_3: f32, +} +@vertex +fn f() -> ShaderIO { + var io: ShaderIO; + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_1 = 1.0; + io.f_3 = io.f_1 + 3.0; + return io; +} +)"; + + { + auto* expect = R"( +struct tint_symbol { + @builtin(position) + pos : vec4<f32>, +} + +fn truncate_shader_output(io : ShaderIO) -> tint_symbol { + return tint_symbol(io.pos); +} + +struct ShaderIO { + pos : vec4<f32>, + f_1 : f32, + f_3 : f32, +} + +@vertex +fn f() -> tint_symbol { + var io : ShaderIO; + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_1 = 1.0; + io.f_3 = (io.f_1 + 3.0); + return truncate_shader_output(io); +} +)"; + + TruncateInterstageVariables::Config cfg; + DataMap data; + data.Add<TruncateInterstageVariables::Config>(cfg); + + auto got = Run<TruncateInterstageVariables>(src, data); + + EXPECT_EQ(expect, str(got)); + } +} + +// Test that the transform only removes IO attributes and preserve other attributes in the old +// Shader IO struct. +TEST_F(TruncateInterstageVariablesTest, RemoveIOAttributes) { + auto* src = R"( +struct ShaderIO { + @builtin(position) @invariant pos: vec4<f32>, + @location(1) f_1: f32, + @location(3) @align(16) f_3: f32, + @location(5) @interpolate(flat) @align(16) @size(16) f_5: u32, +} +@vertex +fn f() -> ShaderIO { + var io: ShaderIO; + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_1 = 1.0; + io.f_3 = io.f_1 + 3.0; + io.f_5 = 1u; + return io; +} +)"; + + { + auto* expect = R"( +struct tint_symbol { + @builtin(position) @invariant + pos : vec4<f32>, + @location(3) @align(16) + f_3 : f32, + @location(5) @interpolate(flat) @align(16) @size(16) + f_5 : u32, +} + +fn truncate_shader_output(io : ShaderIO) -> tint_symbol { + return tint_symbol(io.pos, io.f_3, io.f_5); +} + +struct ShaderIO { + pos : vec4<f32>, + f_1 : f32, + @align(16) + f_3 : f32, + @align(16) @size(16) + f_5 : u32, +} + +@vertex +fn f() -> tint_symbol { + var io : ShaderIO; + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_1 = 1.0; + io.f_3 = (io.f_1 + 3.0); + io.f_5 = 1u; + return truncate_shader_output(io); +} +)"; + + TruncateInterstageVariables::Config cfg; + // Missing @location[1] intentionally to make sure the transform run. + cfg.interstage_locations[3] = true; + cfg.interstage_locations[5] = true; + + DataMap data; + data.Add<TruncateInterstageVariables::Config>(cfg); + + auto got = Run<TruncateInterstageVariables>(src, data); + + EXPECT_EQ(expect, str(got)); + } +} + +TEST_F(TruncateInterstageVariablesTest, MultipleEntryPointsSharingStruct) { + auto* src = R"( +struct ShaderIO { + @builtin(position) pos: vec4<f32>, + @location(1) f_1: f32, + @location(3) f_3: f32, + @location(5) f_5: f32, +} + +@vertex +fn f1() -> ShaderIO { + var io: ShaderIO; + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_1 = 1.0; + io.f_3 = 1.0; + return io; +} + +@vertex +fn f2() -> ShaderIO { + var io: ShaderIO; + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_5 = 2.0; + return io; +} +)"; + + auto* expect = R"( +struct tint_symbol { + @builtin(position) + pos : vec4<f32>, + @location(3) + f_3 : f32, +} + +fn truncate_shader_output(io : ShaderIO) -> tint_symbol { + return tint_symbol(io.pos, io.f_3); +} + +struct ShaderIO { + pos : vec4<f32>, + f_1 : f32, + f_3 : f32, + f_5 : f32, +} + +@vertex +fn f1() -> tint_symbol { + var io : ShaderIO; + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_1 = 1.0; + io.f_3 = 1.0; + return truncate_shader_output(io); +} + +@vertex +fn f2() -> tint_symbol { + var io : ShaderIO; + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_5 = 2.0; + return truncate_shader_output(io); +} +)"; + + TruncateInterstageVariables::Config cfg; + // fragment has input at @location(3) + cfg.interstage_locations[3] = true; + + DataMap data; + data.Add<TruncateInterstageVariables::Config>(cfg); + + auto got = Run<TruncateInterstageVariables>(src, data); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(TruncateInterstageVariablesTest, MultipleEntryPoints) { + auto* src = R"( +struct ShaderIO1 { + @builtin(position) pos: vec4<f32>, + @location(1) f_1: f32, + @location(3) f_3: f32, + @location(5) f_5: f32, +} + +@vertex +fn f1() -> ShaderIO1 { + var io: ShaderIO1; + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_1 = 1.0; + io.f_3 = 1.0; + return io; +} + +struct ShaderIO2 { + @builtin(position) pos: vec4<f32>, + @location(2) f_2: f32, +} + +@vertex +fn f2() -> ShaderIO2 { + var io: ShaderIO2; + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_2 = 2.0; + return io; +} +)"; + + auto* expect = R"( +struct tint_symbol { + @builtin(position) + pos : vec4<f32>, + @location(3) + f_3 : f32, +} + +fn truncate_shader_output(io : ShaderIO1) -> tint_symbol { + return tint_symbol(io.pos, io.f_3); +} + +struct tint_symbol_1 { + @builtin(position) + pos : vec4<f32>, +} + +fn truncate_shader_output_1(io : ShaderIO2) -> tint_symbol_1 { + return tint_symbol_1(io.pos); +} + +struct ShaderIO1 { + pos : vec4<f32>, + f_1 : f32, + f_3 : f32, + f_5 : f32, +} + +@vertex +fn f1() -> tint_symbol { + var io : ShaderIO1; + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_1 = 1.0; + io.f_3 = 1.0; + return truncate_shader_output(io); +} + +struct ShaderIO2 { + pos : vec4<f32>, + f_2 : f32, +} + +@vertex +fn f2() -> tint_symbol_1 { + var io : ShaderIO2; + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_2 = 2.0; + return truncate_shader_output_1(io); +} +)"; + + TruncateInterstageVariables::Config cfg; + // fragment has input at @location(3) + cfg.interstage_locations[3] = true; + + DataMap data; + data.Add<TruncateInterstageVariables::Config>(cfg); + + auto got = Run<TruncateInterstageVariables>(src, data); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(TruncateInterstageVariablesTest, MultipleReturnStatements) { + auto* src = R"( +struct ShaderIO { + @builtin(position) pos: vec4<f32>, + @location(1) f_1: f32, + @location(3) f_3: f32, +} +@vertex +fn f(@builtin(vertex_index) vid: u32) -> ShaderIO { + var io: ShaderIO; + if (vid > 10u) { + io.f_1 = 2.0; + return io; + } + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_1 = 1.0; + io.f_3 = io.f_1 + 3.0; + return io; +} +)"; + + auto* expect = R"( +struct tint_symbol { + @builtin(position) + pos : vec4<f32>, + @location(3) + f_3 : f32, +} + +fn truncate_shader_output(io : ShaderIO) -> tint_symbol { + return tint_symbol(io.pos, io.f_3); +} + +struct ShaderIO { + pos : vec4<f32>, + f_1 : f32, + f_3 : f32, +} + +@vertex +fn f(@builtin(vertex_index) vid : u32) -> tint_symbol { + var io : ShaderIO; + if ((vid > 10u)) { + io.f_1 = 2.0; + return truncate_shader_output(io); + } + io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0); + io.f_1 = 1.0; + io.f_3 = (io.f_1 + 3.0); + return truncate_shader_output(io); +} +)"; + + TruncateInterstageVariables::Config cfg; + // fragment has input at @location(3) + cfg.interstage_locations[3] = true; + + DataMap data; + data.Add<TruncateInterstageVariables::Config>(cfg); + + auto got = Run<TruncateInterstageVariables>(src, data); + + EXPECT_EQ(expect, str(got)); +} + +} // namespace +} // namespace tint::transform
diff --git a/src/tint/utils/bitset.h b/src/tint/utils/bitset.h index 86dccd6..a37f9a2 100644 --- a/src/tint/utils/bitset.h +++ b/src/tint/utils/bitset.h
@@ -92,6 +92,25 @@ return Bit{word, mask}; } + /// Const index operator + /// @param index the index of the bit to access + /// @return bool value of the indexed bit + bool operator[](size_t index) const { + const auto& word = vec_[index / kWordBits]; + auto mask = static_cast<Word>(1) << (index % kWordBits); + return word & mask; + } + + /// @returns true iff the all bits are unset (0) + bool AllBitsZero() const { + for (auto word : vec_) { + if (word) { + return false; + } + } + return true; + } + private: Vector<size_t, NumWords(N)> vec_; size_t len_ = 0;
diff --git a/src/tint/utils/bitset_test.cc b/src/tint/utils/bitset_test.cc index b07cf74..550d946 100644 --- a/src/tint/utils/bitset_test.cc +++ b/src/tint/utils/bitset_test.cc
@@ -26,6 +26,32 @@ EXPECT_EQ(bits.Length(), 100u); } +TEST(Bitset, AllBitsZero) { + Bitset<8> bits; + EXPECT_TRUE(bits.AllBitsZero()); + + bits.Resize(4u); + EXPECT_TRUE(bits.AllBitsZero()); + + bits.Resize(100u); + EXPECT_TRUE(bits.AllBitsZero()); + + bits[63] = true; + EXPECT_FALSE(bits.AllBitsZero()); + + bits.Resize(60); + EXPECT_TRUE(bits.AllBitsZero()); + + bits.Resize(64); + EXPECT_TRUE(bits.AllBitsZero()); + + bits[4] = true; + EXPECT_FALSE(bits.AllBitsZero()); + + bits.Resize(8); + EXPECT_FALSE(bits.AllBitsZero()); +} + TEST(Bitset, InitCleared_NoSpill) { Bitset<256> bits; bits.Resize(256);