[shuffle] Move ast to new structure.
This CL moves the ast folder into the `lang/wgsl/ast` folder and
updates the includes. The namespaces and build groups are not updated in
this CL, just the code move.
Bug: tint:1988
Change-Id: Iff69e4d02f55d2c7f1fcd279715ddd81d3ede4b1
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/142003
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/glsl/ast_writer/generator.cc b/src/tint/lang/glsl/ast_writer/generator.cc
index 146b386..7db15ba 100644
--- a/src/tint/lang/glsl/ast_writer/generator.cc
+++ b/src/tint/lang/glsl/ast_writer/generator.cc
@@ -14,9 +14,9 @@
#include "src/tint/lang/glsl/ast_writer/generator.h"
-#include "src/tint/ast/transform/binding_remapper.h"
-#include "src/tint/ast/transform/combine_samplers.h"
#include "src/tint/lang/glsl/ast_writer/generator_impl.h"
+#include "src/tint/lang/wgsl/ast/transform/binding_remapper.h"
+#include "src/tint/lang/wgsl/ast/transform/combine_samplers.h"
namespace tint::writer::glsl {
diff --git a/src/tint/lang/glsl/ast_writer/generator.h b/src/tint/lang/glsl/ast_writer/generator.h
index 5e32660..039d75a 100644
--- a/src/tint/lang/glsl/ast_writer/generator.h
+++ b/src/tint/lang/glsl/ast_writer/generator.h
@@ -21,9 +21,9 @@
#include <utility>
#include <vector>
-#include "src/tint/ast/pipeline_stage.h"
#include "src/tint/builtin/access.h"
#include "src/tint/lang/glsl/ast_writer/version.h"
+#include "src/tint/lang/wgsl/ast/pipeline_stage.h"
#include "src/tint/sem/binding_point.h"
#include "src/tint/sem/sampler_texture_pair.h"
#include "src/tint/writer/external_texture_options.h"
diff --git a/src/tint/lang/glsl/ast_writer/generator_bench.cc b/src/tint/lang/glsl/ast_writer/generator_bench.cc
index 07a7cdf..43345fb 100644
--- a/src/tint/lang/glsl/ast_writer/generator_bench.cc
+++ b/src/tint/lang/glsl/ast_writer/generator_bench.cc
@@ -14,9 +14,9 @@
#include <string>
-#include "src/tint/ast/identifier.h"
-#include "src/tint/ast/module.h"
#include "src/tint/bench/benchmark.h"
+#include "src/tint/lang/wgsl/ast/identifier.h"
+#include "src/tint/lang/wgsl/ast/module.h"
namespace tint::writer::glsl {
namespace {
diff --git a/src/tint/lang/glsl/ast_writer/generator_impl.cc b/src/tint/lang/glsl/ast_writer/generator_impl.cc
index d996deb..a877cb4 100644
--- a/src/tint/lang/glsl/ast_writer/generator_impl.cc
+++ b/src/tint/lang/glsl/ast_writer/generator_impl.cc
@@ -22,39 +22,39 @@
#include <utility>
#include <vector>
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/id_attribute.h"
-#include "src/tint/ast/internal_attribute.h"
-#include "src/tint/ast/interpolate_attribute.h"
-#include "src/tint/ast/transform/add_block_attribute.h"
-#include "src/tint/ast/transform/add_empty_entry_point.h"
-#include "src/tint/ast/transform/binding_remapper.h"
-#include "src/tint/ast/transform/builtin_polyfill.h"
-#include "src/tint/ast/transform/canonicalize_entry_point_io.h"
-#include "src/tint/ast/transform/combine_samplers.h"
-#include "src/tint/ast/transform/decompose_memory_access.h"
-#include "src/tint/ast/transform/demote_to_helper.h"
-#include "src/tint/ast/transform/direct_variable_access.h"
-#include "src/tint/ast/transform/disable_uniformity_analysis.h"
-#include "src/tint/ast/transform/expand_compound_assignment.h"
-#include "src/tint/ast/transform/multiplanar_external_texture.h"
-#include "src/tint/ast/transform/pad_structs.h"
-#include "src/tint/ast/transform/preserve_padding.h"
-#include "src/tint/ast/transform/promote_initializers_to_let.h"
-#include "src/tint/ast/transform/promote_side_effects_to_decl.h"
-#include "src/tint/ast/transform/remove_phonies.h"
-#include "src/tint/ast/transform/renamer.h"
-#include "src/tint/ast/transform/robustness.h"
-#include "src/tint/ast/transform/simplify_pointers.h"
-#include "src/tint/ast/transform/single_entry_point.h"
-#include "src/tint/ast/transform/std140.h"
-#include "src/tint/ast/transform/texture_1d_to_2d.h"
-#include "src/tint/ast/transform/unshadow.h"
-#include "src/tint/ast/transform/zero_init_workgroup_memory.h"
-#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/constant/splat.h"
#include "src/tint/constant/value.h"
#include "src/tint/debug.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/id_attribute.h"
+#include "src/tint/lang/wgsl/ast/internal_attribute.h"
+#include "src/tint/lang/wgsl/ast/interpolate_attribute.h"
+#include "src/tint/lang/wgsl/ast/transform/add_block_attribute.h"
+#include "src/tint/lang/wgsl/ast/transform/add_empty_entry_point.h"
+#include "src/tint/lang/wgsl/ast/transform/binding_remapper.h"
+#include "src/tint/lang/wgsl/ast/transform/builtin_polyfill.h"
+#include "src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h"
+#include "src/tint/lang/wgsl/ast/transform/combine_samplers.h"
+#include "src/tint/lang/wgsl/ast/transform/decompose_memory_access.h"
+#include "src/tint/lang/wgsl/ast/transform/demote_to_helper.h"
+#include "src/tint/lang/wgsl/ast/transform/direct_variable_access.h"
+#include "src/tint/lang/wgsl/ast/transform/disable_uniformity_analysis.h"
+#include "src/tint/lang/wgsl/ast/transform/expand_compound_assignment.h"
+#include "src/tint/lang/wgsl/ast/transform/multiplanar_external_texture.h"
+#include "src/tint/lang/wgsl/ast/transform/pad_structs.h"
+#include "src/tint/lang/wgsl/ast/transform/preserve_padding.h"
+#include "src/tint/lang/wgsl/ast/transform/promote_initializers_to_let.h"
+#include "src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl.h"
+#include "src/tint/lang/wgsl/ast/transform/remove_phonies.h"
+#include "src/tint/lang/wgsl/ast/transform/renamer.h"
+#include "src/tint/lang/wgsl/ast/transform/robustness.h"
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+#include "src/tint/lang/wgsl/ast/transform/single_entry_point.h"
+#include "src/tint/lang/wgsl/ast/transform/std140.h"
+#include "src/tint/lang/wgsl/ast/transform/texture_1d_to_2d.h"
+#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
+#include "src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/function.h"
diff --git a/src/tint/lang/glsl/ast_writer/generator_impl.h b/src/tint/lang/glsl/ast_writer/generator_impl.h
index aa4f87e..91cd6ab 100644
--- a/src/tint/lang/glsl/ast_writer/generator_impl.h
+++ b/src/tint/lang/glsl/ast_writer/generator_impl.h
@@ -21,21 +21,21 @@
#include <unordered_set>
#include <utility>
-#include "src/tint/ast/assignment_statement.h"
-#include "src/tint/ast/bitcast_expression.h"
-#include "src/tint/ast/break_statement.h"
-#include "src/tint/ast/continue_statement.h"
-#include "src/tint/ast/discard_statement.h"
-#include "src/tint/ast/for_loop_statement.h"
-#include "src/tint/ast/if_statement.h"
-#include "src/tint/ast/loop_statement.h"
-#include "src/tint/ast/return_statement.h"
-#include "src/tint/ast/switch_statement.h"
-#include "src/tint/ast/transform/decompose_memory_access.h"
-#include "src/tint/ast/unary_op_expression.h"
#include "src/tint/builtin/builtin_value.h"
#include "src/tint/lang/glsl/ast_writer/generator.h"
#include "src/tint/lang/glsl/ast_writer/version.h"
+#include "src/tint/lang/wgsl/ast/assignment_statement.h"
+#include "src/tint/lang/wgsl/ast/bitcast_expression.h"
+#include "src/tint/lang/wgsl/ast/break_statement.h"
+#include "src/tint/lang/wgsl/ast/continue_statement.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/for_loop_statement.h"
+#include "src/tint/lang/wgsl/ast/if_statement.h"
+#include "src/tint/lang/wgsl/ast/loop_statement.h"
+#include "src/tint/lang/wgsl/ast/return_statement.h"
+#include "src/tint/lang/wgsl/ast/switch_statement.h"
+#include "src/tint/lang/wgsl/ast/transform/decompose_memory_access.h"
+#include "src/tint/lang/wgsl/ast/unary_op_expression.h"
#include "src/tint/program_builder.h"
#include "src/tint/scope_stack.h"
#include "src/tint/utils/hash.h"
diff --git a/src/tint/lang/glsl/ast_writer/generator_impl_binary_test.cc b/src/tint/lang/glsl/ast_writer/generator_impl_binary_test.cc
index 0c77852..2afea56 100644
--- a/src/tint/lang/glsl/ast_writer/generator_impl_binary_test.cc
+++ b/src/tint/lang/glsl/ast_writer/generator_impl_binary_test.cc
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/lang/glsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
#include "src/tint/utils/string_stream.h"
#include "gmock/gmock.h"
diff --git a/src/tint/lang/glsl/ast_writer/generator_impl_builtin_test.cc b/src/tint/lang/glsl/ast_writer/generator_impl_builtin_test.cc
index 76c6e97..4ea0f0d 100644
--- a/src/tint/lang/glsl/ast_writer/generator_impl_builtin_test.cc
+++ b/src/tint/lang/glsl/ast_writer/generator_impl_builtin_test.cc
@@ -13,9 +13,9 @@
// limitations under the License.
#include "gmock/gmock.h"
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/stage_attribute.h"
#include "src/tint/lang/glsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
#include "src/tint/sem/call.h"
#include "src/tint/utils/string_stream.h"
diff --git a/src/tint/lang/glsl/ast_writer/generator_impl_builtin_texture_test.cc b/src/tint/lang/glsl/ast_writer/generator_impl_builtin_texture_test.cc
index 41de2eb..f6cdaa0 100644
--- a/src/tint/lang/glsl/ast_writer/generator_impl_builtin_texture_test.cc
+++ b/src/tint/lang/glsl/ast_writer/generator_impl_builtin_texture_test.cc
@@ -13,10 +13,10 @@
// limitations under the License.
#include "gmock/gmock.h"
-#include "src/tint/ast/builtin_texture_helper_test.h"
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/stage_attribute.h"
#include "src/tint/lang/glsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/builtin_texture_helper_test.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
namespace tint::writer::glsl {
namespace {
diff --git a/src/tint/lang/glsl/ast_writer/generator_impl_call_test.cc b/src/tint/lang/glsl/ast_writer/generator_impl_call_test.cc
index 76ed9f0..6c1b548 100644
--- a/src/tint/lang/glsl/ast_writer/generator_impl_call_test.cc
+++ b/src/tint/lang/glsl/ast_writer/generator_impl_call_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/call_statement.h"
#include "src/tint/lang/glsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
#include "src/tint/utils/string_stream.h"
#include "gmock/gmock.h"
diff --git a/src/tint/lang/glsl/ast_writer/generator_impl_function_test.cc b/src/tint/lang/glsl/ast_writer/generator_impl_function_test.cc
index 7ad94c6..23e91a2 100644
--- a/src/tint/lang/glsl/ast_writer/generator_impl_function_test.cc
+++ b/src/tint/lang/glsl/ast_writer/generator_impl_function_test.cc
@@ -13,10 +13,10 @@
// limitations under the License.
#include "gmock/gmock.h"
-#include "src/tint/ast/stage_attribute.h"
-#include "src/tint/ast/variable_decl_statement.h"
-#include "src/tint/ast/workgroup_attribute.h"
#include "src/tint/lang/glsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
+#include "src/tint/lang/wgsl/ast/workgroup_attribute.h"
using ::testing::HasSubstr;
diff --git a/src/tint/lang/glsl/ast_writer/generator_impl_loop_test.cc b/src/tint/lang/glsl/ast_writer/generator_impl_loop_test.cc
index ac464cf..73b2cad 100644
--- a/src/tint/lang/glsl/ast_writer/generator_impl_loop_test.cc
+++ b/src/tint/lang/glsl/ast_writer/generator_impl_loop_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/lang/glsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
#include "gmock/gmock.h"
diff --git a/src/tint/lang/glsl/ast_writer/generator_impl_member_accessor_test.cc b/src/tint/lang/glsl/ast_writer/generator_impl_member_accessor_test.cc
index 9bb2ea7..84a1a97 100644
--- a/src/tint/lang/glsl/ast_writer/generator_impl_member_accessor_test.cc
+++ b/src/tint/lang/glsl/ast_writer/generator_impl_member_accessor_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/stage_attribute.h"
#include "src/tint/lang/glsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
#include "gmock/gmock.h"
diff --git a/src/tint/lang/glsl/ast_writer/generator_impl_module_constant_test.cc b/src/tint/lang/glsl/ast_writer/generator_impl_module_constant_test.cc
index 850a3fc..672d359 100644
--- a/src/tint/lang/glsl/ast_writer/generator_impl_module_constant_test.cc
+++ b/src/tint/lang/glsl/ast_writer/generator_impl_module_constant_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/id_attribute.h"
#include "src/tint/lang/glsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/id_attribute.h"
#include "gmock/gmock.h"
diff --git a/src/tint/lang/glsl/ast_writer/generator_impl_sanitizer_test.cc b/src/tint/lang/glsl/ast_writer/generator_impl_sanitizer_test.cc
index f087225..8e7df8c 100644
--- a/src/tint/lang/glsl/ast_writer/generator_impl_sanitizer_test.cc
+++ b/src/tint/lang/glsl/ast_writer/generator_impl_sanitizer_test.cc
@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/stage_attribute.h"
-#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/lang/glsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
#include "gmock/gmock.h"
diff --git a/src/tint/lang/glsl/ast_writer/generator_impl_type_test.cc b/src/tint/lang/glsl/ast_writer/generator_impl_type_test.cc
index bebadda..7a04894 100644
--- a/src/tint/lang/glsl/ast_writer/generator_impl_type_test.cc
+++ b/src/tint/lang/glsl/ast_writer/generator_impl_type_test.cc
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/stage_attribute.h"
#include "src/tint/lang/glsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
#include "src/tint/type/depth_texture.h"
#include "src/tint/type/multisampled_texture.h"
#include "src/tint/type/sampled_texture.h"
diff --git a/src/tint/lang/glsl/ast_writer/generator_impl_variable_decl_statement_test.cc b/src/tint/lang/glsl/ast_writer/generator_impl_variable_decl_statement_test.cc
index 8ff35a7..ce87f7c 100644
--- a/src/tint/lang/glsl/ast_writer/generator_impl_variable_decl_statement_test.cc
+++ b/src/tint/lang/glsl/ast_writer/generator_impl_variable_decl_statement_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/lang/glsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
#include "gmock/gmock.h"
diff --git a/src/tint/lang/glsl/ast_writer/generator_impl_workgroup_var_test.cc b/src/tint/lang/glsl/ast_writer/generator_impl_workgroup_var_test.cc
index b3a1c35..eb95abb 100644
--- a/src/tint/lang/glsl/ast_writer/generator_impl_workgroup_var_test.cc
+++ b/src/tint/lang/glsl/ast_writer/generator_impl_workgroup_var_test.cc
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/id_attribute.h"
-#include "src/tint/ast/stage_attribute.h"
#include "src/tint/lang/glsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/id_attribute.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
#include "gmock/gmock.h"
diff --git a/src/tint/lang/hlsl/ast_writer/generator.h b/src/tint/lang/hlsl/ast_writer/generator.h
index 7ade099..7f74834 100644
--- a/src/tint/lang/hlsl/ast_writer/generator.h
+++ b/src/tint/lang/hlsl/ast_writer/generator.h
@@ -23,7 +23,7 @@
#include <utility>
#include <vector>
-#include "src/tint/ast/pipeline_stage.h"
+#include "src/tint/lang/wgsl/ast/pipeline_stage.h"
#include "src/tint/reflection.h"
#include "src/tint/sem/binding_point.h"
#include "src/tint/utils/bitset.h"
diff --git a/src/tint/lang/hlsl/ast_writer/generator_impl.cc b/src/tint/lang/hlsl/ast_writer/generator_impl.cc
index 38dcb1e..eb40897 100644
--- a/src/tint/lang/hlsl/ast_writer/generator_impl.cc
+++ b/src/tint/lang/hlsl/ast_writer/generator_impl.cc
@@ -22,38 +22,38 @@
#include <utility>
#include <vector>
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/id_attribute.h"
-#include "src/tint/ast/internal_attribute.h"
-#include "src/tint/ast/interpolate_attribute.h"
-#include "src/tint/ast/transform/add_empty_entry_point.h"
-#include "src/tint/ast/transform/array_length_from_uniform.h"
-#include "src/tint/ast/transform/binding_remapper.h"
-#include "src/tint/ast/transform/builtin_polyfill.h"
-#include "src/tint/ast/transform/calculate_array_length.h"
-#include "src/tint/ast/transform/canonicalize_entry_point_io.h"
-#include "src/tint/ast/transform/decompose_memory_access.h"
-#include "src/tint/ast/transform/demote_to_helper.h"
-#include "src/tint/ast/transform/direct_variable_access.h"
-#include "src/tint/ast/transform/disable_uniformity_analysis.h"
-#include "src/tint/ast/transform/expand_compound_assignment.h"
-#include "src/tint/ast/transform/localize_struct_array_assignment.h"
-#include "src/tint/ast/transform/multiplanar_external_texture.h"
-#include "src/tint/ast/transform/num_workgroups_from_uniform.h"
-#include "src/tint/ast/transform/promote_initializers_to_let.h"
-#include "src/tint/ast/transform/promote_side_effects_to_decl.h"
-#include "src/tint/ast/transform/remove_continue_in_switch.h"
-#include "src/tint/ast/transform/remove_phonies.h"
-#include "src/tint/ast/transform/robustness.h"
-#include "src/tint/ast/transform/simplify_pointers.h"
-#include "src/tint/ast/transform/truncate_interstage_variables.h"
-#include "src/tint/ast/transform/unshadow.h"
-#include "src/tint/ast/transform/vectorize_scalar_matrix_initializers.h"
-#include "src/tint/ast/transform/zero_init_workgroup_memory.h"
-#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/constant/splat.h"
#include "src/tint/constant/value.h"
#include "src/tint/debug.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/id_attribute.h"
+#include "src/tint/lang/wgsl/ast/internal_attribute.h"
+#include "src/tint/lang/wgsl/ast/interpolate_attribute.h"
+#include "src/tint/lang/wgsl/ast/transform/add_empty_entry_point.h"
+#include "src/tint/lang/wgsl/ast/transform/array_length_from_uniform.h"
+#include "src/tint/lang/wgsl/ast/transform/binding_remapper.h"
+#include "src/tint/lang/wgsl/ast/transform/builtin_polyfill.h"
+#include "src/tint/lang/wgsl/ast/transform/calculate_array_length.h"
+#include "src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h"
+#include "src/tint/lang/wgsl/ast/transform/decompose_memory_access.h"
+#include "src/tint/lang/wgsl/ast/transform/demote_to_helper.h"
+#include "src/tint/lang/wgsl/ast/transform/direct_variable_access.h"
+#include "src/tint/lang/wgsl/ast/transform/disable_uniformity_analysis.h"
+#include "src/tint/lang/wgsl/ast/transform/expand_compound_assignment.h"
+#include "src/tint/lang/wgsl/ast/transform/localize_struct_array_assignment.h"
+#include "src/tint/lang/wgsl/ast/transform/multiplanar_external_texture.h"
+#include "src/tint/lang/wgsl/ast/transform/num_workgroups_from_uniform.h"
+#include "src/tint/lang/wgsl/ast/transform/promote_initializers_to_let.h"
+#include "src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl.h"
+#include "src/tint/lang/wgsl/ast/transform/remove_continue_in_switch.h"
+#include "src/tint/lang/wgsl/ast/transform/remove_phonies.h"
+#include "src/tint/lang/wgsl/ast/transform/robustness.h"
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+#include "src/tint/lang/wgsl/ast/transform/truncate_interstage_variables.h"
+#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
+#include "src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers.h"
+#include "src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/function.h"
diff --git a/src/tint/lang/hlsl/ast_writer/generator_impl.h b/src/tint/lang/hlsl/ast_writer/generator_impl.h
index 3e440a0..8e355ff 100644
--- a/src/tint/lang/hlsl/ast_writer/generator_impl.h
+++ b/src/tint/lang/hlsl/ast_writer/generator_impl.h
@@ -21,20 +21,20 @@
#include <unordered_set>
#include <utility>
-#include "src/tint/ast/assignment_statement.h"
-#include "src/tint/ast/bitcast_expression.h"
-#include "src/tint/ast/break_statement.h"
-#include "src/tint/ast/continue_statement.h"
-#include "src/tint/ast/discard_statement.h"
-#include "src/tint/ast/for_loop_statement.h"
-#include "src/tint/ast/if_statement.h"
-#include "src/tint/ast/loop_statement.h"
-#include "src/tint/ast/return_statement.h"
-#include "src/tint/ast/switch_statement.h"
-#include "src/tint/ast/transform/decompose_memory_access.h"
-#include "src/tint/ast/unary_op_expression.h"
#include "src/tint/builtin/builtin_value.h"
#include "src/tint/lang/hlsl/ast_writer/generator.h"
+#include "src/tint/lang/wgsl/ast/assignment_statement.h"
+#include "src/tint/lang/wgsl/ast/bitcast_expression.h"
+#include "src/tint/lang/wgsl/ast/break_statement.h"
+#include "src/tint/lang/wgsl/ast/continue_statement.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/for_loop_statement.h"
+#include "src/tint/lang/wgsl/ast/if_statement.h"
+#include "src/tint/lang/wgsl/ast/loop_statement.h"
+#include "src/tint/lang/wgsl/ast/return_statement.h"
+#include "src/tint/lang/wgsl/ast/switch_statement.h"
+#include "src/tint/lang/wgsl/ast/transform/decompose_memory_access.h"
+#include "src/tint/lang/wgsl/ast/unary_op_expression.h"
#include "src/tint/program_builder.h"
#include "src/tint/scope_stack.h"
#include "src/tint/sem/binding_point.h"
diff --git a/src/tint/lang/hlsl/ast_writer/generator_impl_binary_test.cc b/src/tint/lang/hlsl/ast_writer/generator_impl_binary_test.cc
index 9816982..affce16 100644
--- a/src/tint/lang/hlsl/ast_writer/generator_impl_binary_test.cc
+++ b/src/tint/lang/hlsl/ast_writer/generator_impl_binary_test.cc
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/lang/hlsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
#include "src/tint/utils/string_stream.h"
namespace tint::writer::hlsl {
diff --git a/src/tint/lang/hlsl/ast_writer/generator_impl_builtin_test.cc b/src/tint/lang/hlsl/ast_writer/generator_impl_builtin_test.cc
index 07bfd35..082c735 100644
--- a/src/tint/lang/hlsl/ast_writer/generator_impl_builtin_test.cc
+++ b/src/tint/lang/hlsl/ast_writer/generator_impl_builtin_test.cc
@@ -13,9 +13,9 @@
// limitations under the License.
#include "gmock/gmock.h"
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/stage_attribute.h"
#include "src/tint/lang/hlsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
#include "src/tint/sem/call.h"
#include "src/tint/utils/string_stream.h"
diff --git a/src/tint/lang/hlsl/ast_writer/generator_impl_builtin_texture_test.cc b/src/tint/lang/hlsl/ast_writer/generator_impl_builtin_texture_test.cc
index 9c7acbd..5e2e59a 100644
--- a/src/tint/lang/hlsl/ast_writer/generator_impl_builtin_texture_test.cc
+++ b/src/tint/lang/hlsl/ast_writer/generator_impl_builtin_texture_test.cc
@@ -13,10 +13,10 @@
// limitations under the License.
#include "gmock/gmock.h"
-#include "src/tint/ast/builtin_texture_helper_test.h"
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/stage_attribute.h"
#include "src/tint/lang/hlsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/builtin_texture_helper_test.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
namespace tint::writer::hlsl {
namespace {
diff --git a/src/tint/lang/hlsl/ast_writer/generator_impl_call_test.cc b/src/tint/lang/hlsl/ast_writer/generator_impl_call_test.cc
index ac266d6..bb9743a 100644
--- a/src/tint/lang/hlsl/ast_writer/generator_impl_call_test.cc
+++ b/src/tint/lang/hlsl/ast_writer/generator_impl_call_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/call_statement.h"
#include "src/tint/lang/hlsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
#include "src/tint/utils/string_stream.h"
using namespace tint::number_suffixes; // NOLINT
diff --git a/src/tint/lang/hlsl/ast_writer/generator_impl_function_test.cc b/src/tint/lang/hlsl/ast_writer/generator_impl_function_test.cc
index b892067..32454e9 100644
--- a/src/tint/lang/hlsl/ast_writer/generator_impl_function_test.cc
+++ b/src/tint/lang/hlsl/ast_writer/generator_impl_function_test.cc
@@ -13,10 +13,10 @@
// limitations under the License.
#include "gmock/gmock.h"
-#include "src/tint/ast/stage_attribute.h"
-#include "src/tint/ast/variable_decl_statement.h"
-#include "src/tint/ast/workgroup_attribute.h"
#include "src/tint/lang/hlsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
+#include "src/tint/lang/wgsl/ast/workgroup_attribute.h"
using ::testing::HasSubstr;
diff --git a/src/tint/lang/hlsl/ast_writer/generator_impl_loop_test.cc b/src/tint/lang/hlsl/ast_writer/generator_impl_loop_test.cc
index 67663e4..cdea5a4 100644
--- a/src/tint/lang/hlsl/ast_writer/generator_impl_loop_test.cc
+++ b/src/tint/lang/hlsl/ast_writer/generator_impl_loop_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/lang/hlsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
using namespace tint::number_suffixes; // NOLINT
diff --git a/src/tint/lang/hlsl/ast_writer/generator_impl_member_accessor_test.cc b/src/tint/lang/hlsl/ast_writer/generator_impl_member_accessor_test.cc
index d6b42c9..755b51d 100644
--- a/src/tint/lang/hlsl/ast_writer/generator_impl_member_accessor_test.cc
+++ b/src/tint/lang/hlsl/ast_writer/generator_impl_member_accessor_test.cc
@@ -13,8 +13,8 @@
// limitations under the License.
#include "gmock/gmock.h"
-#include "src/tint/ast/stage_attribute.h"
#include "src/tint/lang/hlsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
using ::testing::HasSubstr;
diff --git a/src/tint/lang/hlsl/ast_writer/generator_impl_module_constant_test.cc b/src/tint/lang/hlsl/ast_writer/generator_impl_module_constant_test.cc
index 96e2da9..a1ee291 100644
--- a/src/tint/lang/hlsl/ast_writer/generator_impl_module_constant_test.cc
+++ b/src/tint/lang/hlsl/ast_writer/generator_impl_module_constant_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/id_attribute.h"
#include "src/tint/lang/hlsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/id_attribute.h"
namespace tint::writer::hlsl {
namespace {
diff --git a/src/tint/lang/hlsl/ast_writer/generator_impl_sanitizer_test.cc b/src/tint/lang/hlsl/ast_writer/generator_impl_sanitizer_test.cc
index 0ba1d22..d716b6b 100644
--- a/src/tint/lang/hlsl/ast_writer/generator_impl_sanitizer_test.cc
+++ b/src/tint/lang/hlsl/ast_writer/generator_impl_sanitizer_test.cc
@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/stage_attribute.h"
-#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/lang/hlsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
namespace tint::writer::hlsl {
namespace {
diff --git a/src/tint/lang/hlsl/ast_writer/generator_impl_type_test.cc b/src/tint/lang/hlsl/ast_writer/generator_impl_type_test.cc
index b66f8d6..f640f68 100644
--- a/src/tint/lang/hlsl/ast_writer/generator_impl_type_test.cc
+++ b/src/tint/lang/hlsl/ast_writer/generator_impl_type_test.cc
@@ -13,9 +13,9 @@
// limitations under the License.
#include "gmock/gmock.h"
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/stage_attribute.h"
#include "src/tint/lang/hlsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
#include "src/tint/type/depth_texture.h"
#include "src/tint/type/multisampled_texture.h"
#include "src/tint/type/sampled_texture.h"
diff --git a/src/tint/lang/hlsl/ast_writer/generator_impl_variable_decl_statement_test.cc b/src/tint/lang/hlsl/ast_writer/generator_impl_variable_decl_statement_test.cc
index 0a61ec6..b8dfa9f 100644
--- a/src/tint/lang/hlsl/ast_writer/generator_impl_variable_decl_statement_test.cc
+++ b/src/tint/lang/hlsl/ast_writer/generator_impl_variable_decl_statement_test.cc
@@ -13,8 +13,8 @@
// limitations under the License.
#include "gmock/gmock.h"
-#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/lang/hlsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
namespace tint::writer::hlsl {
namespace {
diff --git a/src/tint/lang/hlsl/ast_writer/generator_impl_workgroup_var_test.cc b/src/tint/lang/hlsl/ast_writer/generator_impl_workgroup_var_test.cc
index 2d28d82..c062375 100644
--- a/src/tint/lang/hlsl/ast_writer/generator_impl_workgroup_var_test.cc
+++ b/src/tint/lang/hlsl/ast_writer/generator_impl_workgroup_var_test.cc
@@ -13,9 +13,9 @@
// limitations under the License.
#include "gmock/gmock.h"
-#include "src/tint/ast/id_attribute.h"
-#include "src/tint/ast/stage_attribute.h"
#include "src/tint/lang/hlsl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/id_attribute.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
using ::testing::HasSubstr;
diff --git a/src/tint/lang/hlsl/ast_writer/test_helper.h b/src/tint/lang/hlsl/ast_writer/test_helper.h
index 68ac3b2..06d81d8 100644
--- a/src/tint/lang/hlsl/ast_writer/test_helper.h
+++ b/src/tint/lang/hlsl/ast_writer/test_helper.h
@@ -20,9 +20,9 @@
#include <utility>
#include "gtest/gtest.h"
-#include "src/tint/ast/transform/renamer.h"
#include "src/tint/lang/hlsl/ast_writer/generator.h"
#include "src/tint/lang/hlsl/ast_writer/generator_impl.h"
+#include "src/tint/lang/wgsl/ast/transform/renamer.h"
#include "src/tint/transform/manager.h"
namespace tint::writer::hlsl {
diff --git a/src/tint/lang/msl/ast_writer/generator_bench.cc b/src/tint/lang/msl/ast_writer/generator_bench.cc
index 98d5ece..c1478c9 100644
--- a/src/tint/lang/msl/ast_writer/generator_bench.cc
+++ b/src/tint/lang/msl/ast_writer/generator_bench.cc
@@ -14,8 +14,8 @@
#include <string>
-#include "src/tint/ast/module.h"
#include "src/tint/bench/benchmark.h"
+#include "src/tint/lang/wgsl/ast/module.h"
#include "src/tint/sem/variable.h"
namespace tint::writer::msl {
diff --git a/src/tint/lang/msl/ast_writer/generator_impl.cc b/src/tint/lang/msl/ast_writer/generator_impl.cc
index 11f71ca..c82807c 100644
--- a/src/tint/lang/msl/ast_writer/generator_impl.cc
+++ b/src/tint/lang/msl/ast_writer/generator_impl.cc
@@ -21,37 +21,37 @@
#include <utility>
#include <vector>
-#include "src/tint/ast/alias.h"
-#include "src/tint/ast/bool_literal_expression.h"
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/disable_validation_attribute.h"
-#include "src/tint/ast/float_literal_expression.h"
-#include "src/tint/ast/id_attribute.h"
-#include "src/tint/ast/interpolate_attribute.h"
-#include "src/tint/ast/module.h"
-#include "src/tint/ast/transform/array_length_from_uniform.h"
-#include "src/tint/ast/transform/binding_remapper.h"
-#include "src/tint/ast/transform/builtin_polyfill.h"
-#include "src/tint/ast/transform/canonicalize_entry_point_io.h"
-#include "src/tint/ast/transform/demote_to_helper.h"
-#include "src/tint/ast/transform/disable_uniformity_analysis.h"
-#include "src/tint/ast/transform/expand_compound_assignment.h"
-#include "src/tint/ast/transform/module_scope_var_to_entry_point_param.h"
-#include "src/tint/ast/transform/multiplanar_external_texture.h"
-#include "src/tint/ast/transform/packed_vec3.h"
-#include "src/tint/ast/transform/preserve_padding.h"
-#include "src/tint/ast/transform/promote_initializers_to_let.h"
-#include "src/tint/ast/transform/promote_side_effects_to_decl.h"
-#include "src/tint/ast/transform/remove_phonies.h"
-#include "src/tint/ast/transform/robustness.h"
-#include "src/tint/ast/transform/simplify_pointers.h"
-#include "src/tint/ast/transform/unshadow.h"
-#include "src/tint/ast/transform/vectorize_scalar_matrix_initializers.h"
-#include "src/tint/ast/transform/zero_init_workgroup_memory.h"
-#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/constant/splat.h"
#include "src/tint/constant/value.h"
#include "src/tint/lang/msl/ast_writer/generator_support.h"
+#include "src/tint/lang/wgsl/ast/alias.h"
+#include "src/tint/lang/wgsl/ast/bool_literal_expression.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/disable_validation_attribute.h"
+#include "src/tint/lang/wgsl/ast/float_literal_expression.h"
+#include "src/tint/lang/wgsl/ast/id_attribute.h"
+#include "src/tint/lang/wgsl/ast/interpolate_attribute.h"
+#include "src/tint/lang/wgsl/ast/module.h"
+#include "src/tint/lang/wgsl/ast/transform/array_length_from_uniform.h"
+#include "src/tint/lang/wgsl/ast/transform/binding_remapper.h"
+#include "src/tint/lang/wgsl/ast/transform/builtin_polyfill.h"
+#include "src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h"
+#include "src/tint/lang/wgsl/ast/transform/demote_to_helper.h"
+#include "src/tint/lang/wgsl/ast/transform/disable_uniformity_analysis.h"
+#include "src/tint/lang/wgsl/ast/transform/expand_compound_assignment.h"
+#include "src/tint/lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.h"
+#include "src/tint/lang/wgsl/ast/transform/multiplanar_external_texture.h"
+#include "src/tint/lang/wgsl/ast/transform/packed_vec3.h"
+#include "src/tint/lang/wgsl/ast/transform/preserve_padding.h"
+#include "src/tint/lang/wgsl/ast/transform/promote_initializers_to_let.h"
+#include "src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl.h"
+#include "src/tint/lang/wgsl/ast/transform/remove_phonies.h"
+#include "src/tint/lang/wgsl/ast/transform/robustness.h"
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
+#include "src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers.h"
+#include "src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/member_accessor_expression.h"
diff --git a/src/tint/lang/msl/ast_writer/generator_impl.h b/src/tint/lang/msl/ast_writer/generator_impl.h
index 9b691e6..07837c4 100644
--- a/src/tint/lang/msl/ast_writer/generator_impl.h
+++ b/src/tint/lang/msl/ast_writer/generator_impl.h
@@ -21,23 +21,23 @@
#include <unordered_set>
#include <vector>
-#include "src/tint/ast/assignment_statement.h"
-#include "src/tint/ast/binary_expression.h"
-#include "src/tint/ast/bitcast_expression.h"
-#include "src/tint/ast/break_statement.h"
-#include "src/tint/ast/continue_statement.h"
-#include "src/tint/ast/discard_statement.h"
-#include "src/tint/ast/expression.h"
-#include "src/tint/ast/if_statement.h"
-#include "src/tint/ast/index_accessor_expression.h"
-#include "src/tint/ast/interpolate_attribute.h"
-#include "src/tint/ast/loop_statement.h"
-#include "src/tint/ast/member_accessor_expression.h"
-#include "src/tint/ast/return_statement.h"
-#include "src/tint/ast/switch_statement.h"
-#include "src/tint/ast/unary_op_expression.h"
#include "src/tint/builtin/builtin_value.h"
#include "src/tint/lang/msl/ast_writer/generator.h"
+#include "src/tint/lang/wgsl/ast/assignment_statement.h"
+#include "src/tint/lang/wgsl/ast/binary_expression.h"
+#include "src/tint/lang/wgsl/ast/bitcast_expression.h"
+#include "src/tint/lang/wgsl/ast/break_statement.h"
+#include "src/tint/lang/wgsl/ast/continue_statement.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/expression.h"
+#include "src/tint/lang/wgsl/ast/if_statement.h"
+#include "src/tint/lang/wgsl/ast/index_accessor_expression.h"
+#include "src/tint/lang/wgsl/ast/interpolate_attribute.h"
+#include "src/tint/lang/wgsl/ast/loop_statement.h"
+#include "src/tint/lang/wgsl/ast/member_accessor_expression.h"
+#include "src/tint/lang/wgsl/ast/return_statement.h"
+#include "src/tint/lang/wgsl/ast/switch_statement.h"
+#include "src/tint/lang/wgsl/ast/unary_op_expression.h"
#include "src/tint/program.h"
#include "src/tint/scope_stack.h"
#include "src/tint/sem/struct.h"
diff --git a/src/tint/lang/msl/ast_writer/generator_impl_builtin_test.cc b/src/tint/lang/msl/ast_writer/generator_impl_builtin_test.cc
index 16aeb79..9305d9a 100644
--- a/src/tint/lang/msl/ast_writer/generator_impl_builtin_test.cc
+++ b/src/tint/lang/msl/ast_writer/generator_impl_builtin_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/call_statement.h"
#include "src/tint/lang/msl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
#include "src/tint/sem/call.h"
#include "src/tint/utils/string_stream.h"
diff --git a/src/tint/lang/msl/ast_writer/generator_impl_builtin_texture_test.cc b/src/tint/lang/msl/ast_writer/generator_impl_builtin_texture_test.cc
index 171e217..ded73ba 100644
--- a/src/tint/lang/msl/ast_writer/generator_impl_builtin_texture_test.cc
+++ b/src/tint/lang/msl/ast_writer/generator_impl_builtin_texture_test.cc
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/builtin_texture_helper_test.h"
-#include "src/tint/ast/call_statement.h"
#include "src/tint/lang/msl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/builtin_texture_helper_test.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
#include "src/tint/utils/string_stream.h"
namespace tint::writer::msl {
diff --git a/src/tint/lang/msl/ast_writer/generator_impl_call_test.cc b/src/tint/lang/msl/ast_writer/generator_impl_call_test.cc
index fe3fad8..f923738 100644
--- a/src/tint/lang/msl/ast_writer/generator_impl_call_test.cc
+++ b/src/tint/lang/msl/ast_writer/generator_impl_call_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/call_statement.h"
#include "src/tint/lang/msl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
#include "src/tint/utils/string_stream.h"
using namespace tint::number_suffixes; // NOLINT
diff --git a/src/tint/lang/msl/ast_writer/generator_impl_function_test.cc b/src/tint/lang/msl/ast_writer/generator_impl_function_test.cc
index c9ac56d..ea07e30 100644
--- a/src/tint/lang/msl/ast_writer/generator_impl_function_test.cc
+++ b/src/tint/lang/msl/ast_writer/generator_impl_function_test.cc
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/stage_attribute.h"
-#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/lang/msl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
namespace tint::writer::msl {
namespace {
diff --git a/src/tint/lang/msl/ast_writer/generator_impl_loop_test.cc b/src/tint/lang/msl/ast_writer/generator_impl_loop_test.cc
index ed3ca24..c314afa 100644
--- a/src/tint/lang/msl/ast_writer/generator_impl_loop_test.cc
+++ b/src/tint/lang/msl/ast_writer/generator_impl_loop_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/lang/msl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
using namespace tint::number_suffixes; // NOLINT
diff --git a/src/tint/lang/msl/ast_writer/generator_impl_module_constant_test.cc b/src/tint/lang/msl/ast_writer/generator_impl_module_constant_test.cc
index a831d11..b4b746d 100644
--- a/src/tint/lang/msl/ast_writer/generator_impl_module_constant_test.cc
+++ b/src/tint/lang/msl/ast_writer/generator_impl_module_constant_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/id_attribute.h"
#include "src/tint/lang/msl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/id_attribute.h"
namespace tint::writer::msl {
namespace {
diff --git a/src/tint/lang/msl/ast_writer/generator_impl_sanitizer_test.cc b/src/tint/lang/msl/ast_writer/generator_impl_sanitizer_test.cc
index 6ed2974..dee22b1 100644
--- a/src/tint/lang/msl/ast_writer/generator_impl_sanitizer_test.cc
+++ b/src/tint/lang/msl/ast_writer/generator_impl_sanitizer_test.cc
@@ -13,10 +13,10 @@
// limitations under the License.
#include "gmock/gmock.h"
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/stage_attribute.h"
-#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/lang/msl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
namespace tint::writer::msl {
namespace {
diff --git a/src/tint/lang/msl/ast_writer/generator_impl_test.cc b/src/tint/lang/msl/ast_writer/generator_impl_test.cc
index e898da1..cd0026a 100644
--- a/src/tint/lang/msl/ast_writer/generator_impl_test.cc
+++ b/src/tint/lang/msl/ast_writer/generator_impl_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/stage_attribute.h"
#include "src/tint/lang/msl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
using namespace tint::number_suffixes; // NOLINT
diff --git a/src/tint/lang/msl/ast_writer/generator_impl_variable_decl_statement_test.cc b/src/tint/lang/msl/ast_writer/generator_impl_variable_decl_statement_test.cc
index 0432096..e9dab8d 100644
--- a/src/tint/lang/msl/ast_writer/generator_impl_variable_decl_statement_test.cc
+++ b/src/tint/lang/msl/ast_writer/generator_impl_variable_decl_statement_test.cc
@@ -13,8 +13,8 @@
// limitations under the License.
#include "gmock/gmock.h"
-#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/lang/msl/ast_writer/test_helper.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
namespace tint::writer::msl {
namespace {
diff --git a/src/tint/lang/spirv/reader/attributes.h b/src/tint/lang/spirv/reader/attributes.h
index f35ff0f..14aea32 100644
--- a/src/tint/lang/spirv/reader/attributes.h
+++ b/src/tint/lang/spirv/reader/attributes.h
@@ -15,8 +15,8 @@
#ifndef SRC_TINT_LANG_SPIRV_READER_ATTRIBUTES_H_
#define SRC_TINT_LANG_SPIRV_READER_ATTRIBUTES_H_
-#include "src/tint/ast/attribute.h"
#include "src/tint/builtin/builtin_value.h"
+#include "src/tint/lang/wgsl/ast/attribute.h"
#include "src/tint/program_builder.h"
#include "src/tint/utils/enum_set.h"
#include "src/tint/utils/vector.h"
diff --git a/src/tint/lang/spirv/reader/entry_point_info.h b/src/tint/lang/spirv/reader/entry_point_info.h
index b3e69f9..b7ae8da 100644
--- a/src/tint/lang/spirv/reader/entry_point_info.h
+++ b/src/tint/lang/spirv/reader/entry_point_info.h
@@ -17,7 +17,7 @@
#include <string>
-#include "src/tint/ast/pipeline_stage.h"
+#include "src/tint/lang/wgsl/ast/pipeline_stage.h"
#include "src/tint/utils/vector.h"
namespace tint::reader::spirv {
diff --git a/src/tint/lang/spirv/reader/enum_converter.h b/src/tint/lang/spirv/reader/enum_converter.h
index 2d8ff4e..800f6c8 100644
--- a/src/tint/lang/spirv/reader/enum_converter.h
+++ b/src/tint/lang/spirv/reader/enum_converter.h
@@ -17,10 +17,10 @@
#include "spirv/unified1/spirv.h"
#include "spirv/unified1/spirv.hpp11"
-#include "src/tint/ast/pipeline_stage.h"
#include "src/tint/builtin/address_space.h"
#include "src/tint/builtin/builtin_value.h"
#include "src/tint/lang/spirv/reader/fail_stream.h"
+#include "src/tint/lang/wgsl/ast/pipeline_stage.h"
#include "src/tint/type/storage_texture.h"
#include "src/tint/type/texture_dimension.h"
diff --git a/src/tint/lang/spirv/reader/function.cc b/src/tint/lang/spirv/reader/function.cc
index acc5633..601ab60 100644
--- a/src/tint/lang/spirv/reader/function.cc
+++ b/src/tint/lang/spirv/reader/function.cc
@@ -17,23 +17,23 @@
#include <algorithm>
#include <array>
-#include "src/tint/ast/assignment_statement.h"
-#include "src/tint/ast/bitcast_expression.h"
-#include "src/tint/ast/break_statement.h"
-#include "src/tint/ast/builtin_attribute.h"
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/continue_statement.h"
-#include "src/tint/ast/discard_statement.h"
-#include "src/tint/ast/if_statement.h"
-#include "src/tint/ast/loop_statement.h"
-#include "src/tint/ast/return_statement.h"
-#include "src/tint/ast/stage_attribute.h"
-#include "src/tint/ast/switch_statement.h"
-#include "src/tint/ast/transform/spirv_atomic.h"
-#include "src/tint/ast/unary_op_expression.h"
-#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/builtin/builtin_value.h"
#include "src/tint/builtin/function.h"
+#include "src/tint/lang/wgsl/ast/assignment_statement.h"
+#include "src/tint/lang/wgsl/ast/bitcast_expression.h"
+#include "src/tint/lang/wgsl/ast/break_statement.h"
+#include "src/tint/lang/wgsl/ast/builtin_attribute.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/continue_statement.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/if_statement.h"
+#include "src/tint/lang/wgsl/ast/loop_statement.h"
+#include "src/tint/lang/wgsl/ast/return_statement.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+#include "src/tint/lang/wgsl/ast/switch_statement.h"
+#include "src/tint/lang/wgsl/ast/transform/spirv_atomic.h"
+#include "src/tint/lang/wgsl/ast/unary_op_expression.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
#include "src/tint/switch.h"
#include "src/tint/type/depth_texture.h"
#include "src/tint/type/sampled_texture.h"
diff --git a/src/tint/lang/spirv/reader/parser.cc b/src/tint/lang/spirv/reader/parser.cc
index 845778e..e4244ea 100644
--- a/src/tint/lang/spirv/reader/parser.cc
+++ b/src/tint/lang/spirv/reader/parser.cc
@@ -16,14 +16,14 @@
#include <utility>
-#include "src/tint/ast/transform/decompose_strided_array.h"
-#include "src/tint/ast/transform/decompose_strided_matrix.h"
-#include "src/tint/ast/transform/fold_trivial_lets.h"
-#include "src/tint/ast/transform/remove_unreachable_statements.h"
-#include "src/tint/ast/transform/simplify_pointers.h"
-#include "src/tint/ast/transform/spirv_atomic.h"
-#include "src/tint/ast/transform/unshadow.h"
#include "src/tint/lang/spirv/reader/parser_impl.h"
+#include "src/tint/lang/wgsl/ast/transform/decompose_strided_array.h"
+#include "src/tint/lang/wgsl/ast/transform/decompose_strided_matrix.h"
+#include "src/tint/lang/wgsl/ast/transform/fold_trivial_lets.h"
+#include "src/tint/lang/wgsl/ast/transform/remove_unreachable_statements.h"
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+#include "src/tint/lang/wgsl/ast/transform/spirv_atomic.h"
+#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
#include "src/tint/transform/manager.h"
namespace tint::reader::spirv {
diff --git a/src/tint/lang/spirv/reader/parser_impl.cc b/src/tint/lang/spirv/reader/parser_impl.cc
index e79499c..f8b82d0c 100644
--- a/src/tint/lang/spirv/reader/parser_impl.cc
+++ b/src/tint/lang/spirv/reader/parser_impl.cc
@@ -20,12 +20,12 @@
#include <utility>
#include "source/opt/build_module.h"
-#include "src/tint/ast/bitcast_expression.h"
-#include "src/tint/ast/disable_validation_attribute.h"
-#include "src/tint/ast/id_attribute.h"
-#include "src/tint/ast/interpolate_attribute.h"
-#include "src/tint/ast/unary_op_expression.h"
#include "src/tint/lang/spirv/reader/function.h"
+#include "src/tint/lang/wgsl/ast/bitcast_expression.h"
+#include "src/tint/lang/wgsl/ast/disable_validation_attribute.h"
+#include "src/tint/lang/wgsl/ast/id_attribute.h"
+#include "src/tint/lang/wgsl/ast/interpolate_attribute.h"
+#include "src/tint/lang/wgsl/ast/unary_op_expression.h"
#include "src/tint/switch.h"
#include "src/tint/type/depth_texture.h"
#include "src/tint/type/multisampled_texture.h"
diff --git a/src/tint/lang/spirv/reader/parser_impl_barrier_test.cc b/src/tint/lang/spirv/reader/parser_impl_barrier_test.cc
index c867b2e..e245fa7 100644
--- a/src/tint/lang/spirv/reader/parser_impl_barrier_test.cc
+++ b/src/tint/lang/spirv/reader/parser_impl_barrier_test.cc
@@ -13,10 +13,10 @@
// limitations under the License.
#include "gmock/gmock.h"
-#include "src/tint/ast/call_statement.h"
#include "src/tint/lang/spirv/reader/function.h"
#include "src/tint/lang/spirv/reader/parser_impl_test_helper.h"
#include "src/tint/lang/spirv/reader/spirv_tools_helpers_test.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
#include "src/tint/sem/call.h"
namespace tint::reader::spirv {
diff --git a/src/tint/lang/spirv/reader/parser_type.h b/src/tint/lang/spirv/reader/parser_type.h
index afe7ad2..ee53f117 100644
--- a/src/tint/lang/spirv/reader/parser_type.h
+++ b/src/tint/lang/spirv/reader/parser_type.h
@@ -19,10 +19,10 @@
#include <string>
#include <vector>
-#include "src/tint/ast/type.h"
#include "src/tint/builtin/access.h"
#include "src/tint/builtin/address_space.h"
#include "src/tint/builtin/texel_format.h"
+#include "src/tint/lang/wgsl/ast/type.h"
#include "src/tint/symbol.h"
#include "src/tint/type/sampler_kind.h"
#include "src/tint/type/texture_dimension.h"
diff --git a/src/tint/lang/wgsl/ast/accessor_expression.cc b/src/tint/lang/wgsl/ast/accessor_expression.cc
new file mode 100644
index 0000000..cd0130d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/accessor_expression.cc
@@ -0,0 +1,34 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/accessor_expression.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::AccessorExpression);
+
+namespace tint::ast {
+
+AccessorExpression::AccessorExpression(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* obj)
+ : Base(pid, nid, src), object(obj) {
+ TINT_ASSERT(AST, object);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, object, program_id);
+}
+
+AccessorExpression::~AccessorExpression() = default;
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/accessor_expression.h b/src/tint/lang/wgsl/ast/accessor_expression.h
new file mode 100644
index 0000000..2551b5e
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/accessor_expression.h
@@ -0,0 +1,41 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_ACCESSOR_EXPRESSION_H_
+#define SRC_TINT_LANG_WGSL_AST_ACCESSOR_EXPRESSION_H_
+
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+namespace tint::ast {
+
+/// Base class for IndexAccessorExpression and MemberAccessorExpression
+class AccessorExpression : public utils::Castable<AccessorExpression, Expression> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the member accessor expression source
+ /// @param object the object
+ AccessorExpression(ProgramID pid, NodeID nid, const Source& source, const Expression* object);
+
+ /// Destructor
+ ~AccessorExpression() override;
+
+ /// The object being accessed
+ const Expression* const object;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_ACCESSOR_EXPRESSION_H_
diff --git a/src/tint/lang/wgsl/ast/alias.cc b/src/tint/lang/wgsl/ast/alias.cc
new file mode 100644
index 0000000..7261255
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/alias.cc
@@ -0,0 +1,38 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/alias.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::Alias);
+
+namespace tint::ast {
+
+Alias::Alias(ProgramID pid, NodeID nid, const Source& src, const Identifier* n, Type subtype)
+ : Base(pid, nid, src, n), type(subtype) {
+ TINT_ASSERT(AST, type);
+}
+
+Alias::~Alias() = default;
+
+const Alias* Alias::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto sym = ctx->Clone(name);
+ auto ty = ctx->Clone(type);
+ return ctx->dst->create<Alias>(src, sym, ty);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/alias.h b/src/tint/lang/wgsl/ast/alias.h
new file mode 100644
index 0000000..553f9c9
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/alias.h
@@ -0,0 +1,50 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_ALIAS_H_
+#define SRC_TINT_LANG_WGSL_AST_ALIAS_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/type.h"
+#include "src/tint/lang/wgsl/ast/type_decl.h"
+
+namespace tint::ast {
+
+/// A type alias type. Holds a name and pointer to another type.
+class Alias final : public utils::Castable<Alias, TypeDecl> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param name the symbol for the alias
+ /// @param subtype the alias'd type
+ Alias(ProgramID pid, NodeID nid, const Source& src, const Identifier* name, Type subtype);
+
+ /// Destructor
+ ~Alias() override;
+
+ /// Clones this type and all transitive types using the `CloneContext` `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned type
+ const Alias* Clone(CloneContext* ctx) const override;
+
+ /// the alias type
+ const Type type;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_ALIAS_H_
diff --git a/src/tint/lang/wgsl/ast/alias_test.cc b/src/tint/lang/wgsl/ast/alias_test.cc
new file mode 100644
index 0000000..04c16b8
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/alias_test.cc
@@ -0,0 +1,32 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/alias.h"
+#include "src/tint/builtin/access.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using AstAliasTest = TestHelper;
+
+TEST_F(AstAliasTest, Create) {
+ auto u32 = ty.u32();
+ auto* a = Alias("a_type", u32);
+ CheckIdentifier(a->name, "a_type");
+ CheckIdentifier(a->type, "u32");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/assignment_statement.cc b/src/tint/lang/wgsl/ast/assignment_statement.cc
new file mode 100644
index 0000000..dcb57f9
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/assignment_statement.cc
@@ -0,0 +1,45 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/assignment_statement.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::AssignmentStatement);
+
+namespace tint::ast {
+
+AssignmentStatement::AssignmentStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* l,
+ const Expression* r)
+ : Base(pid, nid, src), lhs(l), rhs(r) {
+ TINT_ASSERT(AST, lhs);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, lhs, program_id);
+ TINT_ASSERT(AST, rhs);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, rhs, program_id);
+}
+
+AssignmentStatement::~AssignmentStatement() = default;
+
+const AssignmentStatement* AssignmentStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* l = ctx->Clone(lhs);
+ auto* r = ctx->Clone(rhs);
+ return ctx->dst->create<AssignmentStatement>(src, l, r);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/assignment_statement.h b/src/tint/lang/wgsl/ast/assignment_statement.h
new file mode 100644
index 0000000..a635987
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/assignment_statement.h
@@ -0,0 +1,56 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_ASSIGNMENT_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_ASSIGNMENT_STATEMENT_H_
+
+#include "src/tint/lang/wgsl/ast/expression.h"
+#include "src/tint/lang/wgsl/ast/statement.h"
+
+namespace tint::ast {
+
+/// An assignment statement
+class AssignmentStatement final : public utils::Castable<AssignmentStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the assignment statement source
+ /// @param lhs the left side of the expression
+ /// @param rhs the right side of the expression
+ AssignmentStatement(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ const Expression* lhs,
+ const Expression* rhs);
+
+ /// Destructor
+ ~AssignmentStatement() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const AssignmentStatement* Clone(CloneContext* ctx) const override;
+
+ /// left side expression
+ const Expression* const lhs;
+
+ /// right side expression
+ const Expression* const rhs;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_ASSIGNMENT_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/assignment_statement_test.cc b/src/tint/lang/wgsl/ast/assignment_statement_test.cc
new file mode 100644
index 0000000..9785a72
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/assignment_statement_test.cc
@@ -0,0 +1,93 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/assignment_statement.h"
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast {
+namespace {
+
+using AssignmentStatementTest = TestHelper;
+
+TEST_F(AssignmentStatementTest, Creation) {
+ auto* lhs = Expr("lhs");
+ auto* rhs = Expr("rhs");
+
+ auto* stmt = create<AssignmentStatement>(lhs, rhs);
+ EXPECT_EQ(stmt->lhs, lhs);
+ EXPECT_EQ(stmt->rhs, rhs);
+}
+
+TEST_F(AssignmentStatementTest, CreationWithSource) {
+ auto* lhs = Expr("lhs");
+ auto* rhs = Expr("rhs");
+
+ auto* stmt = create<AssignmentStatement>(Source{Source::Location{20, 2}}, lhs, rhs);
+ auto src = stmt->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(AssignmentStatementTest, IsAssign) {
+ auto* lhs = Expr("lhs");
+ auto* rhs = Expr("rhs");
+
+ auto* stmt = create<AssignmentStatement>(lhs, rhs);
+ EXPECT_TRUE(stmt->Is<AssignmentStatement>());
+}
+
+TEST_F(AssignmentStatementTest, Assert_Null_LHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<AssignmentStatement>(nullptr, b.Expr(1_i));
+ },
+ "internal compiler error");
+}
+
+TEST_F(AssignmentStatementTest, Assert_Null_RHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<AssignmentStatement>(b.Expr(1_i), nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(AssignmentStatementTest, Assert_DifferentProgramID_LHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<AssignmentStatement>(b2.Expr("lhs"), b1.Expr("rhs"));
+ },
+ "internal compiler error");
+}
+
+TEST_F(AssignmentStatementTest, Assert_DifferentProgramID_RHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<AssignmentStatement>(b1.Expr("lhs"), b2.Expr("rhs"));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/attribute.cc b/src/tint/lang/wgsl/ast/attribute.cc
new file mode 100644
index 0000000..a2238f6
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/attribute.cc
@@ -0,0 +1,23 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::Attribute);
+
+namespace tint::ast {
+
+Attribute::~Attribute() = default;
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/attribute.h b/src/tint/lang/wgsl/ast/attribute.h
new file mode 100644
index 0000000..6356bb4
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/attribute.h
@@ -0,0 +1,67 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_ATTRIBUTE_H_
+
+#include <string>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/node.h"
+
+namespace tint::ast {
+
+/// The base class for all attributes
+class Attribute : public utils::Castable<Attribute, Node> {
+ public:
+ ~Attribute() override;
+
+ /// @returns the WGSL name for the attribute
+ virtual std::string Name() const = 0;
+
+ protected:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ Attribute(ProgramID pid, NodeID nid, const Source& src) : Base(pid, nid, src) {}
+};
+
+/// @param attributes the list of attributes to search
+/// @returns true if `attributes` includes a attribute of type `T`
+template <typename... Ts>
+bool HasAttribute(utils::VectorRef<const Attribute*> attributes) {
+ for (auto* attr : attributes) {
+ if (attr->IsAnyOf<Ts...>()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+/// @param attributes the list of attributes to search
+/// @returns a pointer to `T` from `attributes` if found, otherwise nullptr.
+template <typename T>
+const T* GetAttribute(utils::VectorRef<const Attribute*> attributes) {
+ for (auto* attr : attributes) {
+ if (attr->Is<T>()) {
+ return attr->As<T>();
+ }
+ }
+ return nullptr;
+}
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/binary_expression.cc b/src/tint/lang/wgsl/ast/binary_expression.cc
new file mode 100644
index 0000000..2de3b57
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/binary_expression.cc
@@ -0,0 +1,47 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/binary_expression.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::BinaryExpression);
+
+namespace tint::ast {
+
+BinaryExpression::BinaryExpression(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ BinaryOp o,
+ const Expression* l,
+ const Expression* r)
+ : Base(pid, nid, src), op(o), lhs(l), rhs(r) {
+ TINT_ASSERT(AST, lhs);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, lhs, program_id);
+ TINT_ASSERT(AST, rhs);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, rhs, program_id);
+ TINT_ASSERT(AST, op != BinaryOp::kNone);
+}
+
+BinaryExpression::~BinaryExpression() = default;
+
+const BinaryExpression* BinaryExpression::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* l = ctx->Clone(lhs);
+ auto* r = ctx->Clone(rhs);
+ return ctx->dst->create<BinaryExpression>(src, op, l, r);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/binary_expression.h b/src/tint/lang/wgsl/ast/binary_expression.h
new file mode 100644
index 0000000..12aa167
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/binary_expression.h
@@ -0,0 +1,310 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_BINARY_EXPRESSION_H_
+#define SRC_TINT_LANG_WGSL_AST_BINARY_EXPRESSION_H_
+
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+namespace tint::ast {
+
+/// The operator type
+enum class BinaryOp {
+ kNone = 0,
+ kAnd, // &
+ kOr, // |
+ kXor,
+ kLogicalAnd, // &&
+ kLogicalOr, // ||
+ kEqual,
+ kNotEqual,
+ kLessThan,
+ kGreaterThan,
+ kLessThanEqual,
+ kGreaterThanEqual,
+ kShiftLeft,
+ kShiftRight,
+ kAdd,
+ kSubtract,
+ kMultiply,
+ kDivide,
+ kModulo,
+};
+
+/// An binary expression
+class BinaryExpression final : public utils::Castable<BinaryExpression, Expression> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the binary expression source
+ /// @param op the operation type
+ /// @param lhs the left side of the expression
+ /// @param rhs the right side of the expression
+ BinaryExpression(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ BinaryOp op,
+ const Expression* lhs,
+ const Expression* rhs);
+ /// Move constructor
+ BinaryExpression(BinaryExpression&&);
+ ~BinaryExpression() override;
+
+ /// @returns true if the op is and
+ bool IsAnd() const { return op == BinaryOp::kAnd; }
+ /// @returns true if the op is or
+ bool IsOr() const { return op == BinaryOp::kOr; }
+ /// @returns true if the op is xor
+ bool IsXor() const { return op == BinaryOp::kXor; }
+ /// @returns true if the op is logical and
+ bool IsLogicalAnd() const { return op == BinaryOp::kLogicalAnd; }
+ /// @returns true if the op is logical or
+ bool IsLogicalOr() const { return op == BinaryOp::kLogicalOr; }
+ /// @returns true if the op is equal
+ bool IsEqual() const { return op == BinaryOp::kEqual; }
+ /// @returns true if the op is not equal
+ bool IsNotEqual() const { return op == BinaryOp::kNotEqual; }
+ /// @returns true if the op is less than
+ bool IsLessThan() const { return op == BinaryOp::kLessThan; }
+ /// @returns true if the op is greater than
+ bool IsGreaterThan() const { return op == BinaryOp::kGreaterThan; }
+ /// @returns true if the op is less than equal
+ bool IsLessThanEqual() const { return op == BinaryOp::kLessThanEqual; }
+ /// @returns true if the op is greater than equal
+ bool IsGreaterThanEqual() const { return op == BinaryOp::kGreaterThanEqual; }
+ /// @returns true if the op is shift left
+ bool IsShiftLeft() const { return op == BinaryOp::kShiftLeft; }
+ /// @returns true if the op is shift right
+ bool IsShiftRight() const { return op == BinaryOp::kShiftRight; }
+ /// @returns true if the op is add
+ bool IsAdd() const { return op == BinaryOp::kAdd; }
+ /// @returns true if the op is subtract
+ bool IsSubtract() const { return op == BinaryOp::kSubtract; }
+ /// @returns true if the op is multiply
+ bool IsMultiply() const { return op == BinaryOp::kMultiply; }
+ /// @returns true if the op is divide
+ bool IsDivide() const { return op == BinaryOp::kDivide; }
+ /// @returns true if the op is modulo
+ bool IsModulo() const { return op == BinaryOp::kModulo; }
+ /// @returns true if the op is an arithmetic operation
+ bool IsArithmetic() const;
+ /// @returns true if the op is a comparison operation
+ bool IsComparison() const;
+ /// @returns true if the op is a bitwise operation
+ bool IsBitwise() const;
+ /// @returns true if the op is a bit shift operation
+ bool IsBitshift() const;
+ /// @returns true if the op is a logical expression
+ bool IsLogical() const;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const BinaryExpression* Clone(CloneContext* ctx) const override;
+
+ /// the binary op type
+ const BinaryOp op;
+ /// the left side expression
+ const Expression* const lhs;
+ /// the right side expression
+ const Expression* const rhs;
+};
+
+/// @param op the operator
+/// @returns true if the op is an arithmetic operation
+inline bool IsArithmetic(BinaryOp op) {
+ switch (op) {
+ case BinaryOp::kAdd:
+ case BinaryOp::kSubtract:
+ case BinaryOp::kMultiply:
+ case BinaryOp::kDivide:
+ case BinaryOp::kModulo:
+ return true;
+ default:
+ return false;
+ }
+}
+
+/// @param op the operator
+/// @returns true if the op is a comparison operation
+inline bool IsComparison(BinaryOp op) {
+ switch (op) {
+ case BinaryOp::kEqual:
+ case BinaryOp::kNotEqual:
+ case BinaryOp::kLessThan:
+ case BinaryOp::kLessThanEqual:
+ case BinaryOp::kGreaterThan:
+ case BinaryOp::kGreaterThanEqual:
+ return true;
+ default:
+ return false;
+ }
+}
+
+/// @param op the operator
+/// @returns true if the op is a bitwise operation
+inline bool IsBitwise(BinaryOp op) {
+ switch (op) {
+ case BinaryOp::kAnd:
+ case BinaryOp::kOr:
+ case BinaryOp::kXor:
+ return true;
+ default:
+ return false;
+ }
+}
+
+/// @param op the operator
+/// @returns true if the op is a bit shift operation
+inline bool IsBitshift(BinaryOp op) {
+ switch (op) {
+ case BinaryOp::kShiftLeft:
+ case BinaryOp::kShiftRight:
+ return true;
+ default:
+ return false;
+ }
+}
+
+inline bool BinaryExpression::IsLogical() const {
+ switch (op) {
+ case BinaryOp::kLogicalAnd:
+ case BinaryOp::kLogicalOr:
+ return true;
+ default:
+ return false;
+ }
+}
+
+inline bool BinaryExpression::IsArithmetic() const {
+ return ast::IsArithmetic(op);
+}
+
+inline bool BinaryExpression::IsComparison() const {
+ return ast::IsComparison(op);
+}
+
+inline bool BinaryExpression::IsBitwise() const {
+ return ast::IsBitwise(op);
+}
+
+inline bool BinaryExpression::IsBitshift() const {
+ return ast::IsBitshift(op);
+}
+
+/// @returns the human readable name of the given BinaryOp
+/// @param op the BinaryOp
+constexpr const char* FriendlyName(BinaryOp op) {
+ switch (op) {
+ case BinaryOp::kNone:
+ return "none";
+ case BinaryOp::kAnd:
+ return "and";
+ case BinaryOp::kOr:
+ return "or";
+ case BinaryOp::kXor:
+ return "xor";
+ case BinaryOp::kLogicalAnd:
+ return "logical_and";
+ case BinaryOp::kLogicalOr:
+ return "logical_or";
+ case BinaryOp::kEqual:
+ return "equal";
+ case BinaryOp::kNotEqual:
+ return "not_equal";
+ case BinaryOp::kLessThan:
+ return "less_than";
+ case BinaryOp::kGreaterThan:
+ return "greater_than";
+ case BinaryOp::kLessThanEqual:
+ return "less_than_equal";
+ case BinaryOp::kGreaterThanEqual:
+ return "greater_than_equal";
+ case BinaryOp::kShiftLeft:
+ return "shift_left";
+ case BinaryOp::kShiftRight:
+ return "shift_right";
+ case BinaryOp::kAdd:
+ return "add";
+ case BinaryOp::kSubtract:
+ return "subtract";
+ case BinaryOp::kMultiply:
+ return "multiply";
+ case BinaryOp::kDivide:
+ return "divide";
+ case BinaryOp::kModulo:
+ return "modulo";
+ }
+ return "<invalid>";
+}
+
+/// @returns the WGSL operator of the BinaryOp
+/// @param op the BinaryOp
+constexpr const char* Operator(BinaryOp op) {
+ switch (op) {
+ case BinaryOp::kAnd:
+ return "&";
+ case BinaryOp::kOr:
+ return "|";
+ case BinaryOp::kXor:
+ return "^";
+ case BinaryOp::kLogicalAnd:
+ return "&&";
+ case BinaryOp::kLogicalOr:
+ return "||";
+ case BinaryOp::kEqual:
+ return "==";
+ case BinaryOp::kNotEqual:
+ return "!=";
+ case BinaryOp::kLessThan:
+ return "<";
+ case BinaryOp::kGreaterThan:
+ return ">";
+ case BinaryOp::kLessThanEqual:
+ return "<=";
+ case BinaryOp::kGreaterThanEqual:
+ return ">=";
+ case BinaryOp::kShiftLeft:
+ return "<<";
+ case BinaryOp::kShiftRight:
+ return ">>";
+ case BinaryOp::kAdd:
+ return "+";
+ case BinaryOp::kSubtract:
+ return "-";
+ case BinaryOp::kMultiply:
+ return "*";
+ case BinaryOp::kDivide:
+ return "/";
+ case BinaryOp::kModulo:
+ return "%";
+ default:
+ break;
+ }
+ return "<invalid>";
+}
+
+/// @param out the stream to write to
+/// @param op the BinaryOp
+/// @return the stream so calls can be chained
+inline utils::StringStream& operator<<(utils::StringStream& out, BinaryOp op) {
+ out << FriendlyName(op);
+ return out;
+}
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_BINARY_EXPRESSION_H_
diff --git a/src/tint/lang/wgsl/ast/binary_expression_test.cc b/src/tint/lang/wgsl/ast/binary_expression_test.cc
new file mode 100644
index 0000000..6385ae1
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/binary_expression_test.cc
@@ -0,0 +1,90 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using BinaryExpressionTest = TestHelper;
+
+TEST_F(BinaryExpressionTest, Creation) {
+ auto* lhs = Expr("lhs");
+ auto* rhs = Expr("rhs");
+
+ auto* r = create<BinaryExpression>(BinaryOp::kEqual, lhs, rhs);
+ EXPECT_EQ(r->lhs, lhs);
+ EXPECT_EQ(r->rhs, rhs);
+ EXPECT_EQ(r->op, BinaryOp::kEqual);
+}
+
+TEST_F(BinaryExpressionTest, Creation_WithSource) {
+ auto* lhs = Expr("lhs");
+ auto* rhs = Expr("rhs");
+
+ auto* r = create<BinaryExpression>(Source{Source::Location{20, 2}}, BinaryOp::kEqual, lhs, rhs);
+ auto src = r->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(BinaryExpressionTest, IsBinary) {
+ auto* lhs = Expr("lhs");
+ auto* rhs = Expr("rhs");
+
+ auto* r = create<BinaryExpression>(BinaryOp::kEqual, lhs, rhs);
+ EXPECT_TRUE(r->Is<BinaryExpression>());
+}
+
+TEST_F(BinaryExpressionTest, Assert_Null_LHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<BinaryExpression>(BinaryOp::kEqual, nullptr, b.Expr("rhs"));
+ },
+ "internal compiler error");
+}
+
+TEST_F(BinaryExpressionTest, Assert_Null_RHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<BinaryExpression>(BinaryOp::kEqual, b.Expr("lhs"), nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(BinaryExpressionTest, Assert_DifferentProgramID_LHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<BinaryExpression>(BinaryOp::kEqual, b2.Expr("lhs"), b1.Expr("rhs"));
+ },
+ "internal compiler error");
+}
+
+TEST_F(BinaryExpressionTest, Assert_DifferentProgramID_RHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<BinaryExpression>(BinaryOp::kEqual, b1.Expr("lhs"), b2.Expr("rhs"));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/binding_attribute.cc b/src/tint/lang/wgsl/ast/binding_attribute.cc
new file mode 100644
index 0000000..3f06e95
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/binding_attribute.cc
@@ -0,0 +1,44 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/binding_attribute.h"
+
+#include <string>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::BindingAttribute);
+
+namespace tint::ast {
+
+BindingAttribute::BindingAttribute(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* exp)
+ : Base(pid, nid, src), expr(exp) {}
+
+BindingAttribute::~BindingAttribute() = default;
+
+std::string BindingAttribute::Name() const {
+ return "binding";
+}
+
+const BindingAttribute* BindingAttribute::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* expr_ = ctx->Clone(expr);
+ return ctx->dst->create<BindingAttribute>(src, expr_);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/binding_attribute.h b/src/tint/lang/wgsl/ast/binding_attribute.h
new file mode 100644
index 0000000..918674e
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/binding_attribute.h
@@ -0,0 +1,51 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_BINDING_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_BINDING_ATTRIBUTE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+namespace tint::ast {
+
+/// A binding attribute
+class BindingAttribute final : public utils::Castable<BindingAttribute, Attribute> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param expr the binding expression
+ BindingAttribute(ProgramID pid, NodeID nid, const Source& src, const Expression* expr);
+ ~BindingAttribute() override;
+
+ /// @returns the WGSL name for the attribute
+ std::string Name() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const BindingAttribute* Clone(CloneContext* ctx) const override;
+
+ /// the binding expression
+ const Expression* const expr;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_BINDING_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/binding_attribute_test.cc b/src/tint/lang/wgsl/ast/binding_attribute_test.cc
new file mode 100644
index 0000000..02f895c
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/binding_attribute_test.cc
@@ -0,0 +1,29 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using BindingAttributeTest = TestHelper;
+
+TEST_F(BindingAttributeTest, Creation) {
+ auto* d = Binding(2_a);
+ EXPECT_TRUE(d->expr->Is<IntLiteralExpression>());
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/bitcast_expression.cc b/src/tint/lang/wgsl/ast/bitcast_expression.cc
new file mode 100644
index 0000000..3480d57
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/bitcast_expression.cc
@@ -0,0 +1,44 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/bitcast_expression.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::BitcastExpression);
+
+namespace tint::ast {
+
+BitcastExpression::BitcastExpression(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ Type t,
+ const Expression* e)
+ : Base(pid, nid, src), type(t), expr(e) {
+ TINT_ASSERT(AST, type);
+ TINT_ASSERT(AST, expr);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, expr, program_id);
+}
+
+BitcastExpression::~BitcastExpression() = default;
+
+const BitcastExpression* BitcastExpression::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto t = ctx->Clone(type);
+ auto* e = ctx->Clone(expr);
+ return ctx->dst->create<BitcastExpression>(src, t, e);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/bitcast_expression.h b/src/tint/lang/wgsl/ast/bitcast_expression.h
new file mode 100644
index 0000000..ef2de22
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/bitcast_expression.h
@@ -0,0 +1,55 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_BITCAST_EXPRESSION_H_
+#define SRC_TINT_LANG_WGSL_AST_BITCAST_EXPRESSION_H_
+
+#include "src/tint/lang/wgsl/ast/expression.h"
+#include "src/tint/lang/wgsl/ast/type.h"
+
+namespace tint::ast {
+
+/// A bitcast expression
+class BitcastExpression final : public utils::Castable<BitcastExpression, Expression> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the bitcast expression source
+ /// @param type the type
+ /// @param expr the expr
+ BitcastExpression(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ Type type,
+ const Expression* expr);
+
+ /// Destructor
+ ~BitcastExpression() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const BitcastExpression* Clone(CloneContext* ctx) const override;
+
+ /// the target cast type
+ const Type type;
+ /// the expression
+ const Expression* const expr;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_BITCAST_EXPRESSION_H_
diff --git a/src/tint/lang/wgsl/ast/bitcast_expression_test.cc b/src/tint/lang/wgsl/ast/bitcast_expression_test.cc
new file mode 100644
index 0000000..54df821
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/bitcast_expression_test.cc
@@ -0,0 +1,77 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/bitcast_expression.h"
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using BitcastExpressionTest = TestHelper;
+
+TEST_F(BitcastExpressionTest, Create) {
+ auto* expr = Expr("expr");
+ auto* exp = Bitcast(ty.f32(), expr);
+ CheckIdentifier(exp->type, "f32");
+ ASSERT_EQ(exp->expr, expr);
+}
+
+TEST_F(BitcastExpressionTest, CreateWithSource) {
+ auto* expr = Expr("expr");
+
+ auto* exp = Bitcast(Source{Source::Location{20, 2}}, ty.f32(), expr);
+ auto src = exp->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(BitcastExpressionTest, IsBitcast) {
+ auto* expr = Expr("expr");
+
+ auto* exp = Bitcast(ty.f32(), expr);
+ EXPECT_TRUE(exp->Is<BitcastExpression>());
+}
+
+TEST_F(BitcastExpressionTest, Assert_Null_Type) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.Bitcast(ast::Type(), "idx");
+ },
+ "internal compiler error");
+}
+
+TEST_F(BitcastExpressionTest, Assert_Null_Expr) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.Bitcast(b.ty.f32(), nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(BitcastExpressionTest, Assert_DifferentProgramID_Expr) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Bitcast(b1.ty.f32(), b2.Expr("idx"));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/block_statement.cc b/src/tint/lang/wgsl/ast/block_statement.cc
new file mode 100644
index 0000000..a87d28f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/block_statement.cc
@@ -0,0 +1,49 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/block_statement.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::BlockStatement);
+
+namespace tint::ast {
+
+BlockStatement::BlockStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ utils::VectorRef<const Statement*> stmts,
+ utils::VectorRef<const Attribute*> attrs)
+ : Base(pid, nid, src), statements(std::move(stmts)), attributes(attrs) {
+ for (auto* stmt : statements) {
+ TINT_ASSERT(AST, stmt);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, stmt, program_id);
+ }
+ for (auto* attr : attributes) {
+ TINT_ASSERT(AST, attr);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, attr, program_id);
+ }
+}
+
+BlockStatement::~BlockStatement() = default;
+
+const BlockStatement* BlockStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto stmts = ctx->Clone(statements);
+ auto attrs = ctx->Clone(attributes);
+ return ctx->dst->create<BlockStatement>(src, std::move(stmts), std::move(attrs));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/block_statement.h b/src/tint/lang/wgsl/ast/block_statement.h
new file mode 100644
index 0000000..befb6c7
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/block_statement.h
@@ -0,0 +1,68 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_BLOCK_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_BLOCK_STATEMENT_H_
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/statement.h"
+
+// Forward declarations
+namespace tint::ast {
+class Attribute;
+} // namespace tint::ast
+
+namespace tint::ast {
+
+/// A block statement
+class BlockStatement final : public utils::Castable<BlockStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the block statement source
+ /// @param statements the statements
+ /// @param attributes the block statement attributes
+ BlockStatement(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ utils::VectorRef<const Statement*> statements,
+ utils::VectorRef<const Attribute*> attributes);
+
+ /// Destructor
+ ~BlockStatement() override;
+
+ /// @returns true if the block has no statements
+ bool Empty() const { return statements.IsEmpty(); }
+
+ /// @returns the last statement in the block or nullptr if block empty
+ const Statement* Last() const { return statements.IsEmpty() ? nullptr : statements.Back(); }
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const BlockStatement* Clone(CloneContext* ctx) const override;
+
+ /// the statement list
+ const utils::Vector<const Statement*, 8> statements;
+
+ /// the attribute list
+ const utils::Vector<const Attribute*, 4> attributes;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_BLOCK_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/block_statement_test.cc b/src/tint/lang/wgsl/ast/block_statement_test.cc
new file mode 100644
index 0000000..d4ec4cf
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/block_statement_test.cc
@@ -0,0 +1,82 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/if_statement.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using BlockStatementTest = TestHelper;
+
+TEST_F(BlockStatementTest, Creation) {
+ auto* d = create<DiscardStatement>();
+ auto* ptr = d;
+
+ auto* b = create<BlockStatement>(utils::Vector{d}, utils::Empty);
+
+ ASSERT_EQ(b->statements.Length(), 1u);
+ EXPECT_EQ(b->statements[0], ptr);
+ EXPECT_EQ(b->attributes.Length(), 0u);
+}
+
+TEST_F(BlockStatementTest, Creation_WithSource) {
+ auto* b = create<BlockStatement>(Source{Source::Location{20, 2}}, utils::Empty, utils::Empty);
+ auto src = b->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(BlockStatementTest, Creation_WithAttributes) {
+ auto* d = create<DiscardStatement>();
+ auto* ptr = d;
+
+ auto* attr1 = DiagnosticAttribute(builtin::DiagnosticSeverity::kOff, "foo");
+ auto* attr2 = DiagnosticAttribute(builtin::DiagnosticSeverity::kOff, "bar");
+ auto* b = create<BlockStatement>(utils::Vector{d}, utils::Vector{attr1, attr2});
+
+ ASSERT_EQ(b->statements.Length(), 1u);
+ EXPECT_EQ(b->statements[0], ptr);
+ EXPECT_THAT(b->attributes, testing::ElementsAre(attr1, attr2));
+}
+
+TEST_F(BlockStatementTest, IsBlock) {
+ auto* b = create<BlockStatement>(utils::Empty, utils::Empty);
+ EXPECT_TRUE(b->Is<BlockStatement>());
+}
+
+TEST_F(BlockStatementTest, Assert_Null_Statement) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<BlockStatement>(utils::Vector<const Statement*, 1>{nullptr}, utils::Empty);
+ },
+ "internal compiler error");
+}
+
+TEST_F(BlockStatementTest, Assert_DifferentProgramID_Statement) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<BlockStatement>(utils::Vector{b2.create<DiscardStatement>()}, utils::Empty);
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/bool_literal_expression.cc b/src/tint/lang/wgsl/ast/bool_literal_expression.cc
new file mode 100644
index 0000000..bf275ea
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/bool_literal_expression.cc
@@ -0,0 +1,34 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/bool_literal_expression.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::BoolLiteralExpression);
+
+namespace tint::ast {
+
+BoolLiteralExpression::BoolLiteralExpression(ProgramID pid, NodeID nid, const Source& src, bool val)
+ : Base(pid, nid, src), value(val) {}
+
+BoolLiteralExpression::~BoolLiteralExpression() = default;
+
+const BoolLiteralExpression* BoolLiteralExpression::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ return ctx->dst->create<BoolLiteralExpression>(src, value);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/bool_literal_expression.h b/src/tint/lang/wgsl/ast/bool_literal_expression.h
new file mode 100644
index 0000000..eef02fa
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/bool_literal_expression.h
@@ -0,0 +1,48 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_BOOL_LITERAL_EXPRESSION_H_
+#define SRC_TINT_LANG_WGSL_AST_BOOL_LITERAL_EXPRESSION_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/literal_expression.h"
+
+namespace tint::ast {
+
+/// A boolean literal
+class BoolLiteralExpression final
+ : public utils::Castable<BoolLiteralExpression, LiteralExpression> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param value the bool literals value
+ BoolLiteralExpression(ProgramID pid, NodeID nid, const Source& src, bool value);
+ ~BoolLiteralExpression() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const BoolLiteralExpression* Clone(CloneContext* ctx) const override;
+
+ /// The boolean literal value
+ const bool value;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_BOOL_LITERAL_EXPRESSION_H_
diff --git a/src/tint/lang/wgsl/ast/bool_literal_expression_test.cc b/src/tint/lang/wgsl/ast/bool_literal_expression_test.cc
new file mode 100644
index 0000000..134bea9
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/bool_literal_expression_test.cc
@@ -0,0 +1,35 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using BoolLiteralExpressionTest = TestHelper;
+
+TEST_F(BoolLiteralExpressionTest, True) {
+ auto* b = create<BoolLiteralExpression>(true);
+ ASSERT_TRUE(b->Is<BoolLiteralExpression>());
+ ASSERT_TRUE(b->value);
+}
+
+TEST_F(BoolLiteralExpressionTest, False) {
+ auto* b = create<BoolLiteralExpression>(false);
+ ASSERT_TRUE(b->Is<BoolLiteralExpression>());
+ ASSERT_FALSE(b->value);
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/break_if_statement.cc b/src/tint/lang/wgsl/ast/break_if_statement.cc
new file mode 100644
index 0000000..ddd0e90
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/break_if_statement.cc
@@ -0,0 +1,41 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/break_if_statement.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::BreakIfStatement);
+
+namespace tint::ast {
+
+BreakIfStatement::BreakIfStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* cond)
+ : Base(pid, nid, src), condition(cond) {
+ TINT_ASSERT(AST, condition);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, condition, program_id);
+}
+
+BreakIfStatement::~BreakIfStatement() = default;
+
+const BreakIfStatement* BreakIfStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* cond = ctx->Clone(condition);
+ return ctx->dst->create<BreakIfStatement>(src, cond);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/break_if_statement.h b/src/tint/lang/wgsl/ast/break_if_statement.h
new file mode 100644
index 0000000..c986e90
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/break_if_statement.h
@@ -0,0 +1,49 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_BREAK_IF_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_BREAK_IF_STATEMENT_H_
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/block_statement.h"
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+namespace tint::ast {
+
+/// A break-if statement
+class BreakIfStatement final : public utils::Castable<BreakIfStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param condition the if condition
+ BreakIfStatement(ProgramID pid, NodeID nid, const Source& src, const Expression* condition);
+
+ /// Destructor
+ ~BreakIfStatement() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext` `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const BreakIfStatement* Clone(CloneContext* ctx) const override;
+
+ /// The if condition or nullptr if none set
+ const Expression* const condition;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_BREAK_IF_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/break_if_statement_test.cc b/src/tint/lang/wgsl/ast/break_if_statement_test.cc
new file mode 100644
index 0000000..2eb0091
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/break_if_statement_test.cc
@@ -0,0 +1,58 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/break_if_statement.h"
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using BreakIfStatementTest = TestHelper;
+
+TEST_F(BreakIfStatementTest, Creation) {
+ auto* cond = Expr("cond");
+ auto* stmt = BreakIf(Source{Source::Location{20, 2}}, cond);
+ auto src = stmt->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(BreakIfStatementTest, IsBreakIf) {
+ auto* stmt = BreakIf(Expr(true));
+ EXPECT_TRUE(stmt->Is<BreakIfStatement>());
+}
+
+TEST_F(BreakIfStatementTest, Assert_Null_Condition) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.BreakIf(nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(BreakIfStatementTest, Assert_DifferentProgramID_Cond) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.BreakIf(b2.Expr(true));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/break_statement.cc b/src/tint/lang/wgsl/ast/break_statement.cc
new file mode 100644
index 0000000..89c8c3a
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/break_statement.cc
@@ -0,0 +1,34 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/break_statement.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::BreakStatement);
+
+namespace tint::ast {
+
+BreakStatement::BreakStatement(ProgramID pid, NodeID nid, const Source& src)
+ : Base(pid, nid, src) {}
+
+BreakStatement::~BreakStatement() = default;
+
+const BreakStatement* BreakStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ return ctx->dst->create<BreakStatement>(src);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/break_statement.h b/src/tint/lang/wgsl/ast/break_statement.h
new file mode 100644
index 0000000..8d48e02
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/break_statement.h
@@ -0,0 +1,43 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_BREAK_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_BREAK_STATEMENT_H_
+
+#include "src/tint/lang/wgsl/ast/statement.h"
+
+namespace tint::ast {
+
+/// An break statement
+class BreakStatement final : public utils::Castable<BreakStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ BreakStatement(ProgramID pid, NodeID nid, const Source& src);
+
+ /// Destructor
+ ~BreakStatement() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const BreakStatement* Clone(CloneContext* ctx) const override;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_BREAK_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/break_statement_test.cc b/src/tint/lang/wgsl/ast/break_statement_test.cc
new file mode 100644
index 0000000..324d53d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/break_statement_test.cc
@@ -0,0 +1,37 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/break_statement.h"
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using BreakStatementTest = TestHelper;
+
+TEST_F(BreakStatementTest, Creation_WithSource) {
+ auto* stmt = create<BreakStatement>(Source{Source::Location{20, 2}});
+ auto src = stmt->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(BreakStatementTest, IsBreak) {
+ auto* stmt = create<BreakStatement>();
+ EXPECT_TRUE(stmt->Is<BreakStatement>());
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/builtin_attribute.cc b/src/tint/lang/wgsl/ast/builtin_attribute.cc
new file mode 100644
index 0000000..3b33e6a
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/builtin_attribute.cc
@@ -0,0 +1,46 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/builtin_attribute.h"
+
+#include <string>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::BuiltinAttribute);
+
+namespace tint::ast {
+
+BuiltinAttribute::BuiltinAttribute(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* b)
+ : Base(pid, nid, src), builtin(b) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL(AST, b, program_id);
+}
+
+BuiltinAttribute::~BuiltinAttribute() = default;
+
+std::string BuiltinAttribute::Name() const {
+ return "builtin";
+}
+
+const BuiltinAttribute* BuiltinAttribute::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto b = ctx->Clone(builtin);
+ return ctx->dst->create<BuiltinAttribute>(src, b);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/builtin_attribute.h b/src/tint/lang/wgsl/ast/builtin_attribute.h
new file mode 100644
index 0000000..a5157c2
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/builtin_attribute.h
@@ -0,0 +1,55 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_BUILTIN_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_BUILTIN_ATTRIBUTE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+
+// Forward declarations
+namespace tint::ast {
+class Expression;
+}
+
+namespace tint::ast {
+
+/// A builtin attribute
+class BuiltinAttribute final : public utils::Castable<BuiltinAttribute, Attribute> {
+ public:
+ /// constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param builtin the builtin value
+ BuiltinAttribute(ProgramID pid, NodeID nid, const Source& src, const Expression* builtin);
+ ~BuiltinAttribute() override;
+
+ /// @returns the WGSL name for the attribute
+ std::string Name() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const BuiltinAttribute* Clone(CloneContext* ctx) const override;
+
+ /// The builtin value
+ const Expression* const builtin;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_BUILTIN_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/builtin_attribute_test.cc b/src/tint/lang/wgsl/ast/builtin_attribute_test.cc
new file mode 100644
index 0000000..98254eb
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/builtin_attribute_test.cc
@@ -0,0 +1,50 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied->
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest-spi.h"
+
+#include "src/tint/builtin/builtin_value.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using BuiltinAttributeTest = TestHelper;
+
+TEST_F(BuiltinAttributeTest, Creation) {
+ auto* d = Builtin(builtin::BuiltinValue::kFragDepth);
+ CheckIdentifier(d->builtin, "frag_depth");
+}
+
+TEST_F(BuiltinAttributeTest, Assert_Null_Builtin) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.Builtin(nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(BuiltinAttributeTest, Assert_DifferentProgramID_Builtin) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Builtin(b2.Expr("bang"));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/builtin_texture_helper_test.cc b/src/tint/lang/wgsl/ast/builtin_texture_helper_test.cc
new file mode 100644
index 0000000..5219e73
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/builtin_texture_helper_test.cc
@@ -0,0 +1,2558 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/builtin_texture_helper_test.h"
+
+#include "src/tint/builtin/texel_format.h"
+#include "src/tint/type/depth_texture.h"
+#include "src/tint/type/multisampled_texture.h"
+#include "src/tint/type/sampled_texture.h"
+#include "src/tint/type/texture_dimension.h"
+
+namespace tint::ast::test {
+namespace {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+utils::StringStream& operator<<(utils::StringStream& out, const TextureKind& kind) {
+ switch (kind) {
+ case TextureKind::kRegular:
+ out << "regular";
+ break;
+ case TextureKind::kDepth:
+ out << "depth";
+ break;
+ case TextureKind::kDepthMultisampled:
+ out << "depth-multisampled";
+ break;
+ case TextureKind::kMultisampled:
+ out << "multisampled";
+ break;
+ case TextureKind::kStorage:
+ out << "storage";
+ break;
+ }
+ return out;
+}
+
+utils::StringStream& operator<<(utils::StringStream& out, const TextureDataType& ty) {
+ switch (ty) {
+ case TextureDataType::kF32:
+ out << "f32";
+ break;
+ case TextureDataType::kU32:
+ out << "u32";
+ break;
+ case TextureDataType::kI32:
+ out << "i32";
+ break;
+ }
+ return out;
+}
+
+} // namespace
+
+TextureOverloadCase::TextureOverloadCase(ValidTextureOverload o,
+ const char* desc,
+ TextureKind tk,
+ type::SamplerKind sk,
+ type::TextureDimension dims,
+ TextureDataType datatype,
+ const char* f,
+ std::function<Args(ProgramBuilder*)> a,
+ bool ret_val)
+ : overload(o),
+ description(desc),
+ texture_kind(tk),
+ sampler_kind(sk),
+ texture_dimension(dims),
+ texture_data_type(datatype),
+ function(f),
+ args(std::move(a)),
+ returns_value(ret_val) {}
+TextureOverloadCase::TextureOverloadCase(ValidTextureOverload o,
+ const char* desc,
+ TextureKind tk,
+ type::TextureDimension dims,
+ TextureDataType datatype,
+ const char* f,
+ std::function<Args(ProgramBuilder*)> a,
+ bool ret_val)
+ : overload(o),
+ description(desc),
+ texture_kind(tk),
+ texture_dimension(dims),
+ texture_data_type(datatype),
+ function(f),
+ args(std::move(a)),
+ returns_value(ret_val) {}
+TextureOverloadCase::TextureOverloadCase(ValidTextureOverload o,
+ const char* d,
+ tint::builtin::Access acc,
+ tint::builtin::TexelFormat fmt,
+ type::TextureDimension dims,
+ TextureDataType datatype,
+ const char* f,
+ std::function<Args(ProgramBuilder*)> a,
+ bool ret_val)
+ : overload(o),
+ description(d),
+ texture_kind(TextureKind::kStorage),
+ access(acc),
+ texel_format(fmt),
+ texture_dimension(dims),
+ texture_data_type(datatype),
+ function(f),
+ args(std::move(a)),
+ returns_value(ret_val) {}
+TextureOverloadCase::TextureOverloadCase(const TextureOverloadCase&) = default;
+TextureOverloadCase::~TextureOverloadCase() = default;
+
+std::ostream& operator<<(std::ostream& out, const TextureOverloadCase& data) {
+ utils::StringStream str;
+ str << "TextureOverloadCase " << static_cast<int>(data.overload) << "\n";
+ str << data.description << "\n";
+ str << "texture_kind: " << data.texture_kind << "\n";
+ str << "sampler_kind: ";
+ if (data.texture_kind != TextureKind::kStorage) {
+ str << data.sampler_kind;
+ } else {
+ str << "<unused>";
+ }
+ str << "\n";
+ str << "access: " << data.access << "\n";
+ str << "texel_format: " << data.texel_format << "\n";
+ str << "texture_dimension: " << data.texture_dimension << "\n";
+ str << "texture_data_type: " << data.texture_data_type << "\n";
+
+ out << str.str();
+ return out;
+}
+
+Type TextureOverloadCase::BuildResultVectorComponentType(ProgramBuilder* b) const {
+ switch (texture_data_type) {
+ case test::TextureDataType::kF32:
+ return b->ty.f32();
+ case test::TextureDataType::kU32:
+ return b->ty.u32();
+ case test::TextureDataType::kI32:
+ return b->ty.i32();
+ }
+
+ TINT_UNREACHABLE(AST, b->Diagnostics());
+ return {};
+}
+
+const Variable* TextureOverloadCase::BuildTextureVariable(ProgramBuilder* b) const {
+ utils::Vector attrs{
+ b->Group(0_u),
+ b->Binding(0_a),
+ };
+ switch (texture_kind) {
+ case test::TextureKind::kRegular:
+ return b->GlobalVar(
+ kTextureName,
+ b->ty.sampled_texture(texture_dimension, BuildResultVectorComponentType(b)), attrs);
+
+ case test::TextureKind::kDepth:
+ return b->GlobalVar(kTextureName, b->ty.depth_texture(texture_dimension), attrs);
+
+ case test::TextureKind::kDepthMultisampled:
+ return b->GlobalVar(kTextureName, b->ty.depth_multisampled_texture(texture_dimension),
+ attrs);
+
+ case test::TextureKind::kMultisampled:
+ return b->GlobalVar(
+ kTextureName,
+ b->ty.multisampled_texture(texture_dimension, BuildResultVectorComponentType(b)),
+ attrs);
+
+ case test::TextureKind::kStorage: {
+ auto st = b->ty.storage_texture(texture_dimension, texel_format, access);
+ return b->GlobalVar(kTextureName, st, attrs);
+ }
+ }
+
+ TINT_UNREACHABLE(AST, b->Diagnostics());
+ return nullptr;
+}
+
+const Variable* TextureOverloadCase::BuildSamplerVariable(ProgramBuilder* b) const {
+ utils::Vector attrs = {b->Group(0_a), b->Binding(1_a)};
+ return b->GlobalVar(kSamplerName, b->ty.sampler(sampler_kind), attrs);
+}
+
+std::vector<TextureOverloadCase> TextureOverloadCase::ValidCases() {
+ return {
+ {
+ ValidTextureOverload::kDimensions1d,
+ "textureDimensions(t : texture_1d<f32>) -> u32",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k1d,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensions2d,
+ "textureDimensions(t : texture_2d<f32>) -> vec2<u32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensions2dLevel,
+ "textureDimensions(t : texture_2d<f32>,\n"
+ " level : i32) -> vec2<u32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName, 1_i); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensions2dArray,
+ "textureDimensions(t : texture_2d_array<f32>) -> vec2<u32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensions2dArrayLevel,
+ "textureDimensions(t : texture_2d_array<f32>,\n"
+ " level : i32) -> vec2<u32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName, 1_i); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensions3d,
+ "textureDimensions(t : texture_3d<f32>) -> vec3<u32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k3d,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensions3dLevel,
+ "textureDimensions(t : texture_3d<f32>,\n"
+ " level : i32) -> vec3<u32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k3d,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName, 1_i); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsCube,
+ "textureDimensions(t : texture_cube<f32>) -> vec2<u32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCube,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsCubeLevel,
+ "textureDimensions(t : texture_cube<f32>,\n"
+ " level : i32) -> vec2<u32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCube,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName, 1_i); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsCubeArray,
+ "textureDimensions(t : texture_cube_array<f32>) -> vec2<u32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsCubeArrayLevel,
+ "textureDimensions(t : texture_cube_array<f32>,\n"
+ " level : i32) -> vec2<u32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName, 1_i); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsMultisampled2d,
+ "textureDimensions(t : texture_multisampled_2d<f32>)-> vec2<u32>",
+ TextureKind::kMultisampled,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsDepth2d,
+ "textureDimensions(t : texture_depth_2d) -> vec2<u32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsDepth2dLevel,
+ "textureDimensions(t : texture_depth_2d,\n"
+ " level : i32) -> vec2<u32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName, 1_i); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsDepth2dArray,
+ "textureDimensions(t : texture_depth_2d_array) -> vec2<u32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsDepth2dArrayLevel,
+ "textureDimensions(t : texture_depth_2d_array,\n"
+ " level : i32) -> vec2<u32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName, 1_i); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsDepthCube,
+ "textureDimensions(t : texture_depth_cube) -> vec2<u32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCube,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsDepthCubeLevel,
+ "textureDimensions(t : texture_depth_cube,\n"
+ " level : i32) -> vec2<u32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCube,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName, 1_i); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsDepthCubeArray,
+ "textureDimensions(t : texture_depth_cube_array) -> vec2<u32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsDepthCubeArrayLevel,
+ "textureDimensions(t : texture_depth_cube_array,\n"
+ " level : i32) -> vec2<u32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName, 1_i); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsDepthMultisampled2d,
+ "textureDimensions(t : texture_depth_multisampled_2d) -> vec2<u32>",
+ TextureKind::kDepthMultisampled,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsStorageWO1d,
+ "textureDimensions(t : texture_storage_1d<rgba32float>) -> u32",
+ tint::builtin::Access::kWrite,
+ tint::builtin::TexelFormat::kRgba32Float,
+ type::TextureDimension::k1d,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsStorageWO2d,
+ "textureDimensions(t : texture_storage_2d<rgba32float>) -> vec2<u32>",
+ tint::builtin::Access::kWrite,
+ tint::builtin::TexelFormat::kRgba32Float,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsStorageWO2dArray,
+ "textureDimensions(t : texture_storage_2d_array<rgba32float>) -> vec2<u32>",
+ tint::builtin::Access::kWrite,
+ tint::builtin::TexelFormat::kRgba32Float,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kDimensionsStorageWO3d,
+ "textureDimensions(t : texture_storage_3d<rgba32float>) -> vec3<u32>",
+ tint::builtin::Access::kWrite,
+ tint::builtin::TexelFormat::kRgba32Float,
+ type::TextureDimension::k3d,
+ TextureDataType::kF32,
+ "textureDimensions",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+
+ {
+ ValidTextureOverload::kGather2dF32,
+ "textureGather(component : i32,\n"
+ " t : texture_2d<T>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>) -> vec4<T>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureGather",
+ [](ProgramBuilder* b) {
+ return b->ExprList(0_i, // component
+ kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f)); // coords
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kGather2dOffsetF32,
+ "textureGather(component : u32,\n"
+ " t : texture_2d<T>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " offset : vec2<i32>) -> vec4<T>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureGather",
+ [](ProgramBuilder* b) {
+ return b->ExprList(0_u, // component
+ kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ b->Call<vec2<i32>>(3_i, 4_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kGather2dArrayF32,
+ "textureGather(component : i32,\n"
+ " t : texture_2d_array<T>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : i32) -> vec4<T>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureGather",
+ [](ProgramBuilder* b) {
+ return b->ExprList(0_i, // component
+ kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_i); // array index
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kGather2dArrayOffsetF32,
+ "textureGather(component : u32,\n"
+ " t : texture_2d_array<T>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : u32,\n"
+ " offset : vec2<i32>) -> vec4<T>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureGather",
+ [](ProgramBuilder* b) {
+ return b->ExprList(0_u, // component
+ kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_u, // array_index
+ b->Call<vec2<i32>>(4_i, 5_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kGatherCubeF32,
+ "textureGather(component : i32,\n"
+ " t : texture_cube<T>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>) -> vec4<T>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCube,
+ TextureDataType::kF32,
+ "textureGather",
+ [](ProgramBuilder* b) {
+ return b->ExprList(0_i, // component
+ kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f)); // coords
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kGatherCubeArrayF32,
+ "textureGather(component : u32,\n"
+ " t : texture_cube_array<T>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " array_index : u32) -> vec4<T>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureGather",
+ [](ProgramBuilder* b) {
+ return b->ExprList(0_u, // component
+ kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_u); // array_index
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kGatherDepth2dF32,
+ "textureGather(t : texture_depth_2d,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>) -> vec4<f32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureGather",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f)); // coords
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kGatherDepth2dOffsetF32,
+ "textureGather(t : texture_depth_2d,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " offset : vec2<i32>) -> vec4<f32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureGather",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ b->Call<vec2<i32>>(3_i, 4_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kGatherDepth2dArrayF32,
+ "textureGather(t : texture_depth_2d_array,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : u32) -> vec4<f32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureGather",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_u); // array_index
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kGatherDepth2dArrayOffsetF32,
+ "textureGather(t : texture_depth_2d_array,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : i32,\n"
+ " offset : vec2<i32>) -> vec4<f32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureGather",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_i, // array_index
+ b->Call<vec2<i32>>(4_i, 5_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kGatherDepthCubeF32,
+ "textureGather(t : texture_depth_cube,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>) -> vec4<f32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCube,
+ TextureDataType::kF32,
+ "textureGather",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f)); // coords
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kGatherDepthCubeArrayF32,
+ "textureGather(t : texture_depth_cube_array,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " array_index : u32) -> vec4<f32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureGather",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_u); // array_index
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kGatherCompareDepth2dF32,
+ "textureGatherCompare(t : texture_depth_2d,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec2<f32>,\n"
+ " depth_ref : f32) -> vec4<f32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureGatherCompare",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_f); // depth_ref
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kGatherCompareDepth2dOffsetF32,
+ "textureGatherCompare(t : texture_depth_2d,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec2<f32>,\n"
+ " depth_ref : f32,\n"
+ " offset : vec2<i32>) -> vec4<f32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureGatherCompare",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_f, // depth_ref
+ b->Call<vec2<i32>>(4_i, 5_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kGatherCompareDepth2dArrayF32,
+ "textureGatherCompare(t : texture_depth_2d_array,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : i32,\n"
+ " depth_ref : f32) -> vec4<f32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureGatherCompare",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_i, // array_index
+ 4_f); // depth_ref
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kGatherCompareDepth2dArrayOffsetF32,
+ "textureGatherCompare(t : texture_depth_2d_array,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : i32,\n"
+ " depth_ref : f32,\n"
+ " offset : vec2<i32>) -> vec4<f32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureGatherCompare",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_i, // array_index
+ 4_f, // depth_ref
+ b->Call<vec2<i32>>(5_i, 6_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kGatherCompareDepthCubeF32,
+ "textureGatherCompare(t : texture_depth_cube,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec3<f32>,\n"
+ " depth_ref : f32) -> vec4<f32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::kCube,
+ TextureDataType::kF32,
+ "textureGatherCompare",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_f); // depth_ref
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kGatherCompareDepthCubeArrayF32,
+ "textureGatherCompare(t : texture_depth_cube_array,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec3<f32>,\n"
+ " array_index : u32,\n"
+ " depth_ref : f32) -> vec4<f32>",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureGatherCompare",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_u, // array_index
+ 5_f); // depth_ref
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kNumLayers2dArray,
+ "textureNumLayers(t : texture_2d_array<f32>) -> u32",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureNumLayers",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kNumLayersCubeArray,
+ "textureNumLayers(t : texture_cube_array<f32>) -> u32",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureNumLayers",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kNumLayersDepth2dArray,
+ "textureNumLayers(t : texture_depth_2d_array) -> u32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureNumLayers",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kNumLayersDepthCubeArray,
+ "textureNumLayers(t : texture_depth_cube_array) -> u32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureNumLayers",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kNumLayersStorageWO2dArray,
+ "textureNumLayers(t : texture_storage_2d_array<rgba32float>) -> u32",
+ tint::builtin::Access::kWrite,
+ tint::builtin::TexelFormat::kRgba32Float,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureNumLayers",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kNumLevels2d,
+ "textureNumLevels(t : texture_2d<f32>) -> u32",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureNumLevels",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kNumLevels2dArray,
+ "textureNumLevels(t : texture_2d_array<f32>) -> u32",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureNumLevels",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kNumLevels3d,
+ "textureNumLevels(t : texture_3d<f32>) -> u32",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k3d,
+ TextureDataType::kF32,
+ "textureNumLevels",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kNumLevelsCube,
+ "textureNumLevels(t : texture_cube<f32>) -> u32",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCube,
+ TextureDataType::kF32,
+ "textureNumLevels",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kNumLevelsCubeArray,
+ "textureNumLevels(t : texture_cube_array<f32>) -> u32",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureNumLevels",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kNumLevelsDepth2d,
+ "textureNumLevels(t : texture_depth_2d) -> u32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureNumLevels",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kNumLevelsDepth2dArray,
+ "textureNumLevels(t : texture_depth_2d_array) -> u32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureNumLevels",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kNumLevelsDepthCube,
+ "textureNumLevels(t : texture_depth_cube) -> u32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCube,
+ TextureDataType::kF32,
+ "textureNumLevels",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kNumLevelsDepthCubeArray,
+ "textureNumLevels(t : texture_depth_cube_array) -> u32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureNumLevels",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kNumSamplesMultisampled2d,
+ "textureNumSamples(t : texture_multisampled_2d<f32>) -> u32",
+ TextureKind::kMultisampled,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureNumSamples",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kNumSamplesDepthMultisampled2d,
+ "textureNumSamples(t : texture_depth_multisampled_2d<f32>) -> u32",
+ TextureKind::kMultisampled,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureNumSamples",
+ [](ProgramBuilder* b) { return b->ExprList(kTextureName); },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSample1dF32,
+ "textureSample(t : texture_1d<f32>,\n"
+ " s : sampler,\n"
+ " coords : f32) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k1d,
+ TextureDataType::kF32,
+ "textureSample",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ 1_f); // coords
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSample2dF32,
+ "textureSample(t : texture_2d<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureSample",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f)); // coords
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSample2dOffsetF32,
+ "textureSample(t : texture_2d<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>\n"
+ " offset : vec2<i32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureSample",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ b->Call<vec2<i32>>(3_i, 4_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSample2dArrayF32,
+ "textureSample(t : texture_2d_array<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : i32) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureSample",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_i); // array_index
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSample2dArrayOffsetF32,
+ "textureSample(t : texture_2d_array<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : u32\n"
+ " offset : vec2<i32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureSample",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_u, // array_index
+ b->Call<vec2<i32>>(4_i, 5_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSample3dF32,
+ "textureSample(t : texture_3d<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k3d,
+ TextureDataType::kF32,
+ "textureSample",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f)); // coords
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSample3dOffsetF32,
+ "textureSample(t : texture_3d<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>\n"
+ " offset : vec3<i32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k3d,
+ TextureDataType::kF32,
+ "textureSample",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ b->Call<vec3<i32>>(4_i, 5_i, 6_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleCubeF32,
+ "textureSample(t : texture_cube<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCube,
+ TextureDataType::kF32,
+ "textureSample",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f)); // coords
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleCubeArrayF32,
+ "textureSample(t : texture_cube_array<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " array_index : i32) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureSample",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_i); // array_index
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleDepth2dF32,
+ "textureSample(t : texture_depth_2d,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureSample",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f)); // coords
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleDepth2dOffsetF32,
+ "textureSample(t : texture_depth_2d,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>\n"
+ " offset : vec2<i32>) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureSample",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ b->Call<vec2<i32>>(3_i, 4_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleDepth2dArrayF32,
+ "textureSample(t : texture_depth_2d_array,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : i32) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureSample",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_i); // array_index
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleDepth2dArrayOffsetF32,
+ "textureSample(t : texture_depth_2d_array,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : i32\n"
+ " offset : vec2<i32>) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureSample",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_i, // array_index
+ b->Call<vec2<i32>>(4_i, 5_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleDepthCubeF32,
+ "textureSample(t : texture_depth_cube,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCube,
+ TextureDataType::kF32,
+ "textureSample",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f)); // coords
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleDepthCubeArrayF32,
+ "textureSample(t : texture_depth_cube_array,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " array_index : u32) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureSample",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_u); // array_index
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleBias2dF32,
+ "textureSampleBias(t : texture_2d<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " bias : f32) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureSampleBias",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_f); // bias
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleBias2dOffsetF32,
+ "textureSampleBias(t : texture_2d<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " bias : f32,\n"
+ " offset : vec2<i32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureSampleBias",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_f, // bias
+ b->Call<vec2<i32>>(4_i, 5_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleBias2dArrayF32,
+ "textureSampleBias(t : texture_2d_array<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : u32,\n"
+ " bias : f32) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureSampleBias",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 4_u, // array_index
+ 3_f); // bias
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleBias2dArrayOffsetF32,
+ "textureSampleBias(t : texture_2d_array<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : i32,\n"
+ " bias : f32,\n"
+ " offset : vec2<i32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureSampleBias",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_i, // array_index
+ 4_f, // bias
+ b->Call<vec2<i32>>(5_i, 6_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleBias3dF32,
+ "textureSampleBias(t : texture_3d<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " bias : f32) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k3d,
+ TextureDataType::kF32,
+ "textureSampleBias",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_f); // bias
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleBias3dOffsetF32,
+ "textureSampleBias(t : texture_3d<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " bias : f32,\n"
+ " offset : vec3<i32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k3d,
+ TextureDataType::kF32,
+ "textureSampleBias",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_f, // bias
+ b->Call<vec3<i32>>(5_i, 6_i, 7_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleBiasCubeF32,
+ "textureSampleBias(t : texture_cube<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " bias : f32) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCube,
+ TextureDataType::kF32,
+ "textureSampleBias",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_f); // bias
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleBiasCubeArrayF32,
+ "textureSampleBias(t : texture_cube_array<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " array_index : i32,\n"
+ " bias : f32) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureSampleBias",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 3_i, // array_index
+ 4_f); // bias
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleLevel2dF32,
+ "textureSampleLevel(t : texture_2d<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " level : f32) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureSampleLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_f); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleLevel2dOffsetF32,
+ "textureSampleLevel(t : texture_2d<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " level : f32,\n"
+ " offset : vec2<i32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureSampleLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_f, // level
+ b->Call<vec2<i32>>(4_i, 5_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleLevel2dArrayF32,
+ "textureSampleLevel(t : texture_2d_array<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : i32,\n"
+ " level : f32) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureSampleLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_i, // array_index
+ 4_f); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleLevel2dArrayOffsetF32,
+ "textureSampleLevel(t : texture_2d_array<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : i32,\n"
+ " level : f32,\n"
+ " offset : vec2<i32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureSampleLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_i, // array_index
+ 4_f, // level
+ b->Call<vec2<i32>>(5_i, 6_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleLevel3dF32,
+ "textureSampleLevel(t : texture_3d<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " level : f32) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k3d,
+ TextureDataType::kF32,
+ "textureSampleLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_f); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleLevel3dOffsetF32,
+ "textureSampleLevel(t : texture_3d<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " level : f32,\n"
+ " offset : vec3<i32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k3d,
+ TextureDataType::kF32,
+ "textureSampleLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_f, // level
+ b->Call<vec3<i32>>(5_i, 6_i, 7_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleLevelCubeF32,
+ "textureSampleLevel(t : texture_cube<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " level : f32) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCube,
+ TextureDataType::kF32,
+ "textureSampleLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_f); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleLevelCubeArrayF32,
+ "textureSampleLevel(t : texture_cube_array<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " array_index : i32,\n"
+ " level : f32) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureSampleLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_i, // array_index
+ 5_f); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleLevelDepth2dF32,
+ "textureSampleLevel(t : texture_depth_2d,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " level : u32) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureSampleLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_u); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleLevelDepth2dOffsetF32,
+ "textureSampleLevel(t : texture_depth_2d,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " level : i32,\n"
+ " offset : vec2<i32>) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureSampleLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_i, // level
+ b->Call<vec2<i32>>(4_i, 5_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleLevelDepth2dArrayF32,
+ "textureSampleLevel(t : texture_depth_2d_array,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : u32,\n"
+ " level : u32) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureSampleLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_u, // array_index
+ 4_u); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleLevelDepth2dArrayOffsetF32,
+ "textureSampleLevel(t : texture_depth_2d_array,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : u32,\n"
+ " level : u32,\n"
+ " offset : vec2<i32>) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureSampleLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_u, // array_index
+ 4_u, // level
+ b->Call<vec2<i32>>(5_i, 6_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleLevelDepthCubeF32,
+ "textureSampleLevel(t : texture_depth_cube,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " level : i32) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCube,
+ TextureDataType::kF32,
+ "textureSampleLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_i); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleLevelDepthCubeArrayF32,
+ "textureSampleLevel(t : texture_depth_cube_array,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " array_index : i32,\n"
+ " level : i32) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureSampleLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_i, // array_index
+ 5_i); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleGrad2dF32,
+ "textureSampleGrad(t : texture_2d<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>\n"
+ " ddx : vec2<f32>,\n"
+ " ddy : vec2<f32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureSampleGrad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ b->Call<vec2<f32>>(3_f, 4_f), // ddx
+ b->Call<vec2<f32>>(5_f, 6_f)); // ddy
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleGrad2dOffsetF32,
+ "textureSampleGrad(t : texture_2d<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " ddx : vec2<f32>,\n"
+ " ddy : vec2<f32>,\n"
+ " offset : vec2<i32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureSampleGrad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ b->Call<vec2<f32>>(3_f, 4_f), // ddx
+ b->Call<vec2<f32>>(5_f, 6_f), // ddy
+ b->Call<vec2<i32>>(7_i, 7_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleGrad2dArrayF32,
+ "textureSampleGrad(t : texture_2d_array<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : i32,\n"
+ " ddx : vec2<f32>,\n"
+ " ddy : vec2<f32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureSampleGrad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_i, // array_index
+ b->Call<vec2<f32>>(4_f, 5_f), // ddx
+ b->Call<vec2<f32>>(6_f, 7_f)); // ddy
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleGrad2dArrayOffsetF32,
+ "textureSampleGrad(t : texture_2d_array<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : u32,\n"
+ " ddx : vec2<f32>,\n"
+ " ddy : vec2<f32>,\n"
+ " offset : vec2<i32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureSampleGrad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_u, // array_index
+ b->Call<vec2<f32>>(4_f, 5_f), // ddx
+ b->Call<vec2<f32>>(6_f, 7_f), // ddy
+ b->Call<vec2<i32>>(6_i, 7_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleGrad3dF32,
+ "textureSampleGrad(t : texture_3d<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " ddx : vec3<f32>,\n"
+ " ddy : vec3<f32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k3d,
+ TextureDataType::kF32,
+ "textureSampleGrad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ b->Call<vec3<f32>>(4_f, 5_f, 6_f), // ddx
+ b->Call<vec3<f32>>(7_f, 8_f, 9_f)); // ddy
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleGrad3dOffsetF32,
+ "textureSampleGrad(t : texture_3d<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " ddx : vec3<f32>,\n"
+ " ddy : vec3<f32>,\n"
+ " offset : vec3<i32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::k3d,
+ TextureDataType::kF32,
+ "textureSampleGrad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ b->Call<vec3<f32>>(4_f, 5_f, 6_f), // ddx
+ b->Call<vec3<f32>>(7_f, 8_f, 9_f), // ddy
+ b->Call<vec3<i32>>(0_i, 1_i, 2_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleGradCubeF32,
+ "textureSampleGrad(t : texture_cube<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " ddx : vec3<f32>,\n"
+ " ddy : vec3<f32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCube,
+ TextureDataType::kF32,
+ "textureSampleGrad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ b->Call<vec3<f32>>(4_f, 5_f, 6_f), // ddx
+ b->Call<vec3<f32>>(7_f, 8_f, 9_f)); // ddy
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleGradCubeArrayF32,
+ "textureSampleGrad(t : texture_cube_array<f32>,\n"
+ " s : sampler,\n"
+ " coords : vec3<f32>,\n"
+ " array_index : u32,\n"
+ " ddx : vec3<f32>,\n"
+ " ddy : vec3<f32>) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::SamplerKind::kSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureSampleGrad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_u, // array_index
+ b->Call<vec3<f32>>(5_f, 6_f, 7_f), // ddx
+ b->Call<vec3<f32>>(8_f, 9_f, 10_f)); // ddy
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleCompareDepth2dF32,
+ "textureSampleCompare(t : texture_depth_2d,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec2<f32>,\n"
+ " depth_ref : f32) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureSampleCompare",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_f); // depth_ref
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleCompareDepth2dOffsetF32,
+ "textureSampleCompare(t : texture_depth_2d,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec2<f32>,\n"
+ " depth_ref : f32,\n"
+ " offset : vec2<i32>) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureSampleCompare",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_f, // depth_ref
+ b->Call<vec2<i32>>(4_i, 5_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleCompareDepth2dArrayF32,
+ "textureSampleCompare(t : texture_depth_2d_array,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : i32,\n"
+ " depth_ref : f32) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureSampleCompare",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 4_i, // array_index
+ 3_f); // depth_ref
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleCompareDepth2dArrayOffsetF32,
+ "textureSampleCompare(t : texture_depth_2d_array,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : u32,\n"
+ " depth_ref : f32,\n"
+ " offset : vec2<i32>) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureSampleCompare",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 4_u, // array_index
+ 3_f, // depth_ref
+ b->Call<vec2<i32>>(5_i, 6_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleCompareDepthCubeF32,
+ "textureSampleCompare(t : texture_depth_cube,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec3<f32>,\n"
+ " depth_ref : f32) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::kCube,
+ TextureDataType::kF32,
+ "textureSampleCompare",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_f); // depth_ref
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleCompareDepthCubeArrayF32,
+ "textureSampleCompare(t : texture_depth_cube_array,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec3<f32>,\n"
+ " array_index : i32,\n"
+ " depth_ref : f32) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureSampleCompare",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_i, // array_index
+ 5_f); // depth_ref
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleCompareLevelDepth2dF32,
+ "textureSampleCompareLevel(t : texture_depth_2d,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec2<f32>,\n"
+ " depth_ref : f32) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureSampleCompareLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_f); // depth_ref
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleCompareLevelDepth2dOffsetF32,
+ "textureSampleCompareLevel(t : texture_depth_2d,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec2<f32>,\n"
+ " depth_ref : f32,\n"
+ " offset : vec2<i32>) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureSampleCompareLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_f, // depth_ref
+ b->Call<vec2<i32>>(4_i, 5_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleCompareLevelDepth2dArrayF32,
+ "textureSampleCompareLevel(t : texture_depth_2d_array,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : i32,\n"
+ " depth_ref : f32) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureSampleCompareLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_i, // array_index
+ 4_f); // depth_ref
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleCompareLevelDepth2dArrayOffsetF32,
+ "textureSampleCompareLevel(t : texture_depth_2d_array,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec2<f32>,\n"
+ " array_index : i32,\n"
+ " depth_ref : f32,\n"
+ " offset : vec2<i32>) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureSampleCompareLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec2<f32>>(1_f, 2_f), // coords
+ 3_i, // array_index
+ 4_f, // depth_ref
+ b->Call<vec2<i32>>(5_i, 6_i)); // offset
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleCompareLevelDepthCubeF32,
+ "textureSampleCompareLevel(t : texture_depth_cube,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec3<f32>,\n"
+ " depth_ref : f32) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::kCube,
+ TextureDataType::kF32,
+ "textureSampleCompareLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_f); // depth_ref
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kSampleCompareLevelDepthCubeArrayF32,
+ "textureSampleCompareLevel(t : texture_depth_cube_array,\n"
+ " s : sampler_comparison,\n"
+ " coords : vec3<f32>,\n"
+ " array_index : i32,\n"
+ " depth_ref : f32) -> f32",
+ TextureKind::kDepth,
+ type::SamplerKind::kComparisonSampler,
+ type::TextureDimension::kCubeArray,
+ TextureDataType::kF32,
+ "textureSampleCompareLevel",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ kSamplerName, // s
+ b->Call<vec3<f32>>(1_f, 2_f, 3_f), // coords
+ 4_i, // array_index
+ 5_f); // depth_ref
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoad1dLevelF32,
+ "textureLoad(t : texture_1d<f32>,\n"
+ " coords : u32,\n"
+ " level : u32) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::TextureDimension::k1d,
+ TextureDataType::kF32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ 1_u, // coords
+ 3_u); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoad1dLevelU32,
+ "textureLoad(t : texture_1d<u32>,\n"
+ " coords : i32,\n"
+ " level : i32) -> vec4<u32>",
+ TextureKind::kRegular,
+ type::TextureDimension::k1d,
+ TextureDataType::kU32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ 1_i, // coords
+ 3_i); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoad1dLevelI32,
+ "textureLoad(t : texture_1d<i32>,\n"
+ " coords : i32,\n"
+ " level : i32) -> vec4<i32>",
+ TextureKind::kRegular,
+ type::TextureDimension::k1d,
+ TextureDataType::kI32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ 1_i, // coords
+ 3_i); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoad2dLevelF32,
+ "textureLoad(t : texture_2d<f32>,\n"
+ " coords : vec2<u32>,\n"
+ " level : u32) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec2<u32>>(1_u, 2_u), // coords
+ 3_u); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoad2dLevelU32,
+ "textureLoad(t : texture_2d<u32>,\n"
+ " coords : vec2<i32>,\n"
+ " level : i32) -> vec4<u32>",
+ TextureKind::kRegular,
+ type::TextureDimension::k2d,
+ TextureDataType::kU32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec2<i32>>(1_i, 2_i), // coords
+ 3_i); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoad2dLevelI32,
+ "textureLoad(t : texture_2d<i32>,\n"
+ " coords : vec2<u32>,\n"
+ " level : u32) -> vec4<i32>",
+ TextureKind::kRegular,
+ type::TextureDimension::k2d,
+ TextureDataType::kI32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec2<u32>>(1_u, 2_u), // coords
+ 3_u); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoad2dArrayLevelF32,
+ "textureLoad(t : texture_2d_array<f32>,\n"
+ " coords : vec2<i32>,\n"
+ " array_index : i32,\n"
+ " level : i32) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec2<i32>>(1_i, 2_i), // coords
+ 3_i, // array_index
+ 4_i); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoad2dArrayLevelU32,
+ "textureLoad(t : texture_2d_array<u32>,\n"
+ " coords : vec2<i32>,\n"
+ " array_index : i32,\n"
+ " level : i32) -> vec4<u32>",
+ TextureKind::kRegular,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kU32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec2<i32>>(1_i, 2_i), // coords
+ 3_i, // array_index
+ 4_i); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoad2dArrayLevelI32,
+ "textureLoad(t : texture_2d_array<i32>,\n"
+ " coords : vec2<u32>,\n"
+ " array_index : u32,\n"
+ " level : u32) -> vec4<i32>",
+ TextureKind::kRegular,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kI32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec2<u32>>(1_u, 2_u), // coords
+ 3_u, // array_index
+ 4_u); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoad3dLevelF32,
+ "textureLoad(t : texture_3d<f32>,\n"
+ " coords : vec3<i32>,\n"
+ " level : i32) -> vec4<f32>",
+ TextureKind::kRegular,
+ type::TextureDimension::k3d,
+ TextureDataType::kF32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec3<i32>>(1_i, 2_i, 3_i), // coords
+ 4_i); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoad3dLevelU32,
+ "textureLoad(t : texture_3d<u32>,\n"
+ " coords : vec3<i32>,\n"
+ " level : i32) -> vec4<u32>",
+ TextureKind::kRegular,
+ type::TextureDimension::k3d,
+ TextureDataType::kU32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec3<i32>>(1_i, 2_i, 3_i), // coords
+ 4_i); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoad3dLevelI32,
+ "textureLoad(t : texture_3d<i32>,\n"
+ " coords : vec3<u32>,\n"
+ " level : u32) -> vec4<i32>",
+ TextureKind::kRegular,
+ type::TextureDimension::k3d,
+ TextureDataType::kI32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec3<u32>>(1_u, 2_u, 3_u), // coords
+ 4_u); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoadMultisampled2dF32,
+ "textureLoad(t : texture_multisampled_2d<f32>,\n"
+ " coords : vec2<i32>,\n"
+ " sample_index : i32) -> vec4<f32>",
+ TextureKind::kMultisampled,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec2<i32>>(1_i, 2_i), // coords
+ 3_i); // sample_index
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoadMultisampled2dU32,
+ "textureLoad(t : texture_multisampled_2d<u32>,\n"
+ " coords : vec2<i32>,\n"
+ " sample_index : i32) -> vec4<u32>",
+ TextureKind::kMultisampled,
+ type::TextureDimension::k2d,
+ TextureDataType::kU32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec2<i32>>(1_i, 2_i), // coords
+ 3_i); // sample_index
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoadMultisampled2dI32,
+ "textureLoad(t : texture_multisampled_2d<i32>,\n"
+ " coords : vec2<u32>,\n"
+ " sample_index : u32) -> vec4<i32>",
+ TextureKind::kMultisampled,
+ type::TextureDimension::k2d,
+ TextureDataType::kI32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec2<u32>>(1_u, 2_u), // coords
+ 3_u); // sample_index
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoadDepth2dLevelF32,
+ "textureLoad(t : texture_depth_2d,\n"
+ " coords : vec2<i32>,\n"
+ " level : i32) -> f32",
+ TextureKind::kDepth,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec2<i32>>(1_i, 2_i), // coords
+ 3_i); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoadDepth2dArrayLevelF32,
+ "textureLoad(t : texture_depth_2d_array,\n"
+ " coords : vec2<u32>,\n"
+ " array_index : u32,\n"
+ " level : u32) -> f32",
+ TextureKind::kDepth,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec2<u32>>(1_u, 2_u), // coords
+ 3_u, // array_index
+ 4_u); // level
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kLoadDepthMultisampled2dF32,
+ "textureLoad(t : texture_depth_multisampled_2d,\n"
+ " coords : vec2<u32>,\n"
+ " sample_index : u32) -> f32",
+ TextureKind::kDepthMultisampled,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureLoad",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec2<u32>>(1_u, 2_u), // coords
+ 3_u); // sample_index
+ },
+ /* returns value */ true,
+ },
+ {
+ ValidTextureOverload::kStoreWO1dRgba32float,
+ "textureStore(t : texture_storage_1d<rgba32float>,\n"
+ " coords : i32,\n"
+ " value : vec4<T>)",
+ tint::builtin::Access::kWrite,
+ tint::builtin::TexelFormat::kRgba32Float,
+ type::TextureDimension::k1d,
+ TextureDataType::kF32,
+ "textureStore",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ 1_i, // coords
+ b->Call<vec4<f32>>(2_f, 3_f, 4_f, 5_f)); // value
+ },
+ /* returns value */ false,
+ },
+ {
+ ValidTextureOverload::kStoreWO2dRgba32float,
+ "textureStore(t : texture_storage_2d<rgba32float>,\n"
+ " coords : vec2<i32>,\n"
+ " value : vec4<T>)",
+ tint::builtin::Access::kWrite,
+ tint::builtin::TexelFormat::kRgba32Float,
+ type::TextureDimension::k2d,
+ TextureDataType::kF32,
+ "textureStore",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec2<i32>>(1_i, 2_i), // coords
+ b->Call<vec4<f32>>(3_f, 4_f, 5_f, 6_f)); // value
+ },
+ /* returns value */ false,
+ },
+ {
+ ValidTextureOverload::kStoreWO2dArrayRgba32float,
+ "textureStore(t : texture_storage_2d_array<rgba32float>,\n"
+ " coords : vec2<u32>,\n"
+ " array_index : u32,\n"
+ " value : vec4<T>)",
+ tint::builtin::Access::kWrite,
+ tint::builtin::TexelFormat::kRgba32Float,
+ type::TextureDimension::k2dArray,
+ TextureDataType::kF32,
+ "textureStore",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec2<u32>>(1_u, 2_u), // coords
+ 3_u, // array_index
+ b->Call<vec4<f32>>(4_f, 5_f, 6_f, 7_f)); // value
+ },
+ /* returns value */ false,
+ },
+ {
+ ValidTextureOverload::kStoreWO3dRgba32float,
+ "textureStore(t : texture_storage_3d<rgba32float>,\n"
+ " coords : vec3<u32>,\n"
+ " value : vec4<T>)",
+ tint::builtin::Access::kWrite,
+ tint::builtin::TexelFormat::kRgba32Float,
+ type::TextureDimension::k3d,
+ TextureDataType::kF32,
+ "textureStore",
+ [](ProgramBuilder* b) {
+ return b->ExprList(kTextureName, // t
+ b->Call<vec3<u32>>(1_u, 2_u, 3_u), // coords
+ b->Call<vec4<f32>>(4_f, 5_f, 6_f, 7_f)); // value
+ },
+ /* returns value */ false,
+ },
+ };
+}
+
+bool ReturnsVoid(ValidTextureOverload texture_overload) {
+ switch (texture_overload) {
+ case ValidTextureOverload::kStoreWO1dRgba32float:
+ case ValidTextureOverload::kStoreWO2dRgba32float:
+ case ValidTextureOverload::kStoreWO2dArrayRgba32float:
+ case ValidTextureOverload::kStoreWO3dRgba32float:
+ return true;
+ default:
+ return false;
+ }
+}
+
+} // namespace tint::ast::test
diff --git a/src/tint/lang/wgsl/ast/builtin_texture_helper_test.h b/src/tint/lang/wgsl/ast/builtin_texture_helper_test.h
new file mode 100644
index 0000000..a928835
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/builtin_texture_helper_test.h
@@ -0,0 +1,273 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_BUILTIN_TEXTURE_HELPER_TEST_H_
+#define SRC_TINT_LANG_WGSL_AST_BUILTIN_TEXTURE_HELPER_TEST_H_
+
+#include <vector>
+
+#include "src/tint/builtin/access.h"
+#include "src/tint/builtin/texel_format.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/type/storage_texture.h"
+#include "src/tint/type/texture_dimension.h"
+
+namespace tint::ast::test {
+
+/// The name of the texture global variable used by the tests.
+static constexpr const char* kTextureName = "Texture";
+
+/// The name of the sampler global variable used by the tests.
+static constexpr const char* kSamplerName = "Sampler";
+
+enum class TextureKind { kRegular, kDepth, kDepthMultisampled, kMultisampled, kStorage };
+enum class TextureDataType { kF32, kU32, kI32 };
+
+std::ostream& operator<<(std::ostream& out, const TextureKind& kind);
+std::ostream& operator<<(std::ostream& out, const TextureDataType& ty);
+
+/// Non-exhaustive list of valid texture overloads
+enum class ValidTextureOverload {
+ kDimensions1d,
+ kDimensions2d,
+ kDimensions2dLevel,
+ kDimensions2dArray,
+ kDimensions2dArrayLevel,
+ kDimensions3d,
+ kDimensions3dLevel,
+ kDimensionsCube,
+ kDimensionsCubeLevel,
+ kDimensionsCubeArray,
+ kDimensionsCubeArrayLevel,
+ kDimensionsMultisampled2d,
+ kDimensionsDepth2d,
+ kDimensionsDepth2dLevel,
+ kDimensionsDepth2dArray,
+ kDimensionsDepth2dArrayLevel,
+ kDimensionsDepthCube,
+ kDimensionsDepthCubeLevel,
+ kDimensionsDepthCubeArray,
+ kDimensionsDepthCubeArrayLevel,
+ kDimensionsDepthMultisampled2d,
+ kDimensionsStorageWO1d,
+ kDimensionsStorageWO2d,
+ kDimensionsStorageWO2dArray,
+ kDimensionsStorageWO3d,
+ kGather2dF32,
+ kGather2dOffsetF32,
+ kGather2dArrayF32,
+ kGather2dArrayOffsetF32,
+ kGatherCubeF32,
+ kGatherCubeArrayF32,
+ kGatherDepth2dF32,
+ kGatherDepth2dOffsetF32,
+ kGatherDepth2dArrayF32,
+ kGatherDepth2dArrayOffsetF32,
+ kGatherDepthCubeF32,
+ kGatherDepthCubeArrayF32,
+ kGatherCompareDepth2dF32,
+ kGatherCompareDepth2dOffsetF32,
+ kGatherCompareDepth2dArrayF32,
+ kGatherCompareDepth2dArrayOffsetF32,
+ kGatherCompareDepthCubeF32,
+ kGatherCompareDepthCubeArrayF32,
+ kNumLayers2dArray,
+ kNumLayersCubeArray,
+ kNumLayersDepth2dArray,
+ kNumLayersDepthCubeArray,
+ kNumLayersStorageWO2dArray,
+ kNumLevels2d,
+ kNumLevels2dArray,
+ kNumLevels3d,
+ kNumLevelsCube,
+ kNumLevelsCubeArray,
+ kNumLevelsDepth2d,
+ kNumLevelsDepth2dArray,
+ kNumLevelsDepthCube,
+ kNumLevelsDepthCubeArray,
+ kNumSamplesMultisampled2d,
+ kNumSamplesDepthMultisampled2d,
+ kSample1dF32,
+ kSample2dF32,
+ kSample2dOffsetF32,
+ kSample2dArrayF32,
+ kSample2dArrayOffsetF32,
+ kSample3dF32,
+ kSample3dOffsetF32,
+ kSampleCubeF32,
+ kSampleCubeArrayF32,
+ kSampleDepth2dF32,
+ kSampleDepth2dOffsetF32,
+ kSampleDepth2dArrayF32,
+ kSampleDepth2dArrayOffsetF32,
+ kSampleDepthCubeF32,
+ kSampleDepthCubeArrayF32,
+ kSampleBias2dF32,
+ kSampleBias2dOffsetF32,
+ kSampleBias2dArrayF32,
+ kSampleBias2dArrayOffsetF32,
+ kSampleBias3dF32,
+ kSampleBias3dOffsetF32,
+ kSampleBiasCubeF32,
+ kSampleBiasCubeArrayF32,
+ kSampleLevel2dF32,
+ kSampleLevel2dOffsetF32,
+ kSampleLevel2dArrayF32,
+ kSampleLevel2dArrayOffsetF32,
+ kSampleLevel3dF32,
+ kSampleLevel3dOffsetF32,
+ kSampleLevelCubeF32,
+ kSampleLevelCubeArrayF32,
+ kSampleLevelDepth2dF32,
+ kSampleLevelDepth2dOffsetF32,
+ kSampleLevelDepth2dArrayF32,
+ kSampleLevelDepth2dArrayOffsetF32,
+ kSampleLevelDepthCubeF32,
+ kSampleLevelDepthCubeArrayF32,
+ kSampleGrad2dF32,
+ kSampleGrad2dOffsetF32,
+ kSampleGrad2dArrayF32,
+ kSampleGrad2dArrayOffsetF32,
+ kSampleGrad3dF32,
+ kSampleGrad3dOffsetF32,
+ kSampleGradCubeF32,
+ kSampleGradCubeArrayF32,
+ kSampleCompareDepth2dF32,
+ kSampleCompareDepth2dOffsetF32,
+ kSampleCompareDepth2dArrayF32,
+ kSampleCompareDepth2dArrayOffsetF32,
+ kSampleCompareDepthCubeF32,
+ kSampleCompareDepthCubeArrayF32,
+ kSampleCompareLevelDepth2dF32,
+ kSampleCompareLevelDepth2dOffsetF32,
+ kSampleCompareLevelDepth2dArrayF32,
+ kSampleCompareLevelDepth2dArrayOffsetF32,
+ kSampleCompareLevelDepthCubeF32,
+ kSampleCompareLevelDepthCubeArrayF32,
+ kLoad1dLevelF32,
+ kLoad1dLevelU32,
+ kLoad1dLevelI32,
+ kLoad2dLevelF32,
+ kLoad2dLevelU32,
+ kLoad2dLevelI32,
+ kLoad2dArrayLevelF32,
+ kLoad2dArrayLevelU32,
+ kLoad2dArrayLevelI32,
+ kLoad3dLevelF32,
+ kLoad3dLevelU32,
+ kLoad3dLevelI32,
+ kLoadMultisampled2dF32,
+ kLoadMultisampled2dU32,
+ kLoadMultisampled2dI32,
+ kLoadDepth2dLevelF32,
+ kLoadDepth2dArrayLevelF32,
+ kLoadDepthMultisampled2dF32,
+ kStoreWO1dRgba32float, // Not permutated for all texel formats
+ kStoreWO2dRgba32float, // Not permutated for all texel formats
+ kStoreWO2dArrayRgba32float, // Not permutated for all texel formats
+ kStoreWO3dRgba32float, // Not permutated for all texel formats
+};
+
+/// @param texture_overload the ValidTextureOverload
+/// @returns true if the ValidTextureOverload builtin returns no value.
+bool ReturnsVoid(ValidTextureOverload texture_overload);
+
+/// Describes a texture builtin overload
+struct TextureOverloadCase {
+ /// Args is a list of Expression used as arguments to the texture overload case.
+ using Args = utils::Vector<const Expression*, 8>;
+
+ /// Constructor for textureSample...() functions
+ TextureOverloadCase(ValidTextureOverload,
+ const char*,
+ TextureKind,
+ type::SamplerKind,
+ type::TextureDimension,
+ TextureDataType,
+ const char*,
+ std::function<Args(ProgramBuilder*)>,
+ bool /* returns_value */);
+ /// Constructor for textureLoad() functions with non-storage textures
+ TextureOverloadCase(ValidTextureOverload,
+ const char*,
+ TextureKind,
+ type::TextureDimension,
+ TextureDataType,
+ const char*,
+ std::function<Args(ProgramBuilder*)>,
+ bool /* returns_value */);
+ /// Constructor for textureLoad() with storage textures
+ TextureOverloadCase(ValidTextureOverload,
+ const char*,
+ tint::builtin::Access,
+ tint::builtin::TexelFormat,
+ type::TextureDimension,
+ TextureDataType,
+ const char*,
+ std::function<Args(ProgramBuilder*)>,
+ bool /* returns_value */);
+ /// Copy constructor
+ TextureOverloadCase(const TextureOverloadCase&);
+
+ /// Destructor
+ ~TextureOverloadCase();
+
+ /// @return a vector containing a large number (non-exhaustive) of valid
+ /// texture overloads.
+ static std::vector<TextureOverloadCase> ValidCases();
+
+ /// @param builder the AST builder used for the test
+ /// @returns the vector component type of the texture function return value
+ Type BuildResultVectorComponentType(ProgramBuilder* builder) const;
+ /// @param builder the AST builder used for the test
+ /// @returns a variable holding the test texture, automatically registered as
+ /// a global variable.
+ const Variable* BuildTextureVariable(ProgramBuilder* builder) const;
+ /// @param builder the AST builder used for the test
+ /// @returns a Variable holding the test sampler, automatically registered as
+ /// a global variable.
+ const Variable* BuildSamplerVariable(ProgramBuilder* builder) const;
+
+ /// The enumerator for this overload
+ const ValidTextureOverload overload;
+ /// A human readable description of the overload
+ const char* const description;
+ /// The texture kind for the texture parameter
+ const TextureKind texture_kind;
+ /// The sampler kind for the sampler parameter
+ /// Used only when texture_kind is not kStorage
+ type::SamplerKind const sampler_kind = type::SamplerKind::kSampler;
+ /// The access control for the storage texture
+ /// Used only when texture_kind is kStorage
+ tint::builtin::Access const access = tint::builtin::Access::kReadWrite;
+ /// The image format for the storage texture
+ /// Used only when texture_kind is kStorage
+ tint::builtin::TexelFormat const texel_format = tint::builtin::TexelFormat::kUndefined;
+ /// The dimensions of the texture parameter
+ type::TextureDimension const texture_dimension;
+ /// The data type of the texture parameter
+ const TextureDataType texture_data_type;
+ /// Name of the function. e.g. `textureSample`, `textureSampleGrad`, etc
+ const char* const function;
+ /// A function that builds the AST arguments for the overload
+ std::function<Args(ProgramBuilder*)> const args;
+ /// True if the function returns a value
+ const bool returns_value;
+};
+
+std::ostream& operator<<(std::ostream& out, const TextureOverloadCase& data);
+
+} // namespace tint::ast::test
+
+#endif // SRC_TINT_LANG_WGSL_AST_BUILTIN_TEXTURE_HELPER_TEST_H_
diff --git a/src/tint/lang/wgsl/ast/call_expression.cc b/src/tint/lang/wgsl/ast/call_expression.cc
new file mode 100644
index 0000000..5f2663a
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/call_expression.cc
@@ -0,0 +1,49 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/call_expression.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::CallExpression);
+
+namespace tint::ast {
+
+CallExpression::CallExpression(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const IdentifierExpression* t,
+ utils::VectorRef<const Expression*> a)
+ : Base(pid, nid, src), target(t), args(std::move(a)) {
+ TINT_ASSERT(AST, target);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, target, program_id);
+ for (auto* arg : args) {
+ TINT_ASSERT(AST, arg);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, arg, program_id);
+ }
+}
+
+CallExpression::~CallExpression() = default;
+
+const CallExpression* CallExpression::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto p = ctx->Clone(args);
+ auto t = ctx->Clone(target);
+ return ctx->dst->create<CallExpression>(src, t, std::move(p));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/call_expression.h b/src/tint/lang/wgsl/ast/call_expression.h
new file mode 100644
index 0000000..ee76612
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/call_expression.h
@@ -0,0 +1,64 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_CALL_EXPRESSION_H_
+#define SRC_TINT_LANG_WGSL_AST_CALL_EXPRESSION_H_
+
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+// Forward declarations
+namespace tint::ast {
+class IdentifierExpression;
+} // namespace tint::ast
+
+namespace tint::ast {
+
+/// A call expression - represents either a:
+/// * sem::Function
+/// * sem::Builtin
+/// * sem::ValueConstructor
+/// * sem::ValueConversion
+class CallExpression final : public utils::Castable<CallExpression, Expression> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the call expression source
+ /// @param target the target of the call
+ /// @param args the arguments
+ CallExpression(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ const IdentifierExpression* target,
+ utils::VectorRef<const Expression*> args);
+
+ /// Destructor
+ ~CallExpression() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const CallExpression* Clone(CloneContext* ctx) const override;
+
+ /// The target function or type
+ const IdentifierExpression* target;
+
+ /// The arguments
+ const utils::Vector<const Expression*, 8> args;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_CALL_EXPRESSION_H_
diff --git a/src/tint/lang/wgsl/ast/call_expression_test.cc b/src/tint/lang/wgsl/ast/call_expression_test.cc
new file mode 100644
index 0000000..c9117ee
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/call_expression_test.cc
@@ -0,0 +1,134 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using CallExpressionTest = TestHelper;
+
+TEST_F(CallExpressionTest, CreationIdentifier) {
+ auto* func = Expr("func");
+ utils::Vector params{
+ Expr("param1"),
+ Expr("param2"),
+ };
+
+ auto* stmt = Call(func, params);
+ EXPECT_EQ(stmt->target, func);
+
+ const auto& vec = stmt->args;
+ ASSERT_EQ(vec.Length(), 2u);
+ EXPECT_EQ(vec[0], params[0]);
+ EXPECT_EQ(vec[1], params[1]);
+}
+
+TEST_F(CallExpressionTest, CreationIdentifier_WithSource) {
+ auto* func = Expr("func");
+ auto* stmt = Call(Source{{20, 2}}, func);
+ EXPECT_EQ(stmt->target, func);
+
+ auto src = stmt->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(CallExpressionTest, CreationType) {
+ auto* type = Expr(ty.f32());
+ utils::Vector params{
+ Expr("param1"),
+ Expr("param2"),
+ };
+
+ auto* stmt = Call(type, params);
+ EXPECT_EQ(stmt->target, type);
+
+ const auto& vec = stmt->args;
+ ASSERT_EQ(vec.Length(), 2u);
+ EXPECT_EQ(vec[0], params[0]);
+ EXPECT_EQ(vec[1], params[1]);
+}
+
+TEST_F(CallExpressionTest, CreationType_WithSource) {
+ auto* type = Expr(ty.f32());
+ auto* stmt = Call(Source{{20, 2}}, type);
+ EXPECT_EQ(stmt->target, type);
+
+ auto src = stmt->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(CallExpressionTest, IsCall) {
+ auto* func = Expr("func");
+ auto* stmt = Call(func);
+ EXPECT_TRUE(stmt->Is<CallExpression>());
+}
+
+TEST_F(CallExpressionTest, Assert_Null_Identifier) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.Call(static_cast<Identifier*>(nullptr));
+ },
+ "internal compiler error");
+}
+
+TEST_F(CallExpressionTest, Assert_Null_Param) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.Call(b.Ident("func"), utils::Vector{
+ b.Expr("param1"),
+ nullptr,
+ b.Expr("param2"),
+ });
+ },
+ "internal compiler error");
+}
+
+TEST_F(CallExpressionTest, Assert_DifferentProgramID_Identifier) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Call(b2.Ident("func"));
+ },
+ "internal compiler error");
+}
+
+TEST_F(CallExpressionTest, Assert_DifferentProgramID_Type) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Call(b2.ty.f32());
+ },
+ "internal compiler error");
+}
+
+TEST_F(CallExpressionTest, Assert_DifferentProgramID_Param) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Call(b1.Ident("func"), b2.Expr("param1"));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/call_statement.cc b/src/tint/lang/wgsl/ast/call_statement.cc
new file mode 100644
index 0000000..4ae4d64
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/call_statement.cc
@@ -0,0 +1,41 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::CallStatement);
+
+namespace tint::ast {
+
+CallStatement::CallStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const CallExpression* call)
+ : Base(pid, nid, src), expr(call) {
+ TINT_ASSERT(AST, expr);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, expr, program_id);
+}
+
+CallStatement::~CallStatement() = default;
+
+const CallStatement* CallStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* call = ctx->Clone(expr);
+ return ctx->dst->create<CallStatement>(src, call);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/call_statement.h b/src/tint/lang/wgsl/ast/call_statement.h
new file mode 100644
index 0000000..fd0c3e7
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/call_statement.h
@@ -0,0 +1,48 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_CALL_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_CALL_STATEMENT_H_
+
+#include "src/tint/lang/wgsl/ast/call_expression.h"
+#include "src/tint/lang/wgsl/ast/statement.h"
+
+namespace tint::ast {
+
+/// A call expression
+class CallStatement final : public utils::Castable<CallStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node for the statement
+ /// @param call the function
+ CallStatement(ProgramID pid, NodeID nid, const Source& src, const CallExpression* call);
+
+ /// Destructor
+ ~CallStatement() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const CallStatement* Clone(CloneContext* ctx) const override;
+
+ /// The call expression
+ const CallExpression* const expr;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_CALL_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/call_statement_test.cc b/src/tint/lang/wgsl/ast/call_statement_test.cc
new file mode 100644
index 0000000..bcf3002
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/call_statement_test.cc
@@ -0,0 +1,57 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using CallStatementTest = TestHelper;
+
+TEST_F(CallStatementTest, Creation) {
+ auto* expr = Call("func");
+
+ auto* c = CallStmt(expr);
+ EXPECT_EQ(c->expr, expr);
+}
+
+TEST_F(CallStatementTest, IsCall) {
+ auto* c = CallStmt(Call("f"));
+ EXPECT_TRUE(c->Is<CallStatement>());
+}
+
+TEST_F(CallStatementTest, Assert_Null_Call) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.CallStmt(nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(CallStatementTest, Assert_DifferentProgramID_Call) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.CallStmt(b2.Call("func"));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/case_selector.cc b/src/tint/lang/wgsl/ast/case_selector.cc
new file mode 100644
index 0000000..2ba540e
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/case_selector.cc
@@ -0,0 +1,37 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/case_selector.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::CaseSelector);
+
+namespace tint::ast {
+
+CaseSelector::CaseSelector(ProgramID pid, NodeID nid, const Source& src, const Expression* e)
+ : Base(pid, nid, src), expr(e) {}
+
+CaseSelector::~CaseSelector() = default;
+
+const CaseSelector* CaseSelector::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto ex = ctx->Clone(expr);
+ return ctx->dst->create<CaseSelector>(src, ex);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/case_selector.h b/src/tint/lang/wgsl/ast/case_selector.h
new file mode 100644
index 0000000..b91dce9
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/case_selector.h
@@ -0,0 +1,52 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_CASE_SELECTOR_H_
+#define SRC_TINT_LANG_WGSL_AST_CASE_SELECTOR_H_
+
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/block_statement.h"
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+namespace tint::ast {
+
+/// A case selector
+class CaseSelector final : public utils::Castable<CaseSelector, Node> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param expr the selector expression, |nullptr| for a `default` selector
+ CaseSelector(ProgramID pid, NodeID nid, const Source& src, const Expression* expr = nullptr);
+
+ /// Destructor
+ ~CaseSelector() override;
+
+ /// @returns true if this is a default statement
+ bool IsDefault() const { return expr == nullptr; }
+
+ /// Clones this node and all transitive child nodes using the `CloneContext` `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const CaseSelector* Clone(CloneContext* ctx) const override;
+
+ /// The selector, nullptr for a default selector
+ const Expression* const expr = nullptr;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_CASE_SELECTOR_H_
diff --git a/src/tint/lang/wgsl/ast/case_selector_test.cc b/src/tint/lang/wgsl/ast/case_selector_test.cc
new file mode 100644
index 0000000..9b5e36e
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/case_selector_test.cc
@@ -0,0 +1,40 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/case_selector.h"
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast {
+namespace {
+
+using CaseSelectorTest = TestHelper;
+
+TEST_F(CaseSelectorTest, NonDefault) {
+ auto* e = Expr(2_i);
+ auto* c = CaseSelector(e);
+ EXPECT_FALSE(c->IsDefault());
+ EXPECT_EQ(e, c->expr);
+}
+
+TEST_F(CaseSelectorTest, Default) {
+ auto* c = DefaultCaseSelector();
+ EXPECT_TRUE(c->IsDefault());
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/case_statement.cc b/src/tint/lang/wgsl/ast/case_statement.cc
new file mode 100644
index 0000000..16e66b2
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/case_statement.cc
@@ -0,0 +1,59 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/case_statement.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::CaseStatement);
+
+namespace tint::ast {
+
+CaseStatement::CaseStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ utils::VectorRef<const CaseSelector*> s,
+ const BlockStatement* b)
+ : Base(pid, nid, src), selectors(std::move(s)), body(b) {
+ TINT_ASSERT(AST, body);
+ TINT_ASSERT(AST, !selectors.IsEmpty());
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, body, program_id);
+ for (auto* selector : selectors) {
+ TINT_ASSERT(AST, selector);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, selector, program_id);
+ }
+}
+
+CaseStatement::~CaseStatement() = default;
+
+bool CaseStatement::ContainsDefault() const {
+ for (const auto* sel : selectors) {
+ if (sel->IsDefault()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+const CaseStatement* CaseStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto sel = ctx->Clone(selectors);
+ auto* b = ctx->Clone(body);
+ return ctx->dst->create<CaseStatement>(src, std::move(sel), b);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/case_statement.h b/src/tint/lang/wgsl/ast/case_statement.h
new file mode 100644
index 0000000..026e14b
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/case_statement.h
@@ -0,0 +1,61 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_CASE_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_CASE_STATEMENT_H_
+
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/block_statement.h"
+#include "src/tint/lang/wgsl/ast/case_selector.h"
+
+namespace tint::ast {
+
+/// A case statement
+class CaseStatement final : public utils::Castable<CaseStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param selectors the case selectors
+ /// @param body the case body
+ CaseStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ utils::VectorRef<const CaseSelector*> selectors,
+ const BlockStatement* body);
+
+ /// Destructor
+ ~CaseStatement() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const CaseStatement* Clone(CloneContext* ctx) const override;
+
+ /// @returns true if this item contains a default selector
+ bool ContainsDefault() const;
+
+ /// The case selectors, empty if none set
+ const utils::Vector<const CaseSelector*, 4> selectors;
+
+ /// The case body
+ const BlockStatement* const body;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_CASE_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/case_statement_test.cc b/src/tint/lang/wgsl/ast/case_statement_test.cc
new file mode 100644
index 0000000..4ef5471
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/case_statement_test.cc
@@ -0,0 +1,131 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/case_statement.h"
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/if_statement.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast {
+namespace {
+
+using CaseStatementTest = TestHelper;
+
+TEST_F(CaseStatementTest, Creation_i32) {
+ auto* selector = CaseSelector(2_i);
+ utils::Vector b{selector};
+
+ auto* discard = create<DiscardStatement>();
+ auto* body = create<BlockStatement>(utils::Vector{discard}, utils::Empty);
+
+ auto* c = create<CaseStatement>(b, body);
+ ASSERT_EQ(c->selectors.Length(), 1u);
+ EXPECT_EQ(c->selectors[0], selector);
+ ASSERT_EQ(c->body->statements.Length(), 1u);
+ EXPECT_EQ(c->body->statements[0], discard);
+}
+
+TEST_F(CaseStatementTest, Creation_u32) {
+ auto* selector = CaseSelector(2_u);
+ utils::Vector b{selector};
+
+ auto* discard = create<DiscardStatement>();
+ auto* body = create<BlockStatement>(utils::Vector{discard}, utils::Empty);
+
+ auto* c = create<CaseStatement>(b, body);
+ ASSERT_EQ(c->selectors.Length(), 1u);
+ EXPECT_EQ(c->selectors[0], selector);
+ ASSERT_EQ(c->body->statements.Length(), 1u);
+ EXPECT_EQ(c->body->statements[0], discard);
+}
+
+TEST_F(CaseStatementTest, ContainsDefault_WithDefault) {
+ utils::Vector b{CaseSelector(2_u), DefaultCaseSelector()};
+ auto* c = create<CaseStatement>(b, create<BlockStatement>(utils::Empty, utils::Empty));
+ EXPECT_TRUE(c->ContainsDefault());
+}
+
+TEST_F(CaseStatementTest, ContainsDefault_WithOutDefault) {
+ utils::Vector b{CaseSelector(2_u), CaseSelector(3_u)};
+ auto* c = create<CaseStatement>(b, create<BlockStatement>(utils::Empty, utils::Empty));
+ EXPECT_FALSE(c->ContainsDefault());
+}
+
+TEST_F(CaseStatementTest, Creation_WithSource) {
+ utils::Vector b{CaseSelector(2_i)};
+
+ auto* body = create<BlockStatement>(
+ utils::Vector{
+ create<DiscardStatement>(),
+ },
+ utils::Empty);
+ auto* c = create<CaseStatement>(Source{Source::Location{20, 2}}, b, body);
+ auto src = c->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(CaseStatementTest, IsCase) {
+ auto* c = create<CaseStatement>(utils::Vector{DefaultCaseSelector()},
+ create<BlockStatement>(utils::Empty, utils::Empty));
+ EXPECT_TRUE(c->Is<CaseStatement>());
+}
+
+TEST_F(CaseStatementTest, Assert_Null_Body) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<CaseStatement>(utils::Vector{b.DefaultCaseSelector()}, nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(CaseStatementTest, Assert_Null_Selector) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<CaseStatement>(utils::Vector<const ast::CaseSelector*, 1>{nullptr},
+ b.create<BlockStatement>(utils::Empty, utils::Empty));
+ },
+ "internal compiler error");
+}
+
+TEST_F(CaseStatementTest, Assert_DifferentProgramID_Call) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<CaseStatement>(utils::Vector{b1.DefaultCaseSelector()},
+ b2.create<BlockStatement>(utils::Empty, utils::Empty));
+ },
+ "internal compiler error");
+}
+
+TEST_F(CaseStatementTest, Assert_DifferentProgramID_Selector) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<CaseStatement>(utils::Vector{b2.CaseSelector(b2.Expr(2_i))},
+ b1.create<BlockStatement>(utils::Empty, utils::Empty));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/compound_assignment_statement.cc b/src/tint/lang/wgsl/ast/compound_assignment_statement.cc
new file mode 100644
index 0000000..c516a5b
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/compound_assignment_statement.cc
@@ -0,0 +1,46 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/compound_assignment_statement.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::CompoundAssignmentStatement);
+
+namespace tint::ast {
+
+CompoundAssignmentStatement::CompoundAssignmentStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* l,
+ const Expression* r,
+ BinaryOp o)
+ : Base(pid, nid, src), lhs(l), rhs(r), op(o) {
+ TINT_ASSERT(AST, lhs);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, lhs, program_id);
+ TINT_ASSERT(AST, rhs);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, rhs, program_id);
+}
+
+CompoundAssignmentStatement::~CompoundAssignmentStatement() = default;
+
+const CompoundAssignmentStatement* CompoundAssignmentStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* l = ctx->Clone(lhs);
+ auto* r = ctx->Clone(rhs);
+ return ctx->dst->create<CompoundAssignmentStatement>(src, l, r, op);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/compound_assignment_statement.h b/src/tint/lang/wgsl/ast/compound_assignment_statement.h
new file mode 100644
index 0000000..c59e5e7
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/compound_assignment_statement.h
@@ -0,0 +1,63 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_COMPOUND_ASSIGNMENT_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_COMPOUND_ASSIGNMENT_STATEMENT_H_
+
+#include "src/tint/lang/wgsl/ast/binary_expression.h"
+#include "src/tint/lang/wgsl/ast/expression.h"
+#include "src/tint/lang/wgsl/ast/statement.h"
+
+namespace tint::ast {
+
+/// A compound assignment statement
+class CompoundAssignmentStatement final
+ : public utils::Castable<CompoundAssignmentStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the compound assignment statement source
+ /// @param lhs the left side of the expression
+ /// @param rhs the right side of the expression
+ /// @param op the binary operator
+ CompoundAssignmentStatement(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ const Expression* lhs,
+ const Expression* rhs,
+ BinaryOp op);
+
+ /// Destructor
+ ~CompoundAssignmentStatement() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const CompoundAssignmentStatement* Clone(CloneContext* ctx) const override;
+
+ /// left side expression
+ const Expression* const lhs;
+
+ /// right side expression
+ const Expression* const rhs;
+
+ /// the binary operator
+ const BinaryOp op;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_COMPOUND_ASSIGNMENT_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/compound_assignment_statement_test.cc b/src/tint/lang/wgsl/ast/compound_assignment_statement_test.cc
new file mode 100644
index 0000000..2595c6a
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/compound_assignment_statement_test.cc
@@ -0,0 +1,97 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/compound_assignment_statement.h"
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast {
+namespace {
+
+using CompoundAssignmentStatementTest = TestHelper;
+
+TEST_F(CompoundAssignmentStatementTest, Creation) {
+ auto* lhs = Expr("lhs");
+ auto* rhs = Expr("rhs");
+ auto op = BinaryOp::kAdd;
+
+ auto* stmt = create<CompoundAssignmentStatement>(lhs, rhs, op);
+ EXPECT_EQ(stmt->lhs, lhs);
+ EXPECT_EQ(stmt->rhs, rhs);
+ EXPECT_EQ(stmt->op, op);
+}
+
+TEST_F(CompoundAssignmentStatementTest, CreationWithSource) {
+ auto* lhs = Expr("lhs");
+ auto* rhs = Expr("rhs");
+ auto op = BinaryOp::kMultiply;
+
+ auto* stmt = create<CompoundAssignmentStatement>(Source{Source::Location{20, 2}}, lhs, rhs, op);
+ auto src = stmt->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(CompoundAssignmentStatementTest, IsCompoundAssign) {
+ auto* lhs = Expr("lhs");
+ auto* rhs = Expr("rhs");
+ auto op = BinaryOp::kSubtract;
+
+ auto* stmt = create<CompoundAssignmentStatement>(lhs, rhs, op);
+ EXPECT_TRUE(stmt->Is<CompoundAssignmentStatement>());
+}
+
+TEST_F(CompoundAssignmentStatementTest, Assert_Null_LHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<CompoundAssignmentStatement>(nullptr, b.Expr(1_i), BinaryOp::kAdd);
+ },
+ "internal compiler error");
+}
+
+TEST_F(CompoundAssignmentStatementTest, Assert_Null_RHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<CompoundAssignmentStatement>(b.Expr(1_i), nullptr, BinaryOp::kAdd);
+ },
+ "internal compiler error");
+}
+
+TEST_F(CompoundAssignmentStatementTest, Assert_DifferentProgramID_LHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<CompoundAssignmentStatement>(b2.Expr("lhs"), b1.Expr("rhs"), BinaryOp::kAdd);
+ },
+ "internal compiler error");
+}
+
+TEST_F(CompoundAssignmentStatementTest, Assert_DifferentProgramID_RHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<CompoundAssignmentStatement>(b1.Expr("lhs"), b2.Expr("rhs"), BinaryOp::kAdd);
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/const.cc b/src/tint/lang/wgsl/ast/const.cc
new file mode 100644
index 0000000..6d51d22
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/const.cc
@@ -0,0 +1,51 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/const.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::Const);
+
+namespace tint::ast {
+
+Const::Const(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Identifier* n,
+ Type ty,
+ const Expression* init,
+ utils::VectorRef<const Attribute*> attrs)
+ : Base(pid, nid, src, n, ty, init, std::move(attrs)) {
+ TINT_ASSERT(AST, init != nullptr);
+}
+
+Const::~Const() = default;
+
+const char* Const::Kind() const {
+ return "const";
+}
+
+const Const* Const::Clone(CloneContext* ctx) const {
+ auto src = ctx->Clone(source);
+ auto n = ctx->Clone(name);
+ auto ty = ctx->Clone(type);
+ auto* init = ctx->Clone(initializer);
+ auto attrs = ctx->Clone(attributes);
+ return ctx->dst->create<Const>(src, n, ty, init, std::move(attrs));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/const.h b/src/tint/lang/wgsl/ast/const.h
new file mode 100644
index 0000000..66c7882
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/const.h
@@ -0,0 +1,66 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_CONST_H_
+#define SRC_TINT_LANG_WGSL_AST_CONST_H_
+
+#include "src/tint/lang/wgsl/ast/variable.h"
+
+namespace tint::ast {
+
+/// A "const" declaration is a name for a module-scoped or function-scoped creation-time value.
+/// const must have a initializer expression.
+///
+/// Examples:
+///
+/// ```
+/// const n = 123; // Abstract-integer typed constant
+/// const pi = 3.14159265359; // Abstract-float typed constant
+/// const max_f32 : f32 = 0x1.fffffep+127; // f32 typed constant
+/// ```
+/// @see https://www.w3.org/TR/WGSL/#creation-time-consts
+class Const final : public utils::Castable<Const, Variable> {
+ public:
+ /// Create a 'const' creation-time value variable.
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the variable source
+ /// @param name the variable name
+ /// @param type the declared variable type
+ /// @param initializer the initializer expression. Must not be nullptr.
+ /// @param attributes the variable attributes
+ Const(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ const Identifier* name,
+ Type type,
+ const Expression* initializer,
+ utils::VectorRef<const Attribute*> attributes);
+
+ /// Destructor
+ ~Const() override;
+
+ /// @returns "const"
+ const char* Kind() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const Const* Clone(CloneContext* ctx) const override;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_CONST_H_
diff --git a/src/tint/lang/wgsl/ast/const_assert.cc b/src/tint/lang/wgsl/ast/const_assert.cc
new file mode 100644
index 0000000..3a60f40
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/const_assert.cc
@@ -0,0 +1,38 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/const_assert.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::ConstAssert);
+
+namespace tint::ast {
+
+ConstAssert::ConstAssert(ProgramID pid, NodeID nid, const Source& src, const Expression* cond)
+ : Base(pid, nid, src), condition(cond) {
+ TINT_ASSERT(AST, cond);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, cond, program_id);
+}
+
+ConstAssert::~ConstAssert() = default;
+
+const ConstAssert* ConstAssert::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* cond = ctx->Clone(condition);
+ return ctx->dst->create<ConstAssert>(src, cond);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/const_assert.h b/src/tint/lang/wgsl/ast/const_assert.h
new file mode 100644
index 0000000..a799093
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/const_assert.h
@@ -0,0 +1,47 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_CONST_ASSERT_H_
+#define SRC_TINT_LANG_WGSL_AST_CONST_ASSERT_H_
+
+#include "src/tint/lang/wgsl/ast/statement.h"
+#include "src/tint/lang/wgsl/ast/variable.h"
+
+namespace tint::ast {
+
+/// A `const_assert` statement
+class ConstAssert final : public utils::Castable<ConstAssert, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the variable statement source
+ /// @param condition the assertion condition
+ ConstAssert(ProgramID pid, NodeID nid, const Source& source, const Expression* condition);
+
+ /// Destructor
+ ~ConstAssert() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext` `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const ConstAssert* Clone(CloneContext* ctx) const override;
+
+ /// The assertion condition
+ const Expression* const condition;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_CONST_ASSERT_H_
diff --git a/src/tint/lang/wgsl/ast/const_assert_test.cc b/src/tint/lang/wgsl/ast/const_assert_test.cc
new file mode 100644
index 0000000..b43cf61
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/const_assert_test.cc
@@ -0,0 +1,66 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/const_assert.h"
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using ConstAssertTest = TestHelper;
+
+TEST_F(ConstAssertTest, Creation) {
+ auto* cond = Expr(true);
+ auto* stmt = ConstAssert(cond);
+ EXPECT_EQ(stmt->condition, cond);
+}
+
+TEST_F(ConstAssertTest, Creation_WithSource) {
+ auto* cond = Expr(true);
+ auto* stmt = ConstAssert(Source{{20, 2}}, cond);
+ auto src = stmt->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(ConstAssertTest, IsConstAssert) {
+ auto* cond = Expr(true);
+
+ auto* stmt = ConstAssert(cond);
+ EXPECT_TRUE(stmt->Is<ast::ConstAssert>());
+}
+
+TEST_F(ConstAssertTest, Assert_Null_Condition) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.ConstAssert(nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(ConstAssertTest, Assert_DifferentProgramID_Condition) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.ConstAssert(b2.Expr(i32(123)));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/continue_statement.cc b/src/tint/lang/wgsl/ast/continue_statement.cc
new file mode 100644
index 0000000..4b8d92b
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/continue_statement.cc
@@ -0,0 +1,34 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/continue_statement.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::ContinueStatement);
+
+namespace tint::ast {
+
+ContinueStatement::ContinueStatement(ProgramID pid, NodeID nid, const Source& src)
+ : Base(pid, nid, src) {}
+
+ContinueStatement::~ContinueStatement() = default;
+
+const ContinueStatement* ContinueStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ return ctx->dst->create<ContinueStatement>(src);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/continue_statement.h b/src/tint/lang/wgsl/ast/continue_statement.h
new file mode 100644
index 0000000..3b1201c
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/continue_statement.h
@@ -0,0 +1,43 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_CONTINUE_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_CONTINUE_STATEMENT_H_
+
+#include "src/tint/lang/wgsl/ast/statement.h"
+
+namespace tint::ast {
+
+/// An continue statement
+class ContinueStatement final : public utils::Castable<ContinueStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ ContinueStatement(ProgramID pid, NodeID nid, const Source& src);
+
+ /// Destructor
+ ~ContinueStatement() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const ContinueStatement* Clone(CloneContext* ctx) const override;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_CONTINUE_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/continue_statement_test.cc b/src/tint/lang/wgsl/ast/continue_statement_test.cc
new file mode 100644
index 0000000..6e9a1e7
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/continue_statement_test.cc
@@ -0,0 +1,37 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/continue_statement.h"
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using ContinueStatementTest = TestHelper;
+
+TEST_F(ContinueStatementTest, Creation_WithSource) {
+ auto* stmt = create<ContinueStatement>(Source{Source::Location{20, 2}});
+ auto src = stmt->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(ContinueStatementTest, IsContinue) {
+ auto* stmt = create<ContinueStatement>();
+ EXPECT_TRUE(stmt->Is<ContinueStatement>());
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/diagnostic_attribute.cc b/src/tint/lang/wgsl/ast/diagnostic_attribute.cc
new file mode 100644
index 0000000..0e3ce31
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/diagnostic_attribute.cc
@@ -0,0 +1,46 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/diagnostic_attribute.h"
+
+#include <string>
+#include <utility>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::DiagnosticAttribute);
+
+namespace tint::ast {
+
+DiagnosticAttribute::DiagnosticAttribute(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ DiagnosticControl&& dc)
+ : Base(pid, nid, src), control(std::move(dc)) {}
+
+DiagnosticAttribute::~DiagnosticAttribute() = default;
+
+std::string DiagnosticAttribute::Name() const {
+ return "diagnostic";
+}
+
+const DiagnosticAttribute* DiagnosticAttribute::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto rule = ctx->Clone(control.rule_name);
+ DiagnosticControl dc(control.severity, rule);
+ return ctx->dst->create<DiagnosticAttribute>(src, std::move(dc));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/diagnostic_attribute.h b/src/tint/lang/wgsl/ast/diagnostic_attribute.h
new file mode 100644
index 0000000..38fbe23
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/diagnostic_attribute.h
@@ -0,0 +1,50 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_DIAGNOSTIC_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_DIAGNOSTIC_ATTRIBUTE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+#include "src/tint/lang/wgsl/ast/diagnostic_control.h"
+
+namespace tint::ast {
+
+/// A diagnostic attribute
+class DiagnosticAttribute final : public utils::Castable<DiagnosticAttribute, Attribute> {
+ public:
+ /// constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param dc the diagnostic control
+ DiagnosticAttribute(ProgramID pid, NodeID nid, const Source& src, DiagnosticControl&& dc);
+ ~DiagnosticAttribute() override;
+
+ /// @returns the WGSL name for the attribute
+ std::string Name() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext` `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const DiagnosticAttribute* Clone(CloneContext* ctx) const override;
+
+ /// The diagnostic control.
+ const DiagnosticControl control;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_DIAGNOSTIC_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/diagnostic_attribute_test.cc b/src/tint/lang/wgsl/ast/diagnostic_attribute_test.cc
new file mode 100644
index 0000000..49de81c
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/diagnostic_attribute_test.cc
@@ -0,0 +1,40 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using DiagnosticAttributeTest = TestHelper;
+
+TEST_F(DiagnosticAttributeTest, Name) {
+ auto* d = DiagnosticAttribute(builtin::DiagnosticSeverity::kWarning, "foo");
+ EXPECT_EQ(d->Name(), "diagnostic");
+ EXPECT_EQ(d->control.severity, builtin::DiagnosticSeverity::kWarning);
+ EXPECT_EQ(d->control.rule_name->category, nullptr);
+ CheckIdentifier(d->control.rule_name->name, "foo");
+}
+
+TEST_F(DiagnosticAttributeTest, CategoryAndName) {
+ auto* d = DiagnosticAttribute(builtin::DiagnosticSeverity::kWarning, "foo", "bar");
+ EXPECT_EQ(d->Name(), "diagnostic");
+ EXPECT_EQ(d->control.severity, builtin::DiagnosticSeverity::kWarning);
+ CheckIdentifier(d->control.rule_name->category, "foo");
+ CheckIdentifier(d->control.rule_name->name, "bar");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/diagnostic_control.cc b/src/tint/lang/wgsl/ast/diagnostic_control.cc
new file mode 100644
index 0000000..fdae705
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/diagnostic_control.cc
@@ -0,0 +1,35 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/diagnostic_control.h"
+
+#include <string>
+
+#include "src/tint/builtin/diagnostic_severity.h"
+#include "src/tint/lang/wgsl/ast/identifier.h"
+#include "src/tint/lang/wgsl/ast/templated_identifier.h"
+
+namespace tint::ast {
+
+DiagnosticControl::DiagnosticControl() = default;
+
+DiagnosticControl::DiagnosticControl(builtin::DiagnosticSeverity sev,
+ const DiagnosticRuleName* rule)
+ : severity(sev), rule_name(rule) {
+ TINT_ASSERT(AST, rule != nullptr);
+}
+
+DiagnosticControl::DiagnosticControl(DiagnosticControl&&) = default;
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/diagnostic_control.h b/src/tint/lang/wgsl/ast/diagnostic_control.h
new file mode 100644
index 0000000..5b5bcf8
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/diagnostic_control.h
@@ -0,0 +1,54 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_DIAGNOSTIC_CONTROL_H_
+#define SRC_TINT_LANG_WGSL_AST_DIAGNOSTIC_CONTROL_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "src/tint/builtin/diagnostic_severity.h"
+#include "src/tint/diagnostic/diagnostic.h"
+
+// Forward declarations
+namespace tint::ast {
+class DiagnosticRuleName;
+} // namespace tint::ast
+
+namespace tint::ast {
+
+/// A diagnostic control used for diagnostic directives and attributes.
+struct DiagnosticControl {
+ public:
+ /// Default constructor.
+ DiagnosticControl();
+
+ /// Constructor
+ /// @param sev the diagnostic severity
+ /// @param rule the diagnostic rule name
+ DiagnosticControl(builtin::DiagnosticSeverity sev, const DiagnosticRuleName* rule);
+
+ /// Move constructor
+ DiagnosticControl(DiagnosticControl&&);
+
+ /// The diagnostic severity control.
+ builtin::DiagnosticSeverity severity = builtin::DiagnosticSeverity::kUndefined;
+
+ /// The diagnostic rule name.
+ const DiagnosticRuleName* rule_name = nullptr;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_DIAGNOSTIC_CONTROL_H_
diff --git a/src/tint/lang/wgsl/ast/diagnostic_control_test.cc b/src/tint/lang/wgsl/ast/diagnostic_control_test.cc
new file mode 100644
index 0000000..74e855d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/diagnostic_control_test.cc
@@ -0,0 +1,37 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/builtin/diagnostic_severity.h"
+#include "src/tint/lang/wgsl/ast/diagnostic_control.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using DiagnosticControlTest = TestHelper;
+
+TEST_F(DiagnosticControlTest, Assert_RuleNotNull) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ DiagnosticControl control(builtin::DiagnosticSeverity::kWarning, nullptr);
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/diagnostic_directive.cc b/src/tint/lang/wgsl/ast/diagnostic_directive.cc
new file mode 100644
index 0000000..d2cceff
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/diagnostic_directive.cc
@@ -0,0 +1,38 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/diagnostic_directive.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::DiagnosticDirective);
+
+namespace tint::ast {
+
+DiagnosticDirective::DiagnosticDirective(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ DiagnosticControl&& dc)
+ : Base(pid, nid, src), control(std::move(dc)) {}
+
+DiagnosticDirective::~DiagnosticDirective() = default;
+
+const DiagnosticDirective* DiagnosticDirective::Clone(CloneContext* ctx) const {
+ auto src = ctx->Clone(source);
+ auto rule = ctx->Clone(control.rule_name);
+ DiagnosticControl dc(control.severity, rule);
+ return ctx->dst->create<DiagnosticDirective>(src, std::move(dc));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/diagnostic_directive.h b/src/tint/lang/wgsl/ast/diagnostic_directive.h
new file mode 100644
index 0000000..5a917d9
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/diagnostic_directive.h
@@ -0,0 +1,55 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_DIAGNOSTIC_DIRECTIVE_H_
+#define SRC_TINT_LANG_WGSL_AST_DIAGNOSTIC_DIRECTIVE_H_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/diagnostic_control.h"
+#include "src/tint/lang/wgsl/ast/node.h"
+
+namespace tint::ast {
+
+/// A "diagnostic" directive. Example:
+/// ```
+/// // Turn off diagnostics for derivative uniformity violations.
+/// diagnostic(off, derivative_uniformity);
+/// ```
+class DiagnosticDirective final : public utils::Castable<DiagnosticDirective, Node> {
+ public:
+ /// Create a extension
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param dc the diagnostic control
+ DiagnosticDirective(ProgramID pid, NodeID nid, const Source& src, DiagnosticControl&& dc);
+
+ /// Destructor
+ ~DiagnosticDirective() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext` `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const DiagnosticDirective* Clone(CloneContext* ctx) const override;
+
+ /// The diagnostic control.
+ const DiagnosticControl control;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_DIAGNOSTIC_DIRECTIVE_H_
diff --git a/src/tint/lang/wgsl/ast/diagnostic_directive_test.cc b/src/tint/lang/wgsl/ast/diagnostic_directive_test.cc
new file mode 100644
index 0000000..009edb0
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/diagnostic_directive_test.cc
@@ -0,0 +1,49 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/diagnostic_directive.h"
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using DiagnosticDirectiveTest = TestHelper;
+
+TEST_F(DiagnosticDirectiveTest, Name) {
+ auto* d = DiagnosticDirective(Source{{{10, 5}, {10, 15}}},
+ builtin::DiagnosticSeverity::kWarning, "foo");
+ EXPECT_EQ(d->source.range.begin.line, 10u);
+ EXPECT_EQ(d->source.range.begin.column, 5u);
+ EXPECT_EQ(d->source.range.end.line, 10u);
+ EXPECT_EQ(d->source.range.end.column, 15u);
+ EXPECT_EQ(d->control.severity, builtin::DiagnosticSeverity::kWarning);
+ EXPECT_EQ(d->control.rule_name->category, nullptr);
+ CheckIdentifier(d->control.rule_name->name, "foo");
+}
+
+TEST_F(DiagnosticDirectiveTest, CategoryAndName) {
+ auto* d = DiagnosticDirective(Source{{{10, 5}, {10, 15}}},
+ builtin::DiagnosticSeverity::kWarning, "foo", "bar");
+ EXPECT_EQ(d->source.range.begin.line, 10u);
+ EXPECT_EQ(d->source.range.begin.column, 5u);
+ EXPECT_EQ(d->source.range.end.line, 10u);
+ EXPECT_EQ(d->source.range.end.column, 15u);
+ EXPECT_EQ(d->control.severity, builtin::DiagnosticSeverity::kWarning);
+ CheckIdentifier(d->control.rule_name->category, "foo");
+ CheckIdentifier(d->control.rule_name->name, "bar");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/diagnostic_rule_name.cc b/src/tint/lang/wgsl/ast/diagnostic_rule_name.cc
new file mode 100644
index 0000000..d0994c9
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/diagnostic_rule_name.cc
@@ -0,0 +1,74 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/diagnostic_rule_name.h"
+
+#include <string>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::DiagnosticRuleName);
+
+namespace tint::ast {
+
+DiagnosticRuleName::DiagnosticRuleName(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Identifier* n)
+ : Base(pid, nid, src), name(n) {
+ TINT_ASSERT(AST, name != nullptr);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, name, program_id);
+ if (name) {
+ // It is invalid for a diagnostic rule name to be templated
+ TINT_ASSERT(AST, !name->Is<TemplatedIdentifier>());
+ }
+}
+
+DiagnosticRuleName::DiagnosticRuleName(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Identifier* c,
+ const Identifier* n)
+ : Base(pid, nid, src), category(c), name(n) {
+ TINT_ASSERT(AST, name != nullptr);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, name, program_id);
+ if (name) {
+ // It is invalid for a diagnostic rule name to be templated
+ TINT_ASSERT(AST, !name->Is<TemplatedIdentifier>());
+ }
+ if (category) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, category, program_id);
+ // It is invalid for a diagnostic rule category to be templated
+ TINT_ASSERT(AST, !category->Is<TemplatedIdentifier>());
+ }
+}
+
+const DiagnosticRuleName* DiagnosticRuleName::Clone(CloneContext* ctx) const {
+ auto src = ctx->Clone(source);
+ auto n = ctx->Clone(name);
+ if (auto c = ctx->Clone(category)) {
+ return ctx->dst->create<DiagnosticRuleName>(src, c, n);
+ }
+ return ctx->dst->create<DiagnosticRuleName>(src, n);
+}
+
+std::string DiagnosticRuleName::String() const {
+ if (category) {
+ return category->symbol.Name() + "." + name->symbol.Name();
+ } else {
+ return name->symbol.Name();
+ }
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/diagnostic_rule_name.h b/src/tint/lang/wgsl/ast/diagnostic_rule_name.h
new file mode 100644
index 0000000..b8692f9
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/diagnostic_rule_name.h
@@ -0,0 +1,68 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_DIAGNOSTIC_RULE_NAME_H_
+#define SRC_TINT_LANG_WGSL_AST_DIAGNOSTIC_RULE_NAME_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/node.h"
+
+// Forward declarations
+namespace tint::ast {
+class Identifier;
+} // namespace tint::ast
+
+namespace tint::ast {
+
+/// A diagnostic rule name used for diagnostic directives and attributes.
+class DiagnosticRuleName final : public utils::Castable<DiagnosticRuleName, Node> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param name the rule name
+ DiagnosticRuleName(ProgramID pid, NodeID nid, const Source& src, const Identifier* name);
+
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param category the rule category.
+ /// @param name the rule name
+ DiagnosticRuleName(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Identifier* category,
+ const Identifier* name);
+
+ /// Clones this node and all transitive child nodes using the `CloneContext` `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const DiagnosticRuleName* Clone(CloneContext* ctx) const override;
+
+ /// @return the full name of this diagnostic rule, either as `name` or `category.name`.
+ std::string String() const;
+
+ /// The diagnostic rule category (category.name)
+ Identifier const* const category = nullptr;
+
+ /// The diagnostic rule name.
+ Identifier const* const name;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_DIAGNOSTIC_RULE_NAME_H_
diff --git a/src/tint/lang/wgsl/ast/diagnostic_rule_name_test.cc b/src/tint/lang/wgsl/ast/diagnostic_rule_name_test.cc
new file mode 100644
index 0000000..d7ebb69
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/diagnostic_rule_name_test.cc
@@ -0,0 +1,50 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/diagnostic_rule_name.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using DiagnosticRuleNameTest = TestHelper;
+
+TEST_F(DiagnosticRuleNameTest, String) {
+ EXPECT_EQ(DiagnosticRuleName("name")->String(), "name");
+ EXPECT_EQ(DiagnosticRuleName("category", "name")->String(), "category.name");
+}
+
+TEST_F(DiagnosticRuleNameTest, Assert_NameNotTemplated) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<ast::DiagnosticRuleName>(b.Ident("name", "a", "b", "c"));
+ },
+ "internal compiler error");
+}
+
+TEST_F(DiagnosticRuleNameTest, Assert_CategoryNotTemplated) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<ast::DiagnosticRuleName>(b.Ident("name"), b.Ident("category", "a", "b", "c"));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/disable_validation_attribute.cc b/src/tint/lang/wgsl/ast/disable_validation_attribute.cc
new file mode 100644
index 0000000..979b2d0
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/disable_validation_attribute.cc
@@ -0,0 +1,59 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/disable_validation_attribute.h"
+#include "src/tint/clone_context.h"
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::DisableValidationAttribute);
+
+namespace tint::ast {
+
+DisableValidationAttribute::DisableValidationAttribute(ProgramID pid,
+ NodeID nid,
+ DisabledValidation val)
+ : Base(pid, nid, utils::Empty), validation(val) {}
+
+DisableValidationAttribute::~DisableValidationAttribute() = default;
+
+std::string DisableValidationAttribute::InternalName() const {
+ switch (validation) {
+ case DisabledValidation::kFunctionHasNoBody:
+ return "disable_validation__function_has_no_body";
+ case DisabledValidation::kBindingPointCollision:
+ return "disable_validation__binding_point_collision";
+ case DisabledValidation::kIgnoreAddressSpace:
+ return "disable_validation__ignore_address_space";
+ case DisabledValidation::kEntryPointParameter:
+ return "disable_validation__entry_point_parameter";
+ case DisabledValidation::kFunctionParameter:
+ return "disable_validation__function_parameter";
+ case DisabledValidation::kIgnoreStrideAttribute:
+ return "disable_validation__ignore_stride";
+ case DisabledValidation::kIgnoreInvalidPointerArgument:
+ return "disable_validation__ignore_invalid_pointer_argument";
+ case DisabledValidation::kIgnorePointerAliasing:
+ return "disable_validation__ignore_pointer_aliasing";
+ case DisabledValidation::kIgnoreStructMemberLimit:
+ return "disable_validation__ignore_struct_member";
+ }
+ return "<invalid>";
+}
+
+const DisableValidationAttribute* DisableValidationAttribute::Clone(CloneContext* ctx) const {
+ return ctx->dst->ASTNodes().Create<DisableValidationAttribute>(
+ ctx->dst->ID(), ctx->dst->AllocateNodeID(), validation);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/disable_validation_attribute.h b/src/tint/lang/wgsl/ast/disable_validation_attribute.h
new file mode 100644
index 0000000..44b815b
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/disable_validation_attribute.h
@@ -0,0 +1,82 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_DISABLE_VALIDATION_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_DISABLE_VALIDATION_ATTRIBUTE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/internal_attribute.h"
+
+namespace tint::ast {
+
+/// Enumerator of validation features that can be disabled with a
+/// DisableValidationAttribute attribute.
+enum class DisabledValidation {
+ /// When applied to a function, the validator will not complain there is no body to a function.
+ kFunctionHasNoBody,
+ /// When applied to a module-scoped variable, the validator will not complain if two resource
+ /// variables have the same binding points.
+ kBindingPointCollision,
+ /// When applied to a variable, the validator will not complain about the declared address
+ /// space.
+ kIgnoreAddressSpace,
+ /// When applied to an entry-point function parameter, the validator will not check for entry IO
+ /// attributes.
+ kEntryPointParameter,
+ /// When applied to a function parameter, the parameter will not be validated.
+ kFunctionParameter,
+ /// When applied to a member attribute, a stride attribute may be applied to non-array types.
+ kIgnoreStrideAttribute,
+ /// When applied to a pointer function parameter, the validator will not require a function call
+ /// argument passed for that parameter to have a certain form.
+ kIgnoreInvalidPointerArgument,
+ /// When applied to a function declaration, the validator will not complain if multiple
+ /// pointer arguments alias when that function is called.
+ kIgnorePointerAliasing,
+ /// When applied to a struct, validation of max number of members is skipped.
+ kIgnoreStructMemberLimit,
+};
+
+/// An internal attribute used to tell the validator to ignore specific
+/// violations. Typically generated by transforms that need to produce ASTs that
+/// would otherwise cause validation errors.
+class DisableValidationAttribute final
+ : public utils::Castable<DisableValidationAttribute, InternalAttribute> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param validation the validation to disable
+ explicit DisableValidationAttribute(ProgramID pid, NodeID nid, DisabledValidation validation);
+
+ /// Destructor
+ ~DisableValidationAttribute() override;
+
+ /// @return a short description of the internal attribute which will be
+ /// displayed in WGSL as `@internal(<name>)` (but is not parsable).
+ std::string InternalName() const override;
+
+ /// Performs a deep clone of this object using the CloneContext `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned object
+ const DisableValidationAttribute* Clone(CloneContext* ctx) const override;
+
+ /// The validation that this attribute disables
+ const DisabledValidation validation;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_DISABLE_VALIDATION_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/discard_statement.cc b/src/tint/lang/wgsl/ast/discard_statement.cc
new file mode 100644
index 0000000..2f95e27
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/discard_statement.cc
@@ -0,0 +1,34 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::DiscardStatement);
+
+namespace tint::ast {
+
+DiscardStatement::DiscardStatement(ProgramID pid, NodeID nid, const Source& src)
+ : Base(pid, nid, src) {}
+
+DiscardStatement::~DiscardStatement() = default;
+
+const DiscardStatement* DiscardStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ return ctx->dst->create<DiscardStatement>(src);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/discard_statement.h b/src/tint/lang/wgsl/ast/discard_statement.h
new file mode 100644
index 0000000..b192da9
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/discard_statement.h
@@ -0,0 +1,43 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_DISCARD_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_DISCARD_STATEMENT_H_
+
+#include "src/tint/lang/wgsl/ast/statement.h"
+
+namespace tint::ast {
+
+/// A discard statement
+class DiscardStatement final : public utils::Castable<DiscardStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ DiscardStatement(ProgramID pid, NodeID nid, const Source& src);
+
+ /// Destructor
+ ~DiscardStatement() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const DiscardStatement* Clone(CloneContext* ctx) const override;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_DISCARD_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/discard_statement_test.cc b/src/tint/lang/wgsl/ast/discard_statement_test.cc
new file mode 100644
index 0000000..be926a9
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/discard_statement_test.cc
@@ -0,0 +1,47 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using DiscardStatementTest = TestHelper;
+
+TEST_F(DiscardStatementTest, Creation) {
+ auto* stmt = create<DiscardStatement>();
+ EXPECT_EQ(stmt->source.range.begin.line, 0u);
+ EXPECT_EQ(stmt->source.range.begin.column, 0u);
+ EXPECT_EQ(stmt->source.range.end.line, 0u);
+ EXPECT_EQ(stmt->source.range.end.column, 0u);
+}
+
+TEST_F(DiscardStatementTest, Creation_WithSource) {
+ auto* stmt = create<DiscardStatement>(
+ Source{Source::Range{Source::Location{20, 2}, Source::Location{20, 5}}});
+ EXPECT_EQ(stmt->source.range.begin.line, 20u);
+ EXPECT_EQ(stmt->source.range.begin.column, 2u);
+ EXPECT_EQ(stmt->source.range.end.line, 20u);
+ EXPECT_EQ(stmt->source.range.end.column, 5u);
+}
+
+TEST_F(DiscardStatementTest, IsDiscard) {
+ auto* stmt = create<DiscardStatement>();
+ EXPECT_TRUE(stmt->Is<DiscardStatement>());
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/enable.cc b/src/tint/lang/wgsl/ast/enable.cc
new file mode 100644
index 0000000..439f45e
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/enable.cc
@@ -0,0 +1,46 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/enable.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::Enable);
+
+namespace tint::ast {
+
+Enable::Enable(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ utils::VectorRef<const Extension*> exts)
+ : Base(pid, nid, src), extensions(std::move(exts)) {}
+
+Enable::~Enable() = default;
+
+bool Enable::HasExtension(builtin::Extension ext) const {
+ for (auto* e : extensions) {
+ if (e->name == ext) {
+ return true;
+ }
+ }
+ return false;
+}
+
+const Enable* Enable::Clone(CloneContext* ctx) const {
+ auto src = ctx->Clone(source);
+ auto exts = ctx->Clone(extensions);
+ return ctx->dst->create<Enable>(src, std::move(exts));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/enable.h b/src/tint/lang/wgsl/ast/enable.h
new file mode 100644
index 0000000..fa32c57
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/enable.h
@@ -0,0 +1,58 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_ENABLE_H_
+#define SRC_TINT_LANG_WGSL_AST_ENABLE_H_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/extension.h"
+
+namespace tint::ast {
+
+/// An "enable" directive. Example:
+/// ```
+/// // Enable an extension named "f16"
+/// enable f16;
+/// ```
+class Enable final : public utils::Castable<Enable, Node> {
+ public:
+ /// Create a extension
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param exts the extensions being enabled by this directive
+ Enable(ProgramID pid, NodeID nid, const Source& src, utils::VectorRef<const Extension*> exts);
+
+ /// Destructor
+ ~Enable() override;
+
+ /// @param ext the extension to search for
+ /// @returns true if this Enable lists the given extension
+ bool HasExtension(builtin::Extension ext) const;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext` `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const Enable* Clone(CloneContext* ctx) const override;
+
+ /// The extensions being enabled by this directive
+ const utils::Vector<const Extension*, 4> extensions;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_ENABLE_H_
diff --git a/src/tint/lang/wgsl/ast/enable_test.cc b/src/tint/lang/wgsl/ast/enable_test.cc
new file mode 100644
index 0000000..61c7bbc
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/enable_test.cc
@@ -0,0 +1,41 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/enable.h"
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using EnableTest = TestHelper;
+
+TEST_F(EnableTest, Creation) {
+ auto* ext = Enable(Source{{{20, 2}, {20, 5}}}, builtin::Extension::kF16);
+ EXPECT_EQ(ext->source.range.begin.line, 20u);
+ EXPECT_EQ(ext->source.range.begin.column, 2u);
+ EXPECT_EQ(ext->source.range.end.line, 20u);
+ EXPECT_EQ(ext->source.range.end.column, 5u);
+ ASSERT_EQ(ext->extensions.Length(), 1u);
+ EXPECT_EQ(ext->extensions[0]->name, builtin::Extension::kF16);
+}
+
+TEST_F(EnableTest, HasExtension) {
+ auto* ext = Enable(Source{{{20, 2}, {20, 5}}}, builtin::Extension::kF16);
+ EXPECT_TRUE(ext->HasExtension(builtin::Extension::kF16));
+ EXPECT_FALSE(ext->HasExtension(builtin::Extension::kChromiumDisableUniformityAnalysis));
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/expression.cc b/src/tint/lang/wgsl/ast/expression.cc
new file mode 100644
index 0000000..4321e57
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/expression.cc
@@ -0,0 +1,25 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::Expression);
+
+namespace tint::ast {
+
+Expression::Expression(ProgramID pid, NodeID nid, const Source& src) : Base(pid, nid, src) {}
+
+Expression::~Expression() = default;
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/expression.h b/src/tint/lang/wgsl/ast/expression.h
new file mode 100644
index 0000000..1ad9e2d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/expression.h
@@ -0,0 +1,40 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_EXPRESSION_H_
+#define SRC_TINT_LANG_WGSL_AST_EXPRESSION_H_
+
+#include <string>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/node.h"
+
+namespace tint::ast {
+
+/// Base expression class
+class Expression : public utils::Castable<Expression, Node> {
+ public:
+ ~Expression() override;
+
+ protected:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ Expression(ProgramID pid, NodeID nid, const Source& src);
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_EXPRESSION_H_
diff --git a/src/tint/lang/wgsl/ast/extension.cc b/src/tint/lang/wgsl/ast/extension.cc
new file mode 100644
index 0000000..4a00aa1
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/extension.cc
@@ -0,0 +1,38 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/extension.h"
+
+#include "src/tint/program_builder.h"
+
+//! @cond Doxygen_Suppress
+// Doxygen gets confused with tint::ast::Extension and tint::builtin::Extension
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::Extension);
+
+namespace tint::ast {
+
+Extension::Extension(ProgramID pid, NodeID nid, const Source& src, builtin::Extension ext)
+ : Base(pid, nid, src), name(ext) {}
+
+Extension::~Extension() = default;
+
+const Extension* Extension::Clone(CloneContext* ctx) const {
+ auto src = ctx->Clone(source);
+ return ctx->dst->create<Extension>(src, name);
+}
+
+} // namespace tint::ast
+
+//! @endcond
diff --git a/src/tint/lang/wgsl/ast/extension.h b/src/tint/lang/wgsl/ast/extension.h
new file mode 100644
index 0000000..ba651b2
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/extension.h
@@ -0,0 +1,51 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_EXTENSION_H_
+#define SRC_TINT_LANG_WGSL_AST_EXTENSION_H_
+
+#include "src/tint/builtin/extension.h"
+#include "src/tint/lang/wgsl/ast/node.h"
+
+namespace tint::ast {
+
+/// An extension used in an "enable" directive. Example:
+/// ```
+/// enable f16;
+/// ```
+class Extension final : public utils::Castable<Extension, Node> {
+ public:
+ /// Create a extension
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param ext the extension
+ Extension(ProgramID pid, NodeID nid, const Source& src, builtin::Extension ext);
+
+ /// Destructor
+ ~Extension() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const Extension* Clone(CloneContext* ctx) const override;
+
+ /// The extension name
+ const builtin::Extension name;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_EXTENSION_H_
diff --git a/src/tint/lang/wgsl/ast/float_literal_expression.cc b/src/tint/lang/wgsl/ast/float_literal_expression.cc
new file mode 100644
index 0000000..6cb5dd9
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/float_literal_expression.cc
@@ -0,0 +1,51 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/float_literal_expression.h"
+
+#include <limits>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::FloatLiteralExpression);
+
+namespace tint::ast {
+
+FloatLiteralExpression::FloatLiteralExpression(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ double val,
+ Suffix suf)
+ : Base(pid, nid, src), value(val), suffix(suf) {}
+
+FloatLiteralExpression::~FloatLiteralExpression() = default;
+
+const FloatLiteralExpression* FloatLiteralExpression::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ return ctx->dst->create<FloatLiteralExpression>(src, value, suffix);
+}
+
+utils::StringStream& operator<<(utils::StringStream& out, FloatLiteralExpression::Suffix suffix) {
+ switch (suffix) {
+ default:
+ return out;
+ case FloatLiteralExpression::Suffix::kF:
+ return out << "f";
+ case FloatLiteralExpression::Suffix::kH:
+ return out << "h";
+ }
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/float_literal_expression.h b/src/tint/lang/wgsl/ast/float_literal_expression.h
new file mode 100644
index 0000000..e156505
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/float_literal_expression.h
@@ -0,0 +1,68 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_FLOAT_LITERAL_EXPRESSION_H_
+#define SRC_TINT_LANG_WGSL_AST_FLOAT_LITERAL_EXPRESSION_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/literal_expression.h"
+
+namespace tint::ast {
+
+/// A float literal
+class FloatLiteralExpression final
+ : public utils::Castable<FloatLiteralExpression, LiteralExpression> {
+ public:
+ /// Literal suffix
+ enum class Suffix {
+ /// No suffix
+ kNone,
+ /// 'f' suffix (f32)
+ kF,
+ /// 'h' suffix (f16)
+ kH,
+ };
+
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param val the literal value
+ /// @param suf the literal suffix
+ FloatLiteralExpression(ProgramID pid, NodeID nid, const Source& src, double val, Suffix suf);
+ ~FloatLiteralExpression() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const FloatLiteralExpression* Clone(CloneContext* ctx) const override;
+
+ /// The literal value
+ const double value;
+
+ /// The literal suffix
+ const Suffix suffix;
+};
+
+/// Writes the float literal suffix to the stream.
+/// @param out the stream to write to
+/// @param suffix the suffix to write
+/// @returns out so calls can be chained
+utils::StringStream& operator<<(utils::StringStream& out, FloatLiteralExpression::Suffix suffix);
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_FLOAT_LITERAL_EXPRESSION_H_
diff --git a/src/tint/lang/wgsl/ast/float_literal_expression_test.cc b/src/tint/lang/wgsl/ast/float_literal_expression_test.cc
new file mode 100644
index 0000000..98e6276
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/float_literal_expression_test.cc
@@ -0,0 +1,58 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+#include "src/tint/utils/string_stream.h"
+
+namespace tint::ast {
+namespace {
+
+using FloatLiteralExpressionTest = TestHelper;
+
+TEST_F(FloatLiteralExpressionTest, SuffixNone) {
+ auto* i = create<FloatLiteralExpression>(42.0, FloatLiteralExpression::Suffix::kNone);
+ ASSERT_TRUE(i->Is<FloatLiteralExpression>());
+ EXPECT_EQ(i->value, 42);
+ EXPECT_EQ(i->suffix, FloatLiteralExpression::Suffix::kNone);
+}
+
+TEST_F(FloatLiteralExpressionTest, SuffixF) {
+ auto* i = create<FloatLiteralExpression>(42.0, FloatLiteralExpression::Suffix::kF);
+ ASSERT_TRUE(i->Is<FloatLiteralExpression>());
+ EXPECT_EQ(i->value, 42);
+ EXPECT_EQ(i->suffix, FloatLiteralExpression::Suffix::kF);
+}
+
+TEST_F(FloatLiteralExpressionTest, SuffixH) {
+ auto* i = create<FloatLiteralExpression>(42.0, FloatLiteralExpression::Suffix::kH);
+ ASSERT_TRUE(i->Is<FloatLiteralExpression>());
+ EXPECT_EQ(i->value, 42);
+ EXPECT_EQ(i->suffix, FloatLiteralExpression::Suffix::kH);
+}
+
+TEST_F(FloatLiteralExpressionTest, SuffixStringStream) {
+ auto to_str = [](FloatLiteralExpression::Suffix suffix) {
+ utils::StringStream ss;
+ ss << suffix;
+ return ss.str();
+ };
+
+ EXPECT_EQ("", to_str(FloatLiteralExpression::Suffix::kNone));
+ EXPECT_EQ("f", to_str(FloatLiteralExpression::Suffix::kF));
+ EXPECT_EQ("h", to_str(FloatLiteralExpression::Suffix::kH));
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/for_loop_statement.cc b/src/tint/lang/wgsl/ast/for_loop_statement.cc
new file mode 100644
index 0000000..1bcd9c2
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/for_loop_statement.cc
@@ -0,0 +1,65 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/for_loop_statement.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::ForLoopStatement);
+
+namespace tint::ast {
+
+ForLoopStatement::ForLoopStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Statement* init,
+ const Expression* cond,
+ const Statement* cont,
+ const BlockStatement* b,
+ utils::VectorRef<const ast::Attribute*> attrs)
+ : Base(pid, nid, src),
+ initializer(init),
+ condition(cond),
+ continuing(cont),
+ body(b),
+ attributes(std::move(attrs)) {
+ TINT_ASSERT(AST, body);
+
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, initializer, program_id);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, condition, program_id);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, continuing, program_id);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, body, program_id);
+ for (auto* attr : attributes) {
+ TINT_ASSERT(AST, attr);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, attr, program_id);
+ }
+}
+
+ForLoopStatement::~ForLoopStatement() = default;
+
+const ForLoopStatement* ForLoopStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+
+ auto* init = ctx->Clone(initializer);
+ auto* cond = ctx->Clone(condition);
+ auto* cont = ctx->Clone(continuing);
+ auto* b = ctx->Clone(body);
+ auto attrs = ctx->Clone(attributes);
+ return ctx->dst->create<ForLoopStatement>(src, init, cond, cont, b, std::move(attrs));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/for_loop_statement.h b/src/tint/lang/wgsl/ast/for_loop_statement.h
new file mode 100644
index 0000000..9156143
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/for_loop_statement.h
@@ -0,0 +1,72 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_FOR_LOOP_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_FOR_LOOP_STATEMENT_H_
+
+#include "src/tint/lang/wgsl/ast/block_statement.h"
+
+namespace tint::ast {
+
+class Expression;
+
+/// A for loop statement
+class ForLoopStatement final : public utils::Castable<ForLoopStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the for loop statement source
+ /// @param initializer the optional loop initializer statement
+ /// @param condition the optional loop condition expression
+ /// @param continuing the optional continuing statement
+ /// @param body the loop body
+ /// @param attributes the while statement attributes
+ ForLoopStatement(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ const Statement* initializer,
+ const Expression* condition,
+ const Statement* continuing,
+ const BlockStatement* body,
+ utils::VectorRef<const ast::Attribute*> attributes);
+
+ /// Destructor
+ ~ForLoopStatement() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const ForLoopStatement* Clone(CloneContext* ctx) const override;
+
+ /// The initializer statement
+ const Statement* const initializer;
+
+ /// The condition expression
+ const Expression* const condition;
+
+ /// The continuing statement
+ const Statement* const continuing;
+
+ /// The loop body block
+ const BlockStatement* const body;
+
+ /// The attribute list
+ const utils::Vector<const Attribute*, 1> attributes;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_FOR_LOOP_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/for_loop_statement_test.cc b/src/tint/lang/wgsl/ast/for_loop_statement_test.cc
new file mode 100644
index 0000000..f6f145c
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/for_loop_statement_test.cc
@@ -0,0 +1,113 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/binary_expression.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast {
+namespace {
+
+using ForLoopStatementTest = TestHelper;
+
+TEST_F(ForLoopStatementTest, Creation) {
+ auto* init = Decl(Var("i", ty.u32()));
+ auto* cond = create<BinaryExpression>(BinaryOp::kLessThan, Expr("i"), Expr(5_u));
+ auto* cont = Assign("i", Add("i", 1_u));
+ auto* body = Block(Return());
+ auto* l = For(init, cond, cont, body);
+
+ EXPECT_EQ(l->initializer, init);
+ EXPECT_EQ(l->condition, cond);
+ EXPECT_EQ(l->continuing, cont);
+ EXPECT_EQ(l->body, body);
+}
+
+TEST_F(ForLoopStatementTest, Creation_WithSource) {
+ auto* body = Block(Return());
+ auto* l = For(Source{{20u, 2u}}, nullptr, nullptr, nullptr, body);
+ auto src = l->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(ForLoopStatementTest, Creation_Null_InitCondCont) {
+ auto* body = Block(Return());
+ auto* l = For(nullptr, nullptr, nullptr, body);
+ EXPECT_EQ(l->body, body);
+}
+
+TEST_F(ForLoopStatementTest, Creation_WithAttributes) {
+ auto* attr1 = DiagnosticAttribute(builtin::DiagnosticSeverity::kOff, "foo");
+ auto* attr2 = DiagnosticAttribute(builtin::DiagnosticSeverity::kOff, "bar");
+ auto* body = Block(Return());
+ auto* l = For(nullptr, nullptr, nullptr, body, utils::Vector{attr1, attr2});
+
+ EXPECT_THAT(l->attributes, testing::ElementsAre(attr1, attr2));
+}
+
+TEST_F(ForLoopStatementTest, Assert_Null_Body) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.For(nullptr, nullptr, nullptr, nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(ForLoopStatementTest, Assert_DifferentProgramID_Initializer) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.For(b2.Block(), nullptr, nullptr, b1.Block());
+ },
+ "internal compiler error");
+}
+
+TEST_F(ForLoopStatementTest, Assert_DifferentProgramID_Condition) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.For(nullptr, b2.Expr(true), nullptr, b1.Block());
+ },
+ "internal compiler error");
+}
+
+TEST_F(ForLoopStatementTest, Assert_DifferentProgramID_Continuing) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.For(nullptr, nullptr, b2.Block(), b1.Block());
+ },
+ "internal compiler error");
+}
+
+TEST_F(ForLoopStatementTest, Assert_DifferentProgramID_Body) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.For(nullptr, nullptr, nullptr, b2.Block());
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/function.cc b/src/tint/lang/wgsl/ast/function.cc
new file mode 100644
index 0000000..d3f54e1
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/function.cc
@@ -0,0 +1,108 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/function.h"
+
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+#include "src/tint/lang/wgsl/ast/workgroup_attribute.h"
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::Function);
+
+namespace tint::ast {
+
+Function::Function(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Identifier* n,
+ utils::VectorRef<const Parameter*> parameters,
+ Type return_ty,
+ const BlockStatement* b,
+ utils::VectorRef<const Attribute*> attrs,
+ utils::VectorRef<const Attribute*> return_type_attrs)
+ : Base(pid, nid, src),
+ name(n),
+ params(std::move(parameters)),
+ return_type(return_ty),
+ body(b),
+ attributes(std::move(attrs)),
+ return_type_attributes(std::move(return_type_attrs)) {
+ TINT_ASSERT(AST, name);
+ if (name) {
+ TINT_ASSERT(AST, !name->Is<TemplatedIdentifier>());
+ }
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, name, program_id);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, return_ty, program_id);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, body, program_id);
+ for (auto* param : params) {
+ TINT_ASSERT(AST, tint::Is<Parameter>(param));
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, param, program_id);
+ }
+ for (auto* attr : attributes) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, attr, program_id);
+ }
+ for (auto* attr : return_type_attributes) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, attr, program_id);
+ }
+}
+
+Function::~Function() = default;
+
+PipelineStage Function::PipelineStage() const {
+ if (auto* stage = GetAttribute<StageAttribute>(attributes)) {
+ return stage->stage;
+ }
+ return PipelineStage::kNone;
+}
+
+const Function* Function::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto n = ctx->Clone(name);
+ auto p = ctx->Clone(params);
+ auto ret = ctx->Clone(return_type);
+ auto* b = ctx->Clone(body);
+ auto attrs = ctx->Clone(attributes);
+ auto ret_attrs = ctx->Clone(return_type_attributes);
+ return ctx->dst->create<Function>(src, n, p, ret, b, attrs, ret_attrs);
+}
+
+const Function* FunctionList::Find(Symbol sym) const {
+ for (auto* func : *this) {
+ if (func->name->symbol == sym) {
+ return func;
+ }
+ }
+ return nullptr;
+}
+
+const Function* FunctionList::Find(Symbol sym, PipelineStage stage) const {
+ for (auto* func : *this) {
+ if (func->name->symbol == sym && func->PipelineStage() == stage) {
+ return func;
+ }
+ }
+ return nullptr;
+}
+
+bool FunctionList::HasStage(ast::PipelineStage stage) const {
+ for (auto* func : *this) {
+ if (func->PipelineStage() == stage) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/function.h b/src/tint/lang/wgsl/ast/function.h
new file mode 100644
index 0000000..02c4876
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/function.h
@@ -0,0 +1,123 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_FUNCTION_H_
+#define SRC_TINT_LANG_WGSL_AST_FUNCTION_H_
+
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+#include "src/tint/lang/wgsl/ast/binding_attribute.h"
+#include "src/tint/lang/wgsl/ast/block_statement.h"
+#include "src/tint/lang/wgsl/ast/builtin_attribute.h"
+#include "src/tint/lang/wgsl/ast/group_attribute.h"
+#include "src/tint/lang/wgsl/ast/location_attribute.h"
+#include "src/tint/lang/wgsl/ast/parameter.h"
+#include "src/tint/lang/wgsl/ast/pipeline_stage.h"
+
+// Forward declarations
+namespace tint::ast {
+class Identifier;
+class IdentifierExpression;
+} // namespace tint::ast
+
+namespace tint::ast {
+
+/// A Function statement.
+class Function final : public utils::Castable<Function, Node> {
+ public:
+ /// Create a function
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the variable source
+ /// @param name the function name
+ /// @param params the function parameters
+ /// @param return_type the return type
+ /// @param body the function body
+ /// @param attributes the function attributes
+ /// @param return_type_attributes the return type attributes
+ Function(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ const Identifier* name,
+ utils::VectorRef<const Parameter*> params,
+ Type return_type,
+ const BlockStatement* body,
+ utils::VectorRef<const Attribute*> attributes,
+ utils::VectorRef<const Attribute*> return_type_attributes);
+
+ /// Destructor
+ ~Function() override;
+
+ /// @returns the functions pipeline stage or None if not set
+ ast::PipelineStage PipelineStage() const;
+
+ /// @returns true if this function is an entry point
+ bool IsEntryPoint() const { return PipelineStage() != ast::PipelineStage::kNone; }
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const Function* Clone(CloneContext* ctx) const override;
+
+ /// The function name
+ const Identifier* const name;
+
+ /// The function params
+ const utils::Vector<const Parameter*, 8> params;
+
+ /// The function return type
+ const Type return_type;
+
+ /// The function body
+ const BlockStatement* const body;
+
+ /// The attributes attached to this function
+ const utils::Vector<const Attribute*, 2> attributes;
+
+ /// The attributes attached to the function return type.
+ const utils::Vector<const Attribute*, 2> return_type_attributes;
+};
+
+/// A list of functions
+class FunctionList : public utils::Vector<const Function*, 8> {
+ public:
+ /// Appends f to the end of the list
+ /// @param f the function to append to this list
+ void Add(const Function* f) { this->Push(f); }
+
+ /// Returns the function with the given name
+ /// @param sym the function symbol to search for
+ /// @returns the associated function or nullptr if none exists
+ const Function* Find(Symbol sym) const;
+
+ /// Returns the function with the given name
+ /// @param sym the function symbol to search for
+ /// @param stage the pipeline stage
+ /// @returns the associated function or nullptr if none exists
+ const Function* Find(Symbol sym, ast::PipelineStage stage) const;
+
+ /// @param stage the pipeline stage
+ /// @returns true if the Builder contains an entrypoint function with
+ /// the given stage
+ bool HasStage(ast::PipelineStage stage) const;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_FUNCTION_H_
diff --git a/src/tint/lang/wgsl/ast/function_test.cc b/src/tint/lang/wgsl/ast/function_test.cc
new file mode 100644
index 0000000..c48e87a
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/function_test.cc
@@ -0,0 +1,242 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/workgroup_attribute.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast {
+namespace {
+
+using FunctionTest = TestHelper;
+
+TEST_F(FunctionTest, Creation_i32ReturnType) {
+ utils::Vector params{Param("var", ty.i32())};
+ auto i32 = ty.i32();
+ auto* var = params[0];
+
+ auto* f = Func("func", params, i32, utils::Empty);
+ EXPECT_EQ(f->name->symbol, Symbols().Get("func"));
+ ASSERT_EQ(f->params.Length(), 1u);
+ CheckIdentifier(f->return_type, "i32");
+ EXPECT_EQ(f->params[0], var);
+}
+
+TEST_F(FunctionTest, Creation_NoReturnType) {
+ utils::Vector params{Param("var", ty.i32())};
+ auto* var = params[0];
+
+ auto* f = Func("func", params, ty.void_(), utils::Empty);
+ EXPECT_EQ(f->name->symbol, Symbols().Get("func"));
+ ASSERT_EQ(f->params.Length(), 1u);
+ EXPECT_EQ(f->return_type, nullptr);
+ EXPECT_EQ(f->params[0], var);
+}
+
+TEST_F(FunctionTest, Creation_SingleParam) {
+ utils::Vector params{Param("var", ty.i32())};
+ auto* var = params[0];
+
+ auto* f = Func("func", params, ty.void_(), utils::Empty);
+ EXPECT_EQ(f->name->symbol, Symbols().Get("func"));
+ ASSERT_EQ(f->params.Length(), 1u);
+ EXPECT_EQ(f->return_type, nullptr);
+ EXPECT_EQ(f->params[0], var);
+ ASSERT_NE(f->body, nullptr);
+ ASSERT_EQ(f->body->statements.Length(), 0u);
+}
+
+TEST_F(FunctionTest, Creation_Body_Vector) {
+ auto* f = Func("func", utils::Empty, ty.void_(), utils::Vector{Discard(), Return()});
+ ASSERT_NE(f->body, nullptr);
+ ASSERT_EQ(f->body->statements.Length(), 2u);
+ EXPECT_TRUE(f->body->statements[0]->Is<DiscardStatement>());
+ EXPECT_TRUE(f->body->statements[1]->Is<ReturnStatement>());
+}
+
+TEST_F(FunctionTest, Creation_Body_Block) {
+ auto* f = Func("func", utils::Empty, ty.void_(), Block(Discard(), Return()));
+ ASSERT_NE(f->body, nullptr);
+ ASSERT_EQ(f->body->statements.Length(), 2u);
+ EXPECT_TRUE(f->body->statements[0]->Is<DiscardStatement>());
+ EXPECT_TRUE(f->body->statements[1]->Is<ReturnStatement>());
+}
+
+TEST_F(FunctionTest, Creation_Body_Stmt) {
+ auto* f = Func("func", utils::Empty, ty.void_(), Return());
+ ASSERT_NE(f->body, nullptr);
+ ASSERT_EQ(f->body->statements.Length(), 1u);
+ EXPECT_TRUE(f->body->statements[0]->Is<ReturnStatement>());
+}
+
+TEST_F(FunctionTest, Creation_Body_Nullptr) {
+ auto* f = Func("func", utils::Empty, ty.void_(), nullptr);
+ EXPECT_EQ(f->body, nullptr);
+}
+
+TEST_F(FunctionTest, Creation_WithSource) {
+ utils::Vector params{Param("var", ty.i32())};
+
+ auto* f = Func(Source{Source::Location{20, 2}}, "func", params, ty.void_(), utils::Empty);
+ auto src = f->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(FunctionTest, Assert_NullName) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.Func(static_cast<Identifier*>(nullptr), utils::Empty, b.ty.void_(), utils::Empty);
+ },
+ "internal compiler error");
+}
+
+TEST_F(FunctionTest, Assert_TemplatedName) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.Func(b.Ident("a", "b"), utils::Empty, b.ty.void_(), utils::Empty);
+ },
+ "internal compiler error");
+}
+
+TEST_F(FunctionTest, Assert_NullParam) {
+ using ParamList = utils::Vector<const Parameter*, 2>;
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ ParamList params;
+ params.Push(b.Param("var", b.ty.i32()));
+ params.Push(nullptr);
+ b.Func("f", params, b.ty.void_(), utils::Empty);
+ },
+ "internal compiler error");
+}
+
+TEST_F(FunctionTest, Assert_DifferentProgramID_Symbol) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Func(b2.Sym("func"), utils::Empty, b1.ty.void_(), utils::Empty);
+ },
+ "internal compiler error");
+}
+
+TEST_F(FunctionTest, Assert_DifferentProgramID_Param) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Func("func",
+ utils::Vector{
+ b2.Param("var", b2.ty.i32()),
+ },
+ b1.ty.void_(), utils::Empty);
+ },
+ "internal compiler error");
+}
+
+TEST_F(FunctionTest, Assert_DifferentProgramID_Attr) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Func("func", utils::Empty, b1.ty.void_(), utils::Empty,
+ utils::Vector{
+ b2.WorkgroupSize(2_i, 4_i, 6_i),
+ });
+ },
+ "internal compiler error");
+}
+
+TEST_F(FunctionTest, Assert_DifferentProgramID_ReturnType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Func("func", utils::Empty, b2.ty.i32(), utils::Empty);
+ },
+ "internal compiler error");
+}
+
+TEST_F(FunctionTest, Assert_DifferentProgramID_ReturnAttr) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Func("func", utils::Empty, b1.ty.void_(), utils::Empty, utils::Empty,
+ utils::Vector{
+ b2.WorkgroupSize(2_i, 4_i, 6_i),
+ });
+ },
+ "internal compiler error");
+}
+
+using FunctionListTest = TestHelper;
+
+TEST_F(FunctionListTest, FindSymbol) {
+ auto* func = Func("main", utils::Empty, ty.f32(), utils::Empty);
+ FunctionList list;
+ list.Add(func);
+ EXPECT_EQ(func, list.Find(Symbols().Register("main")));
+}
+
+TEST_F(FunctionListTest, FindSymbolMissing) {
+ FunctionList list;
+ EXPECT_EQ(nullptr, list.Find(Symbols().Register("Missing")));
+}
+
+TEST_F(FunctionListTest, FindSymbolStage) {
+ auto* fs = Func("main", utils::Empty, ty.f32(), utils::Empty,
+ utils::Vector{
+ Stage(PipelineStage::kFragment),
+ });
+ auto* vs = Func("main", utils::Empty, ty.f32(), utils::Empty,
+ utils::Vector{
+ Stage(PipelineStage::kVertex),
+ });
+ FunctionList list;
+ list.Add(fs);
+ list.Add(vs);
+ EXPECT_EQ(fs, list.Find(Symbols().Register("main"), PipelineStage::kFragment));
+ EXPECT_EQ(vs, list.Find(Symbols().Register("main"), PipelineStage::kVertex));
+}
+
+TEST_F(FunctionListTest, FindSymbolStageMissing) {
+ FunctionList list;
+ list.Add(Func("main", utils::Empty, ty.f32(), utils::Empty,
+ utils::Vector{
+ Stage(PipelineStage::kFragment),
+ }));
+ EXPECT_EQ(nullptr, list.Find(Symbols().Register("main"), PipelineStage::kVertex));
+}
+
+TEST_F(FunctionListTest, HasStage) {
+ FunctionList list;
+ list.Add(Func("main", utils::Empty, ty.f32(), utils::Empty,
+ utils::Vector{
+ Stage(PipelineStage::kFragment),
+ }));
+ EXPECT_TRUE(list.HasStage(PipelineStage::kFragment));
+ EXPECT_FALSE(list.HasStage(PipelineStage::kVertex));
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/group_attribute.cc b/src/tint/lang/wgsl/ast/group_attribute.cc
new file mode 100644
index 0000000..9f3afc3
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/group_attribute.cc
@@ -0,0 +1,41 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/group_attribute.h"
+
+#include <string>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::GroupAttribute);
+
+namespace tint::ast {
+
+GroupAttribute::GroupAttribute(ProgramID pid, NodeID nid, const Source& src, const Expression* exp)
+ : Base(pid, nid, src), expr(exp) {}
+
+GroupAttribute::~GroupAttribute() = default;
+
+std::string GroupAttribute::Name() const {
+ return "group";
+}
+
+const GroupAttribute* GroupAttribute::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* expr_ = ctx->Clone(expr);
+ return ctx->dst->create<GroupAttribute>(src, expr_);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/group_attribute.h b/src/tint/lang/wgsl/ast/group_attribute.h
new file mode 100644
index 0000000..ea45ea7
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/group_attribute.h
@@ -0,0 +1,51 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_GROUP_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_GROUP_ATTRIBUTE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+namespace tint::ast {
+
+/// A group attribute
+class GroupAttribute final : public utils::Castable<GroupAttribute, Attribute> {
+ public:
+ /// constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param expr the group expression
+ GroupAttribute(ProgramID pid, NodeID nid, const Source& src, const Expression* expr);
+ ~GroupAttribute() override;
+
+ /// @returns the WGSL name for the attribute
+ std::string Name() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const GroupAttribute* Clone(CloneContext* ctx) const override;
+
+ /// The group expression
+ const Expression* const expr;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_GROUP_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/group_attribute_test.cc b/src/tint/lang/wgsl/ast/group_attribute_test.cc
new file mode 100644
index 0000000..ad5078a
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/group_attribute_test.cc
@@ -0,0 +1,29 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using GroupAttributeTest = TestHelper;
+
+TEST_F(GroupAttributeTest, Creation) {
+ auto* d = Group(2_a);
+ EXPECT_TRUE(d->expr->Is<IntLiteralExpression>());
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/id_attribute.cc b/src/tint/lang/wgsl/ast/id_attribute.cc
new file mode 100644
index 0000000..687cbe2
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/id_attribute.cc
@@ -0,0 +1,41 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/id_attribute.h"
+
+#include <string>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::IdAttribute);
+
+namespace tint::ast {
+
+IdAttribute::IdAttribute(ProgramID pid, NodeID nid, const Source& src, const Expression* exp)
+ : Base(pid, nid, src), expr(exp) {}
+
+IdAttribute::~IdAttribute() = default;
+
+std::string IdAttribute::Name() const {
+ return "id";
+}
+
+const IdAttribute* IdAttribute::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* expr_ = ctx->Clone(expr);
+ return ctx->dst->create<IdAttribute>(src, expr_);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/id_attribute.h b/src/tint/lang/wgsl/ast/id_attribute.h
new file mode 100644
index 0000000..119494a
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/id_attribute.h
@@ -0,0 +1,51 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_ID_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_ID_ATTRIBUTE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+namespace tint::ast {
+
+/// An id attribute for pipeline-overridable constants
+class IdAttribute final : public utils::Castable<IdAttribute, Attribute> {
+ public:
+ /// Create an id attribute.
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param expr the numeric id expression
+ IdAttribute(ProgramID pid, NodeID nid, const Source& src, const Expression* expr);
+ ~IdAttribute() override;
+
+ /// @returns the WGSL name for the attribute
+ std::string Name() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const IdAttribute* Clone(CloneContext* ctx) const override;
+
+ /// The id expression
+ const Expression* const expr;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_ID_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/id_attribute_test.cc b/src/tint/lang/wgsl/ast/id_attribute_test.cc
new file mode 100644
index 0000000..ee4bb65
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/id_attribute_test.cc
@@ -0,0 +1,31 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/id_attribute.h"
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IdAttributeTest = TestHelper;
+
+TEST_F(IdAttributeTest, Creation) {
+ auto* d = Id(12_a);
+ EXPECT_TRUE(d->expr->Is<IntLiteralExpression>());
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/identifier.cc b/src/tint/lang/wgsl/ast/identifier.cc
new file mode 100644
index 0000000..23ccf35
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/identifier.cc
@@ -0,0 +1,38 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/identifier.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::Identifier);
+
+namespace tint::ast {
+
+Identifier::Identifier(ProgramID pid, NodeID nid, const Source& src, Symbol sym)
+ : Base(pid, nid, src), symbol(sym) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, symbol, program_id);
+ TINT_ASSERT(AST, symbol.IsValid());
+}
+
+Identifier::~Identifier() = default;
+
+const Identifier* Identifier::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto sym = ctx->Clone(symbol);
+ return ctx->dst->create<Identifier>(src, sym);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/identifier.h b/src/tint/lang/wgsl/ast/identifier.h
new file mode 100644
index 0000000..7a4a6eb
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/identifier.h
@@ -0,0 +1,47 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_IDENTIFIER_H_
+#define SRC_TINT_LANG_WGSL_AST_IDENTIFIER_H_
+
+#include "src/tint/lang/wgsl/ast/node.h"
+
+namespace tint::ast {
+
+/// An identifier
+class Identifier : public utils::Castable<Identifier, Node> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param sym the symbol for the identifier
+ Identifier(ProgramID pid, NodeID nid, const Source& src, Symbol sym);
+
+ /// Destructor
+ ~Identifier() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const Identifier* Clone(CloneContext* ctx) const override;
+
+ /// The symbol for the identifier
+ const Symbol symbol;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_IDENTIFIER_H_
diff --git a/src/tint/lang/wgsl/ast/identifier_expression.cc b/src/tint/lang/wgsl/ast/identifier_expression.cc
new file mode 100644
index 0000000..421c5e0
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/identifier_expression.cc
@@ -0,0 +1,41 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/identifier_expression.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::IdentifierExpression);
+
+namespace tint::ast {
+
+IdentifierExpression::IdentifierExpression(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Identifier* ident)
+ : Base(pid, nid, src), identifier(ident) {
+ TINT_ASSERT(AST, identifier != nullptr);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL(AST, identifier, program_id);
+}
+
+IdentifierExpression::~IdentifierExpression() = default;
+
+const IdentifierExpression* IdentifierExpression::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto ident = ctx->Clone(identifier);
+ return ctx->dst->create<IdentifierExpression>(src, ident);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/identifier_expression.h b/src/tint/lang/wgsl/ast/identifier_expression.h
new file mode 100644
index 0000000..a62e4b3
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/identifier_expression.h
@@ -0,0 +1,55 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_IDENTIFIER_EXPRESSION_H_
+#define SRC_TINT_LANG_WGSL_AST_IDENTIFIER_EXPRESSION_H_
+
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+// Forward declarations
+namespace tint::ast {
+class Identifier;
+}
+
+namespace tint::ast {
+
+/// An identifier expression
+class IdentifierExpression final : public utils::Castable<IdentifierExpression, Expression> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param identifier the identifier
+ IdentifierExpression(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Identifier* identifier);
+
+ /// Destructor
+ ~IdentifierExpression() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const IdentifierExpression* Clone(CloneContext* ctx) const override;
+
+ /// The identifier for the expression
+ Identifier const* const identifier;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_IDENTIFIER_EXPRESSION_H_
diff --git a/src/tint/lang/wgsl/ast/identifier_expression_test.cc b/src/tint/lang/wgsl/ast/identifier_expression_test.cc
new file mode 100644
index 0000000..5fc21df
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/identifier_expression_test.cc
@@ -0,0 +1,67 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+
+using IdentifierExpressionTest = TestHelper;
+
+TEST_F(IdentifierExpressionTest, Creation) {
+ auto* i = Expr("ident");
+ EXPECT_EQ(i->identifier->symbol, Symbol(1, ID(), "ident"));
+}
+
+TEST_F(IdentifierExpressionTest, CreationTemplated) {
+ auto* i = Expr(Ident("ident", true));
+ EXPECT_EQ(i->identifier->symbol, Symbol(1, ID(), "ident"));
+ auto* tmpl_ident = i->identifier->As<TemplatedIdentifier>();
+ ASSERT_NE(tmpl_ident, nullptr);
+ EXPECT_EQ(tmpl_ident->arguments.Length(), 1_u);
+ EXPECT_TRUE(tmpl_ident->arguments[0]->Is<BoolLiteralExpression>());
+}
+
+TEST_F(IdentifierExpressionTest, Creation_WithSource) {
+ auto* i = Expr(Source{{20, 2}}, "ident");
+ EXPECT_EQ(i->identifier->symbol, Symbol(1, ID(), "ident"));
+
+ EXPECT_EQ(i->source.range, (Source::Range{{20, 2}}));
+ EXPECT_EQ(i->identifier->source.range, (Source::Range{{20, 2}}));
+}
+
+TEST_F(IdentifierExpressionTest, Assert_InvalidSymbol) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.Expr("");
+ },
+ "internal compiler error");
+}
+
+TEST_F(IdentifierExpressionTest, Assert_DifferentProgramID_Symbol) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Expr(b2.Sym("b2"));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/identifier_test.cc b/src/tint/lang/wgsl/ast/identifier_test.cc
new file mode 100644
index 0000000..45e6760
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/identifier_test.cc
@@ -0,0 +1,62 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using IdentifierTest = TestHelper;
+
+TEST_F(IdentifierTest, Creation) {
+ auto* i = Ident("ident");
+ EXPECT_EQ(i->symbol, Symbol(1, ID(), "ident"));
+}
+
+TEST_F(IdentifierTest, Creation_WithSource) {
+ auto* i = Ident(Source{{20, 2}}, "ident");
+ EXPECT_EQ(i->symbol, Symbol(1, ID(), "ident"));
+
+ auto src = i->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(IdentifierTest, IsIdentifier) {
+ auto* i = Ident("ident");
+ EXPECT_TRUE(i->Is<Identifier>());
+}
+
+TEST_F(IdentifierTest, Assert_InvalidSymbol) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.Ident("");
+ },
+ "internal compiler error");
+}
+
+TEST_F(IdentifierTest, Assert_DifferentProgramID_Symbol) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Ident(b2.Sym("b2"));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/if_statement.cc b/src/tint/lang/wgsl/ast/if_statement.cc
new file mode 100644
index 0000000..77319ef
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/if_statement.cc
@@ -0,0 +1,61 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/if_statement.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::IfStatement);
+
+namespace tint::ast {
+
+IfStatement::IfStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* cond,
+ const BlockStatement* b,
+ const Statement* else_stmt,
+ utils::VectorRef<const Attribute*> attrs)
+ : Base(pid, nid, src),
+ condition(cond),
+ body(b),
+ else_statement(else_stmt),
+ attributes(std::move(attrs)) {
+ TINT_ASSERT(AST, condition);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, condition, program_id);
+ TINT_ASSERT(AST, body);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, body, program_id);
+ if (else_statement) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, else_statement, program_id);
+ TINT_ASSERT(AST, (else_statement->IsAnyOf<IfStatement, BlockStatement>()));
+ }
+ for (auto* attr : attributes) {
+ TINT_ASSERT(AST, attr);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, attr, program_id);
+ }
+}
+
+IfStatement::~IfStatement() = default;
+
+const IfStatement* IfStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* cond = ctx->Clone(condition);
+ auto* b = ctx->Clone(body);
+ auto* el = ctx->Clone(else_statement);
+ auto attrs = ctx->Clone(attributes);
+ return ctx->dst->create<IfStatement>(src, cond, b, el, std::move(attrs));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/if_statement.h b/src/tint/lang/wgsl/ast/if_statement.h
new file mode 100644
index 0000000..70bd758
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/if_statement.h
@@ -0,0 +1,68 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_IF_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_IF_STATEMENT_H_
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/block_statement.h"
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+namespace tint::ast {
+
+/// An if statement
+class IfStatement final : public utils::Castable<IfStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param condition the if condition
+ /// @param body the if body
+ /// @param else_stmt the else statement, or nullptr
+ /// @param attributes the if statement attributes
+ IfStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* condition,
+ const BlockStatement* body,
+ const Statement* else_stmt,
+ utils::VectorRef<const Attribute*> attributes);
+
+ /// Destructor
+ ~IfStatement() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const IfStatement* Clone(CloneContext* ctx) const override;
+
+ /// The if condition or nullptr if none set
+ const Expression* const condition;
+
+ /// The if body
+ const BlockStatement* const body;
+
+ /// The optional else statement, or nullptr
+ const Statement* const else_statement;
+
+ /// The attribute list
+ const utils::Vector<const Attribute*, 1> attributes;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_IF_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/if_statement_test.cc b/src/tint/lang/wgsl/ast/if_statement_test.cc
new file mode 100644
index 0000000..e231f1c
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/if_statement_test.cc
@@ -0,0 +1,107 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/if_statement.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using IfStatementTest = TestHelper;
+
+TEST_F(IfStatementTest, Creation) {
+ auto* cond = Expr("cond");
+ auto* stmt = If(Source{Source::Location{20, 2}}, cond, Block(create<DiscardStatement>()));
+ auto src = stmt->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(IfStatementTest, Creation_WithAttributes) {
+ auto* attr1 = DiagnosticAttribute(builtin::DiagnosticSeverity::kOff, "foo");
+ auto* attr2 = DiagnosticAttribute(builtin::DiagnosticSeverity::kOff, "bar");
+ auto* cond = Expr("cond");
+ auto* stmt = If(cond, Block(), ElseStmt(), utils::Vector{attr1, attr2});
+
+ EXPECT_THAT(stmt->attributes, testing::ElementsAre(attr1, attr2));
+}
+
+TEST_F(IfStatementTest, IsIf) {
+ auto* stmt = If(Expr(true), Block());
+ EXPECT_TRUE(stmt->Is<IfStatement>());
+}
+
+TEST_F(IfStatementTest, Assert_Null_Condition) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.If(nullptr, b.Block());
+ },
+ "internal compiler error");
+}
+
+TEST_F(IfStatementTest, Assert_Null_Body) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.If(b.Expr(true), nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(IfStatementTest, Assert_InvalidElse) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.If(b.Expr(true), b.Block(), b.Else(b.CallStmt(b.Call("foo"))));
+ },
+ "internal compiler error");
+}
+
+TEST_F(IfStatementTest, Assert_DifferentProgramID_Cond) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.If(b2.Expr(true), b1.Block());
+ },
+ "internal compiler error");
+}
+
+TEST_F(IfStatementTest, Assert_DifferentProgramID_Body) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.If(b1.Expr(true), b2.Block());
+ },
+ "internal compiler error");
+}
+
+TEST_F(IfStatementTest, Assert_DifferentProgramID_ElseStatement) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.If(b1.Expr(true), b1.Block(), b2.Else(b2.If(b2.Expr("ident"), b2.Block())));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/increment_decrement_statement.cc b/src/tint/lang/wgsl/ast/increment_decrement_statement.cc
new file mode 100644
index 0000000..834da15
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/increment_decrement_statement.cc
@@ -0,0 +1,41 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/increment_decrement_statement.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::IncrementDecrementStatement);
+
+namespace tint::ast {
+
+IncrementDecrementStatement::IncrementDecrementStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* l,
+ bool inc)
+ : Base(pid, nid, src), lhs(l), increment(inc) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, lhs, program_id);
+}
+
+IncrementDecrementStatement::~IncrementDecrementStatement() = default;
+
+const IncrementDecrementStatement* IncrementDecrementStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* l = ctx->Clone(lhs);
+ return ctx->dst->create<IncrementDecrementStatement>(src, l, increment);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/increment_decrement_statement.h b/src/tint/lang/wgsl/ast/increment_decrement_statement.h
new file mode 100644
index 0000000..d920528
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/increment_decrement_statement.h
@@ -0,0 +1,57 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_INCREMENT_DECREMENT_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_INCREMENT_DECREMENT_STATEMENT_H_
+
+#include "src/tint/lang/wgsl/ast/expression.h"
+#include "src/tint/lang/wgsl/ast/statement.h"
+
+namespace tint::ast {
+
+/// An increment or decrement statement
+class IncrementDecrementStatement final
+ : public utils::Castable<IncrementDecrementStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param lhs the LHS expression
+ /// @param inc `true` for increment, `false` for decrement
+ IncrementDecrementStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* lhs,
+ bool inc);
+
+ /// Destructor
+ ~IncrementDecrementStatement() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const IncrementDecrementStatement* Clone(CloneContext* ctx) const override;
+
+ /// The LHS expression.
+ const Expression* const lhs;
+
+ /// `true` for increment, `false` for decrement.
+ bool increment;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_INCREMENT_DECREMENT_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/increment_decrement_statement_test.cc b/src/tint/lang/wgsl/ast/increment_decrement_statement_test.cc
new file mode 100644
index 0000000..fa6f61d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/increment_decrement_statement_test.cc
@@ -0,0 +1,67 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/increment_decrement_statement.h"
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using IncrementDecrementStatementTest = TestHelper;
+
+TEST_F(IncrementDecrementStatementTest, Creation) {
+ auto* expr = Expr("expr");
+
+ auto* i = create<IncrementDecrementStatement>(expr, true);
+ EXPECT_EQ(i->lhs, expr);
+ EXPECT_TRUE(i->increment);
+}
+
+TEST_F(IncrementDecrementStatementTest, Creation_WithSource) {
+ auto* expr = Expr("expr");
+ auto* i = create<IncrementDecrementStatement>(Source{Source::Location{20, 2}}, expr, true);
+ auto src = i->source;
+ EXPECT_EQ(i->lhs, expr);
+ EXPECT_TRUE(i->increment);
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(IncrementDecrementStatementTest, IsIncrementDecrement) {
+ auto* expr = Expr("expr");
+ auto* i = create<IncrementDecrementStatement>(expr, true);
+ EXPECT_TRUE(i->Is<IncrementDecrementStatement>());
+}
+
+TEST_F(IncrementDecrementStatementTest, Decrement) {
+ auto* expr = Expr("expr");
+ auto* i = create<IncrementDecrementStatement>(expr, false);
+ EXPECT_EQ(i->lhs, expr);
+ EXPECT_FALSE(i->increment);
+}
+
+TEST_F(IncrementDecrementStatementTest, Assert_DifferentProgramID_Expr) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<IncrementDecrementStatement>(b2.Expr(true), true);
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/index_accessor_expression.cc b/src/tint/lang/wgsl/ast/index_accessor_expression.cc
new file mode 100644
index 0000000..48503d47
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/index_accessor_expression.cc
@@ -0,0 +1,43 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/index_accessor_expression.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::IndexAccessorExpression);
+
+namespace tint::ast {
+
+IndexAccessorExpression::IndexAccessorExpression(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* obj,
+ const Expression* idx)
+ : Base(pid, nid, src, obj), index(idx) {
+ TINT_ASSERT(AST, idx);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, idx, program_id);
+}
+
+IndexAccessorExpression::~IndexAccessorExpression() = default;
+
+const IndexAccessorExpression* IndexAccessorExpression::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* obj = ctx->Clone(object);
+ auto* idx = ctx->Clone(index);
+ return ctx->dst->create<IndexAccessorExpression>(src, obj, idx);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/index_accessor_expression.h b/src/tint/lang/wgsl/ast/index_accessor_expression.h
new file mode 100644
index 0000000..20ff4bd
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/index_accessor_expression.h
@@ -0,0 +1,53 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_INDEX_ACCESSOR_EXPRESSION_H_
+#define SRC_TINT_LANG_WGSL_AST_INDEX_ACCESSOR_EXPRESSION_H_
+
+#include "src/tint/lang/wgsl/ast/accessor_expression.h"
+
+namespace tint::ast {
+
+/// An index accessor expression
+class IndexAccessorExpression final
+ : public utils::Castable<IndexAccessorExpression, AccessorExpression> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the index accessor source
+ /// @param obj the object
+ /// @param idx the index expression
+ IndexAccessorExpression(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ const Expression* obj,
+ const Expression* idx);
+
+ /// Destructor
+ ~IndexAccessorExpression() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const IndexAccessorExpression* Clone(CloneContext* ctx) const override;
+
+ /// the index expression
+ const Expression* const index;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_INDEX_ACCESSOR_EXPRESSION_H_
diff --git a/src/tint/lang/wgsl/ast/index_accessor_expression_test.cc b/src/tint/lang/wgsl/ast/index_accessor_expression_test.cc
new file mode 100644
index 0000000..8177109
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/index_accessor_expression_test.cc
@@ -0,0 +1,89 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using IndexAccessorExpressionTest = TestHelper;
+
+TEST_F(IndexAccessorExpressionTest, Create) {
+ auto* obj = Expr("obj");
+ auto* idx = Expr("idx");
+
+ auto* exp = IndexAccessor(obj, idx);
+ ASSERT_EQ(exp->object, obj);
+ ASSERT_EQ(exp->index, idx);
+}
+
+TEST_F(IndexAccessorExpressionTest, CreateWithSource) {
+ auto* obj = Expr("obj");
+ auto* idx = Expr("idx");
+
+ auto* exp = IndexAccessor(Source{{20, 2}}, obj, idx);
+ auto src = exp->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(IndexAccessorExpressionTest, IsIndexAccessor) {
+ auto* obj = Expr("obj");
+ auto* idx = Expr("idx");
+
+ auto* exp = IndexAccessor(obj, idx);
+ EXPECT_TRUE(exp->Is<IndexAccessorExpression>());
+}
+
+TEST_F(IndexAccessorExpressionTest, Assert_Null_Array) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.IndexAccessor(nullptr, b.Expr("idx"));
+ },
+ "internal compiler error");
+}
+
+TEST_F(IndexAccessorExpressionTest, Assert_Null_Index) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.IndexAccessor(b.Expr("arr"), nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(IndexAccessorExpressionTest, Assert_DifferentProgramID_Array) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.IndexAccessor(b2.Expr("arr"), b1.Expr("idx"));
+ },
+ "internal compiler error");
+}
+
+TEST_F(IndexAccessorExpressionTest, Assert_DifferentProgramID_Index) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.IndexAccessor(b1.Expr("arr"), b2.Expr("idx"));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/index_attribute.cc b/src/tint/lang/wgsl/ast/index_attribute.cc
new file mode 100644
index 0000000..ffe227e
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/index_attribute.cc
@@ -0,0 +1,41 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/index_attribute.h"
+
+#include <string>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::IndexAttribute);
+
+namespace tint::ast {
+
+IndexAttribute::IndexAttribute(ProgramID pid, NodeID nid, const Source& src, const Expression* exp)
+ : Base(pid, nid, src), expr(exp) {}
+
+IndexAttribute::~IndexAttribute() = default;
+
+std::string IndexAttribute::Name() const {
+ return "index";
+}
+
+const IndexAttribute* IndexAttribute::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* expr_ = ctx->Clone(expr);
+ return ctx->dst->create<IndexAttribute>(src, expr_);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/index_attribute.h b/src/tint/lang/wgsl/ast/index_attribute.h
new file mode 100644
index 0000000..8c8863b
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/index_attribute.h
@@ -0,0 +1,51 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_INDEX_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_INDEX_ATTRIBUTE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+namespace tint::ast {
+
+/// An id attribute for pipeline-overridable constants
+class IndexAttribute final : public utils::Castable<IndexAttribute, Attribute> {
+ public:
+ /// Create an index attribute.
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param expr the numeric id expression
+ IndexAttribute(ProgramID pid, NodeID nid, const Source& src, const Expression* expr);
+ ~IndexAttribute() override;
+
+ /// @returns the WGSL name for the attribute
+ std::string Name() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const IndexAttribute* Clone(CloneContext* ctx) const override;
+
+ /// The id expression
+ const Expression* const expr;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_INDEX_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/index_attribute_test.cc b/src/tint/lang/wgsl/ast/index_attribute_test.cc
new file mode 100644
index 0000000..6c6bc6f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/index_attribute_test.cc
@@ -0,0 +1,31 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/id_attribute.h"
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IndexAttributeTest = TestHelper;
+
+TEST_F(IndexAttributeTest, Creation) {
+ auto* d = Index(1_a);
+ EXPECT_TRUE(d->expr->Is<IntLiteralExpression>());
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/int_literal_expression.cc b/src/tint/lang/wgsl/ast/int_literal_expression.cc
new file mode 100644
index 0000000..2e5949a
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/int_literal_expression.cc
@@ -0,0 +1,49 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/int_literal_expression.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::IntLiteralExpression);
+
+namespace tint::ast {
+
+IntLiteralExpression::IntLiteralExpression(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ int64_t val,
+ Suffix suf)
+ : Base(pid, nid, src), value(val), suffix(suf) {}
+
+IntLiteralExpression::~IntLiteralExpression() = default;
+
+const IntLiteralExpression* IntLiteralExpression::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ return ctx->dst->create<IntLiteralExpression>(src, value, suffix);
+}
+
+utils::StringStream& operator<<(utils::StringStream& out, IntLiteralExpression::Suffix suffix) {
+ switch (suffix) {
+ default:
+ return out;
+ case IntLiteralExpression::Suffix::kI:
+ return out << "i";
+ case IntLiteralExpression::Suffix::kU:
+ return out << "u";
+ }
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/int_literal_expression.h b/src/tint/lang/wgsl/ast/int_literal_expression.h
new file mode 100644
index 0000000..28e0c75
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/int_literal_expression.h
@@ -0,0 +1,66 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_INT_LITERAL_EXPRESSION_H_
+#define SRC_TINT_LANG_WGSL_AST_INT_LITERAL_EXPRESSION_H_
+
+#include "src/tint/lang/wgsl/ast/literal_expression.h"
+
+namespace tint::ast {
+
+/// An integer literal. The literal may have an 'i', 'u' or no suffix.
+class IntLiteralExpression final : public utils::Castable<IntLiteralExpression, LiteralExpression> {
+ public:
+ /// Literal suffix
+ enum class Suffix {
+ /// No suffix
+ kNone,
+ /// 'i' suffix (i32)
+ kI,
+ /// 'u' suffix (u32)
+ kU,
+ };
+
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param val the literal value
+ /// @param suf the literal suffix
+ IntLiteralExpression(ProgramID pid, NodeID nid, const Source& src, int64_t val, Suffix suf);
+
+ ~IntLiteralExpression() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const IntLiteralExpression* Clone(CloneContext* ctx) const override;
+
+ /// The literal value
+ const int64_t value;
+
+ /// The literal suffix
+ const Suffix suffix;
+};
+
+/// Writes the integer literal suffix to the stream.
+/// @param out the stream to write to
+/// @param suffix the suffix to write
+/// @returns out so calls can be chained
+utils::StringStream& operator<<(utils::StringStream& out, IntLiteralExpression::Suffix suffix);
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_INT_LITERAL_EXPRESSION_H_
diff --git a/src/tint/lang/wgsl/ast/int_literal_expression_test.cc b/src/tint/lang/wgsl/ast/int_literal_expression_test.cc
new file mode 100644
index 0000000..a6dec02
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/int_literal_expression_test.cc
@@ -0,0 +1,58 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+#include "src/tint/utils/string_stream.h"
+
+namespace tint::ast {
+namespace {
+
+using IntLiteralExpressionTest = TestHelper;
+
+TEST_F(IntLiteralExpressionTest, SuffixNone) {
+ auto* i = create<IntLiteralExpression>(42, IntLiteralExpression::Suffix::kNone);
+ ASSERT_TRUE(i->Is<IntLiteralExpression>());
+ EXPECT_EQ(i->value, 42);
+ EXPECT_EQ(i->suffix, IntLiteralExpression::Suffix::kNone);
+}
+
+TEST_F(IntLiteralExpressionTest, SuffixI) {
+ auto* i = create<IntLiteralExpression>(42, IntLiteralExpression::Suffix::kI);
+ ASSERT_TRUE(i->Is<IntLiteralExpression>());
+ EXPECT_EQ(i->value, 42);
+ EXPECT_EQ(i->suffix, IntLiteralExpression::Suffix::kI);
+}
+
+TEST_F(IntLiteralExpressionTest, SuffixU) {
+ auto* i = create<IntLiteralExpression>(42, IntLiteralExpression::Suffix::kU);
+ ASSERT_TRUE(i->Is<IntLiteralExpression>());
+ EXPECT_EQ(i->value, 42);
+ EXPECT_EQ(i->suffix, IntLiteralExpression::Suffix::kU);
+}
+
+TEST_F(IntLiteralExpressionTest, SuffixStringStream) {
+ auto to_str = [](IntLiteralExpression::Suffix suffix) {
+ utils::StringStream ss;
+ ss << suffix;
+ return ss.str();
+ };
+
+ EXPECT_EQ("", to_str(IntLiteralExpression::Suffix::kNone));
+ EXPECT_EQ("i", to_str(IntLiteralExpression::Suffix::kI));
+ EXPECT_EQ("u", to_str(IntLiteralExpression::Suffix::kU));
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/internal_attribute.cc b/src/tint/lang/wgsl/ast/internal_attribute.cc
new file mode 100644
index 0000000..72a6249
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/internal_attribute.cc
@@ -0,0 +1,34 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/internal_attribute.h"
+
+#include <utility>
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::InternalAttribute);
+
+namespace tint::ast {
+
+InternalAttribute::InternalAttribute(ProgramID pid,
+ NodeID nid,
+ utils::VectorRef<const IdentifierExpression*> deps)
+ : Base(pid, nid, Source{}), dependencies(std::move(deps)) {}
+
+InternalAttribute::~InternalAttribute() = default;
+
+std::string InternalAttribute::Name() const {
+ return "internal";
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/internal_attribute.h b/src/tint/lang/wgsl/ast/internal_attribute.h
new file mode 100644
index 0000000..3e758e3
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/internal_attribute.h
@@ -0,0 +1,59 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_INTERNAL_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_INTERNAL_ATTRIBUTE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+#include "src/tint/utils/vector.h"
+
+// Forward declarations
+namespace tint::ast {
+class IdentifierExpression;
+} // namespace tint::ast
+
+namespace tint::ast {
+
+/// An attribute used to indicate that a function is tint-internal.
+/// These attributes are not produced by generators, but instead are usually
+/// created by transforms for consumption by a particular backend.
+class InternalAttribute : public utils::Castable<InternalAttribute, Attribute> {
+ public:
+ /// Constructor
+ /// @param program_id the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param deps a list of identifiers that this attribute is dependent on
+ InternalAttribute(ProgramID program_id,
+ NodeID nid,
+ utils::VectorRef<const IdentifierExpression*> deps);
+
+ /// Destructor
+ ~InternalAttribute() override;
+
+ /// @return a short description of the internal attribute which will be
+ /// displayed in WGSL as `@internal(<name>)` (but is not parsable).
+ virtual std::string InternalName() const = 0;
+
+ /// @returns the WGSL name for the attribute
+ std::string Name() const override;
+
+ /// A list of identifiers that this attribute is dependent on
+ const utils::Vector<const IdentifierExpression*, 1> dependencies;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_INTERNAL_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/interpolate_attribute.cc b/src/tint/lang/wgsl/ast/interpolate_attribute.cc
new file mode 100644
index 0000000..383f94a
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/interpolate_attribute.cc
@@ -0,0 +1,46 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/interpolate_attribute.h"
+
+#include <string>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::InterpolateAttribute);
+
+namespace tint::ast {
+
+InterpolateAttribute::InterpolateAttribute(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* ty,
+ const Expression* smpl)
+ : Base(pid, nid, src), type(ty), sampling(smpl) {}
+
+InterpolateAttribute::~InterpolateAttribute() = default;
+
+std::string InterpolateAttribute::Name() const {
+ return "interpolate";
+}
+
+const InterpolateAttribute* InterpolateAttribute::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* ty = ctx->Clone(type);
+ auto* smpl = ctx->Clone(sampling);
+ return ctx->dst->create<InterpolateAttribute>(src, ty, smpl);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/interpolate_attribute.h b/src/tint/lang/wgsl/ast/interpolate_attribute.h
new file mode 100644
index 0000000..e8c1475
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/interpolate_attribute.h
@@ -0,0 +1,63 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_INTERPOLATE_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_INTERPOLATE_ATTRIBUTE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+
+// Forward declarations
+namespace tint::ast {
+class Expression;
+}
+
+namespace tint::ast {
+
+/// An interpolate attribute
+class InterpolateAttribute final : public utils::Castable<InterpolateAttribute, Attribute> {
+ public:
+ /// Create an interpolate attribute.
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param type the interpolation type
+ /// @param sampling the interpolation sampling
+ InterpolateAttribute(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* type,
+ const Expression* sampling);
+ ~InterpolateAttribute() override;
+
+ /// @returns the WGSL name for the attribute
+ std::string Name() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const InterpolateAttribute* Clone(CloneContext* ctx) const override;
+
+ /// The interpolation type
+ const Expression* const type;
+
+ /// The interpolation sampling
+ const Expression* const sampling;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_INTERPOLATE_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/interpolate_attribute_test.cc b/src/tint/lang/wgsl/ast/interpolate_attribute_test.cc
new file mode 100644
index 0000000..4753a4f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/interpolate_attribute_test.cc
@@ -0,0 +1,35 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/interpolate_attribute.h"
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+#include "src/tint/utils/string.h"
+
+namespace tint::ast {
+namespace {
+
+using InterpolateAttributeTest = TestHelper;
+
+TEST_F(InterpolateAttributeTest, Creation) {
+ auto* d =
+ Interpolate(builtin::InterpolationType::kLinear, builtin::InterpolationSampling::kCenter);
+ CheckIdentifier(d->type, "linear");
+ CheckIdentifier(d->sampling, "center");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/invariant_attribute.cc b/src/tint/lang/wgsl/ast/invariant_attribute.cc
new file mode 100644
index 0000000..096709b
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/invariant_attribute.cc
@@ -0,0 +1,38 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/invariant_attribute.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::InvariantAttribute);
+
+namespace tint::ast {
+
+InvariantAttribute::InvariantAttribute(ProgramID pid, NodeID nid, const Source& src)
+ : Base(pid, nid, src) {}
+
+InvariantAttribute::~InvariantAttribute() = default;
+
+std::string InvariantAttribute::Name() const {
+ return "invariant";
+}
+
+const InvariantAttribute* InvariantAttribute::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ return ctx->dst->create<InvariantAttribute>(src);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/invariant_attribute.h b/src/tint/lang/wgsl/ast/invariant_attribute.h
new file mode 100644
index 0000000..c8bdb4f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/invariant_attribute.h
@@ -0,0 +1,46 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_INVARIANT_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_INVARIANT_ATTRIBUTE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+
+namespace tint::ast {
+
+/// The invariant attribute
+class InvariantAttribute final : public utils::Castable<InvariantAttribute, Attribute> {
+ public:
+ /// constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ InvariantAttribute(ProgramID pid, NodeID nid, const Source& src);
+ ~InvariantAttribute() override;
+
+ /// @returns the WGSL name for the attribute
+ std::string Name() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const InvariantAttribute* Clone(CloneContext* ctx) const override;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_INVARIANT_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/let.cc b/src/tint/lang/wgsl/ast/let.cc
new file mode 100644
index 0000000..f2eee1f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/let.cc
@@ -0,0 +1,51 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/let.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::Let);
+
+namespace tint::ast {
+
+Let::Let(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Identifier* n,
+ Type ty,
+ const Expression* init,
+ utils::VectorRef<const Attribute*> attrs)
+ : Base(pid, nid, src, n, ty, init, std::move(attrs)) {
+ TINT_ASSERT(AST, init != nullptr);
+}
+
+Let::~Let() = default;
+
+const char* Let::Kind() const {
+ return "let";
+}
+
+const Let* Let::Clone(CloneContext* ctx) const {
+ auto src = ctx->Clone(source);
+ auto* n = ctx->Clone(name);
+ auto ty = ctx->Clone(type);
+ auto* init = ctx->Clone(initializer);
+ auto attrs = ctx->Clone(attributes);
+ return ctx->dst->create<Let>(src, n, ty, init, std::move(attrs));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/let.h b/src/tint/lang/wgsl/ast/let.h
new file mode 100644
index 0000000..ecc119d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/let.h
@@ -0,0 +1,63 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_LET_H_
+#define SRC_TINT_LANG_WGSL_AST_LET_H_
+
+#include "src/tint/lang/wgsl/ast/variable.h"
+
+namespace tint::ast {
+
+/// A "let" declaration is a name for a function-scoped runtime typed value.
+///
+/// Examples:
+///
+/// ```
+/// let twice_depth : i32 = width + width; // Must have initializer
+/// ```
+/// @see https://www.w3.org/TR/WGSL/#let-decls
+class Let final : public utils::Castable<Let, Variable> {
+ public:
+ /// Create a 'let' variable
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the variable source
+ /// @param name the variable name
+ /// @param type the declared variable type
+ /// @param initializer the initializer expression
+ /// @param attributes the variable attributes
+ Let(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ const Identifier* name,
+ Type type,
+ const Expression* initializer,
+ utils::VectorRef<const Attribute*> attributes);
+
+ /// Destructor
+ ~Let() override;
+
+ /// @returns "let"
+ const char* Kind() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const Let* Clone(CloneContext* ctx) const override;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_LET_H_
diff --git a/src/tint/lang/wgsl/ast/literal_expression.cc b/src/tint/lang/wgsl/ast/literal_expression.cc
new file mode 100644
index 0000000..df85321
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/literal_expression.cc
@@ -0,0 +1,26 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/literal_expression.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::LiteralExpression);
+
+namespace tint::ast {
+
+LiteralExpression::LiteralExpression(ProgramID pid, NodeID nid, const Source& src)
+ : Base(pid, nid, src) {}
+
+LiteralExpression::~LiteralExpression() = default;
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/literal_expression.h b/src/tint/lang/wgsl/ast/literal_expression.h
new file mode 100644
index 0000000..9036f9d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/literal_expression.h
@@ -0,0 +1,39 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_LITERAL_EXPRESSION_H_
+#define SRC_TINT_LANG_WGSL_AST_LITERAL_EXPRESSION_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+namespace tint::ast {
+
+/// Base class for a literal value expressions
+class LiteralExpression : public utils::Castable<LiteralExpression, Expression> {
+ public:
+ ~LiteralExpression() override;
+
+ protected:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the input source
+ LiteralExpression(ProgramID pid, NodeID nid, const Source& src);
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_LITERAL_EXPRESSION_H_
diff --git a/src/tint/lang/wgsl/ast/location_attribute.cc b/src/tint/lang/wgsl/ast/location_attribute.cc
new file mode 100644
index 0000000..87bc75c
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/location_attribute.cc
@@ -0,0 +1,44 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/location_attribute.h"
+
+#include <string>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::LocationAttribute);
+
+namespace tint::ast {
+
+LocationAttribute::LocationAttribute(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* exp)
+ : Base(pid, nid, src), expr(exp) {}
+
+LocationAttribute::~LocationAttribute() = default;
+
+std::string LocationAttribute::Name() const {
+ return "location";
+}
+
+const LocationAttribute* LocationAttribute::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto expr_ = ctx->Clone(expr);
+ return ctx->dst->create<LocationAttribute>(src, expr_);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/location_attribute.h b/src/tint/lang/wgsl/ast/location_attribute.h
new file mode 100644
index 0000000..949e2a6
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/location_attribute.h
@@ -0,0 +1,51 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_LOCATION_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_LOCATION_ATTRIBUTE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+namespace tint::ast {
+
+/// A location attribute
+class LocationAttribute final : public utils::Castable<LocationAttribute, Attribute> {
+ public:
+ /// constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param expr the location expression
+ LocationAttribute(ProgramID pid, NodeID nid, const Source& src, const Expression* expr);
+ ~LocationAttribute() override;
+
+ /// @returns the WGSL name for the attribute
+ std::string Name() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const LocationAttribute* Clone(CloneContext* ctx) const override;
+
+ /// The location expression
+ const Expression* const expr;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_LOCATION_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/location_attribute_test.cc b/src/tint/lang/wgsl/ast/location_attribute_test.cc
new file mode 100644
index 0000000..122dd86
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/location_attribute_test.cc
@@ -0,0 +1,29 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using LocationAttributeTest = TestHelper;
+
+TEST_F(LocationAttributeTest, Creation) {
+ auto* d = Location(2_a);
+ EXPECT_TRUE(d->expr->Is<IntLiteralExpression>());
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/loop_statement.cc b/src/tint/lang/wgsl/ast/loop_statement.cc
new file mode 100644
index 0000000..e9a73cf
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/loop_statement.cc
@@ -0,0 +1,52 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/loop_statement.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::LoopStatement);
+
+namespace tint::ast {
+
+LoopStatement::LoopStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const BlockStatement* b,
+ const BlockStatement* cont,
+ utils::VectorRef<const ast::Attribute*> attrs)
+ : Base(pid, nid, src), body(b), continuing(cont), attributes(std::move(attrs)) {
+ TINT_ASSERT(AST, body);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, body, program_id);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, continuing, program_id);
+ for (auto* attr : attributes) {
+ TINT_ASSERT(AST, attr);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, attr, program_id);
+ }
+}
+
+LoopStatement::~LoopStatement() = default;
+
+const LoopStatement* LoopStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* b = ctx->Clone(body);
+ auto* cont = ctx->Clone(continuing);
+ auto attrs = ctx->Clone(attributes);
+ return ctx->dst->create<LoopStatement>(src, b, cont, std::move(attrs));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/loop_statement.h b/src/tint/lang/wgsl/ast/loop_statement.h
new file mode 100644
index 0000000..083c38e
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/loop_statement.h
@@ -0,0 +1,59 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_LOOP_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_LOOP_STATEMENT_H_
+
+#include "src/tint/lang/wgsl/ast/block_statement.h"
+
+namespace tint::ast {
+
+/// A loop statement
+class LoopStatement final : public utils::Castable<LoopStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the loop statement source
+ /// @param body the body statements
+ /// @param continuing the continuing statements
+ /// @param attributes the while statement attributes
+ LoopStatement(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ const BlockStatement* body,
+ const BlockStatement* continuing,
+ utils::VectorRef<const ast::Attribute*> attributes);
+ /// Destructor
+ ~LoopStatement() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const LoopStatement* Clone(CloneContext* ctx) const override;
+
+ /// The loop body
+ const BlockStatement* const body;
+
+ /// The continuing statements
+ const BlockStatement* const continuing;
+
+ /// The attribute list
+ const utils::Vector<const Attribute*, 1> attributes;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_LOOP_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/loop_statement_test.cc b/src/tint/lang/wgsl/ast/loop_statement_test.cc
new file mode 100644
index 0000000..2ca171d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/loop_statement_test.cc
@@ -0,0 +1,114 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/loop_statement.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/if_statement.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using LoopStatementTest = TestHelper;
+
+TEST_F(LoopStatementTest, Creation) {
+ auto* body = Block(create<DiscardStatement>());
+ auto* b = body->Last();
+
+ auto* continuing = Block(create<DiscardStatement>());
+
+ auto* l = create<LoopStatement>(body, continuing, utils::Empty);
+ ASSERT_EQ(l->body->statements.Length(), 1u);
+ EXPECT_EQ(l->body->statements[0], b);
+ ASSERT_EQ(l->continuing->statements.Length(), 1u);
+ EXPECT_EQ(l->continuing->statements[0], continuing->Last());
+}
+
+TEST_F(LoopStatementTest, Creation_WithSource) {
+ auto* body = Block(create<DiscardStatement>());
+
+ auto* continuing = Block(create<DiscardStatement>());
+
+ auto* l =
+ create<LoopStatement>(Source{Source::Location{20, 2}}, body, continuing, utils::Empty);
+ auto src = l->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(LoopStatementTest, Creation_WithAttributes) {
+ auto* attr1 = DiagnosticAttribute(builtin::DiagnosticSeverity::kOff, "foo");
+ auto* attr2 = DiagnosticAttribute(builtin::DiagnosticSeverity::kOff, "bar");
+
+ auto* body = Block(Return());
+ auto* l = create<LoopStatement>(body, nullptr, utils::Vector{attr1, attr2});
+
+ EXPECT_THAT(l->attributes, testing::ElementsAre(attr1, attr2));
+}
+
+TEST_F(LoopStatementTest, IsLoop) {
+ auto* l = create<LoopStatement>(Block(), Block(), utils::Empty);
+ EXPECT_TRUE(l->Is<LoopStatement>());
+}
+
+TEST_F(LoopStatementTest, HasContinuing_WithoutContinuing) {
+ auto* body = Block(create<DiscardStatement>());
+
+ auto* l = create<LoopStatement>(body, nullptr, utils::Empty);
+ EXPECT_FALSE(l->continuing);
+}
+
+TEST_F(LoopStatementTest, HasContinuing_WithContinuing) {
+ auto* body = Block(create<DiscardStatement>());
+
+ auto* continuing = Block(create<DiscardStatement>());
+
+ auto* l = create<LoopStatement>(body, continuing, utils::Empty);
+ EXPECT_TRUE(l->continuing);
+}
+
+TEST_F(LoopStatementTest, Assert_Null_Body) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<LoopStatement>(nullptr, nullptr, utils::Empty);
+ },
+ "internal compiler error");
+}
+
+TEST_F(LoopStatementTest, Assert_DifferentProgramID_Body) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<LoopStatement>(b2.Block(), b1.Block(), utils::Empty);
+ },
+ "internal compiler error");
+}
+
+TEST_F(LoopStatementTest, Assert_DifferentProgramID_Continuing) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<LoopStatement>(b1.Block(), b2.Block(), utils::Empty);
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/member_accessor_expression.cc b/src/tint/lang/wgsl/ast/member_accessor_expression.cc
new file mode 100644
index 0000000..f6ccdec
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/member_accessor_expression.cc
@@ -0,0 +1,48 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/member_accessor_expression.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::MemberAccessorExpression);
+
+namespace tint::ast {
+
+MemberAccessorExpression::MemberAccessorExpression(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* obj,
+ const Identifier* mem)
+ : Base(pid, nid, src, obj), member(mem) {
+ TINT_ASSERT(AST, member);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, member, program_id);
+
+ // It is currently invalid for a structure to hold a templated member
+ if (member) {
+ TINT_ASSERT(AST, !member->Is<TemplatedIdentifier>());
+ }
+}
+
+MemberAccessorExpression::~MemberAccessorExpression() = default;
+
+const MemberAccessorExpression* MemberAccessorExpression::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* obj = ctx->Clone(object);
+ auto* mem = ctx->Clone(member);
+ return ctx->dst->create<MemberAccessorExpression>(src, obj, mem);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/member_accessor_expression.h b/src/tint/lang/wgsl/ast/member_accessor_expression.h
new file mode 100644
index 0000000..4d34fb1
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/member_accessor_expression.h
@@ -0,0 +1,54 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_MEMBER_ACCESSOR_EXPRESSION_H_
+#define SRC_TINT_LANG_WGSL_AST_MEMBER_ACCESSOR_EXPRESSION_H_
+
+#include "src/tint/lang/wgsl/ast/accessor_expression.h"
+#include "src/tint/lang/wgsl/ast/identifier_expression.h"
+
+namespace tint::ast {
+
+/// A member accessor expression
+class MemberAccessorExpression final
+ : public utils::Castable<MemberAccessorExpression, AccessorExpression> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the member accessor expression source
+ /// @param object the object
+ /// @param member the member
+ MemberAccessorExpression(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ const Expression* object,
+ const Identifier* member);
+
+ /// Destructor
+ ~MemberAccessorExpression() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const MemberAccessorExpression* Clone(CloneContext* ctx) const override;
+
+ /// The member expression
+ const Identifier* const member;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_MEMBER_ACCESSOR_EXPRESSION_H_
diff --git a/src/tint/lang/wgsl/ast/member_accessor_expression_test.cc b/src/tint/lang/wgsl/ast/member_accessor_expression_test.cc
new file mode 100644
index 0000000..3f1a4bf
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/member_accessor_expression_test.cc
@@ -0,0 +1,94 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using MemberAccessorExpressionTest = TestHelper;
+
+TEST_F(MemberAccessorExpressionTest, Creation) {
+ auto* str = Expr("structure");
+ auto* mem = Ident("member");
+
+ auto* stmt = create<MemberAccessorExpression>(str, mem);
+ EXPECT_EQ(stmt->object, str);
+ EXPECT_EQ(stmt->member, mem);
+}
+
+TEST_F(MemberAccessorExpressionTest, Creation_WithSource) {
+ auto* stmt = create<MemberAccessorExpression>(Source{Source::Location{20, 2}},
+ Expr("structure"), Ident("member"));
+ auto src = stmt->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(MemberAccessorExpressionTest, IsMemberAccessor) {
+ auto* stmt = create<MemberAccessorExpression>(Expr("structure"), Ident("member"));
+ EXPECT_TRUE(stmt->Is<MemberAccessorExpression>());
+}
+
+TEST_F(MemberAccessorExpressionTest, Assert_Null_Struct) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<MemberAccessorExpression>(nullptr, b.Ident("member"));
+ },
+ "internal compiler error");
+}
+
+TEST_F(MemberAccessorExpressionTest, Assert_Null_Member) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<MemberAccessorExpression>(b.Expr("struct"), nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(MemberAccessorExpressionTest, Assert_DifferentProgramID_Struct) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<MemberAccessorExpression>(b2.Expr("structure"), b1.Ident("member"));
+ },
+ "internal compiler error");
+}
+
+TEST_F(MemberAccessorExpressionTest, Assert_DifferentProgramID_Member) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<MemberAccessorExpression>(b1.Expr("structure"), b2.Ident("member"));
+ },
+ "internal compiler error");
+}
+
+TEST_F(MemberAccessorExpressionTest, Assert_MemberNotTemplated) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<MemberAccessorExpression>(b.Expr("structure"),
+ b.Ident("member", "a", "b", "c"));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/module.cc b/src/tint/lang/wgsl/ast/module.cc
new file mode 100644
index 0000000..9ef1b3e
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/module.cc
@@ -0,0 +1,158 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/module.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/type_decl.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/switch.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::Module);
+
+namespace tint::ast {
+
+Module::Module(ProgramID pid, NodeID nid, const Source& src) : Base(pid, nid, src) {}
+
+Module::Module(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ utils::VectorRef<const Node*> global_decls)
+ : Base(pid, nid, src), global_declarations_(std::move(global_decls)) {
+ for (auto* decl : global_declarations_) {
+ if (decl == nullptr) {
+ continue;
+ }
+ diag::List diags;
+ BinGlobalDeclaration(decl, diags);
+ }
+}
+
+Module::~Module() = default;
+
+const TypeDecl* Module::LookupType(Symbol name) const {
+ for (auto* ty : TypeDecls()) {
+ if (ty->name->symbol == name) {
+ return ty;
+ }
+ }
+ return nullptr;
+}
+
+void Module::AddGlobalDeclaration(const tint::ast::Node* decl) {
+ diag::List diags;
+ BinGlobalDeclaration(decl, diags);
+ global_declarations_.Push(decl);
+}
+
+void Module::BinGlobalDeclaration(const tint::ast::Node* decl, diag::List& diags) {
+ Switch(
+ decl, //
+ [&](const TypeDecl* type) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, type, program_id);
+ type_decls_.Push(type);
+ },
+ [&](const Function* func) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, func, program_id);
+ functions_.Push(func);
+ },
+ [&](const Variable* var) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, var, program_id);
+ global_variables_.Push(var);
+ },
+ [&](const DiagnosticDirective* diagnostic) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, diagnostic, program_id);
+ diagnostic_directives_.Push(diagnostic);
+ },
+ [&](const Enable* enable) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, enable, program_id);
+ enables_.Push(enable);
+ },
+ [&](const ConstAssert* assertion) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, assertion, program_id);
+ const_asserts_.Push(assertion);
+ },
+ [&](Default) { TINT_ICE(AST, diags) << "Unknown global declaration type"; });
+}
+
+void Module::AddDiagnosticDirective(const DiagnosticDirective* directive) {
+ TINT_ASSERT(AST, directive);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, directive, program_id);
+ global_declarations_.Push(directive);
+ diagnostic_directives_.Push(directive);
+}
+
+void Module::AddEnable(const Enable* enable) {
+ TINT_ASSERT(AST, enable);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, enable, program_id);
+ global_declarations_.Push(enable);
+ enables_.Push(enable);
+}
+
+void Module::AddGlobalVariable(const Variable* var) {
+ TINT_ASSERT(AST, var);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, var, program_id);
+ global_variables_.Push(var);
+ global_declarations_.Push(var);
+}
+
+void Module::AddConstAssert(const ConstAssert* assertion) {
+ TINT_ASSERT(AST, assertion);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, assertion, program_id);
+ const_asserts_.Push(assertion);
+ global_declarations_.Push(assertion);
+}
+
+void Module::AddTypeDecl(const TypeDecl* type) {
+ TINT_ASSERT(AST, type);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, type, program_id);
+ type_decls_.Push(type);
+ global_declarations_.Push(type);
+}
+
+void Module::AddFunction(const Function* func) {
+ TINT_ASSERT(AST, func);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, func, program_id);
+ functions_.Push(func);
+ global_declarations_.Push(func);
+}
+
+const Module* Module::Clone(CloneContext* ctx) const {
+ auto* out = ctx->dst->create<Module>();
+ out->Copy(ctx, this);
+ return out;
+}
+
+void Module::Copy(CloneContext* ctx, const Module* src) {
+ ctx->Clone(global_declarations_, src->global_declarations_);
+
+ // During the clone, declarations may have been placed into the module.
+ // Clear everything out, as we're about to re-bin the declarations.
+ type_decls_.Clear();
+ functions_.Clear();
+ global_variables_.Clear();
+ enables_.Clear();
+ diagnostic_directives_.Clear();
+
+ for (auto* decl : global_declarations_) {
+ if (TINT_UNLIKELY(!decl)) {
+ TINT_ICE(AST, ctx->dst->Diagnostics()) << "src global declaration was nullptr";
+ continue;
+ }
+ BinGlobalDeclaration(decl, ctx->dst->Diagnostics());
+ }
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/module.h b/src/tint/lang/wgsl/ast/module.h
new file mode 100644
index 0000000..e588454
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/module.h
@@ -0,0 +1,163 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_MODULE_H_
+#define SRC_TINT_LANG_WGSL_AST_MODULE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/const_assert.h"
+#include "src/tint/lang/wgsl/ast/diagnostic_directive.h"
+#include "src/tint/lang/wgsl/ast/enable.h"
+#include "src/tint/lang/wgsl/ast/function.h"
+#include "src/tint/utils/vector.h"
+
+namespace tint::ast {
+
+class TypeDecl;
+
+/// Module holds the top-level AST types, functions and global variables used by
+/// a Program.
+class Module final : public utils::Castable<Module, Node> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ Module(ProgramID pid, NodeID nid, const Source& src);
+
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param global_decls the list of global types, functions, and variables, in
+ /// the order they were declared in the source program
+ Module(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ utils::VectorRef<const Node*> global_decls);
+
+ /// Destructor
+ ~Module() override;
+
+ /// @returns the declaration-ordered global declarations for the module
+ const auto& GlobalDeclarations() const { return global_declarations_; }
+
+ /// Add a global variable to the module
+ /// @param var the variable to add
+ void AddGlobalVariable(const Variable* var);
+
+ /// @returns true if the module has the global declaration `decl`
+ /// @param decl the declaration to check
+ bool HasGlobalDeclaration(Node* decl) const {
+ for (auto* d : global_declarations_) {
+ if (d == decl) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /// Adds a global declaration to the module.
+ /// @param decl the declaration to add
+ void AddGlobalDeclaration(const tint::ast::Node* decl);
+
+ /// @returns the global variables for the module
+ const auto& GlobalVariables() const { return global_variables_; }
+
+ /// @returns the global variables for the module
+ auto& GlobalVariables() { return global_variables_; }
+
+ /// @returns the global variable declarations of kind 'T' for the module
+ template <typename T, typename = utils::traits::EnableIfIsType<T, Variable>>
+ auto Globals() const {
+ utils::Vector<const T*, 32> out;
+ out.Reserve(global_variables_.Length());
+ for (auto* global : global_variables_) {
+ if (auto* var = global->As<T>()) {
+ out.Push(var);
+ }
+ }
+ return out;
+ }
+
+ /// Add a diagnostic directive to the module
+ /// @param diagnostic the diagnostic directive to add
+ void AddDiagnosticDirective(const DiagnosticDirective* diagnostic);
+
+ /// Add a enable directive to the module
+ /// @param ext the enable directive to add
+ void AddEnable(const Enable* ext);
+
+ /// @returns the diagnostic directives for the module
+ const auto& DiagnosticDirectives() const { return diagnostic_directives_; }
+
+ /// @returns the extension set for the module
+ const auto& Enables() const { return enables_; }
+
+ /// Add a global const assertion to the module
+ /// @param assertion the const assert to add
+ void AddConstAssert(const ConstAssert* assertion);
+
+ /// @returns the list of global const assertions
+ const auto& ConstAsserts() const { return const_asserts_; }
+
+ /// Adds a type declaration to the module
+ /// @param decl the type declaration to add
+ void AddTypeDecl(const TypeDecl* decl);
+
+ /// @returns the TypeDecl registered as a TypeDecl()
+ /// @param name the name of the type to search for
+ const TypeDecl* LookupType(Symbol name) const;
+
+ /// @returns the declared types in the module
+ const auto& TypeDecls() const { return type_decls_; }
+
+ /// Add a function to the module
+ /// @param func the function to add
+ void AddFunction(const Function* func);
+
+ /// @returns the functions declared in the module
+ const FunctionList& Functions() const { return functions_; }
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const Module* Clone(CloneContext* ctx) const override;
+
+ /// Copy copies the content of the Module src into this module.
+ /// @param ctx the clone context
+ /// @param src the module to copy into this module
+ void Copy(CloneContext* ctx, const Module* src);
+
+ private:
+ /// Adds `decl` to either:
+ /// * #global_declarations_
+ /// * #type_decls_
+ /// * #functions_
+ void BinGlobalDeclaration(const tint::ast::Node* decl, diag::List& diags);
+
+ utils::Vector<const Node*, 64> global_declarations_;
+ utils::Vector<const TypeDecl*, 16> type_decls_;
+ FunctionList functions_;
+ utils::Vector<const Variable*, 32> global_variables_;
+ utils::Vector<const DiagnosticDirective*, 8> diagnostic_directives_;
+ utils::Vector<const Enable*, 8> enables_;
+ utils::Vector<const ConstAssert*, 8> const_asserts_;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_MODULE_H_
diff --git a/src/tint/lang/wgsl/ast/module_clone_test.cc b/src/tint/lang/wgsl/ast/module_clone_test.cc
new file mode 100644
index 0000000..1e2e6bd
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/module_clone_test.cc
@@ -0,0 +1,179 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <unordered_set>
+
+#include "gtest/gtest.h"
+#include "src/tint/lang/wgsl/ast_writer/generator.h"
+#include "src/tint/lang/wgsl/reader/parser.h"
+
+namespace tint::ast {
+namespace {
+
+TEST(ModuleCloneTest, Clone) {
+ // Shader that exercises the bulk of the AST nodes and types.
+ // See also fuzzers/tint_ast_clone_fuzzer.cc for further coverage of cloning.
+ Source::File file("test.wgsl", R"(enable f16;
+diagnostic(off, chromium.unreachable_code);
+
+struct S0 {
+ @size(4)
+ m0 : u32,
+ m1 : array<u32>,
+};
+
+struct S1 {
+ @size(4)
+ m0 : u32,
+ m1 : array<u32, 6>,
+};
+
+const c0 : i32 = 10;
+const c1 : bool = true;
+
+alias t0 = array<vec4<f32>>;
+alias t1 = array<vec4<f32>>;
+
+var<private> g0 : u32 = 20u;
+var<private> g1 : f32 = 123.0;
+@group(0) @binding(0) var g2 : texture_2d<f32>;
+@group(1) @binding(0) var g3 : texture_depth_2d;
+@group(2) @binding(0) var g4 : texture_storage_2d<rg32float, write>;
+@group(3) @binding(0) var g5 : texture_depth_cube_array;
+@group(4) @binding(0) var g6 : texture_external;
+
+var<private> g7 : vec3<f32>;
+@group(0) @binding(1) var<storage, read_write> g8 : S0;
+@group(1) @binding(1) var<storage, read> g9 : S0;
+@group(2) @binding(1) var<storage, read_write> g10 : S0;
+
+fn f0(p0 : bool) -> f32 {
+ if (p0) {
+ return 1.0;
+ }
+ return 0.0;
+}
+
+@diagnostic(warning, chromium.unreachable_code)
+fn f1(p0 : f32, p1 : i32) -> f32 {
+ var l0 : i32 = 3;
+ var l1 : f32 = 8.0;
+ var l2 : u32 = bitcast<u32>(4);
+ var l3 : vec2<u32> = vec2<u32>(u32(l0), u32(l1));
+ var l4 : S1;
+ var l5 : u32 = l4.m1[5];
+ let l6 : ptr<private, u32> = &g0;
+ const l7 = 123;
+ const l8 : i32 = 123;
+ loop {
+ l0 = (p1 + 2);
+ if (((l0 % 4) == 0)) {
+ break;
+ }
+
+ continuing {
+ if (1 == 2) {
+ l0 = l0 - 1;
+ } else {
+ l0 = l0 - 2;
+ }
+ }
+ }
+ switch(l2) {
+ case 0u: {
+ break;
+ }
+ case 1u: {
+ return f0(true);
+ }
+ default: {
+ discard;
+ }
+ }
+ return 1.0;
+}
+
+@fragment
+fn main() {
+ f1(1.0, 2);
+}
+
+const declaration_order_check_0 : i32 = 1;
+
+alias declaration_order_check_1 = f32;
+
+fn declaration_order_check_2() {}
+
+alias declaration_order_check_3 = f32;
+
+const declaration_order_check_4 : i32 = 1;
+
+)");
+
+ // Parse the wgsl, create the src program
+ auto src = reader::wgsl::Parse(&file);
+
+ ASSERT_TRUE(src.IsValid()) << diag::Formatter().format(src.Diagnostics());
+
+ // Clone the src program to dst
+ Program dst(src.Clone());
+
+ ASSERT_TRUE(dst.IsValid()) << diag::Formatter().format(dst.Diagnostics());
+
+ // Expect the printed strings to match
+ EXPECT_EQ(Program::printer(&src), Program::printer(&dst));
+
+ // Check that none of the AST nodes or type pointers in dst are found in src
+ std::unordered_set<const Node*> src_nodes;
+ for (auto* src_node : src.ASTNodes().Objects()) {
+ src_nodes.emplace(src_node);
+ }
+ std::unordered_set<const type::Type*> src_types;
+ for (auto* src_type : src.Types()) {
+ src_types.emplace(src_type);
+ }
+ for (auto* dst_node : dst.ASTNodes().Objects()) {
+ ASSERT_EQ(src_nodes.count(dst_node), 0u);
+ }
+ for (auto* dst_type : dst.Types()) {
+ ASSERT_EQ(src_types.count(dst_type), 0u);
+ }
+
+ // Regenerate the wgsl for the src program. We use this instead of the
+ // original source so that reformatting doesn't impact the final wgsl
+ // comparison.
+ writer::wgsl::Options options;
+ std::string src_wgsl;
+ {
+ auto result = writer::wgsl::Generate(&src, options);
+ ASSERT_TRUE(result.success) << result.error;
+ src_wgsl = result.wgsl;
+
+ // Move the src program to a temporary that'll be dropped, so that the src
+ // program is released before we attempt to print the dst program. This
+ // guarantee that all the source program nodes and types are destructed and
+ // freed. ASAN should error if there's any remaining references in dst when
+ // we try to reconstruct the WGSL.
+ auto tmp = std::move(src);
+ }
+
+ // Print the dst module, check it matches the original source
+ auto result = writer::wgsl::Generate(&dst, options);
+ ASSERT_TRUE(result.success);
+ auto dst_wgsl = result.wgsl;
+ ASSERT_EQ(src_wgsl, dst_wgsl);
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/module_test.cc b/src/tint/lang/wgsl/ast/module_test.cc
new file mode 100644
index 0000000..744f86f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/module_test.cc
@@ -0,0 +1,158 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/clone_context.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using ModuleTest = TestHelper;
+
+TEST_F(ModuleTest, Creation) {
+ EXPECT_EQ(Program(std::move(*this)).AST().Functions().Length(), 0u);
+}
+
+TEST_F(ModuleTest, LookupFunction) {
+ auto* func = Func("main", {}, ty.f32(), {});
+
+ Program program(std::move(*this));
+ EXPECT_EQ(func, program.AST().Functions().Find(program.Symbols().Get("main")));
+}
+
+TEST_F(ModuleTest, LookupFunctionMissing) {
+ Program program(std::move(*this));
+ EXPECT_EQ(nullptr, program.AST().Functions().Find(program.Symbols().Get("Missing")));
+}
+
+TEST_F(ModuleTest, Assert_Null_GlobalVariable) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder builder;
+ builder.AST().AddGlobalVariable(nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(ModuleTest, Assert_Null_TypeDecl) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder builder;
+ builder.AST().AddTypeDecl(nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(ModuleTest, Assert_DifferentProgramID_Function) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.AST().AddFunction(b2.create<Function>(b2.Ident("func"), utils::Empty, b2.ty.f32(),
+ b2.Block(), utils::Empty, utils::Empty));
+ },
+ "internal compiler error");
+}
+
+TEST_F(ModuleTest, Assert_DifferentProgramID_GlobalVariable) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.AST().AddGlobalVariable(b2.Var("var", b2.ty.i32(), builtin::AddressSpace::kPrivate));
+ },
+ "internal compiler error");
+}
+
+TEST_F(ModuleTest, Assert_Null_Function) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder builder;
+ builder.AST().AddFunction(nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(ModuleTest, CloneOrder) {
+ // Create a program with a function, alias decl and var decl.
+ Program p = [] {
+ ProgramBuilder b;
+ b.Func("F", {}, b.ty.void_(), {});
+ b.Alias("A", b.ty.u32());
+ b.GlobalVar("V", b.ty.i32(), builtin::AddressSpace::kPrivate);
+ return Program(std::move(b));
+ }();
+
+ // Clone the program, using ReplaceAll() to create new module-scope
+ // declarations. We want to test that these are added just before the
+ // declaration that triggered the ReplaceAll().
+ ProgramBuilder cloned;
+ CloneContext ctx(&cloned, &p);
+ ctx.ReplaceAll([&](const Function*) -> const Function* {
+ ctx.dst->Alias("inserted_before_F", cloned.ty.u32());
+ return nullptr;
+ });
+ ctx.ReplaceAll([&](const ast::Alias*) -> const ast::Alias* {
+ ctx.dst->Alias("inserted_before_A", cloned.ty.u32());
+ return nullptr;
+ });
+ ctx.ReplaceAll([&](const Variable*) -> const Variable* {
+ ctx.dst->Alias("inserted_before_V", cloned.ty.u32());
+ return nullptr;
+ });
+ ctx.Clone();
+
+ auto& decls = cloned.AST().GlobalDeclarations();
+ ASSERT_EQ(decls.Length(), 6u);
+ EXPECT_TRUE(decls[1]->Is<Function>());
+ EXPECT_TRUE(decls[3]->Is<ast::Alias>());
+ EXPECT_TRUE(decls[5]->Is<Variable>());
+
+ ASSERT_TRUE(decls[0]->Is<ast::Alias>());
+ ASSERT_TRUE(decls[2]->Is<ast::Alias>());
+ ASSERT_TRUE(decls[4]->Is<ast::Alias>());
+
+ ASSERT_EQ(decls[0]->As<ast::Alias>()->name->symbol.Name(), "inserted_before_F");
+ ASSERT_EQ(decls[2]->As<ast::Alias>()->name->symbol.Name(), "inserted_before_A");
+ ASSERT_EQ(decls[4]->As<ast::Alias>()->name->symbol.Name(), "inserted_before_V");
+}
+
+TEST_F(ModuleTest, Directives) {
+ auto* enable_1 = Enable(builtin::Extension::kF16);
+ auto* diagnostic_1 = DiagnosticDirective(builtin::DiagnosticSeverity::kWarning, "foo");
+ auto* enable_2 = Enable(builtin::Extension::kChromiumExperimentalFullPtrParameters);
+ auto* diagnostic_2 = DiagnosticDirective(builtin::DiagnosticSeverity::kOff, "bar");
+
+ this->SetResolveOnBuild(false);
+ Program program(std::move(*this));
+ EXPECT_THAT(program.AST().GlobalDeclarations(), ::testing::ContainerEq(utils::Vector{
+ enable_1,
+ diagnostic_1,
+ enable_2,
+ diagnostic_2,
+ }));
+ EXPECT_THAT(program.AST().Enables(), ::testing::ContainerEq(utils::Vector{
+ enable_1,
+ enable_2,
+ }));
+ EXPECT_THAT(program.AST().DiagnosticDirectives(), ::testing::ContainerEq(utils::Vector{
+ diagnostic_1,
+ diagnostic_2,
+ }));
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/must_use_attribute.cc b/src/tint/lang/wgsl/ast/must_use_attribute.cc
new file mode 100644
index 0000000..2b2e5e3
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/must_use_attribute.cc
@@ -0,0 +1,38 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/must_use_attribute.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::MustUseAttribute);
+
+namespace tint::ast {
+
+MustUseAttribute::MustUseAttribute(ProgramID pid, NodeID nid, const Source& src)
+ : Base(pid, nid, src) {}
+
+MustUseAttribute::~MustUseAttribute() = default;
+
+std::string MustUseAttribute::Name() const {
+ return "must_use";
+}
+
+const MustUseAttribute* MustUseAttribute::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ return ctx->dst->create<MustUseAttribute>(src);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/must_use_attribute.h b/src/tint/lang/wgsl/ast/must_use_attribute.h
new file mode 100644
index 0000000..6cc0124
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/must_use_attribute.h
@@ -0,0 +1,46 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_MUST_USE_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_MUST_USE_ATTRIBUTE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+
+namespace tint::ast {
+
+/// The must_use attribute
+class MustUseAttribute final : public utils::Castable<MustUseAttribute, Attribute> {
+ public:
+ /// constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ MustUseAttribute(ProgramID pid, NodeID nid, const Source& src);
+ ~MustUseAttribute() override;
+
+ /// @returns the WGSL name for the attribute
+ std::string Name() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const MustUseAttribute* Clone(CloneContext* ctx) const override;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_MUST_USE_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/node.cc b/src/tint/lang/wgsl/ast/node.cc
new file mode 100644
index 0000000..3384682
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/node.cc
@@ -0,0 +1,26 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/node.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::Node);
+
+namespace tint::ast {
+
+Node::Node(ProgramID pid, NodeID nid, const Source& src)
+ : program_id(pid), node_id(nid), source(src) {}
+
+Node::~Node() = default;
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/node.h b/src/tint/lang/wgsl/ast/node.h
new file mode 100644
index 0000000..4873c05
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/node.h
@@ -0,0 +1,63 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_NODE_H_
+#define SRC_TINT_LANG_WGSL_AST_NODE_H_
+
+#include <string>
+
+#include "src/tint/clone_context.h"
+#include "src/tint/lang/wgsl/ast/node_id.h"
+
+namespace tint::ast {
+
+/// AST base class node
+class Node : public utils::Castable<Node, Cloneable> {
+ public:
+ ~Node() override;
+
+ /// The identifier of the program that owns this node
+ const ProgramID program_id;
+
+ /// The node identifier, unique for the program.
+ const NodeID node_id;
+
+ /// The node source data
+ const Source source;
+
+ protected:
+ /// Create a new node
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the input source for the node
+ Node(ProgramID pid, NodeID nid, const Source& src);
+
+ private:
+ Node(const Node&) = delete;
+ Node(Node&&) = delete;
+};
+
+} // namespace tint::ast
+
+namespace tint {
+
+/// @param node a pointer to an AST node
+/// @returns the ProgramID of the given AST node.
+inline ProgramID ProgramIDOf(const ast::Node* node) {
+ return node ? node->program_id : ProgramID();
+}
+
+} // namespace tint
+
+#endif // SRC_TINT_LANG_WGSL_AST_NODE_H_
diff --git a/src/tint/lang/wgsl/ast/node_id.h b/src/tint/lang/wgsl/ast/node_id.h
new file mode 100644
index 0000000..ec8b635
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/node_id.h
@@ -0,0 +1,41 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_NODE_ID_H_
+#define SRC_TINT_LANG_WGSL_AST_NODE_ID_H_
+
+#include <stddef.h>
+
+namespace tint::ast {
+
+/// NodeID is a unique node identifier for a given Program.
+/// NodeIDs are sequentially allocated, starting at 0.
+struct NodeID {
+ /// Equality operator
+ /// @param other the other NodeID
+ /// @returns true if the NodeIDs are the same
+ bool operator==(NodeID other) const { return value == other.value; }
+
+ /// Less-than comparison operator
+ /// @param other the other NodeID
+ /// @returns true if the other comes before this node
+ bool operator<(NodeID other) const { return value < other.value; }
+
+ /// The numerical value for the node identifier
+ size_t value = 0;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_NODE_ID_H_
diff --git a/src/tint/lang/wgsl/ast/override.cc b/src/tint/lang/wgsl/ast/override.cc
new file mode 100644
index 0000000..1239623
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/override.cc
@@ -0,0 +1,49 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/override.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::Override);
+
+namespace tint::ast {
+
+Override::Override(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Identifier* n,
+ Type ty,
+ const Expression* init,
+ utils::VectorRef<const Attribute*> attrs)
+ : Base(pid, nid, src, n, ty, init, std::move(attrs)) {}
+
+Override::~Override() = default;
+
+const char* Override::Kind() const {
+ return "override";
+}
+
+const Override* Override::Clone(CloneContext* ctx) const {
+ auto src = ctx->Clone(source);
+ auto* n = ctx->Clone(name);
+ auto ty = ctx->Clone(type);
+ auto* init = ctx->Clone(initializer);
+ auto attrs = ctx->Clone(attributes);
+ return ctx->dst->create<Override>(src, n, ty, init, std::move(attrs));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/override.h b/src/tint/lang/wgsl/ast/override.h
new file mode 100644
index 0000000..bcedfbd
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/override.h
@@ -0,0 +1,66 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_OVERRIDE_H_
+#define SRC_TINT_LANG_WGSL_AST_OVERRIDE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/variable.h"
+
+namespace tint::ast {
+
+/// An "override" declaration - a name for a pipeline-overridable constant.
+/// Examples:
+///
+/// ```
+/// override radius : i32 = 2; // Can be overridden by name.
+/// @id(5) override width : i32 = 2; // Can be overridden by ID.
+/// override scale : f32; // No default - must be overridden.
+/// ```
+/// @see https://www.w3.org/TR/WGSL/#override-decls
+class Override final : public utils::Castable<Override, Variable> {
+ public:
+ /// Create an 'override' pipeline-overridable constant.
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the variable source
+ /// @param name the variable name
+ /// @param type the declared variable type
+ /// @param initializer the initializer expression
+ /// @param attributes the variable attributes
+ Override(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ const Identifier* name,
+ Type type,
+ const Expression* initializer,
+ utils::VectorRef<const Attribute*> attributes);
+
+ /// Destructor
+ ~Override() override;
+
+ /// @returns "override"
+ const char* Kind() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const Override* Clone(CloneContext* ctx) const override;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_OVERRIDE_H_
diff --git a/src/tint/lang/wgsl/ast/parameter.cc b/src/tint/lang/wgsl/ast/parameter.cc
new file mode 100644
index 0000000..4f0d598
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/parameter.cc
@@ -0,0 +1,47 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/parameter.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::Parameter);
+
+namespace tint::ast {
+
+Parameter::Parameter(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Identifier* n,
+ Type ty,
+ utils::VectorRef<const Attribute*> attrs)
+ : Base(pid, nid, src, n, ty, nullptr, std::move(attrs)) {}
+
+Parameter::~Parameter() = default;
+
+const char* Parameter::Kind() const {
+ return "parameter";
+}
+
+const Parameter* Parameter::Clone(CloneContext* ctx) const {
+ auto src = ctx->Clone(source);
+ auto* n = ctx->Clone(name);
+ auto ty = ctx->Clone(type);
+ auto attrs = ctx->Clone(attributes);
+ return ctx->dst->create<Parameter>(src, n, ty, std::move(attrs));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/parameter.h b/src/tint/lang/wgsl/ast/parameter.h
new file mode 100644
index 0000000..4ebe26b
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/parameter.h
@@ -0,0 +1,65 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_PARAMETER_H_
+#define SRC_TINT_LANG_WGSL_AST_PARAMETER_H_
+
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/variable.h"
+
+namespace tint::ast {
+
+/// A formal parameter to a function - a name for a typed value to be passed into a function.
+/// Example:
+///
+/// ```
+/// fn twice(a: i32) -> i32 { // "a:i32" is the formal parameter
+/// return a + a;
+/// }
+/// ```
+///
+/// @see https://www.w3.org/TR/WGSL/#creation-time-consts
+class Parameter final : public utils::Castable<Parameter, Variable> {
+ public:
+ /// Create a 'parameter' creation-time value variable.
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the variable source
+ /// @param name the variable name
+ /// @param type the declared variable type
+ /// @param attributes the variable attributes
+ Parameter(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ const Identifier* name,
+ Type type,
+ utils::VectorRef<const Attribute*> attributes);
+
+ /// Destructor
+ ~Parameter() override;
+
+ /// @returns "parameter"
+ const char* Kind() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const Parameter* Clone(CloneContext* ctx) const override;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_PARAMETER_H_
diff --git a/src/tint/lang/wgsl/ast/phony_expression.cc b/src/tint/lang/wgsl/ast/phony_expression.cc
new file mode 100644
index 0000000..d41dd5c
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/phony_expression.cc
@@ -0,0 +1,34 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/phony_expression.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::PhonyExpression);
+
+namespace tint::ast {
+
+PhonyExpression::PhonyExpression(ProgramID pid, NodeID nid, const Source& src)
+ : Base(pid, nid, src) {}
+
+PhonyExpression::~PhonyExpression() = default;
+
+const PhonyExpression* PhonyExpression::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ return ctx->dst->create<PhonyExpression>(src);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/phony_expression.h b/src/tint/lang/wgsl/ast/phony_expression.h
new file mode 100644
index 0000000..067e237
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/phony_expression.h
@@ -0,0 +1,44 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_PHONY_EXPRESSION_H_
+#define SRC_TINT_LANG_WGSL_AST_PHONY_EXPRESSION_H_
+
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+namespace tint::ast {
+
+/// Represents the `_` of a phony assignment `_ = <expr>`
+/// @see https://www.w3.org/TR/WGSL/#phony-assignment-section
+class PhonyExpression final : public utils::Castable<PhonyExpression, Expression> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ PhonyExpression(ProgramID pid, NodeID nid, const Source& src);
+
+ /// Destructor
+ ~PhonyExpression() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const PhonyExpression* Clone(CloneContext* ctx) const override;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_PHONY_EXPRESSION_H_
diff --git a/src/tint/lang/wgsl/ast/phony_expression_test.cc b/src/tint/lang/wgsl/ast/phony_expression_test.cc
new file mode 100644
index 0000000..099eada
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/phony_expression_test.cc
@@ -0,0 +1,40 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using IdentifierExpressionTest = TestHelper;
+
+TEST_F(IdentifierExpressionTest, Creation) {
+ EXPECT_NE(Phony(), nullptr);
+}
+
+TEST_F(IdentifierExpressionTest, Creation_WithSource) {
+ auto* p = Phony(Source{{20, 2}});
+
+ auto src = p->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(IdentifierExpressionTest, IsPhony) {
+ auto* p = Phony();
+ EXPECT_TRUE(p->Is<PhonyExpression>());
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/pipeline_stage.cc b/src/tint/lang/wgsl/ast/pipeline_stage.cc
new file mode 100644
index 0000000..5cf1f05
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/pipeline_stage.cc
@@ -0,0 +1,41 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/pipeline_stage.h"
+
+namespace tint::ast {
+
+utils::StringStream& operator<<(utils::StringStream& out, PipelineStage stage) {
+ switch (stage) {
+ case PipelineStage::kNone: {
+ out << "none";
+ break;
+ }
+ case PipelineStage::kVertex: {
+ out << "vertex";
+ break;
+ }
+ case PipelineStage::kFragment: {
+ out << "fragment";
+ break;
+ }
+ case PipelineStage::kCompute: {
+ out << "compute";
+ break;
+ }
+ }
+ return out;
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/pipeline_stage.h b/src/tint/lang/wgsl/ast/pipeline_stage.h
new file mode 100644
index 0000000..8ddc450
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/pipeline_stage.h
@@ -0,0 +1,32 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_PIPELINE_STAGE_H_
+#define SRC_TINT_LANG_WGSL_AST_PIPELINE_STAGE_H_
+
+#include "src/tint/utils/string_stream.h"
+
+namespace tint::ast {
+
+/// The pipeline stage
+enum class PipelineStage { kNone = -1, kVertex, kFragment, kCompute };
+
+/// @param out the stream to write to
+/// @param stage the PipelineStage
+/// @return the stream so calls can be chained
+utils::StringStream& operator<<(utils::StringStream& out, PipelineStage stage);
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_PIPELINE_STAGE_H_
diff --git a/src/tint/lang/wgsl/ast/return_statement.cc b/src/tint/lang/wgsl/ast/return_statement.cc
new file mode 100644
index 0000000..6b666db
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/return_statement.cc
@@ -0,0 +1,43 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/return_statement.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::ReturnStatement);
+
+namespace tint::ast {
+
+ReturnStatement::ReturnStatement(ProgramID pid, NodeID nid, const Source& src)
+ : Base(pid, nid, src), value(nullptr) {}
+
+ReturnStatement::ReturnStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* val)
+ : Base(pid, nid, src), value(val) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, value, program_id);
+}
+
+ReturnStatement::~ReturnStatement() = default;
+
+const ReturnStatement* ReturnStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* ret = ctx->Clone(value);
+ return ctx->dst->create<ReturnStatement>(src, ret);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/return_statement.h b/src/tint/lang/wgsl/ast/return_statement.h
new file mode 100644
index 0000000..d5d3e40
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/return_statement.h
@@ -0,0 +1,54 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_RETURN_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_RETURN_STATEMENT_H_
+
+#include "src/tint/lang/wgsl/ast/expression.h"
+#include "src/tint/lang/wgsl/ast/statement.h"
+
+namespace tint::ast {
+
+/// A return statement
+class ReturnStatement final : public utils::Castable<ReturnStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ ReturnStatement(ProgramID pid, NodeID nid, const Source& src);
+
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param value the return value
+ ReturnStatement(ProgramID pid, NodeID nid, const Source& src, const Expression* value);
+
+ /// Destructor
+ ~ReturnStatement() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const ReturnStatement* Clone(CloneContext* ctx) const override;
+
+ /// The value returned. May be null.
+ const Expression* const value;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_RETURN_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/return_statement_test.cc b/src/tint/lang/wgsl/ast/return_statement_test.cc
new file mode 100644
index 0000000..7e00c3b
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/return_statement_test.cc
@@ -0,0 +1,66 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/return_statement.h"
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using ReturnStatementTest = TestHelper;
+
+TEST_F(ReturnStatementTest, Creation) {
+ auto* expr = Expr("expr");
+
+ auto* r = create<ReturnStatement>(expr);
+ EXPECT_EQ(r->value, expr);
+}
+
+TEST_F(ReturnStatementTest, Creation_WithSource) {
+ auto* r = create<ReturnStatement>(Source{Source::Location{20, 2}});
+ auto src = r->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(ReturnStatementTest, IsReturn) {
+ auto* r = create<ReturnStatement>();
+ EXPECT_TRUE(r->Is<ReturnStatement>());
+}
+
+TEST_F(ReturnStatementTest, WithoutValue) {
+ auto* r = create<ReturnStatement>();
+ EXPECT_EQ(r->value, nullptr);
+}
+
+TEST_F(ReturnStatementTest, WithValue) {
+ auto* expr = Expr("expr");
+ auto* r = create<ReturnStatement>(expr);
+ EXPECT_NE(r->value, nullptr);
+}
+
+TEST_F(ReturnStatementTest, Assert_DifferentProgramID_Expr) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<ReturnStatement>(b2.Expr(true));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/stage_attribute.cc b/src/tint/lang/wgsl/ast/stage_attribute.cc
new file mode 100644
index 0000000..aaa361c
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/stage_attribute.cc
@@ -0,0 +1,40 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+
+#include <string>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::StageAttribute);
+
+namespace tint::ast {
+
+StageAttribute::StageAttribute(ProgramID pid, NodeID nid, const Source& src, PipelineStage s)
+ : Base(pid, nid, src), stage(s) {}
+
+StageAttribute::~StageAttribute() = default;
+
+std::string StageAttribute::Name() const {
+ return "stage";
+}
+
+const StageAttribute* StageAttribute::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ return ctx->dst->create<StageAttribute>(src, stage);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/stage_attribute.h b/src/tint/lang/wgsl/ast/stage_attribute.h
new file mode 100644
index 0000000..708e850
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/stage_attribute.h
@@ -0,0 +1,51 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_STAGE_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_STAGE_ATTRIBUTE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+#include "src/tint/lang/wgsl/ast/pipeline_stage.h"
+
+namespace tint::ast {
+
+/// A workgroup attribute
+class StageAttribute final : public utils::Castable<StageAttribute, Attribute> {
+ public:
+ /// constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param stage the pipeline stage
+ /// @param source the source of this attribute
+ StageAttribute(ProgramID pid, NodeID nid, const Source& source, PipelineStage stage);
+ ~StageAttribute() override;
+
+ /// @returns the WGSL name for the attribute
+ std::string Name() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const StageAttribute* Clone(CloneContext* ctx) const override;
+
+ /// The pipeline stage
+ const PipelineStage stage;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_STAGE_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/stage_attribute_test.cc b/src/tint/lang/wgsl/ast/stage_attribute_test.cc
new file mode 100644
index 0000000..1a81a55
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/stage_attribute_test.cc
@@ -0,0 +1,31 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/workgroup_attribute.h"
+
+namespace tint::ast {
+namespace {
+
+using StageAttributeTest = TestHelper;
+
+TEST_F(StageAttributeTest, Creation_1param) {
+ auto* d = create<StageAttribute>(PipelineStage::kFragment);
+ EXPECT_EQ(d->stage, PipelineStage::kFragment);
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/statement.cc b/src/tint/lang/wgsl/ast/statement.cc
new file mode 100644
index 0000000..91bbc2f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/statement.cc
@@ -0,0 +1,76 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/statement.h"
+
+#include "src/tint/lang/wgsl/ast/assignment_statement.h"
+#include "src/tint/lang/wgsl/ast/break_statement.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/continue_statement.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/if_statement.h"
+#include "src/tint/lang/wgsl/ast/loop_statement.h"
+#include "src/tint/lang/wgsl/ast/return_statement.h"
+#include "src/tint/lang/wgsl/ast/switch_statement.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::Statement);
+
+namespace tint::ast {
+
+Statement::Statement(ProgramID pid, NodeID nid, const Source& src) : Base(pid, nid, src) {}
+
+Statement::~Statement() = default;
+
+const char* Statement::Name() const {
+ if (Is<AssignmentStatement>()) {
+ return "assignment statement";
+ }
+ if (Is<BlockStatement>()) {
+ return "block statement";
+ }
+ if (Is<BreakStatement>()) {
+ return "break statement";
+ }
+ if (Is<CaseStatement>()) {
+ return "case statement";
+ }
+ if (Is<CallStatement>()) {
+ return "function call";
+ }
+ if (Is<ContinueStatement>()) {
+ return "continue statement";
+ }
+ if (Is<DiscardStatement>()) {
+ return "discard statement";
+ }
+ if (Is<IfStatement>()) {
+ return "if statement";
+ }
+ if (Is<LoopStatement>()) {
+ return "loop statement";
+ }
+ if (Is<ReturnStatement>()) {
+ return "return statement";
+ }
+ if (Is<SwitchStatement>()) {
+ return "switch statement";
+ }
+ if (Is<VariableDeclStatement>()) {
+ return "variable declaration";
+ }
+ return "statement";
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/statement.h b/src/tint/lang/wgsl/ast/statement.h
new file mode 100644
index 0000000..9318fd0
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/statement.h
@@ -0,0 +1,42 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_STATEMENT_H_
+
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/node.h"
+
+namespace tint::ast {
+
+/// Base statement class
+class Statement : public utils::Castable<Statement, Node> {
+ public:
+ ~Statement() override;
+
+ /// @returns the human readable name for the statement type.
+ const char* Name() const;
+
+ protected:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of the expression
+ Statement(ProgramID pid, NodeID nid, const Source& src);
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/stride_attribute.cc b/src/tint/lang/wgsl/ast/stride_attribute.cc
new file mode 100644
index 0000000..dc30855
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/stride_attribute.cc
@@ -0,0 +1,40 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/stride_attribute.h"
+
+#include <string>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::StrideAttribute);
+
+namespace tint::ast {
+
+StrideAttribute::StrideAttribute(ProgramID pid, NodeID nid, const Source& src, uint32_t s)
+ : Base(pid, nid, src), stride(s) {}
+
+StrideAttribute::~StrideAttribute() = default;
+
+std::string StrideAttribute::Name() const {
+ return "stride";
+}
+
+const StrideAttribute* StrideAttribute::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ return ctx->dst->create<StrideAttribute>(src, stride);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/stride_attribute.h b/src/tint/lang/wgsl/ast/stride_attribute.h
new file mode 100644
index 0000000..d7875d8
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/stride_attribute.h
@@ -0,0 +1,52 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_STRIDE_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_STRIDE_ATTRIBUTE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+#include "src/tint/lang/wgsl/ast/internal_attribute.h"
+
+namespace tint::ast {
+
+/// A stride attribute used by the SPIR-V reader for strided arrays and
+/// matrices.
+class StrideAttribute final : public utils::Castable<StrideAttribute, Attribute> {
+ public:
+ /// constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param stride the stride value
+ StrideAttribute(ProgramID pid, NodeID nid, const Source& src, uint32_t stride);
+ ~StrideAttribute() override;
+
+ /// @returns the WGSL name for the attribute
+ std::string Name() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const StrideAttribute* Clone(CloneContext* ctx) const override;
+
+ /// The stride value
+ const uint32_t stride;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_STRIDE_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/stride_attribute_test.cc b/src/tint/lang/wgsl/ast/stride_attribute_test.cc
new file mode 100644
index 0000000..32a180c
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/stride_attribute_test.cc
@@ -0,0 +1,37 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using StrideAttributeTest = TestHelper;
+
+TEST_F(StrideAttributeTest, Creation) {
+ auto* d = create<StrideAttribute>(2u);
+ EXPECT_EQ(2u, d->stride);
+}
+
+TEST_F(StrideAttributeTest, Source) {
+ auto* d = create<StrideAttribute>(
+ Source{Source::Range{Source::Location{1, 2}, Source::Location{3, 4}}}, 2u);
+ EXPECT_EQ(d->source.range.begin.line, 1u);
+ EXPECT_EQ(d->source.range.begin.column, 2u);
+ EXPECT_EQ(d->source.range.end.line, 3u);
+ EXPECT_EQ(d->source.range.end.column, 4u);
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/struct.cc b/src/tint/lang/wgsl/ast/struct.cc
new file mode 100644
index 0000000..8543885
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/struct.cc
@@ -0,0 +1,53 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/struct.h"
+
+#include <string>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::Struct);
+
+namespace tint::ast {
+
+Struct::Struct(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Identifier* n,
+ utils::VectorRef<const StructMember*> m,
+ utils::VectorRef<const Attribute*> attrs)
+ : Base(pid, nid, src, n), members(std::move(m)), attributes(std::move(attrs)) {
+ for (auto* mem : members) {
+ TINT_ASSERT(AST, mem);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, mem, program_id);
+ }
+ for (auto* attr : attributes) {
+ TINT_ASSERT(AST, attr);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, attr, program_id);
+ }
+}
+
+Struct::~Struct() = default;
+
+const Struct* Struct::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto n = ctx->Clone(name);
+ auto mem = ctx->Clone(members);
+ auto attrs = ctx->Clone(attributes);
+ return ctx->dst->create<Struct>(src, n, std::move(mem), std::move(attrs));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/struct.h b/src/tint/lang/wgsl/ast/struct.h
new file mode 100644
index 0000000..2513834
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/struct.h
@@ -0,0 +1,63 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_STRUCT_H_
+#define SRC_TINT_LANG_WGSL_AST_STRUCT_H_
+
+#include <string>
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+#include "src/tint/lang/wgsl/ast/struct_member.h"
+#include "src/tint/lang/wgsl/ast/type_decl.h"
+#include "src/tint/utils/vector.h"
+
+namespace tint::ast {
+
+/// A struct statement.
+class Struct final : public utils::Castable<Struct, TypeDecl> {
+ public:
+ /// Create a new struct statement
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node for the import statement
+ /// @param name The name of the structure
+ /// @param members The struct members
+ /// @param attributes The struct attributes
+ Struct(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Identifier* name,
+ utils::VectorRef<const StructMember*> members,
+ utils::VectorRef<const Attribute*> attributes);
+
+ /// Destructor
+ ~Struct() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const Struct* Clone(CloneContext* ctx) const override;
+
+ /// The members
+ const utils::Vector<const StructMember*, 8> members;
+
+ /// The struct attributes
+ const utils::Vector<const Attribute*, 4> attributes;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_STRUCT_H_
diff --git a/src/tint/lang/wgsl/ast/struct_member.cc b/src/tint/lang/wgsl/ast/struct_member.cc
new file mode 100644
index 0000000..bdd3cb6
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/struct_member.cc
@@ -0,0 +1,53 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/struct_member.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::StructMember);
+
+namespace tint::ast {
+
+StructMember::StructMember(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Identifier* n,
+ Type ty,
+ utils::VectorRef<const Attribute*> attrs)
+
+ : Base(pid, nid, src), name(n), type(ty), attributes(std::move(attrs)) {
+ TINT_ASSERT(AST, name);
+ if (name) {
+ TINT_ASSERT(AST, !name->Is<TemplatedIdentifier>());
+ }
+ TINT_ASSERT(AST, type);
+ for (auto* attr : attributes) {
+ TINT_ASSERT(AST, attr);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, attr, program_id);
+ }
+}
+
+StructMember::~StructMember() = default;
+
+const StructMember* StructMember::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto n = ctx->Clone(name);
+ auto ty = ctx->Clone(type);
+ auto attrs = ctx->Clone(attributes);
+ return ctx->dst->create<StructMember>(src, n, ty, std::move(attrs));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/struct_member.h b/src/tint/lang/wgsl/ast/struct_member.h
new file mode 100644
index 0000000..f8efe8b
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/struct_member.h
@@ -0,0 +1,68 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_STRUCT_MEMBER_H_
+#define SRC_TINT_LANG_WGSL_AST_STRUCT_MEMBER_H_
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+#include "src/tint/lang/wgsl/ast/type.h"
+
+// Forward declarations
+namespace tint::ast {
+class Identifier;
+} // namespace tint::ast
+
+namespace tint::ast {
+
+/// A struct member statement.
+class StructMember final : public utils::Castable<StructMember, Node> {
+ public:
+ /// Create a new struct member statement
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node for the struct member statement
+ /// @param name The struct member name
+ /// @param type The struct member type
+ /// @param attributes The struct member attributes
+ StructMember(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Identifier* name,
+ Type type,
+ utils::VectorRef<const Attribute*> attributes);
+
+ /// Destructor
+ ~StructMember() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const StructMember* Clone(CloneContext* ctx) const override;
+
+ /// The member name
+ const Identifier* const name;
+
+ /// The type
+ const Type type;
+
+ /// The attributes
+ const utils::Vector<const Attribute*, 4> attributes;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_STRUCT_MEMBER_H_
diff --git a/src/tint/lang/wgsl/ast/struct_member_align_attribute.cc b/src/tint/lang/wgsl/ast/struct_member_align_attribute.cc
new file mode 100644
index 0000000..4362dc2
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/struct_member_align_attribute.cc
@@ -0,0 +1,45 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/struct_member_align_attribute.h"
+
+#include <string>
+
+#include "src/tint/clone_context.h"
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::StructMemberAlignAttribute);
+
+namespace tint::ast {
+
+StructMemberAlignAttribute::StructMemberAlignAttribute(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* a)
+ : Base(pid, nid, src), expr(a) {}
+
+StructMemberAlignAttribute::~StructMemberAlignAttribute() = default;
+
+std::string StructMemberAlignAttribute::Name() const {
+ return "align";
+}
+
+const StructMemberAlignAttribute* StructMemberAlignAttribute::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* expr_ = ctx->Clone(expr);
+ return ctx->dst->create<StructMemberAlignAttribute>(src, expr_);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/struct_member_align_attribute.h b/src/tint/lang/wgsl/ast/struct_member_align_attribute.h
new file mode 100644
index 0000000..34e0f48
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/struct_member_align_attribute.h
@@ -0,0 +1,56 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_STRUCT_MEMBER_ALIGN_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_STRUCT_MEMBER_ALIGN_ATTRIBUTE_H_
+
+#include <stddef.h>
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+namespace tint::ast {
+
+/// A struct member align attribute
+class StructMemberAlignAttribute final
+ : public utils::Castable<StructMemberAlignAttribute, Attribute> {
+ public:
+ /// constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param align the align expression
+ StructMemberAlignAttribute(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* align);
+ ~StructMemberAlignAttribute() override;
+
+ /// @returns the WGSL name for the attribute
+ std::string Name() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const StructMemberAlignAttribute* Clone(CloneContext* ctx) const override;
+
+ /// The align expression
+ const Expression* const expr;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_STRUCT_MEMBER_ALIGN_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/struct_member_align_attribute_test.cc b/src/tint/lang/wgsl/ast/struct_member_align_attribute_test.cc
new file mode 100644
index 0000000..8ef77f5
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/struct_member_align_attribute_test.cc
@@ -0,0 +1,32 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/struct_member_align_attribute.h"
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using StructMemberAlignAttributeTest = TestHelper;
+
+TEST_F(StructMemberAlignAttributeTest, Creation) {
+ auto* val = Expr("ident");
+ auto* d = create<StructMemberAlignAttribute>(val);
+ EXPECT_EQ(val, d->expr);
+ EXPECT_TRUE(d->expr->Is<IdentifierExpression>());
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/struct_member_offset_attribute.cc b/src/tint/lang/wgsl/ast/struct_member_offset_attribute.cc
new file mode 100644
index 0000000..995b119
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/struct_member_offset_attribute.cc
@@ -0,0 +1,44 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/struct_member_offset_attribute.h"
+
+#include <string>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::StructMemberOffsetAttribute);
+
+namespace tint::ast {
+
+StructMemberOffsetAttribute::StructMemberOffsetAttribute(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* exp)
+ : Base(pid, nid, src), expr(exp) {}
+
+StructMemberOffsetAttribute::~StructMemberOffsetAttribute() = default;
+
+std::string StructMemberOffsetAttribute::Name() const {
+ return "offset";
+}
+
+const StructMemberOffsetAttribute* StructMemberOffsetAttribute::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto expr_ = ctx->Clone(expr);
+ return ctx->dst->create<StructMemberOffsetAttribute>(src, expr_);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/struct_member_offset_attribute.h b/src/tint/lang/wgsl/ast/struct_member_offset_attribute.h
new file mode 100644
index 0000000..2b9f2d9
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/struct_member_offset_attribute.h
@@ -0,0 +1,64 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_STRUCT_MEMBER_OFFSET_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_STRUCT_MEMBER_OFFSET_ATTRIBUTE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+namespace tint::ast {
+
+/// A struct member offset attribute
+/// @note The WGSL spec removed the `@offset(n)` attribute for `@size(n)`
+/// and `@align(n)` in https://github.com/gpuweb/gpuweb/pull/1447. However
+/// this attribute is kept because the SPIR-V reader has to deal with absolute
+/// offsets, and transforming these to size / align is complex and can be done
+/// in a number of ways. The Resolver is responsible for consuming the size and
+/// align attributes and transforming these into absolute offsets. It is
+/// trivial for the Resolver to handle `@offset(n)` or `@size(n)` /
+/// `@align(n)` attributes, so this is what we do, keeping all the layout
+/// logic in one place.
+class StructMemberOffsetAttribute final
+ : public utils::Castable<StructMemberOffsetAttribute, Attribute> {
+ public:
+ /// constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param expr the offset expression
+ StructMemberOffsetAttribute(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* expr);
+ ~StructMemberOffsetAttribute() override;
+
+ /// @returns the WGSL name for the attribute
+ std::string Name() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const StructMemberOffsetAttribute* Clone(CloneContext* ctx) const override;
+
+ /// The offset expression
+ const Expression* const expr;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_STRUCT_MEMBER_OFFSET_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/struct_member_offset_attribute_test.cc b/src/tint/lang/wgsl/ast/struct_member_offset_attribute_test.cc
new file mode 100644
index 0000000..bc665e7
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/struct_member_offset_attribute_test.cc
@@ -0,0 +1,30 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using StructMemberOffsetAttributeTest = TestHelper;
+
+TEST_F(StructMemberOffsetAttributeTest, Creation) {
+ auto* d = MemberOffset(2_u);
+ ASSERT_TRUE(d->expr->Is<IntLiteralExpression>());
+ EXPECT_EQ(2u, d->expr->As<IntLiteralExpression>()->value);
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/struct_member_size_attribute.cc b/src/tint/lang/wgsl/ast/struct_member_size_attribute.cc
new file mode 100644
index 0000000..26e64ed
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/struct_member_size_attribute.cc
@@ -0,0 +1,45 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/struct_member_size_attribute.h"
+
+#include <string>
+
+#include "src/tint/clone_context.h"
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::StructMemberSizeAttribute);
+
+namespace tint::ast {
+
+StructMemberSizeAttribute::StructMemberSizeAttribute(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* exp)
+ : Base(pid, nid, src), expr(exp) {}
+
+StructMemberSizeAttribute::~StructMemberSizeAttribute() = default;
+
+std::string StructMemberSizeAttribute::Name() const {
+ return "size";
+}
+
+const StructMemberSizeAttribute* StructMemberSizeAttribute::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto expr_ = ctx->Clone(expr);
+ return ctx->dst->create<StructMemberSizeAttribute>(src, expr_);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/struct_member_size_attribute.h b/src/tint/lang/wgsl/ast/struct_member_size_attribute.h
new file mode 100644
index 0000000..5b818a2
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/struct_member_size_attribute.h
@@ -0,0 +1,53 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_STRUCT_MEMBER_SIZE_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_STRUCT_MEMBER_SIZE_ATTRIBUTE_H_
+
+#include <stddef.h>
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+namespace tint::ast {
+
+/// A struct member size attribute
+class StructMemberSizeAttribute final
+ : public utils::Castable<StructMemberSizeAttribute, Attribute> {
+ public:
+ /// constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param expr the size expression
+ StructMemberSizeAttribute(ProgramID pid, NodeID nid, const Source& src, const Expression* expr);
+ ~StructMemberSizeAttribute() override;
+
+ /// @returns the WGSL name for the attribute
+ std::string Name() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const StructMemberSizeAttribute* Clone(CloneContext* ctx) const override;
+
+ /// The size expression
+ const Expression* const expr;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_STRUCT_MEMBER_SIZE_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/struct_member_size_attribute_test.cc b/src/tint/lang/wgsl/ast/struct_member_size_attribute_test.cc
new file mode 100644
index 0000000..f7f4c2c
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/struct_member_size_attribute_test.cc
@@ -0,0 +1,32 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/struct_member_size_attribute.h"
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using StructMemberSizeAttributeTest = TestHelper;
+
+TEST_F(StructMemberSizeAttributeTest, Creation) {
+ auto* d = MemberSize(2_u);
+ ASSERT_TRUE(d->expr->Is<IntLiteralExpression>());
+ EXPECT_EQ(2u, d->expr->As<IntLiteralExpression>()->value);
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/struct_member_test.cc b/src/tint/lang/wgsl/ast/struct_member_test.cc
new file mode 100644
index 0000000..2b82094
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/struct_member_test.cc
@@ -0,0 +1,96 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using StructMemberTest = TestHelper;
+
+TEST_F(StructMemberTest, Creation) {
+ auto* st = Member("a", ty.i32(), utils::Vector{MemberSize(4_a)});
+ CheckIdentifier(st->name, "a");
+ CheckIdentifier(st->type, "i32");
+ EXPECT_EQ(st->attributes.Length(), 1u);
+ EXPECT_TRUE(st->attributes[0]->Is<StructMemberSizeAttribute>());
+ EXPECT_EQ(st->source.range.begin.line, 0u);
+ EXPECT_EQ(st->source.range.begin.column, 0u);
+ EXPECT_EQ(st->source.range.end.line, 0u);
+ EXPECT_EQ(st->source.range.end.column, 0u);
+}
+
+TEST_F(StructMemberTest, CreationWithSource) {
+ auto* st = Member(Source{Source::Range{Source::Location{27, 4}, Source::Location{27, 8}}}, "a",
+ ty.i32());
+ CheckIdentifier(st->name, "a");
+ CheckIdentifier(st->type, "i32");
+ EXPECT_EQ(st->attributes.Length(), 0u);
+ EXPECT_EQ(st->source.range.begin.line, 27u);
+ EXPECT_EQ(st->source.range.begin.column, 4u);
+ EXPECT_EQ(st->source.range.end.line, 27u);
+ EXPECT_EQ(st->source.range.end.column, 8u);
+}
+
+TEST_F(StructMemberTest, Assert_Null_Name) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.Member(static_cast<Identifier*>(nullptr), b.ty.i32());
+ },
+ "internal compiler error");
+}
+
+TEST_F(StructMemberTest, Assert_Null_Type) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.Member("a", Type{});
+ },
+ "internal compiler error");
+}
+
+TEST_F(StructMemberTest, Assert_Null_Attribute) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.Member("a", b.ty.i32(), utils::Vector{b.MemberSize(4_a), nullptr});
+ },
+ "internal compiler error");
+}
+
+TEST_F(StructMemberTest, Assert_DifferentProgramID_Symbol) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Member(b2.Sym("a"), b1.ty.i32(), utils::Vector{b1.MemberSize(4_a)});
+ },
+ "internal compiler error");
+}
+
+TEST_F(StructMemberTest, Assert_DifferentProgramID_Attribute) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Member("a", b1.ty.i32(), utils::Vector{b2.MemberSize(4_a)});
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/struct_test.cc b/src/tint/lang/wgsl/ast/struct_test.cc
new file mode 100644
index 0000000..499723b
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/struct_test.cc
@@ -0,0 +1,114 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/struct.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/alias.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/transform/add_block_attribute.h"
+
+namespace tint::ast {
+namespace {
+
+using AstStructTest = TestHelper;
+using BlockAttribute = transform::AddBlockAttribute::BlockAttribute;
+
+TEST_F(AstStructTest, Creation) {
+ auto name = Sym("s");
+ auto* s = Structure(name, utils::Vector{Member("a", ty.i32())});
+ EXPECT_EQ(s->name->symbol, name);
+ EXPECT_EQ(s->members.Length(), 1u);
+ EXPECT_TRUE(s->attributes.IsEmpty());
+ EXPECT_EQ(s->source.range.begin.line, 0u);
+ EXPECT_EQ(s->source.range.begin.column, 0u);
+ EXPECT_EQ(s->source.range.end.line, 0u);
+ EXPECT_EQ(s->source.range.end.column, 0u);
+}
+
+TEST_F(AstStructTest, Creation_WithAttributes) {
+ auto name = Sym("s");
+
+ auto* s = Structure(name, utils::Vector{Member("a", ty.i32())},
+ utils::Vector{
+ ASTNodes().Create<BlockAttribute>(ID(), AllocateNodeID()),
+ });
+ EXPECT_EQ(s->name->symbol, name);
+ EXPECT_EQ(s->members.Length(), 1u);
+ ASSERT_EQ(s->attributes.Length(), 1u);
+ EXPECT_TRUE(s->attributes[0]->Is<BlockAttribute>());
+ EXPECT_EQ(s->source.range.begin.line, 0u);
+ EXPECT_EQ(s->source.range.begin.column, 0u);
+ EXPECT_EQ(s->source.range.end.line, 0u);
+ EXPECT_EQ(s->source.range.end.column, 0u);
+}
+
+TEST_F(AstStructTest, CreationWithSourceAndAttributes) {
+ auto name = Sym("s");
+ auto* s = Structure(Source{Source::Range{Source::Location{27, 4}, Source::Location{27, 8}}},
+ name, utils::Vector{Member("a", ty.i32())},
+ utils::Vector{ASTNodes().Create<BlockAttribute>(ID(), AllocateNodeID())});
+ EXPECT_EQ(s->name->symbol, name);
+ EXPECT_EQ(s->members.Length(), 1u);
+ ASSERT_EQ(s->attributes.Length(), 1u);
+ EXPECT_TRUE(s->attributes[0]->Is<BlockAttribute>());
+ EXPECT_EQ(s->source.range.begin.line, 27u);
+ EXPECT_EQ(s->source.range.begin.column, 4u);
+ EXPECT_EQ(s->source.range.end.line, 27u);
+ EXPECT_EQ(s->source.range.end.column, 8u);
+}
+
+TEST_F(AstStructTest, Assert_Null_StructMember) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.Structure(b.Sym("S"), utils::Vector{b.Member("a", b.ty.i32()), nullptr},
+ utils::Empty);
+ },
+ "internal compiler error");
+}
+
+TEST_F(AstStructTest, Assert_Null_Attribute) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.Structure(b.Sym("S"), utils::Vector{b.Member("a", b.ty.i32())},
+ utils::Vector<const Attribute*, 1>{nullptr});
+ },
+ "internal compiler error");
+}
+
+TEST_F(AstStructTest, Assert_DifferentProgramID_StructMember) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Structure(b1.Sym("S"), utils::Vector{b2.Member("a", b2.ty.i32())});
+ },
+ "internal compiler error");
+}
+
+TEST_F(AstStructTest, Assert_DifferentProgramID_Attribute) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Structure(
+ b1.Sym("S"), utils::Vector{b1.Member("a", b1.ty.i32())},
+ utils::Vector{b2.ASTNodes().Create<BlockAttribute>(b2.ID(), b2.AllocateNodeID())});
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/switch_statement.cc b/src/tint/lang/wgsl/ast/switch_statement.cc
new file mode 100644
index 0000000..858001c
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/switch_statement.cc
@@ -0,0 +1,66 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/switch_statement.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::SwitchStatement);
+
+namespace tint::ast {
+
+SwitchStatement::SwitchStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* cond,
+ utils::VectorRef<const CaseStatement*> b,
+ utils::VectorRef<const Attribute*> stmt_attrs,
+ utils::VectorRef<const Attribute*> body_attrs)
+ : Base(pid, nid, src),
+ condition(cond),
+ body(std::move(b)),
+ attributes(std::move(stmt_attrs)),
+ body_attributes(std::move(body_attrs)) {
+ TINT_ASSERT(AST, condition);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, condition, program_id);
+ for (auto* stmt : body) {
+ TINT_ASSERT(AST, stmt);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, stmt, program_id);
+ }
+ for (auto* attr : attributes) {
+ TINT_ASSERT(AST, attr);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, attr, program_id);
+ }
+ for (auto* attr : body_attributes) {
+ TINT_ASSERT(AST, attr);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, attr, program_id);
+ }
+}
+
+SwitchStatement::~SwitchStatement() = default;
+
+const SwitchStatement* SwitchStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* cond = ctx->Clone(condition);
+ auto b = ctx->Clone(body);
+ auto attrs = ctx->Clone(attributes);
+ auto body_attrs = ctx->Clone(body_attributes);
+ return ctx->dst->create<SwitchStatement>(src, cond, std::move(b), std::move(attrs),
+ std::move(body_attrs));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/switch_statement.h b/src/tint/lang/wgsl/ast/switch_statement.h
new file mode 100644
index 0000000..ca958c5
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/switch_statement.h
@@ -0,0 +1,67 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_SWITCH_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_SWITCH_STATEMENT_H_
+
+#include "src/tint/lang/wgsl/ast/case_statement.h"
+#include "src/tint/lang/wgsl/ast/expression.h"
+
+namespace tint::ast {
+
+/// A switch statement
+class SwitchStatement final : public utils::Castable<SwitchStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param condition the switch condition
+ /// @param body the switch body
+ /// @param stmt_attributes the switch statement attributes
+ /// @param body_attributes the switch body attributes
+ SwitchStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* condition,
+ utils::VectorRef<const CaseStatement*> body,
+ utils::VectorRef<const Attribute*> stmt_attributes,
+ utils::VectorRef<const Attribute*> body_attributes);
+
+ /// Destructor
+ ~SwitchStatement() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const SwitchStatement* Clone(CloneContext* ctx) const override;
+
+ /// The switch condition or nullptr if none set
+ const Expression* const condition;
+
+ /// The Switch body
+ const utils::Vector<const CaseStatement*, 4> body;
+ SwitchStatement(const SwitchStatement&) = delete;
+
+ /// The attribute list for the statement
+ const utils::Vector<const Attribute*, 1> attributes;
+
+ /// The attribute list for the body
+ const utils::Vector<const Attribute*, 1> body_attributes;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_SWITCH_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/switch_statement_test.cc b/src/tint/lang/wgsl/ast/switch_statement_test.cc
new file mode 100644
index 0000000..a3846dc
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/switch_statement_test.cc
@@ -0,0 +1,138 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/switch_statement.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast {
+namespace {
+
+using SwitchStatementTest = TestHelper;
+
+TEST_F(SwitchStatementTest, Creation) {
+ auto* case_stmt = create<CaseStatement>(utils::Vector{CaseSelector(1_u)}, Block());
+ auto* ident = Expr("ident");
+ utils::Vector body{case_stmt};
+
+ auto* stmt = create<SwitchStatement>(ident, body, utils::Empty, utils::Empty);
+ EXPECT_EQ(stmt->condition, ident);
+ ASSERT_EQ(stmt->body.Length(), 1u);
+ EXPECT_EQ(stmt->body[0], case_stmt);
+}
+
+TEST_F(SwitchStatementTest, Creation_WithSource) {
+ auto* ident = Expr("ident");
+ auto* stmt = create<SwitchStatement>(Source{Source::Location{20, 2}}, ident, utils::Empty,
+ utils::Empty, utils::Empty);
+ auto src = stmt->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(SwitchStatementTest, Creation_WithAttributes) {
+ auto* attr1 = DiagnosticAttribute(builtin::DiagnosticSeverity::kOff, "foo");
+ auto* attr2 = DiagnosticAttribute(builtin::DiagnosticSeverity::kOff, "bar");
+ auto* ident = Expr("ident");
+ auto* stmt =
+ create<SwitchStatement>(ident, utils::Empty, utils::Vector{attr1, attr2}, utils::Empty);
+
+ EXPECT_THAT(stmt->attributes, testing::ElementsAre(attr1, attr2));
+}
+
+TEST_F(SwitchStatementTest, Creation_WithBodyAttributes) {
+ auto* attr1 = DiagnosticAttribute(builtin::DiagnosticSeverity::kOff, "foo");
+ auto* attr2 = DiagnosticAttribute(builtin::DiagnosticSeverity::kOff, "bar");
+ auto* ident = Expr("ident");
+ auto* stmt =
+ create<SwitchStatement>(ident, utils::Empty, utils::Empty, utils::Vector{attr1, attr2});
+
+ EXPECT_THAT(stmt->body_attributes, testing::ElementsAre(attr1, attr2));
+}
+
+TEST_F(SwitchStatementTest, IsSwitch) {
+ utils::Vector lit{CaseSelector(2_i)};
+ auto* ident = Expr("ident");
+ utils::Vector body{create<CaseStatement>(lit, Block())};
+
+ auto* stmt = create<SwitchStatement>(ident, body, utils::Empty, utils::Empty);
+ EXPECT_TRUE(stmt->Is<SwitchStatement>());
+}
+
+TEST_F(SwitchStatementTest, Assert_Null_Condition) {
+ using CaseStatementList = utils::Vector<const CaseStatement*, 2>;
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ CaseStatementList cases;
+ cases.Push(
+ b.create<CaseStatement>(utils::Vector{b.CaseSelector(b.Expr(1_i))}, b.Block()));
+ b.create<SwitchStatement>(nullptr, cases, utils::Empty, utils::Empty);
+ },
+ "internal compiler error");
+}
+
+TEST_F(SwitchStatementTest, Assert_Null_CaseStatement) {
+ using CaseStatementList = utils::Vector<const CaseStatement*, 2>;
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<SwitchStatement>(b.Expr(true), CaseStatementList{nullptr}, utils::Empty,
+ utils::Empty);
+ },
+ "internal compiler error");
+}
+
+TEST_F(SwitchStatementTest, Assert_DifferentProgramID_Condition) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<SwitchStatement>(b2.Expr(true),
+ utils::Vector{
+ b1.create<CaseStatement>(
+ utils::Vector{
+ b1.CaseSelector(b1.Expr(1_i)),
+ },
+ b1.Block()),
+ },
+ utils::Empty, utils::Empty);
+ },
+ "internal compiler error");
+}
+
+TEST_F(SwitchStatementTest, Assert_DifferentProgramID_CaseStatement) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<SwitchStatement>(b1.Expr(true),
+ utils::Vector{
+ b2.create<CaseStatement>(
+ utils::Vector{
+ b2.CaseSelector(b2.Expr(1_i)),
+ },
+ b2.Block()),
+ },
+ utils::Empty, utils::Empty);
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/templated_identifier.cc b/src/tint/lang/wgsl/ast/templated_identifier.cc
new file mode 100644
index 0000000..3bb34d4
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/templated_identifier.cc
@@ -0,0 +1,52 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/templated_identifier.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::TemplatedIdentifier);
+
+namespace tint::ast {
+
+TemplatedIdentifier::TemplatedIdentifier(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Symbol& sym,
+ utils::VectorRef<const Expression*> args,
+ utils::VectorRef<const Attribute*> attrs)
+ : Base(pid, nid, src, sym), arguments(std::move(args)), attributes(std::move(attrs)) {
+ TINT_ASSERT(AST, !arguments.IsEmpty()); // Should have been an Identifier if this fires.
+ for (auto* arg : arguments) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL(AST, arg, program_id);
+ }
+ for (auto* attr : attributes) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, attr, program_id);
+ }
+}
+
+TemplatedIdentifier::~TemplatedIdentifier() = default;
+
+const TemplatedIdentifier* TemplatedIdentifier::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto sym = ctx->Clone(symbol);
+ auto args = ctx->Clone(arguments);
+ auto attrs = ctx->Clone(attributes);
+ return ctx->dst->create<TemplatedIdentifier>(src, sym, std::move(args), std::move(attrs));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/templated_identifier.h b/src/tint/lang/wgsl/ast/templated_identifier.h
new file mode 100644
index 0000000..086e2cd
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/templated_identifier.h
@@ -0,0 +1,62 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TEMPLATED_IDENTIFIER_H_
+#define SRC_TINT_LANG_WGSL_AST_TEMPLATED_IDENTIFIER_H_
+
+#include "src/tint/lang/wgsl/ast/identifier.h"
+
+// Forward declarations
+namespace tint::ast {
+class Attribute;
+class Expression;
+} // namespace tint::ast
+
+namespace tint::ast {
+
+/// A templated identifier expression
+class TemplatedIdentifier final : public utils::Castable<TemplatedIdentifier, Identifier> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param sym the symbol for the identifier
+ /// @param args the template arguments
+ /// @param attrs the identifier attributes
+ TemplatedIdentifier(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Symbol& sym,
+ utils::VectorRef<const Expression*> args,
+ utils::VectorRef<const Attribute*> attrs);
+
+ /// Destructor
+ ~TemplatedIdentifier() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext` `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const TemplatedIdentifier* Clone(CloneContext* ctx) const override;
+
+ /// The templated arguments
+ const utils::Vector<const Expression*, 3> arguments;
+
+ /// Attributes on the identifier
+ const utils::Vector<const Attribute*, 0> attributes;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_TEMPLATED_IDENTIFIER_H_
diff --git a/src/tint/lang/wgsl/ast/templated_identifier_test.cc b/src/tint/lang/wgsl/ast/templated_identifier_test.cc
new file mode 100644
index 0000000..01f243f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/templated_identifier_test.cc
@@ -0,0 +1,84 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest-spi.h"
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+
+using TemplatedIdentifierTest = TestHelper;
+
+TEST_F(TemplatedIdentifierTest, Creation) {
+ auto* i = Ident("ident", 1_a, Add("x", "y"), false, "x");
+ EXPECT_EQ(i->symbol, Symbols().Get("ident"));
+ auto* t = i->As<TemplatedIdentifier>();
+ ASSERT_NE(t, nullptr);
+ ASSERT_EQ(t->arguments.Length(), 4u);
+ EXPECT_TRUE(t->arguments[0]->Is<IntLiteralExpression>());
+ EXPECT_TRUE(t->arguments[1]->Is<BinaryExpression>());
+ EXPECT_TRUE(t->arguments[2]->Is<BoolLiteralExpression>());
+ EXPECT_TRUE(t->arguments[3]->Is<IdentifierExpression>());
+}
+
+TEST_F(TemplatedIdentifierTest, Creation_WithSource) {
+ auto* i = Ident(Source{{20, 2}}, "ident", 1_a, Add("x", "y"), false, "x");
+ EXPECT_EQ(i->symbol, Symbols().Get("ident"));
+ auto* t = i->As<TemplatedIdentifier>();
+ ASSERT_NE(t, nullptr);
+ ASSERT_EQ(t->arguments.Length(), 4u);
+ EXPECT_TRUE(t->arguments[0]->Is<IntLiteralExpression>());
+ EXPECT_TRUE(t->arguments[1]->Is<BinaryExpression>());
+ EXPECT_TRUE(t->arguments[2]->Is<BoolLiteralExpression>());
+ EXPECT_TRUE(t->arguments[3]->Is<IdentifierExpression>());
+
+ auto src = i->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(TemplatedIdentifierTest, Assert_InvalidSymbol) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.Expr("");
+ },
+ "internal compiler error");
+}
+
+TEST_F(TemplatedIdentifierTest, Assert_DifferentProgramID_Symbol) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Ident(b2.Sym("b2"), b1.Expr(1_i));
+ },
+ "internal compiler error");
+}
+
+TEST_F(TemplatedIdentifierTest, Assert_DifferentProgramID_TemplateArg) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Ident("b1", b2.Expr(1_i));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/test_helper.h b/src/tint/lang/wgsl/ast/test_helper.h
new file mode 100644
index 0000000..3f3656dc
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/test_helper.h
@@ -0,0 +1,170 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TEST_HELPER_H_
+#define SRC_TINT_LANG_WGSL_AST_TEST_HELPER_H_
+
+#include <tuple>
+#include <utility>
+
+#include "gtest/gtest.h"
+#include "src/tint/program_builder.h"
+
+namespace tint::ast {
+
+/// Helper base class for testing
+template <typename BASE>
+class TestHelperBase : public BASE, public ProgramBuilder {};
+
+/// Helper class for testing that derives from testing::Test.
+using TestHelper = TestHelperBase<testing::Test>;
+
+/// Helper class for testing that derives from `T`.
+template <typename T>
+using TestParamHelper = TestHelperBase<testing::TestWithParam<T>>;
+
+/// A structure to hold a TemplatedIdentifier matcher, used by the CheckIdentifier() test helper
+template <typename... ARGS>
+struct TemplatedIdentifierMatcher {
+ /// The expected name of the TemplatedIdentifier
+ std::string_view name;
+ /// The expected arguments of the TemplatedIdentifier
+ std::tuple<ARGS...> args;
+};
+
+/// Deduction guide for TemplatedIdentifierMatcher
+template <typename... ARGS>
+TemplatedIdentifierMatcher(std::string_view, std::tuple<ARGS...>&&)
+ -> TemplatedIdentifierMatcher<ARGS...>;
+
+/// A helper function for building a TemplatedIdentifierMatcher
+/// @param name the name of the TemplatedIdentifier
+/// @param args the template arguments
+/// @return a TemplatedIdentifierMatcher
+template <typename... ARGS>
+auto Template(std::string_view name, ARGS&&... args) {
+ return TemplatedIdentifierMatcher{name, std::make_tuple(std::forward<ARGS>(args)...)};
+}
+
+/// A traits helper for determining whether the type T is a TemplatedIdentifierMatcher.
+template <typename T>
+struct IsTemplatedIdentifierMatcher {
+ /// True iff T is a TemplatedIdentifierMatcher
+ static constexpr bool value = false;
+};
+
+/// IsTemplatedIdentifierMatcher specialization for TemplatedIdentifierMatcher.
+template <typename... ARGS>
+struct IsTemplatedIdentifierMatcher<TemplatedIdentifierMatcher<ARGS...>> {
+ /// True iff T is a TemplatedIdentifierMatcher
+ static constexpr bool value = true;
+};
+
+/// A testing utility for checking that an Identifier matches the expected values.
+/// @param got the identifier
+/// @param expected the expected identifier name
+template <typename... ARGS>
+void CheckIdentifier(const Identifier* got, std::string_view expected) {
+ EXPECT_FALSE(got->Is<TemplatedIdentifier>());
+ EXPECT_EQ(got->symbol.Name(), expected);
+}
+
+/// A testing utility for checking that an Identifier matches the expected name and template
+/// arguments.
+/// @param ident the identifier
+/// @param expected the expected identifier name and arguments
+template <typename... ARGS>
+void CheckIdentifier(const Identifier* ident, const TemplatedIdentifierMatcher<ARGS...>& expected) {
+ EXPECT_EQ(ident->symbol.Name(), expected.name);
+ ASSERT_TRUE(ident->Is<TemplatedIdentifier>());
+ auto* got = ident->As<TemplatedIdentifier>();
+ ASSERT_EQ(got->arguments.Length(), std::tuple_size_v<decltype(expected.args)>);
+
+ size_t arg_idx = 0;
+ auto check_arg = [&](auto&& expected_arg) {
+ const auto* got_arg = got->arguments[arg_idx++];
+
+ using T = std::decay_t<decltype(expected_arg)>;
+ if constexpr (utils::traits::IsStringLike<T>) {
+ ASSERT_TRUE(got_arg->Is<IdentifierExpression>());
+ CheckIdentifier(got_arg->As<IdentifierExpression>()->identifier, expected_arg);
+ } else if constexpr (IsTemplatedIdentifierMatcher<T>::value) {
+ ASSERT_TRUE(got_arg->Is<IdentifierExpression>());
+ auto* got_ident = got_arg->As<IdentifierExpression>()->identifier;
+ ASSERT_TRUE(got_ident->Is<TemplatedIdentifier>());
+ CheckIdentifier(got_ident->As<TemplatedIdentifier>(), expected_arg);
+ } else if constexpr (std::is_same_v<T, bool>) {
+ ASSERT_TRUE(got_arg->Is<BoolLiteralExpression>());
+ EXPECT_EQ(got_arg->As<BoolLiteralExpression>()->value, expected_arg);
+ } else if constexpr (std::is_same_v<T, AInt>) {
+ ASSERT_TRUE(got_arg->Is<IntLiteralExpression>());
+ EXPECT_EQ(got_arg->As<IntLiteralExpression>()->suffix,
+ IntLiteralExpression::Suffix::kNone);
+ EXPECT_EQ(AInt(got_arg->As<IntLiteralExpression>()->value), expected_arg);
+ } else if constexpr (std::is_same_v<T, i32>) {
+ ASSERT_TRUE(got_arg->Is<IntLiteralExpression>());
+ EXPECT_EQ(got_arg->As<IntLiteralExpression>()->suffix,
+ IntLiteralExpression::Suffix::kI);
+ EXPECT_EQ(i32(got_arg->As<IntLiteralExpression>()->value), expected_arg);
+ } else if constexpr (std::is_same_v<T, u32>) {
+ ASSERT_TRUE(got_arg->Is<IntLiteralExpression>());
+ EXPECT_EQ(got_arg->As<IntLiteralExpression>()->suffix,
+ IntLiteralExpression::Suffix::kU);
+ EXPECT_EQ(u32(got_arg->As<IntLiteralExpression>()->value), expected_arg);
+ } else if constexpr (std::is_same_v<T, AFloat>) {
+ ASSERT_TRUE(got_arg->Is<FloatLiteralExpression>());
+ EXPECT_EQ(got_arg->As<FloatLiteralExpression>()->suffix,
+ FloatLiteralExpression::Suffix::kNone);
+ EXPECT_EQ(AFloat(got_arg->As<FloatLiteralExpression>()->value), expected_arg);
+ } else if constexpr (std::is_same_v<T, f32>) {
+ ASSERT_TRUE(got_arg->Is<FloatLiteralExpression>());
+ EXPECT_EQ(got_arg->As<FloatLiteralExpression>()->suffix,
+ FloatLiteralExpression::Suffix::kF);
+ EXPECT_EQ(f32(got_arg->As<FloatLiteralExpression>()->value), expected_arg);
+ } else if constexpr (std::is_same_v<T, f16>) {
+ ASSERT_TRUE(got_arg->Is<FloatLiteralExpression>());
+ EXPECT_EQ(got_arg->As<FloatLiteralExpression>()->suffix,
+ FloatLiteralExpression::Suffix::kH);
+ EXPECT_EQ(f16(got_arg->As<FloatLiteralExpression>()->value), expected_arg);
+ } else {
+ FAIL() << "unhandled expected_args type";
+ }
+ };
+ std::apply([&](auto&&... args) { ((check_arg(args)), ...); }, expected.args);
+}
+
+/// A testing utility for checking that an IdentifierExpression matches the expected values.
+/// @param expr the IdentifierExpression
+/// @param expected the expected identifier name
+template <typename... ARGS>
+void CheckIdentifier(const Expression* expr, std::string_view expected) {
+ auto* expr_ident = expr->As<IdentifierExpression>();
+ ASSERT_NE(expr_ident, nullptr) << "expression is not a IdentifierExpression";
+ CheckIdentifier(expr_ident->identifier, expected);
+}
+
+/// A testing utility for checking that an IdentifierExpression matches the expected name and
+/// template arguments.
+/// @param expr the IdentifierExpression
+/// @param expected the expected identifier name and arguments
+template <typename... ARGS>
+void CheckIdentifier(const Expression* expr, const TemplatedIdentifierMatcher<ARGS...>& expected) {
+ auto* expr_ident = expr->As<IdentifierExpression>();
+ ASSERT_NE(expr_ident, nullptr) << "expression is not a IdentifierExpression";
+ CheckIdentifier(expr_ident->identifier, expected);
+}
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_TEST_HELPER_H_
diff --git a/src/tint/lang/wgsl/ast/test_helper_test.cc b/src/tint/lang/wgsl/ast/test_helper_test.cc
new file mode 100644
index 0000000..6f397f6
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/test_helper_test.cc
@@ -0,0 +1,42 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+
+using namespace tint::number_suffixes; // NOLINT
+
+using AstCheckIdentifierTest = TestHelper;
+
+TEST_F(AstCheckIdentifierTest, NonTemplated) {
+ CheckIdentifier(Ident("abc"), "abc");
+}
+
+TEST_F(AstCheckIdentifierTest, TemplatedScalars) {
+ CheckIdentifier(Ident("abc", 1_i, 2_u, 3_f, 4_h, 5_a, 6._a, true), //
+ Template("abc", 1_i, 2_u, 3_f, 4_h, 5_a, 6._a, true));
+}
+
+TEST_F(AstCheckIdentifierTest, TemplatedIdentifiers) {
+ CheckIdentifier(Ident("abc", "one", "two", "three"), //
+ Template("abc", "one", "two", "three"));
+}
+
+TEST_F(AstCheckIdentifierTest, NestedTemplate) {
+ CheckIdentifier(Ident("abc", "pre", Ident("nested", 42_a), "post"), //
+ Template("abc", "pre", Template("nested", 42_a), "post"));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/transform/add_block_attribute.cc b/src/tint/lang/wgsl/ast/transform/add_block_attribute.cc
new file mode 100644
index 0000000..1c53ecb
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/add_block_attribute.cc
@@ -0,0 +1,117 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/add_block_attribute.h"
+
+#include <unordered_set>
+#include <utility>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/utils/hashmap.h"
+#include "src/tint/utils/hashset.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::AddBlockAttribute);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::AddBlockAttribute::BlockAttribute);
+
+namespace tint::ast::transform {
+
+AddBlockAttribute::AddBlockAttribute() = default;
+
+AddBlockAttribute::~AddBlockAttribute() = default;
+
+Transform::ApplyResult AddBlockAttribute::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ auto& sem = src->Sem();
+
+ // A map from a type in the source program to a block-decorated wrapper that contains it in the
+ // destination program.
+ utils::Hashmap<const type::Type*, const Struct*, 8> wrapper_structs;
+
+ // Process global 'var' declarations that are buffers.
+ bool made_changes = false;
+ for (auto* global : src->AST().GlobalVariables()) {
+ auto* var = sem.Get(global);
+ if (!builtin::IsHostShareable(var->AddressSpace())) {
+ // Not declared in a host-sharable address space
+ continue;
+ }
+
+ made_changes = true;
+
+ auto* ty = var->Type()->UnwrapRef();
+ auto* str = ty->As<sem::Struct>();
+
+ // Always try to wrap the buffer type into a struct. We can not do so only if it is a struct
+ // but without a fixed footprint, i.e. contains a runtime-sized array as its member. Note
+ // that such struct type can be only used as storage buffer variables' type. Also note that
+ // any buffer struct type that may be nested by another type must have a fixed footprint,
+ // therefore will be wrapped.
+ bool needs_wrapping = !str || // Type is not a structure
+ str->HasFixedFootprint(); // Struct has a fixed footprint
+
+ if (needs_wrapping) {
+ const char* kMemberName = "inner";
+
+ auto* wrapper = wrapper_structs.GetOrCreate(ty, [&] {
+ auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID());
+ auto wrapper_name = global->name->symbol.Name() + "_block";
+ auto* ret = b.create<Struct>(
+ b.Ident(b.Symbols().New(wrapper_name)),
+ utils::Vector{b.Member(kMemberName, CreateASTTypeFor(ctx, ty))},
+ utils::Vector{block});
+ ctx.InsertBefore(src->AST().GlobalDeclarations(), global, ret);
+ return ret;
+ });
+ ctx.Replace(global->type.expr, b.Expr(wrapper->name->symbol));
+
+ // Insert a member accessor to get the original type from the wrapper at
+ // any usage of the original variable.
+ for (auto* user : var->Users()) {
+ ctx.Replace(user->Declaration(),
+ b.MemberAccessor(ctx.Clone(global->name->symbol), kMemberName));
+ }
+ } else {
+ // Add a block attribute to this struct directly.
+ auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID());
+ ctx.InsertFront(str->Declaration()->attributes, block);
+ }
+ }
+
+ if (!made_changes) {
+ return SkipTransform;
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+AddBlockAttribute::BlockAttribute::BlockAttribute(ProgramID pid, NodeID nid)
+ : Base(pid, nid, utils::Empty) {}
+AddBlockAttribute::BlockAttribute::~BlockAttribute() = default;
+std::string AddBlockAttribute::BlockAttribute::InternalName() const {
+ return "block";
+}
+
+const AddBlockAttribute::BlockAttribute* AddBlockAttribute::BlockAttribute::Clone(
+ CloneContext* ctx) const {
+ return ctx->dst->ASTNodes().Create<AddBlockAttribute::BlockAttribute>(
+ ctx->dst->ID(), ctx->dst->AllocateNodeID());
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/add_block_attribute.h b/src/tint/lang/wgsl/ast/transform/add_block_attribute.h
new file mode 100644
index 0000000..e083b6d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/add_block_attribute.h
@@ -0,0 +1,64 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_ADD_BLOCK_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_ADD_BLOCK_ATTRIBUTE_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/internal_attribute.h"
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// AddBlockAttribute is a transform that wrap the store type of a buffer into a struct if possible,
+/// then adds an `@internal(block)` attribute to the wrapper struct.
+class AddBlockAttribute final : public utils::Castable<AddBlockAttribute, Transform> {
+ public:
+ /// BlockAttribute is an InternalAttribute that is used to decorate a
+ // structure that is used as a buffer in SPIR-V or GLSL.
+ class BlockAttribute final : public utils::Castable<BlockAttribute, InternalAttribute> {
+ public:
+ /// Constructor
+ /// @param program_id the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ BlockAttribute(ProgramID program_id, NodeID nid);
+ /// Destructor
+ ~BlockAttribute() override;
+
+ /// @return a short description of the internal attribute which will be
+ /// displayed as `@internal(<name>)`
+ std::string InternalName() const override;
+
+ /// Performs a deep clone of this object using the CloneContext `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned object
+ const BlockAttribute* Clone(CloneContext* ctx) const override;
+ };
+
+ /// Constructor
+ AddBlockAttribute();
+
+ /// Destructor
+ ~AddBlockAttribute() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_ADD_BLOCK_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/transform/add_block_attribute_test.cc b/src/tint/lang/wgsl/ast/transform/add_block_attribute_test.cc
new file mode 100644
index 0000000..bc4126f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/add_block_attribute_test.cc
@@ -0,0 +1,1025 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/add_block_attribute.h"
+
+#include <memory>
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using AddBlockAttributeTest = TransformTest;
+
+TEST_F(AddBlockAttributeTest, EmptyModule) {
+ auto* src = "";
+ auto* expect = "";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, Noop_UsedForPrivateVar) {
+ auto* src = R"(
+struct S {
+ f : f32,
+}
+
+var<private> p : S;
+
+@fragment
+fn main() {
+ p.f = 1.0;
+}
+)";
+ auto* expect = src;
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, Noop_UsedForShaderIO) {
+ auto* src = R"(
+struct S {
+ @location(0)
+ f : f32,
+}
+
+@fragment
+fn main() -> S {
+ return S();
+}
+)";
+ auto* expect = src;
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, BasicScalar) {
+ auto* src = R"(
+@group(0) @binding(0)
+var<uniform> u : f32;
+
+@fragment
+fn main() {
+ let f = u;
+}
+)";
+ auto* expect = R"(
+@internal(block)
+struct u_block {
+ inner : f32,
+}
+
+@group(0) @binding(0) var<uniform> u : u_block;
+
+@fragment
+fn main() {
+ let f = u.inner;
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, BasicArray) {
+ auto* src = R"(
+@group(0) @binding(0)
+var<uniform> u : array<vec4<f32>, 4u>;
+
+@fragment
+fn main() {
+ let a = u;
+}
+)";
+ auto* expect = R"(
+@internal(block)
+struct u_block {
+ inner : array<vec4<f32>, 4u>,
+}
+
+@group(0) @binding(0) var<uniform> u : u_block;
+
+@fragment
+fn main() {
+ let a = u.inner;
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, BasicArray_Alias) {
+ auto* src = R"(
+alias Numbers = array<vec4<f32>, 4u>;
+
+@group(0) @binding(0)
+var<uniform> u : Numbers;
+
+@fragment
+fn main() {
+ let a = u;
+}
+)";
+ auto* expect = R"(
+alias Numbers = array<vec4<f32>, 4u>;
+
+@internal(block)
+struct u_block {
+ inner : array<vec4<f32>, 4u>,
+}
+
+@group(0) @binding(0) var<uniform> u : u_block;
+
+@fragment
+fn main() {
+ let a = u.inner;
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, BasicStruct_AccessRoot) {
+ auto* src = R"(
+struct S {
+ f : f32,
+};
+
+@group(0) @binding(0)
+var<uniform> u : S;
+
+@fragment
+fn main() {
+ let f = u;
+}
+)";
+ auto* expect = R"(
+struct S {
+ f : f32,
+}
+
+@internal(block)
+struct u_block {
+ inner : S,
+}
+
+@group(0) @binding(0) var<uniform> u : u_block;
+
+@fragment
+fn main() {
+ let f = u.inner;
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, BasicStruct_Storage_AccessRoot) {
+ auto* src = R"(
+struct S {
+ f : f32,
+};
+
+@group(0) @binding(0)
+var<storage, read_write> s : S;
+
+@fragment
+fn main() {
+ let f = s;
+}
+)";
+ auto* expect = R"(
+struct S {
+ f : f32,
+}
+
+@internal(block)
+struct s_block {
+ inner : S,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : s_block;
+
+@fragment
+fn main() {
+ let f = s.inner;
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, BasicStruct_Storage_TwoUsage_AccessRoot) {
+ auto* src = R"(
+struct S {
+ f : f32,
+};
+
+@group(0) @binding(0)
+var<storage, read_write> in : S;
+
+@group(0) @binding(1)
+var<storage, read_write> out : S;
+
+@compute @workgroup_size(1)
+fn main() {
+ out = in;
+}
+)";
+ auto* expect = R"(
+struct S {
+ f : f32,
+}
+
+@internal(block)
+struct in_block {
+ inner : S,
+}
+
+@group(0) @binding(0) var<storage, read_write> in : in_block;
+
+@group(0) @binding(1) var<storage, read_write> out : in_block;
+
+@compute @workgroup_size(1)
+fn main() {
+ out.inner = in.inner;
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, BasicStruct_AccessField) {
+ auto* src = R"(
+struct S {
+ f : f32,
+};
+
+@group(0) @binding(0)
+var<uniform> u : S;
+
+@fragment
+fn main() {
+ let f = u.f;
+}
+)";
+ auto* expect = R"(
+struct S {
+ f : f32,
+}
+
+@internal(block)
+struct u_block {
+ inner : S,
+}
+
+@group(0) @binding(0) var<uniform> u : u_block;
+
+@fragment
+fn main() {
+ let f = u.inner.f;
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, BasicScalar_PushConstant) {
+ auto* src = R"(
+enable chromium_experimental_push_constant;
+var<push_constant> u : f32;
+
+@fragment
+fn main() {
+ let f = u;
+}
+)";
+ auto* expect = R"(
+enable chromium_experimental_push_constant;
+
+@internal(block)
+struct u_block {
+ inner : f32,
+}
+
+var<push_constant> u : u_block;
+
+@fragment
+fn main() {
+ let f = u.inner;
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, BasicStruct_PushConstant) {
+ auto* src = R"(
+enable chromium_experimental_push_constant;
+struct S {
+ f : f32,
+};
+var<push_constant> u : S;
+
+@fragment
+fn main() {
+ let f = u.f;
+}
+)";
+ auto* expect = R"(
+enable chromium_experimental_push_constant;
+
+struct S {
+ f : f32,
+}
+
+@internal(block)
+struct u_block {
+ inner : S,
+}
+
+var<push_constant> u : u_block;
+
+@fragment
+fn main() {
+ let f = u.inner.f;
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, Nested_OuterBuffer_InnerNotBuffer) {
+ auto* src = R"(
+struct Inner {
+ f : f32,
+};
+
+struct Outer {
+ i : Inner,
+};
+
+@group(0) @binding(0)
+var<uniform> u : Outer;
+
+@fragment
+fn main() {
+ let f = u.i.f;
+}
+)";
+ auto* expect = R"(
+struct Inner {
+ f : f32,
+}
+
+struct Outer {
+ i : Inner,
+}
+
+@internal(block)
+struct u_block {
+ inner : Outer,
+}
+
+@group(0) @binding(0) var<uniform> u : u_block;
+
+@fragment
+fn main() {
+ let f = u.inner.i.f;
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, Nested_OuterBuffer_InnerBuffer) {
+ auto* src = R"(
+struct Inner {
+ f : f32,
+};
+
+struct Outer {
+ i : Inner,
+};
+
+@group(0) @binding(0)
+var<uniform> u0 : Outer;
+
+@group(0) @binding(1)
+var<uniform> u1 : Inner;
+
+@fragment
+fn main() {
+ let f0 = u0.i.f;
+ let f1 = u1.f;
+}
+)";
+ auto* expect = R"(
+struct Inner {
+ f : f32,
+}
+
+struct Outer {
+ i : Inner,
+}
+
+@internal(block)
+struct u0_block {
+ inner : Outer,
+}
+
+@group(0) @binding(0) var<uniform> u0 : u0_block;
+
+@internal(block)
+struct u1_block {
+ inner : Inner,
+}
+
+@group(0) @binding(1) var<uniform> u1 : u1_block;
+
+@fragment
+fn main() {
+ let f0 = u0.inner.i.f;
+ let f1 = u1.inner.f;
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, Nested_OuterNotBuffer_InnerBuffer) {
+ auto* src = R"(
+struct Inner {
+ f : f32,
+};
+
+struct Outer {
+ i : Inner,
+};
+
+var<private> p : Outer;
+
+@group(0) @binding(1)
+var<uniform> u : Inner;
+
+@fragment
+fn main() {
+ let f0 = p.i.f;
+ let f1 = u.f;
+}
+)";
+ auto* expect = R"(
+struct Inner {
+ f : f32,
+}
+
+struct Outer {
+ i : Inner,
+}
+
+var<private> p : Outer;
+
+@internal(block)
+struct u_block {
+ inner : Inner,
+}
+
+@group(0) @binding(1) var<uniform> u : u_block;
+
+@fragment
+fn main() {
+ let f0 = p.i.f;
+ let f1 = u.inner.f;
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, Nested_InnerUsedForMultipleBuffers) {
+ auto* src = R"(
+struct Inner {
+ f : f32,
+};
+
+struct S {
+ i : Inner,
+};
+
+@group(0) @binding(0)
+var<uniform> u0 : S;
+
+@group(0) @binding(1)
+var<uniform> u1 : Inner;
+
+@group(0) @binding(2)
+var<uniform> u2 : Inner;
+
+@fragment
+fn main() {
+ let f0 = u0.i.f;
+ let f1 = u1.f;
+ let f2 = u2.f;
+}
+)";
+ auto* expect = R"(
+struct Inner {
+ f : f32,
+}
+
+struct S {
+ i : Inner,
+}
+
+@internal(block)
+struct u0_block {
+ inner : S,
+}
+
+@group(0) @binding(0) var<uniform> u0 : u0_block;
+
+@internal(block)
+struct u1_block {
+ inner : Inner,
+}
+
+@group(0) @binding(1) var<uniform> u1 : u1_block;
+
+@group(0) @binding(2) var<uniform> u2 : u1_block;
+
+@fragment
+fn main() {
+ let f0 = u0.inner.i.f;
+ let f1 = u1.inner.f;
+ let f2 = u2.inner.f;
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, StructInArray) {
+ auto* src = R"(
+struct S {
+ f : f32,
+};
+
+@group(0) @binding(0)
+var<uniform> u : S;
+
+@fragment
+fn main() {
+ let f = u.f;
+ let a = array<S, 4>();
+}
+)";
+ auto* expect = R"(
+struct S {
+ f : f32,
+}
+
+@internal(block)
+struct u_block {
+ inner : S,
+}
+
+@group(0) @binding(0) var<uniform> u : u_block;
+
+@fragment
+fn main() {
+ let f = u.inner.f;
+ let a = array<S, 4>();
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, StructInArray_MultipleBuffers) {
+ auto* src = R"(
+struct S {
+ f : f32,
+};
+
+@group(0) @binding(0)
+var<uniform> u0 : S;
+
+@group(0) @binding(1)
+var<uniform> u1 : S;
+
+@fragment
+fn main() {
+ let f0 = u0.f;
+ let f1 = u1.f;
+ let a = array<S, 4>();
+}
+)";
+ auto* expect = R"(
+struct S {
+ f : f32,
+}
+
+@internal(block)
+struct u0_block {
+ inner : S,
+}
+
+@group(0) @binding(0) var<uniform> u0 : u0_block;
+
+@group(0) @binding(1) var<uniform> u1 : u0_block;
+
+@fragment
+fn main() {
+ let f0 = u0.inner.f;
+ let f1 = u1.inner.f;
+ let a = array<S, 4>();
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, Aliases_Nested_OuterBuffer_InnerBuffer) {
+ auto* src = R"(
+struct Inner {
+ f : f32,
+};
+
+alias MyInner = Inner;
+
+struct Outer {
+ i : MyInner,
+};
+
+alias MyOuter = Outer;
+
+@group(0) @binding(0)
+var<uniform> u0 : MyOuter;
+
+@group(0) @binding(1)
+var<uniform> u1 : MyInner;
+
+@fragment
+fn main() {
+ let f0 = u0.i.f;
+ let f1 = u1.f;
+}
+)";
+ auto* expect = R"(
+struct Inner {
+ f : f32,
+}
+
+alias MyInner = Inner;
+
+struct Outer {
+ i : MyInner,
+}
+
+alias MyOuter = Outer;
+
+@internal(block)
+struct u0_block {
+ inner : Outer,
+}
+
+@group(0) @binding(0) var<uniform> u0 : u0_block;
+
+@internal(block)
+struct u1_block {
+ inner : Inner,
+}
+
+@group(0) @binding(1) var<uniform> u1 : u1_block;
+
+@fragment
+fn main() {
+ let f0 = u0.inner.i.f;
+ let f1 = u1.inner.f;
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, Aliases_Nested_OuterBuffer_InnerBuffer_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn main() {
+ let f0 = u0.i.f;
+ let f1 = u1.f;
+}
+
+@group(0) @binding(1)
+var<uniform> u1 : MyInner;
+
+alias MyInner = Inner;
+
+@group(0) @binding(0)
+var<uniform> u0 : MyOuter;
+
+alias MyOuter = Outer;
+
+struct Outer {
+ i : MyInner,
+};
+
+struct Inner {
+ f : f32,
+};
+)";
+ auto* expect = R"(
+@fragment
+fn main() {
+ let f0 = u0.inner.i.f;
+ let f1 = u1.inner.f;
+}
+
+@internal(block)
+struct u1_block {
+ inner : Inner,
+}
+
+@group(0) @binding(1) var<uniform> u1 : u1_block;
+
+alias MyInner = Inner;
+
+@internal(block)
+struct u0_block {
+ inner : Outer,
+}
+
+@group(0) @binding(0) var<uniform> u0 : u0_block;
+
+alias MyOuter = Outer;
+
+struct Outer {
+ i : MyInner,
+}
+
+struct Inner {
+ f : f32,
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, UniformAndPrivateUsages) {
+ auto* src = R"(
+struct S {
+ f : f32,
+}
+
+@group(0) @binding(0) var<uniform> u : S;
+
+var<private> p : S;
+
+@fragment
+fn main() {
+ p = u;
+}
+)";
+ auto* expect = R"(
+struct S {
+ f : f32,
+}
+
+@internal(block)
+struct u_block {
+ inner : S,
+}
+
+@group(0) @binding(0) var<uniform> u : u_block;
+
+var<private> p : S;
+
+@fragment
+fn main() {
+ p = u.inner;
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, StorageAndPrivateUsages) {
+ auto* src = R"(
+struct S {
+ f : f32,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+var<private> p : S;
+
+@fragment
+fn main() {
+ p = s;
+ p.f = 1234.0;
+ s = p;
+}
+)";
+ auto* expect = R"(
+struct S {
+ f : f32,
+}
+
+@internal(block)
+struct s_block {
+ inner : S,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : s_block;
+
+var<private> p : S;
+
+@fragment
+fn main() {
+ p = s.inner;
+ p.f = 1234.0;
+ s.inner = p;
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, StorageAndUniformUsages) {
+ auto* src = R"(
+struct S {
+ f : f32,
+}
+
+@group(0) @binding(0) var<uniform> u : S;
+
+@group(0) @binding(1) var<storage, read_write> s : S;
+
+@fragment
+fn main() {
+ s = u;
+}
+)";
+ auto* expect = R"(
+struct S {
+ f : f32,
+}
+
+@internal(block)
+struct u_block {
+ inner : S,
+}
+
+@group(0) @binding(0) var<uniform> u : u_block;
+
+@group(0) @binding(1) var<storage, read_write> s : u_block;
+
+@fragment
+fn main() {
+ s.inner = u.inner;
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, PrivateUsageOnly) {
+ auto* src = R"(
+struct S {
+ f : f32,
+}
+
+var<private> p : S;
+
+@fragment
+fn main() {
+ p.f = 4321.0f;
+}
+)";
+ auto* expect = src;
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddBlockAttributeTest, StorageBufferWithRuntimeArray) {
+ auto* src = R"(
+struct S {
+ f : f32,
+}
+
+struct SWithArr {
+ f : f32,
+ arr : array<f32>,
+}
+
+@group(0) @binding(0)
+var<storage, read> in_1 : S;
+
+@group(0) @binding(1)
+var<storage, read> in_2 : SWithArr;
+
+@group(1) @binding(0)
+var<storage, read_write> out : SWithArr;
+
+@fragment
+fn main() {
+ out.f = in_1.f;
+ out.arr[0] = in_2.arr[1];
+}
+)";
+ auto* expect = R"(
+struct S {
+ f : f32,
+}
+
+@internal(block) @internal(block)
+struct SWithArr {
+ f : f32,
+ arr : array<f32>,
+}
+
+@internal(block)
+struct in_1_block {
+ inner : S,
+}
+
+@group(0) @binding(0) var<storage, read> in_1 : in_1_block;
+
+@group(0) @binding(1) var<storage, read> in_2 : SWithArr;
+
+@group(1) @binding(0) var<storage, read_write> out : SWithArr;
+
+@fragment
+fn main() {
+ out.f = in_1.inner.f;
+ out.arr[0] = in_2.arr[1];
+}
+)";
+
+ auto got = Run<AddBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/add_empty_entry_point.cc b/src/tint/lang/wgsl/ast/transform/add_empty_entry_point.cc
new file mode 100644
index 0000000..5ba160f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/add_empty_entry_point.cc
@@ -0,0 +1,63 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/add_empty_entry_point.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::AddEmptyEntryPoint);
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast::transform {
+namespace {
+
+bool ShouldRun(const Program* program) {
+ for (auto* func : program->AST().Functions()) {
+ if (func->IsEntryPoint()) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace
+
+AddEmptyEntryPoint::AddEmptyEntryPoint() = default;
+
+AddEmptyEntryPoint::~AddEmptyEntryPoint() = default;
+
+Transform::ApplyResult AddEmptyEntryPoint::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ b.Func(b.Symbols().New("unused_entry_point"), {}, b.ty.void_(), {},
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/add_empty_entry_point.h b/src/tint/lang/wgsl/ast/transform/add_empty_entry_point.h
new file mode 100644
index 0000000..572faba
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/add_empty_entry_point.h
@@ -0,0 +1,38 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_ADD_EMPTY_ENTRY_POINT_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_ADD_EMPTY_ENTRY_POINT_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// Add an empty entry point to the module, if no other entry points exist.
+class AddEmptyEntryPoint final : public utils::Castable<AddEmptyEntryPoint, Transform> {
+ public:
+ /// Constructor
+ AddEmptyEntryPoint();
+ /// Destructor
+ ~AddEmptyEntryPoint() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_ADD_EMPTY_ENTRY_POINT_H_
diff --git a/src/tint/lang/wgsl/ast/transform/add_empty_entry_point_test.cc b/src/tint/lang/wgsl/ast/transform/add_empty_entry_point_test.cc
new file mode 100644
index 0000000..fb5e65b
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/add_empty_entry_point_test.cc
@@ -0,0 +1,86 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/add_empty_entry_point.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using AddEmptyEntryPointTest = TransformTest;
+
+TEST_F(AddEmptyEntryPointTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_TRUE(ShouldRun<AddEmptyEntryPoint>(src));
+}
+
+TEST_F(AddEmptyEntryPointTest, ShouldRunExistingEntryPoint) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn existing() {}
+)";
+
+ EXPECT_FALSE(ShouldRun<AddEmptyEntryPoint>(src));
+}
+
+TEST_F(AddEmptyEntryPointTest, EmptyModule) {
+ auto* src = R"()";
+
+ auto* expect = R"(
+@compute @workgroup_size(1i)
+fn unused_entry_point() {
+}
+)";
+
+ auto got = Run<AddEmptyEntryPoint>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddEmptyEntryPointTest, ExistingEntryPoint) {
+ auto* src = R"(
+@fragment
+fn main() {
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<AddEmptyEntryPoint>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddEmptyEntryPointTest, NameClash) {
+ auto* src = R"(var<private> unused_entry_point : f32;)";
+
+ auto* expect = R"(
+@compute @workgroup_size(1i)
+fn unused_entry_point_1() {
+}
+
+var<private> unused_entry_point : f32;
+)";
+
+ auto got = Run<AddEmptyEntryPoint>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/array_length_from_uniform.cc b/src/tint/lang/wgsl/ast/transform/array_length_from_uniform.cc
new file mode 100644
index 0000000..5f404b4
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/array_length_from_uniform.cc
@@ -0,0 +1,271 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/array_length_from_uniform.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/variable.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::ArrayLengthFromUniform);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::ArrayLengthFromUniform::Config);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::ArrayLengthFromUniform::Result);
+
+namespace tint::ast::transform {
+
+namespace {
+
+bool ShouldRun(const Program* program) {
+ for (auto* fn : program->AST().Functions()) {
+ if (auto* sem_fn = program->Sem().Get(fn)) {
+ for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
+ if (builtin->Type() == builtin::Function::kArrayLength) {
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+ArrayLengthFromUniform::ArrayLengthFromUniform() = default;
+ArrayLengthFromUniform::~ArrayLengthFromUniform() = default;
+
+/// PIMPL state for the transform
+struct ArrayLengthFromUniform::State {
+ /// Constructor
+ /// @param program the source program
+ /// @param in the input transform data
+ /// @param out the output transform data
+ explicit State(const Program* program, const DataMap& in, DataMap& out)
+ : src(program), inputs(in), outputs(out) {}
+
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
+ auto* cfg = inputs.Get<Config>();
+ if (cfg == nullptr) {
+ b.Diagnostics().add_error(
+ diag::System::Transform,
+ "missing transform data for " +
+ std::string(utils::TypeInfo::Of<ArrayLengthFromUniform>().name));
+ return Program(std::move(b));
+ }
+
+ if (!ShouldRun(ctx.src)) {
+ return SkipTransform;
+ }
+
+ const char* kBufferSizeMemberName = "buffer_size";
+
+ // Determine the size of the buffer size array.
+ uint32_t max_buffer_size_index = 0;
+
+ IterateArrayLengthOnStorageVar(
+ [&](const CallExpression*, const sem::VariableUser*, const sem::GlobalVariable* var) {
+ if (auto binding = var->BindingPoint()) {
+ auto idx_itr = cfg->bindpoint_to_size_index.find(*binding);
+ if (idx_itr == cfg->bindpoint_to_size_index.end()) {
+ return;
+ }
+ if (idx_itr->second > max_buffer_size_index) {
+ max_buffer_size_index = idx_itr->second;
+ }
+ }
+ });
+
+ // Get (or create, on first call) the uniform buffer that will receive the
+ // size of each storage buffer in the module.
+ const Variable* buffer_size_ubo = nullptr;
+ auto get_ubo = [&] {
+ if (!buffer_size_ubo) {
+ // Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
+ // We do this because UBOs require an element stride that is 16-byte
+ // aligned.
+ auto* buffer_size_struct = b.Structure(
+ b.Sym(), utils::Vector{
+ b.Member(kBufferSizeMemberName,
+ b.ty.array(b.ty.vec4(b.ty.u32()),
+ u32((max_buffer_size_index / 4) + 1))),
+ });
+ buffer_size_ubo = b.GlobalVar(b.Sym(), b.ty.Of(buffer_size_struct),
+ builtin::AddressSpace::kUniform,
+ b.Group(AInt(cfg->ubo_binding.group)),
+ b.Binding(AInt(cfg->ubo_binding.binding)));
+ }
+ return buffer_size_ubo;
+ };
+
+ std::unordered_set<uint32_t> used_size_indices;
+
+ IterateArrayLengthOnStorageVar([&](const CallExpression* call_expr,
+ const sem::VariableUser* storage_buffer_sem,
+ const sem::GlobalVariable* var) {
+ auto binding = var->BindingPoint();
+ if (!binding) {
+ return;
+ }
+ auto idx_itr = cfg->bindpoint_to_size_index.find(*binding);
+ if (idx_itr == cfg->bindpoint_to_size_index.end()) {
+ return;
+ }
+
+ uint32_t size_index = idx_itr->second;
+ used_size_indices.insert(size_index);
+
+ // Load the total storage buffer size from the UBO.
+ uint32_t array_index = size_index / 4;
+ auto* vec_expr = b.IndexAccessor(
+ b.MemberAccessor(get_ubo()->name->symbol, kBufferSizeMemberName), u32(array_index));
+ uint32_t vec_index = size_index % 4;
+ auto* total_storage_buffer_size = b.IndexAccessor(vec_expr, u32(vec_index));
+
+ // Calculate actual array length
+ // total_storage_buffer_size - array_offset
+ // array_length = ----------------------------------------
+ // array_stride
+ const Expression* total_size = total_storage_buffer_size;
+ auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
+ const type::Array* array_type = nullptr;
+ if (auto* str = storage_buffer_type->As<type::Struct>()) {
+ // The variable is a struct, so subtract the byte offset of the array
+ // member.
+ auto* array_member_sem = str->Members().Back();
+ array_type = array_member_sem->Type()->As<type::Array>();
+ total_size = b.Sub(total_storage_buffer_size, u32(array_member_sem->Offset()));
+ } else if (auto* arr = storage_buffer_type->As<type::Array>()) {
+ array_type = arr;
+ } else {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "expected form of arrayLength argument to be &array_var or "
+ "&struct_var.array_member";
+ return;
+ }
+ auto* array_length = b.Div(total_size, u32(array_type->Stride()));
+
+ ctx.Replace(call_expr, array_length);
+ });
+
+ outputs.Add<Result>(used_size_indices);
+
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+
+ private:
+ /// The source program
+ const Program* const src;
+ /// The transform inputs
+ const DataMap& inputs;
+ /// The transform outputs
+ DataMap& outputs;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+
+ /// Iterate over all arrayLength() builtins that operate on
+ /// storage buffer variables.
+ /// @param functor of type void(const CallExpression*, const
+ /// sem::VariableUser, const sem::GlobalVariable*). It takes in an
+ /// CallExpression of the arrayLength call expression node, a
+ /// sem::VariableUser of the used storage buffer variable, and the
+ /// sem::GlobalVariable for the storage buffer.
+ template <typename F>
+ void IterateArrayLengthOnStorageVar(F&& functor) {
+ auto& sem = src->Sem();
+
+ // Find all calls to the arrayLength() builtin.
+ for (auto* node : src->ASTNodes().Objects()) {
+ auto* call_expr = node->As<CallExpression>();
+ if (!call_expr) {
+ continue;
+ }
+
+ auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
+ auto* builtin = call->Target()->As<sem::Builtin>();
+ if (!builtin || builtin->Type() != builtin::Function::kArrayLength) {
+ continue;
+ }
+
+ if (auto* call_stmt = call->Stmt()->Declaration()->As<CallStatement>()) {
+ if (call_stmt->expr == call_expr) {
+ // arrayLength() is used as a statement.
+ // The argument expression must be side-effect free, so just drop the statement.
+ RemoveStatement(ctx, call_stmt);
+ continue;
+ }
+ }
+
+ // Get the storage buffer that contains the runtime array.
+ // Since we require SimplifyPointers, we can assume that the arrayLength()
+ // call has one of two forms:
+ // arrayLength(&struct_var.array_member)
+ // arrayLength(&array_var)
+ auto* param = call_expr->args[0]->As<UnaryOpExpression>();
+ if (TINT_UNLIKELY(!param || param->op != UnaryOp::kAddressOf)) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "expected form of arrayLength argument to be &array_var or "
+ "&struct_var.array_member";
+ break;
+ }
+ auto* storage_buffer_expr = param->expr;
+ if (auto* accessor = param->expr->As<MemberAccessorExpression>()) {
+ storage_buffer_expr = accessor->object;
+ }
+ auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
+ if (TINT_UNLIKELY(!storage_buffer_sem)) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "expected form of arrayLength argument to be &array_var or "
+ "&struct_var.array_member";
+ break;
+ }
+
+ // Get the index to use for the buffer size array.
+ auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable());
+ if (TINT_UNLIKELY(!var)) {
+ TINT_ICE(Transform, b.Diagnostics()) << "storage buffer is not a global variable";
+ break;
+ }
+ functor(call_expr, storage_buffer_sem, var);
+ }
+ }
+};
+
+Transform::ApplyResult ArrayLengthFromUniform::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap& outputs) const {
+ return State{src, inputs, outputs}.Run();
+}
+
+ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp) : ubo_binding(ubo_bp) {}
+ArrayLengthFromUniform::Config::Config(const Config&) = default;
+ArrayLengthFromUniform::Config& ArrayLengthFromUniform::Config::operator=(const Config&) = default;
+ArrayLengthFromUniform::Config::~Config() = default;
+
+ArrayLengthFromUniform::Result::Result(std::unordered_set<uint32_t> used_size_indices_in)
+ : used_size_indices(std::move(used_size_indices_in)) {}
+ArrayLengthFromUniform::Result::Result(const Result&) = default;
+ArrayLengthFromUniform::Result::~Result() = default;
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/array_length_from_uniform.h b/src/tint/lang/wgsl/ast/transform/array_length_from_uniform.h
new file mode 100644
index 0000000..9b4cc2f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/array_length_from_uniform.h
@@ -0,0 +1,114 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_ARRAY_LENGTH_FROM_UNIFORM_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_ARRAY_LENGTH_FROM_UNIFORM_H_
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+#include "src/tint/sem/binding_point.h"
+
+// Forward declarations
+namespace tint {
+class CloneContext;
+} // namespace tint
+
+namespace tint::ast::transform {
+
+/// ArrayLengthFromUniform is a transform that implements calls to arrayLength()
+/// by calculating the length from the total size of the storage buffer, which
+/// is received via a uniform buffer.
+///
+/// The generated uniform buffer will have the form:
+/// ```
+/// struct buffer_size_struct {
+/// buffer_size : array<u32, 8>;
+/// };
+///
+/// @group(0) @binding(30)
+/// var<uniform> buffer_size_ubo : buffer_size_struct;
+/// ```
+/// The binding group and number used for this uniform buffer is provided via
+/// the `Config` transform input. The `Config` struct also defines the mapping
+/// from a storage buffer's `BindingPoint` to the array index that will be used
+/// to get the size of that buffer.
+///
+/// This transform assumes that the `SimplifyPointers`
+/// transforms have been run before it so that arguments to the arrayLength
+/// builtin always have the form `&resource.array`.
+///
+/// @note Depends on the following transforms to have been run first:
+/// * SimplifyPointers
+class ArrayLengthFromUniform final : public utils::Castable<ArrayLengthFromUniform, Transform> {
+ public:
+ /// Constructor
+ ArrayLengthFromUniform();
+ /// Destructor
+ ~ArrayLengthFromUniform() override;
+
+ /// Configuration options for the ArrayLengthFromUniform transform.
+ struct Config final : public utils::Castable<Config, Data> {
+ /// Constructor
+ /// @param ubo_bp the binding point to use for the generated uniform buffer.
+ explicit Config(sem::BindingPoint ubo_bp);
+
+ /// Copy constructor
+ Config(const Config&);
+
+ /// Copy assignment
+ /// @return this Config
+ Config& operator=(const Config&);
+
+ /// Destructor
+ ~Config() override;
+
+ /// The binding point to use for the generated uniform buffer.
+ sem::BindingPoint ubo_binding;
+
+ /// The mapping from binding point to the index for the buffer size lookup.
+ std::unordered_map<sem::BindingPoint, uint32_t> bindpoint_to_size_index;
+ };
+
+ /// Information produced about what the transform did.
+ /// If there were no calls to the arrayLength() builtin, then no Result will
+ /// be emitted.
+ struct Result final : public utils::Castable<Result, Data> {
+ /// Constructor
+ /// @param used_size_indices Indices into the UBO that are statically used.
+ explicit Result(std::unordered_set<uint32_t> used_size_indices);
+
+ /// Copy constructor
+ Result(const Result&);
+
+ /// Destructor
+ ~Result() override;
+
+ /// Indices into the UBO that are statically used.
+ std::unordered_set<uint32_t> used_size_indices;
+ };
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_ARRAY_LENGTH_FROM_UNIFORM_H_
diff --git a/src/tint/lang/wgsl/ast/transform/array_length_from_uniform_test.cc b/src/tint/lang/wgsl/ast/transform/array_length_from_uniform_test.cc
new file mode 100644
index 0000000..efa223a
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/array_length_from_uniform_test.cc
@@ -0,0 +1,516 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/array_length_from_uniform.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using ArrayLengthFromUniformTest = TransformTest;
+
+TEST_F(ArrayLengthFromUniformTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
+
+ Transform::DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+
+ EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src, data));
+}
+
+TEST_F(ArrayLengthFromUniformTest, ShouldRunNoArrayLength) {
+ auto* src = R"(
+struct SB {
+ x : i32,
+ arr : array<i32>,
+};
+
+@group(0) @binding(0) var<storage, read> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+}
+)";
+
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
+
+ Transform::DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+
+ EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src, data));
+}
+
+TEST_F(ArrayLengthFromUniformTest, ShouldRunWithArrayLength) {
+ auto* src = R"(
+struct SB {
+ x : i32,
+ arr : array<i32>,
+};
+
+@group(0) @binding(0) var<storage, read> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ var len : u32 = arrayLength(&sb.arr);
+}
+)";
+
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
+
+ Transform::DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+
+ EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src, data));
+}
+
+TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) {
+ auto* src = R"(
+struct SB {
+ x : i32,
+ arr : array<i32>,
+};
+
+@group(0) @binding(0) var<storage, read> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ var len : u32 = arrayLength(&sb.arr);
+}
+)";
+
+ auto* expect = "error: missing transform data for tint::ast::transform::ArrayLengthFromUniform";
+
+ auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ArrayLengthFromUniformTest, Basic) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read> sb : array<i32>;
+
+@compute @workgroup_size(1)
+fn main() {
+ var len : u32 = arrayLength(&sb);
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ buffer_size : array<vec4<u32>, 1u>,
+}
+
+@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
+
+@group(0) @binding(0) var<storage, read> sb : array<i32>;
+
+@compute @workgroup_size(1)
+fn main() {
+ var len : u32 = (tint_symbol_1.buffer_size[0u][0u] / 4u);
+}
+)";
+
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
+
+ Transform::DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+
+ auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+ auto* val = got.data.Get<ArrayLengthFromUniform::Result>();
+ ASSERT_NE(val, nullptr);
+ EXPECT_EQ(std::unordered_set<uint32_t>({0}), val->used_size_indices);
+}
+
+TEST_F(ArrayLengthFromUniformTest, BasicInStruct) {
+ auto* src = R"(
+struct SB {
+ x : i32,
+ arr : array<i32>,
+};
+
+@group(0) @binding(0) var<storage, read> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ var len : u32 = arrayLength(&sb.arr);
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ buffer_size : array<vec4<u32>, 1u>,
+}
+
+@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
+
+struct SB {
+ x : i32,
+ arr : array<i32>,
+}
+
+@group(0) @binding(0) var<storage, read> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ var len : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
+}
+)";
+
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
+
+ Transform::DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+
+ auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(std::unordered_set<uint32_t>({0}),
+ got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
+}
+
+TEST_F(ArrayLengthFromUniformTest, MultipleStorageBuffers) {
+ auto* src = R"(
+struct SB1 {
+ x : i32,
+ arr1 : array<i32>,
+};
+struct SB2 {
+ x : i32,
+ arr2 : array<vec4<f32>>,
+};
+struct SB4 {
+ x : i32,
+ arr4 : array<vec4<f32>>,
+};
+
+@group(0) @binding(2) var<storage, read> sb1 : SB1;
+@group(1) @binding(2) var<storage, read> sb2 : SB2;
+@group(2) @binding(2) var<storage, read> sb3 : array<vec4<f32>>;
+@group(3) @binding(2) var<storage, read> sb4 : SB4;
+@group(4) @binding(2) var<storage, read> sb5 : array<vec4<f32>>;
+
+@compute @workgroup_size(1)
+fn main() {
+ var len1 : u32 = arrayLength(&(sb1.arr1));
+ var len2 : u32 = arrayLength(&(sb2.arr2));
+ var len3 : u32 = arrayLength(&sb3);
+ var len4 : u32 = arrayLength(&(sb4.arr4));
+ var len5 : u32 = arrayLength(&sb5);
+ var x : u32 = (len1 + len2 + len3 + len4 + len5);
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ buffer_size : array<vec4<u32>, 2u>,
+}
+
+@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
+
+struct SB1 {
+ x : i32,
+ arr1 : array<i32>,
+}
+
+struct SB2 {
+ x : i32,
+ arr2 : array<vec4<f32>>,
+}
+
+struct SB4 {
+ x : i32,
+ arr4 : array<vec4<f32>>,
+}
+
+@group(0) @binding(2) var<storage, read> sb1 : SB1;
+
+@group(1) @binding(2) var<storage, read> sb2 : SB2;
+
+@group(2) @binding(2) var<storage, read> sb3 : array<vec4<f32>>;
+
+@group(3) @binding(2) var<storage, read> sb4 : SB4;
+
+@group(4) @binding(2) var<storage, read> sb5 : array<vec4<f32>>;
+
+@compute @workgroup_size(1)
+fn main() {
+ var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
+ var len2 : u32 = ((tint_symbol_1.buffer_size[0u][1u] - 16u) / 16u);
+ var len3 : u32 = (tint_symbol_1.buffer_size[0u][2u] / 16u);
+ var len4 : u32 = ((tint_symbol_1.buffer_size[0u][3u] - 16u) / 16u);
+ var len5 : u32 = (tint_symbol_1.buffer_size[1u][0u] / 16u);
+ var x : u32 = ((((len1 + len2) + len3) + len4) + len5);
+}
+)";
+
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2u}, 0);
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{1u, 2u}, 1);
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{2u, 2u}, 2);
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{3u, 2u}, 3);
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{4u, 2u}, 4);
+
+ Transform::DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+
+ auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(std::unordered_set<uint32_t>({0, 1, 2, 3, 4}),
+ got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
+}
+
+TEST_F(ArrayLengthFromUniformTest, MultipleUnusedStorageBuffers) {
+ auto* src = R"(
+struct SB1 {
+ x : i32,
+ arr1 : array<i32>,
+};
+struct SB2 {
+ x : i32,
+ arr2 : array<vec4<f32>>,
+};
+struct SB4 {
+ x : i32,
+ arr4 : array<vec4<f32>>,
+};
+
+@group(0) @binding(2) var<storage, read> sb1 : SB1;
+@group(1) @binding(2) var<storage, read> sb2 : SB2;
+@group(2) @binding(2) var<storage, read> sb3 : array<vec4<f32>>;
+@group(3) @binding(2) var<storage, read> sb4 : SB4;
+@group(4) @binding(2) var<storage, read> sb5 : array<vec4<f32>>;
+
+@compute @workgroup_size(1)
+fn main() {
+ var len1 : u32 = arrayLength(&(sb1.arr1));
+ var len3 : u32 = arrayLength(&sb3);
+ var x : u32 = (len1 + len3);
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ buffer_size : array<vec4<u32>, 1u>,
+}
+
+@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
+
+struct SB1 {
+ x : i32,
+ arr1 : array<i32>,
+}
+
+struct SB2 {
+ x : i32,
+ arr2 : array<vec4<f32>>,
+}
+
+struct SB4 {
+ x : i32,
+ arr4 : array<vec4<f32>>,
+}
+
+@group(0) @binding(2) var<storage, read> sb1 : SB1;
+
+@group(1) @binding(2) var<storage, read> sb2 : SB2;
+
+@group(2) @binding(2) var<storage, read> sb3 : array<vec4<f32>>;
+
+@group(3) @binding(2) var<storage, read> sb4 : SB4;
+
+@group(4) @binding(2) var<storage, read> sb5 : array<vec4<f32>>;
+
+@compute @workgroup_size(1)
+fn main() {
+ var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
+ var len3 : u32 = (tint_symbol_1.buffer_size[0u][2u] / 16u);
+ var x : u32 = (len1 + len3);
+}
+)";
+
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2u}, 0);
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{1u, 2u}, 1);
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{2u, 2u}, 2);
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{3u, 2u}, 3);
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{4u, 2u}, 4);
+
+ Transform::DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+
+ auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(std::unordered_set<uint32_t>({0, 2}),
+ got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
+}
+
+TEST_F(ArrayLengthFromUniformTest, NoArrayLengthCalls) {
+ auto* src = R"(
+struct SB {
+ x : i32,
+ arr : array<i32>,
+}
+
+@group(0) @binding(0) var<storage, read> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ _ = &(sb.arr);
+}
+)";
+
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
+
+ Transform::DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+
+ auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
+
+ EXPECT_EQ(src, str(got));
+ EXPECT_EQ(got.data.Get<ArrayLengthFromUniform::Result>(), nullptr);
+}
+
+TEST_F(ArrayLengthFromUniformTest, MissingBindingPointToIndexMapping) {
+ auto* src = R"(
+struct SB1 {
+ x : i32,
+ arr1 : array<i32>,
+};
+
+struct SB2 {
+ x : i32,
+ arr2 : array<vec4<f32>>,
+};
+
+@group(0) @binding(2) var<storage, read> sb1 : SB1;
+
+@group(1) @binding(2) var<storage, read> sb2 : SB2;
+
+@compute @workgroup_size(1)
+fn main() {
+ var len1 : u32 = arrayLength(&(sb1.arr1));
+ var len2 : u32 = arrayLength(&(sb2.arr2));
+ var x : u32 = (len1 + len2);
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ buffer_size : array<vec4<u32>, 1u>,
+}
+
+@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
+
+struct SB1 {
+ x : i32,
+ arr1 : array<i32>,
+}
+
+struct SB2 {
+ x : i32,
+ arr2 : array<vec4<f32>>,
+}
+
+@group(0) @binding(2) var<storage, read> sb1 : SB1;
+
+@group(1) @binding(2) var<storage, read> sb2 : SB2;
+
+@compute @workgroup_size(1)
+fn main() {
+ var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
+ var len2 : u32 = arrayLength(&(sb2.arr2));
+ var x : u32 = (len1 + len2);
+}
+)";
+
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2}, 0);
+
+ Transform::DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+
+ auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(std::unordered_set<uint32_t>({0}),
+ got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
+}
+
+TEST_F(ArrayLengthFromUniformTest, OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ var len : u32 = arrayLength(&sb.arr);
+}
+
+@group(0) @binding(0) var<storage, read> sb : SB;
+
+struct SB {
+ x : i32,
+ arr : array<i32>,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ buffer_size : array<vec4<u32>, 1u>,
+}
+
+@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
+
+@compute @workgroup_size(1)
+fn main() {
+ var len : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
+}
+
+@group(0) @binding(0) var<storage, read> sb : SB;
+
+struct SB {
+ x : i32,
+ arr : array<i32>,
+}
+)";
+
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
+
+ Transform::DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+
+ auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+ EXPECT_EQ(std::unordered_set<uint32_t>({0}),
+ got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/binding_remapper.cc b/src/tint/lang/wgsl/ast/transform/binding_remapper.cc
new file mode 100644
index 0000000..0b226c0
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/binding_remapper.cc
@@ -0,0 +1,164 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/binding_remapper.h"
+
+#include <string>
+#include <unordered_set>
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/disable_validation_attribute.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/utils/string.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::BindingRemapper);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::BindingRemapper::Remappings);
+
+namespace tint::ast::transform {
+
+BindingRemapper::Remappings::Remappings(BindingPoints bp, AccessControls ac, bool may_collide)
+ : binding_points(std::move(bp)),
+ access_controls(std::move(ac)),
+ allow_collisions(may_collide) {}
+
+BindingRemapper::Remappings::Remappings(const Remappings&) = default;
+BindingRemapper::Remappings::~Remappings() = default;
+
+BindingRemapper::BindingRemapper() = default;
+BindingRemapper::~BindingRemapper() = default;
+
+Transform::ApplyResult BindingRemapper::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ auto* remappings = inputs.Get<Remappings>();
+ if (!remappings) {
+ b.Diagnostics().add_error(diag::System::Transform,
+ "missing transform data for " + std::string(TypeInfo().name));
+ return Program(std::move(b));
+ }
+
+ if (remappings->binding_points.empty() && remappings->access_controls.empty()) {
+ return SkipTransform;
+ }
+
+ // A set of post-remapped binding points that need to be decorated with a
+ // DisableValidationAttribute to disable binding-point-collision validation
+ std::unordered_set<sem::BindingPoint> add_collision_attr;
+
+ if (remappings->allow_collisions) {
+ // Scan for binding point collisions generated by this transform.
+ // Populate all collisions in the `add_collision_attr` set.
+ for (auto* func_ast : src->AST().Functions()) {
+ if (!func_ast->IsEntryPoint()) {
+ continue;
+ }
+ auto* func = src->Sem().Get(func_ast);
+ std::unordered_map<sem::BindingPoint, int> binding_point_counts;
+ for (auto* global : func->TransitivelyReferencedGlobals()) {
+ if (auto from = global->BindingPoint()) {
+ auto bp_it = remappings->binding_points.find(*from);
+ if (bp_it != remappings->binding_points.end()) {
+ // Remapped
+ BindingPoint to = bp_it->second;
+ if (binding_point_counts[to]++) {
+ add_collision_attr.emplace(to);
+ }
+ } else {
+ // No remapping
+ if (binding_point_counts[*from]++) {
+ add_collision_attr.emplace(*from);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ for (auto* var : src->AST().Globals<Var>()) {
+ if (var->HasBindingPoint()) {
+ auto* global_sem = src->Sem().Get<sem::GlobalVariable>(var);
+
+ // The original binding point
+ BindingPoint from = *global_sem->BindingPoint();
+
+ // The binding point after remapping
+ BindingPoint bp = from;
+
+ // Replace any group or binding attributes.
+ // Note: This has to be performed *before* remapping access controls, as
+ // `ctx.Clone(var->attributes)` depend on these replacements.
+ auto bp_it = remappings->binding_points.find(from);
+ if (bp_it != remappings->binding_points.end()) {
+ BindingPoint to = bp_it->second;
+ auto* new_group = b.Group(AInt(to.group));
+ auto* new_binding = b.Binding(AInt(to.binding));
+
+ auto* old_group = GetAttribute<GroupAttribute>(var->attributes);
+ auto* old_binding = GetAttribute<BindingAttribute>(var->attributes);
+
+ ctx.Replace(old_group, new_group);
+ ctx.Replace(old_binding, new_binding);
+ bp = to;
+ }
+
+ // Replace any access controls.
+ auto ac_it = remappings->access_controls.find(from);
+ if (ac_it != remappings->access_controls.end()) {
+ builtin::Access access = ac_it->second;
+ if (access == builtin::Access::kUndefined) {
+ b.Diagnostics().add_error(diag::System::Transform,
+ "invalid access mode (" +
+ std::to_string(static_cast<uint32_t>(access)) +
+ ")");
+ return Program(std::move(b));
+ }
+ auto* sem = src->Sem().Get(var);
+ if (sem->AddressSpace() != builtin::AddressSpace::kStorage) {
+ b.Diagnostics().add_error(
+ diag::System::Transform,
+ "cannot apply access control to variable with address space " +
+ std::string(utils::ToString(sem->AddressSpace())));
+ return Program(std::move(b));
+ }
+ auto* ty = sem->Type()->UnwrapRef();
+ auto inner_ty = CreateASTTypeFor(ctx, ty);
+ auto* new_var =
+ b.create<Var>(ctx.Clone(var->source), // source
+ b.Ident(ctx.Clone(var->name->symbol)), // name
+ inner_ty, // type
+ ctx.Clone(var->declared_address_space), // address space
+ b.Expr(access), // access
+ ctx.Clone(var->initializer), // initializer
+ ctx.Clone(var->attributes)); // attributes
+ ctx.Replace(var, new_var);
+ }
+
+ // Add `DisableValidationAttribute`s if required
+ if (add_collision_attr.count(bp)) {
+ auto* attribute = b.Disable(DisabledValidation::kBindingPointCollision);
+ ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute);
+ }
+ }
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/binding_remapper.h b/src/tint/lang/wgsl/ast/transform/binding_remapper.h
new file mode 100644
index 0000000..00241bc
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/binding_remapper.h
@@ -0,0 +1,78 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_BINDING_REMAPPER_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_BINDING_REMAPPER_H_
+
+#include <unordered_map>
+
+#include "src/tint/builtin/access.h"
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+#include "src/tint/sem/binding_point.h"
+
+namespace tint::ast::transform {
+
+/// BindingPoint is an alias to sem::BindingPoint
+using BindingPoint = sem::BindingPoint;
+
+/// BindingRemapper is a transform used to remap resource binding points and
+/// access controls.
+class BindingRemapper final : public utils::Castable<BindingRemapper, Transform> {
+ public:
+ /// BindingPoints is a map of old binding point to new binding point
+ using BindingPoints = std::unordered_map<BindingPoint, BindingPoint>;
+
+ /// AccessControls is a map of old binding point to new access control
+ using AccessControls = std::unordered_map<BindingPoint, builtin::Access>;
+
+ /// Remappings is consumed by the BindingRemapper transform.
+ /// Data holds information about shader usage and constant buffer offsets.
+ struct Remappings final : public utils::Castable<Remappings, Data> {
+ /// Constructor
+ /// @param bp a map of new binding points
+ /// @param ac a map of new access controls
+ /// @param may_collide If true, then validation will be disabled for
+ /// binding point collisions generated by this transform
+ Remappings(BindingPoints bp, AccessControls ac, bool may_collide = true);
+
+ /// Copy constructor
+ Remappings(const Remappings&);
+
+ /// Destructor
+ ~Remappings() override;
+
+ /// A map of old binding point to new binding point
+ const BindingPoints binding_points;
+
+ /// A map of old binding point to new access controls
+ const AccessControls access_controls;
+
+ /// If true, then validation will be disabled for binding point collisions
+ /// generated by this transform
+ const bool allow_collisions;
+ };
+
+ /// Constructor
+ BindingRemapper();
+ ~BindingRemapper() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_BINDING_REMAPPER_H_
diff --git a/src/tint/lang/wgsl/ast/transform/binding_remapper_test.cc b/src/tint/lang/wgsl/ast/transform/binding_remapper_test.cc
new file mode 100644
index 0000000..354f0ab
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/binding_remapper_test.cc
@@ -0,0 +1,355 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/binding_remapper.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using BindingRemapperTest = TransformTest;
+
+TEST_F(BindingRemapperTest, ShouldRunEmptyRemappings) {
+ auto* src = R"()";
+
+ Transform::DataMap data;
+ data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
+ BindingRemapper::AccessControls{});
+
+ EXPECT_FALSE(ShouldRun<BindingRemapper>(src, data));
+}
+
+TEST_F(BindingRemapperTest, ShouldRunBindingPointRemappings) {
+ auto* src = R"()";
+
+ Transform::DataMap data;
+ data.Add<BindingRemapper::Remappings>(
+ BindingRemapper::BindingPoints{
+ {{2, 1}, {1, 2}},
+ },
+ BindingRemapper::AccessControls{});
+
+ EXPECT_TRUE(ShouldRun<BindingRemapper>(src, data));
+}
+
+TEST_F(BindingRemapperTest, ShouldRunAccessControlRemappings) {
+ auto* src = R"()";
+
+ Transform::DataMap data;
+ data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
+ BindingRemapper::AccessControls{
+ {{2, 1}, builtin::Access::kWrite},
+ });
+
+ EXPECT_TRUE(ShouldRun<BindingRemapper>(src, data));
+}
+
+TEST_F(BindingRemapperTest, NoRemappings) {
+ auto* src = R"(
+struct S {
+ a : f32,
+}
+
+@group(2) @binding(1) var<storage, read> a : S;
+
+@group(3) @binding(2) var<storage, read> b : S;
+
+@compute @workgroup_size(1)
+fn f() {
+}
+)";
+
+ auto* expect = src;
+
+ Transform::DataMap data;
+ data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
+ BindingRemapper::AccessControls{});
+ auto got = Run<BindingRemapper>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BindingRemapperTest, RemapBindingPoints) {
+ auto* src = R"(
+struct S {
+ a : f32,
+};
+
+@group(2) @binding(1) var<storage, read> a : S;
+
+@group(3) @binding(2) var<storage, read> b : S;
+
+@compute @workgroup_size(1)
+fn f() {
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ a : f32,
+}
+
+@group(1) @binding(2) var<storage, read> a : S;
+
+@group(3) @binding(2) var<storage, read> b : S;
+
+@compute @workgroup_size(1)
+fn f() {
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<BindingRemapper::Remappings>(
+ BindingRemapper::BindingPoints{
+ {{2, 1}, {1, 2}}, // Remap
+ {{4, 5}, {6, 7}}, // Not found
+ // Keep @group(3) @binding(2) as is
+ },
+ BindingRemapper::AccessControls{});
+ auto got = Run<BindingRemapper>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BindingRemapperTest, RemapAccessControls) {
+ auto* src = R"(
+struct S {
+ a : f32,
+};
+
+@group(2) @binding(1) var<storage, read_write> a : S;
+
+@group(3) @binding(2) var<storage, read_write> b : S;
+
+@group(4) @binding(3) var<storage, read> c : S;
+
+@compute @workgroup_size(1)
+fn f() {
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ a : f32,
+}
+
+@group(2) @binding(1) var<storage, read_write> a : S;
+
+@group(3) @binding(2) var<storage, read_write> b : S;
+
+@group(4) @binding(3) var<storage, read> c : S;
+
+@compute @workgroup_size(1)
+fn f() {
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<BindingRemapper::Remappings>(
+ BindingRemapper::BindingPoints{},
+ BindingRemapper::AccessControls{
+ {{2, 1}, builtin::Access::kReadWrite}, // Modify access control
+ // Keep @group(3) @binding(2) as is
+ {{4, 3}, builtin::Access::kRead}, // Add access control
+ });
+ auto got = Run<BindingRemapper>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BindingRemapperTest, RemapAll) {
+ auto* src = R"(
+struct S {
+ a : f32,
+};
+
+@group(2) @binding(1) var<storage, read> a : S;
+
+@group(3) @binding(2) var<storage, read> b : S;
+
+@compute @workgroup_size(1)
+fn f() {
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ a : f32,
+}
+
+@group(4) @binding(5) var<storage, read_write> a : S;
+
+@group(6) @binding(7) var<storage, read_write> b : S;
+
+@compute @workgroup_size(1)
+fn f() {
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<BindingRemapper::Remappings>(
+ BindingRemapper::BindingPoints{
+ {{2, 1}, {4, 5}},
+ {{3, 2}, {6, 7}},
+ },
+ BindingRemapper::AccessControls{
+ {{2, 1}, builtin::Access::kReadWrite},
+ {{3, 2}, builtin::Access::kReadWrite},
+ });
+ auto got = Run<BindingRemapper>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BindingRemapperTest, BindingCollisionsSameEntryPoint) {
+ auto* src = R"(
+struct S {
+ i : i32,
+};
+
+@group(2) @binding(1) var<storage, read> a : S;
+
+@group(3) @binding(2) var<storage, read> b : S;
+
+@group(4) @binding(3) var<storage, read> c : S;
+
+@group(5) @binding(4) var<storage, read> d : S;
+
+@compute @workgroup_size(1)
+fn f() {
+ let x : i32 = (((a.i + b.i) + c.i) + d.i);
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ i : i32,
+}
+
+@internal(disable_validation__binding_point_collision) @group(1) @binding(1) var<storage, read> a : S;
+
+@internal(disable_validation__binding_point_collision) @group(1) @binding(1) var<storage, read> b : S;
+
+@internal(disable_validation__binding_point_collision) @group(5) @binding(4) var<storage, read> c : S;
+
+@internal(disable_validation__binding_point_collision) @group(5) @binding(4) var<storage, read> d : S;
+
+@compute @workgroup_size(1)
+fn f() {
+ let x : i32 = (((a.i + b.i) + c.i) + d.i);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<BindingRemapper::Remappings>(
+ BindingRemapper::BindingPoints{
+ {{2, 1}, {1, 1}},
+ {{3, 2}, {1, 1}},
+ {{4, 3}, {5, 4}},
+ },
+ BindingRemapper::AccessControls{}, true);
+ auto got = Run<BindingRemapper>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BindingRemapperTest, BindingCollisionsDifferentEntryPoints) {
+ auto* src = R"(
+struct S {
+ i : i32,
+};
+
+@group(2) @binding(1) var<storage, read> a : S;
+
+@group(3) @binding(2) var<storage, read> b : S;
+
+@group(4) @binding(3) var<storage, read> c : S;
+
+@group(5) @binding(4) var<storage, read> d : S;
+
+@compute @workgroup_size(1)
+fn f1() {
+ let x : i32 = (a.i + c.i);
+}
+
+@compute @workgroup_size(1)
+fn f2() {
+ let x : i32 = (b.i + d.i);
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ i : i32,
+}
+
+@group(1) @binding(1) var<storage, read> a : S;
+
+@group(1) @binding(1) var<storage, read> b : S;
+
+@group(5) @binding(4) var<storage, read> c : S;
+
+@group(5) @binding(4) var<storage, read> d : S;
+
+@compute @workgroup_size(1)
+fn f1() {
+ let x : i32 = (a.i + c.i);
+}
+
+@compute @workgroup_size(1)
+fn f2() {
+ let x : i32 = (b.i + d.i);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<BindingRemapper::Remappings>(
+ BindingRemapper::BindingPoints{
+ {{2, 1}, {1, 1}},
+ {{3, 2}, {1, 1}},
+ {{4, 3}, {5, 4}},
+ },
+ BindingRemapper::AccessControls{}, true);
+ auto got = Run<BindingRemapper>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BindingRemapperTest, NoData) {
+ auto* src = R"(
+struct S {
+ a : f32,
+}
+
+@group(2) @binding(1) var<storage, read> a : S;
+
+@group(3) @binding(2) var<storage, read> b : S;
+
+@compute @workgroup_size(1)
+fn f() {
+}
+)";
+
+ auto* expect = R"(error: missing transform data for tint::ast::transform::BindingRemapper)";
+
+ auto got = Run<BindingRemapper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/builtin_polyfill.cc b/src/tint/lang/wgsl/ast/transform/builtin_polyfill.cc
new file mode 100644
index 0000000..aa6e0d6
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/builtin_polyfill.cc
@@ -0,0 +1,1288 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/builtin_polyfill.h"
+
+#include <algorithm>
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/builtin.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/type_expression.h"
+#include "src/tint/sem/value_conversion.h"
+#include "src/tint/switch.h"
+#include "src/tint/type/storage_texture.h"
+#include "src/tint/type/texture_dimension.h"
+#include "src/tint/utils/map.h"
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::BuiltinPolyfill);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::BuiltinPolyfill::Config);
+
+namespace tint::ast::transform {
+
+/// BinaryOpSignature is tuple of a binary op, LHS type and RHS type
+using BinaryOpSignature = std::tuple<BinaryOp, const type::Type*, const type::Type*>;
+
+/// PIMPL state for the transform
+struct BuiltinPolyfill::State {
+ /// Constructor
+ /// @param program the source program
+ /// @param config the transform config
+ State(const Program* program, const Config& config) : src(program), cfg(config) {
+ has_full_ptr_params = false;
+ for (auto* enable : src->AST().Enables()) {
+ if (enable->HasExtension(builtin::Extension::kChromiumExperimentalFullPtrParameters)) {
+ has_full_ptr_params = true;
+ break;
+ }
+ }
+ }
+
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ Transform::ApplyResult Run() {
+ for (auto* node : src->ASTNodes().Objects()) {
+ Switch(
+ node, //
+ [&](const CallExpression* expr) { Call(expr); },
+ [&](const BinaryExpression* bin_op) {
+ if (auto* s = src->Sem().Get(bin_op);
+ !s || s->Stage() == sem::EvaluationStage::kConstant ||
+ s->Stage() == sem::EvaluationStage::kNotEvaluated) {
+ return; // Don't polyfill @const expressions
+ }
+ switch (bin_op->op) {
+ case BinaryOp::kShiftLeft:
+ case BinaryOp::kShiftRight: {
+ if (cfg.builtins.bitshift_modulo) {
+ ctx.Replace(bin_op,
+ [this, bin_op] { return BitshiftModulo(bin_op); });
+ made_changes = true;
+ }
+ break;
+ }
+ case BinaryOp::kDivide: {
+ if (cfg.builtins.int_div_mod) {
+ auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
+ if (lhs_ty->is_integer_scalar_or_vector()) {
+ ctx.Replace(bin_op,
+ [this, bin_op] { return IntDivMod(bin_op); });
+ made_changes = true;
+ }
+ }
+ break;
+ }
+ case BinaryOp::kModulo: {
+ if (cfg.builtins.int_div_mod) {
+ auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
+ if (lhs_ty->is_integer_scalar_or_vector()) {
+ ctx.Replace(bin_op,
+ [this, bin_op] { return IntDivMod(bin_op); });
+ made_changes = true;
+ }
+ }
+ if (cfg.builtins.precise_float_mod) {
+ auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
+ if (lhs_ty->is_float_scalar_or_vector()) {
+ ctx.Replace(bin_op,
+ [this, bin_op] { return PreciseFloatMod(bin_op); });
+ made_changes = true;
+ }
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ },
+ [&](const Expression* expr) {
+ if (cfg.builtins.bgra8unorm) {
+ if (auto* ty_expr = src->Sem().Get<sem::TypeExpression>(expr)) {
+ if (auto* tex = ty_expr->Type()->As<type::StorageTexture>()) {
+ if (tex->texel_format() == builtin::TexelFormat::kBgra8Unorm) {
+ ctx.Replace(expr, [this, tex] {
+ return ctx.dst->Expr(ctx.dst->ty.storage_texture(
+ tex->dim(), builtin::TexelFormat::kRgba8Unorm,
+ tex->access()));
+ });
+ made_changes = true;
+ }
+ }
+ }
+ }
+ });
+ }
+
+ if (!made_changes) {
+ return SkipTransform;
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+
+ private:
+ /// The source program
+ Program const* const src;
+ /// The transform config
+ const Config& cfg;
+ /// The destination program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx{&b, src};
+ /// The source clone context
+ const sem::Info& sem = src->Sem();
+ /// Polyfill functions for binary operators.
+ utils::Hashmap<BinaryOpSignature, Symbol, 8> binary_op_polyfills;
+ /// Polyfill builtins.
+ utils::Hashmap<const sem::Builtin*, Symbol, 8> builtin_polyfills;
+ /// Polyfill f32 conversion to i32 or u32 (or vectors of)
+ utils::Hashmap<const type::Type*, Symbol, 2> f32_conv_polyfills;
+ // Tracks whether the chromium_experimental_full_ptr_parameters extension has been enabled.
+ bool has_full_ptr_params = false;
+ /// True if the transform has made changes (i.e. the program needs cloning)
+ bool made_changes = false;
+
+ ////////////////////////////////////////////////////////////////////////////
+ // Function polyfills
+ ////////////////////////////////////////////////////////////////////////////
+
+ /// Builds the polyfill function for the `acosh` builtin
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol acosh(const type::Type* ty) {
+ auto name = b.Symbols().New("tint_acosh");
+ uint32_t width = WidthOf(ty);
+
+ auto V = [&](AFloat value) -> const Expression* {
+ const Expression* expr = b.Expr(value);
+ if (width == 1) {
+ return expr;
+ }
+ return b.Call(T(ty), expr);
+ };
+
+ utils::Vector<const Statement*, 4> body;
+ switch (cfg.builtins.acosh) {
+ case Level::kFull:
+ // return log(x + sqrt(x*x - 1));
+ body.Push(b.Return(
+ b.Call("log", b.Add("x", b.Call("sqrt", b.Sub(b.Mul("x", "x"), 1_a))))));
+ break;
+ case Level::kRangeCheck: {
+ // return select(acosh(x), 0, x < 1);
+ body.Push(b.Return(
+ b.Call("select", b.Call("acosh", "x"), V(0.0_a), b.LessThan("x", V(1.0_a)))));
+ break;
+ }
+ default:
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled polyfill level: " << static_cast<int>(cfg.builtins.acosh);
+ return {};
+ }
+
+ b.Func(name, utils::Vector{b.Param("x", T(ty))}, T(ty), body);
+
+ return name;
+ }
+
+ /// Builds the polyfill function for the `asinh` builtin
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol asinh(const type::Type* ty) {
+ auto name = b.Symbols().New("tint_sinh");
+
+ // return log(x + sqrt(x*x + 1));
+ b.Func(name, utils::Vector{b.Param("x", T(ty))}, T(ty),
+ utils::Vector{
+ b.Return(b.Call("log", b.Add("x", b.Call("sqrt", b.Add(b.Mul("x", "x"), 1_a))))),
+ });
+
+ return name;
+ }
+
+ /// Builds the polyfill function for the `atanh` builtin
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol atanh(const type::Type* ty) {
+ auto name = b.Symbols().New("tint_atanh");
+ uint32_t width = WidthOf(ty);
+
+ auto V = [&](AFloat value) -> const Expression* {
+ const Expression* expr = b.Expr(value);
+ if (width == 1) {
+ return expr;
+ }
+ return b.Call(T(ty), expr);
+ };
+
+ utils::Vector<const Statement*, 1> body;
+ switch (cfg.builtins.atanh) {
+ case Level::kFull:
+ // return log((1+x) / (1-x)) * 0.5
+ body.Push(
+ b.Return(b.Mul(b.Call("log", b.Div(b.Add(1_a, "x"), b.Sub(1_a, "x"))), 0.5_a)));
+ break;
+ case Level::kRangeCheck:
+ // return select(atanh(x), 0, x >= 1);
+ body.Push(b.Return(b.Call("select", b.Call("atanh", "x"), V(0.0_a),
+ b.GreaterThanEqual("x", V(1.0_a)))));
+ break;
+ default:
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled polyfill level: " << static_cast<int>(cfg.builtins.acosh);
+ return {};
+ }
+
+ b.Func(name, utils::Vector{b.Param("x", T(ty))}, T(ty), body);
+
+ return name;
+ }
+
+ /// Builds the polyfill function for the `clamp` builtin when called with integer arguments
+ /// (scalar or vector)
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol clampInteger(const type::Type* ty) {
+ auto name = b.Symbols().New("tint_clamp");
+
+ b.Func(name,
+ utils::Vector{
+ b.Param("e", T(ty)),
+ b.Param("low", T(ty)),
+ b.Param("high", T(ty)),
+ },
+ T(ty),
+ utils::Vector{
+ // return min(max(e, low), high);
+ b.Return(b.Call("min", b.Call("max", "e", "low"), "high")),
+ });
+ return name;
+ }
+
+ /// Builds the polyfill function for the `countLeadingZeros` builtin
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol countLeadingZeros(const type::Type* ty) {
+ auto name = b.Symbols().New("tint_count_leading_zeros");
+ uint32_t width = WidthOf(ty);
+
+ // Returns either u32 or vecN<u32>
+ auto U = [&] {
+ if (width == 1) {
+ return b.ty.u32();
+ }
+ return b.ty.vec<u32>(width);
+ };
+ auto V = [&](uint32_t value) -> const Expression* {
+ return ScalarOrVector(width, u32(value));
+ };
+ b.Func(
+ name,
+ utils::Vector{
+ b.Param("v", T(ty)),
+ },
+ T(ty),
+ utils::Vector{
+ // var x = U(v);
+ b.Decl(b.Var("x", b.Call(U(), b.Expr("v")))),
+ // let b16 = select(0, 16, x <= 0x0000ffff);
+ b.Decl(b.Let("b16",
+ b.Call("select", V(0), V(16), b.LessThanEqual("x", V(0x0000ffff))))),
+ // x = x << b16;
+ b.Assign("x", b.Shl("x", "b16")),
+ // let b8 = select(0, 8, x <= 0x00ffffff);
+ b.Decl(
+ b.Let("b8", b.Call("select", V(0), V(8), b.LessThanEqual("x", V(0x00ffffff))))),
+ // x = x << b8;
+ b.Assign("x", b.Shl("x", "b8")),
+ // let b4 = select(0, 4, x <= 0x0fffffff);
+ b.Decl(
+ b.Let("b4", b.Call("select", V(0), V(4), b.LessThanEqual("x", V(0x0fffffff))))),
+ // x = x << b4;
+ b.Assign("x", b.Shl("x", "b4")),
+ // let b2 = select(0, 2, x <= 0x3fffffff);
+ b.Decl(
+ b.Let("b2", b.Call("select", V(0), V(2), b.LessThanEqual("x", V(0x3fffffff))))),
+ // x = x << b2;
+ b.Assign("x", b.Shl("x", "b2")),
+ // let b1 = select(0, 1, x <= 0x7fffffff);
+ b.Decl(
+ b.Let("b1", b.Call("select", V(0), V(1), b.LessThanEqual("x", V(0x7fffffff))))),
+ // let is_zero = select(0, 1, x == 0);
+ b.Decl(b.Let("is_zero", b.Call("select", V(0), V(1), b.Equal("x", V(0))))),
+ // return R((b16 | b8 | b4 | b2 | b1) + zero);
+ b.Return(b.Call(T(ty), b.Add(b.Or(b.Or(b.Or(b.Or("b16", "b8"), "b4"), "b2"), "b1"),
+ "is_zero"))),
+ });
+ return name;
+ }
+
+ /// Builds the polyfill function for the `countTrailingZeros` builtin
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol countTrailingZeros(const type::Type* ty) {
+ auto name = b.Symbols().New("tint_count_trailing_zeros");
+ uint32_t width = WidthOf(ty);
+
+ // Returns either u32 or vecN<u32>
+ auto U = [&] {
+ if (width == 1) {
+ return b.ty.u32();
+ }
+ return b.ty.vec<u32>(width);
+ };
+ auto V = [&](uint32_t value) -> const Expression* {
+ return ScalarOrVector(width, u32(value));
+ };
+ auto B = [&](const Expression* value) -> const Expression* {
+ if (width == 1) {
+ return b.Call<bool>(value);
+ }
+ return b.Call(b.ty.vec<bool>(width), value);
+ };
+ b.Func(
+ name,
+ utils::Vector{
+ b.Param("v", T(ty)),
+ },
+ T(ty),
+ utils::Vector{
+ // var x = U(v);
+ b.Decl(b.Var("x", b.Call(U(), b.Expr("v")))),
+ // let b16 = select(16, 0, bool(x & 0x0000ffff));
+ b.Decl(b.Let("b16", b.Call("select", V(16), V(0), B(b.And("x", V(0x0000ffff)))))),
+ // x = x >> b16;
+ b.Assign("x", b.Shr("x", "b16")),
+ // let b8 = select(8, 0, bool(x & 0x000000ff));
+ b.Decl(b.Let("b8", b.Call("select", V(8), V(0), B(b.And("x", V(0x000000ff)))))),
+ // x = x >> b8;
+ b.Assign("x", b.Shr("x", "b8")),
+ // let b4 = select(4, 0, bool(x & 0x0000000f));
+ b.Decl(b.Let("b4", b.Call("select", V(4), V(0), B(b.And("x", V(0x0000000f)))))),
+ // x = x >> b4;
+ b.Assign("x", b.Shr("x", "b4")),
+ // let b2 = select(2, 0, bool(x & 0x00000003));
+ b.Decl(b.Let("b2", b.Call("select", V(2), V(0), B(b.And("x", V(0x00000003)))))),
+ // x = x >> b2;
+ b.Assign("x", b.Shr("x", "b2")),
+ // let b1 = select(1, 0, bool(x & 0x00000001));
+ b.Decl(b.Let("b1", b.Call("select", V(1), V(0), B(b.And("x", V(0x00000001)))))),
+ // let is_zero = select(0, 1, x == 0);
+ b.Decl(b.Let("is_zero", b.Call("select", V(0), V(1), b.Equal("x", V(0))))),
+ // return R((b16 | b8 | b4 | b2 | b1) + zero);
+ b.Return(b.Call(T(ty), b.Add(b.Or(b.Or(b.Or(b.Or("b16", "b8"), "b4"), "b2"), "b1"),
+ "is_zero"))),
+ });
+ return name;
+ }
+
+ /// Builds the polyfill function for the `extractBits` builtin
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol extractBits(const type::Type* ty) {
+ auto name = b.Symbols().New("tint_extract_bits");
+ uint32_t width = WidthOf(ty);
+
+ constexpr uint32_t W = 32u; // 32-bit
+
+ auto vecN_u32 = [&](const Expression* value) -> const Expression* {
+ if (width == 1) {
+ return value;
+ }
+ return b.Call(b.ty.vec<u32>(width), value);
+ };
+
+ utils::Vector<const Statement*, 8> body{
+ b.Decl(b.Let("s", b.Call("min", "offset", u32(W)))),
+ b.Decl(b.Let("e", b.Call("min", u32(W), b.Add("s", "count")))),
+ };
+
+ switch (cfg.builtins.extract_bits) {
+ case Level::kFull:
+ body.Push(b.Decl(b.Let("shl", b.Sub(u32(W), "e"))));
+ body.Push(b.Decl(b.Let("shr", b.Add("shl", "s"))));
+ // Here we don't want the shl and shr modulos the rhs, so handle the `rhs >= 32u`
+ // cases using `select`. In order to handle the signed shr `lhs >> rhs` corrently,
+ // use `(lhs >> 31u) >> 1u` if `rhs >= 32u`.
+ body.Push(b.Decl(b.Let("shl_result", b.Call("select", b.Call(T(ty)),
+ b.Shl("v", vecN_u32(b.Expr("shl"))),
+ b.LessThan("shl", 32_u)))));
+ body.Push(b.Return(b.Call(
+ "select",
+ b.Shr(b.Shr("shl_result", vecN_u32(b.Expr(31_u))), vecN_u32(b.Expr(1_u))),
+ b.Shr("shl_result", vecN_u32(b.Expr("shr"))), b.LessThan("shr", 32_u))
+
+ ));
+ break;
+ case Level::kClampParameters:
+ body.Push(b.Return(b.Call("extractBits", "v", "s", b.Sub("e", "s"))));
+ break;
+ default:
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled polyfill level: " << static_cast<int>(cfg.builtins.extract_bits);
+ return {};
+ }
+
+ b.Func(name,
+ utils::Vector{
+ b.Param("v", T(ty)),
+ b.Param("offset", b.ty.u32()),
+ b.Param("count", b.ty.u32()),
+ },
+ T(ty), std::move(body));
+
+ return name;
+ }
+
+ /// Builds the polyfill function for the `firstLeadingBit` builtin
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol firstLeadingBit(const type::Type* ty) {
+ auto name = b.Symbols().New("tint_first_leading_bit");
+ uint32_t width = WidthOf(ty);
+
+ // Returns either u32 or vecN<u32>
+ auto U = [&] {
+ if (width == 1) {
+ return b.ty.u32();
+ }
+ return b.ty.vec<u32>(width);
+ };
+ auto V = [&](uint32_t value) -> const Expression* {
+ return ScalarOrVector(width, u32(value));
+ };
+ auto B = [&](const Expression* value) -> const Expression* {
+ if (width == 1) {
+ return b.Call<bool>(value);
+ }
+ return b.Call(b.ty.vec<bool>(width), value);
+ };
+
+ const Expression* x = nullptr;
+ if (ty->is_unsigned_integer_scalar_or_vector()) {
+ x = b.Expr("v");
+ } else {
+ // If ty is signed, then the value is inverted if the sign is negative
+ x = b.Call("select", //
+ b.Call(U(), "v"), //
+ b.Call(U(), b.Complement("v")), //
+ b.LessThan("v", ScalarOrVector(width, 0_i)));
+ }
+
+ b.Func(
+ name,
+ utils::Vector{
+ b.Param("v", T(ty)),
+ },
+ T(ty),
+ utils::Vector{
+ // var x = v; (unsigned)
+ // var x = select(U(v), ~U(v), v < 0); (signed)
+ b.Decl(b.Var("x", x)),
+ // let b16 = select(0, 16, bool(x & 0xffff0000));
+ b.Decl(b.Let("b16", b.Call("select", V(0), V(16), B(b.And("x", V(0xffff0000)))))),
+ // x = x >> b16;
+ b.Assign("x", b.Shr("x", "b16")),
+ // let b8 = select(0, 8, bool(x & 0x0000ff00));
+ b.Decl(b.Let("b8", b.Call("select", V(0), V(8), B(b.And("x", V(0x0000ff00)))))),
+ // x = x >> b8;
+ b.Assign("x", b.Shr("x", "b8")),
+ // let b4 = select(0, 4, bool(x & 0x000000f0));
+ b.Decl(b.Let("b4", b.Call("select", V(0), V(4), B(b.And("x", V(0x000000f0)))))),
+ // x = x >> b4;
+ b.Assign("x", b.Shr("x", "b4")),
+ // let b2 = select(0, 2, bool(x & 0x0000000c));
+ b.Decl(b.Let("b2", b.Call("select", V(0), V(2), B(b.And("x", V(0x0000000c)))))),
+ // x = x >> b2;
+ b.Assign("x", b.Shr("x", "b2")),
+ // let b1 = select(0, 1, bool(x & 0x00000002));
+ b.Decl(b.Let("b1", b.Call("select", V(0), V(1), B(b.And("x", V(0x00000002)))))),
+ // let is_zero = select(0, 0xffffffff, x == 0);
+ b.Decl(b.Let("is_zero", b.Call("select", V(0), V(0xffffffff), b.Equal("x", V(0))))),
+ // return R(b16 | b8 | b4 | b2 | b1 | zero);
+ b.Return(b.Call(
+ T(ty), b.Or(b.Or(b.Or(b.Or(b.Or("b16", "b8"), "b4"), "b2"), "b1"), "is_zero"))),
+ });
+ return name;
+ }
+
+ /// Builds the polyfill function for the `firstTrailingBit` builtin
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol firstTrailingBit(const type::Type* ty) {
+ auto name = b.Symbols().New("tint_first_trailing_bit");
+ uint32_t width = WidthOf(ty);
+
+ // Returns either u32 or vecN<u32>
+ auto U = [&] {
+ if (width == 1) {
+ return b.ty.u32();
+ }
+ return b.ty.vec<u32>(width);
+ };
+ auto V = [&](uint32_t value) -> const Expression* {
+ return ScalarOrVector(width, u32(value));
+ };
+ auto B = [&](const Expression* value) -> const Expression* {
+ if (width == 1) {
+ return b.Call<bool>(value);
+ }
+ return b.Call(b.ty.vec<bool>(width), value);
+ };
+ b.Func(
+ name,
+ utils::Vector{
+ b.Param("v", T(ty)),
+ },
+ T(ty),
+ utils::Vector{
+ // var x = U(v);
+ b.Decl(b.Var("x", b.Call(U(), b.Expr("v")))),
+ // let b16 = select(16, 0, bool(x & 0x0000ffff));
+ b.Decl(b.Let("b16", b.Call("select", V(16), V(0), B(b.And("x", V(0x0000ffff)))))),
+ // x = x >> b16;
+ b.Assign("x", b.Shr("x", "b16")),
+ // let b8 = select(8, 0, bool(x & 0x000000ff));
+ b.Decl(b.Let("b8", b.Call("select", V(8), V(0), B(b.And("x", V(0x000000ff)))))),
+ // x = x >> b8;
+ b.Assign("x", b.Shr("x", "b8")),
+ // let b4 = select(4, 0, bool(x & 0x0000000f));
+ b.Decl(b.Let("b4", b.Call("select", V(4), V(0), B(b.And("x", V(0x0000000f)))))),
+ // x = x >> b4;
+ b.Assign("x", b.Shr("x", "b4")),
+ // let b2 = select(2, 0, bool(x & 0x00000003));
+ b.Decl(b.Let("b2", b.Call("select", V(2), V(0), B(b.And("x", V(0x00000003)))))),
+ // x = x >> b2;
+ b.Assign("x", b.Shr("x", "b2")),
+ // let b1 = select(1, 0, bool(x & 0x00000001));
+ b.Decl(b.Let("b1", b.Call("select", V(1), V(0), B(b.And("x", V(0x00000001)))))),
+ // let is_zero = select(0, 0xffffffff, x == 0);
+ b.Decl(b.Let("is_zero", b.Call("select", V(0), V(0xffffffff), b.Equal("x", V(0))))),
+ // return R(b16 | b8 | b4 | b2 | b1 | is_zero);
+ b.Return(b.Call(
+ T(ty), b.Or(b.Or(b.Or(b.Or(b.Or("b16", "b8"), "b4"), "b2"), "b1"), "is_zero"))),
+ });
+ return name;
+ }
+
+ /// Builds the polyfill function for the `insertBits` builtin
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol insertBits(const type::Type* ty) {
+ auto name = b.Symbols().New("tint_insert_bits");
+ uint32_t width = WidthOf(ty);
+
+ // Currently in WGSL parameters of insertBits must be i32, u32, vecN<i32> or vecN<u32>
+ if (TINT_UNLIKELY(((!ty->DeepestElement()->IsAnyOf<type::I32, type::U32>())))) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "insertBits polyfill only support i32, u32, and vector of i32 or u32, got "
+ << ty->FriendlyName();
+ return {};
+ }
+
+ constexpr uint32_t W = 32u; // 32-bit
+
+ auto V = [&](auto value) -> const Expression* {
+ const Expression* expr = b.Expr(value);
+ if (!ty->is_unsigned_integer_scalar_or_vector()) {
+ expr = b.Call<i32>(expr);
+ }
+ if (ty->Is<type::Vector>()) {
+ expr = b.Call(T(ty), expr);
+ }
+ return expr;
+ };
+ auto U = [&](auto value) -> const Expression* {
+ if (width == 1) {
+ return b.Expr(value);
+ }
+ return b.vec(b.ty.u32(), width, value);
+ };
+
+ // Polyfill algorithm:
+ // s = min(offset, 32u);
+ // e = min(32u, (s + count));
+ // mask = (((1u << s) - 1u) ^ ((1u << e) - 1u));
+ // return (((n << s) & mask) | (v & ~(mask)));
+ // Note that the algorithm above use the left-shifting in C++ manner, but in WGSL, HLSL, MSL
+ // the rhs are modulo to bit-width of lhs (that is 32u in this case), and in GLSL the result
+ // is undefined if rhs is greater than or equal to bit-width of lhs. The results of `x << y`
+ // in C++ and HLSL are different when `y >= 32u`, and the `s` and `e` defined above can be
+ // 32u, which are cases we must handle specially. Replace all `(x << y)` to
+ // `select(Tx(), x << y, y < 32u)`, in which `Tx` is the type of x, where y can be greater
+ // than or equal to 32u.
+ // WGSL polyfill function:
+ // fn tint_insert_bits(v : T, n : T, offset : u32, count : u32) -> T {
+ // let e = offset + count;
+ // let mask = (
+ // (select(0u, 1u << offset, offset < 32u) - 1u) ^
+ // (select(0u, 1u << e, e < 32u) - 1u)
+ // );
+ // return ((select(T(), n << offset, offset < 32u) & mask) | (v & ~(mask)));
+ // }
+
+ utils::Vector<const Statement*, 8> body;
+
+ switch (cfg.builtins.insert_bits) {
+ case Level::kFull:
+ // let e = offset + count;
+ body.Push(b.Decl(b.Let("e", b.Add("offset", "count"))));
+
+ // let mask = (
+ // (select(0u, 1u << offset, offset < 32u) - 1u) ^
+ // (select(0u, 1u << e, e < 32u) - 1u)
+ // );
+ body.Push(b.Decl(b.Let(
+ "mask",
+ b.Xor( //
+ b.Sub(
+ b.Call("select", 0_u, b.Shl(1_u, "offset"), b.LessThan("offset", 32_u)),
+ 1_u),
+ b.Sub(b.Call("select", 0_u, b.Shl(1_u, "e"), b.LessThan("e", 32_u)),
+ 1_u) //
+ ))));
+
+ // return ((select(T(), n << offset, offset < 32u) & mask) | (v & ~(mask)));
+ body.Push(
+ b.Return(b.Or(b.And(b.Call("select", b.Call(T(ty)), b.Shl("n", U("offset")),
+ b.LessThan("offset", 32_u)),
+ V("mask")),
+ b.And("v", V(b.Complement("mask"))))));
+
+ break;
+ case Level::kClampParameters:
+ body.Push(b.Decl(b.Let("s", b.Call("min", "offset", u32(W)))));
+ body.Push(b.Decl(b.Let("e", b.Call("min", u32(W), b.Add("s", "count")))));
+ body.Push(b.Return(b.Call("insertBits", "v", "n", "s", b.Sub("e", "s"))));
+ break;
+ default:
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled polyfill level: " << static_cast<int>(cfg.builtins.insert_bits);
+ return {};
+ }
+
+ b.Func(name,
+ utils::Vector{
+ b.Param("v", T(ty)),
+ b.Param("n", T(ty)),
+ b.Param("offset", b.ty.u32()),
+ b.Param("count", b.ty.u32()),
+ },
+ T(ty), body);
+
+ return name;
+ }
+
+ /// Builds the polyfill function for the `reflect` builtin
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol reflect(const type::Type* ty) {
+ auto name = b.Symbols().New("tint_reflect");
+
+ // WGSL polyfill function:
+ // fn tint_reflect(e1 : T, e2 : T) -> T {
+ // let factor = (-2.0 * dot(e1, e2));
+ // return (e1 + (factor * e2));
+ // }
+ // Using -2.0 instead of 2.0 in factor to prevent the optimization that cause wrong result.
+ // See https://crbug.com/tint/1798 for more details.
+ auto body = utils::Vector{
+ b.Decl(b.Let("factor", b.Mul(-2.0_a, b.Call("dot", "e1", "e2")))),
+ b.Return(b.Add("e1", b.Mul("factor", "e2"))),
+ };
+ b.Func(name,
+ utils::Vector{
+ b.Param("e1", T(ty)),
+ b.Param("e2", T(ty)),
+ },
+ T(ty), body);
+
+ return name;
+ }
+
+ /// Builds the polyfill function for the `saturate` builtin
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol saturate(const type::Type* ty) {
+ auto name = b.Symbols().New("tint_saturate");
+ auto body = utils::Vector{
+ b.Return(b.Call("clamp", "v", b.Call(T(ty), 0_a), b.Call(T(ty), 1_a))),
+ };
+ b.Func(name,
+ utils::Vector{
+ b.Param("v", T(ty)),
+ },
+ T(ty), body);
+
+ return name;
+ }
+
+ /// Builds the polyfill function for the `sign` builtin when the element type is integer
+ /// @param ty the parameter and return type for the function
+ /// @return the polyfill function name
+ Symbol sign_int(const type::Type* ty) {
+ const uint32_t width = WidthOf(ty);
+ auto zero = [&] { return ScalarOrVector(width, 0_a); };
+
+ // pos_or_neg_one = (v > 0) ? 1 : -1
+ auto pos_or_neg_one = b.Call("select", //
+ ScalarOrVector(width, -1_a), //
+ ScalarOrVector(width, 1_a), //
+ b.GreaterThan("v", zero()));
+
+ auto name = b.Symbols().New("tint_sign");
+ b.Func(name,
+ utils::Vector{
+ b.Param("v", T(ty)),
+ },
+ T(ty),
+ utils::Vector{
+ b.Return(b.Call("select", pos_or_neg_one, zero(), b.Equal("v", zero()))),
+ });
+
+ return name;
+ }
+
+ /// Builds the polyfill function for the `textureSampleBaseClampToEdge` builtin, when the
+ /// texture type is texture_2d<f32>.
+ /// @return the polyfill function name
+ Symbol textureSampleBaseClampToEdge_2d_f32() {
+ auto name = b.Symbols().New("tint_textureSampleBaseClampToEdge");
+ auto body = utils::Vector{
+ b.Decl(b.Let("dims", b.Call(b.ty.vec2<f32>(), b.Call("textureDimensions", "t", 0_a)))),
+ b.Decl(b.Let("half_texel", b.Div(b.Call<vec2<f32>>(0.5_a), "dims"))),
+ b.Decl(
+ b.Let("clamped", b.Call("clamp", "coord", "half_texel", b.Sub(1_a, "half_texel")))),
+ b.Return(b.Call("textureSampleLevel", "t", "s", "clamped", 0_a)),
+ };
+ b.Func(name,
+ utils::Vector{
+ b.Param("t", b.ty.sampled_texture(type::TextureDimension::k2d, b.ty.f32())),
+ b.Param("s", b.ty.sampler(type::SamplerKind::kSampler)),
+ b.Param("coord", b.ty.vec2<f32>()),
+ },
+ b.ty.vec4<f32>(), body);
+ return name;
+ }
+
+ /// Builds the polyfill function for the `quantizeToF16` builtin, by replacing the vector form
+ /// with scalar calls.
+ /// @param vec the vector type
+ /// @return the polyfill function name
+ Symbol quantizeToF16(const type::Vector* vec) {
+ auto name = b.Symbols().New("tint_quantizeToF16");
+ utils::Vector<const Expression*, 4> args;
+ for (uint32_t i = 0; i < vec->Width(); i++) {
+ args.Push(b.Call("quantizeToF16", b.IndexAccessor("v", u32(i))));
+ }
+ b.Func(name,
+ utils::Vector{
+ b.Param("v", T(vec)),
+ },
+ T(vec),
+ utils::Vector{
+ b.Return(b.Call(T(vec), std::move(args))),
+ });
+ return name;
+ }
+
+ /// Builds the polyfill function for the `workgroupUniformLoad` builtin.
+ /// @param type the type being loaded
+ /// @return the polyfill function name
+ Symbol workgroupUniformLoad(const type::Type* type) {
+ if (!has_full_ptr_params) {
+ b.Enable(builtin::Extension::kChromiumExperimentalFullPtrParameters);
+ has_full_ptr_params = true;
+ }
+ auto name = b.Symbols().New("tint_workgroupUniformLoad");
+ b.Func(name,
+ utils::Vector{
+ b.Param("p", b.ty.ptr<workgroup>(T(type))),
+ },
+ T(type),
+ utils::Vector{
+ b.CallStmt(b.Call("workgroupBarrier")),
+ b.Decl(b.Let("result", b.Deref("p"))),
+ b.CallStmt(b.Call("workgroupBarrier")),
+ b.Return("result"),
+ });
+ return name;
+ }
+
+ /// Builds the polyfill function to value convert a scalar or vector of f32 to an i32 or u32 (or
+ /// vector of).
+ /// @param source the type of the value being converted
+ /// @param target the target conversion type
+ /// @return the polyfill function name
+ Symbol ConvF32ToIU32(const type::Type* source, const type::Type* target) {
+ struct Limits {
+ AFloat low_condition;
+ AInt low_limit;
+ AFloat high_condition;
+ AInt high_limit;
+ };
+ const bool is_signed = target->is_signed_integer_scalar_or_vector();
+ const Limits limits = is_signed ? Limits{
+ /* low_condition */ -AFloat(0x80000000),
+ /* low_limit */ -AInt(0x80000000),
+ /* high_condition */ AFloat(0x7fffff80),
+ /* high_limit */ AInt(0x7fffffff),
+ }
+ : Limits{
+ /* low_condition */ AFloat(0),
+ /* low_limit */ AInt(0),
+ /* high_condition */ AFloat(0xffffff00),
+ /* high_limit */ AInt(0xffffffff),
+ };
+
+ const uint32_t width = WidthOf(target);
+
+ // select(target(v), low_limit, v < low_condition)
+ auto* select_low = b.Call(builtin::Function::kSelect, //
+ b.Call(T(target), "v"), //
+ ScalarOrVector(width, limits.low_limit), //
+ b.LessThan("v", ScalarOrVector(width, limits.low_condition)));
+
+ // select(high_limit, select_low, v < high_condition)
+ auto* select_high = b.Call(builtin::Function::kSelect, //
+ ScalarOrVector(width, limits.high_limit), //
+ select_low, //
+ b.LessThan("v", ScalarOrVector(width, limits.high_condition)));
+
+ auto name = b.Symbols().New(is_signed ? "tint_ftoi" : "tint_ftou");
+ b.Func(name, utils::Vector{b.Param("v", T(source))}, T(target),
+ utils::Vector{b.Return(select_high)});
+ return name;
+ }
+
+ ////////////////////////////////////////////////////////////////////////////
+ // Inline polyfills
+ ////////////////////////////////////////////////////////////////////////////
+
+ /// Builds the polyfill inline expression for a bitshift left or bitshift right, ensuring that
+ /// the RHS is modulo the bit-width of the LHS.
+ /// @param bin_op the original BinaryExpression
+ /// @return the polyfill value for bitshift operation
+ const Expression* BitshiftModulo(const BinaryExpression* bin_op) {
+ auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
+ auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
+ auto* lhs_el_ty = lhs_ty->DeepestElement();
+ const Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1));
+ if (rhs_ty->Is<type::Vector>()) {
+ mask = b.Call(CreateASTTypeFor(ctx, rhs_ty), mask);
+ }
+ auto* lhs = ctx.Clone(bin_op->lhs);
+ auto* rhs = b.And(ctx.Clone(bin_op->rhs), mask);
+ return b.create<BinaryExpression>(ctx.Clone(bin_op->source), bin_op->op, lhs, rhs);
+ }
+
+ /// Builds the polyfill inline expression for a integer divide or modulo, preventing DBZs and
+ /// integer overflows.
+ /// @param bin_op the original BinaryExpression
+ /// @return the polyfill divide or modulo
+ const Expression* IntDivMod(const BinaryExpression* bin_op) {
+ auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
+ auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
+ BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
+ auto fn = binary_op_polyfills.GetOrCreate(sig, [&] {
+ const bool is_div = bin_op->op == BinaryOp::kDivide;
+
+ const auto [lhs_el_ty, lhs_width] = lhs_ty->Elements(lhs_ty, 1);
+ const auto [rhs_el_ty, rhs_width] = rhs_ty->Elements(rhs_ty, 1);
+
+ const uint32_t width = std::max(lhs_width, rhs_width);
+
+ const char* lhs = "lhs";
+ const char* rhs = "rhs";
+
+ utils::Vector<const Statement*, 4> body;
+
+ if (lhs_width < width) {
+ // lhs is scalar, rhs is vector. Convert lhs to vector.
+ body.Push(b.Decl(b.Let("l", b.vec(T(lhs_el_ty), width, b.Expr(lhs)))));
+ lhs = "l";
+ }
+ if (rhs_width < width) {
+ // lhs is vector, rhs is scalar. Convert rhs to vector.
+ body.Push(b.Decl(b.Let("r", b.vec(T(rhs_el_ty), width, b.Expr(rhs)))));
+ rhs = "r";
+ }
+
+ auto name = b.Symbols().New(is_div ? "tint_div" : "tint_mod");
+
+ auto* rhs_is_zero = b.Equal(rhs, ScalarOrVector(width, 0_a));
+
+ if (lhs_ty->is_signed_integer_scalar_or_vector()) {
+ const auto bits = lhs_el_ty->Size() * 8;
+ auto min_int = AInt(AInt::kLowestValue >> (AInt::kNumBits - bits));
+ const Expression* lhs_is_min = b.Equal(lhs, ScalarOrVector(width, min_int));
+ const Expression* rhs_is_minus_one = b.Equal(rhs, ScalarOrVector(width, -1_a));
+ // use_one = rhs_is_zero | ((lhs == MIN_INT) & (rhs == -1))
+ auto* use_one = b.Or(rhs_is_zero, b.And(lhs_is_min, rhs_is_minus_one));
+
+ // Special handling for mod in case either operand is negative, as negative operands
+ // for % is undefined behaviour for most backends (HLSL, MSL, GLSL, SPIR-V).
+ if (!is_div) {
+ const char* rhs_or_one = "rhs_or_one";
+ body.Push(b.Decl(b.Let(
+ rhs_or_one, b.Call("select", rhs, ScalarOrVector(width, 1_a), use_one))));
+
+ // Is either operand negative?
+ // (lhs | rhs) & (1<<31)
+ auto sign_bit_mask = ScalarOrVector(width, u32(1 << (bits - 1)));
+ auto* lhs_or_rhs = CastScalarOrVector<u32>(width, b.Or(lhs, rhs_or_one));
+ auto* lhs_or_rhs_is_neg =
+ b.NotEqual(b.And(lhs_or_rhs, sign_bit_mask), ScalarOrVector(width, 0_u));
+
+ // lhs - trunc(lhs / rhs) * rhs (note: integral division truncates)
+ auto* slow_mod = b.Sub(lhs, b.Mul(b.Div(lhs, rhs_or_one), rhs_or_one));
+
+ // lhs % rhs
+ auto* fast_mod = b.Mod(lhs, rhs_or_one);
+
+ auto* use_slow = b.Call("any", lhs_or_rhs_is_neg);
+
+ body.Push(b.If(use_slow, b.Block(b.Return(slow_mod)),
+ b.Else(b.Block(b.Return(fast_mod)))));
+
+ } else {
+ auto* rhs_or_one = b.Call("select", rhs, ScalarOrVector(width, 1_a), use_one);
+ body.Push(b.Return(is_div ? b.Div(lhs, rhs_or_one) : b.Mod(lhs, rhs_or_one)));
+ }
+
+ } else {
+ auto* rhs_or_one = b.Call("select", rhs, ScalarOrVector(width, 1_a), rhs_is_zero);
+ body.Push(b.Return(is_div ? b.Div(lhs, rhs_or_one) : b.Mod(lhs, rhs_or_one)));
+ }
+
+ b.Func(name,
+ utils::Vector{
+ b.Param("lhs", T(lhs_ty)),
+ b.Param("rhs", T(rhs_ty)),
+ },
+ width == 1 ? T(lhs_ty) : b.ty.vec(T(lhs_el_ty), width), // return type
+ std::move(body));
+
+ return name;
+ });
+ auto* lhs = ctx.Clone(bin_op->lhs);
+ auto* rhs = ctx.Clone(bin_op->rhs);
+ return b.Call(fn, lhs, rhs);
+ }
+
+ /// Builds the polyfill inline expression for a precise float modulo, as defined in the spec.
+ /// @param bin_op the original BinaryExpression
+ /// @return the polyfill divide or modulo
+ const Expression* PreciseFloatMod(const BinaryExpression* bin_op) {
+ auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
+ auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
+ BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
+ auto fn = binary_op_polyfills.GetOrCreate(sig, [&] {
+ const auto [lhs_el_ty, lhs_width] = lhs_ty->Elements(lhs_ty, 1);
+ const auto [rhs_el_ty, rhs_width] = rhs_ty->Elements(rhs_ty, 1);
+
+ const uint32_t width = std::max(lhs_width, rhs_width);
+
+ const char* lhs = "lhs";
+ const char* rhs = "rhs";
+
+ utils::Vector<const Statement*, 4> body;
+
+ if (lhs_width < width) {
+ // lhs is scalar, rhs is vector. Convert lhs to vector.
+ body.Push(b.Decl(b.Let("l", b.vec(T(lhs_el_ty), width, b.Expr(lhs)))));
+ lhs = "l";
+ }
+ if (rhs_width < width) {
+ // lhs is vector, rhs is scalar. Convert rhs to vector.
+ body.Push(b.Decl(b.Let("r", b.vec(T(rhs_el_ty), width, b.Expr(rhs)))));
+ rhs = "r";
+ }
+
+ auto name = b.Symbols().New("tint_float_mod");
+
+ // lhs - trunc(lhs / rhs) * rhs
+ auto* precise_mod = b.Sub(lhs, b.Mul(b.Call("trunc", b.Div(lhs, rhs)), rhs));
+ body.Push(b.Return(precise_mod));
+
+ b.Func(name,
+ utils::Vector{
+ b.Param("lhs", T(lhs_ty)),
+ b.Param("rhs", T(rhs_ty)),
+ },
+ width == 1 ? T(lhs_ty) : b.ty.vec(T(lhs_el_ty), width), // return type
+ std::move(body));
+
+ return name;
+ });
+ auto* lhs = ctx.Clone(bin_op->lhs);
+ auto* rhs = ctx.Clone(bin_op->rhs);
+ return b.Call(fn, lhs, rhs);
+ }
+
+ /// @returns the AST type for the given sem type
+ Type T(const type::Type* ty) { return CreateASTTypeFor(ctx, ty); }
+
+ /// @returns 1 if `ty` is not a vector, otherwise the vector width
+ uint32_t WidthOf(const type::Type* ty) const {
+ if (auto* v = ty->As<type::Vector>()) {
+ return v->Width();
+ }
+ return 1;
+ }
+
+ /// @returns a scalar or vector with the given width, with each element with
+ /// the given value.
+ template <typename T>
+ const Expression* ScalarOrVector(uint32_t width, T value) {
+ if (width == 1) {
+ return b.Expr(value);
+ }
+ return b.Call(b.ty.vec<T>(width), value);
+ }
+
+ template <typename To>
+ const Expression* CastScalarOrVector(uint32_t width, const Expression* e) {
+ if (width == 1) {
+ return b.Call(b.ty.Of<To>(), e);
+ }
+ return b.Call(b.ty.vec<To>(width), e);
+ }
+
+ /// Examines the call expression @p expr, applying any necessary polyfill transforms
+ void Call(const CallExpression* expr) {
+ auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
+ if (!call || call->Stage() == sem::EvaluationStage::kConstant ||
+ call->Stage() == sem::EvaluationStage::kNotEvaluated) {
+ return; // Don't polyfill @const expressions
+ }
+ Symbol fn = Switch(
+ call->Target(), //
+ [&](const sem::Builtin* builtin) {
+ switch (builtin->Type()) {
+ case builtin::Function::kAcosh:
+ if (cfg.builtins.acosh != Level::kNone) {
+ return builtin_polyfills.GetOrCreate(
+ builtin, [&] { return acosh(builtin->ReturnType()); });
+ }
+ return Symbol{};
+
+ case builtin::Function::kAsinh:
+ if (cfg.builtins.asinh) {
+ return builtin_polyfills.GetOrCreate(
+ builtin, [&] { return asinh(builtin->ReturnType()); });
+ }
+ return Symbol{};
+
+ case builtin::Function::kAtanh:
+ if (cfg.builtins.atanh != Level::kNone) {
+ return builtin_polyfills.GetOrCreate(
+ builtin, [&] { return atanh(builtin->ReturnType()); });
+ }
+ return Symbol{};
+
+ case builtin::Function::kClamp:
+ if (cfg.builtins.clamp_int) {
+ auto& sig = builtin->Signature();
+ if (sig.parameters[0]->Type()->is_integer_scalar_or_vector()) {
+ return builtin_polyfills.GetOrCreate(
+ builtin, [&] { return clampInteger(builtin->ReturnType()); });
+ }
+ }
+ return Symbol{};
+
+ case builtin::Function::kCountLeadingZeros:
+ if (cfg.builtins.count_leading_zeros) {
+ return builtin_polyfills.GetOrCreate(
+ builtin, [&] { return countLeadingZeros(builtin->ReturnType()); });
+ }
+ return Symbol{};
+
+ case builtin::Function::kCountTrailingZeros:
+ if (cfg.builtins.count_trailing_zeros) {
+ return builtin_polyfills.GetOrCreate(
+ builtin, [&] { return countTrailingZeros(builtin->ReturnType()); });
+ }
+ return Symbol{};
+
+ case builtin::Function::kExtractBits:
+ if (cfg.builtins.extract_bits != Level::kNone) {
+ return builtin_polyfills.GetOrCreate(
+ builtin, [&] { return extractBits(builtin->ReturnType()); });
+ }
+ return Symbol{};
+
+ case builtin::Function::kFirstLeadingBit:
+ if (cfg.builtins.first_leading_bit) {
+ return builtin_polyfills.GetOrCreate(
+ builtin, [&] { return firstLeadingBit(builtin->ReturnType()); });
+ }
+ return Symbol{};
+
+ case builtin::Function::kFirstTrailingBit:
+ if (cfg.builtins.first_trailing_bit) {
+ return builtin_polyfills.GetOrCreate(
+ builtin, [&] { return firstTrailingBit(builtin->ReturnType()); });
+ }
+ return Symbol{};
+
+ case builtin::Function::kInsertBits:
+ if (cfg.builtins.insert_bits != Level::kNone) {
+ return builtin_polyfills.GetOrCreate(
+ builtin, [&] { return insertBits(builtin->ReturnType()); });
+ }
+ return Symbol{};
+
+ case builtin::Function::kReflect:
+ // Only polyfill for vec2<f32>. See https://crbug.com/tint/1798 for
+ // more details.
+ if (cfg.builtins.reflect_vec2_f32) {
+ auto& sig = builtin->Signature();
+ auto* vec = sig.return_type->As<type::Vector>();
+ if (vec && vec->Width() == 2 && vec->type()->Is<type::F32>()) {
+ return builtin_polyfills.GetOrCreate(
+ builtin, [&] { return reflect(builtin->ReturnType()); });
+ }
+ }
+ return Symbol{};
+
+ case builtin::Function::kSaturate:
+ if (cfg.builtins.saturate) {
+ return builtin_polyfills.GetOrCreate(
+ builtin, [&] { return saturate(builtin->ReturnType()); });
+ }
+ return Symbol{};
+
+ case builtin::Function::kSign:
+ if (cfg.builtins.sign_int) {
+ auto* ty = builtin->ReturnType();
+ if (ty->is_signed_integer_scalar_or_vector()) {
+ return builtin_polyfills.GetOrCreate(builtin,
+ [&] { return sign_int(ty); });
+ }
+ }
+ return Symbol{};
+
+ case builtin::Function::kTextureSampleBaseClampToEdge:
+ if (cfg.builtins.texture_sample_base_clamp_to_edge_2d_f32) {
+ auto& sig = builtin->Signature();
+ auto* tex = sig.Parameter(sem::ParameterUsage::kTexture);
+ if (auto* stex = tex->Type()->As<type::SampledTexture>()) {
+ if (stex->type()->Is<type::F32>()) {
+ return builtin_polyfills.GetOrCreate(builtin, [&] {
+ return textureSampleBaseClampToEdge_2d_f32();
+ });
+ }
+ }
+ }
+ return Symbol{};
+
+ case builtin::Function::kTextureStore:
+ if (cfg.builtins.bgra8unorm) {
+ auto& sig = builtin->Signature();
+ auto* tex = sig.Parameter(sem::ParameterUsage::kTexture);
+ if (auto* stex = tex->Type()->As<type::StorageTexture>()) {
+ if (stex->texel_format() == builtin::TexelFormat::kBgra8Unorm) {
+ size_t value_idx = static_cast<size_t>(
+ sig.IndexOf(sem::ParameterUsage::kValue));
+ ctx.Replace(expr, [this, expr, value_idx] {
+ utils::Vector<const Expression*, 3> args;
+ for (auto* arg : expr->args) {
+ arg = ctx.Clone(arg);
+ if (args.Length() == value_idx) { // value
+ arg = ctx.dst->MemberAccessor(arg, "bgra");
+ }
+ args.Push(arg);
+ }
+ return ctx.dst->Call(
+ utils::ToString(builtin::Function::kTextureStore),
+ std::move(args));
+ });
+ made_changes = true;
+ }
+ }
+ }
+ return Symbol{};
+
+ case builtin::Function::kQuantizeToF16:
+ if (cfg.builtins.quantize_to_vec_f16) {
+ if (auto* vec = builtin->ReturnType()->As<type::Vector>()) {
+ return builtin_polyfills.GetOrCreate(
+ builtin, [&] { return quantizeToF16(vec); });
+ }
+ }
+ return Symbol{};
+
+ case builtin::Function::kWorkgroupUniformLoad:
+ if (cfg.builtins.workgroup_uniform_load) {
+ return builtin_polyfills.GetOrCreate(builtin, [&] {
+ return workgroupUniformLoad(builtin->ReturnType());
+ });
+ }
+ return Symbol{};
+
+ default:
+ return Symbol{};
+ }
+ },
+ [&](const sem::ValueConversion* conv) {
+ if (cfg.builtins.conv_f32_to_iu32) {
+ auto* src_ty = conv->Source();
+ if (tint::Is<type::F32>(src_ty->Elements(src_ty).type)) {
+ auto* dst_ty = conv->Target();
+ if (tint::utils::IsAnyOf<type::I32, type::U32>(
+ dst_ty->Elements(dst_ty).type)) {
+ return f32_conv_polyfills.GetOrCreate(dst_ty, [&] { //
+ return ConvF32ToIU32(src_ty, dst_ty);
+ });
+ }
+ }
+ }
+ return Symbol{};
+ });
+
+ if (fn.IsValid()) {
+ ctx.Replace(call->Declaration(),
+ [this, fn, expr] { return ctx.dst->Call(fn, ctx.Clone(expr->args)); });
+ made_changes = true;
+ }
+ }
+};
+
+BuiltinPolyfill::BuiltinPolyfill() = default;
+
+BuiltinPolyfill::~BuiltinPolyfill() = default;
+
+Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
+ const DataMap& data,
+ DataMap&) const {
+ auto* cfg = data.Get<Config>();
+ if (!cfg) {
+ return SkipTransform;
+ }
+ return State{src, *cfg}.Run();
+}
+
+BuiltinPolyfill::Config::Config(const Builtins& b) : builtins(b) {}
+BuiltinPolyfill::Config::Config(const Config&) = default;
+BuiltinPolyfill::Config::~Config() = default;
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/builtin_polyfill.h b/src/tint/lang/wgsl/ast/transform/builtin_polyfill.h
new file mode 100644
index 0000000..0bdf89f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/builtin_polyfill.h
@@ -0,0 +1,118 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_BUILTIN_POLYFILL_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_BUILTIN_POLYFILL_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// Implements builtins for backends that do not have a native implementation.
+class BuiltinPolyfill final : public utils::Castable<BuiltinPolyfill, Transform> {
+ public:
+ /// Constructor
+ BuiltinPolyfill();
+ /// Destructor
+ ~BuiltinPolyfill() override;
+
+ /// Enumerator of polyfill levels
+ enum class Level {
+ /// No polyfill needed, supported by the backend.
+ kNone,
+ /// Clamp the parameters to the inner implementation.
+ kClampParameters,
+ /// Range check the input.
+ kRangeCheck,
+ /// Polyfill the entire function
+ kFull,
+ };
+
+ /// Specifies the builtins that should be polyfilled by the transform.
+ struct Builtins {
+ /// What level should `acosh` be polyfilled?
+ Level acosh = Level::kNone;
+ /// Should `asinh` be polyfilled?
+ bool asinh = false;
+ /// What level should `atanh` be polyfilled?
+ Level atanh = Level::kNone;
+ /// Should storage textures of format 'bgra8unorm' be replaced with 'rgba8unorm'?
+ bool bgra8unorm = false;
+ /// Should the RHS of `<<` and `>>` be wrapped in a modulo bit-width of LHS?
+ bool bitshift_modulo = false;
+ /// Should `clamp()` be polyfilled for integer values (scalar or vector)?
+ bool clamp_int = false;
+ /// Should `countLeadingZeros()` be polyfilled?
+ bool count_leading_zeros = false;
+ /// Should `countTrailingZeros()` be polyfilled?
+ bool count_trailing_zeros = false;
+ /// Should converting f32 to i32 or u32 be polyfilled?
+ bool conv_f32_to_iu32 = false;
+ /// What level should `extractBits()` be polyfilled?
+ Level extract_bits = Level::kNone;
+ /// Should `firstLeadingBit()` be polyfilled?
+ bool first_leading_bit = false;
+ /// Should `firstTrailingBit()` be polyfilled?
+ bool first_trailing_bit = false;
+ /// Should `insertBits()` be polyfilled?
+ Level insert_bits = Level::kNone;
+ /// Should integer scalar / vector divides and modulos be polyfilled to avoid DBZ and
+ /// integer overflows?
+ bool int_div_mod = false;
+ /// Should float modulos be polyfilled to emit a precise modulo operation as per the spec?
+ bool precise_float_mod = false;
+ /// Should `reflect()` be polyfilled for vec2<f32>?
+ bool reflect_vec2_f32 = false;
+ /// Should `saturate()` be polyfilled?
+ bool saturate = false;
+ /// Should `sign()` be polyfilled for integer types?
+ bool sign_int = false;
+ /// Should `textureSampleBaseClampToEdge()` be polyfilled for texture_2d<f32> textures?
+ bool texture_sample_base_clamp_to_edge_2d_f32 = false;
+ /// Should the vector form of `quantizeToF16()` be polyfilled with a scalar implementation?
+ /// See crbug.com/tint/1741
+ bool quantize_to_vec_f16 = false;
+ /// Should `workgroupUniformLoad()` be polyfilled?
+ bool workgroup_uniform_load = false;
+ };
+
+ /// Config is consumed by the BuiltinPolyfill transform.
+ /// Config specifies the builtins that should be polyfilled.
+ struct Config final : public utils::Castable<Config, Data> {
+ /// Constructor
+ /// @param b the list of builtins to polyfill
+ explicit Config(const Builtins& b);
+
+ /// Copy constructor
+ Config(const Config&);
+
+ /// Destructor
+ ~Config() override;
+
+ /// The builtins to polyfill
+ const Builtins builtins;
+ };
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_BUILTIN_POLYFILL_H_
diff --git a/src/tint/lang/wgsl/ast/transform/builtin_polyfill_test.cc b/src/tint/lang/wgsl/ast/transform/builtin_polyfill_test.cc
new file mode 100644
index 0000000..505632e
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/builtin_polyfill_test.cc
@@ -0,0 +1,4033 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/builtin_polyfill.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/direct_variable_access.h"
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using Level = BuiltinPolyfill::Level;
+
+using BuiltinPolyfillTest = TransformTest;
+
+TEST_F(BuiltinPolyfillTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+}
+
+TEST_F(BuiltinPolyfillTest, EmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// acosh
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillAcosh(Level level) {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.acosh = level;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunAcosh) {
+ auto* src = R"(
+fn f() {
+ let v = 1.0;
+ _ = acosh(v);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillAcosh(Level::kNone)));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillAcosh(Level::kRangeCheck)));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillAcosh(Level::kFull)));
+}
+
+TEST_F(BuiltinPolyfillTest, Acosh_ConstantExpression) {
+ auto* src = R"(
+fn f() {
+ let r : f32 = acosh(1.0);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillAcosh(Level::kFull)));
+}
+
+TEST_F(BuiltinPolyfillTest, Acosh_Full_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 1.0;
+ let r : f32 = acosh(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_acosh(x : f32) -> f32 {
+ return log((x + sqrt(((x * x) - 1))));
+}
+
+fn f() {
+ let v = 1.0;
+ let r : f32 = tint_acosh(v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillAcosh(Level::kFull));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, Acosh_Full_vec3_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 1.0;
+ let r : vec3<f32> = acosh(vec3<f32>(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_acosh(x : vec3<f32>) -> vec3<f32> {
+ return log((x + sqrt(((x * x) - 1))));
+}
+
+fn f() {
+ let v = 1.0;
+ let r : vec3<f32> = tint_acosh(vec3<f32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillAcosh(Level::kFull));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, Acosh_Range_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 1.0;
+ let r : f32 = acosh(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_acosh(x : f32) -> f32 {
+ return select(acosh(x), 0.0, (x < 1.0));
+}
+
+fn f() {
+ let v = 1.0;
+ let r : f32 = tint_acosh(v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillAcosh(Level::kRangeCheck));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, Acosh_Range_vec3_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 1.0;
+ let r : vec3<f32> = acosh(vec3<f32>(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_acosh(x : vec3<f32>) -> vec3<f32> {
+ return select(acosh(x), vec3<f32>(0.0), (x < vec3<f32>(1.0)));
+}
+
+fn f() {
+ let v = 1.0;
+ let r : vec3<f32> = tint_acosh(vec3<f32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillAcosh(Level::kRangeCheck));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// asinh
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillSinh() {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.asinh = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunAsinh) {
+ auto* src = R"(
+fn f() {
+ let v = 1.0;
+ _ = asinh(v);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillSinh()));
+}
+
+TEST_F(BuiltinPolyfillTest, Asinh_ConstantExpression) {
+ auto* src = R"(
+fn f() {
+ let r : f32 = asinh(1.0);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillSinh()));
+}
+
+TEST_F(BuiltinPolyfillTest, Asinh_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 1.0;
+ let r : f32 = asinh(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_sinh(x : f32) -> f32 {
+ return log((x + sqrt(((x * x) + 1))));
+}
+
+fn f() {
+ let v = 1.0;
+ let r : f32 = tint_sinh(v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillSinh());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, Asinh_vec3_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 1.0;
+ let r : vec3<f32> = asinh(vec3<f32>(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_sinh(x : vec3<f32>) -> vec3<f32> {
+ return log((x + sqrt(((x * x) + 1))));
+}
+
+fn f() {
+ let v = 1.0;
+ let r : vec3<f32> = tint_sinh(vec3<f32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillSinh());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// atanh
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillAtanh(Level level) {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.atanh = level;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunAtanh) {
+ auto* src = R"(
+fn f() {
+ let v = 1.0;
+ _ = atanh(v);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillAtanh(Level::kNone)));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillAtanh(Level::kRangeCheck)));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillAtanh(Level::kFull)));
+}
+
+TEST_F(BuiltinPolyfillTest, Atanh_ConstantExpression) {
+ auto* src = R"(
+fn f() {
+ let r : f32 = atanh(0.23);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillAtanh(Level::kFull)));
+}
+
+TEST_F(BuiltinPolyfillTest, Atanh_Full_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 1.0;
+ let r : f32 = atanh(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_atanh(x : f32) -> f32 {
+ return (log(((1 + x) / (1 - x))) * 0.5);
+}
+
+fn f() {
+ let v = 1.0;
+ let r : f32 = tint_atanh(v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillAtanh(Level::kFull));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, Atanh_Full_vec3_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 1.0;
+ let r : vec3<f32> = atanh(vec3<f32>(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_atanh(x : vec3<f32>) -> vec3<f32> {
+ return (log(((1 + x) / (1 - x))) * 0.5);
+}
+
+fn f() {
+ let v = 1.0;
+ let r : vec3<f32> = tint_atanh(vec3<f32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillAtanh(Level::kFull));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, Atanh_Range_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 1.0;
+ let r : f32 = atanh(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_atanh(x : f32) -> f32 {
+ return select(atanh(x), 0.0, (x >= 1.0));
+}
+
+fn f() {
+ let v = 1.0;
+ let r : f32 = tint_atanh(v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillAtanh(Level::kRangeCheck));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, Atanh_Range_vec3_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 1.0;
+ let r : vec3<f32> = atanh(vec3<f32>(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_atanh(x : vec3<f32>) -> vec3<f32> {
+ return select(atanh(x), vec3<f32>(0.0), (x >= vec3<f32>(1.0)));
+}
+
+fn f() {
+ let v = 1.0;
+ let r : vec3<f32> = tint_atanh(vec3<f32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillAtanh(Level::kRangeCheck));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// bgra8unorm
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillBgra8unorm() {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.bgra8unorm = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunBgra8unorm_StorageTextureVar) {
+ auto* src = R"(
+@group(0) @binding(0) var tex : texture_storage_3d<bgra8unorm, write>;
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillBgra8unorm()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunBgra8unorm_StorageTextureParam) {
+ auto* src = R"(
+fn f(tex : texture_storage_3d<bgra8unorm, write>) {
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillBgra8unorm()));
+}
+
+TEST_F(BuiltinPolyfillTest, Bgra8unorm_StorageTextureVar) {
+ auto* src = R"(
+@group(0) @binding(0) var tex : texture_storage_3d<bgra8unorm, write>;
+)";
+
+ auto* expect = R"(
+@group(0) @binding(0) var tex : texture_storage_3d<rgba8unorm, write>;
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillBgra8unorm());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, Bgra8unorm_StorageTextureParam) {
+ auto* src = R"(
+fn f(tex : texture_storage_3d<bgra8unorm, write>) {
+}
+)";
+
+ auto* expect = R"(
+fn f(tex : texture_storage_3d<rgba8unorm, write>) {
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillBgra8unorm());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, Bgra8unorm_TextureStore) {
+ auto* src = R"(
+@group(0) @binding(0) var tex : texture_storage_2d<bgra8unorm, write>;
+
+fn f(coords : vec2<i32>, value : vec4<f32>) {
+ textureStore(tex, coords, value);
+}
+)";
+
+ auto* expect = R"(
+@group(0) @binding(0) var tex : texture_storage_2d<rgba8unorm, write>;
+
+fn f(coords : vec2<i32>, value : vec4<f32>) {
+ textureStore(tex, coords, value.bgra);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillBgra8unorm());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, Bgra8unorm_TextureStore_Param) {
+ auto* src = R"(
+fn f(tex : texture_storage_2d<bgra8unorm, write>, coords : vec2<i32>, value : vec4<f32>) {
+ textureStore(tex, coords, value);
+}
+)";
+
+ auto* expect = R"(
+fn f(tex : texture_storage_2d<rgba8unorm, write>, coords : vec2<i32>, value : vec4<f32>) {
+ textureStore(tex, coords, value.bgra);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillBgra8unorm());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, Bgra8unorm_TextureStore_WithAtanh) {
+ auto* src = R"(
+@group(0) @binding(0) var tex : texture_storage_2d<bgra8unorm, write>;
+
+fn f(coords : vec2<i32>, value : vec4<f32>) {
+ textureStore(tex, coords, atanh(value));
+}
+)";
+
+ auto* expect = R"(
+fn tint_atanh(x : vec4<f32>) -> vec4<f32> {
+ return (log(((1 + x) / (1 - x))) * 0.5);
+}
+
+@group(0) @binding(0) var tex : texture_storage_2d<rgba8unorm, write>;
+
+fn f(coords : vec2<i32>, value : vec4<f32>) {
+ textureStore(tex, coords, tint_atanh(value).bgra);
+}
+)";
+
+ BuiltinPolyfill::Builtins builtins;
+ builtins.atanh = BuiltinPolyfill::Level::kFull;
+ builtins.bgra8unorm = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+
+ auto got = Run<BuiltinPolyfill>(src, std::move(data));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// bitshiftModulo
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillBitshiftModulo() {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.bitshift_modulo = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunBitshiftModulo_shl_scalar) {
+ auto* src = R"(
+fn f() {
+ let v = 15u;
+ let r = 1i << v;
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillBitshiftModulo()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunBitshiftModulo_shl_vector) {
+ auto* src = R"(
+fn f() {
+ let v = 15u;
+ let r = vec3(1i) << vec3(v);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillBitshiftModulo()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunBitshiftModulo_shr_scalar) {
+ auto* src = R"(
+fn f() {
+ let v = 15u;
+ let r = 1i >> v;
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillBitshiftModulo()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunBitshiftModulo_shr_vector) {
+ auto* src = R"(
+fn f() {
+ let v = 15u;
+ let r = vec3(1i) >> vec3(v);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillBitshiftModulo()));
+}
+
+TEST_F(BuiltinPolyfillTest, BitshiftModulo_shl_scalar) {
+ auto* src = R"(
+fn f() {
+ let v = 15u;
+ let r = 1i << v;
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ let v = 15u;
+ let r = (1i << (v & 31));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillBitshiftModulo());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, BitshiftModulo_shl_vector) {
+ auto* src = R"(
+fn f() {
+ let v = 15u;
+ let r = vec3(1i) << vec3(v);
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ let v = 15u;
+ let r = (vec3(1i) << (vec3(v) & vec3<u32>(31)));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillBitshiftModulo());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, BitshiftModulo_shr_scalar) {
+ auto* src = R"(
+fn f() {
+ let v = 15u;
+ let r = 1i >> v;
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ let v = 15u;
+ let r = (1i >> (v & 31));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillBitshiftModulo());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, BitshiftModulo_shr_vector) {
+ auto* src = R"(
+fn f() {
+ let v = 15u;
+ let r = vec3(1i) >> vec3(v);
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ let v = 15u;
+ let r = (vec3(1i) >> (vec3(v) & vec3<u32>(31)));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillBitshiftModulo());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// clampInteger
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillClampInteger() {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.clamp_int = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunClampInteger_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 1i;
+ _ = clamp(v, 2i, 3i);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillClampInteger()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunClampInteger_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 1u;
+ _ = clamp(v, 2u, 3u);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillClampInteger()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunClampInteger_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 1f;
+ _ = clamp(v, 2f, 3f);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillClampInteger()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunClampInteger_f16) {
+ auto* src = R"(
+enable f16;
+
+fn f() {
+ let v = 1h;
+ _ = clamp(v, 2h, 3h);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillClampInteger()));
+}
+
+TEST_F(BuiltinPolyfillTest, ClampInteger_ConstantExpression) {
+ auto* src = R"(
+fn f() {
+ let r : i32 = clamp(1i, 2i, 3i);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillClampInteger()));
+}
+
+TEST_F(BuiltinPolyfillTest, ClampInteger_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 1i;
+ let r : i32 = clamp(v, 2i, 3i);
+}
+)";
+
+ auto* expect = R"(
+fn tint_clamp(e : i32, low : i32, high : i32) -> i32 {
+ return min(max(e, low), high);
+}
+
+fn f() {
+ let v = 1i;
+ let r : i32 = tint_clamp(v, 2i, 3i);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillClampInteger());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, ClampInteger_vec3_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 1i;
+ let r : vec3<i32> = clamp(vec3(v), vec3(2i), vec3(3i));
+}
+)";
+
+ auto* expect =
+ R"(
+fn tint_clamp(e : vec3<i32>, low : vec3<i32>, high : vec3<i32>) -> vec3<i32> {
+ return min(max(e, low), high);
+}
+
+fn f() {
+ let v = 1i;
+ let r : vec3<i32> = tint_clamp(vec3(v), vec3(2i), vec3(3i));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillClampInteger());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, ClampInteger_u32) {
+ auto* src = R"(
+fn f() {
+ let r : u32 = clamp(1u, 2u, 3u);
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ let r : u32 = clamp(1u, 2u, 3u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillClampInteger());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, ClampInteger_vec3_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 1u;
+ let r : vec3<u32> = clamp(vec3(v), vec3(2u), vec3(3u));
+}
+)";
+
+ auto* expect =
+ R"(
+fn tint_clamp(e : vec3<u32>, low : vec3<u32>, high : vec3<u32>) -> vec3<u32> {
+ return min(max(e, low), high);
+}
+
+fn f() {
+ let v = 1u;
+ let r : vec3<u32> = tint_clamp(vec3(v), vec3(2u), vec3(3u));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillClampInteger());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// conv_f32_to_iu32
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillConvF32ToIU32() {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.conv_f32_to_iu32 = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunConvF32ToI32) {
+ auto* src = R"(
+fn f() {
+ let f = 42.0;
+ _ = i32(f);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillConvF32ToIU32()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunConvF32ToU32) {
+ auto* src = R"(
+fn f() {
+ let f = 42.0;
+ _ = u32(f);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillConvF32ToIU32()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunConvVec3F32ToVec3I32) {
+ auto* src = R"(
+fn f() {
+ let f = vec3(42.0);
+ _ = vec3<i32>(f);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillConvF32ToIU32()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunConvVec3F32ToVec3U32) {
+ auto* src = R"(
+fn f() {
+ let f = vec3(42.0);
+ _ = vec3<u32>(f);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillConvF32ToIU32()));
+}
+
+TEST_F(BuiltinPolyfillTest, ConvF32ToI32) {
+ auto* src = R"(
+fn f() {
+ let f = 42.0;
+ _ = i32(f);
+}
+)";
+ auto* expect = R"(
+fn tint_ftoi(v : f32) -> i32 {
+ return select(2147483647, select(i32(v), -2147483648, (v < -2147483648.0)), (v < 2147483520.0));
+}
+
+fn f() {
+ let f = 42.0;
+ _ = tint_ftoi(f);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillConvF32ToIU32());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, ConvF32ToU32) {
+ auto* src = R"(
+fn f() {
+ let f = 42.0;
+ _ = u32(f);
+}
+)";
+ auto* expect = R"(
+fn tint_ftou(v : f32) -> u32 {
+ return select(4294967295, select(u32(v), 0, (v < 0.0)), (v < 4294967040.0));
+}
+
+fn f() {
+ let f = 42.0;
+ _ = tint_ftou(f);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillConvF32ToIU32());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, ConvVec3F32ToVec3I32) {
+ auto* src = R"(
+fn f() {
+ let f = vec3(42.0);
+ _ = vec3<i32>(f);
+}
+)";
+ auto* expect = R"(
+fn tint_ftoi(v : vec3<f32>) -> vec3<i32> {
+ return select(vec3(2147483647), select(vec3<i32>(v), vec3(-2147483648), (v < vec3(-2147483648.0))), (v < vec3(2147483520.0)));
+}
+
+fn f() {
+ let f = vec3(42.0);
+ _ = tint_ftoi(f);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillConvF32ToIU32());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, ConvVec3F32ToVec3U32) {
+ auto* src = R"(
+fn f() {
+ let f = vec3(42.0);
+ _ = vec3<u32>(f);
+}
+)";
+ auto* expect = R"(
+fn tint_ftou(v : vec3<f32>) -> vec3<u32> {
+ return select(vec3(4294967295), select(vec3<u32>(v), vec3(0), (v < vec3(0.0))), (v < vec3(4294967040.0)));
+}
+
+fn f() {
+ let f = vec3(42.0);
+ _ = tint_ftou(f);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillConvF32ToIU32());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// countLeadingZeros
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillCountLeadingZeros() {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.count_leading_zeros = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunCountLeadingZeros) {
+ auto* src = R"(
+fn f() {
+ let v = 15;
+ _ = countLeadingZeros(v);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillCountLeadingZeros()));
+}
+
+TEST_F(BuiltinPolyfillTest, CountLeadingZeros_ConstantExpression) {
+ auto* src = R"(
+fn f() {
+ let r : i32 = countLeadingZeros(15i);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillCountLeadingZeros()));
+}
+
+TEST_F(BuiltinPolyfillTest, CountLeadingZeros_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 15i;
+ let r : i32 = countLeadingZeros(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_count_leading_zeros(v : i32) -> i32 {
+ var x = u32(v);
+ let b16 = select(0u, 16u, (x <= 65535u));
+ x = (x << b16);
+ let b8 = select(0u, 8u, (x <= 16777215u));
+ x = (x << b8);
+ let b4 = select(0u, 4u, (x <= 268435455u));
+ x = (x << b4);
+ let b2 = select(0u, 2u, (x <= 1073741823u));
+ x = (x << b2);
+ let b1 = select(0u, 1u, (x <= 2147483647u));
+ let is_zero = select(0u, 1u, (x == 0u));
+ return i32((((((b16 | b8) | b4) | b2) | b1) + is_zero));
+}
+
+fn f() {
+ let v = 15i;
+ let r : i32 = tint_count_leading_zeros(v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillCountLeadingZeros());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, CountLeadingZeros_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 15u;
+ let r : u32 = countLeadingZeros(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_count_leading_zeros(v : u32) -> u32 {
+ var x = u32(v);
+ let b16 = select(0u, 16u, (x <= 65535u));
+ x = (x << b16);
+ let b8 = select(0u, 8u, (x <= 16777215u));
+ x = (x << b8);
+ let b4 = select(0u, 4u, (x <= 268435455u));
+ x = (x << b4);
+ let b2 = select(0u, 2u, (x <= 1073741823u));
+ x = (x << b2);
+ let b1 = select(0u, 1u, (x <= 2147483647u));
+ let is_zero = select(0u, 1u, (x == 0u));
+ return u32((((((b16 | b8) | b4) | b2) | b1) + is_zero));
+}
+
+fn f() {
+ let v = 15u;
+ let r : u32 = tint_count_leading_zeros(v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillCountLeadingZeros());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, CountLeadingZeros_vec3_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 15i;
+ let r : vec3<i32> = countLeadingZeros(vec3<i32>(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_count_leading_zeros(v : vec3<i32>) -> vec3<i32> {
+ var x = vec3<u32>(v);
+ let b16 = select(vec3<u32>(0u), vec3<u32>(16u), (x <= vec3<u32>(65535u)));
+ x = (x << b16);
+ let b8 = select(vec3<u32>(0u), vec3<u32>(8u), (x <= vec3<u32>(16777215u)));
+ x = (x << b8);
+ let b4 = select(vec3<u32>(0u), vec3<u32>(4u), (x <= vec3<u32>(268435455u)));
+ x = (x << b4);
+ let b2 = select(vec3<u32>(0u), vec3<u32>(2u), (x <= vec3<u32>(1073741823u)));
+ x = (x << b2);
+ let b1 = select(vec3<u32>(0u), vec3<u32>(1u), (x <= vec3<u32>(2147483647u)));
+ let is_zero = select(vec3<u32>(0u), vec3<u32>(1u), (x == vec3<u32>(0u)));
+ return vec3<i32>((((((b16 | b8) | b4) | b2) | b1) + is_zero));
+}
+
+fn f() {
+ let v = 15i;
+ let r : vec3<i32> = tint_count_leading_zeros(vec3<i32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillCountLeadingZeros());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, CountLeadingZeros_vec3_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 15u;
+ let r : vec3<u32> = countLeadingZeros(vec3<u32>(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_count_leading_zeros(v : vec3<u32>) -> vec3<u32> {
+ var x = vec3<u32>(v);
+ let b16 = select(vec3<u32>(0u), vec3<u32>(16u), (x <= vec3<u32>(65535u)));
+ x = (x << b16);
+ let b8 = select(vec3<u32>(0u), vec3<u32>(8u), (x <= vec3<u32>(16777215u)));
+ x = (x << b8);
+ let b4 = select(vec3<u32>(0u), vec3<u32>(4u), (x <= vec3<u32>(268435455u)));
+ x = (x << b4);
+ let b2 = select(vec3<u32>(0u), vec3<u32>(2u), (x <= vec3<u32>(1073741823u)));
+ x = (x << b2);
+ let b1 = select(vec3<u32>(0u), vec3<u32>(1u), (x <= vec3<u32>(2147483647u)));
+ let is_zero = select(vec3<u32>(0u), vec3<u32>(1u), (x == vec3<u32>(0u)));
+ return vec3<u32>((((((b16 | b8) | b4) | b2) | b1) + is_zero));
+}
+
+fn f() {
+ let v = 15u;
+ let r : vec3<u32> = tint_count_leading_zeros(vec3<u32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillCountLeadingZeros());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// countTrailingZeros
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillCountTrailingZeros() {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.count_trailing_zeros = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunCountTrailingZeros) {
+ auto* src = R"(
+fn f() {
+ let v = 15;
+ _ = countTrailingZeros(v);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillCountTrailingZeros()));
+}
+
+TEST_F(BuiltinPolyfillTest, CountTrailingZeros_ConstantExpression) {
+ auto* src = R"(
+fn f() {
+ let r : i32 = countTrailingZeros(15i);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillCountTrailingZeros()));
+}
+
+TEST_F(BuiltinPolyfillTest, CountTrailingZeros_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 15i;
+ let r : i32 = countTrailingZeros(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_count_trailing_zeros(v : i32) -> i32 {
+ var x = u32(v);
+ let b16 = select(16u, 0u, bool((x & 65535u)));
+ x = (x >> b16);
+ let b8 = select(8u, 0u, bool((x & 255u)));
+ x = (x >> b8);
+ let b4 = select(4u, 0u, bool((x & 15u)));
+ x = (x >> b4);
+ let b2 = select(2u, 0u, bool((x & 3u)));
+ x = (x >> b2);
+ let b1 = select(1u, 0u, bool((x & 1u)));
+ let is_zero = select(0u, 1u, (x == 0u));
+ return i32((((((b16 | b8) | b4) | b2) | b1) + is_zero));
+}
+
+fn f() {
+ let v = 15i;
+ let r : i32 = tint_count_trailing_zeros(v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillCountTrailingZeros());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, CountTrailingZeros_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 15u;
+ let r : u32 = countTrailingZeros(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_count_trailing_zeros(v : u32) -> u32 {
+ var x = u32(v);
+ let b16 = select(16u, 0u, bool((x & 65535u)));
+ x = (x >> b16);
+ let b8 = select(8u, 0u, bool((x & 255u)));
+ x = (x >> b8);
+ let b4 = select(4u, 0u, bool((x & 15u)));
+ x = (x >> b4);
+ let b2 = select(2u, 0u, bool((x & 3u)));
+ x = (x >> b2);
+ let b1 = select(1u, 0u, bool((x & 1u)));
+ let is_zero = select(0u, 1u, (x == 0u));
+ return u32((((((b16 | b8) | b4) | b2) | b1) + is_zero));
+}
+
+fn f() {
+ let v = 15u;
+ let r : u32 = tint_count_trailing_zeros(v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillCountTrailingZeros());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, CountTrailingZeros_vec3_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 15i;
+ let r : vec3<i32> = countTrailingZeros(vec3<i32>(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_count_trailing_zeros(v : vec3<i32>) -> vec3<i32> {
+ var x = vec3<u32>(v);
+ let b16 = select(vec3<u32>(16u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(65535u))));
+ x = (x >> b16);
+ let b8 = select(vec3<u32>(8u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(255u))));
+ x = (x >> b8);
+ let b4 = select(vec3<u32>(4u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(15u))));
+ x = (x >> b4);
+ let b2 = select(vec3<u32>(2u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(3u))));
+ x = (x >> b2);
+ let b1 = select(vec3<u32>(1u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(1u))));
+ let is_zero = select(vec3<u32>(0u), vec3<u32>(1u), (x == vec3<u32>(0u)));
+ return vec3<i32>((((((b16 | b8) | b4) | b2) | b1) + is_zero));
+}
+
+fn f() {
+ let v = 15i;
+ let r : vec3<i32> = tint_count_trailing_zeros(vec3<i32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillCountTrailingZeros());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, CountTrailingZeros_vec3_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 15u;
+ let r : vec3<u32> = countTrailingZeros(vec3<u32>(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_count_trailing_zeros(v : vec3<u32>) -> vec3<u32> {
+ var x = vec3<u32>(v);
+ let b16 = select(vec3<u32>(16u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(65535u))));
+ x = (x >> b16);
+ let b8 = select(vec3<u32>(8u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(255u))));
+ x = (x >> b8);
+ let b4 = select(vec3<u32>(4u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(15u))));
+ x = (x >> b4);
+ let b2 = select(vec3<u32>(2u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(3u))));
+ x = (x >> b2);
+ let b1 = select(vec3<u32>(1u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(1u))));
+ let is_zero = select(vec3<u32>(0u), vec3<u32>(1u), (x == vec3<u32>(0u)));
+ return vec3<u32>((((((b16 | b8) | b4) | b2) | b1) + is_zero));
+}
+
+fn f() {
+ let v = 15u;
+ let r : vec3<u32> = tint_count_trailing_zeros(vec3<u32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillCountTrailingZeros());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// extractBits
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillExtractBits(Level level) {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.extract_bits = level;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunExtractBits) {
+ auto* src = R"(
+fn f() {
+ let v = 1234i;
+ _ = extractBits(v, 5u, 6u);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillExtractBits(Level::kNone)));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillExtractBits(Level::kClampParameters)));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull)));
+}
+
+TEST_F(BuiltinPolyfillTest, ExtractBits_ConstantExpression) {
+ auto* src = R"(
+fn f() {
+ let r : i32 = countTrailingZeros(15i);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull)));
+}
+
+TEST_F(BuiltinPolyfillTest, ExtractBits_Full_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 1234i;
+ let r : i32 = extractBits(v, 5u, 6u);
+}
+)";
+
+ auto* expect = R"(
+fn tint_extract_bits(v : i32, offset : u32, count : u32) -> i32 {
+ let s = min(offset, 32u);
+ let e = min(32u, (s + count));
+ let shl = (32u - e);
+ let shr = (shl + s);
+ let shl_result = select(i32(), (v << shl), (shl < 32u));
+ return select(((shl_result >> 31u) >> 1u), (shl_result >> shr), (shr < 32u));
+}
+
+fn f() {
+ let v = 1234i;
+ let r : i32 = tint_extract_bits(v, 5u, 6u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, ExtractBits_Full_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 1234u;
+ let r : u32 = extractBits(v, 5u, 6u);
+}
+)";
+
+ auto* expect = R"(
+fn tint_extract_bits(v : u32, offset : u32, count : u32) -> u32 {
+ let s = min(offset, 32u);
+ let e = min(32u, (s + count));
+ let shl = (32u - e);
+ let shr = (shl + s);
+ let shl_result = select(u32(), (v << shl), (shl < 32u));
+ return select(((shl_result >> 31u) >> 1u), (shl_result >> shr), (shr < 32u));
+}
+
+fn f() {
+ let v = 1234u;
+ let r : u32 = tint_extract_bits(v, 5u, 6u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, ExtractBits_Full_vec3_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 1234i;
+ let r : vec3<i32> = extractBits(vec3<i32>(v), 5u, 6u);
+}
+)";
+
+ auto* expect = R"(
+fn tint_extract_bits(v : vec3<i32>, offset : u32, count : u32) -> vec3<i32> {
+ let s = min(offset, 32u);
+ let e = min(32u, (s + count));
+ let shl = (32u - e);
+ let shr = (shl + s);
+ let shl_result = select(vec3<i32>(), (v << vec3<u32>(shl)), (shl < 32u));
+ return select(((shl_result >> vec3<u32>(31u)) >> vec3<u32>(1u)), (shl_result >> vec3<u32>(shr)), (shr < 32u));
+}
+
+fn f() {
+ let v = 1234i;
+ let r : vec3<i32> = tint_extract_bits(vec3<i32>(v), 5u, 6u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, ExtractBits_Full_vec3_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 1234u;
+ let r : vec3<u32> = extractBits(vec3<u32>(v), 5u, 6u);
+}
+)";
+
+ auto* expect = R"(
+fn tint_extract_bits(v : vec3<u32>, offset : u32, count : u32) -> vec3<u32> {
+ let s = min(offset, 32u);
+ let e = min(32u, (s + count));
+ let shl = (32u - e);
+ let shr = (shl + s);
+ let shl_result = select(vec3<u32>(), (v << vec3<u32>(shl)), (shl < 32u));
+ return select(((shl_result >> vec3<u32>(31u)) >> vec3<u32>(1u)), (shl_result >> vec3<u32>(shr)), (shr < 32u));
+}
+
+fn f() {
+ let v = 1234u;
+ let r : vec3<u32> = tint_extract_bits(vec3<u32>(v), 5u, 6u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, ExtractBits_Clamp_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 1234i;
+ let r : i32 = extractBits(v, 5u, 6u);
+}
+)";
+
+ auto* expect = R"(
+fn tint_extract_bits(v : i32, offset : u32, count : u32) -> i32 {
+ let s = min(offset, 32u);
+ let e = min(32u, (s + count));
+ return extractBits(v, s, (e - s));
+}
+
+fn f() {
+ let v = 1234i;
+ let r : i32 = tint_extract_bits(v, 5u, 6u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kClampParameters));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, ExtractBits_Clamp_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 1234u;
+ let r : u32 = extractBits(v, 5u, 6u);
+}
+)";
+
+ auto* expect = R"(
+fn tint_extract_bits(v : u32, offset : u32, count : u32) -> u32 {
+ let s = min(offset, 32u);
+ let e = min(32u, (s + count));
+ return extractBits(v, s, (e - s));
+}
+
+fn f() {
+ let v = 1234u;
+ let r : u32 = tint_extract_bits(v, 5u, 6u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kClampParameters));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, ExtractBits_Clamp_vec3_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 1234i;
+ let r : vec3<i32> = extractBits(vec3<i32>(v), 5u, 6u);
+}
+)";
+
+ auto* expect = R"(
+fn tint_extract_bits(v : vec3<i32>, offset : u32, count : u32) -> vec3<i32> {
+ let s = min(offset, 32u);
+ let e = min(32u, (s + count));
+ return extractBits(v, s, (e - s));
+}
+
+fn f() {
+ let v = 1234i;
+ let r : vec3<i32> = tint_extract_bits(vec3<i32>(v), 5u, 6u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kClampParameters));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, ExtractBits_Clamp_vec3_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 1234u;
+ let r : vec3<u32> = extractBits(vec3<u32>(v), 5u, 6u);
+}
+)";
+
+ auto* expect = R"(
+fn tint_extract_bits(v : vec3<u32>, offset : u32, count : u32) -> vec3<u32> {
+ let s = min(offset, 32u);
+ let e = min(32u, (s + count));
+ return extractBits(v, s, (e - s));
+}
+
+fn f() {
+ let v = 1234u;
+ let r : vec3<u32> = tint_extract_bits(vec3<u32>(v), 5u, 6u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kClampParameters));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// firstLeadingBit
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillFirstLeadingBit() {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.first_leading_bit = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunFirstLeadingBit) {
+ auto* src = R"(
+fn f() {
+ let v = 15i;
+ _ = firstLeadingBit(v);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillFirstLeadingBit()));
+}
+
+TEST_F(BuiltinPolyfillTest, FirstLeadingBit_ConstantExpression) {
+ auto* src = R"(
+fn f() {
+ let r : i32 = firstLeadingBit(15i);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillFirstLeadingBit()));
+}
+
+TEST_F(BuiltinPolyfillTest, FirstLeadingBit_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 15i;
+ let r : i32 = firstLeadingBit(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_first_leading_bit(v : i32) -> i32 {
+ var x = select(u32(v), u32(~(v)), (v < 0i));
+ let b16 = select(0u, 16u, bool((x & 4294901760u)));
+ x = (x >> b16);
+ let b8 = select(0u, 8u, bool((x & 65280u)));
+ x = (x >> b8);
+ let b4 = select(0u, 4u, bool((x & 240u)));
+ x = (x >> b4);
+ let b2 = select(0u, 2u, bool((x & 12u)));
+ x = (x >> b2);
+ let b1 = select(0u, 1u, bool((x & 2u)));
+ let is_zero = select(0u, 4294967295u, (x == 0u));
+ return i32((((((b16 | b8) | b4) | b2) | b1) | is_zero));
+}
+
+fn f() {
+ let v = 15i;
+ let r : i32 = tint_first_leading_bit(v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillFirstLeadingBit());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, FirstLeadingBit_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 15u;
+ let r : u32 = firstLeadingBit(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_first_leading_bit(v : u32) -> u32 {
+ var x = v;
+ let b16 = select(0u, 16u, bool((x & 4294901760u)));
+ x = (x >> b16);
+ let b8 = select(0u, 8u, bool((x & 65280u)));
+ x = (x >> b8);
+ let b4 = select(0u, 4u, bool((x & 240u)));
+ x = (x >> b4);
+ let b2 = select(0u, 2u, bool((x & 12u)));
+ x = (x >> b2);
+ let b1 = select(0u, 1u, bool((x & 2u)));
+ let is_zero = select(0u, 4294967295u, (x == 0u));
+ return u32((((((b16 | b8) | b4) | b2) | b1) | is_zero));
+}
+
+fn f() {
+ let v = 15u;
+ let r : u32 = tint_first_leading_bit(v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillFirstLeadingBit());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, FirstLeadingBit_vec3_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 15i;
+ let r : vec3<i32> = firstLeadingBit(vec3<i32>(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_first_leading_bit(v : vec3<i32>) -> vec3<i32> {
+ var x = select(vec3<u32>(v), vec3<u32>(~(v)), (v < vec3<i32>(0i)));
+ let b16 = select(vec3<u32>(0u), vec3<u32>(16u), vec3<bool>((x & vec3<u32>(4294901760u))));
+ x = (x >> b16);
+ let b8 = select(vec3<u32>(0u), vec3<u32>(8u), vec3<bool>((x & vec3<u32>(65280u))));
+ x = (x >> b8);
+ let b4 = select(vec3<u32>(0u), vec3<u32>(4u), vec3<bool>((x & vec3<u32>(240u))));
+ x = (x >> b4);
+ let b2 = select(vec3<u32>(0u), vec3<u32>(2u), vec3<bool>((x & vec3<u32>(12u))));
+ x = (x >> b2);
+ let b1 = select(vec3<u32>(0u), vec3<u32>(1u), vec3<bool>((x & vec3<u32>(2u))));
+ let is_zero = select(vec3<u32>(0u), vec3<u32>(4294967295u), (x == vec3<u32>(0u)));
+ return vec3<i32>((((((b16 | b8) | b4) | b2) | b1) | is_zero));
+}
+
+fn f() {
+ let v = 15i;
+ let r : vec3<i32> = tint_first_leading_bit(vec3<i32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillFirstLeadingBit());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, FirstLeadingBit_vec3_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 15u;
+ let r : vec3<u32> = firstLeadingBit(vec3<u32>(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_first_leading_bit(v : vec3<u32>) -> vec3<u32> {
+ var x = v;
+ let b16 = select(vec3<u32>(0u), vec3<u32>(16u), vec3<bool>((x & vec3<u32>(4294901760u))));
+ x = (x >> b16);
+ let b8 = select(vec3<u32>(0u), vec3<u32>(8u), vec3<bool>((x & vec3<u32>(65280u))));
+ x = (x >> b8);
+ let b4 = select(vec3<u32>(0u), vec3<u32>(4u), vec3<bool>((x & vec3<u32>(240u))));
+ x = (x >> b4);
+ let b2 = select(vec3<u32>(0u), vec3<u32>(2u), vec3<bool>((x & vec3<u32>(12u))));
+ x = (x >> b2);
+ let b1 = select(vec3<u32>(0u), vec3<u32>(1u), vec3<bool>((x & vec3<u32>(2u))));
+ let is_zero = select(vec3<u32>(0u), vec3<u32>(4294967295u), (x == vec3<u32>(0u)));
+ return vec3<u32>((((((b16 | b8) | b4) | b2) | b1) | is_zero));
+}
+
+fn f() {
+ let v = 15u;
+ let r : vec3<u32> = tint_first_leading_bit(vec3<u32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillFirstLeadingBit());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// firstTrailingBit
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillFirstTrailingBit() {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.first_trailing_bit = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunFirstTrailingBit) {
+ auto* src = R"(
+fn f() {
+ let v = 15i;
+ _ = firstTrailingBit(v);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillFirstTrailingBit()));
+}
+
+TEST_F(BuiltinPolyfillTest, FirstTrailingBit_ConstantExpression) {
+ auto* src = R"(
+fn f() {
+ let r : i32 = firstTrailingBit(15i);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillFirstTrailingBit()));
+}
+
+TEST_F(BuiltinPolyfillTest, FirstTrailingBit_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 15i;
+ let r : i32 = firstTrailingBit(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_first_trailing_bit(v : i32) -> i32 {
+ var x = u32(v);
+ let b16 = select(16u, 0u, bool((x & 65535u)));
+ x = (x >> b16);
+ let b8 = select(8u, 0u, bool((x & 255u)));
+ x = (x >> b8);
+ let b4 = select(4u, 0u, bool((x & 15u)));
+ x = (x >> b4);
+ let b2 = select(2u, 0u, bool((x & 3u)));
+ x = (x >> b2);
+ let b1 = select(1u, 0u, bool((x & 1u)));
+ let is_zero = select(0u, 4294967295u, (x == 0u));
+ return i32((((((b16 | b8) | b4) | b2) | b1) | is_zero));
+}
+
+fn f() {
+ let v = 15i;
+ let r : i32 = tint_first_trailing_bit(v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillFirstTrailingBit());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, FirstTrailingBit_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 15u;
+ let r : u32 = firstTrailingBit(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_first_trailing_bit(v : u32) -> u32 {
+ var x = u32(v);
+ let b16 = select(16u, 0u, bool((x & 65535u)));
+ x = (x >> b16);
+ let b8 = select(8u, 0u, bool((x & 255u)));
+ x = (x >> b8);
+ let b4 = select(4u, 0u, bool((x & 15u)));
+ x = (x >> b4);
+ let b2 = select(2u, 0u, bool((x & 3u)));
+ x = (x >> b2);
+ let b1 = select(1u, 0u, bool((x & 1u)));
+ let is_zero = select(0u, 4294967295u, (x == 0u));
+ return u32((((((b16 | b8) | b4) | b2) | b1) | is_zero));
+}
+
+fn f() {
+ let v = 15u;
+ let r : u32 = tint_first_trailing_bit(v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillFirstTrailingBit());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, FirstTrailingBit_vec3_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 15i;
+ let r : vec3<i32> = firstTrailingBit(vec3<i32>(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_first_trailing_bit(v : vec3<i32>) -> vec3<i32> {
+ var x = vec3<u32>(v);
+ let b16 = select(vec3<u32>(16u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(65535u))));
+ x = (x >> b16);
+ let b8 = select(vec3<u32>(8u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(255u))));
+ x = (x >> b8);
+ let b4 = select(vec3<u32>(4u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(15u))));
+ x = (x >> b4);
+ let b2 = select(vec3<u32>(2u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(3u))));
+ x = (x >> b2);
+ let b1 = select(vec3<u32>(1u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(1u))));
+ let is_zero = select(vec3<u32>(0u), vec3<u32>(4294967295u), (x == vec3<u32>(0u)));
+ return vec3<i32>((((((b16 | b8) | b4) | b2) | b1) | is_zero));
+}
+
+fn f() {
+ let v = 15i;
+ let r : vec3<i32> = tint_first_trailing_bit(vec3<i32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillFirstTrailingBit());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, FirstTrailingBit_vec3_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 15u;
+ let r : vec3<u32> = firstTrailingBit(vec3<u32>(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_first_trailing_bit(v : vec3<u32>) -> vec3<u32> {
+ var x = vec3<u32>(v);
+ let b16 = select(vec3<u32>(16u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(65535u))));
+ x = (x >> b16);
+ let b8 = select(vec3<u32>(8u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(255u))));
+ x = (x >> b8);
+ let b4 = select(vec3<u32>(4u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(15u))));
+ x = (x >> b4);
+ let b2 = select(vec3<u32>(2u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(3u))));
+ x = (x >> b2);
+ let b1 = select(vec3<u32>(1u), vec3<u32>(0u), vec3<bool>((x & vec3<u32>(1u))));
+ let is_zero = select(vec3<u32>(0u), vec3<u32>(4294967295u), (x == vec3<u32>(0u)));
+ return vec3<u32>((((((b16 | b8) | b4) | b2) | b1) | is_zero));
+}
+
+fn f() {
+ let v = 15u;
+ let r : vec3<u32> = tint_first_trailing_bit(vec3<u32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillFirstTrailingBit());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// insertBits
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillInsertBits(Level level) {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.insert_bits = level;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunInsertBits) {
+ auto* src = R"(
+fn f() {
+ let v = 1234i;
+ _ = insertBits(v, 5678, 5u, 6u);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillInsertBits(Level::kNone)));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillInsertBits(Level::kClampParameters)));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull)));
+}
+
+TEST_F(BuiltinPolyfillTest, InsertBits_ConstantExpression) {
+ auto* src = R"(
+fn f() {
+ let r : i32 = insertBits(1234i, 5678i, 5u, 6u);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull)));
+}
+
+TEST_F(BuiltinPolyfillTest, InsertBits_Full_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 1234i;
+ let r : i32 = insertBits(v, 5678, 5u, 6u);
+}
+)";
+
+ auto* expect = R"(
+fn tint_insert_bits(v : i32, n : i32, offset : u32, count : u32) -> i32 {
+ let e = (offset + count);
+ let mask = ((select(0u, (1u << offset), (offset < 32u)) - 1u) ^ (select(0u, (1u << e), (e < 32u)) - 1u));
+ return ((select(i32(), (n << offset), (offset < 32u)) & i32(mask)) | (v & i32(~(mask))));
+}
+
+fn f() {
+ let v = 1234i;
+ let r : i32 = tint_insert_bits(v, 5678, 5u, 6u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, InsertBits_Full_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 1234u;
+ let r : u32 = insertBits(v, 5678u, 5u, 6u);
+}
+)";
+
+ auto* expect = R"(
+fn tint_insert_bits(v : u32, n : u32, offset : u32, count : u32) -> u32 {
+ let e = (offset + count);
+ let mask = ((select(0u, (1u << offset), (offset < 32u)) - 1u) ^ (select(0u, (1u << e), (e < 32u)) - 1u));
+ return ((select(u32(), (n << offset), (offset < 32u)) & mask) | (v & ~(mask)));
+}
+
+fn f() {
+ let v = 1234u;
+ let r : u32 = tint_insert_bits(v, 5678u, 5u, 6u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, InsertBits_Full_vec3_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 1234i;
+ let r : vec3<i32> = insertBits(vec3<i32>(v), vec3<i32>(5678), 5u, 6u);
+}
+)";
+
+ auto* expect = R"(
+fn tint_insert_bits(v : vec3<i32>, n : vec3<i32>, offset : u32, count : u32) -> vec3<i32> {
+ let e = (offset + count);
+ let mask = ((select(0u, (1u << offset), (offset < 32u)) - 1u) ^ (select(0u, (1u << e), (e < 32u)) - 1u));
+ return ((select(vec3<i32>(), (n << vec3<u32>(offset)), (offset < 32u)) & vec3<i32>(i32(mask))) | (v & vec3<i32>(i32(~(mask)))));
+}
+
+fn f() {
+ let v = 1234i;
+ let r : vec3<i32> = tint_insert_bits(vec3<i32>(v), vec3<i32>(5678), 5u, 6u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, InsertBits_Full_vec3_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 1234u;
+ let r : vec3<u32> = insertBits(vec3<u32>(v), vec3<u32>(5678u), 5u, 6u);
+}
+)";
+
+ auto* expect = R"(
+fn tint_insert_bits(v : vec3<u32>, n : vec3<u32>, offset : u32, count : u32) -> vec3<u32> {
+ let e = (offset + count);
+ let mask = ((select(0u, (1u << offset), (offset < 32u)) - 1u) ^ (select(0u, (1u << e), (e < 32u)) - 1u));
+ return ((select(vec3<u32>(), (n << vec3<u32>(offset)), (offset < 32u)) & vec3<u32>(mask)) | (v & vec3<u32>(~(mask))));
+}
+
+fn f() {
+ let v = 1234u;
+ let r : vec3<u32> = tint_insert_bits(vec3<u32>(v), vec3<u32>(5678u), 5u, 6u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, InsertBits_Clamp_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 1234i;
+ let r : i32 = insertBits(v, 5678, 5u, 6u);
+}
+)";
+
+ auto* expect = R"(
+fn tint_insert_bits(v : i32, n : i32, offset : u32, count : u32) -> i32 {
+ let s = min(offset, 32u);
+ let e = min(32u, (s + count));
+ return insertBits(v, n, s, (e - s));
+}
+
+fn f() {
+ let v = 1234i;
+ let r : i32 = tint_insert_bits(v, 5678, 5u, 6u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kClampParameters));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, InsertBits_Clamp_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 1234u;
+ let r : u32 = insertBits(v, 5678u, 5u, 6u);
+}
+)";
+
+ auto* expect = R"(
+fn tint_insert_bits(v : u32, n : u32, offset : u32, count : u32) -> u32 {
+ let s = min(offset, 32u);
+ let e = min(32u, (s + count));
+ return insertBits(v, n, s, (e - s));
+}
+
+fn f() {
+ let v = 1234u;
+ let r : u32 = tint_insert_bits(v, 5678u, 5u, 6u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kClampParameters));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, InsertBits_Clamp_vec3_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 1234i;
+ let r : vec3<i32> = insertBits(vec3<i32>(v), vec3<i32>(5678), 5u, 6u);
+}
+)";
+
+ auto* expect = R"(
+fn tint_insert_bits(v : vec3<i32>, n : vec3<i32>, offset : u32, count : u32) -> vec3<i32> {
+ let s = min(offset, 32u);
+ let e = min(32u, (s + count));
+ return insertBits(v, n, s, (e - s));
+}
+
+fn f() {
+ let v = 1234i;
+ let r : vec3<i32> = tint_insert_bits(vec3<i32>(v), vec3<i32>(5678), 5u, 6u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kClampParameters));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, InsertBits_Clamp_vec3_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 1234u;
+ let r : vec3<u32> = insertBits(vec3<u32>(v), vec3<u32>(5678u), 5u, 6u);
+}
+)";
+
+ auto* expect = R"(
+fn tint_insert_bits(v : vec3<u32>, n : vec3<u32>, offset : u32, count : u32) -> vec3<u32> {
+ let s = min(offset, 32u);
+ let e = min(32u, (s + count));
+ return insertBits(v, n, s, (e - s));
+}
+
+fn f() {
+ let v = 1234u;
+ let r : vec3<u32> = tint_insert_bits(vec3<u32>(v), vec3<u32>(5678u), 5u, 6u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kClampParameters));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// precise_float_mod
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillPreciseFloatMod() {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.precise_float_mod = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunPreciseFloatMod) {
+ auto* src = R"(
+fn f() {
+ let v = 10f;
+ let x = 20f % v;
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillPreciseFloatMod()));
+}
+
+TEST_F(BuiltinPolyfillTest, PreciseFloatMod_af_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 10f;
+ let x = 20.0 % v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_float_mod(lhs : f32, rhs : f32) -> f32 {
+ return (lhs - (trunc((lhs / rhs)) * rhs));
+}
+
+fn f() {
+ let v = 10.0f;
+ let x = tint_float_mod(20.0, v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillPreciseFloatMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, PreciseFloatMod_f32_af) {
+ auto* src = R"(
+fn f() {
+ let v = 10.0;
+ let x = 20f % v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_float_mod(lhs : f32, rhs : f32) -> f32 {
+ return (lhs - (trunc((lhs / rhs)) * rhs));
+}
+
+fn f() {
+ let v = 10.0;
+ let x = tint_float_mod(20.0f, v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillPreciseFloatMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, PreciseFloatMod_f32_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 10f;
+ let x = 20f % v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_float_mod(lhs : f32, rhs : f32) -> f32 {
+ return (lhs - (trunc((lhs / rhs)) * rhs));
+}
+
+fn f() {
+ let v = 10.0f;
+ let x = tint_float_mod(20.0f, v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillPreciseFloatMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, PreciseFloatMod_Overloads) {
+ auto* src = R"(
+fn f() {
+ let v = 10f;
+ let x = 20f % v;
+ let w = 10i;
+ let y = 20i % w;
+ let u = 10u;
+ let z = 20u % u;
+}
+)";
+
+ auto* expect = R"(
+fn tint_float_mod(lhs : f32, rhs : f32) -> f32 {
+ return (lhs - (trunc((lhs / rhs)) * rhs));
+}
+
+fn f() {
+ let v = 10.0f;
+ let x = tint_float_mod(20.0f, v);
+ let w = 10i;
+ let y = (20i % w);
+ let u = 10u;
+ let z = (20u % u);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillPreciseFloatMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, PreciseFloatMod_vec3_af_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 10f;
+ let x = vec3(20.0) % v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_float_mod(lhs : vec3<f32>, rhs : f32) -> vec3<f32> {
+ let r = vec3<f32>(rhs);
+ return (lhs - (trunc((lhs / r)) * r));
+}
+
+fn f() {
+ let v = 10.0f;
+ let x = tint_float_mod(vec3(20.0), v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillPreciseFloatMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, PreciseFloatMod_vec3_f32_af) {
+ auto* src = R"(
+fn f() {
+ let v = 10.0;
+ let x = vec3(20f) % v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_float_mod(lhs : vec3<f32>, rhs : f32) -> vec3<f32> {
+ let r = vec3<f32>(rhs);
+ return (lhs - (trunc((lhs / r)) * r));
+}
+
+fn f() {
+ let v = 10.0;
+ let x = tint_float_mod(vec3(20.0f), v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillPreciseFloatMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, PreciseFloatMod_vec3_f32_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 10f;
+ let x = vec3(20f) % v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_float_mod(lhs : vec3<f32>, rhs : f32) -> vec3<f32> {
+ let r = vec3<f32>(rhs);
+ return (lhs - (trunc((lhs / r)) * r));
+}
+
+fn f() {
+ let v = 10.0f;
+ let x = tint_float_mod(vec3(20.0f), v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillPreciseFloatMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, PreciseFloatMod_vec3_f32_vec3_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 10f;
+ let x = vec3<f32>(20f) % vec3<f32>(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_float_mod(lhs : vec3<f32>, rhs : vec3<f32>) -> vec3<f32> {
+ return (lhs - (trunc((lhs / rhs)) * rhs));
+}
+
+fn f() {
+ let v = 10.0f;
+ let x = tint_float_mod(vec3<f32>(20.0f), vec3<f32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillPreciseFloatMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// int_div_mod
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillIntDivMod() {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.int_div_mod = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunIntDiv) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = 20i / v;
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillIntDivMod()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunIntMod) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = 20i % v;
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillIntDivMod()));
+}
+
+TEST_F(BuiltinPolyfillTest, IntDiv_ai_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = 20 / v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_div(lhs : i32, rhs : i32) -> i32 {
+ return (lhs / select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1)))));
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_div(20, v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntMod_ai_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = 20 % v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_mod(lhs : i32, rhs : i32) -> i32 {
+ let rhs_or_one = select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1))));
+ if (any(((u32((lhs | rhs_or_one)) & 2147483648u) != 0u))) {
+ return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
+ } else {
+ return (lhs % rhs_or_one);
+ }
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_mod(20, v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntDiv_i32_ai) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = v / 20;
+}
+)";
+
+ auto* expect = R"(
+fn tint_div(lhs : i32, rhs : i32) -> i32 {
+ return (lhs / select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1)))));
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_div(v, 20);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntMod_i32_ai) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = v % 20;
+}
+)";
+
+ auto* expect = R"(
+fn tint_mod(lhs : i32, rhs : i32) -> i32 {
+ let rhs_or_one = select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1))));
+ if (any(((u32((lhs | rhs_or_one)) & 2147483648u) != 0u))) {
+ return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
+ } else {
+ return (lhs % rhs_or_one);
+ }
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_mod(v, 20);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntDiv_i32_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = 20i / v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_div(lhs : i32, rhs : i32) -> i32 {
+ return (lhs / select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1)))));
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_div(20i, v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntMod_i32_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = 20i % v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_mod(lhs : i32, rhs : i32) -> i32 {
+ let rhs_or_one = select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1))));
+ if (any(((u32((lhs | rhs_or_one)) & 2147483648u) != 0u))) {
+ return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
+ } else {
+ return (lhs % rhs_or_one);
+ }
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_mod(20i, v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntDiv_ai_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 10u;
+ let x = 20 / v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_div(lhs : u32, rhs : u32) -> u32 {
+ return (lhs / select(rhs, 1, (rhs == 0)));
+}
+
+fn f() {
+ let v = 10u;
+ let x = tint_div(20, v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntMod_ai_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 10u;
+ let x = 20 % v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_mod(lhs : u32, rhs : u32) -> u32 {
+ return (lhs % select(rhs, 1, (rhs == 0)));
+}
+
+fn f() {
+ let v = 10u;
+ let x = tint_mod(20, v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntDiv_u32_ai) {
+ auto* src = R"(
+fn f() {
+ let v = 10u;
+ let x = v / 20;
+}
+)";
+
+ auto* expect = R"(
+fn tint_div(lhs : u32, rhs : u32) -> u32 {
+ return (lhs / select(rhs, 1, (rhs == 0)));
+}
+
+fn f() {
+ let v = 10u;
+ let x = tint_div(v, 20);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntMod_u32_ai) {
+ auto* src = R"(
+fn f() {
+ let v = 10u;
+ let x = v % 20;
+}
+)";
+
+ auto* expect = R"(
+fn tint_mod(lhs : u32, rhs : u32) -> u32 {
+ return (lhs % select(rhs, 1, (rhs == 0)));
+}
+
+fn f() {
+ let v = 10u;
+ let x = tint_mod(v, 20);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntDiv_u32_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 10u;
+ let x = 20u / v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_div(lhs : u32, rhs : u32) -> u32 {
+ return (lhs / select(rhs, 1, (rhs == 0)));
+}
+
+fn f() {
+ let v = 10u;
+ let x = tint_div(20u, v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntMod_u32_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 10u;
+ let x = 20u % v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_mod(lhs : u32, rhs : u32) -> u32 {
+ return (lhs % select(rhs, 1, (rhs == 0)));
+}
+
+fn f() {
+ let v = 10u;
+ let x = tint_mod(20u, v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntDiv_vec3_ai_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = vec3(20) / v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_div(lhs : vec3<i32>, rhs : i32) -> vec3<i32> {
+ let r = vec3<i32>(rhs);
+ return (lhs / select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1))))));
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_div(vec3(20), v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntMod_vec3_ai_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = vec3(20) % v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_mod(lhs : vec3<i32>, rhs : i32) -> vec3<i32> {
+ let r = vec3<i32>(rhs);
+ let rhs_or_one = select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1)))));
+ if (any(((vec3<u32>((lhs | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
+ return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
+ } else {
+ return (lhs % rhs_or_one);
+ }
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_mod(vec3(20), v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntDiv_vec3_i32_ai) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = vec3(v) / 20;
+}
+)";
+
+ auto* expect = R"(
+fn tint_div(lhs : vec3<i32>, rhs : i32) -> vec3<i32> {
+ let r = vec3<i32>(rhs);
+ return (lhs / select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1))))));
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_div(vec3(v), 20);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntMod_vec3_i32_ai) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = vec3(v) % 20;
+}
+)";
+
+ auto* expect = R"(
+fn tint_mod(lhs : vec3<i32>, rhs : i32) -> vec3<i32> {
+ let r = vec3<i32>(rhs);
+ let rhs_or_one = select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1)))));
+ if (any(((vec3<u32>((lhs | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
+ return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
+ } else {
+ return (lhs % rhs_or_one);
+ }
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_mod(vec3(v), 20);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntDiv_vec3_i32_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = vec3<i32>(20i) / v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_div(lhs : vec3<i32>, rhs : i32) -> vec3<i32> {
+ let r = vec3<i32>(rhs);
+ return (lhs / select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1))))));
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_div(vec3<i32>(20i), v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntMod_vec3_i32_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = vec3<i32>(20i) % v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_mod(lhs : vec3<i32>, rhs : i32) -> vec3<i32> {
+ let r = vec3<i32>(rhs);
+ let rhs_or_one = select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1)))));
+ if (any(((vec3<u32>((lhs | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
+ return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
+ } else {
+ return (lhs % rhs_or_one);
+ }
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_mod(vec3<i32>(20i), v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntDiv_vec3_u32_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 10u;
+ let x = vec3<u32>(20u) / v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_div(lhs : vec3<u32>, rhs : u32) -> vec3<u32> {
+ let r = vec3<u32>(rhs);
+ return (lhs / select(r, vec3(1), (r == vec3(0))));
+}
+
+fn f() {
+ let v = 10u;
+ let x = tint_div(vec3<u32>(20u), v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntMod_vec3_u32_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 10u;
+ let x = vec3<u32>(20u) % v;
+}
+)";
+
+ auto* expect = R"(
+fn tint_mod(lhs : vec3<u32>, rhs : u32) -> vec3<u32> {
+ let r = vec3<u32>(rhs);
+ return (lhs % select(r, vec3(1), (r == vec3(0))));
+}
+
+fn f() {
+ let v = 10u;
+ let x = tint_mod(vec3<u32>(20u), v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntDiv_ai_vec3_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = 20 / vec3(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_div(lhs : i32, rhs : vec3<i32>) -> vec3<i32> {
+ let l = vec3<i32>(lhs);
+ return (l / select(rhs, vec3(1), ((rhs == vec3(0)) | ((l == vec3(-2147483648)) & (rhs == vec3(-1))))));
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_div(20, vec3(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntMod_ai_vec3_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = 20 % vec3(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_mod(lhs : i32, rhs : vec3<i32>) -> vec3<i32> {
+ let l = vec3<i32>(lhs);
+ let rhs_or_one = select(rhs, vec3(1), ((rhs == vec3(0)) | ((l == vec3(-2147483648)) & (rhs == vec3(-1)))));
+ if (any(((vec3<u32>((l | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
+ return (l - ((l / rhs_or_one) * rhs_or_one));
+ } else {
+ return (l % rhs_or_one);
+ }
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_mod(20, vec3(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntDiv_i32_vec3_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = 20i / vec3<i32>(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_div(lhs : i32, rhs : vec3<i32>) -> vec3<i32> {
+ let l = vec3<i32>(lhs);
+ return (l / select(rhs, vec3(1), ((rhs == vec3(0)) | ((l == vec3(-2147483648)) & (rhs == vec3(-1))))));
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_div(20i, vec3<i32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntMod_i32_vec3_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = 20i % vec3<i32>(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_mod(lhs : i32, rhs : vec3<i32>) -> vec3<i32> {
+ let l = vec3<i32>(lhs);
+ let rhs_or_one = select(rhs, vec3(1), ((rhs == vec3(0)) | ((l == vec3(-2147483648)) & (rhs == vec3(-1)))));
+ if (any(((vec3<u32>((l | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
+ return (l - ((l / rhs_or_one) * rhs_or_one));
+ } else {
+ return (l % rhs_or_one);
+ }
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_mod(20i, vec3<i32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntDiv_u32_vec3_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 10u;
+ let x = 20u / vec3<u32>(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_div(lhs : u32, rhs : vec3<u32>) -> vec3<u32> {
+ let l = vec3<u32>(lhs);
+ return (l / select(rhs, vec3(1), (rhs == vec3(0))));
+}
+
+fn f() {
+ let v = 10u;
+ let x = tint_div(20u, vec3<u32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntMod_u32_vec3_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 10u;
+ let x = 20u % vec3<u32>(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_mod(lhs : u32, rhs : vec3<u32>) -> vec3<u32> {
+ let l = vec3<u32>(lhs);
+ return (l % select(rhs, vec3(1), (rhs == vec3(0))));
+}
+
+fn f() {
+ let v = 10u;
+ let x = tint_mod(20u, vec3<u32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntDiv_vec3_i32_vec3_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = vec3<i32>(20i) / vec3<i32>(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_div(lhs : vec3<i32>, rhs : vec3<i32>) -> vec3<i32> {
+ return (lhs / select(rhs, vec3(1), ((rhs == vec3(0)) | ((lhs == vec3(-2147483648)) & (rhs == vec3(-1))))));
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_div(vec3<i32>(20i), vec3<i32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntMod_vec3_i32_vec3_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 10i;
+ let x = vec3<i32>(20i) % vec3<i32>(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_mod(lhs : vec3<i32>, rhs : vec3<i32>) -> vec3<i32> {
+ let rhs_or_one = select(rhs, vec3(1), ((rhs == vec3(0)) | ((lhs == vec3(-2147483648)) & (rhs == vec3(-1)))));
+ if (any(((vec3<u32>((lhs | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
+ return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
+ } else {
+ return (lhs % rhs_or_one);
+ }
+}
+
+fn f() {
+ let v = 10i;
+ let x = tint_mod(vec3<i32>(20i), vec3<i32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntDiv_vec3_u32_vec3_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 10u;
+ let x = vec3<u32>(20u) / vec3<u32>(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_div(lhs : vec3<u32>, rhs : vec3<u32>) -> vec3<u32> {
+ return (lhs / select(rhs, vec3(1), (rhs == vec3(0))));
+}
+
+fn f() {
+ let v = 10u;
+ let x = tint_div(vec3<u32>(20u), vec3<u32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, IntMod_vec3_u32_vec3_u32) {
+ auto* src = R"(
+fn f() {
+ let v = 10u;
+ let x = vec3<u32>(20u) % vec3<u32>(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_mod(lhs : vec3<u32>, rhs : vec3<u32>) -> vec3<u32> {
+ return (lhs % select(rhs, vec3(1), (rhs == vec3(0))));
+}
+
+fn f() {
+ let v = 10u;
+ let x = tint_mod(vec3<u32>(20u), vec3<u32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillIntDivMod());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// reflect for vec2<f32>
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillReflectVec2F32() {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.reflect_vec2_f32 = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunReflect_vec2_f32) {
+ auto* src = R"(
+fn f() {
+ let e1 = vec2<f32>(1.0f);
+ let e2 = vec2<f32>(1.0f);
+ let x = reflect(e1, e2);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillReflectVec2F32()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunReflect_vec2_f16) {
+ auto* src = R"(
+enable f16;
+
+fn f() {
+ let e1 = vec2<f16>(1.0h);
+ let e2 = vec2<f16>(1.0h);
+ let x = reflect(e1, e2);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillReflectVec2F32()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunReflect_vec3_f32) {
+ auto* src = R"(
+fn f() {
+ let e1 = vec3<f32>(1.0f);
+ let e2 = vec3<f32>(1.0f);
+ let x = reflect(e1, e2);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillReflectVec2F32()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunReflect_vec3_f16) {
+ auto* src = R"(
+enable f16;
+
+fn f() {
+ let e1 = vec3<f16>(1.0h);
+ let e2 = vec3<f16>(1.0h);
+ let x = reflect(e1, e2);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillReflectVec2F32()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunReflect_vec4_f32) {
+ auto* src = R"(
+fn f() {
+ let e1 = vec3<f32>(1.0f);
+ let e2 = vec3<f32>(1.0f);
+ let x = reflect(e1, e2);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillReflectVec2F32()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunReflect_vec4_f16) {
+ auto* src = R"(
+enable f16;
+
+fn f() {
+ let e1 = vec3<f16>(1.0h);
+ let e2 = vec3<f16>(1.0h);
+ let x = reflect(e1, e2);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillReflectVec2F32()));
+}
+
+TEST_F(BuiltinPolyfillTest, Reflect_ConstantExpression) {
+ auto* src = R"(
+fn f() {
+ let r : vec2<f32> = reflect(vec2<f32>(1.0), vec2<f32>(1.0));
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillReflectVec2F32()));
+}
+
+TEST_F(BuiltinPolyfillTest, Reflect_vec2_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 0.5f;
+ let r : vec2<f32> = reflect(vec2<f32>(v), vec2<f32>(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_reflect(e1 : vec2<f32>, e2 : vec2<f32>) -> vec2<f32> {
+ let factor = (-2.0 * dot(e1, e2));
+ return (e1 + (factor * e2));
+}
+
+fn f() {
+ let v = 0.5f;
+ let r : vec2<f32> = tint_reflect(vec2<f32>(v), vec2<f32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillReflectVec2F32());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, Reflect_multiple_types) {
+ auto* src = R"(
+enable f16;
+
+fn f() {
+ let in_f32 = 0.5f;
+ let out_f32_vec2 : vec2<f32> = reflect(vec2<f32>(in_f32), vec2<f32>(in_f32));
+ let out_f32_vec3 : vec3<f32> = reflect(vec3<f32>(in_f32), vec3<f32>(in_f32));
+ let out_f32_vec4 : vec4<f32> = reflect(vec4<f32>(in_f32), vec4<f32>(in_f32));
+ let in_f16 = 0.5h;
+ let out_f16_vec2 : vec2<f16> = reflect(vec2<f16>(in_f16), vec2<f16>(in_f16));
+ let out_f16_vec3 : vec3<f16> = reflect(vec3<f16>(in_f16), vec3<f16>(in_f16));
+ let out_f16_vec4 : vec4<f16> = reflect(vec4<f16>(in_f16), vec4<f16>(in_f16));
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+fn tint_reflect(e1 : vec2<f32>, e2 : vec2<f32>) -> vec2<f32> {
+ let factor = (-2.0 * dot(e1, e2));
+ return (e1 + (factor * e2));
+}
+
+fn f() {
+ let in_f32 = 0.5f;
+ let out_f32_vec2 : vec2<f32> = tint_reflect(vec2<f32>(in_f32), vec2<f32>(in_f32));
+ let out_f32_vec3 : vec3<f32> = reflect(vec3<f32>(in_f32), vec3<f32>(in_f32));
+ let out_f32_vec4 : vec4<f32> = reflect(vec4<f32>(in_f32), vec4<f32>(in_f32));
+ let in_f16 = 0.5h;
+ let out_f16_vec2 : vec2<f16> = reflect(vec2<f16>(in_f16), vec2<f16>(in_f16));
+ let out_f16_vec3 : vec3<f16> = reflect(vec3<f16>(in_f16), vec3<f16>(in_f16));
+ let out_f16_vec4 : vec4<f16> = reflect(vec4<f16>(in_f16), vec4<f16>(in_f16));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillReflectVec2F32());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// saturate
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillSaturate() {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.saturate = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunSaturate) {
+ auto* src = R"(
+fn f() {
+ let v = 0.5f;
+ _ = saturate(v);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillSaturate()));
+}
+
+TEST_F(BuiltinPolyfillTest, Saturate_ConstantExpression) {
+ auto* src = R"(
+fn f() {
+ let r : f32 = saturate(0.5);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillSaturate()));
+}
+
+TEST_F(BuiltinPolyfillTest, Saturate_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 0.5f;
+ let r : f32 = saturate(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_saturate(v : f32) -> f32 {
+ return clamp(v, f32(0), f32(1));
+}
+
+fn f() {
+ let v = 0.5f;
+ let r : f32 = tint_saturate(v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillSaturate());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, Saturate_f16) {
+ auto* src = R"(
+enable f16;
+
+fn f() {
+ let v = 0.5h;
+ let r : f16 = saturate(v);
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+fn tint_saturate(v : f16) -> f16 {
+ return clamp(v, f16(0), f16(1));
+}
+
+fn f() {
+ let v = 0.5h;
+ let r : f16 = tint_saturate(v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillSaturate());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, Saturate_vec3_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 0.5f;
+ let r : vec3<f32> = saturate(vec3<f32>(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_saturate(v : vec3<f32>) -> vec3<f32> {
+ return clamp(v, vec3<f32>(0), vec3<f32>(1));
+}
+
+fn f() {
+ let v = 0.5f;
+ let r : vec3<f32> = tint_saturate(vec3<f32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillSaturate());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, Saturate_vec3_f16) {
+ auto* src = R"(
+enable f16;
+
+fn f() {
+ let v = 0.5h;
+ let r : vec3<f16> = saturate(vec3<f16>(v));
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+fn tint_saturate(v : vec3<f16>) -> vec3<f16> {
+ return clamp(v, vec3<f16>(0), vec3<f16>(1));
+}
+
+fn f() {
+ let v = 0.5h;
+ let r : vec3<f16> = tint_saturate(vec3<f16>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillSaturate());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// sign_int
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillSignInt() {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.sign_int = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunSign_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 1i;
+ _ = sign(v);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillSignInt()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunSign_f32) {
+ auto* src = R"(
+fn f() {
+ let v = 1f;
+ _ = sign(v);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillSignInt()));
+}
+
+TEST_F(BuiltinPolyfillTest, SignInt_ConstantExpression) {
+ auto* src = R"(
+fn f() {
+ let r : i32 = sign(1i);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillSignInt()));
+}
+
+TEST_F(BuiltinPolyfillTest, SignInt_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 1i;
+ let r : i32 = sign(v);
+}
+)";
+
+ auto* expect = R"(
+fn tint_sign(v : i32) -> i32 {
+ return select(select(-1, 1, (v > 0)), 0, (v == 0));
+}
+
+fn f() {
+ let v = 1i;
+ let r : i32 = tint_sign(v);
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillSignInt());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, SignInt_vec3_i32) {
+ auto* src = R"(
+fn f() {
+ let v = 1i;
+ let r : vec3<i32> = sign(vec3<i32>(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_sign(v : vec3<i32>) -> vec3<i32> {
+ return select(select(vec3(-1), vec3(1), (v > vec3(0))), vec3(0), (v == vec3(0)));
+}
+
+fn f() {
+ let v = 1i;
+ let r : vec3<i32> = tint_sign(vec3<i32>(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillSignInt());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// textureSampleBaseClampToEdge
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillTextureSampleBaseClampToEdge_2d_f32() {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.texture_sample_base_clamp_to_edge_2d_f32 = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunTextureSampleBaseClampToEdge_2d_f32) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+@group(0) @binding(1) var s : sampler;
+
+fn f() {
+ _ = textureSampleBaseClampToEdge(t, s, vec2<f32>(0.5));
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillTextureSampleBaseClampToEdge_2d_f32()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunTextureSampleBaseClampToEdge_external) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_external;
+@group(0) @binding(1) var s : sampler;
+
+fn f() {
+ _ = textureSampleBaseClampToEdge(t, s, vec2<f32>(0.5));
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillTextureSampleBaseClampToEdge_2d_f32()));
+}
+
+TEST_F(BuiltinPolyfillTest, TextureSampleBaseClampToEdge_2d_f32_f32) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+@group(0) @binding(1) var s : sampler;
+
+fn f() {
+ let r = textureSampleBaseClampToEdge(t, s, vec2<f32>(0.5));
+}
+)";
+
+ auto* expect = R"(
+fn tint_textureSampleBaseClampToEdge(t : texture_2d<f32>, s : sampler, coord : vec2<f32>) -> vec4<f32> {
+ let dims = vec2<f32>(textureDimensions(t, 0));
+ let half_texel = (vec2<f32>(0.5) / dims);
+ let clamped = clamp(coord, half_texel, (1 - half_texel));
+ return textureSampleLevel(t, s, clamped, 0);
+}
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn f() {
+ let r = tint_textureSampleBaseClampToEdge(t, s, vec2<f32>(0.5));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillTextureSampleBaseClampToEdge_2d_f32());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// workgroupUniformLoad
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillWorkgroupUniformLoad() {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.workgroup_uniform_load = true;
+
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+
+ return data;
+}
+
+Transform::DataMap polyfillWorkgroupUniformLoadWithDirectVariableAccess() {
+ Transform::DataMap data;
+
+ BuiltinPolyfill::Builtins builtins;
+ builtins.workgroup_uniform_load = true;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+
+ DirectVariableAccess::Options options;
+ data.Add<DirectVariableAccess::Config>(options);
+
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunWorkgroupUniformLoad) {
+ auto* src = R"(
+var<workgroup> v : i32;
+
+fn f() {
+ _ = workgroupUniformLoad(&v);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillWorkgroupUniformLoad()));
+}
+
+TEST_F(BuiltinPolyfillTest, WorkgroupUniformLoad_i32) {
+ auto* src = R"(
+var<workgroup> v : i32;
+
+fn f() {
+ let r = workgroupUniformLoad(&v);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn tint_workgroupUniformLoad(p : ptr<workgroup, i32>) -> i32 {
+ workgroupBarrier();
+ let result = *(p);
+ workgroupBarrier();
+ return result;
+}
+
+var<workgroup> v : i32;
+
+fn f() {
+ let r = tint_workgroupUniformLoad(&(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillWorkgroupUniformLoad());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, WorkgroupUniformLoad_ComplexType) {
+ auto* src = R"(
+struct Inner {
+ b : bool,
+ v : vec4<i32>,
+ m : mat3x3<f32>,
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+var<workgroup> v : Outer;
+
+fn f() {
+ let r = workgroupUniformLoad(&v);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn tint_workgroupUniformLoad(p : ptr<workgroup, Outer>) -> Outer {
+ workgroupBarrier();
+ let result = *(p);
+ workgroupBarrier();
+ return result;
+}
+
+struct Inner {
+ b : bool,
+ v : vec4<i32>,
+ m : mat3x3<f32>,
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+var<workgroup> v : Outer;
+
+fn f() {
+ let r = tint_workgroupUniformLoad(&(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillWorkgroupUniformLoad());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, WorkgroupUniformLoad_AvoidDuplicateEnables) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<workgroup> a : i32;
+var<workgroup> b : u32;
+var<workgroup> c : f32;
+
+fn f() {
+ let ra = workgroupUniformLoad(&a);
+ let rb = workgroupUniformLoad(&b);
+ let rc = workgroupUniformLoad(&c);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn tint_workgroupUniformLoad(p : ptr<workgroup, i32>) -> i32 {
+ workgroupBarrier();
+ let result = *(p);
+ workgroupBarrier();
+ return result;
+}
+
+fn tint_workgroupUniformLoad_1(p : ptr<workgroup, u32>) -> u32 {
+ workgroupBarrier();
+ let result = *(p);
+ workgroupBarrier();
+ return result;
+}
+
+fn tint_workgroupUniformLoad_2(p : ptr<workgroup, f32>) -> f32 {
+ workgroupBarrier();
+ let result = *(p);
+ workgroupBarrier();
+ return result;
+}
+
+var<workgroup> a : i32;
+
+var<workgroup> b : u32;
+
+var<workgroup> c : f32;
+
+fn f() {
+ let ra = tint_workgroupUniformLoad(&(a));
+ let rb = tint_workgroupUniformLoad_1(&(b));
+ let rc = tint_workgroupUniformLoad_2(&(c));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillWorkgroupUniformLoad());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, WorkgroupUniformLoad_DirectVariableAccess) {
+ auto* src = R"(
+var<workgroup> v : i32;
+var<workgroup> v2 : i32;
+
+fn f() {
+ let r = workgroupUniformLoad(&v);
+ let s = workgroupUniformLoad(&v2);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn tint_workgroupUniformLoad_v() -> i32 {
+ workgroupBarrier();
+ let result = v;
+ workgroupBarrier();
+ return result;
+}
+
+fn tint_workgroupUniformLoad_v2() -> i32 {
+ workgroupBarrier();
+ let result = v2;
+ workgroupBarrier();
+ return result;
+}
+
+var<workgroup> v : i32;
+
+var<workgroup> v2 : i32;
+
+fn f() {
+ let r = tint_workgroupUniformLoad_v();
+ let s = tint_workgroupUniformLoad_v2();
+}
+)";
+
+ auto got = Run<BuiltinPolyfill, DirectVariableAccess>(
+ src, polyfillWorkgroupUniformLoadWithDirectVariableAccess());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// quantizeToF16
+////////////////////////////////////////////////////////////////////////////////
+Transform::DataMap polyfillQuantizeToF16_2d_f32() {
+ BuiltinPolyfill::Builtins builtins;
+ builtins.quantize_to_vec_f16 = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+ return data;
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunQuantizeToF16_Scalar) {
+ auto* src = R"(
+fn f() {
+ let v = 0.5;
+ _ = quantizeToF16(0.5);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillQuantizeToF16_2d_f32()));
+}
+
+TEST_F(BuiltinPolyfillTest, ShouldRunQuantizeToF16_Vector) {
+ auto* src = R"(
+fn f() {
+ let v = 0.5;
+ _ = quantizeToF16(vec2(v));
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
+ EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillQuantizeToF16_2d_f32()));
+}
+
+TEST_F(BuiltinPolyfillTest, QuantizeToF16_Vec2) {
+ auto* src = R"(
+fn f() {
+ let v = 0.5;
+ _ = quantizeToF16(vec2(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_quantizeToF16(v : vec2<f32>) -> vec2<f32> {
+ return vec2<f32>(quantizeToF16(v[0u]), quantizeToF16(v[1u]));
+}
+
+fn f() {
+ let v = 0.5;
+ _ = tint_quantizeToF16(vec2(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillQuantizeToF16_2d_f32());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, QuantizeToF16_Vec3) {
+ auto* src = R"(
+fn f() {
+ let v = 0.5;
+ _ = quantizeToF16(vec3(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_quantizeToF16(v : vec3<f32>) -> vec3<f32> {
+ return vec3<f32>(quantizeToF16(v[0u]), quantizeToF16(v[1u]), quantizeToF16(v[2u]));
+}
+
+fn f() {
+ let v = 0.5;
+ _ = tint_quantizeToF16(vec3(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillQuantizeToF16_2d_f32());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(BuiltinPolyfillTest, QuantizeToF16_Vec4) {
+ auto* src = R"(
+fn f() {
+ let v = 0.5;
+ _ = quantizeToF16(vec4(v));
+}
+)";
+
+ auto* expect = R"(
+fn tint_quantizeToF16(v : vec4<f32>) -> vec4<f32> {
+ return vec4<f32>(quantizeToF16(v[0u]), quantizeToF16(v[1u]), quantizeToF16(v[2u]), quantizeToF16(v[3u]));
+}
+
+fn f() {
+ let v = 0.5;
+ _ = tint_quantizeToF16(vec4(v));
+}
+)";
+
+ auto got = Run<BuiltinPolyfill>(src, polyfillQuantizeToF16_2d_f32());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Polyfill combinations
+////////////////////////////////////////////////////////////////////////////////
+
+TEST_F(BuiltinPolyfillTest, BitshiftAndModulo) {
+ auto* src = R"(
+fn f(x : i32, y : u32, z : u32) {
+ let l = x << (y % z);
+}
+)";
+
+ auto* expect = R"(
+fn tint_mod(lhs : u32, rhs : u32) -> u32 {
+ return (lhs % select(rhs, 1, (rhs == 0)));
+}
+
+fn f(x : i32, y : u32, z : u32) {
+ let l = (x << (tint_mod(y, z) & 31));
+}
+)";
+
+ BuiltinPolyfill::Builtins builtins;
+ builtins.bitshift_modulo = true;
+ builtins.int_div_mod = true;
+ Transform::DataMap data;
+ data.Add<BuiltinPolyfill::Config>(builtins);
+
+ auto got = Run<BuiltinPolyfill>(src, std::move(data));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/calculate_array_length.cc b/src/tint/lang/wgsl/ast/transform/calculate_array_length.cc
new file mode 100644
index 0000000..7844aca
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/calculate_array_length.cc
@@ -0,0 +1,246 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/calculate_array_length.h"
+
+#include <unordered_map>
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/disable_validation_attribute.h"
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/block_statement.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/struct.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
+#include "src/tint/type/reference.h"
+#include "src/tint/utils/hash.h"
+#include "src/tint/utils/map.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::CalculateArrayLength);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::CalculateArrayLength::BufferSizeIntrinsic);
+
+namespace tint::ast::transform {
+namespace {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+bool ShouldRun(const Program* program) {
+ for (auto* fn : program->AST().Functions()) {
+ if (auto* sem_fn = program->Sem().Get(fn)) {
+ for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
+ if (builtin->Type() == builtin::Function::kArrayLength) {
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
+
+/// ArrayUsage describes a runtime array usage.
+/// It is used as a key by the array_length_by_usage map.
+struct ArrayUsage {
+ BlockStatement const* const block;
+ sem::Variable const* const buffer;
+ bool operator==(const ArrayUsage& rhs) const {
+ return block == rhs.block && buffer == rhs.buffer;
+ }
+ struct Hasher {
+ inline std::size_t operator()(const ArrayUsage& u) const {
+ return utils::Hash(u.block, u.buffer);
+ }
+ };
+};
+
+} // namespace
+
+CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic(ProgramID pid, NodeID nid)
+ : Base(pid, nid, utils::Empty) {}
+CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default;
+std::string CalculateArrayLength::BufferSizeIntrinsic::InternalName() const {
+ return "intrinsic_buffer_size";
+}
+
+const CalculateArrayLength::BufferSizeIntrinsic* CalculateArrayLength::BufferSizeIntrinsic::Clone(
+ CloneContext* ctx) const {
+ return ctx->dst->ASTNodes().Create<CalculateArrayLength::BufferSizeIntrinsic>(
+ ctx->dst->ID(), ctx->dst->AllocateNodeID());
+}
+
+CalculateArrayLength::CalculateArrayLength() = default;
+CalculateArrayLength::~CalculateArrayLength() = default;
+
+Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+ auto& sem = src->Sem();
+
+ // get_buffer_size_intrinsic() emits the function decorated with
+ // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
+ // [RW]ByteAddressBuffer.GetDimensions().
+ std::unordered_map<const type::Reference*, Symbol> buffer_size_intrinsics;
+ auto get_buffer_size_intrinsic = [&](const type::Reference* buffer_type) {
+ return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
+ auto name = b.Sym();
+ auto type = CreateASTTypeFor(ctx, buffer_type);
+ auto* disable_validation = b.Disable(DisabledValidation::kFunctionParameter);
+ b.Func(name,
+ utils::Vector{
+ b.Param("buffer",
+ b.ty.ptr(buffer_type->AddressSpace(), type, buffer_type->Access()),
+ utils::Vector{disable_validation}),
+ b.Param("result", b.ty.ptr<function, u32>()),
+ },
+ b.ty.void_(), nullptr,
+ utils::Vector{
+ b.ASTNodes().Create<BufferSizeIntrinsic>(b.ID(), b.AllocateNodeID()),
+ });
+
+ return name;
+ });
+ };
+
+ std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher> array_length_by_usage;
+
+ // Find all the arrayLength() calls...
+ for (auto* node : src->ASTNodes().Objects()) {
+ if (auto* call_expr = node->As<CallExpression>()) {
+ auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
+ if (auto* builtin = call->Target()->As<sem::Builtin>()) {
+ if (builtin->Type() == builtin::Function::kArrayLength) {
+ // We're dealing with an arrayLength() call
+
+ if (auto* call_stmt = call->Stmt()->Declaration()->As<CallStatement>()) {
+ if (call_stmt->expr == call_expr) {
+ // arrayLength() is used as a statement.
+ // The argument expression must be side-effect free, so just drop the
+ // statement.
+ RemoveStatement(ctx, call_stmt);
+ continue;
+ }
+ }
+
+ // A runtime-sized array can only appear as the store type of a variable, or the
+ // last element of a structure (which cannot itself be nested). Given that we
+ // require SimplifyPointers, we can assume that the arrayLength() call has one
+ // of two forms:
+ // arrayLength(&struct_var.array_member)
+ // arrayLength(&array_var)
+ auto* arg = call_expr->args[0];
+ auto* address_of = arg->As<UnaryOpExpression>();
+ if (TINT_UNLIKELY(!address_of || address_of->op != UnaryOp::kAddressOf)) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "arrayLength() expected address-of, got " << arg->TypeInfo().name;
+ }
+ auto* storage_buffer_expr = address_of->expr;
+ if (auto* accessor = storage_buffer_expr->As<MemberAccessorExpression>()) {
+ storage_buffer_expr = accessor->object;
+ }
+ auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
+ if (TINT_UNLIKELY(!storage_buffer_sem)) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "expected form of arrayLength argument to be &array_var or "
+ "&struct_var.array_member";
+ break;
+ }
+ auto* storage_buffer_var = storage_buffer_sem->Variable();
+ auto* storage_buffer_type = storage_buffer_sem->Type()->As<type::Reference>();
+
+ // Generate BufferSizeIntrinsic for this storage type if we haven't already
+ auto buffer_size = get_buffer_size_intrinsic(storage_buffer_type);
+
+ // Find the current statement block
+ auto* block = call->Stmt()->Block()->Declaration();
+
+ auto array_length =
+ utils::GetOrCreate(array_length_by_usage, {block, storage_buffer_var}, [&] {
+ // First time this array length is used for this block.
+ // Let's calculate it.
+
+ // Construct the variable that'll hold the result of
+ // RWByteAddressBuffer.GetDimensions()
+ auto* buffer_size_result =
+ b.Decl(b.Var(b.Sym(), b.ty.u32(), b.Expr(0_u)));
+
+ // Call storage_buffer.GetDimensions(&buffer_size_result)
+ auto* call_get_dims = b.CallStmt(b.Call(
+ // BufferSizeIntrinsic(X, ARGS...) is
+ // translated to:
+ // X.GetDimensions(ARGS..) by the writer
+ buffer_size, b.AddressOf(ctx.Clone(storage_buffer_expr)),
+ b.AddressOf(b.Expr(buffer_size_result->variable->name->symbol))));
+
+ // Calculate actual array length
+ // total_storage_buffer_size - array_offset
+ // array_length = ----------------------------------------
+ // array_stride
+ auto name = b.Sym();
+ const Expression* total_size = b.Expr(buffer_size_result->variable);
+
+ const type::Array* array_type = Switch(
+ storage_buffer_type->StoreType(),
+ [&](const type::Struct* str) {
+ // The variable is a struct, so subtract the byte offset of
+ // the array member.
+ auto* array_member_sem = str->Members().Back();
+ total_size = b.Sub(total_size, u32(array_member_sem->Offset()));
+ return array_member_sem->Type()->As<type::Array>();
+ },
+ [&](const type::Array* arr) { return arr; });
+
+ if (TINT_UNLIKELY(!array_type)) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "expected form of arrayLength argument to be "
+ "&array_var or &struct_var.array_member";
+ return name;
+ }
+
+ uint32_t array_stride = array_type->Size();
+ auto* array_length_var = b.Decl(
+ b.Let(name, b.ty.u32(), b.Div(total_size, u32(array_stride))));
+
+ // Insert the array length calculations at the top of the block
+ ctx.InsertBefore(block->statements, block->statements[0],
+ buffer_size_result);
+ ctx.InsertBefore(block->statements, block->statements[0],
+ call_get_dims);
+ ctx.InsertBefore(block->statements, block->statements[0],
+ array_length_var);
+ return name;
+ });
+
+ // Replace the call to arrayLength() with the array length variable
+ ctx.Replace(call_expr, b.Expr(array_length));
+ }
+ }
+ }
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/calculate_array_length.h b/src/tint/lang/wgsl/ast/transform/calculate_array_length.h
new file mode 100644
index 0000000..c9139ff
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/calculate_array_length.h
@@ -0,0 +1,71 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/internal_attribute.h"
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+// Forward declarations
+namespace tint {
+class CloneContext;
+} // namespace tint
+
+namespace tint::ast::transform {
+
+/// CalculateArrayLength is a transform used to replace calls to arrayLength()
+/// with a value calculated from the size of the storage buffer.
+///
+/// @note Depends on the following transforms to have been run first:
+/// * SimplifyPointers
+class CalculateArrayLength final : public utils::Castable<CalculateArrayLength, Transform> {
+ public:
+ /// BufferSizeIntrinsic is an InternalAttribute that's applied to intrinsic
+ /// functions used to obtain the runtime size of a storage buffer.
+ class BufferSizeIntrinsic final
+ : public utils::Castable<BufferSizeIntrinsic, InternalAttribute> {
+ public:
+ /// Constructor
+ /// @param program_id the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ BufferSizeIntrinsic(ProgramID program_id, NodeID nid);
+ /// Destructor
+ ~BufferSizeIntrinsic() override;
+
+ /// @return "buffer_size"
+ std::string InternalName() const override;
+
+ /// Performs a deep clone of this object using the CloneContext `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned object
+ const BufferSizeIntrinsic* Clone(CloneContext* ctx) const override;
+ };
+
+ /// Constructor
+ CalculateArrayLength();
+ /// Destructor
+ ~CalculateArrayLength() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_
diff --git a/src/tint/lang/wgsl/ast/transform/calculate_array_length_test.cc b/src/tint/lang/wgsl/ast/transform/calculate_array_length_test.cc
new file mode 100644
index 0000000..38115a2
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/calculate_array_length_test.cc
@@ -0,0 +1,551 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/calculate_array_length.h"
+
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using CalculateArrayLengthTest = TransformTest;
+
+TEST_F(CalculateArrayLengthTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<CalculateArrayLength>(src));
+}
+
+TEST_F(CalculateArrayLengthTest, ShouldRunNoArrayLength) {
+ auto* src = R"(
+struct SB {
+ x : i32,
+ arr : array<i32>,
+};
+
+@group(0) @binding(0) var<storage, read> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<CalculateArrayLength>(src));
+}
+
+TEST_F(CalculateArrayLengthTest, ShouldRunWithArrayLength) {
+ auto* src = R"(
+struct SB {
+ x : i32,
+ arr : array<i32>,
+};
+
+@group(0) @binding(0) var<storage, read> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ var len : u32 = arrayLength(&sb.arr);
+}
+)";
+
+ EXPECT_TRUE(ShouldRun<CalculateArrayLength>(src));
+}
+
+TEST_F(CalculateArrayLengthTest, BasicArray) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read> sb : array<i32>;
+
+@compute @workgroup_size(1)
+fn main() {
+ var len : u32 = arrayLength(&sb);
+}
+)";
+
+ auto* expect = R"(
+@internal(intrinsic_buffer_size)
+fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<i32>, read>, result : ptr<function, u32>)
+
+@group(0) @binding(0) var<storage, read> sb : array<i32>;
+
+@compute @workgroup_size(1)
+fn main() {
+ var tint_symbol_1 : u32 = 0u;
+ tint_symbol(&(sb), &(tint_symbol_1));
+ let tint_symbol_2 : u32 = (tint_symbol_1 / 4u);
+ var len : u32 = tint_symbol_2;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CalculateArrayLengthTest, BasicInStruct) {
+ auto* src = R"(
+struct SB {
+ x : i32,
+ arr : array<i32>,
+};
+
+@group(0) @binding(0) var<storage, read> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ var len : u32 = arrayLength(&sb.arr);
+}
+)";
+
+ auto* expect = R"(
+@internal(intrinsic_buffer_size)
+fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read>, result : ptr<function, u32>)
+
+struct SB {
+ x : i32,
+ arr : array<i32>,
+}
+
+@group(0) @binding(0) var<storage, read> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ var tint_symbol_1 : u32 = 0u;
+ tint_symbol(&(sb), &(tint_symbol_1));
+ let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
+ var len : u32 = tint_symbol_2;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CalculateArrayLengthTest, ArrayOfStruct) {
+ auto* src = R"(
+struct S {
+ f : f32,
+}
+
+@group(0) @binding(0) var<storage, read> arr : array<S>;
+
+@compute @workgroup_size(1)
+fn main() {
+ let len = arrayLength(&arr);
+}
+)";
+ auto* expect = R"(
+@internal(intrinsic_buffer_size)
+fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<S>, read>, result : ptr<function, u32>)
+
+struct S {
+ f : f32,
+}
+
+@group(0) @binding(0) var<storage, read> arr : array<S>;
+
+@compute @workgroup_size(1)
+fn main() {
+ var tint_symbol_1 : u32 = 0u;
+ tint_symbol(&(arr), &(tint_symbol_1));
+ let tint_symbol_2 : u32 = (tint_symbol_1 / 4u);
+ let len = tint_symbol_2;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CalculateArrayLengthTest, ArrayOfArrayOfStruct) {
+ auto* src = R"(
+struct S {
+ f : f32,
+}
+
+@group(0) @binding(0) var<storage, read> arr : array<array<S, 4>>;
+
+@compute @workgroup_size(1)
+fn main() {
+ let len = arrayLength(&arr);
+}
+)";
+ auto* expect = R"(
+@internal(intrinsic_buffer_size)
+fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<array<S, 4u>>, read>, result : ptr<function, u32>)
+
+struct S {
+ f : f32,
+}
+
+@group(0) @binding(0) var<storage, read> arr : array<array<S, 4>>;
+
+@compute @workgroup_size(1)
+fn main() {
+ var tint_symbol_1 : u32 = 0u;
+ tint_symbol(&(arr), &(tint_symbol_1));
+ let tint_symbol_2 : u32 = (tint_symbol_1 / 16u);
+ let len = tint_symbol_2;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CalculateArrayLengthTest, InSameBlock) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read> sb : array<i32>;;
+
+@compute @workgroup_size(1)
+fn main() {
+ var a : u32 = arrayLength(&sb);
+ var b : u32 = arrayLength(&sb);
+ var c : u32 = arrayLength(&sb);
+}
+)";
+
+ auto* expect = R"(
+@internal(intrinsic_buffer_size)
+fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<i32>, read>, result : ptr<function, u32>)
+
+@group(0) @binding(0) var<storage, read> sb : array<i32>;
+
+@compute @workgroup_size(1)
+fn main() {
+ var tint_symbol_1 : u32 = 0u;
+ tint_symbol(&(sb), &(tint_symbol_1));
+ let tint_symbol_2 : u32 = (tint_symbol_1 / 4u);
+ var a : u32 = tint_symbol_2;
+ var b : u32 = tint_symbol_2;
+ var c : u32 = tint_symbol_2;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CalculateArrayLengthTest, InSameBlock_Struct) {
+ auto* src = R"(
+struct SB {
+ x : i32,
+ arr : array<i32>,
+};
+
+@group(0) @binding(0) var<storage, read> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ var a : u32 = arrayLength(&sb.arr);
+ var b : u32 = arrayLength(&sb.arr);
+ var c : u32 = arrayLength(&sb.arr);
+}
+)";
+
+ auto* expect = R"(
+@internal(intrinsic_buffer_size)
+fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read>, result : ptr<function, u32>)
+
+struct SB {
+ x : i32,
+ arr : array<i32>,
+}
+
+@group(0) @binding(0) var<storage, read> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ var tint_symbol_1 : u32 = 0u;
+ tint_symbol(&(sb), &(tint_symbol_1));
+ let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
+ var a : u32 = tint_symbol_2;
+ var b : u32 = tint_symbol_2;
+ var c : u32 = tint_symbol_2;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CalculateArrayLengthTest, Nested) {
+ auto* src = R"(
+struct SB {
+ x : i32,
+ arr : array<i32>,
+};
+
+@group(0) @binding(0) var<storage, read> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ if (true) {
+ var len : u32 = arrayLength(&sb.arr);
+ } else {
+ if (true) {
+ var len : u32 = arrayLength(&sb.arr);
+ }
+ }
+}
+)";
+
+ auto* expect = R"(
+@internal(intrinsic_buffer_size)
+fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read>, result : ptr<function, u32>)
+
+struct SB {
+ x : i32,
+ arr : array<i32>,
+}
+
+@group(0) @binding(0) var<storage, read> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ if (true) {
+ var tint_symbol_1 : u32 = 0u;
+ tint_symbol(&(sb), &(tint_symbol_1));
+ let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
+ var len : u32 = tint_symbol_2;
+ } else {
+ if (true) {
+ var tint_symbol_3 : u32 = 0u;
+ tint_symbol(&(sb), &(tint_symbol_3));
+ let tint_symbol_4 : u32 = ((tint_symbol_3 - 4u) / 4u);
+ var len : u32 = tint_symbol_4;
+ }
+ }
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CalculateArrayLengthTest, MultipleStorageBuffers) {
+ auto* src = R"(
+struct SB1 {
+ x : i32,
+ arr1 : array<i32>,
+};
+
+struct SB2 {
+ x : i32,
+ arr2 : array<vec4<f32>>,
+};
+
+@group(0) @binding(0) var<storage, read> sb1 : SB1;
+
+@group(0) @binding(1) var<storage, read> sb2 : SB2;
+
+@group(0) @binding(2) var<storage, read> sb3 : array<i32>;
+
+@compute @workgroup_size(1)
+fn main() {
+ var len1 : u32 = arrayLength(&(sb1.arr1));
+ var len2 : u32 = arrayLength(&(sb2.arr2));
+ var len3 : u32 = arrayLength(&sb3);
+ var x : u32 = (len1 + len2 + len3);
+}
+)";
+
+ auto* expect = R"(
+@internal(intrinsic_buffer_size)
+fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB1, read>, result : ptr<function, u32>)
+
+@internal(intrinsic_buffer_size)
+fn tint_symbol_3(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB2, read>, result : ptr<function, u32>)
+
+@internal(intrinsic_buffer_size)
+fn tint_symbol_6(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<i32>, read>, result : ptr<function, u32>)
+
+struct SB1 {
+ x : i32,
+ arr1 : array<i32>,
+}
+
+struct SB2 {
+ x : i32,
+ arr2 : array<vec4<f32>>,
+}
+
+@group(0) @binding(0) var<storage, read> sb1 : SB1;
+
+@group(0) @binding(1) var<storage, read> sb2 : SB2;
+
+@group(0) @binding(2) var<storage, read> sb3 : array<i32>;
+
+@compute @workgroup_size(1)
+fn main() {
+ var tint_symbol_1 : u32 = 0u;
+ tint_symbol(&(sb1), &(tint_symbol_1));
+ let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
+ var tint_symbol_4 : u32 = 0u;
+ tint_symbol_3(&(sb2), &(tint_symbol_4));
+ let tint_symbol_5 : u32 = ((tint_symbol_4 - 16u) / 16u);
+ var tint_symbol_7 : u32 = 0u;
+ tint_symbol_6(&(sb3), &(tint_symbol_7));
+ let tint_symbol_8 : u32 = (tint_symbol_7 / 4u);
+ var len1 : u32 = tint_symbol_2;
+ var len2 : u32 = tint_symbol_5;
+ var len3 : u32 = tint_symbol_8;
+ var x : u32 = ((len1 + len2) + len3);
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CalculateArrayLengthTest, Shadowing) {
+ auto* src = R"(
+struct SB {
+ x : i32,
+ arr : array<i32>,
+};
+
+@group(0) @binding(0) var<storage, read> a : SB;
+@group(0) @binding(1) var<storage, read> b : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ let x = &a;
+ var a : u32 = arrayLength(&a.arr);
+ {
+ var b : u32 = arrayLength(&((*x).arr));
+ }
+}
+)";
+
+ auto* expect =
+ R"(
+@internal(intrinsic_buffer_size)
+fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read>, result : ptr<function, u32>)
+
+struct SB {
+ x : i32,
+ arr : array<i32>,
+}
+
+@group(0) @binding(0) var<storage, read> a : SB;
+
+@group(0) @binding(1) var<storage, read> b : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ var tint_symbol_1 : u32 = 0u;
+ tint_symbol(&(a), &(tint_symbol_1));
+ let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
+ var a_1 : u32 = tint_symbol_2;
+ {
+ var tint_symbol_3 : u32 = 0u;
+ tint_symbol(&(a), &(tint_symbol_3));
+ let tint_symbol_4 : u32 = ((tint_symbol_3 - 4u) / 4u);
+ var b_1 : u32 = tint_symbol_4;
+ }
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CalculateArrayLengthTest, OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ var len1 : u32 = arrayLength(&(sb1.arr1));
+ var len2 : u32 = arrayLength(&(sb2.arr2));
+ var len3 : u32 = arrayLength(&sb3);
+ var x : u32 = (len1 + len2 + len3);
+}
+
+@group(0) @binding(0) var<storage, read> sb1 : SB1;
+
+struct SB1 {
+ x : i32,
+ arr1 : array<i32>,
+};
+
+@group(0) @binding(1) var<storage, read> sb2 : SB2;
+
+struct SB2 {
+ x : i32,
+ arr2 : array<vec4<f32>>,
+};
+
+@group(0) @binding(2) var<storage, read> sb3 : array<i32>;
+)";
+
+ auto* expect = R"(
+@internal(intrinsic_buffer_size)
+fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB1, read>, result : ptr<function, u32>)
+
+@internal(intrinsic_buffer_size)
+fn tint_symbol_3(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB2, read>, result : ptr<function, u32>)
+
+@internal(intrinsic_buffer_size)
+fn tint_symbol_6(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<i32>, read>, result : ptr<function, u32>)
+
+@compute @workgroup_size(1)
+fn main() {
+ var tint_symbol_1 : u32 = 0u;
+ tint_symbol(&(sb1), &(tint_symbol_1));
+ let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
+ var tint_symbol_4 : u32 = 0u;
+ tint_symbol_3(&(sb2), &(tint_symbol_4));
+ let tint_symbol_5 : u32 = ((tint_symbol_4 - 16u) / 16u);
+ var tint_symbol_7 : u32 = 0u;
+ tint_symbol_6(&(sb3), &(tint_symbol_7));
+ let tint_symbol_8 : u32 = (tint_symbol_7 / 4u);
+ var len1 : u32 = tint_symbol_2;
+ var len2 : u32 = tint_symbol_5;
+ var len3 : u32 = tint_symbol_8;
+ var x : u32 = ((len1 + len2) + len3);
+}
+
+@group(0) @binding(0) var<storage, read> sb1 : SB1;
+
+struct SB1 {
+ x : i32,
+ arr1 : array<i32>,
+}
+
+@group(0) @binding(1) var<storage, read> sb2 : SB2;
+
+struct SB2 {
+ x : i32,
+ arr2 : array<vec4<f32>>,
+}
+
+@group(0) @binding(2) var<storage, read> sb3 : array<i32>;
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc
new file mode 100644
index 0000000..de1c5f5
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc
@@ -0,0 +1,879 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h"
+
+#include <algorithm>
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "src/tint/builtin/builtin_value.h"
+#include "src/tint/lang/wgsl/ast/disable_validation_attribute.h"
+#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/function.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::CanonicalizeEntryPointIO);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::CanonicalizeEntryPointIO::Config);
+
+namespace tint::ast::transform {
+
+CanonicalizeEntryPointIO::CanonicalizeEntryPointIO() = default;
+CanonicalizeEntryPointIO::~CanonicalizeEntryPointIO() = default;
+
+namespace {
+
+/// Info for a struct member
+struct MemberInfo {
+ /// The struct member item
+ const StructMember* member;
+ /// The struct member location if provided
+ std::optional<uint32_t> location;
+ /// The struct member index if provided
+ std::optional<uint32_t> index;
+};
+
+/// FXC is sensitive to field order in structures, this is used by StructMemberComparator to ensure
+/// that FXC is happy with the order of emitted fields.
+uint32_t BuiltinOrder(builtin::BuiltinValue builtin) {
+ switch (builtin) {
+ case builtin::BuiltinValue::kPosition:
+ return 1;
+ case builtin::BuiltinValue::kVertexIndex:
+ return 2;
+ case builtin::BuiltinValue::kInstanceIndex:
+ return 3;
+ case builtin::BuiltinValue::kFrontFacing:
+ return 4;
+ case builtin::BuiltinValue::kFragDepth:
+ return 5;
+ case builtin::BuiltinValue::kLocalInvocationId:
+ return 6;
+ case builtin::BuiltinValue::kLocalInvocationIndex:
+ return 7;
+ case builtin::BuiltinValue::kGlobalInvocationId:
+ return 8;
+ case builtin::BuiltinValue::kWorkgroupId:
+ return 9;
+ case builtin::BuiltinValue::kNumWorkgroups:
+ return 10;
+ case builtin::BuiltinValue::kSampleIndex:
+ return 11;
+ case builtin::BuiltinValue::kSampleMask:
+ return 12;
+ case builtin::BuiltinValue::kPointSize:
+ return 13;
+ default:
+ break;
+ }
+ return 0;
+}
+
+// Returns true if `attr` is a shader IO attribute.
+bool IsShaderIOAttribute(const Attribute* attr) {
+ return attr->IsAnyOf<BuiltinAttribute, InterpolateAttribute, InvariantAttribute,
+ LocationAttribute, IndexAttribute>();
+}
+
+} // namespace
+
+/// PIMPL state for the transform
+struct CanonicalizeEntryPointIO::State {
+ /// OutputValue represents a shader result that the wrapper function produces.
+ struct OutputValue {
+ /// The name of the output value.
+ std::string name;
+ /// The type of the output value.
+ Type type;
+ /// The shader IO attributes.
+ utils::Vector<const Attribute*, 8> attributes;
+ /// The value itself.
+ const Expression* value;
+ /// The output location.
+ std::optional<uint32_t> location;
+ /// The output index.
+ std::optional<uint32_t> index;
+ };
+
+ /// The clone context.
+ CloneContext& ctx;
+ /// The transform config.
+ CanonicalizeEntryPointIO::Config const cfg;
+ /// The entry point function (AST).
+ const Function* func_ast;
+ /// The entry point function (SEM).
+ const sem::Function* func_sem;
+
+ /// The new entry point wrapper function's parameters.
+ utils::Vector<const Parameter*, 8> wrapper_ep_parameters;
+
+ /// The members of the wrapper function's struct parameter.
+ utils::Vector<MemberInfo, 8> wrapper_struct_param_members;
+ /// The name of the wrapper function's struct parameter.
+ Symbol wrapper_struct_param_name;
+ /// The parameters that will be passed to the original function.
+ utils::Vector<const Expression*, 8> inner_call_parameters;
+ /// The members of the wrapper function's struct return type.
+ utils::Vector<MemberInfo, 8> wrapper_struct_output_members;
+ /// The wrapper function output values.
+ utils::Vector<OutputValue, 8> wrapper_output_values;
+ /// The body of the wrapper function.
+ utils::Vector<const Statement*, 8> wrapper_body;
+ /// Input names used by the entrypoint
+ std::unordered_set<std::string> input_names;
+ /// A map of cloned attribute to builtin value
+ utils::Hashmap<const BuiltinAttribute*, builtin::BuiltinValue, 16> builtin_attrs;
+
+ /// Constructor
+ /// @param context the clone context
+ /// @param config the transform config
+ /// @param function the entry point function
+ State(CloneContext& context,
+ const CanonicalizeEntryPointIO::Config& config,
+ const Function* function)
+ : ctx(context), cfg(config), func_ast(function), func_sem(ctx.src->Sem().Get(function)) {}
+
+ /// Clones the attributes from @p in and adds it to @p out. If @p in is a builtin attribute,
+ /// then builtin_attrs is updated with the builtin information.
+ /// @param in the attribute to clone
+ /// @param out the output Attributes
+ template <size_t N>
+ void CloneAttribute(const Attribute* in, utils::Vector<const Attribute*, N>& out) {
+ auto* cloned = ctx.Clone(in);
+ out.Push(cloned);
+ if (auto* builtin = in->As<BuiltinAttribute>()) {
+ builtin_attrs.Add(cloned->As<BuiltinAttribute>(), ctx.src->Sem().Get(builtin)->Value());
+ }
+ }
+
+ /// Clones the shader IO attributes from @p in.
+ /// @param in the attributes to clone
+ /// @param do_interpolate whether to clone InterpolateAttribute
+ /// @return the cloned attributes
+ template <size_t N>
+ auto CloneShaderIOAttributes(const utils::Vector<const Attribute*, N> in, bool do_interpolate) {
+ utils::Vector<const Attribute*, N> out;
+ for (auto* attr : in) {
+ if (IsShaderIOAttribute(attr) &&
+ (do_interpolate || !attr->template Is<InterpolateAttribute>())) {
+ CloneAttribute(attr, out);
+ }
+ }
+ return out;
+ }
+
+ /// @param attr the input attribute
+ /// @returns the builtin value of the attribute
+ builtin::BuiltinValue BuiltinOf(const BuiltinAttribute* attr) {
+ if (attr->program_id == ctx.dst->ID()) {
+ // attr belongs to the target program.
+ // Obtain the builtin value from #builtin_attrs.
+ if (auto b = builtin_attrs.Get(attr)) {
+ return *b;
+ }
+ } else {
+ // attr belongs to the source program.
+ // Obtain the builtin value from the semantic info.
+ return ctx.src->Sem().Get(attr)->Value();
+ }
+ TINT_ICE(Resolver, ctx.dst->Diagnostics())
+ << "could not obtain builtin value from attribute";
+ return builtin::BuiltinValue::kUndefined;
+ }
+
+ /// @param attrs the input attribute list
+ /// @returns the builtin value if any of the attributes in @p attrs is a builtin attribute,
+ /// otherwise builtin::BuiltinValue::kUndefined
+ builtin::BuiltinValue BuiltinOf(utils::VectorRef<const Attribute*> attrs) {
+ if (auto* builtin = GetAttribute<BuiltinAttribute>(attrs)) {
+ return BuiltinOf(builtin);
+ }
+ return builtin::BuiltinValue::kUndefined;
+ }
+
+ /// Create or return a symbol for the wrapper function's struct parameter.
+ /// @returns the symbol for the struct parameter
+ Symbol InputStructSymbol() {
+ if (!wrapper_struct_param_name.IsValid()) {
+ wrapper_struct_param_name = ctx.dst->Sym();
+ }
+ return wrapper_struct_param_name;
+ }
+
+ /// Add a shader input to the entry point.
+ /// @param name the name of the shader input
+ /// @param type the type of the shader input
+ /// @param location the location if provided
+ /// @param attrs the attributes to apply to the shader input
+ /// @returns an expression which evaluates to the value of the shader input
+ const Expression* AddInput(std::string name,
+ const type::Type* type,
+ std::optional<uint32_t> location,
+ utils::Vector<const Attribute*, 8> attrs) {
+ auto ast_type = CreateASTTypeFor(ctx, type);
+
+ auto builtin_attr = BuiltinOf(attrs);
+
+ if (cfg.shader_style == ShaderStyle::kSpirv || cfg.shader_style == ShaderStyle::kGlsl) {
+ // Vulkan requires that integer user-defined fragment inputs are always decorated with
+ // `Flat`. See:
+ // https://www.khronos.org/registry/vulkan/specs/1.3-extensions/man/html/StandaloneSpirv.html#VUID-StandaloneSpirv-Flat-04744
+ // TODO(crbug.com/tint/1224): Remove this once a flat interpolation attribute is
+ // required for integers.
+ if (func_ast->PipelineStage() == PipelineStage::kFragment &&
+ type->is_integer_scalar_or_vector() && !HasAttribute<InterpolateAttribute>(attrs) &&
+ (HasAttribute<LocationAttribute>(attrs) ||
+ cfg.shader_style == ShaderStyle::kSpirv)) {
+ attrs.Push(ctx.dst->Interpolate(builtin::InterpolationType::kFlat,
+ builtin::InterpolationSampling::kUndefined));
+ }
+
+ // Disable validation for use of the `input` address space.
+ attrs.Push(ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace));
+
+ // In GLSL, if it's a builtin, override the name with the
+ // corresponding gl_ builtin name
+ if (cfg.shader_style == ShaderStyle::kGlsl &&
+ builtin_attr != builtin::BuiltinValue::kUndefined) {
+ name = GLSLBuiltinToString(builtin_attr, func_ast->PipelineStage(),
+ builtin::AddressSpace::kIn);
+ }
+ auto symbol = ctx.dst->Symbols().New(name);
+
+ // Create the global variable and use its value for the shader input.
+ const Expression* value = ctx.dst->Expr(symbol);
+
+ if (builtin_attr != builtin::BuiltinValue::kUndefined) {
+ if (cfg.shader_style == ShaderStyle::kGlsl) {
+ value = FromGLSLBuiltin(builtin_attr, value, ast_type);
+ } else if (builtin_attr == builtin::BuiltinValue::kSampleMask) {
+ // Vulkan requires the type of a SampleMask builtin to be an array.
+ // Declare it as array<u32, 1> and then load the first element.
+ ast_type = ctx.dst->ty.array(ast_type, 1_u);
+ value = ctx.dst->IndexAccessor(value, 0_i);
+ }
+ }
+ ctx.dst->GlobalVar(symbol, ast_type, builtin::AddressSpace::kIn, std::move(attrs));
+ return value;
+ } else if (cfg.shader_style == ShaderStyle::kMsl &&
+ builtin_attr != builtin::BuiltinValue::kUndefined) {
+ // If this input is a builtin and we are targeting MSL, then add it to the
+ // parameter list and pass it directly to the inner function.
+ Symbol symbol = input_names.emplace(name).second ? ctx.dst->Symbols().Register(name)
+ : ctx.dst->Symbols().New(name);
+ wrapper_ep_parameters.Push(ctx.dst->Param(symbol, ast_type, std::move(attrs)));
+ return ctx.dst->Expr(symbol);
+ } else {
+ // Otherwise, move it to the new structure member list.
+ Symbol symbol = input_names.emplace(name).second ? ctx.dst->Symbols().Register(name)
+ : ctx.dst->Symbols().New(name);
+ wrapper_struct_param_members.Push(
+ {ctx.dst->Member(symbol, ast_type, std::move(attrs)), location, std::nullopt});
+ return ctx.dst->MemberAccessor(InputStructSymbol(), symbol);
+ }
+ }
+
+ /// Add a shader output to the entry point.
+ /// @param name the name of the shader output
+ /// @param type the type of the shader output
+ /// @param location the location if provided
+ /// @param index the index if provided
+ /// @param attrs the attributes to apply to the shader output
+ /// @param value the value of the shader output
+ void AddOutput(std::string name,
+ const type::Type* type,
+ std::optional<uint32_t> location,
+ std::optional<uint32_t> index,
+ utils::Vector<const Attribute*, 8> attrs,
+ const Expression* value) {
+ auto builtin_attr = BuiltinOf(attrs);
+ // Vulkan requires that integer user-defined vertex outputs are always decorated with
+ // `Flat`.
+ // TODO(crbug.com/tint/1224): Remove this once a flat interpolation attribute is required
+ // for integers.
+ if (cfg.shader_style == ShaderStyle::kSpirv &&
+ func_ast->PipelineStage() == PipelineStage::kVertex &&
+ type->is_integer_scalar_or_vector() && HasAttribute<LocationAttribute>(attrs) &&
+ !HasAttribute<InterpolateAttribute>(attrs)) {
+ attrs.Push(ctx.dst->Interpolate(builtin::InterpolationType::kFlat,
+ builtin::InterpolationSampling::kUndefined));
+ }
+
+ // In GLSL, if it's a builtin, override the name with the
+ // corresponding gl_ builtin name
+ if (cfg.shader_style == ShaderStyle::kGlsl) {
+ if (builtin_attr != builtin::BuiltinValue::kUndefined) {
+ name = GLSLBuiltinToString(builtin_attr, func_ast->PipelineStage(),
+ builtin::AddressSpace::kOut);
+ value = ToGLSLBuiltin(builtin_attr, value, type);
+ }
+ }
+
+ OutputValue output;
+ output.name = name;
+ output.type = CreateASTTypeFor(ctx, type);
+ output.attributes = std::move(attrs);
+ output.value = value;
+ output.location = location;
+ output.index = index;
+ wrapper_output_values.Push(output);
+ }
+
+ /// Process a non-struct parameter.
+ /// This creates a new object for the shader input, moving the shader IO
+ /// attributes to it. It also adds an expression to the list of parameters
+ /// that will be passed to the original function.
+ /// @param param the original function parameter
+ void ProcessNonStructParameter(const sem::Parameter* param) {
+ // Do not add interpolation attributes on vertex input
+ bool do_interpolate = func_ast->PipelineStage() != PipelineStage::kVertex;
+ // Remove the shader IO attributes from the inner function parameter, and attach them to the
+ // new object instead.
+ utils::Vector<const Attribute*, 8> attributes;
+ for (auto* attr : param->Declaration()->attributes) {
+ if (IsShaderIOAttribute(attr)) {
+ ctx.Remove(param->Declaration()->attributes, attr);
+ if ((do_interpolate || !attr->Is<InterpolateAttribute>())) {
+ CloneAttribute(attr, attributes);
+ }
+ }
+ }
+
+ auto name = param->Declaration()->name->symbol.Name();
+ auto* input_expr = AddInput(name, param->Type(), param->Location(), std::move(attributes));
+ inner_call_parameters.Push(input_expr);
+ }
+
+ /// Process a struct parameter.
+ /// This creates new objects for each struct member, moving the shader IO
+ /// attributes to them. It also creates the structure that will be passed to
+ /// the original function.
+ /// @param param the original function parameter
+ void ProcessStructParameter(const sem::Parameter* param) {
+ // Do not add interpolation attributes on vertex input
+ bool do_interpolate = func_ast->PipelineStage() != PipelineStage::kVertex;
+
+ auto* str = param->Type()->As<sem::Struct>();
+
+ // Recreate struct members in the outer entry point and build an initializer
+ // list to pass them through to the inner function.
+ utils::Vector<const Expression*, 8> inner_struct_values;
+ for (auto* member : str->Members()) {
+ if (TINT_UNLIKELY(member->Type()->Is<type::Struct>())) {
+ TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
+ continue;
+ }
+
+ auto name = member->Name().Name();
+
+ auto attributes =
+ CloneShaderIOAttributes(member->Declaration()->attributes, do_interpolate);
+ auto* input_expr = AddInput(name, member->Type(), member->Attributes().location,
+ std::move(attributes));
+ inner_struct_values.Push(input_expr);
+ }
+
+ // Construct the original structure using the new shader input objects.
+ inner_call_parameters.Push(
+ ctx.dst->Call(ctx.Clone(param->Declaration()->type), inner_struct_values));
+ }
+
+ /// Process the entry point return type.
+ /// This generates a list of output values that are returned by the original
+ /// function.
+ /// @param inner_ret_type the original function return type
+ /// @param original_result the result object produced by the original function
+ void ProcessReturnType(const type::Type* inner_ret_type, Symbol original_result) {
+ // Do not add interpolation attributes on fragment output
+ bool do_interpolate = func_ast->PipelineStage() != PipelineStage::kFragment;
+ if (auto* str = inner_ret_type->As<sem::Struct>()) {
+ for (auto* member : str->Members()) {
+ if (TINT_UNLIKELY(member->Type()->Is<type::Struct>())) {
+ TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
+ continue;
+ }
+
+ auto name = member->Name().Name();
+ auto attributes =
+ CloneShaderIOAttributes(member->Declaration()->attributes, do_interpolate);
+
+ // Extract the original structure member.
+ AddOutput(name, member->Type(), member->Attributes().location,
+ member->Attributes().index, std::move(attributes),
+ ctx.dst->MemberAccessor(original_result, name));
+ }
+ } else if (!inner_ret_type->Is<type::Void>()) {
+ auto attributes =
+ CloneShaderIOAttributes(func_ast->return_type_attributes, do_interpolate);
+
+ // Propagate the non-struct return value as is.
+ AddOutput("value", func_sem->ReturnType(), func_sem->ReturnLocation(),
+ func_sem->ReturnIndex(), std::move(attributes),
+ ctx.dst->Expr(original_result));
+ }
+ }
+
+ /// Add a fixed sample mask to the wrapper function output.
+ /// If there is already a sample mask, bitwise-and it with the fixed mask.
+ /// Otherwise, create a new output value from the fixed mask.
+ void AddFixedSampleMask() {
+ // Check the existing output values for a sample mask builtin.
+ for (auto& outval : wrapper_output_values) {
+ if (BuiltinOf(outval.attributes) == builtin::BuiltinValue::kSampleMask) {
+ // Combine the authored sample mask with the fixed mask.
+ outval.value = ctx.dst->And(outval.value, u32(cfg.fixed_sample_mask));
+ return;
+ }
+ }
+
+ // No existing sample mask builtin was found, so create a new output value using the fixed
+ // sample mask.
+ auto* builtin = ctx.dst->Builtin(builtin::BuiltinValue::kSampleMask);
+ builtin_attrs.Add(builtin, builtin::BuiltinValue::kSampleMask);
+ AddOutput("fixed_sample_mask", ctx.dst->create<type::U32>(), std::nullopt, std::nullopt,
+ {builtin}, ctx.dst->Expr(u32(cfg.fixed_sample_mask)));
+ }
+
+ /// Add a point size builtin to the wrapper function output.
+ void AddVertexPointSize() {
+ // Create a new output value and assign it a literal 1.0 value.
+ auto* builtin = ctx.dst->Builtin(builtin::BuiltinValue::kPointSize);
+ builtin_attrs.Add(builtin, builtin::BuiltinValue::kPointSize);
+ AddOutput("vertex_point_size", ctx.dst->create<type::F32>(), std::nullopt, std::nullopt,
+ {builtin}, ctx.dst->Expr(1_f));
+ }
+
+ /// Create an expression for gl_Position.[component]
+ /// @param component the component of gl_Position to access
+ /// @returns the new expression
+ const Expression* GLPosition(const char* component) {
+ Symbol pos = ctx.dst->Symbols().Register("gl_Position");
+ Symbol c = ctx.dst->Symbols().Register(component);
+ return ctx.dst->MemberAccessor(ctx.dst->Expr(pos), c);
+ }
+
+ /// Comparison function used to reorder struct members such that all members with
+ /// location attributes appear first (ordered by location slot), followed by
+ /// those with builtin attributes.
+ /// @param a a struct member
+ /// @param b another struct member
+ /// @returns true if a comes before b
+ bool StructMemberComparator(const MemberInfo& a, const MemberInfo& b) {
+ auto* a_loc = GetAttribute<LocationAttribute>(a.member->attributes);
+ auto* b_loc = GetAttribute<LocationAttribute>(b.member->attributes);
+ auto* a_blt = GetAttribute<BuiltinAttribute>(a.member->attributes);
+ auto* b_blt = GetAttribute<BuiltinAttribute>(b.member->attributes);
+ if (a_loc) {
+ if (!b_loc) {
+ // `a` has location attribute and `b` does not: `a` goes first.
+ return true;
+ }
+ // Both have location attributes: smallest goes first.
+ return a.location < b.location;
+ } else {
+ if (b_loc) {
+ // `b` has location attribute and `a` does not: `b` goes first.
+ return false;
+ }
+ // Both are builtins: order matters for FXC.
+ auto builtin_a = BuiltinOf(a_blt);
+ auto builtin_b = BuiltinOf(b_blt);
+ return BuiltinOrder(builtin_a) < BuiltinOrder(builtin_b);
+ }
+ }
+ /// Create the wrapper function's struct parameter and type objects.
+ void CreateInputStruct() {
+ // Sort the struct members to satisfy HLSL interfacing matching rules.
+ std::sort(wrapper_struct_param_members.begin(), wrapper_struct_param_members.end(),
+ [&](auto& a, auto& b) { return StructMemberComparator(a, b); });
+
+ utils::Vector<const StructMember*, 8> members;
+ for (auto& mem : wrapper_struct_param_members) {
+ members.Push(mem.member);
+ }
+
+ // Create the new struct type.
+ auto struct_name = ctx.dst->Sym();
+ auto* in_struct =
+ ctx.dst->create<Struct>(ctx.dst->Ident(struct_name), std::move(members), utils::Empty);
+ ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct);
+
+ // Create a new function parameter using this struct type.
+ auto* param = ctx.dst->Param(InputStructSymbol(), ctx.dst->ty(struct_name));
+ wrapper_ep_parameters.Push(param);
+ }
+
+ /// Create and return the wrapper function's struct result object.
+ /// @returns the struct type
+ Struct* CreateOutputStruct() {
+ utils::Vector<const Statement*, 8> assignments;
+
+ auto wrapper_result = ctx.dst->Symbols().New("wrapper_result");
+
+ // Create the struct members and their corresponding assignment statements.
+ std::unordered_set<std::string> member_names;
+ for (auto& outval : wrapper_output_values) {
+ // Use the original output name, unless that is already taken.
+ Symbol name;
+ if (member_names.count(outval.name)) {
+ name = ctx.dst->Symbols().New(outval.name);
+ } else {
+ name = ctx.dst->Symbols().Register(outval.name);
+ }
+ member_names.insert(name.Name());
+
+ wrapper_struct_output_members.Push(
+ {ctx.dst->Member(name, outval.type, std::move(outval.attributes)), outval.location,
+ std::nullopt});
+ assignments.Push(
+ ctx.dst->Assign(ctx.dst->MemberAccessor(wrapper_result, name), outval.value));
+ }
+
+ // Sort the struct members to satisfy HLSL interfacing matching rules.
+ std::sort(wrapper_struct_output_members.begin(), wrapper_struct_output_members.end(),
+ [&](auto& a, auto& b) { return StructMemberComparator(a, b); });
+
+ utils::Vector<const StructMember*, 8> members;
+ for (auto& mem : wrapper_struct_output_members) {
+ members.Push(mem.member);
+ }
+
+ // Create the new struct type.
+ auto* out_struct = ctx.dst->create<Struct>(ctx.dst->Ident(ctx.dst->Sym()),
+ std::move(members), utils::Empty);
+ ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct);
+
+ // Create the output struct object, assign its members, and return it.
+ auto* result_object = ctx.dst->Var(wrapper_result, ctx.dst->ty(out_struct->name->symbol));
+ wrapper_body.Push(ctx.dst->Decl(result_object));
+ for (auto* assignment : assignments) {
+ wrapper_body.Push(assignment);
+ }
+ wrapper_body.Push(ctx.dst->Return(wrapper_result));
+
+ return out_struct;
+ }
+
+ /// Create and assign the wrapper function's output variables.
+ void CreateGlobalOutputVariables() {
+ for (auto& outval : wrapper_output_values) {
+ // Disable validation for use of the `output` address space.
+ auto attributes = std::move(outval.attributes);
+ attributes.Push(ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace));
+
+ // Create the global variable and assign it the output value.
+ auto name = ctx.dst->Symbols().New(outval.name);
+ Type type = outval.type;
+ const Expression* lhs = ctx.dst->Expr(name);
+ if (BuiltinOf(attributes) == builtin::BuiltinValue::kSampleMask) {
+ // Vulkan requires the type of a SampleMask builtin to be an array.
+ // Declare it as array<u32, 1> and then store to the first element.
+ type = ctx.dst->ty.array(type, 1_u);
+ lhs = ctx.dst->IndexAccessor(lhs, 0_i);
+ }
+ ctx.dst->GlobalVar(name, type, builtin::AddressSpace::kOut, std::move(attributes));
+ wrapper_body.Push(ctx.dst->Assign(lhs, outval.value));
+ }
+ }
+
+ // Recreate the original function without entry point attributes and call it.
+ /// @returns the inner function call expression
+ const CallExpression* CallInnerFunction() {
+ Symbol inner_name;
+ if (cfg.shader_style == ShaderStyle::kGlsl) {
+ // In GLSL, clone the original entry point name, as the wrapper will be
+ // called "main".
+ inner_name = ctx.Clone(func_ast->name->symbol);
+ } else {
+ // Add a suffix to the function name, as the wrapper function will take
+ // the original entry point name.
+ auto ep_name = func_ast->name->symbol.Name();
+ inner_name = ctx.dst->Symbols().New(ep_name + "_inner");
+ }
+
+ // Clone everything, dropping the function and return type attributes.
+ // The parameter attributes will have already been stripped during
+ // processing.
+ auto* inner_function =
+ ctx.dst->create<Function>(ctx.dst->Ident(inner_name), ctx.Clone(func_ast->params),
+ ctx.Clone(func_ast->return_type), ctx.Clone(func_ast->body),
+ utils::Empty, utils::Empty);
+ ctx.Replace(func_ast, inner_function);
+
+ // Call the function.
+ return ctx.dst->Call(inner_function->name->symbol, inner_call_parameters);
+ }
+
+ /// Process the entry point function.
+ void Process() {
+ bool needs_fixed_sample_mask = false;
+ bool needs_vertex_point_size = false;
+ if (func_ast->PipelineStage() == PipelineStage::kFragment &&
+ cfg.fixed_sample_mask != 0xFFFFFFFF) {
+ needs_fixed_sample_mask = true;
+ }
+ if (func_ast->PipelineStage() == PipelineStage::kVertex && cfg.emit_vertex_point_size) {
+ needs_vertex_point_size = true;
+ }
+
+ // Exit early if there is no shader IO to handle.
+ if (func_sem->Parameters().Length() == 0 && func_sem->ReturnType()->Is<type::Void>() &&
+ !needs_fixed_sample_mask && !needs_vertex_point_size &&
+ cfg.shader_style != ShaderStyle::kGlsl) {
+ return;
+ }
+
+ // Process the entry point parameters, collecting those that need to be
+ // aggregated into a single structure.
+ if (!func_sem->Parameters().IsEmpty()) {
+ for (auto* param : func_sem->Parameters()) {
+ if (param->Type()->Is<type::Struct>()) {
+ ProcessStructParameter(param);
+ } else {
+ ProcessNonStructParameter(param);
+ }
+ }
+
+ // Create a structure parameter for the outer entry point if necessary.
+ if (!wrapper_struct_param_members.IsEmpty()) {
+ CreateInputStruct();
+ }
+ }
+
+ // Recreate the original function and call it.
+ auto* call_inner = CallInnerFunction();
+
+ // Process the return type, and start building the wrapper function body.
+ std::function<Type()> wrapper_ret_type = [&] { return ctx.dst->ty.void_(); };
+ if (func_sem->ReturnType()->Is<type::Void>()) {
+ // The function call is just a statement with no result.
+ wrapper_body.Push(ctx.dst->CallStmt(call_inner));
+ } else {
+ // Capture the result of calling the original function.
+ auto* inner_result = ctx.dst->Let(ctx.dst->Symbols().New("inner_result"), call_inner);
+ wrapper_body.Push(ctx.dst->Decl(inner_result));
+
+ // Process the original return type to determine the outputs that the
+ // outer function needs to produce.
+ ProcessReturnType(func_sem->ReturnType(), inner_result->name->symbol);
+ }
+
+ // Add a fixed sample mask, if necessary.
+ if (needs_fixed_sample_mask) {
+ AddFixedSampleMask();
+ }
+
+ // Add the pointsize builtin, if necessary.
+ if (needs_vertex_point_size) {
+ AddVertexPointSize();
+ }
+
+ // Produce the entry point outputs, if necessary.
+ if (!wrapper_output_values.IsEmpty()) {
+ if (cfg.shader_style == ShaderStyle::kSpirv || cfg.shader_style == ShaderStyle::kGlsl) {
+ CreateGlobalOutputVariables();
+ } else {
+ auto* output_struct = CreateOutputStruct();
+ wrapper_ret_type = [&, output_struct] {
+ return ctx.dst->ty(output_struct->name->symbol);
+ };
+ }
+ }
+
+ if (cfg.shader_style == ShaderStyle::kGlsl &&
+ func_ast->PipelineStage() == PipelineStage::kVertex) {
+ auto* pos_y = GLPosition("y");
+ auto* negate_pos_y =
+ ctx.dst->create<UnaryOpExpression>(UnaryOp::kNegation, GLPosition("y"));
+ wrapper_body.Push(ctx.dst->Assign(pos_y, negate_pos_y));
+
+ auto* two_z = ctx.dst->Mul(ctx.dst->Expr(2_f), GLPosition("z"));
+ auto* fixed_z = ctx.dst->Sub(two_z, GLPosition("w"));
+ wrapper_body.Push(ctx.dst->Assign(GLPosition("z"), fixed_z));
+ }
+
+ // Create the wrapper entry point function.
+ // For GLSL, use "main", otherwise take the name of the original
+ // entry point function.
+ Symbol name;
+ if (cfg.shader_style == ShaderStyle::kGlsl) {
+ name = ctx.dst->Symbols().New("main");
+ } else {
+ name = ctx.Clone(func_ast->name->symbol);
+ }
+
+ auto* wrapper_func = ctx.dst->create<Function>(
+ ctx.dst->Ident(name), wrapper_ep_parameters, ctx.dst->ty(wrapper_ret_type()),
+ ctx.dst->Block(wrapper_body), ctx.Clone(func_ast->attributes), utils::Empty);
+ ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), func_ast, wrapper_func);
+ }
+
+ /// Retrieve the gl_ string corresponding to a builtin.
+ /// @param builtin the builtin
+ /// @param stage the current pipeline stage
+ /// @param address_space the address space (input or output)
+ /// @returns the gl_ string corresponding to that builtin
+ const char* GLSLBuiltinToString(builtin::BuiltinValue builtin,
+ PipelineStage stage,
+ builtin::AddressSpace address_space) {
+ switch (builtin) {
+ case builtin::BuiltinValue::kPosition:
+ switch (stage) {
+ case PipelineStage::kVertex:
+ return "gl_Position";
+ case PipelineStage::kFragment:
+ return "gl_FragCoord";
+ default:
+ return "";
+ }
+ case builtin::BuiltinValue::kVertexIndex:
+ return "gl_VertexID";
+ case builtin::BuiltinValue::kInstanceIndex:
+ return "gl_InstanceID";
+ case builtin::BuiltinValue::kFrontFacing:
+ return "gl_FrontFacing";
+ case builtin::BuiltinValue::kFragDepth:
+ return "gl_FragDepth";
+ case builtin::BuiltinValue::kLocalInvocationId:
+ return "gl_LocalInvocationID";
+ case builtin::BuiltinValue::kLocalInvocationIndex:
+ return "gl_LocalInvocationIndex";
+ case builtin::BuiltinValue::kGlobalInvocationId:
+ return "gl_GlobalInvocationID";
+ case builtin::BuiltinValue::kNumWorkgroups:
+ return "gl_NumWorkGroups";
+ case builtin::BuiltinValue::kWorkgroupId:
+ return "gl_WorkGroupID";
+ case builtin::BuiltinValue::kSampleIndex:
+ return "gl_SampleID";
+ case builtin::BuiltinValue::kSampleMask:
+ if (address_space == builtin::AddressSpace::kIn) {
+ return "gl_SampleMaskIn";
+ } else {
+ return "gl_SampleMask";
+ }
+ default:
+ return "";
+ }
+ }
+
+ /// Convert a given GLSL builtin value to the corresponding WGSL value.
+ /// @param builtin the builtin variable
+ /// @param value the value to convert
+ /// @param ast_type (inout) the incoming WGSL and outgoing GLSL types
+ /// @returns an expression representing the GLSL builtin converted to what
+ /// WGSL expects
+ const Expression* FromGLSLBuiltin(builtin::BuiltinValue builtin,
+ const Expression* value,
+ Type& ast_type) {
+ switch (builtin) {
+ case builtin::BuiltinValue::kVertexIndex:
+ case builtin::BuiltinValue::kInstanceIndex:
+ case builtin::BuiltinValue::kSampleIndex:
+ // GLSL uses i32 for these, so bitcast to u32.
+ value = ctx.dst->Bitcast(ast_type, value);
+ ast_type = ctx.dst->ty.i32();
+ break;
+ case builtin::BuiltinValue::kSampleMask:
+ // gl_SampleMask is an array of i32. Retrieve the first element and
+ // bitcast it to u32.
+ value = ctx.dst->IndexAccessor(value, 0_i);
+ value = ctx.dst->Bitcast(ast_type, value);
+ ast_type = ctx.dst->ty.array(ctx.dst->ty.i32(), 1_u);
+ break;
+ default:
+ break;
+ }
+ return value;
+ }
+
+ /// Convert a given WGSL value to the type expected when assigning to a
+ /// GLSL builtin.
+ /// @param builtin the builtin variable
+ /// @param value the value to convert
+ /// @param type (out) the type to which the value was converted
+ /// @returns the converted value which can be assigned to the GLSL builtin
+ const Expression* ToGLSLBuiltin(builtin::BuiltinValue builtin,
+ const Expression* value,
+ const type::Type*& type) {
+ switch (builtin) {
+ case builtin::BuiltinValue::kVertexIndex:
+ case builtin::BuiltinValue::kInstanceIndex:
+ case builtin::BuiltinValue::kSampleIndex:
+ case builtin::BuiltinValue::kSampleMask:
+ type = ctx.dst->create<type::I32>();
+ value = ctx.dst->Bitcast(CreateASTTypeFor(ctx, type), value);
+ break;
+ default:
+ break;
+ }
+ return value;
+ }
+};
+
+Transform::ApplyResult CanonicalizeEntryPointIO::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ auto* cfg = inputs.Get<Config>();
+ if (cfg == nullptr) {
+ b.Diagnostics().add_error(diag::System::Transform,
+ "missing transform data for " + std::string(TypeInfo().name));
+ return Program(std::move(b));
+ }
+
+ // Remove entry point IO attributes from struct declarations.
+ // New structures will be created for each entry point, as necessary.
+ for (auto* ty : src->AST().TypeDecls()) {
+ if (auto* struct_ty = ty->As<Struct>()) {
+ for (auto* member : struct_ty->members) {
+ for (auto* attr : member->attributes) {
+ if (IsShaderIOAttribute(attr)) {
+ ctx.Remove(member->attributes, attr);
+ }
+ }
+ }
+ }
+ }
+
+ for (auto* func_ast : src->AST().Functions()) {
+ if (!func_ast->IsEntryPoint()) {
+ continue;
+ }
+
+ State state(ctx, *cfg, func_ast);
+ state.Process();
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+CanonicalizeEntryPointIO::Config::Config(ShaderStyle style,
+ uint32_t sample_mask,
+ bool emit_point_size)
+ : shader_style(style),
+ fixed_sample_mask(sample_mask),
+ emit_vertex_point_size(emit_point_size) {}
+
+CanonicalizeEntryPointIO::Config::Config(const Config&) = default;
+CanonicalizeEntryPointIO::Config::~Config() = default;
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h
new file mode 100644
index 0000000..02d4acf
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h
@@ -0,0 +1,141 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_CANONICALIZE_ENTRY_POINT_IO_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_CANONICALIZE_ENTRY_POINT_IO_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// CanonicalizeEntryPointIO is a transform used to rewrite shader entry point
+/// interfaces into a form that the generators can handle. Each entry point
+/// function is stripped of all shader IO attributes and wrapped in a function
+/// that provides the shader interface.
+/// The transform config determines whether to use global variables, structures,
+/// or parameters for the shader inputs and outputs, and optionally adds
+/// additional builtins to the shader interface.
+///
+/// Before:
+/// ```
+/// struct Locations{
+/// @location(1) loc1 : f32;
+/// @location(2) loc2 : vec4<u32>;
+/// };
+///
+/// @fragment
+/// fn frag_main(@builtin(position) coord : vec4<f32>,
+/// locations : Locations) -> @location(0) f32 {
+/// if (coord.w > 1.0) {
+/// return 0.0;
+/// }
+/// var col : f32 = (coord.x * locations.loc1);
+/// return col;
+/// }
+/// ```
+///
+/// After (using structures for all parameters):
+/// ```
+/// struct Locations{
+/// loc1 : f32;
+/// loc2 : vec4<u32>;
+/// };
+///
+/// struct frag_main_in {
+/// @builtin(position) coord : vec4<f32>;
+/// @location(1) loc1 : f32;
+/// @location(2) loc2 : vec4<u32>
+/// };
+///
+/// struct frag_main_out {
+/// @location(0) loc0 : f32;
+/// };
+///
+/// fn frag_main_inner(coord : vec4<f32>,
+/// locations : Locations) -> f32 {
+/// if (coord.w > 1.0) {
+/// return 0.0;
+/// }
+/// var col : f32 = (coord.x * locations.loc1);
+/// return col;
+/// }
+///
+/// @fragment
+/// fn frag_main(in : frag_main_in) -> frag_main_out {
+/// let inner_retval = frag_main_inner(in.coord, Locations(in.loc1, in.loc2));
+/// var wrapper_result : frag_main_out;
+/// wrapper_result.loc0 = inner_retval;
+/// return wrapper_result;
+/// }
+/// ```
+///
+/// @note Depends on the following transforms to have been run first:
+/// * Unshadow
+class CanonicalizeEntryPointIO final : public utils::Castable<CanonicalizeEntryPointIO, Transform> {
+ public:
+ /// ShaderStyle is an enumerator of different ways to emit shader IO.
+ enum class ShaderStyle {
+ /// Target SPIR-V (using global variables).
+ kSpirv,
+ /// Target GLSL (using global variables).
+ kGlsl,
+ /// Target MSL (using non-struct function parameters for builtins).
+ kMsl,
+ /// Target HLSL (using structures for all IO).
+ kHlsl,
+ };
+
+ /// Configuration options for the transform.
+ struct Config final : public utils::Castable<Config, Data> {
+ /// Constructor
+ /// @param style the approach to use for emitting shader IO.
+ /// @param sample_mask an optional sample mask to combine with shader masks
+ /// @param emit_vertex_point_size `true` to generate a pointsize builtin
+ explicit Config(ShaderStyle style,
+ uint32_t sample_mask = 0xFFFFFFFF,
+ bool emit_vertex_point_size = false);
+
+ /// Copy constructor
+ Config(const Config&);
+
+ /// Destructor
+ ~Config() override;
+
+ /// The approach to use for emitting shader IO.
+ const ShaderStyle shader_style;
+
+ /// A fixed sample mask to combine into masks produced by fragment shaders.
+ const uint32_t fixed_sample_mask;
+
+ /// Set to `true` to generate a pointsize builtin and have it set to 1.0
+ /// from all vertex shaders in the module.
+ const bool emit_vertex_point_size;
+ };
+
+ /// Constructor
+ CanonicalizeEntryPointIO();
+ ~CanonicalizeEntryPointIO() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_CANONICALIZE_ENTRY_POINT_IO_H_
diff --git a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc
new file mode 100644
index 0000000..d969c98
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc
@@ -0,0 +1,4175 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h"
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using CanonicalizeEntryPointIOTest = TransformTest;
+
+TEST_F(CanonicalizeEntryPointIOTest, Error_MissingTransformData) {
+ auto* src = "";
+
+ auto* expect =
+ "error: missing transform data for tint::ast::transform::CanonicalizeEntryPointIO";
+
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, NoShaderIO) {
+ // Test that we do not introduce wrapper functions when there is no shader IO
+ // to process.
+ auto* src = R"(
+@fragment
+fn frag_main() {
+}
+
+@compute @workgroup_size(1)
+fn comp_main() {
+}
+)";
+
+ auto* expect = src;
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Parameters_Spirv) {
+ auto* src = R"(
+@fragment
+fn frag_main(@location(1) loc1 : f32,
+ @location(2) @interpolate(flat) loc2 : vec4<u32>,
+ @builtin(position) coord : vec4<f32>) {
+ var col : f32 = (coord.x * loc1);
+}
+)";
+
+ auto* expect = R"(
+@location(1) @internal(disable_validation__ignore_address_space) var<__in> loc1_1 : f32;
+
+@location(2) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__in> loc2_1 : vec4<u32>;
+
+@builtin(position) @internal(disable_validation__ignore_address_space) var<__in> coord_1 : vec4<f32>;
+
+fn frag_main_inner(loc1 : f32, loc2 : vec4<u32>, coord : vec4<f32>) {
+ var col : f32 = (coord.x * loc1);
+}
+
+@fragment
+fn frag_main() {
+ frag_main_inner(loc1_1, loc2_1, coord_1);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Parameters_Msl) {
+ auto* src = R"(
+@fragment
+fn frag_main(@location(1) loc1 : f32,
+ @location(2) @interpolate(flat) loc2 : vec4<u32>,
+ @builtin(position) coord : vec4<f32>) {
+ var col : f32 = (coord.x * loc1);
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol_1 {
+ @location(1)
+ loc1 : f32,
+ @location(2) @interpolate(flat)
+ loc2 : vec4<u32>,
+}
+
+fn frag_main_inner(loc1 : f32, loc2 : vec4<u32>, coord : vec4<f32>) {
+ var col : f32 = (coord.x * loc1);
+}
+
+@fragment
+fn frag_main(@builtin(position) coord : vec4<f32>, tint_symbol : tint_symbol_1) {
+ frag_main_inner(tint_symbol.loc1, tint_symbol.loc2, coord);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Parameters_Hlsl) {
+ auto* src = R"(
+@fragment
+fn frag_main(@location(1) loc1 : f32,
+ @location(2) @interpolate(flat) loc2 : vec4<u32>,
+ @builtin(position) coord : vec4<f32>) {
+ var col : f32 = (coord.x * loc1);
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol_1 {
+ @location(1)
+ loc1 : f32,
+ @location(2) @interpolate(flat)
+ loc2 : vec4<u32>,
+ @builtin(position)
+ coord : vec4<f32>,
+}
+
+fn frag_main_inner(loc1 : f32, loc2 : vec4<u32>, coord : vec4<f32>) {
+ var col : f32 = (coord.x * loc1);
+}
+
+@fragment
+fn frag_main(tint_symbol : tint_symbol_1) {
+ frag_main_inner(tint_symbol.loc1, tint_symbol.loc2, tint_symbol.coord);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Parameter_TypeAlias) {
+ auto* src = R"(
+alias myf32 = f32;
+
+@fragment
+fn frag_main(@location(1) loc1 : myf32) {
+ var x : myf32 = loc1;
+}
+)";
+
+ auto* expect = R"(
+alias myf32 = f32;
+
+struct tint_symbol_1 {
+ @location(1)
+ loc1 : f32,
+}
+
+fn frag_main_inner(loc1 : myf32) {
+ var x : myf32 = loc1;
+}
+
+@fragment
+fn frag_main(tint_symbol : tint_symbol_1) {
+ frag_main_inner(tint_symbol.loc1);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Parameter_TypeAlias_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn frag_main(@location(1) loc1 : myf32) {
+ var x : myf32 = loc1;
+}
+
+alias myf32 = f32;
+)";
+
+ auto* expect = R"(
+struct tint_symbol_1 {
+ @location(1)
+ loc1 : f32,
+}
+
+fn frag_main_inner(loc1 : myf32) {
+ var x : myf32 = loc1;
+}
+
+@fragment
+fn frag_main(tint_symbol : tint_symbol_1) {
+ frag_main_inner(tint_symbol.loc1);
+}
+
+alias myf32 = f32;
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_Spirv) {
+ auto* src = R"(
+struct FragBuiltins {
+ @builtin(position) coord : vec4<f32>,
+};
+struct FragLocations {
+ @location(1) loc1 : f32,
+ @location(2) @interpolate(flat) loc2 : vec4<u32>,
+};
+
+@fragment
+fn frag_main(@location(0) loc0 : f32,
+ locations : FragLocations,
+ builtins : FragBuiltins) {
+ var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
+}
+)";
+
+ auto* expect = R"(
+@location(0) @internal(disable_validation__ignore_address_space) var<__in> loc0_1 : f32;
+
+@location(1) @internal(disable_validation__ignore_address_space) var<__in> loc1_1 : f32;
+
+@location(2) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__in> loc2_1 : vec4<u32>;
+
+@builtin(position) @internal(disable_validation__ignore_address_space) var<__in> coord_1 : vec4<f32>;
+
+struct FragBuiltins {
+ coord : vec4<f32>,
+}
+
+struct FragLocations {
+ loc1 : f32,
+ loc2 : vec4<u32>,
+}
+
+fn frag_main_inner(loc0 : f32, locations : FragLocations, builtins : FragBuiltins) {
+ var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
+}
+
+@fragment
+fn frag_main() {
+ frag_main_inner(loc0_1, FragLocations(loc1_1, loc2_1), FragBuiltins(coord_1));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_Spirv_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn frag_main(@location(0) loc0 : f32,
+ locations : FragLocations,
+ builtins : FragBuiltins) {
+ var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
+}
+
+struct FragBuiltins {
+ @builtin(position) coord : vec4<f32>,
+};
+struct FragLocations {
+ @location(1) loc1 : f32,
+ @location(2) @interpolate(flat) loc2 : vec4<u32>,
+};
+)";
+
+ auto* expect = R"(
+@location(0) @internal(disable_validation__ignore_address_space) var<__in> loc0_1 : f32;
+
+@location(1) @internal(disable_validation__ignore_address_space) var<__in> loc1_1 : f32;
+
+@location(2) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__in> loc2_1 : vec4<u32>;
+
+@builtin(position) @internal(disable_validation__ignore_address_space) var<__in> coord_1 : vec4<f32>;
+
+fn frag_main_inner(loc0 : f32, locations : FragLocations, builtins : FragBuiltins) {
+ var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
+}
+
+@fragment
+fn frag_main() {
+ frag_main_inner(loc0_1, FragLocations(loc1_1, loc2_1), FragBuiltins(coord_1));
+}
+
+struct FragBuiltins {
+ coord : vec4<f32>,
+}
+
+struct FragLocations {
+ loc1 : f32,
+ loc2 : vec4<u32>,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_kMsl) {
+ auto* src = R"(
+struct FragBuiltins {
+ @builtin(position) coord : vec4<f32>,
+};
+struct FragLocations {
+ @location(1) loc1 : f32,
+ @location(2) @interpolate(flat) loc2 : vec4<u32>,
+};
+
+@fragment
+fn frag_main(@location(0) loc0 : f32,
+ locations : FragLocations,
+ builtins : FragBuiltins) {
+ var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
+}
+)";
+
+ auto* expect = R"(
+struct FragBuiltins {
+ coord : vec4<f32>,
+}
+
+struct FragLocations {
+ loc1 : f32,
+ loc2 : vec4<u32>,
+}
+
+struct tint_symbol_1 {
+ @location(0)
+ loc0 : f32,
+ @location(1)
+ loc1 : f32,
+ @location(2) @interpolate(flat)
+ loc2 : vec4<u32>,
+}
+
+fn frag_main_inner(loc0 : f32, locations : FragLocations, builtins : FragBuiltins) {
+ var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
+}
+
+@fragment
+fn frag_main(@builtin(position) coord : vec4<f32>, tint_symbol : tint_symbol_1) {
+ frag_main_inner(tint_symbol.loc0, FragLocations(tint_symbol.loc1, tint_symbol.loc2), FragBuiltins(coord));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_kMsl_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn frag_main(@location(0) loc0 : f32,
+ locations : FragLocations,
+ builtins : FragBuiltins) {
+ var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
+}
+
+struct FragBuiltins {
+ @builtin(position) coord : vec4<f32>,
+};
+struct FragLocations {
+ @location(1) loc1 : f32,
+ @location(2) @interpolate(flat) loc2 : vec4<u32>,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol_1 {
+ @location(0)
+ loc0 : f32,
+ @location(1)
+ loc1 : f32,
+ @location(2) @interpolate(flat)
+ loc2 : vec4<u32>,
+}
+
+fn frag_main_inner(loc0 : f32, locations : FragLocations, builtins : FragBuiltins) {
+ var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
+}
+
+@fragment
+fn frag_main(@builtin(position) coord : vec4<f32>, tint_symbol : tint_symbol_1) {
+ frag_main_inner(tint_symbol.loc0, FragLocations(tint_symbol.loc1, tint_symbol.loc2), FragBuiltins(coord));
+}
+
+struct FragBuiltins {
+ coord : vec4<f32>,
+}
+
+struct FragLocations {
+ loc1 : f32,
+ loc2 : vec4<u32>,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_Hlsl) {
+ auto* src = R"(
+struct FragBuiltins {
+ @builtin(position) coord : vec4<f32>,
+};
+struct FragLocations {
+ @location(1) loc1 : f32,
+ @location(2) @interpolate(flat) loc2 : vec4<u32>,
+};
+
+@fragment
+fn frag_main(@location(0) loc0 : f32,
+ locations : FragLocations,
+ builtins : FragBuiltins) {
+ var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
+}
+)";
+
+ auto* expect = R"(
+struct FragBuiltins {
+ coord : vec4<f32>,
+}
+
+struct FragLocations {
+ loc1 : f32,
+ loc2 : vec4<u32>,
+}
+
+struct tint_symbol_1 {
+ @location(0)
+ loc0 : f32,
+ @location(1)
+ loc1 : f32,
+ @location(2) @interpolate(flat)
+ loc2 : vec4<u32>,
+ @builtin(position)
+ coord : vec4<f32>,
+}
+
+fn frag_main_inner(loc0 : f32, locations : FragLocations, builtins : FragBuiltins) {
+ var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
+}
+
+@fragment
+fn frag_main(tint_symbol : tint_symbol_1) {
+ frag_main_inner(tint_symbol.loc0, FragLocations(tint_symbol.loc1, tint_symbol.loc2), FragBuiltins(tint_symbol.coord));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_Hlsl_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn frag_main(@location(0) loc0 : f32,
+ locations : FragLocations,
+ builtins : FragBuiltins) {
+ var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
+}
+
+struct FragBuiltins {
+ @builtin(position) coord : vec4<f32>,
+};
+struct FragLocations {
+ @location(1) loc1 : f32,
+ @location(2) @interpolate(flat) loc2 : vec4<u32>,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol_1 {
+ @location(0)
+ loc0 : f32,
+ @location(1)
+ loc1 : f32,
+ @location(2) @interpolate(flat)
+ loc2 : vec4<u32>,
+ @builtin(position)
+ coord : vec4<f32>,
+}
+
+fn frag_main_inner(loc0 : f32, locations : FragLocations, builtins : FragBuiltins) {
+ var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
+}
+
+@fragment
+fn frag_main(tint_symbol : tint_symbol_1) {
+ frag_main_inner(tint_symbol.loc0, FragLocations(tint_symbol.loc1, tint_symbol.loc2), FragBuiltins(tint_symbol.coord));
+}
+
+struct FragBuiltins {
+ coord : vec4<f32>,
+}
+
+struct FragLocations {
+ loc1 : f32,
+ loc2 : vec4<u32>,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Return_NonStruct_Spirv) {
+ auto* src = R"(
+@fragment
+fn frag_main() -> @builtin(frag_depth) f32 {
+ return 1.0;
+}
+)";
+
+ auto* expect = R"(
+@builtin(frag_depth) @internal(disable_validation__ignore_address_space) var<__out> value : f32;
+
+fn frag_main_inner() -> f32 {
+ return 1.0;
+}
+
+@fragment
+fn frag_main() {
+ let inner_result = frag_main_inner();
+ value = inner_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Return_NonStruct_Msl) {
+ auto* src = R"(
+@fragment
+fn frag_main() -> @builtin(frag_depth) f32 {
+ return 1.0;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @builtin(frag_depth)
+ value : f32,
+}
+
+fn frag_main_inner() -> f32 {
+ return 1.0;
+}
+
+@fragment
+fn frag_main() -> tint_symbol {
+ let inner_result = frag_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.value = inner_result;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Return_NonStruct_Hlsl) {
+ auto* src = R"(
+@fragment
+fn frag_main() -> @builtin(frag_depth) f32 {
+ return 1.0;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @builtin(frag_depth)
+ value : f32,
+}
+
+fn frag_main_inner() -> f32 {
+ return 1.0;
+}
+
+@fragment
+fn frag_main() -> tint_symbol {
+ let inner_result = frag_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.value = inner_result;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Spirv) {
+ auto* src = R"(
+struct FragOutput {
+ @location(0) color : vec4<f32>,
+ @builtin(frag_depth) depth : f32,
+ @builtin(sample_mask) mask : u32,
+};
+
+@fragment
+fn frag_main() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+)";
+
+ auto* expect = R"(
+@location(0) @internal(disable_validation__ignore_address_space) var<__out> color_1 : vec4<f32>;
+
+@builtin(frag_depth) @internal(disable_validation__ignore_address_space) var<__out> depth_1 : f32;
+
+@builtin(sample_mask) @internal(disable_validation__ignore_address_space) var<__out> mask_1 : array<u32, 1u>;
+
+struct FragOutput {
+ color : vec4<f32>,
+ depth : f32,
+ mask : u32,
+}
+
+fn frag_main_inner() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+
+@fragment
+fn frag_main() {
+ let inner_result = frag_main_inner();
+ color_1 = inner_result.color;
+ depth_1 = inner_result.depth;
+ mask_1[0i] = inner_result.mask;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Spirv_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn frag_main() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+
+struct FragOutput {
+ @location(0) color : vec4<f32>,
+ @builtin(frag_depth) depth : f32,
+ @builtin(sample_mask) mask : u32,
+};
+)";
+
+ auto* expect = R"(
+@location(0) @internal(disable_validation__ignore_address_space) var<__out> color_1 : vec4<f32>;
+
+@builtin(frag_depth) @internal(disable_validation__ignore_address_space) var<__out> depth_1 : f32;
+
+@builtin(sample_mask) @internal(disable_validation__ignore_address_space) var<__out> mask_1 : array<u32, 1u>;
+
+fn frag_main_inner() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+
+@fragment
+fn frag_main() {
+ let inner_result = frag_main_inner();
+ color_1 = inner_result.color;
+ depth_1 = inner_result.depth;
+ mask_1[0i] = inner_result.mask;
+}
+
+struct FragOutput {
+ color : vec4<f32>,
+ depth : f32,
+ mask : u32,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Msl) {
+ auto* src = R"(
+struct FragOutput {
+ @location(0) color : vec4<f32>,
+ @builtin(frag_depth) depth : f32,
+ @builtin(sample_mask) mask : u32,
+};
+
+@fragment
+fn frag_main() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+)";
+
+ auto* expect = R"(
+struct FragOutput {
+ color : vec4<f32>,
+ depth : f32,
+ mask : u32,
+}
+
+struct tint_symbol {
+ @location(0)
+ color : vec4<f32>,
+ @builtin(frag_depth)
+ depth : f32,
+ @builtin(sample_mask)
+ mask : u32,
+}
+
+fn frag_main_inner() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+
+@fragment
+fn frag_main() -> tint_symbol {
+ let inner_result = frag_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.color = inner_result.color;
+ wrapper_result.depth = inner_result.depth;
+ wrapper_result.mask = inner_result.mask;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Msl_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn frag_main() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+
+struct FragOutput {
+ @location(0) color : vec4<f32>,
+ @builtin(frag_depth) depth : f32,
+ @builtin(sample_mask) mask : u32,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @location(0)
+ color : vec4<f32>,
+ @builtin(frag_depth)
+ depth : f32,
+ @builtin(sample_mask)
+ mask : u32,
+}
+
+fn frag_main_inner() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+
+@fragment
+fn frag_main() -> tint_symbol {
+ let inner_result = frag_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.color = inner_result.color;
+ wrapper_result.depth = inner_result.depth;
+ wrapper_result.mask = inner_result.mask;
+ return wrapper_result;
+}
+
+struct FragOutput {
+ color : vec4<f32>,
+ depth : f32,
+ mask : u32,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Hlsl) {
+ auto* src = R"(
+struct FragOutput {
+ @location(0) color : vec4<f32>,
+ @builtin(frag_depth) depth : f32,
+ @builtin(sample_mask) mask : u32,
+};
+
+@fragment
+fn frag_main() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+)";
+
+ auto* expect = R"(
+struct FragOutput {
+ color : vec4<f32>,
+ depth : f32,
+ mask : u32,
+}
+
+struct tint_symbol {
+ @location(0)
+ color : vec4<f32>,
+ @builtin(frag_depth)
+ depth : f32,
+ @builtin(sample_mask)
+ mask : u32,
+}
+
+fn frag_main_inner() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+
+@fragment
+fn frag_main() -> tint_symbol {
+ let inner_result = frag_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.color = inner_result.color;
+ wrapper_result.depth = inner_result.depth;
+ wrapper_result.mask = inner_result.mask;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Hlsl_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn frag_main() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+
+struct FragOutput {
+ @location(0) color : vec4<f32>,
+ @builtin(frag_depth) depth : f32,
+ @builtin(sample_mask) mask : u32,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @location(0)
+ color : vec4<f32>,
+ @builtin(frag_depth)
+ depth : f32,
+ @builtin(sample_mask)
+ mask : u32,
+}
+
+fn frag_main_inner() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+
+@fragment
+fn frag_main() -> tint_symbol {
+ let inner_result = frag_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.color = inner_result.color;
+ wrapper_result.depth = inner_result.depth;
+ wrapper_result.mask = inner_result.mask;
+ return wrapper_result;
+}
+
+struct FragOutput {
+ color : vec4<f32>,
+ depth : f32,
+ mask : u32,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_SharedDeviceFunction_Spirv) {
+ auto* src = R"(
+struct FragmentInput {
+ @location(0) value : f32,
+ @location(1) mul : f32,
+};
+
+fn foo(x : FragmentInput) -> f32 {
+ return x.value * x.mul;
+}
+
+@fragment
+fn frag_main1(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main2(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+)";
+
+ auto* expect = R"(
+@location(0) @internal(disable_validation__ignore_address_space) var<__in> value_1 : f32;
+
+@location(1) @internal(disable_validation__ignore_address_space) var<__in> mul_1 : f32;
+
+@location(0) @internal(disable_validation__ignore_address_space) var<__in> value_2 : f32;
+
+@location(1) @internal(disable_validation__ignore_address_space) var<__in> mul_2 : f32;
+
+struct FragmentInput {
+ value : f32,
+ mul : f32,
+}
+
+fn foo(x : FragmentInput) -> f32 {
+ return (x.value * x.mul);
+}
+
+fn frag_main1_inner(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main1() {
+ frag_main1_inner(FragmentInput(value_1, mul_1));
+}
+
+fn frag_main2_inner(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main2() {
+ frag_main2_inner(FragmentInput(value_2, mul_2));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_SharedDeviceFunction_Spirv_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn frag_main1(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main2(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+fn foo(x : FragmentInput) -> f32 {
+ return x.value * x.mul;
+}
+
+struct FragmentInput {
+ @location(0) value : f32,
+ @location(1) mul : f32,
+};
+)";
+
+ auto* expect = R"(
+@location(0) @internal(disable_validation__ignore_address_space) var<__in> value_1 : f32;
+
+@location(1) @internal(disable_validation__ignore_address_space) var<__in> mul_1 : f32;
+
+@location(0) @internal(disable_validation__ignore_address_space) var<__in> value_2 : f32;
+
+@location(1) @internal(disable_validation__ignore_address_space) var<__in> mul_2 : f32;
+
+fn frag_main1_inner(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main1() {
+ frag_main1_inner(FragmentInput(value_1, mul_1));
+}
+
+fn frag_main2_inner(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main2() {
+ frag_main2_inner(FragmentInput(value_2, mul_2));
+}
+
+fn foo(x : FragmentInput) -> f32 {
+ return (x.value * x.mul);
+}
+
+struct FragmentInput {
+ value : f32,
+ mul : f32,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_SharedDeviceFunction_Msl) {
+ auto* src = R"(
+struct FragmentInput {
+ @location(0) value : f32,
+ @location(1) mul : f32,
+};
+
+fn foo(x : FragmentInput) -> f32 {
+ return x.value * x.mul;
+}
+
+@fragment
+fn frag_main1(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main2(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+)";
+
+ auto* expect = R"(
+struct FragmentInput {
+ value : f32,
+ mul : f32,
+}
+
+fn foo(x : FragmentInput) -> f32 {
+ return (x.value * x.mul);
+}
+
+struct tint_symbol_1 {
+ @location(0)
+ value : f32,
+ @location(1)
+ mul : f32,
+}
+
+fn frag_main1_inner(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main1(tint_symbol : tint_symbol_1) {
+ frag_main1_inner(FragmentInput(tint_symbol.value, tint_symbol.mul));
+}
+
+struct tint_symbol_3 {
+ @location(0)
+ value : f32,
+ @location(1)
+ mul : f32,
+}
+
+fn frag_main2_inner(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main2(tint_symbol_2 : tint_symbol_3) {
+ frag_main2_inner(FragmentInput(tint_symbol_2.value, tint_symbol_2.mul));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_SharedDeviceFunction_Msl_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn frag_main1(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main2(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+fn foo(x : FragmentInput) -> f32 {
+ return x.value * x.mul;
+}
+
+struct FragmentInput {
+ @location(0) value : f32,
+ @location(1) mul : f32,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol_1 {
+ @location(0)
+ value : f32,
+ @location(1)
+ mul : f32,
+}
+
+fn frag_main1_inner(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main1(tint_symbol : tint_symbol_1) {
+ frag_main1_inner(FragmentInput(tint_symbol.value, tint_symbol.mul));
+}
+
+struct tint_symbol_3 {
+ @location(0)
+ value : f32,
+ @location(1)
+ mul : f32,
+}
+
+fn frag_main2_inner(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main2(tint_symbol_2 : tint_symbol_3) {
+ frag_main2_inner(FragmentInput(tint_symbol_2.value, tint_symbol_2.mul));
+}
+
+fn foo(x : FragmentInput) -> f32 {
+ return (x.value * x.mul);
+}
+
+struct FragmentInput {
+ value : f32,
+ mul : f32,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_SharedDeviceFunction_Hlsl) {
+ auto* src = R"(
+struct FragmentInput {
+ @location(0) value : f32,
+ @location(1) mul : f32,
+};
+
+fn foo(x : FragmentInput) -> f32 {
+ return x.value * x.mul;
+}
+
+@fragment
+fn frag_main1(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main2(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+)";
+
+ auto* expect = R"(
+struct FragmentInput {
+ value : f32,
+ mul : f32,
+}
+
+fn foo(x : FragmentInput) -> f32 {
+ return (x.value * x.mul);
+}
+
+struct tint_symbol_1 {
+ @location(0)
+ value : f32,
+ @location(1)
+ mul : f32,
+}
+
+fn frag_main1_inner(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main1(tint_symbol : tint_symbol_1) {
+ frag_main1_inner(FragmentInput(tint_symbol.value, tint_symbol.mul));
+}
+
+struct tint_symbol_3 {
+ @location(0)
+ value : f32,
+ @location(1)
+ mul : f32,
+}
+
+fn frag_main2_inner(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main2(tint_symbol_2 : tint_symbol_3) {
+ frag_main2_inner(FragmentInput(tint_symbol_2.value, tint_symbol_2.mul));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, StructParameters_SharedDeviceFunction_Hlsl_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn frag_main1(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main2(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+fn foo(x : FragmentInput) -> f32 {
+ return x.value * x.mul;
+}
+
+struct FragmentInput {
+ @location(0) value : f32,
+ @location(1) mul : f32,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol_1 {
+ @location(0)
+ value : f32,
+ @location(1)
+ mul : f32,
+}
+
+fn frag_main1_inner(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main1(tint_symbol : tint_symbol_1) {
+ frag_main1_inner(FragmentInput(tint_symbol.value, tint_symbol.mul));
+}
+
+struct tint_symbol_3 {
+ @location(0)
+ value : f32,
+ @location(1)
+ mul : f32,
+}
+
+fn frag_main2_inner(inputs : FragmentInput) {
+ var x : f32 = foo(inputs);
+}
+
+@fragment
+fn frag_main2(tint_symbol_2 : tint_symbol_3) {
+ frag_main2_inner(FragmentInput(tint_symbol_2.value, tint_symbol_2.mul));
+}
+
+fn foo(x : FragmentInput) -> f32 {
+ return (x.value * x.mul);
+}
+
+struct FragmentInput {
+ value : f32,
+ mul : f32,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Struct_ModuleScopeVariable) {
+ auto* src = R"(
+struct FragmentInput {
+ @location(0) col1 : f32,
+ @location(1) col2 : f32,
+};
+
+var<private> global_inputs : FragmentInput;
+
+fn foo() -> f32 {
+ return global_inputs.col1 * 0.5;
+}
+
+fn bar() -> f32 {
+ return global_inputs.col2 * 2.0;
+}
+
+@fragment
+fn frag_main1(inputs : FragmentInput) {
+ global_inputs = inputs;
+ var r : f32 = foo();
+ var g : f32 = bar();
+}
+)";
+
+ auto* expect = R"(
+struct FragmentInput {
+ col1 : f32,
+ col2 : f32,
+}
+
+var<private> global_inputs : FragmentInput;
+
+fn foo() -> f32 {
+ return (global_inputs.col1 * 0.5);
+}
+
+fn bar() -> f32 {
+ return (global_inputs.col2 * 2.0);
+}
+
+struct tint_symbol_1 {
+ @location(0)
+ col1 : f32,
+ @location(1)
+ col2 : f32,
+}
+
+fn frag_main1_inner(inputs : FragmentInput) {
+ global_inputs = inputs;
+ var r : f32 = foo();
+ var g : f32 = bar();
+}
+
+@fragment
+fn frag_main1(tint_symbol : tint_symbol_1) {
+ frag_main1_inner(FragmentInput(tint_symbol.col1, tint_symbol.col2));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Struct_ModuleScopeVariable_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn frag_main1(inputs : FragmentInput) {
+ global_inputs = inputs;
+ var r : f32 = foo();
+ var g : f32 = bar();
+}
+
+fn foo() -> f32 {
+ return global_inputs.col1 * 0.5;
+}
+
+fn bar() -> f32 {
+ return global_inputs.col2 * 2.0;
+}
+
+var<private> global_inputs : FragmentInput;
+
+struct FragmentInput {
+ @location(0) col1 : f32,
+ @location(1) col2 : f32,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol_1 {
+ @location(0)
+ col1 : f32,
+ @location(1)
+ col2 : f32,
+}
+
+fn frag_main1_inner(inputs : FragmentInput) {
+ global_inputs = inputs;
+ var r : f32 = foo();
+ var g : f32 = bar();
+}
+
+@fragment
+fn frag_main1(tint_symbol : tint_symbol_1) {
+ frag_main1_inner(FragmentInput(tint_symbol.col1, tint_symbol.col2));
+}
+
+fn foo() -> f32 {
+ return (global_inputs.col1 * 0.5);
+}
+
+fn bar() -> f32 {
+ return (global_inputs.col2 * 2.0);
+}
+
+var<private> global_inputs : FragmentInput;
+
+struct FragmentInput {
+ col1 : f32,
+ col2 : f32,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Struct_TypeAliases) {
+ auto* src = R"(
+alias myf32 = f32;
+
+struct FragmentInput {
+ @location(0) col1 : myf32,
+ @location(1) col2 : myf32,
+};
+
+struct FragmentOutput {
+ @location(0) col1 : myf32,
+ @location(1) col2 : myf32,
+};
+
+alias MyFragmentInput = FragmentInput;
+
+alias MyFragmentOutput = FragmentOutput;
+
+fn foo(x : MyFragmentInput) -> myf32 {
+ return x.col1;
+}
+
+@fragment
+fn frag_main(inputs : MyFragmentInput) -> MyFragmentOutput {
+ var x : myf32 = foo(inputs);
+ return MyFragmentOutput(x, inputs.col2);
+}
+)";
+
+ auto* expect = R"(
+alias myf32 = f32;
+
+struct FragmentInput {
+ col1 : myf32,
+ col2 : myf32,
+}
+
+struct FragmentOutput {
+ col1 : myf32,
+ col2 : myf32,
+}
+
+alias MyFragmentInput = FragmentInput;
+
+alias MyFragmentOutput = FragmentOutput;
+
+fn foo(x : MyFragmentInput) -> myf32 {
+ return x.col1;
+}
+
+struct tint_symbol_1 {
+ @location(0)
+ col1 : f32,
+ @location(1)
+ col2 : f32,
+}
+
+struct tint_symbol_2 {
+ @location(0)
+ col1 : f32,
+ @location(1)
+ col2 : f32,
+}
+
+fn frag_main_inner(inputs : MyFragmentInput) -> MyFragmentOutput {
+ var x : myf32 = foo(inputs);
+ return MyFragmentOutput(x, inputs.col2);
+}
+
+@fragment
+fn frag_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
+ let inner_result = frag_main_inner(MyFragmentInput(tint_symbol.col1, tint_symbol.col2));
+ var wrapper_result : tint_symbol_2;
+ wrapper_result.col1 = inner_result.col1;
+ wrapper_result.col2 = inner_result.col2;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Struct_TypeAliases_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn frag_main(inputs : MyFragmentInput) -> MyFragmentOutput {
+ var x : myf32 = foo(inputs);
+ return MyFragmentOutput(x, inputs.col2);
+}
+
+alias MyFragmentInput = FragmentInput;
+
+alias MyFragmentOutput = FragmentOutput;
+
+fn foo(x : MyFragmentInput) -> myf32 {
+ return x.col1;
+}
+
+struct FragmentInput {
+ @location(0) col1 : myf32,
+ @location(1) col2 : myf32,
+};
+
+struct FragmentOutput {
+ @location(0) col1 : myf32,
+ @location(1) col2 : myf32,
+};
+
+alias myf32 = f32;
+)";
+
+ auto* expect = R"(
+struct tint_symbol_1 {
+ @location(0)
+ col1 : f32,
+ @location(1)
+ col2 : f32,
+}
+
+struct tint_symbol_2 {
+ @location(0)
+ col1 : f32,
+ @location(1)
+ col2 : f32,
+}
+
+fn frag_main_inner(inputs : MyFragmentInput) -> MyFragmentOutput {
+ var x : myf32 = foo(inputs);
+ return MyFragmentOutput(x, inputs.col2);
+}
+
+@fragment
+fn frag_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
+ let inner_result = frag_main_inner(MyFragmentInput(tint_symbol.col1, tint_symbol.col2));
+ var wrapper_result : tint_symbol_2;
+ wrapper_result.col1 = inner_result.col1;
+ wrapper_result.col2 = inner_result.col2;
+ return wrapper_result;
+}
+
+alias MyFragmentInput = FragmentInput;
+
+alias MyFragmentOutput = FragmentOutput;
+
+fn foo(x : MyFragmentInput) -> myf32 {
+ return x.col1;
+}
+
+struct FragmentInput {
+ col1 : myf32,
+ col2 : myf32,
+}
+
+struct FragmentOutput {
+ col1 : myf32,
+ col2 : myf32,
+}
+
+alias myf32 = f32;
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, InterpolateAttributes) {
+ auto* src = R"(
+struct VertexOut {
+ @builtin(position) pos : vec4<f32>,
+ @location(1) @interpolate(flat) loc1 : f32,
+ @location(2) @interpolate(linear, sample) loc2 : f32,
+ @location(3) @interpolate(perspective, centroid) loc3 : f32,
+};
+
+struct FragmentIn {
+ @location(1) @interpolate(flat) loc1 : f32,
+ @location(2) @interpolate(linear, sample) loc2 : f32,
+};
+
+@vertex
+fn vert_main() -> VertexOut {
+ return VertexOut();
+}
+
+@fragment
+fn frag_main(inputs : FragmentIn,
+ @location(3) @interpolate(perspective, centroid) loc3 : f32) {
+ let x = inputs.loc1 + inputs.loc2 + loc3;
+}
+)";
+
+ auto* expect = R"(
+struct VertexOut {
+ pos : vec4<f32>,
+ loc1 : f32,
+ loc2 : f32,
+ loc3 : f32,
+}
+
+struct FragmentIn {
+ loc1 : f32,
+ loc2 : f32,
+}
+
+struct tint_symbol {
+ @location(1) @interpolate(flat)
+ loc1 : f32,
+ @location(2) @interpolate(linear, sample)
+ loc2 : f32,
+ @location(3) @interpolate(perspective, centroid)
+ loc3 : f32,
+ @builtin(position)
+ pos : vec4<f32>,
+}
+
+fn vert_main_inner() -> VertexOut {
+ return VertexOut();
+}
+
+@vertex
+fn vert_main() -> tint_symbol {
+ let inner_result = vert_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.pos = inner_result.pos;
+ wrapper_result.loc1 = inner_result.loc1;
+ wrapper_result.loc2 = inner_result.loc2;
+ wrapper_result.loc3 = inner_result.loc3;
+ return wrapper_result;
+}
+
+struct tint_symbol_2 {
+ @location(1) @interpolate(flat)
+ loc1 : f32,
+ @location(2) @interpolate(linear, sample)
+ loc2 : f32,
+ @location(3) @interpolate(perspective, centroid)
+ loc3 : f32,
+}
+
+fn frag_main_inner(inputs : FragmentIn, loc3 : f32) {
+ let x = ((inputs.loc1 + inputs.loc2) + loc3);
+}
+
+@fragment
+fn frag_main(tint_symbol_1 : tint_symbol_2) {
+ frag_main_inner(FragmentIn(tint_symbol_1.loc1, tint_symbol_1.loc2), tint_symbol_1.loc3);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, InterpolateAttributes_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn frag_main(inputs : FragmentIn,
+ @location(3) @interpolate(perspective, centroid) loc3 : f32) {
+ let x = inputs.loc1 + inputs.loc2 + loc3;
+}
+
+@vertex
+fn vert_main() -> VertexOut {
+ return VertexOut();
+}
+
+struct VertexOut {
+ @builtin(position) pos : vec4<f32>,
+ @location(1) @interpolate(flat) loc1 : f32,
+ @location(2) @interpolate(linear, sample) loc2 : f32,
+ @location(3) @interpolate(perspective, centroid) loc3 : f32,
+};
+
+struct FragmentIn {
+ @location(1) @interpolate(flat) loc1: f32,
+ @location(2) @interpolate(linear, sample) loc2 : f32,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol_1 {
+ @location(1) @interpolate(flat)
+ loc1 : f32,
+ @location(2) @interpolate(linear, sample)
+ loc2 : f32,
+ @location(3) @interpolate(perspective, centroid)
+ loc3 : f32,
+}
+
+fn frag_main_inner(inputs : FragmentIn, loc3 : f32) {
+ let x = ((inputs.loc1 + inputs.loc2) + loc3);
+}
+
+@fragment
+fn frag_main(tint_symbol : tint_symbol_1) {
+ frag_main_inner(FragmentIn(tint_symbol.loc1, tint_symbol.loc2), tint_symbol.loc3);
+}
+
+struct tint_symbol_2 {
+ @location(1) @interpolate(flat)
+ loc1 : f32,
+ @location(2) @interpolate(linear, sample)
+ loc2 : f32,
+ @location(3) @interpolate(perspective, centroid)
+ loc3 : f32,
+ @builtin(position)
+ pos : vec4<f32>,
+}
+
+fn vert_main_inner() -> VertexOut {
+ return VertexOut();
+}
+
+@vertex
+fn vert_main() -> tint_symbol_2 {
+ let inner_result = vert_main_inner();
+ var wrapper_result : tint_symbol_2;
+ wrapper_result.pos = inner_result.pos;
+ wrapper_result.loc1 = inner_result.loc1;
+ wrapper_result.loc2 = inner_result.loc2;
+ wrapper_result.loc3 = inner_result.loc3;
+ return wrapper_result;
+}
+
+struct VertexOut {
+ pos : vec4<f32>,
+ loc1 : f32,
+ loc2 : f32,
+ loc3 : f32,
+}
+
+struct FragmentIn {
+ loc1 : f32,
+ loc2 : f32,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, InterpolateAttributes_Integers_Spirv) {
+ // Test that we add a Flat attribute to integers that are vertex outputs and
+ // fragment inputs, but not vertex inputs or fragment outputs.
+ auto* src = R"(
+struct VertexIn {
+ @location(0) i : i32,
+ @location(1) u : u32,
+ @location(2) vi : vec4<i32>,
+ @location(3) vu : vec4<u32>,
+};
+
+struct VertexOut {
+ @location(0) @interpolate(flat) i : i32,
+ @location(1) @interpolate(flat) u : u32,
+ @location(2) @interpolate(flat) vi : vec4<i32>,
+ @location(3) @interpolate(flat) vu : vec4<u32>,
+ @builtin(position) pos : vec4<f32>,
+};
+
+struct FragmentInterface {
+ @location(0) @interpolate(flat) i : i32,
+ @location(1) @interpolate(flat) u : u32,
+ @location(2) @interpolate(flat) vi : vec4<i32>,
+ @location(3) @interpolate(flat) vu : vec4<u32>,
+};
+
+@vertex
+fn vert_main(in : VertexIn) -> VertexOut {
+ return VertexOut(in.i, in.u, in.vi, in.vu, vec4<f32>());
+}
+
+@fragment
+fn frag_main(inputs : FragmentInterface) -> FragmentInterface {
+ return inputs;
+}
+)";
+
+ auto* expect =
+ R"(
+@location(0) @internal(disable_validation__ignore_address_space) var<__in> i_1 : i32;
+
+@location(1) @internal(disable_validation__ignore_address_space) var<__in> u_1 : u32;
+
+@location(2) @internal(disable_validation__ignore_address_space) var<__in> vi_1 : vec4<i32>;
+
+@location(3) @internal(disable_validation__ignore_address_space) var<__in> vu_1 : vec4<u32>;
+
+@location(0) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__out> i_2 : i32;
+
+@location(1) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__out> u_2 : u32;
+
+@location(2) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__out> vi_2 : vec4<i32>;
+
+@location(3) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__out> vu_2 : vec4<u32>;
+
+@builtin(position) @internal(disable_validation__ignore_address_space) var<__out> pos_1 : vec4<f32>;
+
+@location(0) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__in> i_3 : i32;
+
+@location(1) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__in> u_3 : u32;
+
+@location(2) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__in> vi_3 : vec4<i32>;
+
+@location(3) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__in> vu_3 : vec4<u32>;
+
+@location(0) @internal(disable_validation__ignore_address_space) var<__out> i_4 : i32;
+
+@location(1) @internal(disable_validation__ignore_address_space) var<__out> u_4 : u32;
+
+@location(2) @internal(disable_validation__ignore_address_space) var<__out> vi_4 : vec4<i32>;
+
+@location(3) @internal(disable_validation__ignore_address_space) var<__out> vu_4 : vec4<u32>;
+
+struct VertexIn {
+ i : i32,
+ u : u32,
+ vi : vec4<i32>,
+ vu : vec4<u32>,
+}
+
+struct VertexOut {
+ i : i32,
+ u : u32,
+ vi : vec4<i32>,
+ vu : vec4<u32>,
+ pos : vec4<f32>,
+}
+
+struct FragmentInterface {
+ i : i32,
+ u : u32,
+ vi : vec4<i32>,
+ vu : vec4<u32>,
+}
+
+fn vert_main_inner(in : VertexIn) -> VertexOut {
+ return VertexOut(in.i, in.u, in.vi, in.vu, vec4<f32>());
+}
+
+@vertex
+fn vert_main() {
+ let inner_result = vert_main_inner(VertexIn(i_1, u_1, vi_1, vu_1));
+ i_2 = inner_result.i;
+ u_2 = inner_result.u;
+ vi_2 = inner_result.vi;
+ vu_2 = inner_result.vu;
+ pos_1 = inner_result.pos;
+}
+
+fn frag_main_inner(inputs : FragmentInterface) -> FragmentInterface {
+ return inputs;
+}
+
+@fragment
+fn frag_main() {
+ let inner_result_1 = frag_main_inner(FragmentInterface(i_3, u_3, vi_3, vu_3));
+ i_4 = inner_result_1.i;
+ u_4 = inner_result_1.u;
+ vi_4 = inner_result_1.vi;
+ vu_4 = inner_result_1.vu;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, InterpolateAttributes_Integers_Spirv_OutOfOrder) {
+ // Test that we add a Flat attribute to integers that are vertex outputs and
+ // fragment inputs, but not vertex inputs or fragment outputs.
+ auto* src = R"(
+@vertex
+fn vert_main(in : VertexIn) -> VertexOut {
+ return VertexOut(in.i, in.u, in.vi, in.vu, vec4<f32>());
+}
+
+@fragment
+fn frag_main(inputs : FragmentInterface) -> FragmentInterface {
+ return inputs;
+}
+
+struct VertexIn {
+ @location(0) i : i32,
+ @location(1) u : u32,
+ @location(2) vi : vec4<i32>,
+ @location(3) vu : vec4<u32>,
+};
+
+struct VertexOut {
+ @location(0) @interpolate(flat) i : i32,
+ @location(1) @interpolate(flat) u : u32,
+ @location(2) @interpolate(flat) vi : vec4<i32>,
+ @location(3) @interpolate(flat) vu : vec4<u32>,
+ @builtin(position) pos : vec4<f32>,
+};
+
+struct FragmentInterface {
+ @location(0) @interpolate(flat) i : i32,
+ @location(1) @interpolate(flat) u : u32,
+ @location(2) @interpolate(flat) vi : vec4<i32>,
+ @location(3) @interpolate(flat) vu : vec4<u32>,
+};
+)";
+
+ auto* expect =
+ R"(
+@location(0) @internal(disable_validation__ignore_address_space) var<__in> i_1 : i32;
+
+@location(1) @internal(disable_validation__ignore_address_space) var<__in> u_1 : u32;
+
+@location(2) @internal(disable_validation__ignore_address_space) var<__in> vi_1 : vec4<i32>;
+
+@location(3) @internal(disable_validation__ignore_address_space) var<__in> vu_1 : vec4<u32>;
+
+@location(0) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__out> i_2 : i32;
+
+@location(1) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__out> u_2 : u32;
+
+@location(2) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__out> vi_2 : vec4<i32>;
+
+@location(3) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__out> vu_2 : vec4<u32>;
+
+@builtin(position) @internal(disable_validation__ignore_address_space) var<__out> pos_1 : vec4<f32>;
+
+@location(0) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__in> i_3 : i32;
+
+@location(1) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__in> u_3 : u32;
+
+@location(2) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__in> vi_3 : vec4<i32>;
+
+@location(3) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__in> vu_3 : vec4<u32>;
+
+@location(0) @internal(disable_validation__ignore_address_space) var<__out> i_4 : i32;
+
+@location(1) @internal(disable_validation__ignore_address_space) var<__out> u_4 : u32;
+
+@location(2) @internal(disable_validation__ignore_address_space) var<__out> vi_4 : vec4<i32>;
+
+@location(3) @internal(disable_validation__ignore_address_space) var<__out> vu_4 : vec4<u32>;
+
+fn vert_main_inner(in : VertexIn) -> VertexOut {
+ return VertexOut(in.i, in.u, in.vi, in.vu, vec4<f32>());
+}
+
+@vertex
+fn vert_main() {
+ let inner_result = vert_main_inner(VertexIn(i_1, u_1, vi_1, vu_1));
+ i_2 = inner_result.i;
+ u_2 = inner_result.u;
+ vi_2 = inner_result.vi;
+ vu_2 = inner_result.vu;
+ pos_1 = inner_result.pos;
+}
+
+fn frag_main_inner(inputs : FragmentInterface) -> FragmentInterface {
+ return inputs;
+}
+
+@fragment
+fn frag_main() {
+ let inner_result_1 = frag_main_inner(FragmentInterface(i_3, u_3, vi_3, vu_3));
+ i_4 = inner_result_1.i;
+ u_4 = inner_result_1.u;
+ vi_4 = inner_result_1.vi;
+ vu_4 = inner_result_1.vu;
+}
+
+struct VertexIn {
+ i : i32,
+ u : u32,
+ vi : vec4<i32>,
+ vu : vec4<u32>,
+}
+
+struct VertexOut {
+ i : i32,
+ u : u32,
+ vi : vec4<i32>,
+ vu : vec4<u32>,
+ pos : vec4<f32>,
+}
+
+struct FragmentInterface {
+ i : i32,
+ u : u32,
+ vi : vec4<i32>,
+ vu : vec4<u32>,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, InvariantAttributes) {
+ auto* src = R"(
+struct VertexOut {
+ @builtin(position) @invariant pos : vec4<f32>,
+};
+
+@vertex
+fn main1() -> VertexOut {
+ return VertexOut();
+}
+
+@vertex
+fn main2() -> @builtin(position) @invariant vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+struct VertexOut {
+ pos : vec4<f32>,
+}
+
+struct tint_symbol {
+ @builtin(position) @invariant
+ pos : vec4<f32>,
+}
+
+fn main1_inner() -> VertexOut {
+ return VertexOut();
+}
+
+@vertex
+fn main1() -> tint_symbol {
+ let inner_result = main1_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.pos = inner_result.pos;
+ return wrapper_result;
+}
+
+struct tint_symbol_1 {
+ @builtin(position) @invariant
+ value : vec4<f32>,
+}
+
+fn main2_inner() -> vec4<f32> {
+ return vec4<f32>();
+}
+
+@vertex
+fn main2() -> tint_symbol_1 {
+ let inner_result_1 = main2_inner();
+ var wrapper_result_1 : tint_symbol_1;
+ wrapper_result_1.value = inner_result_1;
+ return wrapper_result_1;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, InvariantAttributes_OutOfOrder) {
+ auto* src = R"(
+@vertex
+fn main1() -> VertexOut {
+ return VertexOut();
+}
+
+@vertex
+fn main2() -> @builtin(position) @invariant vec4<f32> {
+ return vec4<f32>();
+}
+
+struct VertexOut {
+ @builtin(position) @invariant pos : vec4<f32>,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @builtin(position) @invariant
+ pos : vec4<f32>,
+}
+
+fn main1_inner() -> VertexOut {
+ return VertexOut();
+}
+
+@vertex
+fn main1() -> tint_symbol {
+ let inner_result = main1_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.pos = inner_result.pos;
+ return wrapper_result;
+}
+
+struct tint_symbol_1 {
+ @builtin(position) @invariant
+ value : vec4<f32>,
+}
+
+fn main2_inner() -> vec4<f32> {
+ return vec4<f32>();
+}
+
+@vertex
+fn main2() -> tint_symbol_1 {
+ let inner_result_1 = main2_inner();
+ var wrapper_result_1 : tint_symbol_1;
+ wrapper_result_1.value = inner_result_1;
+ return wrapper_result_1;
+}
+
+struct VertexOut {
+ pos : vec4<f32>,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Struct_LayoutAttributes) {
+ auto* src = R"(
+struct FragmentInput {
+ @size(16) @location(1) value : f32,
+ @builtin(position) @align(32) coord : vec4<f32>,
+ @location(0) @interpolate(linear, sample) @align(128) loc0 : f32,
+};
+
+struct FragmentOutput {
+ @size(16) @location(1) @interpolate(flat) value : f32,
+};
+
+@fragment
+fn frag_main(inputs : FragmentInput) -> FragmentOutput {
+ return FragmentOutput(inputs.coord.x * inputs.value + inputs.loc0);
+}
+)";
+
+ auto* expect = R"(
+struct FragmentInput {
+ @size(16)
+ value : f32,
+ @align(32)
+ coord : vec4<f32>,
+ @align(128)
+ loc0 : f32,
+}
+
+struct FragmentOutput {
+ @size(16)
+ value : f32,
+}
+
+struct tint_symbol_1 {
+ @location(0) @interpolate(linear, sample)
+ loc0 : f32,
+ @location(1)
+ value : f32,
+ @builtin(position)
+ coord : vec4<f32>,
+}
+
+struct tint_symbol_2 {
+ @location(1)
+ value : f32,
+}
+
+fn frag_main_inner(inputs : FragmentInput) -> FragmentOutput {
+ return FragmentOutput(((inputs.coord.x * inputs.value) + inputs.loc0));
+}
+
+@fragment
+fn frag_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
+ let inner_result = frag_main_inner(FragmentInput(tint_symbol.value, tint_symbol.coord, tint_symbol.loc0));
+ var wrapper_result : tint_symbol_2;
+ wrapper_result.value = inner_result.value;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Struct_LayoutAttributes_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn frag_main(inputs : FragmentInput) -> FragmentOutput {
+ return FragmentOutput(inputs.coord.x * inputs.value + inputs.loc0);
+}
+
+struct FragmentInput {
+ @size(16) @location(1) value : f32,
+ @builtin(position) @align(32) coord : vec4<f32>,
+ @location(0) @interpolate(linear, sample) @align(128) loc0 : f32,
+};
+
+struct FragmentOutput {
+ @size(16) @location(1) @interpolate(flat) value : f32,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol_1 {
+ @location(0) @interpolate(linear, sample)
+ loc0 : f32,
+ @location(1)
+ value : f32,
+ @builtin(position)
+ coord : vec4<f32>,
+}
+
+struct tint_symbol_2 {
+ @location(1)
+ value : f32,
+}
+
+fn frag_main_inner(inputs : FragmentInput) -> FragmentOutput {
+ return FragmentOutput(((inputs.coord.x * inputs.value) + inputs.loc0));
+}
+
+@fragment
+fn frag_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
+ let inner_result = frag_main_inner(FragmentInput(tint_symbol.value, tint_symbol.coord, tint_symbol.loc0));
+ var wrapper_result : tint_symbol_2;
+ wrapper_result.value = inner_result.value;
+ return wrapper_result;
+}
+
+struct FragmentInput {
+ @size(16)
+ value : f32,
+ @align(32)
+ coord : vec4<f32>,
+ @align(128)
+ loc0 : f32,
+}
+
+struct FragmentOutput {
+ @size(16)
+ value : f32,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, SortedMembers) {
+ auto* src = R"(
+struct VertexOutput {
+ @location(1) @interpolate(flat) b : u32,
+ @builtin(position) pos : vec4<f32>,
+ @location(3) @interpolate(flat) d : u32,
+ @location(0) a : f32,
+ @location(2) @interpolate(flat) c : i32,
+};
+
+struct FragmentInputExtra {
+ @location(3) @interpolate(flat) d : u32,
+ @builtin(position) pos : vec4<f32>,
+ @location(0) a : f32,
+};
+
+@vertex
+fn vert_main() -> VertexOutput {
+ return VertexOutput();
+}
+
+@fragment
+fn frag_main(@builtin(front_facing) ff : bool,
+ @location(2) @interpolate(flat) c : i32,
+ inputs : FragmentInputExtra,
+ @location(1) @interpolate(flat) b : u32) {
+}
+)";
+
+ auto* expect = R"(
+struct VertexOutput {
+ b : u32,
+ pos : vec4<f32>,
+ d : u32,
+ a : f32,
+ c : i32,
+}
+
+struct FragmentInputExtra {
+ d : u32,
+ pos : vec4<f32>,
+ a : f32,
+}
+
+struct tint_symbol {
+ @location(0)
+ a : f32,
+ @location(1) @interpolate(flat)
+ b : u32,
+ @location(2) @interpolate(flat)
+ c : i32,
+ @location(3) @interpolate(flat)
+ d : u32,
+ @builtin(position)
+ pos : vec4<f32>,
+}
+
+fn vert_main_inner() -> VertexOutput {
+ return VertexOutput();
+}
+
+@vertex
+fn vert_main() -> tint_symbol {
+ let inner_result = vert_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.b = inner_result.b;
+ wrapper_result.pos = inner_result.pos;
+ wrapper_result.d = inner_result.d;
+ wrapper_result.a = inner_result.a;
+ wrapper_result.c = inner_result.c;
+ return wrapper_result;
+}
+
+struct tint_symbol_2 {
+ @location(0)
+ a : f32,
+ @location(1) @interpolate(flat)
+ b : u32,
+ @location(2) @interpolate(flat)
+ c : i32,
+ @location(3) @interpolate(flat)
+ d : u32,
+ @builtin(position)
+ pos : vec4<f32>,
+ @builtin(front_facing)
+ ff : bool,
+}
+
+fn frag_main_inner(ff : bool, c : i32, inputs : FragmentInputExtra, b : u32) {
+}
+
+@fragment
+fn frag_main(tint_symbol_1 : tint_symbol_2) {
+ frag_main_inner(tint_symbol_1.ff, tint_symbol_1.c, FragmentInputExtra(tint_symbol_1.d, tint_symbol_1.pos, tint_symbol_1.a), tint_symbol_1.b);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, SortedMembers_OutOfOrder) {
+ auto* src = R"(
+@vertex
+fn vert_main() -> VertexOutput {
+ return VertexOutput();
+}
+
+@fragment
+fn frag_main(@builtin(front_facing) ff : bool,
+ @location(2) @interpolate(flat) c : i32,
+ inputs : FragmentInputExtra,
+ @location(1) @interpolate(flat) b : u32) {
+}
+
+struct VertexOutput {
+ @location(1) @interpolate(flat) b : u32,
+ @builtin(position) pos : vec4<f32>,
+ @location(3) @interpolate(flat) d : u32,
+ @location(0) a : f32,
+ @location(2) @interpolate(flat) c : i32,
+};
+
+struct FragmentInputExtra {
+ @location(3) @interpolate(flat) d : u32,
+ @builtin(position) pos : vec4<f32>,
+ @location(0) a : f32,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @location(0)
+ a : f32,
+ @location(1) @interpolate(flat)
+ b : u32,
+ @location(2) @interpolate(flat)
+ c : i32,
+ @location(3) @interpolate(flat)
+ d : u32,
+ @builtin(position)
+ pos : vec4<f32>,
+}
+
+fn vert_main_inner() -> VertexOutput {
+ return VertexOutput();
+}
+
+@vertex
+fn vert_main() -> tint_symbol {
+ let inner_result = vert_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.b = inner_result.b;
+ wrapper_result.pos = inner_result.pos;
+ wrapper_result.d = inner_result.d;
+ wrapper_result.a = inner_result.a;
+ wrapper_result.c = inner_result.c;
+ return wrapper_result;
+}
+
+struct tint_symbol_2 {
+ @location(0)
+ a : f32,
+ @location(1) @interpolate(flat)
+ b : u32,
+ @location(2) @interpolate(flat)
+ c : i32,
+ @location(3) @interpolate(flat)
+ d : u32,
+ @builtin(position)
+ pos : vec4<f32>,
+ @builtin(front_facing)
+ ff : bool,
+}
+
+fn frag_main_inner(ff : bool, c : i32, inputs : FragmentInputExtra, b : u32) {
+}
+
+@fragment
+fn frag_main(tint_symbol_1 : tint_symbol_2) {
+ frag_main_inner(tint_symbol_1.ff, tint_symbol_1.c, FragmentInputExtra(tint_symbol_1.d, tint_symbol_1.pos, tint_symbol_1.a), tint_symbol_1.b);
+}
+
+struct VertexOutput {
+ b : u32,
+ pos : vec4<f32>,
+ d : u32,
+ a : f32,
+ c : i32,
+}
+
+struct FragmentInputExtra {
+ d : u32,
+ pos : vec4<f32>,
+ a : f32,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, DontRenameSymbols) {
+ auto* src = R"(
+@fragment
+fn tint_symbol_1(@location(0) col : f32) {
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol_2 {
+ @location(0)
+ col : f32,
+}
+
+fn tint_symbol_1_inner(col : f32) {
+}
+
+@fragment
+fn tint_symbol_1(tint_symbol : tint_symbol_2) {
+ tint_symbol_1_inner(tint_symbol.col);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_VoidNoReturn) {
+ auto* src = R"(
+@fragment
+fn frag_main() {
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @builtin(sample_mask)
+ fixed_sample_mask : u32,
+}
+
+fn frag_main_inner() {
+}
+
+@fragment
+fn frag_main() -> tint_symbol {
+ frag_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.fixed_sample_mask = 3u;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_VoidWithReturn) {
+ auto* src = R"(
+@fragment
+fn frag_main() {
+ return;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @builtin(sample_mask)
+ fixed_sample_mask : u32,
+}
+
+fn frag_main_inner() {
+ return;
+}
+
+@fragment
+fn frag_main() -> tint_symbol {
+ frag_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.fixed_sample_mask = 3u;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_WithAuthoredMask) {
+ auto* src = R"(
+@fragment
+fn frag_main() -> @builtin(sample_mask) u32 {
+ return 7u;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @builtin(sample_mask)
+ value : u32,
+}
+
+fn frag_main_inner() -> u32 {
+ return 7u;
+}
+
+@fragment
+fn frag_main() -> tint_symbol {
+ let inner_result = frag_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.value = (inner_result & 3u);
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_WithoutAuthoredMask) {
+ auto* src = R"(
+@fragment
+fn frag_main() -> @location(0) f32 {
+ return 1.0;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @location(0)
+ value : f32,
+ @builtin(sample_mask)
+ fixed_sample_mask : u32,
+}
+
+fn frag_main_inner() -> f32 {
+ return 1.0;
+}
+
+@fragment
+fn frag_main() -> tint_symbol {
+ let inner_result = frag_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.value = inner_result;
+ wrapper_result.fixed_sample_mask = 3u;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_StructWithAuthoredMask) {
+ auto* src = R"(
+struct Output {
+ @builtin(frag_depth) depth : f32,
+ @builtin(sample_mask) mask : u32,
+ @location(0) value : f32,
+};
+
+@fragment
+fn frag_main() -> Output {
+ return Output(0.5, 7u, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct Output {
+ depth : f32,
+ mask : u32,
+ value : f32,
+}
+
+struct tint_symbol {
+ @location(0)
+ value : f32,
+ @builtin(frag_depth)
+ depth : f32,
+ @builtin(sample_mask)
+ mask : u32,
+}
+
+fn frag_main_inner() -> Output {
+ return Output(0.5, 7u, 1.0);
+}
+
+@fragment
+fn frag_main() -> tint_symbol {
+ let inner_result = frag_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.depth = inner_result.depth;
+ wrapper_result.mask = (inner_result.mask & 3u);
+ wrapper_result.value = inner_result.value;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_StructWithAuthoredMask_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn frag_main() -> Output {
+ return Output(0.5, 7u, 1.0);
+}
+
+struct Output {
+ @builtin(frag_depth) depth : f32,
+ @builtin(sample_mask) mask : u32,
+ @location(0) value : f32,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @location(0)
+ value : f32,
+ @builtin(frag_depth)
+ depth : f32,
+ @builtin(sample_mask)
+ mask : u32,
+}
+
+fn frag_main_inner() -> Output {
+ return Output(0.5, 7u, 1.0);
+}
+
+@fragment
+fn frag_main() -> tint_symbol {
+ let inner_result = frag_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.depth = inner_result.depth;
+ wrapper_result.mask = (inner_result.mask & 3u);
+ wrapper_result.value = inner_result.value;
+ return wrapper_result;
+}
+
+struct Output {
+ depth : f32,
+ mask : u32,
+ value : f32,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_StructWithoutAuthoredMask) {
+ auto* src = R"(
+struct Output {
+ @builtin(frag_depth) depth : f32,
+ @location(0) value : f32,
+};
+
+@fragment
+fn frag_main() -> Output {
+ return Output(0.5, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct Output {
+ depth : f32,
+ value : f32,
+}
+
+struct tint_symbol {
+ @location(0)
+ value : f32,
+ @builtin(frag_depth)
+ depth : f32,
+ @builtin(sample_mask)
+ fixed_sample_mask : u32,
+}
+
+fn frag_main_inner() -> Output {
+ return Output(0.5, 1.0);
+}
+
+@fragment
+fn frag_main() -> tint_symbol {
+ let inner_result = frag_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.depth = inner_result.depth;
+ wrapper_result.value = inner_result.value;
+ wrapper_result.fixed_sample_mask = 3u;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_StructWithoutAuthoredMask_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn frag_main() -> Output {
+ return Output(0.5, 1.0);
+}
+
+struct Output {
+ @builtin(frag_depth) depth : f32,
+ @location(0) value : f32,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @location(0)
+ value : f32,
+ @builtin(frag_depth)
+ depth : f32,
+ @builtin(sample_mask)
+ fixed_sample_mask : u32,
+}
+
+fn frag_main_inner() -> Output {
+ return Output(0.5, 1.0);
+}
+
+@fragment
+fn frag_main() -> tint_symbol {
+ let inner_result = frag_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.depth = inner_result.depth;
+ wrapper_result.value = inner_result.value;
+ wrapper_result.fixed_sample_mask = 3u;
+ return wrapper_result;
+}
+
+struct Output {
+ depth : f32,
+ value : f32,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_MultipleShaders) {
+ auto* src = R"(
+@fragment
+fn frag_main1() -> @builtin(sample_mask) u32 {
+ return 7u;
+}
+
+@fragment
+fn frag_main2() -> @location(0) f32 {
+ return 1.0;
+}
+
+@vertex
+fn vert_main1() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+
+@compute @workgroup_size(1)
+fn comp_main1() {
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @builtin(sample_mask)
+ value : u32,
+}
+
+fn frag_main1_inner() -> u32 {
+ return 7u;
+}
+
+@fragment
+fn frag_main1() -> tint_symbol {
+ let inner_result = frag_main1_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.value = (inner_result & 3u);
+ return wrapper_result;
+}
+
+struct tint_symbol_1 {
+ @location(0)
+ value : f32,
+ @builtin(sample_mask)
+ fixed_sample_mask : u32,
+}
+
+fn frag_main2_inner() -> f32 {
+ return 1.0;
+}
+
+@fragment
+fn frag_main2() -> tint_symbol_1 {
+ let inner_result_1 = frag_main2_inner();
+ var wrapper_result_1 : tint_symbol_1;
+ wrapper_result_1.value = inner_result_1;
+ wrapper_result_1.fixed_sample_mask = 3u;
+ return wrapper_result_1;
+}
+
+struct tint_symbol_2 {
+ @builtin(position)
+ value : vec4<f32>,
+}
+
+fn vert_main1_inner() -> vec4<f32> {
+ return vec4<f32>();
+}
+
+@vertex
+fn vert_main1() -> tint_symbol_2 {
+ let inner_result_2 = vert_main1_inner();
+ var wrapper_result_2 : tint_symbol_2;
+ wrapper_result_2.value = inner_result_2;
+ return wrapper_result_2;
+}
+
+@compute @workgroup_size(1)
+fn comp_main1() {
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03u);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_AvoidNameClash) {
+ auto* src = R"(
+struct FragOut {
+ @location(0) fixed_sample_mask : vec4<f32>,
+ @location(1) fixed_sample_mask_1 : vec4<f32>,
+};
+
+@fragment
+fn frag_main() -> FragOut {
+ return FragOut();
+}
+)";
+
+ auto* expect = R"(
+struct FragOut {
+ fixed_sample_mask : vec4<f32>,
+ fixed_sample_mask_1 : vec4<f32>,
+}
+
+struct tint_symbol {
+ @location(0)
+ fixed_sample_mask : vec4<f32>,
+ @location(1)
+ fixed_sample_mask_1 : vec4<f32>,
+ @builtin(sample_mask)
+ fixed_sample_mask_2 : u32,
+}
+
+fn frag_main_inner() -> FragOut {
+ return FragOut();
+}
+
+@fragment
+fn frag_main() -> tint_symbol {
+ let inner_result = frag_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.fixed_sample_mask = inner_result.fixed_sample_mask;
+ wrapper_result.fixed_sample_mask_1 = inner_result.fixed_sample_mask_1;
+ wrapper_result.fixed_sample_mask_2 = 3u;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl, 0x03);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_ReturnNonStruct_Spirv) {
+ auto* src = R"(
+@vertex
+fn vert_main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+@builtin(position) @internal(disable_validation__ignore_address_space) var<__out> value : vec4<f32>;
+
+@builtin(__point_size) @internal(disable_validation__ignore_address_space) var<__out> vertex_point_size : f32;
+
+fn vert_main_inner() -> vec4<f32> {
+ return vec4<f32>();
+}
+
+@vertex
+fn vert_main() {
+ let inner_result = vert_main_inner();
+ value = inner_result;
+ vertex_point_size = 1.0f;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_ReturnNonStruct_Msl) {
+ auto* src = R"(
+@vertex
+fn vert_main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @builtin(position)
+ value : vec4<f32>,
+ @builtin(__point_size)
+ vertex_point_size : f32,
+}
+
+fn vert_main_inner() -> vec4<f32> {
+ return vec4<f32>();
+}
+
+@vertex
+fn vert_main() -> tint_symbol {
+ let inner_result = vert_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.value = inner_result;
+ wrapper_result.vertex_point_size = 1.0f;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_ReturnStruct_Spirv) {
+ auto* src = R"(
+struct VertOut {
+ @builtin(position) pos : vec4<f32>,
+};
+
+@vertex
+fn vert_main() -> VertOut {
+ return VertOut();
+}
+)";
+
+ auto* expect = R"(
+@builtin(position) @internal(disable_validation__ignore_address_space) var<__out> pos_1 : vec4<f32>;
+
+@builtin(__point_size) @internal(disable_validation__ignore_address_space) var<__out> vertex_point_size : f32;
+
+struct VertOut {
+ pos : vec4<f32>,
+}
+
+fn vert_main_inner() -> VertOut {
+ return VertOut();
+}
+
+@vertex
+fn vert_main() {
+ let inner_result = vert_main_inner();
+ pos_1 = inner_result.pos;
+ vertex_point_size = 1.0f;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_ReturnStruct_Spirv_OutOfOrder) {
+ auto* src = R"(
+@vertex
+fn vert_main() -> VertOut {
+ return VertOut();
+}
+
+struct VertOut {
+ @builtin(position) pos : vec4<f32>,
+};
+)";
+
+ auto* expect = R"(
+@builtin(position) @internal(disable_validation__ignore_address_space) var<__out> pos_1 : vec4<f32>;
+
+@builtin(__point_size) @internal(disable_validation__ignore_address_space) var<__out> vertex_point_size : f32;
+
+fn vert_main_inner() -> VertOut {
+ return VertOut();
+}
+
+@vertex
+fn vert_main() {
+ let inner_result = vert_main_inner();
+ pos_1 = inner_result.pos;
+ vertex_point_size = 1.0f;
+}
+
+struct VertOut {
+ pos : vec4<f32>,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_ReturnStruct_Msl) {
+ auto* src = R"(
+struct VertOut {
+ @builtin(position) pos : vec4<f32>,
+};
+
+@vertex
+fn vert_main() -> VertOut {
+ return VertOut();
+}
+)";
+
+ auto* expect = R"(
+struct VertOut {
+ pos : vec4<f32>,
+}
+
+struct tint_symbol {
+ @builtin(position)
+ pos : vec4<f32>,
+ @builtin(__point_size)
+ vertex_point_size : f32,
+}
+
+fn vert_main_inner() -> VertOut {
+ return VertOut();
+}
+
+@vertex
+fn vert_main() -> tint_symbol {
+ let inner_result = vert_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.pos = inner_result.pos;
+ wrapper_result.vertex_point_size = 1.0f;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_ReturnStruct_Msl_OutOfOrder) {
+ auto* src = R"(
+@vertex
+fn vert_main() -> VertOut {
+ return VertOut();
+}
+
+struct VertOut {
+ @builtin(position) pos : vec4<f32>,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @builtin(position)
+ pos : vec4<f32>,
+ @builtin(__point_size)
+ vertex_point_size : f32,
+}
+
+fn vert_main_inner() -> VertOut {
+ return VertOut();
+}
+
+@vertex
+fn vert_main() -> tint_symbol {
+ let inner_result = vert_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.pos = inner_result.pos;
+ wrapper_result.vertex_point_size = 1.0f;
+ return wrapper_result;
+}
+
+struct VertOut {
+ pos : vec4<f32>,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_AvoidNameClash_Spirv) {
+ auto* src = R"(
+var<private> vertex_point_size : f32;
+var<private> vertex_point_size_1 : f32;
+var<private> vertex_point_size_2 : f32;
+
+struct VertIn1 {
+ @location(0) collide : f32,
+};
+
+struct VertIn2 {
+ @location(1) collide : f32,
+};
+
+struct VertOut {
+ @location(0) vertex_point_size : f32,
+ @builtin(position) vertex_point_size_1 : vec4<f32>,
+};
+
+@vertex
+fn vert_main(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
+ let x = collide.collide + collide_1.collide;
+ return VertOut();
+}
+)";
+
+ auto* expect = R"(
+@location(0) @internal(disable_validation__ignore_address_space) var<__in> collide_2 : f32;
+
+@location(1) @internal(disable_validation__ignore_address_space) var<__in> collide_3 : f32;
+
+@location(0) @internal(disable_validation__ignore_address_space) var<__out> vertex_point_size_3 : f32;
+
+@builtin(position) @internal(disable_validation__ignore_address_space) var<__out> vertex_point_size_1_1 : vec4<f32>;
+
+@builtin(__point_size) @internal(disable_validation__ignore_address_space) var<__out> vertex_point_size_4 : f32;
+
+var<private> vertex_point_size : f32;
+
+var<private> vertex_point_size_1 : f32;
+
+var<private> vertex_point_size_2 : f32;
+
+struct VertIn1 {
+ collide : f32,
+}
+
+struct VertIn2 {
+ collide : f32,
+}
+
+struct VertOut {
+ vertex_point_size : f32,
+ vertex_point_size_1 : vec4<f32>,
+}
+
+fn vert_main_inner(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
+ let x = (collide.collide + collide_1.collide);
+ return VertOut();
+}
+
+@vertex
+fn vert_main() {
+ let inner_result = vert_main_inner(VertIn1(collide_2), VertIn2(collide_3));
+ vertex_point_size_3 = inner_result.vertex_point_size;
+ vertex_point_size_1_1 = inner_result.vertex_point_size_1;
+ vertex_point_size_4 = 1.0f;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_AvoidNameClash_Spirv_OutOfOrder) {
+ auto* src = R"(
+@vertex
+fn vert_main(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
+ let x = collide.collide + collide_1.collide;
+ return VertOut();
+}
+
+struct VertIn1 {
+ @location(0) collide : f32,
+};
+
+struct VertIn2 {
+ @location(1) collide : f32,
+};
+
+var<private> vertex_point_size : f32;
+var<private> vertex_point_size_1 : f32;
+var<private> vertex_point_size_2 : f32;
+
+struct VertOut {
+ @location(0) vertex_point_size : f32,
+ @builtin(position) vertex_point_size_1 : vec4<f32>,
+};
+)";
+
+ auto* expect = R"(
+@location(0) @internal(disable_validation__ignore_address_space) var<__in> collide_2 : f32;
+
+@location(1) @internal(disable_validation__ignore_address_space) var<__in> collide_3 : f32;
+
+@location(0) @internal(disable_validation__ignore_address_space) var<__out> vertex_point_size_3 : f32;
+
+@builtin(position) @internal(disable_validation__ignore_address_space) var<__out> vertex_point_size_1_1 : vec4<f32>;
+
+@builtin(__point_size) @internal(disable_validation__ignore_address_space) var<__out> vertex_point_size_4 : f32;
+
+fn vert_main_inner(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
+ let x = (collide.collide + collide_1.collide);
+ return VertOut();
+}
+
+@vertex
+fn vert_main() {
+ let inner_result = vert_main_inner(VertIn1(collide_2), VertIn2(collide_3));
+ vertex_point_size_3 = inner_result.vertex_point_size;
+ vertex_point_size_1_1 = inner_result.vertex_point_size_1;
+ vertex_point_size_4 = 1.0f;
+}
+
+struct VertIn1 {
+ collide : f32,
+}
+
+struct VertIn2 {
+ collide : f32,
+}
+
+var<private> vertex_point_size : f32;
+
+var<private> vertex_point_size_1 : f32;
+
+var<private> vertex_point_size_2 : f32;
+
+struct VertOut {
+ vertex_point_size : f32,
+ vertex_point_size_1 : vec4<f32>,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_AvoidNameClash_Msl) {
+ auto* src = R"(
+struct VertIn1 {
+ @location(0) collide : f32,
+};
+
+struct VertIn2 {
+ @location(1) collide : f32,
+};
+
+struct VertOut {
+ @location(0) vertex_point_size : vec4<f32>,
+ @builtin(position) vertex_point_size_1 : vec4<f32>,
+};
+
+@vertex
+fn vert_main(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
+ let x = collide.collide + collide_1.collide;
+ return VertOut();
+}
+)";
+
+ auto* expect = R"(
+struct VertIn1 {
+ collide : f32,
+}
+
+struct VertIn2 {
+ collide : f32,
+}
+
+struct VertOut {
+ vertex_point_size : vec4<f32>,
+ vertex_point_size_1 : vec4<f32>,
+}
+
+struct tint_symbol_1 {
+ @location(0)
+ collide : f32,
+ @location(1)
+ collide_2 : f32,
+}
+
+struct tint_symbol_2 {
+ @location(0)
+ vertex_point_size : vec4<f32>,
+ @builtin(position)
+ vertex_point_size_1 : vec4<f32>,
+ @builtin(__point_size)
+ vertex_point_size_2 : f32,
+}
+
+fn vert_main_inner(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
+ let x = (collide.collide + collide_1.collide);
+ return VertOut();
+}
+
+@vertex
+fn vert_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
+ let inner_result = vert_main_inner(VertIn1(tint_symbol.collide), VertIn2(tint_symbol.collide_2));
+ var wrapper_result : tint_symbol_2;
+ wrapper_result.vertex_point_size = inner_result.vertex_point_size;
+ wrapper_result.vertex_point_size_1 = inner_result.vertex_point_size_1;
+ wrapper_result.vertex_point_size_2 = 1.0f;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_AvoidNameClash_Msl_OutOfOrder) {
+ auto* src = R"(
+@vertex
+fn vert_main(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
+ let x = collide.collide + collide_1.collide;
+ return VertOut();
+}
+
+struct VertIn1 {
+ @location(0) collide : f32,
+};
+
+struct VertIn2 {
+ @location(1) collide : f32,
+};
+
+struct VertOut {
+ @location(0) vertex_point_size : vec4<f32>,
+ @builtin(position) vertex_point_size_1 : vec4<f32>,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol_1 {
+ @location(0)
+ collide : f32,
+ @location(1)
+ collide_2 : f32,
+}
+
+struct tint_symbol_2 {
+ @location(0)
+ vertex_point_size : vec4<f32>,
+ @builtin(position)
+ vertex_point_size_1 : vec4<f32>,
+ @builtin(__point_size)
+ vertex_point_size_2 : f32,
+}
+
+fn vert_main_inner(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
+ let x = (collide.collide + collide_1.collide);
+ return VertOut();
+}
+
+@vertex
+fn vert_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
+ let inner_result = vert_main_inner(VertIn1(tint_symbol.collide), VertIn2(tint_symbol.collide_2));
+ var wrapper_result : tint_symbol_2;
+ wrapper_result.vertex_point_size = inner_result.vertex_point_size;
+ wrapper_result.vertex_point_size_1 = inner_result.vertex_point_size_1;
+ wrapper_result.vertex_point_size_2 = 1.0f;
+ return wrapper_result;
+}
+
+struct VertIn1 {
+ collide : f32,
+}
+
+struct VertIn2 {
+ collide : f32,
+}
+
+struct VertOut {
+ vertex_point_size : vec4<f32>,
+ vertex_point_size_1 : vec4<f32>,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_AvoidNameClash_Hlsl) {
+ auto* src = R"(
+struct VertIn1 {
+ @location(0) collide : f32,
+};
+
+struct VertIn2 {
+ @location(1) collide : f32,
+};
+
+struct VertOut {
+ @location(0) vertex_point_size : vec4<f32>,
+ @builtin(position) vertex_point_size_1 : vec4<f32>,
+};
+
+@vertex
+fn vert_main(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
+ let x = collide.collide + collide_1.collide;
+ return VertOut();
+}
+)";
+
+ auto* expect = R"(
+struct VertIn1 {
+ collide : f32,
+}
+
+struct VertIn2 {
+ collide : f32,
+}
+
+struct VertOut {
+ vertex_point_size : vec4<f32>,
+ vertex_point_size_1 : vec4<f32>,
+}
+
+struct tint_symbol_1 {
+ @location(0)
+ collide : f32,
+ @location(1)
+ collide_2 : f32,
+}
+
+struct tint_symbol_2 {
+ @location(0)
+ vertex_point_size : vec4<f32>,
+ @builtin(position)
+ vertex_point_size_1 : vec4<f32>,
+ @builtin(__point_size)
+ vertex_point_size_2 : f32,
+}
+
+fn vert_main_inner(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
+ let x = (collide.collide + collide_1.collide);
+ return VertOut();
+}
+
+@vertex
+fn vert_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
+ let inner_result = vert_main_inner(VertIn1(tint_symbol.collide), VertIn2(tint_symbol.collide_2));
+ var wrapper_result : tint_symbol_2;
+ wrapper_result.vertex_point_size = inner_result.vertex_point_size;
+ wrapper_result.vertex_point_size_1 = inner_result.vertex_point_size_1;
+ wrapper_result.vertex_point_size_2 = 1.0f;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_AvoidNameClash_Hlsl_OutOfOrder) {
+ auto* src = R"(
+@vertex
+fn vert_main(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
+ let x = collide.collide + collide_1.collide;
+ return VertOut();
+}
+
+struct VertIn1 {
+ @location(0) collide : f32,
+};
+
+struct VertIn2 {
+ @location(1) collide : f32,
+};
+
+struct VertOut {
+ @location(0) vertex_point_size : vec4<f32>,
+ @builtin(position) vertex_point_size_1 : vec4<f32>,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol_1 {
+ @location(0)
+ collide : f32,
+ @location(1)
+ collide_2 : f32,
+}
+
+struct tint_symbol_2 {
+ @location(0)
+ vertex_point_size : vec4<f32>,
+ @builtin(position)
+ vertex_point_size_1 : vec4<f32>,
+ @builtin(__point_size)
+ vertex_point_size_2 : f32,
+}
+
+fn vert_main_inner(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
+ let x = (collide.collide + collide_1.collide);
+ return VertOut();
+}
+
+@vertex
+fn vert_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
+ let inner_result = vert_main_inner(VertIn1(tint_symbol.collide), VertIn2(tint_symbol.collide_2));
+ var wrapper_result : tint_symbol_2;
+ wrapper_result.vertex_point_size = inner_result.vertex_point_size;
+ wrapper_result.vertex_point_size_1 = inner_result.vertex_point_size_1;
+ wrapper_result.vertex_point_size_2 = 1.0f;
+ return wrapper_result;
+}
+
+struct VertIn1 {
+ collide : f32,
+}
+
+struct VertIn2 {
+ collide : f32,
+}
+
+struct VertOut {
+ vertex_point_size : vec4<f32>,
+ vertex_point_size_1 : vec4<f32>,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl,
+ 0xFFFFFFFF, true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, SpirvSampleMaskBuiltins) {
+ auto* src = R"(
+@fragment
+fn main(@builtin(sample_index) sample_index : u32,
+ @builtin(sample_mask) mask_in : u32
+ ) -> @builtin(sample_mask) u32 {
+ return mask_in;
+}
+)";
+
+ auto* expect = R"(
+@builtin(sample_index) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__in> sample_index_1 : u32;
+
+@builtin(sample_mask) @interpolate(flat) @internal(disable_validation__ignore_address_space) var<__in> mask_in_1 : array<u32, 1u>;
+
+@builtin(sample_mask) @internal(disable_validation__ignore_address_space) var<__out> value : array<u32, 1u>;
+
+fn main_inner(sample_index : u32, mask_in : u32) -> u32 {
+ return mask_in;
+}
+
+@fragment
+fn main() {
+ let inner_result = main_inner(sample_index_1, mask_in_1[0i]);
+ value[0i] = inner_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, GLSLSampleMaskBuiltins) {
+ auto* src = R"(
+@fragment
+fn fragment_main(@builtin(sample_index) sample_index : u32,
+ @builtin(sample_mask) mask_in : u32
+ ) -> @builtin(sample_mask) u32 {
+ return mask_in;
+}
+)";
+
+ auto* expect = R"(
+@builtin(sample_index) @internal(disable_validation__ignore_address_space) var<__in> gl_SampleID : i32;
+
+@builtin(sample_mask) @internal(disable_validation__ignore_address_space) var<__in> gl_SampleMaskIn : array<i32, 1u>;
+
+@builtin(sample_mask) @internal(disable_validation__ignore_address_space) var<__out> gl_SampleMask : array<i32, 1u>;
+
+fn fragment_main(sample_index : u32, mask_in : u32) -> u32 {
+ return mask_in;
+}
+
+@fragment
+fn main() {
+ let inner_result = fragment_main(bitcast<u32>(gl_SampleID), bitcast<u32>(gl_SampleMaskIn[0i]));
+ gl_SampleMask[0i] = bitcast<i32>(inner_result);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kGlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, GLSLVertexInstanceIndexBuiltins) {
+ auto* src = R"(
+@vertex
+fn vertex_main(@builtin(vertex_index) vertexID : u32,
+ @builtin(instance_index) instanceID : u32
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(f32(vertexID) + f32(instanceID));
+}
+)";
+
+ auto* expect = R"(
+@builtin(vertex_index) @internal(disable_validation__ignore_address_space) var<__in> gl_VertexID : i32;
+
+@builtin(instance_index) @internal(disable_validation__ignore_address_space) var<__in> gl_InstanceID : i32;
+
+@builtin(position) @internal(disable_validation__ignore_address_space) var<__out> gl_Position : vec4<f32>;
+
+fn vertex_main(vertexID : u32, instanceID : u32) -> vec4<f32> {
+ return vec4<f32>((f32(vertexID) + f32(instanceID)));
+}
+
+@vertex
+fn main() {
+ let inner_result = vertex_main(bitcast<u32>(gl_VertexID), bitcast<u32>(gl_InstanceID));
+ gl_Position = inner_result;
+ gl_Position.y = -(gl_Position.y);
+ gl_Position.z = ((2.0f * gl_Position.z) - gl_Position.w);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kGlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Index_Attribute_Spirv) {
+ auto* src = R"(
+enable chromium_internal_dual_source_blending;
+
+struct FragOutput {
+ @location(0) @index(0) color : vec4<f32>,
+ @location(0) @index(1) blend : vec4<f32>,
+ @builtin(frag_depth) depth : f32,
+ @builtin(sample_mask) mask : u32,
+};
+
+@fragment
+fn frag_main() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ output.blend = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_dual_source_blending;
+
+@location(0) @index(0) @internal(disable_validation__ignore_address_space) var<__out> color_1 : vec4<f32>;
+
+@location(0) @index(1) @internal(disable_validation__ignore_address_space) var<__out> blend_1 : vec4<f32>;
+
+@builtin(frag_depth) @internal(disable_validation__ignore_address_space) var<__out> depth_1 : f32;
+
+@builtin(sample_mask) @internal(disable_validation__ignore_address_space) var<__out> mask_1 : array<u32, 1u>;
+
+struct FragOutput {
+ color : vec4<f32>,
+ blend : vec4<f32>,
+ depth : f32,
+ mask : u32,
+}
+
+fn frag_main_inner() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ output.blend = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+
+@fragment
+fn frag_main() {
+ let inner_result = frag_main_inner();
+ color_1 = inner_result.color;
+ blend_1 = inner_result.blend;
+ depth_1 = inner_result.depth;
+ mask_1[0i] = inner_result.mask;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Index_Attribute_Msl) {
+ auto* src = R"(
+enable chromium_internal_dual_source_blending;
+
+struct FragOutput {
+ @location(0) @index(0) color : vec4<f32>,
+ @location(0) @index(1) blend : vec4<f32>,
+ @builtin(frag_depth) depth : f32,
+ @builtin(sample_mask) mask : u32,
+};
+
+@fragment
+fn frag_main() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ output.blend = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_dual_source_blending;
+
+struct FragOutput {
+ color : vec4<f32>,
+ blend : vec4<f32>,
+ depth : f32,
+ mask : u32,
+}
+
+struct tint_symbol {
+ @location(0) @index(0)
+ color : vec4<f32>,
+ @location(0) @index(1)
+ blend : vec4<f32>,
+ @builtin(frag_depth)
+ depth : f32,
+ @builtin(sample_mask)
+ mask : u32,
+}
+
+fn frag_main_inner() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ output.blend = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+
+@fragment
+fn frag_main() -> tint_symbol {
+ let inner_result = frag_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.color = inner_result.color;
+ wrapper_result.blend = inner_result.blend;
+ wrapper_result.depth = inner_result.depth;
+ wrapper_result.mask = inner_result.mask;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kMsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Index_Attribute_Hlsl) {
+ auto* src = R"(
+enable chromium_internal_dual_source_blending;
+
+struct FragOutput {
+ @location(0) @index(0) color : vec4<f32>,
+ @location(0) @index(1) blend : vec4<f32>,
+ @builtin(frag_depth) depth : f32,
+ @builtin(sample_mask) mask : u32,
+};
+
+@fragment
+fn frag_main() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ output.blend = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_dual_source_blending;
+
+struct FragOutput {
+ color : vec4<f32>,
+ blend : vec4<f32>,
+ depth : f32,
+ mask : u32,
+}
+
+struct tint_symbol {
+ @location(0) @index(0)
+ color : vec4<f32>,
+ @location(0) @index(1)
+ blend : vec4<f32>,
+ @builtin(frag_depth)
+ depth : f32,
+ @builtin(sample_mask)
+ mask : u32,
+}
+
+fn frag_main_inner() -> FragOutput {
+ var output : FragOutput;
+ output.depth = 1.0;
+ output.mask = 7u;
+ output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ output.blend = vec4<f32>(0.5, 0.5, 0.5, 1.0);
+ return output;
+}
+
+@fragment
+fn frag_main() -> tint_symbol {
+ let inner_result = frag_main_inner();
+ var wrapper_result : tint_symbol;
+ wrapper_result.color = inner_result.color;
+ wrapper_result.blend = inner_result.blend;
+ wrapper_result.depth = inner_result.depth;
+ wrapper_result.mask = inner_result.mask;
+ return wrapper_result;
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/clamp_frag_depth.cc b/src/tint/lang/wgsl/ast/transform/clamp_frag_depth.cc
new file mode 100644
index 0000000..6910727
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/clamp_frag_depth.cc
@@ -0,0 +1,228 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/clamp_frag_depth.h"
+
+#include <utility>
+
+#include "src/tint/builtin/builtin_value.h"
+#include "src/tint/lang/wgsl/ast/attribute.h"
+#include "src/tint/lang/wgsl/ast/builtin_attribute.h"
+#include "src/tint/lang/wgsl/ast/function.h"
+#include "src/tint/lang/wgsl/ast/module.h"
+#include "src/tint/lang/wgsl/ast/struct.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/struct.h"
+#include "src/tint/utils/scoped_assignment.h"
+#include "src/tint/utils/vector.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::ClampFragDepth);
+
+namespace tint::ast::transform {
+
+/// PIMPL state for the transform
+struct ClampFragDepth::State {
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b{};
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+ /// The sem::Info of the program
+ const sem::Info& sem = src->Sem();
+ /// The symbols of the program
+ const SymbolTable& sym = src->Symbols();
+
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ Transform::ApplyResult Run() {
+ // Abort on any use of push constants in the module.
+ for (auto* global : src->AST().GlobalVariables()) {
+ if (auto* var = global->As<Var>()) {
+ auto* v = src->Sem().Get(var);
+ if (TINT_UNLIKELY(v->AddressSpace() == builtin::AddressSpace::kPushConstant)) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "ClampFragDepth doesn't know how to handle module that already use push "
+ "constants";
+ return Program(std::move(b));
+ }
+ }
+ }
+
+ if (!ShouldRun()) {
+ return SkipTransform;
+ }
+
+ // At least one entry-point needs clamping. Add the following to the module:
+ //
+ // enable chromium_experimental_push_constant;
+ //
+ // struct FragDepthClampArgs {
+ // min : f32,
+ // max : f32,
+ // }
+ // var<push_constant> frag_depth_clamp_args : FragDepthClampArgs;
+ //
+ // fn clamp_frag_depth(v : f32) -> f32 {
+ // return clamp(v, frag_depth_clamp_args.min, frag_depth_clamp_args.max);
+ // }
+ b.Enable(builtin::Extension::kChromiumExperimentalPushConstant);
+
+ b.Structure(b.Symbols().New("FragDepthClampArgs"),
+ utils::Vector{b.Member("min", b.ty.f32()), b.Member("max", b.ty.f32())});
+
+ auto args_sym = b.Symbols().New("frag_depth_clamp_args");
+ b.GlobalVar(args_sym, b.ty("FragDepthClampArgs"), builtin::AddressSpace::kPushConstant);
+
+ auto base_fn_sym = b.Symbols().New("clamp_frag_depth");
+ b.Func(base_fn_sym, utils::Vector{b.Param("v", b.ty.f32())}, b.ty.f32(),
+ utils::Vector{b.Return(b.Call("clamp", "v", b.MemberAccessor(args_sym, "min"),
+ b.MemberAccessor(args_sym, "max")))});
+
+ // If true, the currently cloned function returns frag depth directly as a scalar
+ bool returns_frag_depth_as_value = false;
+
+ // If valid, the currently cloned function returns frag depth in a struct
+ // The symbol is the name of the helper function to apply the depth clamping.
+ Symbol returns_frag_depth_as_struct_helper;
+
+ // Map of io struct to helper function to return the structure with the depth clamping
+ // applied.
+ utils::Hashmap<const Struct*, Symbol, 4u> io_structs_clamp_helpers;
+
+ // Register a callback that will be called for each visted AST function.
+ // This call wraps the cloning of the function's statements, and will assign to
+ // `returns_frag_depth_as_value` or `returns_frag_depth_as_struct_helper` if the function's
+ // return value requires depth clamping.
+ ctx.ReplaceAll([&](const Function* fn) {
+ if (fn->PipelineStage() != PipelineStage::kFragment) {
+ return ctx.CloneWithoutTransform(fn);
+ }
+
+ if (ReturnsFragDepthAsValue(fn)) {
+ TINT_SCOPED_ASSIGNMENT(returns_frag_depth_as_value, true);
+ return ctx.CloneWithoutTransform(fn);
+ }
+
+ if (ReturnsFragDepthInStruct(fn)) {
+ // At most once per I/O struct, add the conversion function:
+ //
+ // fn clamp_frag_depth_S(s : S) -> S {
+ // return S(s.first, s.second, clamp_frag_depth(s.frag_depth), s.last);
+ // }
+ auto* struct_ty = sem.Get(fn)->ReturnType()->As<sem::Struct>()->Declaration();
+ auto helper = io_structs_clamp_helpers.GetOrCreate(struct_ty, [&] {
+ auto return_ty = fn->return_type;
+ auto fn_sym =
+ b.Symbols().New("clamp_frag_depth_" + struct_ty->name->symbol.Name());
+
+ utils::Vector<const Expression*, 8u> initializer_args;
+ for (auto* member : struct_ty->members) {
+ const Expression* arg =
+ b.MemberAccessor("s", ctx.Clone(member->name->symbol));
+ if (ContainsFragDepth(member->attributes)) {
+ arg = b.Call(base_fn_sym, arg);
+ }
+ initializer_args.Push(arg);
+ }
+ utils::Vector params{b.Param("s", ctx.Clone(return_ty))};
+ utils::Vector body{
+ b.Return(b.Call(ctx.Clone(return_ty), std::move(initializer_args))),
+ };
+ b.Func(fn_sym, params, ctx.Clone(return_ty), body);
+ return fn_sym;
+ });
+
+ TINT_SCOPED_ASSIGNMENT(returns_frag_depth_as_struct_helper, helper);
+ return ctx.CloneWithoutTransform(fn);
+ }
+
+ return ctx.CloneWithoutTransform(fn);
+ });
+
+ // Replace the return statements `return expr` with `return clamp_frag_depth(expr)`.
+ ctx.ReplaceAll([&](const ReturnStatement* stmt) -> const ReturnStatement* {
+ if (returns_frag_depth_as_value) {
+ return b.Return(stmt->source, b.Call(base_fn_sym, ctx.Clone(stmt->value)));
+ }
+ if (returns_frag_depth_as_struct_helper.IsValid()) {
+ return b.Return(stmt->source, b.Call(returns_frag_depth_as_struct_helper,
+ ctx.Clone(stmt->value)));
+ }
+ return nullptr;
+ });
+
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+
+ private:
+ /// @returns true if the transform should run
+ bool ShouldRun() {
+ for (auto* fn : src->AST().Functions()) {
+ if (fn->PipelineStage() == PipelineStage::kFragment &&
+ (ReturnsFragDepthAsValue(fn) || ReturnsFragDepthInStruct(fn))) {
+ return true;
+ }
+ }
+
+ return false;
+ }
+ /// @param attrs the attributes to examine
+ /// @returns true if @p attrs contains a `@builtin(frag_depth)` attribute
+ bool ContainsFragDepth(utils::VectorRef<const Attribute*> attrs) {
+ for (auto* attribute : attrs) {
+ if (auto* builtin_attr = attribute->As<BuiltinAttribute>()) {
+ auto builtin = sem.Get(builtin_attr)->Value();
+ if (builtin == builtin::BuiltinValue::kFragDepth) {
+ return true;
+ }
+ }
+ }
+
+ return false;
+ }
+
+ /// @param fn the function to examine
+ /// @returns true if @p fn has a return type with a `@builtin(frag_depth)` attribute
+ bool ReturnsFragDepthAsValue(const Function* fn) {
+ return ContainsFragDepth(fn->return_type_attributes);
+ }
+
+ /// @param fn the function to examine
+ /// @returns true if @p fn has a return structure with a `@builtin(frag_depth)` attribute on one
+ /// of the members
+ bool ReturnsFragDepthInStruct(const Function* fn) {
+ if (auto* struct_ty = sem.Get(fn)->ReturnType()->As<sem::Struct>()) {
+ for (auto* member : struct_ty->Members()) {
+ if (ContainsFragDepth(member->Declaration()->attributes)) {
+ return true;
+ }
+ }
+ }
+
+ return false;
+ }
+};
+
+ClampFragDepth::ClampFragDepth() = default;
+ClampFragDepth::~ClampFragDepth() = default;
+
+Transform::ApplyResult ClampFragDepth::Apply(const Program* src, const DataMap&, DataMap&) const {
+ return State{src}.Run();
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/clamp_frag_depth.h b/src/tint/lang/wgsl/ast/transform/clamp_frag_depth.h
new file mode 100644
index 0000000..0338dbd
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/clamp_frag_depth.h
@@ -0,0 +1,75 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_CLAMP_FRAG_DEPTH_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_CLAMP_FRAG_DEPTH_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+// Forward declarations
+namespace tint {
+class CloneContext;
+} // namespace tint
+
+namespace tint::ast::transform {
+
+/// Add clamping of the `@builtin(frag_depth)` output of fragment shaders using two push constants
+/// provided by the outside environment. For example the following code:
+///
+/// ```
+/// @fragment fn main() -> @builtin(frag_depth) f32 {
+/// return 0.0;
+/// }
+/// ```
+///
+/// Is transformed to:
+///
+/// ```
+/// enable chromium_experimental_push_constant;
+///
+/// struct FragDepthClampArgs {
+/// min : f32,
+/// max : f32,
+/// }
+///
+/// var<push_constant> frag_depth_clamp_args : FragDepthClampArgs;
+///
+/// fn clamp_frag_depth(v : f32) -> f32 {
+/// return clamp(v, frag_depth_clamp_args.min, frag_depth_clamp_args.max);
+/// }
+///
+/// @fragment
+/// fn main() -> @builtin(frag_depth) f32 {
+/// return clamp_frag_depth(0.0);
+/// }
+/// ```
+class ClampFragDepth final : public utils::Castable<ClampFragDepth, Transform> {
+ public:
+ /// Constructor
+ ClampFragDepth();
+ /// Destructor
+ ~ClampFragDepth() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_CLAMP_FRAG_DEPTH_H_
diff --git a/src/tint/lang/wgsl/ast/transform/clamp_frag_depth_test.cc b/src/tint/lang/wgsl/ast/transform/clamp_frag_depth_test.cc
new file mode 100644
index 0000000..d2d567d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/clamp_frag_depth_test.cc
@@ -0,0 +1,381 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/clamp_frag_depth.h"
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using ClampFragDepthTest = TransformTest;
+
+TEST_F(ClampFragDepthTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<ClampFragDepth>(src));
+}
+
+TEST_F(ClampFragDepthTest, ShouldRunNoFragmentShader) {
+ auto* src = R"(
+ fn f() -> f32 {
+ return 0.0;
+ }
+
+ @compute @workgroup_size(1) fn cs() {
+ }
+
+ @vertex fn vs() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+ }
+ )";
+
+ EXPECT_FALSE(ShouldRun<ClampFragDepth>(src));
+}
+
+TEST_F(ClampFragDepthTest, ShouldRunFragmentShaderNoReturnType) {
+ auto* src = R"(
+ @fragment fn main() {
+ }
+ )";
+
+ EXPECT_FALSE(ShouldRun<ClampFragDepth>(src));
+}
+
+TEST_F(ClampFragDepthTest, ShouldRunFragmentShaderNoFragDepth) {
+ auto* src = R"(
+ @fragment fn main() -> @location(0) f32 {
+ return 0.0;
+ }
+
+ struct S {
+ @location(0) a : f32,
+ @builtin(sample_mask) b : u32,
+ }
+ @fragment fn main2() -> S {
+ return S();
+ }
+ )";
+
+ EXPECT_FALSE(ShouldRun<ClampFragDepth>(src));
+}
+
+TEST_F(ClampFragDepthTest, ShouldRunFragDepthAsDirectReturn) {
+ auto* src = R"(
+ @fragment fn main() -> @builtin(frag_depth) f32 {
+ return 0.0;
+ }
+ )";
+
+ EXPECT_TRUE(ShouldRun<ClampFragDepth>(src));
+}
+
+TEST_F(ClampFragDepthTest, ShouldRunFragDepthInStruct) {
+ auto* src = R"(
+ struct S {
+ @location(0) a : f32,
+ @builtin(frag_depth) b : f32,
+ @location(1) c : f32,
+ }
+ @fragment fn main() -> S {
+ return S();
+ }
+ )";
+
+ EXPECT_TRUE(ShouldRun<ClampFragDepth>(src));
+}
+
+TEST_F(ClampFragDepthTest, SingleReturnOfFragDepth) {
+ auto* src = R"(
+ @fragment fn main() -> @builtin(frag_depth) f32 {
+ return 0.0;
+ }
+ )";
+
+ auto* expect = R"(
+enable chromium_experimental_push_constant;
+
+struct FragDepthClampArgs {
+ min : f32,
+ max : f32,
+}
+
+var<push_constant> frag_depth_clamp_args : FragDepthClampArgs;
+
+fn clamp_frag_depth(v : f32) -> f32 {
+ return clamp(v, frag_depth_clamp_args.min, frag_depth_clamp_args.max);
+}
+
+@fragment
+fn main() -> @builtin(frag_depth) f32 {
+ return clamp_frag_depth(0.0);
+}
+)";
+
+ auto got = Run<ClampFragDepth>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ClampFragDepthTest, MultipleReturnOfFragDepth) {
+ auto* src = R"(
+ @fragment fn main() -> @builtin(frag_depth) f32 {
+ if (false) {
+ return 1.0;
+ }
+ return 0.0;
+ }
+ )";
+
+ auto* expect = R"(
+enable chromium_experimental_push_constant;
+
+struct FragDepthClampArgs {
+ min : f32,
+ max : f32,
+}
+
+var<push_constant> frag_depth_clamp_args : FragDepthClampArgs;
+
+fn clamp_frag_depth(v : f32) -> f32 {
+ return clamp(v, frag_depth_clamp_args.min, frag_depth_clamp_args.max);
+}
+
+@fragment
+fn main() -> @builtin(frag_depth) f32 {
+ if (false) {
+ return clamp_frag_depth(1.0);
+ }
+ return clamp_frag_depth(0.0);
+}
+)";
+
+ auto got = Run<ClampFragDepth>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ClampFragDepthTest, OtherFunctionWithoutFragDepth) {
+ auto* src = R"(
+ @fragment fn main() -> @builtin(frag_depth) f32 {
+ return 0.0;
+ }
+ @fragment fn other() -> @location(0) f32 {
+ return 0.0;
+ }
+ )";
+
+ auto* expect = R"(
+enable chromium_experimental_push_constant;
+
+struct FragDepthClampArgs {
+ min : f32,
+ max : f32,
+}
+
+var<push_constant> frag_depth_clamp_args : FragDepthClampArgs;
+
+fn clamp_frag_depth(v : f32) -> f32 {
+ return clamp(v, frag_depth_clamp_args.min, frag_depth_clamp_args.max);
+}
+
+@fragment
+fn main() -> @builtin(frag_depth) f32 {
+ return clamp_frag_depth(0.0);
+}
+
+@fragment
+fn other() -> @location(0) f32 {
+ return 0.0;
+}
+)";
+
+ auto got = Run<ClampFragDepth>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ClampFragDepthTest, SimpleReturnOfStruct) {
+ auto* src = R"(
+ struct S {
+ @builtin(frag_depth) frag_depth : f32,
+ }
+
+ @fragment fn main() -> S {
+ return S(0.0);
+ }
+ )";
+
+ auto* expect = R"(
+enable chromium_experimental_push_constant;
+
+struct FragDepthClampArgs {
+ min : f32,
+ max : f32,
+}
+
+var<push_constant> frag_depth_clamp_args : FragDepthClampArgs;
+
+fn clamp_frag_depth(v : f32) -> f32 {
+ return clamp(v, frag_depth_clamp_args.min, frag_depth_clamp_args.max);
+}
+
+struct S {
+ @builtin(frag_depth)
+ frag_depth : f32,
+}
+
+fn clamp_frag_depth_S(s : S) -> S {
+ return S(clamp_frag_depth(s.frag_depth));
+}
+
+@fragment
+fn main() -> S {
+ return clamp_frag_depth_S(S(0.0));
+}
+)";
+
+ auto got = Run<ClampFragDepth>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ClampFragDepthTest, MixOfFunctionReturningStruct) {
+ auto* src = R"(
+ struct S {
+ @builtin(frag_depth) frag_depth : f32,
+ }
+ struct S2 {
+ @builtin(frag_depth) frag_depth : f32,
+ }
+
+ @fragment fn returnS() -> S {
+ return S(0.0);
+ }
+ @fragment fn againReturnS() -> S {
+ return S(0.0);
+ }
+ @fragment fn returnS2() -> S2 {
+ return S2(0.0);
+ }
+ )";
+
+ // clamp_frag_depth_S is emitted only once.
+ // S2 gets its own clamping function.
+ auto* expect = R"(
+enable chromium_experimental_push_constant;
+
+struct FragDepthClampArgs {
+ min : f32,
+ max : f32,
+}
+
+var<push_constant> frag_depth_clamp_args : FragDepthClampArgs;
+
+fn clamp_frag_depth(v : f32) -> f32 {
+ return clamp(v, frag_depth_clamp_args.min, frag_depth_clamp_args.max);
+}
+
+struct S {
+ @builtin(frag_depth)
+ frag_depth : f32,
+}
+
+struct S2 {
+ @builtin(frag_depth)
+ frag_depth : f32,
+}
+
+fn clamp_frag_depth_S(s : S) -> S {
+ return S(clamp_frag_depth(s.frag_depth));
+}
+
+@fragment
+fn returnS() -> S {
+ return clamp_frag_depth_S(S(0.0));
+}
+
+@fragment
+fn againReturnS() -> S {
+ return clamp_frag_depth_S(S(0.0));
+}
+
+fn clamp_frag_depth_S2(s : S2) -> S2 {
+ return S2(clamp_frag_depth(s.frag_depth));
+}
+
+@fragment
+fn returnS2() -> S2 {
+ return clamp_frag_depth_S2(S2(0.0));
+}
+)";
+
+ auto got = Run<ClampFragDepth>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ClampFragDepthTest, ComplexIOStruct) {
+ auto* src = R"(
+ struct S {
+ @location(0) blou : vec4<f32>,
+ @location(1) bi : vec4<f32>,
+ @builtin(frag_depth) frag_depth : f32,
+ @location(2) boul : i32,
+ @builtin(sample_mask) ga : u32,
+ }
+
+ @fragment fn main() -> S {
+ return S(vec4<f32>(), vec4<f32>(), 0.0, 1, 0u);
+ }
+ )";
+
+ auto* expect = R"(
+enable chromium_experimental_push_constant;
+
+struct FragDepthClampArgs {
+ min : f32,
+ max : f32,
+}
+
+var<push_constant> frag_depth_clamp_args : FragDepthClampArgs;
+
+fn clamp_frag_depth(v : f32) -> f32 {
+ return clamp(v, frag_depth_clamp_args.min, frag_depth_clamp_args.max);
+}
+
+struct S {
+ @location(0)
+ blou : vec4<f32>,
+ @location(1)
+ bi : vec4<f32>,
+ @builtin(frag_depth)
+ frag_depth : f32,
+ @location(2)
+ boul : i32,
+ @builtin(sample_mask)
+ ga : u32,
+}
+
+fn clamp_frag_depth_S(s : S) -> S {
+ return S(s.blou, s.bi, clamp_frag_depth(s.frag_depth), s.boul, s.ga);
+}
+
+@fragment
+fn main() -> S {
+ return clamp_frag_depth_S(S(vec4<f32>(), vec4<f32>(), 0.0, 1, 0u));
+}
+)";
+
+ auto got = Run<ClampFragDepth>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/combine_samplers.cc b/src/tint/lang/wgsl/ast/transform/combine_samplers.cc
new file mode 100644
index 0000000..2e6775f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/combine_samplers.cc
@@ -0,0 +1,354 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/combine_samplers.h"
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/statement.h"
+
+#include "src/tint/utils/map.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::CombineSamplers);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::CombineSamplers::BindingInfo);
+
+namespace {
+
+bool IsGlobal(const tint::sem::VariablePair& pair) {
+ return pair.first->Is<tint::sem::GlobalVariable>() &&
+ (!pair.second || pair.second->Is<tint::sem::GlobalVariable>());
+}
+
+} // namespace
+
+namespace tint::ast::transform {
+
+using namespace tint::number_suffixes; // NOLINT
+
+CombineSamplers::BindingInfo::BindingInfo(const BindingMap& map,
+ const sem::BindingPoint& placeholder)
+ : binding_map(map), placeholder_binding_point(placeholder) {}
+CombineSamplers::BindingInfo::BindingInfo(const BindingInfo& other) = default;
+CombineSamplers::BindingInfo::~BindingInfo() = default;
+
+/// PIMPL state for the transform
+struct CombineSamplers::State {
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+
+ /// The binding info
+ const BindingInfo* binding_info;
+
+ /// Map from a texture/sampler pair to the corresponding combined sampler
+ /// variable
+ using CombinedTextureSamplerMap = std::unordered_map<sem::VariablePair, const Variable*>;
+
+ /// Use sem::BindingPoint without scope.
+ using BindingPoint = sem::BindingPoint;
+
+ /// A map of all global texture/sampler variable pairs to the global
+ /// combined sampler variable that will replace it.
+ CombinedTextureSamplerMap global_combined_texture_samplers_;
+
+ /// A map of all texture/sampler variable pairs that contain a function
+ /// parameter to the combined sampler function paramter that will replace it.
+ std::unordered_map<const sem::Function*, CombinedTextureSamplerMap>
+ function_combined_texture_samplers_;
+
+ /// Placeholder global samplers used when a function contains texture-only
+ /// references (one comparison sampler, one regular). These are also used as
+ /// temporary sampler parameters to the texture builtins to satisfy the WGSL
+ /// resolver, but are then ignored and removed by the GLSL writer.
+ const Variable* placeholder_samplers_[2] = {};
+
+ /// Group and binding attributes used by all combined sampler globals.
+ /// Group 0 and binding 0 are used, with collisions disabled.
+ /// @returns the newly-created attribute list
+ auto Attributes() const {
+ utils::Vector<const Attribute*, 3> attributes{ctx.dst->Group(0_a), ctx.dst->Binding(0_a)};
+ attributes.Push(ctx.dst->Disable(DisabledValidation::kBindingPointCollision));
+ return attributes;
+ }
+
+ /// Constructor
+ /// @param program the source program
+ /// @param info the binding map information
+ State(const Program* program, const BindingInfo* info) : src(program), binding_info(info) {}
+
+ /// Creates a combined sampler global variables.
+ /// (Note this is actually a Texture node at the AST level, but it will be
+ /// written as the corresponding sampler (eg., sampler2D) on GLSL output.)
+ /// @param texture_var the texture (global) variable
+ /// @param sampler_var the sampler (global) variable
+ /// @param name the default name to use (may be overridden by map lookup)
+ /// @returns the newly-created global variable
+ const Variable* CreateCombinedGlobal(const sem::Variable* texture_var,
+ const sem::Variable* sampler_var,
+ std::string name) {
+ SamplerTexturePair bp_pair;
+ bp_pair.texture_binding_point = *texture_var->As<sem::GlobalVariable>()->BindingPoint();
+ bp_pair.sampler_binding_point =
+ sampler_var ? *sampler_var->As<sem::GlobalVariable>()->BindingPoint()
+ : binding_info->placeholder_binding_point;
+ auto it = binding_info->binding_map.find(bp_pair);
+ if (it != binding_info->binding_map.end()) {
+ name = it->second;
+ }
+ Type type = CreateCombinedASTTypeFor(texture_var, sampler_var);
+ Symbol symbol = ctx.dst->Symbols().New(name);
+ return ctx.dst->GlobalVar(symbol, type, Attributes());
+ }
+
+ /// Creates placeholder global sampler variables.
+ /// @param kind the sampler kind to create for
+ /// @returns the newly-created global variable
+ const Variable* CreatePlaceholder(type::SamplerKind kind) {
+ Type type = ctx.dst->ty.sampler(kind);
+ const char* name = kind == type::SamplerKind::kComparisonSampler
+ ? "placeholder_comparison_sampler"
+ : "placeholder_sampler";
+ Symbol symbol = ctx.dst->Symbols().New(name);
+ return ctx.dst->GlobalVar(symbol, type, Attributes());
+ }
+
+ /// Creates Identifier for a given texture and sampler variable pair.
+ /// Depth textures with no samplers are turned into the corresponding
+ /// f32 texture (e.g., texture_depth_2d -> texture_2d<f32>).
+ /// @param texture the texture variable of interest
+ /// @param sampler the texture variable of interest
+ /// @returns the newly-created type
+ Type CreateCombinedASTTypeFor(const sem::Variable* texture, const sem::Variable* sampler) {
+ const type::Type* texture_type = texture->Type()->UnwrapRef();
+ const type::DepthTexture* depth = texture_type->As<type::DepthTexture>();
+ if (depth && !sampler) {
+ return ctx.dst->ty.sampled_texture(depth->dim(), ctx.dst->ty.f32());
+ } else {
+ return CreateASTTypeFor(ctx, texture_type);
+ }
+ }
+
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
+ auto& sem = ctx.src->Sem();
+
+ // Remove all texture and sampler global variables. These will be replaced
+ // by combined samplers.
+ for (auto* global : ctx.src->AST().GlobalVariables()) {
+ auto* global_sem = sem.Get(global)->As<sem::GlobalVariable>();
+ auto* type = ctx.src->TypeOf(global->type);
+ if (tint::utils::IsAnyOf<type::Texture, type::Sampler>(type) &&
+ !type->Is<type::StorageTexture>()) {
+ ctx.Remove(ctx.src->AST().GlobalDeclarations(), global);
+ } else if (auto binding_point = global_sem->BindingPoint()) {
+ if (binding_point->group == 0 && binding_point->binding == 0) {
+ auto* attribute = ctx.dst->Disable(DisabledValidation::kBindingPointCollision);
+ ctx.InsertFront(global->attributes, attribute);
+ }
+ }
+ }
+
+ // Rewrite all function signatures to use combined samplers, and remove
+ // separate textures & samplers. Create new combined globals where found.
+ ctx.ReplaceAll([&](const Function* ast_fn) -> const Function* {
+ if (auto* fn = sem.Get(ast_fn)) {
+ auto pairs = fn->TextureSamplerPairs();
+ if (pairs.IsEmpty()) {
+ return nullptr;
+ }
+ utils::Vector<const Parameter*, 8> params;
+ for (auto pair : fn->TextureSamplerPairs()) {
+ const sem::Variable* texture_var = pair.first;
+ const sem::Variable* sampler_var = pair.second;
+ std::string name = texture_var->Declaration()->name->symbol.Name();
+ if (sampler_var) {
+ name += "_" + sampler_var->Declaration()->name->symbol.Name();
+ }
+ if (IsGlobal(pair)) {
+ // Both texture and sampler are global; add a new global variable
+ // to represent the combined sampler (if not already created).
+ utils::GetOrCreate(global_combined_texture_samplers_, pair, [&] {
+ return CreateCombinedGlobal(texture_var, sampler_var, name);
+ });
+ } else {
+ // Either texture or sampler (or both) is a function parameter;
+ // add a new function parameter to represent the combined sampler.
+ Type type = CreateCombinedASTTypeFor(texture_var, sampler_var);
+ auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type);
+ params.Push(var);
+ function_combined_texture_samplers_[fn][pair] = var;
+ }
+ }
+ // Filter out separate textures and samplers from the original
+ // function signature.
+ for (auto* param : fn->Parameters()) {
+ if (!param->Type()->IsAnyOf<type::Texture, type::Sampler>()) {
+ params.Push(ctx.Clone(param->Declaration()));
+ }
+ }
+ // Create a new function signature that differs only in the parameter
+ // list.
+ auto name = ctx.Clone(ast_fn->name);
+ auto return_type = ctx.Clone(ast_fn->return_type);
+ auto* body = ctx.Clone(ast_fn->body);
+ auto attributes = ctx.Clone(ast_fn->attributes);
+ auto return_type_attributes = ctx.Clone(ast_fn->return_type_attributes);
+ return ctx.dst->create<Function>(name, params, return_type, body,
+ std::move(attributes),
+ std::move(return_type_attributes));
+ }
+ return nullptr;
+ });
+
+ // Replace all function call expressions containing texture or
+ // sampler parameters to use the current function's combined samplers or
+ // the combined global samplers, as appropriate.
+ ctx.ReplaceAll([&](const CallExpression* expr) -> const Expression* {
+ if (auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>()) {
+ utils::Vector<const Expression*, 8> args;
+ // Replace all texture builtin calls.
+ if (auto* builtin = call->Target()->As<sem::Builtin>()) {
+ const auto& signature = builtin->Signature();
+ auto sampler_index = signature.IndexOf(sem::ParameterUsage::kSampler);
+ auto texture_index = signature.IndexOf(sem::ParameterUsage::kTexture);
+ if (texture_index == -1) {
+ return nullptr;
+ }
+ const sem::ValueExpression* texture =
+ call->Arguments()[static_cast<size_t>(texture_index)];
+ // We don't want to combine storage textures with anything, since
+ // they never have associated samplers in GLSL.
+ if (texture->Type()->UnwrapRef()->Is<type::StorageTexture>()) {
+ return nullptr;
+ }
+ const sem::ValueExpression* sampler =
+ sampler_index != -1 ? call->Arguments()[static_cast<size_t>(sampler_index)]
+ : nullptr;
+ auto* texture_var = texture->UnwrapLoad()->As<sem::VariableUser>()->Variable();
+ auto* sampler_var =
+ sampler ? sampler->UnwrapLoad()->As<sem::VariableUser>()->Variable()
+ : nullptr;
+ sem::VariablePair new_pair(texture_var, sampler_var);
+ for (auto* arg : expr->args) {
+ auto* type = ctx.src->TypeOf(arg)->UnwrapRef();
+ if (type->Is<type::Texture>()) {
+ const Variable* var =
+ IsGlobal(new_pair)
+ ? global_combined_texture_samplers_[new_pair]
+ : function_combined_texture_samplers_[call->Stmt()->Function()]
+ [new_pair];
+ args.Push(ctx.dst->Expr(var->name->symbol));
+ } else if (auto* sampler_type = type->As<type::Sampler>()) {
+ type::SamplerKind kind = sampler_type->kind();
+ int index = (kind == type::SamplerKind::kSampler) ? 0 : 1;
+ const Variable*& p = placeholder_samplers_[index];
+ if (!p) {
+ p = CreatePlaceholder(kind);
+ }
+ args.Push(ctx.dst->Expr(p->name->symbol));
+ } else {
+ args.Push(ctx.Clone(arg));
+ }
+ }
+ const Expression* value = ctx.dst->Call(ctx.Clone(expr->target), args);
+ if (builtin->Type() == builtin::Function::kTextureLoad &&
+ texture_var->Type()->UnwrapRef()->Is<type::DepthTexture>() &&
+ !call->Stmt()->Declaration()->Is<CallStatement>()) {
+ value = ctx.dst->MemberAccessor(value, "x");
+ }
+ return value;
+ }
+ // Replace all function calls.
+ if (auto* callee = call->Target()->As<sem::Function>()) {
+ for (auto pair : callee->TextureSamplerPairs()) {
+ // Global pairs used by the callee do not require a function
+ // parameter at the call site.
+ if (IsGlobal(pair)) {
+ continue;
+ }
+ const sem::Variable* texture_var = pair.first;
+ const sem::Variable* sampler_var = pair.second;
+ if (auto* param = texture_var->As<sem::Parameter>()) {
+ const sem::ValueExpression* texture = call->Arguments()[param->Index()];
+ texture_var =
+ texture->UnwrapLoad()->As<sem::VariableUser>()->Variable();
+ }
+ if (sampler_var) {
+ if (auto* param = sampler_var->As<sem::Parameter>()) {
+ const sem::ValueExpression* sampler =
+ call->Arguments()[param->Index()];
+ sampler_var =
+ sampler->UnwrapLoad()->As<sem::VariableUser>()->Variable();
+ }
+ }
+ sem::VariablePair new_pair(texture_var, sampler_var);
+ // If both texture and sampler are (now) global, pass that
+ // global variable to the callee. Otherwise use the caller's
+ // function parameter for this pair.
+ const Variable* var =
+ IsGlobal(new_pair)
+ ? global_combined_texture_samplers_[new_pair]
+ : function_combined_texture_samplers_[call->Stmt()->Function()]
+ [new_pair];
+ auto* arg = ctx.dst->Expr(var->name->symbol);
+ args.Push(arg);
+ }
+ // Append all of the remaining non-texture and non-sampler
+ // parameters.
+ for (auto* arg : expr->args) {
+ if (!ctx.src->TypeOf(arg)
+ ->UnwrapRef()
+ ->IsAnyOf<type::Texture, type::Sampler>()) {
+ args.Push(ctx.Clone(arg));
+ }
+ }
+ return ctx.dst->Call(ctx.Clone(expr->target), args);
+ }
+ }
+ return nullptr;
+ });
+
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+};
+
+CombineSamplers::CombineSamplers() = default;
+
+CombineSamplers::~CombineSamplers() = default;
+
+Transform::ApplyResult CombineSamplers::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap&) const {
+ auto* binding_info = inputs.Get<BindingInfo>();
+ if (!binding_info) {
+ ProgramBuilder b;
+ b.Diagnostics().add_error(diag::System::Transform,
+ "missing transform data for " + std::string(TypeInfo().name));
+ return Program(std::move(b));
+ }
+
+ return State(src, binding_info).Run();
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/combine_samplers.h b/src/tint/lang/wgsl/ast/transform/combine_samplers.h
new file mode 100644
index 0000000..ffc793d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/combine_samplers.h
@@ -0,0 +1,102 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_COMBINE_SAMPLERS_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_COMBINE_SAMPLERS_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+#include "src/tint/sem/sampler_texture_pair.h"
+
+namespace tint::ast::transform {
+
+/// This transform converts all separate texture/sampler refences in a
+/// program into combined texture/samplers. This is required for GLSL,
+/// which does not support separate texture/samplers.
+///
+/// It utilizes the texture/sampler information collected by the
+/// Resolver and stored on each sem::Function. For each function, all
+/// separate texture/sampler parameters in the function signature are
+/// removed. For each unique pair, if both texture and sampler are
+/// global variables, the function passes the corresponding combined
+/// global stored in global_combined_texture_samplers_ at the call
+/// site. Otherwise, either the texture or sampler must be a function
+/// parameter. In this case, a new parameter is added to the function
+/// signature. All separate texture/sampler parameters are removed.
+///
+/// All texture builtin callsites are modified to pass the combined
+/// texture/sampler as the first argument, and separate texture/sampler
+/// arguments are removed.
+///
+/// Note that the sampler may be null, indicating that only a texture
+/// reference was required (e.g., textureLoad). In this case, a
+/// placeholder global sampler is used at the AST level. This will be
+/// combined with the original texture to give a combined global, and
+/// the placeholder removed (ignored) by the GLSL writer.
+///
+/// Note that the combined samplers are actually represented by a
+/// Texture node at the AST level, since this contains all the
+/// information needed to represent a combined sampler in GLSL
+/// (dimensionality, component type, etc). The GLSL writer outputs such
+/// (Tint) Textures as (GLSL) Samplers.
+class CombineSamplers final : public utils::Castable<CombineSamplers, Transform> {
+ public:
+ /// A pair of binding points.
+ using SamplerTexturePair = sem::SamplerTexturePair;
+
+ /// A map from a sampler/texture pair to a named global.
+ using BindingMap = std::unordered_map<SamplerTexturePair, std::string>;
+
+ /// The client-provided mapping from separate texture and sampler binding
+ /// points to combined sampler binding point.
+ struct BindingInfo final : public utils::Castable<BindingInfo, Data> {
+ /// Constructor
+ /// @param map the map of all (texture, sampler) -> (combined) pairs
+ /// @param placeholder the binding point to use for placeholder samplers.
+ BindingInfo(const BindingMap& map, const sem::BindingPoint& placeholder);
+
+ /// Copy constructor
+ /// @param other the other BindingInfo to copy
+ BindingInfo(const BindingInfo& other);
+
+ /// Destructor
+ ~BindingInfo() override;
+
+ /// A map of bindings from (texture, sampler) -> combined sampler.
+ BindingMap binding_map;
+
+ /// The binding point to use for placeholder samplers.
+ sem::BindingPoint placeholder_binding_point;
+ };
+
+ /// Constructor
+ CombineSamplers();
+
+ /// Destructor
+ ~CombineSamplers() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_COMBINE_SAMPLERS_H_
diff --git a/src/tint/lang/wgsl/ast/transform/combine_samplers_test.cc b/src/tint/lang/wgsl/ast/transform/combine_samplers_test.cc
new file mode 100644
index 0000000..6308232
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/combine_samplers_test.cc
@@ -0,0 +1,988 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/combine_samplers.h"
+
+#include <memory>
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using CombineSamplersTest = TransformTest;
+
+TEST_F(CombineSamplersTest, EmptyModule) {
+ auto* src = "";
+ auto* expect = "";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, SimplePair) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn main() -> vec4<f32> {
+ return textureSample(t, s, vec2<f32>(1.0, 2.0));
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var t_s : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn main() -> vec4<f32> {
+ return textureSample(t_s, placeholder_sampler, vec2<f32>(1.0, 2.0));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, SimplePair_OutOfOrder) {
+ auto* src = R"(
+fn main() -> vec4<f32> {
+ return textureSample(t, s, vec2<f32>(1.0, 2.0));
+}
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var t_s : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn main() -> vec4<f32> {
+ return textureSample(t_s, placeholder_sampler, vec2<f32>(1.0, 2.0));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, SimplePairInAFunction) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn sample(t : texture_2d<f32>, s : sampler, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t, s, coords);
+}
+
+fn main() -> vec4<f32> {
+ return sample(t, s, vec2<f32>(1.0, 2.0));
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn sample(t_s : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t_s, placeholder_sampler, coords);
+}
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var t_s_1 : texture_2d<f32>;
+
+fn main() -> vec4<f32> {
+ return sample(t_s_1, vec2<f32>(1.0, 2.0));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, SimplePairInAFunction_OutOfOrder) {
+ auto* src = R"(
+fn main() -> vec4<f32> {
+ return sample(t, s, vec2<f32>(1.0, 2.0));
+}
+
+fn sample(t : texture_2d<f32>, s : sampler, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t, s, coords);
+}
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var t_s : texture_2d<f32>;
+
+fn main() -> vec4<f32> {
+ return sample(t_s, vec2<f32>(1.0, 2.0));
+}
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn sample(t_s_1 : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t_s_1, placeholder_sampler, coords);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, SimplePairRename) {
+ auto* src = R"(
+@group(0) @binding(1) var t : texture_2d<f32>;
+
+@group(2) @binding(3) var s : sampler;
+
+fn main() -> vec4<f32> {
+ return textureSample(t, s, vec2<f32>(1.0, 2.0));
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var fuzzy : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn main() -> vec4<f32> {
+ return textureSample(fuzzy, placeholder_sampler, vec2<f32>(1.0, 2.0));
+}
+)";
+
+ Transform::DataMap data;
+ CombineSamplers::BindingMap map;
+ sem::SamplerTexturePair pair;
+ pair.texture_binding_point.group = 0;
+ pair.texture_binding_point.binding = 1;
+ pair.sampler_binding_point.group = 2;
+ pair.sampler_binding_point.binding = 3;
+ map[pair] = "fuzzy";
+ sem::BindingPoint placeholder{1024, 0};
+ data.Add<CombineSamplers::BindingInfo>(map, placeholder);
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, SimplePairRenameMiss) {
+ auto* src = R"(
+@group(0) @binding(1) var t : texture_2d<f32>;
+
+@group(2) @binding(3) var s : sampler;
+
+fn main() -> vec4<f32> {
+ return textureSample(t, s, vec2<f32>(1.0, 2.0));
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var t_s : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn main() -> vec4<f32> {
+ return textureSample(t_s, placeholder_sampler, vec2<f32>(1.0, 2.0));
+}
+)";
+
+ Transform::DataMap data;
+ CombineSamplers::BindingMap map;
+ sem::SamplerTexturePair pair;
+ pair.texture_binding_point.group = 3;
+ pair.texture_binding_point.binding = 2;
+ pair.sampler_binding_point.group = 1;
+ pair.sampler_binding_point.binding = 0;
+ map[pair] = "fuzzy";
+ sem::BindingPoint placeholder{1024, 0};
+ data.Add<CombineSamplers::BindingInfo>(map, placeholder);
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, AliasedTypes) {
+ auto* src = R"(
+
+alias Tex2d = texture_2d<f32>;
+
+@group(0) @binding(0) var t : Tex2d;
+
+@group(0) @binding(1) var s : sampler;
+
+fn sample(t : Tex2d, s : sampler, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t, s, coords);
+}
+
+fn main() -> vec4<f32> {
+ return sample(t, s, vec2<f32>(1.0, 2.0));
+}
+)";
+ auto* expect = R"(
+alias Tex2d = texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn sample(t_s : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t_s, placeholder_sampler, coords);
+}
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var t_s_1 : texture_2d<f32>;
+
+fn main() -> vec4<f32> {
+ return sample(t_s_1, vec2<f32>(1.0, 2.0));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, AliasedTypes_OutOfOrder) {
+ auto* src = R"(
+fn main() -> vec4<f32> {
+ return sample(t, s, vec2<f32>(1.0, 2.0));
+}
+
+fn sample(t : Tex2d, s : sampler, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t, s, coords);
+}
+
+@group(0) @binding(0) var t : Tex2d;
+@group(0) @binding(1) var s : sampler;
+
+alias Tex2d = texture_2d<f32>;
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var t_s : texture_2d<f32>;
+
+fn main() -> vec4<f32> {
+ return sample(t_s, vec2<f32>(1.0, 2.0));
+}
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn sample(t_s_1 : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t_s_1, placeholder_sampler, coords);
+}
+
+alias Tex2d = texture_2d<f32>;
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, SimplePairInTwoFunctions) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn g(t : texture_2d<f32>, s : sampler, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t, s, coords);
+}
+
+fn f(t : texture_2d<f32>, s : sampler, coords : vec2<f32>) -> vec4<f32> {
+ return g(t, s, coords);
+}
+
+fn main() -> vec4<f32> {
+ return f(t, s, vec2<f32>(1.0, 2.0));
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn g(t_s : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t_s, placeholder_sampler, coords);
+}
+
+fn f(t_s_1 : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return g(t_s_1, coords);
+}
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var t_s_2 : texture_2d<f32>;
+
+fn main() -> vec4<f32> {
+ return f(t_s_2, vec2<f32>(1.0, 2.0));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, SimplePairInTwoFunctions_OutOfOrder) {
+ auto* src = R"(
+fn main() -> vec4<f32> {
+ return f(t, s, vec2<f32>(1.0, 2.0));
+}
+
+fn f(t : texture_2d<f32>, s : sampler, coords : vec2<f32>) -> vec4<f32> {
+ return g(t, s, coords);
+}
+
+fn g(t : texture_2d<f32>, s : sampler, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t, s, coords);
+}
+
+@group(0) @binding(1) var s : sampler;
+@group(0) @binding(0) var t : texture_2d<f32>;
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var t_s : texture_2d<f32>;
+
+fn main() -> vec4<f32> {
+ return f(t_s, vec2<f32>(1.0, 2.0));
+}
+
+fn f(t_s_1 : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return g(t_s_1, coords);
+}
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn g(t_s_2 : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t_s_2, placeholder_sampler, coords);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, TwoFunctionsGenerateSamePair) {
+ auto* src = R"(
+@group(1) @binding(0) var tex : texture_2d<f32>;
+
+@group(1) @binding(1) var samp : sampler;
+
+fn f() -> vec4<f32> {
+ return textureSample(tex, samp, vec2<f32>(1.0, 2.0));
+}
+
+fn g() -> vec4<f32> {
+ return textureSample(tex, samp, vec2<f32>(3.0, 4.0));
+}
+
+fn main() -> vec4<f32> {
+ return f() + g();
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn f() -> vec4<f32> {
+ return textureSample(tex_samp, placeholder_sampler, vec2<f32>(1.0, 2.0));
+}
+
+fn g() -> vec4<f32> {
+ return textureSample(tex_samp, placeholder_sampler, vec2<f32>(3.0, 4.0));
+}
+
+fn main() -> vec4<f32> {
+ return (f() + g());
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, ThreeTexturesThreeSamplers) {
+ auto* src = R"(
+@group(0) @binding(0) var tex1 : texture_2d<f32>;
+@group(0) @binding(1) var tex2 : texture_2d<f32>;
+@group(0) @binding(2) var tex3 : texture_2d<f32>;
+
+@group(1) @binding(0) var samp1 : sampler;
+@group(1) @binding(1) var samp2: sampler;
+@group(1) @binding(2) var samp3: sampler;
+
+fn sample(t : texture_2d<f32>, s : sampler) -> vec4<f32> {
+ return textureSample(t, s, vec2<f32>(1.0, 2.0));
+}
+
+fn main() -> vec4<f32> {
+ return sample(tex1, samp1)
+ + sample(tex1, samp2)
+ + sample(tex1, samp3)
+ + sample(tex2, samp1)
+ + sample(tex2, samp2)
+ + sample(tex2, samp3)
+ + sample(tex3, samp1)
+ + sample(tex3, samp2)
+ + sample(tex3, samp3);
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn sample(t_s : texture_2d<f32>) -> vec4<f32> {
+ return textureSample(t_s, placeholder_sampler, vec2<f32>(1.0, 2.0));
+}
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex1_samp1 : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex1_samp2 : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex1_samp3 : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex2_samp1 : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex2_samp2 : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex2_samp3 : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex3_samp1 : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex3_samp2 : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex3_samp3 : texture_2d<f32>;
+
+fn main() -> vec4<f32> {
+ return ((((((((sample(tex1_samp1) + sample(tex1_samp2)) + sample(tex1_samp3)) + sample(tex2_samp1)) + sample(tex2_samp2)) + sample(tex2_samp3)) + sample(tex3_samp1)) + sample(tex3_samp2)) + sample(tex3_samp3));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, TwoFunctionsTwoTexturesDiamond) {
+ auto* src = R"(
+@group(0) @binding(0) var tex1 : texture_2d<f32>;
+
+@group(0) @binding(1) var tex2 : texture_2d<f32>;
+
+@group(0) @binding(2) var samp : sampler;
+
+fn sample(t : texture_2d<f32>, s : sampler, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t, s, coords);
+}
+
+fn f(t1 : texture_2d<f32>, t2 : texture_2d<f32>, s : sampler, coords : vec2<f32>) -> vec4<f32> {
+ return sample(t1, s, coords) + sample(t2, s, coords);
+}
+
+fn main() -> vec4<f32> {
+ return f(tex1, tex2, samp, vec2<f32>(1.0, 2.0));
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn sample(t_s : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t_s, placeholder_sampler, coords);
+}
+
+fn f(t1_s : texture_2d<f32>, t2_s : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return (sample(t1_s, coords) + sample(t2_s, coords));
+}
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex1_samp : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex2_samp : texture_2d<f32>;
+
+fn main() -> vec4<f32> {
+ return f(tex1_samp, tex2_samp, vec2<f32>(1.0, 2.0));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, TwoFunctionsTwoSamplersDiamond) {
+ auto* src = R"(
+@group(0) @binding(0) var tex : texture_2d<f32>;
+
+@group(0) @binding(1) var samp1 : sampler;
+
+@group(0) @binding(2) var samp2 : sampler;
+
+fn sample(t : texture_2d<f32>, s : sampler, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t, s, coords);
+}
+
+fn f(t : texture_2d<f32>, s1 : sampler, s2 : sampler, coords : vec2<f32>) -> vec4<f32> {
+ return sample(t, s1, coords) + sample(t, s2, coords);
+}
+
+fn main() -> vec4<f32> {
+ return f(tex, samp1, samp2, vec2<f32>(1.0, 2.0));
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn sample(t_s : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t_s, placeholder_sampler, coords);
+}
+
+fn f(t_s1 : texture_2d<f32>, t_s2 : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return (sample(t_s1, coords) + sample(t_s2, coords));
+}
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp1 : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp2 : texture_2d<f32>;
+
+fn main() -> vec4<f32> {
+ return f(tex_samp1, tex_samp2, vec2<f32>(1.0, 2.0));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, GlobalTextureLocalSampler) {
+ auto* src = R"(
+@group(0) @binding(0) var tex : texture_2d<f32>;
+
+@group(0) @binding(1) var samp1 : sampler;
+
+@group(0) @binding(2) var samp2 : sampler;
+
+fn f(s1 : sampler, s2 : sampler, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(tex, s1, coords) + textureSample(tex, s2, coords);
+}
+
+fn main() -> vec4<f32> {
+ return f(samp1, samp2, vec2<f32>(1.0, 2.0));
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn f(tex_s1 : texture_2d<f32>, tex_s2 : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return (textureSample(tex_s1, placeholder_sampler, coords) + textureSample(tex_s2, placeholder_sampler, coords));
+}
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp1 : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp2 : texture_2d<f32>;
+
+fn main() -> vec4<f32> {
+ return f(tex_samp1, tex_samp2, vec2<f32>(1.0, 2.0));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, GlobalTextureLocalSampler_OutOfOrder) {
+ auto* src = R"(
+fn main() -> vec4<f32> {
+ return f(samp1, samp2, vec2<f32>(1.0, 2.0));
+}
+
+fn f(s1 : sampler, s2 : sampler, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(tex, s1, coords) + textureSample(tex, s2, coords);
+}
+
+@group(0) @binding(1) var samp1 : sampler;
+@group(0) @binding(2) var samp2 : sampler;
+@group(0) @binding(0) var tex : texture_2d<f32>;
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp1 : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp2 : texture_2d<f32>;
+
+fn main() -> vec4<f32> {
+ return f(tex_samp1, tex_samp2, vec2<f32>(1.0, 2.0));
+}
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn f(tex_s1 : texture_2d<f32>, tex_s2 : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return (textureSample(tex_s1, placeholder_sampler, coords) + textureSample(tex_s2, placeholder_sampler, coords));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, LocalTextureGlobalSampler) {
+ auto* src = R"(
+@group(0) @binding(0) var tex1 : texture_2d<f32>;
+
+@group(0) @binding(1) var tex2 : texture_2d<f32>;
+
+@group(0) @binding(2) var samp : sampler;
+
+fn f(t1 : texture_2d<f32>, t2 : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t1, samp, coords) + textureSample(t2, samp, coords);
+}
+
+fn main() -> vec4<f32> {
+ return f(tex1, tex2, vec2<f32>(1.0, 2.0));
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn f(t1_samp : texture_2d<f32>, t2_samp : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return (textureSample(t1_samp, placeholder_sampler, coords) + textureSample(t2_samp, placeholder_sampler, coords));
+}
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex1_samp : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex2_samp : texture_2d<f32>;
+
+fn main() -> vec4<f32> {
+ return f(tex1_samp, tex2_samp, vec2<f32>(1.0, 2.0));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, LocalTextureGlobalSampler_OutOfOrder) {
+ auto* src = R"(
+fn main() -> vec4<f32> {
+ return f(tex1, tex2, vec2<f32>(1.0, 2.0));
+}
+
+fn f(t1 : texture_2d<f32>, t2 : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return textureSample(t1, samp, coords) + textureSample(t2, samp, coords);
+}
+
+@group(0) @binding(2) var samp : sampler;
+@group(0) @binding(0) var tex1 : texture_2d<f32>;
+@group(0) @binding(1) var tex2 : texture_2d<f32>;
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex1_samp : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex2_samp : texture_2d<f32>;
+
+fn main() -> vec4<f32> {
+ return f(tex1_samp, tex2_samp, vec2<f32>(1.0, 2.0));
+}
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn f(t1_samp : texture_2d<f32>, t2_samp : texture_2d<f32>, coords : vec2<f32>) -> vec4<f32> {
+ return (textureSample(t1_samp, placeholder_sampler, coords) + textureSample(t2_samp, placeholder_sampler, coords));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, TextureLoadNoSampler) {
+ auto* src = R"(
+@group(0) @binding(0) var tex : texture_2d<f32>;
+
+fn f(t : texture_2d<f32>, coords : vec2<i32>) -> vec4<f32> {
+ return textureLoad(t, coords, 0);
+}
+
+fn main() -> vec4<f32> {
+ return f(tex, vec2<i32>(1, 2));
+}
+)";
+ auto* expect = R"(
+fn f(t_1 : texture_2d<f32>, coords : vec2<i32>) -> vec4<f32> {
+ return textureLoad(t_1, coords, 0);
+}
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var fred : texture_2d<f32>;
+
+fn main() -> vec4<f32> {
+ return f(fred, vec2<i32>(1, 2));
+}
+)";
+
+ sem::BindingPoint placeholder{1024, 0};
+ sem::SamplerTexturePair pair;
+ pair.texture_binding_point.group = 0;
+ pair.texture_binding_point.binding = 0;
+ pair.sampler_binding_point.group = placeholder.group;
+ pair.sampler_binding_point.binding = placeholder.binding;
+ CombineSamplers::BindingMap map;
+ map[pair] = "fred";
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(map, placeholder);
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, TextureWithAndWithoutSampler) {
+ auto* src = R"(
+@group(0) @binding(0) var tex : texture_2d<f32>;
+@group(0) @binding(1) var samp : sampler;
+
+fn main() -> vec4<f32> {
+ return textureLoad(tex, vec2<i32>(), 0) +
+ textureSample(tex, samp, vec2<f32>());
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var fred : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var barney : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn main() -> vec4<f32> {
+ return (textureLoad(fred, vec2<i32>(), 0) + textureSample(barney, placeholder_sampler, vec2<f32>()));
+}
+)";
+
+ sem::BindingPoint placeholder{1024, 0};
+ sem::BindingPoint tex{0, 0};
+ sem::BindingPoint samp{0, 1};
+ sem::SamplerTexturePair pair, placeholder_pair;
+ pair.texture_binding_point.group = tex.group;
+ pair.texture_binding_point.binding = tex.binding;
+ pair.sampler_binding_point.group = samp.group;
+ pair.sampler_binding_point.binding = samp.binding;
+ placeholder_pair.texture_binding_point.group = tex.group;
+ placeholder_pair.texture_binding_point.binding = tex.binding;
+ placeholder_pair.sampler_binding_point.group = placeholder.group;
+ placeholder_pair.sampler_binding_point.binding = placeholder.binding;
+ CombineSamplers::BindingMap map;
+ map[pair] = "barney";
+ map[placeholder_pair] = "fred";
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(map, placeholder);
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, TextureSampleCompare) {
+ auto* src = R"(
+@group(0) @binding(0) var tex : texture_depth_2d;
+
+@group(0) @binding(1) var samp : sampler_comparison;
+
+fn main() -> vec4<f32> {
+ return vec4<f32>(textureSampleCompare(tex, samp, vec2<f32>(1.0, 2.0), 0.5));
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp : texture_depth_2d;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_comparison_sampler : sampler_comparison;
+
+fn main() -> vec4<f32> {
+ return vec4<f32>(textureSampleCompare(tex_samp, placeholder_comparison_sampler, vec2<f32>(1.0, 2.0), 0.5));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, TextureSampleCompareInAFunction) {
+ auto* src = R"(
+@group(0) @binding(0) var tex : texture_depth_2d;
+
+@group(0) @binding(1) var samp : sampler_comparison;
+
+fn f(t : texture_depth_2d, s : sampler_comparison, coords : vec2<f32>) -> f32 {
+ return textureSampleCompare(t, s, coords, 5.0f);
+}
+
+fn main() -> vec4<f32> {
+ return vec4<f32>(f(tex, samp, vec2<f32>(1.0, 2.0)));
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_comparison_sampler : sampler_comparison;
+
+fn f(t_s : texture_depth_2d, coords : vec2<f32>) -> f32 {
+ return textureSampleCompare(t_s, placeholder_comparison_sampler, coords, 5.0f);
+}
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp : texture_depth_2d;
+
+fn main() -> vec4<f32> {
+ return vec4<f32>(f(tex_samp, vec2<f32>(1.0, 2.0)));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, TextureSampleCompareInAFunction_OutOfOrder) {
+ auto* src = R"(
+fn main() -> vec4<f32> {
+ return vec4<f32>(f(tex, samp, vec2<f32>(1.0, 2.0)));
+}
+
+fn f(t : texture_depth_2d, s : sampler_comparison, coords : vec2<f32>) -> f32 {
+ return textureSampleCompare(t, s, coords, 5.0f);
+}
+
+@group(0) @binding(0) var tex : texture_depth_2d;
+@group(0) @binding(1) var samp : sampler_comparison;
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp : texture_depth_2d;
+
+fn main() -> vec4<f32> {
+ return vec4<f32>(f(tex_samp, vec2<f32>(1.0, 2.0)));
+}
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_comparison_sampler : sampler_comparison;
+
+fn f(t_s : texture_depth_2d, coords : vec2<f32>) -> f32 {
+ return textureSampleCompare(t_s, placeholder_comparison_sampler, coords, 5.0f);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, BindingPointCollision) {
+ auto* src = R"(
+@group(1) @binding(0) var tex : texture_2d<f32>;
+
+@group(1) @binding(1) var samp : sampler;
+
+@group(0) @binding(0) var<uniform> gcoords : vec2<f32>;
+
+fn main() -> vec4<f32> {
+ return textureSample(tex, samp, gcoords);
+}
+)";
+ auto* expect = R"(
+@internal(disable_validation__binding_point_collision) @group(0) @binding(0) var<uniform> gcoords : vec2<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn main() -> vec4<f32> {
+ return textureSample(tex_samp, placeholder_sampler, gcoords);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CombineSamplersTest, BindingPointCollision_OutOfOrder) {
+ auto* src = R"(
+fn main() -> vec4<f32> {
+ return textureSample(tex, samp, gcoords);
+}
+
+@group(1) @binding(1) var samp : sampler;
+@group(0) @binding(0) var<uniform> gcoords : vec2<f32>;
+@group(1) @binding(0) var tex : texture_2d<f32>;
+
+)";
+ auto* expect = R"(
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var tex_samp : texture_2d<f32>;
+
+@group(0) @binding(0) @internal(disable_validation__binding_point_collision) var placeholder_sampler : sampler;
+
+fn main() -> vec4<f32> {
+ return textureSample(tex_samp, placeholder_sampler, gcoords);
+}
+
+@internal(disable_validation__binding_point_collision) @group(0) @binding(0) var<uniform> gcoords : vec2<f32>;
+)";
+
+ Transform::DataMap data;
+ data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(), sem::BindingPoint());
+ auto got = Run<CombineSamplers>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/decompose_memory_access.cc b/src/tint/lang/wgsl/ast/transform/decompose_memory_access.cc
new file mode 100644
index 0000000..b4a06c7
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/decompose_memory_access.cc
@@ -0,0 +1,988 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/decompose_memory_access.h"
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/assignment_statement.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/disable_validation_attribute.h"
+#include "src/tint/lang/wgsl/ast/unary_op.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/struct.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
+#include "src/tint/type/array.h"
+#include "src/tint/type/atomic.h"
+#include "src/tint/type/reference.h"
+#include "src/tint/utils/block_allocator.h"
+#include "src/tint/utils/hash.h"
+#include "src/tint/utils/map.h"
+#include "src/tint/utils/string_stream.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::DecomposeMemoryAccess);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::DecomposeMemoryAccess::Intrinsic);
+
+namespace tint::ast::transform {
+
+namespace {
+
+bool ShouldRun(const Program* program) {
+ for (auto* decl : program->AST().GlobalDeclarations()) {
+ if (auto* var = program->Sem().Get<sem::Variable>(decl)) {
+ if (var->AddressSpace() == builtin::AddressSpace::kStorage ||
+ var->AddressSpace() == builtin::AddressSpace::kUniform) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+/// Offset is a simple Expression builder interface, used to build byte
+/// offsets for storage and uniform buffer accesses.
+struct Offset : utils::Castable<Offset> {
+ /// @returns builds and returns the Expression in `ctx.dst`
+ virtual const Expression* Build(CloneContext& ctx) const = 0;
+};
+
+/// OffsetExpr is an implementation of Offset that clones and casts the given
+/// expression to `u32`.
+struct OffsetExpr : Offset {
+ const Expression* const expr = nullptr;
+
+ explicit OffsetExpr(const Expression* e) : expr(e) {}
+
+ const Expression* Build(CloneContext& ctx) const override {
+ auto* type = ctx.src->Sem().GetVal(expr)->Type()->UnwrapRef();
+ auto* res = ctx.Clone(expr);
+ if (!type->Is<type::U32>()) {
+ res = ctx.dst->Call<u32>(res);
+ }
+ return res;
+ }
+};
+
+/// OffsetLiteral is an implementation of Offset that constructs a u32 literal
+/// value.
+struct OffsetLiteral final : utils::Castable<OffsetLiteral, Offset> {
+ uint32_t const literal = 0;
+
+ explicit OffsetLiteral(uint32_t lit) : literal(lit) {}
+
+ const Expression* Build(CloneContext& ctx) const override {
+ return ctx.dst->Expr(u32(literal));
+ }
+};
+
+/// OffsetBinOp is an implementation of Offset that constructs a binary-op of
+/// two Offsets.
+struct OffsetBinOp : Offset {
+ BinaryOp op;
+ Offset const* lhs = nullptr;
+ Offset const* rhs = nullptr;
+
+ const Expression* Build(CloneContext& ctx) const override {
+ return ctx.dst->create<BinaryExpression>(op, lhs->Build(ctx), rhs->Build(ctx));
+ }
+};
+
+/// LoadStoreKey is the unordered map key to a load or store intrinsic.
+struct LoadStoreKey {
+ type::Type const* el_ty = nullptr; // element type
+ Symbol const buffer; // buffer name
+ bool operator==(const LoadStoreKey& rhs) const {
+ return el_ty == rhs.el_ty && buffer == rhs.buffer;
+ }
+ struct Hasher {
+ inline std::size_t operator()(const LoadStoreKey& u) const {
+ return utils::Hash(u.el_ty, u.buffer);
+ }
+ };
+};
+
+/// AtomicKey is the unordered map key to an atomic intrinsic.
+struct AtomicKey {
+ type::Type const* el_ty = nullptr; // element type
+ builtin::Function const op; // atomic op
+ Symbol const buffer; // buffer name
+ bool operator==(const AtomicKey& rhs) const {
+ return el_ty == rhs.el_ty && op == rhs.op && buffer == rhs.buffer;
+ }
+ struct Hasher {
+ inline std::size_t operator()(const AtomicKey& u) const {
+ return utils::Hash(u.el_ty, u.op, u.buffer);
+ }
+ };
+};
+
+bool IntrinsicDataTypeFor(const type::Type* ty, DecomposeMemoryAccess::Intrinsic::DataType& out) {
+ if (ty->Is<type::I32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kI32;
+ return true;
+ }
+ if (ty->Is<type::U32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kU32;
+ return true;
+ }
+ if (ty->Is<type::F32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kF32;
+ return true;
+ }
+ if (ty->Is<type::F16>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kF16;
+ return true;
+ }
+ if (auto* vec = ty->As<type::Vector>()) {
+ switch (vec->Width()) {
+ case 2:
+ if (vec->type()->Is<type::I32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2I32;
+ return true;
+ }
+ if (vec->type()->Is<type::U32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2U32;
+ return true;
+ }
+ if (vec->type()->Is<type::F32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2F32;
+ return true;
+ }
+ if (vec->type()->Is<type::F16>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2F16;
+ return true;
+ }
+ break;
+ case 3:
+ if (vec->type()->Is<type::I32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3I32;
+ return true;
+ }
+ if (vec->type()->Is<type::U32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3U32;
+ return true;
+ }
+ if (vec->type()->Is<type::F32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3F32;
+ return true;
+ }
+ if (vec->type()->Is<type::F16>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3F16;
+ return true;
+ }
+ break;
+ case 4:
+ if (vec->type()->Is<type::I32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4I32;
+ return true;
+ }
+ if (vec->type()->Is<type::U32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4U32;
+ return true;
+ }
+ if (vec->type()->Is<type::F32>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4F32;
+ return true;
+ }
+ if (vec->type()->Is<type::F16>()) {
+ out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4F16;
+ return true;
+ }
+ break;
+ }
+ return false;
+ }
+
+ return false;
+}
+
+/// @returns a DecomposeMemoryAccess::Intrinsic attribute that can be applied to a stub function to
+/// load the type @p ty from the uniform or storage buffer with name @p buffer.
+DecomposeMemoryAccess::Intrinsic* IntrinsicLoadFor(ProgramBuilder* builder,
+ const type::Type* ty,
+ builtin::AddressSpace address_space,
+ const Symbol& buffer) {
+ DecomposeMemoryAccess::Intrinsic::DataType type;
+ if (!IntrinsicDataTypeFor(ty, type)) {
+ return nullptr;
+ }
+ return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
+ builder->ID(), builder->AllocateNodeID(), DecomposeMemoryAccess::Intrinsic::Op::kLoad, type,
+ address_space, builder->Expr(buffer));
+}
+
+/// @returns a DecomposeMemoryAccess::Intrinsic attribute that can be applied to a stub function to
+/// store the type @p ty to the storage buffer with name @p buffer.
+DecomposeMemoryAccess::Intrinsic* IntrinsicStoreFor(ProgramBuilder* builder,
+ const type::Type* ty,
+ const Symbol& buffer) {
+ DecomposeMemoryAccess::Intrinsic::DataType type;
+ if (!IntrinsicDataTypeFor(ty, type)) {
+ return nullptr;
+ }
+ return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
+ builder->ID(), builder->AllocateNodeID(), DecomposeMemoryAccess::Intrinsic::Op::kStore,
+ type, builtin::AddressSpace::kStorage, builder->Expr(buffer));
+}
+
+/// @returns a DecomposeMemoryAccess::Intrinsic attribute that can be applied to a stub function for
+/// the atomic op and the type @p ty.
+DecomposeMemoryAccess::Intrinsic* IntrinsicAtomicFor(ProgramBuilder* builder,
+ builtin::Function ity,
+ const type::Type* ty,
+ const Symbol& buffer) {
+ auto op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
+ switch (ity) {
+ case builtin::Function::kAtomicLoad:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
+ break;
+ case builtin::Function::kAtomicStore:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicStore;
+ break;
+ case builtin::Function::kAtomicAdd:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAdd;
+ break;
+ case builtin::Function::kAtomicSub:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicSub;
+ break;
+ case builtin::Function::kAtomicMax:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMax;
+ break;
+ case builtin::Function::kAtomicMin:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMin;
+ break;
+ case builtin::Function::kAtomicAnd:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAnd;
+ break;
+ case builtin::Function::kAtomicOr:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicOr;
+ break;
+ case builtin::Function::kAtomicXor:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicXor;
+ break;
+ case builtin::Function::kAtomicExchange:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicExchange;
+ break;
+ case builtin::Function::kAtomicCompareExchangeWeak:
+ op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicCompareExchangeWeak;
+ break;
+ default:
+ TINT_ICE(Transform, builder->Diagnostics())
+ << "invalid IntrinsicType for DecomposeMemoryAccess::Intrinsic: "
+ << ty->TypeInfo().name;
+ break;
+ }
+
+ DecomposeMemoryAccess::Intrinsic::DataType type;
+ if (!IntrinsicDataTypeFor(ty, type)) {
+ return nullptr;
+ }
+ return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
+ builder->ID(), builder->AllocateNodeID(), op, type, builtin::AddressSpace::kStorage,
+ builder->Expr(buffer));
+}
+
+/// BufferAccess describes a single storage or uniform buffer access
+struct BufferAccess {
+ sem::GlobalVariable const* var = nullptr; // Storage or uniform buffer variable
+ Offset const* offset = nullptr; // The byte offset on var
+ type::Type const* type = nullptr; // The type of the access
+ operator bool() const { return var; } // Returns true if valid
+};
+
+/// Store describes a single storage or uniform buffer write
+struct Store {
+ const AssignmentStatement* assignment; // The AST assignment statement
+ BufferAccess target; // The target for the write
+};
+
+} // namespace
+
+/// PIMPL state for the transform
+struct DecomposeMemoryAccess::State {
+ /// The clone context
+ CloneContext& ctx;
+ /// Alias to `*ctx.dst`
+ ProgramBuilder& b;
+ /// Map of AST expression to storage or uniform buffer access
+ /// This map has entries added when encountered, and removed when outer
+ /// expressions chain the access.
+ /// Subset of #expression_order, as expressions are not removed from
+ /// #expression_order.
+ std::unordered_map<const Expression*, BufferAccess> accesses;
+ /// The visited order of AST expressions (superset of #accesses)
+ std::vector<const Expression*> expression_order;
+ /// [buffer-type, element-type] -> load function name
+ std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> load_funcs;
+ /// [buffer-type, element-type] -> store function name
+ std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> store_funcs;
+ /// [buffer-type, element-type, atomic-op] -> load function name
+ std::unordered_map<AtomicKey, Symbol, AtomicKey::Hasher> atomic_funcs;
+ /// List of storage or uniform buffer writes
+ std::vector<Store> stores;
+ /// Allocations for offsets
+ utils::BlockAllocator<Offset> offsets_;
+
+ /// Constructor
+ /// @param context the CloneContext
+ explicit State(CloneContext& context) : ctx(context), b(*ctx.dst) {}
+
+ /// @param offset the offset value to wrap in an Offset
+ /// @returns an Offset for the given literal value
+ const Offset* ToOffset(uint32_t offset) { return offsets_.Create<OffsetLiteral>(offset); }
+
+ /// @param expr the expression to convert to an Offset
+ /// @returns an Offset for the given Expression
+ const Offset* ToOffset(const Expression* expr) {
+ if (auto* lit = expr->As<IntLiteralExpression>()) {
+ if (lit->value >= 0) {
+ return offsets_.Create<OffsetLiteral>(static_cast<uint32_t>(lit->value));
+ }
+ }
+ return offsets_.Create<OffsetExpr>(expr);
+ }
+
+ /// @param offset the Offset that is returned
+ /// @returns the given offset (pass-through)
+ const Offset* ToOffset(const Offset* offset) { return offset; }
+
+ /// @param lhs_ the left-hand side of the add expression
+ /// @param rhs_ the right-hand side of the add expression
+ /// @return an Offset that is a sum of lhs and rhs, performing basic constant
+ /// folding if possible
+ template <typename LHS, typename RHS>
+ const Offset* Add(LHS&& lhs_, RHS&& rhs_) {
+ auto* lhs = ToOffset(std::forward<LHS>(lhs_));
+ auto* rhs = ToOffset(std::forward<RHS>(rhs_));
+ auto* lhs_lit = tint::As<OffsetLiteral>(lhs);
+ auto* rhs_lit = tint::As<OffsetLiteral>(rhs);
+ if (lhs_lit && lhs_lit->literal == 0) {
+ return rhs;
+ }
+ if (rhs_lit && rhs_lit->literal == 0) {
+ return lhs;
+ }
+ if (lhs_lit && rhs_lit) {
+ if (static_cast<uint64_t>(lhs_lit->literal) + static_cast<uint64_t>(rhs_lit->literal) <=
+ 0xffffffff) {
+ return offsets_.Create<OffsetLiteral>(lhs_lit->literal + rhs_lit->literal);
+ }
+ }
+ auto* out = offsets_.Create<OffsetBinOp>();
+ out->op = BinaryOp::kAdd;
+ out->lhs = lhs;
+ out->rhs = rhs;
+ return out;
+ }
+
+ /// @param lhs_ the left-hand side of the multiply expression
+ /// @param rhs_ the right-hand side of the multiply expression
+ /// @return an Offset that is the multiplication of lhs and rhs, performing
+ /// basic constant folding if possible
+ template <typename LHS, typename RHS>
+ const Offset* Mul(LHS&& lhs_, RHS&& rhs_) {
+ auto* lhs = ToOffset(std::forward<LHS>(lhs_));
+ auto* rhs = ToOffset(std::forward<RHS>(rhs_));
+ auto* lhs_lit = tint::As<OffsetLiteral>(lhs);
+ auto* rhs_lit = tint::As<OffsetLiteral>(rhs);
+ if (lhs_lit && lhs_lit->literal == 0) {
+ return offsets_.Create<OffsetLiteral>(0u);
+ }
+ if (rhs_lit && rhs_lit->literal == 0) {
+ return offsets_.Create<OffsetLiteral>(0u);
+ }
+ if (lhs_lit && lhs_lit->literal == 1) {
+ return rhs;
+ }
+ if (rhs_lit && rhs_lit->literal == 1) {
+ return lhs;
+ }
+ if (lhs_lit && rhs_lit) {
+ return offsets_.Create<OffsetLiteral>(lhs_lit->literal * rhs_lit->literal);
+ }
+ auto* out = offsets_.Create<OffsetBinOp>();
+ out->op = BinaryOp::kMultiply;
+ out->lhs = lhs;
+ out->rhs = rhs;
+ return out;
+ }
+
+ /// AddAccess() adds the `expr -> access` map item to #accesses, and `expr`
+ /// to #expression_order.
+ /// @param expr the expression that performs the access
+ /// @param access the access
+ void AddAccess(const Expression* expr, const BufferAccess& access) {
+ TINT_ASSERT(Transform, access.type);
+ accesses.emplace(expr, access);
+ expression_order.emplace_back(expr);
+ }
+
+ /// TakeAccess() removes the `node` item from #accesses (if it exists),
+ /// returning the BufferAccess. If #accesses does not hold an item for
+ /// `node`, an invalid BufferAccess is returned.
+ /// @param node the expression that performed an access
+ /// @return the BufferAccess for the given expression
+ BufferAccess TakeAccess(const Expression* node) {
+ auto lhs_it = accesses.find(node);
+ if (lhs_it == accesses.end()) {
+ return {};
+ }
+ auto access = lhs_it->second;
+ accesses.erase(node);
+ return access;
+ }
+
+ /// LoadFunc() returns a symbol to an intrinsic function that loads an element of type @p el_ty
+ /// from a storage or uniform buffer with name @p buffer.
+ /// The emitted function has the signature:
+ /// `fn load(offset : u32) -> el_ty`
+ /// @param el_ty the storage or uniform buffer element type
+ /// @param address_space either kUniform or kStorage
+ /// @param buffer the symbol of the storage or uniform buffer variable, owned by the target
+ /// ProgramBuilder.
+ /// @return the name of the function that performs the load
+ Symbol LoadFunc(const type::Type* el_ty,
+ builtin::AddressSpace address_space,
+ const Symbol& buffer) {
+ return utils::GetOrCreate(load_funcs, LoadStoreKey{el_ty, buffer}, [&] {
+ utils::Vector params{b.Param("offset", b.ty.u32())};
+
+ auto name = b.Symbols().New(buffer.Name() + "_load");
+
+ if (auto* intrinsic = IntrinsicLoadFor(ctx.dst, el_ty, address_space, buffer)) {
+ auto el_ast_ty = CreateASTTypeFor(ctx, el_ty);
+ b.Func(name, params, el_ast_ty, nullptr,
+ utils::Vector{
+ intrinsic,
+ b.Disable(DisabledValidation::kFunctionHasNoBody),
+ });
+ } else if (auto* arr_ty = el_ty->As<type::Array>()) {
+ // fn load_func(buffer : buf_ty, offset : u32) -> array<T, N> {
+ // var arr : array<T, N>;
+ // for (var i = 0u; i < array_count; i = i + 1) {
+ // arr[i] = el_load_func(buffer, offset + i * array_stride)
+ // }
+ // return arr;
+ // }
+ auto load = LoadFunc(arr_ty->ElemType()->UnwrapRef(), address_space, buffer);
+ auto* arr = b.Var(b.Symbols().New("arr"), CreateASTTypeFor(ctx, arr_ty));
+ auto* i = b.Var(b.Symbols().New("i"), b.Expr(0_u));
+ auto* for_init = b.Decl(i);
+ auto arr_cnt = arr_ty->ConstantCount();
+ if (TINT_UNLIKELY(!arr_cnt)) {
+ // Non-constant counts should not be possible:
+ // * Override-expression counts can only be applied to workgroup arrays, and
+ // this method only handles storage and uniform.
+ // * Runtime-sized arrays are not loadable.
+ TINT_ICE(Transform, b.Diagnostics()) << "unexpected non-constant array count";
+ arr_cnt = 1;
+ }
+ auto* for_cond = b.create<BinaryExpression>(BinaryOp::kLessThan, b.Expr(i),
+ b.Expr(u32(arr_cnt.value())));
+ auto* for_cont = b.Assign(i, b.Add(i, 1_u));
+ auto* arr_el = b.IndexAccessor(arr, i);
+ auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
+ auto* el_val = b.Call(load, el_offset);
+ auto* for_loop =
+ b.For(for_init, for_cond, for_cont, b.Block(b.Assign(arr_el, el_val)));
+
+ b.Func(name, params, CreateASTTypeFor(ctx, arr_ty),
+ utils::Vector{
+ b.Decl(arr),
+ for_loop,
+ b.Return(arr),
+ });
+ } else {
+ utils::Vector<const Expression*, 8> values;
+ if (auto* mat_ty = el_ty->As<type::Matrix>()) {
+ auto* vec_ty = mat_ty->ColumnType();
+ Symbol load = LoadFunc(vec_ty, address_space, buffer);
+ for (uint32_t i = 0; i < mat_ty->columns(); i++) {
+ auto* offset = b.Add("offset", u32(i * mat_ty->ColumnStride()));
+ values.Push(b.Call(load, offset));
+ }
+ } else if (auto* str = el_ty->As<type::Struct>()) {
+ for (auto* member : str->Members()) {
+ auto* offset = b.Add("offset", u32(member->Offset()));
+ Symbol load = LoadFunc(member->Type()->UnwrapRef(), address_space, buffer);
+ values.Push(b.Call(load, offset));
+ }
+ }
+ b.Func(name, params, CreateASTTypeFor(ctx, el_ty),
+ utils::Vector{
+ b.Return(b.Call(CreateASTTypeFor(ctx, el_ty), values)),
+ });
+ }
+ return name;
+ });
+ }
+
+ /// StoreFunc() returns a symbol to an intrinsic function that stores an element of type @p
+ /// el_ty to the storage buffer @p buffer. The function has the signature:
+ /// `fn store(offset : u32, value : el_ty)`
+ /// @param el_ty the storage buffer element type
+ /// @param buffer the symbol of the storage buffer variable, owned by the target ProgramBuilder.
+ /// @return the name of the function that performs the store
+ Symbol StoreFunc(const type::Type* el_ty, const Symbol& buffer) {
+ return utils::GetOrCreate(store_funcs, LoadStoreKey{el_ty, buffer}, [&] {
+ utils::Vector params{
+ b.Param("offset", b.ty.u32()),
+ b.Param("value", CreateASTTypeFor(ctx, el_ty)),
+ };
+
+ auto name = b.Symbols().New(buffer.Name() + "_store");
+
+ if (auto* intrinsic = IntrinsicStoreFor(ctx.dst, el_ty, buffer)) {
+ b.Func(name, params, b.ty.void_(), nullptr,
+ utils::Vector{
+ intrinsic,
+ b.Disable(DisabledValidation::kFunctionHasNoBody),
+ });
+ } else {
+ auto body = Switch<utils::Vector<const Statement*, 8>>(
+ el_ty, //
+ [&](const type::Array* arr_ty) {
+ // fn store_func(buffer : buf_ty, offset : u32, value : el_ty) {
+ // var array = value; // No dynamic indexing on constant arrays
+ // for (var i = 0u; i < array_count; i = i + 1) {
+ // arr[i] = el_store_func(buffer, offset + i * array_stride,
+ // value[i])
+ // }
+ // return arr;
+ // }
+ auto* array = b.Var(b.Symbols().New("array"), b.Expr("value"));
+ auto store = StoreFunc(arr_ty->ElemType()->UnwrapRef(), buffer);
+ auto* i = b.Var(b.Symbols().New("i"), b.Expr(0_u));
+ auto* for_init = b.Decl(i);
+ auto arr_cnt = arr_ty->ConstantCount();
+ if (TINT_UNLIKELY(!arr_cnt)) {
+ // Non-constant counts should not be possible:
+ // * Override-expression counts can only be applied to workgroup
+ // arrays, and this method only handles storage and uniform.
+ // * Runtime-sized arrays are not storable.
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unexpected non-constant array count";
+ arr_cnt = 1;
+ }
+ auto* for_cond = b.create<BinaryExpression>(BinaryOp::kLessThan, b.Expr(i),
+ b.Expr(u32(arr_cnt.value())));
+ auto* for_cont = b.Assign(i, b.Add(i, 1_u));
+ auto* arr_el = b.IndexAccessor(array, i);
+ auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
+ auto* store_stmt = b.CallStmt(b.Call(store, el_offset, arr_el));
+ auto* for_loop = b.For(for_init, for_cond, for_cont, b.Block(store_stmt));
+
+ return utils::Vector{b.Decl(array), for_loop};
+ },
+ [&](const type::Matrix* mat_ty) {
+ auto* vec_ty = mat_ty->ColumnType();
+ Symbol store = StoreFunc(vec_ty, buffer);
+ utils::Vector<const Statement*, 4> stmts;
+ for (uint32_t i = 0; i < mat_ty->columns(); i++) {
+ auto* offset = b.Add("offset", u32(i * mat_ty->ColumnStride()));
+ auto* element = b.IndexAccessor("value", u32(i));
+ auto* call = b.Call(store, offset, element);
+ stmts.Push(b.CallStmt(call));
+ }
+ return stmts;
+ },
+ [&](const type::Struct* str) {
+ utils::Vector<const Statement*, 8> stmts;
+ for (auto* member : str->Members()) {
+ auto* offset = b.Add("offset", u32(member->Offset()));
+ auto* element = b.MemberAccessor("value", ctx.Clone(member->Name()));
+ Symbol store = StoreFunc(member->Type()->UnwrapRef(), buffer);
+ auto* call = b.Call(store, offset, element);
+ stmts.Push(b.CallStmt(call));
+ }
+ return stmts;
+ });
+
+ b.Func(name, params, b.ty.void_(), body);
+ }
+
+ return name;
+ });
+ }
+
+ /// AtomicFunc() returns a symbol to an intrinsic function that performs an atomic operation on
+ /// the storage buffer @p buffer. The function has the signature:
+ // `fn atomic_op(offset : u32, ...) -> T`
+ /// @param el_ty the storage buffer element type
+ /// @param intrinsic the atomic intrinsic
+ /// @param buffer the symbol of the storage buffer variable, owned by the target ProgramBuilder.
+ /// @return the name of the function that performs the load
+ Symbol AtomicFunc(const type::Type* el_ty,
+ const sem::Builtin* intrinsic,
+ const Symbol& buffer) {
+ auto op = intrinsic->Type();
+ return utils::GetOrCreate(atomic_funcs, AtomicKey{el_ty, op, buffer}, [&] {
+ // The first parameter to all WGSL atomics is the expression to the
+ // atomic. This is replaced with two parameters: the buffer and offset.
+ utils::Vector params{b.Param("offset", b.ty.u32())};
+
+ // Other parameters are copied as-is:
+ for (size_t i = 1; i < intrinsic->Parameters().Length(); i++) {
+ auto* param = intrinsic->Parameters()[i];
+ auto ty = CreateASTTypeFor(ctx, param->Type());
+ params.Push(b.Param("param_" + std::to_string(i), ty));
+ }
+
+ auto* atomic = IntrinsicAtomicFor(ctx.dst, op, el_ty, buffer);
+ if (TINT_UNLIKELY(!atomic)) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "IntrinsicAtomicFor() returned nullptr for op " << op << " and type "
+ << el_ty->TypeInfo().name;
+ }
+
+ Type ret_ty = CreateASTTypeFor(ctx, intrinsic->ReturnType());
+
+ auto name = b.Symbols().New(buffer.Name() + intrinsic->str());
+ b.Func(name, std::move(params), ret_ty, nullptr,
+ utils::Vector{
+ atomic,
+ b.Disable(DisabledValidation::kFunctionHasNoBody),
+ });
+ return name;
+ });
+ }
+};
+
+DecomposeMemoryAccess::Intrinsic::Intrinsic(ProgramID pid,
+ NodeID nid,
+ Op o,
+ DataType ty,
+ builtin::AddressSpace as,
+ const IdentifierExpression* buf)
+ : Base(pid, nid, utils::Vector{buf}), op(o), type(ty), address_space(as) {}
+DecomposeMemoryAccess::Intrinsic::~Intrinsic() = default;
+std::string DecomposeMemoryAccess::Intrinsic::InternalName() const {
+ utils::StringStream ss;
+ switch (op) {
+ case Op::kLoad:
+ ss << "intrinsic_load_";
+ break;
+ case Op::kStore:
+ ss << "intrinsic_store_";
+ break;
+ case Op::kAtomicLoad:
+ ss << "intrinsic_atomic_load_";
+ break;
+ case Op::kAtomicStore:
+ ss << "intrinsic_atomic_store_";
+ break;
+ case Op::kAtomicAdd:
+ ss << "intrinsic_atomic_add_";
+ break;
+ case Op::kAtomicSub:
+ ss << "intrinsic_atomic_sub_";
+ break;
+ case Op::kAtomicMax:
+ ss << "intrinsic_atomic_max_";
+ break;
+ case Op::kAtomicMin:
+ ss << "intrinsic_atomic_min_";
+ break;
+ case Op::kAtomicAnd:
+ ss << "intrinsic_atomic_and_";
+ break;
+ case Op::kAtomicOr:
+ ss << "intrinsic_atomic_or_";
+ break;
+ case Op::kAtomicXor:
+ ss << "intrinsic_atomic_xor_";
+ break;
+ case Op::kAtomicExchange:
+ ss << "intrinsic_atomic_exchange_";
+ break;
+ case Op::kAtomicCompareExchangeWeak:
+ ss << "intrinsic_atomic_compare_exchange_weak_";
+ break;
+ }
+ ss << address_space << "_";
+ switch (type) {
+ case DataType::kU32:
+ ss << "u32";
+ break;
+ case DataType::kF32:
+ ss << "f32";
+ break;
+ case DataType::kI32:
+ ss << "i32";
+ break;
+ case DataType::kF16:
+ ss << "f16";
+ break;
+ case DataType::kVec2U32:
+ ss << "vec2_u32";
+ break;
+ case DataType::kVec2F32:
+ ss << "vec2_f32";
+ break;
+ case DataType::kVec2I32:
+ ss << "vec2_i32";
+ break;
+ case DataType::kVec2F16:
+ ss << "vec2_f16";
+ break;
+ case DataType::kVec3U32:
+ ss << "vec3_u32";
+ break;
+ case DataType::kVec3F32:
+ ss << "vec3_f32";
+ break;
+ case DataType::kVec3I32:
+ ss << "vec3_i32";
+ break;
+ case DataType::kVec3F16:
+ ss << "vec3_f16";
+ break;
+ case DataType::kVec4U32:
+ ss << "vec4_u32";
+ break;
+ case DataType::kVec4F32:
+ ss << "vec4_f32";
+ break;
+ case DataType::kVec4I32:
+ ss << "vec4_i32";
+ break;
+ case DataType::kVec4F16:
+ ss << "vec4_f16";
+ break;
+ }
+ return ss.str();
+}
+
+const DecomposeMemoryAccess::Intrinsic* DecomposeMemoryAccess::Intrinsic::Clone(
+ CloneContext* ctx) const {
+ auto buf = ctx->Clone(Buffer());
+ return ctx->dst->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
+ ctx->dst->ID(), ctx->dst->AllocateNodeID(), op, type, address_space, buf);
+}
+
+bool DecomposeMemoryAccess::Intrinsic::IsAtomic() const {
+ return op != Op::kLoad && op != Op::kStore;
+}
+
+const IdentifierExpression* DecomposeMemoryAccess::Intrinsic::Buffer() const {
+ return dependencies[0];
+}
+
+DecomposeMemoryAccess::DecomposeMemoryAccess() = default;
+DecomposeMemoryAccess::~DecomposeMemoryAccess() = default;
+
+Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ auto& sem = src->Sem();
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+ State state(ctx);
+
+ // Scan the AST nodes for storage and uniform buffer accesses. Complex
+ // expression chains (e.g. `storage_buffer.foo.bar[20].x`) are handled by
+ // maintaining an offset chain via the `state.TakeAccess()`,
+ // `state.AddAccess()` methods.
+ //
+ // Inner-most expression nodes are guaranteed to be visited first because AST
+ // nodes are fully immutable and require their children to be constructed
+ // first so their pointer can be passed to the parent's initializer.
+ for (auto* node : src->ASTNodes().Objects()) {
+ if (auto* ident = node->As<IdentifierExpression>()) {
+ // X
+ if (auto* sem_ident = sem.GetVal(ident)) {
+ if (auto* user = sem_ident->UnwrapLoad()->As<sem::VariableUser>()) {
+ if (auto* global = user->Variable()->As<sem::GlobalVariable>()) {
+ if (global->AddressSpace() == builtin::AddressSpace::kStorage ||
+ global->AddressSpace() == builtin::AddressSpace::kUniform) {
+ // Variable to a storage or uniform buffer
+ state.AddAccess(ident, {
+ global,
+ state.ToOffset(0u),
+ global->Type()->UnwrapRef(),
+ });
+ }
+ }
+ }
+ }
+ continue;
+ }
+
+ if (auto* accessor = node->As<MemberAccessorExpression>()) {
+ // X.Y
+ auto* accessor_sem = sem.Get(accessor)->UnwrapLoad();
+ if (auto* swizzle = accessor_sem->As<sem::Swizzle>()) {
+ if (swizzle->Indices().Length() == 1) {
+ if (auto access = state.TakeAccess(accessor->object)) {
+ auto* vec_ty = access.type->As<type::Vector>();
+ auto* offset = state.Mul(vec_ty->type()->Size(), swizzle->Indices()[0u]);
+ state.AddAccess(accessor, {
+ access.var,
+ state.Add(access.offset, offset),
+ vec_ty->type(),
+ });
+ }
+ }
+ } else {
+ if (auto access = state.TakeAccess(accessor->object)) {
+ auto* str_ty = access.type->As<type::Struct>();
+ auto* member = str_ty->FindMember(accessor->member->symbol);
+ auto offset = member->Offset();
+ state.AddAccess(accessor, {
+ access.var,
+ state.Add(access.offset, offset),
+ member->Type(),
+ });
+ }
+ }
+ continue;
+ }
+
+ if (auto* accessor = node->As<IndexAccessorExpression>()) {
+ if (auto access = state.TakeAccess(accessor->object)) {
+ // X[Y]
+ if (auto* arr = access.type->As<type::Array>()) {
+ auto* offset = state.Mul(arr->Stride(), accessor->index);
+ state.AddAccess(accessor, {
+ access.var,
+ state.Add(access.offset, offset),
+ arr->ElemType(),
+ });
+ continue;
+ }
+ if (auto* vec_ty = access.type->As<type::Vector>()) {
+ auto* offset = state.Mul(vec_ty->type()->Size(), accessor->index);
+ state.AddAccess(accessor, {
+ access.var,
+ state.Add(access.offset, offset),
+ vec_ty->type(),
+ });
+ continue;
+ }
+ if (auto* mat_ty = access.type->As<type::Matrix>()) {
+ auto* offset = state.Mul(mat_ty->ColumnStride(), accessor->index);
+ state.AddAccess(accessor, {
+ access.var,
+ state.Add(access.offset, offset),
+ mat_ty->ColumnType(),
+ });
+ continue;
+ }
+ }
+ }
+
+ if (auto* op = node->As<UnaryOpExpression>()) {
+ if (op->op == UnaryOp::kAddressOf) {
+ // &X
+ if (auto access = state.TakeAccess(op->expr)) {
+ // HLSL does not support pointers, so just take the access from the
+ // reference and place it on the pointer.
+ state.AddAccess(op, access);
+ continue;
+ }
+ }
+ }
+
+ if (auto* assign = node->As<AssignmentStatement>()) {
+ // X = Y
+ // Move the LHS access to a store.
+ if (auto lhs = state.TakeAccess(assign->lhs)) {
+ state.stores.emplace_back(Store{assign, lhs});
+ }
+ }
+
+ if (auto* call_expr = node->As<CallExpression>()) {
+ auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
+ if (auto* builtin = call->Target()->As<sem::Builtin>()) {
+ if (builtin->Type() == builtin::Function::kArrayLength) {
+ // arrayLength(X)
+ // Don't convert X into a load, this builtin actually requires the real pointer.
+ state.TakeAccess(call_expr->args[0]);
+ continue;
+ }
+ if (builtin->IsAtomic()) {
+ if (auto access = state.TakeAccess(call_expr->args[0])) {
+ // atomic___(X)
+ ctx.Replace(call_expr, [=, &ctx, &state] {
+ auto* offset = access.offset->Build(ctx);
+ auto* el_ty = access.type->UnwrapRef()->As<type::Atomic>()->Type();
+ auto buffer = ctx.Clone(access.var->Declaration()->name->symbol);
+ Symbol func = state.AtomicFunc(el_ty, builtin, buffer);
+
+ utils::Vector<const Expression*, 8> args{offset};
+ for (size_t i = 1; i < call_expr->args.Length(); i++) {
+ auto* arg = call_expr->args[i];
+ args.Push(ctx.Clone(arg));
+ }
+ return ctx.dst->Call(func, args);
+ });
+ }
+ }
+ }
+ }
+ }
+
+ // All remaining accesses are loads, transform these into calls to the
+ // corresponding load function
+ // TODO(crbug.com/tint/1784): Use `sem::Load`s instead of maintaining `state.expression_order`.
+ for (auto* expr : state.expression_order) {
+ auto access_it = state.accesses.find(expr);
+ if (access_it == state.accesses.end()) {
+ continue;
+ }
+ BufferAccess access = access_it->second;
+ ctx.Replace(expr, [=, &ctx, &state] {
+ auto* offset = access.offset->Build(ctx);
+ auto* el_ty = access.type->UnwrapRef();
+ auto buffer = ctx.Clone(access.var->Declaration()->name->symbol);
+ Symbol func = state.LoadFunc(el_ty, access.var->AddressSpace(), buffer);
+ return ctx.dst->Call(func, offset);
+ });
+ }
+
+ // And replace all storage and uniform buffer assignments with stores
+ for (auto store : state.stores) {
+ ctx.Replace(store.assignment, [=, &ctx, &state] {
+ auto* offset = store.target.offset->Build(ctx);
+ auto* el_ty = store.target.type->UnwrapRef();
+ auto* value = store.assignment->rhs;
+ auto buffer = ctx.Clone(store.target.var->Declaration()->name->symbol);
+ Symbol func = state.StoreFunc(el_ty, buffer);
+ auto* call = ctx.dst->Call(func, offset, ctx.Clone(value));
+ return ctx.dst->CallStmt(call);
+ });
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::Offset);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::OffsetLiteral);
diff --git a/src/tint/lang/wgsl/ast/transform/decompose_memory_access.h b/src/tint/lang/wgsl/ast/transform/decompose_memory_access.h
new file mode 100644
index 0000000..9b02195
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/decompose_memory_access.h
@@ -0,0 +1,134 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_DECOMPOSE_MEMORY_ACCESS_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_DECOMPOSE_MEMORY_ACCESS_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/internal_attribute.h"
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+// Forward declarations
+namespace tint {
+class CloneContext;
+} // namespace tint
+
+namespace tint::ast::transform {
+
+/// DecomposeMemoryAccess is a transform used to replace storage and uniform buffer accesses with a
+/// combination of load, store or atomic functions on primitive types.
+class DecomposeMemoryAccess final : public utils::Castable<DecomposeMemoryAccess, Transform> {
+ public:
+ /// Intrinsic is an InternalAttribute that's used to decorate a stub function so that the HLSL
+ /// transforms this into calls to
+ /// `[RW]ByteAddressBuffer.Load[N]()` or `[RW]ByteAddressBuffer.Store[N]()`,
+ /// with a possible cast.
+ class Intrinsic final : public utils::Castable<Intrinsic, InternalAttribute> {
+ public:
+ /// Intrinsic op
+ enum class Op {
+ kLoad,
+ kStore,
+ kAtomicLoad,
+ kAtomicStore,
+ kAtomicAdd,
+ kAtomicSub,
+ kAtomicMax,
+ kAtomicMin,
+ kAtomicAnd,
+ kAtomicOr,
+ kAtomicXor,
+ kAtomicExchange,
+ kAtomicCompareExchangeWeak,
+ };
+
+ /// Intrinsic data type
+ enum class DataType {
+ kU32,
+ kF32,
+ kI32,
+ kF16,
+ kVec2U32,
+ kVec2F32,
+ kVec2I32,
+ kVec2F16,
+ kVec3U32,
+ kVec3F32,
+ kVec3I32,
+ kVec3F16,
+ kVec4U32,
+ kVec4F32,
+ kVec4I32,
+ kVec4F16,
+ };
+
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param o the op of the intrinsic
+ /// @param type the data type of the intrinsic
+ /// @param address_space the address space of the buffer
+ /// @param buffer the storage or uniform buffer identifier
+ Intrinsic(ProgramID pid,
+ NodeID nid,
+ Op o,
+ DataType type,
+ builtin::AddressSpace address_space,
+ const IdentifierExpression* buffer);
+ /// Destructor
+ ~Intrinsic() override;
+
+ /// @return a short description of the internal attribute which will be
+ /// displayed as `@internal(<name>)`
+ std::string InternalName() const override;
+
+ /// Performs a deep clone of this object using the CloneContext `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned object
+ const Intrinsic* Clone(CloneContext* ctx) const override;
+
+ /// @return true if op is atomic
+ bool IsAtomic() const;
+
+ /// @return the buffer that this intrinsic operates on
+ const IdentifierExpression* Buffer() const;
+
+ /// The op of the intrinsic
+ const Op op;
+
+ /// The type of the intrinsic
+ const DataType type;
+
+ /// The address space of the buffer this intrinsic operates on
+ const builtin::AddressSpace address_space;
+ };
+
+ /// Constructor
+ DecomposeMemoryAccess();
+ /// Destructor
+ ~DecomposeMemoryAccess() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_DECOMPOSE_MEMORY_ACCESS_H_
diff --git a/src/tint/lang/wgsl/ast/transform/decompose_memory_access_test.cc b/src/tint/lang/wgsl/ast/transform/decompose_memory_access_test.cc
new file mode 100644
index 0000000..8f77635
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/decompose_memory_access_test.cc
@@ -0,0 +1,3945 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/decompose_memory_access.h"
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using DecomposeMemoryAccessTest = TransformTest;
+
+TEST_F(DecomposeMemoryAccessTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<DecomposeMemoryAccess>(src));
+}
+
+TEST_F(DecomposeMemoryAccessTest, ShouldRunStorageBuffer) {
+ auto* src = R"(
+struct Buffer {
+ i : i32,
+};
+@group(0) @binding(0) var<storage, read_write> sb : Buffer;
+)";
+
+ EXPECT_TRUE(ShouldRun<DecomposeMemoryAccess>(src));
+}
+
+TEST_F(DecomposeMemoryAccessTest, ShouldRunUniformBuffer) {
+ auto* src = R"(
+struct Buffer {
+ i : i32,
+};
+@group(0) @binding(0) var<uniform> ub : Buffer;
+)";
+
+ EXPECT_TRUE(ShouldRun<DecomposeMemoryAccess>(src));
+}
+
+TEST_F(DecomposeMemoryAccessTest, SB_BasicLoad) {
+ auto* src = R"(
+enable f16;
+
+struct SB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_vec3_f16 : array<vec3<f16>, 2>,
+};
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ var scalar_f32 : f32 = sb.scalar_f32;
+ var scalar_i32 : i32 = sb.scalar_i32;
+ var scalar_u32 : u32 = sb.scalar_u32;
+ var scalar_f16 : f16 = sb.scalar_f16;
+ var vec2_f32 : vec2<f32> = sb.vec2_f32;
+ var vec2_i32 : vec2<i32> = sb.vec2_i32;
+ var vec2_u32 : vec2<u32> = sb.vec2_u32;
+ var vec2_f16 : vec2<f16> = sb.vec2_f16;
+ var vec3_f32 : vec3<f32> = sb.vec3_f32;
+ var vec3_i32 : vec3<i32> = sb.vec3_i32;
+ var vec3_u32 : vec3<u32> = sb.vec3_u32;
+ var vec3_f16 : vec3<f16> = sb.vec3_f16;
+ var vec4_f32 : vec4<f32> = sb.vec4_f32;
+ var vec4_i32 : vec4<i32> = sb.vec4_i32;
+ var vec4_u32 : vec4<u32> = sb.vec4_u32;
+ var vec4_f16 : vec4<f16> = sb.vec4_f16;
+ var mat2x2_f32 : mat2x2<f32> = sb.mat2x2_f32;
+ var mat2x3_f32 : mat2x3<f32> = sb.mat2x3_f32;
+ var mat2x4_f32 : mat2x4<f32> = sb.mat2x4_f32;
+ var mat3x2_f32 : mat3x2<f32> = sb.mat3x2_f32;
+ var mat3x3_f32 : mat3x3<f32> = sb.mat3x3_f32;
+ var mat3x4_f32 : mat3x4<f32> = sb.mat3x4_f32;
+ var mat4x2_f32 : mat4x2<f32> = sb.mat4x2_f32;
+ var mat4x3_f32 : mat4x3<f32> = sb.mat4x3_f32;
+ var mat4x4_f32 : mat4x4<f32> = sb.mat4x4_f32;
+ var mat2x2_f16 : mat2x2<f16> = sb.mat2x2_f16;
+ var mat2x3_f16 : mat2x3<f16> = sb.mat2x3_f16;
+ var mat2x4_f16 : mat2x4<f16> = sb.mat2x4_f16;
+ var mat3x2_f16 : mat3x2<f16> = sb.mat3x2_f16;
+ var mat3x3_f16 : mat3x3<f16> = sb.mat3x3_f16;
+ var mat3x4_f16 : mat3x4<f16> = sb.mat3x4_f16;
+ var mat4x2_f16 : mat4x2<f16> = sb.mat4x2_f16;
+ var mat4x3_f16 : mat4x3<f16> = sb.mat4x3_f16;
+ var mat4x4_f16 : mat4x4<f16> = sb.mat4x4_f16;
+ var arr2_vec3_f32 : array<vec3<f32>, 2> = sb.arr2_vec3_f32;
+ var arr2_vec3_f16 : array<vec3<f16>, 2> = sb.arr2_vec3_f16;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct SB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_vec3_f16 : array<vec3<f16>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+@internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load(offset : u32) -> f32
+
+@internal(intrinsic_load_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sb_load_1(offset : u32) -> i32
+
+@internal(intrinsic_load_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sb_load_2(offset : u32) -> u32
+
+@internal(intrinsic_load_storage_f16) @internal(disable_validation__function_has_no_body)
+fn sb_load_3(offset : u32) -> f16
+
+@internal(intrinsic_load_storage_vec2_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load_4(offset : u32) -> vec2<f32>
+
+@internal(intrinsic_load_storage_vec2_i32) @internal(disable_validation__function_has_no_body)
+fn sb_load_5(offset : u32) -> vec2<i32>
+
+@internal(intrinsic_load_storage_vec2_u32) @internal(disable_validation__function_has_no_body)
+fn sb_load_6(offset : u32) -> vec2<u32>
+
+@internal(intrinsic_load_storage_vec2_f16) @internal(disable_validation__function_has_no_body)
+fn sb_load_7(offset : u32) -> vec2<f16>
+
+@internal(intrinsic_load_storage_vec3_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load_8(offset : u32) -> vec3<f32>
+
+@internal(intrinsic_load_storage_vec3_i32) @internal(disable_validation__function_has_no_body)
+fn sb_load_9(offset : u32) -> vec3<i32>
+
+@internal(intrinsic_load_storage_vec3_u32) @internal(disable_validation__function_has_no_body)
+fn sb_load_10(offset : u32) -> vec3<u32>
+
+@internal(intrinsic_load_storage_vec3_f16) @internal(disable_validation__function_has_no_body)
+fn sb_load_11(offset : u32) -> vec3<f16>
+
+@internal(intrinsic_load_storage_vec4_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load_12(offset : u32) -> vec4<f32>
+
+@internal(intrinsic_load_storage_vec4_i32) @internal(disable_validation__function_has_no_body)
+fn sb_load_13(offset : u32) -> vec4<i32>
+
+@internal(intrinsic_load_storage_vec4_u32) @internal(disable_validation__function_has_no_body)
+fn sb_load_14(offset : u32) -> vec4<u32>
+
+@internal(intrinsic_load_storage_vec4_f16) @internal(disable_validation__function_has_no_body)
+fn sb_load_15(offset : u32) -> vec4<f16>
+
+fn sb_load_16(offset : u32) -> mat2x2<f32> {
+ return mat2x2<f32>(sb_load_4((offset + 0u)), sb_load_4((offset + 8u)));
+}
+
+fn sb_load_17(offset : u32) -> mat2x3<f32> {
+ return mat2x3<f32>(sb_load_8((offset + 0u)), sb_load_8((offset + 16u)));
+}
+
+fn sb_load_18(offset : u32) -> mat2x4<f32> {
+ return mat2x4<f32>(sb_load_12((offset + 0u)), sb_load_12((offset + 16u)));
+}
+
+fn sb_load_19(offset : u32) -> mat3x2<f32> {
+ return mat3x2<f32>(sb_load_4((offset + 0u)), sb_load_4((offset + 8u)), sb_load_4((offset + 16u)));
+}
+
+fn sb_load_20(offset : u32) -> mat3x3<f32> {
+ return mat3x3<f32>(sb_load_8((offset + 0u)), sb_load_8((offset + 16u)), sb_load_8((offset + 32u)));
+}
+
+fn sb_load_21(offset : u32) -> mat3x4<f32> {
+ return mat3x4<f32>(sb_load_12((offset + 0u)), sb_load_12((offset + 16u)), sb_load_12((offset + 32u)));
+}
+
+fn sb_load_22(offset : u32) -> mat4x2<f32> {
+ return mat4x2<f32>(sb_load_4((offset + 0u)), sb_load_4((offset + 8u)), sb_load_4((offset + 16u)), sb_load_4((offset + 24u)));
+}
+
+fn sb_load_23(offset : u32) -> mat4x3<f32> {
+ return mat4x3<f32>(sb_load_8((offset + 0u)), sb_load_8((offset + 16u)), sb_load_8((offset + 32u)), sb_load_8((offset + 48u)));
+}
+
+fn sb_load_24(offset : u32) -> mat4x4<f32> {
+ return mat4x4<f32>(sb_load_12((offset + 0u)), sb_load_12((offset + 16u)), sb_load_12((offset + 32u)), sb_load_12((offset + 48u)));
+}
+
+fn sb_load_25(offset : u32) -> mat2x2<f16> {
+ return mat2x2<f16>(sb_load_7((offset + 0u)), sb_load_7((offset + 4u)));
+}
+
+fn sb_load_26(offset : u32) -> mat2x3<f16> {
+ return mat2x3<f16>(sb_load_11((offset + 0u)), sb_load_11((offset + 8u)));
+}
+
+fn sb_load_27(offset : u32) -> mat2x4<f16> {
+ return mat2x4<f16>(sb_load_15((offset + 0u)), sb_load_15((offset + 8u)));
+}
+
+fn sb_load_28(offset : u32) -> mat3x2<f16> {
+ return mat3x2<f16>(sb_load_7((offset + 0u)), sb_load_7((offset + 4u)), sb_load_7((offset + 8u)));
+}
+
+fn sb_load_29(offset : u32) -> mat3x3<f16> {
+ return mat3x3<f16>(sb_load_11((offset + 0u)), sb_load_11((offset + 8u)), sb_load_11((offset + 16u)));
+}
+
+fn sb_load_30(offset : u32) -> mat3x4<f16> {
+ return mat3x4<f16>(sb_load_15((offset + 0u)), sb_load_15((offset + 8u)), sb_load_15((offset + 16u)));
+}
+
+fn sb_load_31(offset : u32) -> mat4x2<f16> {
+ return mat4x2<f16>(sb_load_7((offset + 0u)), sb_load_7((offset + 4u)), sb_load_7((offset + 8u)), sb_load_7((offset + 12u)));
+}
+
+fn sb_load_32(offset : u32) -> mat4x3<f16> {
+ return mat4x3<f16>(sb_load_11((offset + 0u)), sb_load_11((offset + 8u)), sb_load_11((offset + 16u)), sb_load_11((offset + 24u)));
+}
+
+fn sb_load_33(offset : u32) -> mat4x4<f16> {
+ return mat4x4<f16>(sb_load_15((offset + 0u)), sb_load_15((offset + 8u)), sb_load_15((offset + 16u)), sb_load_15((offset + 24u)));
+}
+
+fn sb_load_34(offset : u32) -> array<vec3<f32>, 2u> {
+ var arr : array<vec3<f32>, 2u>;
+ for(var i = 0u; (i < 2u); i = (i + 1u)) {
+ arr[i] = sb_load_8((offset + (i * 16u)));
+ }
+ return arr;
+}
+
+fn sb_load_35(offset : u32) -> array<vec3<f16>, 2u> {
+ var arr_1 : array<vec3<f16>, 2u>;
+ for(var i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) {
+ arr_1[i_1] = sb_load_11((offset + (i_1 * 8u)));
+ }
+ return arr_1;
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ var scalar_f32 : f32 = sb_load(0u);
+ var scalar_i32 : i32 = sb_load_1(4u);
+ var scalar_u32 : u32 = sb_load_2(8u);
+ var scalar_f16 : f16 = sb_load_3(12u);
+ var vec2_f32 : vec2<f32> = sb_load_4(16u);
+ var vec2_i32 : vec2<i32> = sb_load_5(24u);
+ var vec2_u32 : vec2<u32> = sb_load_6(32u);
+ var vec2_f16 : vec2<f16> = sb_load_7(40u);
+ var vec3_f32 : vec3<f32> = sb_load_8(48u);
+ var vec3_i32 : vec3<i32> = sb_load_9(64u);
+ var vec3_u32 : vec3<u32> = sb_load_10(80u);
+ var vec3_f16 : vec3<f16> = sb_load_11(96u);
+ var vec4_f32 : vec4<f32> = sb_load_12(112u);
+ var vec4_i32 : vec4<i32> = sb_load_13(128u);
+ var vec4_u32 : vec4<u32> = sb_load_14(144u);
+ var vec4_f16 : vec4<f16> = sb_load_15(160u);
+ var mat2x2_f32 : mat2x2<f32> = sb_load_16(168u);
+ var mat2x3_f32 : mat2x3<f32> = sb_load_17(192u);
+ var mat2x4_f32 : mat2x4<f32> = sb_load_18(224u);
+ var mat3x2_f32 : mat3x2<f32> = sb_load_19(256u);
+ var mat3x3_f32 : mat3x3<f32> = sb_load_20(288u);
+ var mat3x4_f32 : mat3x4<f32> = sb_load_21(336u);
+ var mat4x2_f32 : mat4x2<f32> = sb_load_22(384u);
+ var mat4x3_f32 : mat4x3<f32> = sb_load_23(416u);
+ var mat4x4_f32 : mat4x4<f32> = sb_load_24(480u);
+ var mat2x2_f16 : mat2x2<f16> = sb_load_25(544u);
+ var mat2x3_f16 : mat2x3<f16> = sb_load_26(552u);
+ var mat2x4_f16 : mat2x4<f16> = sb_load_27(568u);
+ var mat3x2_f16 : mat3x2<f16> = sb_load_28(584u);
+ var mat3x3_f16 : mat3x3<f16> = sb_load_29(600u);
+ var mat3x4_f16 : mat3x4<f16> = sb_load_30(624u);
+ var mat4x2_f16 : mat4x2<f16> = sb_load_31(648u);
+ var mat4x3_f16 : mat4x3<f16> = sb_load_32(664u);
+ var mat4x4_f16 : mat4x4<f16> = sb_load_33(696u);
+ var arr2_vec3_f32 : array<vec3<f32>, 2> = sb_load_34(736u);
+ var arr2_vec3_f16 : array<vec3<f16>, 2> = sb_load_35(768u);
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, SB_BasicLoad_OutOfOrder) {
+ auto* src = R"(
+enable f16;
+
+@compute @workgroup_size(1)
+fn main() {
+ var scalar_f32 : f32 = sb.scalar_f32;
+ var scalar_i32 : i32 = sb.scalar_i32;
+ var scalar_u32 : u32 = sb.scalar_u32;
+ var scalar_f16 : f16 = sb.scalar_f16;
+ var vec2_f32 : vec2<f32> = sb.vec2_f32;
+ var vec2_i32 : vec2<i32> = sb.vec2_i32;
+ var vec2_u32 : vec2<u32> = sb.vec2_u32;
+ var vec2_f16 : vec2<f16> = sb.vec2_f16;
+ var vec3_f32 : vec3<f32> = sb.vec3_f32;
+ var vec3_i32 : vec3<i32> = sb.vec3_i32;
+ var vec3_u32 : vec3<u32> = sb.vec3_u32;
+ var vec3_f16 : vec3<f16> = sb.vec3_f16;
+ var vec4_f32 : vec4<f32> = sb.vec4_f32;
+ var vec4_i32 : vec4<i32> = sb.vec4_i32;
+ var vec4_u32 : vec4<u32> = sb.vec4_u32;
+ var vec4_f16 : vec4<f16> = sb.vec4_f16;
+ var mat2x2_f32 : mat2x2<f32> = sb.mat2x2_f32;
+ var mat2x3_f32 : mat2x3<f32> = sb.mat2x3_f32;
+ var mat2x4_f32 : mat2x4<f32> = sb.mat2x4_f32;
+ var mat3x2_f32 : mat3x2<f32> = sb.mat3x2_f32;
+ var mat3x3_f32 : mat3x3<f32> = sb.mat3x3_f32;
+ var mat3x4_f32 : mat3x4<f32> = sb.mat3x4_f32;
+ var mat4x2_f32 : mat4x2<f32> = sb.mat4x2_f32;
+ var mat4x3_f32 : mat4x3<f32> = sb.mat4x3_f32;
+ var mat4x4_f32 : mat4x4<f32> = sb.mat4x4_f32;
+ var mat2x2_f16 : mat2x2<f16> = sb.mat2x2_f16;
+ var mat2x3_f16 : mat2x3<f16> = sb.mat2x3_f16;
+ var mat2x4_f16 : mat2x4<f16> = sb.mat2x4_f16;
+ var mat3x2_f16 : mat3x2<f16> = sb.mat3x2_f16;
+ var mat3x3_f16 : mat3x3<f16> = sb.mat3x3_f16;
+ var mat3x4_f16 : mat3x4<f16> = sb.mat3x4_f16;
+ var mat4x2_f16 : mat4x2<f16> = sb.mat4x2_f16;
+ var mat4x3_f16 : mat4x3<f16> = sb.mat4x3_f16;
+ var mat4x4_f16 : mat4x4<f16> = sb.mat4x4_f16;
+ var arr2_vec3_f32 : array<vec3<f32>, 2> = sb.arr2_vec3_f32;
+ var arr2_vec3_f16 : array<vec3<f16>, 2> = sb.arr2_vec3_f16;
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+struct SB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_vec3_f16 : array<vec3<f16>, 2>,
+};
+)";
+
+ auto* expect = R"(
+enable f16;
+
+@internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load(offset : u32) -> f32
+
+@internal(intrinsic_load_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sb_load_1(offset : u32) -> i32
+
+@internal(intrinsic_load_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sb_load_2(offset : u32) -> u32
+
+@internal(intrinsic_load_storage_f16) @internal(disable_validation__function_has_no_body)
+fn sb_load_3(offset : u32) -> f16
+
+@internal(intrinsic_load_storage_vec2_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load_4(offset : u32) -> vec2<f32>
+
+@internal(intrinsic_load_storage_vec2_i32) @internal(disable_validation__function_has_no_body)
+fn sb_load_5(offset : u32) -> vec2<i32>
+
+@internal(intrinsic_load_storage_vec2_u32) @internal(disable_validation__function_has_no_body)
+fn sb_load_6(offset : u32) -> vec2<u32>
+
+@internal(intrinsic_load_storage_vec2_f16) @internal(disable_validation__function_has_no_body)
+fn sb_load_7(offset : u32) -> vec2<f16>
+
+@internal(intrinsic_load_storage_vec3_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load_8(offset : u32) -> vec3<f32>
+
+@internal(intrinsic_load_storage_vec3_i32) @internal(disable_validation__function_has_no_body)
+fn sb_load_9(offset : u32) -> vec3<i32>
+
+@internal(intrinsic_load_storage_vec3_u32) @internal(disable_validation__function_has_no_body)
+fn sb_load_10(offset : u32) -> vec3<u32>
+
+@internal(intrinsic_load_storage_vec3_f16) @internal(disable_validation__function_has_no_body)
+fn sb_load_11(offset : u32) -> vec3<f16>
+
+@internal(intrinsic_load_storage_vec4_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load_12(offset : u32) -> vec4<f32>
+
+@internal(intrinsic_load_storage_vec4_i32) @internal(disable_validation__function_has_no_body)
+fn sb_load_13(offset : u32) -> vec4<i32>
+
+@internal(intrinsic_load_storage_vec4_u32) @internal(disable_validation__function_has_no_body)
+fn sb_load_14(offset : u32) -> vec4<u32>
+
+@internal(intrinsic_load_storage_vec4_f16) @internal(disable_validation__function_has_no_body)
+fn sb_load_15(offset : u32) -> vec4<f16>
+
+fn sb_load_16(offset : u32) -> mat2x2<f32> {
+ return mat2x2<f32>(sb_load_4((offset + 0u)), sb_load_4((offset + 8u)));
+}
+
+fn sb_load_17(offset : u32) -> mat2x3<f32> {
+ return mat2x3<f32>(sb_load_8((offset + 0u)), sb_load_8((offset + 16u)));
+}
+
+fn sb_load_18(offset : u32) -> mat2x4<f32> {
+ return mat2x4<f32>(sb_load_12((offset + 0u)), sb_load_12((offset + 16u)));
+}
+
+fn sb_load_19(offset : u32) -> mat3x2<f32> {
+ return mat3x2<f32>(sb_load_4((offset + 0u)), sb_load_4((offset + 8u)), sb_load_4((offset + 16u)));
+}
+
+fn sb_load_20(offset : u32) -> mat3x3<f32> {
+ return mat3x3<f32>(sb_load_8((offset + 0u)), sb_load_8((offset + 16u)), sb_load_8((offset + 32u)));
+}
+
+fn sb_load_21(offset : u32) -> mat3x4<f32> {
+ return mat3x4<f32>(sb_load_12((offset + 0u)), sb_load_12((offset + 16u)), sb_load_12((offset + 32u)));
+}
+
+fn sb_load_22(offset : u32) -> mat4x2<f32> {
+ return mat4x2<f32>(sb_load_4((offset + 0u)), sb_load_4((offset + 8u)), sb_load_4((offset + 16u)), sb_load_4((offset + 24u)));
+}
+
+fn sb_load_23(offset : u32) -> mat4x3<f32> {
+ return mat4x3<f32>(sb_load_8((offset + 0u)), sb_load_8((offset + 16u)), sb_load_8((offset + 32u)), sb_load_8((offset + 48u)));
+}
+
+fn sb_load_24(offset : u32) -> mat4x4<f32> {
+ return mat4x4<f32>(sb_load_12((offset + 0u)), sb_load_12((offset + 16u)), sb_load_12((offset + 32u)), sb_load_12((offset + 48u)));
+}
+
+fn sb_load_25(offset : u32) -> mat2x2<f16> {
+ return mat2x2<f16>(sb_load_7((offset + 0u)), sb_load_7((offset + 4u)));
+}
+
+fn sb_load_26(offset : u32) -> mat2x3<f16> {
+ return mat2x3<f16>(sb_load_11((offset + 0u)), sb_load_11((offset + 8u)));
+}
+
+fn sb_load_27(offset : u32) -> mat2x4<f16> {
+ return mat2x4<f16>(sb_load_15((offset + 0u)), sb_load_15((offset + 8u)));
+}
+
+fn sb_load_28(offset : u32) -> mat3x2<f16> {
+ return mat3x2<f16>(sb_load_7((offset + 0u)), sb_load_7((offset + 4u)), sb_load_7((offset + 8u)));
+}
+
+fn sb_load_29(offset : u32) -> mat3x3<f16> {
+ return mat3x3<f16>(sb_load_11((offset + 0u)), sb_load_11((offset + 8u)), sb_load_11((offset + 16u)));
+}
+
+fn sb_load_30(offset : u32) -> mat3x4<f16> {
+ return mat3x4<f16>(sb_load_15((offset + 0u)), sb_load_15((offset + 8u)), sb_load_15((offset + 16u)));
+}
+
+fn sb_load_31(offset : u32) -> mat4x2<f16> {
+ return mat4x2<f16>(sb_load_7((offset + 0u)), sb_load_7((offset + 4u)), sb_load_7((offset + 8u)), sb_load_7((offset + 12u)));
+}
+
+fn sb_load_32(offset : u32) -> mat4x3<f16> {
+ return mat4x3<f16>(sb_load_11((offset + 0u)), sb_load_11((offset + 8u)), sb_load_11((offset + 16u)), sb_load_11((offset + 24u)));
+}
+
+fn sb_load_33(offset : u32) -> mat4x4<f16> {
+ return mat4x4<f16>(sb_load_15((offset + 0u)), sb_load_15((offset + 8u)), sb_load_15((offset + 16u)), sb_load_15((offset + 24u)));
+}
+
+fn sb_load_34(offset : u32) -> array<vec3<f32>, 2u> {
+ var arr : array<vec3<f32>, 2u>;
+ for(var i = 0u; (i < 2u); i = (i + 1u)) {
+ arr[i] = sb_load_8((offset + (i * 16u)));
+ }
+ return arr;
+}
+
+fn sb_load_35(offset : u32) -> array<vec3<f16>, 2u> {
+ var arr_1 : array<vec3<f16>, 2u>;
+ for(var i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) {
+ arr_1[i_1] = sb_load_11((offset + (i_1 * 8u)));
+ }
+ return arr_1;
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ var scalar_f32 : f32 = sb_load(0u);
+ var scalar_i32 : i32 = sb_load_1(4u);
+ var scalar_u32 : u32 = sb_load_2(8u);
+ var scalar_f16 : f16 = sb_load_3(12u);
+ var vec2_f32 : vec2<f32> = sb_load_4(16u);
+ var vec2_i32 : vec2<i32> = sb_load_5(24u);
+ var vec2_u32 : vec2<u32> = sb_load_6(32u);
+ var vec2_f16 : vec2<f16> = sb_load_7(40u);
+ var vec3_f32 : vec3<f32> = sb_load_8(48u);
+ var vec3_i32 : vec3<i32> = sb_load_9(64u);
+ var vec3_u32 : vec3<u32> = sb_load_10(80u);
+ var vec3_f16 : vec3<f16> = sb_load_11(96u);
+ var vec4_f32 : vec4<f32> = sb_load_12(112u);
+ var vec4_i32 : vec4<i32> = sb_load_13(128u);
+ var vec4_u32 : vec4<u32> = sb_load_14(144u);
+ var vec4_f16 : vec4<f16> = sb_load_15(160u);
+ var mat2x2_f32 : mat2x2<f32> = sb_load_16(168u);
+ var mat2x3_f32 : mat2x3<f32> = sb_load_17(192u);
+ var mat2x4_f32 : mat2x4<f32> = sb_load_18(224u);
+ var mat3x2_f32 : mat3x2<f32> = sb_load_19(256u);
+ var mat3x3_f32 : mat3x3<f32> = sb_load_20(288u);
+ var mat3x4_f32 : mat3x4<f32> = sb_load_21(336u);
+ var mat4x2_f32 : mat4x2<f32> = sb_load_22(384u);
+ var mat4x3_f32 : mat4x3<f32> = sb_load_23(416u);
+ var mat4x4_f32 : mat4x4<f32> = sb_load_24(480u);
+ var mat2x2_f16 : mat2x2<f16> = sb_load_25(544u);
+ var mat2x3_f16 : mat2x3<f16> = sb_load_26(552u);
+ var mat2x4_f16 : mat2x4<f16> = sb_load_27(568u);
+ var mat3x2_f16 : mat3x2<f16> = sb_load_28(584u);
+ var mat3x3_f16 : mat3x3<f16> = sb_load_29(600u);
+ var mat3x4_f16 : mat3x4<f16> = sb_load_30(624u);
+ var mat4x2_f16 : mat4x2<f16> = sb_load_31(648u);
+ var mat4x3_f16 : mat4x3<f16> = sb_load_32(664u);
+ var mat4x4_f16 : mat4x4<f16> = sb_load_33(696u);
+ var arr2_vec3_f32 : array<vec3<f32>, 2> = sb_load_34(736u);
+ var arr2_vec3_f16 : array<vec3<f16>, 2> = sb_load_35(768u);
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+struct SB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_vec3_f16 : array<vec3<f16>, 2>,
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, UB_BasicLoad) {
+ auto* src = R"(
+enable f16;
+
+struct UB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_mat4x2_f16 : array<mat4x2<f16>, 2>,
+};
+
+@group(0) @binding(0) var<uniform> ub : UB;
+
+@compute @workgroup_size(1)
+fn main() {
+ var scalar_f32 : f32 = ub.scalar_f32;
+ var scalar_i32 : i32 = ub.scalar_i32;
+ var scalar_u32 : u32 = ub.scalar_u32;
+ var scalar_f16 : f16 = ub.scalar_f16;
+ var vec2_f32 : vec2<f32> = ub.vec2_f32;
+ var vec2_i32 : vec2<i32> = ub.vec2_i32;
+ var vec2_u32 : vec2<u32> = ub.vec2_u32;
+ var vec2_f16 : vec2<f16> = ub.vec2_f16;
+ var vec3_f32 : vec3<f32> = ub.vec3_f32;
+ var vec3_i32 : vec3<i32> = ub.vec3_i32;
+ var vec3_u32 : vec3<u32> = ub.vec3_u32;
+ var vec3_f16 : vec3<f16> = ub.vec3_f16;
+ var vec4_f32 : vec4<f32> = ub.vec4_f32;
+ var vec4_i32 : vec4<i32> = ub.vec4_i32;
+ var vec4_u32 : vec4<u32> = ub.vec4_u32;
+ var vec4_f16 : vec4<f16> = ub.vec4_f16;
+ var mat2x2_f32 : mat2x2<f32> = ub.mat2x2_f32;
+ var mat2x3_f32 : mat2x3<f32> = ub.mat2x3_f32;
+ var mat2x4_f32 : mat2x4<f32> = ub.mat2x4_f32;
+ var mat3x2_f32 : mat3x2<f32> = ub.mat3x2_f32;
+ var mat3x3_f32 : mat3x3<f32> = ub.mat3x3_f32;
+ var mat3x4_f32 : mat3x4<f32> = ub.mat3x4_f32;
+ var mat4x2_f32 : mat4x2<f32> = ub.mat4x2_f32;
+ var mat4x3_f32 : mat4x3<f32> = ub.mat4x3_f32;
+ var mat4x4_f32 : mat4x4<f32> = ub.mat4x4_f32;
+ var mat2x2_f16 : mat2x2<f16> = ub.mat2x2_f16;
+ var mat2x3_f16 : mat2x3<f16> = ub.mat2x3_f16;
+ var mat2x4_f16 : mat2x4<f16> = ub.mat2x4_f16;
+ var mat3x2_f16 : mat3x2<f16> = ub.mat3x2_f16;
+ var mat3x3_f16 : mat3x3<f16> = ub.mat3x3_f16;
+ var mat3x4_f16 : mat3x4<f16> = ub.mat3x4_f16;
+ var mat4x2_f16 : mat4x2<f16> = ub.mat4x2_f16;
+ var mat4x3_f16 : mat4x3<f16> = ub.mat4x3_f16;
+ var mat4x4_f16 : mat4x4<f16> = ub.mat4x4_f16;
+ var arr2_vec3_f32 : array<vec3<f32>, 2> = ub.arr2_vec3_f32;
+ var arr2_mat4x2_f16 : array<mat4x2<f16>, 2> = ub.arr2_mat4x2_f16;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct UB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_mat4x2_f16 : array<mat4x2<f16>, 2>,
+}
+
+@group(0) @binding(0) var<uniform> ub : UB;
+
+@internal(intrinsic_load_uniform_f32) @internal(disable_validation__function_has_no_body)
+fn ub_load(offset : u32) -> f32
+
+@internal(intrinsic_load_uniform_i32) @internal(disable_validation__function_has_no_body)
+fn ub_load_1(offset : u32) -> i32
+
+@internal(intrinsic_load_uniform_u32) @internal(disable_validation__function_has_no_body)
+fn ub_load_2(offset : u32) -> u32
+
+@internal(intrinsic_load_uniform_f16) @internal(disable_validation__function_has_no_body)
+fn ub_load_3(offset : u32) -> f16
+
+@internal(intrinsic_load_uniform_vec2_f32) @internal(disable_validation__function_has_no_body)
+fn ub_load_4(offset : u32) -> vec2<f32>
+
+@internal(intrinsic_load_uniform_vec2_i32) @internal(disable_validation__function_has_no_body)
+fn ub_load_5(offset : u32) -> vec2<i32>
+
+@internal(intrinsic_load_uniform_vec2_u32) @internal(disable_validation__function_has_no_body)
+fn ub_load_6(offset : u32) -> vec2<u32>
+
+@internal(intrinsic_load_uniform_vec2_f16) @internal(disable_validation__function_has_no_body)
+fn ub_load_7(offset : u32) -> vec2<f16>
+
+@internal(intrinsic_load_uniform_vec3_f32) @internal(disable_validation__function_has_no_body)
+fn ub_load_8(offset : u32) -> vec3<f32>
+
+@internal(intrinsic_load_uniform_vec3_i32) @internal(disable_validation__function_has_no_body)
+fn ub_load_9(offset : u32) -> vec3<i32>
+
+@internal(intrinsic_load_uniform_vec3_u32) @internal(disable_validation__function_has_no_body)
+fn ub_load_10(offset : u32) -> vec3<u32>
+
+@internal(intrinsic_load_uniform_vec3_f16) @internal(disable_validation__function_has_no_body)
+fn ub_load_11(offset : u32) -> vec3<f16>
+
+@internal(intrinsic_load_uniform_vec4_f32) @internal(disable_validation__function_has_no_body)
+fn ub_load_12(offset : u32) -> vec4<f32>
+
+@internal(intrinsic_load_uniform_vec4_i32) @internal(disable_validation__function_has_no_body)
+fn ub_load_13(offset : u32) -> vec4<i32>
+
+@internal(intrinsic_load_uniform_vec4_u32) @internal(disable_validation__function_has_no_body)
+fn ub_load_14(offset : u32) -> vec4<u32>
+
+@internal(intrinsic_load_uniform_vec4_f16) @internal(disable_validation__function_has_no_body)
+fn ub_load_15(offset : u32) -> vec4<f16>
+
+fn ub_load_16(offset : u32) -> mat2x2<f32> {
+ return mat2x2<f32>(ub_load_4((offset + 0u)), ub_load_4((offset + 8u)));
+}
+
+fn ub_load_17(offset : u32) -> mat2x3<f32> {
+ return mat2x3<f32>(ub_load_8((offset + 0u)), ub_load_8((offset + 16u)));
+}
+
+fn ub_load_18(offset : u32) -> mat2x4<f32> {
+ return mat2x4<f32>(ub_load_12((offset + 0u)), ub_load_12((offset + 16u)));
+}
+
+fn ub_load_19(offset : u32) -> mat3x2<f32> {
+ return mat3x2<f32>(ub_load_4((offset + 0u)), ub_load_4((offset + 8u)), ub_load_4((offset + 16u)));
+}
+
+fn ub_load_20(offset : u32) -> mat3x3<f32> {
+ return mat3x3<f32>(ub_load_8((offset + 0u)), ub_load_8((offset + 16u)), ub_load_8((offset + 32u)));
+}
+
+fn ub_load_21(offset : u32) -> mat3x4<f32> {
+ return mat3x4<f32>(ub_load_12((offset + 0u)), ub_load_12((offset + 16u)), ub_load_12((offset + 32u)));
+}
+
+fn ub_load_22(offset : u32) -> mat4x2<f32> {
+ return mat4x2<f32>(ub_load_4((offset + 0u)), ub_load_4((offset + 8u)), ub_load_4((offset + 16u)), ub_load_4((offset + 24u)));
+}
+
+fn ub_load_23(offset : u32) -> mat4x3<f32> {
+ return mat4x3<f32>(ub_load_8((offset + 0u)), ub_load_8((offset + 16u)), ub_load_8((offset + 32u)), ub_load_8((offset + 48u)));
+}
+
+fn ub_load_24(offset : u32) -> mat4x4<f32> {
+ return mat4x4<f32>(ub_load_12((offset + 0u)), ub_load_12((offset + 16u)), ub_load_12((offset + 32u)), ub_load_12((offset + 48u)));
+}
+
+fn ub_load_25(offset : u32) -> mat2x2<f16> {
+ return mat2x2<f16>(ub_load_7((offset + 0u)), ub_load_7((offset + 4u)));
+}
+
+fn ub_load_26(offset : u32) -> mat2x3<f16> {
+ return mat2x3<f16>(ub_load_11((offset + 0u)), ub_load_11((offset + 8u)));
+}
+
+fn ub_load_27(offset : u32) -> mat2x4<f16> {
+ return mat2x4<f16>(ub_load_15((offset + 0u)), ub_load_15((offset + 8u)));
+}
+
+fn ub_load_28(offset : u32) -> mat3x2<f16> {
+ return mat3x2<f16>(ub_load_7((offset + 0u)), ub_load_7((offset + 4u)), ub_load_7((offset + 8u)));
+}
+
+fn ub_load_29(offset : u32) -> mat3x3<f16> {
+ return mat3x3<f16>(ub_load_11((offset + 0u)), ub_load_11((offset + 8u)), ub_load_11((offset + 16u)));
+}
+
+fn ub_load_30(offset : u32) -> mat3x4<f16> {
+ return mat3x4<f16>(ub_load_15((offset + 0u)), ub_load_15((offset + 8u)), ub_load_15((offset + 16u)));
+}
+
+fn ub_load_31(offset : u32) -> mat4x2<f16> {
+ return mat4x2<f16>(ub_load_7((offset + 0u)), ub_load_7((offset + 4u)), ub_load_7((offset + 8u)), ub_load_7((offset + 12u)));
+}
+
+fn ub_load_32(offset : u32) -> mat4x3<f16> {
+ return mat4x3<f16>(ub_load_11((offset + 0u)), ub_load_11((offset + 8u)), ub_load_11((offset + 16u)), ub_load_11((offset + 24u)));
+}
+
+fn ub_load_33(offset : u32) -> mat4x4<f16> {
+ return mat4x4<f16>(ub_load_15((offset + 0u)), ub_load_15((offset + 8u)), ub_load_15((offset + 16u)), ub_load_15((offset + 24u)));
+}
+
+fn ub_load_34(offset : u32) -> array<vec3<f32>, 2u> {
+ var arr : array<vec3<f32>, 2u>;
+ for(var i = 0u; (i < 2u); i = (i + 1u)) {
+ arr[i] = ub_load_8((offset + (i * 16u)));
+ }
+ return arr;
+}
+
+fn ub_load_35(offset : u32) -> array<mat4x2<f16>, 2u> {
+ var arr_1 : array<mat4x2<f16>, 2u>;
+ for(var i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) {
+ arr_1[i_1] = ub_load_31((offset + (i_1 * 16u)));
+ }
+ return arr_1;
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ var scalar_f32 : f32 = ub_load(0u);
+ var scalar_i32 : i32 = ub_load_1(4u);
+ var scalar_u32 : u32 = ub_load_2(8u);
+ var scalar_f16 : f16 = ub_load_3(12u);
+ var vec2_f32 : vec2<f32> = ub_load_4(16u);
+ var vec2_i32 : vec2<i32> = ub_load_5(24u);
+ var vec2_u32 : vec2<u32> = ub_load_6(32u);
+ var vec2_f16 : vec2<f16> = ub_load_7(40u);
+ var vec3_f32 : vec3<f32> = ub_load_8(48u);
+ var vec3_i32 : vec3<i32> = ub_load_9(64u);
+ var vec3_u32 : vec3<u32> = ub_load_10(80u);
+ var vec3_f16 : vec3<f16> = ub_load_11(96u);
+ var vec4_f32 : vec4<f32> = ub_load_12(112u);
+ var vec4_i32 : vec4<i32> = ub_load_13(128u);
+ var vec4_u32 : vec4<u32> = ub_load_14(144u);
+ var vec4_f16 : vec4<f16> = ub_load_15(160u);
+ var mat2x2_f32 : mat2x2<f32> = ub_load_16(168u);
+ var mat2x3_f32 : mat2x3<f32> = ub_load_17(192u);
+ var mat2x4_f32 : mat2x4<f32> = ub_load_18(224u);
+ var mat3x2_f32 : mat3x2<f32> = ub_load_19(256u);
+ var mat3x3_f32 : mat3x3<f32> = ub_load_20(288u);
+ var mat3x4_f32 : mat3x4<f32> = ub_load_21(336u);
+ var mat4x2_f32 : mat4x2<f32> = ub_load_22(384u);
+ var mat4x3_f32 : mat4x3<f32> = ub_load_23(416u);
+ var mat4x4_f32 : mat4x4<f32> = ub_load_24(480u);
+ var mat2x2_f16 : mat2x2<f16> = ub_load_25(544u);
+ var mat2x3_f16 : mat2x3<f16> = ub_load_26(552u);
+ var mat2x4_f16 : mat2x4<f16> = ub_load_27(568u);
+ var mat3x2_f16 : mat3x2<f16> = ub_load_28(584u);
+ var mat3x3_f16 : mat3x3<f16> = ub_load_29(600u);
+ var mat3x4_f16 : mat3x4<f16> = ub_load_30(624u);
+ var mat4x2_f16 : mat4x2<f16> = ub_load_31(648u);
+ var mat4x3_f16 : mat4x3<f16> = ub_load_32(664u);
+ var mat4x4_f16 : mat4x4<f16> = ub_load_33(696u);
+ var arr2_vec3_f32 : array<vec3<f32>, 2> = ub_load_34(736u);
+ var arr2_mat4x2_f16 : array<mat4x2<f16>, 2> = ub_load_35(768u);
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, UB_BasicLoad_OutOfOrder) {
+ auto* src = R"(
+enable f16;
+
+@compute @workgroup_size(1)
+fn main() {
+ var scalar_f32 : f32 = ub.scalar_f32;
+ var scalar_i32 : i32 = ub.scalar_i32;
+ var scalar_u32 : u32 = ub.scalar_u32;
+ var scalar_f16 : f16 = ub.scalar_f16;
+ var vec2_f32 : vec2<f32> = ub.vec2_f32;
+ var vec2_i32 : vec2<i32> = ub.vec2_i32;
+ var vec2_u32 : vec2<u32> = ub.vec2_u32;
+ var vec2_f16 : vec2<f16> = ub.vec2_f16;
+ var vec3_f32 : vec3<f32> = ub.vec3_f32;
+ var vec3_i32 : vec3<i32> = ub.vec3_i32;
+ var vec3_u32 : vec3<u32> = ub.vec3_u32;
+ var vec3_f16 : vec3<f16> = ub.vec3_f16;
+ var vec4_f32 : vec4<f32> = ub.vec4_f32;
+ var vec4_i32 : vec4<i32> = ub.vec4_i32;
+ var vec4_u32 : vec4<u32> = ub.vec4_u32;
+ var vec4_f16 : vec4<f16> = ub.vec4_f16;
+ var mat2x2_f32 : mat2x2<f32> = ub.mat2x2_f32;
+ var mat2x3_f32 : mat2x3<f32> = ub.mat2x3_f32;
+ var mat2x4_f32 : mat2x4<f32> = ub.mat2x4_f32;
+ var mat3x2_f32 : mat3x2<f32> = ub.mat3x2_f32;
+ var mat3x3_f32 : mat3x3<f32> = ub.mat3x3_f32;
+ var mat3x4_f32 : mat3x4<f32> = ub.mat3x4_f32;
+ var mat4x2_f32 : mat4x2<f32> = ub.mat4x2_f32;
+ var mat4x3_f32 : mat4x3<f32> = ub.mat4x3_f32;
+ var mat4x4_f32 : mat4x4<f32> = ub.mat4x4_f32;
+ var mat2x2_f16 : mat2x2<f16> = ub.mat2x2_f16;
+ var mat2x3_f16 : mat2x3<f16> = ub.mat2x3_f16;
+ var mat2x4_f16 : mat2x4<f16> = ub.mat2x4_f16;
+ var mat3x2_f16 : mat3x2<f16> = ub.mat3x2_f16;
+ var mat3x3_f16 : mat3x3<f16> = ub.mat3x3_f16;
+ var mat3x4_f16 : mat3x4<f16> = ub.mat3x4_f16;
+ var mat4x2_f16 : mat4x2<f16> = ub.mat4x2_f16;
+ var mat4x3_f16 : mat4x3<f16> = ub.mat4x3_f16;
+ var mat4x4_f16 : mat4x4<f16> = ub.mat4x4_f16;
+ var arr2_vec3_f32 : array<vec3<f32>, 2> = ub.arr2_vec3_f32;
+ var arr2_mat4x2_f16 : array<mat4x2<f16>, 2> = ub.arr2_mat4x2_f16;
+}
+
+@group(0) @binding(0) var<uniform> ub : UB;
+
+struct UB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_mat4x2_f16 : array<mat4x2<f16>, 2>,
+};
+)";
+
+ auto* expect = R"(
+enable f16;
+
+@internal(intrinsic_load_uniform_f32) @internal(disable_validation__function_has_no_body)
+fn ub_load(offset : u32) -> f32
+
+@internal(intrinsic_load_uniform_i32) @internal(disable_validation__function_has_no_body)
+fn ub_load_1(offset : u32) -> i32
+
+@internal(intrinsic_load_uniform_u32) @internal(disable_validation__function_has_no_body)
+fn ub_load_2(offset : u32) -> u32
+
+@internal(intrinsic_load_uniform_f16) @internal(disable_validation__function_has_no_body)
+fn ub_load_3(offset : u32) -> f16
+
+@internal(intrinsic_load_uniform_vec2_f32) @internal(disable_validation__function_has_no_body)
+fn ub_load_4(offset : u32) -> vec2<f32>
+
+@internal(intrinsic_load_uniform_vec2_i32) @internal(disable_validation__function_has_no_body)
+fn ub_load_5(offset : u32) -> vec2<i32>
+
+@internal(intrinsic_load_uniform_vec2_u32) @internal(disable_validation__function_has_no_body)
+fn ub_load_6(offset : u32) -> vec2<u32>
+
+@internal(intrinsic_load_uniform_vec2_f16) @internal(disable_validation__function_has_no_body)
+fn ub_load_7(offset : u32) -> vec2<f16>
+
+@internal(intrinsic_load_uniform_vec3_f32) @internal(disable_validation__function_has_no_body)
+fn ub_load_8(offset : u32) -> vec3<f32>
+
+@internal(intrinsic_load_uniform_vec3_i32) @internal(disable_validation__function_has_no_body)
+fn ub_load_9(offset : u32) -> vec3<i32>
+
+@internal(intrinsic_load_uniform_vec3_u32) @internal(disable_validation__function_has_no_body)
+fn ub_load_10(offset : u32) -> vec3<u32>
+
+@internal(intrinsic_load_uniform_vec3_f16) @internal(disable_validation__function_has_no_body)
+fn ub_load_11(offset : u32) -> vec3<f16>
+
+@internal(intrinsic_load_uniform_vec4_f32) @internal(disable_validation__function_has_no_body)
+fn ub_load_12(offset : u32) -> vec4<f32>
+
+@internal(intrinsic_load_uniform_vec4_i32) @internal(disable_validation__function_has_no_body)
+fn ub_load_13(offset : u32) -> vec4<i32>
+
+@internal(intrinsic_load_uniform_vec4_u32) @internal(disable_validation__function_has_no_body)
+fn ub_load_14(offset : u32) -> vec4<u32>
+
+@internal(intrinsic_load_uniform_vec4_f16) @internal(disable_validation__function_has_no_body)
+fn ub_load_15(offset : u32) -> vec4<f16>
+
+fn ub_load_16(offset : u32) -> mat2x2<f32> {
+ return mat2x2<f32>(ub_load_4((offset + 0u)), ub_load_4((offset + 8u)));
+}
+
+fn ub_load_17(offset : u32) -> mat2x3<f32> {
+ return mat2x3<f32>(ub_load_8((offset + 0u)), ub_load_8((offset + 16u)));
+}
+
+fn ub_load_18(offset : u32) -> mat2x4<f32> {
+ return mat2x4<f32>(ub_load_12((offset + 0u)), ub_load_12((offset + 16u)));
+}
+
+fn ub_load_19(offset : u32) -> mat3x2<f32> {
+ return mat3x2<f32>(ub_load_4((offset + 0u)), ub_load_4((offset + 8u)), ub_load_4((offset + 16u)));
+}
+
+fn ub_load_20(offset : u32) -> mat3x3<f32> {
+ return mat3x3<f32>(ub_load_8((offset + 0u)), ub_load_8((offset + 16u)), ub_load_8((offset + 32u)));
+}
+
+fn ub_load_21(offset : u32) -> mat3x4<f32> {
+ return mat3x4<f32>(ub_load_12((offset + 0u)), ub_load_12((offset + 16u)), ub_load_12((offset + 32u)));
+}
+
+fn ub_load_22(offset : u32) -> mat4x2<f32> {
+ return mat4x2<f32>(ub_load_4((offset + 0u)), ub_load_4((offset + 8u)), ub_load_4((offset + 16u)), ub_load_4((offset + 24u)));
+}
+
+fn ub_load_23(offset : u32) -> mat4x3<f32> {
+ return mat4x3<f32>(ub_load_8((offset + 0u)), ub_load_8((offset + 16u)), ub_load_8((offset + 32u)), ub_load_8((offset + 48u)));
+}
+
+fn ub_load_24(offset : u32) -> mat4x4<f32> {
+ return mat4x4<f32>(ub_load_12((offset + 0u)), ub_load_12((offset + 16u)), ub_load_12((offset + 32u)), ub_load_12((offset + 48u)));
+}
+
+fn ub_load_25(offset : u32) -> mat2x2<f16> {
+ return mat2x2<f16>(ub_load_7((offset + 0u)), ub_load_7((offset + 4u)));
+}
+
+fn ub_load_26(offset : u32) -> mat2x3<f16> {
+ return mat2x3<f16>(ub_load_11((offset + 0u)), ub_load_11((offset + 8u)));
+}
+
+fn ub_load_27(offset : u32) -> mat2x4<f16> {
+ return mat2x4<f16>(ub_load_15((offset + 0u)), ub_load_15((offset + 8u)));
+}
+
+fn ub_load_28(offset : u32) -> mat3x2<f16> {
+ return mat3x2<f16>(ub_load_7((offset + 0u)), ub_load_7((offset + 4u)), ub_load_7((offset + 8u)));
+}
+
+fn ub_load_29(offset : u32) -> mat3x3<f16> {
+ return mat3x3<f16>(ub_load_11((offset + 0u)), ub_load_11((offset + 8u)), ub_load_11((offset + 16u)));
+}
+
+fn ub_load_30(offset : u32) -> mat3x4<f16> {
+ return mat3x4<f16>(ub_load_15((offset + 0u)), ub_load_15((offset + 8u)), ub_load_15((offset + 16u)));
+}
+
+fn ub_load_31(offset : u32) -> mat4x2<f16> {
+ return mat4x2<f16>(ub_load_7((offset + 0u)), ub_load_7((offset + 4u)), ub_load_7((offset + 8u)), ub_load_7((offset + 12u)));
+}
+
+fn ub_load_32(offset : u32) -> mat4x3<f16> {
+ return mat4x3<f16>(ub_load_11((offset + 0u)), ub_load_11((offset + 8u)), ub_load_11((offset + 16u)), ub_load_11((offset + 24u)));
+}
+
+fn ub_load_33(offset : u32) -> mat4x4<f16> {
+ return mat4x4<f16>(ub_load_15((offset + 0u)), ub_load_15((offset + 8u)), ub_load_15((offset + 16u)), ub_load_15((offset + 24u)));
+}
+
+fn ub_load_34(offset : u32) -> array<vec3<f32>, 2u> {
+ var arr : array<vec3<f32>, 2u>;
+ for(var i = 0u; (i < 2u); i = (i + 1u)) {
+ arr[i] = ub_load_8((offset + (i * 16u)));
+ }
+ return arr;
+}
+
+fn ub_load_35(offset : u32) -> array<mat4x2<f16>, 2u> {
+ var arr_1 : array<mat4x2<f16>, 2u>;
+ for(var i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) {
+ arr_1[i_1] = ub_load_31((offset + (i_1 * 16u)));
+ }
+ return arr_1;
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ var scalar_f32 : f32 = ub_load(0u);
+ var scalar_i32 : i32 = ub_load_1(4u);
+ var scalar_u32 : u32 = ub_load_2(8u);
+ var scalar_f16 : f16 = ub_load_3(12u);
+ var vec2_f32 : vec2<f32> = ub_load_4(16u);
+ var vec2_i32 : vec2<i32> = ub_load_5(24u);
+ var vec2_u32 : vec2<u32> = ub_load_6(32u);
+ var vec2_f16 : vec2<f16> = ub_load_7(40u);
+ var vec3_f32 : vec3<f32> = ub_load_8(48u);
+ var vec3_i32 : vec3<i32> = ub_load_9(64u);
+ var vec3_u32 : vec3<u32> = ub_load_10(80u);
+ var vec3_f16 : vec3<f16> = ub_load_11(96u);
+ var vec4_f32 : vec4<f32> = ub_load_12(112u);
+ var vec4_i32 : vec4<i32> = ub_load_13(128u);
+ var vec4_u32 : vec4<u32> = ub_load_14(144u);
+ var vec4_f16 : vec4<f16> = ub_load_15(160u);
+ var mat2x2_f32 : mat2x2<f32> = ub_load_16(168u);
+ var mat2x3_f32 : mat2x3<f32> = ub_load_17(192u);
+ var mat2x4_f32 : mat2x4<f32> = ub_load_18(224u);
+ var mat3x2_f32 : mat3x2<f32> = ub_load_19(256u);
+ var mat3x3_f32 : mat3x3<f32> = ub_load_20(288u);
+ var mat3x4_f32 : mat3x4<f32> = ub_load_21(336u);
+ var mat4x2_f32 : mat4x2<f32> = ub_load_22(384u);
+ var mat4x3_f32 : mat4x3<f32> = ub_load_23(416u);
+ var mat4x4_f32 : mat4x4<f32> = ub_load_24(480u);
+ var mat2x2_f16 : mat2x2<f16> = ub_load_25(544u);
+ var mat2x3_f16 : mat2x3<f16> = ub_load_26(552u);
+ var mat2x4_f16 : mat2x4<f16> = ub_load_27(568u);
+ var mat3x2_f16 : mat3x2<f16> = ub_load_28(584u);
+ var mat3x3_f16 : mat3x3<f16> = ub_load_29(600u);
+ var mat3x4_f16 : mat3x4<f16> = ub_load_30(624u);
+ var mat4x2_f16 : mat4x2<f16> = ub_load_31(648u);
+ var mat4x3_f16 : mat4x3<f16> = ub_load_32(664u);
+ var mat4x4_f16 : mat4x4<f16> = ub_load_33(696u);
+ var arr2_vec3_f32 : array<vec3<f32>, 2> = ub_load_34(736u);
+ var arr2_mat4x2_f16 : array<mat4x2<f16>, 2> = ub_load_35(768u);
+}
+
+@group(0) @binding(0) var<uniform> ub : UB;
+
+struct UB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_mat4x2_f16 : array<mat4x2<f16>, 2>,
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, SB_BasicStore) {
+ auto* src = R"(
+enable f16;
+
+struct SB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_mat4x2_f16 : array<mat4x2<f16>, 2>,
+};
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ sb.scalar_f32 = f32();
+ sb.scalar_i32 = i32();
+ sb.scalar_u32 = u32();
+ sb.scalar_f16 = f16();
+ sb.vec2_f32 = vec2<f32>();
+ sb.vec2_i32 = vec2<i32>();
+ sb.vec2_u32 = vec2<u32>();
+ sb.vec2_f16 = vec2<f16>();
+ sb.vec3_f32 = vec3<f32>();
+ sb.vec3_i32 = vec3<i32>();
+ sb.vec3_u32 = vec3<u32>();
+ sb.vec3_f16 = vec3<f16>();
+ sb.vec4_f32 = vec4<f32>();
+ sb.vec4_i32 = vec4<i32>();
+ sb.vec4_u32 = vec4<u32>();
+ sb.vec4_f16 = vec4<f16>();
+ sb.mat2x2_f32 = mat2x2<f32>();
+ sb.mat2x3_f32 = mat2x3<f32>();
+ sb.mat2x4_f32 = mat2x4<f32>();
+ sb.mat3x2_f32 = mat3x2<f32>();
+ sb.mat3x3_f32 = mat3x3<f32>();
+ sb.mat3x4_f32 = mat3x4<f32>();
+ sb.mat4x2_f32 = mat4x2<f32>();
+ sb.mat4x3_f32 = mat4x3<f32>();
+ sb.mat4x4_f32 = mat4x4<f32>();
+ sb.mat2x2_f16 = mat2x2<f16>();
+ sb.mat2x3_f16 = mat2x3<f16>();
+ sb.mat2x4_f16 = mat2x4<f16>();
+ sb.mat3x2_f16 = mat3x2<f16>();
+ sb.mat3x3_f16 = mat3x3<f16>();
+ sb.mat3x4_f16 = mat3x4<f16>();
+ sb.mat4x2_f16 = mat4x2<f16>();
+ sb.mat4x3_f16 = mat4x3<f16>();
+ sb.mat4x4_f16 = mat4x4<f16>();
+ sb.arr2_vec3_f32 = array<vec3<f32>, 2>();
+ sb.arr2_mat4x2_f16 = array<mat4x2<f16>, 2>();
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct SB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_mat4x2_f16 : array<mat4x2<f16>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+@internal(intrinsic_store_storage_f32) @internal(disable_validation__function_has_no_body)
+fn sb_store(offset : u32, value : f32)
+
+@internal(intrinsic_store_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sb_store_1(offset : u32, value : i32)
+
+@internal(intrinsic_store_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sb_store_2(offset : u32, value : u32)
+
+@internal(intrinsic_store_storage_f16) @internal(disable_validation__function_has_no_body)
+fn sb_store_3(offset : u32, value : f16)
+
+@internal(intrinsic_store_storage_vec2_f32) @internal(disable_validation__function_has_no_body)
+fn sb_store_4(offset : u32, value : vec2<f32>)
+
+@internal(intrinsic_store_storage_vec2_i32) @internal(disable_validation__function_has_no_body)
+fn sb_store_5(offset : u32, value : vec2<i32>)
+
+@internal(intrinsic_store_storage_vec2_u32) @internal(disable_validation__function_has_no_body)
+fn sb_store_6(offset : u32, value : vec2<u32>)
+
+@internal(intrinsic_store_storage_vec2_f16) @internal(disable_validation__function_has_no_body)
+fn sb_store_7(offset : u32, value : vec2<f16>)
+
+@internal(intrinsic_store_storage_vec3_f32) @internal(disable_validation__function_has_no_body)
+fn sb_store_8(offset : u32, value : vec3<f32>)
+
+@internal(intrinsic_store_storage_vec3_i32) @internal(disable_validation__function_has_no_body)
+fn sb_store_9(offset : u32, value : vec3<i32>)
+
+@internal(intrinsic_store_storage_vec3_u32) @internal(disable_validation__function_has_no_body)
+fn sb_store_10(offset : u32, value : vec3<u32>)
+
+@internal(intrinsic_store_storage_vec3_f16) @internal(disable_validation__function_has_no_body)
+fn sb_store_11(offset : u32, value : vec3<f16>)
+
+@internal(intrinsic_store_storage_vec4_f32) @internal(disable_validation__function_has_no_body)
+fn sb_store_12(offset : u32, value : vec4<f32>)
+
+@internal(intrinsic_store_storage_vec4_i32) @internal(disable_validation__function_has_no_body)
+fn sb_store_13(offset : u32, value : vec4<i32>)
+
+@internal(intrinsic_store_storage_vec4_u32) @internal(disable_validation__function_has_no_body)
+fn sb_store_14(offset : u32, value : vec4<u32>)
+
+@internal(intrinsic_store_storage_vec4_f16) @internal(disable_validation__function_has_no_body)
+fn sb_store_15(offset : u32, value : vec4<f16>)
+
+fn sb_store_16(offset : u32, value : mat2x2<f32>) {
+ sb_store_4((offset + 0u), value[0u]);
+ sb_store_4((offset + 8u), value[1u]);
+}
+
+fn sb_store_17(offset : u32, value : mat2x3<f32>) {
+ sb_store_8((offset + 0u), value[0u]);
+ sb_store_8((offset + 16u), value[1u]);
+}
+
+fn sb_store_18(offset : u32, value : mat2x4<f32>) {
+ sb_store_12((offset + 0u), value[0u]);
+ sb_store_12((offset + 16u), value[1u]);
+}
+
+fn sb_store_19(offset : u32, value : mat3x2<f32>) {
+ sb_store_4((offset + 0u), value[0u]);
+ sb_store_4((offset + 8u), value[1u]);
+ sb_store_4((offset + 16u), value[2u]);
+}
+
+fn sb_store_20(offset : u32, value : mat3x3<f32>) {
+ sb_store_8((offset + 0u), value[0u]);
+ sb_store_8((offset + 16u), value[1u]);
+ sb_store_8((offset + 32u), value[2u]);
+}
+
+fn sb_store_21(offset : u32, value : mat3x4<f32>) {
+ sb_store_12((offset + 0u), value[0u]);
+ sb_store_12((offset + 16u), value[1u]);
+ sb_store_12((offset + 32u), value[2u]);
+}
+
+fn sb_store_22(offset : u32, value : mat4x2<f32>) {
+ sb_store_4((offset + 0u), value[0u]);
+ sb_store_4((offset + 8u), value[1u]);
+ sb_store_4((offset + 16u), value[2u]);
+ sb_store_4((offset + 24u), value[3u]);
+}
+
+fn sb_store_23(offset : u32, value : mat4x3<f32>) {
+ sb_store_8((offset + 0u), value[0u]);
+ sb_store_8((offset + 16u), value[1u]);
+ sb_store_8((offset + 32u), value[2u]);
+ sb_store_8((offset + 48u), value[3u]);
+}
+
+fn sb_store_24(offset : u32, value : mat4x4<f32>) {
+ sb_store_12((offset + 0u), value[0u]);
+ sb_store_12((offset + 16u), value[1u]);
+ sb_store_12((offset + 32u), value[2u]);
+ sb_store_12((offset + 48u), value[3u]);
+}
+
+fn sb_store_25(offset : u32, value : mat2x2<f16>) {
+ sb_store_7((offset + 0u), value[0u]);
+ sb_store_7((offset + 4u), value[1u]);
+}
+
+fn sb_store_26(offset : u32, value : mat2x3<f16>) {
+ sb_store_11((offset + 0u), value[0u]);
+ sb_store_11((offset + 8u), value[1u]);
+}
+
+fn sb_store_27(offset : u32, value : mat2x4<f16>) {
+ sb_store_15((offset + 0u), value[0u]);
+ sb_store_15((offset + 8u), value[1u]);
+}
+
+fn sb_store_28(offset : u32, value : mat3x2<f16>) {
+ sb_store_7((offset + 0u), value[0u]);
+ sb_store_7((offset + 4u), value[1u]);
+ sb_store_7((offset + 8u), value[2u]);
+}
+
+fn sb_store_29(offset : u32, value : mat3x3<f16>) {
+ sb_store_11((offset + 0u), value[0u]);
+ sb_store_11((offset + 8u), value[1u]);
+ sb_store_11((offset + 16u), value[2u]);
+}
+
+fn sb_store_30(offset : u32, value : mat3x4<f16>) {
+ sb_store_15((offset + 0u), value[0u]);
+ sb_store_15((offset + 8u), value[1u]);
+ sb_store_15((offset + 16u), value[2u]);
+}
+
+fn sb_store_31(offset : u32, value : mat4x2<f16>) {
+ sb_store_7((offset + 0u), value[0u]);
+ sb_store_7((offset + 4u), value[1u]);
+ sb_store_7((offset + 8u), value[2u]);
+ sb_store_7((offset + 12u), value[3u]);
+}
+
+fn sb_store_32(offset : u32, value : mat4x3<f16>) {
+ sb_store_11((offset + 0u), value[0u]);
+ sb_store_11((offset + 8u), value[1u]);
+ sb_store_11((offset + 16u), value[2u]);
+ sb_store_11((offset + 24u), value[3u]);
+}
+
+fn sb_store_33(offset : u32, value : mat4x4<f16>) {
+ sb_store_15((offset + 0u), value[0u]);
+ sb_store_15((offset + 8u), value[1u]);
+ sb_store_15((offset + 16u), value[2u]);
+ sb_store_15((offset + 24u), value[3u]);
+}
+
+fn sb_store_34(offset : u32, value : array<vec3<f32>, 2u>) {
+ var array_1 = value;
+ for(var i = 0u; (i < 2u); i = (i + 1u)) {
+ sb_store_8((offset + (i * 16u)), array_1[i]);
+ }
+}
+
+fn sb_store_35(offset : u32, value : array<mat4x2<f16>, 2u>) {
+ var array_2 = value;
+ for(var i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) {
+ sb_store_31((offset + (i_1 * 16u)), array_2[i_1]);
+ }
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ sb_store(0u, f32());
+ sb_store_1(4u, i32());
+ sb_store_2(8u, u32());
+ sb_store_3(12u, f16());
+ sb_store_4(16u, vec2<f32>());
+ sb_store_5(24u, vec2<i32>());
+ sb_store_6(32u, vec2<u32>());
+ sb_store_7(40u, vec2<f16>());
+ sb_store_8(48u, vec3<f32>());
+ sb_store_9(64u, vec3<i32>());
+ sb_store_10(80u, vec3<u32>());
+ sb_store_11(96u, vec3<f16>());
+ sb_store_12(112u, vec4<f32>());
+ sb_store_13(128u, vec4<i32>());
+ sb_store_14(144u, vec4<u32>());
+ sb_store_15(160u, vec4<f16>());
+ sb_store_16(168u, mat2x2<f32>());
+ sb_store_17(192u, mat2x3<f32>());
+ sb_store_18(224u, mat2x4<f32>());
+ sb_store_19(256u, mat3x2<f32>());
+ sb_store_20(288u, mat3x3<f32>());
+ sb_store_21(336u, mat3x4<f32>());
+ sb_store_22(384u, mat4x2<f32>());
+ sb_store_23(416u, mat4x3<f32>());
+ sb_store_24(480u, mat4x4<f32>());
+ sb_store_25(544u, mat2x2<f16>());
+ sb_store_26(552u, mat2x3<f16>());
+ sb_store_27(568u, mat2x4<f16>());
+ sb_store_28(584u, mat3x2<f16>());
+ sb_store_29(600u, mat3x3<f16>());
+ sb_store_30(624u, mat3x4<f16>());
+ sb_store_31(648u, mat4x2<f16>());
+ sb_store_32(664u, mat4x3<f16>());
+ sb_store_33(696u, mat4x4<f16>());
+ sb_store_34(736u, array<vec3<f32>, 2>());
+ sb_store_35(768u, array<mat4x2<f16>, 2>());
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, SB_BasicStore_OutOfOrder) {
+ auto* src = R"(
+enable f16;
+
+@compute @workgroup_size(1)
+fn main() {
+ sb.scalar_f32 = f32();
+ sb.scalar_i32 = i32();
+ sb.scalar_u32 = u32();
+ sb.scalar_f16 = f16();
+ sb.vec2_f32 = vec2<f32>();
+ sb.vec2_i32 = vec2<i32>();
+ sb.vec2_u32 = vec2<u32>();
+ sb.vec2_f16 = vec2<f16>();
+ sb.vec3_f32 = vec3<f32>();
+ sb.vec3_i32 = vec3<i32>();
+ sb.vec3_u32 = vec3<u32>();
+ sb.vec3_f16 = vec3<f16>();
+ sb.vec4_f32 = vec4<f32>();
+ sb.vec4_i32 = vec4<i32>();
+ sb.vec4_u32 = vec4<u32>();
+ sb.vec4_f16 = vec4<f16>();
+ sb.mat2x2_f32 = mat2x2<f32>();
+ sb.mat2x3_f32 = mat2x3<f32>();
+ sb.mat2x4_f32 = mat2x4<f32>();
+ sb.mat3x2_f32 = mat3x2<f32>();
+ sb.mat3x3_f32 = mat3x3<f32>();
+ sb.mat3x4_f32 = mat3x4<f32>();
+ sb.mat4x2_f32 = mat4x2<f32>();
+ sb.mat4x3_f32 = mat4x3<f32>();
+ sb.mat4x4_f32 = mat4x4<f32>();
+ sb.mat2x2_f16 = mat2x2<f16>();
+ sb.mat2x3_f16 = mat2x3<f16>();
+ sb.mat2x4_f16 = mat2x4<f16>();
+ sb.mat3x2_f16 = mat3x2<f16>();
+ sb.mat3x3_f16 = mat3x3<f16>();
+ sb.mat3x4_f16 = mat3x4<f16>();
+ sb.mat4x2_f16 = mat4x2<f16>();
+ sb.mat4x3_f16 = mat4x3<f16>();
+ sb.mat4x4_f16 = mat4x4<f16>();
+ sb.arr2_vec3_f32 = array<vec3<f32>, 2>();
+ sb.arr2_mat4x2_f16 = array<mat4x2<f16>, 2>();
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+struct SB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_mat4x2_f16 : array<mat4x2<f16>, 2>,
+};
+)";
+
+ auto* expect = R"(
+enable f16;
+
+@internal(intrinsic_store_storage_f32) @internal(disable_validation__function_has_no_body)
+fn sb_store(offset : u32, value : f32)
+
+@internal(intrinsic_store_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sb_store_1(offset : u32, value : i32)
+
+@internal(intrinsic_store_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sb_store_2(offset : u32, value : u32)
+
+@internal(intrinsic_store_storage_f16) @internal(disable_validation__function_has_no_body)
+fn sb_store_3(offset : u32, value : f16)
+
+@internal(intrinsic_store_storage_vec2_f32) @internal(disable_validation__function_has_no_body)
+fn sb_store_4(offset : u32, value : vec2<f32>)
+
+@internal(intrinsic_store_storage_vec2_i32) @internal(disable_validation__function_has_no_body)
+fn sb_store_5(offset : u32, value : vec2<i32>)
+
+@internal(intrinsic_store_storage_vec2_u32) @internal(disable_validation__function_has_no_body)
+fn sb_store_6(offset : u32, value : vec2<u32>)
+
+@internal(intrinsic_store_storage_vec2_f16) @internal(disable_validation__function_has_no_body)
+fn sb_store_7(offset : u32, value : vec2<f16>)
+
+@internal(intrinsic_store_storage_vec3_f32) @internal(disable_validation__function_has_no_body)
+fn sb_store_8(offset : u32, value : vec3<f32>)
+
+@internal(intrinsic_store_storage_vec3_i32) @internal(disable_validation__function_has_no_body)
+fn sb_store_9(offset : u32, value : vec3<i32>)
+
+@internal(intrinsic_store_storage_vec3_u32) @internal(disable_validation__function_has_no_body)
+fn sb_store_10(offset : u32, value : vec3<u32>)
+
+@internal(intrinsic_store_storage_vec3_f16) @internal(disable_validation__function_has_no_body)
+fn sb_store_11(offset : u32, value : vec3<f16>)
+
+@internal(intrinsic_store_storage_vec4_f32) @internal(disable_validation__function_has_no_body)
+fn sb_store_12(offset : u32, value : vec4<f32>)
+
+@internal(intrinsic_store_storage_vec4_i32) @internal(disable_validation__function_has_no_body)
+fn sb_store_13(offset : u32, value : vec4<i32>)
+
+@internal(intrinsic_store_storage_vec4_u32) @internal(disable_validation__function_has_no_body)
+fn sb_store_14(offset : u32, value : vec4<u32>)
+
+@internal(intrinsic_store_storage_vec4_f16) @internal(disable_validation__function_has_no_body)
+fn sb_store_15(offset : u32, value : vec4<f16>)
+
+fn sb_store_16(offset : u32, value : mat2x2<f32>) {
+ sb_store_4((offset + 0u), value[0u]);
+ sb_store_4((offset + 8u), value[1u]);
+}
+
+fn sb_store_17(offset : u32, value : mat2x3<f32>) {
+ sb_store_8((offset + 0u), value[0u]);
+ sb_store_8((offset + 16u), value[1u]);
+}
+
+fn sb_store_18(offset : u32, value : mat2x4<f32>) {
+ sb_store_12((offset + 0u), value[0u]);
+ sb_store_12((offset + 16u), value[1u]);
+}
+
+fn sb_store_19(offset : u32, value : mat3x2<f32>) {
+ sb_store_4((offset + 0u), value[0u]);
+ sb_store_4((offset + 8u), value[1u]);
+ sb_store_4((offset + 16u), value[2u]);
+}
+
+fn sb_store_20(offset : u32, value : mat3x3<f32>) {
+ sb_store_8((offset + 0u), value[0u]);
+ sb_store_8((offset + 16u), value[1u]);
+ sb_store_8((offset + 32u), value[2u]);
+}
+
+fn sb_store_21(offset : u32, value : mat3x4<f32>) {
+ sb_store_12((offset + 0u), value[0u]);
+ sb_store_12((offset + 16u), value[1u]);
+ sb_store_12((offset + 32u), value[2u]);
+}
+
+fn sb_store_22(offset : u32, value : mat4x2<f32>) {
+ sb_store_4((offset + 0u), value[0u]);
+ sb_store_4((offset + 8u), value[1u]);
+ sb_store_4((offset + 16u), value[2u]);
+ sb_store_4((offset + 24u), value[3u]);
+}
+
+fn sb_store_23(offset : u32, value : mat4x3<f32>) {
+ sb_store_8((offset + 0u), value[0u]);
+ sb_store_8((offset + 16u), value[1u]);
+ sb_store_8((offset + 32u), value[2u]);
+ sb_store_8((offset + 48u), value[3u]);
+}
+
+fn sb_store_24(offset : u32, value : mat4x4<f32>) {
+ sb_store_12((offset + 0u), value[0u]);
+ sb_store_12((offset + 16u), value[1u]);
+ sb_store_12((offset + 32u), value[2u]);
+ sb_store_12((offset + 48u), value[3u]);
+}
+
+fn sb_store_25(offset : u32, value : mat2x2<f16>) {
+ sb_store_7((offset + 0u), value[0u]);
+ sb_store_7((offset + 4u), value[1u]);
+}
+
+fn sb_store_26(offset : u32, value : mat2x3<f16>) {
+ sb_store_11((offset + 0u), value[0u]);
+ sb_store_11((offset + 8u), value[1u]);
+}
+
+fn sb_store_27(offset : u32, value : mat2x4<f16>) {
+ sb_store_15((offset + 0u), value[0u]);
+ sb_store_15((offset + 8u), value[1u]);
+}
+
+fn sb_store_28(offset : u32, value : mat3x2<f16>) {
+ sb_store_7((offset + 0u), value[0u]);
+ sb_store_7((offset + 4u), value[1u]);
+ sb_store_7((offset + 8u), value[2u]);
+}
+
+fn sb_store_29(offset : u32, value : mat3x3<f16>) {
+ sb_store_11((offset + 0u), value[0u]);
+ sb_store_11((offset + 8u), value[1u]);
+ sb_store_11((offset + 16u), value[2u]);
+}
+
+fn sb_store_30(offset : u32, value : mat3x4<f16>) {
+ sb_store_15((offset + 0u), value[0u]);
+ sb_store_15((offset + 8u), value[1u]);
+ sb_store_15((offset + 16u), value[2u]);
+}
+
+fn sb_store_31(offset : u32, value : mat4x2<f16>) {
+ sb_store_7((offset + 0u), value[0u]);
+ sb_store_7((offset + 4u), value[1u]);
+ sb_store_7((offset + 8u), value[2u]);
+ sb_store_7((offset + 12u), value[3u]);
+}
+
+fn sb_store_32(offset : u32, value : mat4x3<f16>) {
+ sb_store_11((offset + 0u), value[0u]);
+ sb_store_11((offset + 8u), value[1u]);
+ sb_store_11((offset + 16u), value[2u]);
+ sb_store_11((offset + 24u), value[3u]);
+}
+
+fn sb_store_33(offset : u32, value : mat4x4<f16>) {
+ sb_store_15((offset + 0u), value[0u]);
+ sb_store_15((offset + 8u), value[1u]);
+ sb_store_15((offset + 16u), value[2u]);
+ sb_store_15((offset + 24u), value[3u]);
+}
+
+fn sb_store_34(offset : u32, value : array<vec3<f32>, 2u>) {
+ var array_1 = value;
+ for(var i = 0u; (i < 2u); i = (i + 1u)) {
+ sb_store_8((offset + (i * 16u)), array_1[i]);
+ }
+}
+
+fn sb_store_35(offset : u32, value : array<mat4x2<f16>, 2u>) {
+ var array_2 = value;
+ for(var i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) {
+ sb_store_31((offset + (i_1 * 16u)), array_2[i_1]);
+ }
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ sb_store(0u, f32());
+ sb_store_1(4u, i32());
+ sb_store_2(8u, u32());
+ sb_store_3(12u, f16());
+ sb_store_4(16u, vec2<f32>());
+ sb_store_5(24u, vec2<i32>());
+ sb_store_6(32u, vec2<u32>());
+ sb_store_7(40u, vec2<f16>());
+ sb_store_8(48u, vec3<f32>());
+ sb_store_9(64u, vec3<i32>());
+ sb_store_10(80u, vec3<u32>());
+ sb_store_11(96u, vec3<f16>());
+ sb_store_12(112u, vec4<f32>());
+ sb_store_13(128u, vec4<i32>());
+ sb_store_14(144u, vec4<u32>());
+ sb_store_15(160u, vec4<f16>());
+ sb_store_16(168u, mat2x2<f32>());
+ sb_store_17(192u, mat2x3<f32>());
+ sb_store_18(224u, mat2x4<f32>());
+ sb_store_19(256u, mat3x2<f32>());
+ sb_store_20(288u, mat3x3<f32>());
+ sb_store_21(336u, mat3x4<f32>());
+ sb_store_22(384u, mat4x2<f32>());
+ sb_store_23(416u, mat4x3<f32>());
+ sb_store_24(480u, mat4x4<f32>());
+ sb_store_25(544u, mat2x2<f16>());
+ sb_store_26(552u, mat2x3<f16>());
+ sb_store_27(568u, mat2x4<f16>());
+ sb_store_28(584u, mat3x2<f16>());
+ sb_store_29(600u, mat3x3<f16>());
+ sb_store_30(624u, mat3x4<f16>());
+ sb_store_31(648u, mat4x2<f16>());
+ sb_store_32(664u, mat4x3<f16>());
+ sb_store_33(696u, mat4x4<f16>());
+ sb_store_34(736u, array<vec3<f32>, 2>());
+ sb_store_35(768u, array<mat4x2<f16>, 2>());
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+struct SB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_mat4x2_f16 : array<mat4x2<f16>, 2>,
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, LoadStructure) {
+ auto* src = R"(
+enable f16;
+
+struct SB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_mat4x2_f16 : array<mat4x2<f16>, 2>,
+};
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ var x : SB = sb;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct SB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_mat4x2_f16 : array<mat4x2<f16>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+@internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load_1(offset : u32) -> f32
+
+@internal(intrinsic_load_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sb_load_2(offset : u32) -> i32
+
+@internal(intrinsic_load_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sb_load_3(offset : u32) -> u32
+
+@internal(intrinsic_load_storage_f16) @internal(disable_validation__function_has_no_body)
+fn sb_load_4(offset : u32) -> f16
+
+@internal(intrinsic_load_storage_vec2_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load_5(offset : u32) -> vec2<f32>
+
+@internal(intrinsic_load_storage_vec2_i32) @internal(disable_validation__function_has_no_body)
+fn sb_load_6(offset : u32) -> vec2<i32>
+
+@internal(intrinsic_load_storage_vec2_u32) @internal(disable_validation__function_has_no_body)
+fn sb_load_7(offset : u32) -> vec2<u32>
+
+@internal(intrinsic_load_storage_vec2_f16) @internal(disable_validation__function_has_no_body)
+fn sb_load_8(offset : u32) -> vec2<f16>
+
+@internal(intrinsic_load_storage_vec3_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load_9(offset : u32) -> vec3<f32>
+
+@internal(intrinsic_load_storage_vec3_i32) @internal(disable_validation__function_has_no_body)
+fn sb_load_10(offset : u32) -> vec3<i32>
+
+@internal(intrinsic_load_storage_vec3_u32) @internal(disable_validation__function_has_no_body)
+fn sb_load_11(offset : u32) -> vec3<u32>
+
+@internal(intrinsic_load_storage_vec3_f16) @internal(disable_validation__function_has_no_body)
+fn sb_load_12(offset : u32) -> vec3<f16>
+
+@internal(intrinsic_load_storage_vec4_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load_13(offset : u32) -> vec4<f32>
+
+@internal(intrinsic_load_storage_vec4_i32) @internal(disable_validation__function_has_no_body)
+fn sb_load_14(offset : u32) -> vec4<i32>
+
+@internal(intrinsic_load_storage_vec4_u32) @internal(disable_validation__function_has_no_body)
+fn sb_load_15(offset : u32) -> vec4<u32>
+
+@internal(intrinsic_load_storage_vec4_f16) @internal(disable_validation__function_has_no_body)
+fn sb_load_16(offset : u32) -> vec4<f16>
+
+fn sb_load_17(offset : u32) -> mat2x2<f32> {
+ return mat2x2<f32>(sb_load_5((offset + 0u)), sb_load_5((offset + 8u)));
+}
+
+fn sb_load_18(offset : u32) -> mat2x3<f32> {
+ return mat2x3<f32>(sb_load_9((offset + 0u)), sb_load_9((offset + 16u)));
+}
+
+fn sb_load_19(offset : u32) -> mat2x4<f32> {
+ return mat2x4<f32>(sb_load_13((offset + 0u)), sb_load_13((offset + 16u)));
+}
+
+fn sb_load_20(offset : u32) -> mat3x2<f32> {
+ return mat3x2<f32>(sb_load_5((offset + 0u)), sb_load_5((offset + 8u)), sb_load_5((offset + 16u)));
+}
+
+fn sb_load_21(offset : u32) -> mat3x3<f32> {
+ return mat3x3<f32>(sb_load_9((offset + 0u)), sb_load_9((offset + 16u)), sb_load_9((offset + 32u)));
+}
+
+fn sb_load_22(offset : u32) -> mat3x4<f32> {
+ return mat3x4<f32>(sb_load_13((offset + 0u)), sb_load_13((offset + 16u)), sb_load_13((offset + 32u)));
+}
+
+fn sb_load_23(offset : u32) -> mat4x2<f32> {
+ return mat4x2<f32>(sb_load_5((offset + 0u)), sb_load_5((offset + 8u)), sb_load_5((offset + 16u)), sb_load_5((offset + 24u)));
+}
+
+fn sb_load_24(offset : u32) -> mat4x3<f32> {
+ return mat4x3<f32>(sb_load_9((offset + 0u)), sb_load_9((offset + 16u)), sb_load_9((offset + 32u)), sb_load_9((offset + 48u)));
+}
+
+fn sb_load_25(offset : u32) -> mat4x4<f32> {
+ return mat4x4<f32>(sb_load_13((offset + 0u)), sb_load_13((offset + 16u)), sb_load_13((offset + 32u)), sb_load_13((offset + 48u)));
+}
+
+fn sb_load_26(offset : u32) -> mat2x2<f16> {
+ return mat2x2<f16>(sb_load_8((offset + 0u)), sb_load_8((offset + 4u)));
+}
+
+fn sb_load_27(offset : u32) -> mat2x3<f16> {
+ return mat2x3<f16>(sb_load_12((offset + 0u)), sb_load_12((offset + 8u)));
+}
+
+fn sb_load_28(offset : u32) -> mat2x4<f16> {
+ return mat2x4<f16>(sb_load_16((offset + 0u)), sb_load_16((offset + 8u)));
+}
+
+fn sb_load_29(offset : u32) -> mat3x2<f16> {
+ return mat3x2<f16>(sb_load_8((offset + 0u)), sb_load_8((offset + 4u)), sb_load_8((offset + 8u)));
+}
+
+fn sb_load_30(offset : u32) -> mat3x3<f16> {
+ return mat3x3<f16>(sb_load_12((offset + 0u)), sb_load_12((offset + 8u)), sb_load_12((offset + 16u)));
+}
+
+fn sb_load_31(offset : u32) -> mat3x4<f16> {
+ return mat3x4<f16>(sb_load_16((offset + 0u)), sb_load_16((offset + 8u)), sb_load_16((offset + 16u)));
+}
+
+fn sb_load_32(offset : u32) -> mat4x2<f16> {
+ return mat4x2<f16>(sb_load_8((offset + 0u)), sb_load_8((offset + 4u)), sb_load_8((offset + 8u)), sb_load_8((offset + 12u)));
+}
+
+fn sb_load_33(offset : u32) -> mat4x3<f16> {
+ return mat4x3<f16>(sb_load_12((offset + 0u)), sb_load_12((offset + 8u)), sb_load_12((offset + 16u)), sb_load_12((offset + 24u)));
+}
+
+fn sb_load_34(offset : u32) -> mat4x4<f16> {
+ return mat4x4<f16>(sb_load_16((offset + 0u)), sb_load_16((offset + 8u)), sb_load_16((offset + 16u)), sb_load_16((offset + 24u)));
+}
+
+fn sb_load_35(offset : u32) -> array<vec3<f32>, 2u> {
+ var arr : array<vec3<f32>, 2u>;
+ for(var i = 0u; (i < 2u); i = (i + 1u)) {
+ arr[i] = sb_load_9((offset + (i * 16u)));
+ }
+ return arr;
+}
+
+fn sb_load_36(offset : u32) -> array<mat4x2<f16>, 2u> {
+ var arr_1 : array<mat4x2<f16>, 2u>;
+ for(var i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) {
+ arr_1[i_1] = sb_load_32((offset + (i_1 * 16u)));
+ }
+ return arr_1;
+}
+
+fn sb_load(offset : u32) -> SB {
+ return SB(sb_load_1((offset + 0u)), sb_load_2((offset + 4u)), sb_load_3((offset + 8u)), sb_load_4((offset + 12u)), sb_load_5((offset + 16u)), sb_load_6((offset + 24u)), sb_load_7((offset + 32u)), sb_load_8((offset + 40u)), sb_load_9((offset + 48u)), sb_load_10((offset + 64u)), sb_load_11((offset + 80u)), sb_load_12((offset + 96u)), sb_load_13((offset + 112u)), sb_load_14((offset + 128u)), sb_load_15((offset + 144u)), sb_load_16((offset + 160u)), sb_load_17((offset + 168u)), sb_load_18((offset + 192u)), sb_load_19((offset + 224u)), sb_load_20((offset + 256u)), sb_load_21((offset + 288u)), sb_load_22((offset + 336u)), sb_load_23((offset + 384u)), sb_load_24((offset + 416u)), sb_load_25((offset + 480u)), sb_load_26((offset + 544u)), sb_load_27((offset + 552u)), sb_load_28((offset + 568u)), sb_load_29((offset + 584u)), sb_load_30((offset + 600u)), sb_load_31((offset + 624u)), sb_load_32((offset + 648u)), sb_load_33((offset + 664u)), sb_load_34((offset + 696u)), sb_load_35((offset + 736u)), sb_load_36((offset + 768u)));
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ var x : SB = sb_load(0u);
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, LoadStructure_OutOfOrder) {
+ auto* src = R"(
+enable f16;
+
+@compute @workgroup_size(1)
+fn main() {
+ var x : SB = sb;
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+struct SB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_mat4x2_f16 : array<mat4x2<f16>, 2>,
+};
+)";
+
+ auto* expect = R"(
+enable f16;
+
+@internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load_1(offset : u32) -> f32
+
+@internal(intrinsic_load_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sb_load_2(offset : u32) -> i32
+
+@internal(intrinsic_load_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sb_load_3(offset : u32) -> u32
+
+@internal(intrinsic_load_storage_f16) @internal(disable_validation__function_has_no_body)
+fn sb_load_4(offset : u32) -> f16
+
+@internal(intrinsic_load_storage_vec2_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load_5(offset : u32) -> vec2<f32>
+
+@internal(intrinsic_load_storage_vec2_i32) @internal(disable_validation__function_has_no_body)
+fn sb_load_6(offset : u32) -> vec2<i32>
+
+@internal(intrinsic_load_storage_vec2_u32) @internal(disable_validation__function_has_no_body)
+fn sb_load_7(offset : u32) -> vec2<u32>
+
+@internal(intrinsic_load_storage_vec2_f16) @internal(disable_validation__function_has_no_body)
+fn sb_load_8(offset : u32) -> vec2<f16>
+
+@internal(intrinsic_load_storage_vec3_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load_9(offset : u32) -> vec3<f32>
+
+@internal(intrinsic_load_storage_vec3_i32) @internal(disable_validation__function_has_no_body)
+fn sb_load_10(offset : u32) -> vec3<i32>
+
+@internal(intrinsic_load_storage_vec3_u32) @internal(disable_validation__function_has_no_body)
+fn sb_load_11(offset : u32) -> vec3<u32>
+
+@internal(intrinsic_load_storage_vec3_f16) @internal(disable_validation__function_has_no_body)
+fn sb_load_12(offset : u32) -> vec3<f16>
+
+@internal(intrinsic_load_storage_vec4_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load_13(offset : u32) -> vec4<f32>
+
+@internal(intrinsic_load_storage_vec4_i32) @internal(disable_validation__function_has_no_body)
+fn sb_load_14(offset : u32) -> vec4<i32>
+
+@internal(intrinsic_load_storage_vec4_u32) @internal(disable_validation__function_has_no_body)
+fn sb_load_15(offset : u32) -> vec4<u32>
+
+@internal(intrinsic_load_storage_vec4_f16) @internal(disable_validation__function_has_no_body)
+fn sb_load_16(offset : u32) -> vec4<f16>
+
+fn sb_load_17(offset : u32) -> mat2x2<f32> {
+ return mat2x2<f32>(sb_load_5((offset + 0u)), sb_load_5((offset + 8u)));
+}
+
+fn sb_load_18(offset : u32) -> mat2x3<f32> {
+ return mat2x3<f32>(sb_load_9((offset + 0u)), sb_load_9((offset + 16u)));
+}
+
+fn sb_load_19(offset : u32) -> mat2x4<f32> {
+ return mat2x4<f32>(sb_load_13((offset + 0u)), sb_load_13((offset + 16u)));
+}
+
+fn sb_load_20(offset : u32) -> mat3x2<f32> {
+ return mat3x2<f32>(sb_load_5((offset + 0u)), sb_load_5((offset + 8u)), sb_load_5((offset + 16u)));
+}
+
+fn sb_load_21(offset : u32) -> mat3x3<f32> {
+ return mat3x3<f32>(sb_load_9((offset + 0u)), sb_load_9((offset + 16u)), sb_load_9((offset + 32u)));
+}
+
+fn sb_load_22(offset : u32) -> mat3x4<f32> {
+ return mat3x4<f32>(sb_load_13((offset + 0u)), sb_load_13((offset + 16u)), sb_load_13((offset + 32u)));
+}
+
+fn sb_load_23(offset : u32) -> mat4x2<f32> {
+ return mat4x2<f32>(sb_load_5((offset + 0u)), sb_load_5((offset + 8u)), sb_load_5((offset + 16u)), sb_load_5((offset + 24u)));
+}
+
+fn sb_load_24(offset : u32) -> mat4x3<f32> {
+ return mat4x3<f32>(sb_load_9((offset + 0u)), sb_load_9((offset + 16u)), sb_load_9((offset + 32u)), sb_load_9((offset + 48u)));
+}
+
+fn sb_load_25(offset : u32) -> mat4x4<f32> {
+ return mat4x4<f32>(sb_load_13((offset + 0u)), sb_load_13((offset + 16u)), sb_load_13((offset + 32u)), sb_load_13((offset + 48u)));
+}
+
+fn sb_load_26(offset : u32) -> mat2x2<f16> {
+ return mat2x2<f16>(sb_load_8((offset + 0u)), sb_load_8((offset + 4u)));
+}
+
+fn sb_load_27(offset : u32) -> mat2x3<f16> {
+ return mat2x3<f16>(sb_load_12((offset + 0u)), sb_load_12((offset + 8u)));
+}
+
+fn sb_load_28(offset : u32) -> mat2x4<f16> {
+ return mat2x4<f16>(sb_load_16((offset + 0u)), sb_load_16((offset + 8u)));
+}
+
+fn sb_load_29(offset : u32) -> mat3x2<f16> {
+ return mat3x2<f16>(sb_load_8((offset + 0u)), sb_load_8((offset + 4u)), sb_load_8((offset + 8u)));
+}
+
+fn sb_load_30(offset : u32) -> mat3x3<f16> {
+ return mat3x3<f16>(sb_load_12((offset + 0u)), sb_load_12((offset + 8u)), sb_load_12((offset + 16u)));
+}
+
+fn sb_load_31(offset : u32) -> mat3x4<f16> {
+ return mat3x4<f16>(sb_load_16((offset + 0u)), sb_load_16((offset + 8u)), sb_load_16((offset + 16u)));
+}
+
+fn sb_load_32(offset : u32) -> mat4x2<f16> {
+ return mat4x2<f16>(sb_load_8((offset + 0u)), sb_load_8((offset + 4u)), sb_load_8((offset + 8u)), sb_load_8((offset + 12u)));
+}
+
+fn sb_load_33(offset : u32) -> mat4x3<f16> {
+ return mat4x3<f16>(sb_load_12((offset + 0u)), sb_load_12((offset + 8u)), sb_load_12((offset + 16u)), sb_load_12((offset + 24u)));
+}
+
+fn sb_load_34(offset : u32) -> mat4x4<f16> {
+ return mat4x4<f16>(sb_load_16((offset + 0u)), sb_load_16((offset + 8u)), sb_load_16((offset + 16u)), sb_load_16((offset + 24u)));
+}
+
+fn sb_load_35(offset : u32) -> array<vec3<f32>, 2u> {
+ var arr : array<vec3<f32>, 2u>;
+ for(var i = 0u; (i < 2u); i = (i + 1u)) {
+ arr[i] = sb_load_9((offset + (i * 16u)));
+ }
+ return arr;
+}
+
+fn sb_load_36(offset : u32) -> array<mat4x2<f16>, 2u> {
+ var arr_1 : array<mat4x2<f16>, 2u>;
+ for(var i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) {
+ arr_1[i_1] = sb_load_32((offset + (i_1 * 16u)));
+ }
+ return arr_1;
+}
+
+fn sb_load(offset : u32) -> SB {
+ return SB(sb_load_1((offset + 0u)), sb_load_2((offset + 4u)), sb_load_3((offset + 8u)), sb_load_4((offset + 12u)), sb_load_5((offset + 16u)), sb_load_6((offset + 24u)), sb_load_7((offset + 32u)), sb_load_8((offset + 40u)), sb_load_9((offset + 48u)), sb_load_10((offset + 64u)), sb_load_11((offset + 80u)), sb_load_12((offset + 96u)), sb_load_13((offset + 112u)), sb_load_14((offset + 128u)), sb_load_15((offset + 144u)), sb_load_16((offset + 160u)), sb_load_17((offset + 168u)), sb_load_18((offset + 192u)), sb_load_19((offset + 224u)), sb_load_20((offset + 256u)), sb_load_21((offset + 288u)), sb_load_22((offset + 336u)), sb_load_23((offset + 384u)), sb_load_24((offset + 416u)), sb_load_25((offset + 480u)), sb_load_26((offset + 544u)), sb_load_27((offset + 552u)), sb_load_28((offset + 568u)), sb_load_29((offset + 584u)), sb_load_30((offset + 600u)), sb_load_31((offset + 624u)), sb_load_32((offset + 648u)), sb_load_33((offset + 664u)), sb_load_34((offset + 696u)), sb_load_35((offset + 736u)), sb_load_36((offset + 768u)));
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ var x : SB = sb_load(0u);
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+struct SB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_mat4x2_f16 : array<mat4x2<f16>, 2>,
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, StoreStructure) {
+ auto* src = R"(
+enable f16;
+
+struct SB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_mat4x2_f16 : array<mat4x2<f16>, 2>,
+};
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ sb = SB();
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct SB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_mat4x2_f16 : array<mat4x2<f16>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+@internal(intrinsic_store_storage_f32) @internal(disable_validation__function_has_no_body)
+fn sb_store_1(offset : u32, value : f32)
+
+@internal(intrinsic_store_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sb_store_2(offset : u32, value : i32)
+
+@internal(intrinsic_store_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sb_store_3(offset : u32, value : u32)
+
+@internal(intrinsic_store_storage_f16) @internal(disable_validation__function_has_no_body)
+fn sb_store_4(offset : u32, value : f16)
+
+@internal(intrinsic_store_storage_vec2_f32) @internal(disable_validation__function_has_no_body)
+fn sb_store_5(offset : u32, value : vec2<f32>)
+
+@internal(intrinsic_store_storage_vec2_i32) @internal(disable_validation__function_has_no_body)
+fn sb_store_6(offset : u32, value : vec2<i32>)
+
+@internal(intrinsic_store_storage_vec2_u32) @internal(disable_validation__function_has_no_body)
+fn sb_store_7(offset : u32, value : vec2<u32>)
+
+@internal(intrinsic_store_storage_vec2_f16) @internal(disable_validation__function_has_no_body)
+fn sb_store_8(offset : u32, value : vec2<f16>)
+
+@internal(intrinsic_store_storage_vec3_f32) @internal(disable_validation__function_has_no_body)
+fn sb_store_9(offset : u32, value : vec3<f32>)
+
+@internal(intrinsic_store_storage_vec3_i32) @internal(disable_validation__function_has_no_body)
+fn sb_store_10(offset : u32, value : vec3<i32>)
+
+@internal(intrinsic_store_storage_vec3_u32) @internal(disable_validation__function_has_no_body)
+fn sb_store_11(offset : u32, value : vec3<u32>)
+
+@internal(intrinsic_store_storage_vec3_f16) @internal(disable_validation__function_has_no_body)
+fn sb_store_12(offset : u32, value : vec3<f16>)
+
+@internal(intrinsic_store_storage_vec4_f32) @internal(disable_validation__function_has_no_body)
+fn sb_store_13(offset : u32, value : vec4<f32>)
+
+@internal(intrinsic_store_storage_vec4_i32) @internal(disable_validation__function_has_no_body)
+fn sb_store_14(offset : u32, value : vec4<i32>)
+
+@internal(intrinsic_store_storage_vec4_u32) @internal(disable_validation__function_has_no_body)
+fn sb_store_15(offset : u32, value : vec4<u32>)
+
+@internal(intrinsic_store_storage_vec4_f16) @internal(disable_validation__function_has_no_body)
+fn sb_store_16(offset : u32, value : vec4<f16>)
+
+fn sb_store_17(offset : u32, value : mat2x2<f32>) {
+ sb_store_5((offset + 0u), value[0u]);
+ sb_store_5((offset + 8u), value[1u]);
+}
+
+fn sb_store_18(offset : u32, value : mat2x3<f32>) {
+ sb_store_9((offset + 0u), value[0u]);
+ sb_store_9((offset + 16u), value[1u]);
+}
+
+fn sb_store_19(offset : u32, value : mat2x4<f32>) {
+ sb_store_13((offset + 0u), value[0u]);
+ sb_store_13((offset + 16u), value[1u]);
+}
+
+fn sb_store_20(offset : u32, value : mat3x2<f32>) {
+ sb_store_5((offset + 0u), value[0u]);
+ sb_store_5((offset + 8u), value[1u]);
+ sb_store_5((offset + 16u), value[2u]);
+}
+
+fn sb_store_21(offset : u32, value : mat3x3<f32>) {
+ sb_store_9((offset + 0u), value[0u]);
+ sb_store_9((offset + 16u), value[1u]);
+ sb_store_9((offset + 32u), value[2u]);
+}
+
+fn sb_store_22(offset : u32, value : mat3x4<f32>) {
+ sb_store_13((offset + 0u), value[0u]);
+ sb_store_13((offset + 16u), value[1u]);
+ sb_store_13((offset + 32u), value[2u]);
+}
+
+fn sb_store_23(offset : u32, value : mat4x2<f32>) {
+ sb_store_5((offset + 0u), value[0u]);
+ sb_store_5((offset + 8u), value[1u]);
+ sb_store_5((offset + 16u), value[2u]);
+ sb_store_5((offset + 24u), value[3u]);
+}
+
+fn sb_store_24(offset : u32, value : mat4x3<f32>) {
+ sb_store_9((offset + 0u), value[0u]);
+ sb_store_9((offset + 16u), value[1u]);
+ sb_store_9((offset + 32u), value[2u]);
+ sb_store_9((offset + 48u), value[3u]);
+}
+
+fn sb_store_25(offset : u32, value : mat4x4<f32>) {
+ sb_store_13((offset + 0u), value[0u]);
+ sb_store_13((offset + 16u), value[1u]);
+ sb_store_13((offset + 32u), value[2u]);
+ sb_store_13((offset + 48u), value[3u]);
+}
+
+fn sb_store_26(offset : u32, value : mat2x2<f16>) {
+ sb_store_8((offset + 0u), value[0u]);
+ sb_store_8((offset + 4u), value[1u]);
+}
+
+fn sb_store_27(offset : u32, value : mat2x3<f16>) {
+ sb_store_12((offset + 0u), value[0u]);
+ sb_store_12((offset + 8u), value[1u]);
+}
+
+fn sb_store_28(offset : u32, value : mat2x4<f16>) {
+ sb_store_16((offset + 0u), value[0u]);
+ sb_store_16((offset + 8u), value[1u]);
+}
+
+fn sb_store_29(offset : u32, value : mat3x2<f16>) {
+ sb_store_8((offset + 0u), value[0u]);
+ sb_store_8((offset + 4u), value[1u]);
+ sb_store_8((offset + 8u), value[2u]);
+}
+
+fn sb_store_30(offset : u32, value : mat3x3<f16>) {
+ sb_store_12((offset + 0u), value[0u]);
+ sb_store_12((offset + 8u), value[1u]);
+ sb_store_12((offset + 16u), value[2u]);
+}
+
+fn sb_store_31(offset : u32, value : mat3x4<f16>) {
+ sb_store_16((offset + 0u), value[0u]);
+ sb_store_16((offset + 8u), value[1u]);
+ sb_store_16((offset + 16u), value[2u]);
+}
+
+fn sb_store_32(offset : u32, value : mat4x2<f16>) {
+ sb_store_8((offset + 0u), value[0u]);
+ sb_store_8((offset + 4u), value[1u]);
+ sb_store_8((offset + 8u), value[2u]);
+ sb_store_8((offset + 12u), value[3u]);
+}
+
+fn sb_store_33(offset : u32, value : mat4x3<f16>) {
+ sb_store_12((offset + 0u), value[0u]);
+ sb_store_12((offset + 8u), value[1u]);
+ sb_store_12((offset + 16u), value[2u]);
+ sb_store_12((offset + 24u), value[3u]);
+}
+
+fn sb_store_34(offset : u32, value : mat4x4<f16>) {
+ sb_store_16((offset + 0u), value[0u]);
+ sb_store_16((offset + 8u), value[1u]);
+ sb_store_16((offset + 16u), value[2u]);
+ sb_store_16((offset + 24u), value[3u]);
+}
+
+fn sb_store_35(offset : u32, value : array<vec3<f32>, 2u>) {
+ var array_1 = value;
+ for(var i = 0u; (i < 2u); i = (i + 1u)) {
+ sb_store_9((offset + (i * 16u)), array_1[i]);
+ }
+}
+
+fn sb_store_36(offset : u32, value : array<mat4x2<f16>, 2u>) {
+ var array_2 = value;
+ for(var i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) {
+ sb_store_32((offset + (i_1 * 16u)), array_2[i_1]);
+ }
+}
+
+fn sb_store(offset : u32, value : SB) {
+ sb_store_1((offset + 0u), value.scalar_f32);
+ sb_store_2((offset + 4u), value.scalar_i32);
+ sb_store_3((offset + 8u), value.scalar_u32);
+ sb_store_4((offset + 12u), value.scalar_f16);
+ sb_store_5((offset + 16u), value.vec2_f32);
+ sb_store_6((offset + 24u), value.vec2_i32);
+ sb_store_7((offset + 32u), value.vec2_u32);
+ sb_store_8((offset + 40u), value.vec2_f16);
+ sb_store_9((offset + 48u), value.vec3_f32);
+ sb_store_10((offset + 64u), value.vec3_i32);
+ sb_store_11((offset + 80u), value.vec3_u32);
+ sb_store_12((offset + 96u), value.vec3_f16);
+ sb_store_13((offset + 112u), value.vec4_f32);
+ sb_store_14((offset + 128u), value.vec4_i32);
+ sb_store_15((offset + 144u), value.vec4_u32);
+ sb_store_16((offset + 160u), value.vec4_f16);
+ sb_store_17((offset + 168u), value.mat2x2_f32);
+ sb_store_18((offset + 192u), value.mat2x3_f32);
+ sb_store_19((offset + 224u), value.mat2x4_f32);
+ sb_store_20((offset + 256u), value.mat3x2_f32);
+ sb_store_21((offset + 288u), value.mat3x3_f32);
+ sb_store_22((offset + 336u), value.mat3x4_f32);
+ sb_store_23((offset + 384u), value.mat4x2_f32);
+ sb_store_24((offset + 416u), value.mat4x3_f32);
+ sb_store_25((offset + 480u), value.mat4x4_f32);
+ sb_store_26((offset + 544u), value.mat2x2_f16);
+ sb_store_27((offset + 552u), value.mat2x3_f16);
+ sb_store_28((offset + 568u), value.mat2x4_f16);
+ sb_store_29((offset + 584u), value.mat3x2_f16);
+ sb_store_30((offset + 600u), value.mat3x3_f16);
+ sb_store_31((offset + 624u), value.mat3x4_f16);
+ sb_store_32((offset + 648u), value.mat4x2_f16);
+ sb_store_33((offset + 664u), value.mat4x3_f16);
+ sb_store_34((offset + 696u), value.mat4x4_f16);
+ sb_store_35((offset + 736u), value.arr2_vec3_f32);
+ sb_store_36((offset + 768u), value.arr2_mat4x2_f16);
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ sb_store(0u, SB());
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, StoreStructure_OutOfOrder) {
+ auto* src = R"(
+enable f16;
+
+@compute @workgroup_size(1)
+fn main() {
+ sb = SB();
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+struct SB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_mat4x2_f16 : array<mat4x2<f16>, 2>,
+};
+)";
+
+ auto* expect = R"(
+enable f16;
+
+@internal(intrinsic_store_storage_f32) @internal(disable_validation__function_has_no_body)
+fn sb_store_1(offset : u32, value : f32)
+
+@internal(intrinsic_store_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sb_store_2(offset : u32, value : i32)
+
+@internal(intrinsic_store_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sb_store_3(offset : u32, value : u32)
+
+@internal(intrinsic_store_storage_f16) @internal(disable_validation__function_has_no_body)
+fn sb_store_4(offset : u32, value : f16)
+
+@internal(intrinsic_store_storage_vec2_f32) @internal(disable_validation__function_has_no_body)
+fn sb_store_5(offset : u32, value : vec2<f32>)
+
+@internal(intrinsic_store_storage_vec2_i32) @internal(disable_validation__function_has_no_body)
+fn sb_store_6(offset : u32, value : vec2<i32>)
+
+@internal(intrinsic_store_storage_vec2_u32) @internal(disable_validation__function_has_no_body)
+fn sb_store_7(offset : u32, value : vec2<u32>)
+
+@internal(intrinsic_store_storage_vec2_f16) @internal(disable_validation__function_has_no_body)
+fn sb_store_8(offset : u32, value : vec2<f16>)
+
+@internal(intrinsic_store_storage_vec3_f32) @internal(disable_validation__function_has_no_body)
+fn sb_store_9(offset : u32, value : vec3<f32>)
+
+@internal(intrinsic_store_storage_vec3_i32) @internal(disable_validation__function_has_no_body)
+fn sb_store_10(offset : u32, value : vec3<i32>)
+
+@internal(intrinsic_store_storage_vec3_u32) @internal(disable_validation__function_has_no_body)
+fn sb_store_11(offset : u32, value : vec3<u32>)
+
+@internal(intrinsic_store_storage_vec3_f16) @internal(disable_validation__function_has_no_body)
+fn sb_store_12(offset : u32, value : vec3<f16>)
+
+@internal(intrinsic_store_storage_vec4_f32) @internal(disable_validation__function_has_no_body)
+fn sb_store_13(offset : u32, value : vec4<f32>)
+
+@internal(intrinsic_store_storage_vec4_i32) @internal(disable_validation__function_has_no_body)
+fn sb_store_14(offset : u32, value : vec4<i32>)
+
+@internal(intrinsic_store_storage_vec4_u32) @internal(disable_validation__function_has_no_body)
+fn sb_store_15(offset : u32, value : vec4<u32>)
+
+@internal(intrinsic_store_storage_vec4_f16) @internal(disable_validation__function_has_no_body)
+fn sb_store_16(offset : u32, value : vec4<f16>)
+
+fn sb_store_17(offset : u32, value : mat2x2<f32>) {
+ sb_store_5((offset + 0u), value[0u]);
+ sb_store_5((offset + 8u), value[1u]);
+}
+
+fn sb_store_18(offset : u32, value : mat2x3<f32>) {
+ sb_store_9((offset + 0u), value[0u]);
+ sb_store_9((offset + 16u), value[1u]);
+}
+
+fn sb_store_19(offset : u32, value : mat2x4<f32>) {
+ sb_store_13((offset + 0u), value[0u]);
+ sb_store_13((offset + 16u), value[1u]);
+}
+
+fn sb_store_20(offset : u32, value : mat3x2<f32>) {
+ sb_store_5((offset + 0u), value[0u]);
+ sb_store_5((offset + 8u), value[1u]);
+ sb_store_5((offset + 16u), value[2u]);
+}
+
+fn sb_store_21(offset : u32, value : mat3x3<f32>) {
+ sb_store_9((offset + 0u), value[0u]);
+ sb_store_9((offset + 16u), value[1u]);
+ sb_store_9((offset + 32u), value[2u]);
+}
+
+fn sb_store_22(offset : u32, value : mat3x4<f32>) {
+ sb_store_13((offset + 0u), value[0u]);
+ sb_store_13((offset + 16u), value[1u]);
+ sb_store_13((offset + 32u), value[2u]);
+}
+
+fn sb_store_23(offset : u32, value : mat4x2<f32>) {
+ sb_store_5((offset + 0u), value[0u]);
+ sb_store_5((offset + 8u), value[1u]);
+ sb_store_5((offset + 16u), value[2u]);
+ sb_store_5((offset + 24u), value[3u]);
+}
+
+fn sb_store_24(offset : u32, value : mat4x3<f32>) {
+ sb_store_9((offset + 0u), value[0u]);
+ sb_store_9((offset + 16u), value[1u]);
+ sb_store_9((offset + 32u), value[2u]);
+ sb_store_9((offset + 48u), value[3u]);
+}
+
+fn sb_store_25(offset : u32, value : mat4x4<f32>) {
+ sb_store_13((offset + 0u), value[0u]);
+ sb_store_13((offset + 16u), value[1u]);
+ sb_store_13((offset + 32u), value[2u]);
+ sb_store_13((offset + 48u), value[3u]);
+}
+
+fn sb_store_26(offset : u32, value : mat2x2<f16>) {
+ sb_store_8((offset + 0u), value[0u]);
+ sb_store_8((offset + 4u), value[1u]);
+}
+
+fn sb_store_27(offset : u32, value : mat2x3<f16>) {
+ sb_store_12((offset + 0u), value[0u]);
+ sb_store_12((offset + 8u), value[1u]);
+}
+
+fn sb_store_28(offset : u32, value : mat2x4<f16>) {
+ sb_store_16((offset + 0u), value[0u]);
+ sb_store_16((offset + 8u), value[1u]);
+}
+
+fn sb_store_29(offset : u32, value : mat3x2<f16>) {
+ sb_store_8((offset + 0u), value[0u]);
+ sb_store_8((offset + 4u), value[1u]);
+ sb_store_8((offset + 8u), value[2u]);
+}
+
+fn sb_store_30(offset : u32, value : mat3x3<f16>) {
+ sb_store_12((offset + 0u), value[0u]);
+ sb_store_12((offset + 8u), value[1u]);
+ sb_store_12((offset + 16u), value[2u]);
+}
+
+fn sb_store_31(offset : u32, value : mat3x4<f16>) {
+ sb_store_16((offset + 0u), value[0u]);
+ sb_store_16((offset + 8u), value[1u]);
+ sb_store_16((offset + 16u), value[2u]);
+}
+
+fn sb_store_32(offset : u32, value : mat4x2<f16>) {
+ sb_store_8((offset + 0u), value[0u]);
+ sb_store_8((offset + 4u), value[1u]);
+ sb_store_8((offset + 8u), value[2u]);
+ sb_store_8((offset + 12u), value[3u]);
+}
+
+fn sb_store_33(offset : u32, value : mat4x3<f16>) {
+ sb_store_12((offset + 0u), value[0u]);
+ sb_store_12((offset + 8u), value[1u]);
+ sb_store_12((offset + 16u), value[2u]);
+ sb_store_12((offset + 24u), value[3u]);
+}
+
+fn sb_store_34(offset : u32, value : mat4x4<f16>) {
+ sb_store_16((offset + 0u), value[0u]);
+ sb_store_16((offset + 8u), value[1u]);
+ sb_store_16((offset + 16u), value[2u]);
+ sb_store_16((offset + 24u), value[3u]);
+}
+
+fn sb_store_35(offset : u32, value : array<vec3<f32>, 2u>) {
+ var array_1 = value;
+ for(var i = 0u; (i < 2u); i = (i + 1u)) {
+ sb_store_9((offset + (i * 16u)), array_1[i]);
+ }
+}
+
+fn sb_store_36(offset : u32, value : array<mat4x2<f16>, 2u>) {
+ var array_2 = value;
+ for(var i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) {
+ sb_store_32((offset + (i_1 * 16u)), array_2[i_1]);
+ }
+}
+
+fn sb_store(offset : u32, value : SB) {
+ sb_store_1((offset + 0u), value.scalar_f32);
+ sb_store_2((offset + 4u), value.scalar_i32);
+ sb_store_3((offset + 8u), value.scalar_u32);
+ sb_store_4((offset + 12u), value.scalar_f16);
+ sb_store_5((offset + 16u), value.vec2_f32);
+ sb_store_6((offset + 24u), value.vec2_i32);
+ sb_store_7((offset + 32u), value.vec2_u32);
+ sb_store_8((offset + 40u), value.vec2_f16);
+ sb_store_9((offset + 48u), value.vec3_f32);
+ sb_store_10((offset + 64u), value.vec3_i32);
+ sb_store_11((offset + 80u), value.vec3_u32);
+ sb_store_12((offset + 96u), value.vec3_f16);
+ sb_store_13((offset + 112u), value.vec4_f32);
+ sb_store_14((offset + 128u), value.vec4_i32);
+ sb_store_15((offset + 144u), value.vec4_u32);
+ sb_store_16((offset + 160u), value.vec4_f16);
+ sb_store_17((offset + 168u), value.mat2x2_f32);
+ sb_store_18((offset + 192u), value.mat2x3_f32);
+ sb_store_19((offset + 224u), value.mat2x4_f32);
+ sb_store_20((offset + 256u), value.mat3x2_f32);
+ sb_store_21((offset + 288u), value.mat3x3_f32);
+ sb_store_22((offset + 336u), value.mat3x4_f32);
+ sb_store_23((offset + 384u), value.mat4x2_f32);
+ sb_store_24((offset + 416u), value.mat4x3_f32);
+ sb_store_25((offset + 480u), value.mat4x4_f32);
+ sb_store_26((offset + 544u), value.mat2x2_f16);
+ sb_store_27((offset + 552u), value.mat2x3_f16);
+ sb_store_28((offset + 568u), value.mat2x4_f16);
+ sb_store_29((offset + 584u), value.mat3x2_f16);
+ sb_store_30((offset + 600u), value.mat3x3_f16);
+ sb_store_31((offset + 624u), value.mat3x4_f16);
+ sb_store_32((offset + 648u), value.mat4x2_f16);
+ sb_store_33((offset + 664u), value.mat4x3_f16);
+ sb_store_34((offset + 696u), value.mat4x4_f16);
+ sb_store_35((offset + 736u), value.arr2_vec3_f32);
+ sb_store_36((offset + 768u), value.arr2_mat4x2_f16);
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ sb_store(0u, SB());
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+struct SB {
+ scalar_f32 : f32,
+ scalar_i32 : i32,
+ scalar_u32 : u32,
+ scalar_f16 : f16,
+ vec2_f32 : vec2<f32>,
+ vec2_i32 : vec2<i32>,
+ vec2_u32 : vec2<u32>,
+ vec2_f16 : vec2<f16>,
+ vec3_f32 : vec3<f32>,
+ vec3_i32 : vec3<i32>,
+ vec3_u32 : vec3<u32>,
+ vec3_f16 : vec3<f16>,
+ vec4_f32 : vec4<f32>,
+ vec4_i32 : vec4<i32>,
+ vec4_u32 : vec4<u32>,
+ vec4_f16 : vec4<f16>,
+ mat2x2_f32 : mat2x2<f32>,
+ mat2x3_f32 : mat2x3<f32>,
+ mat2x4_f32 : mat2x4<f32>,
+ mat3x2_f32 : mat3x2<f32>,
+ mat3x3_f32 : mat3x3<f32>,
+ mat3x4_f32 : mat3x4<f32>,
+ mat4x2_f32 : mat4x2<f32>,
+ mat4x3_f32 : mat4x3<f32>,
+ mat4x4_f32 : mat4x4<f32>,
+ mat2x2_f16 : mat2x2<f16>,
+ mat2x3_f16 : mat2x3<f16>,
+ mat2x4_f16 : mat2x4<f16>,
+ mat3x2_f16 : mat3x2<f16>,
+ mat3x3_f16 : mat3x3<f16>,
+ mat3x4_f16 : mat3x4<f16>,
+ mat4x2_f16 : mat4x2<f16>,
+ mat4x3_f16 : mat4x3<f16>,
+ mat4x4_f16 : mat4x4<f16>,
+ arr2_vec3_f32 : array<vec3<f32>, 2>,
+ arr2_mat4x2_f16 : array<mat4x2<f16>, 2>,
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, ComplexStaticAccessChain) {
+ auto* src = R"(
+// sizeof(S1) == 32
+// alignof(S1) == 16
+struct S1 {
+ a : i32,
+ b : vec3<f32>,
+ c : i32,
+};
+
+// sizeof(S2) == 116
+// alignof(S2) == 16
+struct S2 {
+ a : i32,
+ b : array<S1, 3>,
+ c : i32,
+};
+
+struct SB {
+ @size(128)
+ a : i32,
+ b : array<S2>,
+};
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ var x : f32 = sb.b[4].b[1].b.z;
+}
+)";
+
+ // sb.b[4].b[1].b.z
+ // ^ ^ ^ ^ ^ ^
+ // | | | | | |
+ // 128 | |688 | 712
+ // | | |
+ // 640 656 704
+
+ auto* expect = R"(
+struct S1 {
+ a : i32,
+ b : vec3<f32>,
+ c : i32,
+}
+
+struct S2 {
+ a : i32,
+ b : array<S1, 3>,
+ c : i32,
+}
+
+struct SB {
+ @size(128)
+ a : i32,
+ b : array<S2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+@internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load(offset : u32) -> f32
+
+@compute @workgroup_size(1)
+fn main() {
+ var x : f32 = sb_load(712u);
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, ComplexStaticAccessChain_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ var x : f32 = sb.b[4].b[1].b.z;
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+struct SB {
+ @size(128)
+ a : i32,
+ b : array<S2>,
+};
+
+struct S2 {
+ a : i32,
+ b : array<S1, 3>,
+ c : i32,
+};
+
+struct S1 {
+ a : i32,
+ b : vec3<f32>,
+ c : i32,
+};
+)";
+
+ // sb.b[4].b[1].b.z
+ // ^ ^ ^ ^ ^ ^
+ // | | | | | |
+ // 128 | |688 | 712
+ // | | |
+ // 640 656 704
+
+ auto* expect = R"(
+@internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load(offset : u32) -> f32
+
+@compute @workgroup_size(1)
+fn main() {
+ var x : f32 = sb_load(712u);
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+struct SB {
+ @size(128)
+ a : i32,
+ b : array<S2>,
+}
+
+struct S2 {
+ a : i32,
+ b : array<S1, 3>,
+ c : i32,
+}
+
+struct S1 {
+ a : i32,
+ b : vec3<f32>,
+ c : i32,
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, ComplexDynamicAccessChain) {
+ auto* src = R"(
+struct S1 {
+ a : i32,
+ b : vec3<f32>,
+ c : i32,
+};
+
+struct S2 {
+ a : i32,
+ b : array<S1, 3>,
+ c : i32,
+};
+
+struct SB {
+ @size(128)
+ a : i32,
+ b : array<S2>,
+};
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ var i : i32 = 4;
+ var j : u32 = 1u;
+ var k : i32 = 2;
+ var x : f32 = sb.b[i].b[j].b[k];
+}
+)";
+
+ auto* expect = R"(
+struct S1 {
+ a : i32,
+ b : vec3<f32>,
+ c : i32,
+}
+
+struct S2 {
+ a : i32,
+ b : array<S1, 3>,
+ c : i32,
+}
+
+struct SB {
+ @size(128)
+ a : i32,
+ b : array<S2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+@internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load(offset : u32) -> f32
+
+@compute @workgroup_size(1)
+fn main() {
+ var i : i32 = 4;
+ var j : u32 = 1u;
+ var k : i32 = 2;
+ var x : f32 = sb_load((((((128u + (128u * u32(i))) + 16u) + (32u * j)) + 16u) + (4u * u32(k))));
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, ComplexDynamicAccessChain_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ var i : i32 = 4;
+ var j : u32 = 1u;
+ var k : i32 = 2;
+ var x : f32 = sb.b[i].b[j].b[k];
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+struct SB {
+ @size(128)
+ a : i32,
+ b : array<S2>
+};
+
+struct S2 {
+ a : i32,
+ b : array<S1, 3>,
+ c : i32,
+};
+
+struct S1 {
+ a : i32,
+ b : vec3<f32>,
+ c : i32,
+};
+)";
+
+ auto* expect = R"(
+@internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load(offset : u32) -> f32
+
+@compute @workgroup_size(1)
+fn main() {
+ var i : i32 = 4;
+ var j : u32 = 1u;
+ var k : i32 = 2;
+ var x : f32 = sb_load((((((128u + (128u * u32(i))) + 16u) + (32u * j)) + 16u) + (4u * u32(k))));
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+struct SB {
+ @size(128)
+ a : i32,
+ b : array<S2>,
+}
+
+struct S2 {
+ a : i32,
+ b : array<S1, 3>,
+ c : i32,
+}
+
+struct S1 {
+ a : i32,
+ b : vec3<f32>,
+ c : i32,
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, ComplexDynamicAccessChainWithAliases) {
+ auto* src = R"(
+struct S1 {
+ a : i32,
+ b : vec3<f32>,
+ c : i32,
+};
+
+alias A1 = S1;
+
+alias A1_Array = array<S1, 3>;
+
+struct S2 {
+ a : i32,
+ b : A1_Array,
+ c : i32,
+};
+
+alias A2 = S2;
+
+alias A2_Array = array<S2>;
+
+struct SB {
+ @size(128)
+ a : i32,
+ b : A2_Array,
+};
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ var i : i32 = 4;
+ var j : u32 = 1u;
+ var k : i32 = 2;
+ var x : f32 = sb.b[i].b[j].b[k];
+}
+)";
+
+ auto* expect = R"(
+struct S1 {
+ a : i32,
+ b : vec3<f32>,
+ c : i32,
+}
+
+alias A1 = S1;
+
+alias A1_Array = array<S1, 3>;
+
+struct S2 {
+ a : i32,
+ b : A1_Array,
+ c : i32,
+}
+
+alias A2 = S2;
+
+alias A2_Array = array<S2>;
+
+struct SB {
+ @size(128)
+ a : i32,
+ b : A2_Array,
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+@internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load(offset : u32) -> f32
+
+@compute @workgroup_size(1)
+fn main() {
+ var i : i32 = 4;
+ var j : u32 = 1u;
+ var k : i32 = 2;
+ var x : f32 = sb_load((((((128u + (128u * u32(i))) + 16u) + (32u * j)) + 16u) + (4u * u32(k))));
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, ComplexDynamicAccessChainWithAliases_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ var i : i32 = 4;
+ var j : u32 = 1u;
+ var k : i32 = 2;
+ var x : f32 = sb.b[i].b[j].b[k];
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+struct SB {
+ @size(128)
+ a : i32,
+ b : A2_Array,
+};
+
+alias A2_Array = array<S2>;
+
+alias A2 = S2;
+
+struct S2 {
+ a : i32,
+ b : A1_Array,
+ c : i32,
+};
+
+alias A1 = S1;
+
+alias A1_Array = array<S1, 3>;
+
+struct S1 {
+ a : i32,
+ b : vec3<f32>,
+ c : i32,
+};
+)";
+
+ auto* expect = R"(
+@internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
+fn sb_load(offset : u32) -> f32
+
+@compute @workgroup_size(1)
+fn main() {
+ var i : i32 = 4;
+ var j : u32 = 1u;
+ var k : i32 = 2;
+ var x : f32 = sb_load((((((128u + (128u * u32(i))) + 16u) + (32u * j)) + 16u) + (4u * u32(k))));
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+struct SB {
+ @size(128)
+ a : i32,
+ b : A2_Array,
+}
+
+alias A2_Array = array<S2>;
+
+alias A2 = S2;
+
+struct S2 {
+ a : i32,
+ b : A1_Array,
+ c : i32,
+}
+
+alias A1 = S1;
+
+alias A1_Array = array<S1, 3>;
+
+struct S1 {
+ a : i32,
+ b : vec3<f32>,
+ c : i32,
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, StorageBufferAtomics) {
+ auto* src = R"(
+struct SB {
+ padding : vec4<f32>,
+ a : atomic<i32>,
+ b : atomic<u32>,
+};
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ atomicStore(&sb.a, 123);
+ atomicLoad(&sb.a);
+ atomicAdd(&sb.a, 123);
+ atomicSub(&sb.a, 123);
+ atomicMax(&sb.a, 123);
+ atomicMin(&sb.a, 123);
+ atomicAnd(&sb.a, 123);
+ atomicOr(&sb.a, 123);
+ atomicXor(&sb.a, 123);
+ atomicExchange(&sb.a, 123);
+ atomicCompareExchangeWeak(&sb.a, 123, 345);
+
+ atomicStore(&sb.b, 123u);
+ atomicLoad(&sb.b);
+ atomicAdd(&sb.b, 123u);
+ atomicSub(&sb.b, 123u);
+ atomicMax(&sb.b, 123u);
+ atomicMin(&sb.b, 123u);
+ atomicAnd(&sb.b, 123u);
+ atomicOr(&sb.b, 123u);
+ atomicXor(&sb.b, 123u);
+ atomicExchange(&sb.b, 123u);
+ atomicCompareExchangeWeak(&sb.b, 123u, 345u);
+}
+)";
+
+ auto* expect = R"(
+struct SB {
+ padding : vec4<f32>,
+ a : atomic<i32>,
+ b : atomic<u32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+@internal(intrinsic_atomic_store_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicStore(offset : u32, param_1 : i32)
+
+@internal(intrinsic_atomic_load_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicLoad(offset : u32) -> i32
+
+@internal(intrinsic_atomic_add_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicAdd(offset : u32, param_1 : i32) -> i32
+
+@internal(intrinsic_atomic_sub_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicSub(offset : u32, param_1 : i32) -> i32
+
+@internal(intrinsic_atomic_max_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicMax(offset : u32, param_1 : i32) -> i32
+
+@internal(intrinsic_atomic_min_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicMin(offset : u32, param_1 : i32) -> i32
+
+@internal(intrinsic_atomic_and_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicAnd(offset : u32, param_1 : i32) -> i32
+
+@internal(intrinsic_atomic_or_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicOr(offset : u32, param_1 : i32) -> i32
+
+@internal(intrinsic_atomic_xor_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicXor(offset : u32, param_1 : i32) -> i32
+
+@internal(intrinsic_atomic_exchange_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicExchange(offset : u32, param_1 : i32) -> i32
+
+@internal(intrinsic_atomic_compare_exchange_weak_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicCompareExchangeWeak(offset : u32, param_1 : i32, param_2 : i32) -> __atomic_compare_exchange_result_i32
+
+@internal(intrinsic_atomic_store_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicStore_1(offset : u32, param_1 : u32)
+
+@internal(intrinsic_atomic_load_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicLoad_1(offset : u32) -> u32
+
+@internal(intrinsic_atomic_add_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicAdd_1(offset : u32, param_1 : u32) -> u32
+
+@internal(intrinsic_atomic_sub_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicSub_1(offset : u32, param_1 : u32) -> u32
+
+@internal(intrinsic_atomic_max_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicMax_1(offset : u32, param_1 : u32) -> u32
+
+@internal(intrinsic_atomic_min_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicMin_1(offset : u32, param_1 : u32) -> u32
+
+@internal(intrinsic_atomic_and_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicAnd_1(offset : u32, param_1 : u32) -> u32
+
+@internal(intrinsic_atomic_or_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicOr_1(offset : u32, param_1 : u32) -> u32
+
+@internal(intrinsic_atomic_xor_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicXor_1(offset : u32, param_1 : u32) -> u32
+
+@internal(intrinsic_atomic_exchange_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicExchange_1(offset : u32, param_1 : u32) -> u32
+
+@internal(intrinsic_atomic_compare_exchange_weak_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicCompareExchangeWeak_1(offset : u32, param_1 : u32, param_2 : u32) -> __atomic_compare_exchange_result_u32
+
+@compute @workgroup_size(1)
+fn main() {
+ sbatomicStore(16u, 123);
+ sbatomicLoad(16u);
+ sbatomicAdd(16u, 123);
+ sbatomicSub(16u, 123);
+ sbatomicMax(16u, 123);
+ sbatomicMin(16u, 123);
+ sbatomicAnd(16u, 123);
+ sbatomicOr(16u, 123);
+ sbatomicXor(16u, 123);
+ sbatomicExchange(16u, 123);
+ sbatomicCompareExchangeWeak(16u, 123, 345);
+ sbatomicStore_1(20u, 123u);
+ sbatomicLoad_1(20u);
+ sbatomicAdd_1(20u, 123u);
+ sbatomicSub_1(20u, 123u);
+ sbatomicMax_1(20u, 123u);
+ sbatomicMin_1(20u, 123u);
+ sbatomicAnd_1(20u, 123u);
+ sbatomicOr_1(20u, 123u);
+ sbatomicXor_1(20u, 123u);
+ sbatomicExchange_1(20u, 123u);
+ sbatomicCompareExchangeWeak_1(20u, 123u, 345u);
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, StorageBufferAtomics_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ atomicStore(&sb.a, 123);
+ atomicLoad(&sb.a);
+ atomicAdd(&sb.a, 123);
+ atomicSub(&sb.a, 123);
+ atomicMax(&sb.a, 123);
+ atomicMin(&sb.a, 123);
+ atomicAnd(&sb.a, 123);
+ atomicOr(&sb.a, 123);
+ atomicXor(&sb.a, 123);
+ atomicExchange(&sb.a, 123);
+ atomicCompareExchangeWeak(&sb.a, 123, 345);
+
+ atomicStore(&sb.b, 123u);
+ atomicLoad(&sb.b);
+ atomicAdd(&sb.b, 123u);
+ atomicSub(&sb.b, 123u);
+ atomicMax(&sb.b, 123u);
+ atomicMin(&sb.b, 123u);
+ atomicAnd(&sb.b, 123u);
+ atomicOr(&sb.b, 123u);
+ atomicXor(&sb.b, 123u);
+ atomicExchange(&sb.b, 123u);
+ atomicCompareExchangeWeak(&sb.b, 123u, 345u);
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+struct SB {
+ padding : vec4<f32>,
+ a : atomic<i32>,
+ b : atomic<u32>,
+};
+)";
+
+ auto* expect = R"(
+@internal(intrinsic_atomic_store_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicStore(offset : u32, param_1 : i32)
+
+@internal(intrinsic_atomic_load_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicLoad(offset : u32) -> i32
+
+@internal(intrinsic_atomic_add_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicAdd(offset : u32, param_1 : i32) -> i32
+
+@internal(intrinsic_atomic_sub_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicSub(offset : u32, param_1 : i32) -> i32
+
+@internal(intrinsic_atomic_max_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicMax(offset : u32, param_1 : i32) -> i32
+
+@internal(intrinsic_atomic_min_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicMin(offset : u32, param_1 : i32) -> i32
+
+@internal(intrinsic_atomic_and_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicAnd(offset : u32, param_1 : i32) -> i32
+
+@internal(intrinsic_atomic_or_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicOr(offset : u32, param_1 : i32) -> i32
+
+@internal(intrinsic_atomic_xor_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicXor(offset : u32, param_1 : i32) -> i32
+
+@internal(intrinsic_atomic_exchange_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicExchange(offset : u32, param_1 : i32) -> i32
+
+@internal(intrinsic_atomic_compare_exchange_weak_storage_i32) @internal(disable_validation__function_has_no_body)
+fn sbatomicCompareExchangeWeak(offset : u32, param_1 : i32, param_2 : i32) -> __atomic_compare_exchange_result_i32
+
+@internal(intrinsic_atomic_store_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicStore_1(offset : u32, param_1 : u32)
+
+@internal(intrinsic_atomic_load_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicLoad_1(offset : u32) -> u32
+
+@internal(intrinsic_atomic_add_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicAdd_1(offset : u32, param_1 : u32) -> u32
+
+@internal(intrinsic_atomic_sub_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicSub_1(offset : u32, param_1 : u32) -> u32
+
+@internal(intrinsic_atomic_max_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicMax_1(offset : u32, param_1 : u32) -> u32
+
+@internal(intrinsic_atomic_min_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicMin_1(offset : u32, param_1 : u32) -> u32
+
+@internal(intrinsic_atomic_and_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicAnd_1(offset : u32, param_1 : u32) -> u32
+
+@internal(intrinsic_atomic_or_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicOr_1(offset : u32, param_1 : u32) -> u32
+
+@internal(intrinsic_atomic_xor_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicXor_1(offset : u32, param_1 : u32) -> u32
+
+@internal(intrinsic_atomic_exchange_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicExchange_1(offset : u32, param_1 : u32) -> u32
+
+@internal(intrinsic_atomic_compare_exchange_weak_storage_u32) @internal(disable_validation__function_has_no_body)
+fn sbatomicCompareExchangeWeak_1(offset : u32, param_1 : u32, param_2 : u32) -> __atomic_compare_exchange_result_u32
+
+@compute @workgroup_size(1)
+fn main() {
+ sbatomicStore(16u, 123);
+ sbatomicLoad(16u);
+ sbatomicAdd(16u, 123);
+ sbatomicSub(16u, 123);
+ sbatomicMax(16u, 123);
+ sbatomicMin(16u, 123);
+ sbatomicAnd(16u, 123);
+ sbatomicOr(16u, 123);
+ sbatomicXor(16u, 123);
+ sbatomicExchange(16u, 123);
+ sbatomicCompareExchangeWeak(16u, 123, 345);
+ sbatomicStore_1(20u, 123u);
+ sbatomicLoad_1(20u);
+ sbatomicAdd_1(20u, 123u);
+ sbatomicSub_1(20u, 123u);
+ sbatomicMax_1(20u, 123u);
+ sbatomicMin_1(20u, 123u);
+ sbatomicAnd_1(20u, 123u);
+ sbatomicOr_1(20u, 123u);
+ sbatomicXor_1(20u, 123u);
+ sbatomicExchange_1(20u, 123u);
+ sbatomicCompareExchangeWeak_1(20u, 123u, 345u);
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+struct SB {
+ padding : vec4<f32>,
+ a : atomic<i32>,
+ b : atomic<u32>,
+}
+)";
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, WorkgroupBufferAtomics) {
+ auto* src = R"(
+struct S {
+ padding : vec4<f32>,
+ a : atomic<i32>,
+ b : atomic<u32>,
+}
+
+var<workgroup> w : S;
+
+@compute @workgroup_size(1)
+fn main() {
+ atomicStore(&(w.a), 123);
+ atomicLoad(&(w.a));
+ atomicAdd(&(w.a), 123);
+ atomicSub(&(w.a), 123);
+ atomicMax(&(w.a), 123);
+ atomicMin(&(w.a), 123);
+ atomicAnd(&(w.a), 123);
+ atomicOr(&(w.a), 123);
+ atomicXor(&(w.a), 123);
+ atomicExchange(&(w.a), 123);
+ atomicCompareExchangeWeak(&(w.a), 123, 345);
+ atomicStore(&(w.b), 123u);
+ atomicLoad(&(w.b));
+ atomicAdd(&(w.b), 123u);
+ atomicSub(&(w.b), 123u);
+ atomicMax(&(w.b), 123u);
+ atomicMin(&(w.b), 123u);
+ atomicAnd(&(w.b), 123u);
+ atomicOr(&(w.b), 123u);
+ atomicXor(&(w.b), 123u);
+ atomicExchange(&(w.b), 123u);
+ atomicCompareExchangeWeak(&(w.b), 123u, 345u);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeMemoryAccessTest, WorkgroupBufferAtomics_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ atomicStore(&(w.a), 123);
+ atomicLoad(&(w.a));
+ atomicAdd(&(w.a), 123);
+ atomicSub(&(w.a), 123);
+ atomicMax(&(w.a), 123);
+ atomicMin(&(w.a), 123);
+ atomicAnd(&(w.a), 123);
+ atomicOr(&(w.a), 123);
+ atomicXor(&(w.a), 123);
+ atomicExchange(&(w.a), 123);
+ atomicCompareExchangeWeak(&(w.a), 123, 345);
+ atomicStore(&(w.b), 123u);
+ atomicLoad(&(w.b));
+ atomicAdd(&(w.b), 123u);
+ atomicSub(&(w.b), 123u);
+ atomicMax(&(w.b), 123u);
+ atomicMin(&(w.b), 123u);
+ atomicAnd(&(w.b), 123u);
+ atomicOr(&(w.b), 123u);
+ atomicXor(&(w.b), 123u);
+ atomicExchange(&(w.b), 123u);
+ atomicCompareExchangeWeak(&(w.b), 123u, 345u);
+}
+
+var<workgroup> w : S;
+
+struct S {
+ padding : vec4<f32>,
+ a : atomic<i32>,
+ b : atomic<u32>,
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<DecomposeMemoryAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/decompose_strided_array.cc b/src/tint/lang/wgsl/ast/transform/decompose_strided_array.cc
new file mode 100644
index 0000000..08dec78
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/decompose_strided_array.cc
@@ -0,0 +1,178 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/decompose_strided_array.h"
+
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/type_expression.h"
+#include "src/tint/sem/value_constructor.h"
+#include "src/tint/sem/value_expression.h"
+#include "src/tint/utils/hash.h"
+#include "src/tint/utils/map.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::DecomposeStridedArray);
+
+namespace tint::ast::transform {
+namespace {
+
+using DecomposedArrays = std::unordered_map<const type::Array*, Symbol>;
+
+bool ShouldRun(const Program* program) {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (auto* ident = node->As<TemplatedIdentifier>()) {
+ if (GetAttribute<StrideAttribute>(ident->attributes)) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+DecomposeStridedArray::DecomposeStridedArray() = default;
+
+DecomposeStridedArray::~DecomposeStridedArray() = default;
+
+Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+ const auto& sem = src->Sem();
+
+ static constexpr const char* kMemberName = "el";
+
+ // Maps an array type in the source program to the name of the struct wrapper
+ // type in the target program.
+ std::unordered_map<const type::Array*, Symbol> decomposed;
+
+ // Find and replace all arrays with a @stride attribute with a array that has
+ // the @stride removed. If the source array stride does not match the natural
+ // stride for the array element type, then replace the array element type with
+ // a structure, holding a single field with a @size attribute equal to the
+ // array stride.
+ ctx.ReplaceAll([&](const IdentifierExpression* expr) -> const IdentifierExpression* {
+ auto* ident = expr->identifier->As<TemplatedIdentifier>();
+ if (!ident) {
+ return nullptr;
+ }
+ auto* type_expr = sem.Get<sem::TypeExpression>(expr);
+ if (!type_expr) {
+ return nullptr;
+ }
+ auto* arr = type_expr->Type()->As<type::Array>();
+ if (!arr) {
+ return nullptr;
+ }
+ if (!arr->IsStrideImplicit()) {
+ auto el_ty = utils::GetOrCreate(decomposed, arr, [&] {
+ auto name = b.Symbols().New("strided_arr");
+ auto* member_ty = ctx.Clone(ident->arguments[0]->As<IdentifierExpression>());
+ auto* member = b.Member(kMemberName, Type{member_ty},
+ utils::Vector{
+ b.MemberSize(AInt(arr->Stride())),
+ });
+ b.Structure(name, utils::Vector{member});
+ return name;
+ });
+ if (ident->arguments.Length() > 1) {
+ auto* count = ctx.Clone(ident->arguments[1]);
+ return b.Expr(b.ty.array(b.ty(el_ty), count));
+ } else {
+ return b.Expr(b.ty.array(b.ty(el_ty)));
+ }
+ }
+ if (GetAttribute<StrideAttribute>(ident->attributes)) {
+ // Strip the @stride attribute
+ auto* ty = ctx.Clone(ident->arguments[0]->As<IdentifierExpression>());
+ if (ident->arguments.Length() > 1) {
+ auto* count = ctx.Clone(ident->arguments[1]);
+ return b.Expr(b.ty.array(Type{ty}, count));
+ } else {
+ return b.Expr(b.ty.array(Type{ty}));
+ }
+ }
+ return nullptr;
+ });
+
+ // Find all array index-accessors expressions for arrays that have had their
+ // element changed to a single field structure. These expressions are adjusted
+ // to insert an additional member accessor for the single structure field.
+ // Example: `arr[i]` -> `arr[i].el`
+ ctx.ReplaceAll([&](const IndexAccessorExpression* idx) -> const Expression* {
+ if (auto* ty = src->TypeOf(idx->object)) {
+ if (auto* arr = ty->UnwrapRef()->As<type::Array>()) {
+ if (!arr->IsStrideImplicit()) {
+ auto* expr = ctx.CloneWithoutTransform(idx);
+ return b.MemberAccessor(expr, kMemberName);
+ }
+ }
+ }
+ return nullptr;
+ });
+
+ // Find all constructor expressions for array types that have had their element changed to a
+ // single field structure. These constructors are adjusted to wrap each of the arguments with an
+ // additional initializer for the new element structure type. Example:
+ // `@stride(32) array<i32, 3>(1, 2, 3)`
+ // ->
+ // `array<strided_arr, 3>(strided_arr(1), strided_arr(2), strided_arr(3))`
+ ctx.ReplaceAll([&](const CallExpression* expr) -> const Expression* {
+ if (!expr->args.IsEmpty()) {
+ if (auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>()) {
+ if (auto* ctor = call->Target()->As<sem::ValueConstructor>()) {
+ if (auto* arr = ctor->ReturnType()->As<type::Array>()) {
+ // Begin by cloning the array initializer type or name
+ // If this is an unaliased array, this may add a new entry to
+ // decomposed.
+ // If this is an aliased array, decomposed should already be
+ // populated with any strided aliases.
+
+ auto* target = ctx.Clone(expr->target);
+
+ utils::Vector<const Expression*, 8> args;
+ if (auto it = decomposed.find(arr); it != decomposed.end()) {
+ args.Reserve(expr->args.Length());
+ for (auto* arg : expr->args) {
+ args.Push(b.Call(it->second, ctx.Clone(arg)));
+ }
+ } else {
+ args = ctx.Clone(expr->args);
+ }
+
+ return b.Call(target, std::move(args));
+ }
+ }
+ }
+ }
+ return nullptr;
+ });
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/decompose_strided_array.h b/src/tint/lang/wgsl/ast/transform/decompose_strided_array.h
new file mode 100644
index 0000000..2839672
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/decompose_strided_array.h
@@ -0,0 +1,46 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_DECOMPOSE_STRIDED_ARRAY_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_DECOMPOSE_STRIDED_ARRAY_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// DecomposeStridedArray transforms replaces arrays with a non-default
+/// `@stride` attribute with an array of structure elements, where the
+/// structure contains a single field with an equivalent `@size` attribute.
+/// `@stride` attributes on arrays that match the default stride are also
+/// removed.
+///
+/// @note Depends on the following transforms to have been run first:
+/// * SimplifyPointers
+class DecomposeStridedArray final : public utils::Castable<DecomposeStridedArray, Transform> {
+ public:
+ /// Constructor
+ DecomposeStridedArray();
+
+ /// Destructor
+ ~DecomposeStridedArray() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_DECOMPOSE_STRIDED_ARRAY_H_
diff --git a/src/tint/lang/wgsl/ast/transform/decompose_strided_array_test.cc b/src/tint/lang/wgsl/ast/transform/decompose_strided_array_test.cc
new file mode 100644
index 0000000..18a1208
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/decompose_strided_array_test.cc
@@ -0,0 +1,741 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/decompose_strided_array.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
+#include "src/tint/program_builder.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast::transform {
+namespace {
+
+using DecomposeStridedArrayTest = TransformTest;
+
+TEST_F(DecomposeStridedArrayTest, ShouldRunEmptyModule) {
+ ProgramBuilder b;
+ EXPECT_FALSE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
+}
+
+TEST_F(DecomposeStridedArrayTest, ShouldRunNonStridedArray) {
+ // var<private> arr : array<f32, 4u>
+
+ ProgramBuilder b;
+ b.GlobalVar("arr", b.ty.array<f32, 4u>(), builtin::AddressSpace::kPrivate);
+ EXPECT_FALSE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
+}
+
+TEST_F(DecomposeStridedArrayTest, ShouldRunDefaultStridedArray) {
+ // var<private> arr : @stride(4) array<f32, 4u>
+
+ ProgramBuilder b;
+ b.GlobalVar("arr",
+ b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(4),
+ }),
+ builtin::AddressSpace::kPrivate);
+ EXPECT_TRUE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
+}
+
+TEST_F(DecomposeStridedArrayTest, ShouldRunExplicitStridedArray) {
+ // var<private> arr : @stride(16) array<f32, 4u>
+
+ ProgramBuilder b;
+ b.GlobalVar("arr",
+ b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(16),
+ }),
+ builtin::AddressSpace::kPrivate);
+ EXPECT_TRUE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
+}
+
+TEST_F(DecomposeStridedArrayTest, Empty) {
+ auto* src = R"()";
+ auto* expect = src;
+
+ auto got = Run<DecomposeStridedArray>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedArrayTest, PrivateDefaultStridedArray) {
+ // var<private> arr : @stride(4) array<f32, 4u>
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let a : @stride(4) array<f32, 4u> = a;
+ // let b : f32 = arr[1];
+ // }
+
+ ProgramBuilder b;
+ b.GlobalVar("arr",
+ b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(4),
+ }),
+ builtin::AddressSpace::kPrivate);
+ b.Func("f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Decl(b.Let("a",
+ b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(4),
+ }),
+ b.Expr("arr"))),
+ b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1_i))),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+var<private> arr : array<f32, 4u>;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let a : array<f32, 4u> = arr;
+ let b : f32 = arr[1i];
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedArrayTest, PrivateStridedArray) {
+ // var<private> arr : @stride(32) array<f32, 4u>
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let a : @stride(32) array<f32, 4u> = a;
+ // let b : f32 = arr[1];
+ // }
+
+ ProgramBuilder b;
+ b.GlobalVar("arr",
+ b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(32),
+ }),
+ builtin::AddressSpace::kPrivate);
+ b.Func("f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Decl(b.Let("a",
+ b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(32),
+ }),
+ b.Expr("arr"))),
+ b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1_i))),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct strided_arr {
+ @size(32)
+ el : f32,
+}
+
+var<private> arr : array<strided_arr, 4u>;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let a : array<strided_arr, 4u> = arr;
+ let b : f32 = arr[1i].el;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedArrayTest, ReadUniformStridedArray) {
+ // struct S {
+ // a : @stride(32) array<f32, 4u>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let a : @stride(32) array<f32, 4u> = s.a;
+ // let b : f32 = s.a[1];
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", utils::Vector{b.Member("a", b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(32),
+ }))});
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+ b.Func("f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Decl(b.Let("a",
+ b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(32),
+ }),
+ b.MemberAccessor("s", "a"))),
+ b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct strided_arr {
+ @size(32)
+ el : f32,
+}
+
+struct S {
+ a : array<strided_arr, 4u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let a : array<strided_arr, 4u> = s.a;
+ let b : f32 = s.a[1i].el;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedArrayTest, ReadUniformDefaultStridedArray) {
+ // struct S {
+ // a : @stride(16) array<vec4<f32>, 4u>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let a : @stride(16) array<vec4<f32>, 4u> = s.a;
+ // let b : f32 = s.a[1][2];
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", utils::Vector{b.Member("a", b.ty.array(b.ty.vec4<f32>(), 4_u,
+ utils::Vector{
+ b.Stride(16),
+ }))});
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+ b.Func(
+ "f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Decl(b.Let("a",
+ b.ty.array(b.ty.vec4<f32>(), 4_u,
+ utils::Vector{
+ b.Stride(16),
+ }),
+ b.MemberAccessor("s", "a"))),
+ b.Decl(b.Let("b", b.ty.f32(),
+ b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 2_i))),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect =
+ R"(
+struct S {
+ a : array<vec4<f32>, 4u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let a : array<vec4<f32>, 4u> = s.a;
+ let b : f32 = s.a[1i][2i];
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedArrayTest, ReadStorageStridedArray) {
+ // struct S {
+ // a : @stride(32) array<f32, 4u>,
+ // };
+ // @group(0) @binding(0) var<storage> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let a : @stride(32) array<f32, 4u> = s.a;
+ // let b : f32 = s.a[1];
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", utils::Vector{b.Member("a", b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(32),
+ }))});
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, b.Group(0_a), b.Binding(0_a));
+ b.Func("f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Decl(b.Let("a",
+ b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(32),
+ }),
+ b.MemberAccessor("s", "a"))),
+ b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct strided_arr {
+ @size(32)
+ el : f32,
+}
+
+struct S {
+ a : array<strided_arr, 4u>,
+}
+
+@group(0) @binding(0) var<storage> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let a : array<strided_arr, 4u> = s.a;
+ let b : f32 = s.a[1i].el;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedArrayTest, ReadStorageDefaultStridedArray) {
+ // struct S {
+ // a : @stride(4) array<f32, 4u>,
+ // };
+ // @group(0) @binding(0) var<storage> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let a : @stride(4) array<f32, 4u> = s.a;
+ // let b : f32 = s.a[1];
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", utils::Vector{b.Member("a", b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(4),
+ }))});
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, b.Group(0_a), b.Binding(0_a));
+ b.Func("f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Decl(b.Let("a",
+ b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(4),
+ }),
+ b.MemberAccessor("s", "a"))),
+ b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ a : array<f32, 4u>,
+}
+
+@group(0) @binding(0) var<storage> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let a : array<f32, 4u> = s.a;
+ let b : f32 = s.a[1i];
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedArrayTest, WriteStorageStridedArray) {
+ // struct S {
+ // a : @stride(32) array<f32, 4u>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // s.a = @stride(32) array<f32, 4u>();
+ // s.a = @stride(32) array<f32, 4u>(1.0, 2.0, 3.0, 4.0);
+ // s.a[1i] = 5.0;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", utils::Vector{b.Member("a", b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(32),
+ }))});
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func("f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Assign(b.MemberAccessor("s", "a"), b.Call(b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(32),
+ }))),
+ b.Assign(b.MemberAccessor("s", "a"), b.Call(b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(32),
+ }),
+ 1_f, 2_f, 3_f, 4_f)),
+ b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect =
+ R"(
+struct strided_arr {
+ @size(32)
+ el : f32,
+}
+
+struct S {
+ a : array<strided_arr, 4u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ s.a = array<strided_arr, 4u>();
+ s.a = array<strided_arr, 4u>(strided_arr(1.0f), strided_arr(2.0f), strided_arr(3.0f), strided_arr(4.0f));
+ s.a[1i].el = 5.0f;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedArrayTest, WriteStorageDefaultStridedArray) {
+ // struct S {
+ // a : @stride(4) array<f32, 4u>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // s.a = @stride(4) array<f32, 4u>();
+ // s.a = @stride(4) array<f32, 4u>(1.0, 2.0, 3.0, 4.0);
+ // s.a[1] = 5.0;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", utils::Vector{
+ b.Member("a", b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(4),
+ })),
+ });
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func("f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Assign(b.MemberAccessor("s", "a"), b.Call(b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(4),
+ }))),
+ b.Assign(b.MemberAccessor("s", "a"), b.Call(b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(4),
+ }),
+ 1_f, 2_f, 3_f, 4_f)),
+ b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect =
+ R"(
+struct S {
+ a : array<f32, 4u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ s.a = array<f32, 4u>();
+ s.a = array<f32, 4u>(1.0f, 2.0f, 3.0f, 4.0f);
+ s.a[1i] = 5.0f;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedArrayTest, ReadWriteViaPointerLets) {
+ // struct S {
+ // a : @stride(32) array<f32, 4u>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let a = &s.a;
+ // let b = &*&*(a);
+ // let c = *b;
+ // let d = (*b)[1];
+ // (*b) = @stride(32) array<f32, 4u>(1.0, 2.0, 3.0, 4.0);
+ // (*b)[1] = 5.0;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", utils::Vector{b.Member("a", b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(32),
+ }))});
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func("f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Decl(b.Let("a", b.AddressOf(b.MemberAccessor("s", "a")))),
+ b.Decl(b.Let("b", b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))),
+ b.Decl(b.Let("c", b.Deref("b"))),
+ b.Decl(b.Let("d", b.IndexAccessor(b.Deref("b"), 1_i))),
+ b.Assign(b.Deref("b"), b.Call(b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(32),
+ }),
+ 1_f, 2_f, 3_f, 4_f)),
+ b.Assign(b.IndexAccessor(b.Deref("b"), 1_i), 5_f),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect =
+ R"(
+struct strided_arr {
+ @size(32)
+ el : f32,
+}
+
+struct S {
+ a : array<strided_arr, 4u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let c = s.a;
+ let d = s.a[1i].el;
+ s.a = array<strided_arr, 4u>(strided_arr(1.0f), strided_arr(2.0f), strided_arr(3.0f), strided_arr(4.0f));
+ s.a[1i].el = 5.0f;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedArrayTest, PrivateAliasedStridedArray) {
+ // type ARR = @stride(32) array<f32, 4u>;
+ // struct S {
+ // a : ARR,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let a : ARR = s.a;
+ // let b : f32 = s.a[1];
+ // s.a = ARR();
+ // s.a = ARR(1.0, 2.0, 3.0, 4.0);
+ // s.a[1] = 5.0;
+ // }
+ ProgramBuilder b;
+ b.Alias("ARR", b.ty.array<f32, 4u>(utils::Vector{
+ b.Stride(32),
+ }));
+ auto* S = b.Structure("S", utils::Vector{b.Member("a", b.ty("ARR"))});
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func("f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Decl(b.Let("a", b.ty("ARR"), b.MemberAccessor("s", "a"))),
+ b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))),
+ b.Assign(b.MemberAccessor("s", "a"), b.Call("ARR")),
+ b.Assign(b.MemberAccessor("s", "a"), b.Call("ARR", 1_f, 2_f, 3_f, 4_f)),
+ b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct strided_arr {
+ @size(32)
+ el : f32,
+}
+
+alias ARR = array<strided_arr, 4u>;
+
+struct S {
+ a : ARR,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let a : ARR = s.a;
+ let b : f32 = s.a[1i].el;
+ s.a = ARR();
+ s.a = ARR(strided_arr(1.0f), strided_arr(2.0f), strided_arr(3.0f), strided_arr(4.0f));
+ s.a[1i].el = 5.0f;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedArrayTest, PrivateNestedStridedArray) {
+ // type ARR_A = @stride(8) array<f32, 2u>;
+ // type ARR_B = @stride(128) array<@stride(16) array<ARR_A, 3u>, 4u>;
+ // struct S {
+ // a : ARR_B,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let a : ARR_B = s.a;
+ // let b : array<@stride(8) array<f32, 2u>, 3u> = s.a[3];
+ // let c = s.a[3][2];
+ // let d = s.a[3][2][1];
+ // s.a = ARR_B();
+ // s.a[3][2][1] = 5.0;
+ // }
+
+ ProgramBuilder b;
+ b.Alias("ARR_A", b.ty.array<f32, 2>(utils::Vector{
+ b.Stride(8),
+ }));
+ b.Alias("ARR_B", b.ty.array( //
+ b.ty.array(b.ty("ARR_A"), 3_u,
+ utils::Vector{
+ b.Stride(16),
+ }),
+ 4_u,
+ utils::Vector{
+ b.Stride(128),
+ }));
+ auto* S = b.Structure("S", utils::Vector{b.Member("a", b.ty("ARR_B"))});
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func("f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Decl(b.Let("a", b.ty("ARR_B"), b.MemberAccessor("s", "a"))),
+ b.Decl(b.Let("b",
+ b.ty.array(b.ty("ARR_A"), 3_u,
+ utils::Vector{
+ b.Stride(16),
+ }),
+ b.IndexAccessor( //
+ b.MemberAccessor("s", "a"), //
+ 3_i))),
+ b.Decl(b.Let("c", b.ty("ARR_A"),
+ b.IndexAccessor( //
+ b.IndexAccessor( //
+ b.MemberAccessor("s", "a"), //
+ 3_i),
+ 2_i))),
+ b.Decl(b.Let("d", b.ty.f32(),
+ b.IndexAccessor( //
+ b.IndexAccessor( //
+ b.IndexAccessor( //
+ b.MemberAccessor("s", "a"), //
+ 3_i),
+ 2_i),
+ 1_i))),
+ b.Assign(b.MemberAccessor("s", "a"), b.Call("ARR_B")),
+ b.Assign(b.IndexAccessor( //
+ b.IndexAccessor( //
+ b.IndexAccessor( //
+ b.MemberAccessor("s", "a"), //
+ 3_i),
+ 2_i),
+ 1_i),
+ 5_f),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect =
+ R"(
+struct strided_arr {
+ @size(8)
+ el : f32,
+}
+
+alias ARR_A = array<strided_arr, 2u>;
+
+struct strided_arr_1 {
+ @size(128)
+ el : array<ARR_A, 3u>,
+}
+
+alias ARR_B = array<strided_arr_1, 4u>;
+
+struct S {
+ a : ARR_B,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let a : ARR_B = s.a;
+ let b : array<ARR_A, 3u> = s.a[3i].el;
+ let c : ARR_A = s.a[3i].el[2i];
+ let d : f32 = s.a[3i].el[2i][1i].el;
+ s.a = ARR_B();
+ s.a[3i].el[2i][1i].el = 5.0f;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/decompose_strided_matrix.cc b/src/tint/lang/wgsl/ast/transform/decompose_strided_matrix.cc
new file mode 100644
index 0000000..fc7a6d4
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/decompose_strided_matrix.cc
@@ -0,0 +1,206 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/decompose_strided_matrix.h"
+
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/value_expression.h"
+#include "src/tint/utils/hash.h"
+#include "src/tint/utils/map.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::DecomposeStridedMatrix);
+
+namespace tint::ast::transform {
+namespace {
+
+/// MatrixInfo describes a matrix member with a custom stride
+struct MatrixInfo {
+ /// The stride in bytes between columns of the matrix
+ uint32_t stride = 0;
+ /// The type of the matrix
+ const type::Matrix* matrix = nullptr;
+
+ /// @returns the identifier of an array that holds an vector column for each row of the matrix.
+ Type array(ProgramBuilder* b) const {
+ return b->ty.array(b->ty.vec<f32>(matrix->rows()), u32(matrix->columns()),
+ utils::Vector{
+ b->Stride(stride),
+ });
+ }
+
+ /// Equality operator
+ bool operator==(const MatrixInfo& info) const {
+ return stride == info.stride && matrix == info.matrix;
+ }
+ /// Hash function
+ struct Hasher {
+ size_t operator()(const MatrixInfo& t) const { return utils::Hash(t.stride, t.matrix); }
+ };
+};
+
+} // namespace
+
+DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
+
+DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
+
+Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ // Scan the program for all storage and uniform structure matrix members with
+ // a custom stride attribute. Replace these matrices with an equivalent array,
+ // and populate the `decomposed` map with the members that have been replaced.
+ utils::Hashmap<const type::StructMember*, MatrixInfo, 8> decomposed;
+ for (auto* node : src->ASTNodes().Objects()) {
+ if (auto* str = node->As<Struct>()) {
+ auto* str_ty = src->Sem().Get(str);
+ if (!str_ty->UsedAs(builtin::AddressSpace::kUniform) &&
+ !str_ty->UsedAs(builtin::AddressSpace::kStorage)) {
+ continue;
+ }
+ for (auto* member : str_ty->Members()) {
+ auto* matrix = member->Type()->As<type::Matrix>();
+ if (!matrix) {
+ continue;
+ }
+ auto* attr = GetAttribute<StrideAttribute>(member->Declaration()->attributes);
+ if (!attr) {
+ continue;
+ }
+ uint32_t stride = attr->stride;
+ if (matrix->ColumnStride() == stride) {
+ continue;
+ }
+ // We've got ourselves a struct member of a matrix type with a custom
+ // stride. Replace this with an array of column vectors.
+ MatrixInfo info{stride, matrix};
+ auto* replacement =
+ b.Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
+ ctx.Replace(member->Declaration(), replacement);
+ decomposed.Add(member, info);
+ }
+ }
+ }
+
+ if (decomposed.IsEmpty()) {
+ return SkipTransform;
+ }
+
+ // For all expressions where a single matrix column vector was indexed, we can
+ // preserve these without calling conversion functions.
+ // Example:
+ // ssbo.mat[2] -> ssbo.mat[2]
+ ctx.ReplaceAll([&](const IndexAccessorExpression* expr) -> const IndexAccessorExpression* {
+ if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
+ if (decomposed.Contains(access->Member())) {
+ auto* obj = ctx.CloneWithoutTransform(expr->object);
+ auto* idx = ctx.Clone(expr->index);
+ return b.IndexAccessor(obj, idx);
+ }
+ }
+ return nullptr;
+ });
+
+ // For all struct member accesses to the matrix on the LHS of an assignment,
+ // we need to convert the matrix to the array before assigning to the
+ // structure.
+ // Example:
+ // ssbo.mat = mat_to_arr(m)
+ std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
+ ctx.ReplaceAll([&](const AssignmentStatement* stmt) -> const Statement* {
+ if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
+ if (auto info = decomposed.Find(access->Member())) {
+ auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] {
+ auto name =
+ b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" +
+ std::to_string(info->matrix->rows()) + "_stride_" +
+ std::to_string(info->stride) + "_to_arr");
+
+ auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); };
+ auto array = [&] { return info->array(ctx.dst); };
+
+ auto mat = b.Sym("m");
+ utils::Vector<const Expression*, 4> columns;
+ for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
+ columns.Push(b.IndexAccessor(mat, u32(i)));
+ }
+ b.Func(name,
+ utils::Vector{
+ b.Param(mat, matrix()),
+ },
+ array(),
+ utils::Vector{
+ b.Return(b.Call(array(), columns)),
+ });
+ return name;
+ });
+ auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
+ auto* rhs = b.Call(fn, ctx.Clone(stmt->rhs));
+ return b.Assign(lhs, rhs);
+ }
+ }
+ return nullptr;
+ });
+
+ // For all other struct member accesses, we need to convert the array to the
+ // matrix type. Example:
+ // m = arr_to_mat(ssbo.mat)
+ std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
+ ctx.ReplaceAll([&](const MemberAccessorExpression* expr) -> const Expression* {
+ if (auto* access = src->Sem().Get(expr)->UnwrapLoad()->As<sem::StructMemberAccess>()) {
+ if (auto info = decomposed.Find(access->Member())) {
+ auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] {
+ auto name =
+ b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) +
+ "x" + std::to_string(info->matrix->rows()) + "_stride_" +
+ std::to_string(info->stride));
+
+ auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); };
+ auto array = [&] { return info->array(ctx.dst); };
+
+ auto arr = b.Sym("arr");
+ utils::Vector<const Expression*, 4> columns;
+ for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
+ columns.Push(b.IndexAccessor(arr, u32(i)));
+ }
+ b.Func(name,
+ utils::Vector{
+ b.Param(arr, array()),
+ },
+ matrix(),
+ utils::Vector{
+ b.Return(b.Call(matrix(), columns)),
+ });
+ return name;
+ });
+ return b.Call(fn, ctx.CloneWithoutTransform(expr));
+ }
+ }
+ return nullptr;
+ });
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/decompose_strided_matrix.h b/src/tint/lang/wgsl/ast/transform/decompose_strided_matrix.h
new file mode 100644
index 0000000..5036313
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/decompose_strided_matrix.h
@@ -0,0 +1,46 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// DecomposeStridedMatrix transforms replaces matrix members of storage or
+/// uniform buffer structures, that have a stride attribute, into an array
+/// of N column vectors.
+/// This transform is used by the SPIR-V reader to handle the SPIR-V
+/// MatrixStride attribute.
+///
+/// @note Depends on the following transforms to have been run first:
+/// * SimplifyPointers
+class DecomposeStridedMatrix final : public utils::Castable<DecomposeStridedMatrix, Transform> {
+ public:
+ /// Constructor
+ DecomposeStridedMatrix();
+
+ /// Destructor
+ ~DecomposeStridedMatrix() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
diff --git a/src/tint/lang/wgsl/ast/transform/decompose_strided_matrix_test.cc b/src/tint/lang/wgsl/ast/transform/decompose_strided_matrix_test.cc
new file mode 100644
index 0000000..d9dcac7
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/decompose_strided_matrix_test.cc
@@ -0,0 +1,641 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/decompose_strided_matrix.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/disable_validation_attribute.h"
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
+#include "src/tint/program_builder.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+using DecomposeStridedMatrixTest = TransformTest;
+
+TEST_F(DecomposeStridedMatrixTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ShouldRunNonStridedMatrox) {
+ auto* src = R"(
+var<private> m : mat3x2<f32>;
+)";
+
+ EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
+}
+
+TEST_F(DecomposeStridedMatrixTest, Empty) {
+ auto* src = R"()";
+ auto* expect = src;
+
+ auto got = Run<DecomposeStridedMatrix>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix) {
+ // struct S {
+ // @offset(16) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let x : mat2x2<f32> = s.m;
+ // }
+ ProgramBuilder b;
+ auto* S =
+ b.Structure("S", utils::Vector{
+ b.Member("m", b.ty.mat2x2<f32>(),
+ utils::Vector{
+ b.MemberOffset(16_u),
+ b.create<StrideAttribute>(32u),
+ b.Disable(DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+ b.Func("f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(16)
+ padding : u32,
+ /* @offset(16) */
+ m : @stride(32) array<vec2<f32>, 2u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn arr_to_mat2x2_stride_32(arr : @stride(32) array<vec2<f32>, 2u>) -> mat2x2<f32> {
+ return mat2x2<f32>(arr[0u], arr[1u]);
+}
+
+@compute @workgroup_size(1i)
+fn f() {
+ let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn) {
+ // struct S {
+ // @offset(16) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let x : vec2<f32> = s.m[1];
+ // }
+ ProgramBuilder b;
+ auto* S =
+ b.Structure("S", utils::Vector{
+ b.Member("m", b.ty.mat2x2<f32>(),
+ utils::Vector{
+ b.MemberOffset(16_u),
+ b.create<StrideAttribute>(32u),
+ b.Disable(DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+ b.Func(
+ "f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Decl(b.Let("x", b.ty.vec2<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(16)
+ padding : u32,
+ /* @offset(16) */
+ m : @stride(32) array<vec2<f32>, 2u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let x : vec2<f32> = s.m[1i];
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride) {
+ // struct S {
+ // @offset(16) @stride(8)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let x : mat2x2<f32> = s.m;
+ // }
+ ProgramBuilder b;
+ auto* S =
+ b.Structure("S", utils::Vector{
+ b.Member("m", b.ty.mat2x2<f32>(),
+ utils::Vector{
+ b.MemberOffset(16_u),
+ b.create<StrideAttribute>(8u),
+ b.Disable(DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+ b.Func("f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(16)
+ padding : u32,
+ /* @offset(16u) */
+ @stride(8) @internal(disable_validation__ignore_stride)
+ m : mat2x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let x : mat2x2<f32> = s.m;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) {
+ // struct S {
+ // @offset(8) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let x : mat2x2<f32> = s.m;
+ // }
+ ProgramBuilder b;
+ auto* S =
+ b.Structure("S", utils::Vector{
+ b.Member("m", b.ty.mat2x2<f32>(),
+ utils::Vector{
+ b.MemberOffset(8_u),
+ b.create<StrideAttribute>(32u),
+ b.Disable(DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func("f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(8)
+ padding : u32,
+ /* @offset(8) */
+ m : @stride(32) array<vec2<f32>, 2u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+fn arr_to_mat2x2_stride_32(arr : @stride(32) array<vec2<f32>, 2u>) -> mat2x2<f32> {
+ return mat2x2<f32>(arr[0u], arr[1u]);
+}
+
+@compute @workgroup_size(1i)
+fn f() {
+ let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadStorageColumn) {
+ // struct S {
+ // @offset(16) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let x : vec2<f32> = s.m[1];
+ // }
+ ProgramBuilder b;
+ auto* S =
+ b.Structure("S", utils::Vector{
+ b.Member("m", b.ty.mat2x2<f32>(),
+ utils::Vector{
+ b.MemberOffset(16_u),
+ b.create<StrideAttribute>(32u),
+ b.Disable(DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func(
+ "f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Decl(b.Let("x", b.ty.vec2<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(16)
+ padding : u32,
+ /* @offset(16) */
+ m : @stride(32) array<vec2<f32>, 2u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let x : vec2<f32> = s.m[1i];
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix) {
+ // struct S {
+ // @offset(8) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
+ // }
+ ProgramBuilder b;
+ auto* S =
+ b.Structure("S", utils::Vector{
+ b.Member("m", b.ty.mat2x2<f32>(),
+ utils::Vector{
+ b.MemberOffset(8_u),
+ b.create<StrideAttribute>(32u),
+ b.Disable(DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func(
+ "f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Assign(b.MemberAccessor("s", "m"),
+ b.Call<mat2x2<f32>>(b.Call<vec2<f32>>(1_f, 2_f), b.Call<vec2<f32>>(3_f, 4_f))),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(8)
+ padding : u32,
+ /* @offset(8) */
+ m : @stride(32) array<vec2<f32>, 2u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+fn mat2x2_stride_32_to_arr(m : mat2x2<f32>) -> @stride(32) array<vec2<f32>, 2u> {
+ return @stride(32) array<vec2<f32>, 2u>(m[0u], m[1u]);
+}
+
+@compute @workgroup_size(1i)
+fn f() {
+ s.m = mat2x2_stride_32_to_arr(mat2x2<f32>(vec2<f32>(1.0f, 2.0f), vec2<f32>(3.0f, 4.0f)));
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn) {
+ // struct S {
+ // @offset(8) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // s.m[1] = vec2<f32>(1.0, 2.0);
+ // }
+ ProgramBuilder b;
+ auto* S =
+ b.Structure("S", utils::Vector{
+ b.Member("m", b.ty.mat2x2<f32>(),
+ utils::Vector{
+ b.MemberOffset(8_u),
+ b.create<StrideAttribute>(32u),
+ b.Disable(DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func(
+ "f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i), b.Call<vec2<f32>>(1_f, 2_f)),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(8)
+ padding : u32,
+ /* @offset(8) */
+ m : @stride(32) array<vec2<f32>, 2u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ s.m[1i] = vec2<f32>(1.0f, 2.0f);
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadWriteViaPointerLets) {
+ // struct S {
+ // @offset(8) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let a = &s.m;
+ // let b = &*&*(a);
+ // let x = *b;
+ // let y = (*b)[1];
+ // let z = x[1];
+ // (*b) = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
+ // (*b)[1] = vec2<f32>(5.0, 6.0);
+ // }
+ ProgramBuilder b;
+ auto* S =
+ b.Structure("S", utils::Vector{
+ b.Member("m", b.ty.mat2x2<f32>(),
+ utils::Vector{
+ b.MemberOffset(8_u),
+ b.create<StrideAttribute>(32u),
+ b.Disable(DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func("f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Decl(b.Let("a", b.AddressOf(b.MemberAccessor("s", "m")))),
+ b.Decl(b.Let("b", b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))),
+ b.Decl(b.Let("x", b.Deref("b"))),
+ b.Decl(b.Let("y", b.IndexAccessor(b.Deref("b"), 1_i))),
+ b.Decl(b.Let("z", b.IndexAccessor("x", 1_i))),
+ b.Assign(b.Deref("b"), b.Call<mat2x2<f32>>(b.Call<vec2<f32>>(1_f, 2_f),
+ b.Call<vec2<f32>>(3_f, 4_f))),
+ b.Assign(b.IndexAccessor(b.Deref("b"), 1_i), b.Call<vec2<f32>>(5_f, 6_f)),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(8)
+ padding : u32,
+ /* @offset(8) */
+ m : @stride(32) array<vec2<f32>, 2u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+fn arr_to_mat2x2_stride_32(arr : @stride(32) array<vec2<f32>, 2u>) -> mat2x2<f32> {
+ return mat2x2<f32>(arr[0u], arr[1u]);
+}
+
+fn mat2x2_stride_32_to_arr(m : mat2x2<f32>) -> @stride(32) array<vec2<f32>, 2u> {
+ return @stride(32) array<vec2<f32>, 2u>(m[0u], m[1u]);
+}
+
+@compute @workgroup_size(1i)
+fn f() {
+ let x = arr_to_mat2x2_stride_32(s.m);
+ let y = s.m[1i];
+ let z = x[1i];
+ s.m = mat2x2_stride_32_to_arr(mat2x2<f32>(vec2<f32>(1.0f, 2.0f), vec2<f32>(3.0f, 4.0f)));
+ s.m[1i] = vec2<f32>(5.0f, 6.0f);
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) {
+ // struct S {
+ // @offset(8) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // var<private> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let x : mat2x2<f32> = s.m;
+ // }
+ ProgramBuilder b;
+ auto* S =
+ b.Structure("S", utils::Vector{
+ b.Member("m", b.ty.mat2x2<f32>(),
+ utils::Vector{
+ b.MemberOffset(8_u),
+ b.create<StrideAttribute>(32u),
+ b.Disable(DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kPrivate);
+ b.Func("f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(8)
+ padding : u32,
+ /* @offset(8u) */
+ @stride(32) @internal(disable_validation__ignore_stride)
+ m : mat2x2<f32>,
+}
+
+var<private> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let x : mat2x2<f32> = s.m;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, WritePrivateMatrix) {
+ // struct S {
+ // @offset(8) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f32>,
+ // };
+ // var<private> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
+ // }
+ ProgramBuilder b;
+ auto* S =
+ b.Structure("S", utils::Vector{
+ b.Member("m", b.ty.mat2x2<f32>(),
+ utils::Vector{
+ b.MemberOffset(8_u),
+ b.create<StrideAttribute>(32u),
+ b.Disable(DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kPrivate);
+ b.Func(
+ "f", utils::Empty, b.ty.void_(),
+ utils::Vector{
+ b.Assign(b.MemberAccessor("s", "m"),
+ b.Call<mat2x2<f32>>(b.Call<vec2<f32>>(1_f, 2_f), b.Call<vec2<f32>>(3_f, 4_f))),
+ },
+ utils::Vector{
+ b.Stage(PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(8)
+ padding : u32,
+ /* @offset(8u) */
+ @stride(32) @internal(disable_validation__ignore_stride)
+ m : mat2x2<f32>,
+}
+
+var<private> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ s.m = mat2x2<f32>(vec2<f32>(1.0f, 2.0f), vec2<f32>(3.0f, 4.0f));
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/demote_to_helper.cc b/src/tint/lang/wgsl/ast/transform/demote_to_helper.cc
new file mode 100644
index 0000000..97a6670
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/demote_to_helper.cc
@@ -0,0 +1,248 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/demote_to_helper.h"
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/block_statement.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/switch.h"
+#include "src/tint/type/reference.h"
+#include "src/tint/utils/map.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::DemoteToHelper);
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast::transform {
+
+DemoteToHelper::DemoteToHelper() = default;
+
+DemoteToHelper::~DemoteToHelper() = default;
+
+Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&, DataMap&) const {
+ auto& sem = src->Sem();
+
+ // Collect the set of functions that need to be processed.
+ // A function needs to be processed if it is reachable by a shader that contains a discard at
+ // any point in its call hierarchy.
+ std::unordered_set<const sem::Function*> functions_to_process;
+ for (auto* func : src->AST().Functions()) {
+ if (!func->IsEntryPoint()) {
+ continue;
+ }
+
+ // Determine whether this entry point and its callees need to be transformed.
+ bool needs_transform = false;
+ if (sem.Get(func)->DiscardStatement()) {
+ needs_transform = true;
+ } else {
+ for (auto* callee : sem.Get(func)->TransitivelyCalledFunctions()) {
+ if (callee->DiscardStatement()) {
+ needs_transform = true;
+ break;
+ }
+ }
+ }
+ if (!needs_transform) {
+ continue;
+ }
+
+ // Process the entry point and its callees.
+ functions_to_process.insert(sem.Get(func));
+ for (auto* callee : sem.Get(func)->TransitivelyCalledFunctions()) {
+ functions_to_process.insert(callee);
+ }
+ }
+
+ if (functions_to_process.empty()) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ // Create a module-scope flag that indicates whether the current invocation has been discarded.
+ auto flag = b.Symbols().New("tint_discarded");
+ b.GlobalVar(flag, builtin::AddressSpace::kPrivate, b.Expr(false));
+
+ // Replace all discard statements with a statement that marks the invocation as discarded.
+ ctx.ReplaceAll(
+ [&](const DiscardStatement*) -> const Statement* { return b.Assign(flag, b.Expr(true)); });
+
+ // Insert a conditional discard at the end of each entry point that does not end with a return.
+ for (auto* func : functions_to_process) {
+ if (func->Declaration()->IsEntryPoint()) {
+ auto* sem_body = sem.Get(func->Declaration()->body);
+ if (sem_body->Behaviors().Contains(sem::Behavior::kNext)) {
+ ctx.InsertBack(func->Declaration()->body->statements,
+ b.If(flag, b.Block(b.Discard())));
+ }
+ }
+ }
+
+ HoistToDeclBefore hoist_to_decl_before(ctx);
+
+ // Mask all writes to host-visible memory using the discarded flag.
+ // We also insert a discard statement before all return statements in entry points for shaders
+ // that discard.
+ std::unordered_map<const type::Type*, Symbol> atomic_cmpxchg_result_types;
+ for (auto* node : src->ASTNodes().Objects()) {
+ Switch(
+ node,
+
+ // Mask assignments to storage buffer variables.
+ [&](const AssignmentStatement* assign) {
+ // Skip writes in functions that are not called from shaders that discard.
+ auto* func = sem.Get(assign)->Function();
+ if (functions_to_process.count(func) == 0) {
+ return;
+ }
+
+ // Skip phony assignments.
+ if (assign->lhs->Is<PhonyExpression>()) {
+ return;
+ }
+
+ // Skip writes to invocation-private address spaces.
+ auto* ref = sem.GetVal(assign->lhs)->Type()->As<type::Reference>();
+ switch (ref->AddressSpace()) {
+ case builtin::AddressSpace::kStorage:
+ // Need to mask these.
+ break;
+ case builtin::AddressSpace::kFunction:
+ case builtin::AddressSpace::kPrivate:
+ case builtin::AddressSpace::kOut:
+ // Skip these.
+ return;
+ default:
+ TINT_UNREACHABLE(Transform, b.Diagnostics())
+ << "write to unhandled address space: " << ref->AddressSpace();
+ }
+
+ // Mask the assignment using the invocation-discarded flag.
+ ctx.Replace(assign, b.If(b.Not(flag), b.Block(ctx.Clone(assign))));
+ },
+
+ // Mask builtins that write to host-visible memory.
+ [&](const CallExpression* call) {
+ auto* sem_call = sem.Get<sem::Call>(call);
+ auto* stmt = sem_call ? sem_call->Stmt() : nullptr;
+ auto* func = stmt ? stmt->Function() : nullptr;
+ auto* builtin = sem_call ? sem_call->Target()->As<sem::Builtin>() : nullptr;
+ if (functions_to_process.count(func) == 0 || !builtin) {
+ return;
+ }
+
+ if (builtin->Type() == builtin::Function::kTextureStore) {
+ // A call to textureStore() will always be a statement.
+ // Wrap it inside a conditional block.
+ auto* masked_call = b.If(b.Not(flag), b.Block(ctx.Clone(stmt->Declaration())));
+ ctx.Replace(stmt->Declaration(), masked_call);
+ } else if (builtin->IsAtomic() &&
+ builtin->Type() != builtin::Function::kAtomicLoad) {
+ // A call to an atomic builtin can be a statement or an expression.
+ if (auto* call_stmt = stmt->Declaration()->As<CallStatement>();
+ call_stmt && call_stmt->expr == call) {
+ // This call is a statement.
+ // Wrap it inside a conditional block.
+ auto* masked_call = b.If(b.Not(flag), b.Block(ctx.Clone(call_stmt)));
+ ctx.Replace(stmt->Declaration(), masked_call);
+ } else {
+ // This call is an expression.
+ // We transform:
+ // let y = x + atomicAdd(&p, 1);
+ // Into:
+ // var tmp : i32;
+ // if (!tint_discarded) {
+ // tmp = atomicAdd(&p, 1);
+ // }
+ // let y = x + tmp;
+ auto result = b.Sym();
+ Type result_ty;
+ const Statement* masked_call = nullptr;
+ if (builtin->Type() == builtin::Function::kAtomicCompareExchangeWeak) {
+ // Special case for atomicCompareExchangeWeak as we cannot name its
+ // result type. We have to declare an equivalent struct and copy the
+ // original member values over to it.
+
+ // Declare a struct to hold the result values.
+ auto* result_struct = sem_call->Type()->As<type::Struct>();
+ auto* atomic_ty = result_struct->Members()[0]->Type();
+ result_ty = b.ty(
+ utils::GetOrCreate(atomic_cmpxchg_result_types, atomic_ty, [&] {
+ auto name = b.Sym();
+ b.Structure(
+ name,
+ utils::Vector{
+ b.Member("old_value", CreateASTTypeFor(ctx, atomic_ty)),
+ b.Member("exchanged", b.ty.bool_()),
+ });
+ return name;
+ }));
+
+ // Generate the masked call and member-wise copy:
+ // if (!tint_discarded) {
+ // let tmp_result = atomicCompareExchangeWeak(&p, 1, 2);
+ // result.exchanged = tmp_result.exchanged;
+ // result.old_value = tmp_result.old_value;
+ // }
+ auto tmp_result = b.Sym();
+ masked_call =
+ b.If(b.Not(flag),
+ b.Block(utils::Vector{
+ b.Decl(b.Let(tmp_result, ctx.CloneWithoutTransform(call))),
+ b.Assign(b.MemberAccessor(result, "old_value"),
+ b.MemberAccessor(tmp_result, "old_value")),
+ b.Assign(b.MemberAccessor(result, "exchanged"),
+ b.MemberAccessor(tmp_result, "exchanged")),
+ }));
+ } else {
+ result_ty = CreateASTTypeFor(ctx, sem_call->Type());
+ masked_call =
+ b.If(b.Not(flag),
+ b.Block(b.Assign(result, ctx.CloneWithoutTransform(call))));
+ }
+ auto* result_decl = b.Decl(b.Var(result, result_ty));
+ hoist_to_decl_before.Prepare(sem_call);
+ hoist_to_decl_before.InsertBefore(stmt, result_decl);
+ hoist_to_decl_before.InsertBefore(stmt, masked_call);
+ ctx.Replace(call, b.Expr(result));
+ }
+ }
+ },
+
+ // Insert a conditional discard before all return statements in entry points.
+ [&](const ReturnStatement* ret) {
+ auto* func = sem.Get(ret)->Function();
+ if (func->Declaration()->IsEntryPoint() && functions_to_process.count(func)) {
+ auto* discard = b.If(flag, b.Block(b.Discard()));
+ ctx.InsertBefore(sem.Get(ret)->Block()->Declaration()->statements, ret,
+ discard);
+ }
+ });
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/demote_to_helper.h b/src/tint/lang/wgsl/ast/transform/demote_to_helper.h
new file mode 100644
index 0000000..3f1463f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/demote_to_helper.h
@@ -0,0 +1,47 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_DEMOTE_TO_HELPER_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_DEMOTE_TO_HELPER_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// Implement demote-to-helper semantics for discard statements.
+///
+/// For backend targets that implement discard by terminating the invocation, we need to change the
+/// program to ensure that discarding the fragment does not affect uniformity with respect to
+/// derivative operations. We do this by setting a global flag and masking all writes to storage
+/// buffers and textures.
+///
+/// @note Depends on the following transforms to have been run first:
+/// * PromoteSideEffectsToDecl
+/// * ExpandCompoundAssignment
+class DemoteToHelper final : public utils::Castable<DemoteToHelper, Transform> {
+ public:
+ /// Constructor
+ DemoteToHelper();
+ /// Destructor
+ ~DemoteToHelper() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_DEMOTE_TO_HELPER_H_
diff --git a/src/tint/lang/wgsl/ast/transform/demote_to_helper_test.cc b/src/tint/lang/wgsl/ast/transform/demote_to_helper_test.cc
new file mode 100644
index 0000000..526895f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/demote_to_helper_test.cc
@@ -0,0 +1,1205 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/demote_to_helper.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using DemoteToHelperTest = TransformTest;
+
+TEST_F(DemoteToHelperTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<DemoteToHelper>(src));
+}
+
+TEST_F(DemoteToHelperTest, ShouldRunNoDiscard) {
+ auto* src = R"(
+@group(0) @binding(0)
+var<storage, read_write> v : f32;
+
+@fragment
+fn foo() {
+ v = 42;
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<DemoteToHelper>(src));
+}
+
+TEST_F(DemoteToHelperTest, ShouldRunDiscardInEntryPoint) {
+ auto* src = R"(
+@group(0) @binding(0)
+var<storage, read_write> v : f32;
+
+@fragment
+fn foo() {
+ discard;
+ v = 42;
+}
+)";
+
+ EXPECT_TRUE(ShouldRun<DemoteToHelper>(src));
+}
+
+TEST_F(DemoteToHelperTest, ShouldRunDiscardInHelper) {
+ auto* src = R"(
+@group(0) @binding(0)
+var<storage, read_write> v : f32;
+
+fn bar() {
+ discard;
+}
+
+@fragment
+fn foo() {
+ bar();
+ v = 42;
+}
+)";
+
+ EXPECT_TRUE(ShouldRun<DemoteToHelper>(src));
+}
+
+TEST_F(DemoteToHelperTest, EmptyModule) {
+ auto* src = R"()";
+
+ auto* expect = src;
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DemoteToHelperTest, WriteInEntryPoint_DiscardInEntryPoint) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if (in == 0.0) {
+ discard;
+ }
+ let ret = textureSample(t, s, coord);
+ v = ret.x;
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if ((in == 0.0)) {
+ tint_discarded = true;
+ }
+ let ret = textureSample(t, s, coord);
+ if (!(tint_discarded)) {
+ v = ret.x;
+ }
+ if (tint_discarded) {
+ discard;
+ }
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DemoteToHelperTest, WriteInEntryPoint_DiscardInHelper) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+fn bar() {
+ discard;
+}
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if (in == 0.0) {
+ bar();
+ }
+ let ret = textureSample(t, s, coord);
+ v = ret.x;
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+fn bar() {
+ tint_discarded = true;
+}
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if ((in == 0.0)) {
+ bar();
+ }
+ let ret = textureSample(t, s, coord);
+ if (!(tint_discarded)) {
+ v = ret.x;
+ }
+ if (tint_discarded) {
+ discard;
+ }
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DemoteToHelperTest, WriteInHelper_DiscardInEntryPoint) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+fn bar(coord : vec2<f32>) {
+ let ret = textureSample(t, s, coord);
+ v = ret.x;
+}
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if (in == 0.0) {
+ discard;
+ }
+ bar(coord);
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+fn bar(coord : vec2<f32>) {
+ let ret = textureSample(t, s, coord);
+ if (!(tint_discarded)) {
+ v = ret.x;
+ }
+}
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if ((in == 0.0)) {
+ tint_discarded = true;
+ }
+ bar(coord);
+ if (tint_discarded) {
+ discard;
+ }
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DemoteToHelperTest, WriteInHelper_DiscardInHelper) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+fn bar(in : f32, coord : vec2<f32>) {
+ if (in == 0.0) {
+ discard;
+ }
+ let ret = textureSample(t, s, coord);
+ v = ret.x;
+}
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ bar(in, coord);
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+fn bar(in : f32, coord : vec2<f32>) {
+ if ((in == 0.0)) {
+ tint_discarded = true;
+ }
+ let ret = textureSample(t, s, coord);
+ if (!(tint_discarded)) {
+ v = ret.x;
+ }
+}
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ bar(in, coord);
+ if (tint_discarded) {
+ discard;
+ }
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DemoteToHelperTest, WriteInEntryPoint_NoDiscard) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ let ret = textureSample(t, s, coord);
+ v = ret.x;
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that no additional discards are inserted when the function unconditionally returns in a
+// nested block.
+TEST_F(DemoteToHelperTest, EntryPointReturn_NestedInBlock) {
+ auto* src = R"(
+@fragment
+fn foo() {
+ {
+ discard;
+ return;
+ }
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+@fragment
+fn foo() {
+ {
+ tint_discarded = true;
+ if (tint_discarded) {
+ discard;
+ }
+ return;
+ }
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that a discard statement is inserted before every return statement in an entry point that
+// contains a discard.
+TEST_F(DemoteToHelperTest, EntryPointReturns_DiscardInEntryPoint) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) -> @location(0) f32 {
+ if (in == 0.0) {
+ discard;
+ }
+ let ret = textureSample(t, s, coord);
+ if (in < 1.0) {
+ return ret.x;
+ }
+ return 2.0;
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) -> @location(0) f32 {
+ if ((in == 0.0)) {
+ tint_discarded = true;
+ }
+ let ret = textureSample(t, s, coord);
+ if ((in < 1.0)) {
+ if (tint_discarded) {
+ discard;
+ }
+ return ret.x;
+ }
+ if (tint_discarded) {
+ discard;
+ }
+ return 2.0;
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that a discard statement is inserted before every return statement in an entry point that
+// calls a function that contains a discard.
+TEST_F(DemoteToHelperTest, EntryPointReturns_DiscardInHelper) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+fn bar() {
+ discard;
+}
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) -> @location(0) f32 {
+ if (in == 0.0) {
+ bar();
+ }
+ let ret = textureSample(t, s, coord);
+ if (in < 1.0) {
+ return ret.x;
+ }
+ return 2.0;
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+fn bar() {
+ tint_discarded = true;
+}
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) -> @location(0) f32 {
+ if ((in == 0.0)) {
+ bar();
+ }
+ let ret = textureSample(t, s, coord);
+ if ((in < 1.0)) {
+ if (tint_discarded) {
+ discard;
+ }
+ return ret.x;
+ }
+ if (tint_discarded) {
+ discard;
+ }
+ return 2.0;
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that no return statements are modified in an entry point that does not discard.
+TEST_F(DemoteToHelperTest, EntryPointReturns_NoDiscard) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+fn bar() {
+}
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) -> @location(0) f32 {
+ if ((in == 0.0)) {
+ bar();
+ }
+ let ret = textureSample(t, s, coord);
+ if ((in < 1.0)) {
+ return ret.x;
+ }
+ return 2.0;
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that only functions that are part of a shader that discards are transformed.
+// Functions in non-discarding stages should not have their writes masked, and non-discarding entry
+// points should not have their return statements replaced.
+TEST_F(DemoteToHelperTest, MultipleShaders) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v1 : f32;
+
+@group(0) @binding(3) var<storage, read_write> v2 : f32;
+
+fn bar_discard(in : f32, coord : vec2<f32>) -> f32 {
+ let ret = textureSample(t, s, coord);
+ v1 = ret.x * 2.0;
+ return ret.y * 2.0;
+}
+
+fn bar_no_discard(in : f32, coord : vec2<f32>) -> f32 {
+ let ret = textureSample(t, s, coord);
+ v1 = ret.x * 2.0;
+ return ret.y * 2.0;
+}
+
+@fragment
+fn foo_discard(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if (in == 0.0) {
+ discard;
+ }
+ let ret = bar_discard(in, coord);
+ v2 = ret;
+}
+
+@fragment
+fn foo_no_discard(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ let ret = bar_no_discard(in, coord);
+ if (in == 0.0) {
+ return;
+ }
+ v2 = ret;
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v1 : f32;
+
+@group(0) @binding(3) var<storage, read_write> v2 : f32;
+
+fn bar_discard(in : f32, coord : vec2<f32>) -> f32 {
+ let ret = textureSample(t, s, coord);
+ if (!(tint_discarded)) {
+ v1 = (ret.x * 2.0);
+ }
+ return (ret.y * 2.0);
+}
+
+fn bar_no_discard(in : f32, coord : vec2<f32>) -> f32 {
+ let ret = textureSample(t, s, coord);
+ v1 = (ret.x * 2.0);
+ return (ret.y * 2.0);
+}
+
+@fragment
+fn foo_discard(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if ((in == 0.0)) {
+ tint_discarded = true;
+ }
+ let ret = bar_discard(in, coord);
+ if (!(tint_discarded)) {
+ v2 = ret;
+ }
+ if (tint_discarded) {
+ discard;
+ }
+}
+
+@fragment
+fn foo_no_discard(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ let ret = bar_no_discard(in, coord);
+ if ((in == 0.0)) {
+ return;
+ }
+ v2 = ret;
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that we do not mask writes to invocation-private address spaces.
+TEST_F(DemoteToHelperTest, InvocationPrivateWrites) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+var<private> vp : f32;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if (in == 0.0) {
+ discard;
+ }
+ let ret = textureSample(t, s, coord);
+ var vf : f32;
+ vf = ret.x;
+ vp = ret.y;
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+var<private> vp : f32;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if ((in == 0.0)) {
+ tint_discarded = true;
+ }
+ let ret = textureSample(t, s, coord);
+ var vf : f32;
+ vf = ret.x;
+ vp = ret.y;
+ if (tint_discarded) {
+ discard;
+ }
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DemoteToHelperTest, TextureStoreInEntryPoint) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@group(0) @binding(3) var t2 : texture_storage_2d<rgba8unorm, write>;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if (in == 0.0) {
+ discard;
+ }
+ let ret = textureSample(t, s, coord);
+ textureStore(t2, vec2<u32>(coord), ret);
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@group(0) @binding(3) var t2 : texture_storage_2d<rgba8unorm, write>;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if ((in == 0.0)) {
+ tint_discarded = true;
+ }
+ let ret = textureSample(t, s, coord);
+ if (!(tint_discarded)) {
+ textureStore(t2, vec2<u32>(coord), ret);
+ }
+ if (tint_discarded) {
+ discard;
+ }
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DemoteToHelperTest, TextureStoreInHelper) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@group(0) @binding(3) var t2 : texture_storage_2d<rgba8unorm, write>;
+
+fn bar(coord : vec2<f32>, value : vec4<f32>) {
+ textureStore(t2, vec2<u32>(coord), value);
+}
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if (in == 0.0) {
+ discard;
+ }
+ let ret = textureSample(t, s, coord);
+ bar(coord, ret);
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@group(0) @binding(3) var t2 : texture_storage_2d<rgba8unorm, write>;
+
+fn bar(coord : vec2<f32>, value : vec4<f32>) {
+ if (!(tint_discarded)) {
+ textureStore(t2, vec2<u32>(coord), value);
+ }
+}
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if ((in == 0.0)) {
+ tint_discarded = true;
+ }
+ let ret = textureSample(t, s, coord);
+ bar(coord, ret);
+ if (tint_discarded) {
+ discard;
+ }
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DemoteToHelperTest, TextureStore_NoDiscard) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@group(0) @binding(3) var t2 : texture_storage_2d<rgba8unorm, write>;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ let ret = textureSample(t, s, coord);
+ textureStore(t2, vec2<u32>(coord), ret);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DemoteToHelperTest, AtomicStoreInEntryPoint) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@group(0) @binding(3) var<storage, read_write> a : atomic<i32>;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if (in == 0.0) {
+ discard;
+ }
+ let ret = textureSample(t, s, coord);
+ atomicStore(&a, i32(ret.x));
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@group(0) @binding(3) var<storage, read_write> a : atomic<i32>;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if ((in == 0.0)) {
+ tint_discarded = true;
+ }
+ let ret = textureSample(t, s, coord);
+ if (!(tint_discarded)) {
+ atomicStore(&(a), i32(ret.x));
+ }
+ if (tint_discarded) {
+ discard;
+ }
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DemoteToHelperTest, AtomicStoreInHelper) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@group(0) @binding(3) var<storage, read_write> a : atomic<i32>;
+
+fn bar(value : vec4<f32>) {
+ atomicStore(&a, i32(value.x));
+}
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if (in == 0.0) {
+ discard;
+ }
+ let ret = textureSample(t, s, coord);
+ bar(ret);
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@group(0) @binding(3) var<storage, read_write> a : atomic<i32>;
+
+fn bar(value : vec4<f32>) {
+ if (!(tint_discarded)) {
+ atomicStore(&(a), i32(value.x));
+ }
+}
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ if ((in == 0.0)) {
+ tint_discarded = true;
+ }
+ let ret = textureSample(t, s, coord);
+ bar(ret);
+ if (tint_discarded) {
+ discard;
+ }
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DemoteToHelperTest, AtomicStore_NoDiscard) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@group(0) @binding(3) var<storage, read_write> a : atomic<i32>;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) {
+ let ret = textureSample(t, s, coord);
+ atomicStore(&(a), i32(ret.x));
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DemoteToHelperTest, AtomicBuiltinExpression) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@group(0) @binding(3) var<storage, read_write> a : atomic<i32>;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) -> @location(0) i32 {
+ if (in == 0.0) {
+ discard;
+ }
+ let v = i32(textureSample(t, s, coord).x);
+ let result = v + atomicAdd(&a, v);
+ return result;
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@group(0) @binding(3) var<storage, read_write> a : atomic<i32>;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) -> @location(0) i32 {
+ if ((in == 0.0)) {
+ tint_discarded = true;
+ }
+ let v = i32(textureSample(t, s, coord).x);
+ var tint_symbol : i32;
+ if (!(tint_discarded)) {
+ tint_symbol = atomicAdd(&(a), v);
+ }
+ let result = (v + tint_symbol);
+ if (tint_discarded) {
+ discard;
+ }
+ return result;
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DemoteToHelperTest, AtomicBuiltinExpression_InForLoopContinuing) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> a : atomic<i32>;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) -> @location(0) i32 {
+ if (in == 0.0) {
+ discard;
+ }
+ var result = 0;
+ for (var i = 0; i < 10; i = atomicAdd(&a, 1)) {
+ result += i;
+ }
+ result += i32(textureSample(t, s, coord).x);
+ return result;
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> a : atomic<i32>;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) -> @location(0) i32 {
+ if ((in == 0.0)) {
+ tint_discarded = true;
+ }
+ var result = 0;
+ {
+ var i = 0;
+ loop {
+ if (!((i < 10))) {
+ break;
+ }
+ {
+ result += i;
+ }
+
+ continuing {
+ var tint_symbol : i32;
+ if (!(tint_discarded)) {
+ tint_symbol = atomicAdd(&(a), 1);
+ }
+ i = tint_symbol;
+ }
+ }
+ }
+ result += i32(textureSample(t, s, coord).x);
+ if (tint_discarded) {
+ discard;
+ }
+ return result;
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DemoteToHelperTest, AtomicCompareExchangeWeak) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> a : atomic<i32>;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) -> @location(0) i32 {
+ if (in == 0.0) {
+ discard;
+ }
+ var result = 0;
+ if (!atomicCompareExchangeWeak(&a, i32(in), 42).exchanged) {
+ let xchg = atomicCompareExchangeWeak(&a, i32(in), 42);
+ result = xchg.old_value;
+ }
+ result += i32(textureSample(t, s, coord).x);
+ return result;
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+struct tint_symbol_1 {
+ old_value : i32,
+ exchanged : bool,
+}
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> a : atomic<i32>;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) -> @location(0) i32 {
+ if ((in == 0.0)) {
+ tint_discarded = true;
+ }
+ var result = 0;
+ var tint_symbol : tint_symbol_1;
+ if (!(tint_discarded)) {
+ let tint_symbol_2 = atomicCompareExchangeWeak(&(a), i32(in), 42);
+ tint_symbol.old_value = tint_symbol_2.old_value;
+ tint_symbol.exchanged = tint_symbol_2.exchanged;
+ }
+ if (!(tint_symbol.exchanged)) {
+ var tint_symbol_3 : tint_symbol_1;
+ if (!(tint_discarded)) {
+ let tint_symbol_4 = atomicCompareExchangeWeak(&(a), i32(in), 42);
+ tint_symbol_3.old_value = tint_symbol_4.old_value;
+ tint_symbol_3.exchanged = tint_symbol_4.exchanged;
+ }
+ let xchg = tint_symbol_3;
+ result = xchg.old_value;
+ }
+ result += i32(textureSample(t, s, coord).x);
+ if (tint_discarded) {
+ discard;
+ }
+ return result;
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that no masking is generated for calls to `atomicLoad()`.
+TEST_F(DemoteToHelperTest, AtomicLoad) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@group(0) @binding(3) var<storage, read_write> a : atomic<i32>;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) -> @location(0) i32 {
+ if (in == 0.0) {
+ discard;
+ }
+ let v = i32(textureSample(t, s, coord).x);
+ let result = v + atomicLoad(&a);
+ return result;
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+@group(0) @binding(2) var<storage, read_write> v : f32;
+
+@group(0) @binding(3) var<storage, read_write> a : atomic<i32>;
+
+@fragment
+fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) -> @location(0) i32 {
+ if ((in == 0.0)) {
+ tint_discarded = true;
+ }
+ let v = i32(textureSample(t, s, coord).x);
+ let result = (v + atomicLoad(&(a)));
+ if (tint_discarded) {
+ discard;
+ }
+ return result;
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DemoteToHelperTest, PhonyAssignment) {
+ auto* src = R"(
+@fragment
+fn foo(@location(0) in : f32) {
+ if (in == 0.0) {
+ discard;
+ }
+ _ = in;
+}
+)";
+
+ auto* expect = R"(
+var<private> tint_discarded = false;
+
+@fragment
+fn foo(@location(0) in : f32) {
+ if ((in == 0.0)) {
+ tint_discarded = true;
+ }
+ _ = in;
+ if (tint_discarded) {
+ discard;
+ }
+}
+)";
+
+ auto got = Run<DemoteToHelper>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/direct_variable_access.cc b/src/tint/lang/wgsl/ast/transform/direct_variable_access.cc
new file mode 100644
index 0000000..2e39b11
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/direct_variable_access.cc
@@ -0,0 +1,1218 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/direct_variable_access.h"
+
+#include <algorithm>
+#include <string>
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before.h"
+#include "src/tint/lang/wgsl/ast/traverse_expressions.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/index_accessor_expression.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/module.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/struct.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/type/abstract_int.h"
+#include "src/tint/utils/reverse.h"
+#include "src/tint/utils/scoped_assignment.h"
+#include "src/tint/utils/string_stream.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::DirectVariableAccess);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::DirectVariableAccess::Config);
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace {
+
+/// AccessRoot describes the root of an AccessShape.
+struct AccessRoot {
+ /// The pointer-unwrapped type of the *transformed* variable.
+ /// This may be different for pointers in 'private' and 'function' address space, as the pointer
+ /// parameter type is to the *base object* instead of the input pointer type.
+ tint::type::Type const* type = nullptr;
+ /// The originating module-scope variable ('private', 'storage', 'uniform', 'workgroup'),
+ /// function-scope variable ('function'), or pointer parameter in the source program.
+ tint::sem::Variable const* variable = nullptr;
+ /// The address space of the variable or pointer type.
+ tint::builtin::AddressSpace address_space = tint::builtin::AddressSpace::kUndefined;
+};
+
+/// Inequality operator for AccessRoot
+bool operator!=(const AccessRoot& a, const AccessRoot& b) {
+ return a.type != b.type || a.variable != b.variable;
+}
+
+/// DynamicIndex is used by DirectVariableAccess::State::AccessOp to indicate an array, matrix or
+/// vector index.
+struct DynamicIndex {
+ /// The index of the expression in DirectVariableAccess::State::AccessChain::dynamic_indices
+ size_t slot = 0;
+};
+
+/// Inequality operator for DynamicIndex
+bool operator!=(const DynamicIndex& a, const DynamicIndex& b) {
+ return a.slot != b.slot;
+}
+
+/// AccessOp describes a single access in an access chain.
+/// The access is one of:
+/// Symbol - a struct member access.
+/// DynamicIndex - a runtime index on an array, matrix column, or vector element.
+using AccessOp = std::variant<tint::Symbol, DynamicIndex>;
+
+/// A vector of AccessOp. Describes the static "path" from a root variable to an element
+/// within the variable. Array accessors index expressions are held externally to the
+/// AccessShape, so AccessShape will be considered equal even if the array, matrix or vector
+/// index values differ.
+///
+/// For example, consider the following:
+///
+/// ```
+/// struct A {
+/// x : array<i32, 8>,
+/// y : u32,
+/// };
+/// struct B {
+/// x : i32,
+/// y : array<A, 4>
+/// };
+/// var<workgroup> C : B;
+/// ```
+///
+/// The following AccessShape would describe the following:
+///
+/// +==============================+===============+=================================+
+/// | AccessShape | Type | Expression |
+/// +==============================+===============+=================================+
+/// | [ Variable 'C', Symbol 'x' ] | i32 | C.x |
+/// +------------------------------+---------------+---------------------------------+
+/// | [ Variable 'C', Symbol 'y' ] | array<A, 4> | C.y |
+/// +------------------------------+---------------+---------------------------------+
+/// | [ Variable 'C', Symbol 'y', | A | C.y[dyn_idx[0]] |
+/// | DynamicIndex ] | | |
+/// +------------------------------+---------------+---------------------------------+
+/// | [ Variable 'C', Symbol 'y', | array<i32, 8> | C.y[dyn_idx[0]].x |
+/// | DynamicIndex, Symbol 'x' ] | | |
+/// +------------------------------+---------------+---------------------------------+
+/// | [ Variable 'C', Symbol 'y', | i32 | C.y[dyn_idx[0]].x[dyn_idx[1]] |
+/// | DynamicIndex, Symbol 'x', | | |
+/// | DynamicIndex ] | | |
+/// +------------------------------+---------------+---------------------------------+
+/// | [ Variable 'C', Symbol 'y', | u32 | C.y[dyn_idx[0]].y |
+/// | DynamicIndex, Symbol 'y' ] | | |
+/// +------------------------------+---------------+---------------------------------+
+///
+/// Where: `dyn_idx` is the AccessChain::dynamic_indices.
+struct AccessShape {
+ // The originating variable.
+ AccessRoot root;
+ /// The chain of access ops.
+ tint::utils::Vector<AccessOp, 8> ops;
+
+ /// @returns the number of DynamicIndex operations in #ops.
+ uint32_t NumDynamicIndices() const {
+ uint32_t count = 0;
+ for (auto& op : ops) {
+ if (std::holds_alternative<DynamicIndex>(op)) {
+ count++;
+ }
+ }
+ return count;
+ }
+};
+
+/// Equality operator for AccessShape
+bool operator==(const AccessShape& a, const AccessShape& b) {
+ return !(a.root != b.root) && a.ops == b.ops;
+}
+
+/// Inequality operator for AccessShape
+bool operator!=(const AccessShape& a, const AccessShape& b) {
+ return !(a == b);
+}
+
+/// AccessChain describes a chain of access expressions originating from a variable.
+struct AccessChain : AccessShape {
+ /// The array accessor index expressions. This vector is indexed by the `DynamicIndex`s in
+ /// #indices.
+ tint::utils::Vector<const tint::sem::ValueExpression*, 8> dynamic_indices;
+ /// If true, then this access chain is used as an argument to call a variant.
+ bool used_in_call = false;
+};
+
+} // namespace
+
+namespace tint::utils {
+
+/// Hasher specialization for AccessRoot
+template <>
+struct Hasher<AccessRoot> {
+ /// The hash function for the AccessRoot
+ /// @param d the AccessRoot to hash
+ /// @return the hash for the given AccessRoot
+ size_t operator()(const AccessRoot& d) const { return utils::Hash(d.type, d.variable); }
+};
+
+/// Hasher specialization for DynamicIndex
+template <>
+struct Hasher<DynamicIndex> {
+ /// The hash function for the DynamicIndex
+ /// @param d the DynamicIndex to hash
+ /// @return the hash for the given DynamicIndex
+ size_t operator()(const DynamicIndex& d) const { return utils::Hash(d.slot); }
+};
+
+/// Hasher specialization for AccessShape
+template <>
+struct Hasher<AccessShape> {
+ /// The hash function for the AccessShape
+ /// @param s the AccessShape to hash
+ /// @return the hash for the given AccessShape
+ size_t operator()(const AccessShape& s) const { return utils::Hash(s.root, s.ops); }
+};
+
+} // namespace tint::utils
+
+namespace tint::ast::transform {
+
+/// The PIMPL state for the DirectVariableAccess transform
+struct DirectVariableAccess::State {
+ /// Constructor
+ /// @param src the source Program
+ /// @param options the transform options
+ State(const Program* src, const Options& options)
+ : ctx{&b, src, /* auto_clone_symbols */ true}, opts(options) {}
+
+ /// The main function for the transform.
+ /// @returns the ApplyResult
+ ApplyResult Run() {
+ if (!ctx.src->Sem().Module()->Extensions().Contains(
+ builtin::Extension::kChromiumExperimentalFullPtrParameters)) {
+ // If the 'chromium_experimental_full_ptr_parameters' extension is not enabled, then
+ // there's nothing for this transform to do.
+ return SkipTransform;
+ }
+
+ // Stage 1:
+ // Walk all the expressions of the program, starting with the expression leaves.
+ // Whenever we find an identifier resolving to a var, pointer parameter or pointer let to
+ // another chain, start constructing an access chain. When chains are accessed, these chains
+ // are grown and moved up the expression tree. After this stage, we are left with all the
+ // expression access chains to variables that we may need to transform.
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* expr = sem.GetVal(node)) {
+ AppendAccessChain(expr);
+ }
+ }
+
+ // Stage 2:
+ // Walk the functions in dependency order, starting with the entry points.
+ // Construct the set of function 'variants' by examining the calls made by each function to
+ // their call target. Each variant holds a map of pointer parameter to access chains, and
+ // will have the pointer parameters replaced with an array of u32s, used to perform the
+ // pointer indexing in the variant.
+ // Function call pointer arguments are replaced with an array of these dynamic indices.
+ auto decls = sem.Module()->DependencyOrderedDeclarations();
+ for (auto* decl : utils::Reverse(decls)) {
+ if (auto* fn = sem.Get<sem::Function>(decl)) {
+ auto* fn_info = FnInfoFor(fn);
+ ProcessFunction(fn, fn_info);
+ TransformFunction(fn, fn_info);
+ }
+ }
+
+ // Stage 3:
+ // Filter out access chains that do not need transforming.
+ // Ensure that chain dynamic index expressions are evaluated once at the correct place
+ ProcessAccessChains();
+
+ // Stage 4:
+ // Replace all the access chain expressions in all functions with reconstructed expression
+ // using the originating global variable, and any dynamic indices passed in to the function
+ // variant.
+ TransformAccessChainExpressions();
+
+ // Stage 5:
+ // Actually kick the clone.
+ CloneState state;
+ clone_state = &state;
+ ctx.Clone();
+ return Program(std::move(*ctx.dst));
+ }
+
+ private:
+ /// Holds symbols of the transformed pointer parameter.
+ /// If both symbols are valid, then #base_ptr and #indices are both program-unique symbols
+ /// derived from the original parameter name.
+ /// If only one symbol is valid, then this is the original parameter symbol.
+ struct PtrParamSymbols {
+ /// The symbol of the base pointer parameter.
+ Symbol base_ptr;
+ /// The symbol of the dynamic indicies parameter.
+ Symbol indices;
+ };
+
+ /// FnVariant describes a unique variant of a function, specialized by the AccessShape of the
+ /// pointer arguments - also known as the variant's "signature".
+ ///
+ /// To help understand what a variant is, consider the following WGSL:
+ ///
+ /// ```
+ /// fn F(a : ptr<storage, u32>, b : u32, c : ptr<storage, u32>) {
+ /// return *a + b + *c;
+ /// }
+ ///
+ /// @group(0) @binding(0) var<storage> S0 : u32;
+ /// @group(0) @binding(0) var<storage> S1 : array<u32, 64>;
+ ///
+ /// fn x() {
+ /// F(&S0, 0, &S0); // (A)
+ /// F(&S0, 0, &S0); // (B)
+ /// F(&S1[0], 1, &S0); // (C)
+ /// F(&S1[5], 2, &S0); // (D)
+ /// F(&S1[5], 3, &S1[3]); // (E)
+ /// F(&S1[7], 4, &S1[2]); // (F)
+ /// }
+ /// ```
+ ///
+ /// Given the calls in x(), function F() will have 3 variants:
+ /// (1) F<S0,S0> - called by (A) and (B).
+ /// Note that only 'uniform', 'storage' and 'workgroup' pointer
+ /// parameters are considered for a variant signature, and so
+ /// the argument for parameter 'b' is not included in the
+ /// signature.
+ /// (2) F<S1[dyn_idx],S0> - called by (C) and (D).
+ /// Note that the array index value is external to the
+ /// AccessShape, and so is not part of the variant signature.
+ /// (3) F<S1[dyn_idx],S1[dyn_idx]> - called by (E) and (F).
+ ///
+ /// Each variant of the function will be emitted as a separate function by the transform, and
+ /// would look something like:
+ ///
+ /// ```
+ /// // variant F<S0,S0> (1)
+ /// fn F_S0_S0(b : u32) {
+ /// return S0 + b + S0;
+ /// }
+ ///
+ /// type S1_X = array<u32, 1>;
+ ///
+ /// // variant F<S1[dyn_idx],S0> (2)
+ /// fn F_S1_X_S0(a : S1_X, b : u32) {
+ /// return S1[a[0]] + b + S0;
+ /// }
+ ///
+ /// // variant F<S1[dyn_idx],S1[dyn_idx]> (3)
+ /// fn F_S1_X_S1_X(a : S1_X, b : u32, c : S1_X) {
+ /// return S1[a[0]] + b + S1[c[0]];
+ /// }
+ ///
+ /// @group(0) @binding(0) var<storage> S0 : u32;
+ /// @group(0) @binding(0) var<storage> S1 : array<u32, 64>;
+ ///
+ /// fn x() {
+ /// F_S0_S0(0); // (A)
+ /// F(&S0, 0, &S0); // (B)
+ /// F_S1_X_S0(S1_X(0), 1); // (C)
+ /// F_S1_X_S0(S1_X(5), 2); // (D)
+ /// F_S1_X_S1_X(S1_X(5), 3, S1_X(3)); // (E)
+ /// F_S1_X_S1_X(S1_X(7), 4, S1_X(2)); // (F)
+ /// }
+ /// ```
+ struct FnVariant {
+ /// The signature of the variant is a map of each of the function's 'uniform', 'storage' and
+ /// 'workgroup' pointer parameters to the caller's AccessShape.
+ using Signature = utils::Hashmap<const sem::Parameter*, AccessShape, 4>;
+
+ /// The unique name of the variant.
+ /// The symbol is in the `ctx.dst` program namespace.
+ Symbol name;
+
+ /// A map of direct calls made by this variant to the name of other function variants.
+ utils::Hashmap<const sem::Call*, Symbol, 4> calls;
+
+ /// A map of input program parameter to output parameter symbols.
+ utils::Hashmap<const sem::Parameter*, PtrParamSymbols, 4> ptr_param_symbols;
+
+ /// The declaration order of the variant, in relation to other variants of the same
+ /// function. Used to ensure deterministic ordering of the transform, as map iteration is
+ /// not deterministic between compilers.
+ size_t order = 0;
+ };
+
+ /// FnInfo holds information about a function in the input program.
+ struct FnInfo {
+ /// A map of variant signature to the variant data.
+ utils::Hashmap<FnVariant::Signature, FnVariant, 8> variants;
+ /// A map of expressions that have been hoisted to a 'let' declaration in the function.
+ utils::Hashmap<const sem::ValueExpression*, Symbol, 8> hoisted_exprs;
+
+ /// @returns the variants of the function in a deterministically ordered vector.
+ utils::Vector<std::pair<const FnVariant::Signature*, FnVariant*>, 8> SortedVariants() {
+ utils::Vector<std::pair<const FnVariant::Signature*, FnVariant*>, 8> out;
+ out.Reserve(variants.Count());
+ for (auto it : variants) {
+ out.Push({&it.key, &it.value});
+ }
+ out.Sort([&](auto& va, auto& vb) { return va.second->order < vb.second->order; });
+ return out;
+ }
+ };
+
+ /// The program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx;
+ /// The transform options
+ const Options& opts;
+ /// Alias to the semantic info in ctx.src
+ const sem::Info& sem = ctx.src->Sem();
+ /// Alias to the symbols in ctx.src
+ const SymbolTable& sym = ctx.src->Symbols();
+ /// Map of semantic function to the function info
+ utils::Hashmap<const sem::Function*, FnInfo*, 8> fns;
+ /// Map of AccessShape to the name of a type alias for the an array<u32, N> used for the
+ /// dynamic indices of an access chain, passed down as the transformed type of a variant's
+ /// pointer parameter.
+ utils::Hashmap<AccessShape, Symbol, 8> dynamic_index_array_aliases;
+ /// Map of semantic expression to AccessChain
+ utils::Hashmap<const sem::ValueExpression*, AccessChain*, 32> access_chains;
+ /// Allocator for FnInfo
+ utils::BlockAllocator<FnInfo> fn_info_allocator;
+ /// Allocator for AccessChain
+ utils::BlockAllocator<AccessChain> access_chain_allocator;
+ /// Helper used for hoisting expressions to lets
+ HoistToDeclBefore hoist{ctx};
+ /// Map of string to unique symbol (no collisions in output program).
+ utils::Hashmap<std::string, Symbol, 8> unique_symbols;
+
+ /// CloneState holds pointers to the current function, variant and variant's parameters.
+ struct CloneState {
+ /// The current function being cloned
+ FnInfo* current_function = nullptr;
+ /// The current function variant being built
+ FnVariant* current_variant = nullptr;
+ /// The signature of the current function variant being built
+ const FnVariant::Signature* current_variant_sig = nullptr;
+ };
+
+ /// The clone state.
+ /// Only valid during the lifetime of the CloneContext::Clone().
+ CloneState* clone_state = nullptr;
+
+ /// AppendAccessChain creates or extends an existing AccessChain for the given expression,
+ /// modifying the #access_chains map.
+ void AppendAccessChain(const sem::ValueExpression* expr) {
+ // take_chain moves the AccessChain from the expression `from` to the expression `expr`.
+ // Returns nullptr if `from` did not hold an access chain.
+ auto take_chain = [&](const sem::ValueExpression* from) -> AccessChain* {
+ if (auto* chain = AccessChainFor(from)) {
+ access_chains.Remove(from);
+ access_chains.Add(expr, chain);
+ return chain;
+ }
+ return nullptr;
+ };
+
+ Switch(
+ expr,
+ [&](const sem::VariableUser* user) {
+ // Expression resolves to a variable.
+ auto* variable = user->Variable();
+
+ auto create_new_chain = [&] {
+ auto* chain = access_chain_allocator.Create();
+ chain->root.variable = variable;
+ chain->root.type = variable->Type();
+ chain->root.address_space = variable->AddressSpace();
+ if (auto* ptr = chain->root.type->As<type::Pointer>()) {
+ chain->root.address_space = ptr->AddressSpace();
+ }
+ access_chains.Add(expr, chain);
+ };
+
+ Switch(
+ variable->Declaration(),
+ [&](const Var*) {
+ if (variable->AddressSpace() != builtin::AddressSpace::kHandle) {
+ // Start a new access chain for the non-handle 'var' access
+ create_new_chain();
+ }
+ },
+ [&](const Parameter*) {
+ if (variable->Type()->Is<type::Pointer>()) {
+ // Start a new access chain for the pointer parameter access
+ create_new_chain();
+ }
+ },
+ [&](const Let*) {
+ if (variable->Type()->Is<type::Pointer>()) {
+ // variable is a pointer-let.
+ auto* init = sem.GetVal(variable->Declaration()->initializer);
+ // Note: We do not use take_chain() here, as we need to preserve the
+ // AccessChain on the let's initializer, as the let needs its
+ // initializer updated, and the let may be used multiple times. Instead
+ // we copy the let's AccessChain into a a new AccessChain.
+ if (auto* init_chain = AccessChainFor(init)) {
+ access_chains.Add(expr, access_chain_allocator.Create(*init_chain));
+ }
+ }
+ });
+ },
+ [&](const sem::StructMemberAccess* a) {
+ // Structure member access.
+ // Append the Symbol of the member name to the chain, and move the chain to the
+ // member access expression.
+ if (auto* chain = take_chain(a->Object())) {
+ chain->ops.Push(a->Member()->Name());
+ }
+ },
+ [&](const sem::IndexAccessorExpression* a) {
+ // Array, matrix or vector index.
+ // Store the index expression into AccessChain::dynamic_indices, append a
+ // DynamicIndex to the chain, and move the chain to the index accessor expression.
+ if (auto* chain = take_chain(a->Object())) {
+ chain->ops.Push(DynamicIndex{chain->dynamic_indices.Length()});
+ chain->dynamic_indices.Push(a->Index());
+ }
+ },
+ [&](const sem::ValueExpression* e) {
+ if (auto* unary = e->Declaration()->As<UnaryOpExpression>()) {
+ // Unary op.
+ // If this is a '&' or '*', simply move the chain to the unary op expression.
+ if (unary->op == UnaryOp::kAddressOf || unary->op == UnaryOp::kIndirection) {
+ take_chain(sem.GetVal(unary->expr));
+ }
+ }
+ });
+ }
+
+ /// MaybeHoistDynamicIndices examines the AccessChain::dynamic_indices member of @p chain,
+ /// hoisting all expressions to their own uniquely named 'let' if none of the following are
+ /// true:
+ /// 1. The index expression is a constant value.
+ /// 2. The index expression's statement is the same as @p usage.
+ /// 3. The index expression is an identifier resolving to a 'let', 'const' or parameter, AND
+ /// that identifier resolves to the same variable at @p usage.
+ ///
+ /// A dynamic index will only be hoisted once. The hoisting applies to all variants of the
+ /// function that holds the dynamic index expression.
+ void MaybeHoistDynamicIndices(AccessChain* chain, const sem::Statement* usage) {
+ for (auto& idx : chain->dynamic_indices) {
+ if (idx->ConstantValue()) {
+ // Dynamic index is constant.
+ continue; // Hoisting not required.
+ }
+
+ if (idx->Stmt() == usage) {
+ // The index expression is owned by the statement of usage.
+ continue; // Hoisting not required
+ }
+
+ if (auto* idx_variable_user = idx->UnwrapMaterialize()->As<sem::VariableUser>()) {
+ auto* idx_variable = idx_variable_user->Variable();
+ if (idx_variable->Declaration()->IsAnyOf<Let, Parameter>()) {
+ // Dynamic index is an immutable variable
+ continue; // Hoisting not required.
+ }
+ }
+
+ // The dynamic index needs to be hoisted (if it hasn't been already).
+ auto fn = FnInfoFor(idx->Stmt()->Function());
+ fn->hoisted_exprs.GetOrCreate(idx, [=] {
+ // Create a name for the new 'let'
+ auto name = b.Symbols().New("ptr_index_save");
+ // Insert a new 'let' just above the dynamic index statement.
+ hoist.InsertBefore(idx->Stmt(), [this, idx, name] {
+ return b.Decl(b.Let(name, ctx.CloneWithoutTransform(idx->Declaration())));
+ });
+ return name;
+ });
+ }
+ }
+
+ /// BuildDynamicIndex builds the AST expression node for the dynamic index expression used in an
+ /// AccessChain. This is similar to just cloning the expression, but BuildDynamicIndex()
+ /// also:
+ /// * Collapses constant value index expressions down to the computed value. This acts as an
+ /// constant folding optimization and reduces noise from the transform.
+ /// * Casts the resulting expression to a u32 if @p cast_to_u32 is true, and the expression type
+ /// isn't implicitly usable as a u32. This is to help feed the expression into a
+ /// `array<u32, N>` argument passed to a callee variant function.
+ const Expression* BuildDynamicIndex(const sem::ValueExpression* idx, bool cast_to_u32) {
+ if (auto* val = idx->ConstantValue()) {
+ // Expression evaluated to a constant value. Just emit that constant.
+ return b.Expr(val->ValueAs<AInt>());
+ }
+
+ // Expression is not a constant, clone the expression.
+ // Note: If the dynamic index expression was hoisted to a let, then cloning will return an
+ // identifier expression to the hoisted let.
+ auto* expr = ctx.Clone(idx->Declaration());
+
+ if (cast_to_u32) {
+ // The index may be fed to a dynamic index array<u32, N> argument, so the index
+ // expression may need casting to u32.
+ if (!idx->UnwrapMaterialize()
+ ->Type()
+ ->UnwrapRef()
+ ->IsAnyOf<type::U32, type::AbstractInt>()) {
+ expr = b.Call<u32>(expr);
+ }
+ }
+
+ return expr;
+ }
+
+ /// ProcessFunction scans the direct calls made by the function @p fn, adding new variants to
+ /// the callee functions and transforming the call expression to pass dynamic indices instead of
+ /// true pointers.
+ /// If the function @p fn has pointer parameters that must be transformed to a caller variant,
+ /// and the function is not called, then the function is dropped from the output of the
+ /// transform, as it cannot be generated.
+ /// @note ProcessFunction must be called in dependency order for the program, starting with the
+ /// entry points.
+ void ProcessFunction(const sem::Function* fn, FnInfo* fn_info) {
+ if (fn_info->variants.IsEmpty()) {
+ // Function has no variants pre-generated by callers.
+ if (MustBeCalled(fn)) {
+ // Drop the function, as it wasn't called and cannot be generated.
+ ctx.Remove(ctx.src->AST().GlobalDeclarations(), fn->Declaration());
+ return;
+ }
+
+ // Function was not called. Create a single variant with an empty signature.
+ FnVariant variant;
+ variant.name = ctx.Clone(fn->Declaration()->name->symbol);
+ variant.order = 0; // Unaltered comes first.
+ fn_info->variants.Add(FnVariant::Signature{}, std::move(variant));
+ }
+
+ // Process each of the direct calls made by this function.
+ for (auto* call : fn->DirectCalls()) {
+ ProcessCall(fn_info, call);
+ }
+ }
+
+ /// ProcessCall creates new variants of the callee function by permuting the call for each of
+ /// the variants of @p caller. ProcessCall also registers the clone callback to transform the
+ /// call expression to pass dynamic indices instead of true pointers.
+ void ProcessCall(FnInfo* caller, const sem::Call* call) {
+ auto* target = call->Target()->As<sem::Function>();
+ if (!target) {
+ // Call target is not a user-declared function.
+ return; // Not interested in this call.
+ }
+
+ if (!HasPointerParameter(target)) {
+ return; // Not interested in this call.
+ }
+
+ bool call_needs_transforming = false;
+
+ // Build the call target function variant for each variant of the caller.
+ for (auto caller_variant_it : caller->SortedVariants()) {
+ auto& caller_signature = *caller_variant_it.first;
+ auto& caller_variant = *caller_variant_it.second;
+
+ // Build the target variant's signature.
+ FnVariant::Signature target_signature;
+ for (size_t i = 0; i < call->Arguments().Length(); i++) {
+ const auto* arg = call->Arguments()[i];
+ const auto* param = target->Parameters()[i];
+ const auto* param_ty = param->Type()->As<type::Pointer>();
+ if (!param_ty) {
+ continue; // Parameter type is not a pointer.
+ }
+
+ // Fetch the access chain for the argument.
+ auto* arg_chain = AccessChainFor(arg);
+ if (!arg_chain) {
+ continue; // Argument does not have an access chain
+ }
+
+ // Construct the absolute AccessShape by considering the AccessShape of the caller
+ // variant's argument. This will propagate back through pointer parameters, to the
+ // outermost caller.
+ auto absolute = AbsoluteAccessShape(caller_signature, *arg_chain);
+
+ // If the address space of the root variable of the access chain does not require
+ // transformation, then there's nothing to do.
+ if (!AddressSpaceRequiresTransform(absolute.root.address_space)) {
+ continue;
+ }
+
+ // Record that this chain was used in a function call.
+ // This preserves the chain during the access chain filtering stage.
+ arg_chain->used_in_call = true;
+
+ if (IsPrivateOrFunction(absolute.root.address_space)) {
+ // Pointers in 'private' and 'function' address spaces need to be passed by
+ // pointer argument.
+ absolute.root.variable = param;
+ }
+
+ // Add the parameter's absolute AccessShape to the target's signature.
+ target_signature.Add(param, std::move(absolute));
+ }
+
+ // Construct a new FnVariant if this is the first caller of the target signature
+ auto* target_info = FnInfoFor(target);
+ auto& target_variant = target_info->variants.GetOrCreate(target_signature, [&] {
+ if (target_signature.IsEmpty()) {
+ // Call target does not require any argument changes.
+ FnVariant variant;
+ variant.name = ctx.Clone(target->Declaration()->name->symbol);
+ variant.order = 0; // Unaltered comes first.
+ return variant;
+ }
+
+ // Build an appropriate variant function name.
+ // This is derived from the original function name and the pointer parameter
+ // chains.
+ utils::StringStream ss;
+ ss << target->Declaration()->name->symbol.Name();
+ for (auto* param : target->Parameters()) {
+ if (auto indices = target_signature.Find(param)) {
+ ss << "_" << AccessShapeName(*indices);
+ }
+ }
+
+ // Build the pointer parameter symbols.
+ utils::Hashmap<const sem::Parameter*, PtrParamSymbols, 4> ptr_param_symbols;
+ for (auto param_it : target_signature) {
+ auto* param = param_it.key;
+ auto& shape = param_it.value;
+
+ // Parameter needs replacing with either zero, one or two parameters:
+ // If the parameter is in the 'private' or 'function' address space, then the
+ // originating pointer is always passed down. This always comes first.
+ // If the access chain has dynamic indices, then we create an array<u32, N>
+ // parameter to hold the dynamic indices.
+ bool requires_base_ptr_param = IsPrivateOrFunction(shape.root.address_space);
+ bool requires_indices_param = shape.NumDynamicIndices() > 0;
+
+ PtrParamSymbols symbols;
+ if (requires_base_ptr_param && requires_indices_param) {
+ auto original_name = param->Declaration()->name->symbol;
+ symbols.base_ptr = UniqueSymbolWithSuffix(original_name, "_base");
+ symbols.indices = UniqueSymbolWithSuffix(original_name, "_indices");
+ } else if (requires_base_ptr_param) {
+ symbols.base_ptr = ctx.Clone(param->Declaration()->name->symbol);
+ } else if (requires_indices_param) {
+ symbols.indices = ctx.Clone(param->Declaration()->name->symbol);
+ }
+
+ // Remember this base pointer name.
+ ptr_param_symbols.Add(param, symbols);
+ }
+
+ // Build the variant.
+ FnVariant variant;
+ variant.name = b.Symbols().New(ss.str());
+ variant.order = target_info->variants.Count() + 1;
+ variant.ptr_param_symbols = std::move(ptr_param_symbols);
+ return variant;
+ });
+
+ // Record the call made by caller variant to the target variant.
+ caller_variant.calls.Add(call, target_variant.name);
+ if (!target_signature.IsEmpty()) {
+ // The call expression will need transforming for at least one caller variant.
+ call_needs_transforming = true;
+ }
+ }
+
+ if (call_needs_transforming) {
+ // Register the clone callback to correctly transform the call expression into the
+ // appropriate variant calls.
+ TransformCall(call);
+ }
+ }
+
+ /// @returns true if the address space @p address_space requires transforming given the
+ /// transform's options.
+ bool AddressSpaceRequiresTransform(builtin::AddressSpace address_space) const {
+ switch (address_space) {
+ case builtin::AddressSpace::kUniform:
+ case builtin::AddressSpace::kStorage:
+ case builtin::AddressSpace::kWorkgroup:
+ return true;
+ case builtin::AddressSpace::kPrivate:
+ return opts.transform_private;
+ case builtin::AddressSpace::kFunction:
+ return opts.transform_function;
+ default:
+ return false;
+ }
+ }
+
+ /// @returns the AccessChain for the expression @p expr, or nullptr if the expression does
+ /// not hold an access chain.
+ AccessChain* AccessChainFor(const sem::ValueExpression* expr) const {
+ if (auto chain = access_chains.Find(expr)) {
+ return *chain;
+ }
+ return nullptr;
+ }
+
+ /// @returns the absolute AccessShape for @p indices, by replacing the originating pointer
+ /// parameter with the AccessChain of variant's signature.
+ AccessShape AbsoluteAccessShape(const FnVariant::Signature& signature,
+ const AccessShape& shape) const {
+ if (auto* root_param = shape.root.variable->As<sem::Parameter>()) {
+ if (auto incoming_chain = signature.Find(root_param)) {
+ // Access chain originates from a parameter, which will be transformed into an array
+ // of dynamic indices. Concatenate the signature's AccessShape for the parameter
+ // to the chain's indices, skipping over the chain's initial parameter index.
+ auto absolute = *incoming_chain;
+ for (auto& op : shape.ops) {
+ absolute.ops.Push(op);
+ }
+ return absolute;
+ }
+ }
+
+ // Chain does not originate from a parameter, so is already absolute.
+ return shape;
+ }
+
+ /// TransformFunction registers the clone callback to transform the function @p fn into the
+ /// (potentially multiple) function's variants. TransformFunction will assign the current
+ /// function and variant to #clone_state, which can be used by the other clone callbacks.
+ void TransformFunction(const sem::Function* fn, FnInfo* fn_info) {
+ // Register a custom handler for the specific function
+ ctx.Replace(fn->Declaration(), [this, fn, fn_info] {
+ // For the scope of this lambda, assign current_function to fn_info.
+ TINT_SCOPED_ASSIGNMENT(clone_state->current_function, fn_info);
+
+ // This callback expects a single function returned. As we're generating potentially
+ // many variant functions, keep a record of the last created variant, and explicitly add
+ // this to the module if it isn't the last. We'll return the last created variant,
+ // taking the place of the original function.
+ const Function* pending_variant = nullptr;
+
+ // For each variant of fn...
+ for (auto variant_it : fn_info->SortedVariants()) {
+ if (pending_variant) {
+ b.AST().AddFunction(pending_variant);
+ }
+
+ auto& variant_sig = *variant_it.first;
+ auto& variant = *variant_it.second;
+
+ // For the rest of this scope, assign the current variant and variant signature.
+ TINT_SCOPED_ASSIGNMENT(clone_state->current_variant_sig, &variant_sig);
+ TINT_SCOPED_ASSIGNMENT(clone_state->current_variant, &variant);
+
+ // Build the variant's parameters.
+ // Pointer parameters in the 'uniform', 'storage' or 'workgroup' address space are
+ // either replaced with an array of dynamic indices, or are dropped (if there are no
+ // dynamic indices).
+ utils::Vector<const Parameter*, 8> params;
+ for (auto* param : fn->Parameters()) {
+ if (auto incoming_shape = variant_sig.Find(param)) {
+ auto& symbols = *variant.ptr_param_symbols.Find(param);
+ if (symbols.base_ptr.IsValid()) {
+ auto base_ptr_ty =
+ b.ty.ptr(incoming_shape->root.address_space,
+ CreateASTTypeFor(ctx, incoming_shape->root.type));
+ params.Push(b.Param(symbols.base_ptr, base_ptr_ty));
+ }
+ if (symbols.indices.IsValid()) {
+ // Variant has dynamic indices for this variant, replace it.
+ auto dyn_idx_arr_type = DynamicIndexArrayType(*incoming_shape);
+ params.Push(b.Param(symbols.indices, dyn_idx_arr_type));
+ }
+ } else {
+ // Just a regular parameter. Just clone the original parameter.
+ params.Push(ctx.Clone(param->Declaration()));
+ }
+ }
+
+ // Build the variant by cloning the source function. The other clone callbacks will
+ // use clone_state->current_variant and clone_state->current_variant_sig to produce
+ // the variant.
+ auto ret_ty = ctx.Clone(fn->Declaration()->return_type);
+ auto body = ctx.Clone(fn->Declaration()->body);
+ auto attrs = ctx.Clone(fn->Declaration()->attributes);
+ auto ret_attrs = ctx.Clone(fn->Declaration()->return_type_attributes);
+ pending_variant =
+ b.create<Function>(b.Ident(variant.name), std::move(params), ret_ty, body,
+ std::move(attrs), std::move(ret_attrs));
+ }
+
+ return pending_variant;
+ });
+ }
+
+ /// TransformCall registers the clone callback to transform the call expression @p call to call
+ /// the correct target variant, and to replace pointers arguments with an array of dynamic
+ /// indices.
+ void TransformCall(const sem::Call* call) {
+ // Register a custom handler for the specific call expression
+ ctx.Replace(call->Declaration(), [this, call] {
+ auto target_variant = clone_state->current_variant->calls.Find(call);
+ if (!target_variant) {
+ // The current variant does not need to transform this call.
+ return ctx.CloneWithoutTransform(call->Declaration());
+ }
+
+ // Build the new call expressions's arguments.
+ utils::Vector<const Expression*, 8> new_args;
+ for (size_t arg_idx = 0; arg_idx < call->Arguments().Length(); arg_idx++) {
+ auto* arg = call->Arguments()[arg_idx];
+ auto* param = call->Target()->Parameters()[arg_idx];
+ auto* param_ty = param->Type()->As<type::Pointer>();
+ if (!param_ty) {
+ // Parameter is not a pointer.
+ // Just clone the unaltered argument.
+ new_args.Push(ctx.Clone(arg->Declaration()));
+ continue; // Parameter is not a pointer
+ }
+
+ auto* chain = AccessChainFor(arg);
+ if (!chain) {
+ // No access chain means the argument is not a pointer that needs transforming.
+ // Just clone the unaltered argument.
+ new_args.Push(ctx.Clone(arg->Declaration()));
+ continue;
+ }
+
+ // Construct the absolute AccessShape by considering the AccessShape of the caller
+ // variant's argument. This will propagate back through pointer parameters, to the
+ // outermost caller.
+ auto full_indices = AbsoluteAccessShape(*clone_state->current_variant_sig, *chain);
+
+ // If the parameter is a pointer in the 'private' or 'function' address space, then
+ // we need to pass an additional pointer argument to the base object.
+ if (IsPrivateOrFunction(param_ty->AddressSpace())) {
+ auto* root_expr = BuildAccessRootExpr(chain->root, /* deref */ false);
+ if (!chain->root.variable->Is<sem::Parameter>()) {
+ root_expr = b.AddressOf(root_expr);
+ }
+ new_args.Push(root_expr);
+ }
+
+ // Get or create the dynamic indices array.
+ if (auto dyn_idx_arr_ty = DynamicIndexArrayType(full_indices)) {
+ // Build an array of dynamic indices to pass as the replacement for the pointer.
+ utils::Vector<const Expression*, 8> dyn_idx_args;
+ if (auto* root_param = chain->root.variable->As<sem::Parameter>()) {
+ // Access chain originates from a pointer parameter.
+ if (auto incoming_chain =
+ clone_state->current_variant_sig->Find(root_param)) {
+ auto indices =
+ clone_state->current_variant->ptr_param_symbols.Find(root_param)
+ ->indices;
+
+ // This pointer parameter will have been replaced with a array<u32, N>
+ // holding the variant's dynamic indices for the pointer. Unpack these
+ // directly into the array constructor's arguments.
+ auto N = incoming_chain->NumDynamicIndices();
+ for (uint32_t i = 0; i < N; i++) {
+ dyn_idx_args.Push(b.IndexAccessor(indices, u32(i)));
+ }
+ }
+ }
+ // Pass the dynamic indices of the access chain into the array constructor.
+ for (auto& dyn_idx : chain->dynamic_indices) {
+ dyn_idx_args.Push(BuildDynamicIndex(dyn_idx, /* cast_to_u32 */ true));
+ }
+ // Construct the dynamic index array, and push as an argument.
+ new_args.Push(b.Call(dyn_idx_arr_ty, std::move(dyn_idx_args)));
+ }
+ }
+
+ // Make the call to the target's variant.
+ return b.Call(*target_variant, std::move(new_args));
+ });
+ }
+
+ /// ProcessAccessChains performs the following:
+ /// * Removes all AccessChains from expressions that are not either used as a pointer argument
+ /// in a call, or originates from a pointer parameter.
+ /// * Hoists the dynamic index expressions of AccessChains to 'let' statements, to prevent
+ /// multiple evaluation of the expressions, and avoid expressions resolving to different
+ /// variables based on lexical scope.
+ void ProcessAccessChains() {
+ auto chain_exprs = access_chains.Keys();
+ chain_exprs.Sort([](const auto& expr_a, const auto& expr_b) {
+ return expr_a->Declaration()->node_id.value < expr_b->Declaration()->node_id.value;
+ });
+
+ for (auto* expr : chain_exprs) {
+ auto* chain = *access_chains.Get(expr);
+ if (!chain->used_in_call && !chain->root.variable->Is<sem::Parameter>()) {
+ // Chain was not used in a function call, and does not originate from a
+ // parameter. This chain does not need transforming. Drop it.
+ access_chains.Remove(expr);
+ continue;
+ }
+
+ // Chain requires transforming.
+
+ // We need to be careful that the chain does not use expressions with side-effects which
+ // cannot be repeatedly evaluated. In this situation we can hoist the dynamic index
+ // expressions to their own uniquely named lets (if required).
+ MaybeHoistDynamicIndices(chain, expr->Stmt());
+ }
+ }
+
+ /// TransformAccessChainExpressions registers the clone callback to:
+ /// * Transform all expressions that have an AccessChain (which aren't arguments to function
+ /// calls, these are handled by TransformCall()), into the equivalent expression using a
+ /// module-scope variable.
+ /// * Replace expressions that have been hoisted to a let, with an identifier expression to that
+ /// let.
+ void TransformAccessChainExpressions() {
+ // Register a custom handler for all non-function call expressions
+ ctx.ReplaceAll([this](const Expression* ast_expr) -> const Expression* {
+ if (!clone_state->current_variant) {
+ // Expression does not belong to a function variant.
+ return nullptr; // Just clone the expression.
+ }
+
+ auto* expr = sem.GetVal(ast_expr);
+ if (!expr) {
+ // No semantic node for the expression.
+ return nullptr; // Just clone the expression.
+ }
+
+ // If the expression has been hoisted to a 'let', then replace the expression with an
+ // identifier to the hoisted let.
+ if (auto hoisted = clone_state->current_function->hoisted_exprs.Find(expr)) {
+ return b.Expr(*hoisted);
+ }
+
+ auto* chain = AccessChainFor(expr);
+ if (!chain) {
+ // The expression does not have an AccessChain.
+ return nullptr; // Just clone the expression.
+ }
+
+ auto* root_param = chain->root.variable->As<sem::Parameter>();
+ if (!root_param) {
+ // The expression has an access chain, but does not originate with a pointer
+ // parameter. We don't need to change anything here.
+ return nullptr; // Just clone the expression.
+ }
+
+ auto incoming_shape = clone_state->current_variant_sig->Find(root_param);
+ if (!incoming_shape) {
+ // The root parameter of the access chain is not part of the variant's signature.
+ return nullptr; // Just clone the expression.
+ }
+
+ // Expression holds an access chain to a pointer parameter that needs transforming.
+ // Reconstruct the expression using the variant's incoming shape.
+
+ auto* chain_expr = BuildAccessRootExpr(incoming_shape->root, /* deref */ true);
+
+ // Chain starts with a pointer parameter.
+ // Replace this with the variant's incoming shape. This will bring the expression up to
+ // the incoming pointer.
+ auto indices =
+ clone_state->current_variant->ptr_param_symbols.Find(root_param)->indices;
+ for (auto param_access : incoming_shape->ops) {
+ chain_expr = BuildAccessExpr(chain_expr, param_access, [&](size_t i) {
+ return b.IndexAccessor(indices, AInt(i));
+ });
+ }
+
+ // Now build the expression chain within the function.
+
+ // For each access in the chain (excluding the pointer parameter)...
+ for (auto& op : chain->ops) {
+ chain_expr = BuildAccessExpr(chain_expr, op, [&](size_t i) {
+ return BuildDynamicIndex(chain->dynamic_indices[i], false);
+ });
+ }
+
+ // BuildAccessExpr() always returns a non-pointer.
+ // If the expression we're replacing is a pointer, take the address.
+ if (expr->Type()->Is<type::Pointer>()) {
+ chain_expr = b.AddressOf(chain_expr);
+ }
+
+ return chain_expr;
+ });
+ }
+
+ /// @returns the FnInfo for the given function, constructing a new FnInfo if @p fn doesn't
+ /// already have one.
+ FnInfo* FnInfoFor(const sem::Function* fn) {
+ return fns.GetOrCreate(fn, [this] { return fn_info_allocator.Create(); });
+ }
+
+ /// @returns the type alias used to hold the dynamic indices for @p shape, declaring a new alias
+ /// if this is the first call for the given shape.
+ Type DynamicIndexArrayType(const AccessShape& shape) {
+ auto name = dynamic_index_array_aliases.GetOrCreate(shape, [&] {
+ // Count the number of dynamic indices
+ uint32_t num_dyn_indices = shape.NumDynamicIndices();
+ if (num_dyn_indices == 0) {
+ return Symbol{};
+ }
+ auto symbol = b.Symbols().New(AccessShapeName(shape));
+ b.Alias(symbol, b.ty.array(b.ty.u32(), u32(num_dyn_indices)));
+ return symbol;
+ });
+ return name.IsValid() ? b.ty(name) : Type{};
+ }
+
+ /// @returns a name describing the given shape
+ std::string AccessShapeName(const AccessShape& shape) {
+ utils::StringStream ss;
+
+ if (IsPrivateOrFunction(shape.root.address_space)) {
+ ss << "F";
+ } else {
+ ss << shape.root.variable->Declaration()->name->symbol.Name();
+ }
+
+ for (auto& op : shape.ops) {
+ ss << "_";
+
+ if (std::holds_alternative<DynamicIndex>(op)) {
+ /// The op uses a dynamic (runtime-expression) index.
+ ss << "X";
+ continue;
+ }
+
+ auto* member = std::get_if<Symbol>(&op);
+ if (TINT_LIKELY(member)) {
+ ss << member->Name();
+ continue;
+ }
+
+ TINT_ICE(Transform, b.Diagnostics()) << "unhandled variant for access chain";
+ break;
+ }
+ return ss.str();
+ }
+
+ /// Builds an expresion to the root of an access, returning the new expression.
+ /// @param root the AccessRoot
+ /// @param deref if true, the returned expression will always be a reference type.
+ const Expression* BuildAccessRootExpr(const AccessRoot& root, bool deref) {
+ if (auto* param = root.variable->As<sem::Parameter>()) {
+ if (auto symbols = clone_state->current_variant->ptr_param_symbols.Find(param)) {
+ if (deref) {
+ return b.Deref(b.Expr(symbols->base_ptr));
+ }
+ return b.Expr(symbols->base_ptr);
+ }
+ }
+
+ const Expression* expr = b.Expr(ctx.Clone(root.variable->Declaration()->name->symbol));
+ if (deref) {
+ if (root.variable->Type()->Is<type::Pointer>()) {
+ expr = b.Deref(expr);
+ }
+ }
+ return expr;
+ }
+
+ /// Builds a single access in an access chain, returning the new expression.
+ /// The returned expression will always be of a reference type.
+ /// @param expr the input expression
+ /// @param access the access to perform on the current expression
+ /// @param dynamic_index a function that obtains the i'th dynamic index
+ const Expression* BuildAccessExpr(const Expression* expr,
+ const AccessOp& access,
+ std::function<const Expression*(size_t)> dynamic_index) {
+ if (auto* dyn_idx = std::get_if<DynamicIndex>(&access)) {
+ /// The access uses a dynamic (runtime-expression) index.
+ auto* idx = dynamic_index(dyn_idx->slot);
+ return b.IndexAccessor(expr, idx);
+ }
+
+ auto* member = std::get_if<Symbol>(&access);
+ if (TINT_LIKELY(member)) {
+ /// The access is a member access.
+ return b.MemberAccessor(expr, ctx.Clone(*member));
+ }
+
+ TINT_ICE(Transform, b.Diagnostics()) << "unhandled variant type for access chain";
+ return nullptr;
+ }
+
+ /// @returns a new Symbol starting with @p symbol concatenated with @p suffix, and possibly an
+ /// underscore and number, if the symbol is already taken.
+ Symbol UniqueSymbolWithSuffix(Symbol symbol, const std::string& suffix) {
+ auto str = symbol.Name() + suffix;
+ return unique_symbols.GetOrCreate(str, [&] { return b.Symbols().New(str); });
+ }
+
+ /// @returns true if the function @p fn has at least one pointer parameter.
+ static bool HasPointerParameter(const sem::Function* fn) {
+ for (auto* param : fn->Parameters()) {
+ if (param->Type()->Is<type::Pointer>()) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /// @returns true if the function @p fn has at least one pointer parameter in an address space
+ /// that must be replaced. If this function is not called, then the function cannot be sensibly
+ /// generated, and must be stripped.
+ static bool MustBeCalled(const sem::Function* fn) {
+ for (auto* param : fn->Parameters()) {
+ if (auto* ptr = param->Type()->As<type::Pointer>()) {
+ switch (ptr->AddressSpace()) {
+ case builtin::AddressSpace::kUniform:
+ case builtin::AddressSpace::kStorage:
+ case builtin::AddressSpace::kWorkgroup:
+ return true;
+ default:
+ return false;
+ }
+ }
+ }
+ return false;
+ }
+
+ /// @returns true if the given address space is 'private' or 'function'.
+ static bool IsPrivateOrFunction(const builtin::AddressSpace sc) {
+ return sc == builtin::AddressSpace::kPrivate || sc == builtin::AddressSpace::kFunction;
+ }
+};
+
+DirectVariableAccess::Config::Config(const Options& opt) : options(opt) {}
+
+DirectVariableAccess::Config::~Config() = default;
+
+DirectVariableAccess::DirectVariableAccess() = default;
+
+DirectVariableAccess::~DirectVariableAccess() = default;
+
+Transform::ApplyResult DirectVariableAccess::Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap&) const {
+ Options options;
+ if (auto* cfg = inputs.Get<Config>()) {
+ options = cfg->options;
+ }
+ return State(program, options).Run();
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/direct_variable_access.h b/src/tint/lang/wgsl/ast/transform/direct_variable_access.h
new file mode 100644
index 0000000..eb9dbc6
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/direct_variable_access.h
@@ -0,0 +1,74 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_DIRECT_VARIABLE_ACCESS_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_DIRECT_VARIABLE_ACCESS_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// DirectVariableAccess is a transform that allows usage of pointer parameters in the 'storage',
+/// 'uniform' and 'workgroup' address space, and passing of pointers to sub-objects. These pointers
+/// are only allowed by the resolver when the `chromium_experimental_full_ptr_parameters` extension
+/// is enabled.
+///
+/// DirectVariableAccess works by creating specializations of functions that have pointer
+/// parameters, one specialization for each pointer argument's unique access chain 'shape' from a
+/// unique variable. Calls to specialized functions are transformed so that the pointer arguments
+/// are replaced with an array of access-chain indicies, and if the pointer is in the 'function' or
+/// 'private' address space, also with a pointer to the root object. For more information, see the
+/// comments in src/tint/lang/wgsl/ast/transform/direct_variable_access.cc.
+///
+/// @note DirectVariableAccess requires the transform::Unshadow transform to have been run first.
+class DirectVariableAccess final : public utils::Castable<DirectVariableAccess, Transform> {
+ public:
+ /// Constructor
+ DirectVariableAccess();
+ /// Destructor
+ ~DirectVariableAccess() override;
+
+ /// Options adjusts the behaviour of the transform.
+ struct Options {
+ /// If true, then 'private' sub-object pointer arguments will be transformed.
+ bool transform_private = false;
+ /// If true, then 'function' sub-object pointer arguments will be transformed.
+ bool transform_function = false;
+ };
+
+ /// Config is consumed by the DirectVariableAccess transform.
+ /// Config specifies the behavior of the transform.
+ struct Config final : public utils::Castable<Config, Data> {
+ /// Constructor
+ /// @param options behavior of the transform
+ explicit Config(const Options& options);
+ /// Destructor
+ ~Config() override;
+
+ /// The transform behavior options
+ const Options options;
+ };
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_DIRECT_VARIABLE_ACCESS_H_
diff --git a/src/tint/lang/wgsl/ast/transform/direct_variable_access_test.cc b/src/tint/lang/wgsl/ast/transform/direct_variable_access_test.cc
new file mode 100644
index 0000000..992a89f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/direct_variable_access_test.cc
@@ -0,0 +1,2713 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/direct_variable_access.h"
+
+#include <memory>
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/utils/string.h"
+
+namespace tint::ast::transform {
+namespace {
+
+/// @returns a Transform::DataMap with DirectVariableAccess::Config::transform_private enabled.
+static Transform::DataMap EnablePrivate() {
+ DirectVariableAccess::Options opts;
+ opts.transform_private = true;
+
+ Transform::DataMap inputs;
+ inputs.Add<DirectVariableAccess::Config>(opts);
+ return inputs;
+}
+
+/// @returns a Transform::DataMap with DirectVariableAccess::Config::transform_function enabled.
+static Transform::DataMap EnableFunction() {
+ DirectVariableAccess::Options opts;
+ opts.transform_function = true;
+
+ Transform::DataMap inputs;
+ inputs.Add<DirectVariableAccess::Config>(opts);
+ return inputs;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// ShouldRun tests
+////////////////////////////////////////////////////////////////////////////////
+namespace should_run {
+
+using DirectVariableAccessShouldRunTest = TransformTest;
+
+TEST_F(DirectVariableAccessShouldRunTest, EmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<DirectVariableAccess>(src));
+}
+
+} // namespace should_run
+
+////////////////////////////////////////////////////////////////////////////////
+// remove uncalled
+////////////////////////////////////////////////////////////////////////////////
+namespace remove_uncalled {
+
+using DirectVariableAccessRemoveUncalledTest = TransformTest;
+
+TEST_F(DirectVariableAccessRemoveUncalledTest, PtrUniform) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> keep_me = 42;
+
+fn u(pre : i32, p : ptr<uniform, i32>, post : i32) -> i32 {
+ return *(p);
+}
+
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> keep_me = 42;
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessRemoveUncalledTest, PtrStorage) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> keep_me = 42;
+
+fn s(pre : i32, p : ptr<storage, i32>, post : i32) -> i32 {
+ return *(p);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> keep_me = 42;
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessRemoveUncalledTest, PtrWorkgroup) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> keep_me = 42;
+
+fn w(pre : i32, p : ptr<workgroup, i32>, post : i32) -> i32 {
+ return *(p);
+}
+
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> keep_me = 42;
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessRemoveUncalledTest, PtrPrivate_Disabled) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> keep_me = 42;
+
+fn f(pre : i32, p : ptr<private, i32>, post : i32) -> i32 {
+ return *(p);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessRemoveUncalledTest, PtrPrivate_Enabled) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> keep_me = 42;
+)";
+
+ auto* expect = src;
+
+ auto got = Run<DirectVariableAccess>(src, EnablePrivate());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessRemoveUncalledTest, PtrFunction_Disabled) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> keep_me = 42;
+
+fn f(pre : i32, p : ptr<function, i32>, post : i32) -> i32 {
+ return *(p);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessRemoveUncalledTest, PtrFunction_Enabled) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> keep_me = 42;
+)";
+
+ auto* expect = src;
+
+ auto got = Run<DirectVariableAccess>(src, EnableFunction());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace remove_uncalled
+
+////////////////////////////////////////////////////////////////////////////////
+// pointer chains
+////////////////////////////////////////////////////////////////////////////////
+namespace pointer_chains_tests {
+
+using DirectVariableAccessPtrChainsTest = TransformTest;
+
+TEST_F(DirectVariableAccessPtrChainsTest, ConstantIndices) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> U : array<array<array<vec4<i32>, 8>, 8>, 8>;
+
+fn a(pre : i32, p : ptr<uniform, vec4<i32>>, post : i32) -> vec4<i32> {
+ return *p;
+}
+
+fn b() {
+ let p0 = &U;
+ let p1 = &(*p0)[1];
+ let p2 = &(*p1)[1+1];
+ let p3 = &(*p2)[2*2 - 1];
+ a(10, p3, 20);
+}
+
+fn c(p : ptr<uniform, array<array<array<vec4<i32>, 8>, 8>, 8>>) {
+ let p0 = p;
+ let p1 = &(*p0)[1];
+ let p2 = &(*p1)[1+1];
+ let p3 = &(*p2)[2*2 - 1];
+ a(10, p3, 20);
+}
+
+fn d() {
+ c(&U);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> U : array<array<array<vec4<i32>, 8>, 8>, 8>;
+
+alias U_X_X_X = array<u32, 3u>;
+
+fn a_U_X_X_X(pre : i32, p : U_X_X_X, post : i32) -> vec4<i32> {
+ return U[p[0]][p[1]][p[2]];
+}
+
+fn b() {
+ let p0 = &(U);
+ let p1 = &((*(p0))[1]);
+ let p2 = &((*(p1))[(1 + 1)]);
+ let p3 = &((*(p2))[((2 * 2) - 1)]);
+ a_U_X_X_X(10, U_X_X_X(1, 2, 3), 20);
+}
+
+fn c_U() {
+ let p0 = &(U);
+ let p1 = &(U[1]);
+ let p2 = &(U[1][2]);
+ let p3 = &(U[1][2][3]);
+ a_U_X_X_X(10, U_X_X_X(1, 2, 3), 20);
+}
+
+fn d() {
+ c_U();
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessPtrChainsTest, HoistIndices) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> U : array<array<array<vec4<i32>, 8>, 8>, 8>;
+
+var<private> i : i32;
+fn first() -> i32 {
+ i++;
+ return i;
+}
+fn second() -> i32 {
+ i++;
+ return i;
+}
+fn third() -> i32 {
+ i++;
+ return i;
+}
+
+fn a(pre : i32, p : ptr<uniform, vec4<i32>>, post : i32) -> vec4<i32> {
+ return *p;
+}
+
+fn b() {
+ let p0 = &U;
+ let p1 = &(*p0)[first()];
+ let p2 = &(*p1)[second()][third()];
+ a(10, p2, 20);
+}
+
+fn c(p : ptr<uniform, array<array<array<vec4<i32>, 8>, 8>, 8>>) {
+ let p0 = &U;
+ let p1 = &(*p0)[first()];
+ let p2 = &(*p1)[second()][third()];
+ a(10, p2, 20);
+}
+
+fn d() {
+ c(&U);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> U : array<array<array<vec4<i32>, 8>, 8>, 8>;
+
+var<private> i : i32;
+
+fn first() -> i32 {
+ i++;
+ return i;
+}
+
+fn second() -> i32 {
+ i++;
+ return i;
+}
+
+fn third() -> i32 {
+ i++;
+ return i;
+}
+
+alias U_X_X_X = array<u32, 3u>;
+
+fn a_U_X_X_X(pre : i32, p : U_X_X_X, post : i32) -> vec4<i32> {
+ return U[p[0]][p[1]][p[2]];
+}
+
+fn b() {
+ let p0 = &(U);
+ let ptr_index_save = first();
+ let p1 = &((*(p0))[ptr_index_save]);
+ let ptr_index_save_1 = second();
+ let ptr_index_save_2 = third();
+ let p2 = &((*(p1))[ptr_index_save_1][ptr_index_save_2]);
+ a_U_X_X_X(10, U_X_X_X(u32(ptr_index_save), u32(ptr_index_save_1), u32(ptr_index_save_2)), 20);
+}
+
+fn c_U() {
+ let p0 = &(U);
+ let ptr_index_save_3 = first();
+ let p1 = &((*(p0))[ptr_index_save_3]);
+ let ptr_index_save_4 = second();
+ let ptr_index_save_5 = third();
+ let p2 = &((*(p1))[ptr_index_save_4][ptr_index_save_5]);
+ a_U_X_X_X(10, U_X_X_X(u32(ptr_index_save_3), u32(ptr_index_save_4), u32(ptr_index_save_5)), 20);
+}
+
+fn d() {
+ c_U();
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessPtrChainsTest, HoistInForLoopInit) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> U : array<array<vec4<i32>, 8>, 8>;
+
+var<private> i : i32;
+fn first() -> i32 {
+ i++;
+ return i;
+}
+fn second() -> i32 {
+ i++;
+ return i;
+}
+
+fn a(pre : i32, p : ptr<uniform, vec4<i32>>, post : i32) -> vec4<i32> {
+ return *p;
+}
+
+fn b() {
+ for (let p1 = &U[first()]; true; ) {
+ a(10, &(*p1)[second()], 20);
+ }
+}
+
+fn c(p : ptr<uniform, array<array<vec4<i32>, 8>, 8>>) {
+ for (let p1 = &(*p)[first()]; true; ) {
+ a(10, &(*p1)[second()], 20);
+ }
+}
+
+fn d() {
+ c(&U);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> U : array<array<vec4<i32>, 8>, 8>;
+
+var<private> i : i32;
+
+fn first() -> i32 {
+ i++;
+ return i;
+}
+
+fn second() -> i32 {
+ i++;
+ return i;
+}
+
+alias U_X_X = array<u32, 2u>;
+
+fn a_U_X_X(pre : i32, p : U_X_X, post : i32) -> vec4<i32> {
+ return U[p[0]][p[1]];
+}
+
+fn b() {
+ {
+ let ptr_index_save = first();
+ let p1 = &(U[ptr_index_save]);
+ loop {
+ if (!(true)) {
+ break;
+ }
+ {
+ a_U_X_X(10, U_X_X(u32(ptr_index_save), u32(second())), 20);
+ }
+ }
+ }
+}
+
+fn c_U() {
+ {
+ let ptr_index_save_1 = first();
+ let p1 = &(U[ptr_index_save_1]);
+ loop {
+ if (!(true)) {
+ break;
+ }
+ {
+ a_U_X_X(10, U_X_X(u32(ptr_index_save_1), u32(second())), 20);
+ }
+ }
+ }
+}
+
+fn d() {
+ c_U();
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessPtrChainsTest, HoistInForLoopCond) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> U : array<array<vec4<i32>, 8>, 8>;
+
+var<private> i : i32;
+fn first() -> i32 {
+ i++;
+ return i;
+}
+fn second() -> i32 {
+ i++;
+ return i;
+}
+
+fn a(pre : i32, p : ptr<uniform, vec4<i32>>, post : i32) -> vec4<i32> {
+ return *p;
+}
+
+fn b() {
+ let p = &U[first()][second()];
+ for (; a(10, p, 20).x < 4; ) {
+ let body = 1;
+ }
+}
+
+fn c(p : ptr<uniform, array<array<vec4<i32>, 8>, 8>>) {
+ let p2 = &(*p)[first()][second()];
+ for (; a(10, p2, 20).x < 4; ) {
+ let body = 1;
+ }
+}
+
+fn d() {
+ c(&U);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> U : array<array<vec4<i32>, 8>, 8>;
+
+var<private> i : i32;
+
+fn first() -> i32 {
+ i++;
+ return i;
+}
+
+fn second() -> i32 {
+ i++;
+ return i;
+}
+
+alias U_X_X = array<u32, 2u>;
+
+fn a_U_X_X(pre : i32, p : U_X_X, post : i32) -> vec4<i32> {
+ return U[p[0]][p[1]];
+}
+
+fn b() {
+ let ptr_index_save = first();
+ let ptr_index_save_1 = second();
+ let p = &(U[ptr_index_save][ptr_index_save_1]);
+ for(; (a_U_X_X(10, U_X_X(u32(ptr_index_save), u32(ptr_index_save_1)), 20).x < 4); ) {
+ let body = 1;
+ }
+}
+
+fn c_U() {
+ let ptr_index_save_2 = first();
+ let ptr_index_save_3 = second();
+ let p2 = &(U[ptr_index_save_2][ptr_index_save_3]);
+ for(; (a_U_X_X(10, U_X_X(u32(ptr_index_save_2), u32(ptr_index_save_3)), 20).x < 4); ) {
+ let body = 1;
+ }
+}
+
+fn d() {
+ c_U();
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessPtrChainsTest, HoistInForLoopCont) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> U : array<array<vec4<i32>, 8>, 8>;
+
+var<private> i : i32;
+fn first() -> i32 {
+ i++;
+ return i;
+}
+fn second() -> i32 {
+ i++;
+ return i;
+}
+
+fn a(pre : i32, p : ptr<uniform, vec4<i32>>, post : i32) -> vec4<i32> {
+ return *p;
+}
+
+fn b() {
+ let p = &U[first()][second()];
+ for (var i = 0; i < 3; a(10, p, 20)) {
+ i++;
+ }
+}
+
+fn c(p : ptr<uniform, array<array<vec4<i32>, 8>, 8>>) {
+ let p2 = &(*p)[first()][second()];
+ for (var i = 0; i < 3; a(10, p2, 20)) {
+ i++;
+ }
+}
+
+fn d() {
+ c(&U);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> U : array<array<vec4<i32>, 8>, 8>;
+
+var<private> i : i32;
+
+fn first() -> i32 {
+ i++;
+ return i;
+}
+
+fn second() -> i32 {
+ i++;
+ return i;
+}
+
+alias U_X_X = array<u32, 2u>;
+
+fn a_U_X_X(pre : i32, p : U_X_X, post : i32) -> vec4<i32> {
+ return U[p[0]][p[1]];
+}
+
+fn b() {
+ let ptr_index_save = first();
+ let ptr_index_save_1 = second();
+ let p = &(U[ptr_index_save][ptr_index_save_1]);
+ for(var i = 0; (i < 3); a_U_X_X(10, U_X_X(u32(ptr_index_save), u32(ptr_index_save_1)), 20)) {
+ i++;
+ }
+}
+
+fn c_U() {
+ let ptr_index_save_2 = first();
+ let ptr_index_save_3 = second();
+ let p2 = &(U[ptr_index_save_2][ptr_index_save_3]);
+ for(var i = 0; (i < 3); a_U_X_X(10, U_X_X(u32(ptr_index_save_2), u32(ptr_index_save_3)), 20)) {
+ i++;
+ }
+}
+
+fn d() {
+ c_U();
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessPtrChainsTest, HoistInWhileCond) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> U : array<array<vec4<i32>, 8>, 8>;
+
+var<private> i : i32;
+fn first() -> i32 {
+ i++;
+ return i;
+}
+fn second() -> i32 {
+ i++;
+ return i;
+}
+
+fn a(pre : i32, p : ptr<uniform, vec4<i32>>, post : i32) -> vec4<i32> {
+ return *p;
+}
+
+fn b() {
+ let p = &U[first()][second()];
+ while (a(10, p, 20).x < 4) {
+ let body = 1;
+ }
+}
+
+fn c(p : ptr<uniform, array<array<vec4<i32>, 8>, 8>>) {
+ let p2 = &(*p)[first()][second()];
+ while (a(10, p2, 20).x < 4) {
+ let body = 1;
+ }
+}
+
+fn d() {
+ c(&U);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> U : array<array<vec4<i32>, 8>, 8>;
+
+var<private> i : i32;
+
+fn first() -> i32 {
+ i++;
+ return i;
+}
+
+fn second() -> i32 {
+ i++;
+ return i;
+}
+
+alias U_X_X = array<u32, 2u>;
+
+fn a_U_X_X(pre : i32, p : U_X_X, post : i32) -> vec4<i32> {
+ return U[p[0]][p[1]];
+}
+
+fn b() {
+ let ptr_index_save = first();
+ let ptr_index_save_1 = second();
+ let p = &(U[ptr_index_save][ptr_index_save_1]);
+ while((a_U_X_X(10, U_X_X(u32(ptr_index_save), u32(ptr_index_save_1)), 20).x < 4)) {
+ let body = 1;
+ }
+}
+
+fn c_U() {
+ let ptr_index_save_2 = first();
+ let ptr_index_save_3 = second();
+ let p2 = &(U[ptr_index_save_2][ptr_index_save_3]);
+ while((a_U_X_X(10, U_X_X(u32(ptr_index_save_2), u32(ptr_index_save_3)), 20).x < 4)) {
+ let body = 1;
+ }
+}
+
+fn d() {
+ c_U();
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace pointer_chains_tests
+
+////////////////////////////////////////////////////////////////////////////////
+// 'uniform' address space
+////////////////////////////////////////////////////////////////////////////////
+namespace uniform_as_tests {
+
+using DirectVariableAccessUniformASTest = TransformTest;
+
+TEST_F(DirectVariableAccessUniformASTest, Param_ptr_i32_read) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> U : i32;
+
+fn a(pre : i32, p : ptr<uniform, i32>, post : i32) -> i32 {
+ return *p;
+}
+
+fn b() {
+ a(10, &U, 20);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> U : i32;
+
+fn a_U(pre : i32, post : i32) -> i32 {
+ return U;
+}
+
+fn b() {
+ a_U(10, 20);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessUniformASTest, Param_ptr_vec4i32_Via_array_DynamicRead) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> U : array<vec4<i32>, 8>;
+
+fn a(pre : i32, p : ptr<uniform, vec4<i32>>, post : i32) -> vec4<i32> {
+ return *p;
+}
+
+fn b() {
+ let I = 3;
+ a(10, &U[I], 20);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> U : array<vec4<i32>, 8>;
+
+alias U_X = array<u32, 1u>;
+
+fn a_U_X(pre : i32, p : U_X, post : i32) -> vec4<i32> {
+ return U[p[0]];
+}
+
+fn b() {
+ let I = 3;
+ a_U_X(10, U_X(u32(I)), 20);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessUniformASTest, CallChaining) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct Inner {
+ mat : mat3x4<f32>,
+};
+
+alias InnerArr = array<Inner, 4>;
+
+struct Outer {
+ arr : InnerArr,
+ mat : mat3x4<f32>,
+};
+
+@group(0) @binding(0) var<uniform> U : Outer;
+
+fn f0(p : ptr<uniform, vec4<f32>>) -> f32 {
+ return (*p).x;
+}
+
+fn f1(p : ptr<uniform, mat3x4<f32>>) -> f32 {
+ var res : f32;
+ {
+ // call f0() with inline usage of p
+ res += f0(&(*p)[1]);
+ }
+ {
+ // call f0() with pointer-let usage of p
+ let p_vec = &(*p)[1];
+ res += f0(p_vec);
+ }
+ {
+ // call f0() with inline usage of U
+ res += f0(&U.arr[2].mat[1]);
+ }
+ {
+ // call f0() with pointer-let usage of U
+ let p_vec = &U.arr[2].mat[1];
+ res += f0(p_vec);
+ }
+ return res;
+}
+
+fn f2(p : ptr<uniform, Inner>) -> f32 {
+ let p_mat = &(*p).mat;
+ return f1(p_mat);
+}
+
+fn f3(p0 : ptr<uniform, InnerArr>, p1 : ptr<uniform, mat3x4<f32>>) -> f32 {
+ let p0_inner = &(*p0)[3];
+ return f2(p0_inner) + f1(p1);
+}
+
+fn f4(p : ptr<uniform, Outer>) -> f32 {
+ return f3(&(*p).arr, &U.mat);
+}
+
+fn b() {
+ f4(&U);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct Inner {
+ mat : mat3x4<f32>,
+}
+
+alias InnerArr = array<Inner, 4>;
+
+struct Outer {
+ arr : InnerArr,
+ mat : mat3x4<f32>,
+}
+
+@group(0) @binding(0) var<uniform> U : Outer;
+
+alias U_mat_X = array<u32, 1u>;
+
+fn f0_U_mat_X(p : U_mat_X) -> f32 {
+ return U.mat[p[0]].x;
+}
+
+alias U_arr_X_mat_X = array<u32, 2u>;
+
+fn f0_U_arr_X_mat_X(p : U_arr_X_mat_X) -> f32 {
+ return U.arr[p[0]].mat[p[0]].x;
+}
+
+alias U_arr_X_mat_X_1 = array<u32, 2u>;
+
+fn f0_U_arr_X_mat_X_1(p : U_arr_X_mat_X_1) -> f32 {
+ return U.arr[p[0]].mat[p[1]].x;
+}
+
+fn f1_U_mat() -> f32 {
+ var res : f32;
+ {
+ res += f0_U_mat_X(U_mat_X(1));
+ }
+ {
+ let p_vec = &(U.mat[1]);
+ res += f0_U_mat_X(U_mat_X(1));
+ }
+ {
+ res += f0_U_arr_X_mat_X_1(U_arr_X_mat_X_1(2, 1));
+ }
+ {
+ let p_vec = &(U.arr[2].mat[1]);
+ res += f0_U_arr_X_mat_X_1(U_arr_X_mat_X_1(2, 1));
+ }
+ return res;
+}
+
+alias U_arr_X_mat = array<u32, 1u>;
+
+fn f1_U_arr_X_mat(p : U_arr_X_mat) -> f32 {
+ var res : f32;
+ {
+ res += f0_U_arr_X_mat_X(U_arr_X_mat_X(p[0u], 1));
+ }
+ {
+ let p_vec = &(U.arr[p[0]].mat[1]);
+ res += f0_U_arr_X_mat_X(U_arr_X_mat_X(p[0u], 1));
+ }
+ {
+ res += f0_U_arr_X_mat_X_1(U_arr_X_mat_X_1(2, 1));
+ }
+ {
+ let p_vec = &(U.arr[2].mat[1]);
+ res += f0_U_arr_X_mat_X_1(U_arr_X_mat_X_1(2, 1));
+ }
+ return res;
+}
+
+alias U_arr_X = array<u32, 1u>;
+
+fn f2_U_arr_X(p : U_arr_X) -> f32 {
+ let p_mat = &(U.arr[p[0]].mat);
+ return f1_U_arr_X_mat(U_arr_X_mat(p[0u]));
+}
+
+fn f3_U_arr_U_mat() -> f32 {
+ let p0_inner = &(U.arr[3]);
+ return (f2_U_arr_X(U_arr_X(3)) + f1_U_mat());
+}
+
+fn f4_U() -> f32 {
+ return f3_U_arr_U_mat();
+}
+
+fn b() {
+ f4_U();
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace uniform_as_tests
+
+////////////////////////////////////////////////////////////////////////////////
+// 'storage' address space
+////////////////////////////////////////////////////////////////////////////////
+namespace storage_as_tests {
+
+using DirectVariableAccessStorageASTest = TransformTest;
+
+TEST_F(DirectVariableAccessStorageASTest, Param_ptr_i32_Via_struct_read) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ i : i32,
+};
+
+@group(0) @binding(0) var<storage> S : str;
+
+fn a(pre : i32, p : ptr<storage, i32>, post : i32) -> i32 {
+ return *p;
+}
+
+fn b() {
+ a(10, &S.i, 20);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ i : i32,
+}
+
+@group(0) @binding(0) var<storage> S : str;
+
+fn a_S_i(pre : i32, post : i32) -> i32 {
+ return S.i;
+}
+
+fn b() {
+ a_S_i(10, 20);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessStorageASTest, Param_ptr_arr_i32_Via_struct_write) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ arr : array<i32, 4>,
+};
+
+@group(0) @binding(0) var<storage, read_write> S : str;
+
+fn a(pre : i32, p : ptr<storage, array<i32, 4>, read_write>, post : i32) {
+ *p = array<i32, 4>();
+}
+
+fn b() {
+ a(10, &S.arr, 20);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ arr : array<i32, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> S : str;
+
+fn a_S_arr(pre : i32, post : i32) {
+ S.arr = array<i32, 4>();
+}
+
+fn b() {
+ a_S_arr(10, 20);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessStorageASTest, Param_ptr_vec4i32_Via_array_DynamicWrite) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> S : array<vec4<i32>, 8>;
+
+fn a(pre : i32, p : ptr<storage, vec4<i32>, read_write>, post : i32) {
+ *p = vec4<i32>();
+}
+
+fn b() {
+ let I = 3;
+ a(10, &S[I], 20);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> S : array<vec4<i32>, 8>;
+
+alias S_X = array<u32, 1u>;
+
+fn a_S_X(pre : i32, p : S_X, post : i32) {
+ S[p[0]] = vec4<i32>();
+}
+
+fn b() {
+ let I = 3;
+ a_S_X(10, S_X(u32(I)), 20);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessStorageASTest, CallChaining) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct Inner {
+ mat : mat3x4<f32>,
+};
+
+alias InnerArr = array<Inner, 4>;
+
+struct Outer {
+ arr : InnerArr,
+ mat : mat3x4<f32>,
+};
+
+@group(0) @binding(0) var<storage> S : Outer;
+
+fn f0(p : ptr<storage, vec4<f32>>) -> f32 {
+ return (*p).x;
+}
+
+fn f1(p : ptr<storage, mat3x4<f32>>) -> f32 {
+ var res : f32;
+ {
+ // call f0() with inline usage of p
+ res += f0(&(*p)[1]);
+ }
+ {
+ // call f0() with pointer-let usage of p
+ let p_vec = &(*p)[1];
+ res += f0(p_vec);
+ }
+ {
+ // call f0() with inline usage of S
+ res += f0(&S.arr[2].mat[1]);
+ }
+ {
+ // call f0() with pointer-let usage of S
+ let p_vec = &S.arr[2].mat[1];
+ res += f0(p_vec);
+ }
+ return res;
+}
+
+fn f2(p : ptr<storage, Inner>) -> f32 {
+ let p_mat = &(*p).mat;
+ return f1(p_mat);
+}
+
+fn f3(p0 : ptr<storage, InnerArr>, p1 : ptr<storage, mat3x4<f32>>) -> f32 {
+ let p0_inner = &(*p0)[3];
+ return f2(p0_inner) + f1(p1);
+}
+
+fn f4(p : ptr<storage, Outer>) -> f32 {
+ return f3(&(*p).arr, &S.mat);
+}
+
+fn b() {
+ f4(&S);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct Inner {
+ mat : mat3x4<f32>,
+}
+
+alias InnerArr = array<Inner, 4>;
+
+struct Outer {
+ arr : InnerArr,
+ mat : mat3x4<f32>,
+}
+
+@group(0) @binding(0) var<storage> S : Outer;
+
+alias S_mat_X = array<u32, 1u>;
+
+fn f0_S_mat_X(p : S_mat_X) -> f32 {
+ return S.mat[p[0]].x;
+}
+
+alias S_arr_X_mat_X = array<u32, 2u>;
+
+fn f0_S_arr_X_mat_X(p : S_arr_X_mat_X) -> f32 {
+ return S.arr[p[0]].mat[p[0]].x;
+}
+
+alias S_arr_X_mat_X_1 = array<u32, 2u>;
+
+fn f0_S_arr_X_mat_X_1(p : S_arr_X_mat_X_1) -> f32 {
+ return S.arr[p[0]].mat[p[1]].x;
+}
+
+fn f1_S_mat() -> f32 {
+ var res : f32;
+ {
+ res += f0_S_mat_X(S_mat_X(1));
+ }
+ {
+ let p_vec = &(S.mat[1]);
+ res += f0_S_mat_X(S_mat_X(1));
+ }
+ {
+ res += f0_S_arr_X_mat_X_1(S_arr_X_mat_X_1(2, 1));
+ }
+ {
+ let p_vec = &(S.arr[2].mat[1]);
+ res += f0_S_arr_X_mat_X_1(S_arr_X_mat_X_1(2, 1));
+ }
+ return res;
+}
+
+alias S_arr_X_mat = array<u32, 1u>;
+
+fn f1_S_arr_X_mat(p : S_arr_X_mat) -> f32 {
+ var res : f32;
+ {
+ res += f0_S_arr_X_mat_X(S_arr_X_mat_X(p[0u], 1));
+ }
+ {
+ let p_vec = &(S.arr[p[0]].mat[1]);
+ res += f0_S_arr_X_mat_X(S_arr_X_mat_X(p[0u], 1));
+ }
+ {
+ res += f0_S_arr_X_mat_X_1(S_arr_X_mat_X_1(2, 1));
+ }
+ {
+ let p_vec = &(S.arr[2].mat[1]);
+ res += f0_S_arr_X_mat_X_1(S_arr_X_mat_X_1(2, 1));
+ }
+ return res;
+}
+
+alias S_arr_X = array<u32, 1u>;
+
+fn f2_S_arr_X(p : S_arr_X) -> f32 {
+ let p_mat = &(S.arr[p[0]].mat);
+ return f1_S_arr_X_mat(S_arr_X_mat(p[0u]));
+}
+
+fn f3_S_arr_S_mat() -> f32 {
+ let p0_inner = &(S.arr[3]);
+ return (f2_S_arr_X(S_arr_X(3)) + f1_S_mat());
+}
+
+fn f4_S() -> f32 {
+ return f3_S_arr_S_mat();
+}
+
+fn b() {
+ f4_S();
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace storage_as_tests
+
+////////////////////////////////////////////////////////////////////////////////
+// 'workgroup' address space
+////////////////////////////////////////////////////////////////////////////////
+namespace workgroup_as_tests {
+
+using DirectVariableAccessWorkgroupASTest = TransformTest;
+
+TEST_F(DirectVariableAccessWorkgroupASTest, Param_ptr_vec4i32_Via_array_StaticRead) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<workgroup> W : array<vec4<i32>, 8>;
+
+fn a(pre : i32, p : ptr<workgroup, vec4<i32>>, post : i32) -> vec4<i32> {
+ return *p;
+}
+
+fn b() {
+ a(10, &W[3], 20);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<workgroup> W : array<vec4<i32>, 8>;
+
+alias W_X = array<u32, 1u>;
+
+fn a_W_X(pre : i32, p : W_X, post : i32) -> vec4<i32> {
+ return W[p[0]];
+}
+
+fn b() {
+ a_W_X(10, W_X(3), 20);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessWorkgroupASTest, Param_ptr_vec4i32_Via_array_StaticWrite) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<workgroup> W : array<vec4<i32>, 8>;
+
+fn a(pre : i32, p : ptr<workgroup, vec4<i32>>, post : i32) {
+ *p = vec4<i32>();
+}
+
+fn b() {
+ a(10, &W[3], 20);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<workgroup> W : array<vec4<i32>, 8>;
+
+alias W_X = array<u32, 1u>;
+
+fn a_W_X(pre : i32, p : W_X, post : i32) {
+ W[p[0]] = vec4<i32>();
+}
+
+fn b() {
+ a_W_X(10, W_X(3), 20);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessWorkgroupASTest, CallChaining) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct Inner {
+ mat : mat3x4<f32>,
+};
+
+alias InnerArr = array<Inner, 4>;
+
+struct Outer {
+ arr : InnerArr,
+ mat : mat3x4<f32>,
+};
+
+var<workgroup> W : Outer;
+
+fn f0(p : ptr<workgroup, vec4<f32>>) -> f32 {
+ return (*p).x;
+}
+
+fn f1(p : ptr<workgroup, mat3x4<f32>>) -> f32 {
+ var res : f32;
+ {
+ // call f0() with inline usage of p
+ res += f0(&(*p)[1]);
+ }
+ {
+ // call f0() with pointer-let usage of p
+ let p_vec = &(*p)[1];
+ res += f0(p_vec);
+ }
+ {
+ // call f0() with inline usage of W
+ res += f0(&W.arr[2].mat[1]);
+ }
+ {
+ // call f0() with pointer-let usage of W
+ let p_vec = &W.arr[2].mat[1];
+ res += f0(p_vec);
+ }
+ return res;
+}
+
+fn f2(p : ptr<workgroup, Inner>) -> f32 {
+ let p_mat = &(*p).mat;
+ return f1(p_mat);
+}
+
+fn f3(p0 : ptr<workgroup, InnerArr>, p1 : ptr<workgroup, mat3x4<f32>>) -> f32 {
+ let p0_inner = &(*p0)[3];
+ return f2(p0_inner) + f1(p1);
+}
+
+fn f4(p : ptr<workgroup, Outer>) -> f32 {
+ return f3(&(*p).arr, &W.mat);
+}
+
+fn b() {
+ f4(&W);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct Inner {
+ mat : mat3x4<f32>,
+}
+
+alias InnerArr = array<Inner, 4>;
+
+struct Outer {
+ arr : InnerArr,
+ mat : mat3x4<f32>,
+}
+
+var<workgroup> W : Outer;
+
+alias W_mat_X = array<u32, 1u>;
+
+fn f0_W_mat_X(p : W_mat_X) -> f32 {
+ return W.mat[p[0]].x;
+}
+
+alias W_arr_X_mat_X = array<u32, 2u>;
+
+fn f0_W_arr_X_mat_X(p : W_arr_X_mat_X) -> f32 {
+ return W.arr[p[0]].mat[p[0]].x;
+}
+
+alias W_arr_X_mat_X_1 = array<u32, 2u>;
+
+fn f0_W_arr_X_mat_X_1(p : W_arr_X_mat_X_1) -> f32 {
+ return W.arr[p[0]].mat[p[1]].x;
+}
+
+fn f1_W_mat() -> f32 {
+ var res : f32;
+ {
+ res += f0_W_mat_X(W_mat_X(1));
+ }
+ {
+ let p_vec = &(W.mat[1]);
+ res += f0_W_mat_X(W_mat_X(1));
+ }
+ {
+ res += f0_W_arr_X_mat_X_1(W_arr_X_mat_X_1(2, 1));
+ }
+ {
+ let p_vec = &(W.arr[2].mat[1]);
+ res += f0_W_arr_X_mat_X_1(W_arr_X_mat_X_1(2, 1));
+ }
+ return res;
+}
+
+alias W_arr_X_mat = array<u32, 1u>;
+
+fn f1_W_arr_X_mat(p : W_arr_X_mat) -> f32 {
+ var res : f32;
+ {
+ res += f0_W_arr_X_mat_X(W_arr_X_mat_X(p[0u], 1));
+ }
+ {
+ let p_vec = &(W.arr[p[0]].mat[1]);
+ res += f0_W_arr_X_mat_X(W_arr_X_mat_X(p[0u], 1));
+ }
+ {
+ res += f0_W_arr_X_mat_X_1(W_arr_X_mat_X_1(2, 1));
+ }
+ {
+ let p_vec = &(W.arr[2].mat[1]);
+ res += f0_W_arr_X_mat_X_1(W_arr_X_mat_X_1(2, 1));
+ }
+ return res;
+}
+
+alias W_arr_X = array<u32, 1u>;
+
+fn f2_W_arr_X(p : W_arr_X) -> f32 {
+ let p_mat = &(W.arr[p[0]].mat);
+ return f1_W_arr_X_mat(W_arr_X_mat(p[0u]));
+}
+
+fn f3_W_arr_W_mat() -> f32 {
+ let p0_inner = &(W.arr[3]);
+ return (f2_W_arr_X(W_arr_X(3)) + f1_W_mat());
+}
+
+fn f4_W() -> f32 {
+ return f3_W_arr_W_mat();
+}
+
+fn b() {
+ f4_W();
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace workgroup_as_tests
+
+////////////////////////////////////////////////////////////////////////////////
+// 'private' address space
+////////////////////////////////////////////////////////////////////////////////
+namespace private_as_tests {
+
+using DirectVariableAccessPrivateASTest = TransformTest;
+
+TEST_F(DirectVariableAccessPrivateASTest, Enabled_Param_ptr_i32_read) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn a(pre : i32, p : ptr<private, i32>, post : i32) -> i32 {
+ return *(p);
+}
+
+var<private> P : i32;
+
+fn b() {
+ a(10, &(P), 20);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn a_F(pre : i32, p : ptr<private, i32>, post : i32) -> i32 {
+ return *(p);
+}
+
+var<private> P : i32;
+
+fn b() {
+ a_F(10, &(P), 20);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src, EnablePrivate());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessPrivateASTest, Enabled_Param_ptr_i32_write) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn a(pre : i32, p : ptr<private, i32>, post : i32) {
+ *(p) = 42;
+}
+
+var<private> P : i32;
+
+fn b() {
+ a(10, &(P), 20);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn a_F(pre : i32, p : ptr<private, i32>, post : i32) {
+ *(p) = 42;
+}
+
+var<private> P : i32;
+
+fn b() {
+ a_F(10, &(P), 20);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src, EnablePrivate());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessPrivateASTest, Enabled_Param_ptr_i32_Via_struct_read) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ i : i32,
+};
+
+fn a(pre : i32, p : ptr<private, i32>, post : i32) -> i32 {
+ return *p;
+}
+
+var<private> P : str;
+
+fn b() {
+ a(10, &P.i, 20);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ i : i32,
+}
+
+fn a_F_i(pre : i32, p : ptr<private, str>, post : i32) -> i32 {
+ return (*(p)).i;
+}
+
+var<private> P : str;
+
+fn b() {
+ a_F_i(10, &(P), 20);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src, EnablePrivate());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessPrivateASTest, Disabled_Param_ptr_i32_Via_struct_read) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ i : i32,
+}
+
+fn a(pre : i32, p : ptr<private, i32>, post : i32) -> i32 {
+ return *(p);
+}
+
+var<private> P : str;
+
+fn b() {
+ a(10, &(P.i), 20);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessPrivateASTest, Enabled_Param_ptr_arr_i32_Via_struct_write) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ arr : array<i32, 4>,
+};
+
+fn a(pre : i32, p : ptr<private, array<i32, 4>>, post : i32) {
+ *p = array<i32, 4>();
+}
+
+var<private> P : str;
+
+fn b() {
+ a(10, &P.arr, 20);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ arr : array<i32, 4>,
+}
+
+fn a_F_arr(pre : i32, p : ptr<private, str>, post : i32) {
+ (*(p)).arr = array<i32, 4>();
+}
+
+var<private> P : str;
+
+fn b() {
+ a_F_arr(10, &(P), 20);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src, EnablePrivate());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessPrivateASTest, Disabled_Param_ptr_arr_i32_Via_struct_write) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ arr : array<i32, 4>,
+}
+
+fn a(pre : i32, p : ptr<private, array<i32, 4>>, post : i32) {
+ *(p) = array<i32, 4>();
+}
+
+var<private> P : str;
+
+fn b() {
+ a(10, &(P.arr), 20);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessPrivateASTest, Enabled_Param_ptr_i32_mixed) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ i : i32,
+};
+
+fn a(pre : i32, p : ptr<private, i32>, post : i32) -> i32 {
+ return *p;
+}
+
+var<private> Pi : i32;
+var<private> Ps : str;
+var<private> Pa : array<i32, 4>;
+
+fn b() {
+ a(10, &Pi, 20);
+ a(30, &Ps.i, 40);
+ a(50, &Pa[2], 60);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ i : i32,
+}
+
+fn a_F(pre : i32, p : ptr<private, i32>, post : i32) -> i32 {
+ return *(p);
+}
+
+fn a_F_i(pre : i32, p : ptr<private, str>, post : i32) -> i32 {
+ return (*(p)).i;
+}
+
+alias F_X = array<u32, 1u>;
+
+fn a_F_X(pre : i32, p_base : ptr<private, array<i32, 4u>>, p_indices : F_X, post : i32) -> i32 {
+ return (*(p_base))[p_indices[0]];
+}
+
+var<private> Pi : i32;
+
+var<private> Ps : str;
+
+var<private> Pa : array<i32, 4>;
+
+alias F_X_1 = array<u32, 1u>;
+
+fn b() {
+ a_F(10, &(Pi), 20);
+ a_F_i(30, &(Ps), 40);
+ a_F_X(50, &(Pa), F_X_1(2), 60);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src, EnablePrivate());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessPrivateASTest, Disabled_Param_ptr_i32_mixed) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ i : i32,
+}
+
+fn a(pre : i32, p : ptr<private, i32>, post : i32) -> i32 {
+ return *(p);
+}
+
+var<private> Pi : i32;
+
+var<private> Ps : str;
+
+var<private> Pa : array<i32, 4>;
+
+fn b() {
+ a(10, &(Pi), 20);
+ a(10, &(Ps.i), 20);
+ a(10, &(Pa[2]), 20);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessPrivateASTest, Enabled_CallChaining) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct Inner {
+ mat : mat3x4<f32>,
+};
+
+alias InnerArr = array<Inner, 4>;
+
+struct Outer {
+ arr : InnerArr,
+ mat : mat3x4<f32>,
+};
+
+var<private> P : Outer;
+
+fn f0(p : ptr<private, vec4<f32>>) -> f32 {
+ return (*p).x;
+}
+
+fn f1(p : ptr<private, mat3x4<f32>>) -> f32 {
+ var res : f32;
+ {
+ // call f0() with inline usage of p
+ res += f0(&(*p)[1]);
+ }
+ {
+ // call f0() with pointer-let usage of p
+ let p_vec = &(*p)[1];
+ res += f0(p_vec);
+ }
+ {
+ // call f0() with inline usage of P
+ res += f0(&P.arr[2].mat[1]);
+ }
+ {
+ // call f0() with pointer-let usage of P
+ let p_vec = &P.arr[2].mat[1];
+ res += f0(p_vec);
+ }
+ return res;
+}
+
+fn f2(p : ptr<private, Inner>) -> f32 {
+ let p_mat = &(*p).mat;
+ return f1(p_mat);
+}
+
+fn f3(p0 : ptr<private, InnerArr>, p1 : ptr<private, mat3x4<f32>>) -> f32 {
+ let p0_inner = &(*p0)[3];
+ return f2(p0_inner) + f1(p1);
+}
+
+fn f4(p : ptr<private, Outer>) -> f32 {
+ return f3(&(*p).arr, &P.mat);
+}
+
+fn b() {
+ f4(&P);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct Inner {
+ mat : mat3x4<f32>,
+}
+
+alias InnerArr = array<Inner, 4>;
+
+struct Outer {
+ arr : InnerArr,
+ mat : mat3x4<f32>,
+}
+
+var<private> P : Outer;
+
+alias F_mat_X = array<u32, 1u>;
+
+fn f0_F_mat_X(p_base : ptr<private, Outer>, p_indices : F_mat_X) -> f32 {
+ return (*(p_base)).mat[p_indices[0]].x;
+}
+
+alias F_arr_X_mat_X = array<u32, 2u>;
+
+fn f0_F_arr_X_mat_X(p_base : ptr<private, Outer>, p_indices : F_arr_X_mat_X) -> f32 {
+ return (*(p_base)).arr[p_indices[0]].mat[p_indices[0]].x;
+}
+
+alias F_arr_X_mat_X_1 = array<u32, 2u>;
+
+fn f0_F_arr_X_mat_X_1(p_base : ptr<private, Outer>, p_indices : F_arr_X_mat_X_1) -> f32 {
+ return (*(p_base)).arr[p_indices[0]].mat[p_indices[1]].x;
+}
+
+alias F_mat_X_1 = array<u32, 1u>;
+
+alias F_arr_X_mat_X_2 = array<u32, 2u>;
+
+fn f1_F_mat(p : ptr<private, Outer>) -> f32 {
+ var res : f32;
+ {
+ res += f0_F_mat_X(p, F_mat_X_1(1));
+ }
+ {
+ let p_vec = &((*(p)).mat[1]);
+ res += f0_F_mat_X(p, F_mat_X_1(1));
+ }
+ {
+ res += f0_F_arr_X_mat_X_1(&(P), F_arr_X_mat_X_2(2, 1));
+ }
+ {
+ let p_vec = &(P.arr[2].mat[1]);
+ res += f0_F_arr_X_mat_X_1(&(P), F_arr_X_mat_X_2(2, 1));
+ }
+ return res;
+}
+
+alias F_arr_X_mat = array<u32, 1u>;
+
+alias F_arr_X_mat_X_3 = array<u32, 2u>;
+
+fn f1_F_arr_X_mat(p_base : ptr<private, Outer>, p_indices : F_arr_X_mat) -> f32 {
+ var res : f32;
+ {
+ res += f0_F_arr_X_mat_X(p_base, F_arr_X_mat_X_3(p_indices[0u], 1));
+ }
+ {
+ let p_vec = &((*(p_base)).arr[p_indices[0]].mat[1]);
+ res += f0_F_arr_X_mat_X(p_base, F_arr_X_mat_X_3(p_indices[0u], 1));
+ }
+ {
+ res += f0_F_arr_X_mat_X_1(&(P), F_arr_X_mat_X_2(2, 1));
+ }
+ {
+ let p_vec = &(P.arr[2].mat[1]);
+ res += f0_F_arr_X_mat_X_1(&(P), F_arr_X_mat_X_2(2, 1));
+ }
+ return res;
+}
+
+alias F_arr_X = array<u32, 1u>;
+
+alias F_arr_X_mat_1 = array<u32, 1u>;
+
+fn f2_F_arr_X(p_base : ptr<private, Outer>, p_indices : F_arr_X) -> f32 {
+ let p_mat = &((*(p_base)).arr[p_indices[0]].mat);
+ return f1_F_arr_X_mat(p_base, F_arr_X_mat_1(p_indices[0u]));
+}
+
+alias F_arr_X_1 = array<u32, 1u>;
+
+fn f3_F_arr_F_mat(p0 : ptr<private, Outer>, p1 : ptr<private, Outer>) -> f32 {
+ let p0_inner = &((*(p0)).arr[3]);
+ return (f2_F_arr_X(p0, F_arr_X_1(3)) + f1_F_mat(p1));
+}
+
+fn f4_F(p : ptr<private, Outer>) -> f32 {
+ return f3_F_arr_F_mat(p, &(P));
+}
+
+fn b() {
+ f4_F(&(P));
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src, EnablePrivate());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessPrivateASTest, Disabled_CallChaining) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct Inner {
+ mat : mat3x4<f32>,
+}
+
+alias InnerArr = array<Inner, 4>;
+
+struct Outer {
+ arr : InnerArr,
+ mat : mat3x4<f32>,
+}
+
+var<private> P : Outer;
+
+fn f0(p : ptr<private, vec4<f32>>) -> f32 {
+ return (*(p)).x;
+}
+
+fn f1(p : ptr<private, mat3x4<f32>>) -> f32 {
+ var res : f32;
+ {
+ res += f0(&((*(p))[1]));
+ }
+ {
+ let p_vec = &((*(p))[1]);
+ res += f0(p_vec);
+ }
+ {
+ res += f0(&(P.arr[2].mat[1]));
+ }
+ {
+ let p_vec = &(P.arr[2].mat[1]);
+ res += f0(p_vec);
+ }
+ return res;
+}
+
+fn f2(p : ptr<private, Inner>) -> f32 {
+ let p_mat = &((*(p)).mat);
+ return f1(p_mat);
+}
+
+fn f3(p0 : ptr<private, InnerArr>, p1 : ptr<private, mat3x4<f32>>) -> f32 {
+ let p0_inner = &((*(p0))[3]);
+ return (f2(p0_inner) + f1(p1));
+}
+
+fn f4(p : ptr<private, Outer>) -> f32 {
+ return f3(&((*(p)).arr), &(P.mat));
+}
+
+fn b() {
+ f4(&(P));
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace private_as_tests
+
+////////////////////////////////////////////////////////////////////////////////
+// 'function' address space
+////////////////////////////////////////////////////////////////////////////////
+namespace function_as_tests {
+
+using DirectVariableAccessFunctionASTest = TransformTest;
+
+TEST_F(DirectVariableAccessFunctionASTest, Enabled_LocalPtr) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn f() {
+ var v : i32;
+ let p : ptr<function, i32> = &(v);
+ var x : i32 = *(p);
+}
+)";
+
+ auto* expect = src; // Nothing changes
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessFunctionASTest, Enabled_Param_ptr_i32_read) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn a(pre : i32, p : ptr<function, i32>, post : i32) -> i32 {
+ return *(p);
+}
+
+fn b() {
+ var F : i32;
+ a(10, &(F), 20);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn a_F(pre : i32, p : ptr<function, i32>, post : i32) -> i32 {
+ return *(p);
+}
+
+fn b() {
+ var F : i32;
+ a_F(10, &(F), 20);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src, EnableFunction());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessFunctionASTest, Enabled_Param_ptr_i32_write) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn a(pre : i32, p : ptr<function, i32>, post : i32) {
+ *(p) = 42;
+}
+
+fn b() {
+ var F : i32;
+ a(10, &(F), 20);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn a_F(pre : i32, p : ptr<function, i32>, post : i32) {
+ *(p) = 42;
+}
+
+fn b() {
+ var F : i32;
+ a_F(10, &(F), 20);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src, EnableFunction());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessFunctionASTest, Enabled_Param_ptr_i32_Via_struct_read) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ i : i32,
+};
+
+fn a(pre : i32, p : ptr<function, i32>, post : i32) -> i32 {
+ return *p;
+}
+
+fn b() {
+ var F : str;
+ a(10, &F.i, 20);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ i : i32,
+}
+
+fn a_F_i(pre : i32, p : ptr<function, str>, post : i32) -> i32 {
+ return (*(p)).i;
+}
+
+fn b() {
+ var F : str;
+ a_F_i(10, &(F), 20);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src, EnableFunction());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessFunctionASTest, Enabled_Param_ptr_arr_i32_Via_struct_write) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ arr : array<i32, 4>,
+};
+
+fn a(pre : i32, p : ptr<function, array<i32, 4>>, post : i32) {
+ *p = array<i32, 4>();
+}
+
+fn b() {
+ var F : str;
+ a(10, &F.arr, 20);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ arr : array<i32, 4>,
+}
+
+fn a_F_arr(pre : i32, p : ptr<function, str>, post : i32) {
+ (*(p)).arr = array<i32, 4>();
+}
+
+fn b() {
+ var F : str;
+ a_F_arr(10, &(F), 20);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src, EnableFunction());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessFunctionASTest, Enabled_Param_ptr_i32_mixed) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ i : i32,
+};
+
+fn a(pre : i32, p : ptr<function, i32>, post : i32) -> i32 {
+ return *p;
+}
+
+fn b() {
+ var Fi : i32;
+ var Fs : str;
+ var Fa : array<i32, 4>;
+
+ a(10, &Fi, 20);
+ a(30, &Fs.i, 40);
+ a(50, &Fa[2], 60);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ i : i32,
+}
+
+fn a_F(pre : i32, p : ptr<function, i32>, post : i32) -> i32 {
+ return *(p);
+}
+
+fn a_F_i(pre : i32, p : ptr<function, str>, post : i32) -> i32 {
+ return (*(p)).i;
+}
+
+alias F_X = array<u32, 1u>;
+
+fn a_F_X(pre : i32, p_base : ptr<function, array<i32, 4u>>, p_indices : F_X, post : i32) -> i32 {
+ return (*(p_base))[p_indices[0]];
+}
+
+alias F_X_1 = array<u32, 1u>;
+
+fn b() {
+ var Fi : i32;
+ var Fs : str;
+ var Fa : array<i32, 4>;
+ a_F(10, &(Fi), 20);
+ a_F_i(30, &(Fs), 40);
+ a_F_X(50, &(Fa), F_X_1(2), 60);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src, EnableFunction());
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessFunctionASTest, Disabled_Param_ptr_i32_Via_struct_read) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ i : i32,
+}
+
+fn a(pre : i32, p : ptr<function, i32>, post : i32) -> i32 {
+ return *(p);
+}
+
+fn b() {
+ var F : str;
+ a(10, &(F.i), 20);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessFunctionASTest, Disabled_Param_ptr_arr_i32_Via_struct_write) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ arr : array<i32, 4>,
+}
+
+fn a(pre : i32, p : ptr<function, array<i32, 4>>, post : i32) {
+ *(p) = array<i32, 4>();
+}
+
+fn b() {
+ var F : str;
+ a(10, &(F.arr), 20);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace function_as_tests
+
+////////////////////////////////////////////////////////////////////////////////
+// complex tests
+////////////////////////////////////////////////////////////////////////////////
+namespace complex_tests {
+
+using DirectVariableAccessComplexTest = TransformTest;
+
+TEST_F(DirectVariableAccessComplexTest, Param_ptr_mixed_vec4i32_ViaMultiple) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ i : vec4<i32>,
+};
+
+@group(0) @binding(0) var<uniform> U : vec4<i32>;
+@group(0) @binding(1) var<uniform> U_str : str;
+@group(0) @binding(2) var<uniform> U_arr : array<vec4<i32>, 8>;
+@group(0) @binding(3) var<uniform> U_arr_arr : array<array<vec4<i32>, 8>, 4>;
+
+@group(1) @binding(0) var<storage> S : vec4<i32>;
+@group(1) @binding(1) var<storage> S_str : str;
+@group(1) @binding(2) var<storage> S_arr : array<vec4<i32>, 8>;
+@group(1) @binding(3) var<storage> S_arr_arr : array<array<vec4<i32>, 8>, 4>;
+
+ var<workgroup> W : vec4<i32>;
+ var<workgroup> W_str : str;
+ var<workgroup> W_arr : array<vec4<i32>, 8>;
+ var<workgroup> W_arr_arr : array<array<vec4<i32>, 8>, 4>;
+
+fn fn_u(p : ptr<uniform, vec4<i32>>) -> vec4<i32> {
+ return *p;
+}
+
+fn fn_s(p : ptr<storage, vec4<i32>>) -> vec4<i32> {
+ return *p;
+}
+
+fn fn_w(p : ptr<workgroup, vec4<i32>>) -> vec4<i32> {
+ return *p;
+}
+
+fn b() {
+ let I = 3;
+ let J = 4;
+
+ let u = fn_u(&U);
+ let u_str = fn_u(&U_str.i);
+ let u_arr0 = fn_u(&U_arr[0]);
+ let u_arr1 = fn_u(&U_arr[1]);
+ let u_arrI = fn_u(&U_arr[I]);
+ let u_arr1_arr0 = fn_u(&U_arr_arr[1][0]);
+ let u_arr2_arrI = fn_u(&U_arr_arr[2][I]);
+ let u_arrI_arr2 = fn_u(&U_arr_arr[I][2]);
+ let u_arrI_arrJ = fn_u(&U_arr_arr[I][J]);
+
+ let s = fn_s(&S);
+ let s_str = fn_s(&S_str.i);
+ let s_arr0 = fn_s(&S_arr[0]);
+ let s_arr1 = fn_s(&S_arr[1]);
+ let s_arrI = fn_s(&S_arr[I]);
+ let s_arr1_arr0 = fn_s(&S_arr_arr[1][0]);
+ let s_arr2_arrI = fn_s(&S_arr_arr[2][I]);
+ let s_arrI_arr2 = fn_s(&S_arr_arr[I][2]);
+ let s_arrI_arrJ = fn_s(&S_arr_arr[I][J]);
+
+ let w = fn_w(&W);
+ let w_str = fn_w(&W_str.i);
+ let w_arr0 = fn_w(&W_arr[0]);
+ let w_arr1 = fn_w(&W_arr[1]);
+ let w_arrI = fn_w(&W_arr[I]);
+ let w_arr1_arr0 = fn_w(&W_arr_arr[1][0]);
+ let w_arr2_arrI = fn_w(&W_arr_arr[2][I]);
+ let w_arrI_arr2 = fn_w(&W_arr_arr[I][2]);
+ let w_arrI_arrJ = fn_w(&W_arr_arr[I][J]);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct str {
+ i : vec4<i32>,
+}
+
+@group(0) @binding(0) var<uniform> U : vec4<i32>;
+
+@group(0) @binding(1) var<uniform> U_str : str;
+
+@group(0) @binding(2) var<uniform> U_arr : array<vec4<i32>, 8>;
+
+@group(0) @binding(3) var<uniform> U_arr_arr : array<array<vec4<i32>, 8>, 4>;
+
+@group(1) @binding(0) var<storage> S : vec4<i32>;
+
+@group(1) @binding(1) var<storage> S_str : str;
+
+@group(1) @binding(2) var<storage> S_arr : array<vec4<i32>, 8>;
+
+@group(1) @binding(3) var<storage> S_arr_arr : array<array<vec4<i32>, 8>, 4>;
+
+var<workgroup> W : vec4<i32>;
+
+var<workgroup> W_str : str;
+
+var<workgroup> W_arr : array<vec4<i32>, 8>;
+
+var<workgroup> W_arr_arr : array<array<vec4<i32>, 8>, 4>;
+
+fn fn_u_U() -> vec4<i32> {
+ return U;
+}
+
+fn fn_u_U_str_i() -> vec4<i32> {
+ return U_str.i;
+}
+
+alias U_arr_X = array<u32, 1u>;
+
+fn fn_u_U_arr_X(p : U_arr_X) -> vec4<i32> {
+ return U_arr[p[0]];
+}
+
+alias U_arr_arr_X_X = array<u32, 2u>;
+
+fn fn_u_U_arr_arr_X_X(p : U_arr_arr_X_X) -> vec4<i32> {
+ return U_arr_arr[p[0]][p[1]];
+}
+
+fn fn_s_S() -> vec4<i32> {
+ return S;
+}
+
+fn fn_s_S_str_i() -> vec4<i32> {
+ return S_str.i;
+}
+
+alias S_arr_X = array<u32, 1u>;
+
+fn fn_s_S_arr_X(p : S_arr_X) -> vec4<i32> {
+ return S_arr[p[0]];
+}
+
+alias S_arr_arr_X_X = array<u32, 2u>;
+
+fn fn_s_S_arr_arr_X_X(p : S_arr_arr_X_X) -> vec4<i32> {
+ return S_arr_arr[p[0]][p[1]];
+}
+
+fn fn_w_W() -> vec4<i32> {
+ return W;
+}
+
+fn fn_w_W_str_i() -> vec4<i32> {
+ return W_str.i;
+}
+
+alias W_arr_X = array<u32, 1u>;
+
+fn fn_w_W_arr_X(p : W_arr_X) -> vec4<i32> {
+ return W_arr[p[0]];
+}
+
+alias W_arr_arr_X_X = array<u32, 2u>;
+
+fn fn_w_W_arr_arr_X_X(p : W_arr_arr_X_X) -> vec4<i32> {
+ return W_arr_arr[p[0]][p[1]];
+}
+
+fn b() {
+ let I = 3;
+ let J = 4;
+ let u = fn_u_U();
+ let u_str = fn_u_U_str_i();
+ let u_arr0 = fn_u_U_arr_X(U_arr_X(0));
+ let u_arr1 = fn_u_U_arr_X(U_arr_X(1));
+ let u_arrI = fn_u_U_arr_X(U_arr_X(u32(I)));
+ let u_arr1_arr0 = fn_u_U_arr_arr_X_X(U_arr_arr_X_X(1, 0));
+ let u_arr2_arrI = fn_u_U_arr_arr_X_X(U_arr_arr_X_X(2, u32(I)));
+ let u_arrI_arr2 = fn_u_U_arr_arr_X_X(U_arr_arr_X_X(u32(I), 2));
+ let u_arrI_arrJ = fn_u_U_arr_arr_X_X(U_arr_arr_X_X(u32(I), u32(J)));
+ let s = fn_s_S();
+ let s_str = fn_s_S_str_i();
+ let s_arr0 = fn_s_S_arr_X(S_arr_X(0));
+ let s_arr1 = fn_s_S_arr_X(S_arr_X(1));
+ let s_arrI = fn_s_S_arr_X(S_arr_X(u32(I)));
+ let s_arr1_arr0 = fn_s_S_arr_arr_X_X(S_arr_arr_X_X(1, 0));
+ let s_arr2_arrI = fn_s_S_arr_arr_X_X(S_arr_arr_X_X(2, u32(I)));
+ let s_arrI_arr2 = fn_s_S_arr_arr_X_X(S_arr_arr_X_X(u32(I), 2));
+ let s_arrI_arrJ = fn_s_S_arr_arr_X_X(S_arr_arr_X_X(u32(I), u32(J)));
+ let w = fn_w_W();
+ let w_str = fn_w_W_str_i();
+ let w_arr0 = fn_w_W_arr_X(W_arr_X(0));
+ let w_arr1 = fn_w_W_arr_X(W_arr_X(1));
+ let w_arrI = fn_w_W_arr_X(W_arr_X(u32(I)));
+ let w_arr1_arr0 = fn_w_W_arr_arr_X_X(W_arr_arr_X_X(1, 0));
+ let w_arr2_arrI = fn_w_W_arr_arr_X_X(W_arr_arr_X_X(2, u32(I)));
+ let w_arrI_arr2 = fn_w_W_arr_arr_X_X(W_arr_arr_X_X(u32(I), 2));
+ let w_arrI_arrJ = fn_w_W_arr_arr_X_X(W_arr_arr_X_X(u32(I), u32(J)));
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessComplexTest, Indexing) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage> S : array<array<array<array<i32, 9>, 9>, 9>, 50>;
+
+fn a(i : i32) -> i32 { return i; }
+
+fn b(p : ptr<storage, array<array<array<i32, 9>, 9>, 9>>) -> i32 {
+ return (*p) [ a( (*p)[0][1][2] )]
+ [ a( (*p)[a(3)][4][5] )]
+ [ a( (*p)[6][a(7)][8] )];
+}
+
+fn c() {
+ let v = b(&S[42]);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage> S : array<array<array<array<i32, 9>, 9>, 9>, 50>;
+
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+alias S_X = array<u32, 1u>;
+
+fn b_S_X(p : S_X) -> i32 {
+ return S[p[0]][a(S[p[0]][0][1][2])][a(S[p[0]][a(3)][4][5])][a(S[p[0]][6][a(7)][8])];
+}
+
+fn c() {
+ let v = b_S_X(S_X(42));
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessComplexTest, IndexingInPtrCall) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage> S : array<array<array<array<i32, 9>, 9>, 9>, 50>;
+
+fn a(pre : i32, i : ptr<storage, i32>, post : i32) -> i32 {
+ return *i;
+}
+
+fn b(p : ptr<storage, array<array<array<i32, 9>, 9>, 9>>) -> i32 {
+ return a(10, &(*p)[ a( 20, &(*p)[0][1][2], 30 )]
+ [ a( 40, &(*p)[3][4][5], 50 )]
+ [ a( 60, &(*p)[6][7][8], 70 )], 80);
+}
+
+fn c() {
+ let v = b(&S[42]);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage> S : array<array<array<array<i32, 9>, 9>, 9>, 50>;
+
+alias S_X_X_X_X = array<u32, 4u>;
+
+fn a_S_X_X_X_X(pre : i32, i : S_X_X_X_X, post : i32) -> i32 {
+ return S[i[0]][i[0]][i[1]][i[2]];
+}
+
+alias S_X = array<u32, 1u>;
+
+fn b_S_X(p : S_X) -> i32 {
+ return a_S_X_X_X_X(10, S_X_X_X_X(p[0u], u32(a_S_X_X_X_X(20, S_X_X_X_X(p[0u], 0, 1, 2), 30)), u32(a_S_X_X_X_X(40, S_X_X_X_X(p[0u], 3, 4, 5), 50)), u32(a_S_X_X_X_X(60, S_X_X_X_X(p[0u], 6, 7, 8), 70))), 80);
+}
+
+fn c() {
+ let v = b_S_X(S_X(42));
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DirectVariableAccessComplexTest, IndexingDualPointers) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage> S : array<array<array<i32, 9>, 9>, 50>;
+@group(0) @binding(0) var<uniform> U : array<array<array<vec4<i32>, 9>, 9>, 50>;
+
+fn a(i : i32) -> i32 { return i; }
+
+fn b(s : ptr<storage, array<array<i32, 9>, 9>>,
+ u : ptr<uniform, array<array<vec4<i32>, 9>, 9>>) -> i32 {
+ return (*s) [ a( (*u)[0][1].x )]
+ [ a( (*u)[a(3)][4].y )];
+}
+
+fn c() {
+ let v = b(&S[42], &U[24]);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage> S : array<array<array<i32, 9>, 9>, 50>;
+
+@group(0) @binding(0) var<uniform> U : array<array<array<vec4<i32>, 9>, 9>, 50>;
+
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+alias S_X = array<u32, 1u>;
+
+alias U_X = array<u32, 1u>;
+
+fn b_S_X_U_X(s : S_X, u : U_X) -> i32 {
+ return S[s[0]][a(U[u[0]][0][1].x)][a(U[u[0]][a(3)][4].y)];
+}
+
+fn c() {
+ let v = b_S_X_U_X(S_X(42), U_X(24));
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace complex_tests
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/disable_uniformity_analysis.cc b/src/tint/lang/wgsl/ast/transform/disable_uniformity_analysis.cc
new file mode 100644
index 0000000..d02d9de
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/disable_uniformity_analysis.cc
@@ -0,0 +1,46 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/disable_uniformity_analysis.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/module.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::DisableUniformityAnalysis);
+
+namespace tint::ast::transform {
+
+DisableUniformityAnalysis::DisableUniformityAnalysis() = default;
+
+DisableUniformityAnalysis::~DisableUniformityAnalysis() = default;
+
+Transform::ApplyResult DisableUniformityAnalysis::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (src->Sem().Module()->Extensions().Contains(
+ builtin::Extension::kChromiumDisableUniformityAnalysis)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+ b.Enable(builtin::Extension::kChromiumDisableUniformityAnalysis);
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/disable_uniformity_analysis.h b/src/tint/lang/wgsl/ast/transform/disable_uniformity_analysis.h
new file mode 100644
index 0000000..4205acd
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/disable_uniformity_analysis.h
@@ -0,0 +1,39 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_DISABLE_UNIFORMITY_ANALYSIS_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_DISABLE_UNIFORMITY_ANALYSIS_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// Disable uniformity analysis for the program.
+class DisableUniformityAnalysis final
+ : public utils::Castable<DisableUniformityAnalysis, Transform> {
+ public:
+ /// Constructor
+ DisableUniformityAnalysis();
+ /// Destructor
+ ~DisableUniformityAnalysis() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_DISABLE_UNIFORMITY_ANALYSIS_H_
diff --git a/src/tint/lang/wgsl/ast/transform/disable_uniformity_analysis_test.cc b/src/tint/lang/wgsl/ast/transform/disable_uniformity_analysis_test.cc
new file mode 100644
index 0000000..2791562
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/disable_uniformity_analysis_test.cc
@@ -0,0 +1,73 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/disable_uniformity_analysis.h"
+
+#include <string>
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using DisableUniformityAnalysisTest = TransformTest;
+
+TEST_F(DisableUniformityAnalysisTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_TRUE(ShouldRun<DisableUniformityAnalysis>(src));
+}
+
+TEST_F(DisableUniformityAnalysisTest, ShouldRunExtensionAlreadyPresent) {
+ auto* src = R"(
+enable chromium_disable_uniformity_analysis;
+)";
+
+ EXPECT_FALSE(ShouldRun<DisableUniformityAnalysis>(src));
+}
+
+TEST_F(DisableUniformityAnalysisTest, EmptyModule) {
+ auto* src = R"()";
+
+ auto* expect = R"(
+enable chromium_disable_uniformity_analysis;
+)";
+
+ auto got = Run<DisableUniformityAnalysis>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DisableUniformityAnalysisTest, NonEmptyModule) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read> global : i32;
+
+@compute @workgroup_size(64)
+fn main() {
+ if ((global == 42)) {
+ workgroupBarrier();
+ }
+}
+)";
+
+ auto expect = "\nenable chromium_disable_uniformity_analysis;\n" + std::string(src);
+
+ auto got = Run<DisableUniformityAnalysis>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/expand_compound_assignment.cc b/src/tint/lang/wgsl/ast/transform/expand_compound_assignment.cc
new file mode 100644
index 0000000..b9e17a4
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/expand_compound_assignment.cc
@@ -0,0 +1,187 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/expand_compound_assignment.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/compound_assignment_statement.h"
+#include "src/tint/lang/wgsl/ast/increment_decrement_statement.h"
+#include "src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/block_statement.h"
+#include "src/tint/sem/for_loop_statement.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/value_expression.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::ExpandCompoundAssignment);
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast::transform {
+
+namespace {
+
+bool ShouldRun(const Program* program) {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (node->IsAnyOf<CompoundAssignmentStatement, IncrementDecrementStatement>()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+/// PIMPL state for the transform
+struct ExpandCompoundAssignment::State {
+ /// Constructor
+ /// @param context the clone context
+ explicit State(CloneContext& context) : ctx(context), b(*ctx.dst), hoist_to_decl_before(ctx) {}
+
+ /// Replace `stmt` with a regular assignment statement of the form:
+ /// lhs = lhs op rhs
+ /// The LHS expression will only be evaluated once, and any side effects will
+ /// be hoisted to `let` declarations above the assignment statement.
+ /// @param stmt the statement to replace
+ /// @param lhs the lhs expression from the source statement
+ /// @param rhs the rhs expression in the destination module
+ /// @param op the binary operator
+ void Expand(const Statement* stmt, const Expression* lhs, const Expression* rhs, BinaryOp op) {
+ // Helper function to create the new LHS expression. This will be called
+ // twice when building the non-compound assignment statement, so must
+ // not produce expressions that cause side effects.
+ std::function<const Expression*()> new_lhs;
+
+ // Helper function to create a variable that is a pointer to `expr`.
+ auto hoist_pointer_to = [&](const Expression* expr) {
+ auto name = b.Sym();
+ auto* ptr = b.AddressOf(ctx.Clone(expr));
+ auto* decl = b.Decl(b.Let(name, ptr));
+ hoist_to_decl_before.InsertBefore(ctx.src->Sem().Get(stmt), decl);
+ return name;
+ };
+
+ // Helper function to hoist `expr` to a let declaration.
+ auto hoist_expr_to_let = [&](const Expression* expr) {
+ auto name = b.Sym();
+ auto* decl = b.Decl(b.Let(name, ctx.Clone(expr)));
+ hoist_to_decl_before.InsertBefore(ctx.src->Sem().Get(stmt), decl);
+ return name;
+ };
+
+ // Helper function that returns `true` if the type of `expr` is a vector.
+ auto is_vec = [&](const Expression* expr) {
+ if (auto* val_expr = ctx.src->Sem().GetVal(expr)) {
+ return val_expr->Type()->UnwrapRef()->Is<type::Vector>();
+ }
+ return false;
+ };
+
+ // Hoist the LHS expression subtree into local constants to produce a new
+ // LHS that we can evaluate twice.
+ // We need to special case compound assignments to vector components since
+ // we cannot take the address of a vector component.
+ auto* index_accessor = lhs->As<IndexAccessorExpression>();
+ auto* member_accessor = lhs->As<MemberAccessorExpression>();
+ if (lhs->Is<IdentifierExpression>() ||
+ (member_accessor && member_accessor->object->Is<IdentifierExpression>())) {
+ // This is the simple case with no side effects, so we can just use the
+ // original LHS expression directly.
+ // Before:
+ // foo.bar += rhs;
+ // After:
+ // foo.bar = foo.bar + rhs;
+ new_lhs = [&] { return ctx.Clone(lhs); };
+ } else if (index_accessor && is_vec(index_accessor->object)) {
+ // This is the case for vector component via an array accessor. We need
+ // to capture a pointer to the vector and also the index value.
+ // Before:
+ // v[idx()] += rhs;
+ // After:
+ // let vec_ptr = &v;
+ // let index = idx();
+ // (*vec_ptr)[index] = (*vec_ptr)[index] + rhs;
+ auto lhs_ptr = hoist_pointer_to(index_accessor->object);
+ auto index = hoist_expr_to_let(index_accessor->index);
+ new_lhs = [&, lhs_ptr, index] { return b.IndexAccessor(b.Deref(lhs_ptr), index); };
+ } else if (member_accessor && is_vec(member_accessor->object)) {
+ // This is the case for vector component via a member accessor. We just
+ // need to capture a pointer to the vector.
+ // Before:
+ // a[idx()].y += rhs;
+ // After:
+ // let vec_ptr = &a[idx()];
+ // (*vec_ptr).y = (*vec_ptr).y + rhs;
+ auto lhs_ptr = hoist_pointer_to(member_accessor->object);
+ new_lhs = [&, lhs_ptr] {
+ return b.MemberAccessor(b.Deref(lhs_ptr), ctx.Clone(member_accessor->member));
+ };
+ } else {
+ // For all other statements that may have side-effecting expressions, we
+ // just need to capture a pointer to the whole LHS.
+ // Before:
+ // a[idx()] += rhs;
+ // After:
+ // let lhs_ptr = &a[idx()];
+ // (*lhs_ptr) = (*lhs_ptr) + rhs;
+ auto lhs_ptr = hoist_pointer_to(lhs);
+ new_lhs = [&, lhs_ptr] { return b.Deref(lhs_ptr); };
+ }
+
+ // Replace the statement with a regular assignment statement.
+ auto* value = b.create<BinaryExpression>(op, new_lhs(), rhs);
+ ctx.Replace(stmt, b.Assign(new_lhs(), value));
+ }
+
+ private:
+ /// The clone context.
+ CloneContext& ctx;
+
+ /// The program builder.
+ ProgramBuilder& b;
+
+ /// The HoistToDeclBefore helper instance.
+ HoistToDeclBefore hoist_to_decl_before;
+};
+
+ExpandCompoundAssignment::ExpandCompoundAssignment() = default;
+
+ExpandCompoundAssignment::~ExpandCompoundAssignment() = default;
+
+Transform::ApplyResult ExpandCompoundAssignment::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+ State state(ctx);
+ for (auto* node : src->ASTNodes().Objects()) {
+ if (auto* assign = node->As<CompoundAssignmentStatement>()) {
+ state.Expand(assign, assign->lhs, ctx.Clone(assign->rhs), assign->op);
+ } else if (auto* inc_dec = node->As<IncrementDecrementStatement>()) {
+ // For increment/decrement statements, `i++` becomes `i = i + 1`.
+ auto op = inc_dec->increment ? BinaryOp::kAdd : BinaryOp::kSubtract;
+ state.Expand(inc_dec, inc_dec->lhs, ctx.dst->Expr(1_a), op);
+ }
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/expand_compound_assignment.h b/src/tint/lang/wgsl/ast/transform/expand_compound_assignment.h
new file mode 100644
index 0000000..ff43727
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/expand_compound_assignment.h
@@ -0,0 +1,59 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_EXPAND_COMPOUND_ASSIGNMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_EXPAND_COMPOUND_ASSIGNMENT_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// Converts compound assignment statements to regular assignment statements,
+/// hoisting the LHS expression if necessary.
+///
+/// Before:
+/// ```
+/// a += 1;
+/// vector_array[foo()][bar()] *= 2.0;
+/// ```
+///
+/// After:
+/// ```
+/// a = a + 1;
+/// let _vec = &vector_array[foo()];
+/// let _idx = bar();
+/// (*_vec)[_idx] = (*_vec)[_idx] * 2.0;
+/// ```
+///
+/// This transform also handles increment and decrement statements in the same
+/// manner, by replacing `i++` with `i = i + 1`.
+class ExpandCompoundAssignment final : public utils::Castable<ExpandCompoundAssignment, Transform> {
+ public:
+ /// Constructor
+ ExpandCompoundAssignment();
+ /// Destructor
+ ~ExpandCompoundAssignment() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_EXPAND_COMPOUND_ASSIGNMENT_H_
diff --git a/src/tint/lang/wgsl/ast/transform/expand_compound_assignment_test.cc b/src/tint/lang/wgsl/ast/transform/expand_compound_assignment_test.cc
new file mode 100644
index 0000000..0c2f880
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/expand_compound_assignment_test.cc
@@ -0,0 +1,747 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/expand_compound_assignment.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using ExpandCompoundAssignmentTest = TransformTest;
+
+TEST_F(ExpandCompoundAssignmentTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<ExpandCompoundAssignment>(src));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, ShouldRunHasCompoundAssignment) {
+ auto* src = R"(
+fn foo() {
+ var v : i32;
+ v += 1;
+}
+)";
+
+ EXPECT_TRUE(ShouldRun<ExpandCompoundAssignment>(src));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, ShouldRunHasIncrementDecrement) {
+ auto* src = R"(
+fn foo() {
+ var v : i32;
+ v++;
+}
+)";
+
+ EXPECT_TRUE(ShouldRun<ExpandCompoundAssignment>(src));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, Basic) {
+ auto* src = R"(
+fn main() {
+ var v : i32;
+ v += 1;
+}
+)";
+
+ auto* expect = R"(
+fn main() {
+ var v : i32;
+ v = (v + 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsPointer) {
+ auto* src = R"(
+fn main() {
+ var v : i32;
+ let p = &v;
+ *p += 1;
+}
+)";
+
+ auto* expect = R"(
+fn main() {
+ var v : i32;
+ let p = &(v);
+ let tint_symbol = &(*(p));
+ *(tint_symbol) = (*(tint_symbol) + 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsStructMember) {
+ auto* src = R"(
+struct S {
+ m : f32,
+}
+
+fn main() {
+ var s : S;
+ s.m += 1.0;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : f32,
+}
+
+fn main() {
+ var s : S;
+ s.m = (s.m + 1.0);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsArrayElement) {
+ auto* src = R"(
+var<private> a : array<i32, 4>;
+
+fn idx() -> i32 {
+ a[1] = 42;
+ return 1;
+}
+
+fn main() {
+ a[idx()] += 1;
+}
+)";
+
+ auto* expect = R"(
+var<private> a : array<i32, 4>;
+
+fn idx() -> i32 {
+ a[1] = 42;
+ return 1;
+}
+
+fn main() {
+ let tint_symbol = &(a[idx()]);
+ *(tint_symbol) = (*(tint_symbol) + 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsVectorComponent_ArrayAccessor) {
+ auto* src = R"(
+var<private> v : vec4<i32>;
+
+fn idx() -> i32 {
+ v.y = 42;
+ return 1;
+}
+
+fn main() {
+ v[idx()] += 1;
+}
+)";
+
+ auto* expect = R"(
+var<private> v : vec4<i32>;
+
+fn idx() -> i32 {
+ v.y = 42;
+ return 1;
+}
+
+fn main() {
+ let tint_symbol = &(v);
+ let tint_symbol_1 = idx();
+ (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsVectorComponent_MemberAccessor) {
+ auto* src = R"(
+fn main() {
+ var v : vec4<i32>;
+ v.y += 1;
+}
+)";
+
+ auto* expect = R"(
+fn main() {
+ var v : vec4<i32>;
+ v.y = (v.y + 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsMatrixColumn) {
+ auto* src = R"(
+var<private> m : mat4x4<f32>;
+
+fn idx() -> i32 {
+ m[0].y = 42.0;
+ return 1;
+}
+
+fn main() {
+ m[idx()] += 1.0;
+}
+)";
+
+ auto* expect = R"(
+var<private> m : mat4x4<f32>;
+
+fn idx() -> i32 {
+ m[0].y = 42.0;
+ return 1;
+}
+
+fn main() {
+ let tint_symbol = &(m[idx()]);
+ *(tint_symbol) = (*(tint_symbol) + 1.0);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsMatrixElement) {
+ auto* src = R"(
+var<private> m : mat4x4<f32>;
+
+fn idx1() -> i32 {
+ m[0].y = 42.0;
+ return 1;
+}
+
+fn idx2() -> i32 {
+ m[1].z = 42.0;
+ return 1;
+}
+
+fn main() {
+ m[idx1()][idx2()] += 1.0;
+}
+)";
+
+ auto* expect = R"(
+var<private> m : mat4x4<f32>;
+
+fn idx1() -> i32 {
+ m[0].y = 42.0;
+ return 1;
+}
+
+fn idx2() -> i32 {
+ m[1].z = 42.0;
+ return 1;
+}
+
+fn main() {
+ let tint_symbol = &(m[idx1()]);
+ let tint_symbol_1 = idx2();
+ (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1.0);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsMultipleSideEffects) {
+ auto* src = R"(
+struct S {
+ a : array<vec4<f32>, 3>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : array<S>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+ p += 1;
+ return 3;
+}
+
+fn idx2() -> i32 {
+ p *= 3;
+ return 2;
+}
+
+fn idx3() -> i32 {
+ p -= 2;
+ return 1;
+}
+
+fn main() {
+ buffer[idx1()].a[idx2()][idx3()] += 1.0;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ a : array<vec4<f32>, 3>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : array<S>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+ p = (p + 1);
+ return 3;
+}
+
+fn idx2() -> i32 {
+ p = (p * 3);
+ return 2;
+}
+
+fn idx3() -> i32 {
+ p = (p - 2);
+ return 1;
+}
+
+fn main() {
+ let tint_symbol = &(buffer[idx1()].a[idx2()]);
+ let tint_symbol_1 = idx3();
+ (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1.0);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, ForLoopInit) {
+ auto* src = R"(
+var<private> a : array<vec4<i32>, 4>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+ p = (p + 1);
+ return 3;
+}
+
+fn idx2() -> i32 {
+ p = (p * 3);
+ return 2;
+}
+
+fn main() {
+ for (a[idx1()][idx2()] += 1; ; ) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+var<private> a : array<vec4<i32>, 4>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+ p = (p + 1);
+ return 3;
+}
+
+fn idx2() -> i32 {
+ p = (p * 3);
+ return 2;
+}
+
+fn main() {
+ {
+ let tint_symbol = &(a[idx1()]);
+ let tint_symbol_1 = idx2();
+ (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1);
+ loop {
+ {
+ break;
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, ForLoopCont) {
+ auto* src = R"(
+var<private> a : array<vec4<i32>, 4>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+ p = (p + 1);
+ return 3;
+}
+
+fn idx2() -> i32 {
+ p = (p * 3);
+ return 2;
+}
+
+fn main() {
+ for (; ; a[idx1()][idx2()] += 1) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+var<private> a : array<vec4<i32>, 4>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+ p = (p + 1);
+ return 3;
+}
+
+fn idx2() -> i32 {
+ p = (p * 3);
+ return 2;
+}
+
+fn main() {
+ loop {
+ {
+ break;
+ }
+
+ continuing {
+ let tint_symbol = &(a[idx1()]);
+ let tint_symbol_1 = idx2();
+ (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1);
+ }
+ }
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, Increment_I32) {
+ auto* src = R"(
+fn main() {
+ var v : i32;
+ v++;
+}
+)";
+
+ auto* expect = R"(
+fn main() {
+ var v : i32;
+ v = (v + 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, Increment_U32) {
+ auto* src = R"(
+fn main() {
+ var v : u32;
+ v++;
+}
+)";
+
+ auto* expect = R"(
+fn main() {
+ var v : u32;
+ v = (v + 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, Decrement_I32) {
+ auto* src = R"(
+fn main() {
+ var v : i32;
+ v--;
+}
+)";
+
+ auto* expect = R"(
+fn main() {
+ var v : i32;
+ v = (v - 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, Decrement_U32) {
+ auto* src = R"(
+fn main() {
+ var v : u32;
+ v--;
+}
+)";
+
+ auto* expect = R"(
+fn main() {
+ var v : u32;
+ v = (v - 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, Increment_LhsPointer) {
+ auto* src = R"(
+fn main() {
+ var v : i32;
+ let p = &v;
+ *p++;
+}
+)";
+
+ auto* expect = R"(
+fn main() {
+ var v : i32;
+ let p = &(v);
+ let tint_symbol = &(*(p));
+ *(tint_symbol) = (*(tint_symbol) + 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, Increment_LhsStructMember) {
+ auto* src = R"(
+struct S {
+ m : i32,
+}
+
+fn main() {
+ var s : S;
+ s.m++;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : i32,
+}
+
+fn main() {
+ var s : S;
+ s.m = (s.m + 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, Increment_LhsArrayElement) {
+ auto* src = R"(
+var<private> a : array<i32, 4>;
+
+fn idx() -> i32 {
+ a[1] = 42;
+ return 1;
+}
+
+fn main() {
+ a[idx()]++;
+}
+)";
+
+ auto* expect = R"(
+var<private> a : array<i32, 4>;
+
+fn idx() -> i32 {
+ a[1] = 42;
+ return 1;
+}
+
+fn main() {
+ let tint_symbol = &(a[idx()]);
+ *(tint_symbol) = (*(tint_symbol) + 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, Increment_LhsVectorComponent_ArrayAccessor) {
+ auto* src = R"(
+var<private> v : vec4<i32>;
+
+fn idx() -> i32 {
+ v.y = 42;
+ return 1;
+}
+
+fn main() {
+ v[idx()]++;
+}
+)";
+
+ auto* expect = R"(
+var<private> v : vec4<i32>;
+
+fn idx() -> i32 {
+ v.y = 42;
+ return 1;
+}
+
+fn main() {
+ let tint_symbol = &(v);
+ let tint_symbol_1 = idx();
+ (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, Increment_LhsVectorComponent_MemberAccessor) {
+ auto* src = R"(
+fn main() {
+ var v : vec4<i32>;
+ v.y++;
+}
+)";
+
+ auto* expect = R"(
+fn main() {
+ var v : vec4<i32>;
+ v.y = (v.y + 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, Increment_ForLoopCont) {
+ auto* src = R"(
+var<private> a : array<vec4<i32>, 4>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+ p = (p + 1);
+ return 3;
+}
+
+fn idx2() -> i32 {
+ p = (p * 3);
+ return 2;
+}
+
+fn main() {
+ for (; ; a[idx1()][idx2()]++) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+var<private> a : array<vec4<i32>, 4>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+ p = (p + 1);
+ return 3;
+}
+
+fn idx2() -> i32 {
+ p = (p * 3);
+ return 2;
+}
+
+fn main() {
+ loop {
+ {
+ break;
+ }
+
+ continuing {
+ let tint_symbol = &(a[idx1()]);
+ let tint_symbol_1 = idx2();
+ (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1);
+ }
+ }
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/first_index_offset.cc b/src/tint/lang/wgsl/ast/transform/first_index_offset.cc
new file mode 100644
index 0000000..c570a6d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/first_index_offset.cc
@@ -0,0 +1,170 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/first_index_offset.h"
+
+#include <memory>
+#include <unordered_map>
+#include <utility>
+
+#include "src/tint/builtin/builtin_value.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/struct.h"
+#include "src/tint/sem/variable.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::FirstIndexOffset);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::FirstIndexOffset::BindingPoint);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::FirstIndexOffset::Data);
+
+namespace tint::ast::transform {
+namespace {
+
+// Uniform buffer member names
+constexpr char kFirstVertexName[] = "first_vertex_index";
+constexpr char kFirstInstanceName[] = "first_instance_index";
+
+bool ShouldRun(const Program* program) {
+ for (auto* fn : program->AST().Functions()) {
+ if (fn->PipelineStage() == PipelineStage::kVertex) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+FirstIndexOffset::BindingPoint::BindingPoint() = default;
+FirstIndexOffset::BindingPoint::BindingPoint(uint32_t b, uint32_t g) : binding(b), group(g) {}
+FirstIndexOffset::BindingPoint::~BindingPoint() = default;
+
+FirstIndexOffset::Data::Data(bool has_vtx_index, bool has_inst_index)
+ : has_vertex_index(has_vtx_index), has_instance_index(has_inst_index) {}
+FirstIndexOffset::Data::Data(const Data&) = default;
+FirstIndexOffset::Data::~Data() = default;
+
+FirstIndexOffset::FirstIndexOffset() = default;
+FirstIndexOffset::~FirstIndexOffset() = default;
+
+Transform::ApplyResult FirstIndexOffset::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap& outputs) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ // Get the uniform buffer binding point
+ uint32_t ub_binding = binding_;
+ uint32_t ub_group = group_;
+ if (auto* binding_point = inputs.Get<BindingPoint>()) {
+ ub_binding = binding_point->binding;
+ ub_group = binding_point->group;
+ }
+
+ // Map of builtin usages
+ std::unordered_map<const sem::Variable*, const char*> builtin_vars;
+ std::unordered_map<const type::StructMember*, const char*> builtin_members;
+
+ bool has_vertex_index = false;
+ bool has_instance_index = false;
+
+ // Traverse the AST scanning for builtin accesses via variables (includes
+ // parameters) or structure member accesses.
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* var = node->As<Variable>()) {
+ for (auto* attr : var->attributes) {
+ if (auto* builtin_attr = attr->As<BuiltinAttribute>()) {
+ builtin::BuiltinValue builtin = src->Sem().Get(builtin_attr)->Value();
+ if (builtin == builtin::BuiltinValue::kVertexIndex) {
+ auto* sem_var = ctx.src->Sem().Get(var);
+ builtin_vars.emplace(sem_var, kFirstVertexName);
+ has_vertex_index = true;
+ }
+ if (builtin == builtin::BuiltinValue::kInstanceIndex) {
+ auto* sem_var = ctx.src->Sem().Get(var);
+ builtin_vars.emplace(sem_var, kFirstInstanceName);
+ has_instance_index = true;
+ }
+ }
+ }
+ }
+ if (auto* member = node->As<StructMember>()) {
+ for (auto* attr : member->attributes) {
+ if (auto* builtin_attr = attr->As<BuiltinAttribute>()) {
+ builtin::BuiltinValue builtin = src->Sem().Get(builtin_attr)->Value();
+ if (builtin == builtin::BuiltinValue::kVertexIndex) {
+ auto* sem_mem = ctx.src->Sem().Get(member);
+ builtin_members.emplace(sem_mem, kFirstVertexName);
+ has_vertex_index = true;
+ }
+ if (builtin == builtin::BuiltinValue::kInstanceIndex) {
+ auto* sem_mem = ctx.src->Sem().Get(member);
+ builtin_members.emplace(sem_mem, kFirstInstanceName);
+ has_instance_index = true;
+ }
+ }
+ }
+ }
+ }
+
+ if (has_vertex_index || has_instance_index) {
+ // Add uniform buffer members and calculate byte offsets
+ utils::Vector<const StructMember*, 8> members;
+ members.Push(b.Member(kFirstVertexName, b.ty.u32()));
+ members.Push(b.Member(kFirstInstanceName, b.ty.u32()));
+ auto* struct_ = b.Structure(b.Sym(), std::move(members));
+
+ // Create a global to hold the uniform buffer
+ Symbol buffer_name = b.Sym();
+ b.GlobalVar(buffer_name, b.ty.Of(struct_), builtin::AddressSpace::kUniform,
+ utils::Vector{
+ b.Binding(AInt(ub_binding)),
+ b.Group(AInt(ub_group)),
+ });
+
+ // Fix up all references to the builtins with the offsets
+ ctx.ReplaceAll([=, &ctx](const Expression* expr) -> const Expression* {
+ if (auto* sem = ctx.src->Sem().GetVal(expr)) {
+ if (auto* user = sem->UnwrapLoad()->As<sem::VariableUser>()) {
+ auto it = builtin_vars.find(user->Variable());
+ if (it != builtin_vars.end()) {
+ return ctx.dst->Add(ctx.CloneWithoutTransform(expr),
+ ctx.dst->MemberAccessor(buffer_name, it->second));
+ }
+ }
+ if (auto* access = sem->As<sem::StructMemberAccess>()) {
+ auto it = builtin_members.find(access->Member());
+ if (it != builtin_members.end()) {
+ return ctx.dst->Add(ctx.CloneWithoutTransform(expr),
+ ctx.dst->MemberAccessor(buffer_name, it->second));
+ }
+ }
+ }
+ // Not interested in this expression. Just clone.
+ return nullptr;
+ });
+ }
+
+ outputs.Add<Data>(has_vertex_index, has_instance_index);
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/first_index_offset.h b/src/tint/lang/wgsl/ast/transform/first_index_offset.h
new file mode 100644
index 0000000..68e5500
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/first_index_offset.h
@@ -0,0 +1,120 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_FIRST_INDEX_OFFSET_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_FIRST_INDEX_OFFSET_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// Adds firstVertex/Instance (injected via root constants) to
+/// vertex/instance index builtins.
+///
+/// This transform assumes that Name transform has been run before.
+///
+/// Unlike other APIs, D3D always starts vertex and instance numbering at 0,
+/// regardless of the firstVertex/Instance value specified. This transformer
+/// adds the value of firstVertex/Instance to each builtin. This action is
+/// performed by adding a new constant equal to original builtin +
+/// firstVertex/Instance to each function that references one of these builtins.
+///
+/// Note that D3D does not have any semantics for firstVertex/Instance.
+/// Therefore, these values must by passed to the shader.
+///
+/// Before:
+/// ```
+/// @builtin(vertex_index) var<in> vert_idx : u32;
+/// fn func() -> u32 {
+/// return vert_idx;
+/// }
+/// ```
+///
+/// After:
+/// ```
+/// struct TintFirstIndexOffsetData {
+/// tint_first_vertex_index : u32;
+/// tint_first_instance_index : u32;
+/// };
+/// @builtin(vertex_index) var<in> tint_first_index_offset_vert_idx : u32;
+/// @binding(N) @group(M) var<uniform> tint_first_index_data :
+/// TintFirstIndexOffsetData;
+/// fn func() -> u32 {
+/// const vert_idx = (tint_first_index_offset_vert_idx +
+/// tint_first_index_data.tint_first_vertex_index);
+/// return vert_idx;
+/// }
+/// ```
+///
+class FirstIndexOffset final : public utils::Castable<FirstIndexOffset, Transform> {
+ public:
+ /// BindingPoint is consumed by the FirstIndexOffset transform.
+ /// BindingPoint specifies the binding point of the first index uniform
+ /// buffer.
+ struct BindingPoint final : public utils::Castable<BindingPoint, Transform::Data> {
+ /// Constructor
+ BindingPoint();
+
+ /// Constructor
+ /// @param b the binding index
+ /// @param g the binding group
+ BindingPoint(uint32_t b, uint32_t g);
+
+ /// Destructor
+ ~BindingPoint() override;
+
+ /// `@binding()` for the first vertex / first instance uniform buffer
+ uint32_t binding = 0;
+ /// `@group()` for the first vertex / first instance uniform buffer
+ uint32_t group = 0;
+ };
+
+ /// Data is outputted by the FirstIndexOffset transform.
+ /// Data holds information about shader usage and constant buffer offsets.
+ struct Data final : public utils::Castable<Data, Transform::Data> {
+ /// Constructor
+ /// @param has_vtx_index True if the shader uses vertex_index
+ /// @param has_inst_index True if the shader uses instance_index
+ Data(bool has_vtx_index, bool has_inst_index);
+
+ /// Copy constructor
+ Data(const Data&);
+
+ /// Destructor
+ ~Data() override;
+
+ /// True if the shader uses vertex_index
+ const bool has_vertex_index;
+ /// True if the shader uses instance_index
+ const bool has_instance_index;
+ };
+
+ /// Constructor
+ FirstIndexOffset();
+ /// Destructor
+ ~FirstIndexOffset() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ uint32_t binding_ = 0;
+ uint32_t group_ = 0;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_FIRST_INDEX_OFFSET_H_
diff --git a/src/tint/lang/wgsl/ast/transform/first_index_offset_test.cc b/src/tint/lang/wgsl/ast/transform/first_index_offset_test.cc
new file mode 100644
index 0000000..c186921
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/first_index_offset_test.cc
@@ -0,0 +1,632 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/first_index_offset.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using FirstIndexOffsetTest = TransformTest;
+
+TEST_F(FirstIndexOffsetTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<FirstIndexOffset>(src));
+}
+
+TEST_F(FirstIndexOffsetTest, ShouldRunFragmentStage) {
+ auto* src = R"(
+@fragment
+fn entry() {
+ return;
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<FirstIndexOffset>(src));
+}
+
+TEST_F(FirstIndexOffsetTest, ShouldRunVertexStage) {
+ auto* src = R"(
+@vertex
+fn entry() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ EXPECT_TRUE(ShouldRun<FirstIndexOffset>(src));
+}
+
+TEST_F(FirstIndexOffsetTest, EmptyModule) {
+ auto* src = "";
+ auto* expect = "";
+
+ Transform::DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(0, 0);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
+
+ EXPECT_EQ(data, nullptr);
+}
+
+TEST_F(FirstIndexOffsetTest, BasicVertexShader) {
+ auto* src = R"(
+@vertex
+fn entry() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+ auto* expect = src;
+
+ Transform::DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(0, 0);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
+
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_index, false);
+ EXPECT_EQ(data->has_instance_index, false);
+}
+
+TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex) {
+ auto* src = R"(
+fn test(vert_idx : u32) -> u32 {
+ return vert_idx;
+}
+
+@vertex
+fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
+ test(vert_idx);
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ first_vertex_index : u32,
+ first_instance_index : u32,
+}
+
+@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
+
+fn test(vert_idx : u32) -> u32 {
+ return vert_idx;
+}
+
+@vertex
+fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
+ test((vert_idx + tint_symbol_1.first_vertex_index));
+ return vec4<f32>();
+}
+)";
+
+ Transform::DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 2);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
+
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_index, true);
+ EXPECT_EQ(data->has_instance_index, false);
+}
+
+TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex_OutOfOrder) {
+ auto* src = R"(
+@vertex
+fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
+ test(vert_idx);
+ return vec4<f32>();
+}
+
+fn test(vert_idx : u32) -> u32 {
+ return vert_idx;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ first_vertex_index : u32,
+ first_instance_index : u32,
+}
+
+@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
+
+@vertex
+fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
+ test((vert_idx + tint_symbol_1.first_vertex_index));
+ return vec4<f32>();
+}
+
+fn test(vert_idx : u32) -> u32 {
+ return vert_idx;
+}
+)";
+
+ Transform::DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 2);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
+
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_index, true);
+ EXPECT_EQ(data->has_instance_index, false);
+}
+
+TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex) {
+ auto* src = R"(
+fn test(inst_idx : u32) -> u32 {
+ return inst_idx;
+}
+
+@vertex
+fn entry(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
+ test(inst_idx);
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ first_vertex_index : u32,
+ first_instance_index : u32,
+}
+
+@binding(1) @group(7) var<uniform> tint_symbol_1 : tint_symbol;
+
+fn test(inst_idx : u32) -> u32 {
+ return inst_idx;
+}
+
+@vertex
+fn entry(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
+ test((inst_idx + tint_symbol_1.first_instance_index));
+ return vec4<f32>();
+}
+)";
+
+ Transform::DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 7);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
+
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_index, false);
+ EXPECT_EQ(data->has_instance_index, true);
+}
+
+TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex_OutOfOrder) {
+ auto* src = R"(
+@vertex
+fn entry(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
+ test(inst_idx);
+ return vec4<f32>();
+}
+
+fn test(inst_idx : u32) -> u32 {
+ return inst_idx;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ first_vertex_index : u32,
+ first_instance_index : u32,
+}
+
+@binding(1) @group(7) var<uniform> tint_symbol_1 : tint_symbol;
+
+@vertex
+fn entry(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
+ test((inst_idx + tint_symbol_1.first_instance_index));
+ return vec4<f32>();
+}
+
+fn test(inst_idx : u32) -> u32 {
+ return inst_idx;
+}
+)";
+
+ Transform::DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 7);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
+
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_index, false);
+ EXPECT_EQ(data->has_instance_index, true);
+}
+
+TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex) {
+ auto* src = R"(
+fn test(instance_idx : u32, vert_idx : u32) -> u32 {
+ return instance_idx + vert_idx;
+}
+
+struct Inputs {
+ @builtin(instance_index) instance_idx : u32,
+ @builtin(vertex_index) vert_idx : u32,
+};
+
+@vertex
+fn entry(inputs : Inputs) -> @builtin(position) vec4<f32> {
+ test(inputs.instance_idx, inputs.vert_idx);
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ first_vertex_index : u32,
+ first_instance_index : u32,
+}
+
+@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
+
+fn test(instance_idx : u32, vert_idx : u32) -> u32 {
+ return (instance_idx + vert_idx);
+}
+
+struct Inputs {
+ @builtin(instance_index)
+ instance_idx : u32,
+ @builtin(vertex_index)
+ vert_idx : u32,
+}
+
+@vertex
+fn entry(inputs : Inputs) -> @builtin(position) vec4<f32> {
+ test((inputs.instance_idx + tint_symbol_1.first_instance_index), (inputs.vert_idx + tint_symbol_1.first_vertex_index));
+ return vec4<f32>();
+}
+)";
+
+ Transform::DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 2);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
+
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_index, true);
+ EXPECT_EQ(data->has_instance_index, true);
+}
+
+TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex_OutOfOrder) {
+ auto* src = R"(
+@vertex
+fn entry(inputs : Inputs) -> @builtin(position) vec4<f32> {
+ test(inputs.instance_idx, inputs.vert_idx);
+ return vec4<f32>();
+}
+
+struct Inputs {
+ @builtin(instance_index) instance_idx : u32,
+ @builtin(vertex_index) vert_idx : u32,
+};
+
+fn test(instance_idx : u32, vert_idx : u32) -> u32 {
+ return instance_idx + vert_idx;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ first_vertex_index : u32,
+ first_instance_index : u32,
+}
+
+@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
+
+@vertex
+fn entry(inputs : Inputs) -> @builtin(position) vec4<f32> {
+ test((inputs.instance_idx + tint_symbol_1.first_instance_index), (inputs.vert_idx + tint_symbol_1.first_vertex_index));
+ return vec4<f32>();
+}
+
+struct Inputs {
+ @builtin(instance_index)
+ instance_idx : u32,
+ @builtin(vertex_index)
+ vert_idx : u32,
+}
+
+fn test(instance_idx : u32, vert_idx : u32) -> u32 {
+ return (instance_idx + vert_idx);
+}
+)";
+
+ Transform::DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 2);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
+
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_index, true);
+ EXPECT_EQ(data->has_instance_index, true);
+}
+
+TEST_F(FirstIndexOffsetTest, NestedCalls) {
+ auto* src = R"(
+fn func1(vert_idx : u32) -> u32 {
+ return vert_idx;
+}
+
+fn func2(vert_idx : u32) -> u32 {
+ return func1(vert_idx);
+}
+
+@vertex
+fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
+ func2(vert_idx);
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ first_vertex_index : u32,
+ first_instance_index : u32,
+}
+
+@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
+
+fn func1(vert_idx : u32) -> u32 {
+ return vert_idx;
+}
+
+fn func2(vert_idx : u32) -> u32 {
+ return func1(vert_idx);
+}
+
+@vertex
+fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
+ func2((vert_idx + tint_symbol_1.first_vertex_index));
+ return vec4<f32>();
+}
+)";
+
+ Transform::DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 2);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
+
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_index, true);
+ EXPECT_EQ(data->has_instance_index, false);
+}
+
+TEST_F(FirstIndexOffsetTest, NestedCalls_OutOfOrder) {
+ auto* src = R"(
+@vertex
+fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
+ func2(vert_idx);
+ return vec4<f32>();
+}
+
+fn func2(vert_idx : u32) -> u32 {
+ return func1(vert_idx);
+}
+
+fn func1(vert_idx : u32) -> u32 {
+ return vert_idx;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ first_vertex_index : u32,
+ first_instance_index : u32,
+}
+
+@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
+
+@vertex
+fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
+ func2((vert_idx + tint_symbol_1.first_vertex_index));
+ return vec4<f32>();
+}
+
+fn func2(vert_idx : u32) -> u32 {
+ return func1(vert_idx);
+}
+
+fn func1(vert_idx : u32) -> u32 {
+ return vert_idx;
+}
+)";
+
+ Transform::DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 2);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
+
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_index, true);
+ EXPECT_EQ(data->has_instance_index, false);
+}
+
+TEST_F(FirstIndexOffsetTest, MultipleEntryPoints) {
+ auto* src = R"(
+fn func(i : u32) -> u32 {
+ return i;
+}
+
+@vertex
+fn entry_a(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
+ func(vert_idx);
+ return vec4<f32>();
+}
+
+@vertex
+fn entry_b(@builtin(vertex_index) vert_idx : u32, @builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
+ func(vert_idx + inst_idx);
+ return vec4<f32>();
+}
+
+@vertex
+fn entry_c(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
+ func(inst_idx);
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ first_vertex_index : u32,
+ first_instance_index : u32,
+}
+
+@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
+
+fn func(i : u32) -> u32 {
+ return i;
+}
+
+@vertex
+fn entry_a(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
+ func((vert_idx + tint_symbol_1.first_vertex_index));
+ return vec4<f32>();
+}
+
+@vertex
+fn entry_b(@builtin(vertex_index) vert_idx : u32, @builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
+ func(((vert_idx + tint_symbol_1.first_vertex_index) + (inst_idx + tint_symbol_1.first_instance_index)));
+ return vec4<f32>();
+}
+
+@vertex
+fn entry_c(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
+ func((inst_idx + tint_symbol_1.first_instance_index));
+ return vec4<f32>();
+}
+)";
+
+ Transform::DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 2);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
+
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_index, true);
+ EXPECT_EQ(data->has_instance_index, true);
+}
+
+TEST_F(FirstIndexOffsetTest, MultipleEntryPoints_OutOfOrder) {
+ auto* src = R"(
+@vertex
+fn entry_a(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
+ func(vert_idx);
+ return vec4<f32>();
+}
+
+@vertex
+fn entry_b(@builtin(vertex_index) vert_idx : u32, @builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
+ func(vert_idx + inst_idx);
+ return vec4<f32>();
+}
+
+@vertex
+fn entry_c(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
+ func(inst_idx);
+ return vec4<f32>();
+}
+
+fn func(i : u32) -> u32 {
+ return i;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ first_vertex_index : u32,
+ first_instance_index : u32,
+}
+
+@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
+
+@vertex
+fn entry_a(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
+ func((vert_idx + tint_symbol_1.first_vertex_index));
+ return vec4<f32>();
+}
+
+@vertex
+fn entry_b(@builtin(vertex_index) vert_idx : u32, @builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
+ func(((vert_idx + tint_symbol_1.first_vertex_index) + (inst_idx + tint_symbol_1.first_instance_index)));
+ return vec4<f32>();
+}
+
+@vertex
+fn entry_c(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
+ func((inst_idx + tint_symbol_1.first_instance_index));
+ return vec4<f32>();
+}
+
+fn func(i : u32) -> u32 {
+ return i;
+}
+)";
+
+ Transform::DataMap config;
+ config.Add<FirstIndexOffset::BindingPoint>(1, 2);
+ auto got = Run<FirstIndexOffset>(src, std::move(config));
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<FirstIndexOffset::Data>();
+
+ ASSERT_NE(data, nullptr);
+ EXPECT_EQ(data->has_vertex_index, true);
+ EXPECT_EQ(data->has_instance_index, true);
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/fold_trivial_lets.cc b/src/tint/lang/wgsl/ast/transform/fold_trivial_lets.cc
new file mode 100644
index 0000000..dbaac2a
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/fold_trivial_lets.cc
@@ -0,0 +1,157 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/fold_trivial_lets.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/traverse_expressions.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/value_expression.h"
+#include "src/tint/utils/hashmap.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::FoldTrivialLets);
+
+namespace tint::ast::transform {
+
+/// PIMPL state for the transform.
+struct FoldTrivialLets::State {
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+ /// The semantic info.
+ const sem::Info& sem = {ctx.src->Sem()};
+
+ /// Constructor
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
+
+ /// Process a block.
+ /// @param block the block
+ void ProcessBlock(const BlockStatement* block) {
+ // PendingLet describes a let declaration that might be inlined.
+ struct PendingLet {
+ // The let declaration.
+ const VariableDeclStatement* decl = nullptr;
+ // The number of uses that have not yet been inlined.
+ size_t remaining_uses = 0;
+ };
+
+ // A map from semantic variables to their PendingLet descriptors.
+ utils::Hashmap<const sem::Variable*, PendingLet, 16> pending_lets;
+
+ // Helper that folds pending let declarations into `expr` if possible.
+ auto fold_lets = [&](const Expression* expr) {
+ TraverseExpressions(expr, b.Diagnostics(), [&](const IdentifierExpression* ident) {
+ if (auto* user = sem.Get<sem::VariableUser>(ident)) {
+ auto itr = pending_lets.Find(user->Variable());
+ if (itr) {
+ TINT_ASSERT(Transform, itr->remaining_uses > 0);
+
+ // We found a reference to a pending let, so replace it with the inlined
+ // initializer expression.
+ ctx.Replace(ident, ctx.Clone(itr->decl->variable->initializer));
+
+ // Decrement the remaining uses count and remove the let declaration if this
+ // was the last remaining use.
+ if (--itr->remaining_uses == 0) {
+ ctx.Remove(block->statements, itr->decl);
+ }
+ }
+ }
+ return TraverseAction::Descend;
+ });
+ };
+
+ // Loop over all statements in the block.
+ for (auto* stmt : block->statements) {
+ // Check for a let declarations.
+ if (auto* decl = stmt->As<VariableDeclStatement>()) {
+ if (auto* let = decl->variable->As<Let>()) {
+ // If the initializer doesn't have side effects, we might be able to inline it.
+ if (!sem.GetVal(let->initializer)->HasSideEffects()) { //
+ auto num_users = sem.Get(let)->Users().Length();
+ if (let->initializer->Is<IdentifierExpression>()) {
+ // The initializer is a single identifier expression.
+ // We can fold it into multiple uses in the next non-let statement.
+ // We also fold previous pending lets into this one, but only if
+ // it's only used once (to avoid duplicating potentially complex
+ // expressions).
+ if (num_users == 1) {
+ fold_lets(let->initializer);
+ }
+ pending_lets.Add(sem.Get(let), PendingLet{decl, num_users});
+ } else {
+ // The initializer is something more complex, so we only want to inline
+ // it if it's only used once.
+ // We also fold previous pending lets into this one.
+ fold_lets(let->initializer);
+ if (num_users == 1) {
+ pending_lets.Add(sem.Get(let), PendingLet{decl, 1});
+ }
+ }
+ continue;
+ }
+ }
+ }
+
+ // Fold pending let declarations into a select few places that are frequently generated
+ // by the SPIR_V reader.
+ if (auto* assign = stmt->As<AssignmentStatement>()) {
+ // We can fold into the RHS of an assignment statement if the RHS and LHS
+ // expressions have no side effects.
+ if (!sem.GetVal(assign->lhs)->HasSideEffects() &&
+ !sem.GetVal(assign->rhs)->HasSideEffects()) {
+ fold_lets(assign->rhs);
+ }
+ } else if (auto* ifelse = stmt->As<IfStatement>()) {
+ // We can fold into the condition of an if statement if the condition expression has
+ // no side effects.
+ if (!sem.GetVal(ifelse->condition)->HasSideEffects()) {
+ fold_lets(ifelse->condition);
+ }
+ }
+
+ // Clear any remaining pending lets.
+ // We do not try to fold lets beyond the first non-let statement.
+ pending_lets.Clear();
+ }
+ }
+
+ /// Runs the transform.
+ /// @returns the new program
+ ApplyResult Run() {
+ // Process all blocks in the module.
+ for (auto* node : src->ASTNodes().Objects()) {
+ if (auto* block = node->As<BlockStatement>()) {
+ ProcessBlock(block);
+ }
+ }
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+};
+
+FoldTrivialLets::FoldTrivialLets() = default;
+
+FoldTrivialLets::~FoldTrivialLets() = default;
+
+Transform::ApplyResult FoldTrivialLets::Apply(const Program* src, const DataMap&, DataMap&) const {
+ return State(src).Run();
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/fold_trivial_lets.h b/src/tint/lang/wgsl/ast/transform/fold_trivial_lets.h
new file mode 100644
index 0000000..2ce8f12
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/fold_trivial_lets.h
@@ -0,0 +1,44 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_FOLD_TRIVIAL_LETS_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_FOLD_TRIVIAL_LETS_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// FoldTrivialLets is a transform that inlines the initializers of let declarations whose
+/// initializers are just identifier expressions, or lets that are only used once. This is used to
+/// clean up unnecessary let declarations created by the SPIR-V reader.
+class FoldTrivialLets final : public utils::Castable<FoldTrivialLets, Transform> {
+ public:
+ /// Constructor
+ FoldTrivialLets();
+
+ /// Destructor
+ ~FoldTrivialLets() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_FOLD_TRIVIAL_LETS_H_
diff --git a/src/tint/lang/wgsl/ast/transform/fold_trivial_lets_test.cc b/src/tint/lang/wgsl/ast/transform/fold_trivial_lets_test.cc
new file mode 100644
index 0000000..9d9357f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/fold_trivial_lets_test.cc
@@ -0,0 +1,286 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/fold_trivial_lets.h"
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using FoldTrivialLetsTest = TransformTest;
+
+TEST_F(FoldTrivialLetsTest, Fold_IdentInitializer_AssignRHS) {
+ auto* src = R"(
+fn f() {
+ var v = 42;
+ let x = v;
+ v = (x + 1);
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var v = 42;
+ v = (v + 1);
+}
+)";
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, Fold_IdentInitializer_IfCondition) {
+ auto* src = R"(
+fn f() {
+ var v = 42;
+ let x = v;
+ if (x > 0) {
+ v = 0;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var v = 42;
+ if ((v > 0)) {
+ v = 0;
+ }
+}
+)";
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, NoFold_IdentInitializer_StoreBeforeUse) {
+ auto* src = R"(
+fn f() {
+ var v = 42;
+ let x = v;
+ v = 0;
+ v = (x + 1);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, NoFold_IdentInitializer_SideEffectsInUseExpression) {
+ auto* src = R"(
+var<private> v = 42;
+
+fn g() -> i32 {
+ v = 0;
+ return 1;
+}
+
+fn f() {
+ let x = v;
+ v = (g() + x);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, Fold_IdentInitializer_MultiUse) {
+ auto* src = R"(
+fn f() {
+ var v = 42;
+ let x = v;
+ v = (x + x);
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var v = 42;
+ v = (v + v);
+}
+)";
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, Fold_IdentInitializer_MultiUse_OnlySomeInlineable) {
+ auto* src = R"(
+fn f() {
+ var v = 42;
+ let x = v;
+ v = (x + x);
+ if (x > 0) {
+ v = 0;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var v = 42;
+ let x = v;
+ v = (v + v);
+ if ((x > 0)) {
+ v = 0;
+ }
+}
+)";
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, Fold_ComplexInitializer_SingleUse) {
+ auto* src = R"(
+fn f(idx : i32) {
+ var v = array<vec4i, 4>();
+ let x = v[idx].y;
+ v[0].x = (x + 1);
+}
+)";
+
+ auto* expect = R"(
+fn f(idx : i32) {
+ var v = array<vec4i, 4>();
+ v[0].x = (v[idx].y + 1);
+}
+)";
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, NoFold_ComplexInitializer_SingleUseWithSideEffects) {
+ auto* src = R"(
+var<private> i = 0;
+
+fn bar() -> i32 {
+ i++;
+ return i;
+}
+
+fn f() -> i32 {
+ var v = array<vec4i, 4>();
+ let x = v[bar()].y;
+ return (i + x);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, NoFold_ComplexInitializer_MultiUse) {
+ auto* src = R"(
+fn f(idx : i32) {
+ var v = array<vec4i, 4>();
+ let x = v[idx].y;
+ v[0].x = (x + x);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, Fold_ComplexInitializer_SingleUseViaSimpleLet) {
+ auto* src = R"(
+fn f(a : i32, b : i32, c : i32) -> i32 {
+ let x = ((a * b) + c);
+ let y = x;
+ return y;
+}
+)";
+
+ auto* expect = R"(
+fn f(a : i32, b : i32, c : i32) -> i32 {
+ let y = ((a * b) + c);
+ return y;
+}
+)";
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, Fold_ComplexInitializer_SingleUseViaSimpleLetUsedTwice) {
+ auto* src = R"(
+fn f(a : i32, b : i32, c : i32) -> i32 {
+ let x = (a * b) + c;
+ let y = x;
+ let z = y + y;
+ return z;
+}
+)";
+
+ auto* expect = R"(
+fn f(a : i32, b : i32, c : i32) -> i32 {
+ let x = ((a * b) + c);
+ let z = (x + x);
+ return z;
+}
+)";
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, Fold_ComplexInitializer_MultiUseUseDifferentLets) {
+ auto* src = R"(
+fn f(a : i32, b : i32, c : i32) -> i32 {
+ let x = (a * b) + c;
+ let y = x;
+ let z = x + y;
+ return z;
+}
+)";
+
+ auto* expect = R"(
+fn f(a : i32, b : i32, c : i32) -> i32 {
+ let x = ((a * b) + c);
+ let z = (x + x);
+ return z;
+}
+)";
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/for_loop_to_loop.cc b/src/tint/lang/wgsl/ast/transform/for_loop_to_loop.cc
new file mode 100644
index 0000000..830703f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/for_loop_to_loop.cc
@@ -0,0 +1,85 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/for_loop_to_loop.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/break_statement.h"
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::ForLoopToLoop);
+
+namespace tint::ast::transform {
+namespace {
+
+bool ShouldRun(const Program* program) {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (node->Is<ForLoopStatement>()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+ForLoopToLoop::ForLoopToLoop() = default;
+
+ForLoopToLoop::~ForLoopToLoop() = default;
+
+Transform::ApplyResult ForLoopToLoop::Apply(const Program* src, const DataMap&, DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ ctx.ReplaceAll([&](const ForLoopStatement* for_loop) -> const Statement* {
+ utils::Vector<const Statement*, 8> stmts;
+ if (auto* cond = for_loop->condition) {
+ // !condition
+ auto* not_cond = b.Not(ctx.Clone(cond));
+
+ // { break; }
+ auto* break_body = b.Block(b.Break());
+
+ // if (!condition) { break; }
+ stmts.Push(b.If(not_cond, break_body));
+ }
+ for (auto* stmt : for_loop->body->statements) {
+ stmts.Push(ctx.Clone(stmt));
+ }
+
+ const BlockStatement* continuing = nullptr;
+ if (auto* cont = for_loop->continuing) {
+ continuing = b.Block(ctx.Clone(cont));
+ }
+
+ auto* body = b.Block(stmts);
+ auto* loop = b.Loop(body, continuing);
+
+ if (auto* init = for_loop->initializer) {
+ return b.Block(ctx.Clone(init), loop);
+ }
+
+ return loop;
+ });
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/for_loop_to_loop.h b/src/tint/lang/wgsl/ast/transform/for_loop_to_loop.h
new file mode 100644
index 0000000..8d2ff51
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/for_loop_to_loop.h
@@ -0,0 +1,40 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_FOR_LOOP_TO_LOOP_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_FOR_LOOP_TO_LOOP_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// ForLoopToLoop is a Transform that converts a for-loop statement into a loop
+/// statement. This is required by the SPIR-V writer.
+class ForLoopToLoop final : public utils::Castable<ForLoopToLoop, Transform> {
+ public:
+ /// Constructor
+ ForLoopToLoop();
+
+ /// Destructor
+ ~ForLoopToLoop() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_FOR_LOOP_TO_LOOP_H_
diff --git a/src/tint/lang/wgsl/ast/transform/for_loop_to_loop_test.cc b/src/tint/lang/wgsl/ast/transform/for_loop_to_loop_test.cc
new file mode 100644
index 0000000..2beadb3
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/for_loop_to_loop_test.cc
@@ -0,0 +1,372 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/for_loop_to_loop.h"
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using ForLoopToLoopTest = TransformTest;
+
+TEST_F(ForLoopToLoopTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<ForLoopToLoop>(src));
+}
+
+TEST_F(ForLoopToLoopTest, ShouldRunHasForLoop) {
+ auto* src = R"(
+fn f() {
+ for (;;) {
+ break;
+ }
+}
+)";
+
+ EXPECT_TRUE(ShouldRun<ForLoopToLoop>(src));
+}
+
+TEST_F(ForLoopToLoopTest, EmptyModule) {
+ auto* src = "";
+ auto* expect = src;
+
+ auto got = Run<ForLoopToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test an empty for loop.
+TEST_F(ForLoopToLoopTest, Empty) {
+ auto* src = R"(
+fn f() {
+ for (;;) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ loop {
+ break;
+ }
+}
+)";
+
+ auto got = Run<ForLoopToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test a for loop with non-empty body.
+TEST_F(ForLoopToLoopTest, Body) {
+ auto* src = R"(
+fn f() {
+ for (;;) {
+ return;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ loop {
+ return;
+ }
+}
+)";
+
+ auto got = Run<ForLoopToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test a for loop declaring a variable in the initializer statement.
+TEST_F(ForLoopToLoopTest, InitializerStatementDecl) {
+ auto* src = R"(
+fn f() {
+ for (var i: i32;;) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ {
+ var i : i32;
+ loop {
+ break;
+ }
+ }
+}
+)";
+
+ auto got = Run<ForLoopToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test a for loop declaring and initializing a variable in the initializer
+// statement.
+TEST_F(ForLoopToLoopTest, InitializerStatementDeclEqual) {
+ auto* src = R"(
+fn f() {
+ for (var i: i32 = 0;;) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ {
+ var i : i32 = 0;
+ loop {
+ break;
+ }
+ }
+}
+)";
+
+ auto got = Run<ForLoopToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test a for loop declaring a const variable in the initializer statement.
+TEST_F(ForLoopToLoopTest, InitializerStatementConstDecl) {
+ auto* src = R"(
+fn f() {
+ for (let i: i32 = 0;;) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ {
+ let i : i32 = 0;
+ loop {
+ break;
+ }
+ }
+}
+)";
+
+ auto got = Run<ForLoopToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test a for loop assigning a variable in the initializer statement.
+TEST_F(ForLoopToLoopTest, InitializerStatementAssignment) {
+ auto* src = R"(
+fn f() {
+ var i: i32;
+ for (i = 0;;) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i : i32;
+ {
+ i = 0;
+ loop {
+ break;
+ }
+ }
+}
+)";
+
+ auto got = Run<ForLoopToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test a for loop calling a function in the initializer statement.
+TEST_F(ForLoopToLoopTest, InitializerStatementFuncCall) {
+ auto* src = R"(
+fn a(x : i32, y : i32) {
+}
+
+fn f() {
+ var b : i32;
+ var c : i32;
+ for (a(b,c);;) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(x : i32, y : i32) {
+}
+
+fn f() {
+ var b : i32;
+ var c : i32;
+ {
+ a(b, c);
+ loop {
+ break;
+ }
+ }
+}
+)";
+
+ auto got = Run<ForLoopToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test a for loop with a break condition
+TEST_F(ForLoopToLoopTest, BreakCondition) {
+ auto* src = R"(
+fn f() {
+ for (; 0 == 1;) {
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ loop {
+ if (!((0 == 1))) {
+ break;
+ }
+ }
+}
+)";
+
+ auto got = Run<ForLoopToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test a for loop assigning a variable in the continuing statement.
+TEST_F(ForLoopToLoopTest, ContinuingAssignment) {
+ auto* src = R"(
+fn f() {
+ var x: i32;
+ for (;;x = 2) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var x : i32;
+ loop {
+ break;
+
+ continuing {
+ x = 2;
+ }
+ }
+}
+)";
+
+ auto got = Run<ForLoopToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test a for loop calling a function in the continuing statement.
+TEST_F(ForLoopToLoopTest, ContinuingFuncCall) {
+ auto* src = R"(
+fn a(x : i32, y : i32) {
+}
+
+fn f() {
+ var b : i32;
+ var c : i32;
+ for (;;a(b,c)) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(x : i32, y : i32) {
+}
+
+fn f() {
+ var b : i32;
+ var c : i32;
+ loop {
+ break;
+
+ continuing {
+ a(b, c);
+ }
+ }
+}
+)";
+
+ auto got = Run<ForLoopToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test a for loop with all statements non-empty.
+TEST_F(ForLoopToLoopTest, All) {
+ auto* src = R"(
+fn f() {
+ var a : i32;
+ for(var i : i32 = 0; i < 4; i = i + 1) {
+ if (a == 0) {
+ continue;
+ }
+ a = a + 2;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var a : i32;
+ {
+ var i : i32 = 0;
+ loop {
+ if (!((i < 4))) {
+ break;
+ }
+ if ((a == 0)) {
+ continue;
+ }
+ a = (a + 2);
+
+ continuing {
+ i = (i + 1);
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<ForLoopToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/localize_struct_array_assignment.cc b/src/tint/lang/wgsl/ast/transform/localize_struct_array_assignment.cc
new file mode 100644
index 0000000..ccb6095
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/localize_struct_array_assignment.cc
@@ -0,0 +1,223 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/localize_struct_array_assignment.h"
+
+#include <unordered_map>
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/assignment_statement.h"
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+#include "src/tint/lang/wgsl/ast/traverse_expressions.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/value_expression.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/type/reference.h"
+#include "src/tint/utils/scoped_assignment.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::LocalizeStructArrayAssignment);
+
+namespace tint::ast::transform {
+
+/// PIMPL state for the transform
+struct LocalizeStructArrayAssignment::State {
+ /// Constructor
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
+
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
+ struct Shared {
+ bool process_nested_nodes = false;
+ utils::Vector<const Statement*, 4> insert_before_stmts;
+ utils::Vector<const Statement*, 4> insert_after_stmts;
+ } s;
+
+ bool made_changes = false;
+
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* assign_stmt = node->As<AssignmentStatement>()) {
+ // Process if it's an assignment statement to a dynamically indexed array
+ // within a struct on a function or private storage variable. This
+ // specific use-case is what FXC fails to compile with:
+ // error X3500: array reference cannot be used as an l-value; not natively
+ // addressable
+ if (!ContainsStructArrayIndex(assign_stmt->lhs)) {
+ continue;
+ }
+ auto og = GetOriginatingTypeAndAddressSpace(assign_stmt);
+ if (!(og.first->Is<type::Struct>() &&
+ (og.second == builtin::AddressSpace::kFunction ||
+ og.second == builtin::AddressSpace::kPrivate))) {
+ continue;
+ }
+
+ ctx.Replace(assign_stmt, [&, assign_stmt] {
+ // Reset shared state for this assignment statement
+ s = Shared{};
+
+ const Expression* new_lhs = nullptr;
+ {
+ TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, true);
+ new_lhs = ctx.Clone(assign_stmt->lhs);
+ }
+
+ auto* new_assign_stmt = b.Assign(new_lhs, ctx.Clone(assign_stmt->rhs));
+
+ // Combine insert_before_stmts + new_assign_stmt + insert_after_stmts into
+ // a block and return it
+ auto stmts = std::move(s.insert_before_stmts);
+ stmts.Reserve(1 + s.insert_after_stmts.Length());
+ stmts.Push(new_assign_stmt);
+ for (auto* stmt : s.insert_after_stmts) {
+ stmts.Push(stmt);
+ }
+
+ return b.Block(std::move(stmts));
+ });
+
+ made_changes = true;
+ }
+ }
+
+ if (!made_changes) {
+ return SkipTransform;
+ }
+
+ ctx.ReplaceAll([&](const IndexAccessorExpression* index_access) -> const Expression* {
+ if (!s.process_nested_nodes) {
+ return nullptr;
+ }
+
+ // Indexing a member access expr?
+ auto* mem_access = index_access->object->As<MemberAccessorExpression>();
+ if (!mem_access) {
+ return nullptr;
+ }
+
+ // Process any nested IndexAccessorExpressions
+ mem_access = ctx.Clone(mem_access);
+
+ // Store the address of the member access into a let as we need to read
+ // the value twice e.g. let tint_symbol = &(s.a1);
+ auto mem_access_ptr = b.Sym();
+ s.insert_before_stmts.Push(b.Decl(b.Let(mem_access_ptr, b.AddressOf(mem_access))));
+
+ // Disable further transforms when cloning
+ TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, false);
+
+ // Copy entire array out of struct into local temp var
+ // e.g. var tint_symbol_1 = *(tint_symbol);
+ auto tmp_var = b.Sym();
+ s.insert_before_stmts.Push(b.Decl(b.Var(tmp_var, b.Deref(mem_access_ptr))));
+
+ // Replace input index_access with a clone of itself, but with its
+ // .object replaced by the new temp var. This is returned from this
+ // function to modify the original assignment statement. e.g.
+ // tint_symbol_1[uniforms.i]
+ auto* new_index_access = b.IndexAccessor(tmp_var, ctx.Clone(index_access->index));
+
+ // Assign temp var back to array
+ // e.g. *(tint_symbol) = tint_symbol_1;
+ auto* assign_rhs_to_temp = b.Assign(b.Deref(mem_access_ptr), tmp_var);
+ {
+ utils::Vector<const Statement*, 8> stmts{assign_rhs_to_temp};
+ for (auto* stmt : s.insert_after_stmts) {
+ stmts.Push(stmt);
+ }
+ s.insert_after_stmts = std::move(stmts);
+ }
+
+ return new_index_access;
+ });
+
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+
+ private:
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+
+ /// Returns true if `expr` contains an index accessor expression to a
+ /// structure member of array type.
+ bool ContainsStructArrayIndex(const Expression* expr) {
+ bool result = false;
+ TraverseExpressions(expr, b.Diagnostics(), [&](const IndexAccessorExpression* ia) {
+ // Indexing using a runtime value?
+ auto* idx_sem = src->Sem().GetVal(ia->index);
+ if (!idx_sem->ConstantValue()) {
+ // Indexing a member access expr?
+ if (auto* ma = ia->object->As<MemberAccessorExpression>()) {
+ // That accesses an array?
+ if (src->TypeOf(ma)->UnwrapRef()->Is<type::Array>()) {
+ result = true;
+ return TraverseAction::Stop;
+ }
+ }
+ }
+ return TraverseAction::Descend;
+ });
+
+ return result;
+ }
+
+ // Returns the type and address space of the originating variable of the lhs
+ // of the assignment statement.
+ // See https://www.w3.org/TR/WGSL/#originating-variable-section
+ std::pair<const type::Type*, builtin::AddressSpace> GetOriginatingTypeAndAddressSpace(
+ const AssignmentStatement* assign_stmt) {
+ auto* root_ident = src->Sem().GetVal(assign_stmt->lhs)->RootIdentifier();
+ if (TINT_UNLIKELY(!root_ident)) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "Unable to determine originating variable for lhs of assignment "
+ "statement";
+ return {};
+ }
+
+ return Switch(
+ root_ident->Type(), //
+ [&](const type::Reference* ref) {
+ return std::make_pair(ref->StoreType(), ref->AddressSpace());
+ },
+ [&](const type::Pointer* ptr) {
+ return std::make_pair(ptr->StoreType(), ptr->AddressSpace());
+ },
+ [&](Default) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "Expecting to find variable of type pointer or reference on lhs "
+ "of assignment statement";
+ return std::pair<const type::Type*, builtin::AddressSpace>{};
+ });
+ }
+};
+
+LocalizeStructArrayAssignment::LocalizeStructArrayAssignment() = default;
+
+LocalizeStructArrayAssignment::~LocalizeStructArrayAssignment() = default;
+
+Transform::ApplyResult LocalizeStructArrayAssignment::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ return State{src}.Run();
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/localize_struct_array_assignment.h b/src/tint/lang/wgsl/ast/transform/localize_struct_array_assignment.h
new file mode 100644
index 0000000..0b62934
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/localize_struct_array_assignment.h
@@ -0,0 +1,50 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_LOCALIZE_STRUCT_ARRAY_ASSIGNMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_LOCALIZE_STRUCT_ARRAY_ASSIGNMENT_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// This transforms replaces assignment to dynamically-indexed fixed-size arrays
+/// in structs on shader-local variables with code that copies the arrays to a
+/// temporary local variable, assigns to the local variable, and copies the
+/// array back. This is to work around FXC's compilation failure for these cases
+/// (see crbug.com/tint/1206).
+///
+/// @note Depends on the following transforms to have been run first:
+/// * SimplifyPointers
+class LocalizeStructArrayAssignment final
+ : public utils::Castable<LocalizeStructArrayAssignment, Transform> {
+ public:
+ /// Constructor
+ LocalizeStructArrayAssignment();
+
+ /// Destructor
+ ~LocalizeStructArrayAssignment() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_LOCALIZE_STRUCT_ARRAY_ASSIGNMENT_H_
diff --git a/src/tint/lang/wgsl/ast/transform/localize_struct_array_assignment_test.cc b/src/tint/lang/wgsl/ast/transform/localize_struct_array_assignment_test.cc
new file mode 100644
index 0000000..268d85d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/localize_struct_array_assignment_test.cc
@@ -0,0 +1,847 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/localize_struct_array_assignment.h"
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using LocalizeStructArrayAssignmentTest = TransformTest;
+
+TEST_F(LocalizeStructArrayAssignmentTest, EmptyModule) {
+ auto* src = R"()";
+ EXPECT_FALSE(ShouldRun<LocalizeStructArrayAssignment>(src));
+}
+
+TEST_F(LocalizeStructArrayAssignmentTest, StructArray) {
+ auto* src = R"(
+struct Uniforms {
+ i : u32,
+};
+
+struct InnerS {
+ v : i32,
+};
+
+struct OuterS {
+ a1 : array<InnerS, 8>,
+};
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s1 : OuterS;
+ s1.a1[uniforms.i] = v;
+}
+)";
+
+ auto* expect = R"(
+struct Uniforms {
+ i : u32,
+}
+
+struct InnerS {
+ v : i32,
+}
+
+struct OuterS {
+ a1 : array<InnerS, 8>,
+}
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s1 : OuterS;
+ {
+ let tint_symbol = &(s1.a1);
+ var tint_symbol_1 = *(tint_symbol);
+ tint_symbol_1[uniforms.i] = v;
+ *(tint_symbol) = tint_symbol_1;
+ }
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(LocalizeStructArrayAssignmentTest, StructArray_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s1 : OuterS;
+ s1.a1[uniforms.i] = v;
+}
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+struct OuterS {
+ a1 : array<InnerS, 8>,
+};
+
+struct InnerS {
+ v : i32,
+};
+
+struct Uniforms {
+ i : u32,
+};
+)";
+
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s1 : OuterS;
+ {
+ let tint_symbol = &(s1.a1);
+ var tint_symbol_1 = *(tint_symbol);
+ tint_symbol_1[uniforms.i] = v;
+ *(tint_symbol) = tint_symbol_1;
+ }
+}
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+struct OuterS {
+ a1 : array<InnerS, 8>,
+}
+
+struct InnerS {
+ v : i32,
+}
+
+struct Uniforms {
+ i : u32,
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(LocalizeStructArrayAssignmentTest, StructStructArray) {
+ auto* src = R"(
+struct Uniforms {
+ i : u32,
+};
+
+struct InnerS {
+ v : i32,
+};
+
+struct S1 {
+ a : array<InnerS, 8>,
+};
+
+struct OuterS {
+ s2 : S1,
+};
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s1 : OuterS;
+ s1.s2.a[uniforms.i] = v;
+}
+)";
+
+ auto* expect = R"(
+struct Uniforms {
+ i : u32,
+}
+
+struct InnerS {
+ v : i32,
+}
+
+struct S1 {
+ a : array<InnerS, 8>,
+}
+
+struct OuterS {
+ s2 : S1,
+}
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s1 : OuterS;
+ {
+ let tint_symbol = &(s1.s2.a);
+ var tint_symbol_1 = *(tint_symbol);
+ tint_symbol_1[uniforms.i] = v;
+ *(tint_symbol) = tint_symbol_1;
+ }
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(LocalizeStructArrayAssignmentTest, StructStructArray_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s1 : OuterS;
+ s1.s2.a[uniforms.i] = v;
+}
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+struct OuterS {
+ s2 : S1,
+};
+
+struct S1 {
+ a : array<InnerS, 8>,
+};
+
+struct InnerS {
+ v : i32,
+};
+
+struct Uniforms {
+ i : u32,
+};
+)";
+
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s1 : OuterS;
+ {
+ let tint_symbol = &(s1.s2.a);
+ var tint_symbol_1 = *(tint_symbol);
+ tint_symbol_1[uniforms.i] = v;
+ *(tint_symbol) = tint_symbol_1;
+ }
+}
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+struct OuterS {
+ s2 : S1,
+}
+
+struct S1 {
+ a : array<InnerS, 8>,
+}
+
+struct InnerS {
+ v : i32,
+}
+
+struct Uniforms {
+ i : u32,
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(LocalizeStructArrayAssignmentTest, StructArrayArray) {
+ auto* src = R"(
+struct Uniforms {
+ i : u32,
+ j : u32,
+};
+
+struct InnerS {
+ v : i32,
+};
+
+struct OuterS {
+ a1 : array<array<InnerS, 8>, 8>,
+};
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s1 : OuterS;
+ s1.a1[uniforms.i][uniforms.j] = v;
+}
+)";
+
+ auto* expect = R"(
+struct Uniforms {
+ i : u32,
+ j : u32,
+}
+
+struct InnerS {
+ v : i32,
+}
+
+struct OuterS {
+ a1 : array<array<InnerS, 8>, 8>,
+}
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s1 : OuterS;
+ {
+ let tint_symbol = &(s1.a1);
+ var tint_symbol_1 = *(tint_symbol);
+ tint_symbol_1[uniforms.i][uniforms.j] = v;
+ *(tint_symbol) = tint_symbol_1;
+ }
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(LocalizeStructArrayAssignmentTest, StructArrayStruct) {
+ auto* src = R"(
+struct Uniforms {
+ i : u32,
+};
+
+struct InnerS {
+ v : i32,
+};
+
+struct S1 {
+ s2 : InnerS,
+};
+
+struct OuterS {
+ a1 : array<S1, 8>,
+};
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s1 : OuterS;
+ s1.a1[uniforms.i].s2 = v;
+}
+)";
+
+ auto* expect = R"(
+struct Uniforms {
+ i : u32,
+}
+
+struct InnerS {
+ v : i32,
+}
+
+struct S1 {
+ s2 : InnerS,
+}
+
+struct OuterS {
+ a1 : array<S1, 8>,
+}
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s1 : OuterS;
+ {
+ let tint_symbol = &(s1.a1);
+ var tint_symbol_1 = *(tint_symbol);
+ tint_symbol_1[uniforms.i].s2 = v;
+ *(tint_symbol) = tint_symbol_1;
+ }
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(LocalizeStructArrayAssignmentTest, StructArrayStructArray) {
+ auto* src = R"(
+struct Uniforms {
+ i : u32,
+ j : u32,
+};
+
+struct InnerS {
+ v : i32,
+};
+
+struct S1 {
+ a2 : array<InnerS, 8>,
+};
+
+struct OuterS {
+ a1 : array<S1, 8>,
+};
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s : OuterS;
+ s.a1[uniforms.i].a2[uniforms.j] = v;
+}
+)";
+
+ auto* expect = R"(
+struct Uniforms {
+ i : u32,
+ j : u32,
+}
+
+struct InnerS {
+ v : i32,
+}
+
+struct S1 {
+ a2 : array<InnerS, 8>,
+}
+
+struct OuterS {
+ a1 : array<S1, 8>,
+}
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s : OuterS;
+ {
+ let tint_symbol = &(s.a1);
+ var tint_symbol_1 = *(tint_symbol);
+ let tint_symbol_2 = &(tint_symbol_1[uniforms.i].a2);
+ var tint_symbol_3 = *(tint_symbol_2);
+ tint_symbol_3[uniforms.j] = v;
+ *(tint_symbol_2) = tint_symbol_3;
+ *(tint_symbol) = tint_symbol_1;
+ }
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(LocalizeStructArrayAssignmentTest, IndexingWithSideEffectFunc) {
+ auto* src = R"(
+struct Uniforms {
+ i : u32,
+ j : u32,
+};
+
+struct InnerS {
+ v : i32,
+};
+
+struct S1 {
+ a2 : array<InnerS, 8>,
+};
+
+struct OuterS {
+ a1 : array<S1, 8>,
+};
+
+var<private> nextIndex : u32;
+fn getNextIndex() -> u32 {
+ nextIndex = nextIndex + 1u;
+ return nextIndex;
+}
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s : OuterS;
+ s.a1[getNextIndex()].a2[uniforms.j] = v;
+}
+)";
+
+ auto* expect = R"(
+struct Uniforms {
+ i : u32,
+ j : u32,
+}
+
+struct InnerS {
+ v : i32,
+}
+
+struct S1 {
+ a2 : array<InnerS, 8>,
+}
+
+struct OuterS {
+ a1 : array<S1, 8>,
+}
+
+var<private> nextIndex : u32;
+
+fn getNextIndex() -> u32 {
+ nextIndex = (nextIndex + 1u);
+ return nextIndex;
+}
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s : OuterS;
+ {
+ let tint_symbol = &(s.a1);
+ var tint_symbol_1 = *(tint_symbol);
+ let tint_symbol_2 = &(tint_symbol_1[getNextIndex()].a2);
+ var tint_symbol_3 = *(tint_symbol_2);
+ tint_symbol_3[uniforms.j] = v;
+ *(tint_symbol_2) = tint_symbol_3;
+ *(tint_symbol) = tint_symbol_1;
+ }
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(LocalizeStructArrayAssignmentTest, IndexingWithSideEffectFunc_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s : OuterS;
+ s.a1[getNextIndex()].a2[uniforms.j] = v;
+}
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+struct Uniforms {
+ i : u32,
+ j : u32,
+};
+
+var<private> nextIndex : u32;
+fn getNextIndex() -> u32 {
+ nextIndex = nextIndex + 1u;
+ return nextIndex;
+}
+
+struct OuterS {
+ a1 : array<S1, 8>,
+};
+
+struct S1 {
+ a2 : array<InnerS, 8>,
+};
+
+struct InnerS {
+ v : i32,
+};
+)";
+
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s : OuterS;
+ {
+ let tint_symbol = &(s.a1);
+ var tint_symbol_1 = *(tint_symbol);
+ let tint_symbol_2 = &(tint_symbol_1[getNextIndex()].a2);
+ var tint_symbol_3 = *(tint_symbol_2);
+ tint_symbol_3[uniforms.j] = v;
+ *(tint_symbol_2) = tint_symbol_3;
+ *(tint_symbol) = tint_symbol_1;
+ }
+}
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+struct Uniforms {
+ i : u32,
+ j : u32,
+}
+
+var<private> nextIndex : u32;
+
+fn getNextIndex() -> u32 {
+ nextIndex = (nextIndex + 1u);
+ return nextIndex;
+}
+
+struct OuterS {
+ a1 : array<S1, 8>,
+}
+
+struct S1 {
+ a2 : array<InnerS, 8>,
+}
+
+struct InnerS {
+ v : i32,
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(LocalizeStructArrayAssignmentTest, ViaPointerArg) {
+ auto* src = R"(
+struct Uniforms {
+ i : u32,
+};
+struct InnerS {
+ v : i32,
+};
+struct OuterS {
+ a1 : array<InnerS, 8>,
+};
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+fn f(p : ptr<function, OuterS>) {
+ var v : InnerS;
+ (*p).a1[uniforms.i] = v;
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ var s1 : OuterS;
+ f(&s1);
+}
+)";
+
+ auto* expect = R"(
+struct Uniforms {
+ i : u32,
+}
+
+struct InnerS {
+ v : i32,
+}
+
+struct OuterS {
+ a1 : array<InnerS, 8>,
+}
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+fn f(p : ptr<function, OuterS>) {
+ var v : InnerS;
+ {
+ let tint_symbol = &((*(p)).a1);
+ var tint_symbol_1 = *(tint_symbol);
+ tint_symbol_1[uniforms.i] = v;
+ *(tint_symbol) = tint_symbol_1;
+ }
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ var s1 : OuterS;
+ f(&(s1));
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(LocalizeStructArrayAssignmentTest, ViaPointerArg_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ var s1 : OuterS;
+ f(&s1);
+}
+
+fn f(p : ptr<function, OuterS>) {
+ var v : InnerS;
+ (*p).a1[uniforms.i] = v;
+}
+
+struct InnerS {
+ v : i32,
+};
+struct OuterS {
+ a1 : array<InnerS, 8>,
+};
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+struct Uniforms {
+ i : u32,
+};
+)";
+
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn main() {
+ var s1 : OuterS;
+ f(&(s1));
+}
+
+fn f(p : ptr<function, OuterS>) {
+ var v : InnerS;
+ {
+ let tint_symbol = &((*(p)).a1);
+ var tint_symbol_1 = *(tint_symbol);
+ tint_symbol_1[uniforms.i] = v;
+ *(tint_symbol) = tint_symbol_1;
+ }
+}
+
+struct InnerS {
+ v : i32,
+}
+
+struct OuterS {
+ a1 : array<InnerS, 8>,
+}
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+struct Uniforms {
+ i : u32,
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(LocalizeStructArrayAssignmentTest, ViaPointerVar) {
+ auto* src = R"(
+struct Uniforms {
+ i : u32,
+};
+
+struct InnerS {
+ v : i32,
+};
+
+struct OuterS {
+ a1 : array<InnerS, 8>,
+};
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+fn f(p : ptr<function, InnerS>, v : InnerS) {
+ *(p) = v;
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s1 : OuterS;
+ let p = &(s1.a1[uniforms.i]);
+ *(p) = v;
+}
+)";
+
+ auto* expect = R"(
+struct Uniforms {
+ i : u32,
+}
+
+struct InnerS {
+ v : i32,
+}
+
+struct OuterS {
+ a1 : array<InnerS, 8>,
+}
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+fn f(p : ptr<function, InnerS>, v : InnerS) {
+ *(p) = v;
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ var v : InnerS;
+ var s1 : OuterS;
+ let p_save = uniforms.i;
+ {
+ let tint_symbol = &(s1.a1);
+ var tint_symbol_1 = *(tint_symbol);
+ tint_symbol_1[p_save] = v;
+ *(tint_symbol) = tint_symbol_1;
+ }
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(LocalizeStructArrayAssignmentTest, VectorAssignment) {
+ auto* src = R"(
+struct Uniforms {
+ i : u32,
+}
+
+struct OuterS {
+ a1 : array<u32, 8>,
+}
+
+@group(1) @binding(4) var<uniform> uniforms : Uniforms;
+
+fn f(i : u32) -> u32 {
+ return (i + 1u);
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ var s1 : OuterS;
+ var v : vec3<f32>;
+ v[s1.a1[uniforms.i]] = 1.0;
+ v[f(s1.a1[uniforms.i])] = 1.0;
+}
+)";
+
+ // Transform does nothing here as we're not actually assigning to the array in
+ // the struct.
+ EXPECT_FALSE(ShouldRun<LocalizeStructArrayAssignment>(src));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/merge_return.cc b/src/tint/lang/wgsl/ast/transform/merge_return.cc
new file mode 100644
index 0000000..d522300
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/merge_return.cc
@@ -0,0 +1,241 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/merge_return.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/switch.h"
+#include "src/tint/utils/scoped_assignment.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::MergeReturn);
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast::transform {
+
+namespace {
+
+/// Returns `true` if `stmt` has the behavior `behavior`.
+bool HasBehavior(const Program* program, const Statement* stmt, sem::Behavior behavior) {
+ return program->Sem().Get(stmt)->Behaviors().Contains(behavior);
+}
+
+/// Returns `true` if `func` needs to be transformed.
+bool NeedsTransform(const Program* program, const Function* func) {
+ // Entry points and intrinsic declarations never need transforming.
+ if (func->IsEntryPoint() || func->body == nullptr) {
+ return false;
+ }
+
+ // Avoid transforming functions that only have a single exit point.
+ // TODO(jrprice): Alternatively, use the uniformity analysis to decide which
+ // functions need to be transformed.
+ for (auto* s : func->body->statements) {
+ // Find the first statement that has contains the Return behavior.
+ if (HasBehavior(program, s, sem::Behavior::kReturn)) {
+ // If this statement is itself a return, it will be the only exit point,
+ // so no need to apply the transform to the function.
+ if (s->Is<ReturnStatement>()) {
+ return false;
+ } else {
+ // Apply the transform in all other cases.
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+MergeReturn::MergeReturn() = default;
+
+MergeReturn::~MergeReturn() = default;
+
+namespace {
+
+/// Internal class used to during the transform.
+class State {
+ private:
+ /// The clone context.
+ CloneContext& ctx;
+
+ /// The program builder.
+ ProgramBuilder& b;
+
+ /// The function.
+ const Function* function;
+
+ /// The symbol for the return flag variable.
+ Symbol flag;
+
+ /// The symbol for the return value variable.
+ Symbol retval;
+
+ /// Tracks whether we are currently inside a loop or switch statement.
+ bool is_in_loop_or_switch = false;
+
+ public:
+ /// Constructor
+ /// @param context the clone context
+ State(CloneContext& context, const Function* func)
+ : ctx(context), b(*ctx.dst), function(func) {}
+
+ /// Process a statement (recursively).
+ void ProcessStatement(const Statement* stmt) {
+ if (stmt == nullptr || !HasBehavior(ctx.src, stmt, sem::Behavior::kReturn)) {
+ return;
+ }
+
+ Switch(
+ stmt, [&](const BlockStatement* block) { ProcessBlock(block); },
+ [&](const CaseStatement* c) { ProcessStatement(c->body); },
+ [&](const ForLoopStatement* f) {
+ TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
+ ProcessStatement(f->body);
+ },
+ [&](const IfStatement* i) {
+ ProcessStatement(i->body);
+ ProcessStatement(i->else_statement);
+ },
+ [&](const LoopStatement* l) {
+ TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
+ ProcessStatement(l->body);
+ },
+ [&](const ReturnStatement* r) {
+ utils::Vector<const Statement*, 3> stmts;
+ // Set the return flag to signal that we have hit a return.
+ stmts.Push(b.Assign(b.Expr(flag), true));
+ if (r->value) {
+ // Set the return value if necessary.
+ stmts.Push(b.Assign(b.Expr(retval), ctx.Clone(r->value)));
+ }
+ if (is_in_loop_or_switch) {
+ // If we are in a loop or switch statement, break out of it.
+ stmts.Push(b.Break());
+ }
+ ctx.Replace(r, b.Block(std::move(stmts)));
+ },
+ [&](const SwitchStatement* s) {
+ TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
+ for (auto* c : s->body) {
+ ProcessStatement(c);
+ }
+ },
+ [&](const WhileStatement* w) {
+ TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
+ ProcessStatement(w->body);
+ },
+ [&](Default) { TINT_ICE(Transform, b.Diagnostics()) << "unhandled statement type"; });
+ }
+
+ void ProcessBlock(const BlockStatement* block) {
+ // We will rebuild the contents of the block statement.
+ // We may introduce conditionals around statements that follow a statement with the
+ // `Return` behavior, so build a stack of statement lists that represent the new
+ // (potentially nested) conditional blocks.
+ utils::Vector<utils::Vector<const Statement*, 8>, 8> new_stmts({{}});
+
+ // Insert variables for the return flag and return value at the top of the function.
+ if (block == function->body) {
+ flag = b.Symbols().New("tint_return_flag");
+ new_stmts[0].Push(b.Decl(b.Var(flag, b.ty.bool_())));
+
+ if (function->return_type) {
+ retval = b.Symbols().New("tint_return_value");
+ new_stmts[0].Push(b.Decl(b.Var(retval, ctx.Clone(function->return_type))));
+ }
+ }
+
+ for (auto* s : block->statements) {
+ // Process the statement and add it to the current block.
+ ProcessStatement(s);
+ new_stmts.Back().Push(ctx.Clone(s));
+
+ // Check if the statement is or contains a return statement.
+ // We need to make sure any statements that follow this one do not get executed if the
+ // return flag has been set.
+ if (HasBehavior(ctx.src, s, sem::Behavior::kReturn)) {
+ if (is_in_loop_or_switch) {
+ // We're in a loop/switch, and so we would have inserted a `break`.
+ // If we've just come out of a loop/switch statement, we need to `break` again.
+ if (s->IsAnyOf<LoopStatement, ForLoopStatement, SwitchStatement>()) {
+ // If the loop only has the 'Return' behavior, we can just unconditionally
+ // break. Otherwise check the return flag.
+ if (HasBehavior(ctx.src, s, sem::Behavior::kNext)) {
+ new_stmts.Back().Push(
+ b.If(b.Expr(flag), b.Block(utils::Vector{b.Break()})));
+ } else {
+ new_stmts.Back().Push(b.Break());
+ }
+ }
+ } else {
+ // Create a new list for any subsequent statements, which we will wrap in a
+ // conditional block.
+ new_stmts.Push({});
+ }
+ }
+ }
+
+ // Descend the stack of new block statements, wrapping them in conditionals.
+ while (new_stmts.Length() > 1) {
+ const IfStatement* i = nullptr;
+ if (new_stmts.Back().Length() > 0) {
+ i = b.If(b.Not(b.Expr(flag)), b.Block(new_stmts.Back()));
+ }
+ new_stmts.Pop();
+ if (i) {
+ new_stmts.Back().Push(i);
+ }
+ }
+
+ // Insert the final return statement at the end of the function body.
+ if (block == function->body && retval.IsValid()) {
+ new_stmts[0].Push(b.Return(b.Expr(retval)));
+ }
+
+ ctx.Replace(block, b.Block(new_stmts[0]));
+ }
+};
+
+} // namespace
+
+Transform::ApplyResult MergeReturn::Apply(const Program* src, const DataMap&, DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ bool made_changes = false;
+
+ for (auto* func : ctx.src->AST().Functions()) {
+ if (!NeedsTransform(ctx.src, func)) {
+ continue;
+ }
+
+ State state(ctx, func);
+ state.ProcessStatement(func->body);
+ made_changes = true;
+ }
+
+ if (!made_changes) {
+ return SkipTransform;
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/merge_return.h b/src/tint/lang/wgsl/ast/transform/merge_return.h
new file mode 100644
index 0000000..3d295b9
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/merge_return.h
@@ -0,0 +1,38 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_MERGE_RETURN_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_MERGE_RETURN_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// Merge return statements into a single return at the end of the function.
+class MergeReturn final : public utils::Castable<MergeReturn, Transform> {
+ public:
+ /// Constructor
+ MergeReturn();
+ /// Destructor
+ ~MergeReturn() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_MERGE_RETURN_H_
diff --git a/src/tint/lang/wgsl/ast/transform/merge_return_test.cc b/src/tint/lang/wgsl/ast/transform/merge_return_test.cc
new file mode 100644
index 0000000..ad492e6
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/merge_return_test.cc
@@ -0,0 +1,860 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/merge_return.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using MergeReturnTest = TransformTest;
+
+TEST_F(MergeReturnTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<MergeReturn>(src));
+}
+
+TEST_F(MergeReturnTest, ShouldRunOnlyEntryPoints) {
+ auto* src = R"(
+@fragment
+fn foo() {}
+
+@compute @workgroup_size(1)
+fn bar() {}
+)";
+
+ EXPECT_FALSE(ShouldRun<MergeReturn>(src));
+}
+
+TEST_F(MergeReturnTest, ShouldRunUserFuncNoReturn) {
+ auto* src = R"(
+fn foo() {}
+
+@compute @workgroup_size(1)
+fn bar() {
+ foo();
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<MergeReturn>(src));
+}
+
+TEST_F(MergeReturnTest, ShouldRunUserFuncSingleReturn) {
+ auto* src = R"(
+fn foo() -> i32 {
+ return 42;
+}
+
+@compute @workgroup_size(1)
+fn bar() {
+ foo();
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<MergeReturn>(src));
+}
+
+TEST_F(MergeReturnTest, ShouldRunUserFuncMultipleReturn) {
+ auto* src = R"(
+fn foo(x : i32) -> i32 {
+ if (x == 0) {
+ return 42;
+ }
+ return 0;
+}
+
+@compute @workgroup_size(1)
+fn bar() {
+ foo(7);
+}
+)";
+
+ EXPECT_TRUE(ShouldRun<MergeReturn>(src));
+}
+
+TEST_F(MergeReturnTest, EmptyModule) {
+ auto* src = R"()";
+
+ auto* expect = src;
+
+ auto got = Run<MergeReturn>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(MergeReturnTest, SingleExitPoint) {
+ auto* src = R"(
+fn foo(x : i32) -> i32 {
+ var retval = 0;
+ if ((x == 0)) {
+ retval = 42;
+ } else {
+ retval = 7;
+ }
+ return retval;
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<MergeReturn>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(MergeReturnTest, Basic) {
+ auto* src = R"(
+var<private> non_uniform_global : i32;
+
+fn foo() {
+ if (non_uniform_global == 0) {
+ return;
+ } else {
+ return;
+ }
+}
+)";
+
+ auto* expect = R"(
+var<private> non_uniform_global : i32;
+
+fn foo() {
+ var tint_return_flag : bool;
+ if ((non_uniform_global == 0)) {
+ {
+ tint_return_flag = true;
+ }
+ } else {
+ {
+ tint_return_flag = true;
+ }
+ }
+}
+)";
+
+ auto got = Run<MergeReturn>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(MergeReturnTest, BasicRetval) {
+ auto* src = R"(
+var<private> non_uniform_global : i32;
+
+fn foo() -> i32 {
+ if (non_uniform_global == 0) {
+ return 1;
+ } else {
+ return 2;
+ }
+}
+)";
+
+ auto* expect = R"(
+var<private> non_uniform_global : i32;
+
+fn foo() -> i32 {
+ var tint_return_flag : bool;
+ var tint_return_value : i32;
+ if ((non_uniform_global == 0)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 1;
+ }
+ } else {
+ {
+ tint_return_flag = true;
+ tint_return_value = 2;
+ }
+ }
+ return tint_return_value;
+}
+)";
+
+ auto got = Run<MergeReturn>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(MergeReturnTest, CodeAfterIf) {
+ auto* src = R"(
+var<private> non_uniform_global : i32;
+
+var<private> bar : i32;
+
+fn foo() -> i32 {
+ if (non_uniform_global == 0) {
+ return 1;
+ }
+ bar = 42;
+ return 2;
+}
+)";
+
+ auto* expect = R"(
+var<private> non_uniform_global : i32;
+
+var<private> bar : i32;
+
+fn foo() -> i32 {
+ var tint_return_flag : bool;
+ var tint_return_value : i32;
+ if ((non_uniform_global == 0)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 1;
+ }
+ }
+ if (!(tint_return_flag)) {
+ bar = 42;
+ {
+ tint_return_flag = true;
+ tint_return_value = 2;
+ }
+ }
+ return tint_return_value;
+}
+)";
+
+ auto got = Run<MergeReturn>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(MergeReturnTest, SequenceOfConditionalReturns) {
+ auto* src = R"(
+var<private> non_uniform_global : i32;
+
+var<private> bar : i32;
+
+fn foo() -> i32 {
+ if (non_uniform_global == 0) {
+ return 1;
+ }
+ bar = 42;
+ if (non_uniform_global == 1) {
+ return 2;
+ }
+ bar = 43;
+ if (non_uniform_global == 2) {
+ return 3;
+ }
+ bar = 44;
+ if (non_uniform_global == 3) {
+ return 4;
+ }
+ bar = 45;
+ return 0;
+}
+)";
+
+ auto* expect = R"(
+var<private> non_uniform_global : i32;
+
+var<private> bar : i32;
+
+fn foo() -> i32 {
+ var tint_return_flag : bool;
+ var tint_return_value : i32;
+ if ((non_uniform_global == 0)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 1;
+ }
+ }
+ if (!(tint_return_flag)) {
+ bar = 42;
+ if ((non_uniform_global == 1)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 2;
+ }
+ }
+ if (!(tint_return_flag)) {
+ bar = 43;
+ if ((non_uniform_global == 2)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 3;
+ }
+ }
+ if (!(tint_return_flag)) {
+ bar = 44;
+ if ((non_uniform_global == 3)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 4;
+ }
+ }
+ if (!(tint_return_flag)) {
+ bar = 45;
+ {
+ tint_return_flag = true;
+ tint_return_value = 0;
+ }
+ }
+ }
+ }
+ }
+ return tint_return_value;
+}
+)";
+
+ auto got = Run<MergeReturn>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(MergeReturnTest, NestedConditionals) {
+ auto* src = R"(
+var<private> non_uniform_global : i32;
+
+var<private> bar : i32;
+
+fn foo() -> i32 {
+ if (non_uniform_global == 0) {
+ if (true) {
+ return 1;
+ } else {
+ return 2;
+ }
+ } else {
+ if (true) {
+ bar = 42;
+ } else if (false) {
+ return 3;
+ } else if (true) {
+ return 4;
+ }
+ return 5;
+ }
+}
+)";
+
+ auto* expect = R"(
+var<private> non_uniform_global : i32;
+
+var<private> bar : i32;
+
+fn foo() -> i32 {
+ var tint_return_flag : bool;
+ var tint_return_value : i32;
+ if ((non_uniform_global == 0)) {
+ if (true) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 1;
+ }
+ } else {
+ {
+ tint_return_flag = true;
+ tint_return_value = 2;
+ }
+ }
+ } else {
+ if (true) {
+ bar = 42;
+ } else if (false) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 3;
+ }
+ } else if (true) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 4;
+ }
+ }
+ if (!(tint_return_flag)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 5;
+ }
+ }
+ }
+ return tint_return_value;
+}
+)";
+
+ auto got = Run<MergeReturn>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(MergeReturnTest, Loop) {
+ auto* src = R"(
+var<private> non_uniform_global : i32;
+
+var<private> bar : i32;
+
+fn foo() -> i32 {
+ var i = 0;
+ loop {
+ bar = i;
+ if (non_uniform_global == 0) {
+ return non_uniform_global;
+ }
+ if (i >= 10) {
+ break;
+ }
+ }
+ return 0;
+}
+)";
+
+ auto* expect = R"(
+var<private> non_uniform_global : i32;
+
+var<private> bar : i32;
+
+fn foo() -> i32 {
+ var tint_return_flag : bool;
+ var tint_return_value : i32;
+ var i = 0;
+ loop {
+ bar = i;
+ if ((non_uniform_global == 0)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = non_uniform_global;
+ break;
+ }
+ }
+ if ((i >= 10)) {
+ break;
+ }
+ }
+ if (!(tint_return_flag)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 0;
+ }
+ }
+ return tint_return_value;
+}
+)";
+
+ auto got = Run<MergeReturn>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(MergeReturnTest, NestedLoop) {
+ auto* src = R"(
+var<private> non_uniform_global : i32;
+
+var<private> bar : i32;
+
+fn foo() -> i32 {
+ var i = 0;
+ loop {
+ loop {
+ bar = i;
+ if (non_uniform_global == 0) {
+ return non_uniform_global;
+ }
+ if (i >= 10) {
+ break;
+ }
+ }
+ bar = 10;
+ }
+ return 0;
+}
+)";
+
+ auto* expect = R"(
+var<private> non_uniform_global : i32;
+
+var<private> bar : i32;
+
+fn foo() -> i32 {
+ var tint_return_flag : bool;
+ var tint_return_value : i32;
+ var i = 0;
+ loop {
+ loop {
+ bar = i;
+ if ((non_uniform_global == 0)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = non_uniform_global;
+ break;
+ }
+ }
+ if ((i >= 10)) {
+ break;
+ }
+ }
+ if (tint_return_flag) {
+ break;
+ }
+ bar = 10;
+ }
+ if (!(tint_return_flag)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 0;
+ }
+ }
+ return tint_return_value;
+}
+)";
+
+ auto got = Run<MergeReturn>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(MergeReturnTest, NestedLoopAlwaysReturns) {
+ auto* src = R"(
+var<private> non_uniform_global : i32;
+
+var<private> bar : i32;
+
+fn foo() -> i32 {
+ var i = 0;
+ loop {
+ bar = i;
+ loop {
+ return non_uniform_global;
+ }
+ if (i == 10) {
+ break;
+ }
+ }
+ return 0;
+}
+)";
+
+ auto* expect = R"(
+var<private> non_uniform_global : i32;
+
+var<private> bar : i32;
+
+fn foo() -> i32 {
+ var tint_return_flag : bool;
+ var tint_return_value : i32;
+ var i = 0;
+ loop {
+ bar = i;
+ loop {
+ {
+ tint_return_flag = true;
+ tint_return_value = non_uniform_global;
+ break;
+ }
+ }
+ break;
+ if ((i == 10)) {
+ break;
+ }
+ }
+ if (!(tint_return_flag)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 0;
+ }
+ }
+ return tint_return_value;
+}
+)";
+
+ auto got = Run<MergeReturn>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(MergeReturnTest, ForLoop) {
+ auto* src = R"(
+var<private> non_uniform_global : array<i32, 10>;
+
+fn foo() -> i32 {
+ for (var i = 0; i < 10; i++) {
+ if (non_uniform_global[i] == 0) {
+ return i;
+ }
+ }
+ return 0;
+}
+)";
+
+ auto* expect = R"(
+var<private> non_uniform_global : array<i32, 10>;
+
+fn foo() -> i32 {
+ var tint_return_flag : bool;
+ var tint_return_value : i32;
+ for(var i = 0; (i < 10); i++) {
+ if ((non_uniform_global[i] == 0)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = i;
+ break;
+ }
+ }
+ }
+ if (!(tint_return_flag)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 0;
+ }
+ }
+ return tint_return_value;
+}
+)";
+
+ auto got = Run<MergeReturn>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(MergeReturnTest, WhileLoop) {
+ auto* src = R"(
+var<private> non_uniform_global : array<i32, 10>;
+
+fn foo() -> i32 {
+ var i = 0;
+ while (i < 10) {
+ if (non_uniform_global[i] == 0) {
+ return i;
+ }
+ i++;
+ }
+ return 0;
+}
+)";
+
+ auto* expect = R"(
+var<private> non_uniform_global : array<i32, 10>;
+
+fn foo() -> i32 {
+ var tint_return_flag : bool;
+ var tint_return_value : i32;
+ var i = 0;
+ while((i < 10)) {
+ if ((non_uniform_global[i] == 0)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = i;
+ break;
+ }
+ }
+ i++;
+ }
+ if (!(tint_return_flag)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 0;
+ }
+ }
+ return tint_return_value;
+}
+)";
+
+ auto got = Run<MergeReturn>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(MergeReturnTest, Switch) {
+ auto* src = R"(
+var<private> non_uniform_global : i32;
+
+var<private> bar : i32;
+
+fn foo() -> i32 {
+ switch (non_uniform_global) {
+ case 0 {
+ return 0;
+ }
+ case 1 {
+ if (false) {
+ return 1;
+ }
+ bar = 6;
+ }
+ case 2 {
+ bar = 5;
+ }
+ default {
+ return 42;
+ }
+ }
+ return 0;
+}
+)";
+
+ auto* expect = R"(
+var<private> non_uniform_global : i32;
+
+var<private> bar : i32;
+
+fn foo() -> i32 {
+ var tint_return_flag : bool;
+ var tint_return_value : i32;
+ switch(non_uniform_global) {
+ case 0: {
+ {
+ tint_return_flag = true;
+ tint_return_value = 0;
+ break;
+ }
+ }
+ case 1: {
+ if (false) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 1;
+ break;
+ }
+ }
+ bar = 6;
+ }
+ case 2: {
+ bar = 5;
+ }
+ default: {
+ {
+ tint_return_flag = true;
+ tint_return_value = 42;
+ break;
+ }
+ }
+ }
+ if (!(tint_return_flag)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 0;
+ }
+ }
+ return tint_return_value;
+}
+)";
+
+ auto got = Run<MergeReturn>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(MergeReturnTest, ComplexNesting) {
+ auto* src = R"(
+var<private> non_uniform_global : i32;
+
+var<private> bar : i32;
+
+fn foo() -> i32 {
+ for (var i = 0; i < 10; i++) {
+ bar = i;
+
+ switch (non_uniform_global) {
+ case 0 {
+ return 0;
+ }
+ case 1 {
+ if (false) {
+ loop {
+ if (i == non_uniform_global) {
+ return 1;
+ }
+ continuing {
+ bar++;
+ break if (i == 7);
+ }
+ }
+ }
+ bar = 0;
+ }
+ default {
+ return 42;
+ }
+ }
+ }
+ return 0;
+}
+)";
+
+ auto* expect = R"(
+var<private> non_uniform_global : i32;
+
+var<private> bar : i32;
+
+fn foo() -> i32 {
+ var tint_return_flag : bool;
+ var tint_return_value : i32;
+ for(var i = 0; (i < 10); i++) {
+ bar = i;
+ switch(non_uniform_global) {
+ case 0: {
+ {
+ tint_return_flag = true;
+ tint_return_value = 0;
+ break;
+ }
+ }
+ case 1: {
+ if (false) {
+ loop {
+ if ((i == non_uniform_global)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 1;
+ break;
+ }
+ }
+
+ continuing {
+ bar++;
+ break if (i == 7);
+ }
+ }
+ if (tint_return_flag) {
+ break;
+ }
+ }
+ bar = 0;
+ }
+ default: {
+ {
+ tint_return_flag = true;
+ tint_return_value = 42;
+ break;
+ }
+ }
+ }
+ if (tint_return_flag) {
+ break;
+ }
+ }
+ if (!(tint_return_flag)) {
+ {
+ tint_return_flag = true;
+ tint_return_value = 0;
+ }
+ }
+ return tint_return_value;
+}
+)";
+
+ auto got = Run<MergeReturn>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.cc b/src/tint/lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.cc
new file mode 100644
index 0000000..96658d8
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.cc
@@ -0,0 +1,606 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.h"
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/disable_validation_attribute.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/module.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/utils/string.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::ModuleScopeVarToEntryPointParam);
+
+using namespace tint::builtin::fluent_types; // NOLINT
+
+namespace tint::ast::transform {
+namespace {
+
+using StructMemberList = utils::Vector<const StructMember*, 8>;
+
+// The name of the struct member for arrays that are wrapped in structures.
+const char* kWrappedArrayMemberName = "arr";
+
+bool ShouldRun(const Program* program) {
+ for (auto* decl : program->AST().GlobalDeclarations()) {
+ if (decl->Is<Variable>()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+// Returns `true` if `type` is or contains a matrix type.
+bool ContainsMatrix(const type::Type* type) {
+ type = type->UnwrapRef();
+ if (type->Is<type::Matrix>()) {
+ return true;
+ } else if (auto* ary = type->As<type::Array>()) {
+ return ContainsMatrix(ary->ElemType());
+ } else if (auto* str = type->As<type::Struct>()) {
+ for (auto* member : str->Members()) {
+ if (ContainsMatrix(member->Type())) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+} // namespace
+
+/// PIMPL state for the transform
+struct ModuleScopeVarToEntryPointParam::State {
+ /// The clone context.
+ CloneContext& ctx;
+
+ /// Constructor
+ /// @param context the clone context
+ explicit State(CloneContext& context) : ctx(context) {}
+
+ /// Clone any struct types that are contained in `ty` (including `ty` itself),
+ /// and add it to the global declarations now, so that they precede new global
+ /// declarations that need to reference them.
+ /// @param ty the type to clone
+ void CloneStructTypes(const type::Type* ty) {
+ if (auto* str = ty->As<sem::Struct>()) {
+ if (!cloned_structs_.emplace(str).second) {
+ // The struct has already been cloned.
+ return;
+ }
+
+ // Recurse into members.
+ for (auto* member : str->Members()) {
+ CloneStructTypes(member->Type());
+ }
+
+ // Clone the struct and add it to the global declaration list.
+ // Remove the old declaration.
+ auto* ast_str = str->Declaration();
+ ctx.dst->AST().AddTypeDecl(ctx.Clone(ast_str));
+ ctx.Remove(ctx.src->AST().GlobalDeclarations(), ast_str);
+ } else if (auto* arr = ty->As<type::Array>()) {
+ CloneStructTypes(arr->ElemType());
+ }
+ }
+
+ /// Process a variable `var` that is referenced in the entry point function `func`.
+ /// This will redeclare the variable as a function parameter, possibly as a pointer.
+ /// Some workgroup variables will be redeclared as a member inside a workgroup structure.
+ /// @param func the entry point function
+ /// @param var the variable
+ /// @param new_var_symbol the symbol to use for the replacement
+ /// @param workgroup_param helper function to get a symbol to a workgroup struct parameter
+ /// @param workgroup_parameter_members reference to a list of a workgroup struct members
+ /// @param is_pointer output signalling whether the replacement is a pointer
+ /// @param is_wrapped output signalling whether the replacement is wrapped in a struct
+ void ProcessVariableInEntryPoint(const Function* func,
+ const sem::Variable* var,
+ Symbol new_var_symbol,
+ std::function<Symbol()> workgroup_param,
+ StructMemberList& workgroup_parameter_members,
+ bool& is_pointer,
+ bool& is_wrapped) {
+ auto* ty = var->Type()->UnwrapRef();
+
+ // Helper to create an AST node for the store type of the variable.
+ auto store_type = [&] { return CreateASTTypeFor(ctx, ty); };
+
+ builtin::AddressSpace sc = var->AddressSpace();
+ switch (sc) {
+ case builtin::AddressSpace::kHandle: {
+ // For a texture or sampler variable, redeclare it as an entry point parameter.
+ // Disable entry point parameter validation.
+ auto* disable_validation =
+ ctx.dst->Disable(DisabledValidation::kEntryPointParameter);
+ auto attrs = ctx.Clone(var->Declaration()->attributes);
+ attrs.Push(disable_validation);
+ auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs);
+ ctx.InsertFront(func->params, param);
+
+ break;
+ }
+ case builtin::AddressSpace::kStorage:
+ case builtin::AddressSpace::kUniform: {
+ // Variables into the Storage and Uniform address spaces are redeclared as entry
+ // point parameters with a pointer type.
+ auto attributes = ctx.Clone(var->Declaration()->attributes);
+ attributes.Push(ctx.dst->Disable(DisabledValidation::kEntryPointParameter));
+ attributes.Push(ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace));
+
+ auto param_type = store_type();
+ if (auto* arr = ty->As<type::Array>();
+ arr && arr->Count()->Is<type::RuntimeArrayCount>()) {
+ // Wrap runtime-sized arrays in structures, so that we can declare pointers to
+ // them. Ideally we'd just emit the array itself as a pointer, but this is not
+ // representable in Tint's AST.
+ CloneStructTypes(ty);
+ auto* wrapper = ctx.dst->Structure(
+ ctx.dst->Sym(), utils::Vector{
+ ctx.dst->Member(kWrappedArrayMemberName, param_type),
+ });
+ param_type = ctx.dst->ty.Of(wrapper);
+ is_wrapped = true;
+ }
+
+ param_type = sc == builtin::AddressSpace::kStorage
+ ? ctx.dst->ty.ptr(sc, param_type, var->Access())
+ : ctx.dst->ty.ptr(sc, param_type);
+ auto* param = ctx.dst->Param(new_var_symbol, param_type, attributes);
+ ctx.InsertFront(func->params, param);
+ is_pointer = true;
+
+ break;
+ }
+ case builtin::AddressSpace::kWorkgroup: {
+ if (ContainsMatrix(var->Type())) {
+ // Due to a bug in the MSL compiler, we use a threadgroup memory argument for
+ // any workgroup allocation that contains a matrix. See crbug.com/tint/938.
+ // TODO(jrprice): Do this for all other workgroup variables too.
+
+ // Create a member in the workgroup parameter struct.
+ auto member = ctx.Clone(var->Declaration()->name->symbol);
+ workgroup_parameter_members.Push(ctx.dst->Member(member, store_type()));
+ CloneStructTypes(var->Type()->UnwrapRef());
+
+ // Create a function-scope variable that is a pointer to the member.
+ auto* member_ptr = ctx.dst->AddressOf(
+ ctx.dst->MemberAccessor(ctx.dst->Deref(workgroup_param()), member));
+ auto* local_var = ctx.dst->Let(
+ new_var_symbol, ctx.dst->ty.ptr<workgroup>(store_type()), member_ptr);
+ ctx.InsertFront(func->body->statements, ctx.dst->Decl(local_var));
+ is_pointer = true;
+ } else {
+ auto* disable_validation =
+ ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace);
+ auto* initializer = ctx.Clone(var->Declaration()->initializer);
+ auto* local_var = ctx.dst->Var(new_var_symbol, store_type(), sc, initializer,
+ utils::Vector{disable_validation});
+ ctx.InsertFront(func->body->statements, ctx.dst->Decl(local_var));
+ }
+ break;
+ }
+ case builtin::AddressSpace::kPushConstant: {
+ ctx.dst->Diagnostics().add_error(
+ diag::System::Transform,
+ "unhandled module-scope address space (" + utils::ToString(sc) + ")");
+ break;
+ }
+ default: {
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "unhandled module-scope address space (" << sc << ")";
+ break;
+ }
+ }
+ }
+
+ /// Process a variable `var` that is referenced in the user-defined function `func`.
+ /// This will redeclare the variable as a function parameter, possibly as a pointer.
+ /// @param func the user-defined function
+ /// @param var the variable
+ /// @param new_var_symbol the symbol to use for the replacement
+ /// @param is_pointer output signalling whether the replacement is a pointer or not
+ void ProcessVariableInUserFunction(const Function* func,
+ const sem::Variable* var,
+ Symbol new_var_symbol,
+ bool& is_pointer) {
+ auto* ty = var->Type()->UnwrapRef();
+ auto param_type = CreateASTTypeFor(ctx, ty);
+ auto sc = var->AddressSpace();
+ switch (sc) {
+ case builtin::AddressSpace::kPrivate:
+ // Private variables are passed all together in a struct.
+ return;
+ case builtin::AddressSpace::kStorage:
+ case builtin::AddressSpace::kUniform:
+ case builtin::AddressSpace::kHandle:
+ case builtin::AddressSpace::kWorkgroup:
+ break;
+ case builtin::AddressSpace::kPushConstant: {
+ ctx.dst->Diagnostics().add_error(
+ diag::System::Transform,
+ "unhandled module-scope address space (" + utils::ToString(sc) + ")");
+ break;
+ }
+ default: {
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "unhandled module-scope address space (" << sc << ")";
+ }
+ }
+
+ // Use a pointer for non-handle types.
+ utils::Vector<const Attribute*, 2> attributes;
+ if (!ty->is_handle()) {
+ param_type = sc == builtin::AddressSpace::kStorage
+ ? ctx.dst->ty.ptr(sc, param_type, var->Access())
+ : ctx.dst->ty.ptr(sc, param_type);
+ is_pointer = true;
+
+ // Disable validation of the parameter's address space and of arguments passed to it.
+ attributes.Push(ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace));
+ attributes.Push(ctx.dst->Disable(DisabledValidation::kIgnoreInvalidPointerArgument));
+ }
+
+ // Redeclare the variable as a parameter.
+ ctx.InsertBack(func->params,
+ ctx.dst->Param(new_var_symbol, param_type, std::move(attributes)));
+ }
+
+ /// Replace all uses of `var` in `func` with references to `new_var`.
+ /// @param func the function
+ /// @param var the variable to replace
+ /// @param new_var the symbol to use for replacement
+ /// @param is_pointer true if `new_var` is a pointer to the new variable
+ /// @param member_name if valid, the name of the struct member that holds this variable
+ void ReplaceUsesInFunction(const Function* func,
+ const sem::Variable* var,
+ Symbol new_var,
+ bool is_pointer,
+ Symbol member_name) {
+ for (auto* user : var->Users()) {
+ if (user->Stmt()->Function()->Declaration() == func) {
+ const Expression* expr = ctx.dst->Expr(new_var);
+ if (is_pointer) {
+ // If this identifier is used by an address-of operator, just remove the
+ // address-of instead of adding a deref, since we already have a pointer.
+ auto* ident = user->Declaration()->As<IdentifierExpression>();
+ if (ident_to_address_of_.count(ident) && !member_name.IsValid()) {
+ ctx.Replace(ident_to_address_of_[ident], expr);
+ continue;
+ }
+
+ expr = ctx.dst->Deref(expr);
+ }
+ if (member_name.IsValid()) {
+ // Get the member from the containing structure.
+ expr = ctx.dst->MemberAccessor(expr, member_name);
+ }
+ ctx.Replace(user->Declaration(), expr);
+ }
+ }
+ }
+
+ /// Process the module.
+ void Process() {
+ // Predetermine the list of function calls that need to be replaced.
+ using CallList = utils::Vector<const CallExpression*, 8>;
+ std::unordered_map<const Function*, CallList> calls_to_replace;
+
+ utils::Vector<const Function*, 8> functions_to_process;
+
+ // Collect private variables into a single structure.
+ StructMemberList private_struct_members;
+ utils::Vector<std::function<const AssignmentStatement*()>, 4> private_initializers;
+ std::unordered_set<const Function*> uses_privates;
+
+ // Build a list of functions that transitively reference any module-scope variables.
+ for (auto* decl : ctx.src->Sem().Module()->DependencyOrderedDeclarations()) {
+ if (auto* var = decl->As<Var>()) {
+ auto* sem_var = ctx.src->Sem().Get(var);
+ if (sem_var->AddressSpace() == builtin::AddressSpace::kPrivate) {
+ // Create a member in the private variable struct.
+ auto* ty = sem_var->Type()->UnwrapRef();
+ auto name = ctx.Clone(var->name->symbol);
+ private_struct_members.Push(ctx.dst->Member(name, CreateASTTypeFor(ctx, ty)));
+ CloneStructTypes(ty);
+
+ // Create a statement to assign the initializer if present.
+ if (var->initializer) {
+ private_initializers.Push([&, name, var] {
+ return ctx.dst->Assign(
+ ctx.dst->MemberAccessor(PrivateStructVariableName(), name),
+ ctx.Clone(var->initializer));
+ });
+ }
+ }
+ continue;
+ }
+
+ auto* func_ast = decl->As<Function>();
+ if (!func_ast) {
+ continue;
+ }
+
+ auto* func_sem = ctx.src->Sem().Get(func_ast);
+
+ bool needs_processing = false;
+ for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
+ if (var->AddressSpace() != builtin::AddressSpace::kUndefined) {
+ if (var->AddressSpace() == builtin::AddressSpace::kPrivate) {
+ uses_privates.insert(func_ast);
+ }
+ needs_processing = true;
+ }
+ }
+ if (needs_processing) {
+ functions_to_process.Push(func_ast);
+
+ // Find all of the calls to this function that will need to be replaced.
+ for (auto* call : func_sem->CallSites()) {
+ calls_to_replace[call->Stmt()->Function()->Declaration()].Push(
+ call->Declaration());
+ }
+ }
+ }
+
+ if (!private_struct_members.IsEmpty()) {
+ // Create the private variable struct.
+ ctx.dst->Structure(PrivateStructName(), std::move(private_struct_members));
+ // Passing a pointer to a private variable will now involve passing a pointer to the
+ // member of a structure, so enable the extension that allows this.
+ ctx.dst->Enable(builtin::Extension::kChromiumExperimentalFullPtrParameters);
+ }
+
+ // Build a list of `&ident` expressions. We'll use this later to avoid generating
+ // expressions of the form `&*ident`, which break WGSL validation rules when this expression
+ // is passed to a function.
+ // TODO(jrprice): We should add support for bidirectional SEM tree traversal so that we can
+ // do this on the fly instead.
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ auto* address_of = node->As<UnaryOpExpression>();
+ if (!address_of || address_of->op != UnaryOp::kAddressOf) {
+ continue;
+ }
+ if (auto* ident = address_of->expr->As<IdentifierExpression>()) {
+ ident_to_address_of_[ident] = address_of;
+ }
+ }
+
+ for (auto* func_ast : functions_to_process) {
+ auto* func_sem = ctx.src->Sem().Get(func_ast);
+ bool is_entry_point = func_ast->IsEntryPoint();
+ bool needs_pointer_aliasing = false;
+
+ // Map module-scope variables onto their replacement.
+ struct NewVar {
+ Symbol symbol;
+ bool is_pointer;
+ bool is_wrapped;
+ };
+ std::unordered_map<const sem::Variable*, NewVar> var_to_newvar;
+
+ // We aggregate all workgroup variables into a struct to avoid hitting MSL's limit for
+ // threadgroup memory arguments.
+ Symbol workgroup_parameter_symbol;
+ StructMemberList workgroup_parameter_members;
+ auto workgroup_param = [&] {
+ if (!workgroup_parameter_symbol.IsValid()) {
+ workgroup_parameter_symbol = ctx.dst->Sym();
+ }
+ return workgroup_parameter_symbol;
+ };
+
+ // If this function references any private variables, it needs to take the private
+ // variable struct as a parameter (or declare it, if it is an entry point function).
+ if (uses_privates.count(func_ast)) {
+ if (is_entry_point) {
+ // Create a local declaration for the private variable struct.
+ auto* var =
+ ctx.dst->Var(PrivateStructVariableName(), ctx.dst->ty(PrivateStructName()),
+ builtin::AddressSpace::kPrivate,
+ utils::Vector{
+ ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace),
+ });
+ ctx.InsertFront(func_ast->body->statements, ctx.dst->Decl(var));
+
+ // Initialize the members of that struct with the original initializers.
+ for (auto init : private_initializers) {
+ ctx.InsertFront(func_ast->body->statements, init());
+ }
+ } else {
+ // Create a parameter that is a pointer to the private variable struct.
+ auto ptr = ctx.dst->ty.ptr<private_>(ctx.dst->ty(PrivateStructName()));
+ auto* param = ctx.dst->Param(PrivateStructVariableName(), ptr);
+ ctx.InsertBack(func_ast->params, param);
+ }
+ }
+
+ // Process and redeclare all variables referenced by the function.
+ for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
+ if (var->AddressSpace() == builtin::AddressSpace::kUndefined) {
+ continue;
+ }
+ if (var->AddressSpace() == builtin::AddressSpace::kPrivate) {
+ // Private variable are collected into a single struct that is passed by
+ // pointer (handled above), so we just need to replace the uses here.
+ ReplaceUsesInFunction(func_ast, var, PrivateStructVariableName(),
+ /* is_pointer */ !is_entry_point,
+ ctx.Clone(var->Declaration()->name->symbol));
+ continue;
+ }
+
+ // This is the symbol for the variable that replaces the module-scope var.
+ auto new_var_symbol = ctx.dst->Sym();
+
+ // Track whether the new variable is a pointer or not.
+ bool is_pointer = false;
+
+ // Track whether the new variable was wrapped in a struct or not.
+ bool is_wrapped = false;
+
+ // Process the variable to redeclare it as a parameter or local variable.
+ if (is_entry_point) {
+ ProcessVariableInEntryPoint(func_ast, var, new_var_symbol, workgroup_param,
+ workgroup_parameter_members, is_pointer,
+ is_wrapped);
+ } else {
+ ProcessVariableInUserFunction(func_ast, var, new_var_symbol, is_pointer);
+ if (var->AddressSpace() == builtin::AddressSpace::kWorkgroup) {
+ needs_pointer_aliasing = true;
+ }
+ }
+
+ // Record the replacement symbol.
+ var_to_newvar[var] = {new_var_symbol, is_pointer, is_wrapped};
+
+ // Replace all uses of the module-scope variable.
+ ReplaceUsesInFunction(
+ func_ast, var, new_var_symbol, is_pointer,
+ is_wrapped ? ctx.dst->Sym(kWrappedArrayMemberName) : Symbol());
+ }
+
+ // Allow pointer aliasing if needed.
+ if (needs_pointer_aliasing) {
+ ctx.InsertBack(func_ast->attributes,
+ ctx.dst->Disable(DisabledValidation::kIgnorePointerAliasing));
+ }
+
+ if (!workgroup_parameter_members.IsEmpty()) {
+ // Create the workgroup memory parameter.
+ // The parameter is a struct that contains members for each workgroup variable.
+ auto* str =
+ ctx.dst->Structure(ctx.dst->Sym(), std::move(workgroup_parameter_members));
+ auto param_type = ctx.dst->ty.ptr(workgroup, ctx.dst->ty.Of(str));
+ auto* param =
+ ctx.dst->Param(workgroup_param(), param_type,
+ utils::Vector{
+ ctx.dst->Disable(DisabledValidation::kEntryPointParameter),
+ ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace),
+ });
+ ctx.InsertFront(func_ast->params, param);
+ }
+
+ // Pass the variables as pointers to any functions that need them.
+ for (auto* call : calls_to_replace[func_ast]) {
+ auto* call_sem = ctx.src->Sem().Get(call)->Unwrap()->As<sem::Call>();
+ auto* target_sem = call_sem->Target()->As<sem::Function>();
+
+ // Pass the private variable struct pointer if needed.
+ if (uses_privates.count(target_sem->Declaration())) {
+ const Expression* arg = ctx.dst->Expr(PrivateStructVariableName());
+ if (is_entry_point) {
+ arg = ctx.dst->AddressOf(arg);
+ }
+ ctx.InsertBack(call->args, arg);
+ }
+
+ // Add new arguments for any variables that are needed by the callee.
+ // For entry points, pass non-handle types as pointers.
+ for (auto* target_var : target_sem->TransitivelyReferencedGlobals()) {
+ auto sc = target_var->AddressSpace();
+ if (sc == builtin::AddressSpace::kUndefined) {
+ continue;
+ }
+
+ auto it = var_to_newvar.find(target_var);
+ if (it == var_to_newvar.end()) {
+ // No replacement was created for this function.
+ continue;
+ }
+
+ auto new_var = it->second;
+ bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
+ const Expression* arg = ctx.dst->Expr(new_var.symbol);
+ if (new_var.is_wrapped) {
+ // The variable is wrapped in a struct, so we need to pass a pointer to the
+ // struct member instead.
+ arg = ctx.dst->AddressOf(
+ ctx.dst->MemberAccessor(ctx.dst->Deref(arg), kWrappedArrayMemberName));
+ } else if (is_entry_point && !is_handle && !new_var.is_pointer) {
+ // We need to pass a pointer and we don't already have one, so take
+ // the address of the new variable.
+ arg = ctx.dst->AddressOf(arg);
+ }
+ ctx.InsertBack(call->args, arg);
+ }
+ }
+ }
+
+ // Now remove all module-scope variables with these address spaces.
+ for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
+ auto* var_sem = ctx.src->Sem().Get(var_ast);
+ if (var_sem->AddressSpace() != builtin::AddressSpace::kUndefined) {
+ ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
+ }
+ }
+ }
+
+ /// @returns the name of the structure that contains all of the module-scope private variables
+ Symbol PrivateStructName() {
+ if (!private_struct_name.IsValid()) {
+ private_struct_name = ctx.dst->Symbols().New("tint_private_vars_struct");
+ }
+ return private_struct_name;
+ }
+
+ /// @returns the name of the variable that contains all of the module-scope private variables
+ Symbol PrivateStructVariableName() {
+ if (!private_struct_variable_name.IsValid()) {
+ private_struct_variable_name = ctx.dst->Symbols().New("tint_private_vars");
+ }
+ return private_struct_variable_name;
+ }
+
+ private:
+ // The structures that have already been cloned by this transform.
+ std::unordered_set<const sem::Struct*> cloned_structs_;
+
+ // Map from identifier expression to the address-of expression that uses it.
+ std::unordered_map<const IdentifierExpression*, const UnaryOpExpression*> ident_to_address_of_;
+
+ // The name of the structure that contains all the module-scope private variables.
+ Symbol private_struct_name;
+
+ // The name of the structure variable that contains all the module-scope private variables.
+ Symbol private_struct_variable_name;
+};
+
+ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
+
+ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
+
+Transform::ApplyResult ModuleScopeVarToEntryPointParam::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+ State state{ctx};
+ state.Process();
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.h b/src/tint/lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.h
new file mode 100644
index 0000000..4136aa6
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.h
@@ -0,0 +1,83 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_MODULE_SCOPE_VAR_TO_ENTRY_POINT_PARAM_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_MODULE_SCOPE_VAR_TO_ENTRY_POINT_PARAM_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// Move module-scope variables into the entry point as parameters.
+///
+/// MSL does not allow module-scope variables to have any address space other
+/// than `constant`. This transform moves all module-scope declarations into the
+/// entry point function (either as parameters or function-scope variables) and
+/// then passes them as pointer parameters to any function that references them.
+///
+/// Since WGSL does not allow entry point parameters or function-scope variables
+/// to have these address spaces, we annotate the new variable declarations
+/// with an attribute that bypasses that validation rule.
+///
+/// Before:
+/// ```
+/// struct S {
+/// f : f32;
+/// };
+/// @binding(0) @group(0)
+/// var<storage, read> s : S;
+/// var<private> p : f32 = 2.0;
+///
+/// fn foo() {
+/// p = p + f;
+/// }
+///
+/// @compute @workgroup_size(1)
+/// fn main() {
+/// foo();
+/// }
+/// ```
+///
+/// After:
+/// ```
+/// fn foo(p : ptr<private, f32>, sptr : ptr<storage, S, read>) {
+/// *p = *p + (*sptr).f;
+/// }
+///
+/// @compute @workgroup_size(1)
+/// fn main(sptr : ptr<storage, S, read>) {
+/// var<private> p : f32 = 2.0;
+/// foo(&p, sptr);
+/// }
+/// ```
+class ModuleScopeVarToEntryPointParam final
+ : public utils::Castable<ModuleScopeVarToEntryPointParam, Transform> {
+ public:
+ /// Constructor
+ ModuleScopeVarToEntryPointParam();
+ /// Destructor
+ ~ModuleScopeVarToEntryPointParam() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_MODULE_SCOPE_VAR_TO_ENTRY_POINT_PARAM_H_
diff --git a/src/tint/lang/wgsl/ast/transform/module_scope_var_to_entry_point_param_test.cc b/src/tint/lang/wgsl/ast/transform/module_scope_var_to_entry_point_param_test.cc
new file mode 100644
index 0000000..2689bb0
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/module_scope_var_to_entry_point_param_test.cc
@@ -0,0 +1,1372 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using ModuleScopeVarToEntryPointParamTest = TransformTest;
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<ModuleScopeVarToEntryPointParam>(src));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, ShouldRunHasGlobal) {
+ auto* src = R"(
+var<private> v : i32;
+)";
+
+ EXPECT_TRUE(ShouldRun<ModuleScopeVarToEntryPointParam>(src));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Basic) {
+ auto* src = R"(
+var<private> p : f32;
+var<workgroup> w : f32;
+
+@compute @workgroup_size(1)
+fn main() {
+ w = p;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ p : f32,
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ @internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol : f32;
+ tint_symbol = tint_private_vars.p;
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Basic_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ w = p;
+}
+
+var<workgroup> w : f32;
+var<private> p : f32;
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ p : f32,
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ @internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol : f32;
+ tint_symbol = tint_private_vars.p;
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, FunctionCalls) {
+ auto* src = R"(
+var<private> p : f32;
+var<workgroup> w : f32;
+
+fn no_uses() {
+}
+
+fn zoo() {
+ p = p * 2.0;
+}
+
+fn bar(a : f32, b : f32) {
+ p = a;
+ w = b;
+ zoo();
+}
+
+fn foo(a : f32) {
+ let b : f32 = 2.0;
+ bar(a, b);
+ no_uses();
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ foo(1.0);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ p : f32,
+}
+
+fn no_uses() {
+}
+
+fn zoo(tint_private_vars : ptr<private, tint_private_vars_struct>) {
+ (*(tint_private_vars)).p = ((*(tint_private_vars)).p * 2.0);
+}
+
+@internal(disable_validation__ignore_pointer_aliasing)
+fn bar(a : f32, b : f32, tint_private_vars : ptr<private, tint_private_vars_struct>, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol : ptr<workgroup, f32>) {
+ (*(tint_private_vars)).p = a;
+ *(tint_symbol) = b;
+ zoo(tint_private_vars);
+}
+
+@internal(disable_validation__ignore_pointer_aliasing)
+fn foo(a : f32, tint_private_vars : ptr<private, tint_private_vars_struct>, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_1 : ptr<workgroup, f32>) {
+ let b : f32 = 2.0;
+ bar(a, b, tint_private_vars, tint_symbol_1);
+ no_uses();
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ @internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol_2 : f32;
+ foo(1.0, &(tint_private_vars), &(tint_symbol_2));
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, FunctionCalls_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ foo(1.0);
+}
+
+fn foo(a : f32) {
+ let b : f32 = 2.0;
+ bar(a, b);
+ no_uses();
+}
+
+fn no_uses() {
+}
+
+fn bar(a : f32, b : f32) {
+ p = a;
+ w = b;
+ zoo();
+}
+
+fn zoo() {
+ p = p * 2.0;
+}
+
+var<private> p : f32;
+var<workgroup> w : f32;
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ p : f32,
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ @internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol_2 : f32;
+ foo(1.0, &(tint_private_vars), &(tint_symbol_2));
+}
+
+@internal(disable_validation__ignore_pointer_aliasing)
+fn foo(a : f32, tint_private_vars : ptr<private, tint_private_vars_struct>, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_1 : ptr<workgroup, f32>) {
+ let b : f32 = 2.0;
+ bar(a, b, tint_private_vars, tint_symbol_1);
+ no_uses();
+}
+
+fn no_uses() {
+}
+
+@internal(disable_validation__ignore_pointer_aliasing)
+fn bar(a : f32, b : f32, tint_private_vars : ptr<private, tint_private_vars_struct>, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol : ptr<workgroup, f32>) {
+ (*(tint_private_vars)).p = a;
+ *(tint_symbol) = b;
+ zoo(tint_private_vars);
+}
+
+fn zoo(tint_private_vars : ptr<private, tint_private_vars_struct>) {
+ (*(tint_private_vars)).p = ((*(tint_private_vars)).p * 2.0);
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Initializers) {
+ auto* src = R"(
+var<private> a : f32 = 1.0;
+var<private> b : f32 = f32();
+
+@compute @workgroup_size(1)
+fn main() {
+ let x : f32 = a + b;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ a : f32,
+ b : f32,
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ tint_private_vars.a = 1.0;
+ tint_private_vars.b = f32();
+ let x : f32 = (tint_private_vars.a + tint_private_vars.b);
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Initializers_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ let x : f32 = a + b;
+}
+
+var<private> b : f32 = f32();
+var<private> a : f32 = 1.0;
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ a : f32,
+ b : f32,
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ tint_private_vars.a = 1.0;
+ tint_private_vars.b = f32();
+ let x : f32 = (tint_private_vars.a + tint_private_vars.b);
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Pointers) {
+ auto* src = R"(
+var<private> p : f32;
+var<workgroup> w : f32;
+
+@compute @workgroup_size(1)
+fn main() {
+ let p_ptr : ptr<private, f32> = &p;
+ let w_ptr : ptr<workgroup, f32> = &w;
+ let x : f32 = *p_ptr + *w_ptr;
+ *p_ptr = x;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ p : f32,
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ @internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol : f32;
+ let p_ptr : ptr<private, f32> = &(tint_private_vars.p);
+ let w_ptr : ptr<workgroup, f32> = &(tint_symbol);
+ let x : f32 = (*(p_ptr) + *(w_ptr));
+ *(p_ptr) = x;
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Pointers_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ let p_ptr : ptr<private, f32> = &p;
+ let w_ptr : ptr<workgroup, f32> = &w;
+ let x : f32 = *p_ptr + *w_ptr;
+ *p_ptr = x;
+}
+
+var<workgroup> w : f32;
+var<private> p : f32;
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ p : f32,
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ @internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol : f32;
+ let p_ptr : ptr<private, f32> = &(tint_private_vars.p);
+ let w_ptr : ptr<workgroup, f32> = &(tint_symbol);
+ let x : f32 = (*(p_ptr) + *(w_ptr));
+ *(p_ptr) = x;
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// TODO(crbug.com/tint/1758): Requires support for workgroup pointer parameters, which is
+// unsupported until WGSL 1.1
+TEST_F(ModuleScopeVarToEntryPointParamTest, DISABLED_FoldAddressOfDeref) {
+ auto* src = R"(
+var<workgroup> v : f32;
+
+fn bar(p : ptr<workgroup, f32>) {
+ (*p) = 0.0;
+}
+
+fn foo() {
+ bar(&v);
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ foo();
+}
+)";
+
+ auto* expect = R"(
+fn bar(p : ptr<workgroup, f32>) {
+ *(p) = 0.0;
+}
+
+fn foo(@internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol : ptr<workgroup, f32>) {
+ bar(tint_symbol);
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ @internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol_1 : f32;
+ foo(&(tint_symbol_1));
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// TODO(crbug.com/tint/1758): Requires support for workgroup pointer parameters, which is
+// unsupported until WGSL 1.1
+TEST_F(ModuleScopeVarToEntryPointParamTest, DISABLED_FoldAddressOfDeref_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ foo();
+}
+
+fn foo() {
+ bar(&v);
+}
+
+fn bar(p : ptr<workgroup, f32>) {
+ (*p) = 0.0;
+}
+
+var<workgroup> v : f32;
+)";
+
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn main() {
+ @internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol_1 : f32;
+ foo(&(tint_symbol_1));
+}
+
+fn foo(@internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol : ptr<workgroup, f32>) {
+ bar(tint_symbol);
+}
+
+fn bar(p : ptr<workgroup, f32>) {
+ *(p) = 0.0;
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Buffers_Basic) {
+ auto* src = R"(
+struct S {
+ a : f32,
+};
+
+@group(0) @binding(0)
+var<uniform> u : S;
+@group(0) @binding(1)
+var<storage> s : S;
+
+@compute @workgroup_size(1)
+fn main() {
+ _ = u;
+ _ = s;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ a : f32,
+}
+
+@compute @workgroup_size(1)
+fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol : ptr<uniform, S>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol_1 : ptr<storage, S, read>) {
+ _ = *(tint_symbol);
+ _ = *(tint_symbol_1);
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Buffers_Basic_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ _ = u;
+ _ = s;
+}
+
+@group(0) @binding(0) var<uniform> u : S;
+@group(0) @binding(1) var<storage> s : S;
+
+struct S {
+ a : f32,
+};
+
+)";
+
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol : ptr<uniform, S>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol_1 : ptr<storage, S, read>) {
+ _ = *(tint_symbol);
+ _ = *(tint_symbol_1);
+}
+
+struct S {
+ a : f32,
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray) {
+ auto* src = R"(
+@group(0) @binding(0)
+var<storage> buffer : array<f32>;
+
+@compute @workgroup_size(1)
+fn main() {
+ _ = buffer[0];
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol_1 {
+ arr : array<f32>,
+}
+
+@compute @workgroup_size(1)
+fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol : ptr<storage, tint_symbol_1, read>) {
+ _ = (*(tint_symbol)).arr[0];
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ _ = buffer[0];
+}
+
+@group(0) @binding(0)
+var<storage> buffer : array<f32>;
+)";
+
+ auto* expect = R"(
+struct tint_symbol_1 {
+ arr : array<f32>,
+}
+
+@compute @workgroup_size(1)
+fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol : ptr<storage, tint_symbol_1, read>) {
+ _ = (*(tint_symbol)).arr[0];
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArrayInsideFunction) {
+ auto* src = R"(
+@group(0) @binding(0)
+var<storage> buffer : array<f32>;
+
+fn foo() {
+ _ = buffer[0];
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ foo();
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol_2 {
+ arr : array<f32>,
+}
+
+fn foo(@internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol : ptr<storage, array<f32>, read>) {
+ _ = (*(tint_symbol))[0];
+}
+
+@compute @workgroup_size(1)
+fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol_1 : ptr<storage, tint_symbol_2, read>) {
+ foo(&((*(tint_symbol_1)).arr));
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArrayInsideFunction_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ foo();
+}
+
+fn foo() {
+ _ = buffer[0];
+}
+
+@group(0) @binding(0) var<storage> buffer : array<f32>;
+)";
+
+ auto* expect = R"(
+struct tint_symbol_2 {
+ arr : array<f32>,
+}
+
+@compute @workgroup_size(1)
+fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol_1 : ptr<storage, tint_symbol_2, read>) {
+ foo(&((*(tint_symbol_1)).arr));
+}
+
+fn foo(@internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol : ptr<storage, array<f32>, read>) {
+ _ = (*(tint_symbol))[0];
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray_Alias) {
+ auto* src = R"(
+alias myarray = array<f32>;
+
+@group(0) @binding(0)
+var<storage> buffer : myarray;
+
+@compute @workgroup_size(1)
+fn main() {
+ _ = buffer[0];
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol_1 {
+ arr : array<f32>,
+}
+
+alias myarray = array<f32>;
+
+@compute @workgroup_size(1)
+fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol : ptr<storage, tint_symbol_1, read>) {
+ _ = (*(tint_symbol)).arr[0];
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray_Alias_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ _ = buffer[0];
+}
+
+@group(0) @binding(0) var<storage> buffer : myarray;
+
+alias myarray = array<f32>;
+)";
+
+ auto* expect = R"(
+struct tint_symbol_1 {
+ arr : array<f32>,
+}
+
+@compute @workgroup_size(1)
+fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol : ptr<storage, tint_symbol_1, read>) {
+ _ = (*(tint_symbol)).arr[0];
+}
+
+alias myarray = array<f32>;
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_ArrayOfStruct) {
+ auto* src = R"(
+struct S {
+ f : f32,
+};
+
+@group(0) @binding(0)
+var<storage> buffer : array<S>;
+
+@compute @workgroup_size(1)
+fn main() {
+ _ = buffer[0];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ f : f32,
+}
+
+struct tint_symbol_1 {
+ arr : array<S>,
+}
+
+@compute @workgroup_size(1)
+fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol : ptr<storage, tint_symbol_1, read>) {
+ _ = (*(tint_symbol)).arr[0];
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_ArrayOfStruct_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ _ = buffer[0];
+}
+
+@group(0) @binding(0) var<storage> buffer : array<S>;
+
+struct S {
+ f : f32,
+};
+)";
+
+ auto* expect = R"(
+struct S {
+ f : f32,
+}
+
+struct tint_symbol_1 {
+ arr : array<S>,
+}
+
+@compute @workgroup_size(1)
+fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol : ptr<storage, tint_symbol_1, read>) {
+ _ = (*(tint_symbol)).arr[0];
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Buffers_FunctionCalls) {
+ auto* src = R"(
+struct S {
+ a : f32,
+};
+
+@group(0) @binding(0)
+var<uniform> u : S;
+@group(0) @binding(1)
+var<storage> s : S;
+
+fn no_uses() {
+}
+
+fn bar(a : f32, b : f32) {
+ _ = u;
+ _ = s;
+}
+
+fn foo(a : f32) {
+ let b : f32 = 2.0;
+ _ = u;
+ bar(a, b);
+ no_uses();
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ foo(1.0);
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ a : f32,
+}
+
+fn no_uses() {
+}
+
+fn bar(a : f32, b : f32, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol : ptr<uniform, S>, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_1 : ptr<storage, S, read>) {
+ _ = *(tint_symbol);
+ _ = *(tint_symbol_1);
+}
+
+fn foo(a : f32, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_2 : ptr<uniform, S>, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_3 : ptr<storage, S, read>) {
+ let b : f32 = 2.0;
+ _ = *(tint_symbol_2);
+ bar(a, b, tint_symbol_2, tint_symbol_3);
+ no_uses();
+}
+
+@compute @workgroup_size(1)
+fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol_4 : ptr<uniform, S>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol_5 : ptr<storage, S, read>) {
+ foo(1.0, tint_symbol_4, tint_symbol_5);
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Buffers_FunctionCalls_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ foo(1.0);
+}
+
+fn foo(a : f32) {
+ let b : f32 = 2.0;
+ _ = u;
+ bar(a, b);
+ no_uses();
+}
+
+fn no_uses() {
+}
+
+fn bar(a : f32, b : f32) {
+ _ = u;
+ _ = s;
+}
+
+struct S {
+ a : f32,
+};
+
+@group(0) @binding(0)
+var<uniform> u : S;
+@group(0) @binding(1)
+var<storage> s : S;
+)";
+
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol_4 : ptr<uniform, S>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol_5 : ptr<storage, S, read>) {
+ foo(1.0, tint_symbol_4, tint_symbol_5);
+}
+
+fn foo(a : f32, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_2 : ptr<uniform, S>, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_3 : ptr<storage, S, read>) {
+ let b : f32 = 2.0;
+ _ = *(tint_symbol_2);
+ bar(a, b, tint_symbol_2, tint_symbol_3);
+ no_uses();
+}
+
+fn no_uses() {
+}
+
+fn bar(a : f32, b : f32, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol : ptr<uniform, S>, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_1 : ptr<storage, S, read>) {
+ _ = *(tint_symbol);
+ _ = *(tint_symbol_1);
+}
+
+struct S {
+ a : f32,
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, HandleTypes_Basic) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+@group(0) @binding(1) var s : sampler;
+
+@compute @workgroup_size(1)
+fn main() {
+ _ = t;
+ _ = s;
+}
+)";
+
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) tint_symbol : texture_2d<f32>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) tint_symbol_1 : sampler) {
+ _ = tint_symbol;
+ _ = tint_symbol_1;
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, HandleTypes_FunctionCalls) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+@group(0) @binding(1) var s : sampler;
+
+fn no_uses() {
+}
+
+fn bar(a : f32, b : f32) {
+ _ = t;
+ _ = s;
+}
+
+fn foo(a : f32) {
+ let b : f32 = 2.0;
+ _ = t;
+ bar(a, b);
+ no_uses();
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ foo(1.0);
+}
+)";
+
+ auto* expect = R"(
+fn no_uses() {
+}
+
+fn bar(a : f32, b : f32, tint_symbol : texture_2d<f32>, tint_symbol_1 : sampler) {
+ _ = tint_symbol;
+ _ = tint_symbol_1;
+}
+
+fn foo(a : f32, tint_symbol_2 : texture_2d<f32>, tint_symbol_3 : sampler) {
+ let b : f32 = 2.0;
+ _ = tint_symbol_2;
+ bar(a, b, tint_symbol_2, tint_symbol_3);
+ no_uses();
+}
+
+@compute @workgroup_size(1)
+fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) tint_symbol_4 : texture_2d<f32>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) tint_symbol_5 : sampler) {
+ foo(1.0, tint_symbol_4, tint_symbol_5);
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, HandleTypes_FunctionCalls_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ foo(1.0);
+}
+
+fn foo(a : f32) {
+ let b : f32 = 2.0;
+ _ = t;
+ bar(a, b);
+ no_uses();
+}
+
+fn no_uses() {
+}
+
+fn bar(a : f32, b : f32) {
+ _ = t;
+ _ = s;
+}
+
+@group(0) @binding(0) var t : texture_2d<f32>;
+@group(0) @binding(1) var s : sampler;
+)";
+
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) tint_symbol_4 : texture_2d<f32>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) tint_symbol_5 : sampler) {
+ foo(1.0, tint_symbol_4, tint_symbol_5);
+}
+
+fn foo(a : f32, tint_symbol_2 : texture_2d<f32>, tint_symbol_3 : sampler) {
+ let b : f32 = 2.0;
+ _ = tint_symbol_2;
+ bar(a, b, tint_symbol_2, tint_symbol_3);
+ no_uses();
+}
+
+fn no_uses() {
+}
+
+fn bar(a : f32, b : f32, tint_symbol : texture_2d<f32>, tint_symbol_1 : sampler) {
+ _ = tint_symbol;
+ _ = tint_symbol_1;
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, Matrix) {
+ auto* src = R"(
+var<workgroup> m : mat2x2<f32>;
+
+@compute @workgroup_size(1)
+fn main() {
+ let x = m;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol_2 {
+ m : mat2x2<f32>,
+}
+
+@compute @workgroup_size(1)
+fn main(@internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol_1 : ptr<workgroup, tint_symbol_2>) {
+ let tint_symbol : ptr<workgroup, mat2x2<f32>> = &((*(tint_symbol_1)).m);
+ let x = *(tint_symbol);
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, NestedMatrix) {
+ auto* src = R"(
+struct S1 {
+ m : mat2x2<f32>,
+};
+struct S2 {
+ s : S1,
+};
+var<workgroup> m : array<S2, 4>;
+
+@compute @workgroup_size(1)
+fn main() {
+ let x = m;
+}
+)";
+
+ auto* expect = R"(
+struct S1 {
+ m : mat2x2<f32>,
+}
+
+struct S2 {
+ s : S1,
+}
+
+struct tint_symbol_2 {
+ m : array<S2, 4u>,
+}
+
+@compute @workgroup_size(1)
+fn main(@internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol_1 : ptr<workgroup, tint_symbol_2>) {
+ let tint_symbol : ptr<workgroup, array<S2, 4u>> = &((*(tint_symbol_1)).m);
+ let x = *(tint_symbol);
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that we do not duplicate a struct type used by multiple workgroup
+// variables that are promoted to threadgroup memory arguments.
+TEST_F(ModuleScopeVarToEntryPointParamTest, DuplicateThreadgroupArgumentTypes) {
+ auto* src = R"(
+struct S {
+ m : mat2x2<f32>,
+};
+
+var<workgroup> a : S;
+
+var<workgroup> b : S;
+
+@compute @workgroup_size(1)
+fn main() {
+ let x = a;
+ let y = b;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+struct tint_symbol_3 {
+ a : S,
+ b : S,
+}
+
+@compute @workgroup_size(1)
+fn main(@internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol_1 : ptr<workgroup, tint_symbol_3>) {
+ let tint_symbol : ptr<workgroup, S> = &((*(tint_symbol_1)).a);
+ let tint_symbol_2 : ptr<workgroup, S> = &((*(tint_symbol_1)).b);
+ let x = *(tint_symbol);
+ let y = *(tint_symbol_2);
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that we do not duplicate a struct type used by multiple workgroup
+// variables that are promoted to threadgroup memory arguments.
+TEST_F(ModuleScopeVarToEntryPointParamTest, DuplicateThreadgroupArgumentTypes_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+ let x = a;
+ let y = b;
+}
+
+var<workgroup> a : S;
+var<workgroup> b : S;
+
+struct S {
+ m : mat2x2<f32>,
+};
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+struct tint_symbol_3 {
+ a : S,
+ b : S,
+}
+
+@compute @workgroup_size(1)
+fn main(@internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_address_space) tint_symbol_1 : ptr<workgroup, tint_symbol_3>) {
+ let tint_symbol : ptr<workgroup, S> = &((*(tint_symbol_1)).a);
+ let tint_symbol_2 : ptr<workgroup, S> = &((*(tint_symbol_1)).b);
+ let x = *(tint_symbol);
+ let y = *(tint_symbol_2);
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, UnusedVariables) {
+ auto* src = R"(
+struct S {
+ a : f32,
+};
+
+var<private> p : f32;
+var<workgroup> w : f32;
+var<private> p_with_init : f32 = 42;
+
+@group(0) @binding(0)
+var<uniform> ub : S;
+@group(0) @binding(1)
+var<storage> sb : S;
+
+@group(0) @binding(2) var t : texture_2d<f32>;
+@group(0) @binding(3) var s : sampler;
+
+@compute @workgroup_size(1)
+fn main() {
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ p : f32,
+ p_with_init : f32,
+}
+
+struct S {
+ a : f32,
+}
+
+@compute @workgroup_size(1)
+fn main() {
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, MultiplePrivateVariables) {
+ auto* src = R"(
+struct S {
+ a : f32,
+ b : f32,
+ c : f32,
+}
+
+var<private> a : f32;
+var<private> b : f32 = 42;
+var<private> c : S = S(1, 2, 3);
+var<private> d : S;
+var<private> unused : f32;
+
+fn foo(x : f32) -> f32 {
+ return (a + b + c.a + d.c) * x;
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ _ = foo(1.0);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ a : f32,
+ b : f32,
+ c : f32,
+}
+
+struct tint_private_vars_struct {
+ a : f32,
+ b : f32,
+ c : S,
+ d : S,
+ unused : f32,
+}
+
+fn foo(x : f32, tint_private_vars : ptr<private, tint_private_vars_struct>) -> f32 {
+ return (((((*(tint_private_vars)).a + (*(tint_private_vars)).b) + (*(tint_private_vars)).c.a) + (*(tint_private_vars)).d.c) * x);
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ tint_private_vars.b = 42;
+ tint_private_vars.c = S(1, 2, 3);
+ _ = foo(1.0, &(tint_private_vars));
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, MultiplePrivateVariables_OutOfOrder) {
+ auto* src = R"(
+var<private> a : f32;
+var<private> c : S = S(1, 2, 3);
+var<private> unused : f32;
+
+@compute @workgroup_size(1)
+fn main() {
+ _ = foo(1.0);
+}
+
+fn foo(x : f32) -> f32 {
+ return (a + b + c.a + d.c) * x;
+}
+
+var<private> b : f32 = 42;
+
+struct S {
+ a : f32,
+ b : f32,
+ c : f32,
+}
+
+var<private> d : S;
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ a : f32,
+ b : f32,
+ c : f32,
+}
+
+struct tint_private_vars_struct {
+ a : f32,
+ c : S,
+ unused : f32,
+ b : f32,
+ d : S,
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ tint_private_vars.c = S(1, 2, 3);
+ tint_private_vars.b = 42;
+ _ = foo(1.0, &(tint_private_vars));
+}
+
+fn foo(x : f32, tint_private_vars : ptr<private, tint_private_vars_struct>) -> f32 {
+ return (((((*(tint_private_vars)).a + (*(tint_private_vars)).b) + (*(tint_private_vars)).c.a) + (*(tint_private_vars)).d.c) * x);
+}
+)";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ModuleScopeVarToEntryPointParamTest, EmtpyModule) {
+ auto* src = "";
+
+ auto got = Run<ModuleScopeVarToEntryPointParam>(src);
+
+ EXPECT_EQ(src, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/multiplanar_external_texture.cc b/src/tint/lang/wgsl/ast/transform/multiplanar_external_texture.cc
new file mode 100644
index 0000000..90f00a4
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/multiplanar_external_texture.cc
@@ -0,0 +1,529 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/multiplanar_external_texture.h"
+
+#include <string>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/function.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/type/texture_dimension.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::MultiplanarExternalTexture);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::MultiplanarExternalTexture::NewBindingPoints);
+
+namespace tint::ast::transform {
+namespace {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+bool ShouldRun(const Program* program) {
+ auto ext = program->Types().Find<type::ExternalTexture>();
+ return ext != nullptr;
+}
+
+/// This struct stores symbols for new bindings created as a result of transforming a
+/// texture_external instance.
+struct NewBindingSymbols {
+ Symbol params;
+ Symbol plane_0;
+ Symbol plane_1;
+};
+} // namespace
+
+/// PIMPL state for the transform
+struct MultiplanarExternalTexture::State {
+ /// The clone context.
+ CloneContext& ctx;
+
+ /// ProgramBuilder for the context
+ ProgramBuilder& b;
+
+ /// Destination binding locations for the expanded texture_external provided
+ /// as input into the transform.
+ const NewBindingPoints* new_binding_points;
+
+ /// Symbol for the GammaTransferParams
+ Symbol gamma_transfer_struct_sym;
+
+ /// Symbol for the ExternalTextureParams struct
+ Symbol params_struct_sym;
+
+ /// Symbol for the textureLoadExternal functions
+ utils::Hashmap<const sem::CallTarget*, Symbol, 2> texture_load_external_fns;
+
+ /// Symbol for the textureSampleExternal function
+ Symbol texture_sample_external_sym;
+
+ /// Symbol for the textureSampleExternalDEPRECATED function
+ Symbol texture_sample_external_deprecated_sym;
+
+ /// Symbol for the gammaCorrection function
+ Symbol gamma_correction_sym;
+
+ /// Storage for new bindings that have been created corresponding to an original
+ /// texture_external binding.
+ std::unordered_map<const sem::Variable*, NewBindingSymbols> new_binding_symbols;
+
+ /// Constructor
+ /// @param context the clone
+ /// @param newBindingPoints the input destination binding locations for the
+ /// expanded texture_external
+ State(CloneContext& context, const NewBindingPoints* newBindingPoints)
+ : ctx(context), b(*context.dst), new_binding_points(newBindingPoints) {}
+
+ /// Processes the module
+ void Process() {
+ auto& sem = ctx.src->Sem();
+
+ // For each texture_external binding, we replace it with a texture_2d<f32> binding and
+ // create two additional bindings (one texture_2d<f32> to represent the secondary plane and
+ // one uniform buffer for the ExternalTextureParams struct).
+ for (auto* global : ctx.src->AST().GlobalVariables()) {
+ auto* sem_var = sem.Get<sem::GlobalVariable>(global);
+ if (!sem_var->Type()->UnwrapRef()->Is<type::ExternalTexture>()) {
+ continue;
+ }
+
+ // If the attributes are empty, then this must be a texture_external passed as a
+ // function parameter. These variables are transformed elsewhere.
+ if (global->attributes.IsEmpty()) {
+ continue;
+ }
+
+ // If we find a texture_external binding, we know we must emit the ExternalTextureParams
+ // struct.
+ if (!params_struct_sym.IsValid()) {
+ createExtTexParamsStructs();
+ }
+
+ // The binding points for the newly introduced bindings must have been provided to this
+ // transform. We fetch the new binding points by providing the original texture_external
+ // binding points into the passed map.
+ sem::BindingPoint bp = *sem_var->BindingPoint();
+
+ BindingsMap::const_iterator it = new_binding_points->bindings_map.find(bp);
+ if (it == new_binding_points->bindings_map.end()) {
+ b.Diagnostics().add_error(
+ diag::System::Transform,
+ "missing new binding points for texture_external at binding {" +
+ std::to_string(bp.group) + "," + std::to_string(bp.binding) + "}");
+ continue;
+ }
+
+ BindingPoints bps = it->second;
+
+ // Symbols for the newly created bindings must be saved so they can be passed as
+ // parameters later. These are placed in a map and keyed by the source symbol associated
+ // with the texture_external binding that corresponds with the new destination bindings.
+ // NewBindingSymbols new_binding_syms;
+ auto& syms = new_binding_symbols[sem_var];
+ syms.plane_0 = ctx.Clone(global->name->symbol);
+ syms.plane_1 = b.Symbols().New("ext_tex_plane_1");
+ b.GlobalVar(syms.plane_1, b.ty.sampled_texture(type::TextureDimension::k2d, b.ty.f32()),
+ b.Group(AInt(bps.plane_1.group)), b.Binding(AInt(bps.plane_1.binding)));
+ syms.params = b.Symbols().New("ext_tex_params");
+ b.GlobalVar(syms.params, b.ty("ExternalTextureParams"), builtin::AddressSpace::kUniform,
+ b.Group(AInt(bps.params.group)), b.Binding(AInt(bps.params.binding)));
+
+ // Replace the original texture_external binding with a texture_2d<f32> binding.
+ auto cloned_attributes = ctx.Clone(global->attributes);
+ const Expression* cloned_initializer = ctx.Clone(global->initializer);
+
+ auto* replacement =
+ b.Var(syms.plane_0, b.ty.sampled_texture(type::TextureDimension::k2d, b.ty.f32()),
+ cloned_initializer, cloned_attributes);
+ ctx.Replace(global, replacement);
+ }
+
+ // We must update all the texture_external parameters for user declared functions.
+ for (auto* fn : ctx.src->AST().Functions()) {
+ for (const Variable* param : fn->params) {
+ if (auto* sem_var = sem.Get(param)) {
+ if (!sem_var->Type()->UnwrapRef()->Is<type::ExternalTexture>()) {
+ continue;
+ }
+ // If we find a texture_external, we must ensure the ExternalTextureParams
+ // struct exists.
+ if (!params_struct_sym.IsValid()) {
+ createExtTexParamsStructs();
+ }
+ // When a texture_external is found, we insert all components the
+ // texture_external into the parameter list. We must also place the new symbols
+ // into the transform state so they can be used when transforming function
+ // calls.
+ auto& syms = new_binding_symbols[sem_var];
+ syms.plane_0 = ctx.Clone(param->name->symbol);
+ syms.plane_1 = b.Symbols().New("ext_tex_plane_1");
+ syms.params = b.Symbols().New("ext_tex_params");
+ auto tex2d_f32 = [&] {
+ return b.ty.sampled_texture(type::TextureDimension::k2d, b.ty.f32());
+ };
+ ctx.Replace(param, b.Param(syms.plane_0, tex2d_f32()));
+ ctx.InsertAfter(fn->params, param, b.Param(syms.plane_1, tex2d_f32()));
+ ctx.InsertAfter(fn->params, param,
+ b.Param(syms.params, b.ty(params_struct_sym)));
+ }
+ }
+ }
+
+ // Transform the external texture builtin calls into calls to the external texture
+ // functions.
+ ctx.ReplaceAll([&](const CallExpression* expr) -> const CallExpression* {
+ auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>();
+ auto* builtin = call->Target()->As<sem::Builtin>();
+
+ if (builtin && !builtin->Parameters().IsEmpty() &&
+ builtin->Parameters()[0]->Type()->Is<type::ExternalTexture>() &&
+ builtin->Type() != builtin::Function::kTextureDimensions) {
+ if (auto* var_user =
+ sem.GetVal(expr->args[0])->UnwrapLoad()->As<sem::VariableUser>()) {
+ auto it = new_binding_symbols.find(var_user->Variable());
+ if (it == new_binding_symbols.end()) {
+ // If valid new binding locations were not provided earlier, we would have
+ // been unable to create these symbols. An error message was emitted
+ // earlier, so just return early to avoid internal compiler errors and
+ // retain a clean error message.
+ return nullptr;
+ }
+ auto& syms = it->second;
+
+ switch (builtin->Type()) {
+ case builtin::Function::kTextureLoad:
+ return createTextureLoad(call, syms);
+ case builtin::Function::kTextureSampleBaseClampToEdge:
+ return createTextureSampleBaseClampToEdge(expr, syms);
+ default:
+ break;
+ }
+ }
+ } else if (call->Target()->Is<sem::Function>()) {
+ // The call expression may be to a user-defined function that contains a
+ // texture_external parameter. These need to be expanded out to multiple plane
+ // textures and the texture parameters structure.
+ for (auto* arg : expr->args) {
+ if (auto* var_user = sem.GetVal(arg)->UnwrapLoad()->As<sem::VariableUser>()) {
+ // Check if a parameter is a texture_external by trying to find
+ // it in the transform state.
+ auto it = new_binding_symbols.find(var_user->Variable());
+ if (it != new_binding_symbols.end()) {
+ auto& syms = it->second;
+ // When we find a texture_external, we must unpack it into its
+ // components.
+ ctx.Replace(arg, b.Expr(syms.plane_0));
+ ctx.InsertAfter(expr->args, arg, b.Expr(syms.plane_1));
+ ctx.InsertAfter(expr->args, arg, b.Expr(syms.params));
+ }
+ }
+ }
+ }
+
+ return nullptr;
+ });
+ }
+
+ /// Creates the parameter structs associated with the transform.
+ void createExtTexParamsStructs() {
+ // Create GammaTransferParams struct.
+ utils::Vector gamma_transfer_member_list{
+ b.Member("G", b.ty.f32()), b.Member("A", b.ty.f32()), b.Member("B", b.ty.f32()),
+ b.Member("C", b.ty.f32()), b.Member("D", b.ty.f32()), b.Member("E", b.ty.f32()),
+ b.Member("F", b.ty.f32()), b.Member("padding", b.ty.u32())};
+
+ gamma_transfer_struct_sym = b.Symbols().New("GammaTransferParams");
+
+ b.Structure(gamma_transfer_struct_sym, gamma_transfer_member_list);
+
+ // Create ExternalTextureParams struct.
+ utils::Vector ext_tex_params_member_list{
+ b.Member("numPlanes", b.ty.u32()),
+ b.Member("doYuvToRgbConversionOnly", b.ty.u32()),
+ b.Member("yuvToRgbConversionMatrix", b.ty.mat3x4<f32>()),
+ b.Member("gammaDecodeParams", b.ty("GammaTransferParams")),
+ b.Member("gammaEncodeParams", b.ty("GammaTransferParams")),
+ b.Member("gamutConversionMatrix", b.ty.mat3x3<f32>()),
+ b.Member("coordTransformationMatrix", b.ty.mat3x2<f32>()),
+ };
+
+ params_struct_sym = b.Symbols().New("ExternalTextureParams");
+
+ b.Structure(params_struct_sym, ext_tex_params_member_list);
+ }
+
+ /// Creates the gammaCorrection function if needed and returns a call
+ /// expression to it.
+ void createGammaCorrectionFn() {
+ gamma_correction_sym = b.Symbols().New("gammaCorrection");
+
+ b.Func(gamma_correction_sym,
+ utils::Vector{
+ b.Param("v", b.ty.vec3<f32>()),
+ b.Param("params", b.ty(gamma_transfer_struct_sym)),
+ },
+ b.ty.vec3<f32>(),
+ utils::Vector{
+ // let cond = abs(v) < vec3(params.D);
+ b.Decl(b.Let("cond",
+ b.LessThan(b.Call("abs", "v"),
+ b.Call<vec3<f32>>(b.MemberAccessor("params", "D"))))),
+ // let t = sign(v) * ((params.C * abs(v)) + params.F);
+ b.Decl(b.Let(
+ "t", b.Mul(b.Call("sign", "v"),
+ b.Add(b.Mul(b.MemberAccessor("params", "C"), b.Call("abs", "v")),
+ b.MemberAccessor("params", "F"))))),
+ // let f = (sign(v) * pow(((params.A * abs(v)) + params.B),
+ // vec3(params.G))) + params.E;
+ b.Decl(b.Let(
+ "f", b.Mul(b.Call("sign", "v"),
+ b.Add(b.Call("pow",
+ b.Add(b.Mul(b.MemberAccessor("params", "A"),
+ b.Call("abs", "v")),
+ b.MemberAccessor("params", "B")),
+ b.Call<vec3<f32>>(b.MemberAccessor("params", "G"))),
+ b.MemberAccessor("params", "E"))))),
+ // return select(f, t, cond);
+ b.Return(b.Call("select", "f", "t", "cond")),
+ });
+ }
+
+ /// Constructs a StatementList containing all the statements making up the body of the texture
+ /// builtin function.
+ /// @param call_type determines which function body to generate
+ /// @returns a statement list that makes of the body of the chosen function
+ auto buildTextureBuiltinBody(builtin::Function call_type) {
+ utils::Vector<const Statement*, 16> stmts;
+ const CallExpression* single_plane_call = nullptr;
+ const CallExpression* plane_0_call = nullptr;
+ const CallExpression* plane_1_call = nullptr;
+ switch (call_type) {
+ case builtin::Function::kTextureSampleBaseClampToEdge:
+ stmts.Push(b.Decl(b.Let(
+ "modifiedCoords", b.Mul(b.MemberAccessor("params", "coordTransformationMatrix"),
+ b.Call<vec3<f32>>("coord", 1_a)))));
+
+ stmts.Push(b.Decl(
+ b.Let("plane0_dims",
+ b.Call(b.ty.vec2<f32>(), b.Call("textureDimensions", "plane0", 0_a)))));
+ stmts.Push(b.Decl(
+ b.Let("plane0_half_texel", b.Div(b.Call<vec2<f32>>(0.5_a), "plane0_dims"))));
+ stmts.Push(b.Decl(
+ b.Let("plane0_clamped", b.Call("clamp", "modifiedCoords", "plane0_half_texel",
+ b.Sub(1_a, "plane0_half_texel")))));
+ stmts.Push(b.Decl(
+ b.Let("plane1_dims",
+ b.Call(b.ty.vec2<f32>(), b.Call("textureDimensions", "plane1", 0_a)))));
+ stmts.Push(b.Decl(
+ b.Let("plane1_half_texel", b.Div(b.Call<vec2<f32>>(0.5_a), "plane1_dims"))));
+ stmts.Push(b.Decl(
+ b.Let("plane1_clamped", b.Call("clamp", "modifiedCoords", "plane1_half_texel",
+ b.Sub(1_a, "plane1_half_texel")))));
+
+ // textureSampleLevel(plane0, smp, plane0_clamped, 0.0);
+ single_plane_call =
+ b.Call("textureSampleLevel", "plane0", "smp", "plane0_clamped", 0_a);
+ // textureSampleLevel(plane0, smp, plane0_clamped, 0.0);
+ plane_0_call = b.Call("textureSampleLevel", "plane0", "smp", "plane0_clamped", 0_a);
+ // textureSampleLevel(plane1, smp, plane1_clamped, 0.0);
+ plane_1_call = b.Call("textureSampleLevel", "plane1", "smp", "plane1_clamped", 0_a);
+ break;
+ case builtin::Function::kTextureLoad:
+ // textureLoad(plane0, coord, 0);
+ single_plane_call = b.Call("textureLoad", "plane0", "coord", 0_a);
+ // textureLoad(plane0, coord, 0);
+ plane_0_call = b.Call("textureLoad", "plane0", "coord", 0_a);
+ // let coord1 = coord >> 1;
+ stmts.Push(b.Decl(b.Let("coord1", b.Shr("coord", b.Call<vec2<u32>>(1_a)))));
+ // textureLoad(plane1, coord1, 0);
+ plane_1_call = b.Call("textureLoad", "plane1", "coord1", 0_a);
+ break;
+ default:
+ TINT_ICE(Transform, b.Diagnostics()) << "unhandled builtin: " << call_type;
+ }
+
+ // var color: vec3<f32>;
+ stmts.Push(b.Decl(b.Var("color", b.ty.vec3(b.ty.f32()))));
+
+ // if ((params.numPlanes == 1u))
+ stmts.Push(
+ b.If(b.Equal(b.MemberAccessor("params", "numPlanes"), b.Expr(1_a)),
+ b.Block(
+ // color = textureLoad(plane0, coord, 0).rgb;
+ b.Assign("color", b.MemberAccessor(single_plane_call, "rgb"))),
+ b.Else(b.Block(
+ // color = vec4<f32>(plane_0_call.r, plane_1_call.rg, 1.0) *
+ // params.yuvToRgbConversionMatrix;
+ b.Assign("color",
+ b.Mul(b.Call<vec4<f32>>(b.MemberAccessor(plane_0_call, "r"),
+ b.MemberAccessor(plane_1_call, "rg"), 1_a),
+ b.MemberAccessor("params", "yuvToRgbConversionMatrix")))))));
+
+ // if (params.doYuvToRgbConversionOnly == 0u)
+ stmts.Push(
+ b.If(b.Equal(b.MemberAccessor("params", "doYuvToRgbConversionOnly"), b.Expr(0_a)),
+ b.Block(
+ // color = gammaConversion(color, gammaDecodeParams);
+ b.Assign("color", b.Call("gammaCorrection", "color",
+ b.MemberAccessor("params", "gammaDecodeParams"))),
+ // color = (params.gamutConversionMatrix * color);
+ b.Assign("color",
+ b.Mul(b.MemberAccessor("params", "gamutConversionMatrix"), "color")),
+ // color = gammaConversion(color, gammaEncodeParams);
+ b.Assign("color", b.Call("gammaCorrection", "color",
+ b.MemberAccessor("params", "gammaEncodeParams"))))));
+
+ // return vec4<f32>(color, 1.f);
+ stmts.Push(b.Return(b.Call<vec4<f32>>("color", 1_a)));
+
+ return stmts;
+ }
+
+ /// Creates the textureSampleExternal function if needed and returns a call expression to it.
+ /// @param expr the call expression being transformed
+ /// @param syms the expanded symbols to be used in the new call
+ /// @returns a call expression to textureSampleExternal
+ const CallExpression* createTextureSampleBaseClampToEdge(const CallExpression* expr,
+ NewBindingSymbols syms) {
+ const Expression* plane_0_binding_param = ctx.Clone(expr->args[0]);
+
+ if (TINT_UNLIKELY(expr->args.Length() != 3)) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "expected textureSampleBaseClampToEdge call with a "
+ "texture_external to have 3 parameters, found "
+ << expr->args.Length() << " parameters";
+ }
+
+ // TextureSampleExternal calls the gammaCorrection function, so ensure it
+ // exists.
+ if (!gamma_correction_sym.IsValid()) {
+ createGammaCorrectionFn();
+ }
+
+ if (!texture_sample_external_sym.IsValid()) {
+ texture_sample_external_sym = b.Symbols().New("textureSampleExternal");
+
+ // Emit the textureSampleExternal function.
+ b.Func(texture_sample_external_sym,
+ utils::Vector{
+ b.Param("plane0",
+ b.ty.sampled_texture(type::TextureDimension::k2d, b.ty.f32())),
+ b.Param("plane1",
+ b.ty.sampled_texture(type::TextureDimension::k2d, b.ty.f32())),
+ b.Param("smp", b.ty.sampler(type::SamplerKind::kSampler)),
+ b.Param("coord", b.ty.vec2(b.ty.f32())),
+ b.Param("params", b.ty(params_struct_sym)),
+ },
+ b.ty.vec4(b.ty.f32()),
+ buildTextureBuiltinBody(builtin::Function::kTextureSampleBaseClampToEdge));
+ }
+
+ return b.Call(texture_sample_external_sym, utils::Vector{
+ plane_0_binding_param,
+ b.Expr(syms.plane_1),
+ ctx.Clone(expr->args[1]),
+ ctx.Clone(expr->args[2]),
+ b.Expr(syms.params),
+ });
+ }
+
+ /// Creates the textureLoadExternal function if needed and returns a call expression to it.
+ /// @param call the call expression being transformed
+ /// @param syms the expanded symbols to be used in the new call
+ /// @returns a call expression to textureLoadExternal
+ const CallExpression* createTextureLoad(const sem::Call* call, NewBindingSymbols syms) {
+ if (TINT_UNLIKELY(call->Arguments().Length() != 2)) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "expected textureLoad call with a texture_external to have 2 arguments, found "
+ << call->Arguments().Length() << " arguments";
+ }
+
+ auto& args = call->Arguments();
+
+ // TextureLoadExternal calls the gammaCorrection function, so ensure it exists.
+ if (!gamma_correction_sym.IsValid()) {
+ createGammaCorrectionFn();
+ }
+
+ auto texture_load_external_sym = texture_load_external_fns.GetOrCreate(call->Target(), [&] {
+ auto& sig = call->Target()->Signature();
+ auto* coord_ty = sig.Parameter(sem::ParameterUsage::kCoords)->Type();
+
+ auto name = b.Symbols().New("textureLoadExternal");
+
+ // Emit the textureLoadExternal() function.
+ b.Func(name,
+ utils::Vector{
+ b.Param("plane0",
+ b.ty.sampled_texture(type::TextureDimension::k2d, b.ty.f32())),
+ b.Param("plane1",
+ b.ty.sampled_texture(type::TextureDimension::k2d, b.ty.f32())),
+ b.Param("coord", CreateASTTypeFor(ctx, coord_ty)),
+ b.Param("params", b.ty(params_struct_sym)),
+ },
+ b.ty.vec4(b.ty.f32()), //
+ buildTextureBuiltinBody(builtin::Function::kTextureLoad));
+
+ return name;
+ });
+
+ auto plane_0_binding_arg = ctx.Clone(args[0]->Declaration());
+
+ return b.Call(texture_load_external_sym, plane_0_binding_arg, syms.plane_1,
+ ctx.Clone(args[1]->Declaration()), syms.params);
+ }
+};
+
+MultiplanarExternalTexture::NewBindingPoints::NewBindingPoints(BindingsMap inputBindingsMap)
+ : bindings_map(std::move(inputBindingsMap)) {}
+
+MultiplanarExternalTexture::NewBindingPoints::~NewBindingPoints() = default;
+
+MultiplanarExternalTexture::MultiplanarExternalTexture() = default;
+MultiplanarExternalTexture::~MultiplanarExternalTexture() = default;
+
+// Within this transform, an instance of a texture_external binding is unpacked into two
+// texture_2d<f32> bindings representing two possible planes of a single texture and a uniform
+// buffer binding representing a struct of parameters. Calls to texture builtins that contain a
+// texture_external parameter will be transformed into a newly generated version of the function,
+// which can perform the desired operation on a single RGBA plane or on separate Y and UV planes.
+Transform::ApplyResult MultiplanarExternalTexture::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap&) const {
+ auto* new_binding_points = inputs.Get<NewBindingPoints>();
+
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+ if (!new_binding_points) {
+ b.Diagnostics().add_error(diag::System::Transform, "missing new binding point data for " +
+ std::string(TypeInfo().name));
+ return Program(std::move(b));
+ }
+
+ State state(ctx, new_binding_points);
+
+ state.Process();
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/multiplanar_external_texture.h b/src/tint/lang/wgsl/ast/transform/multiplanar_external_texture.h
new file mode 100644
index 0000000..e2b0d80
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/multiplanar_external_texture.h
@@ -0,0 +1,83 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_MULTIPLANAR_EXTERNAL_TEXTURE_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_MULTIPLANAR_EXTERNAL_TEXTURE_H_
+
+#include <unordered_map>
+#include <utility>
+
+#include "src/tint/builtin/function.h"
+#include "src/tint/lang/wgsl/ast/struct_member.h"
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+#include "src/tint/sem/binding_point.h"
+#include "src/tint/sem/external_texture.h"
+
+namespace tint::ast::transform {
+
+/// Within the MultiplanarExternalTexture transform, each instance of a
+/// texture_external binding is unpacked into two texture_2d<f32> bindings
+/// representing two possible planes of a texture and a uniform buffer binding
+/// representing a struct of parameters. Calls to textureLoad or
+/// textureSampleLevel that contain a texture_external parameter will be
+/// transformed into a newly generated version of the function, which can
+/// perform the desired operation on a single RGBA plane or on separate Y and UV
+/// planes, and do colorspace conversions including yuv->rgb conversion, gamma
+/// decoding, gamut conversion, and gamma encoding steps. Specifically
+// for BT.709 to SRGB conversion, it takes the fast path only doing the yuv->rgb
+// step and skipping all other steps.
+class MultiplanarExternalTexture final
+ : public utils::Castable<MultiplanarExternalTexture, Transform> {
+ public:
+ /// This struct identifies the binding groups and locations for new bindings to
+ /// use when transforming a texture_external instance.
+ using BindingPoints = sem::external_texture::BindingPoints;
+
+ /// BindingsMap is a map where the key is the binding location of a
+ /// texture_external and the value is a struct containing the desired
+ /// locations for new bindings expanded from the texture_external instance.
+ using BindingsMap = sem::external_texture::BindingsMap;
+
+ /// NewBindingPoints is consumed by the MultiplanarExternalTexture transform.
+ /// Data holds information about location of each texture_external binding and
+ /// which binding slots it should expand into.
+ struct NewBindingPoints final : public utils::Castable<NewBindingPoints, Data> {
+ /// Constructor
+ /// @param bm a map to the new binding slots to use.
+ explicit NewBindingPoints(BindingsMap bm);
+
+ /// Destructor
+ ~NewBindingPoints() override;
+
+ /// A map of new binding points to use.
+ const BindingsMap bindings_map;
+ };
+
+ /// Constructor
+ MultiplanarExternalTexture();
+ /// Destructor
+ ~MultiplanarExternalTexture() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_MULTIPLANAR_EXTERNAL_TEXTURE_H_
diff --git a/src/tint/lang/wgsl/ast/transform/multiplanar_external_texture_test.cc b/src/tint/lang/wgsl/ast/transform/multiplanar_external_texture_test.cc
new file mode 100644
index 0000000..cbe48e4
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/multiplanar_external_texture_test.cc
@@ -0,0 +1,1816 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/multiplanar_external_texture.h"
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using MultiplanarExternalTextureTest = TransformTest;
+
+TEST_F(MultiplanarExternalTextureTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+
+ EXPECT_FALSE(ShouldRun<MultiplanarExternalTexture>(src, data));
+}
+
+TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) {
+ auto* src = R"(
+alias ET = texture_external;
+)";
+
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+
+ EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data));
+}
+TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureGlobal) {
+ auto* src = R"(
+@group(0) @binding(0) var ext_tex : texture_external;
+)";
+
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+
+ EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data));
+}
+
+TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) {
+ auto* src = R"(
+fn f(ext_tex : texture_external) {}
+)";
+
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+
+ EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data));
+}
+
+// Running the transform without passing in data for the new bindings should result in an error.
+TEST_F(MultiplanarExternalTextureTest, ErrorNoPassedData_SampleBaseClampToEdge) {
+ auto* src = R"(
+@group(0) @binding(0) var s : sampler;
+@group(0) @binding(1) var ext_tex : texture_external;
+
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ return textureSampleBaseClampToEdge(ext_tex, s, coord.xy);
+}
+)";
+ auto* expect =
+ "error: missing new binding point data for "
+ "tint::ast::transform::MultiplanarExternalTexture";
+
+ auto got = Run<MultiplanarExternalTexture>(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Running the transform with incorrect binding data should result in an error.
+TEST_F(MultiplanarExternalTextureTest, ErrorIncorrectBindingPont_SampleBaseClampToEdge) {
+ auto* src = R"(
+@group(0) @binding(0) var s : sampler;
+@group(0) @binding(1) var ext_tex : texture_external;
+
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ return textureSampleBaseClampToEdge(ext_tex, s, coord.xy);
+}
+)";
+
+ auto* expect = R"(error: missing new binding points for texture_external at binding {0,1})";
+
+ Transform::DataMap data;
+ // This bindings map specifies 0,0 as the location of the texture_external,
+ // which is incorrect.
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Tests that the transform works with a textureDimensions call.
+TEST_F(MultiplanarExternalTextureTest, Dimensions) {
+ auto* src = R"(
+@group(0) @binding(0) var ext_tex : texture_external;
+
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ var dim : vec2<u32>;
+ dim = textureDimensions(ext_tex);
+ return vec4<f32>(0.0, 0.0, 0.0, 0.0);
+}
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(1) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(2) var<uniform> ext_tex_params : ExternalTextureParams;
+
+@group(0) @binding(0) var ext_tex : texture_2d<f32>;
+
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ var dim : vec2<u32>;
+ dim = textureDimensions(ext_tex);
+ return vec4<f32>(0.0, 0.0, 0.0, 0.0);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Tests that the transform works with a textureDimensions call.
+TEST_F(MultiplanarExternalTextureTest, Dimensions_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ var dim : vec2<u32>;
+ dim = textureDimensions(ext_tex);
+ return vec4<f32>(0.0, 0.0, 0.0, 0.0);
+}
+
+@group(0) @binding(0) var ext_tex : texture_external;
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(1) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(2) var<uniform> ext_tex_params : ExternalTextureParams;
+
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ var dim : vec2<u32>;
+ dim = textureDimensions(ext_tex);
+ return vec4<f32>(0.0, 0.0, 0.0, 0.0);
+}
+
+@group(0) @binding(0) var ext_tex : texture_2d<f32>;
+)";
+
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that the transform works with a textureSampleBaseClampToEdge call.
+TEST_F(MultiplanarExternalTextureTest, BasicTextureSampleBaseClampToEdge) {
+ auto* src = R"(
+@group(0) @binding(0) var s : sampler;
+@group(0) @binding(1) var ext_tex : texture_external;
+
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ return textureSampleBaseClampToEdge(ext_tex, s, coord.xy);
+}
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(2) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(3) var<uniform> ext_tex_params : ExternalTextureParams;
+
+@group(0) @binding(0) var s : sampler;
+
+@group(0) @binding(1) var ext_tex : texture_2d<f32>;
+
+fn gammaCorrection(v : vec3<f32>, params : GammaTransferParams) -> vec3<f32> {
+ let cond = (abs(v) < vec3<f32>(params.D));
+ let t = (sign(v) * ((params.C * abs(v)) + params.F));
+ let f = (sign(v) * (pow(((params.A * abs(v)) + params.B), vec3<f32>(params.G)) + params.E));
+ return select(f, t, cond);
+}
+
+fn textureSampleExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, smp : sampler, coord : vec2<f32>, params : ExternalTextureParams) -> vec4<f32> {
+ let modifiedCoords = (params.coordTransformationMatrix * vec3<f32>(coord, 1));
+ let plane0_dims = vec2<f32>(textureDimensions(plane0, 0));
+ let plane0_half_texel = (vec2<f32>(0.5) / plane0_dims);
+ let plane0_clamped = clamp(modifiedCoords, plane0_half_texel, (1 - plane0_half_texel));
+ let plane1_dims = vec2<f32>(textureDimensions(plane1, 0));
+ let plane1_half_texel = (vec2<f32>(0.5) / plane1_dims);
+ let plane1_clamped = clamp(modifiedCoords, plane1_half_texel, (1 - plane1_half_texel));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureSampleLevel(plane0, smp, plane0_clamped, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureSampleLevel(plane0, smp, plane0_clamped, 0).r, textureSampleLevel(plane1, smp, plane1_clamped, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ return textureSampleExternal(ext_tex, ext_tex_plane_1, s, coord.xy, ext_tex_params);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 1}, {{0, 2}, {0, 3}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that the transform works with a textureSampleBaseClampToEdge call.
+TEST_F(MultiplanarExternalTextureTest, BasicTextureSampleBaseClampToEdge_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ return textureSampleBaseClampToEdge(ext_tex, s, coord.xy);
+}
+
+@group(0) @binding(1) var ext_tex : texture_external;
+@group(0) @binding(0) var s : sampler;
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(2) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(3) var<uniform> ext_tex_params : ExternalTextureParams;
+
+fn gammaCorrection(v : vec3<f32>, params : GammaTransferParams) -> vec3<f32> {
+ let cond = (abs(v) < vec3<f32>(params.D));
+ let t = (sign(v) * ((params.C * abs(v)) + params.F));
+ let f = (sign(v) * (pow(((params.A * abs(v)) + params.B), vec3<f32>(params.G)) + params.E));
+ return select(f, t, cond);
+}
+
+fn textureSampleExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, smp : sampler, coord : vec2<f32>, params : ExternalTextureParams) -> vec4<f32> {
+ let modifiedCoords = (params.coordTransformationMatrix * vec3<f32>(coord, 1));
+ let plane0_dims = vec2<f32>(textureDimensions(plane0, 0));
+ let plane0_half_texel = (vec2<f32>(0.5) / plane0_dims);
+ let plane0_clamped = clamp(modifiedCoords, plane0_half_texel, (1 - plane0_half_texel));
+ let plane1_dims = vec2<f32>(textureDimensions(plane1, 0));
+ let plane1_half_texel = (vec2<f32>(0.5) / plane1_dims);
+ let plane1_clamped = clamp(modifiedCoords, plane1_half_texel, (1 - plane1_half_texel));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureSampleLevel(plane0, smp, plane0_clamped, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureSampleLevel(plane0, smp, plane0_clamped, 0).r, textureSampleLevel(plane1, smp, plane1_clamped, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ return textureSampleExternal(ext_tex, ext_tex_plane_1, s, coord.xy, ext_tex_params);
+}
+
+@group(0) @binding(1) var ext_tex : texture_2d<f32>;
+
+@group(0) @binding(0) var s : sampler;
+)";
+
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 1}, {{0, 2}, {0, 3}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Tests that the transform works with a textureLoad call.
+TEST_F(MultiplanarExternalTextureTest, BasicTextureLoad) {
+ auto* src = R"(
+@group(0) @binding(0) var ext_tex : texture_external;
+
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ var val_signed = textureLoad(ext_tex, vec2<i32>(1));
+ var val_unsigned = textureLoad(ext_tex, vec2<u32>(1));
+ return val_signed + val_unsigned;
+}
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(1) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(2) var<uniform> ext_tex_params : ExternalTextureParams;
+
+@group(0) @binding(0) var ext_tex : texture_2d<f32>;
+
+fn gammaCorrection(v : vec3<f32>, params : GammaTransferParams) -> vec3<f32> {
+ let cond = (abs(v) < vec3<f32>(params.D));
+ let t = (sign(v) * ((params.C * abs(v)) + params.F));
+ let f = (sign(v) * (pow(((params.A * abs(v)) + params.B), vec3<f32>(params.G)) + params.E));
+ return select(f, t, cond);
+}
+
+fn textureLoadExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, coord : vec2<i32>, params : ExternalTextureParams) -> vec4<f32> {
+ let coord1 = (coord >> vec2<u32>(1));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureLoad(plane0, coord, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureLoad(plane0, coord, 0).r, textureLoad(plane1, coord1, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+fn textureLoadExternal_1(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, coord : vec2<u32>, params : ExternalTextureParams) -> vec4<f32> {
+ let coord1 = (coord >> vec2<u32>(1));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureLoad(plane0, coord, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureLoad(plane0, coord, 0).r, textureLoad(plane1, coord1, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ var val_signed = textureLoadExternal(ext_tex, ext_tex_plane_1, vec2<i32>(1), ext_tex_params);
+ var val_unsigned = textureLoadExternal_1(ext_tex, ext_tex_plane_1, vec2<u32>(1), ext_tex_params);
+ return (val_signed + val_unsigned);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Tests that the transform works with a textureLoad call.
+TEST_F(MultiplanarExternalTextureTest, BasicTextureLoad_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ var val_signed = textureLoad(ext_tex, vec2<i32>(1));
+ var val_unsigned = textureLoad(ext_tex, vec2<u32>(1));
+ return val_signed + val_unsigned;
+}
+
+@group(0) @binding(0) var ext_tex : texture_external;
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(1) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(2) var<uniform> ext_tex_params : ExternalTextureParams;
+
+fn gammaCorrection(v : vec3<f32>, params : GammaTransferParams) -> vec3<f32> {
+ let cond = (abs(v) < vec3<f32>(params.D));
+ let t = (sign(v) * ((params.C * abs(v)) + params.F));
+ let f = (sign(v) * (pow(((params.A * abs(v)) + params.B), vec3<f32>(params.G)) + params.E));
+ return select(f, t, cond);
+}
+
+fn textureLoadExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, coord : vec2<i32>, params : ExternalTextureParams) -> vec4<f32> {
+ let coord1 = (coord >> vec2<u32>(1));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureLoad(plane0, coord, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureLoad(plane0, coord, 0).r, textureLoad(plane1, coord1, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+fn textureLoadExternal_1(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, coord : vec2<u32>, params : ExternalTextureParams) -> vec4<f32> {
+ let coord1 = (coord >> vec2<u32>(1));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureLoad(plane0, coord, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureLoad(plane0, coord, 0).r, textureLoad(plane1, coord1, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ var val_signed = textureLoadExternal(ext_tex, ext_tex_plane_1, vec2<i32>(1), ext_tex_params);
+ var val_unsigned = textureLoadExternal_1(ext_tex, ext_tex_plane_1, vec2<u32>(1), ext_tex_params);
+ return (val_signed + val_unsigned);
+}
+
+@group(0) @binding(0) var ext_tex : texture_2d<f32>;
+)";
+
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Tests that the transform works with both a textureSampleBaseClampToEdge and textureLoad call.
+TEST_F(MultiplanarExternalTextureTest, TextureSampleBaseClampToEdgeAndTextureLoad) {
+ auto* src = R"(
+@group(0) @binding(0) var s : sampler;
+@group(0) @binding(1) var ext_tex : texture_external;
+
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ return textureSampleBaseClampToEdge(ext_tex, s, coord.xy) + textureLoad(ext_tex, vec2<i32>(1, 1));
+}
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(2) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(3) var<uniform> ext_tex_params : ExternalTextureParams;
+
+@group(0) @binding(0) var s : sampler;
+
+@group(0) @binding(1) var ext_tex : texture_2d<f32>;
+
+fn gammaCorrection(v : vec3<f32>, params : GammaTransferParams) -> vec3<f32> {
+ let cond = (abs(v) < vec3<f32>(params.D));
+ let t = (sign(v) * ((params.C * abs(v)) + params.F));
+ let f = (sign(v) * (pow(((params.A * abs(v)) + params.B), vec3<f32>(params.G)) + params.E));
+ return select(f, t, cond);
+}
+
+fn textureSampleExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, smp : sampler, coord : vec2<f32>, params : ExternalTextureParams) -> vec4<f32> {
+ let modifiedCoords = (params.coordTransformationMatrix * vec3<f32>(coord, 1));
+ let plane0_dims = vec2<f32>(textureDimensions(plane0, 0));
+ let plane0_half_texel = (vec2<f32>(0.5) / plane0_dims);
+ let plane0_clamped = clamp(modifiedCoords, plane0_half_texel, (1 - plane0_half_texel));
+ let plane1_dims = vec2<f32>(textureDimensions(plane1, 0));
+ let plane1_half_texel = (vec2<f32>(0.5) / plane1_dims);
+ let plane1_clamped = clamp(modifiedCoords, plane1_half_texel, (1 - plane1_half_texel));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureSampleLevel(plane0, smp, plane0_clamped, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureSampleLevel(plane0, smp, plane0_clamped, 0).r, textureSampleLevel(plane1, smp, plane1_clamped, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+fn textureLoadExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, coord : vec2<i32>, params : ExternalTextureParams) -> vec4<f32> {
+ let coord1 = (coord >> vec2<u32>(1));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureLoad(plane0, coord, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureLoad(plane0, coord, 0).r, textureLoad(plane1, coord1, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ return (textureSampleExternal(ext_tex, ext_tex_plane_1, s, coord.xy, ext_tex_params) + textureLoadExternal(ext_tex, ext_tex_plane_1, vec2<i32>(1, 1), ext_tex_params));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 1}, {{0, 2}, {0, 3}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Tests that the transform works with both a textureSampleBaseClampToEdge and textureLoad call.
+TEST_F(MultiplanarExternalTextureTest, TextureSampleBaseClampToEdgeAndTextureLoad_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ return textureSampleBaseClampToEdge(ext_tex, s, coord.xy) + textureLoad(ext_tex, vec2<i32>(1, 1));
+}
+
+@group(0) @binding(0) var s : sampler;
+@group(0) @binding(1) var ext_tex : texture_external;
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(2) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(3) var<uniform> ext_tex_params : ExternalTextureParams;
+
+fn gammaCorrection(v : vec3<f32>, params : GammaTransferParams) -> vec3<f32> {
+ let cond = (abs(v) < vec3<f32>(params.D));
+ let t = (sign(v) * ((params.C * abs(v)) + params.F));
+ let f = (sign(v) * (pow(((params.A * abs(v)) + params.B), vec3<f32>(params.G)) + params.E));
+ return select(f, t, cond);
+}
+
+fn textureSampleExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, smp : sampler, coord : vec2<f32>, params : ExternalTextureParams) -> vec4<f32> {
+ let modifiedCoords = (params.coordTransformationMatrix * vec3<f32>(coord, 1));
+ let plane0_dims = vec2<f32>(textureDimensions(plane0, 0));
+ let plane0_half_texel = (vec2<f32>(0.5) / plane0_dims);
+ let plane0_clamped = clamp(modifiedCoords, plane0_half_texel, (1 - plane0_half_texel));
+ let plane1_dims = vec2<f32>(textureDimensions(plane1, 0));
+ let plane1_half_texel = (vec2<f32>(0.5) / plane1_dims);
+ let plane1_clamped = clamp(modifiedCoords, plane1_half_texel, (1 - plane1_half_texel));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureSampleLevel(plane0, smp, plane0_clamped, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureSampleLevel(plane0, smp, plane0_clamped, 0).r, textureSampleLevel(plane1, smp, plane1_clamped, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+fn textureLoadExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, coord : vec2<i32>, params : ExternalTextureParams) -> vec4<f32> {
+ let coord1 = (coord >> vec2<u32>(1));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureLoad(plane0, coord, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureLoad(plane0, coord, 0).r, textureLoad(plane1, coord1, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ return (textureSampleExternal(ext_tex, ext_tex_plane_1, s, coord.xy, ext_tex_params) + textureLoadExternal(ext_tex, ext_tex_plane_1, vec2<i32>(1, 1), ext_tex_params));
+}
+
+@group(0) @binding(0) var s : sampler;
+
+@group(0) @binding(1) var ext_tex : texture_2d<f32>;
+)";
+
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 1}, {{0, 2}, {0, 3}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Tests that the transform works with many instances of texture_external.
+TEST_F(MultiplanarExternalTextureTest, ManyTextureSampleBaseClampToEdge) {
+ auto* src = R"(
+@group(0) @binding(0) var s : sampler;
+@group(0) @binding(1) var ext_tex : texture_external;
+@group(0) @binding(2) var ext_tex_1 : texture_external;
+@group(0) @binding(3) var ext_tex_2 : texture_external;
+@group(1) @binding(0) var ext_tex_3 : texture_external;
+
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ return textureSampleBaseClampToEdge(ext_tex, s, coord.xy) +
+ textureSampleBaseClampToEdge(ext_tex_1, s, coord.xy) +
+ textureSampleBaseClampToEdge(ext_tex_2, s, coord.xy) +
+ textureSampleBaseClampToEdge(ext_tex_3, s, coord.xy);
+}
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(4) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(5) var<uniform> ext_tex_params : ExternalTextureParams;
+
+@group(0) @binding(6) var ext_tex_plane_1_1 : texture_2d<f32>;
+
+@group(0) @binding(7) var<uniform> ext_tex_params_1 : ExternalTextureParams;
+
+@group(0) @binding(8) var ext_tex_plane_1_2 : texture_2d<f32>;
+
+@group(0) @binding(9) var<uniform> ext_tex_params_2 : ExternalTextureParams;
+
+@group(1) @binding(1) var ext_tex_plane_1_3 : texture_2d<f32>;
+
+@group(1) @binding(2) var<uniform> ext_tex_params_3 : ExternalTextureParams;
+
+@group(0) @binding(0) var s : sampler;
+
+@group(0) @binding(1) var ext_tex : texture_2d<f32>;
+
+@group(0) @binding(2) var ext_tex_1 : texture_2d<f32>;
+
+@group(0) @binding(3) var ext_tex_2 : texture_2d<f32>;
+
+@group(1) @binding(0) var ext_tex_3 : texture_2d<f32>;
+
+fn gammaCorrection(v : vec3<f32>, params : GammaTransferParams) -> vec3<f32> {
+ let cond = (abs(v) < vec3<f32>(params.D));
+ let t = (sign(v) * ((params.C * abs(v)) + params.F));
+ let f = (sign(v) * (pow(((params.A * abs(v)) + params.B), vec3<f32>(params.G)) + params.E));
+ return select(f, t, cond);
+}
+
+fn textureSampleExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, smp : sampler, coord : vec2<f32>, params : ExternalTextureParams) -> vec4<f32> {
+ let modifiedCoords = (params.coordTransformationMatrix * vec3<f32>(coord, 1));
+ let plane0_dims = vec2<f32>(textureDimensions(plane0, 0));
+ let plane0_half_texel = (vec2<f32>(0.5) / plane0_dims);
+ let plane0_clamped = clamp(modifiedCoords, plane0_half_texel, (1 - plane0_half_texel));
+ let plane1_dims = vec2<f32>(textureDimensions(plane1, 0));
+ let plane1_half_texel = (vec2<f32>(0.5) / plane1_dims);
+ let plane1_clamped = clamp(modifiedCoords, plane1_half_texel, (1 - plane1_half_texel));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureSampleLevel(plane0, smp, plane0_clamped, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureSampleLevel(plane0, smp, plane0_clamped, 0).r, textureSampleLevel(plane1, smp, plane1_clamped, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+@fragment
+fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
+ return (((textureSampleExternal(ext_tex, ext_tex_plane_1, s, coord.xy, ext_tex_params) + textureSampleExternal(ext_tex_1, ext_tex_plane_1_1, s, coord.xy, ext_tex_params_1)) + textureSampleExternal(ext_tex_2, ext_tex_plane_1_2, s, coord.xy, ext_tex_params_2)) + textureSampleExternal(ext_tex_3, ext_tex_plane_1_3, s, coord.xy, ext_tex_params_3));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 1}, {{0, 4}, {0, 5}}},
+ {{0, 2}, {{0, 6}, {0, 7}}},
+ {{0, 3}, {{0, 8}, {0, 9}}},
+ {{1, 0}, {{1, 1}, {1, 2}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Tests that the texture_external passed as a function parameter produces the
+// correct output.
+TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsParam) {
+ auto* src = R"(
+fn f(t : texture_external, s : sampler) {
+ _ = textureSampleBaseClampToEdge(t, s, vec2<f32>(1.0, 2.0));
+}
+
+@group(0) @binding(0) var ext_tex : texture_external;
+@group(0) @binding(1) var smp : sampler;
+
+@fragment
+fn main() {
+ f(ext_tex, smp);
+}
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(2) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(3) var<uniform> ext_tex_params : ExternalTextureParams;
+
+fn gammaCorrection(v : vec3<f32>, params : GammaTransferParams) -> vec3<f32> {
+ let cond = (abs(v) < vec3<f32>(params.D));
+ let t = (sign(v) * ((params.C * abs(v)) + params.F));
+ let f = (sign(v) * (pow(((params.A * abs(v)) + params.B), vec3<f32>(params.G)) + params.E));
+ return select(f, t, cond);
+}
+
+fn textureSampleExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, smp : sampler, coord : vec2<f32>, params : ExternalTextureParams) -> vec4<f32> {
+ let modifiedCoords = (params.coordTransformationMatrix * vec3<f32>(coord, 1));
+ let plane0_dims = vec2<f32>(textureDimensions(plane0, 0));
+ let plane0_half_texel = (vec2<f32>(0.5) / plane0_dims);
+ let plane0_clamped = clamp(modifiedCoords, plane0_half_texel, (1 - plane0_half_texel));
+ let plane1_dims = vec2<f32>(textureDimensions(plane1, 0));
+ let plane1_half_texel = (vec2<f32>(0.5) / plane1_dims);
+ let plane1_clamped = clamp(modifiedCoords, plane1_half_texel, (1 - plane1_half_texel));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureSampleLevel(plane0, smp, plane0_clamped, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureSampleLevel(plane0, smp, plane0_clamped, 0).r, textureSampleLevel(plane1, smp, plane1_clamped, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+fn f(t : texture_2d<f32>, ext_tex_plane_1_1 : texture_2d<f32>, ext_tex_params_1 : ExternalTextureParams, s : sampler) {
+ _ = textureSampleExternal(t, ext_tex_plane_1_1, s, vec2<f32>(1.0, 2.0), ext_tex_params_1);
+}
+
+@group(0) @binding(0) var ext_tex : texture_2d<f32>;
+
+@group(0) @binding(1) var smp : sampler;
+
+@fragment
+fn main() {
+ f(ext_tex, ext_tex_plane_1, ext_tex_params, smp);
+}
+)";
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 2}, {0, 3}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Tests that the texture_external passed as a function parameter produces the
+// correct output.
+TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsParam_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn main() {
+ f(ext_tex, smp);
+}
+
+fn f(t : texture_external, s : sampler) {
+ _ = textureSampleBaseClampToEdge(t, s, vec2<f32>(1.0, 2.0));
+}
+
+@group(0) @binding(0) var ext_tex : texture_external;
+@group(0) @binding(1) var smp : sampler;
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(2) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(3) var<uniform> ext_tex_params : ExternalTextureParams;
+
+@fragment
+fn main() {
+ f(ext_tex, ext_tex_plane_1, ext_tex_params, smp);
+}
+
+fn gammaCorrection(v : vec3<f32>, params : GammaTransferParams) -> vec3<f32> {
+ let cond = (abs(v) < vec3<f32>(params.D));
+ let t = (sign(v) * ((params.C * abs(v)) + params.F));
+ let f = (sign(v) * (pow(((params.A * abs(v)) + params.B), vec3<f32>(params.G)) + params.E));
+ return select(f, t, cond);
+}
+
+fn textureSampleExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, smp : sampler, coord : vec2<f32>, params : ExternalTextureParams) -> vec4<f32> {
+ let modifiedCoords = (params.coordTransformationMatrix * vec3<f32>(coord, 1));
+ let plane0_dims = vec2<f32>(textureDimensions(plane0, 0));
+ let plane0_half_texel = (vec2<f32>(0.5) / plane0_dims);
+ let plane0_clamped = clamp(modifiedCoords, plane0_half_texel, (1 - plane0_half_texel));
+ let plane1_dims = vec2<f32>(textureDimensions(plane1, 0));
+ let plane1_half_texel = (vec2<f32>(0.5) / plane1_dims);
+ let plane1_clamped = clamp(modifiedCoords, plane1_half_texel, (1 - plane1_half_texel));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureSampleLevel(plane0, smp, plane0_clamped, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureSampleLevel(plane0, smp, plane0_clamped, 0).r, textureSampleLevel(plane1, smp, plane1_clamped, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+fn f(t : texture_2d<f32>, ext_tex_plane_1_1 : texture_2d<f32>, ext_tex_params_1 : ExternalTextureParams, s : sampler) {
+ _ = textureSampleExternal(t, ext_tex_plane_1_1, s, vec2<f32>(1.0, 2.0), ext_tex_params_1);
+}
+
+@group(0) @binding(0) var ext_tex : texture_2d<f32>;
+
+@group(0) @binding(1) var smp : sampler;
+)";
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 2}, {0, 3}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Tests that the texture_external passed as a parameter not in the first
+// position produces the correct output.
+TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsSecondParam) {
+ auto* src = R"(
+fn f(s : sampler, t : texture_external) {
+ _ = textureSampleBaseClampToEdge(t, s, vec2<f32>(1.0, 2.0));
+}
+
+@group(0) @binding(0) var ext_tex : texture_external;
+@group(0) @binding(1) var smp : sampler;
+
+@fragment
+fn main() {
+ f(smp, ext_tex);
+}
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(2) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(3) var<uniform> ext_tex_params : ExternalTextureParams;
+
+fn gammaCorrection(v : vec3<f32>, params : GammaTransferParams) -> vec3<f32> {
+ let cond = (abs(v) < vec3<f32>(params.D));
+ let t = (sign(v) * ((params.C * abs(v)) + params.F));
+ let f = (sign(v) * (pow(((params.A * abs(v)) + params.B), vec3<f32>(params.G)) + params.E));
+ return select(f, t, cond);
+}
+
+fn textureSampleExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, smp : sampler, coord : vec2<f32>, params : ExternalTextureParams) -> vec4<f32> {
+ let modifiedCoords = (params.coordTransformationMatrix * vec3<f32>(coord, 1));
+ let plane0_dims = vec2<f32>(textureDimensions(plane0, 0));
+ let plane0_half_texel = (vec2<f32>(0.5) / plane0_dims);
+ let plane0_clamped = clamp(modifiedCoords, plane0_half_texel, (1 - plane0_half_texel));
+ let plane1_dims = vec2<f32>(textureDimensions(plane1, 0));
+ let plane1_half_texel = (vec2<f32>(0.5) / plane1_dims);
+ let plane1_clamped = clamp(modifiedCoords, plane1_half_texel, (1 - plane1_half_texel));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureSampleLevel(plane0, smp, plane0_clamped, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureSampleLevel(plane0, smp, plane0_clamped, 0).r, textureSampleLevel(plane1, smp, plane1_clamped, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+fn f(s : sampler, t : texture_2d<f32>, ext_tex_plane_1_1 : texture_2d<f32>, ext_tex_params_1 : ExternalTextureParams) {
+ _ = textureSampleExternal(t, ext_tex_plane_1_1, s, vec2<f32>(1.0, 2.0), ext_tex_params_1);
+}
+
+@group(0) @binding(0) var ext_tex : texture_2d<f32>;
+
+@group(0) @binding(1) var smp : sampler;
+
+@fragment
+fn main() {
+ f(smp, ext_tex, ext_tex_plane_1, ext_tex_params);
+}
+)";
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 2}, {0, 3}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Tests that multiple texture_external params passed to a function produces the
+// correct output.
+TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsParamMultiple) {
+ auto* src = R"(
+fn f(t : texture_external, s : sampler, t2 : texture_external) {
+ _ = textureSampleBaseClampToEdge(t, s, vec2<f32>(1.0, 2.0));
+ _ = textureSampleBaseClampToEdge(t2, s, vec2<f32>(1.0, 2.0));
+}
+
+@group(0) @binding(0) var ext_tex : texture_external;
+@group(0) @binding(1) var smp : sampler;
+@group(0) @binding(2) var ext_tex2 : texture_external;
+
+@fragment
+fn main() {
+ f(ext_tex, smp, ext_tex2);
+}
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(3) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(4) var<uniform> ext_tex_params : ExternalTextureParams;
+
+@group(0) @binding(5) var ext_tex_plane_1_1 : texture_2d<f32>;
+
+@group(0) @binding(6) var<uniform> ext_tex_params_1 : ExternalTextureParams;
+
+fn gammaCorrection(v : vec3<f32>, params : GammaTransferParams) -> vec3<f32> {
+ let cond = (abs(v) < vec3<f32>(params.D));
+ let t = (sign(v) * ((params.C * abs(v)) + params.F));
+ let f = (sign(v) * (pow(((params.A * abs(v)) + params.B), vec3<f32>(params.G)) + params.E));
+ return select(f, t, cond);
+}
+
+fn textureSampleExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, smp : sampler, coord : vec2<f32>, params : ExternalTextureParams) -> vec4<f32> {
+ let modifiedCoords = (params.coordTransformationMatrix * vec3<f32>(coord, 1));
+ let plane0_dims = vec2<f32>(textureDimensions(plane0, 0));
+ let plane0_half_texel = (vec2<f32>(0.5) / plane0_dims);
+ let plane0_clamped = clamp(modifiedCoords, plane0_half_texel, (1 - plane0_half_texel));
+ let plane1_dims = vec2<f32>(textureDimensions(plane1, 0));
+ let plane1_half_texel = (vec2<f32>(0.5) / plane1_dims);
+ let plane1_clamped = clamp(modifiedCoords, plane1_half_texel, (1 - plane1_half_texel));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureSampleLevel(plane0, smp, plane0_clamped, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureSampleLevel(plane0, smp, plane0_clamped, 0).r, textureSampleLevel(plane1, smp, plane1_clamped, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+fn f(t : texture_2d<f32>, ext_tex_plane_1_2 : texture_2d<f32>, ext_tex_params_2 : ExternalTextureParams, s : sampler, t2 : texture_2d<f32>, ext_tex_plane_1_3 : texture_2d<f32>, ext_tex_params_3 : ExternalTextureParams) {
+ _ = textureSampleExternal(t, ext_tex_plane_1_2, s, vec2<f32>(1.0, 2.0), ext_tex_params_2);
+ _ = textureSampleExternal(t2, ext_tex_plane_1_3, s, vec2<f32>(1.0, 2.0), ext_tex_params_3);
+}
+
+@group(0) @binding(0) var ext_tex : texture_2d<f32>;
+
+@group(0) @binding(1) var smp : sampler;
+
+@group(0) @binding(2) var ext_tex2 : texture_2d<f32>;
+
+@fragment
+fn main() {
+ f(ext_tex, ext_tex_plane_1, ext_tex_params, smp, ext_tex2, ext_tex_plane_1_1, ext_tex_params_1);
+}
+)";
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 3}, {0, 4}}},
+ {{0, 2}, {{0, 5}, {0, 6}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Tests that multiple texture_external params passed to a function produces the
+// correct output.
+TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsParamMultiple_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn main() {
+ f(ext_tex, smp, ext_tex2);
+}
+
+fn f(t : texture_external, s : sampler, t2 : texture_external) {
+ _ = textureSampleBaseClampToEdge(t, s, vec2<f32>(1.0, 2.0));
+ _ = textureSampleBaseClampToEdge(t2, s, vec2<f32>(1.0, 2.0));
+}
+
+@group(0) @binding(0) var ext_tex : texture_external;
+@group(0) @binding(1) var smp : sampler;
+@group(0) @binding(2) var ext_tex2 : texture_external;
+
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(3) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(4) var<uniform> ext_tex_params : ExternalTextureParams;
+
+@group(0) @binding(5) var ext_tex_plane_1_1 : texture_2d<f32>;
+
+@group(0) @binding(6) var<uniform> ext_tex_params_1 : ExternalTextureParams;
+
+@fragment
+fn main() {
+ f(ext_tex, ext_tex_plane_1, ext_tex_params, smp, ext_tex2, ext_tex_plane_1_1, ext_tex_params_1);
+}
+
+fn gammaCorrection(v : vec3<f32>, params : GammaTransferParams) -> vec3<f32> {
+ let cond = (abs(v) < vec3<f32>(params.D));
+ let t = (sign(v) * ((params.C * abs(v)) + params.F));
+ let f = (sign(v) * (pow(((params.A * abs(v)) + params.B), vec3<f32>(params.G)) + params.E));
+ return select(f, t, cond);
+}
+
+fn textureSampleExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, smp : sampler, coord : vec2<f32>, params : ExternalTextureParams) -> vec4<f32> {
+ let modifiedCoords = (params.coordTransformationMatrix * vec3<f32>(coord, 1));
+ let plane0_dims = vec2<f32>(textureDimensions(plane0, 0));
+ let plane0_half_texel = (vec2<f32>(0.5) / plane0_dims);
+ let plane0_clamped = clamp(modifiedCoords, plane0_half_texel, (1 - plane0_half_texel));
+ let plane1_dims = vec2<f32>(textureDimensions(plane1, 0));
+ let plane1_half_texel = (vec2<f32>(0.5) / plane1_dims);
+ let plane1_clamped = clamp(modifiedCoords, plane1_half_texel, (1 - plane1_half_texel));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureSampleLevel(plane0, smp, plane0_clamped, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureSampleLevel(plane0, smp, plane0_clamped, 0).r, textureSampleLevel(plane1, smp, plane1_clamped, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+fn f(t : texture_2d<f32>, ext_tex_plane_1_2 : texture_2d<f32>, ext_tex_params_2 : ExternalTextureParams, s : sampler, t2 : texture_2d<f32>, ext_tex_plane_1_3 : texture_2d<f32>, ext_tex_params_3 : ExternalTextureParams) {
+ _ = textureSampleExternal(t, ext_tex_plane_1_2, s, vec2<f32>(1.0, 2.0), ext_tex_params_2);
+ _ = textureSampleExternal(t2, ext_tex_plane_1_3, s, vec2<f32>(1.0, 2.0), ext_tex_params_3);
+}
+
+@group(0) @binding(0) var ext_tex : texture_2d<f32>;
+
+@group(0) @binding(1) var smp : sampler;
+
+@group(0) @binding(2) var ext_tex2 : texture_2d<f32>;
+)";
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 3}, {0, 4}}},
+ {{0, 2}, {{0, 5}, {0, 6}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Tests that the texture_external passed to as a parameter to multiple
+// functions produces the correct output.
+TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsParamNested) {
+ auto* src = R"(
+fn nested(t : texture_external, s : sampler) {
+ _ = textureSampleBaseClampToEdge(t, s, vec2<f32>(1.0, 2.0));
+}
+
+fn f(t : texture_external, s : sampler) {
+ nested(t, s);
+}
+
+@group(0) @binding(0) var ext_tex : texture_external;
+@group(0) @binding(1) var smp : sampler;
+
+@fragment
+fn main() {
+ f(ext_tex, smp);
+}
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(2) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(3) var<uniform> ext_tex_params : ExternalTextureParams;
+
+fn gammaCorrection(v : vec3<f32>, params : GammaTransferParams) -> vec3<f32> {
+ let cond = (abs(v) < vec3<f32>(params.D));
+ let t = (sign(v) * ((params.C * abs(v)) + params.F));
+ let f = (sign(v) * (pow(((params.A * abs(v)) + params.B), vec3<f32>(params.G)) + params.E));
+ return select(f, t, cond);
+}
+
+fn textureSampleExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, smp : sampler, coord : vec2<f32>, params : ExternalTextureParams) -> vec4<f32> {
+ let modifiedCoords = (params.coordTransformationMatrix * vec3<f32>(coord, 1));
+ let plane0_dims = vec2<f32>(textureDimensions(plane0, 0));
+ let plane0_half_texel = (vec2<f32>(0.5) / plane0_dims);
+ let plane0_clamped = clamp(modifiedCoords, plane0_half_texel, (1 - plane0_half_texel));
+ let plane1_dims = vec2<f32>(textureDimensions(plane1, 0));
+ let plane1_half_texel = (vec2<f32>(0.5) / plane1_dims);
+ let plane1_clamped = clamp(modifiedCoords, plane1_half_texel, (1 - plane1_half_texel));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureSampleLevel(plane0, smp, plane0_clamped, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureSampleLevel(plane0, smp, plane0_clamped, 0).r, textureSampleLevel(plane1, smp, plane1_clamped, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+fn nested(t : texture_2d<f32>, ext_tex_plane_1_1 : texture_2d<f32>, ext_tex_params_1 : ExternalTextureParams, s : sampler) {
+ _ = textureSampleExternal(t, ext_tex_plane_1_1, s, vec2<f32>(1.0, 2.0), ext_tex_params_1);
+}
+
+fn f(t : texture_2d<f32>, ext_tex_plane_1_2 : texture_2d<f32>, ext_tex_params_2 : ExternalTextureParams, s : sampler) {
+ nested(t, ext_tex_plane_1_2, ext_tex_params_2, s);
+}
+
+@group(0) @binding(0) var ext_tex : texture_2d<f32>;
+
+@group(0) @binding(1) var smp : sampler;
+
+@fragment
+fn main() {
+ f(ext_tex, ext_tex_plane_1, ext_tex_params, smp);
+}
+)";
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 2}, {0, 3}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Tests that the texture_external passed to as a parameter to multiple functions produces the
+// correct output.
+TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsParamNested_OutOfOrder) {
+ auto* src = R"(
+fn nested(t : texture_external, s : sampler) {
+ _ = textureSampleBaseClampToEdge(t, s, vec2<f32>(1.0, 2.0));
+}
+
+fn f(t : texture_external, s : sampler) {
+ nested(t, s);
+}
+
+@group(0) @binding(0) var ext_tex : texture_external;
+@group(0) @binding(1) var smp : sampler;
+
+@fragment
+fn main() {
+ f(ext_tex, smp);
+}
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(2) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(3) var<uniform> ext_tex_params : ExternalTextureParams;
+
+fn gammaCorrection(v : vec3<f32>, params : GammaTransferParams) -> vec3<f32> {
+ let cond = (abs(v) < vec3<f32>(params.D));
+ let t = (sign(v) * ((params.C * abs(v)) + params.F));
+ let f = (sign(v) * (pow(((params.A * abs(v)) + params.B), vec3<f32>(params.G)) + params.E));
+ return select(f, t, cond);
+}
+
+fn textureSampleExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, smp : sampler, coord : vec2<f32>, params : ExternalTextureParams) -> vec4<f32> {
+ let modifiedCoords = (params.coordTransformationMatrix * vec3<f32>(coord, 1));
+ let plane0_dims = vec2<f32>(textureDimensions(plane0, 0));
+ let plane0_half_texel = (vec2<f32>(0.5) / plane0_dims);
+ let plane0_clamped = clamp(modifiedCoords, plane0_half_texel, (1 - plane0_half_texel));
+ let plane1_dims = vec2<f32>(textureDimensions(plane1, 0));
+ let plane1_half_texel = (vec2<f32>(0.5) / plane1_dims);
+ let plane1_clamped = clamp(modifiedCoords, plane1_half_texel, (1 - plane1_half_texel));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureSampleLevel(plane0, smp, plane0_clamped, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureSampleLevel(plane0, smp, plane0_clamped, 0).r, textureSampleLevel(plane1, smp, plane1_clamped, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+fn nested(t : texture_2d<f32>, ext_tex_plane_1_1 : texture_2d<f32>, ext_tex_params_1 : ExternalTextureParams, s : sampler) {
+ _ = textureSampleExternal(t, ext_tex_plane_1_1, s, vec2<f32>(1.0, 2.0), ext_tex_params_1);
+}
+
+fn f(t : texture_2d<f32>, ext_tex_plane_1_2 : texture_2d<f32>, ext_tex_params_2 : ExternalTextureParams, s : sampler) {
+ nested(t, ext_tex_plane_1_2, ext_tex_params_2, s);
+}
+
+@group(0) @binding(0) var ext_tex : texture_2d<f32>;
+
+@group(0) @binding(1) var smp : sampler;
+
+@fragment
+fn main() {
+ f(ext_tex, ext_tex_plane_1, ext_tex_params, smp);
+}
+)";
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 2}, {0, 3}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Tests that the transform works with a function using an external texture,
+// even if there's no external texture declared at module scope.
+TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsParamWithoutGlobalDecl) {
+ auto* src = R"(
+fn f(ext_tex : texture_external) -> vec2<u32> {
+ return textureDimensions(ext_tex);
+}
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+fn f(ext_tex : texture_2d<f32>, ext_tex_plane_1 : texture_2d<f32>, ext_tex_params : ExternalTextureParams) -> vec2<u32> {
+ return textureDimensions(ext_tex);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Tests that the the transform handles aliases to external textures
+TEST_F(MultiplanarExternalTextureTest, ExternalTextureAlias) {
+ auto* src = R"(
+alias ET = texture_external;
+
+fn f(t : ET, s : sampler) {
+ _ = textureSampleBaseClampToEdge(t, s, vec2<f32>(1.0, 2.0));
+}
+
+@group(0) @binding(0) var ext_tex : ET;
+@group(0) @binding(1) var smp : sampler;
+
+@fragment
+fn main() {
+ f(ext_tex, smp);
+}
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(2) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(3) var<uniform> ext_tex_params : ExternalTextureParams;
+
+alias ET = texture_external;
+
+fn gammaCorrection(v : vec3<f32>, params : GammaTransferParams) -> vec3<f32> {
+ let cond = (abs(v) < vec3<f32>(params.D));
+ let t = (sign(v) * ((params.C * abs(v)) + params.F));
+ let f = (sign(v) * (pow(((params.A * abs(v)) + params.B), vec3<f32>(params.G)) + params.E));
+ return select(f, t, cond);
+}
+
+fn textureSampleExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, smp : sampler, coord : vec2<f32>, params : ExternalTextureParams) -> vec4<f32> {
+ let modifiedCoords = (params.coordTransformationMatrix * vec3<f32>(coord, 1));
+ let plane0_dims = vec2<f32>(textureDimensions(plane0, 0));
+ let plane0_half_texel = (vec2<f32>(0.5) / plane0_dims);
+ let plane0_clamped = clamp(modifiedCoords, plane0_half_texel, (1 - plane0_half_texel));
+ let plane1_dims = vec2<f32>(textureDimensions(plane1, 0));
+ let plane1_half_texel = (vec2<f32>(0.5) / plane1_dims);
+ let plane1_clamped = clamp(modifiedCoords, plane1_half_texel, (1 - plane1_half_texel));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureSampleLevel(plane0, smp, plane0_clamped, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureSampleLevel(plane0, smp, plane0_clamped, 0).r, textureSampleLevel(plane1, smp, plane1_clamped, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+fn f(t : texture_2d<f32>, ext_tex_plane_1_1 : texture_2d<f32>, ext_tex_params_1 : ExternalTextureParams, s : sampler) {
+ _ = textureSampleExternal(t, ext_tex_plane_1_1, s, vec2<f32>(1.0, 2.0), ext_tex_params_1);
+}
+
+@group(0) @binding(0) var ext_tex : texture_2d<f32>;
+
+@group(0) @binding(1) var smp : sampler;
+
+@fragment
+fn main() {
+ f(ext_tex, ext_tex_plane_1, ext_tex_params, smp);
+}
+)";
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 2}, {0, 3}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Tests that the the transform handles aliases to external textures
+TEST_F(MultiplanarExternalTextureTest, ExternalTextureAlias_OutOfOrder) {
+ auto* src = R"(
+@fragment
+fn main() {
+ f(ext_tex, smp);
+}
+
+fn f(t : ET, s : sampler) {
+ _ = textureSampleBaseClampToEdge(t, s, vec2<f32>(1.0, 2.0));
+}
+
+@group(0) @binding(0) var ext_tex : ET;
+@group(0) @binding(1) var smp : sampler;
+
+alias ET = texture_external;
+)";
+
+ auto* expect = R"(
+struct GammaTransferParams {
+ G : f32,
+ A : f32,
+ B : f32,
+ C : f32,
+ D : f32,
+ E : f32,
+ F : f32,
+ padding : u32,
+}
+
+struct ExternalTextureParams {
+ numPlanes : u32,
+ doYuvToRgbConversionOnly : u32,
+ yuvToRgbConversionMatrix : mat3x4<f32>,
+ gammaDecodeParams : GammaTransferParams,
+ gammaEncodeParams : GammaTransferParams,
+ gamutConversionMatrix : mat3x3<f32>,
+ coordTransformationMatrix : mat3x2<f32>,
+}
+
+@group(0) @binding(2) var ext_tex_plane_1 : texture_2d<f32>;
+
+@group(0) @binding(3) var<uniform> ext_tex_params : ExternalTextureParams;
+
+@fragment
+fn main() {
+ f(ext_tex, ext_tex_plane_1, ext_tex_params, smp);
+}
+
+fn gammaCorrection(v : vec3<f32>, params : GammaTransferParams) -> vec3<f32> {
+ let cond = (abs(v) < vec3<f32>(params.D));
+ let t = (sign(v) * ((params.C * abs(v)) + params.F));
+ let f = (sign(v) * (pow(((params.A * abs(v)) + params.B), vec3<f32>(params.G)) + params.E));
+ return select(f, t, cond);
+}
+
+fn textureSampleExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, smp : sampler, coord : vec2<f32>, params : ExternalTextureParams) -> vec4<f32> {
+ let modifiedCoords = (params.coordTransformationMatrix * vec3<f32>(coord, 1));
+ let plane0_dims = vec2<f32>(textureDimensions(plane0, 0));
+ let plane0_half_texel = (vec2<f32>(0.5) / plane0_dims);
+ let plane0_clamped = clamp(modifiedCoords, plane0_half_texel, (1 - plane0_half_texel));
+ let plane1_dims = vec2<f32>(textureDimensions(plane1, 0));
+ let plane1_half_texel = (vec2<f32>(0.5) / plane1_dims);
+ let plane1_clamped = clamp(modifiedCoords, plane1_half_texel, (1 - plane1_half_texel));
+ var color : vec3<f32>;
+ if ((params.numPlanes == 1)) {
+ color = textureSampleLevel(plane0, smp, plane0_clamped, 0).rgb;
+ } else {
+ color = (vec4<f32>(textureSampleLevel(plane0, smp, plane0_clamped, 0).r, textureSampleLevel(plane1, smp, plane1_clamped, 0).rg, 1) * params.yuvToRgbConversionMatrix);
+ }
+ if ((params.doYuvToRgbConversionOnly == 0)) {
+ color = gammaCorrection(color, params.gammaDecodeParams);
+ color = (params.gamutConversionMatrix * color);
+ color = gammaCorrection(color, params.gammaEncodeParams);
+ }
+ return vec4<f32>(color, 1);
+}
+
+fn f(t : texture_2d<f32>, ext_tex_plane_1_1 : texture_2d<f32>, ext_tex_params_1 : ExternalTextureParams, s : sampler) {
+ _ = textureSampleExternal(t, ext_tex_plane_1_1, s, vec2<f32>(1.0, 2.0), ext_tex_params_1);
+}
+
+@group(0) @binding(0) var ext_tex : texture_2d<f32>;
+
+@group(0) @binding(1) var smp : sampler;
+
+alias ET = texture_external;
+)";
+ Transform::DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(MultiplanarExternalTexture::BindingsMap{
+ {{0, 0}, {{0, 2}, {0, 3}}},
+ });
+ auto got = Run<MultiplanarExternalTexture>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/num_workgroups_from_uniform.cc b/src/tint/lang/wgsl/ast/transform/num_workgroups_from_uniform.cc
new file mode 100644
index 0000000..6360631
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/num_workgroups_from_uniform.cc
@@ -0,0 +1,193 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/num_workgroups_from_uniform.h"
+
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <utility>
+
+#include "src/tint/builtin/builtin_value.h"
+#include "src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/utils/hash.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::NumWorkgroupsFromUniform);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::NumWorkgroupsFromUniform::Config);
+
+namespace tint::ast::transform {
+namespace {
+
+bool ShouldRun(const Program* program) {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (auto* attr = node->As<BuiltinAttribute>()) {
+ if (program->Sem().Get(attr)->Value() == builtin::BuiltinValue::kNumWorkgroups) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+/// Accessor describes the identifiers used in a member accessor that is being
+/// used to retrieve the num_workgroups builtin from a parameter.
+struct Accessor {
+ Symbol param;
+ Symbol member;
+
+ /// Equality operator
+ bool operator==(const Accessor& other) const {
+ return param == other.param && member == other.member;
+ }
+ /// Hash function
+ struct Hasher {
+ size_t operator()(const Accessor& a) const { return utils::Hash(a.param, a.member); }
+ };
+};
+
+} // namespace
+
+NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default;
+NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default;
+
+Transform::ApplyResult NumWorkgroupsFromUniform::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ auto* cfg = inputs.Get<Config>();
+ if (cfg == nullptr) {
+ b.Diagnostics().add_error(diag::System::Transform,
+ "missing transform data for " + std::string(TypeInfo().name));
+ return Program(std::move(b));
+ }
+
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ const char* kNumWorkgroupsMemberName = "num_workgroups";
+
+ // Find all entry point parameters that declare the num_workgroups builtin.
+ std::unordered_set<Accessor, Accessor::Hasher> to_replace;
+ for (auto* func : src->AST().Functions()) {
+ // num_workgroups is only valid for compute stages.
+ if (func->PipelineStage() != PipelineStage::kCompute) {
+ continue;
+ }
+
+ for (auto* param : src->Sem().Get(func)->Parameters()) {
+ // Because the CanonicalizeEntryPointIO transform has been run, builtins
+ // will only appear as struct members.
+ auto* str = param->Type()->As<sem::Struct>();
+ if (!str) {
+ continue;
+ }
+
+ for (auto* member : str->Members()) {
+ if (member->Attributes().builtin != builtin::BuiltinValue::kNumWorkgroups) {
+ continue;
+ }
+
+ // Capture the symbols that would be used to access this member, which
+ // we will replace later. We currently have no way to get from the
+ // parameter directly to the member accessor expressions that use it.
+ to_replace.insert({param->Declaration()->name->symbol, member->Name()});
+
+ // Remove the struct member.
+ // The CanonicalizeEntryPointIO transform will have generated this
+ // struct uniquely for this particular entry point, so we know that
+ // there will be no other uses of this struct in the module and that we
+ // can safely modify it here.
+ ctx.Remove(str->Declaration()->members, member->Declaration());
+
+ // If this is the only member, remove the struct and parameter too.
+ if (str->Members().Length() == 1) {
+ ctx.Remove(func->params, param->Declaration());
+ ctx.Remove(src->AST().GlobalDeclarations(), str->Declaration());
+ }
+ }
+ }
+ }
+
+ // Get (or create, on first call) the uniform buffer that will receive the
+ // number of workgroups.
+ const Variable* num_workgroups_ubo = nullptr;
+ auto get_ubo = [&] {
+ if (!num_workgroups_ubo) {
+ auto* num_workgroups_struct =
+ b.Structure(b.Sym(), utils::Vector{
+ b.Member(kNumWorkgroupsMemberName, b.ty.vec3(b.ty.u32())),
+ });
+
+ uint32_t group, binding;
+ if (cfg->ubo_binding.has_value()) {
+ // If cfg->ubo_binding holds a value, use the specified binding point.
+ group = cfg->ubo_binding->group;
+ binding = cfg->ubo_binding->binding;
+ } else {
+ // If cfg->ubo_binding holds no value, use the binding 0 of the largest used group
+ // plus 1, or group 0 if no resource bound.
+ group = 0;
+
+ for (auto* global : src->AST().GlobalVariables()) {
+ auto* global_sem = src->Sem().Get<sem::GlobalVariable>(global);
+ if (auto bp = global_sem->BindingPoint()) {
+ if (bp->group >= group) {
+ group = bp->group + 1;
+ }
+ }
+ }
+
+ binding = 0;
+ }
+
+ num_workgroups_ubo = b.GlobalVar(b.Sym(), b.ty.Of(num_workgroups_struct),
+ builtin::AddressSpace::kUniform, b.Group(AInt(group)),
+ b.Binding(AInt(binding)));
+ }
+ return num_workgroups_ubo;
+ };
+
+ // Now replace all the places where the builtins are accessed with the value
+ // loaded from the uniform buffer.
+ for (auto* node : src->ASTNodes().Objects()) {
+ auto* accessor = node->As<MemberAccessorExpression>();
+ if (!accessor) {
+ continue;
+ }
+ auto* ident = accessor->object->As<IdentifierExpression>();
+ if (!ident) {
+ continue;
+ }
+
+ if (to_replace.count({ident->identifier->symbol, accessor->member->symbol})) {
+ ctx.Replace(accessor,
+ b.MemberAccessor(get_ubo()->name->symbol, kNumWorkgroupsMemberName));
+ }
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+NumWorkgroupsFromUniform::Config::Config(std::optional<sem::BindingPoint> ubo_bp)
+ : ubo_binding(ubo_bp) {}
+NumWorkgroupsFromUniform::Config::Config(const Config&) = default;
+NumWorkgroupsFromUniform::Config::~Config() = default;
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/num_workgroups_from_uniform.h b/src/tint/lang/wgsl/ast/transform/num_workgroups_from_uniform.h
new file mode 100644
index 0000000..947ece5
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/num_workgroups_from_uniform.h
@@ -0,0 +1,83 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_NUM_WORKGROUPS_FROM_UNIFORM_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_NUM_WORKGROUPS_FROM_UNIFORM_H_
+
+#include <optional>
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+#include "src/tint/sem/binding_point.h"
+
+// Forward declarations
+namespace tint {
+class CloneContext;
+} // namespace tint
+
+namespace tint::ast::transform {
+
+/// NumWorkgroupsFromUniform is a transform that implements the `num_workgroups`
+/// builtin by loading it from a uniform buffer.
+///
+/// The generated uniform buffer will have the form:
+/// ```
+/// struct num_workgroups_struct {
+/// num_workgroups : vec3<u32>;
+/// };
+///
+/// @group(0) @binding(0)
+/// var<uniform> num_workgroups_ubo : num_workgroups_struct;
+/// ```
+/// The binding group and number used for this uniform buffer is provided via
+/// the `Config` transform input.
+///
+/// @note Depends on the following transforms to have been run first:
+/// * CanonicalizeEntryPointIO
+class NumWorkgroupsFromUniform final : public utils::Castable<NumWorkgroupsFromUniform, Transform> {
+ public:
+ /// Constructor
+ NumWorkgroupsFromUniform();
+ /// Destructor
+ ~NumWorkgroupsFromUniform() override;
+
+ /// Configuration options for the NumWorkgroupsFromUniform transform.
+ struct Config final : public utils::Castable<Config, Data> {
+ /// Constructor
+ /// @param ubo_bp the binding point to use for the generated uniform buffer. If ubo_bp
+ /// contains no value, a free binding point will be used to ensure the generated program is
+ /// valid. Specifically, binding 0 of the largest used group plus 1 is used if at least one
+ /// resource is bound, otherwise group 0 binding 0 is used.
+ explicit Config(std::optional<sem::BindingPoint> ubo_bp);
+
+ /// Copy constructor
+ Config(const Config&);
+
+ /// Destructor
+ ~Config() override;
+
+ /// The binding point to use for the generated uniform buffer. If ubo_bp contains no value,
+ /// a free binding point will be used. Specifically, binding 0 of the largest used group
+ /// plus 1 is used if at least one resource is bound, otherwise group 0 binding 0 is used.
+ std::optional<sem::BindingPoint> ubo_binding;
+ };
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_NUM_WORKGROUPS_FROM_UNIFORM_H_
diff --git a/src/tint/lang/wgsl/ast/transform/num_workgroups_from_uniform_test.cc b/src/tint/lang/wgsl/ast/transform/num_workgroups_from_uniform_test.cc
new file mode 100644
index 0000000..b547355
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/num_workgroups_from_uniform_test.cc
@@ -0,0 +1,696 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/num_workgroups_from_uniform.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h"
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using NumWorkgroupsFromUniformTest = TransformTest;
+
+TEST_F(NumWorkgroupsFromUniformTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ Transform::DataMap data;
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src, data));
+}
+
+TEST_F(NumWorkgroupsFromUniformTest, ShouldRunHasNumWorkgroups) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) {
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src, data));
+}
+
+TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) {
+}
+)";
+
+ auto* expect =
+ "error: missing transform data for tint::ast::transform::NumWorkgroupsFromUniform";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(NumWorkgroupsFromUniformTest, Basic) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) {
+ let groups_x = num_wgs.x;
+ let groups_y = num_wgs.y;
+ let groups_z = num_wgs.z;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol_2 {
+ num_workgroups : vec3<u32>,
+}
+
+@group(0) @binding(30) var<uniform> tint_symbol_3 : tint_symbol_2;
+
+fn main_inner(num_wgs : vec3<u32>) {
+ let groups_x = num_wgs.x;
+ let groups_y = num_wgs.y;
+ let groups_z = num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ main_inner(tint_symbol_3.num_workgroups);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(NumWorkgroupsFromUniformTest, StructOnlyMember) {
+ auto* src = R"(
+struct Builtins {
+ @builtin(num_workgroups) num_wgs : vec3<u32>,
+};
+
+@compute @workgroup_size(1)
+fn main(in : Builtins) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol_2 {
+ num_workgroups : vec3<u32>,
+}
+
+@group(0) @binding(30) var<uniform> tint_symbol_3 : tint_symbol_2;
+
+struct Builtins {
+ num_wgs : vec3<u32>,
+}
+
+fn main_inner(in : Builtins) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ main_inner(Builtins(tint_symbol_3.num_workgroups));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(NumWorkgroupsFromUniformTest, StructOnlyMember_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main(in : Builtins) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+
+struct Builtins {
+ @builtin(num_workgroups) num_wgs : vec3<u32>,
+};
+)";
+
+ auto* expect = R"(
+struct tint_symbol_2 {
+ num_workgroups : vec3<u32>,
+}
+
+@group(0) @binding(30) var<uniform> tint_symbol_3 : tint_symbol_2;
+
+fn main_inner(in : Builtins) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ main_inner(Builtins(tint_symbol_3.num_workgroups));
+}
+
+struct Builtins {
+ num_wgs : vec3<u32>,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(NumWorkgroupsFromUniformTest, StructMultipleMembers) {
+ auto* src = R"(
+struct Builtins {
+ @builtin(global_invocation_id) gid : vec3<u32>,
+ @builtin(num_workgroups) num_wgs : vec3<u32>,
+ @builtin(workgroup_id) wgid : vec3<u32>,
+};
+
+@compute @workgroup_size(1)
+fn main(in : Builtins) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol_2 {
+ num_workgroups : vec3<u32>,
+}
+
+@group(0) @binding(30) var<uniform> tint_symbol_3 : tint_symbol_2;
+
+struct Builtins {
+ gid : vec3<u32>,
+ num_wgs : vec3<u32>,
+ wgid : vec3<u32>,
+}
+
+struct tint_symbol_1 {
+ @builtin(global_invocation_id)
+ gid : vec3<u32>,
+ @builtin(workgroup_id)
+ wgid : vec3<u32>,
+}
+
+fn main_inner(in : Builtins) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main(tint_symbol : tint_symbol_1) {
+ main_inner(Builtins(tint_symbol.gid, tint_symbol_3.num_workgroups, tint_symbol.wgid));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(NumWorkgroupsFromUniformTest, StructMultipleMembers_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main(in : Builtins) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+
+struct Builtins {
+ @builtin(global_invocation_id) gid : vec3<u32>,
+ @builtin(num_workgroups) num_wgs : vec3<u32>,
+ @builtin(workgroup_id) wgid : vec3<u32>,
+};
+
+)";
+
+ auto* expect = R"(
+struct tint_symbol_2 {
+ num_workgroups : vec3<u32>,
+}
+
+@group(0) @binding(30) var<uniform> tint_symbol_3 : tint_symbol_2;
+
+struct tint_symbol_1 {
+ @builtin(global_invocation_id)
+ gid : vec3<u32>,
+ @builtin(workgroup_id)
+ wgid : vec3<u32>,
+}
+
+fn main_inner(in : Builtins) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main(tint_symbol : tint_symbol_1) {
+ main_inner(Builtins(tint_symbol.gid, tint_symbol_3.num_workgroups, tint_symbol.wgid));
+}
+
+struct Builtins {
+ gid : vec3<u32>,
+ num_wgs : vec3<u32>,
+ wgid : vec3<u32>,
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(NumWorkgroupsFromUniformTest, MultipleEntryPoints) {
+ auto* src = R"(
+struct Builtins1 {
+ @builtin(num_workgroups) num_wgs : vec3<u32>,
+};
+
+struct Builtins2 {
+ @builtin(global_invocation_id) gid : vec3<u32>,
+ @builtin(num_workgroups) num_wgs : vec3<u32>,
+ @builtin(workgroup_id) wgid : vec3<u32>,
+};
+
+@compute @workgroup_size(1)
+fn main1(in : Builtins1) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main2(in : Builtins2) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main3(@builtin(num_workgroups) num_wgs : vec3<u32>) {
+ let groups_x = num_wgs.x;
+ let groups_y = num_wgs.y;
+ let groups_z = num_wgs.z;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol_6 {
+ num_workgroups : vec3<u32>,
+}
+
+@group(0) @binding(30) var<uniform> tint_symbol_7 : tint_symbol_6;
+
+struct Builtins1 {
+ num_wgs : vec3<u32>,
+}
+
+struct Builtins2 {
+ gid : vec3<u32>,
+ num_wgs : vec3<u32>,
+ wgid : vec3<u32>,
+}
+
+fn main1_inner(in : Builtins1) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main1() {
+ main1_inner(Builtins1(tint_symbol_7.num_workgroups));
+}
+
+struct tint_symbol_3 {
+ @builtin(global_invocation_id)
+ gid : vec3<u32>,
+ @builtin(workgroup_id)
+ wgid : vec3<u32>,
+}
+
+fn main2_inner(in : Builtins2) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main2(tint_symbol_2 : tint_symbol_3) {
+ main2_inner(Builtins2(tint_symbol_2.gid, tint_symbol_7.num_workgroups, tint_symbol_2.wgid));
+}
+
+fn main3_inner(num_wgs : vec3<u32>) {
+ let groups_x = num_wgs.x;
+ let groups_y = num_wgs.y;
+ let groups_z = num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main3() {
+ main3_inner(tint_symbol_7.num_workgroups);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(NumWorkgroupsFromUniformTest, NoUsages) {
+ auto* src = R"(
+struct Builtins {
+ @builtin(global_invocation_id) gid : vec3<u32>,
+ @builtin(workgroup_id) wgid : vec3<u32>,
+};
+
+@compute @workgroup_size(1)
+fn main(in : Builtins) {
+}
+)";
+
+ auto* expect = R"(
+struct Builtins {
+ gid : vec3<u32>,
+ wgid : vec3<u32>,
+}
+
+struct tint_symbol_1 {
+ @builtin(global_invocation_id)
+ gid : vec3<u32>,
+ @builtin(workgroup_id)
+ wgid : vec3<u32>,
+}
+
+fn main_inner(in : Builtins) {
+}
+
+@compute @workgroup_size(1)
+fn main(tint_symbol : tint_symbol_1) {
+ main_inner(Builtins(tint_symbol.gid, tint_symbol.wgid));
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that group 0 binding 0 is used if no bound resource in the program and binding point is not
+// specified in NumWorkgroupsFromUniform::Config.
+TEST_F(NumWorkgroupsFromUniformTest, UnspecifiedBindingPoint_NoResourceBound) {
+ auto* src = R"(
+struct Builtins1 {
+ @builtin(num_workgroups) num_wgs : vec3<u32>,
+};
+
+struct Builtins2 {
+ @builtin(global_invocation_id) gid : vec3<u32>,
+ @builtin(num_workgroups) num_wgs : vec3<u32>,
+ @builtin(workgroup_id) wgid : vec3<u32>,
+};
+
+@compute @workgroup_size(1)
+fn main1(in : Builtins1) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main2(in : Builtins2) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main3(@builtin(num_workgroups) num_wgs : vec3<u32>) {
+ let groups_x = num_wgs.x;
+ let groups_y = num_wgs.y;
+ let groups_z = num_wgs.z;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol_6 {
+ num_workgroups : vec3<u32>,
+}
+
+@group(0) @binding(0) var<uniform> tint_symbol_7 : tint_symbol_6;
+
+struct Builtins1 {
+ num_wgs : vec3<u32>,
+}
+
+struct Builtins2 {
+ gid : vec3<u32>,
+ num_wgs : vec3<u32>,
+ wgid : vec3<u32>,
+}
+
+fn main1_inner(in : Builtins1) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main1() {
+ main1_inner(Builtins1(tint_symbol_7.num_workgroups));
+}
+
+struct tint_symbol_3 {
+ @builtin(global_invocation_id)
+ gid : vec3<u32>,
+ @builtin(workgroup_id)
+ wgid : vec3<u32>,
+}
+
+fn main2_inner(in : Builtins2) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main2(tint_symbol_2 : tint_symbol_3) {
+ main2_inner(Builtins2(tint_symbol_2.gid, tint_symbol_7.num_workgroups, tint_symbol_2.wgid));
+}
+
+fn main3_inner(num_wgs : vec3<u32>) {
+ let groups_x = num_wgs.x;
+ let groups_y = num_wgs.y;
+ let groups_z = num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main3() {
+ main3_inner(tint_symbol_7.num_workgroups);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ // Make binding point unspecified.
+ data.Add<NumWorkgroupsFromUniform::Config>(std::nullopt);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that binding 0 of the largest used group plus 1 is used if at least one resource is bound in
+// the program and binding point is not specified in NumWorkgroupsFromUniform::Config.
+TEST_F(NumWorkgroupsFromUniformTest, UnspecifiedBindingPoint_MultipleResourceBound) {
+ auto* src = R"(
+struct Builtins1 {
+ @builtin(num_workgroups) num_wgs : vec3<u32>,
+};
+
+struct Builtins2 {
+ @builtin(global_invocation_id) gid : vec3<u32>,
+ @builtin(num_workgroups) num_wgs : vec3<u32>,
+ @builtin(workgroup_id) wgid : vec3<u32>,
+};
+
+struct S0 {
+ @size(4)
+ m0 : u32,
+ m1 : array<u32>,
+};
+
+struct S1 {
+ @size(4)
+ m0 : u32,
+ m1 : array<u32, 6>,
+};
+
+@group(0) @binding(0) var g2 : texture_2d<f32>;
+@group(1) @binding(0) var g3 : texture_depth_2d;
+@group(1) @binding(1) var g4 : texture_storage_2d<rg32float, write>;
+@group(3) @binding(0) var g5 : texture_depth_cube_array;
+@group(4) @binding(0) var g6 : texture_external;
+
+@group(0) @binding(1) var<storage, read_write> g8 : S0;
+@group(1) @binding(3) var<storage, read> g9 : S0;
+@group(3) @binding(2) var<storage, read_write> g10 : S0;
+
+@compute @workgroup_size(1)
+fn main1(in : Builtins1) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+ g8.m0 = 1u;
+}
+
+@compute @workgroup_size(1)
+fn main2(in : Builtins2) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main3(@builtin(num_workgroups) num_wgs : vec3<u32>) {
+ let groups_x = num_wgs.x;
+ let groups_y = num_wgs.y;
+ let groups_z = num_wgs.z;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol_6 {
+ num_workgroups : vec3<u32>,
+}
+
+@group(5) @binding(0) var<uniform> tint_symbol_7 : tint_symbol_6;
+
+struct Builtins1 {
+ num_wgs : vec3<u32>,
+}
+
+struct Builtins2 {
+ gid : vec3<u32>,
+ num_wgs : vec3<u32>,
+ wgid : vec3<u32>,
+}
+
+struct S0 {
+ @size(4)
+ m0 : u32,
+ m1 : array<u32>,
+}
+
+struct S1 {
+ @size(4)
+ m0 : u32,
+ m1 : array<u32, 6>,
+}
+
+@group(0) @binding(0) var g2 : texture_2d<f32>;
+
+@group(1) @binding(0) var g3 : texture_depth_2d;
+
+@group(1) @binding(1) var g4 : texture_storage_2d<rg32float, write>;
+
+@group(3) @binding(0) var g5 : texture_depth_cube_array;
+
+@group(4) @binding(0) var g6 : texture_external;
+
+@group(0) @binding(1) var<storage, read_write> g8 : S0;
+
+@group(1) @binding(3) var<storage, read> g9 : S0;
+
+@group(3) @binding(2) var<storage, read_write> g10 : S0;
+
+fn main1_inner(in : Builtins1) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+ g8.m0 = 1u;
+}
+
+@compute @workgroup_size(1)
+fn main1() {
+ main1_inner(Builtins1(tint_symbol_7.num_workgroups));
+}
+
+struct tint_symbol_3 {
+ @builtin(global_invocation_id)
+ gid : vec3<u32>,
+ @builtin(workgroup_id)
+ wgid : vec3<u32>,
+}
+
+fn main2_inner(in : Builtins2) {
+ let groups_x = in.num_wgs.x;
+ let groups_y = in.num_wgs.y;
+ let groups_z = in.num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main2(tint_symbol_2 : tint_symbol_3) {
+ main2_inner(Builtins2(tint_symbol_2.gid, tint_symbol_7.num_workgroups, tint_symbol_2.wgid));
+}
+
+fn main3_inner(num_wgs : vec3<u32>) {
+ let groups_x = num_wgs.x;
+ let groups_y = num_wgs.y;
+ let groups_z = num_wgs.z;
+}
+
+@compute @workgroup_size(1)
+fn main3() {
+ main3_inner(tint_symbol_7.num_workgroups);
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ // Make binding point unspecified.
+ data.Add<NumWorkgroupsFromUniform::Config>(std::nullopt);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/packed_vec3.cc b/src/tint/lang/wgsl/ast/transform/packed_vec3.cc
new file mode 100644
index 0000000..da02d5b
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/packed_vec3.cc
@@ -0,0 +1,521 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/packed_vec3.h"
+
+#include <algorithm>
+#include <string>
+#include <utility>
+
+#include "src/tint/builtin/builtin.h"
+#include "src/tint/lang/wgsl/ast/assignment_statement.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/array_count.h"
+#include "src/tint/sem/index_accessor_expression.h"
+#include "src/tint/sem/load.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/type_expression.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
+#include "src/tint/type/array.h"
+#include "src/tint/type/reference.h"
+#include "src/tint/type/vector.h"
+#include "src/tint/utils/hashmap.h"
+#include "src/tint/utils/hashset.h"
+#include "src/tint/utils/vector.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::PackedVec3);
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast::transform {
+
+/// PIMPL state for the transform
+struct PackedVec3::State {
+ /// Constructor
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
+
+ /// The name of the struct member used when wrapping packed vec3 types.
+ static constexpr const char* kStructMemberName = "elements";
+
+ /// The names of the structures used to wrap packed vec3 types.
+ utils::Hashmap<const type::Type*, Symbol, 4> packed_vec3_wrapper_struct_names;
+
+ /// A cache of host-shareable structures that have been rewritten.
+ utils::Hashmap<const type::Type*, Symbol, 4> rewritten_structs;
+
+ /// A map from type to the name of a helper function used to pack that type.
+ utils::Hashmap<const type::Type*, Symbol, 4> pack_helpers;
+
+ /// A map from type to the name of a helper function used to unpack that type.
+ utils::Hashmap<const type::Type*, Symbol, 4> unpack_helpers;
+
+ /// @param ty the type to test
+ /// @returns true if `ty` is a vec3, false otherwise
+ bool IsVec3(const type::Type* ty) {
+ if (auto* vec = ty->As<type::Vector>()) {
+ if (vec->Width() == 3) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /// @param ty the type to test
+ /// @returns true if `ty` is or contains a vec3, false otherwise
+ bool ContainsVec3(const type::Type* ty) {
+ return Switch(
+ ty, //
+ [&](const type::Vector* vec) { return IsVec3(vec); },
+ [&](const type::Matrix* mat) { return ContainsVec3(mat->ColumnType()); },
+ [&](const type::Array* arr) { return ContainsVec3(arr->ElemType()); },
+ [&](const type::Struct* str) {
+ for (auto* member : str->Members()) {
+ if (ContainsVec3(member->Type())) {
+ return true;
+ }
+ }
+ return false;
+ });
+ }
+
+ /// Create a `__packed_vec3` type with the same element type as `ty`.
+ /// @param ty a three-element vector type
+ /// @returns the new AST type
+ Type MakePackedVec3(const type::Type* ty) {
+ auto* vec = ty->As<type::Vector>();
+ TINT_ASSERT(Transform, vec != nullptr && vec->Width() == 3);
+ return b.ty(builtin::Builtin::kPackedVec3, CreateASTTypeFor(ctx, vec->type()));
+ }
+
+ /// Recursively rewrite a type using `__packed_vec3`, if needed.
+ /// When used as an array element type, the `__packed_vec3` type will be wrapped in a structure
+ /// and given an `@align()` attribute to give it alignment it needs to yield the correct array
+ /// element stride. For vec3 types used in structures directly, the `@align()` attribute is
+ /// placed on the containing structure instead. Matrices with three rows become arrays of
+ /// columns, and used the aligned wrapper struct for the column type.
+ /// @param ty the type to rewrite
+ /// @param array_element `true` if this is being called for the element of an array
+ /// @returns the new AST type, or nullptr if rewriting was not necessary
+ Type RewriteType(const type::Type* ty, bool array_element = false) {
+ return Switch(
+ ty,
+ [&](const type::Vector* vec) -> Type {
+ if (IsVec3(vec)) {
+ if (array_element) {
+ // Create a struct with a single `__packed_vec3` member.
+ // Give the struct member the same alignment as the original unpacked vec3
+ // type, to avoid changing the array element stride.
+ return b.ty(packed_vec3_wrapper_struct_names.GetOrCreate(vec, [&] {
+ auto name = b.Symbols().New(
+ "tint_packed_vec3_" + vec->type()->FriendlyName() +
+ (array_element ? "_array_element" : "_struct_member"));
+ auto* member =
+ b.Member(kStructMemberName, MakePackedVec3(vec),
+ utils::Vector{b.MemberAlign(AInt(vec->Align()))});
+ b.Structure(b.Ident(name), utils::Vector{member}, utils::Empty);
+ return name;
+ }));
+ } else {
+ return MakePackedVec3(vec);
+ }
+ }
+ return {};
+ },
+ [&](const type::Matrix* mat) -> Type {
+ // Rewrite the matrix as an array of columns that use the aligned wrapper struct.
+ auto new_col_type = RewriteType(mat->ColumnType(), /* array_element */ true);
+ if (new_col_type) {
+ return b.ty.array(new_col_type, u32(mat->columns()));
+ }
+ return {};
+ },
+ [&](const type::Array* arr) -> Type {
+ // Rewrite the array with the modified element type.
+ auto new_type = RewriteType(arr->ElemType(), /* array_element */ true);
+ if (new_type) {
+ utils::Vector<const Attribute*, 1> attrs;
+ if (arr->Count()->Is<type::RuntimeArrayCount>()) {
+ return b.ty.array(new_type, std::move(attrs));
+ } else if (auto count = arr->ConstantCount()) {
+ return b.ty.array(new_type, u32(count.value()), std::move(attrs));
+ } else {
+ TINT_ICE(Transform, b.Diagnostics())
+ << type::Array::kErrExpectedConstantCount;
+ return {};
+ }
+ }
+ return {};
+ },
+ [&](const type::Struct* str) -> Type {
+ if (ContainsVec3(str)) {
+ auto name = rewritten_structs.GetOrCreate(str, [&] {
+ utils::Vector<const StructMember*, 4> members;
+ for (auto* member : str->Members()) {
+ // If the member type contains a vec3, rewrite it.
+ auto new_type = RewriteType(member->Type());
+ if (new_type) {
+ // Copy the member attributes.
+ bool needs_align = true;
+ utils::Vector<const Attribute*, 4> attributes;
+ if (auto* sem_mem = member->As<sem::StructMember>()) {
+ for (auto* attr : sem_mem->Declaration()->attributes) {
+ if (attr->IsAnyOf<StructMemberAlignAttribute,
+ StructMemberOffsetAttribute>()) {
+ needs_align = false;
+ }
+ attributes.Push(ctx.Clone(attr));
+ }
+ }
+ // If the alignment wasn't already specified, add an attribute to
+ // make sure that we don't alter the alignment when using the packed
+ // vector type.
+ if (needs_align) {
+ attributes.Push(b.MemberAlign(AInt(member->Align())));
+ }
+ members.Push(b.Member(ctx.Clone(member->Name()), new_type,
+ std::move(attributes)));
+ } else {
+ // No vec3s, just clone the member as is.
+ if (auto* sem_mem = member->As<sem::StructMember>()) {
+ members.Push(ctx.Clone(sem_mem->Declaration()));
+ } else {
+ members.Push(b.Member(ctx.Clone(member->Name()), new_type,
+ utils::Empty));
+ }
+ }
+ }
+ // Create the new structure.
+ auto struct_name =
+ b.Symbols().New(str->Name().Name() + "_tint_packed_vec3");
+ b.Structure(struct_name, std::move(members));
+ return struct_name;
+ });
+ return b.ty(name);
+ }
+ return {};
+ });
+ }
+
+ /// Create a helper function to recursively pack or unpack a composite that contains vec3 types.
+ /// @param name_prefix the name of the helper function
+ /// @param ty the composite type to pack or unpack
+ /// @param pack_or_unpack_element a function that packs or unpacks an element with a given type
+ /// @param in_type a function that create an AST type for the input type
+ /// @param out_type a function that create an AST type for the output type
+ /// @returns the name of the helper function
+ Symbol MakePackUnpackHelper(
+ const char* name_prefix,
+ const type::Type* ty,
+ const std::function<const Expression*(const Expression*, const type::Type*)>&
+ pack_or_unpack_element,
+ const std::function<Type()>& in_type,
+ const std::function<Type()>& out_type) {
+ // Allocate a variable to hold the return value of the function.
+ utils::Vector<const Statement*, 4> statements;
+ statements.Push(b.Decl(b.Var("result", out_type())));
+
+ // Helper that generates a loop to copy and pack/unpack elements of an array to the result:
+ // for (var i = 0u; i < num_elements; i = i + 1) {
+ // result[i] = pack_or_unpack_element(in[i]);
+ // }
+ auto copy_array_elements = [&](uint32_t num_elements, const type::Type* element_type) {
+ // Generate an expression for packing or unpacking an element of the array.
+ auto* element = pack_or_unpack_element(b.IndexAccessor("in", "i"), element_type);
+ statements.Push(b.For( //
+ b.Decl(b.Var("i", b.ty.u32())), //
+ b.LessThan("i", u32(num_elements)), //
+ b.Assign("i", b.Add("i", 1_a)), //
+ b.Block(utils::Vector{
+ b.Assign(b.IndexAccessor("result", "i"), element),
+ })));
+ };
+
+ // Copy the elements of the value over to the result.
+ Switch(
+ ty,
+ [&](const type::Array* arr) {
+ TINT_ASSERT(Transform, arr->ConstantCount());
+ copy_array_elements(arr->ConstantCount().value(), arr->ElemType());
+ },
+ [&](const type::Matrix* mat) {
+ copy_array_elements(mat->columns(), mat->ColumnType());
+ },
+ [&](const type::Struct* str) {
+ // Copy the struct members over one at a time, packing/unpacking as necessary.
+ for (auto* member : str->Members()) {
+ const Expression* element =
+ b.MemberAccessor("in", b.Ident(ctx.Clone(member->Name())));
+ if (ContainsVec3(member->Type())) {
+ element = pack_or_unpack_element(element, member->Type());
+ }
+ statements.Push(b.Assign(
+ b.MemberAccessor("result", b.Ident(ctx.Clone(member->Name()))), element));
+ }
+ });
+
+ // Return the result.
+ statements.Push(b.Return("result"));
+
+ // Create the function and return its name.
+ auto name = b.Symbols().New(name_prefix);
+ b.Func(name, utils::Vector{b.Param("in", in_type())}, out_type(), std::move(statements));
+ return name;
+ }
+
+ /// Unpack the composite value `expr` to the unpacked type `ty`. If `ty` is a matrix, this will
+ /// produce a regular matNx3 value from an array of packed column vectors.
+ /// @param expr the composite value expression to unpack
+ /// @param ty the unpacked type
+ /// @returns an expression that holds the unpacked value
+ const Expression* UnpackComposite(const Expression* expr, const type::Type* ty) {
+ auto helper = unpack_helpers.GetOrCreate(ty, [&] {
+ return MakePackUnpackHelper(
+ "tint_unpack_vec3_in_composite", ty,
+ [&](const Expression* element,
+ const type::Type* element_type) -> const Expression* {
+ if (element_type->Is<type::Vector>()) {
+ // Unpack a `__packed_vec3` by casting it to a regular vec3.
+ // If it is an array element, extract the vector from the wrapper struct.
+ if (element->Is<IndexAccessorExpression>()) {
+ element = b.MemberAccessor(element, kStructMemberName);
+ }
+ return b.Call(CreateASTTypeFor(ctx, element_type), element);
+ } else {
+ return UnpackComposite(element, element_type);
+ }
+ },
+ [&] { return RewriteType(ty); }, //
+ [&] { return CreateASTTypeFor(ctx, ty); });
+ });
+ return b.Call(helper, expr);
+ }
+
+ /// Pack the composite value `expr` from the unpacked type `ty`. If `ty` is a matrix, this will
+ /// produce an array of packed column vectors.
+ /// @param expr the composite value expression to pack
+ /// @param ty the unpacked type
+ /// @returns an expression that holds the packed value
+ const Expression* PackComposite(const Expression* expr, const type::Type* ty) {
+ auto helper = pack_helpers.GetOrCreate(ty, [&] {
+ return MakePackUnpackHelper(
+ "tint_pack_vec3_in_composite", ty,
+ [&](const Expression* element,
+ const type::Type* element_type) -> const Expression* {
+ if (element_type->Is<type::Vector>()) {
+ // Pack a vector element by casting it to a packed_vec3.
+ // If it is an array element, construct a wrapper struct.
+ auto* packed = b.Call(MakePackedVec3(element_type), element);
+ if (element->Is<IndexAccessorExpression>()) {
+ packed = b.Call(RewriteType(element_type, true), packed);
+ }
+ return packed;
+ } else {
+ return PackComposite(element, element_type);
+ }
+ },
+ [&] { return CreateASTTypeFor(ctx, ty); }, //
+ [&] { return RewriteType(ty); });
+ });
+ return b.Call(helper, expr);
+ }
+
+ /// @returns true if there are host-shareable vec3's that need transforming
+ bool ShouldRun() {
+ // Check for vec3s in the types of all uniform and storage buffer variables to determine
+ // if the transform is necessary.
+ for (auto* decl : src->AST().GlobalVariables()) {
+ auto* var = sem.Get<sem::GlobalVariable>(decl);
+ if (var && builtin::IsHostShareable(var->AddressSpace()) &&
+ ContainsVec3(var->Type()->UnwrapRef())) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
+ if (!ShouldRun()) {
+ return SkipTransform;
+ }
+
+ // Changing the types of certain structure members can trigger stricter layout validation
+ // rules for the uniform address space. In particular, replacing 16-bit matrices with arrays
+ // violates the requirement that the array element stride is a multiple of 16 bytes, and
+ // replacing vec3s with a structure violates the requirement that there must be at least 16
+ // bytes from the start of a structure to the start of the next member.
+ // Disable these validation rules using an internal extension, as MSL does not have these
+ // restrictions.
+ b.Enable(builtin::Extension::kChromiumInternalRelaxedUniformLayout);
+
+ // Track expressions that need to be packed or unpacked.
+ utils::Hashset<const sem::ValueExpression*, 8> to_pack;
+ utils::Hashset<const sem::ValueExpression*, 8> to_unpack;
+
+ // Replace vec3 types in host-shareable address spaces with `__packed_vec3` types, and
+ // collect expressions that need to be converted to or from values that use the
+ // `__packed_vec3` type.
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ Switch(
+ sem.Get(node),
+ [&](const sem::TypeExpression* type) {
+ // Rewrite pointers to types that contain vec3s.
+ auto* ptr = type->Type()->As<type::Pointer>();
+ if (ptr && builtin::IsHostShareable(ptr->AddressSpace())) {
+ auto new_store_type = RewriteType(ptr->StoreType());
+ if (new_store_type) {
+ auto access = ptr->AddressSpace() == builtin::AddressSpace::kStorage
+ ? ptr->Access()
+ : builtin::Access::kUndefined;
+ auto new_ptr_type =
+ b.ty.ptr(ptr->AddressSpace(), new_store_type, access);
+ ctx.Replace(node, new_ptr_type.expr);
+ }
+ }
+ },
+ [&](const sem::Variable* var) {
+ if (!builtin::IsHostShareable(var->AddressSpace())) {
+ return;
+ }
+
+ // Rewrite the var type, if it contains vec3s.
+ auto new_store_type = RewriteType(var->Type()->UnwrapRef());
+ if (new_store_type) {
+ ctx.Replace(var->Declaration()->type.expr, new_store_type.expr);
+ }
+ },
+ [&](const sem::Statement* stmt) {
+ // Pack the RHS of assignment statements that are writing to packed types.
+ if (auto* assign = stmt->Declaration()->As<AssignmentStatement>()) {
+ auto* lhs = sem.GetVal(assign->lhs);
+ auto* rhs = sem.GetVal(assign->rhs);
+ if (!ContainsVec3(rhs->Type()) ||
+ !builtin::IsHostShareable(
+ lhs->Type()->As<type::Reference>()->AddressSpace())) {
+ // Skip assignments to address spaces that are not host-shareable, or
+ // that do not contain vec3 types.
+ return;
+ }
+
+ // Pack the RHS expression.
+ if (to_unpack.Contains(rhs)) {
+ // The expression will already be packed, so skip the pending unpack.
+ to_unpack.Remove(rhs);
+
+ // If the expression produces a vec3 from an array element, extract
+ // the packed vector from the wrapper struct.
+ if (IsVec3(rhs->Type()) &&
+ rhs->UnwrapLoad()->Is<sem::IndexAccessorExpression>()) {
+ ctx.Replace(rhs->Declaration(),
+ b.MemberAccessor(ctx.Clone(rhs->Declaration()),
+ kStructMemberName));
+ }
+ } else if (rhs) {
+ to_pack.Add(rhs);
+ }
+ }
+ },
+ [&](const sem::Load* load) {
+ // Unpack loads of types that contain vec3s in host-shareable address spaces.
+ if (ContainsVec3(load->Type()) &&
+ builtin::IsHostShareable(load->ReferenceType()->AddressSpace())) {
+ to_unpack.Add(load);
+ }
+ },
+ [&](const sem::IndexAccessorExpression* accessor) {
+ // If the expression produces a reference to a vec3 in a host-shareable address
+ // space from an array element, extract the packed vector from the wrapper
+ // struct.
+ if (auto* ref = accessor->Type()->As<type::Reference>()) {
+ if (IsVec3(ref->StoreType()) &&
+ builtin::IsHostShareable(ref->AddressSpace())) {
+ ctx.Replace(node, b.MemberAccessor(ctx.Clone(accessor->Declaration()),
+ kStructMemberName));
+ }
+ }
+ });
+ }
+
+ // Sort the pending pack/unpack operations by AST node ID to make the order deterministic.
+ auto to_unpack_sorted = to_unpack.Vector();
+ auto to_pack_sorted = to_pack.Vector();
+ auto pred = [&](auto* expr_a, auto* expr_b) {
+ return expr_a->Declaration()->node_id < expr_b->Declaration()->node_id;
+ };
+ to_unpack_sorted.Sort(pred);
+ to_pack_sorted.Sort(pred);
+
+ // Apply all of the pending unpack operations that we have collected.
+ for (auto* expr : to_unpack_sorted) {
+ TINT_ASSERT(Transform, ContainsVec3(expr->Type()));
+ auto* packed = ctx.Clone(expr->Declaration());
+ const Expression* unpacked = nullptr;
+ if (IsVec3(expr->Type())) {
+ if (expr->UnwrapLoad()->Is<sem::IndexAccessorExpression>()) {
+ // If we are unpacking a vec3 from an array element, extract the vector from the
+ // wrapper struct.
+ packed = b.MemberAccessor(packed, kStructMemberName);
+ }
+ // Cast the packed vector to a regular vec3.
+ unpacked = b.Call(CreateASTTypeFor(ctx, expr->Type()), packed);
+ } else {
+ // Use a helper function to unpack an array or matrix.
+ unpacked = UnpackComposite(packed, expr->Type());
+ }
+ TINT_ASSERT(Transform, unpacked != nullptr);
+ ctx.Replace(expr->Declaration(), unpacked);
+ }
+
+ // Apply all of the pending pack operations that we have collected.
+ for (auto* expr : to_pack_sorted) {
+ TINT_ASSERT(Transform, ContainsVec3(expr->Type()));
+ auto* unpacked = ctx.Clone(expr->Declaration());
+ const Expression* packed = nullptr;
+ if (IsVec3(expr->Type())) {
+ // Cast the regular vec3 to a packed vector type.
+ packed = b.Call(MakePackedVec3(expr->Type()), unpacked);
+ } else {
+ // Use a helper function to pack an array or matrix.
+ packed = PackComposite(unpacked, expr->Type());
+ }
+ TINT_ASSERT(Transform, packed != nullptr);
+ ctx.Replace(expr->Declaration(), packed);
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+
+ private:
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+ /// Alias to the semantic info in ctx.src
+ const sem::Info& sem = ctx.src->Sem();
+};
+
+PackedVec3::PackedVec3() = default;
+PackedVec3::~PackedVec3() = default;
+
+Transform::ApplyResult PackedVec3::Apply(const Program* src, const DataMap&, DataMap&) const {
+ return State{src}.Run();
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/packed_vec3.h b/src/tint/lang/wgsl/ast/transform/packed_vec3.h
new file mode 100644
index 0000000..e03f2c7
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/packed_vec3.h
@@ -0,0 +1,59 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_PACKED_VEC3_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_PACKED_VEC3_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// A transform to be used by the MSL backend which will:
+/// * Replace `vec3<T>` types with an internal `__packed_vec3` type when they are used in
+/// host-shareable address spaces.
+/// * Wrap generated `__packed_vec3` types in a structure when they are used in arrays, so that we
+/// ensure that the array has the correct element stride.
+/// * Multi-version structures that contain `vec3<T>` types when they are used in host-shareable
+/// memory, to avoid modifying uses in other address spaces.
+/// * Rewrite matrix types that have three rows into arrays of column vectors.
+/// * Insert calls to helper functions to convert expressions that use these types to or from the
+/// regular vec3 types when accessing host-shareable memory.
+/// * Cast all direct (not sub-accessed) loads of these packed vectors to the 'unpacked' vec3<T>
+/// type before usage.
+///
+/// This transform is necessary in order to emit vec3 types with the correct size (so that scalars
+/// can follow them in structures), and also to ensure that padding bytes are preserved when writing
+/// to a vec3, an array of vec3 elements, or a matrix with vec3 column type.
+///
+/// @note Depends on the following transforms to have been run first:
+/// * ExpandCompoundAssignment
+class PackedVec3 final : public utils::Castable<PackedVec3, Transform> {
+ public:
+ /// Constructor
+ PackedVec3();
+ /// Destructor
+ ~PackedVec3() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_PACKED_VEC3_H_
diff --git a/src/tint/lang/wgsl/ast/transform/packed_vec3_test.cc b/src/tint/lang/wgsl/ast/transform/packed_vec3_test.cc
new file mode 100644
index 0000000..0461f26
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/packed_vec3_test.cc
@@ -0,0 +1,8095 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/packed_vec3.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/module.h"
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/struct.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/type/array.h"
+#include "src/tint/utils/string.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using PackedVec3Test = TransformTest;
+
+TEST_F(PackedVec3Test, ShouldRun_EmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<PackedVec3>(src));
+}
+
+TEST_F(PackedVec3Test, ShouldRun_NoHostShareableVec3s) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ a : array<vec3<f32>, 4>,
+}
+
+var<private> p_s : S;
+var<private> p_v : vec3<f32>;
+var<private> p_m : mat3x3<f32>;
+var<private> p_a : array<vec3<f32>, 4>;
+
+var<workgroup> w_s : S;
+var<workgroup> w_v : vec3<f32>;
+var<workgroup> w_m : mat3x3<f32>;
+var<workgroup> w_a : array<vec3<f32>, 4>;
+
+fn f() {
+ var f_s : S;
+ var f_v : vec3<f32>;
+ var f_m : mat3x3<f32>;
+ var f_a : array<vec3<f32>, 4>;
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<PackedVec3>(src));
+}
+
+TEST_F(PackedVec3Test, ShouldRun_Vec4Vec2) {
+ auto* src = R"(
+struct S {
+ v4 : vec4<f32>,
+ v2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> Ps : S; // Host sharable
+@group(0) @binding(1) var<uniform> Pv4 : vec4<f32>; // Host sharable
+@group(0) @binding(2) var<uniform> Pv2 : vec2<f32>; // Host sharable
+)";
+
+ EXPECT_FALSE(ShouldRun<PackedVec3>(src));
+}
+
+TEST_F(PackedVec3Test, ShouldRun_OtherMatrices) {
+ auto* src = R"(
+struct S {
+ m2x2 : mat2x2<f32>,
+ m2x4 : mat2x4<f32>,
+ m3x2 : mat3x2<f32>,
+ m3x4 : mat3x4<f32>,
+ m4x2 : mat4x2<f32>,
+ m4x4 : mat4x4<f32>,
+}
+
+@group(0) @binding(0) var<uniform> Ps : S; // Host sharable
+@group(0) @binding(1) var<uniform> Pm2x2 : mat2x2<f32>; // Host sharable
+@group(0) @binding(2) var<uniform> Pm2x4 : mat2x4<f32>; // Host sharable
+@group(0) @binding(3) var<uniform> Pm3x2 : mat3x2<f32>; // Host sharable
+@group(0) @binding(4) var<uniform> Pm3x4 : mat3x4<f32>; // Host sharable
+@group(0) @binding(5) var<uniform> Pm4x2 : mat4x2<f32>; // Host sharable
+@group(0) @binding(6) var<uniform> Pm4x4 : mat4x4<f32>; // Host sharable
+)";
+
+ EXPECT_FALSE(ShouldRun<PackedVec3>(src));
+}
+
+TEST_F(PackedVec3Test, ShouldRun_ArrayOfNonVec3) {
+ auto* src = R"(
+struct S {
+ arr_v : array<vec2<f32>, 4>,
+ arr_m : array<mat3x2<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> Ps : S; // Host sharable
+@group(0) @binding(1) var<storage> Parr_v : array<vec2<f32>, 4>; // Host sharable
+@group(0) @binding(2) var<storage> Parr_m : array<mat3x2<f32>, 4>; // Host sharable
+)";
+
+ EXPECT_FALSE(ShouldRun<PackedVec3>(src));
+}
+
+TEST_F(PackedVec3Test, ShouldRun_HostSharable_Vec3) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> P : vec3<f32>; // Host sharable
+)";
+
+ EXPECT_TRUE(ShouldRun<PackedVec3>(src));
+}
+
+TEST_F(PackedVec3Test, ShouldRun_HostSharable_Mat3x3) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> P : mat3x3<f32>; // Host sharable
+)";
+
+ EXPECT_TRUE(ShouldRun<PackedVec3>(src));
+}
+
+TEST_F(PackedVec3Test, ShouldRun_HostSharable_ArrayOfVec3) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> P : array<vec3<f32>>; // Host sharable
+)";
+
+ EXPECT_TRUE(ShouldRun<PackedVec3>(src));
+}
+
+TEST_F(PackedVec3Test, ShouldRun_HostSharable_ArrayOfMat3x3) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> P : array<mat3x3<f32>>; // Host sharable
+)";
+
+ EXPECT_TRUE(ShouldRun<PackedVec3>(src));
+}
+
+TEST_F(PackedVec3Test, ShouldRun_HostSharableStruct_Vec3) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<uniform> P : S; // Host sharable
+)";
+
+ EXPECT_TRUE(ShouldRun<PackedVec3>(src));
+}
+
+TEST_F(PackedVec3Test, ShouldRun_HostSharableStruct_Mat3x3) {
+ auto* src = R"(
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<uniform> P : S; // Host sharable
+)";
+
+ EXPECT_TRUE(ShouldRun<PackedVec3>(src));
+}
+
+TEST_F(PackedVec3Test, ShouldRun_HostSharableStruct_ArrayOfVec3) {
+ auto* src = R"(
+struct S {
+ a : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<uniform> P : S; // Host sharable
+)";
+
+ EXPECT_TRUE(ShouldRun<PackedVec3>(src));
+}
+
+TEST_F(PackedVec3Test, ShouldRun_HostSharableStruct_ArrayOfMat3x3) {
+ auto* src = R"(
+struct S {
+ a : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<uniform> P : S; // Host sharable
+)";
+
+ EXPECT_TRUE(ShouldRun<PackedVec3>(src));
+}
+
+TEST_F(PackedVec3Test, Vec3_ReadVector) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> v : vec3<f32>;
+
+fn f() {
+ let x = v;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+@group(0) @binding(0) var<storage> v : __packed_vec3<f32>;
+
+fn f() {
+ let x = vec3<f32>(v);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Vec3_ReadComponent_MemberAccessChain) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> v : vec3<f32>;
+
+fn f() {
+ let x = v.yz.x;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+@group(0) @binding(0) var<storage> v : __packed_vec3<f32>;
+
+fn f() {
+ let x = vec3<f32>(v).yz.x;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Vec3_ReadComponent_IndexAccessor) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> v : vec3<f32>;
+
+fn f() {
+ let x = v[1];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+@group(0) @binding(0) var<storage> v : __packed_vec3<f32>;
+
+fn f() {
+ let x = v[1];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Vec3_WriteVector_ValueRHS) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> v : vec3<f32>;
+
+fn f() {
+ v = vec3(1.23);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+@group(0) @binding(0) var<storage, read_write> v : __packed_vec3<f32>;
+
+fn f() {
+ v = __packed_vec3<f32>(vec3(1.22999999999999998224));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Vec3_WriteVector_RefRHS) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> v : vec3<f32>;
+@group(0) @binding(1) var<uniform> in : vec3<f32>;
+
+fn f() {
+ v = in;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+@group(0) @binding(0) var<storage, read_write> v : __packed_vec3<f32>;
+
+@group(0) @binding(1) var<uniform> in : __packed_vec3<f32>;
+
+fn f() {
+ v = in;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Vec3_WriteComponent_MemberAccessor) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> v : vec3<f32>;
+
+fn f() {
+ v.y = 1.23;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+@group(0) @binding(0) var<storage, read_write> v : __packed_vec3<f32>;
+
+fn f() {
+ v.y = 1.22999999999999998224;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Vec3_WriteComponent_IndexAccessor) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> v : vec3<f32>;
+
+fn f() {
+ v[1] = 1.23;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+@group(0) @binding(0) var<storage, read_write> v : __packed_vec3<f32>;
+
+fn f() {
+ v[1] = 1.22999999999999998224;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfVec3_ReadArray) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> arr : array<vec3<f32>, 4>;
+
+fn f() {
+ let x = arr;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 4u>) -> array<vec3<f32>, 4u> {
+ var result : array<vec3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+@group(0) @binding(0) var<storage> arr : array<tint_packed_vec3_f32_array_element, 4u>;
+
+fn f() {
+ let x = tint_unpack_vec3_in_composite(arr);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfVec3_ReadVector) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> arr : array<vec3<f32>, 4>;
+
+fn f() {
+ let x = arr[0];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage> arr : array<tint_packed_vec3_f32_array_element, 4u>;
+
+fn f() {
+ let x = vec3<f32>(arr[0].elements);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfVec3_ReadComponent_MemberAccessor) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> arr : array<vec3<f32>, 4>;
+
+fn f() {
+ let x = arr[0].y;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage> arr : array<tint_packed_vec3_f32_array_element, 4u>;
+
+fn f() {
+ let x = arr[0].elements.y;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfVec3_ReadComponent_IndexAccessor) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> arr : array<vec3<f32>, 4>;
+
+fn f() {
+ let x = arr[0][1];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage> arr : array<tint_packed_vec3_f32_array_element, 4u>;
+
+fn f() {
+ let x = arr[0].elements[1];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfVec3_WriteArray_ValueRHS) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> arr : array<vec3<f32>, 2>;
+
+fn f() {
+ arr = array(vec3(1.5, 2.5, 3.5), vec3(4.5, 5.5, 6.5));
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+fn tint_pack_vec3_in_composite(in : array<vec3<f32>, 2u>) -> array<tint_packed_vec3_f32_array_element, 2u> {
+ var result : array<tint_packed_vec3_f32_array_element, 2u>;
+ for(var i : u32; (i < 2u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+@group(0) @binding(0) var<storage, read_write> arr : array<tint_packed_vec3_f32_array_element, 2u>;
+
+fn f() {
+ arr = tint_pack_vec3_in_composite(array(vec3(1.5, 2.5, 3.5), vec3(4.5, 5.5, 6.5)));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfVec3_WriteArray_RefRHS) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> arr : array<vec3<f32>, 2>;
+@group(0) @binding(1) var<uniform> in : array<vec3<f32>, 2>;
+
+fn f() {
+ arr = in;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr : array<tint_packed_vec3_f32_array_element, 2u>;
+
+@group(0) @binding(1) var<uniform> in : array<tint_packed_vec3_f32_array_element, 2u>;
+
+fn f() {
+ arr = in;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfVec3_WriteVector_ValueRHS) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> arr : array<vec3<f32>, 4>;
+
+fn f() {
+ arr[0] = vec3(1.23);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr : array<tint_packed_vec3_f32_array_element, 4u>;
+
+fn f() {
+ arr[0].elements = __packed_vec3<f32>(vec3(1.22999999999999998224));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfVec3_WriteVector_RefRHS) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> arr : array<vec3<f32>, 4>;
+@group(0) @binding(1) var<uniform> in_arr : array<vec3<f32>, 4>;
+@group(0) @binding(2) var<uniform> in_vec : vec3<f32>;
+
+fn f() {
+ arr[0] = in_arr[0];
+ arr[1] = in_vec;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr : array<tint_packed_vec3_f32_array_element, 4u>;
+
+@group(0) @binding(1) var<uniform> in_arr : array<tint_packed_vec3_f32_array_element, 4u>;
+
+@group(0) @binding(2) var<uniform> in_vec : __packed_vec3<f32>;
+
+fn f() {
+ arr[0].elements = in_arr[0].elements;
+ arr[1].elements = in_vec;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfVec3_WriteComponent_MemberAccessor) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> arr : array<vec3<f32>, 4>;
+
+fn f() {
+ arr[0].y = 1.23;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr : array<tint_packed_vec3_f32_array_element, 4u>;
+
+fn f() {
+ arr[0].elements.y = 1.22999999999999998224;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfVec3_WriteComponent_IndexAccessor) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> arr : array<vec3<f32>, 4>;
+
+fn f() {
+ arr[0][1] = 1.23;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr : array<tint_packed_vec3_f32_array_element, 4u>;
+
+fn f() {
+ arr[0].elements[1] = 1.22999999999999998224;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Matrix_ReadMatrix) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> m : mat3x3<f32>;
+
+fn f() {
+ let x = m;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+@group(0) @binding(0) var<storage> m : array<tint_packed_vec3_f32_array_element, 3u>;
+
+fn f() {
+ let x = tint_unpack_vec3_in_composite(m);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Matrix_ReadColumn) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> m : mat3x3<f32>;
+
+fn f() {
+ let x = m[1];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage> m : array<tint_packed_vec3_f32_array_element, 3u>;
+
+fn f() {
+ let x = vec3<f32>(m[1].elements);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Matrix_ReadComponent_MemberAccessChain) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> m : mat3x3<f32>;
+
+fn f() {
+ let x = m[1].yz.x;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage> m : array<tint_packed_vec3_f32_array_element, 3u>;
+
+fn f() {
+ let x = vec3<f32>(m[1].elements).yz.x;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Matrix_ReadComponent_IndexAccessor) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> m : mat3x3<f32>;
+
+fn f() {
+ let x = m[2][1];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage> m : array<tint_packed_vec3_f32_array_element, 3u>;
+
+fn f() {
+ let x = m[2].elements[1];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Matrix_WriteMatrix_ValueRHS) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> m : mat3x3<f32>;
+
+fn f() {
+ m = mat3x3(1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+fn tint_pack_vec3_in_composite(in : mat3x3<f32>) -> array<tint_packed_vec3_f32_array_element, 3u> {
+ var result : array<tint_packed_vec3_f32_array_element, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+@group(0) @binding(0) var<storage, read_write> m : array<tint_packed_vec3_f32_array_element, 3u>;
+
+fn f() {
+ m = tint_pack_vec3_in_composite(mat3x3(1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Matrix_WriteMatrix_RefRHS) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> m : mat3x3<f32>;
+@group(0) @binding(1) var<uniform> in : mat3x3<f32>;
+
+fn f() {
+ m = in;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> m : array<tint_packed_vec3_f32_array_element, 3u>;
+
+@group(0) @binding(1) var<uniform> in : array<tint_packed_vec3_f32_array_element, 3u>;
+
+fn f() {
+ m = in;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Matrix_WriteColumn_ValueRHS) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> m : mat3x3<f32>;
+
+fn f() {
+ m[1] = vec3(1.23);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> m : array<tint_packed_vec3_f32_array_element, 3u>;
+
+fn f() {
+ m[1].elements = __packed_vec3<f32>(vec3(1.22999999999999998224));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Matrix_WriteColumn_RefRHS) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> m : mat3x3<f32>;
+@group(0) @binding(1) var<uniform> in_mat : mat3x3<f32>;
+@group(0) @binding(1) var<uniform> in_vec : vec3<f32>;
+
+fn f() {
+ m[0] = in_mat[0];
+ m[1] = in_vec;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> m : array<tint_packed_vec3_f32_array_element, 3u>;
+
+@group(0) @binding(1) var<uniform> in_mat : array<tint_packed_vec3_f32_array_element, 3u>;
+
+@group(0) @binding(1) var<uniform> in_vec : __packed_vec3<f32>;
+
+fn f() {
+ m[0].elements = in_mat[0].elements;
+ m[1].elements = in_vec;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Matrix_WriteComponent_MemberAccessor) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> m : mat3x3<f32>;
+
+fn f() {
+ m[1].y = 1.23;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> m : array<tint_packed_vec3_f32_array_element, 3u>;
+
+fn f() {
+ m[1].elements.y = 1.22999999999999998224;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Matrix_WriteComponent_IndexAccessor) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> m : mat3x3<f32>;
+
+fn f() {
+ m[1][2] = 1.23;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> m : array<tint_packed_vec3_f32_array_element, 3u>;
+
+fn f() {
+ m[1].elements[2] = 1.22999999999999998224;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfMatrix_ReadArray) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> arr : array<mat3x3<f32>, 4>;
+
+fn f() {
+ let x = arr;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_1(in : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>) -> array<mat3x3<f32>, 4u> {
+ var result : array<mat3x3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+@group(0) @binding(0) var<storage> arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+fn f() {
+ let x = tint_unpack_vec3_in_composite_1(arr);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfMatrix_ReadMatrix) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> arr : array<mat3x3<f32>, 4>;
+
+fn f() {
+ let x = arr[0];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+@group(0) @binding(0) var<storage> arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+fn f() {
+ let x = tint_unpack_vec3_in_composite(arr[0]);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfMatrix_ReadColumn) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> arr : array<mat3x3<f32>, 4>;
+
+fn f() {
+ let x = arr[0][1];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage> arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+fn f() {
+ let x = vec3<f32>(arr[0][1].elements);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfMatrix_ReadComponent_MemberAccessor) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> arr : array<mat3x3<f32>, 4>;
+
+fn f() {
+ let x = arr[0][1].y;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage> arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+fn f() {
+ let x = arr[0][1].elements.y;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfMatrix_ReadComponent_IndexAccessor) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage> arr : array<mat3x3<f32>, 4>;
+
+fn f() {
+ let x = arr[0][1][2];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage> arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+fn f() {
+ let x = arr[0][1].elements[2];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfMatrix_WriteArray_ValueRHS) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> arr : array<mat3x3<f32>, 2>;
+
+fn f() {
+ arr = array(mat3x3<f32>(), mat3x3(1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5));
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+fn tint_pack_vec3_in_composite(in : mat3x3<f32>) -> array<tint_packed_vec3_f32_array_element, 3u> {
+ var result : array<tint_packed_vec3_f32_array_element, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_1(in : array<mat3x3<f32>, 2u>) -> array<array<tint_packed_vec3_f32_array_element, 3u>, 2u> {
+ var result : array<array<tint_packed_vec3_f32_array_element, 3u>, 2u>;
+ for(var i : u32; (i < 2u); i = (i + 1)) {
+ result[i] = tint_pack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+@group(0) @binding(0) var<storage, read_write> arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 2u>;
+
+fn f() {
+ arr = tint_pack_vec3_in_composite_1(array(mat3x3<f32>(), mat3x3(1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5)));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfMatrix_WriteArray_RefRHS) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> arr : array<mat3x3<f32>, 2>;
+@group(0) @binding(1) var<uniform> in : array<mat3x3<f32>, 2>;
+
+fn f() {
+ arr = in;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 2u>;
+
+@group(0) @binding(1) var<uniform> in : array<array<tint_packed_vec3_f32_array_element, 3u>, 2u>;
+
+fn f() {
+ arr = in;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfMatrix_WriteMatrix_ValueRHS) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> arr : array<mat3x3<f32>, 4>;
+
+fn f() {
+ arr[0] = mat3x3(1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+fn tint_pack_vec3_in_composite(in : mat3x3<f32>) -> array<tint_packed_vec3_f32_array_element, 3u> {
+ var result : array<tint_packed_vec3_f32_array_element, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+@group(0) @binding(0) var<storage, read_write> arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+fn f() {
+ arr[0] = tint_pack_vec3_in_composite(mat3x3(1.10000000000000008882, 2.20000000000000017764, 3.29999999999999982236, 4.40000000000000035527, 5.5, 6.59999999999999964473, 7.70000000000000017764, 8.80000000000000071054, 9.90000000000000035527));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfMatrix_WriteMatrix_RefRHS) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> arr : array<mat3x3<f32>, 4>;
+@group(0) @binding(1) var<uniform> in_arr : array<mat3x3<f32>, 4>;
+@group(0) @binding(2) var<uniform> in_mat : mat3x3<f32>;
+
+fn f() {
+ arr[0] = in_arr[0];
+ arr[1] = in_mat;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+@group(0) @binding(1) var<uniform> in_arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+@group(0) @binding(2) var<uniform> in_mat : array<tint_packed_vec3_f32_array_element, 3u>;
+
+fn f() {
+ arr[0] = in_arr[0];
+ arr[1] = in_mat;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfMatrix_WriteVector_ValueRHS) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> arr : array<mat3x3<f32>, 4>;
+
+fn f() {
+ arr[0][1] = vec3(1.23);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+fn f() {
+ arr[0][1].elements = __packed_vec3<f32>(vec3(1.22999999999999998224));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfMatrix_WriteVector_RefRHS) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> arr : array<mat3x3<f32>, 4>;
+@group(0) @binding(1) var<uniform> in_arr : array<mat3x3<f32>, 4>;
+@group(0) @binding(2) var<uniform> in_mat : mat3x3<f32>;
+@group(0) @binding(3) var<uniform> in_vec : vec3<f32>;
+
+fn f() {
+ arr[0][0] = arr[0][1];
+ arr[0][1] = in_mat[2];
+ arr[0][2] = in_vec;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+@group(0) @binding(1) var<uniform> in_arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+@group(0) @binding(2) var<uniform> in_mat : array<tint_packed_vec3_f32_array_element, 3u>;
+
+@group(0) @binding(3) var<uniform> in_vec : __packed_vec3<f32>;
+
+fn f() {
+ arr[0][0].elements = arr[0][1].elements;
+ arr[0][1].elements = in_mat[2].elements;
+ arr[0][2].elements = in_vec;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfMatrix_WriteComponent_MemberAccessor) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> arr : array<mat3x3<f32>, 4>;
+
+fn f() {
+ arr[0][1].y = 1.23;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+fn f() {
+ arr[0][1].elements.y = 1.22999999999999998224;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfMatrix_WriteComponent_IndexAccessor) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> arr : array<mat3x3<f32>, 4>;
+
+fn f() {
+ arr[0][1][2] = 1.23;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+fn f() {
+ arr[0][1].elements[2] = 1.22999999999999998224;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Vec3_ReadStruct) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+}
+
+fn tint_unpack_vec3_in_composite(in : S_tint_packed_vec3) -> S {
+ var result : S;
+ result.v = vec3<f32>(in.v);
+ return result;
+}
+
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = tint_unpack_vec3_in_composite(P);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Vec3_ReadVector) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P.v;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+}
+
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = vec3<f32>(P.v);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Vec3_ReadComponent_MemberAccessChain) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P.v.yz.x;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+}
+
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = vec3<f32>(P.v).yz.x;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Vec3_ReadComponent_IndexAccessor) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P.v[1];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+}
+
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = P.v[1];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Vec3_WriteStruct_ValueRHS) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P = S(vec3(1.23));
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+}
+
+fn tint_pack_vec3_in_composite(in : S) -> S_tint_packed_vec3 {
+ var result : S_tint_packed_vec3;
+ result.v = __packed_vec3<f32>(in.v);
+ return result;
+}
+
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P = tint_pack_vec3_in_composite(S(vec3(1.22999999999999998224)));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Vec3_WriteStruct_RefRHS) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+@group(0) @binding(1) var<uniform> in : S;
+
+fn f() {
+ P = in;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+}
+
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+@group(0) @binding(1) var<uniform> in : S_tint_packed_vec3;
+
+fn f() {
+ P = in;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Vec3_WriteVector_ValueRHS) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P.v = vec3(1.23);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+}
+
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P.v = __packed_vec3<f32>(vec3(1.22999999999999998224));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Vec3_WriteVector_RefRHS) {
+ auto* src = R"(
+struct S {
+ v1 : vec3<f32>,
+ v2 : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+@group(0) @binding(1) var<uniform> in_str : S;
+@group(0) @binding(2) var<uniform> in_vec : vec3<f32>;
+
+fn f() {
+ P.v1 = in_str.v1;
+ P.v2 = in_vec;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v1 : __packed_vec3<f32>,
+ @align(16)
+ v2 : __packed_vec3<f32>,
+}
+
+struct S {
+ v1 : vec3<f32>,
+ v2 : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+@group(0) @binding(1) var<uniform> in_str : S_tint_packed_vec3;
+
+@group(0) @binding(2) var<uniform> in_vec : __packed_vec3<f32>;
+
+fn f() {
+ P.v1 = in_str.v1;
+ P.v2 = in_vec;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Vec3_WriteComponent_MemberAccessor) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P.v.y = 1.23;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+}
+
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P.v.y = 1.22999999999999998224;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Vec3_WriteComponent_IndexAccessor) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P.v[1] = 1.23;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+}
+
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P.v[1] = 1.22999999999999998224;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfVec3_ReadStruct) {
+ auto* src = R"(
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 4u>) -> array<vec3<f32>, 4u> {
+ var result : array<vec3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_1(in : S_tint_packed_vec3) -> S {
+ var result : S;
+ result.arr = tint_unpack_vec3_in_composite(in.arr);
+ return result;
+}
+
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = tint_unpack_vec3_in_composite_1(P);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfVec3_ReadArray) {
+ auto* src = R"(
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P.arr;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 4u>) -> array<vec3<f32>, 4u> {
+ var result : array<vec3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = tint_unpack_vec3_in_composite(P.arr);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfVec3_ReadVector) {
+ auto* src = R"(
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P.arr[0];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = vec3<f32>(P.arr[0].elements);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfVec3_ReadComponent_MemberAccessor) {
+ auto* src = R"(
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P.arr[0].y;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = P.arr[0].elements.y;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfVec3_ReadComponent_IndexAccessor) {
+ auto* src = R"(
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P.arr[0][1];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = P.arr[0].elements[1];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfVec3_WriteStruct_ValueRHS) {
+ auto* src = R"(
+struct S {
+ arr : array<vec3<f32>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P = S(array(vec3(1.5, 4.5, 7.5), vec3(9.5, 6.5, 3.5)));
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 2u>,
+}
+
+fn tint_pack_vec3_in_composite(in : array<vec3<f32>, 2u>) -> array<tint_packed_vec3_f32_array_element, 2u> {
+ var result : array<tint_packed_vec3_f32_array_element, 2u>;
+ for(var i : u32; (i < 2u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_1(in : S) -> S_tint_packed_vec3 {
+ var result : S_tint_packed_vec3;
+ result.arr = tint_pack_vec3_in_composite(in.arr);
+ return result;
+}
+
+struct S {
+ arr : array<vec3<f32>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P = tint_pack_vec3_in_composite_1(S(array(vec3(1.5, 4.5, 7.5), vec3(9.5, 6.5, 3.5))));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfVec3_WriteStruct_RefRHS) {
+ auto* src = R"(
+struct S {
+ arr : array<vec3<f32>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+@group(0) @binding(1) var<uniform> in : S;
+
+fn f() {
+ P = in;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 2u>,
+}
+
+struct S {
+ arr : array<vec3<f32>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+@group(0) @binding(1) var<uniform> in : S_tint_packed_vec3;
+
+fn f() {
+ P = in;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfVec3_WriteArray_ValueRHS) {
+ auto* src = R"(
+struct S {
+ arr : array<vec3<f32>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P.arr = array(vec3(1.5, 4.5, 7.5), vec3(9.5, 6.5, 3.5));
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 2u>,
+}
+
+fn tint_pack_vec3_in_composite(in : array<vec3<f32>, 2u>) -> array<tint_packed_vec3_f32_array_element, 2u> {
+ var result : array<tint_packed_vec3_f32_array_element, 2u>;
+ for(var i : u32; (i < 2u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+struct S {
+ arr : array<vec3<f32>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P.arr = tint_pack_vec3_in_composite(array(vec3(1.5, 4.5, 7.5), vec3(9.5, 6.5, 3.5)));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfVec3_WriteArray_RefRHS) {
+ auto* src = R"(
+struct S {
+ arr1 : array<vec3<f32>, 2>,
+ arr2 : array<vec3<f32>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+@group(0) @binding(1) var<uniform> in_str : S;
+@group(0) @binding(2) var<uniform> in_arr : array<vec3<f32>, 2>;
+
+fn f() {
+ P.arr1 = in_str.arr1;
+ P.arr2 = in_arr;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr1 : array<tint_packed_vec3_f32_array_element, 2u>,
+ @align(16)
+ arr2 : array<tint_packed_vec3_f32_array_element, 2u>,
+}
+
+struct S {
+ arr1 : array<vec3<f32>, 2>,
+ arr2 : array<vec3<f32>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+@group(0) @binding(1) var<uniform> in_str : S_tint_packed_vec3;
+
+@group(0) @binding(2) var<uniform> in_arr : array<tint_packed_vec3_f32_array_element, 2u>;
+
+fn f() {
+ P.arr1 = in_str.arr1;
+ P.arr2 = in_arr;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfVec3_WriteVector_ValueRHS) {
+ auto* src = R"(
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P.arr[0] = vec3(1.23);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P.arr[0].elements = __packed_vec3<f32>(vec3(1.22999999999999998224));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfVec3_WriteVector_RefRHS) {
+ auto* src = R"(
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+@group(0) @binding(1) var<uniform> in_str : S;
+@group(0) @binding(2) var<uniform> in_arr : array<vec3<f32>, 4>;
+@group(0) @binding(3) var<uniform> in_vec : vec3<f32>;
+
+fn f() {
+ P.arr[0] = in_str.arr[0];
+ P.arr[1] = in_arr[1];
+ P.arr[2] = in_vec;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+@group(0) @binding(1) var<uniform> in_str : S_tint_packed_vec3;
+
+@group(0) @binding(2) var<uniform> in_arr : array<tint_packed_vec3_f32_array_element, 4u>;
+
+@group(0) @binding(3) var<uniform> in_vec : __packed_vec3<f32>;
+
+fn f() {
+ P.arr[0].elements = in_str.arr[0].elements;
+ P.arr[1].elements = in_arr[1].elements;
+ P.arr[2].elements = in_vec;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfVec3_WriteComponent_MemberAccessor) {
+ auto* src = R"(
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P.arr[0].y = 1.23;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P.arr[0].elements.y = 1.22999999999999998224;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfVec3_WriteComponent_IndexAccessor) {
+ auto* src = R"(
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P.arr[0][1] = 1.23;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P.arr[0].elements[1] = 1.22999999999999998224;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Matrix_ReadStruct) {
+ auto* src = R"(
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_1(in : S_tint_packed_vec3) -> S {
+ var result : S;
+ result.m = tint_unpack_vec3_in_composite(in.m);
+ return result;
+}
+
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = tint_unpack_vec3_in_composite_1(P);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Matrix_ReadMatrix) {
+ auto* src = R"(
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P.m;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = tint_unpack_vec3_in_composite(P.m);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Matrix_ReadColumn) {
+ auto* src = R"(
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P.m[1];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = vec3<f32>(P.m[1].elements);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Matrix_ReadComponent_MemberAccessChain) {
+ auto* src = R"(
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P.m[1].yz.x;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = vec3<f32>(P.m[1].elements).yz.x;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Matrix_ReadComponent_IndexAccessor) {
+ auto* src = R"(
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P.m[2][1];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = P.m[2].elements[1];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Matrix_WriteStruct_ValueRHS) {
+ auto* src = R"(
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P = S(mat3x3(1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5));
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+fn tint_pack_vec3_in_composite(in : mat3x3<f32>) -> array<tint_packed_vec3_f32_array_element, 3u> {
+ var result : array<tint_packed_vec3_f32_array_element, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_1(in : S) -> S_tint_packed_vec3 {
+ var result : S_tint_packed_vec3;
+ result.m = tint_pack_vec3_in_composite(in.m);
+ return result;
+}
+
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P = tint_pack_vec3_in_composite_1(S(mat3x3(1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5)));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Matrix_WriteStruct_RefRHS) {
+ auto* src = R"(
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+@group(0) @binding(1) var<uniform> in : S;
+
+fn f() {
+ P = in;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+@group(0) @binding(1) var<uniform> in : S_tint_packed_vec3;
+
+fn f() {
+ P = in;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Matrix_WriteMatrix_ValueRHS) {
+ auto* src = R"(
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P.m = mat3x3(1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+fn tint_pack_vec3_in_composite(in : mat3x3<f32>) -> array<tint_packed_vec3_f32_array_element, 3u> {
+ var result : array<tint_packed_vec3_f32_array_element, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P.m = tint_pack_vec3_in_composite(mat3x3(1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Matrix_WriteMatrix_RefRHS) {
+ auto* src = R"(
+struct S {
+ m1 : mat3x3<f32>,
+ m2 : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+@group(0) @binding(1) var<uniform> in_str : S;
+@group(0) @binding(2) var<uniform> in_mat : mat3x3<f32>;
+
+fn f() {
+ P.m1 = in_str.m1;
+ P.m2 = in_mat;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ m1 : array<tint_packed_vec3_f32_array_element, 3u>,
+ @align(16)
+ m2 : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+struct S {
+ m1 : mat3x3<f32>,
+ m2 : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+@group(0) @binding(1) var<uniform> in_str : S_tint_packed_vec3;
+
+@group(0) @binding(2) var<uniform> in_mat : array<tint_packed_vec3_f32_array_element, 3u>;
+
+fn f() {
+ P.m1 = in_str.m1;
+ P.m2 = in_mat;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Matrix_WriteColumn_ValueRHS) {
+ auto* src = R"(
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P.m[1] = vec3(1.23);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P.m[1].elements = __packed_vec3<f32>(vec3(1.22999999999999998224));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Matrix_WriteColumn_RefRHS) {
+ auto* src = R"(
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+@group(0) @binding(1) var<uniform> in_str : S;
+@group(0) @binding(2) var<uniform> in_mat : mat3x3<f32>;
+@group(0) @binding(3) var<uniform> in_vec : vec3<f32>;
+
+fn f() {
+ P.m[0] = in_str.m[0];
+ P.m[1] = in_mat[1];
+ P.m[2] = in_vec;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+@group(0) @binding(1) var<uniform> in_str : S_tint_packed_vec3;
+
+@group(0) @binding(2) var<uniform> in_mat : array<tint_packed_vec3_f32_array_element, 3u>;
+
+@group(0) @binding(3) var<uniform> in_vec : __packed_vec3<f32>;
+
+fn f() {
+ P.m[0].elements = in_str.m[0].elements;
+ P.m[1].elements = in_mat[1].elements;
+ P.m[2].elements = in_vec;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Matrix_WriteComponent_MemberAccessor) {
+ auto* src = R"(
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P.m[1].y = 1.23;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P.m[1].elements.y = 1.22999999999999998224;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_Matrix_WriteComponent_IndexAccessor) {
+ auto* src = R"(
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P.m[1][2] = 1.23;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P.m[1].elements[2] = 1.22999999999999998224;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfMatrix_ReadStruct) {
+ auto* src = R"(
+struct S {
+ arr : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_1(in : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>) -> array<mat3x3<f32>, 4u> {
+ var result : array<mat3x3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_2(in : S_tint_packed_vec3) -> S {
+ var result : S;
+ result.arr = tint_unpack_vec3_in_composite_1(in.arr);
+ return result;
+}
+
+struct S {
+ arr : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = tint_unpack_vec3_in_composite_2(P);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfMatrix_ReadArray) {
+ auto* src = R"(
+struct S {
+ arr : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P.arr;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_1(in : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>) -> array<mat3x3<f32>, 4u> {
+ var result : array<mat3x3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+struct S {
+ arr : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = tint_unpack_vec3_in_composite_1(P.arr);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfMatrix_ReadMatrix) {
+ auto* src = R"(
+struct S {
+ arr : array<mat3x3<f32>, 4u>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P.arr[0];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+struct S {
+ arr : array<mat3x3<f32>, 4u>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = tint_unpack_vec3_in_composite(P.arr[0]);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfMatrix_ReadColumn) {
+ auto* src = R"(
+struct S {
+ arr : array<mat3x3<f32>, 4u>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P.arr[0][1];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+struct S {
+ arr : array<mat3x3<f32>, 4u>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = vec3<f32>(P.arr[0][1].elements);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfMatrix_ReadComponent_MemberAccessor) {
+ auto* src = R"(
+struct S {
+ arr : array<mat3x3<f32>, 4u>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P.arr[0][1].y;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+struct S {
+ arr : array<mat3x3<f32>, 4u>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = P.arr[0][1].elements.y;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfMatrix_ReadComponent_IndexAccessor) {
+ auto* src = R"(
+struct S {
+ arr : array<mat3x3<f32>, 4u>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let x = P.arr[0][1][2];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+struct S {
+ arr : array<mat3x3<f32>, 4u>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = P.arr[0][1].elements[2];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfMatrix_WriteStruct_ValueRHS) {
+ auto* src = R"(
+struct S {
+ arr : array<mat3x3<f32>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P = S(array(mat3x3<f32>(), mat3x3(1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5)));
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 2u>,
+}
+
+fn tint_pack_vec3_in_composite(in : mat3x3<f32>) -> array<tint_packed_vec3_f32_array_element, 3u> {
+ var result : array<tint_packed_vec3_f32_array_element, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_1(in : array<mat3x3<f32>, 2u>) -> array<array<tint_packed_vec3_f32_array_element, 3u>, 2u> {
+ var result : array<array<tint_packed_vec3_f32_array_element, 3u>, 2u>;
+ for(var i : u32; (i < 2u); i = (i + 1)) {
+ result[i] = tint_pack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_2(in : S) -> S_tint_packed_vec3 {
+ var result : S_tint_packed_vec3;
+ result.arr = tint_pack_vec3_in_composite_1(in.arr);
+ return result;
+}
+
+struct S {
+ arr : array<mat3x3<f32>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P = tint_pack_vec3_in_composite_2(S(array(mat3x3<f32>(), mat3x3(1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5))));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfMatrix_WriteStruct_RefRHS) {
+ auto* src = R"(
+struct S {
+ arr : array<mat3x3<f32>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+@group(0) @binding(1) var<uniform> in : S;
+
+fn f() {
+ P = in;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 2u>,
+}
+
+struct S {
+ arr : array<mat3x3<f32>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+@group(0) @binding(1) var<uniform> in : S_tint_packed_vec3;
+
+fn f() {
+ P = in;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfMatrix_WriteArray_ValueRHS) {
+ auto* src = R"(
+struct S {
+ arr : array<mat3x3<f32>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P.arr = array(mat3x3<f32>(), mat3x3(1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5));
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 2u>,
+}
+
+fn tint_pack_vec3_in_composite(in : mat3x3<f32>) -> array<tint_packed_vec3_f32_array_element, 3u> {
+ var result : array<tint_packed_vec3_f32_array_element, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_1(in : array<mat3x3<f32>, 2u>) -> array<array<tint_packed_vec3_f32_array_element, 3u>, 2u> {
+ var result : array<array<tint_packed_vec3_f32_array_element, 3u>, 2u>;
+ for(var i : u32; (i < 2u); i = (i + 1)) {
+ result[i] = tint_pack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+struct S {
+ arr : array<mat3x3<f32>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P.arr = tint_pack_vec3_in_composite_1(array(mat3x3<f32>(), mat3x3(1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5)));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfMatrix_WriteArray_RefRHS) {
+ auto* src = R"(
+struct S {
+ arr1 : array<mat3x3<f32>, 2>,
+ arr2 : array<mat3x3<f32>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+@group(0) @binding(1) var<uniform> in_str : S;
+@group(0) @binding(2) var<uniform> in_arr : array<mat3x3<f32>, 2>;
+
+fn f() {
+ P.arr1 = in_str.arr1;
+ P.arr2 = in_arr;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr1 : array<array<tint_packed_vec3_f32_array_element, 3u>, 2u>,
+ @align(16)
+ arr2 : array<array<tint_packed_vec3_f32_array_element, 3u>, 2u>,
+}
+
+struct S {
+ arr1 : array<mat3x3<f32>, 2>,
+ arr2 : array<mat3x3<f32>, 2>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+@group(0) @binding(1) var<uniform> in_str : S_tint_packed_vec3;
+
+@group(0) @binding(2) var<uniform> in_arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 2u>;
+
+fn f() {
+ P.arr1 = in_str.arr1;
+ P.arr2 = in_arr;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfMatrix_WriteMatrix_ValueRHS) {
+ auto* src = R"(
+struct S {
+ arr : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P.arr[0] = mat3x3(1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+fn tint_pack_vec3_in_composite(in : mat3x3<f32>) -> array<tint_packed_vec3_f32_array_element, 3u> {
+ var result : array<tint_packed_vec3_f32_array_element, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+struct S {
+ arr : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P.arr[0] = tint_pack_vec3_in_composite(mat3x3(1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfMatrix_WriteMatrix_RefRHS) {
+ auto* src = R"(
+struct S {
+ arr : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+@group(0) @binding(1) var<uniform> in_str : S;
+@group(0) @binding(2) var<uniform> in_arr : array<mat3x3<f32>, 4>;
+@group(0) @binding(3) var<uniform> in_mat : mat3x3<f32>;
+
+fn f() {
+ P.arr[0] = in_str.arr[0];
+ P.arr[1] = in_arr[1];
+ P.arr[2] = in_mat;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+struct S {
+ arr : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+@group(0) @binding(1) var<uniform> in_str : S_tint_packed_vec3;
+
+@group(0) @binding(2) var<uniform> in_arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+@group(0) @binding(3) var<uniform> in_mat : array<tint_packed_vec3_f32_array_element, 3u>;
+
+fn f() {
+ P.arr[0] = in_str.arr[0];
+ P.arr[1] = in_arr[1];
+ P.arr[2] = in_mat;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfMatrix_WriteVector_ValueRHS) {
+ auto* src = R"(
+struct S {
+ arr : array<mat3x3<f32>, 4u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P.arr[0][1] = vec3(1.23);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+struct S {
+ arr : array<mat3x3<f32>, 4u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P.arr[0][1].elements = __packed_vec3<f32>(vec3(1.22999999999999998224));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfMatrix_WriteVector_RefRHS) {
+ auto* src = R"(
+struct S {
+ arr : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+@group(0) @binding(1) var<uniform> in_str : S;
+@group(0) @binding(2) var<uniform> in_arr : array<mat3x3<f32>, 4>;
+@group(0) @binding(3) var<uniform> in_mat : mat3x3<f32>;
+@group(0) @binding(4) var<uniform> in_vec : vec3<f32>;
+
+fn f() {
+ P.arr[0][0] = in_str.arr[0][1];
+ P.arr[1][1] = in_arr[3][2];
+ P.arr[2][2] = in_mat[1];
+ P.arr[3][0] = in_vec;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+struct S {
+ arr : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+@group(0) @binding(1) var<uniform> in_str : S_tint_packed_vec3;
+
+@group(0) @binding(2) var<uniform> in_arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+@group(0) @binding(3) var<uniform> in_mat : array<tint_packed_vec3_f32_array_element, 3u>;
+
+@group(0) @binding(4) var<uniform> in_vec : __packed_vec3<f32>;
+
+fn f() {
+ P.arr[0][0].elements = in_str.arr[0][1].elements;
+ P.arr[1][1].elements = in_arr[3][2].elements;
+ P.arr[2][2].elements = in_mat[1].elements;
+ P.arr[3][0].elements = in_vec;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfMatrix_WriteComponent_MemberAccessor) {
+ auto* src = R"(
+struct S {
+ arr : array<mat3x3<f32>, 4u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P.arr[0][1].y = 1.23;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+struct S {
+ arr : array<mat3x3<f32>, 4u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P.arr[0][1].elements.y = 1.22999999999999998224;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ArrayOfMatrix_WriteComponent_IndexAccessor) {
+ auto* src = R"(
+struct S {
+ arr : array<mat3x3<f32>, 4u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ P.arr[0][1][2] = 1.23;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+struct S {
+ arr : array<mat3x3<f32>, 4u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ P.arr[0][1].elements[2] = 1.22999999999999998224;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ExistingMemberAttributes) {
+ auto* src = R"(
+struct S {
+ @align(32) @size(32) v : vec3<f32>,
+ @align(64) @size(64) arr : array<vec3<f32>, 4>,
+ @align(128) @size(128) x : u32,
+}
+
+@group(0) @binding(0) var<uniform> P : S;
+
+fn f() {
+ let x = P.v[0];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(32) @size(32)
+ v : __packed_vec3<f32>,
+ @align(64) @size(64)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+ @align(128) @size(128)
+ x : u32,
+}
+
+struct S {
+ @align(32) @size(32)
+ v : vec3<f32>,
+ @align(64) @size(64)
+ arr : array<vec3<f32>, 4>,
+ @align(128) @size(128)
+ x : u32,
+}
+
+@group(0) @binding(0) var<uniform> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = P.v[0];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ExistingMemberAttributes_SizeMatchesUnpackedVec3) {
+ // Test that the type we replace a vec3 with is not larger than it should be.
+ auto* src = R"(
+struct S {
+ @size(12) v : vec3<f32>,
+ @size(64) arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<uniform> P : S;
+
+fn f() {
+ let x = P.v[0];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @size(12) @align(16)
+ v : __packed_vec3<f32>,
+ @size(64) @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+struct S {
+ @size(12)
+ v : vec3<f32>,
+ @size(64)
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<uniform> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = P.v[0];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ExistingMemberAttributes_AlignTooSmall) {
+ // Test that we add an @align() attribute when the new alignment of the packed vec3 struct would
+ // be too small.
+ auto* src = R"(
+struct S {
+ a : u32,
+ v : vec3<f32>,
+ b : u32,
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<uniform> P : S;
+
+fn f() {
+ let x = P.v[0];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ a : u32,
+ @align(16)
+ v : __packed_vec3<f32>,
+ b : u32,
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+struct S {
+ a : u32,
+ v : vec3<f32>,
+ b : u32,
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<uniform> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = P.v[0];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructMember_ExistingMemberAttributes_ExplicitOffset) {
+ // Test that the we do not add an @align attribute if @offset is present.
+
+ // struct S {
+ // a : u32,
+ // @offset(32) v : vec3<f32>,
+ // b : u32,
+ // @offset(128) arr : array<vec3<f32>, 4>,
+ // }
+ //
+ // @group(0) @binding(0) var<uniform> P : S;
+ ProgramBuilder b;
+ b.Structure("S", utils::Vector{
+ b.Member("a", b.ty.u32()),
+ b.Member("v", b.ty.vec3<f32>(), utils::Vector{b.MemberOffset(AInt(32))}),
+ b.Member("b", b.ty.u32()),
+ b.Member("arr", b.ty.array(b.ty.vec3<f32>(), b.Expr(AInt(4))),
+ utils::Vector{b.MemberOffset(AInt(128))}),
+ });
+ b.GlobalVar("P", builtin::AddressSpace::kStorage, b.ty("S"),
+ utils::Vector{b.Group(AInt(0)), b.Binding(AInt(0))});
+ Program src(std::move(b));
+
+ auto* expect =
+ R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ a : u32,
+ @size(28)
+ padding : u32,
+ /* @offset(32) */
+ v : __packed_vec3<f32>,
+ b : u32,
+ @size(80)
+ padding_1 : u32,
+ /* @offset(128) */
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+struct S {
+ a : u32,
+ @size(16)
+ padding_2 : u32,
+ /* @offset(32) */
+ v : vec3<f32>,
+ b : u32,
+ @size(80)
+ padding_3 : u32,
+ /* @offset(128) */
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(std::move(src), data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructValueConstructor_ViaIndexAccessor) {
+ auto* src = R"(
+struct S {
+ a : vec3<f32>,
+ b : vec3<f32>,
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> s : S;
+
+fn f() {
+ let value_arr : array<vec3<f32>, 4> = array<vec3<f32>, 4>();
+ let x = S(value_arr[0], s.arr[0], value_arr);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ a : __packed_vec3<f32>,
+ @align(16)
+ b : __packed_vec3<f32>,
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+struct S {
+ a : vec3<f32>,
+ b : vec3<f32>,
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> s : S_tint_packed_vec3;
+
+fn f() {
+ let value_arr : array<vec3<f32>, 4> = array<vec3<f32>, 4>();
+ let x = S(value_arr[0], vec3<f32>(s.arr[0].elements), value_arr);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, WrapperStructLayout_MixedUsage) {
+ // Test the layout of the generated wrapper struct(s) when vec3s are used in both structures and
+ // arrays.
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+ a : u32,
+}
+
+@group(0) @binding(0) var<storage, read_write> str : S;
+@group(0) @binding(1) var<storage, read_write> arr : array<vec3<f32>, 4>;
+
+fn main() {
+ str.v = arr[0];
+ arr[1] = str.v;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ a : u32,
+}
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S {
+ v : vec3<f32>,
+ a : u32,
+}
+
+@group(0) @binding(0) var<storage, read_write> str : S_tint_packed_vec3;
+
+@group(0) @binding(1) var<storage, read_write> arr : array<tint_packed_vec3_f32_array_element, 4u>;
+
+fn main() {
+ str.v = arr[0].elements;
+ arr[1].elements = str.v;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ auto& vars = got.program.AST().GlobalVariables();
+ ASSERT_EQ(vars.Length(), 2u);
+
+ {
+ // Check the layout of the struct type of "str".
+ // The first member should have an alignment of 16 bytes, a size of 12 bytes, and the second
+ // member should have an offset of 12 bytes.
+ auto* sem_str = got.program.Sem().Get(vars[0]);
+ auto* str_ty = sem_str->Type()->UnwrapRef()->As<type::Struct>();
+ ASSERT_NE(str_ty, nullptr);
+ ASSERT_EQ(str_ty->Members().Length(), 2u);
+ EXPECT_EQ(str_ty->Members()[0]->Align(), 16u);
+ EXPECT_EQ(str_ty->Members()[0]->Size(), 12u);
+ EXPECT_EQ(str_ty->Members()[1]->Offset(), 12u);
+ }
+
+ {
+ // Check the layout of the array type of "arr".
+ // The element stride should be 16 bytes.
+ auto* sem_arr = got.program.Sem().Get(vars[1]);
+ auto* arr_ty = sem_arr->Type()->UnwrapRef()->As<type::Array>();
+ ASSERT_NE(arr_ty, nullptr);
+ EXPECT_EQ(arr_ty->Stride(), 16u);
+ }
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, PackUnpackStructWithNonVec3Members) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+ arr : array<vec3<f32>, 4>,
+ a : u32,
+ b : vec4<f32>,
+ c : array<vec4<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ let x = P;
+ P = x;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+ a : u32,
+ b : vec4<f32>,
+ c : array<vec4<f32>, 4>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 4u>) -> array<vec3<f32>, 4u> {
+ var result : array<vec3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_1(in : S_tint_packed_vec3) -> S {
+ var result : S;
+ result.v = vec3<f32>(in.v);
+ result.arr = tint_unpack_vec3_in_composite(in.arr);
+ result.a = in.a;
+ result.b = in.b;
+ result.c = in.c;
+ return result;
+}
+
+fn tint_pack_vec3_in_composite(in : array<vec3<f32>, 4u>) -> array<tint_packed_vec3_f32_array_element, 4u> {
+ var result : array<tint_packed_vec3_f32_array_element, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_1(in : S) -> S_tint_packed_vec3 {
+ var result : S_tint_packed_vec3;
+ result.v = __packed_vec3<f32>(in.v);
+ result.arr = tint_pack_vec3_in_composite(in.arr);
+ result.a = in.a;
+ result.b = in.b;
+ result.c = in.c;
+ return result;
+}
+
+struct S {
+ v : vec3<f32>,
+ arr : array<vec3<f32>, 4>,
+ a : u32,
+ b : vec4<f32>,
+ c : array<vec4<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ let x = tint_unpack_vec3_in_composite_1(P);
+ P = tint_pack_vec3_in_composite_1(x);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Struct_ShaderIO) {
+ // Test that we do not modify structures that are used for shader IO.
+ auto* src = R"(
+struct S1 {
+ @location(0) v : vec3<f32>,
+}
+
+struct S2 {
+ @location(0) v : vec3<f32>,
+ @builtin(position) pos : vec4<f32>,
+}
+
+@vertex
+fn main(s1 : S1) -> S2 {
+ let v : vec3<f32> = s1.v;
+ var s2 : S2;
+ s2.v = v;
+ return s2;
+}
+)";
+
+ auto* expect = R"(
+struct S1 {
+ @location(0)
+ v : vec3<f32>,
+}
+
+struct S2 {
+ @location(0)
+ v : vec3<f32>,
+ @builtin(position)
+ pos : vec4<f32>,
+}
+
+@vertex
+fn main(s1 : S1) -> S2 {
+ let v : vec3<f32> = s1.v;
+ var s2 : S2;
+ s2.v = v;
+ return s2;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ModfReturnStruct) {
+ // Test that we do not try to modify accessors on the anonymous structure returned by modf.
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> output : vec3<f32>;
+
+const values = array(modf(vec3(1.0, 2.0, 3.0)).fract);
+
+@compute @workgroup_size(1)
+fn main() {
+ output = values[0];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+@group(0) @binding(0) var<storage, read_write> output : __packed_vec3<f32>;
+
+const values = array(modf(vec3(1.0, 2.0, 3.0)).fract);
+
+@compute @workgroup_size(1)
+fn main() {
+ output = __packed_vec3<f32>(values[0]);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ModfReturnStruct_PointerToMember) {
+ // Test that we can pass a pointer to the vec3 member of the modf return struct to a function
+ // parameter to which we also pass a pointer to a vec3 member on a host-shareable struct.
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ v : vec3<f32>
+}
+
+@group(0) @binding(0) var<storage, read_write> output : S;
+
+fn foo(p : ptr<function, vec3<f32>>) {
+ (*p) = vec3(1, 2, 3);
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ var f : S;
+ var modf_ret = modf(vec3(1.0, 2.0, 3.0));
+ foo(&f.v);
+ foo(&modf_ret.fract);
+ output.v = modf_ret.fract;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+enable chromium_experimental_full_ptr_parameters;
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+}
+
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> output : S_tint_packed_vec3;
+
+fn foo(p : ptr<function, vec3<f32>>) {
+ *(p) = vec3(1, 2, 3);
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ var f : S;
+ var modf_ret = modf(vec3(1.0, 2.0, 3.0));
+ foo(&(f.v));
+ foo(&(modf_ret.fract));
+ output.v = __packed_vec3<f32>(modf_ret.fract);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, MultipleStructMembers) {
+ auto* src = R"(
+struct S {
+ v2_a : vec2<f32>,
+ v3_a : vec3<f32>,
+ v4_a : vec4<f32>,
+ v2_b : vec2<f32>,
+ v3_b : vec3<f32>,
+ v4_b : vec4<f32>,
+ v2_arr : array<vec2<f32>, 4>,
+ v3_arr : array<vec3<f32>, 4>,
+ v4_arr : array<vec4<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ let v2_a = P.v2_a;
+ let v3_a = P.v3_a;
+ let v4_a = P.v4_a;
+ let v2_b = P.v2_b;
+ let v3_b = P.v3_b;
+ let v4_b = P.v4_b;
+ let v2_arr : array<vec2<f32>, 4> = P.v2_arr;
+ let v3_arr : array<vec3<f32>, 4> = P.v3_arr;
+ let v4_arr : array<vec4<f32>, 4> = P.v4_arr;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ v2_a : vec2<f32>,
+ @align(16)
+ v3_a : __packed_vec3<f32>,
+ v4_a : vec4<f32>,
+ v2_b : vec2<f32>,
+ @align(16)
+ v3_b : __packed_vec3<f32>,
+ v4_b : vec4<f32>,
+ v2_arr : array<vec2<f32>, 4>,
+ @align(16)
+ v3_arr : array<tint_packed_vec3_f32_array_element, 4u>,
+ v4_arr : array<vec4<f32>, 4>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 4u>) -> array<vec3<f32>, 4u> {
+ var result : array<vec3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+struct S {
+ v2_a : vec2<f32>,
+ v3_a : vec3<f32>,
+ v4_a : vec4<f32>,
+ v2_b : vec2<f32>,
+ v3_b : vec3<f32>,
+ v4_b : vec4<f32>,
+ v2_arr : array<vec2<f32>, 4>,
+ v3_arr : array<vec3<f32>, 4>,
+ v4_arr : array<vec4<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ let v2_a = P.v2_a;
+ let v3_a = vec3<f32>(P.v3_a);
+ let v4_a = P.v4_a;
+ let v2_b = P.v2_b;
+ let v3_b = vec3<f32>(P.v3_b);
+ let v4_b = P.v4_b;
+ let v2_arr : array<vec2<f32>, 4> = P.v2_arr;
+ let v3_arr : array<vec3<f32>, 4> = tint_unpack_vec3_in_composite(P.v3_arr);
+ let v4_arr : array<vec4<f32>, 4> = P.v4_arr;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Vec3Pointers) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : vec3<f32>;
+@group(0) @binding(1) var<storage, read_write> arr_v : array<vec3<f32>, 4>;
+@group(0) @binding(2) var<storage, read_write> m : mat3x3<f32>;
+@group(0) @binding(3) var<storage, read_write> arr_m : array<mat3x3<f32>, 4>;
+@group(0) @binding(4) var<storage, read_write> str : S;
+
+fn f() {
+ let p_v = &v;
+ let v = *p_v;
+ *p_v = v;
+
+ let p_arr_v = &arr_v[0];
+ let arr_v = *p_arr_v;
+ *p_arr_v = arr_v;
+
+ let p_m = &m[0];
+ let m = *p_m;
+ *p_m = m;
+
+ let p_arr_m = &arr_m[0][1];
+ let arr_m = *p_arr_m;
+ *p_arr_m = arr_m;
+
+ let p_str_v = &str.v;
+ let str_v = *p_str_v;
+ *p_str_v = str_v;
+
+ let p_str_arr_v = &str.arr_v[0];
+ let str_arr_v = *p_str_arr_v;
+ *p_str_arr_v = str_arr_v;
+
+ let p_str_m = &str.m[0];
+ let str_m = *p_str_m;
+ *p_str_m = str_m;
+
+ let p_str_arr_m = &str.arr_m[0][1];
+ let str_arr_m = *p_str_arr_m;
+ *p_str_arr_m = str_arr_m;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+ @align(16)
+ arr_v : array<tint_packed_vec3_f32_array_element, 4u>,
+ @align(16)
+ arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : __packed_vec3<f32>;
+
+@group(0) @binding(1) var<storage, read_write> arr_v : array<tint_packed_vec3_f32_array_element, 4u>;
+
+@group(0) @binding(2) var<storage, read_write> m : array<tint_packed_vec3_f32_array_element, 3u>;
+
+@group(0) @binding(3) var<storage, read_write> arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+@group(0) @binding(4) var<storage, read_write> str : S_tint_packed_vec3;
+
+fn f() {
+ let p_v = &(v);
+ let v = vec3<f32>(*(p_v));
+ *(p_v) = __packed_vec3<f32>(v);
+ let p_arr_v = &(arr_v[0].elements);
+ let arr_v = vec3<f32>(*(p_arr_v));
+ *(p_arr_v) = __packed_vec3<f32>(arr_v);
+ let p_m = &(m[0].elements);
+ let m = vec3<f32>(*(p_m));
+ *(p_m) = __packed_vec3<f32>(m);
+ let p_arr_m = &(arr_m[0][1].elements);
+ let arr_m = vec3<f32>(*(p_arr_m));
+ *(p_arr_m) = __packed_vec3<f32>(arr_m);
+ let p_str_v = &(str.v);
+ let str_v = vec3<f32>(*(p_str_v));
+ *(p_str_v) = __packed_vec3<f32>(str_v);
+ let p_str_arr_v = &(str.arr_v[0].elements);
+ let str_arr_v = vec3<f32>(*(p_str_arr_v));
+ *(p_str_arr_v) = __packed_vec3<f32>(str_arr_v);
+ let p_str_m = &(str.m[0].elements);
+ let str_m = vec3<f32>(*(p_str_m));
+ *(p_str_m) = __packed_vec3<f32>(str_m);
+ let p_str_arr_m = &(str.arr_m[0][1].elements);
+ let str_arr_m = vec3<f32>(*(p_str_arr_m));
+ *(p_str_arr_m) = __packed_vec3<f32>(str_arr_m);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, MatrixPointers) {
+ auto* src = R"(
+struct S {
+ m : mat3x3<f32>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> m : mat3x3<f32>;
+@group(0) @binding(1) var<storage, read_write> arr_m : array<mat3x3<f32>, 4>;
+@group(0) @binding(2) var<storage, read_write> str : S;
+
+fn f() {
+ let p_m = &m;
+ let m = *p_m;
+ *p_m = m;
+
+ let p_arr_m = &arr_m[0];
+ let arr_m = *p_arr_m;
+ *p_arr_m = arr_m;
+
+ let p_str_m = &str.m;
+ let str_m = *p_str_m;
+ *p_str_m = str_m;
+
+ let p_str_arr_m = &str.arr_m[0];
+ let str_arr_m = *p_str_arr_m;
+ *p_str_arr_m = str_arr_m;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+ @align(16)
+ arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite(in : mat3x3<f32>) -> array<tint_packed_vec3_f32_array_element, 3u> {
+ var result : array<tint_packed_vec3_f32_array_element, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+struct S {
+ m : mat3x3<f32>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> m : array<tint_packed_vec3_f32_array_element, 3u>;
+
+@group(0) @binding(1) var<storage, read_write> arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+@group(0) @binding(2) var<storage, read_write> str : S_tint_packed_vec3;
+
+fn f() {
+ let p_m = &(m);
+ let m = tint_unpack_vec3_in_composite(*(p_m));
+ *(p_m) = tint_pack_vec3_in_composite(m);
+ let p_arr_m = &(arr_m[0]);
+ let arr_m = tint_unpack_vec3_in_composite(*(p_arr_m));
+ *(p_arr_m) = tint_pack_vec3_in_composite(arr_m);
+ let p_str_m = &(str.m);
+ let str_m = tint_unpack_vec3_in_composite(*(p_str_m));
+ *(p_str_m) = tint_pack_vec3_in_composite(str_m);
+ let p_str_arr_m = &(str.arr_m[0]);
+ let str_arr_m = tint_unpack_vec3_in_composite(*(p_str_arr_m));
+ *(p_str_arr_m) = tint_pack_vec3_in_composite(str_arr_m);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfVec3Pointers) {
+ auto* src = R"(
+struct S {
+ arr_v : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr_v : array<vec3<f32>, 4>;
+@group(0) @binding(1) var<storage, read_write> str : S;
+
+fn f() {
+ let p_arr_v = &arr_v;
+ let arr_v = *p_arr_v;
+ *p_arr_v = arr_v;
+
+ let p_str_arr_v = &str.arr_v;
+ let str_arr_v = *p_str_arr_v;
+ *p_str_arr_v = str_arr_v;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr_v : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 4u>) -> array<vec3<f32>, 4u> {
+ var result : array<vec3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite(in : array<vec3<f32>, 4u>) -> array<tint_packed_vec3_f32_array_element, 4u> {
+ var result : array<tint_packed_vec3_f32_array_element, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+struct S {
+ arr_v : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr_v : array<tint_packed_vec3_f32_array_element, 4u>;
+
+@group(0) @binding(1) var<storage, read_write> str : S_tint_packed_vec3;
+
+fn f() {
+ let p_arr_v = &(arr_v);
+ let arr_v = tint_unpack_vec3_in_composite(*(p_arr_v));
+ *(p_arr_v) = tint_pack_vec3_in_composite(arr_v);
+ let p_str_arr_v = &(str.arr_v);
+ let str_arr_v = tint_unpack_vec3_in_composite(*(p_str_arr_v));
+ *(p_str_arr_v) = tint_pack_vec3_in_composite(str_arr_v);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfMatrixPointers) {
+ auto* src = R"(
+struct S {
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr_m : array<mat3x3<f32>, 4>;
+@group(0) @binding(1) var<storage, read_write> str : S;
+
+fn f() {
+ let p_arr_m = &arr_m;
+ let arr_m = *p_arr_m;
+ *p_arr_m = arr_m;
+
+ let p_str_arr_m = &str.arr_m;
+ let str_arr_m = *p_str_arr_m;
+ *p_str_arr_m = str_arr_m;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_1(in : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>) -> array<mat3x3<f32>, 4u> {
+ var result : array<mat3x3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite(in : mat3x3<f32>) -> array<tint_packed_vec3_f32_array_element, 3u> {
+ var result : array<tint_packed_vec3_f32_array_element, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_1(in : array<mat3x3<f32>, 4u>) -> array<array<tint_packed_vec3_f32_array_element, 3u>, 4u> {
+ var result : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_pack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+struct S {
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+@group(0) @binding(1) var<storage, read_write> str : S_tint_packed_vec3;
+
+fn f() {
+ let p_arr_m = &(arr_m);
+ let arr_m = tint_unpack_vec3_in_composite_1(*(p_arr_m));
+ *(p_arr_m) = tint_pack_vec3_in_composite_1(arr_m);
+ let p_str_arr_m = &(str.arr_m);
+ let str_arr_m = tint_unpack_vec3_in_composite_1(*(p_str_arr_m));
+ *(p_str_arr_m) = tint_pack_vec3_in_composite_1(str_arr_m);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructPointers) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> str : S;
+
+fn f() {
+ let p_str = &str;
+ let str = *p_str;
+ *p_str = str;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+ @align(16)
+ arr_v : array<tint_packed_vec3_f32_array_element, 4u>,
+ @align(16)
+ arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_1(in : array<tint_packed_vec3_f32_array_element, 4u>) -> array<vec3<f32>, 4u> {
+ var result : array<vec3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_2(in : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>) -> array<mat3x3<f32>, 4u> {
+ var result : array<mat3x3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_3(in : S_tint_packed_vec3) -> S {
+ var result : S;
+ result.v = vec3<f32>(in.v);
+ result.m = tint_unpack_vec3_in_composite(in.m);
+ result.arr_v = tint_unpack_vec3_in_composite_1(in.arr_v);
+ result.arr_m = tint_unpack_vec3_in_composite_2(in.arr_m);
+ return result;
+}
+
+fn tint_pack_vec3_in_composite(in : mat3x3<f32>) -> array<tint_packed_vec3_f32_array_element, 3u> {
+ var result : array<tint_packed_vec3_f32_array_element, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_1(in : array<vec3<f32>, 4u>) -> array<tint_packed_vec3_f32_array_element, 4u> {
+ var result : array<tint_packed_vec3_f32_array_element, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_2(in : array<mat3x3<f32>, 4u>) -> array<array<tint_packed_vec3_f32_array_element, 3u>, 4u> {
+ var result : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_pack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_3(in : S) -> S_tint_packed_vec3 {
+ var result : S_tint_packed_vec3;
+ result.v = __packed_vec3<f32>(in.v);
+ result.m = tint_pack_vec3_in_composite(in.m);
+ result.arr_v = tint_pack_vec3_in_composite_1(in.arr_v);
+ result.arr_m = tint_pack_vec3_in_composite_2(in.arr_m);
+ return result;
+}
+
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> str : S_tint_packed_vec3;
+
+fn f() {
+ let p_str = &(str);
+ let str = tint_unpack_vec3_in_composite_3(*(p_str));
+ *(p_str) = tint_pack_vec3_in_composite_3(str);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, VectorPointerParameters) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : vec3<f32>;
+@group(0) @binding(1) var<storage, read_write> arr_v : array<vec3<f32>, 4>;
+@group(0) @binding(2) var<storage, read_write> m : mat3x3<f32>;
+@group(0) @binding(3) var<storage, read_write> arr_m : array<mat3x3<f32>, 4>;
+@group(0) @binding(4) var<storage, read_write> str : S;
+
+fn load(p : ptr<storage, vec3<f32>, read_write>) -> vec3<f32> {
+ return *p;
+}
+
+fn store(p : ptr<storage, vec3<f32>, read_write>) {
+ *p = vec3(1, 2, 3);
+}
+
+fn f() {
+ load(&v);
+ store(&v);
+ load(&arr_v[0]);
+ store(&arr_v[0]);
+ load(&m[0]);
+ store(&m[0]);
+ load(&arr_m[0][1]);
+ store(&arr_m[0][1]);
+ load(&str.v);
+ store(&str.v);
+ load(&str.arr_v[0]);
+ store(&str.arr_v[0]);
+ load(&str.m[0]);
+ store(&str.m[0]);
+ load(&str.arr_m[0][1]);
+ store(&str.arr_m[0][1]);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+ @align(16)
+ arr_v : array<tint_packed_vec3_f32_array_element, 4u>,
+ @align(16)
+ arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : __packed_vec3<f32>;
+
+@group(0) @binding(1) var<storage, read_write> arr_v : array<tint_packed_vec3_f32_array_element, 4u>;
+
+@group(0) @binding(2) var<storage, read_write> m : array<tint_packed_vec3_f32_array_element, 3u>;
+
+@group(0) @binding(3) var<storage, read_write> arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+@group(0) @binding(4) var<storage, read_write> str : S_tint_packed_vec3;
+
+fn load(p : ptr<storage, __packed_vec3<f32>, read_write>) -> vec3<f32> {
+ return vec3<f32>(*(p));
+}
+
+fn store(p : ptr<storage, __packed_vec3<f32>, read_write>) {
+ *(p) = __packed_vec3<f32>(vec3(1, 2, 3));
+}
+
+fn f() {
+ load(&(v));
+ store(&(v));
+ load(&(arr_v[0].elements));
+ store(&(arr_v[0].elements));
+ load(&(m[0].elements));
+ store(&(m[0].elements));
+ load(&(arr_m[0][1].elements));
+ store(&(arr_m[0][1].elements));
+ load(&(str.v));
+ store(&(str.v));
+ load(&(str.arr_v[0].elements));
+ store(&(str.arr_v[0].elements));
+ load(&(str.m[0].elements));
+ store(&(str.m[0].elements));
+ load(&(str.arr_m[0][1].elements));
+ store(&(str.arr_m[0][1].elements));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, MatrixPointerParameters) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ m : mat3x3<f32>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> m : mat3x3<f32>;
+@group(0) @binding(1) var<storage, read_write> arr_m : array<mat3x3<f32>, 4>;
+@group(0) @binding(2) var<storage, read_write> str : S;
+
+fn load(p : ptr<storage, mat3x3<f32>, read_write>) -> mat3x3<f32> {
+ return *p;
+}
+
+fn store(p : ptr<storage, mat3x3<f32>, read_write>) {
+ *p = mat3x3(1, 2, 3, 4, 5, 6, 7, 8, 9);
+}
+
+fn f() {
+ load(&m);
+ store(&m);
+ load(&arr_m[0]);
+ store(&arr_m[0]);
+ load(&str.m);
+ store(&str.m);
+ load(&str.arr_m[0]);
+ store(&str.arr_m[0]);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+ @align(16)
+ arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite(in : mat3x3<f32>) -> array<tint_packed_vec3_f32_array_element, 3u> {
+ var result : array<tint_packed_vec3_f32_array_element, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+struct S {
+ m : mat3x3<f32>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> m : array<tint_packed_vec3_f32_array_element, 3u>;
+
+@group(0) @binding(1) var<storage, read_write> arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+@group(0) @binding(2) var<storage, read_write> str : S_tint_packed_vec3;
+
+fn load(p : ptr<storage, array<tint_packed_vec3_f32_array_element, 3u>, read_write>) -> mat3x3<f32> {
+ return tint_unpack_vec3_in_composite(*(p));
+}
+
+fn store(p : ptr<storage, array<tint_packed_vec3_f32_array_element, 3u>, read_write>) {
+ *(p) = tint_pack_vec3_in_composite(mat3x3(1, 2, 3, 4, 5, 6, 7, 8, 9));
+}
+
+fn f() {
+ load(&(m));
+ store(&(m));
+ load(&(arr_m[0]));
+ store(&(arr_m[0]));
+ load(&(str.m));
+ store(&(str.m));
+ load(&(str.arr_m[0]));
+ store(&(str.arr_m[0]));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfVectorPointerParameters) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ arr_v : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr_v : array<vec3<f32>, 4>;
+@group(0) @binding(1) var<storage, read_write> str : S;
+
+fn load(p : ptr<storage, array<vec3<f32>, 4>, read_write>) -> array<vec3<f32>, 4> {
+ return *p;
+}
+
+fn store(p : ptr<storage, array<vec3<f32>, 4>, read_write>) {
+ *p = array(vec3(1.0), vec3(2.0), vec3(3.0), vec3(4.0));
+}
+
+fn f() {
+ load(&arr_v);
+ store(&arr_v);
+ load(&str.arr_v);
+ store(&str.arr_v);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr_v : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 4u>) -> array<vec3<f32>, 4u> {
+ var result : array<vec3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite(in : array<vec3<f32>, 4u>) -> array<tint_packed_vec3_f32_array_element, 4u> {
+ var result : array<tint_packed_vec3_f32_array_element, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+struct S {
+ arr_v : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr_v : array<tint_packed_vec3_f32_array_element, 4u>;
+
+@group(0) @binding(1) var<storage, read_write> str : S_tint_packed_vec3;
+
+fn load(p : ptr<storage, array<tint_packed_vec3_f32_array_element, 4u>, read_write>) -> array<vec3<f32>, 4> {
+ return tint_unpack_vec3_in_composite(*(p));
+}
+
+fn store(p : ptr<storage, array<tint_packed_vec3_f32_array_element, 4u>, read_write>) {
+ *(p) = tint_pack_vec3_in_composite(array(vec3(1.0), vec3(2.0), vec3(3.0), vec3(4.0)));
+}
+
+fn f() {
+ load(&(arr_v));
+ store(&(arr_v));
+ load(&(str.arr_v));
+ store(&(str.arr_v));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ArrayOfMatrixPointerParameters) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr_m : array<mat3x3<f32>, 4>;
+@group(0) @binding(1) var<storage, read_write> str : S;
+
+fn load(p : ptr<storage, array<mat3x3<f32>, 4>, read_write>) -> array<mat3x3<f32>, 4> {
+ return *p;
+}
+
+fn store(p : ptr<storage, array<mat3x3<f32>, 4>, read_write>) {
+ *p = array(mat3x3<f32>(), mat3x3<f32>(), mat3x3<f32>(), mat3x3<f32>());
+}
+
+fn f() {
+ load(&arr_m);
+ store(&arr_m);
+ load(&str.arr_m);
+ store(&str.arr_m);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_1(in : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>) -> array<mat3x3<f32>, 4u> {
+ var result : array<mat3x3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite(in : mat3x3<f32>) -> array<tint_packed_vec3_f32_array_element, 3u> {
+ var result : array<tint_packed_vec3_f32_array_element, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_1(in : array<mat3x3<f32>, 4u>) -> array<array<tint_packed_vec3_f32_array_element, 3u>, 4u> {
+ var result : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_pack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+struct S {
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+@group(0) @binding(1) var<storage, read_write> str : S_tint_packed_vec3;
+
+fn load(p : ptr<storage, array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>, read_write>) -> array<mat3x3<f32>, 4> {
+ return tint_unpack_vec3_in_composite_1(*(p));
+}
+
+fn store(p : ptr<storage, array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>, read_write>) {
+ *(p) = tint_pack_vec3_in_composite_1(array(mat3x3<f32>(), mat3x3<f32>(), mat3x3<f32>(), mat3x3<f32>()));
+}
+
+fn f() {
+ load(&(arr_m));
+ store(&(arr_m));
+ load(&(str.arr_m));
+ store(&(str.arr_m));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, StructPointerParameters) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> str : S;
+
+fn load(p : ptr<storage, S, read_write>) -> S {
+ return *p;
+}
+
+fn store(p : ptr<storage, S, read_write>) {
+ *p = S();
+}
+
+fn f() {
+ load(&str);
+ store(&str);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+ @align(16)
+ arr_v : array<tint_packed_vec3_f32_array_element, 4u>,
+ @align(16)
+ arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_1(in : array<tint_packed_vec3_f32_array_element, 4u>) -> array<vec3<f32>, 4u> {
+ var result : array<vec3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_2(in : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>) -> array<mat3x3<f32>, 4u> {
+ var result : array<mat3x3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_3(in : S_tint_packed_vec3) -> S {
+ var result : S;
+ result.v = vec3<f32>(in.v);
+ result.m = tint_unpack_vec3_in_composite(in.m);
+ result.arr_v = tint_unpack_vec3_in_composite_1(in.arr_v);
+ result.arr_m = tint_unpack_vec3_in_composite_2(in.arr_m);
+ return result;
+}
+
+fn tint_pack_vec3_in_composite(in : mat3x3<f32>) -> array<tint_packed_vec3_f32_array_element, 3u> {
+ var result : array<tint_packed_vec3_f32_array_element, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_1(in : array<vec3<f32>, 4u>) -> array<tint_packed_vec3_f32_array_element, 4u> {
+ var result : array<tint_packed_vec3_f32_array_element, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_2(in : array<mat3x3<f32>, 4u>) -> array<array<tint_packed_vec3_f32_array_element, 3u>, 4u> {
+ var result : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_pack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_3(in : S) -> S_tint_packed_vec3 {
+ var result : S_tint_packed_vec3;
+ result.v = __packed_vec3<f32>(in.v);
+ result.m = tint_pack_vec3_in_composite(in.m);
+ result.arr_v = tint_pack_vec3_in_composite_1(in.arr_v);
+ result.arr_m = tint_pack_vec3_in_composite_2(in.arr_m);
+ return result;
+}
+
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> str : S_tint_packed_vec3;
+
+fn load(p : ptr<storage, S_tint_packed_vec3, read_write>) -> S {
+ return tint_unpack_vec3_in_composite_3(*(p));
+}
+
+fn store(p : ptr<storage, S_tint_packed_vec3, read_write>) {
+ *(p) = tint_pack_vec3_in_composite_3(S());
+}
+
+fn f() {
+ load(&(str));
+ store(&(str));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, MixedAddressSpace_Struct) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn f() {
+ var f : S;
+ let v = f.v;
+ let arr = f.arr;
+ P = f;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+fn tint_pack_vec3_in_composite(in : array<vec3<f32>, 4u>) -> array<tint_packed_vec3_f32_array_element, 4u> {
+ var result : array<tint_packed_vec3_f32_array_element, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_1(in : S) -> S_tint_packed_vec3 {
+ var result : S_tint_packed_vec3;
+ result.v = __packed_vec3<f32>(in.v);
+ result.arr = tint_pack_vec3_in_composite(in.arr);
+ return result;
+}
+
+struct S {
+ v : vec3<f32>,
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn f() {
+ var f : S;
+ let v = f.v;
+ let arr = f.arr;
+ P = tint_pack_vec3_in_composite_1(f);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, MixedAddressSpace_NestedStruct) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+ arr : array<vec3<f32>, 4>,
+}
+
+struct Outer {
+ inner : S,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : Outer;
+
+fn f() {
+ var f : Outer;
+ let v = f.inner.v;
+ let arr = f.inner.arr;
+ P = f;
+ P.inner = f.inner;
+ P.inner.v = f.inner.v;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+struct Outer_tint_packed_vec3 {
+ @align(16)
+ inner : S_tint_packed_vec3,
+}
+
+fn tint_pack_vec3_in_composite(in : array<vec3<f32>, 4u>) -> array<tint_packed_vec3_f32_array_element, 4u> {
+ var result : array<tint_packed_vec3_f32_array_element, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_1(in : S) -> S_tint_packed_vec3 {
+ var result : S_tint_packed_vec3;
+ result.v = __packed_vec3<f32>(in.v);
+ result.arr = tint_pack_vec3_in_composite(in.arr);
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_2(in : Outer) -> Outer_tint_packed_vec3 {
+ var result : Outer_tint_packed_vec3;
+ result.inner = tint_pack_vec3_in_composite_1(in.inner);
+ return result;
+}
+
+struct S {
+ v : vec3<f32>,
+ arr : array<vec3<f32>, 4>,
+}
+
+struct Outer {
+ inner : S,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : Outer_tint_packed_vec3;
+
+fn f() {
+ var f : Outer;
+ let v = f.inner.v;
+ let arr = f.inner.arr;
+ P = tint_pack_vec3_in_composite_2(f);
+ P.inner = tint_pack_vec3_in_composite_1(f.inner);
+ P.inner.v = __packed_vec3<f32>(f.inner.v);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, MixedAddressSpace_AnotherStructNotShared) {
+ // Test that we can pass a pointers to a members of both shared and non-shared structs to the
+ // same function.
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ v : vec3<f32>,
+ arr : array<vec3<f32>, 4>,
+}
+
+struct NotShared {
+ v : vec3<f32>,
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S;
+
+fn g(p : ptr<function, vec3<f32>>) -> vec3<f32> {
+ return *p;
+}
+
+fn f() {
+ var f1 : S;
+ var f2 : NotShared;
+ g(&f1.v);
+ g(&f1.arr[0]);
+ g(&f2.v);
+ g(&f2.arr[0]);
+ P = f1;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+fn tint_pack_vec3_in_composite(in : array<vec3<f32>, 4u>) -> array<tint_packed_vec3_f32_array_element, 4u> {
+ var result : array<tint_packed_vec3_f32_array_element, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_1(in : S) -> S_tint_packed_vec3 {
+ var result : S_tint_packed_vec3;
+ result.v = __packed_vec3<f32>(in.v);
+ result.arr = tint_pack_vec3_in_composite(in.arr);
+ return result;
+}
+
+struct S {
+ v : vec3<f32>,
+ arr : array<vec3<f32>, 4>,
+}
+
+struct NotShared {
+ v : vec3<f32>,
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> P : S_tint_packed_vec3;
+
+fn g(p : ptr<function, vec3<f32>>) -> vec3<f32> {
+ return *(p);
+}
+
+fn f() {
+ var f1 : S;
+ var f2 : NotShared;
+ g(&(f1.v));
+ g(&(f1.arr[0]));
+ g(&(f2.v));
+ g(&(f2.arr[0]));
+ P = tint_pack_vec3_in_composite_1(f1);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, MixedAddressSpace_InitFromLoad_ExplicitVarType) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ var f1 : S = P;
+ var f2 : vec3<f32> = P.v;
+ var f3 : mat3x3<f32> = P.m;
+ var f4 : array<vec3<f32>, 4> = P.arr_v;
+ var f5 : array<mat3x3<f32>, 4> = P.arr_m;
+ let v_1 = f1.v;
+ let v_2 = f2;
+ let v_3 = f3[0];
+ let v_4 = f4[1];
+ let v_5 = f5[2][2];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+ @align(16)
+ arr_v : array<tint_packed_vec3_f32_array_element, 4u>,
+ @align(16)
+ arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_1(in : array<tint_packed_vec3_f32_array_element, 4u>) -> array<vec3<f32>, 4u> {
+ var result : array<vec3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_2(in : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>) -> array<mat3x3<f32>, 4u> {
+ var result : array<mat3x3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_3(in : S_tint_packed_vec3) -> S {
+ var result : S;
+ result.v = vec3<f32>(in.v);
+ result.m = tint_unpack_vec3_in_composite(in.m);
+ result.arr_v = tint_unpack_vec3_in_composite_1(in.arr_v);
+ result.arr_m = tint_unpack_vec3_in_composite_2(in.arr_m);
+ return result;
+}
+
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ var f1 : S = tint_unpack_vec3_in_composite_3(P);
+ var f2 : vec3<f32> = vec3<f32>(P.v);
+ var f3 : mat3x3<f32> = tint_unpack_vec3_in_composite(P.m);
+ var f4 : array<vec3<f32>, 4> = tint_unpack_vec3_in_composite_1(P.arr_v);
+ var f5 : array<mat3x3<f32>, 4> = tint_unpack_vec3_in_composite_2(P.arr_m);
+ let v_1 = f1.v;
+ let v_2 = f2;
+ let v_3 = f3[0];
+ let v_4 = f4[1];
+ let v_5 = f5[2][2];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, MixedAddressSpace_InitFromLoad_InferredVarType) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ var f1 = P;
+ var f2 = P.v;
+ var f3 = P.m;
+ var f4 = P.arr_v;
+ var f5 = P.arr_m;
+ let v_1 = f1.v;
+ let v_2 = f2;
+ let v_3 = f3[0];
+ let v_4 = f4[1];
+ let v_5 = f5[2][2];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+ @align(16)
+ arr_v : array<tint_packed_vec3_f32_array_element, 4u>,
+ @align(16)
+ arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_1(in : array<tint_packed_vec3_f32_array_element, 4u>) -> array<vec3<f32>, 4u> {
+ var result : array<vec3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_2(in : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>) -> array<mat3x3<f32>, 4u> {
+ var result : array<mat3x3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_3(in : S_tint_packed_vec3) -> S {
+ var result : S;
+ result.v = vec3<f32>(in.v);
+ result.m = tint_unpack_vec3_in_composite(in.m);
+ result.arr_v = tint_unpack_vec3_in_composite_1(in.arr_v);
+ result.arr_m = tint_unpack_vec3_in_composite_2(in.arr_m);
+ return result;
+}
+
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ var f1 = tint_unpack_vec3_in_composite_3(P);
+ var f2 = vec3<f32>(P.v);
+ var f3 = tint_unpack_vec3_in_composite(P.m);
+ var f4 = tint_unpack_vec3_in_composite_1(P.arr_v);
+ var f5 = tint_unpack_vec3_in_composite_2(P.arr_m);
+ let v_1 = f1.v;
+ let v_2 = f2;
+ let v_3 = f3[0];
+ let v_4 = f4[1];
+ let v_5 = f5[2][2];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, MixedAddressSpace_InitFromValue_ExplicitVarType) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ var f1 : S = S();
+ var f2 : vec3<f32> = vec3<f32>();
+ var f3 : mat3x3<f32> = mat3x3<f32>();
+ var f4 : array<vec3<f32>, 4> = array(vec3<f32>(), vec3<f32>(), vec3<f32>(), vec3<f32>());
+ var f5 : array<mat3x3<f32>, 4> = array(mat3x3<f32>(), mat3x3<f32>(), mat3x3<f32>(), mat3x3<f32>());
+ let v_1 = f1.v;
+ let v_2 = f2;
+ let v_3 = f3[0];
+ let v_4 = f4[1];
+ let v_5 = f5[2][2];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+ @align(16)
+ arr_v : array<tint_packed_vec3_f32_array_element, 4u>,
+ @align(16)
+ arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ var f1 : S = S();
+ var f2 : vec3<f32> = vec3<f32>();
+ var f3 : mat3x3<f32> = mat3x3<f32>();
+ var f4 : array<vec3<f32>, 4> = array(vec3<f32>(), vec3<f32>(), vec3<f32>(), vec3<f32>());
+ var f5 : array<mat3x3<f32>, 4> = array(mat3x3<f32>(), mat3x3<f32>(), mat3x3<f32>(), mat3x3<f32>());
+ let v_1 = f1.v;
+ let v_2 = f2;
+ let v_3 = f3[0];
+ let v_4 = f4[1];
+ let v_5 = f5[2][2];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, MixedAddressSpace_InitFromValue_InferredVarType) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ var f1 = S();
+ var f2 = vec3<f32>();
+ var f3 = mat3x3<f32>();
+ var f4 = array(vec3<f32>(), vec3<f32>(), vec3<f32>(), vec3<f32>());
+ var f5 = array(mat3x3<f32>(), mat3x3<f32>(), mat3x3<f32>(), mat3x3<f32>());
+ let v_1 = f1.v;
+ let v_2 = f2;
+ let v_3 = f3[0];
+ let v_4 = f4[1];
+ let v_5 = f5[2][2];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+ @align(16)
+ arr_v : array<tint_packed_vec3_f32_array_element, 4u>,
+ @align(16)
+ arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ var f1 = S();
+ var f2 = vec3<f32>();
+ var f3 = mat3x3<f32>();
+ var f4 = array(vec3<f32>(), vec3<f32>(), vec3<f32>(), vec3<f32>());
+ var f5 = array(mat3x3<f32>(), mat3x3<f32>(), mat3x3<f32>(), mat3x3<f32>());
+ let v_1 = f1.v;
+ let v_2 = f2;
+ let v_3 = f3[0];
+ let v_4 = f4[1];
+ let v_5 = f5[2][2];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, MixedAddressSpace_Pointers_Function) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn f() {
+ var f1 : S = P;
+ var f2 : vec3<f32> = P.v;
+ var f3 : array<vec3<f32>, 4>;
+ var f4 : mat3x3<f32> = P.m;
+ let pv_1 : ptr<function, vec3<f32>> = &f1.v;
+ let pv_2 : ptr<function, vec3<f32>> = &f2;
+ let pv_3 : ptr<function, vec3<f32>> = &f3[0];
+ let pv_4 : ptr<function, mat3x3<f32>> = &f1.m;
+ let pv_5 : ptr<function, mat3x3<f32>> = &f4;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_1(in : S_tint_packed_vec3) -> S {
+ var result : S;
+ result.v = vec3<f32>(in.v);
+ result.m = tint_unpack_vec3_in_composite(in.m);
+ return result;
+}
+
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn f() {
+ var f1 : S = tint_unpack_vec3_in_composite_1(P);
+ var f2 : vec3<f32> = vec3<f32>(P.v);
+ var f3 : array<vec3<f32>, 4>;
+ var f4 : mat3x3<f32> = tint_unpack_vec3_in_composite(P.m);
+ let pv_1 : ptr<function, vec3<f32>> = &(f1.v);
+ let pv_2 : ptr<function, vec3<f32>> = &(f2);
+ let pv_3 : ptr<function, vec3<f32>> = &(f3[0]);
+ let pv_4 : ptr<function, mat3x3<f32>> = &(f1.m);
+ let pv_5 : ptr<function, mat3x3<f32>> = &(f4);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, MixedAddressSpace_Pointers_Private) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+var<private> p1 : S;
+var<private> p2 : vec3<f32>;
+var<private> p3 : array<vec3<f32>, 4>;
+var<private> p4 : mat3x3<f32>;
+
+fn f() {
+ let pv_1 : ptr<private, vec3<f32>> = &p1.v;
+ let pv_2 : ptr<private, vec3<f32>> = &p2;
+ let pv_3 : ptr<private, vec3<f32>> = &p3[0];
+ let pv_4 : ptr<private, mat3x3<f32>> = &p1.m;
+ let pv_5 : ptr<private, mat3x3<f32>> = &p4;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+var<private> p1 : S;
+
+var<private> p2 : vec3<f32>;
+
+var<private> p3 : array<vec3<f32>, 4>;
+
+var<private> p4 : mat3x3<f32>;
+
+fn f() {
+ let pv_1 : ptr<private, vec3<f32>> = &(p1.v);
+ let pv_2 : ptr<private, vec3<f32>> = &(p2);
+ let pv_3 : ptr<private, vec3<f32>> = &(p3[0]);
+ let pv_4 : ptr<private, mat3x3<f32>> = &(p1.m);
+ let pv_5 : ptr<private, mat3x3<f32>> = &(p4);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, MixedAddressSpace_Pointers_Workgroup) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+var<workgroup> w1 : S;
+var<workgroup> w2 : vec3<f32>;
+var<workgroup> w3 : array<vec3<f32>, 4>;
+var<workgroup> w4 : mat3x3<f32>;
+
+fn f() {
+ let pv_1 : ptr<workgroup, vec3<f32>> = &w1.v;
+ let pv_2 : ptr<workgroup, vec3<f32>> = &w2;
+ let pv_3 : ptr<workgroup, vec3<f32>> = &w3[0];
+ let pv_4 : ptr<workgroup, mat3x3<f32>> = &w1.m;
+ let pv_5 : ptr<workgroup, mat3x3<f32>> = &w4;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+var<workgroup> w1 : S;
+
+var<workgroup> w2 : vec3<f32>;
+
+var<workgroup> w3 : array<vec3<f32>, 4>;
+
+var<workgroup> w4 : mat3x3<f32>;
+
+fn f() {
+ let pv_1 : ptr<workgroup, vec3<f32>> = &(w1.v);
+ let pv_2 : ptr<workgroup, vec3<f32>> = &(w2);
+ let pv_3 : ptr<workgroup, vec3<f32>> = &(w3[0]);
+ let pv_4 : ptr<workgroup, mat3x3<f32>> = &(w1.m);
+ let pv_5 : ptr<workgroup, mat3x3<f32>> = &(w4);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, MixedAddressSpace_PointerParameters) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S;
+
+fn g_v(p : ptr<function, vec3<f32>>) -> vec3<f32> {
+ return *p;
+}
+
+fn g_m(p : ptr<function, mat3x3<f32>>) -> mat3x3<f32> {
+ return *p;
+}
+
+fn f() {
+ var f1 : S = P;
+ var f2 : vec3<f32> = P.v;
+ var f3 : array<vec3<f32>, 4>;
+ var f4 : mat3x3<f32> = P.m;
+ g_v(&f1.v);
+ g_v(&f2);
+ g_v(&f3[0]);
+ g_m(&f1.m);
+ g_m(&f4);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_1(in : S_tint_packed_vec3) -> S {
+ var result : S;
+ result.v = vec3<f32>(in.v);
+ result.m = tint_unpack_vec3_in_composite(in.m);
+ return result;
+}
+
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage> P : S_tint_packed_vec3;
+
+fn g_v(p : ptr<function, vec3<f32>>) -> vec3<f32> {
+ return *(p);
+}
+
+fn g_m(p : ptr<function, mat3x3<f32>>) -> mat3x3<f32> {
+ return *(p);
+}
+
+fn f() {
+ var f1 : S = tint_unpack_vec3_in_composite_1(P);
+ var f2 : vec3<f32> = vec3<f32>(P.v);
+ var f3 : array<vec3<f32>, 4>;
+ var f4 : mat3x3<f32> = tint_unpack_vec3_in_composite(P.m);
+ g_v(&(f1.v));
+ g_v(&(f2));
+ g_v(&(f3[0]));
+ g_m(&(f1.m));
+ g_m(&(f4));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, WriteVec3Swizzle_FromRef) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> v : vec3<f32>;
+
+fn f() {
+ v = v.zyx;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+@group(0) @binding(0) var<storage, read_write> v : __packed_vec3<f32>;
+
+fn f() {
+ v = __packed_vec3<f32>(vec3<f32>(v).zyx);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, WriteVec3Swizzle_FromValue) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> v : vec3<f32>;
+
+fn f() {
+ v = vec3f(1, 2, 3).zyx;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+@group(0) @binding(0) var<storage, read_write> v : __packed_vec3<f32>;
+
+fn f() {
+ v = __packed_vec3<f32>(vec3f(1, 2, 3).zyx);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, WriteVec3Component_FromPackedValueIndexAccessor) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+fn g() -> S {
+ return S();
+}
+
+fn f() {
+ s.v[0] = g().v[1];
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+}
+
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S_tint_packed_vec3;
+
+fn g() -> S {
+ return S();
+}
+
+fn f() {
+ s.v[0] = g().v[1];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ExtractVec3FromStructValueExpression) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : S;
+
+fn f() {
+ var v_var : vec3<f32> = S().v;
+ let v_let : vec3<f32> = S().v;
+ v_var = S().v;
+ v_var = S().v * 2.0;
+ buffer = S(S().v);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+}
+
+fn tint_pack_vec3_in_composite(in : S) -> S_tint_packed_vec3 {
+ var result : S_tint_packed_vec3;
+ result.v = __packed_vec3<f32>(in.v);
+ return result;
+}
+
+struct S {
+ v : vec3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : S_tint_packed_vec3;
+
+fn f() {
+ var v_var : vec3<f32> = S().v;
+ let v_let : vec3<f32> = S().v;
+ v_var = S().v;
+ v_var = (S().v * 2.0);
+ buffer = tint_pack_vec3_in_composite(S(S().v));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ExtractArrayOfVec3FromStructValueExpression) {
+ auto* src = R"(
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : S;
+
+fn f() {
+ var arr_var : array<vec3<f32>, 4> = S().arr;
+ let arr_let : array<vec3<f32>, 4> = S().arr;
+ arr_var = S().arr;
+ arr_var[0] = S().arr[0] * 2.0;
+ buffer = S(S().arr);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element, 4u>,
+}
+
+fn tint_pack_vec3_in_composite(in : array<vec3<f32>, 4u>) -> array<tint_packed_vec3_f32_array_element, 4u> {
+ var result : array<tint_packed_vec3_f32_array_element, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_1(in : S) -> S_tint_packed_vec3 {
+ var result : S_tint_packed_vec3;
+ result.arr = tint_pack_vec3_in_composite(in.arr);
+ return result;
+}
+
+struct S {
+ arr : array<vec3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : S_tint_packed_vec3;
+
+fn f() {
+ var arr_var : array<vec3<f32>, 4> = S().arr;
+ let arr_let : array<vec3<f32>, 4> = S().arr;
+ arr_var = S().arr;
+ arr_var[0] = (S().arr[0] * 2.0);
+ buffer = tint_pack_vec3_in_composite_1(S(S().arr));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ExtractNestedArrayFromStructValueExpression) {
+ auto* src = R"(
+struct S {
+ arr : array<array<vec3<f32>, 4>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : S;
+
+fn f() {
+ var arr_var : array<array<vec3<f32>, 4>, 4> = S().arr;
+ var inner_var : array<vec3<f32>, 4> = S().arr[0];
+ let arr_let : array<array<vec3<f32>, 4>, 4> = S().arr;
+ arr_var = S().arr;
+ inner_var = S().arr[0];
+ arr_var[0][0] = S().arr[0][0] * 2.0;
+ buffer = S(S().arr);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<array<tint_packed_vec3_f32_array_element, 4u>, 4u>,
+}
+
+fn tint_pack_vec3_in_composite(in : array<vec3<f32>, 4u>) -> array<tint_packed_vec3_f32_array_element, 4u> {
+ var result : array<tint_packed_vec3_f32_array_element, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_1(in : array<array<vec3<f32>, 4u>, 4u>) -> array<array<tint_packed_vec3_f32_array_element, 4u>, 4u> {
+ var result : array<array<tint_packed_vec3_f32_array_element, 4u>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_pack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_2(in : S) -> S_tint_packed_vec3 {
+ var result : S_tint_packed_vec3;
+ result.arr = tint_pack_vec3_in_composite_1(in.arr);
+ return result;
+}
+
+struct S {
+ arr : array<array<vec3<f32>, 4>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : S_tint_packed_vec3;
+
+fn f() {
+ var arr_var : array<array<vec3<f32>, 4>, 4> = S().arr;
+ var inner_var : array<vec3<f32>, 4> = S().arr[0];
+ let arr_let : array<array<vec3<f32>, 4>, 4> = S().arr;
+ arr_var = S().arr;
+ inner_var = S().arr[0];
+ arr_var[0][0] = (S().arr[0][0] * 2.0);
+ buffer = tint_pack_vec3_in_composite_2(S(S().arr));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ExtractMatrixFromStructValueExpression) {
+ auto* src = R"(
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : S;
+
+fn f() {
+ var m_var : mat3x3<f32> = S().m;
+ let m_let : mat3x3<f32> = S().m;
+ m_var = S().m;
+ m_var = S().m * 2.0;
+ buffer = S(S().m);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+}
+
+fn tint_pack_vec3_in_composite(in : mat3x3<f32>) -> array<tint_packed_vec3_f32_array_element, 3u> {
+ var result : array<tint_packed_vec3_f32_array_element, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_1(in : S) -> S_tint_packed_vec3 {
+ var result : S_tint_packed_vec3;
+ result.m = tint_pack_vec3_in_composite(in.m);
+ return result;
+}
+
+struct S {
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : S_tint_packed_vec3;
+
+fn f() {
+ var m_var : mat3x3<f32> = S().m;
+ let m_let : mat3x3<f32> = S().m;
+ m_var = S().m;
+ m_var = (S().m * 2.0);
+ buffer = tint_pack_vec3_in_composite_1(S(S().m));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, ExtractArrayOfMatrixFromStructValueExpression) {
+ auto* src = R"(
+struct S {
+ arr : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : S;
+
+fn f() {
+ var arr_var : array<mat3x3<f32>, 4> = S().arr;
+ let arr_let : array<mat3x3<f32>, 4> = S().arr;
+ arr_var = S().arr;
+ arr_var[0] = S().arr[0] * 2.0;
+ buffer = S(S().arr);
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+}
+
+fn tint_pack_vec3_in_composite(in : mat3x3<f32>) -> array<tint_packed_vec3_f32_array_element, 3u> {
+ var result : array<tint_packed_vec3_f32_array_element, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = tint_packed_vec3_f32_array_element(__packed_vec3<f32>(in[i]));
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_1(in : array<mat3x3<f32>, 4u>) -> array<array<tint_packed_vec3_f32_array_element, 3u>, 4u> {
+ var result : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_pack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+fn tint_pack_vec3_in_composite_2(in : S) -> S_tint_packed_vec3 {
+ var result : S_tint_packed_vec3;
+ result.arr = tint_pack_vec3_in_composite_1(in.arr);
+ return result;
+}
+
+struct S {
+ arr : array<mat3x3<f32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : S_tint_packed_vec3;
+
+fn f() {
+ var arr_var : array<mat3x3<f32>, 4> = S().arr;
+ let arr_let : array<mat3x3<f32>, 4> = S().arr;
+ arr_var = S().arr;
+ arr_var[0] = (S().arr[0] * 2.0);
+ buffer = tint_pack_vec3_in_composite_2(S(S().arr));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, NestedArrays_Let) {
+ auto* src = R"(
+struct S {
+ arr_v : array<array<vec3<f32>, 4>, 4>,
+ arr_m : array<array<mat3x3<f32>, 4>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr_s : array<S, 4>;
+
+fn f() {
+ let full_let : array<S, 4> = arr_s;
+ let struct_let : S = arr_s[0];
+ let outer_arr_v_let : array<array<vec3<f32>, 4>, 4> = arr_s[0].arr_v;
+ let inner_arr_v_let : array<vec3<f32>, 4> = arr_s[0].arr_v[1];
+ let v_let : vec3<f32> = arr_s[0].arr_v[1][2];
+ let v_element_let : f32 = arr_s[0].arr_v[1][2].y;
+ let outer_arr_m_let : array<array<mat3x3<f32>, 4>, 4> = arr_s[0].arr_m;
+ let inner_arr_m_let : array<mat3x3<f32>, 4> = arr_s[0].arr_m[1];
+ let m_let : mat3x3<f32> = arr_s[0].arr_m[1][2];
+ let m_col_let : vec3<f32> = arr_s[0].arr_m[1][2][0];
+ let m_element_let : f32 = arr_s[0].arr_m[1][2][0].y;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr_v : array<array<tint_packed_vec3_f32_array_element, 4u>, 4u>,
+ @align(16)
+ arr_m : array<array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 4u>) -> array<vec3<f32>, 4u> {
+ var result : array<vec3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_1(in : array<array<tint_packed_vec3_f32_array_element, 4u>, 4u>) -> array<array<vec3<f32>, 4u>, 4u> {
+ var result : array<array<vec3<f32>, 4u>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_2(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_3(in : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>) -> array<mat3x3<f32>, 4u> {
+ var result : array<mat3x3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite_2(in[i]);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_4(in : array<array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>, 4u>) -> array<array<mat3x3<f32>, 4u>, 4u> {
+ var result : array<array<mat3x3<f32>, 4u>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite_3(in[i]);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_5(in : S_tint_packed_vec3) -> S {
+ var result : S;
+ result.arr_v = tint_unpack_vec3_in_composite_1(in.arr_v);
+ result.arr_m = tint_unpack_vec3_in_composite_4(in.arr_m);
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_6(in : array<S_tint_packed_vec3, 4u>) -> array<S, 4u> {
+ var result : array<S, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite_5(in[i]);
+ }
+ return result;
+}
+
+struct S {
+ arr_v : array<array<vec3<f32>, 4>, 4>,
+ arr_m : array<array<mat3x3<f32>, 4>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr_s : array<S_tint_packed_vec3, 4u>;
+
+fn f() {
+ let full_let : array<S, 4> = tint_unpack_vec3_in_composite_6(arr_s);
+ let struct_let : S = tint_unpack_vec3_in_composite_5(arr_s[0]);
+ let outer_arr_v_let : array<array<vec3<f32>, 4>, 4> = tint_unpack_vec3_in_composite_1(arr_s[0].arr_v);
+ let inner_arr_v_let : array<vec3<f32>, 4> = tint_unpack_vec3_in_composite(arr_s[0].arr_v[1]);
+ let v_let : vec3<f32> = vec3<f32>(arr_s[0].arr_v[1][2].elements);
+ let v_element_let : f32 = arr_s[0].arr_v[1][2].elements.y;
+ let outer_arr_m_let : array<array<mat3x3<f32>, 4>, 4> = tint_unpack_vec3_in_composite_4(arr_s[0].arr_m);
+ let inner_arr_m_let : array<mat3x3<f32>, 4> = tint_unpack_vec3_in_composite_3(arr_s[0].arr_m[1]);
+ let m_let : mat3x3<f32> = tint_unpack_vec3_in_composite_2(arr_s[0].arr_m[1][2]);
+ let m_col_let : vec3<f32> = vec3<f32>(arr_s[0].arr_m[1][2][0].elements);
+ let m_element_let : f32 = arr_s[0].arr_m[1][2][0].elements.y;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, NestedArrays_VarInit) {
+ auto* src = R"(
+struct S {
+ arr_v : array<array<vec3<f32>, 4>, 4>,
+ arr_m : array<array<mat3x3<f32>, 4>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr_s : array<S, 4>;
+
+fn f() {
+ var full_var : array<S, 4> = arr_s;
+ var struct_var : S = arr_s[0];
+ var outer_arr_v_var : array<array<vec3<f32>, 4>, 4> = arr_s[0].arr_v;
+ var inner_arr_v_var : array<vec3<f32>, 4> = arr_s[0].arr_v[1];
+ var v_var : vec3<f32> = arr_s[0].arr_v[1][2];
+ var v_element_var : f32 = arr_s[0].arr_v[1][2].y;
+ var outer_arr_m_var : array<array<mat3x3<f32>, 4>, 4> = arr_s[0].arr_m;
+ var inner_arr_m_var : array<mat3x3<f32>, 4> = arr_s[0].arr_m[1];
+ var m_var : mat3x3<f32> = arr_s[0].arr_m[1][2];
+ var m_col_var : vec3<f32> = arr_s[0].arr_m[1][2][0];
+ var m_element_var : f32 = arr_s[0].arr_m[1][2][0].y;
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr_v : array<array<tint_packed_vec3_f32_array_element, 4u>, 4u>,
+ @align(16)
+ arr_m : array<array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 4u>) -> array<vec3<f32>, 4u> {
+ var result : array<vec3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_1(in : array<array<tint_packed_vec3_f32_array_element, 4u>, 4u>) -> array<array<vec3<f32>, 4u>, 4u> {
+ var result : array<array<vec3<f32>, 4u>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_2(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_3(in : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>) -> array<mat3x3<f32>, 4u> {
+ var result : array<mat3x3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite_2(in[i]);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_4(in : array<array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>, 4u>) -> array<array<mat3x3<f32>, 4u>, 4u> {
+ var result : array<array<mat3x3<f32>, 4u>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite_3(in[i]);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_5(in : S_tint_packed_vec3) -> S {
+ var result : S;
+ result.arr_v = tint_unpack_vec3_in_composite_1(in.arr_v);
+ result.arr_m = tint_unpack_vec3_in_composite_4(in.arr_m);
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_6(in : array<S_tint_packed_vec3, 4u>) -> array<S, 4u> {
+ var result : array<S, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite_5(in[i]);
+ }
+ return result;
+}
+
+struct S {
+ arr_v : array<array<vec3<f32>, 4>, 4>,
+ arr_m : array<array<mat3x3<f32>, 4>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr_s : array<S_tint_packed_vec3, 4u>;
+
+fn f() {
+ var full_var : array<S, 4> = tint_unpack_vec3_in_composite_6(arr_s);
+ var struct_var : S = tint_unpack_vec3_in_composite_5(arr_s[0]);
+ var outer_arr_v_var : array<array<vec3<f32>, 4>, 4> = tint_unpack_vec3_in_composite_1(arr_s[0].arr_v);
+ var inner_arr_v_var : array<vec3<f32>, 4> = tint_unpack_vec3_in_composite(arr_s[0].arr_v[1]);
+ var v_var : vec3<f32> = vec3<f32>(arr_s[0].arr_v[1][2].elements);
+ var v_element_var : f32 = arr_s[0].arr_v[1][2].elements.y;
+ var outer_arr_m_var : array<array<mat3x3<f32>, 4>, 4> = tint_unpack_vec3_in_composite_4(arr_s[0].arr_m);
+ var inner_arr_m_var : array<mat3x3<f32>, 4> = tint_unpack_vec3_in_composite_3(arr_s[0].arr_m[1]);
+ var m_var : mat3x3<f32> = tint_unpack_vec3_in_composite_2(arr_s[0].arr_m[1][2]);
+ var m_col_var : vec3<f32> = vec3<f32>(arr_s[0].arr_m[1][2][0].elements);
+ var m_element_var : f32 = arr_s[0].arr_m[1][2][0].elements.y;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, NestedArrays_VarAssignment) {
+ auto* src = R"(
+struct S {
+ arr_v : array<array<vec3<f32>, 4>, 4>,
+ arr_m : array<array<mat3x3<f32>, 4>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr_s : array<S, 4>;
+
+fn f() {
+ var full_var : array<S, 4>;
+ var struct_var : S;
+ var outer_arr_v_var : array<array<vec3<f32>, 4>, 4>;
+ var inner_arr_v_var : array<vec3<f32>, 4>;
+ var v_var : vec3<f32>;
+ var v_element_var : f32;
+ var outer_arr_m_var : array<array<mat3x3<f32>, 4>, 4>;
+ var inner_arr_m_var : array<mat3x3<f32>, 4>;
+ var m_var : mat3x3<f32>;
+ var m_col_var : vec3<f32>;
+ var m_element_var : f32;
+
+ full_var = arr_s;
+ struct_var = arr_s[0];
+ outer_arr_v_var = arr_s[0].arr_v;
+ inner_arr_v_var = arr_s[0].arr_v[1];
+ v_var = arr_s[0].arr_v[1][2];
+ v_element_var = arr_s[0].arr_v[1][2].y;
+ outer_arr_m_var = arr_s[0].arr_m;
+ inner_arr_m_var = arr_s[0].arr_m[1];
+ m_var = arr_s[0].arr_m[1][2];
+ m_col_var = arr_s[0].arr_m[1][2][0];
+ m_element_var = arr_s[0].arr_m[1][2][0].y;
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr_v : array<array<tint_packed_vec3_f32_array_element, 4u>, 4u>,
+ @align(16)
+ arr_m : array<array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 4u>) -> array<vec3<f32>, 4u> {
+ var result : array<vec3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_1(in : array<array<tint_packed_vec3_f32_array_element, 4u>, 4u>) -> array<array<vec3<f32>, 4u>, 4u> {
+ var result : array<array<vec3<f32>, 4u>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite(in[i]);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_2(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_3(in : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>) -> array<mat3x3<f32>, 4u> {
+ var result : array<mat3x3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite_2(in[i]);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_4(in : array<array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>, 4u>) -> array<array<mat3x3<f32>, 4u>, 4u> {
+ var result : array<array<mat3x3<f32>, 4u>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite_3(in[i]);
+ }
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_5(in : S_tint_packed_vec3) -> S {
+ var result : S;
+ result.arr_v = tint_unpack_vec3_in_composite_1(in.arr_v);
+ result.arr_m = tint_unpack_vec3_in_composite_4(in.arr_m);
+ return result;
+}
+
+fn tint_unpack_vec3_in_composite_6(in : array<S_tint_packed_vec3, 4u>) -> array<S, 4u> {
+ var result : array<S, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = tint_unpack_vec3_in_composite_5(in[i]);
+ }
+ return result;
+}
+
+struct S {
+ arr_v : array<array<vec3<f32>, 4>, 4>,
+ arr_m : array<array<mat3x3<f32>, 4>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr_s : array<S_tint_packed_vec3, 4u>;
+
+fn f() {
+ var full_var : array<S, 4>;
+ var struct_var : S;
+ var outer_arr_v_var : array<array<vec3<f32>, 4>, 4>;
+ var inner_arr_v_var : array<vec3<f32>, 4>;
+ var v_var : vec3<f32>;
+ var v_element_var : f32;
+ var outer_arr_m_var : array<array<mat3x3<f32>, 4>, 4>;
+ var inner_arr_m_var : array<mat3x3<f32>, 4>;
+ var m_var : mat3x3<f32>;
+ var m_col_var : vec3<f32>;
+ var m_element_var : f32;
+ full_var = tint_unpack_vec3_in_composite_6(arr_s);
+ struct_var = tint_unpack_vec3_in_composite_5(arr_s[0]);
+ outer_arr_v_var = tint_unpack_vec3_in_composite_1(arr_s[0].arr_v);
+ inner_arr_v_var = tint_unpack_vec3_in_composite(arr_s[0].arr_v[1]);
+ v_var = vec3<f32>(arr_s[0].arr_v[1][2].elements);
+ v_element_var = arr_s[0].arr_v[1][2].elements.y;
+ outer_arr_m_var = tint_unpack_vec3_in_composite_4(arr_s[0].arr_m);
+ inner_arr_m_var = tint_unpack_vec3_in_composite_3(arr_s[0].arr_m[1]);
+ m_var = tint_unpack_vec3_in_composite_2(arr_s[0].arr_m[1][2]);
+ m_col_var = vec3<f32>(arr_s[0].arr_m[1][2][0].elements);
+ m_element_var = arr_s[0].arr_m[1][2][0].elements.y;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, RuntimeSizedArray) {
+ auto* src = R"(
+struct S {
+ arr : array<vec3<f32>>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr_v : array<vec3<f32>>;
+@group(0) @binding(1) var<storage, read_write> s : S;
+
+fn main() {
+ s.arr[0] = arr_v[0];
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ arr : array<tint_packed_vec3_f32_array_element>,
+}
+
+struct S {
+ arr : array<vec3<f32>>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr_v : array<tint_packed_vec3_f32_array_element>;
+
+@group(0) @binding(1) var<storage, read_write> s : S_tint_packed_vec3;
+
+fn main() {
+ s.arr[0].elements = arr_v[0].elements;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Mat3x3_F16_Uniform) {
+ // Test that array element alignment validation rules do not trigger when we rewrite an f16
+ // matrix into an array of vec3s in uniform storage.
+ auto* src = R"(
+enable f16;
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> m : mat3x3<f16>;
+
+fn g(p : ptr<uniform, mat3x3<f16>>) -> vec3<f16> {
+ return (*p)[0] + vec3<f16>(1);
+}
+
+fn f() {
+ let v = g(&m);
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_internal_relaxed_uniform_layout;
+enable f16;
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_packed_vec3_f16_array_element {
+ @align(8)
+ elements : __packed_vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> m : array<tint_packed_vec3_f16_array_element, 3u>;
+
+fn g(p : ptr<uniform, array<tint_packed_vec3_f16_array_element, 3u>>) -> vec3<f16> {
+ return (vec3<f16>((*(p))[0].elements) + vec3<f16>(1));
+}
+
+fn f() {
+ let v = g(&(m));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, MultipleComponentTypes_StructMembers) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ f : vec3f,
+ h : vec3h,
+ i : vec3i,
+ u : vec3u,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+fn f() {
+ let f = s.f;
+ let h = s.h;
+ let i = s.i;
+ let u = s.u;
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_internal_relaxed_uniform_layout;
+enable f16;
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ f : __packed_vec3<f32>,
+ @align(8)
+ h : __packed_vec3<f16>,
+ @align(16)
+ i : __packed_vec3<i32>,
+ @align(16)
+ u : __packed_vec3<u32>,
+}
+
+struct S {
+ f : vec3f,
+ h : vec3h,
+ i : vec3i,
+ u : vec3u,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S_tint_packed_vec3;
+
+fn f() {
+ let f = vec3<f32>(s.f);
+ let h = vec3<f16>(s.h);
+ let i = vec3<i32>(s.i);
+ let u = vec3<u32>(s.u);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, MultipleComponentTypes_ArrayElement) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<storage, read_write> arr_f : array<vec3f>;
+@group(0) @binding(1) var<storage, read_write> arr_h : array<vec3h>;
+@group(0) @binding(2) var<storage, read_write> arr_i : array<vec3i>;
+@group(0) @binding(3) var<storage, read_write> arr_u : array<vec3u>;
+
+fn main() {
+ let f = arr_f[0];
+ let h = arr_h[0];
+ let i = arr_i[0];
+ let u = arr_u[0];
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_internal_relaxed_uniform_layout;
+enable f16;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct tint_packed_vec3_f16_array_element {
+ @align(8)
+ elements : __packed_vec3<f16>,
+}
+
+struct tint_packed_vec3_i32_array_element {
+ @align(16)
+ elements : __packed_vec3<i32>,
+}
+
+struct tint_packed_vec3_u32_array_element {
+ @align(16)
+ elements : __packed_vec3<u32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> arr_f : array<tint_packed_vec3_f32_array_element>;
+
+@group(0) @binding(1) var<storage, read_write> arr_h : array<tint_packed_vec3_f16_array_element>;
+
+@group(0) @binding(2) var<storage, read_write> arr_i : array<tint_packed_vec3_i32_array_element>;
+
+@group(0) @binding(3) var<storage, read_write> arr_u : array<tint_packed_vec3_u32_array_element>;
+
+fn main() {
+ let f = vec3<f32>(arr_f[0].elements);
+ let h = vec3<f16>(arr_h[0].elements);
+ let i = vec3<i32>(arr_i[0].elements);
+ let u = vec3<u32>(arr_u[0].elements);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Arithmetic_FromRef) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> buffer_v : vec3<f32>;
+@group(0) @binding(1) var<storage, read_write> buffer_m : mat3x3<f32>;
+@group(0) @binding(2) var<storage, read_write> buffer_arr_v : array<vec3<f32>, 4>;
+@group(0) @binding(3) var<storage, read_write> buffer_arr_m : array<mat3x3<f32>, 4>;
+@group(0) @binding(4) var<storage, read_write> buffer_nested_arr_v : array<array<vec3<f32>, 4>, 4>;
+
+fn f() {
+ var v : vec3<f32> = buffer_v * 2;
+ v = -v;
+ v = buffer_m * v;
+ v = buffer_m[0] + v;
+ v = buffer_arr_v[0] + v;
+ v = buffer_arr_m[0] * v;
+ v = buffer_arr_m[0][1] + v;
+ v = buffer_nested_arr_v[0][0] + v;
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer_v : __packed_vec3<f32>;
+
+@group(0) @binding(1) var<storage, read_write> buffer_m : array<tint_packed_vec3_f32_array_element, 3u>;
+
+@group(0) @binding(2) var<storage, read_write> buffer_arr_v : array<tint_packed_vec3_f32_array_element, 4u>;
+
+@group(0) @binding(3) var<storage, read_write> buffer_arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+@group(0) @binding(4) var<storage, read_write> buffer_nested_arr_v : array<array<tint_packed_vec3_f32_array_element, 4u>, 4u>;
+
+fn f() {
+ var v : vec3<f32> = (vec3<f32>(buffer_v) * 2);
+ v = -(v);
+ v = (tint_unpack_vec3_in_composite(buffer_m) * v);
+ v = (vec3<f32>(buffer_m[0].elements) + v);
+ v = (vec3<f32>(buffer_arr_v[0].elements) + v);
+ v = (tint_unpack_vec3_in_composite(buffer_arr_m[0]) * v);
+ v = (vec3<f32>(buffer_arr_m[0][1].elements) + v);
+ v = (vec3<f32>(buffer_nested_arr_v[0][0].elements) + v);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Arithmetic_FromValue) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> buffer_v : vec3f;
+
+fn f() {
+ var v : vec3f = buffer_v;
+ v = -vec3f(1, 2, 3);
+ v = mat3x3f() * v;
+ v = mat3x3f()[0] + v;
+ v = array<vec3f, 4>()[0] + v;
+ v = array<mat3x3f, 4>()[0] * v;
+ v = array<mat3x3f, 4>()[0][1] + v;
+ v = array<array<vec3f, 4>, 4>()[0][0] + v;
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+@group(0) @binding(0) var<storage, read_write> buffer_v : __packed_vec3<f32>;
+
+fn f() {
+ var v : vec3f = vec3<f32>(buffer_v);
+ v = -(vec3f(1, 2, 3));
+ v = (mat3x3f() * v);
+ v = (mat3x3f()[0] + v);
+ v = (array<vec3f, 4>()[0] + v);
+ v = (array<mat3x3f, 4>()[0] * v);
+ v = (array<mat3x3f, 4>()[0][1] + v);
+ v = (array<array<vec3f, 4>, 4>()[0][0] + v);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Arithmetic_FromRefStruct) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+ nested_arr_v : array<array<vec3<f32>, 4>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : S;
+
+fn f() {
+ var v : vec3<f32> = buffer.v * 2;
+ v = -v;
+ v = buffer.m * v;
+ v = buffer.m[0] + v;
+ v = buffer.arr_v[0] + v;
+ v = buffer.arr_m[0] * v;
+ v = buffer.arr_m[0][1] + v;
+ v = buffer.nested_arr_v[0][0] + v;
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+ @align(16)
+ arr_v : array<tint_packed_vec3_f32_array_element, 4u>,
+ @align(16)
+ arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+ @align(16)
+ nested_arr_v : array<array<tint_packed_vec3_f32_array_element, 4u>, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 3u>) -> mat3x3<f32> {
+ var result : mat3x3<f32>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+ nested_arr_v : array<array<vec3<f32>, 4>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : S_tint_packed_vec3;
+
+fn f() {
+ var v : vec3<f32> = (vec3<f32>(buffer.v) * 2);
+ v = -(v);
+ v = (tint_unpack_vec3_in_composite(buffer.m) * v);
+ v = (vec3<f32>(buffer.m[0].elements) + v);
+ v = (vec3<f32>(buffer.arr_v[0].elements) + v);
+ v = (tint_unpack_vec3_in_composite(buffer.arr_m[0]) * v);
+ v = (vec3<f32>(buffer.arr_m[0][1].elements) + v);
+ v = (vec3<f32>(buffer.nested_arr_v[0][0].elements) + v);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Arithmetic_FromValueStruct) {
+ auto* src = R"(
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+ nested_arr_v : array<array<vec3<f32>, 4>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : S;
+
+fn f() {
+ var v : vec3<f32> = S().v;
+ v = -S().v;
+ v = S().m * v;
+ v = S().m[0] + v;
+ v = S().arr_v[0] + v;
+ v = S().arr_m[0] * v;
+ v = S().arr_m[0][1] + v;
+ v = S().nested_arr_v[0][0] + v;
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : __packed_vec3<f32>,
+ @align(16)
+ m : array<tint_packed_vec3_f32_array_element, 3u>,
+ @align(16)
+ arr_v : array<tint_packed_vec3_f32_array_element, 4u>,
+ @align(16)
+ arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+ @align(16)
+ nested_arr_v : array<array<tint_packed_vec3_f32_array_element, 4u>, 4u>,
+}
+
+struct S {
+ v : vec3<f32>,
+ m : mat3x3<f32>,
+ arr_v : array<vec3<f32>, 4>,
+ arr_m : array<mat3x3<f32>, 4>,
+ nested_arr_v : array<array<vec3<f32>, 4>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : S_tint_packed_vec3;
+
+fn f() {
+ var v : vec3<f32> = S().v;
+ v = -(S().v);
+ v = (S().m * v);
+ v = (S().m[0] + v);
+ v = (S().arr_v[0] + v);
+ v = (S().arr_m[0] * v);
+ v = (S().arr_m[0][1] + v);
+ v = (S().nested_arr_v[0][0] + v);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Aliases) {
+ auto* src = R"(
+alias VecArray = array<vec3<f32>, 4>;
+alias MatArray = array<mat3x3<f32>, 4>;
+alias NestedArray = array<VecArray, 4>;
+
+struct S {
+ v : VecArray,
+ m : MatArray,
+ n : NestedArray,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+@group(0) @binding(1) var<storage, read_write> arr_v : VecArray;
+@group(0) @binding(2) var<storage, read_write> arr_m : MatArray;
+@group(0) @binding(3) var<storage, read_write> arr_n : NestedArray;
+
+fn g(p : ptr<function, VecArray>) {
+}
+
+fn f() {
+ var f_arr_v : VecArray = s.v;
+ g(&f_arr_v);
+
+ arr_v = s.v;
+ arr_m[0] = s.m[0];
+ arr_n[1][2] = s.n[1][2];
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct tint_packed_vec3_f32_array_element {
+ @align(16)
+ elements : __packed_vec3<f32>,
+}
+
+struct S_tint_packed_vec3 {
+ @align(16)
+ v : array<tint_packed_vec3_f32_array_element, 4u>,
+ @align(16)
+ m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>,
+ @align(16)
+ n : array<array<tint_packed_vec3_f32_array_element, 4u>, 4u>,
+}
+
+fn tint_unpack_vec3_in_composite(in : array<tint_packed_vec3_f32_array_element, 4u>) -> array<vec3<f32>, 4u> {
+ var result : array<vec3<f32>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ result[i] = vec3<f32>(in[i].elements);
+ }
+ return result;
+}
+
+alias VecArray = array<vec3<f32>, 4>;
+
+alias MatArray = array<mat3x3<f32>, 4>;
+
+alias NestedArray = array<VecArray, 4>;
+
+struct S {
+ v : VecArray,
+ m : MatArray,
+ n : NestedArray,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S_tint_packed_vec3;
+
+@group(0) @binding(1) var<storage, read_write> arr_v : array<tint_packed_vec3_f32_array_element, 4u>;
+
+@group(0) @binding(2) var<storage, read_write> arr_m : array<array<tint_packed_vec3_f32_array_element, 3u>, 4u>;
+
+@group(0) @binding(3) var<storage, read_write> arr_n : array<array<tint_packed_vec3_f32_array_element, 4u>, 4u>;
+
+fn g(p : ptr<function, VecArray>) {
+}
+
+fn f() {
+ var f_arr_v : VecArray = tint_unpack_vec3_in_composite(s.v);
+ g(&(f_arr_v));
+ arr_v = s.v;
+ arr_m[0] = s.m[0];
+ arr_n[1][2].elements = s.n[1][2].elements;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PackedVec3Test, Vec3Bool) {
+ // Make sure that we don't rewrite vec3<bool> types, as the `packed_bool<n>` types are reserved
+ // in MSL and might not be supported everywhere.
+ auto* src = R"(
+struct S {
+ vf : vec3<f32>,
+ af : array<vec3<f32>, 4>,
+ vb : vec3<bool>,
+ ab : array<vec3<bool>, 4>,
+}
+
+// Create a vec3 storage buffer so that the transform is not skipped.
+@group(0) @binding(0) var<storage, read_write> buffer : vec3<f32>;
+
+fn f() {
+ var f : S;
+ f.vf = buffer;
+ buffer = f.af[0];
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_internal_relaxed_uniform_layout;
+
+struct S {
+ vf : vec3<f32>,
+ af : array<vec3<f32>, 4>,
+ vb : vec3<bool>,
+ ab : array<vec3<bool>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : __packed_vec3<f32>;
+
+fn f() {
+ var f : S;
+ f.vf = vec3<f32>(buffer);
+ buffer = __packed_vec3<f32>(f.af[0]);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PackedVec3>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/pad_structs.cc b/src/tint/lang/wgsl/ast/transform/pad_structs.cc
new file mode 100644
index 0000000..d276b15
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/pad_structs.cc
@@ -0,0 +1,159 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/pad_structs.h"
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/disable_validation_attribute.h"
+#include "src/tint/lang/wgsl/ast/parameter.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/module.h"
+#include "src/tint/sem/value_constructor.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::PadStructs);
+
+namespace tint::ast::transform {
+
+namespace {
+
+void CreatePadding(utils::Vector<const StructMember*, 8>* new_members,
+ utils::Hashset<const StructMember*, 8>* padding_members,
+ ProgramBuilder* b,
+ uint32_t bytes) {
+ const size_t count = bytes / 4u;
+ padding_members->Reserve(count);
+ new_members->Reserve(count);
+ for (uint32_t i = 0; i < count; ++i) {
+ auto name = b->Symbols().New("pad");
+ auto* member = b->Member(name, b->ty.u32());
+ padding_members->Add(member);
+ new_members->Push(member);
+ }
+}
+
+} // namespace
+
+PadStructs::PadStructs() = default;
+
+PadStructs::~PadStructs() = default;
+
+Transform::ApplyResult PadStructs::Apply(const Program* src, const DataMap&, DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+ auto& sem = src->Sem();
+
+ std::unordered_map<const Struct*, const Struct*> replaced_structs;
+ utils::Hashset<const StructMember*, 8> padding_members;
+
+ ctx.ReplaceAll([&](const Struct* ast_str) -> const Struct* {
+ auto* str = sem.Get(ast_str);
+ if (!str || !str->IsHostShareable()) {
+ return nullptr;
+ }
+ uint32_t offset = 0;
+ bool has_runtime_sized_array = false;
+ utils::Vector<const StructMember*, 8> new_members;
+ for (auto* mem : str->Members()) {
+ auto name = mem->Name().Name();
+
+ if (offset < mem->Offset()) {
+ CreatePadding(&new_members, &padding_members, ctx.dst, mem->Offset() - offset);
+ offset = mem->Offset();
+ }
+
+ auto* ty = mem->Type();
+ auto type = CreateASTTypeFor(ctx, ty);
+
+ new_members.Push(b.Member(name, type));
+
+ uint32_t size = ty->Size();
+ if (ty->Is<type::Struct>() && str->UsedAs(builtin::AddressSpace::kUniform)) {
+ // std140 structs should be padded out to 16 bytes.
+ size = utils::RoundUp(16u, size);
+ } else if (auto* array_ty = ty->As<type::Array>()) {
+ if (array_ty->Count()->Is<type::RuntimeArrayCount>()) {
+ has_runtime_sized_array = true;
+ }
+ }
+ offset += size;
+ }
+
+ // Add any required padding after the last member, if it's not a runtime-sized array.
+ uint32_t struct_size = str->Size();
+ if (str->UsedAs(builtin::AddressSpace::kUniform)) {
+ struct_size = utils::RoundUp(16u, struct_size);
+ }
+ if (offset < struct_size && !has_runtime_sized_array) {
+ CreatePadding(&new_members, &padding_members, ctx.dst, struct_size - offset);
+ }
+
+ utils::Vector<const Attribute*, 1> struct_attribs;
+ if (!padding_members.IsEmpty()) {
+ struct_attribs = utils::Vector{b.Disable(DisabledValidation::kIgnoreStructMemberLimit)};
+ }
+
+ auto* new_struct = b.create<Struct>(ctx.Clone(ast_str->name), std::move(new_members),
+ std::move(struct_attribs));
+ replaced_structs[ast_str] = new_struct;
+ return new_struct;
+ });
+
+ ctx.ReplaceAll([&](const CallExpression* ast_call) -> const CallExpression* {
+ if (ast_call->args.Length() == 0) {
+ return nullptr;
+ }
+
+ auto* call = sem.Get<sem::Call>(ast_call);
+ if (!call) {
+ return nullptr;
+ }
+ auto* cons = call->Target()->As<sem::ValueConstructor>();
+ if (!cons) {
+ return nullptr;
+ }
+ auto* str = cons->ReturnType()->As<sem::Struct>();
+ if (!str) {
+ return nullptr;
+ }
+
+ auto* new_struct = replaced_structs[str->Declaration()];
+ if (!new_struct) {
+ return nullptr;
+ }
+
+ utils::Vector<const Expression*, 8> new_args;
+
+ auto* arg = ast_call->args.begin();
+ for (auto* member : new_struct->members) {
+ if (padding_members.Contains(member)) {
+ new_args.Push(b.Expr(0_u));
+ } else {
+ new_args.Push(ctx.Clone(*arg));
+ arg++;
+ }
+ }
+ return b.Call(CreateASTTypeFor(ctx, str), new_args);
+ });
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/pad_structs.h b/src/tint/lang/wgsl/ast/transform/pad_structs.h
new file mode 100644
index 0000000..64d42d1
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/pad_structs.h
@@ -0,0 +1,43 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_PAD_STRUCTS_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_PAD_STRUCTS_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// This transform turns all explicit alignment and sizing into padding
+/// members of structs. This is required for GLSL ES, since it not support
+/// the offset= decoration.
+///
+/// @note This transform requires the CanonicalizeEntryPointIO transform to have been run first.
+class PadStructs final : public utils::Castable<PadStructs, Transform> {
+ public:
+ /// Constructor
+ PadStructs();
+
+ /// Destructor
+ ~PadStructs() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_PAD_STRUCTS_H_
diff --git a/src/tint/lang/wgsl/ast/transform/pad_structs_test.cc b/src/tint/lang/wgsl/ast/transform/pad_structs_test.cc
new file mode 100644
index 0000000..03c8baa
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/pad_structs_test.cc
@@ -0,0 +1,608 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/pad_structs.h"
+
+#include <memory>
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using PadStructsTest = TransformTest;
+
+TEST_F(PadStructsTest, EmptyModule) {
+ auto* src = "";
+ auto* expect = src;
+
+ Transform::DataMap data;
+ auto got = Run<PadStructs>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PadStructsTest, Uniform) {
+ auto* src = R"(
+struct S {
+ x : i32,
+}
+
+@group(0) @binding(0) var<uniform> u : S;
+
+fn main() {
+ let x = u.x;
+}
+)";
+ auto* expect = R"(
+@internal(disable_validation__ignore_struct_member)
+struct S {
+ x : i32,
+ pad : u32,
+ pad_1 : u32,
+ pad_2 : u32,
+}
+
+@group(0) @binding(0) var<uniform> u : S;
+
+fn main() {
+ let x = u.x;
+}
+)";
+ Transform::DataMap data;
+ auto got = Run<PadStructs>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PadStructsTest, Size) {
+ auto* src = R"(
+struct S {
+ @size(12)
+ x : i32,
+ y : i32,
+}
+
+@group(0) @binding(0) var<uniform> u : S;
+
+fn main() {
+ let x = u.x;
+}
+)";
+ auto* expect = R"(
+@internal(disable_validation__ignore_struct_member)
+struct S {
+ x : i32,
+ pad : u32,
+ pad_1 : u32,
+ y : i32,
+}
+
+@group(0) @binding(0) var<uniform> u : S;
+
+fn main() {
+ let x = u.x;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PadStructs>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PadStructsTest, SizeUniformAndPrivate) {
+ auto* src = R"(
+struct S {
+ @size(12)
+ x : i32,
+ y : i32,
+}
+
+@group(0) @binding(0) var<uniform> u : S;
+
+var<private> p : S;
+
+fn main() {
+ p.x = u.x;
+}
+)";
+ auto* expect = R"(
+@internal(disable_validation__ignore_struct_member)
+struct S {
+ x : i32,
+ pad : u32,
+ pad_1 : u32,
+ y : i32,
+}
+
+@group(0) @binding(0) var<uniform> u : S;
+
+var<private> p : S;
+
+fn main() {
+ p.x = u.x;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PadStructs>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PadStructsTest, SizeStorageAndPrivate) {
+ auto* src = R"(
+struct S {
+ @size(12)
+ x : i32,
+ y : i32,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+var<private> p : S;
+
+fn main() {
+ p.x = 123;
+ s.x = p.x;
+}
+)";
+ auto* expect = R"(
+@internal(disable_validation__ignore_struct_member)
+struct S {
+ x : i32,
+ pad : u32,
+ pad_1 : u32,
+ y : i32,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+var<private> p : S;
+
+fn main() {
+ p.x = 123;
+ s.x = p.x;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PadStructs>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PadStructsTest, SizeUniformAndStorage) {
+ auto* src = R"(
+struct S {
+ @size(12)
+ x : i32,
+ y : i32,
+}
+
+@group(0) @binding(0) var<uniform> u : S;
+
+@group(0) @binding(1) var<storage, read_write> s : S;
+
+fn main() {
+ s.x = u.x;
+}
+)";
+ auto* expect = R"(
+@internal(disable_validation__ignore_struct_member)
+struct S {
+ x : i32,
+ pad : u32,
+ pad_1 : u32,
+ y : i32,
+}
+
+@group(0) @binding(0) var<uniform> u : S;
+
+@group(0) @binding(1) var<storage, read_write> s : S;
+
+fn main() {
+ s.x = u.x;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PadStructs>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PadStructsTest, SizePrivateOnly) {
+ // Structs that are not host-visible should have no explicit padding.
+ auto* src = R"(
+struct S {
+ @size(12)
+ x : i32,
+ y : i32,
+}
+
+var<private> p : S;
+
+fn main() {
+ p.x = 123;
+}
+)";
+ auto* expect = R"(
+struct S {
+ @size(12)
+ x : i32,
+ y : i32,
+}
+
+var<private> p : S;
+
+fn main() {
+ p.x = 123;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PadStructs>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PadStructsTest, AlignUniformAndPrivate) {
+ auto* src = R"(
+struct S {
+ a : i32,
+ @align(16)
+ b : i32,
+}
+
+@group(0) @binding(0) var<uniform> u : S;
+
+var<private> p : S;
+
+fn main() {
+ p.a = u.b;
+ p.b = u.a;
+}
+)";
+ auto* expect = R"(
+@internal(disable_validation__ignore_struct_member)
+struct S {
+ a : i32,
+ pad : u32,
+ pad_1 : u32,
+ pad_2 : u32,
+ b : i32,
+ pad_3 : u32,
+ pad_4 : u32,
+ pad_5 : u32,
+}
+
+@group(0) @binding(0) var<uniform> u : S;
+
+var<private> p : S;
+
+fn main() {
+ p.a = u.b;
+ p.b = u.a;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PadStructs>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PadStructsTest, AlignStorageAndPrivate) {
+ auto* src = R"(
+struct S {
+ a : i32,
+ @align(16)
+ b : i32,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+var<private> p : S;
+
+fn main() {
+ p.a = 123;
+ p.b = 321;
+ s.a = p.b;
+ s.b = p.a;
+}
+)";
+ auto* expect = R"(
+@internal(disable_validation__ignore_struct_member)
+struct S {
+ a : i32,
+ pad : u32,
+ pad_1 : u32,
+ pad_2 : u32,
+ b : i32,
+ pad_3 : u32,
+ pad_4 : u32,
+ pad_5 : u32,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+var<private> p : S;
+
+fn main() {
+ p.a = 123;
+ p.b = 321;
+ s.a = p.b;
+ s.b = p.a;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PadStructs>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PadStructsTest, AlignUniformAndStorage) {
+ auto* src = R"(
+struct S {
+ a : i32,
+ @align(16)
+ b : i32,
+}
+
+@group(0) @binding(0) var<uniform> u : S;
+
+@group(0) @binding(1) var<storage, read_write> s : S;
+
+fn main() {
+ s.a = u.b;
+ s.b = u.a;
+}
+)";
+ auto* expect = R"(
+@internal(disable_validation__ignore_struct_member)
+struct S {
+ a : i32,
+ pad : u32,
+ pad_1 : u32,
+ pad_2 : u32,
+ b : i32,
+ pad_3 : u32,
+ pad_4 : u32,
+ pad_5 : u32,
+}
+
+@group(0) @binding(0) var<uniform> u : S;
+
+@group(0) @binding(1) var<storage, read_write> s : S;
+
+fn main() {
+ s.a = u.b;
+ s.b = u.a;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PadStructs>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PadStructsTest, AlignPrivateOnly) {
+ // Structs that are not host-visible should have no explicit padding.
+ auto* src = R"(
+struct S {
+ a : i32,
+ @align(16)
+ b : i32,
+}
+
+var<private> p : S;
+
+fn main() {
+ p.a = 123;
+ p.b = 321;
+}
+)";
+ auto* expect = R"(
+struct S {
+ a : i32,
+ @align(16)
+ b : i32,
+}
+
+var<private> p : S;
+
+fn main() {
+ p.a = 123;
+ p.b = 321;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PadStructs>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PadStructsTest, LastMemberRuntimeSizeArray) {
+ // Structs with runtime-sized arrays should not be padded after the
+ // last member.
+ auto* src = R"(
+struct T {
+ a : f32,
+ b : i32,
+}
+
+struct S {
+ a : vec4<f32>,
+ b : array<T>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+fn main() {
+ s.b[0] = T(1.0f, 23);
+}
+)";
+ auto* expect = R"(
+struct T {
+ a : f32,
+ b : i32,
+}
+
+struct S {
+ a : vec4<f32>,
+ b : array<T>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+fn main() {
+ s.b[0] = T(1.0f, 23);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PadStructs>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PadStructsTest, LastMemberFixedSizeArray) {
+ // Structs without runtime-sized arrays should be padded after the last
+ // member.
+ auto* src = R"(
+struct T {
+ a : f32,
+ b : i32,
+}
+
+struct S {
+ a : vec4<f32>,
+ b : array<T, 1u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+fn main() {
+ s.b[0] = T(1.0f, 23);
+}
+)";
+ auto* expect = R"(
+struct T {
+ a : f32,
+ b : i32,
+}
+
+@internal(disable_validation__ignore_struct_member)
+struct S {
+ a : vec4<f32>,
+ b : array<T, 1u>,
+ pad : u32,
+ pad_1 : u32,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+fn main() {
+ s.b[0] = T(1.0f, 23);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PadStructs>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PadStructsTest, Initializer) {
+ // Calls to a initializer of a padded struct must be modified to initialize the padding.
+ auto* src = R"(
+struct S {
+ a : f32,
+ @align(8)
+ b : i32,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+fn main() {
+ s = S(1.0f, 2);
+}
+)";
+ auto* expect = R"(
+@internal(disable_validation__ignore_struct_member)
+struct S {
+ a : f32,
+ pad : u32,
+ b : i32,
+ pad_1 : u32,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+fn main() {
+ s = S(1.0f, 0u, 2, 0u);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PadStructs>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PadStructsTest, InitializerZeroArgs) {
+ // Calls to a zero-argument initializer of a padded struct should not be modified.
+ auto* src = R"(
+struct S {
+ a : f32,
+ @align(8)
+ b : i32,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+fn main() {
+ s = S();
+}
+)";
+ auto* expect = R"(
+@internal(disable_validation__ignore_struct_member)
+struct S {
+ a : f32,
+ pad : u32,
+ b : i32,
+ pad_1 : u32,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+fn main() {
+ s = S();
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<PadStructs>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/preserve_padding.cc b/src/tint/lang/wgsl/ast/transform/preserve_padding.cc
new file mode 100644
index 0000000..ac48194
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/preserve_padding.cc
@@ -0,0 +1,247 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/preserve_padding.h"
+
+#include <unordered_set>
+#include <utility>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/struct.h"
+#include "src/tint/switch.h"
+#include "src/tint/type/reference.h"
+#include "src/tint/utils/map.h"
+#include "src/tint/utils/vector.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::PreservePadding);
+
+namespace tint::ast::transform {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+PreservePadding::PreservePadding() = default;
+
+PreservePadding::~PreservePadding() = default;
+
+/// The PIMPL state for the PreservePadding transform
+struct PreservePadding::State {
+ /// Constructor
+ /// @param src the source Program
+ explicit State(const Program* src) : ctx{&b, src, /* auto_clone_symbols */ true} {}
+
+ /// The main function for the transform.
+ /// @returns the ApplyResult
+ ApplyResult Run() {
+ // Gather a list of assignments that need to be transformed.
+ std::unordered_set<const AssignmentStatement*> assignments_to_transform;
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ Switch(
+ node, //
+ [&](const AssignmentStatement* assign) {
+ auto* ty = sem.GetVal(assign->lhs)->Type();
+ if (assign->lhs->Is<PhonyExpression>()) {
+ // Ignore phony assignment.
+ return;
+ }
+ if (ty->As<type::Reference>()->AddressSpace() !=
+ builtin::AddressSpace::kStorage) {
+ // We only care about assignments that write to variables in the storage
+ // address space, as nothing else is host-visible.
+ return;
+ }
+ if (HasPadding(ty->UnwrapRef())) {
+ // The assigned type has padding bytes, so we need to decompose the writes.
+ assignments_to_transform.insert(assign);
+ }
+ },
+ [&](const Enable* enable) {
+ // Check if the full pointer parameters extension is already enabled.
+ if (enable->HasExtension(
+ builtin::Extension::kChromiumExperimentalFullPtrParameters)) {
+ ext_enabled = true;
+ }
+ });
+ }
+ if (assignments_to_transform.empty()) {
+ return SkipTransform;
+ }
+
+ // Replace all assignments that include padding with decomposed versions.
+ ctx.ReplaceAll([&](const AssignmentStatement* assign) -> const Statement* {
+ if (!assignments_to_transform.count(assign)) {
+ return nullptr;
+ }
+ auto* ty = sem.GetVal(assign->lhs)->Type()->UnwrapRef();
+ return MakeAssignment(ty, ctx.Clone(assign->lhs), ctx.Clone(assign->rhs));
+ });
+
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+
+ /// Create a statement that will perform the assignment `lhs = rhs`, creating and using helper
+ /// functions to decompose the assignment into element-wise copies if needed.
+ /// @param ty the type of the assignment
+ /// @param lhs the lhs expression (in the destination program)
+ /// @param rhs the rhs expression (in the destination program)
+ /// @returns the statement that performs the assignment
+ const Statement* MakeAssignment(const type::Type* ty,
+ const Expression* lhs,
+ const Expression* rhs) {
+ if (!HasPadding(ty)) {
+ // No padding - use a regular assignment.
+ return b.Assign(lhs, rhs);
+ }
+
+ // Call (and create if necessary) a helper function that assigns a composite using the
+ // statements in `body`. The helper will have the form:
+ // fn assign_helper_T(dest : ptr<storage, T, read_write>, value : T) {
+ // <body>
+ // }
+ // It will be called by passing a pointer to the original LHS:
+ // assign_helper_T(&lhs, rhs);
+ //
+ // Since this requires passing pointers to the storage address space, this will also enable
+ // the chromium_experimental_full_ptr_parameters extension.
+ const char* kDestParamName = "dest";
+ const char* kValueParamName = "value";
+ auto call_helper = [&](auto&& body) {
+ EnableExtension();
+ auto helper = helpers.GetOrCreate(ty, [&] {
+ auto helper_name = b.Symbols().New("assign_and_preserve_padding");
+ utils::Vector<const Parameter*, 2> params = {
+ b.Param(kDestParamName,
+ b.ty.ptr<storage, read_write>(CreateASTTypeFor(ctx, ty))),
+ b.Param(kValueParamName, CreateASTTypeFor(ctx, ty)),
+ };
+ b.Func(helper_name, params, b.ty.void_(), body());
+ return helper_name;
+ });
+ return b.CallStmt(b.Call(helper, b.AddressOf(lhs), rhs));
+ };
+
+ return Switch(
+ ty, //
+ [&](const type::Array* arr) {
+ // Call a helper function that uses a loop to assigns each element separately.
+ return call_helper([&] {
+ utils::Vector<const Statement*, 8> body;
+ auto* idx = b.Var("i", b.Expr(0_u));
+ body.Push(
+ b.For(b.Decl(idx), b.LessThan(idx, u32(arr->ConstantCount().value())),
+ b.Assign(idx, b.Add(idx, 1_u)),
+ b.Block(MakeAssignment(arr->ElemType(),
+ b.IndexAccessor(b.Deref(kDestParamName), idx),
+ b.IndexAccessor(kValueParamName, idx)))));
+ return body;
+ });
+ },
+ [&](const type::Matrix* mat) {
+ // Call a helper function that assigns each column separately.
+ return call_helper([&] {
+ utils::Vector<const Statement*, 4> body;
+ for (uint32_t i = 0; i < mat->columns(); i++) {
+ body.Push(MakeAssignment(mat->ColumnType(),
+ b.IndexAccessor(b.Deref(kDestParamName), u32(i)),
+ b.IndexAccessor(kValueParamName, u32(i))));
+ }
+ return body;
+ });
+ },
+ [&](const type::Struct* str) {
+ // Call a helper function that assigns each member separately.
+ return call_helper([&] {
+ utils::Vector<const Statement*, 8> body;
+ for (auto member : str->Members()) {
+ auto name = member->Name().Name();
+ body.Push(MakeAssignment(member->Type(),
+ b.MemberAccessor(b.Deref(kDestParamName), name),
+ b.MemberAccessor(kValueParamName, name)));
+ }
+ return body;
+ });
+ },
+ [&](Default) {
+ TINT_ICE(Transform, b.Diagnostics()) << "unhandled type with padding";
+ return nullptr;
+ });
+ }
+
+ /// Checks if a type contains padding bytes.
+ /// @param ty the type to check
+ /// @returns true if `ty` (or any of its contained types) have padding bytes
+ bool HasPadding(const type::Type* ty) {
+ return Switch(
+ ty, //
+ [&](const type::Array* arr) {
+ auto* elem_ty = arr->ElemType();
+ if (elem_ty->Size() % elem_ty->Align() > 0) {
+ return true;
+ }
+ return HasPadding(elem_ty);
+ },
+ [&](const type::Matrix* mat) {
+ auto* col_ty = mat->ColumnType();
+ if (mat->ColumnStride() > col_ty->Size()) {
+ return true;
+ }
+ return HasPadding(col_ty);
+ },
+ [&](const type::Struct* str) {
+ uint32_t current_offset = 0;
+ for (auto* member : str->Members()) {
+ if (member->Offset() > current_offset) {
+ return true;
+ }
+ if (HasPadding(member->Type())) {
+ return true;
+ }
+ current_offset += member->Type()->Size();
+ }
+ return (current_offset < str->Size());
+ },
+ [&](Default) { return false; });
+ }
+
+ /// Enable the full pointer parameters extension, if we have not already done so.
+ void EnableExtension() {
+ if (!ext_enabled) {
+ b.Enable(builtin::Extension::kChromiumExperimentalFullPtrParameters);
+ ext_enabled = true;
+ }
+ }
+
+ private:
+ /// The program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx;
+ /// Alias to the semantic info in ctx.src
+ const sem::Info& sem = ctx.src->Sem();
+ /// Alias to the symbols in ctx.src
+ const SymbolTable& sym = ctx.src->Symbols();
+ /// Flag to track whether we have already enabled the full pointer parameters extension.
+ bool ext_enabled = false;
+ /// Map of semantic types to their assignment helper functions.
+ utils::Hashmap<const type::Type*, Symbol, 8> helpers;
+};
+
+Transform::ApplyResult PreservePadding::Apply(const Program* program,
+ const DataMap&,
+ DataMap&) const {
+ return State(program).Run();
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/preserve_padding.h b/src/tint/lang/wgsl/ast/transform/preserve_padding.h
new file mode 100644
index 0000000..4100013
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/preserve_padding.h
@@ -0,0 +1,47 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_PRESERVE_PADDING_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_PRESERVE_PADDING_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// Decompose assignments of whole structure and array types to preserve padding bytes.
+///
+/// WGSL states that memory operations on structures and arrays will not access padding bytes. To
+/// avoid overwriting padding bytes when writing to buffers, this transform decomposes those
+/// assignments into element-wise assignments via helper functions.
+///
+/// @note Assumes that the DirectVariableTransform will be run afterwards for backends that need it.
+class PreservePadding final : public utils::Castable<PreservePadding, Transform> {
+ public:
+ /// Constructor
+ PreservePadding();
+ /// Destructor
+ ~PreservePadding() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_PRESERVE_PADDING_H_
diff --git a/src/tint/lang/wgsl/ast/transform/preserve_padding_test.cc b/src/tint/lang/wgsl/ast/transform/preserve_padding_test.cc
new file mode 100644
index 0000000..85919e1
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/preserve_padding_test.cc
@@ -0,0 +1,779 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/preserve_padding.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using PreservePaddingTest = TransformTest;
+
+TEST_F(PreservePaddingTest, ShouldRun_EmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<PreservePadding>(src));
+}
+
+TEST_F(PreservePaddingTest, ShouldRun_NonStructVec3) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> v : vec3<u32>;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = vec3<u32>();
+}
+ )";
+
+ EXPECT_FALSE(ShouldRun<PreservePadding>(src));
+}
+
+TEST_F(PreservePaddingTest, ShouldRun_StructWithoutPadding) {
+ auto* src = R"(
+struct S {
+ a : u32,
+ b : u32,
+ c : u32,
+ d : u32,
+ e : vec3<u32>,
+ f : u32,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : S;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = S();
+}
+ )";
+
+ EXPECT_FALSE(ShouldRun<PreservePadding>(src));
+}
+
+TEST_F(PreservePaddingTest, ShouldRun_ArrayWithoutPadding) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> v : array<vec4<u32>, 4>;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = array<vec4<u32>, 4>();
+}
+ )";
+
+ EXPECT_FALSE(ShouldRun<PreservePadding>(src));
+}
+
+TEST_F(PreservePaddingTest, EmptyModule) {
+ auto* src = R"()";
+
+ auto* expect = src;
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, StructTrailingPadding) {
+ auto* src = R"(
+struct S {
+ a : u32,
+ b : u32,
+ c : u32,
+ d : u32,
+ e : vec3<u32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : S;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = S();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ a : u32,
+ b : u32,
+ c : u32,
+ d : u32,
+ e : vec3<u32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : S;
+
+fn assign_and_preserve_padding(dest : ptr<storage, S, read_write>, value : S) {
+ (*(dest)).a = value.a;
+ (*(dest)).b = value.b;
+ (*(dest)).c = value.c;
+ (*(dest)).d = value.d;
+ (*(dest)).e = value.e;
+}
+
+@compute @workgroup_size(1)
+fn foo() {
+ assign_and_preserve_padding(&(v), S());
+}
+)";
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, StructInternalPadding) {
+ auto* src = R"(
+struct S {
+ a : u32,
+ b : vec4<u32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : S;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = S();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ a : u32,
+ b : vec4<u32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : S;
+
+fn assign_and_preserve_padding(dest : ptr<storage, S, read_write>, value : S) {
+ (*(dest)).a = value.a;
+ (*(dest)).b = value.b;
+}
+
+@compute @workgroup_size(1)
+fn foo() {
+ assign_and_preserve_padding(&(v), S());
+}
+)";
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, StructExplicitSize_TrailingPadding) {
+ auto* src = R"(
+struct S {
+ @size(16) a : u32,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : S;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = S();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ @size(16)
+ a : u32,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : S;
+
+fn assign_and_preserve_padding(dest : ptr<storage, S, read_write>, value : S) {
+ (*(dest)).a = value.a;
+}
+
+@compute @workgroup_size(1)
+fn foo() {
+ assign_and_preserve_padding(&(v), S());
+}
+)";
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, StructExplicitSize_InternalPadding) {
+ auto* src = R"(
+struct S {
+ @size(16) a : u32,
+ b : u32,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : S;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = S();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ @size(16)
+ a : u32,
+ b : u32,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : S;
+
+fn assign_and_preserve_padding(dest : ptr<storage, S, read_write>, value : S) {
+ (*(dest)).a = value.a;
+ (*(dest)).b = value.b;
+}
+
+@compute @workgroup_size(1)
+fn foo() {
+ assign_and_preserve_padding(&(v), S());
+}
+)";
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, NestedStructs) {
+ auto* src = R"(
+struct S1 {
+ a1 : u32,
+ b1 : vec3<u32>,
+ c1 : u32,
+}
+
+struct S2 {
+ a2 : u32,
+ b2 : S1,
+ c2 : S1,
+}
+
+struct S3 {
+ a3 : S1,
+ b3 : S2,
+ c3 : S2,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : S3;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = S3();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S1 {
+ a1 : u32,
+ b1 : vec3<u32>,
+ c1 : u32,
+}
+
+struct S2 {
+ a2 : u32,
+ b2 : S1,
+ c2 : S1,
+}
+
+struct S3 {
+ a3 : S1,
+ b3 : S2,
+ c3 : S2,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : S3;
+
+fn assign_and_preserve_padding_1(dest : ptr<storage, S1, read_write>, value : S1) {
+ (*(dest)).a1 = value.a1;
+ (*(dest)).b1 = value.b1;
+ (*(dest)).c1 = value.c1;
+}
+
+fn assign_and_preserve_padding_2(dest : ptr<storage, S2, read_write>, value : S2) {
+ (*(dest)).a2 = value.a2;
+ assign_and_preserve_padding_1(&((*(dest)).b2), value.b2);
+ assign_and_preserve_padding_1(&((*(dest)).c2), value.c2);
+}
+
+fn assign_and_preserve_padding(dest : ptr<storage, S3, read_write>, value : S3) {
+ assign_and_preserve_padding_1(&((*(dest)).a3), value.a3);
+ assign_and_preserve_padding_2(&((*(dest)).b3), value.b3);
+ assign_and_preserve_padding_2(&((*(dest)).c3), value.c3);
+}
+
+@compute @workgroup_size(1)
+fn foo() {
+ assign_and_preserve_padding(&(v), S3());
+}
+)";
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, ArrayOfVec3) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> v : array<vec3<u32>, 4>;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = array<vec3<u32>, 4>();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> v : array<vec3<u32>, 4>;
+
+fn assign_and_preserve_padding(dest : ptr<storage, array<vec3<u32>, 4u>, read_write>, value : array<vec3<u32>, 4u>) {
+ for(var i = 0u; (i < 4u); i = (i + 1u)) {
+ (*(dest))[i] = value[i];
+ }
+}
+
+@compute @workgroup_size(1)
+fn foo() {
+ assign_and_preserve_padding(&(v), array<vec3<u32>, 4>());
+}
+)";
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, ArrayOfArray) {
+ auto* src = R"(
+alias Array = array<array<vec3<u32>, 4>, 3>;
+
+@group(0) @binding(0) var<storage, read_write> v : Array;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = Array();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+alias Array = array<array<vec3<u32>, 4>, 3>;
+
+@group(0) @binding(0) var<storage, read_write> v : Array;
+
+fn assign_and_preserve_padding_1(dest : ptr<storage, array<vec3<u32>, 4u>, read_write>, value : array<vec3<u32>, 4u>) {
+ for(var i = 0u; (i < 4u); i = (i + 1u)) {
+ (*(dest))[i] = value[i];
+ }
+}
+
+fn assign_and_preserve_padding(dest : ptr<storage, array<array<vec3<u32>, 4u>, 3u>, read_write>, value : array<array<vec3<u32>, 4u>, 3u>) {
+ for(var i = 0u; (i < 3u); i = (i + 1u)) {
+ assign_and_preserve_padding_1(&((*(dest))[i]), value[i]);
+ }
+}
+
+@compute @workgroup_size(1)
+fn foo() {
+ assign_and_preserve_padding(&(v), Array());
+}
+)";
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, ArrayOfStructOfArray) {
+ auto* src = R"(
+struct S {
+ a : u32,
+ b : array<vec3<u32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : array<S, 3>;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = array<S, 3>();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ a : u32,
+ b : array<vec3<u32>, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : array<S, 3>;
+
+fn assign_and_preserve_padding_2(dest : ptr<storage, array<vec3<u32>, 4u>, read_write>, value : array<vec3<u32>, 4u>) {
+ for(var i = 0u; (i < 4u); i = (i + 1u)) {
+ (*(dest))[i] = value[i];
+ }
+}
+
+fn assign_and_preserve_padding_1(dest : ptr<storage, S, read_write>, value : S) {
+ (*(dest)).a = value.a;
+ assign_and_preserve_padding_2(&((*(dest)).b), value.b);
+}
+
+fn assign_and_preserve_padding(dest : ptr<storage, array<S, 3u>, read_write>, value : array<S, 3u>) {
+ for(var i = 0u; (i < 3u); i = (i + 1u)) {
+ assign_and_preserve_padding_1(&((*(dest))[i]), value[i]);
+ }
+}
+
+@compute @workgroup_size(1)
+fn foo() {
+ assign_and_preserve_padding(&(v), array<S, 3>());
+}
+)";
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, Mat3x3) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> m : mat3x3<f32>;
+
+@compute @workgroup_size(1)
+fn foo() {
+ m = mat3x3<f32>();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> m : mat3x3<f32>;
+
+fn assign_and_preserve_padding(dest : ptr<storage, mat3x3<f32>, read_write>, value : mat3x3<f32>) {
+ (*(dest))[0u] = value[0u];
+ (*(dest))[1u] = value[1u];
+ (*(dest))[2u] = value[2u];
+}
+
+@compute @workgroup_size(1)
+fn foo() {
+ assign_and_preserve_padding(&(m), mat3x3<f32>());
+}
+)";
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, Mat3x3_InStruct) {
+ auto* src = R"(
+struct S {
+ a : u32,
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : S;
+
+@compute @workgroup_size(1)
+fn foo() {
+ buffer = S();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ a : u32,
+ m : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : S;
+
+fn assign_and_preserve_padding_1(dest : ptr<storage, mat3x3<f32>, read_write>, value : mat3x3<f32>) {
+ (*(dest))[0u] = value[0u];
+ (*(dest))[1u] = value[1u];
+ (*(dest))[2u] = value[2u];
+}
+
+fn assign_and_preserve_padding(dest : ptr<storage, S, read_write>, value : S) {
+ (*(dest)).a = value.a;
+ assign_and_preserve_padding_1(&((*(dest)).m), value.m);
+}
+
+@compute @workgroup_size(1)
+fn foo() {
+ assign_and_preserve_padding(&(buffer), S());
+}
+)";
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, ArrayOfMat3x3) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> arr_m : array<mat3x3<f32>, 4>;
+
+@compute @workgroup_size(1)
+fn foo() {
+ arr_m = array<mat3x3<f32>, 4>();
+ arr_m[0] = mat3x3<f32>();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> arr_m : array<mat3x3<f32>, 4>;
+
+fn assign_and_preserve_padding_1(dest : ptr<storage, mat3x3<f32>, read_write>, value : mat3x3<f32>) {
+ (*(dest))[0u] = value[0u];
+ (*(dest))[1u] = value[1u];
+ (*(dest))[2u] = value[2u];
+}
+
+fn assign_and_preserve_padding(dest : ptr<storage, array<mat3x3<f32>, 4u>, read_write>, value : array<mat3x3<f32>, 4u>) {
+ for(var i = 0u; (i < 4u); i = (i + 1u)) {
+ assign_and_preserve_padding_1(&((*(dest))[i]), value[i]);
+ }
+}
+
+@compute @workgroup_size(1)
+fn foo() {
+ assign_and_preserve_padding(&(arr_m), array<mat3x3<f32>, 4>());
+ assign_and_preserve_padding_1(&(arr_m[0]), mat3x3<f32>());
+}
+)";
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, NoModify_Vec3) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> v : vec3<u32>;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = vec3<u32>();
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, AvoidDuplicateEnables) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ @size(16) a : u32,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : S;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = S();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ @size(16)
+ a : u32,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : S;
+
+fn assign_and_preserve_padding(dest : ptr<storage, S, read_write>, value : S) {
+ (*(dest)).a = value.a;
+}
+
+@compute @workgroup_size(1)
+fn foo() {
+ assign_and_preserve_padding(&(v), S());
+}
+)";
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, NoModify_StructNoPadding) {
+ auto* src = R"(
+struct S {
+ a : u32,
+ b : u32,
+ c : u32,
+ d : u32,
+ e : vec4<u32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : S;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = S();
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, NoModify_ArrayNoPadding) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> v : array<vec4<u32>, 4>;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = array<vec4<u32>, 4>();
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, NoModify_ArrayOfStructNoPadding) {
+ auto* src = R"(
+struct S {
+ a : u32,
+ b : u32,
+ c : u32,
+ d : u32,
+ e : vec4<u32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> v : array<S, 4>;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = array<S, 4>();
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, NoModify_Workgroup) {
+ auto* src = R"(
+struct S {
+ a : u32,
+ b : vec3<u32>,
+}
+
+var<workgroup> v : S;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = S();
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, NoModify_Private) {
+ auto* src = R"(
+struct S {
+ a : u32,
+ b : vec3<u32>,
+}
+
+var<private> v : S;
+
+@compute @workgroup_size(1)
+fn foo() {
+ v = S();
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PreservePaddingTest, NoModify_Function) {
+ auto* src = R"(
+struct S {
+ a : u32,
+ b : vec3<u32>,
+}
+
+@compute @workgroup_size(1)
+fn foo() {
+ var<function> v : S;
+ v = S();
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<PreservePadding>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/promote_initializers_to_let.cc b/src/tint/lang/wgsl/ast/transform/promote_initializers_to_let.cc
new file mode 100644
index 0000000..4569d99
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/promote_initializers_to_let.cc
@@ -0,0 +1,154 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/promote_initializers_to_let.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before.h"
+#include "src/tint/lang/wgsl/ast/traverse_expressions.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/value_constructor.h"
+#include "src/tint/type/struct.h"
+#include "src/tint/utils/hashset.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::PromoteInitializersToLet);
+
+namespace tint::ast::transform {
+
+PromoteInitializersToLet::PromoteInitializersToLet() = default;
+
+PromoteInitializersToLet::~PromoteInitializersToLet() = default;
+
+Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ // Returns true if the expression should be hoisted to a new let statement before the
+ // expression's statement.
+ auto should_hoist = [&](const sem::ValueExpression* expr) {
+ if (!expr->Type()->IsAnyOf<type::Array, type::Struct>()) {
+ // We only care about array and struct initializers
+ return false;
+ }
+
+ // Check whether the expression is an array or structure constructor
+ {
+ // Follow const-chains
+ auto* root_expr = expr;
+ if (expr->Stage() == sem::EvaluationStage::kConstant) {
+ if (expr->Type()->HoldsAbstract()) {
+ // Do not hoist expressions that are not materialized, as doing so would cause
+ // premature materialization.
+ return false;
+ }
+ while (auto* user = root_expr->UnwrapMaterialize()->As<sem::VariableUser>()) {
+ root_expr = user->Variable()->Initializer();
+ }
+ }
+
+ auto* ctor = root_expr->UnwrapMaterialize()->As<sem::Call>();
+ if (!ctor || !ctor->Target()->Is<sem::ValueConstructor>()) {
+ // Root expression is not a value constructor. Not interested in this.
+ return false;
+ }
+ }
+
+ if (auto* src_var_decl = expr->Stmt()->Declaration()->As<VariableDeclStatement>()) {
+ if (src_var_decl->variable->initializer == expr->Declaration()) {
+ // This statement is just a variable declaration with the initializer as the
+ // initializer value. This is what we're attempting to transform to, and so
+ // ignore.
+ return false;
+ }
+ }
+
+ return true;
+ };
+
+ // A list of expressions that should be hoisted.
+ utils::Vector<const sem::ValueExpression*, 32> to_hoist;
+ // A set of expressions that are constant, which _may_ need to be hoisted.
+ utils::Hashset<const Expression*, 32> const_chains;
+
+ // Walk the AST nodes. This order guarantees that leaf-expressions are visited first.
+ for (auto* node : src->ASTNodes().Objects()) {
+ if (auto* sem = src->Sem().GetVal(node)) {
+ auto* stmt = sem->Stmt();
+ if (!stmt) {
+ // Expression is outside of a statement. This usually means the expression is part
+ // of a global (module-scope) constant declaration. These must be constexpr, and so
+ // cannot contain the type of expressions that must be sanitized.
+ continue;
+ }
+
+ if (sem->Stage() == sem::EvaluationStage::kConstant) {
+ // Expression is constant. We only need to hoist expressions if they're the
+ // outermost constant expression in a chain. Remove the immediate child nodes of the
+ // expression from const_chains, and add this expression to the const_chains. As we
+ // visit leaf-expressions first, this means the content of const_chains only
+ // contains the outer-most constant expressions.
+ auto* expr = sem->Declaration();
+ bool ok = TraverseExpressions(expr, b.Diagnostics(), [&](const Expression* child) {
+ const_chains.Remove(child);
+ return child == expr ? TraverseAction::Descend : TraverseAction::Skip;
+ });
+ if (!ok) {
+ return Program(std::move(b));
+ }
+ const_chains.Add(expr);
+ } else if (should_hoist(sem)) {
+ to_hoist.Push(sem);
+ }
+ }
+ }
+
+ // After walking the full AST, const_chains only contains the outer-most constant expressions.
+ // Check if any of these need hoisting, and append those to to_hoist.
+ for (auto* expr : const_chains) {
+ if (auto* sem = src->Sem().GetVal(expr); should_hoist(sem)) {
+ to_hoist.Push(sem);
+ }
+ }
+
+ if (to_hoist.IsEmpty()) {
+ // Nothing to do. Skip.
+ return SkipTransform;
+ }
+
+ // The order of to_hoist is currently undefined. Sort by AST node id, which will make this
+ // deterministic.
+ to_hoist.Sort([&](auto* expr_a, auto* expr_b) {
+ return expr_a->Declaration()->node_id < expr_b->Declaration()->node_id;
+ });
+
+ // Hoist all the expressions in to_hoist to a constant variable, declared just before the
+ // statement of usage.
+ HoistToDeclBefore hoist_to_decl_before(ctx);
+ for (auto* expr : to_hoist) {
+ if (!hoist_to_decl_before.Add(expr, expr->Declaration(),
+ HoistToDeclBefore::VariableKind::kLet)) {
+ return Program(std::move(b));
+ }
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/promote_initializers_to_let.h b/src/tint/lang/wgsl/ast/transform/promote_initializers_to_let.h
new file mode 100644
index 0000000..bbc473b
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/promote_initializers_to_let.h
@@ -0,0 +1,44 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_PROMOTE_INITIALIZERS_TO_LET_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_PROMOTE_INITIALIZERS_TO_LET_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// A transform that hoists array and structure initializers, and identifiers resolving to a
+/// 'const' array to a 'let' variable, declared just before the statement of usage.
+/// This transform is used by backends that do not support expressions that operate on an immediate
+/// array or structure. For example, the following is not immediately expressable for HLSL:
+/// `array<i32, 2>(1, 2)[0]`
+/// @see crbug.com/tint/406
+class PromoteInitializersToLet final : public utils::Castable<PromoteInitializersToLet, Transform> {
+ public:
+ /// Constructor
+ PromoteInitializersToLet();
+
+ /// Destructor
+ ~PromoteInitializersToLet() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_PROMOTE_INITIALIZERS_TO_LET_H_
diff --git a/src/tint/lang/wgsl/ast/transform/promote_initializers_to_let_test.cc b/src/tint/lang/wgsl/ast/transform/promote_initializers_to_let_test.cc
new file mode 100644
index 0000000..57c919c
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/promote_initializers_to_let_test.cc
@@ -0,0 +1,1388 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/promote_initializers_to_let.h"
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using PromoteInitializersToLetTest = TransformTest;
+
+TEST_F(PromoteInitializersToLetTest, EmptyModule) {
+ auto* src = "";
+ auto* expect = "";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, BasicConstArray) {
+ auto* src = R"(
+fn f() {
+ const f0 = 1.0;
+ const f1 = 2.0;
+ const f2 = 3.0;
+ const f3 = 4.0;
+ var i = array<f32, 4u>(f0, f1, f2, f3)[2];
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<PromoteInitializersToLet>(src));
+}
+
+TEST_F(PromoteInitializersToLetTest, BasicRuntimeArray) {
+ auto* src = R"(
+fn f() {
+ var f0 = 1.0;
+ var f1 = 2.0;
+ var f2 = 3.0;
+ var f3 = 4.0;
+ var i = array<f32, 4u>(f0, f1, f2, f3)[2];
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var f0 = 1.0;
+ var f1 = 2.0;
+ var f2 = 3.0;
+ var f3 = 4.0;
+ let tint_symbol : array<f32, 4u> = array<f32, 4u>(f0, f1, f2, f3);
+ var i = tint_symbol[2];
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, BasicConstStruct) {
+ auto* src = R"(
+struct S {
+ a : i32,
+ b : f32,
+ c : vec3<f32>,
+};
+
+fn f() {
+ var x = S(1, 2.0, vec3<f32>()).b;
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<PromoteInitializersToLet>(src));
+}
+
+TEST_F(PromoteInitializersToLetTest, BasicRuntimeStruct) {
+ auto* src = R"(
+struct S {
+ a : i32,
+ b : f32,
+ c : vec3<f32>,
+};
+
+fn f() {
+ let runtime_value = 1;
+ var x = S(runtime_value, 2.0, vec3<f32>()).b;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ a : i32,
+ b : f32,
+ c : vec3<f32>,
+}
+
+fn f() {
+ let runtime_value = 1;
+ let tint_symbol : S = S(runtime_value, 2.0, vec3<f32>());
+ var x = tint_symbol.b;
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, BasicStruct_OutOfOrder) {
+ auto* src = R"(
+fn f() {
+ let runtime_value = 1;
+ var x = S(runtime_value, 2.0, vec3<f32>()).b;
+}
+
+struct S {
+ a : i32,
+ b : f32,
+ c : vec3<f32>,
+};
+)";
+
+ auto* expect = R"(
+fn f() {
+ let runtime_value = 1;
+ let tint_symbol : S = S(runtime_value, 2.0, vec3<f32>());
+ var x = tint_symbol.b;
+}
+
+struct S {
+ a : i32,
+ b : f32,
+ c : vec3<f32>,
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, GlobalConstBasicArray) {
+ auto* src = R"(
+const f0 = 1.0;
+
+const f1 = 2.0;
+
+const C = array<f32, 2u>(f0, f1);
+
+fn f() {
+ var f0 = 100.0;
+ var f1 = 100.0;
+ var i = C[1]; // Not hoisted, as the final const value is not an array
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<PromoteInitializersToLet>(src));
+}
+
+TEST_F(PromoteInitializersToLetTest, GlobalConstBasicArray_OutOfOrder) {
+ auto* src = R"(
+fn f() {
+ var f0 = 100.0;
+ var f1 = 100.0;
+ var i = C[1];
+}
+
+const C = array<f32, 2u>(f0, f1);
+
+const f0 = 1.0;
+
+const f1 = 2.0;
+)";
+
+ EXPECT_FALSE(ShouldRun<PromoteInitializersToLet>(src));
+}
+
+TEST_F(PromoteInitializersToLetTest, GlobalConstArrayDynamicIndex) {
+ auto* src = R"(
+const TRI_VERTICES = array(
+ vec4(0., 0., 0., 1.),
+ vec4(0., 1., 0., 1.),
+ vec4(1., 1., 0., 1.),
+);
+
+@vertex
+fn vs_main(@builtin(vertex_index) in_vertex_index: u32) -> @builtin(position) vec4<f32> {
+ // note: TRI_VERTICES requires a materialize before the dynamic index.
+ return TRI_VERTICES[in_vertex_index];
+}
+)";
+
+ auto* expect = R"(
+const TRI_VERTICES = array(vec4(0.0, 0.0, 0.0, 1.0), vec4(0.0, 1.0, 0.0, 1.0), vec4(1.0, 1.0, 0.0, 1.0));
+
+@vertex
+fn vs_main(@builtin(vertex_index) in_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ let tint_symbol : array<vec4<f32>, 3u> = TRI_VERTICES;
+ return tint_symbol[in_vertex_index];
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, LocalConstBasicArray) {
+ auto* src = R"(
+fn f() {
+ const f0 = 1.0;
+ const f1 = 2.0;
+ const C = array<f32, 2u>(f0, f1);
+ var i = C[1]; // Not hoisted, as the final const value is not an array
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<PromoteInitializersToLet>(src));
+}
+
+TEST_F(PromoteInitializersToLetTest, LocalConstBasicArrayRuntimeIndex) {
+ auto* src = R"(
+fn f() {
+ const f0 = 1.0;
+ const f1 = 2.0;
+ const C = array<f32, 2u>(f0, f1);
+ let runtime_value = 1;
+ var i = C[runtime_value];
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ const f0 = 1.0;
+ const f1 = 2.0;
+ const C = array<f32, 2u>(f0, f1);
+ let runtime_value = 1;
+ let tint_symbol : array<f32, 2u> = C;
+ var i = tint_symbol[runtime_value];
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, ArrayInForLoopInit) {
+ auto* src = R"(
+fn f() {
+ var insert_after = 1;
+ for(var i = array<f32, 4u>(0.0, 1.0, 2.0, 3.0)[insert_after]; ; ) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var insert_after = 1;
+ {
+ let tint_symbol : array<f32, 4u> = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
+ var i = tint_symbol[insert_after];
+ loop {
+ {
+ break;
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, LocalConstArrayInForLoopInit) {
+ auto* src = R"(
+fn f() {
+ const arr = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
+ let runtime_value = 1;
+ var insert_after = 1;
+ for(var i = arr[runtime_value]; ; ) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ const arr = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
+ let runtime_value = 1;
+ var insert_after = 1;
+ {
+ let tint_symbol : array<f32, 4u> = arr;
+ var i = tint_symbol[runtime_value];
+ loop {
+ {
+ break;
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, GlobalConstArrayInForLoopInit) {
+ auto* src = R"(
+const arr = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
+
+fn f() {
+ let runtime_value = 1;
+ var insert_after = 1;
+ for(var i = arr[runtime_value]; ; ) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+const arr = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
+
+fn f() {
+ let runtime_value = 1;
+ var insert_after = 1;
+ {
+ let tint_symbol : array<f32, 4u> = arr;
+ var i = tint_symbol[runtime_value];
+ loop {
+ {
+ break;
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, StructInForLoopInit) {
+ auto* src = R"(
+struct S {
+ a : i32,
+ b : f32,
+ c : vec3<f32>,
+};
+
+fn get_b_runtime(s : S) -> f32 {
+ return s.b;
+}
+
+fn f() {
+ var insert_after = 1;
+ for(var x = get_b_runtime(S(1, 2.0, vec3<f32>())); ; ) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ a : i32,
+ b : f32,
+ c : vec3<f32>,
+}
+
+fn get_b_runtime(s : S) -> f32 {
+ return s.b;
+}
+
+fn f() {
+ var insert_after = 1;
+ {
+ let tint_symbol : S = S(1, 2.0, vec3<f32>());
+ var x = get_b_runtime(tint_symbol);
+ loop {
+ {
+ break;
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, StructInForLoopInit_OutOfOrder) {
+ auto* src = R"(
+fn f() {
+ var insert_after = 1;
+ for(var x = get_b_runtime(S(1, 2.0, vec3<f32>())); ; ) {
+ break;
+ }
+}
+
+fn get_b_runtime(s : S) -> f32 {
+ return s.b;
+}
+
+struct S {
+ a : i32,
+ b : f32,
+ c : vec3<f32>,
+};
+)";
+
+ auto* expect = R"(
+fn f() {
+ var insert_after = 1;
+ {
+ let tint_symbol : S = S(1, 2.0, vec3<f32>());
+ var x = get_b_runtime(tint_symbol);
+ loop {
+ {
+ break;
+ }
+ }
+ }
+}
+
+fn get_b_runtime(s : S) -> f32 {
+ return s.b;
+}
+
+struct S {
+ a : i32,
+ b : f32,
+ c : vec3<f32>,
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, ArrayInForLoopCond) {
+ auto* src = R"(
+fn f() {
+ var f = 1.0;
+ for(; f == array<f32, 1u>(f)[0]; f = f + 1.0) {
+ var marker = 1;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var f = 1.0;
+ loop {
+ let tint_symbol : array<f32, 1u> = array<f32, 1u>(f);
+ if (!((f == tint_symbol[0]))) {
+ break;
+ }
+ {
+ var marker = 1;
+ }
+
+ continuing {
+ f = (f + 1.0);
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, LocalConstArrayInForLoopCond) {
+ auto* src = R"(
+fn f() {
+ let runtime_value = 0;
+ const f = 1.0;
+ const arr = array<f32, 1u>(f);
+ for(var i = f; i == arr[runtime_value]; i = i + 1.0) {
+ var marker = 1;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ let runtime_value = 0;
+ const f = 1.0;
+ const arr = array<f32, 1u>(f);
+ {
+ var i = f;
+ loop {
+ let tint_symbol : array<f32, 1u> = arr;
+ if (!((i == tint_symbol[runtime_value]))) {
+ break;
+ }
+ {
+ var marker = 1;
+ }
+
+ continuing {
+ i = (i + 1.0);
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, GlobalConstArrayInForLoopCond) {
+ auto* src = R"(
+const f = 1.0;
+
+const arr = array<f32, 1u>(f);
+
+fn F() {
+ let runtime_value = 0;
+ for(var i = f; i == arr[runtime_value]; i = i + 1.0) {
+ var marker = 1;
+ }
+}
+)";
+
+ auto* expect = R"(
+const f = 1.0;
+
+const arr = array<f32, 1u>(f);
+
+fn F() {
+ let runtime_value = 0;
+ {
+ var i = f;
+ loop {
+ let tint_symbol : array<f32, 1u> = arr;
+ if (!((i == tint_symbol[runtime_value]))) {
+ break;
+ }
+ {
+ var marker = 1;
+ }
+
+ continuing {
+ i = (i + 1.0);
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, ArrayInForLoopCont) {
+ auto* src = R"(
+fn f() {
+ let runtime_value = 0;
+ var f = 0.0;
+ for(; f < 10.0; f = f + array<f32, 1u>(1.0)[runtime_value]) {
+ var marker = 1;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ let runtime_value = 0;
+ var f = 0.0;
+ loop {
+ if (!((f < 10.0))) {
+ break;
+ }
+ {
+ var marker = 1;
+ }
+
+ continuing {
+ let tint_symbol : array<f32, 1u> = array<f32, 1u>(1.0);
+ f = (f + tint_symbol[runtime_value]);
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, LocalConstArrayInForLoopCont) {
+ auto* src = R"(
+fn f() {
+ let runtime_value = 0;
+ const arr = array<f32, 1u>(1.0);
+ var f = 0.0;
+ for(; f < 10.0; f = f + arr[runtime_value]) {
+ var marker = 1;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ let runtime_value = 0;
+ const arr = array<f32, 1u>(1.0);
+ var f = 0.0;
+ loop {
+ if (!((f < 10.0))) {
+ break;
+ }
+ {
+ var marker = 1;
+ }
+
+ continuing {
+ let tint_symbol : array<f32, 1u> = arr;
+ f = (f + tint_symbol[runtime_value]);
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, GlobalConstArrayInForLoopCont) {
+ auto* src = R"(
+const arr = array<f32, 1u>(1.0);
+
+fn f() {
+ let runtime_value = 0;
+ var f = 0.0;
+ for(; f < 10.0; f = f + arr[runtime_value]) {
+ var marker = 1;
+ }
+}
+)";
+
+ auto* expect = R"(
+const arr = array<f32, 1u>(1.0);
+
+fn f() {
+ let runtime_value = 0;
+ var f = 0.0;
+ loop {
+ if (!((f < 10.0))) {
+ break;
+ }
+ {
+ var marker = 1;
+ }
+
+ continuing {
+ let tint_symbol : array<f32, 1u> = arr;
+ f = (f + tint_symbol[runtime_value]);
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, ArrayInForLoopInitCondCont) {
+ auto* src = R"(
+fn f() {
+ let runtime_value = 0;
+ for(var f = array<f32, 1u>(0.0)[runtime_value];
+ f < array<f32, 1u>(1.0)[runtime_value];
+ f = f + array<f32, 1u>(2.0)[runtime_value]) {
+ var marker = 1;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ let runtime_value = 0;
+ {
+ let tint_symbol : array<f32, 1u> = array<f32, 1u>(0.0);
+ var f = tint_symbol[runtime_value];
+ loop {
+ let tint_symbol_1 : array<f32, 1u> = array<f32, 1u>(1.0);
+ if (!((f < tint_symbol_1[runtime_value]))) {
+ break;
+ }
+ {
+ var marker = 1;
+ }
+
+ continuing {
+ let tint_symbol_2 : array<f32, 1u> = array<f32, 1u>(2.0);
+ f = (f + tint_symbol_2[runtime_value]);
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, LocalConstArrayInForLoopInitCondCont) {
+ auto* src = R"(
+fn f() {
+ let runtime_value = 0;
+ const arr_a = array<f32, 1u>(0.0);
+ const arr_b = array<f32, 1u>(1.0);
+ const arr_c = array<f32, 1u>(2.0);
+ for(var f = arr_a[runtime_value]; f < arr_b[runtime_value]; f = f + arr_c[runtime_value]) {
+ var marker = 1;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ let runtime_value = 0;
+ const arr_a = array<f32, 1u>(0.0);
+ const arr_b = array<f32, 1u>(1.0);
+ const arr_c = array<f32, 1u>(2.0);
+ {
+ let tint_symbol : array<f32, 1u> = arr_a;
+ var f = tint_symbol[runtime_value];
+ loop {
+ let tint_symbol_1 : array<f32, 1u> = arr_b;
+ if (!((f < tint_symbol_1[runtime_value]))) {
+ break;
+ }
+ {
+ var marker = 1;
+ }
+
+ continuing {
+ let tint_symbol_2 : array<f32, 1u> = arr_c;
+ f = (f + tint_symbol_2[runtime_value]);
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, ArrayInElseIf) {
+ auto* src = R"(
+fn f() {
+ var f = 1.0;
+ if (true) {
+ var marker = 0;
+ } else if (f == array<f32, 2u>(f, f)[0]) {
+ var marker = 1;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var f = 1.0;
+ if (true) {
+ var marker = 0;
+ } else {
+ let tint_symbol : array<f32, 2u> = array<f32, 2u>(f, f);
+ if ((f == tint_symbol[0])) {
+ var marker = 1;
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, ArrayInElseIfChain) {
+ auto* src = R"(
+fn f() {
+ var f = 1.0;
+ if (true) {
+ var marker = 0;
+ } else if (true) {
+ var marker = 1;
+ } else if (f == array<f32, 2u>(f, f)[0]) {
+ var marker = 2;
+ } else if (f == array<f32, 2u>(f, f)[1]) {
+ var marker = 3;
+ } else if (true) {
+ var marker = 4;
+ } else {
+ var marker = 5;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var f = 1.0;
+ if (true) {
+ var marker = 0;
+ } else if (true) {
+ var marker = 1;
+ } else {
+ let tint_symbol : array<f32, 2u> = array<f32, 2u>(f, f);
+ if ((f == tint_symbol[0])) {
+ var marker = 2;
+ } else {
+ let tint_symbol_1 : array<f32, 2u> = array<f32, 2u>(f, f);
+ if ((f == tint_symbol_1[1])) {
+ var marker = 3;
+ } else if (true) {
+ var marker = 4;
+ } else {
+ var marker = 5;
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, LocalConstArrayInElseIfChain) {
+ auto* src = R"(
+fn f() {
+ let runtime_value = 0;
+ const f = 1.0;
+ const arr = array<f32, 2u>(f, f);
+ if (true) {
+ var marker = 0;
+ } else if (true) {
+ var marker = 1;
+ } else if (f == arr[runtime_value]) {
+ var marker = 2;
+ } else if (f == arr[runtime_value + 1]) {
+ var marker = 3;
+ } else if (true) {
+ var marker = 4;
+ } else {
+ var marker = 5;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ let runtime_value = 0;
+ const f = 1.0;
+ const arr = array<f32, 2u>(f, f);
+ if (true) {
+ var marker = 0;
+ } else if (true) {
+ var marker = 1;
+ } else {
+ let tint_symbol : array<f32, 2u> = arr;
+ if ((f == tint_symbol[runtime_value])) {
+ var marker = 2;
+ } else {
+ let tint_symbol_1 : array<f32, 2u> = arr;
+ if ((f == tint_symbol_1[(runtime_value + 1)])) {
+ var marker = 3;
+ } else if (true) {
+ var marker = 4;
+ } else {
+ var marker = 5;
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, GlobalConstArrayInElseIfChain) {
+ auto* src = R"(
+const f = 1.0;
+
+const arr = array<f32, 2u>(f, f);
+
+fn F() {
+ let runtime_value = 0;
+ if (true) {
+ var marker = 0;
+ } else if (true) {
+ var marker = 1;
+ } else if (f == arr[runtime_value]) {
+ var marker = 2;
+ } else if (f == arr[runtime_value + 1]) {
+ var marker = 3;
+ } else if (true) {
+ var marker = 4;
+ } else {
+ var marker = 5;
+ }
+}
+)";
+
+ auto* expect = R"(
+const f = 1.0;
+
+const arr = array<f32, 2u>(f, f);
+
+fn F() {
+ let runtime_value = 0;
+ if (true) {
+ var marker = 0;
+ } else if (true) {
+ var marker = 1;
+ } else {
+ let tint_symbol : array<f32, 2u> = arr;
+ if ((f == tint_symbol[runtime_value])) {
+ var marker = 2;
+ } else {
+ let tint_symbol_1 : array<f32, 2u> = arr;
+ if ((f == tint_symbol_1[(runtime_value + 1)])) {
+ var marker = 3;
+ } else if (true) {
+ var marker = 4;
+ } else {
+ var marker = 5;
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, ArrayInArrayArrayConstIndex) {
+ auto* src = R"(
+fn f() {
+ var i = array<array<f32, 2u>, 2u>(array<f32, 2u>(1.0, 2.0), array<f32, 2u>(3.0, 4.0))[0][1];
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<PromoteInitializersToLet>(src));
+}
+
+TEST_F(PromoteInitializersToLetTest, ArrayInArrayArrayRuntimeIndex) {
+ auto* src = R"(
+fn f() {
+ let runtime_value = 1;
+ var i = array<array<f32, 2u>, 2u>(array<f32, 2u>(1.0, 2.0), array<f32, 2u>(3.0, 4.0))[runtime_value][runtime_value + 1];
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ let runtime_value = 1;
+ let tint_symbol : array<array<f32, 2u>, 2u> = array<array<f32, 2u>, 2u>(array<f32, 2u>(1.0, 2.0), array<f32, 2u>(3.0, 4.0));
+ var i = tint_symbol[runtime_value][(runtime_value + 1)];
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, LocalConstArrayInArrayArrayConstIndex) {
+ auto* src = R"(
+fn f() {
+ const arr_0 = array<f32, 2u>(1.0, 2.0);
+ const arr_1 = array<f32, 2u>(3.0, 4.0);
+ const arr_2 = array<array<f32, 2u>, 2u>(arr_0, arr_1);
+ var i = arr_2[0][1];
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<PromoteInitializersToLet>(src));
+}
+
+TEST_F(PromoteInitializersToLetTest, LocalConstArrayInArrayArrayRuntimeIndex) {
+ auto* src = R"(
+fn f() {
+ let runtime_value = 1;
+ const arr_0 = array<f32, 2u>(1.0, 2.0);
+ const arr_1 = array<f32, 2u>(3.0, 4.0);
+ const arr_2 = array<array<f32, 2u>, 2u>(arr_0, arr_1);
+ var i = arr_2[runtime_value][runtime_value + 1];
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ let runtime_value = 1;
+ const arr_0 = array<f32, 2u>(1.0, 2.0);
+ const arr_1 = array<f32, 2u>(3.0, 4.0);
+ const arr_2 = array<array<f32, 2u>, 2u>(arr_0, arr_1);
+ let tint_symbol : array<array<f32, 2u>, 2u> = arr_2;
+ var i = tint_symbol[runtime_value][(runtime_value + 1)];
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, GlobalConstArrayInArrayArray) {
+ auto* src = R"(
+const arr_0 = array<f32, 2u>(1.0, 2.0);
+
+const arr_1 = array<f32, 2u>(3.0, 4.0);
+
+const arr_2 = array<array<f32, 2u>, 2u>(arr_0, arr_1);
+
+fn f() {
+ let runtime_value = 1;
+ var i = arr_2[runtime_value][runtime_value + 1];
+}
+)";
+
+ auto* expect = R"(
+const arr_0 = array<f32, 2u>(1.0, 2.0);
+
+const arr_1 = array<f32, 2u>(3.0, 4.0);
+
+const arr_2 = array<array<f32, 2u>, 2u>(arr_0, arr_1);
+
+fn f() {
+ let runtime_value = 1;
+ let tint_symbol : array<array<f32, 2u>, 2u> = arr_2;
+ var i = tint_symbol[runtime_value][(runtime_value + 1)];
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, StructNested) {
+ auto* src = R"(
+struct S1 {
+ a : i32,
+};
+
+struct S2 {
+ a : i32,
+ b : S1,
+ c : i32,
+};
+
+struct S3 {
+ a : S2,
+};
+
+fn get_a(s : S3) -> S2 {
+ return s.a;
+}
+
+fn f() {
+ var x = get_a(S3(S2(1, S1(2), 3))).b.a;
+}
+)";
+
+ auto* expect = R"(
+struct S1 {
+ a : i32,
+}
+
+struct S2 {
+ a : i32,
+ b : S1,
+ c : i32,
+}
+
+struct S3 {
+ a : S2,
+}
+
+fn get_a(s : S3) -> S2 {
+ return s.a;
+}
+
+fn f() {
+ let tint_symbol : S3 = S3(S2(1, S1(2), 3));
+ var x = get_a(tint_symbol).b.a;
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, Mixed) {
+ auto* src = R"(
+struct S1 {
+ a : i32,
+};
+
+struct S2 {
+ a : array<S1, 3u>,
+};
+
+fn get_a(s : S2) -> array<S1, 3u> {
+ return s.a;
+}
+
+fn f() {
+ var x = get_a(S2(array<S1, 3u>(S1(1), S1(2), S1(3))))[1].a;
+}
+)";
+
+ auto* expect = R"(
+struct S1 {
+ a : i32,
+}
+
+struct S2 {
+ a : array<S1, 3u>,
+}
+
+fn get_a(s : S2) -> array<S1, 3u> {
+ return s.a;
+}
+
+fn f() {
+ let tint_symbol : S2 = S2(array<S1, 3u>(S1(1), S1(2), S1(3)));
+ var x = get_a(tint_symbol)[1].a;
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, Mixed_OutOfOrder) {
+ auto* src = R"(
+fn f() {
+ var x = get_a(S2(array<S1, 3u>(S1(1), S1(2), S1(3))))[1].a;
+}
+
+fn get_a(s : S2) -> array<S1, 3u> {
+ return s.a;
+}
+
+struct S2 {
+ a : array<S1, 3u>,
+};
+
+struct S1 {
+ a : i32,
+};
+)";
+
+ auto* expect = R"(
+fn f() {
+ let tint_symbol : S2 = S2(array<S1, 3u>(S1(1), S1(2), S1(3)));
+ var x = get_a(tint_symbol)[1].a;
+}
+
+fn get_a(s : S2) -> array<S1, 3u> {
+ return s.a;
+}
+
+struct S2 {
+ a : array<S1, 3u>,
+}
+
+struct S1 {
+ a : i32,
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, NoChangeOnVarDecl) {
+ auto* src = R"(
+alias F = f32;
+
+fn f() {
+ var local_arr = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
+ var local_str = F(3.0);
+}
+
+const module_arr : array<f32, 4u> = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
+
+const module_str : F = F(2.0);
+)";
+
+ auto* expect = src;
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, NoChangeOnVarDecl_OutOfOrder) {
+ auto* src = R"(
+fn f() {
+ var local_arr = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
+ var local_str = F(3.0);
+}
+
+const module_str : F = F(2.0);
+
+alias F = f32;
+
+const module_arr : array<f32, 4u> = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
+)";
+
+ auto* expect = src;
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, ForLoopShadowing) {
+ auto* src = R"(
+fn X() {
+ var i = 10;
+ for(var f = 0; f < 10; f = f + array<i32, 1u>(i)[0]) {
+ var i = 20;
+ }
+}
+
+fn Y() {
+ var i = 10;
+ for(var f = 0; f < array<i32, 1u>(i)[0]; f = f + 1) {
+ var i = 20;
+ }
+}
+
+fn Z() {
+ var i = 10;
+ for(var f = array<i32, 1u>(i)[0]; f < 10; f = f + 1) {
+ var i = 20;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn X() {
+ var i = 10;
+ {
+ var f = 0;
+ loop {
+ if (!((f < 10))) {
+ break;
+ }
+ {
+ var i = 20;
+ }
+
+ continuing {
+ let tint_symbol : array<i32, 1u> = array<i32, 1u>(i);
+ f = (f + tint_symbol[0]);
+ }
+ }
+ }
+}
+
+fn Y() {
+ var i = 10;
+ {
+ var f = 0;
+ loop {
+ let tint_symbol_1 : array<i32, 1u> = array<i32, 1u>(i);
+ if (!((f < tint_symbol_1[0]))) {
+ break;
+ }
+ {
+ var i = 20;
+ }
+
+ continuing {
+ f = (f + 1);
+ }
+ }
+ }
+}
+
+fn Z() {
+ var i = 10;
+ {
+ let tint_symbol_2 : array<i32, 1u> = array<i32, 1u>(i);
+ var f = tint_symbol_2[0];
+ loop {
+ if (!((f < 10))) {
+ break;
+ }
+ {
+ var i = 20;
+ }
+
+ continuing {
+ f = (f + 1);
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, AssignAbstractArray) {
+ // Test that hoisting an array with abstract type still materializes to the correct type.
+ auto* src = R"(
+fn f() {
+ var arr : array<f32, 4>;
+ arr = array(1, 2, 3, 4);
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var arr : array<f32, 4>;
+ let tint_symbol : array<f32, 4u> = array(1, 2, 3, 4);
+ arr = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteInitializersToLetTest, AssignAbstractArray_ToPhony) {
+ // Test that we do not try to hoist an abstract array expression that is the RHS of a phony
+ // assignment, as its type will not be materialized.
+ auto* src = R"(
+fn f() {
+ _ = array(1, 2, 3, 4);
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<PromoteInitializersToLet>(src));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl.cc b/src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl.cc
new file mode 100644
index 0000000..6adde0d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl.cc
@@ -0,0 +1,679 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl.h"
+
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/transform/utils/get_insertion_point.h"
+#include "src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before.h"
+#include "src/tint/lang/wgsl/ast/traverse_expressions.h"
+#include "src/tint/sem/block_statement.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/for_loop_statement.h"
+#include "src/tint/sem/if_statement.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/sem/while_statement.h"
+#include "src/tint/transform/manager.h"
+#include "src/tint/utils/scoped_assignment.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::PromoteSideEffectsToDecl);
+
+namespace tint::ast::transform {
+namespace {
+
+// Base state class for common members
+class StateBase {
+ protected:
+ CloneContext& ctx;
+ ProgramBuilder& b;
+ const sem::Info& sem;
+
+ explicit StateBase(CloneContext& ctx_in)
+ : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}
+};
+
+// This first transform converts side-effecting for-loops to loops and else-ifs
+// to else {if}s so that the next transform, DecomposeSideEffects, can insert
+// hoisted expressions above their current location.
+struct SimplifySideEffectStatements : tint::utils::Castable<PromoteSideEffectsToDecl, Transform> {
+ ApplyResult Apply(const Program* src, const DataMap& inputs, DataMap& outputs) const override;
+};
+
+Transform::ApplyResult SimplifySideEffectStatements::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ bool made_changes = false;
+
+ HoistToDeclBefore hoist_to_decl_before(ctx);
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* sem_expr = src->Sem().GetVal(node)) {
+ if (!sem_expr->HasSideEffects()) {
+ continue;
+ }
+
+ hoist_to_decl_before.Prepare(sem_expr);
+ made_changes = true;
+ }
+ }
+
+ if (!made_changes) {
+ return SkipTransform;
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+// Decomposes side-effecting expressions to ensure order of evaluation. This
+// handles both breaking down logical binary expressions for short-circuit
+// evaluation, as well as hoisting expressions to ensure order of evaluation.
+struct DecomposeSideEffects : tint::utils::Castable<PromoteSideEffectsToDecl, Transform> {
+ class CollectHoistsState;
+ class DecomposeState;
+ ApplyResult Apply(const Program* src, const DataMap& inputs, DataMap& outputs) const override;
+};
+
+// CollectHoistsState traverses the AST top-down, identifying which expressions
+// need to be hoisted to ensure order of evaluation, both those that give
+// side-effects, as well as those that receive, and returns a set of these
+// expressions.
+using ToHoistSet = std::unordered_set<const Expression*>;
+class DecomposeSideEffects::CollectHoistsState : public StateBase {
+ // Expressions to hoist because they either cause or receive side-effects.
+ ToHoistSet to_hoist;
+
+ // Used to mark expressions as not or no longer having side-effects.
+ std::unordered_set<const Expression*> no_side_effects;
+
+ // Returns true if `expr` has side-effects. Unlike invoking
+ // sem::ValueExpression::HasSideEffects(), this function takes into account whether
+ // `expr` has been hoisted, returning false in that case. Furthermore, it
+ // returns the correct result on parent expression nodes by traversing the
+ // expression tree, memoizing the results to ensure O(1) amortized lookup.
+ bool HasSideEffects(const Expression* expr) {
+ if (no_side_effects.count(expr)) {
+ return false;
+ }
+
+ return Switch(
+ expr, [&](const CallExpression* e) -> bool { return sem.Get(e)->HasSideEffects(); },
+ [&](const BinaryExpression* e) {
+ if (HasSideEffects(e->lhs) || HasSideEffects(e->rhs)) {
+ return true;
+ }
+ no_side_effects.insert(e);
+ return false;
+ },
+ [&](const IndexAccessorExpression* e) {
+ if (HasSideEffects(e->object) || HasSideEffects(e->index)) {
+ return true;
+ }
+ no_side_effects.insert(e);
+ return false;
+ },
+ [&](const MemberAccessorExpression* e) {
+ if (HasSideEffects(e->object)) {
+ return true;
+ }
+ no_side_effects.insert(e);
+ return false;
+ },
+ [&](const BitcastExpression* e) { //
+ if (HasSideEffects(e->expr)) {
+ return true;
+ }
+ no_side_effects.insert(e);
+ return false;
+ },
+
+ [&](const UnaryOpExpression* e) { //
+ if (HasSideEffects(e->expr)) {
+ return true;
+ }
+ no_side_effects.insert(e);
+ return false;
+ },
+ [&](const IdentifierExpression* e) {
+ no_side_effects.insert(e);
+ return false;
+ },
+ [&](const LiteralExpression* e) {
+ no_side_effects.insert(e);
+ return false;
+ },
+ [&](const PhonyExpression* e) {
+ no_side_effects.insert(e);
+ return false;
+ },
+ [&](Default) {
+ TINT_ICE(Transform, b.Diagnostics()) << "Unhandled expression type";
+ return false;
+ });
+ }
+
+ // Adds `e` to `to_hoist` for hoisting to a let later on.
+ void Hoist(const Expression* e) {
+ no_side_effects.insert(e);
+ to_hoist.emplace(e);
+ }
+
+ // Hoists any expressions in `maybe_hoist` and clears it
+ template <size_t N>
+ void Flush(tint::utils::Vector<const Expression*, N>& maybe_hoist) {
+ for (auto* m : maybe_hoist) {
+ Hoist(m);
+ }
+ maybe_hoist.Clear();
+ }
+
+ // Recursive function that processes expressions for side-effects. It
+ // traverses the expression tree child before parent, left-to-right. Each call
+ // returns whether the input expression should maybe be hoisted, allowing the
+ // parent node to decide whether to hoist or not. Generally:
+ // * When 'true' is returned, the expression is added to the maybe_hoist list.
+ // * When a side-effecting expression is met, we flush the expressions in the
+ // maybe_hoist list, as they are potentially receivers of the side-effects.
+ // * For index and member accessor expressions, special care is taken to not
+ // over-hoist the lhs expressions, as these may be be chained to refer to a
+ // single memory location.
+ template <size_t N>
+ bool ProcessExpression(const Expression* expr,
+ tint::utils::Vector<const Expression*, N>& maybe_hoist) {
+ auto process = [&](const Expression* e) -> bool {
+ return ProcessExpression(e, maybe_hoist);
+ };
+
+ auto default_process = [&](const Expression* e) {
+ auto maybe = process(e);
+ if (maybe) {
+ maybe_hoist.Push(e);
+ }
+ if (HasSideEffects(e)) {
+ Flush(maybe_hoist);
+ }
+ return false;
+ };
+
+ auto binary_process = [&](const Expression* lhs, const Expression* rhs) {
+ // If neither side causes side-effects, but at least one receives them,
+ // let parent node hoist. This avoids over-hoisting side-effect receivers
+ // of compound binary expressions (e.g. for "((a && b) && c) && f()", we
+ // don't want to hoist each of "a", "b", and "c" separately, but want to
+ // hoist "((a && b) && c)".
+ if (!HasSideEffects(lhs) && !HasSideEffects(rhs)) {
+ auto lhs_maybe = process(lhs);
+ auto rhs_maybe = process(rhs);
+ if (lhs_maybe || rhs_maybe) {
+ return true;
+ }
+ return false;
+ }
+
+ default_process(lhs);
+ default_process(rhs);
+ return false;
+ };
+
+ auto accessor_process = [&](const Expression* lhs, const Expression* rhs = nullptr) {
+ auto maybe = process(lhs);
+ // If lhs is a variable, let parent node hoist otherwise flush it right
+ // away. This is to avoid over-hoisting the lhs of accessor chains (e.g.
+ // for "v[a][b][c] + g()" we want to hoist all of "v[a][b][c]", not "t1 =
+ // v[a]", then "t2 = t1[b]" then "t3 = t2[c]").
+ if (maybe && HasSideEffects(lhs)) {
+ maybe_hoist.Push(lhs);
+ Flush(maybe_hoist);
+ maybe = false;
+ }
+ if (rhs) {
+ default_process(rhs);
+ }
+ return maybe;
+ };
+
+ return Switch(
+ expr,
+ [&](const CallExpression* e) -> bool {
+ // We eagerly flush any variables in maybe_hoist for the current
+ // call expression. Then we scope maybe_hoist to the processing of
+ // the call args. This ensures that given: g(c, a(0), d) we hoist
+ // 'c' because of 'a(0)', but not 'd' because there's no need, since
+ // the call to g() will be hoisted if necessary.
+ if (HasSideEffects(e)) {
+ Flush(maybe_hoist);
+ }
+
+ TINT_SCOPED_ASSIGNMENT(maybe_hoist, {});
+ for (auto* a : e->args) {
+ default_process(a);
+ }
+
+ // Always hoist this call, even if it has no side-effects to ensure
+ // left-to-right order of evaluation.
+ // E.g. for "no_side_effects() + side_effects()", we want to hoist
+ // no_side_effects() first.
+ return true;
+ },
+ [&](const IdentifierExpression* e) {
+ if (auto* sem_e = sem.GetVal(e)) {
+ if (auto* var_user = sem_e->UnwrapLoad()->As<sem::VariableUser>()) {
+ // Don't hoist constants.
+ if (var_user->ConstantValue()) {
+ return false;
+ }
+ // Don't hoist read-only variables as they cannot receive side-effects.
+ if (var_user->Variable()->Access() == builtin::Access::kRead) {
+ return false;
+ }
+ // Don't hoist textures / samplers as they can't be placed into a let, nor
+ // can they have side effects.
+ if (var_user->Variable()->Type()->IsAnyOf<type::Texture, type::Sampler>()) {
+ return false;
+ }
+ return true;
+ }
+ }
+ return false;
+ },
+ [&](const BinaryExpression* e) {
+ if (e->IsLogical() && HasSideEffects(e)) {
+ // Don't hoist children of logical binary expressions with
+ // side-effects. These will be handled by DecomposeState.
+ process(e->lhs);
+ process(e->rhs);
+ return false;
+ }
+ return binary_process(e->lhs, e->rhs);
+ },
+ [&](const BitcastExpression* e) { //
+ return process(e->expr);
+ },
+ [&](const UnaryOpExpression* e) { //
+ auto r = process(e->expr);
+ // Don't hoist address-of expressions.
+ // E.g. for "g(&b, a(0))", we hoist "a(0)" only.
+ if (e->op == UnaryOp::kAddressOf) {
+ return false;
+ }
+ return r;
+ },
+ [&](const IndexAccessorExpression* e) { return accessor_process(e->object, e->index); },
+ [&](const MemberAccessorExpression* e) { return accessor_process(e->object); },
+ [&](const LiteralExpression*) {
+ // Leaf
+ return false;
+ },
+ [&](const PhonyExpression*) {
+ // Leaf
+ return false;
+ },
+ [&](Default) {
+ TINT_ICE(Transform, b.Diagnostics()) << "Unhandled expression type";
+ return false;
+ });
+ }
+
+ // Starts the recursive processing of a statement's expression(s) to hoist side-effects to lets.
+ void ProcessExpression(const Expression* expr) {
+ if (!expr) {
+ return;
+ }
+
+ tint::utils::Vector<const Expression*, 8> maybe_hoist;
+ ProcessExpression(expr, maybe_hoist);
+ }
+
+ public:
+ explicit CollectHoistsState(CloneContext& ctx_in) : StateBase(ctx_in) {}
+
+ ToHoistSet Run() {
+ // Traverse all statements, recursively processing their expression tree(s)
+ // to hoist side-effects to lets.
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ auto* stmt = node->As<Statement>();
+ if (!stmt) {
+ continue;
+ }
+
+ Switch(
+ stmt, //
+ [&](const AssignmentStatement* s) {
+ tint::utils::Vector<const Expression*, 8> maybe_hoist;
+ ProcessExpression(s->lhs, maybe_hoist);
+ ProcessExpression(s->rhs, maybe_hoist);
+ },
+ [&](const CallStatement* s) { //
+ ProcessExpression(s->expr);
+ },
+ [&](const ForLoopStatement* s) { ProcessExpression(s->condition); },
+ [&](const WhileStatement* s) { ProcessExpression(s->condition); },
+ [&](const IfStatement* s) { //
+ ProcessExpression(s->condition);
+ },
+ [&](const ReturnStatement* s) { //
+ ProcessExpression(s->value);
+ },
+ [&](const SwitchStatement* s) { ProcessExpression(s->condition); },
+ [&](const VariableDeclStatement* s) {
+ ProcessExpression(s->variable->initializer);
+ });
+ }
+
+ return std::move(to_hoist);
+ }
+};
+
+// DecomposeState performs the actual transforming of the AST to ensure order of
+// evaluation, using the set of expressions to hoist collected by
+// CollectHoistsState.
+class DecomposeSideEffects::DecomposeState : public StateBase {
+ ToHoistSet to_hoist;
+
+ // Returns true if `binary_expr` should be decomposed for short-circuit eval.
+ bool IsLogicalWithSideEffects(const BinaryExpression* binary_expr) {
+ return binary_expr->IsLogical() && (sem.GetVal(binary_expr->lhs)->HasSideEffects() ||
+ sem.GetVal(binary_expr->rhs)->HasSideEffects());
+ }
+
+ // Recursive function used to decompose an expression for short-circuit eval.
+ template <size_t N>
+ const Expression* Decompose(const Expression* expr,
+ tint::utils::Vector<const Statement*, N>* curr_stmts) {
+ // Helper to avoid passing in same args.
+ auto decompose = [&](auto& e) { return Decompose(e, curr_stmts); };
+
+ // Clones `expr`, possibly hoisting it to a let.
+ auto clone_maybe_hoisted = [&](const Expression* e) -> const Expression* {
+ if (to_hoist.count(e)) {
+ auto name = b.Symbols().New();
+ auto* ty = sem.GetVal(e)->Type();
+ auto* v = b.Let(name, Transform::CreateASTTypeFor(ctx, ty), ctx.Clone(e));
+ auto* decl = b.Decl(v);
+ curr_stmts->Push(decl);
+ return b.Expr(name);
+ }
+ return ctx.Clone(e);
+ };
+
+ return Switch(
+ expr,
+ [&](const BinaryExpression* bin_expr) -> const Expression* {
+ if (!IsLogicalWithSideEffects(bin_expr)) {
+ // No short-circuit, emit usual binary expr
+ ctx.Replace(bin_expr->lhs, decompose(bin_expr->lhs));
+ ctx.Replace(bin_expr->rhs, decompose(bin_expr->rhs));
+ return clone_maybe_hoisted(bin_expr);
+ }
+
+ // Decompose into ifs to implement short-circuiting
+ // For example, 'let r = a && b' becomes:
+ //
+ // var temp = a;
+ // if (temp) {
+ // temp = b;
+ // }
+ // let r = temp;
+ //
+ // and similarly, 'let r = a || b' becomes:
+ //
+ // var temp = a;
+ // if (!temp) {
+ // temp = b;
+ // }
+ // let r = temp;
+ //
+ // Further, compound logical binary expressions are also handled
+ // recursively, for example, 'let r = (a && (b && c))' becomes:
+ //
+ // var temp = a;
+ // if (temp) {
+ // var temp2 = b;
+ // if (temp2) {
+ // temp2 = c;
+ // }
+ // temp = temp2;
+ // }
+ // let r = temp;
+
+ auto name = b.Sym();
+ curr_stmts->Push(b.Decl(b.Var(name, decompose(bin_expr->lhs))));
+
+ const Expression* if_cond = nullptr;
+ if (bin_expr->IsLogicalOr()) {
+ if_cond = b.Not(name);
+ } else {
+ if_cond = b.Expr(name);
+ }
+
+ const BlockStatement* if_body = nullptr;
+ {
+ tint::utils::Vector<const Statement*, N> stmts;
+ TINT_SCOPED_ASSIGNMENT(curr_stmts, &stmts);
+ auto* new_rhs = decompose(bin_expr->rhs);
+ curr_stmts->Push(b.Assign(name, new_rhs));
+ if_body = b.Block(std::move(*curr_stmts));
+ }
+
+ curr_stmts->Push(b.If(if_cond, if_body));
+
+ return b.Expr(name);
+ },
+ [&](const IndexAccessorExpression* idx) {
+ ctx.Replace(idx->object, decompose(idx->object));
+ ctx.Replace(idx->index, decompose(idx->index));
+ return clone_maybe_hoisted(idx);
+ },
+ [&](const BitcastExpression* bitcast) {
+ ctx.Replace(bitcast->expr, decompose(bitcast->expr));
+ return clone_maybe_hoisted(bitcast);
+ },
+ [&](const CallExpression* call) {
+ for (auto* a : call->args) {
+ ctx.Replace(a, decompose(a));
+ }
+ return clone_maybe_hoisted(call);
+ },
+ [&](const MemberAccessorExpression* member) {
+ ctx.Replace(member->object, decompose(member->object));
+ return clone_maybe_hoisted(member);
+ },
+ [&](const UnaryOpExpression* unary) {
+ ctx.Replace(unary->expr, decompose(unary->expr));
+ return clone_maybe_hoisted(unary);
+ },
+ [&](const LiteralExpression* lit) {
+ return clone_maybe_hoisted(lit); // Leaf expression, just clone as is
+ },
+ [&](const IdentifierExpression* id) {
+ return clone_maybe_hoisted(id); // Leaf expression, just clone as is
+ },
+ [&](const PhonyExpression* phony) {
+ return clone_maybe_hoisted(phony); // Leaf expression, just clone as is
+ },
+ [&](Default) {
+ TINT_ICE(AST, b.Diagnostics())
+ << "unhandled expression type: " << expr->TypeInfo().name;
+ return nullptr;
+ });
+ }
+
+ // Inserts statements in `stmts` before `stmt`
+ template <size_t N>
+ void InsertBefore(tint::utils::Vector<const Statement*, N>& stmts, const Statement* stmt) {
+ if (!stmts.IsEmpty()) {
+ auto ip = utils::GetInsertionPoint(ctx, stmt);
+ for (auto* s : stmts) {
+ ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, s);
+ }
+ }
+ }
+
+ // Decomposes expressions of `stmt`, returning a replacement statement or
+ // nullptr if not replacing it.
+ const Statement* DecomposeStatement(const Statement* stmt) {
+ return Switch(
+ stmt,
+ [&](const AssignmentStatement* s) -> const Statement* {
+ if (!sem.GetVal(s->lhs)->HasSideEffects() &&
+ !sem.GetVal(s->rhs)->HasSideEffects()) {
+ return nullptr;
+ }
+ // lhs before rhs
+ tint::utils::Vector<const Statement*, 8> stmts;
+ ctx.Replace(s->lhs, Decompose(s->lhs, &stmts));
+ ctx.Replace(s->rhs, Decompose(s->rhs, &stmts));
+ InsertBefore(stmts, s);
+ return ctx.CloneWithoutTransform(s);
+ },
+ [&](const CallStatement* s) -> const Statement* {
+ if (!sem.Get(s->expr)->HasSideEffects()) {
+ return nullptr;
+ }
+ tint::utils::Vector<const Statement*, 8> stmts;
+ ctx.Replace(s->expr, Decompose(s->expr, &stmts));
+ InsertBefore(stmts, s);
+ return ctx.CloneWithoutTransform(s);
+ },
+ [&](const ForLoopStatement* s) -> const Statement* {
+ if (!s->condition || !sem.GetVal(s->condition)->HasSideEffects()) {
+ return nullptr;
+ }
+ tint::utils::Vector<const Statement*, 8> stmts;
+ ctx.Replace(s->condition, Decompose(s->condition, &stmts));
+ InsertBefore(stmts, s);
+ return ctx.CloneWithoutTransform(s);
+ },
+ [&](const WhileStatement* s) -> const Statement* {
+ if (!sem.GetVal(s->condition)->HasSideEffects()) {
+ return nullptr;
+ }
+ tint::utils::Vector<const Statement*, 8> stmts;
+ ctx.Replace(s->condition, Decompose(s->condition, &stmts));
+ InsertBefore(stmts, s);
+ return ctx.CloneWithoutTransform(s);
+ },
+ [&](const IfStatement* s) -> const Statement* {
+ if (!sem.GetVal(s->condition)->HasSideEffects()) {
+ return nullptr;
+ }
+ tint::utils::Vector<const Statement*, 8> stmts;
+ ctx.Replace(s->condition, Decompose(s->condition, &stmts));
+ InsertBefore(stmts, s);
+ return ctx.CloneWithoutTransform(s);
+ },
+ [&](const ReturnStatement* s) -> const Statement* {
+ if (!s->value || !sem.GetVal(s->value)->HasSideEffects()) {
+ return nullptr;
+ }
+ tint::utils::Vector<const Statement*, 8> stmts;
+ ctx.Replace(s->value, Decompose(s->value, &stmts));
+ InsertBefore(stmts, s);
+ return ctx.CloneWithoutTransform(s);
+ },
+ [&](const SwitchStatement* s) -> const Statement* {
+ if (!sem.Get(s->condition)) {
+ return nullptr;
+ }
+ tint::utils::Vector<const Statement*, 8> stmts;
+ ctx.Replace(s->condition, Decompose(s->condition, &stmts));
+ InsertBefore(stmts, s);
+ return ctx.CloneWithoutTransform(s);
+ },
+ [&](const VariableDeclStatement* s) -> const Statement* {
+ auto* var = s->variable;
+ if (!var->initializer || !sem.GetVal(var->initializer)->HasSideEffects()) {
+ return nullptr;
+ }
+ tint::utils::Vector<const Statement*, 8> stmts;
+ ctx.Replace(var->initializer, Decompose(var->initializer, &stmts));
+ InsertBefore(stmts, s);
+ return b.Decl(ctx.CloneWithoutTransform(var));
+ },
+ [](Default) -> const Statement* {
+ // Other statement types don't have expressions
+ return nullptr;
+ });
+ }
+
+ public:
+ explicit DecomposeState(CloneContext& ctx_in, ToHoistSet to_hoist_in)
+ : StateBase(ctx_in), to_hoist(std::move(to_hoist_in)) {}
+
+ void Run() {
+ // We replace all BlockStatements as this allows us to iterate over the
+ // block statements and ctx.InsertBefore hoisted declarations on them.
+ ctx.ReplaceAll([&](const BlockStatement* block) -> const Statement* {
+ for (auto* stmt : block->statements) {
+ if (auto* new_stmt = DecomposeStatement(stmt)) {
+ ctx.Replace(stmt, new_stmt);
+ }
+
+ // Handle for loops, as they are the only other AST node that
+ // contains statements outside of BlockStatements.
+ if (auto* fl = stmt->As<ForLoopStatement>()) {
+ if (auto* new_stmt = DecomposeStatement(fl->initializer)) {
+ ctx.Replace(fl->initializer, new_stmt);
+ }
+ if (auto* new_stmt = DecomposeStatement(fl->continuing)) {
+ ctx.Replace(fl->continuing, new_stmt);
+ }
+ }
+ }
+ return nullptr;
+ });
+ }
+};
+
+Transform::ApplyResult DecomposeSideEffects::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ // First collect side-effecting expressions to hoist
+ CollectHoistsState collect_hoists_state{ctx};
+ auto to_hoist = collect_hoists_state.Run();
+
+ // Now decompose these expressions
+ DecomposeState decompose_state{ctx, std::move(to_hoist)};
+ decompose_state.Run();
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace
+
+PromoteSideEffectsToDecl::PromoteSideEffectsToDecl() = default;
+PromoteSideEffectsToDecl::~PromoteSideEffectsToDecl() = default;
+
+Transform::ApplyResult PromoteSideEffectsToDecl::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap& outputs) const {
+ tint::transform::Manager manager;
+ manager.Add<SimplifySideEffectStatements>();
+ manager.Add<DecomposeSideEffects>();
+ return manager.Run(src, inputs, outputs);
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl.h b/src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl.h
new file mode 100644
index 0000000..25f48a8
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl.h
@@ -0,0 +1,42 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_PROMOTE_SIDE_EFFECTS_TO_DECL_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_PROMOTE_SIDE_EFFECTS_TO_DECL_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// A transform that hoists expressions with side-effects to variable
+/// declarations before the statement of usage with the goal of ensuring
+/// left-to-right order of evaluation, while respecting short-circuit
+/// evaluation.
+class PromoteSideEffectsToDecl final : public utils::Castable<PromoteSideEffectsToDecl, Transform> {
+ public:
+ /// Constructor
+ PromoteSideEffectsToDecl();
+
+ /// Destructor
+ ~PromoteSideEffectsToDecl() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_PROMOTE_SIDE_EFFECTS_TO_DECL_H_
diff --git a/src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl_test.cc b/src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl_test.cc
new file mode 100644
index 0000000..175a91e
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl_test.cc
@@ -0,0 +1,4100 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl.h"
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using PromoteSideEffectsToDeclTest = TransformTest;
+
+TEST_F(PromoteSideEffectsToDeclTest, EmptyModule) {
+ auto* src = "";
+ auto* expect = "";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Unary_Arith_SE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ let r = -(a(0));
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_BothSE) {
+ auto* src = R"(
+fn a() -> i32 {
+ return 1;
+}
+
+fn b() -> i32 {
+ return 1;
+}
+
+fn f() {
+ let r = a() + b();
+}
+)";
+
+ auto* expect = R"(
+fn a() -> i32 {
+ return 1;
+}
+
+fn b() -> i32 {
+ return 1;
+}
+
+fn f() {
+ let tint_symbol : i32 = a();
+ let tint_symbol_1 : i32 = b();
+ let r = (tint_symbol + tint_symbol_1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_LeftSE) {
+ auto* src = R"(
+fn a() -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ let r = a() + b;
+}
+)";
+
+ auto* expect = R"(
+fn a() -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ let tint_symbol : i32 = a();
+ let r = (tint_symbol + b);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_RightSE) {
+ auto* src = R"(
+fn a() -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ let r = b + a();
+}
+)";
+
+ auto* expect = R"(
+fn a() -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ let tint_symbol : i32 = b;
+ let tint_symbol_1 : i32 = a();
+ let r = (tint_symbol + tint_symbol_1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_LeftmostSE) {
+ auto* src = R"(
+fn a() -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ var c = 1;
+ var d = 1;
+ let r = a() + b + c + d;
+}
+)";
+
+ auto* expect = R"(
+fn a() -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ var c = 1;
+ var d = 1;
+ let tint_symbol : i32 = a();
+ let r = (((tint_symbol + b) + c) + d);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_RightmostSE) {
+ auto* src = R"(
+fn a() -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ var c = 1;
+ var d = 1;
+ let r = b + c + d + a();
+}
+)";
+
+ auto* expect = R"(
+fn a() -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ var c = 1;
+ var d = 1;
+ let tint_symbol : i32 = ((b + c) + d);
+ let tint_symbol_1 : i32 = a();
+ let r = (tint_symbol + tint_symbol_1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_MiddleSE) {
+ auto* src = R"(
+fn a() -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ var c = 1;
+ var d = 1;
+ var e = 1;
+ let r = b + c + a() + d + e;
+}
+)";
+
+ auto* expect = R"(
+fn a() -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ var c = 1;
+ var d = 1;
+ var e = 1;
+ let tint_symbol : i32 = (b + c);
+ let tint_symbol_1 : i32 = a();
+ let r = (((tint_symbol + tint_symbol_1) + d) + e);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_ThreeSE) {
+ auto* src = R"(
+fn a(v : i32) -> i32 {
+ return v;
+}
+
+fn f() {
+ let r = a(0) + a(1) + a(2);
+}
+)";
+
+ auto* expect = R"(
+fn a(v : i32) -> i32 {
+ return v;
+}
+
+fn f() {
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
+ let tint_symbol_2 : i32 = a(2);
+ let r = ((tint_symbol + tint_symbol_1) + tint_symbol_2);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Constants_NoRecvSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ let r = ((((1 + a(0)) - 2) + 3) - 4);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ let tint_symbol : i32 = a(0);
+ let r = ((((1 + tint_symbol) - 2) + 3) - 4);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Constants_RecvSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ let r = a(0) + 1 + b + a(1) + 2;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = b;
+ let tint_symbol_2 : i32 = a(1);
+ let r = ((((tint_symbol + 1) + tint_symbol_1) + tint_symbol_2) + 2);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Constants_ConstAndSEAndVar) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn main() {
+ var b = 1;
+ var c = 1;
+ let r = 1 + a(0) + b;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn main() {
+ var b = 1;
+ var c = 1;
+ let tint_symbol : i32 = a(0);
+ let r = ((1 + tint_symbol) + b);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Constants_VarAndSEAndConst) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn main() {
+ var b = 1;
+ let r = b + a(0) + 1;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn main() {
+ var b = 1;
+ let tint_symbol : i32 = b;
+ let tint_symbol_1 : i32 = a(0);
+ let r = ((tint_symbol + tint_symbol_1) + 1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Constants_SEAndVarAndConstAndVar) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn main() {
+ var b = 1;
+ var c = 1;
+ let r = a(0) + b + 1 + c;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn main() {
+ var b = 1;
+ var c = 1;
+ let tint_symbol : i32 = a(0);
+ let r = (((tint_symbol + b) + 1) + c);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Builtins_WithSE) {
+ auto* src = R"(
+struct SB {
+ a : atomic<i32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+fn f() {
+ var b = 0;
+ let r = atomicAdd(&sb.a, 123) + b;
+}
+)";
+
+ auto* expect = R"(
+struct SB {
+ a : atomic<i32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+fn f() {
+ var b = 0;
+ let tint_symbol : i32 = atomicAdd(&(sb.a), 123);
+ let r = (tint_symbol + b);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Builtins_NoSEAndVar) {
+ auto* src = R"(
+struct SB {
+ a : atomic<i32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+fn f() {
+ var b = 0;
+ let r = (atomicLoad(&(sb.a)) + b);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Builtins_NoSEAndSE) {
+ auto* src = R"(
+struct SB {
+ a : atomic<i32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 0;
+ let r = (atomicLoad(&(sb.a)) + a(0));
+}
+)";
+
+ auto* expect = R"(
+struct SB {
+ a : atomic<i32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> sb : SB;
+
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 0;
+ let tint_symbol : i32 = atomicLoad(&(sb.a));
+ let tint_symbol_1 : i32 = a(0);
+ let r = (tint_symbol + tint_symbol_1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_Vector_RightSE) {
+ auto* src = R"(
+fn a() -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b : vec3<i32>;
+ var c : i32;
+ let r = b[c] + a();
+}
+)";
+
+ auto* expect = R"(
+fn a() -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b : vec3<i32>;
+ var c : i32;
+ let tint_symbol : i32 = c;
+ let tint_symbol_1 : i32 = b[tint_symbol];
+ let tint_symbol_2 : i32 = a();
+ let r = (tint_symbol_1 + tint_symbol_2);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InCall) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn g(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ let r = g(a(0)) - g(b + a(1));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn g(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = g(tint_symbol);
+ let tint_symbol_2 : i32 = b;
+ let tint_symbol_3 : i32 = a(1);
+ let tint_symbol_4 : i32 = g((tint_symbol_2 + tint_symbol_3));
+ let r = (tint_symbol_1 - tint_symbol_4);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InTypeInit) {
+ auto* src = R"(
+
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ let r = i32(a(0)) + i32(a(1) + b) - i32(a(2) - a(3));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = i32(tint_symbol);
+ let tint_symbol_2 : i32 = a(1);
+ let tint_symbol_3 : i32 = i32((tint_symbol_2 + b));
+ let tint_symbol_4 : i32 = a(2);
+ let tint_symbol_5 : i32 = a(3);
+ let tint_symbol_6 : i32 = i32((tint_symbol_4 - tint_symbol_5));
+ let r = ((tint_symbol_1 + tint_symbol_3) - tint_symbol_6);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InTypeConversion) {
+ auto* src = R"(
+
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1u;
+ let r = u32(a(0)) + u32(a(1)) - b;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1u;
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : u32 = u32(tint_symbol);
+ let tint_symbol_2 : i32 = a(1);
+ let tint_symbol_3 : u32 = u32(tint_symbol_2);
+ let r = ((tint_symbol_1 + tint_symbol_3) - b);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InIntrinsic) {
+ auto* src = R"(
+
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ let r = abs(a(0)) + abs(a(1) + b) - abs(a(2) + a(3));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = abs(tint_symbol);
+ let tint_symbol_2 : i32 = a(1);
+ let tint_symbol_3 : i32 = abs((tint_symbol_2 + b));
+ let tint_symbol_4 : i32 = a(2);
+ let tint_symbol_5 : i32 = a(3);
+ let tint_symbol_6 : i32 = abs((tint_symbol_4 + tint_symbol_5));
+ let r = ((tint_symbol_1 + tint_symbol_3) - tint_symbol_6);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InMemberAccessor) {
+ auto* src = R"(
+
+struct S {
+ v : i32,
+}
+
+fn a(i : i32) -> S {
+ return S();
+}
+
+fn f() {
+ var b = 1;
+ let r = a(0).v + b + a(1).v;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ v : i32,
+}
+
+fn a(i : i32) -> S {
+ return S();
+}
+
+fn f() {
+ var b = 1;
+ let tint_symbol : S = a(0);
+ let tint_symbol_1 : i32 = b;
+ let tint_symbol_2 : S = a(1);
+ let r = ((tint_symbol.v + tint_symbol_1) + tint_symbol_2.v);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InUnary) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ let r = -a(0) + -(b + a(1));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ let tint_symbol : i32 = -(a(0));
+ let tint_symbol_1 : i32 = b;
+ let tint_symbol_2 : i32 = a(1);
+ let r = (tint_symbol + -((tint_symbol_1 + tint_symbol_2)));
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InBitcast) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ let r = bitcast<u32>(a(0) + a(1));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
+ let r = bitcast<u32>((tint_symbol + tint_symbol_1));
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InForLoopInit) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ for(var r = a(0) + b; ; ) {
+ var marker = 0;
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ {
+ let tint_symbol : i32 = a(0);
+ var r = (tint_symbol + b);
+ loop {
+ {
+ var marker = 0;
+ break;
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InForLoopCond) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ for(; a(0) + b > 0;) {
+ var marker = 0;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ loop {
+ let tint_symbol : i32 = a(0);
+ if (!(((tint_symbol + b) > 0))) {
+ break;
+ }
+ {
+ var marker = 0;
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InForLoopCont) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ var r = 0;
+ for(; ; r = a(0) + b) {
+ var marker = 0;
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ var r = 0;
+ loop {
+ {
+ var marker = 0;
+ break;
+ }
+
+ continuing {
+ let tint_symbol : i32 = a(0);
+ r = (tint_symbol + b);
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InForLoopInitCondCont) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ var c = 2;
+ var d = 3;
+ var r = 0;
+ for(var r = a(0) + b; a(1) + c > 0; r = a(2) + d) {
+ var marker = 0;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ var c = 2;
+ var d = 3;
+ var r = 0;
+ {
+ let tint_symbol : i32 = a(0);
+ var r = (tint_symbol + b);
+ loop {
+ let tint_symbol_1 : i32 = a(1);
+ if (!(((tint_symbol_1 + c) > 0))) {
+ break;
+ }
+ {
+ var marker = 0;
+ }
+
+ continuing {
+ let tint_symbol_2 : i32 = a(2);
+ r = (tint_symbol_2 + d);
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InWhileCond) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ while(a(0) + b > 0) {
+ var marker = 0;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ loop {
+ let tint_symbol : i32 = a(0);
+ if (!(((tint_symbol + b) > 0))) {
+ break;
+ }
+ {
+ var marker = 0;
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InElseIf) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ if (true) {
+ var marker = 0;
+ } else if (a(0) + b > 0) {
+ var marker = 1;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ if (true) {
+ var marker = 0;
+ } else {
+ let tint_symbol : i32 = a(0);
+ if (((tint_symbol + b) > 0)) {
+ var marker = 1;
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InElseIfChain) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ if (true) {
+ var marker = 0;
+ } else if (true) {
+ var marker = 1;
+ } else if (a(0) + b > 0) {
+ var marker = 2;
+ } else if (a(1) + a(2) > 0) {
+ var marker = 3;
+ } else if (true) {
+ var marker = 4;
+ } else {
+ var marker = 5;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ if (true) {
+ var marker = 0;
+ } else if (true) {
+ var marker = 1;
+ } else {
+ let tint_symbol : i32 = a(0);
+ if (((tint_symbol + b) > 0)) {
+ var marker = 2;
+ } else {
+ let tint_symbol_1 : i32 = a(1);
+ let tint_symbol_2 : i32 = a(2);
+ if (((tint_symbol_1 + tint_symbol_2) > 0)) {
+ var marker = 3;
+ } else if (true) {
+ var marker = 4;
+ } else {
+ var marker = 5;
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InReturn) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() -> i32 {
+ var b = 1;
+ return b + a(0);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() -> i32 {
+ var b = 1;
+ let tint_symbol : i32 = b;
+ let tint_symbol_1 : i32 = a(0);
+ return (tint_symbol + tint_symbol_1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InSwitch) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ switch (b + a(0)) {
+ default: {
+ }
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ let tint_symbol : i32 = b;
+ let tint_symbol_1 : i32 = a(0);
+ switch((tint_symbol + tint_symbol_1)) {
+ default: {
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_LeftSE) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ let r = a(0) && b;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var tint_symbol = a(0);
+ if (tint_symbol) {
+ tint_symbol = b;
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_RightSE) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ let r = b && a(0);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var tint_symbol = b;
+ if (tint_symbol) {
+ tint_symbol = a(0);
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_BothSE) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ let r = a(0) && a(1);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var tint_symbol = a(0);
+ if (tint_symbol) {
+ tint_symbol = a(1);
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_LeftmostSE) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var c = true;
+ var d = true;
+ let r = a(0) && b && c && d;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var c = true;
+ var d = true;
+ var tint_symbol_2 = a(0);
+ if (tint_symbol_2) {
+ tint_symbol_2 = b;
+ }
+ var tint_symbol_1 = tint_symbol_2;
+ if (tint_symbol_1) {
+ tint_symbol_1 = c;
+ }
+ var tint_symbol = tint_symbol_1;
+ if (tint_symbol) {
+ tint_symbol = d;
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_RightmostSE) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var c = true;
+ var d = true;
+ let r = b && c && d && a(0);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var c = true;
+ var d = true;
+ var tint_symbol = ((b && c) && d);
+ if (tint_symbol) {
+ tint_symbol = a(0);
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_MiddleSE) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var c = true;
+ var d = true;
+ let r = b && c && a(0) && c && d;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var c = true;
+ var d = true;
+ var tint_symbol_2 = (b && c);
+ if (tint_symbol_2) {
+ tint_symbol_2 = a(0);
+ }
+ var tint_symbol_1 = tint_symbol_2;
+ if (tint_symbol_1) {
+ tint_symbol_1 = c;
+ }
+ var tint_symbol = tint_symbol_1;
+ if (tint_symbol) {
+ tint_symbol = d;
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_Constants_NoRecvSE) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ let r = true && a(0) && false && a(1);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var tint_symbol_2 = true;
+ if (tint_symbol_2) {
+ tint_symbol_2 = a(0);
+ }
+ var tint_symbol_1 = tint_symbol_2;
+ if (tint_symbol_1) {
+ tint_symbol_1 = false;
+ }
+ var tint_symbol = tint_symbol_1;
+ if (tint_symbol) {
+ tint_symbol = a(1);
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_Constants_RecvSE) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ let r = b && true && a(0) && false && a(1);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var tint_symbol_2 = (b && true);
+ if (tint_symbol_2) {
+ tint_symbol_2 = a(0);
+ }
+ var tint_symbol_1 = tint_symbol_2;
+ if (tint_symbol_1) {
+ tint_symbol_1 = false;
+ }
+ var tint_symbol = tint_symbol_1;
+ if (tint_symbol) {
+ tint_symbol = a(1);
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_Constants_ConstAndSEAndVar) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn main() {
+ var b = true;
+ var c = true;
+ let r = true && a(0) && b;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn main() {
+ var b = true;
+ var c = true;
+ var tint_symbol_1 = true;
+ if (tint_symbol_1) {
+ tint_symbol_1 = a(0);
+ }
+ var tint_symbol = tint_symbol_1;
+ if (tint_symbol) {
+ tint_symbol = b;
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_Constants_VarAndSEAndConst) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn main() {
+ var b = true;
+ let r = b && a(0) && true;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn main() {
+ var b = true;
+ var tint_symbol_1 = b;
+ if (tint_symbol_1) {
+ tint_symbol_1 = a(0);
+ }
+ var tint_symbol = tint_symbol_1;
+ if (tint_symbol) {
+ tint_symbol = true;
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_Constants_SEAndVarAndConstAndVar) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn main() {
+ var b = true;
+ var c = true;
+ let r = a(0) && b && true && c;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn main() {
+ var b = true;
+ var c = true;
+ var tint_symbol_2 = a(0);
+ if (tint_symbol_2) {
+ tint_symbol_2 = b;
+ }
+ var tint_symbol_1 = tint_symbol_2;
+ if (tint_symbol_1) {
+ tint_symbol_1 = true;
+ }
+ var tint_symbol = tint_symbol_1;
+ if (tint_symbol) {
+ tint_symbol = c;
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_MixedSE) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var c = true;
+ var d = true;
+ let r = (b && a(0)) || (c && a(1) && c && d) || a(2);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var c = true;
+ var d = true;
+ var tint_symbol_2 = b;
+ if (tint_symbol_2) {
+ tint_symbol_2 = a(0);
+ }
+ var tint_symbol_1 = tint_symbol_2;
+ if (!(tint_symbol_1)) {
+ var tint_symbol_5 = c;
+ if (tint_symbol_5) {
+ tint_symbol_5 = a(1);
+ }
+ var tint_symbol_4 = tint_symbol_5;
+ if (tint_symbol_4) {
+ tint_symbol_4 = c;
+ }
+ var tint_symbol_3 = tint_symbol_4;
+ if (tint_symbol_3) {
+ tint_symbol_3 = d;
+ }
+ tint_symbol_1 = tint_symbol_3;
+ }
+ var tint_symbol = tint_symbol_1;
+ if (!(tint_symbol)) {
+ tint_symbol = a(2);
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_NestedAnds) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ let r = a(0) && (a(1) && (a(2) && (a(3) && (a(4) && a(5)))));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var tint_symbol = a(0);
+ if (tint_symbol) {
+ var tint_symbol_1 = a(1);
+ if (tint_symbol_1) {
+ var tint_symbol_2 = a(2);
+ if (tint_symbol_2) {
+ var tint_symbol_3 = a(3);
+ if (tint_symbol_3) {
+ var tint_symbol_4 = a(4);
+ if (tint_symbol_4) {
+ tint_symbol_4 = a(5);
+ }
+ tint_symbol_3 = tint_symbol_4;
+ }
+ tint_symbol_2 = tint_symbol_3;
+ }
+ tint_symbol_1 = tint_symbol_2;
+ }
+ tint_symbol = tint_symbol_1;
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_NestedOrs) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ let r = a(0) || (a(1) || (a(2) || (a(3) || (a(4) || a(5)))));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var tint_symbol = a(0);
+ if (!(tint_symbol)) {
+ var tint_symbol_1 = a(1);
+ if (!(tint_symbol_1)) {
+ var tint_symbol_2 = a(2);
+ if (!(tint_symbol_2)) {
+ var tint_symbol_3 = a(3);
+ if (!(tint_symbol_3)) {
+ var tint_symbol_4 = a(4);
+ if (!(tint_symbol_4)) {
+ tint_symbol_4 = a(5);
+ }
+ tint_symbol_3 = tint_symbol_4;
+ }
+ tint_symbol_2 = tint_symbol_3;
+ }
+ tint_symbol_1 = tint_symbol_2;
+ }
+ tint_symbol = tint_symbol_1;
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_MultipleStatements) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ let r1 = b && a(0);
+ let r2 = a(1) || b;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var tint_symbol = b;
+ if (tint_symbol) {
+ tint_symbol = a(0);
+ }
+ let r1 = tint_symbol;
+ var tint_symbol_1 = a(1);
+ if (!(tint_symbol_1)) {
+ tint_symbol_1 = b;
+ }
+ let r2 = tint_symbol_1;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InCall) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn g(v : bool) -> bool {
+ return v;
+}
+
+fn f() {
+ var b = true;
+ g(b && a(1));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn g(v : bool) -> bool {
+ return v;
+}
+
+fn f() {
+ var b = true;
+ var tint_symbol = b;
+ if (tint_symbol) {
+ tint_symbol = a(1);
+ }
+ g(tint_symbol);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InTypeInit) {
+ auto* src = R"(
+
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ let r = (bool(a(0)) && bool(a(1) && b)) || bool(a(2) && a(3));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ let tint_symbol_2 : bool = a(0);
+ var tint_symbol_1 = bool(tint_symbol_2);
+ if (tint_symbol_1) {
+ var tint_symbol_3 = a(1);
+ if (tint_symbol_3) {
+ tint_symbol_3 = b;
+ }
+ tint_symbol_1 = bool(tint_symbol_3);
+ }
+ var tint_symbol = tint_symbol_1;
+ if (!(tint_symbol)) {
+ var tint_symbol_4 = a(2);
+ if (tint_symbol_4) {
+ tint_symbol_4 = a(3);
+ }
+ tint_symbol = bool(tint_symbol_4);
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InTypeConversion) {
+ auto* src = R"(
+
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = true;
+ let r = (bool(a(0)) && bool(a(1))) || b;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = true;
+ let tint_symbol_2 : i32 = a(0);
+ var tint_symbol_1 = bool(tint_symbol_2);
+ if (tint_symbol_1) {
+ let tint_symbol_3 : i32 = a(1);
+ tint_symbol_1 = bool(tint_symbol_3);
+ }
+ var tint_symbol = tint_symbol_1;
+ if (!(tint_symbol)) {
+ tint_symbol = b;
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Make sure we process logical binary expressions of non-logical binary
+// expressions.
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_OfNonLogical) {
+ auto* src = R"(
+
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ let r = bool(a(0) == b) && bool(a(1) == b);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ let tint_symbol_1 : i32 = a(0);
+ var tint_symbol = bool((tint_symbol_1 == b));
+ if (tint_symbol) {
+ let tint_symbol_2 : i32 = a(1);
+ tint_symbol = bool((tint_symbol_2 == b));
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InIntrinsic) {
+ auto* src = R"(
+
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ let r = (all(a(0)) && all(a(1) && b)) || all(a(2) && a(3));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ let tint_symbol_2 : bool = a(0);
+ var tint_symbol_1 = all(tint_symbol_2);
+ if (tint_symbol_1) {
+ var tint_symbol_3 = a(1);
+ if (tint_symbol_3) {
+ tint_symbol_3 = b;
+ }
+ tint_symbol_1 = all(tint_symbol_3);
+ }
+ var tint_symbol = tint_symbol_1;
+ if (!(tint_symbol)) {
+ var tint_symbol_4 = a(2);
+ if (tint_symbol_4) {
+ tint_symbol_4 = a(3);
+ }
+ tint_symbol = all(tint_symbol_4);
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InMemberAccessor) {
+ auto* src = R"(
+
+struct S {
+ v : bool,
+}
+
+fn a(i : i32) -> S {
+ return S();
+}
+
+fn f() {
+ var b = true;
+ let r = a(0).v && b && a(1).v;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ v : bool,
+}
+
+fn a(i : i32) -> S {
+ return S();
+}
+
+fn f() {
+ var b = true;
+ let tint_symbol_2 : S = a(0);
+ var tint_symbol_1 = tint_symbol_2.v;
+ if (tint_symbol_1) {
+ tint_symbol_1 = b;
+ }
+ var tint_symbol = tint_symbol_1;
+ if (tint_symbol) {
+ let tint_symbol_3 : S = a(1);
+ tint_symbol = tint_symbol_3.v;
+ }
+ let r = tint_symbol;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InUnary) {
+ auto* src = R"(
+
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ let r = !(b || a(1));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var tint_symbol = b;
+ if (!(tint_symbol)) {
+ tint_symbol = a(1);
+ }
+ let r = !(tint_symbol);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InBitcast) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ let r = bitcast<u32>(i32(a(0) && a(1)));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var tint_symbol = a(0);
+ if (tint_symbol) {
+ tint_symbol = a(1);
+ }
+ let r = bitcast<u32>(i32(tint_symbol));
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InForLoopInit) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ for(var r = a(0) && b; ; ) {
+ var marker = 0;
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ {
+ var tint_symbol = a(0);
+ if (tint_symbol) {
+ tint_symbol = b;
+ }
+ var r = tint_symbol;
+ loop {
+ {
+ var marker = 0;
+ break;
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InForLoopCond) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ for(; a(0) && b;) {
+ var marker = 0;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ loop {
+ var tint_symbol = a(0);
+ if (tint_symbol) {
+ tint_symbol = b;
+ }
+ if (!(tint_symbol)) {
+ break;
+ }
+ {
+ var marker = 0;
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InForLoopCont) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var r = true;
+ for(; ; r = a(0) && b) {
+ var marker = 0;
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var r = true;
+ loop {
+ {
+ var marker = 0;
+ break;
+ }
+
+ continuing {
+ var tint_symbol = a(0);
+ if (tint_symbol) {
+ tint_symbol = b;
+ }
+ r = tint_symbol;
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InForLoopInitCondCont) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var c = true;
+ var d = true;
+ var r = true;
+ for(var r = a(0) && b; a(1) && c; r = a(2) && d) {
+ var marker = 0;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ var c = true;
+ var d = true;
+ var r = true;
+ {
+ var tint_symbol = a(0);
+ if (tint_symbol) {
+ tint_symbol = b;
+ }
+ var r = tint_symbol;
+ loop {
+ var tint_symbol_1 = a(1);
+ if (tint_symbol_1) {
+ tint_symbol_1 = c;
+ }
+ if (!(tint_symbol_1)) {
+ break;
+ }
+ {
+ var marker = 0;
+ }
+
+ continuing {
+ var tint_symbol_2 = a(2);
+ if (tint_symbol_2) {
+ tint_symbol_2 = d;
+ }
+ r = tint_symbol_2;
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InWhileCond) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ while(a(0) && b) {
+ var marker = 0;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ loop {
+ var tint_symbol = a(0);
+ if (tint_symbol) {
+ tint_symbol = b;
+ }
+ if (!(tint_symbol)) {
+ break;
+ }
+ {
+ var marker = 0;
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InElseIf) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ if (true) {
+ var marker = 0;
+ } else if (a(0) && b) {
+ var marker = 1;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ if (true) {
+ var marker = 0;
+ } else {
+ var tint_symbol = a(0);
+ if (tint_symbol) {
+ tint_symbol = b;
+ }
+ if (tint_symbol) {
+ var marker = 1;
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InElseIfChain) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ if (true) {
+ var marker = 0;
+ } else if (true) {
+ var marker = 1;
+ } else if (a(0) && b) {
+ var marker = 2;
+ } else if (a(1) && a(2)) {
+ var marker = 3;
+ } else if (true) {
+ var marker = 4;
+ } else {
+ var marker = 5;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ if (true) {
+ var marker = 0;
+ } else if (true) {
+ var marker = 1;
+ } else {
+ var tint_symbol = a(0);
+ if (tint_symbol) {
+ tint_symbol = b;
+ }
+ if (tint_symbol) {
+ var marker = 2;
+ } else {
+ var tint_symbol_1 = a(1);
+ if (tint_symbol_1) {
+ tint_symbol_1 = a(2);
+ }
+ if (tint_symbol_1) {
+ var marker = 3;
+ } else if (true) {
+ var marker = 4;
+ } else {
+ var marker = 5;
+ }
+ }
+ }
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Call_NoSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn g(a : i32, b : i32, c : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ var c = 1;
+ let r = g(b, c, 3);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Call_OneSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn g(a : i32, b : i32, c : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ let r = g(a(0), b, 3);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn g(a : i32, b : i32, c : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ let tint_symbol : i32 = a(0);
+ let r = g(tint_symbol, b, 3);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Call_AllSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn g(a : i32, b : i32, c : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ let r = g(a(0), a(1), a(2));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn g(a : i32, b : i32, c : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
+ let tint_symbol_2 : i32 = a(2);
+ let r = g(tint_symbol, tint_symbol_1, tint_symbol_2);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Call_MiddleNotSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn g(a : i32, b : i32, c : i32) -> i32 {
+ return 1;
+}
+
+
+fn f() {
+ var b = 1;
+ let r = g(a(0), b, a(1));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn g(a : i32, b : i32, c : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 1;
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = b;
+ let tint_symbol_2 : i32 = a(1);
+ let r = g(tint_symbol, tint_symbol_1, tint_symbol_2);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Call_InBinary) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn g(i : i32, j : i32, k : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 0;
+ var c = 0;
+ var d = 0;
+ let r = b + g(c, a(0), d) + a(1);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn g(i : i32, j : i32, k : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = 0;
+ var c = 0;
+ var d = 0;
+ let tint_symbol : i32 = b;
+ let tint_symbol_1 : i32 = c;
+ let tint_symbol_2 : i32 = a(0);
+ let tint_symbol_3 : i32 = g(tint_symbol_1, tint_symbol_2, d);
+ let tint_symbol_4 : i32 = a(1);
+ let r = ((tint_symbol + tint_symbol_3) + tint_symbol_4);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor_2D_LeftSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = array<array<i32, 10>, 10>();
+ var c = 1;
+ var r = b[a(0)][c];
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = array<array<i32, 10>, 10>();
+ var c = 1;
+ let tint_symbol : i32 = a(0);
+ var r = b[tint_symbol][c];
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor_2D_RightSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = array<array<i32, 10>, 10>();
+ var c = 1;
+ let tint_symbol = c;
+ let tint_symbol_1 = a(0);
+ var r = b[tint_symbol][tint_symbol_1];
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = array<array<i32, 10>, 10>();
+ var c = 1;
+ let tint_symbol = c;
+ let tint_symbol_1 = a(0);
+ var r = b[tint_symbol][tint_symbol_1];
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor_2D_BothSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = array<array<i32, 10>, 10>();
+ var r = b[a(0)][a(1)];
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = array<array<i32, 10>, 10>();
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
+ var r = b[tint_symbol][tint_symbol_1];
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Assignment_ToPhony) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ _ = a(0);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ _ = a(0);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Assignment_ToArray1D) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = array<i32, 10>();
+ b[a(0)] = a(1);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = array<i32, 10>();
+ let tint_symbol : i32 = a(0);
+ b[tint_symbol] = a(1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Assignment_ToArray2D) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = array<array<i32, 10>, 10>();
+ b[a(0)][a(1)] = a(2);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = array<array<i32, 10>, 10>();
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
+ b[tint_symbol][tint_symbol_1] = a(2);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Assignment_ToArray3D) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = array<array<array<i32, 10>, 10>, 10>();
+ b[a(0)][a(1)][a(2)] = a(3);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = array<array<array<i32, 10>, 10>, 10>();
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
+ let tint_symbol_2 : i32 = a(2);
+ b[tint_symbol][tint_symbol_1][tint_symbol_2] = a(3);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Assignment_ToArray_FromArray) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = array<i32, 3>();
+ var d = array<array<i32, 3>, 3>();
+ var a_1 = 0;
+ b[a(2)] = d[a(0)][a_1];
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = array<i32, 3>();
+ var d = array<array<i32, 3>, 3>();
+ var a_1 = 0;
+ let tint_symbol : i32 = a(2);
+ let tint_symbol_1 : i32 = a(0);
+ b[tint_symbol] = d[tint_symbol_1][a_1];
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Assignment_ToVec_BothSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = vec3<i32>();
+ b[a(0)] = a(1);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = vec3<i32>();
+ let tint_symbol : i32 = a(0);
+ b[tint_symbol] = a(1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Assignment_ToVec_LeftSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = vec3<i32>();
+ var c = 0;
+ b[a(0)] = c;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = vec3<i32>();
+ var c = 0;
+ let tint_symbol : i32 = a(0);
+ b[tint_symbol] = c;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Assignment_ToVec_RightSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = vec3<i32>();
+ var c = 0;
+ b[c] = a(0);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var b = vec3<i32>();
+ var c = 0;
+ let tint_symbol : i32 = c;
+ b[tint_symbol] = a(0);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, TypeInitializer_Struct) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+struct S {
+ x : i32,
+ y : i32,
+ z : i32,
+}
+
+fn f() {
+ var r = S(a(0), a(1), a(2));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+struct S {
+ x : i32,
+ y : i32,
+ z : i32,
+}
+
+fn f() {
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
+ let tint_symbol_2 : i32 = a(2);
+ var r = S(tint_symbol, tint_symbol_1, tint_symbol_2);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, TypeInitializer_Array1D) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var r = array<i32, 3>(a(0), a(1), a(2));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
+ let tint_symbol_2 : i32 = a(2);
+ var r = array<i32, 3>(tint_symbol, tint_symbol_1, tint_symbol_2);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, TypeInitializer_Array2D) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ var r = array<array<i32, 2>, 2>(array<i32, 2>(a(0), a(1)), array<i32, 2>(a(2), a(3)));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 1;
+}
+
+fn f() {
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
+ let tint_symbol_2 : array<i32, 2u> = array<i32, 2>(tint_symbol, tint_symbol_1);
+ let tint_symbol_3 : i32 = a(2);
+ let tint_symbol_4 : i32 = a(3);
+ let tint_symbol_5 : array<i32, 2u> = array<i32, 2>(tint_symbol_3, tint_symbol_4);
+ var r = array<array<i32, 2>, 2>(tint_symbol_2, tint_symbol_5);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, MemberAccessor_Vec) {
+ auto* src = R"(
+fn a(i : i32) -> vec3<i32> {
+ return vec3<i32>();
+}
+
+fn f() {
+ var r = a(0).x + a(1).y;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> vec3<i32> {
+ return vec3<i32>();
+}
+
+fn f() {
+ let tint_symbol : vec3<i32> = a(0);
+ let tint_symbol_1 : vec3<i32> = a(1);
+ var r = (tint_symbol.x + tint_symbol_1.y);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, MemberAccessor_Struct) {
+ auto* src = R"(
+struct S {
+ x : i32,
+ y : i32,
+}
+
+fn a(i : i32) -> S {
+ return S();
+}
+
+fn f() {
+ var r = a(0).x + a(1).y;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ x : i32,
+ y : i32,
+}
+
+fn a(i : i32) -> S {
+ return S();
+}
+
+fn f() {
+ let tint_symbol : S = a(0);
+ let tint_symbol_1 : S = a(1);
+ var r = (tint_symbol.x + tint_symbol_1.y);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, MemberAccessor_Struct_Mixed) {
+ auto* src = R"(
+struct S {
+ x : i32,
+ y : i32,
+ arr : array<i32, 10>,
+}
+
+fn a(i : i32) -> S {
+ return S();
+}
+
+fn b(i : i32) -> i32 {
+ return 0;
+}
+
+fn f() {
+ var i = 0;
+ var j = 0;
+ var k = 0;
+ var l = 0;
+ var m = 0;
+ var r = a(0).x + i + a(1).y + j + a(2).arr[k + b(3) + l] + m;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ x : i32,
+ y : i32,
+ arr : array<i32, 10>,
+}
+
+fn a(i : i32) -> S {
+ return S();
+}
+
+fn b(i : i32) -> i32 {
+ return 0;
+}
+
+fn f() {
+ var i = 0;
+ var j = 0;
+ var k = 0;
+ var l = 0;
+ var m = 0;
+ let tint_symbol : S = a(0);
+ let tint_symbol_1 : i32 = i;
+ let tint_symbol_2 : S = a(1);
+ let tint_symbol_3 : i32 = j;
+ let tint_symbol_4 : S = a(2);
+ let tint_symbol_5 : i32 = k;
+ let tint_symbol_6 : i32 = b(3);
+ var r = (((((tint_symbol.x + tint_symbol_1) + tint_symbol_2.y) + tint_symbol_3) + tint_symbol_4.arr[((tint_symbol_5 + tint_symbol_6) + l)]) + m);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor_Plus_SE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<i32, 10>();
+ let r = v[0] + a(0);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<i32, 10>();
+ let tint_symbol : i32 = v[0];
+ let tint_symbol_1 : i32 = a(0);
+ let r = (tint_symbol + tint_symbol_1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor_Of_SE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<i32, 10>();
+ let r = v[a(0)];
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<i32, 10>();
+ let tint_symbol : i32 = a(0);
+ let r = v[tint_symbol];
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor2_Of_LeftSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<array<i32, 10>, 10>();
+ let r = v[a(0)][0];
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<array<i32, 10>, 10>();
+ let tint_symbol : i32 = a(0);
+ let r = v[tint_symbol][0];
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor2_Of_RightSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<array<i32, 10>, 10>();
+ let r = v[0][a(0)];
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<array<i32, 10>, 10>();
+ let tint_symbol : i32 = a(0);
+ let r = v[0][tint_symbol];
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor2_Of_SEAndVar) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<array<i32, 10>, 10>();
+ var b : i32;
+ let r = v[a(0)][b];
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<array<i32, 10>, 10>();
+ var b : i32;
+ let tint_symbol : i32 = a(0);
+ let r = v[tint_symbol][b];
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor2_Of_VarAndSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<array<i32, 10>, 10>();
+ var b : i32;
+ let r = v[b][a(0)];
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<array<i32, 10>, 10>();
+ var b : i32;
+ let tint_symbol : i32 = b;
+ let tint_symbol_1 : i32 = a(0);
+ let r = v[tint_symbol][tint_symbol_1];
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, IndexAccessorOfVar_Plus_SE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<i32, 10>();
+ var b = 0;
+ let r = v[b] + a(0);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<i32, 10>();
+ var b = 0;
+ let tint_symbol : i32 = b;
+ let tint_symbol_1 : i32 = v[tint_symbol];
+ let tint_symbol_2 : i32 = a(0);
+ let r = (tint_symbol_1 + tint_symbol_2);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, IndexAccessor_Plus_IndexAccessorOfSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<i32, 10>();
+ let r = v[0] + v[a(0)];
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<i32, 10>();
+ let tint_symbol : i32 = v[0];
+ let tint_symbol_1 : i32 = a(0);
+ let r = (tint_symbol + v[tint_symbol_1]);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, AssignTo_IndexAccessorOfIndexAccessorOfSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<i32, 10>();
+ var w = array<i32, 10>();
+ v[w[a(0)]] = 1;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<i32, 10>();
+ var w = array<i32, 10>();
+ let tint_symbol : i32 = a(0);
+ v[w[tint_symbol]] = 1;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, AssignTo_IndexAccessorOfIndexAccessorOfLiteralPlusSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<i32, 10>();
+ var w = array<i32, 10>();
+ v[w[0] + a(0)] = 1;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<i32, 10>();
+ var w = array<i32, 10>();
+ let tint_symbol : i32 = w[0];
+ let tint_symbol_1 : i32 = a(0);
+ v[(tint_symbol + tint_symbol_1)] = 1;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest,
+ AssignTo_IndexAccessorOfIndexAccessorOfLiteralPlusIndexAccessorOfSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<i32, 10>();
+ var w = array<i32, 10>();
+ v[w[0] + w[a(0)]] = 1;
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v = array<i32, 10>();
+ var w = array<i32, 10>();
+ let tint_symbol : i32 = w[0];
+ let tint_symbol_1 : i32 = a(0);
+ v[(tint_symbol + w[tint_symbol_1])] = 1;
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, IndexAccessorOfLhsSERhsSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn b() -> array<i32, 10> {
+ return array<i32, 10>();
+}
+
+fn f() {
+ let r = b()[a(0)];
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn b() -> array<i32, 10> {
+ return array<i32, 10>();
+}
+
+fn f() {
+ let tint_symbol : array<i32, 10u> = b();
+ let tint_symbol_1 : i32 = a(0);
+ let r = tint_symbol[tint_symbol_1];
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, BinaryIndexAccessorOfLhsSERhsSE) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn b() -> array<i32, 10> {
+ return array<i32, 10>();
+}
+
+fn f() {
+ let r = b()[0] + a(0);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn b() -> array<i32, 10> {
+ return array<i32, 10>();
+}
+
+fn f() {
+ let tint_symbol : array<i32, 10u> = b();
+ let tint_symbol_1 : i32 = a(0);
+ let r = (tint_symbol[0] + tint_symbol_1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, BinaryMemberAccessorPlusSE) {
+ // bclayton@'s example:
+ // https://dawn-review.googlesource.com/c/tint/+/78620/6..8/src/transform/promote_side_effects_to_decl.cc#b490
+ auto* src = R"(
+fn modify_vec(p : ptr<function, vec4<i32>>) -> i32 {
+ (*p).x = 42;
+ return 0;
+}
+
+fn f() {
+ var v = vec4<i32>();
+ let l = v.x + modify_vec(&v);
+ // l should be 0, not 42
+}
+)";
+
+ auto* expect = R"(
+fn modify_vec(p : ptr<function, vec4<i32>>) -> i32 {
+ (*(p)).x = 42;
+ return 0;
+}
+
+fn f() {
+ var v = vec4<i32>();
+ let tint_symbol : i32 = v.x;
+ let tint_symbol_1 : i32 = modify_vec(&(v));
+ let l = (tint_symbol + tint_symbol_1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Call_ReadOnlyArgAndSE) {
+ // Make sure that read-only args don't get hoisted (tex and samp)
+ auto* src = R"(
+@group(1) @binding(1) var tex: texture_2d_array<u32>;
+@group(1) @binding(2) var samp: sampler;
+
+fn get_uv() -> vec2<f32> {
+ return vec2<f32>(1.0, 2.0);
+}
+
+fn f() {
+ let r = textureGather(1, tex, samp, get_uv(), 1);
+}
+)";
+
+ auto* expect = R"(
+@group(1) @binding(1) var tex : texture_2d_array<u32>;
+
+@group(1) @binding(2) var samp : sampler;
+
+fn get_uv() -> vec2<f32> {
+ return vec2<f32>(1.0, 2.0);
+}
+
+fn f() {
+ let tint_symbol : vec2<f32> = get_uv();
+ let r = textureGather(1, tex, samp, tint_symbol, 1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Call_PtrArgAndSE) {
+ // Make sure that read-only args don't get hoisted (tex and samp)
+ auto* src = R"(
+
+var<private> b : i32 = 0;
+
+fn a(i : i32) -> i32 {
+ b = 42;
+ return 0;
+}
+
+fn g(i : ptr<private, i32>, j : i32) -> i32 {
+ return *i;
+}
+
+fn f() {
+ // a(0) should be hoisted, but not &b
+ let r = g(&b, a(0));
+ // r should be 42
+}
+)";
+
+ auto* expect = R"(
+var<private> b : i32 = 0;
+
+fn a(i : i32) -> i32 {
+ b = 42;
+ return 0;
+}
+
+fn g(i : ptr<private, i32>, j : i32) -> i32 {
+ return *(i);
+}
+
+fn f() {
+ let tint_symbol : i32 = a(0);
+ let r = g(&(b), tint_symbol);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, TypeInit_VarPlusI32InitPlusVar) {
+ auto* src = R"(
+fn f() {
+ var b = 0;
+ var c = 0;
+ var d = 0;
+ let r = ((b + i32(c)) + d);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Mixed_ArithPlusLogical) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 0;
+}
+
+fn b(i : i32) -> bool {
+ return true;
+}
+
+fn g(i : bool) -> i32 {
+ return 0;
+}
+
+fn f() {
+ let r = a(0) + g(b(1) && b(2));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 0;
+}
+
+fn b(i : i32) -> bool {
+ return true;
+}
+
+fn g(i : bool) -> i32 {
+ return 0;
+}
+
+fn f() {
+ let tint_symbol : i32 = a(0);
+ var tint_symbol_1 = b(1);
+ if (tint_symbol_1) {
+ tint_symbol_1 = b(2);
+ }
+ let tint_symbol_2 : i32 = g(tint_symbol_1);
+ let r = (tint_symbol + tint_symbol_2);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Mixed_LogicalPlusArith) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 0;
+}
+
+fn b(i : i32) -> bool {
+ return true;
+}
+
+fn g(i : bool) -> i32 {
+ return 0;
+}
+
+fn f() {
+ let r = g(b(0) && b(1)) + a(2);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 0;
+}
+
+fn b(i : i32) -> bool {
+ return true;
+}
+
+fn g(i : bool) -> i32 {
+ return 0;
+}
+
+fn f() {
+ var tint_symbol = b(0);
+ if (tint_symbol) {
+ tint_symbol = b(1);
+ }
+ let tint_symbol_1 : i32 = g(tint_symbol);
+ let tint_symbol_2 : i32 = a(2);
+ let r = (tint_symbol_1 + tint_symbol_2);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Mixed_ArithAndLogicalArgs) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 0;
+}
+
+fn b(i : i32) -> bool {
+ return true;
+}
+
+fn g(i : i32, j : bool) -> i32 {
+ return 0;
+}
+
+fn f() {
+ let r = g(a(0), b(1) && b(2));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 0;
+}
+
+fn b(i : i32) -> bool {
+ return true;
+}
+
+fn g(i : i32, j : bool) -> i32 {
+ return 0;
+}
+
+fn f() {
+ let tint_symbol : i32 = a(0);
+ var tint_symbol_1 = b(1);
+ if (tint_symbol_1) {
+ tint_symbol_1 = b(2);
+ }
+ let r = g(tint_symbol, tint_symbol_1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Mixed_LogicalAndArithArgs) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 0;
+}
+
+fn b(i : i32) -> bool {
+ return true;
+}
+
+fn g(i : bool, j : i32) -> i32 {
+ return 0;
+}
+
+fn f() {
+ let r = g(b(0) && b(1), a(2));
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 0;
+}
+
+fn b(i : i32) -> bool {
+ return true;
+}
+
+fn g(i : bool, j : i32) -> i32 {
+ return 0;
+}
+
+fn f() {
+ var tint_symbol = b(0);
+ if (tint_symbol) {
+ tint_symbol = b(1);
+ }
+ let tint_symbol_1 : i32 = a(2);
+ let r = g(tint_symbol, tint_symbol_1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Mixed_Complex) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return 0;
+}
+
+fn b(i : i32) -> bool {
+ return true;
+}
+
+fn c(i : bool) -> i32 {
+ return 0;
+}
+
+fn g(i : bool, j : i32) -> i32 {
+ return 0;
+}
+
+fn f() {
+ let r = a(0) + g(b(1) && b(a(2) + a(3)), a(4)) + a(5);
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return 0;
+}
+
+fn b(i : i32) -> bool {
+ return true;
+}
+
+fn c(i : bool) -> i32 {
+ return 0;
+}
+
+fn g(i : bool, j : i32) -> i32 {
+ return 0;
+}
+
+fn f() {
+ let tint_symbol : i32 = a(0);
+ var tint_symbol_1 = b(1);
+ if (tint_symbol_1) {
+ let tint_symbol_2 : i32 = a(2);
+ let tint_symbol_3 : i32 = a(3);
+ tint_symbol_1 = b((tint_symbol_2 + tint_symbol_3));
+ }
+ let tint_symbol_4 : i32 = a(4);
+ let tint_symbol_5 : i32 = g(tint_symbol_1, tint_symbol_4);
+ let tint_symbol_6 : i32 = a(5);
+ let r = ((tint_symbol + tint_symbol_5) + tint_symbol_6);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, TextureSamplerParameter) {
+ auto* src = R"(
+@group(0) @binding(0) var T : texture_2d<f32>;
+@group(0) @binding(1) var S : sampler;
+
+var<private> P : vec2<f32>;
+fn side_effects() -> vec2<f32> {
+ P += vec2(1.0);
+ return P;
+}
+
+fn f(t : texture_2d<f32>, s : sampler) -> vec4<f32> {
+ return textureSample(t, s, side_effects());
+}
+
+fn m() -> vec4<f32>{
+ return f(T, S);
+}
+)";
+
+ auto* expect = R"(
+@group(0) @binding(0) var T : texture_2d<f32>;
+
+@group(0) @binding(1) var S : sampler;
+
+var<private> P : vec2<f32>;
+
+fn side_effects() -> vec2<f32> {
+ P += vec2(1.0);
+ return P;
+}
+
+fn f(t : texture_2d<f32>, s : sampler) -> vec4<f32> {
+ let tint_symbol : vec2<f32> = side_effects();
+ return textureSample(t, s, tint_symbol);
+}
+
+fn m() -> vec4<f32> {
+ return f(T, S);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, BuiltinReturnType) {
+ auto* src = R"(
+fn X(a:vec2f, b:vec2f) {
+}
+
+fn Y() -> vec2f { return vec2f(); }
+
+fn f() {
+ var v: vec2f;
+ X(vec2(), v); // okay
+ X(vec2(), Y()); // errors
+}
+)";
+
+ auto* expect = R"(
+fn X(a : vec2f, b : vec2f) {
+}
+
+fn Y() -> vec2f {
+ return vec2f();
+}
+
+fn f() {
+ var v : vec2f;
+ X(vec2(), v);
+ let tint_symbol : vec2<f32> = vec2();
+ let tint_symbol_1 : vec2<f32> = Y();
+ X(tint_symbol, tint_symbol_1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Bug1963) {
+ auto* src = R"(
+fn X(a:vec2f, b:vec2f) {
+}
+
+fn Y() -> vec2f { return vec2f(); }
+
+fn f() {
+ var v: vec2f;
+ X(vec2(), v); // okay
+ X(vec2(), Y()); // errors
+}
+)";
+
+ auto* expect = R"(
+fn X(a : vec2f, b : vec2f) {
+}
+
+fn Y() -> vec2f {
+ return vec2f();
+}
+
+fn f() {
+ var v : vec2f;
+ X(vec2(), v);
+ let tint_symbol : vec2<f32> = vec2();
+ let tint_symbol_1 : vec2<f32> = Y();
+ X(tint_symbol, tint_symbol_1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/remove_continue_in_switch.cc b/src/tint/lang/wgsl/ast/transform/remove_continue_in_switch.cc
new file mode 100644
index 0000000..2558ba4
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/remove_continue_in_switch.cc
@@ -0,0 +1,133 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/remove_continue_in_switch.h"
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/continue_statement.h"
+#include "src/tint/lang/wgsl/ast/switch_statement.h"
+#include "src/tint/lang/wgsl/ast/transform/utils/get_insertion_point.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/block_statement.h"
+#include "src/tint/sem/for_loop_statement.h"
+#include "src/tint/sem/loop_statement.h"
+#include "src/tint/sem/switch_statement.h"
+#include "src/tint/sem/while_statement.h"
+#include "src/tint/utils/map.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::RemoveContinueInSwitch);
+
+namespace tint::ast::transform {
+
+/// PIMPL state for the transform
+struct RemoveContinueInSwitch::State {
+ /// Constructor
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
+
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
+ bool made_changes = false;
+
+ for (auto* node : src->ASTNodes().Objects()) {
+ auto* cont = node->As<ContinueStatement>();
+ if (!cont) {
+ continue;
+ }
+
+ // If first parent is not a switch within a loop, skip
+ auto* switch_stmt = GetParentSwitchInLoop(sem, cont);
+ if (!switch_stmt) {
+ continue;
+ }
+
+ made_changes = true;
+
+ auto cont_var_name =
+ tint::utils::GetOrCreate(switch_to_cont_var_name, switch_stmt, [&] {
+ // Create and insert 'var tint_continue : bool = false;' before the
+ // switch.
+ auto var_name = b.Symbols().New("tint_continue");
+ auto* decl = b.Decl(b.Var(var_name, b.ty.bool_(), b.Expr(false)));
+ auto ip = utils::GetInsertionPoint(ctx, switch_stmt);
+ ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, decl);
+
+ // Create and insert 'if (tint_continue) { continue; }' after
+ // switch.
+ auto* if_stmt = b.If(b.Expr(var_name), b.Block(b.Continue()));
+ ctx.InsertAfter(ip.first->Declaration()->statements, ip.second, if_stmt);
+
+ // Return the new var name
+ return var_name;
+ });
+
+ // Replace 'continue;' with '{ tint_continue = true; break; }'
+ auto* new_stmt = b.Block( //
+ b.Assign(b.Expr(cont_var_name), true), //
+ b.Break());
+
+ ctx.Replace(cont, new_stmt);
+ }
+
+ if (!made_changes) {
+ return SkipTransform;
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+
+ private:
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+ /// Alias to src->sem
+ const sem::Info& sem = src->Sem();
+
+ // Map of switch statement to 'tint_continue' variable.
+ std::unordered_map<const SwitchStatement*, Symbol> switch_to_cont_var_name;
+
+ // If `cont` is within a switch statement within a loop, returns a pointer to
+ // that switch statement.
+ static const SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem,
+ const ContinueStatement* cont) {
+ // Find whether first parent is a switch or a loop
+ auto* sem_stmt = sem.Get(cont);
+ auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement,
+ sem::ForLoopStatement, sem::WhileStatement>();
+ if (!sem_parent) {
+ return nullptr;
+ }
+ return sem_parent->Declaration()->As<SwitchStatement>();
+ }
+};
+
+RemoveContinueInSwitch::RemoveContinueInSwitch() = default;
+RemoveContinueInSwitch::~RemoveContinueInSwitch() = default;
+
+Transform::ApplyResult RemoveContinueInSwitch::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ State state(src);
+ return state.Run();
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/remove_continue_in_switch.h b/src/tint/lang/wgsl/ast/transform/remove_continue_in_switch.h
new file mode 100644
index 0000000..ad78617
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/remove_continue_in_switch.h
@@ -0,0 +1,45 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_REMOVE_CONTINUE_IN_SWITCH_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_REMOVE_CONTINUE_IN_SWITCH_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// This transform replaces continue statements in switch cases with setting a
+/// bool variable, and checking if the variable is set after the switch to
+/// continue. It is necessary to work around FXC "error X3708: continue cannot
+/// be used in a switch". See crbug.com/tint/1080.
+class RemoveContinueInSwitch final : public utils::Castable<RemoveContinueInSwitch, Transform> {
+ public:
+ /// Constructor
+ RemoveContinueInSwitch();
+
+ /// Destructor
+ ~RemoveContinueInSwitch() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_REMOVE_CONTINUE_IN_SWITCH_H_
diff --git a/src/tint/lang/wgsl/ast/transform/remove_continue_in_switch_test.cc b/src/tint/lang/wgsl/ast/transform/remove_continue_in_switch_test.cc
new file mode 100644
index 0000000..8b88195
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/remove_continue_in_switch_test.cc
@@ -0,0 +1,617 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/remove_continue_in_switch.h"
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using RemoveContinueInSwitchTest = TransformTest;
+
+TEST_F(RemoveContinueInSwitchTest, ShouldRun_True) {
+ auto* src = R"(
+fn f() {
+ var i = 0;
+ loop {
+ switch(i) {
+ case 0: {
+ continue;
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ break;
+ }
+}
+)";
+
+ EXPECT_TRUE(ShouldRun<RemoveContinueInSwitch>(src));
+}
+
+TEST_F(RemoveContinueInSwitchTest, ShouldRunEmptyModule_False) {
+ auto* src = "";
+
+ EXPECT_FALSE(ShouldRun<RemoveContinueInSwitch>(src));
+}
+
+TEST_F(RemoveContinueInSwitchTest, ShouldRunContinueNotInSwitch_False) {
+ auto* src = R"(
+fn f() {
+ var i = 0;
+ loop {
+ switch(i) {
+ case 0: {
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+
+ if (true) {
+ continue;
+ }
+ break;
+ }
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<RemoveContinueInSwitch>(src));
+}
+
+TEST_F(RemoveContinueInSwitchTest, ShouldRunContinueInLoopInSwitch_False) {
+ auto* src = R"(
+fn f() {
+ var i = 0;
+ switch(i) {
+ case 0: {
+ loop {
+ if (true) {
+ continue;
+ }
+ break;
+ }
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<RemoveContinueInSwitch>(src));
+}
+
+TEST_F(RemoveContinueInSwitchTest, EmptyModule) {
+ auto* src = "";
+ auto* expect = src;
+
+ Transform::DataMap data;
+ auto got = Run<RemoveContinueInSwitch>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemoveContinueInSwitchTest, SingleContinue) {
+ auto* src = R"(
+fn f() {
+ var i = 0;
+ loop {
+ let marker1 = 0;
+ switch(i) {
+ case 0: {
+ continue;
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ let marker2 = 0;
+ break;
+
+ continuing {
+ let marker3 = 0;
+ }
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i = 0;
+ loop {
+ let marker1 = 0;
+ var tint_continue : bool = false;
+ switch(i) {
+ case 0: {
+ {
+ tint_continue = true;
+ break;
+ }
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ if (tint_continue) {
+ continue;
+ }
+ let marker2 = 0;
+ break;
+
+ continuing {
+ let marker3 = 0;
+ }
+ }
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<RemoveContinueInSwitch>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemoveContinueInSwitchTest, MultipleContinues) {
+ auto* src = R"(
+fn f() {
+ var i = 0;
+ loop {
+ let marker1 = 0;
+ switch(i) {
+ case 0: {
+ continue;
+ break;
+ }
+ case 1: {
+ continue;
+ break;
+ }
+ case 2: {
+ continue;
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ let marker2 = 0;
+ break;
+
+ continuing {
+ let marker3 = 0;
+ }
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i = 0;
+ loop {
+ let marker1 = 0;
+ var tint_continue : bool = false;
+ switch(i) {
+ case 0: {
+ {
+ tint_continue = true;
+ break;
+ }
+ break;
+ }
+ case 1: {
+ {
+ tint_continue = true;
+ break;
+ }
+ break;
+ }
+ case 2: {
+ {
+ tint_continue = true;
+ break;
+ }
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ if (tint_continue) {
+ continue;
+ }
+ let marker2 = 0;
+ break;
+
+ continuing {
+ let marker3 = 0;
+ }
+ }
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<RemoveContinueInSwitch>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemoveContinueInSwitchTest, MultipleSwitch) {
+ auto* src = R"(
+fn f() {
+ var i = 0;
+ loop {
+ let marker1 = 0;
+ switch(i) {
+ case 0: {
+ continue;
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ let marker2 = 0;
+
+ let marker3 = 0;
+ switch(i) {
+ case 0: {
+ continue;
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ let marker4 = 0;
+
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i = 0;
+ loop {
+ let marker1 = 0;
+ var tint_continue : bool = false;
+ switch(i) {
+ case 0: {
+ {
+ tint_continue = true;
+ break;
+ }
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ if (tint_continue) {
+ continue;
+ }
+ let marker2 = 0;
+ let marker3 = 0;
+ var tint_continue_1 : bool = false;
+ switch(i) {
+ case 0: {
+ {
+ tint_continue_1 = true;
+ break;
+ }
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ if (tint_continue_1) {
+ continue;
+ }
+ let marker4 = 0;
+ break;
+ }
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<RemoveContinueInSwitch>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemoveContinueInSwitchTest, NestedLoopSwitch) {
+ auto* src = R"(
+fn f() {
+ var i = 0;
+ loop {
+ let marker1 = 0;
+ switch(i) {
+ case 0: {
+ var j = 0;
+ loop {
+ let marker3 = 0;
+ switch(j) {
+ case 0: {
+ continue;
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ let marker4 = 0;
+ break;
+ }
+ continue;
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ let marker2 = 0;
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i = 0;
+ loop {
+ let marker1 = 0;
+ var tint_continue_1 : bool = false;
+ switch(i) {
+ case 0: {
+ var j = 0;
+ loop {
+ let marker3 = 0;
+ var tint_continue : bool = false;
+ switch(j) {
+ case 0: {
+ {
+ tint_continue = true;
+ break;
+ }
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ if (tint_continue) {
+ continue;
+ }
+ let marker4 = 0;
+ break;
+ }
+ {
+ tint_continue_1 = true;
+ break;
+ }
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ if (tint_continue_1) {
+ continue;
+ }
+ let marker2 = 0;
+ break;
+ }
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<RemoveContinueInSwitch>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemoveContinueInSwitchTest, ExtraScopes) {
+ auto* src = R"(
+fn f() {
+ var i = 0;
+ var a = true;
+ var b = true;
+ var c = true;
+ var d = true;
+ loop {
+ if (a) {
+ if (b) {
+ let marker1 = 0;
+ switch(i) {
+ case 0: {
+ if (c) {
+ if (d) {
+ continue;
+ }
+ }
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ let marker2 = 0;
+ break;
+ }
+ }
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i = 0;
+ var a = true;
+ var b = true;
+ var c = true;
+ var d = true;
+ loop {
+ if (a) {
+ if (b) {
+ let marker1 = 0;
+ var tint_continue : bool = false;
+ switch(i) {
+ case 0: {
+ if (c) {
+ if (d) {
+ {
+ tint_continue = true;
+ break;
+ }
+ }
+ }
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ if (tint_continue) {
+ continue;
+ }
+ let marker2 = 0;
+ break;
+ }
+ }
+ }
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<RemoveContinueInSwitch>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemoveContinueInSwitchTest, ForLoop) {
+ auto* src = R"(
+fn f() {
+ for (var i = 0; i < 4; i = i + 1) {
+ let marker1 = 0;
+ switch(i) {
+ case 0: {
+ continue;
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ let marker2 = 0;
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ for(var i = 0; (i < 4); i = (i + 1)) {
+ let marker1 = 0;
+ var tint_continue : bool = false;
+ switch(i) {
+ case 0: {
+ {
+ tint_continue = true;
+ break;
+ }
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ if (tint_continue) {
+ continue;
+ }
+ let marker2 = 0;
+ break;
+ }
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<RemoveContinueInSwitch>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemoveContinueInSwitchTest, While) {
+ auto* src = R"(
+fn f() {
+ var i = 0;
+ while (i < 4) {
+ let marker1 = 0;
+ switch(i) {
+ case 0: {
+ continue;
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ let marker2 = 0;
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i = 0;
+ while((i < 4)) {
+ let marker1 = 0;
+ var tint_continue : bool = false;
+ switch(i) {
+ case 0: {
+ {
+ tint_continue = true;
+ break;
+ }
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ if (tint_continue) {
+ continue;
+ }
+ let marker2 = 0;
+ break;
+ }
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<RemoveContinueInSwitch>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/remove_phonies.cc b/src/tint/lang/wgsl/ast/transform/remove_phonies.cc
new file mode 100644
index 0000000..290eeee
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/remove_phonies.cc
@@ -0,0 +1,156 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/remove_phonies.h"
+
+#include <memory>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/traverse_expressions.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/block_statement.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/utils/map.h"
+#include "src/tint/utils/scoped_assignment.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::RemovePhonies);
+
+namespace tint::ast::transform {
+namespace {
+
+using SinkSignature = std::vector<const type::Type*>;
+
+} // namespace
+
+RemovePhonies::RemovePhonies() = default;
+
+RemovePhonies::~RemovePhonies() = default;
+
+Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&, DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ auto& sem = src->Sem();
+
+ utils::Hashmap<SinkSignature, Symbol, 8, utils::Hasher<SinkSignature>> sinks;
+
+ bool made_changes = false;
+ for (auto* node : src->ASTNodes().Objects()) {
+ Switch(
+ node,
+ [&](const AssignmentStatement* stmt) {
+ if (stmt->lhs->Is<PhonyExpression>()) {
+ made_changes = true;
+
+ std::vector<const Expression*> side_effects;
+ if (!TraverseExpressions(
+ stmt->rhs, b.Diagnostics(), [&](const CallExpression* expr) {
+ // CallExpression may map to a function or builtin call
+ // (both may have side-effects), or a value constructor or value
+ // conversion (both do not have side effects).
+ auto* call = sem.Get<sem::Call>(expr);
+ if (!call) {
+ // Semantic node must be a Materialize, in which case the
+ // expression was creation-time (compile time), so could not
+ // have side effects. Just skip.
+ return TraverseAction::Skip;
+ }
+ if (call->Target()->IsAnyOf<sem::Function, sem::Builtin>() &&
+ call->HasSideEffects()) {
+ side_effects.push_back(expr);
+ return TraverseAction::Skip;
+ }
+ return TraverseAction::Descend;
+ })) {
+ return;
+ }
+
+ if (side_effects.empty()) {
+ // Phony assignment with no side effects.
+ // Just remove it.
+ RemoveStatement(ctx, stmt);
+ return;
+ }
+
+ if (side_effects.size() == 1) {
+ if (auto* call_expr = side_effects[0]->As<CallExpression>()) {
+ // Phony assignment with single call side effect.
+ auto* call = sem.Get(call_expr)->Unwrap()->As<sem::Call>();
+ if (call->Target()->MustUse()) {
+ // Replace phony assignment assignment to uniquely named let.
+ ctx.Replace<Statement>(stmt, [&, call_expr] { //
+ auto name = b.Symbols().New("tint_phony");
+ auto* rhs = ctx.Clone(call_expr);
+ return b.Decl(b.Let(name, rhs));
+ });
+ } else {
+ // Replace phony assignment with call statement.
+ ctx.Replace(stmt, [&, call_expr] { //
+ return b.CallStmt(ctx.Clone(call_expr));
+ });
+ }
+ return;
+ }
+ }
+
+ // Phony assignment with multiple side effects.
+ // Generate a call to a placeholder function with the side
+ // effects as arguments.
+ ctx.Replace(stmt, [&, side_effects] {
+ SinkSignature sig;
+ for (auto* arg : side_effects) {
+ sig.push_back(sem.GetVal(arg)->Type()->UnwrapRef());
+ }
+ auto sink = sinks.GetOrCreate(sig, [&] {
+ auto name = b.Symbols().New("phony_sink");
+ utils::Vector<const Parameter*, 8> params;
+ for (auto* ty : sig) {
+ auto ast_ty = CreateASTTypeFor(ctx, ty);
+ params.Push(b.Param("p" + std::to_string(params.Length()), ast_ty));
+ }
+ b.Func(name, params, b.ty.void_(), {});
+ return name;
+ });
+ utils::Vector<const Expression*, 8> args;
+ for (auto* arg : side_effects) {
+ args.Push(ctx.Clone(arg));
+ }
+ return b.CallStmt(b.Call(sink, args));
+ });
+ }
+ },
+ [&](const CallStatement* stmt) {
+ // Remove call statements to const value-returning functions.
+ // TODO(crbug.com/tint/1637): Remove if `stmt->expr` has no side-effects.
+ auto* sem_expr = sem.Get(stmt->expr);
+ if ((sem_expr->ConstantValue() != nullptr) && !sem_expr->HasSideEffects()) {
+ made_changes = true;
+ ctx.Remove(sem.Get(stmt)->Block()->Declaration()->statements, stmt);
+ }
+ });
+ }
+
+ if (!made_changes) {
+ return SkipTransform;
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/remove_phonies.h b/src/tint/lang/wgsl/ast/transform/remove_phonies.h
new file mode 100644
index 0000000..5b5c0f1
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/remove_phonies.h
@@ -0,0 +1,44 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_REMOVE_PHONIES_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_REMOVE_PHONIES_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// RemovePhonies is a Transform that removes all phony-assignment statements,
+/// while preserving function call expressions in the RHS of the assignment that
+/// may have side-effects. It also removes calls to builtins that return a constant value.
+class RemovePhonies final : public utils::Castable<RemovePhonies, Transform> {
+ public:
+ /// Constructor
+ RemovePhonies();
+
+ /// Destructor
+ ~RemovePhonies() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_REMOVE_PHONIES_H_
diff --git a/src/tint/lang/wgsl/ast/transform/remove_phonies_test.cc b/src/tint/lang/wgsl/ast/transform/remove_phonies_test.cc
new file mode 100644
index 0000000..210dcda
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/remove_phonies_test.cc
@@ -0,0 +1,487 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/remove_phonies.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using RemovePhoniesTest = TransformTest;
+
+TEST_F(RemovePhoniesTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<RemovePhonies>(src));
+}
+
+TEST_F(RemovePhoniesTest, ShouldRunHasPhony) {
+ auto* src = R"(
+fn f() {
+ _ = 1;
+}
+)";
+
+ EXPECT_TRUE(ShouldRun<RemovePhonies>(src));
+}
+
+TEST_F(RemovePhoniesTest, EmptyModule) {
+ auto* src = "";
+ auto* expect = "";
+
+ auto got = Run<RemovePhonies>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemovePhoniesTest, NoSideEffects) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+fn f() {
+ var v : i32;
+ _ = &v;
+ _ = 1;
+ _ = 1 + 2;
+ _ = t;
+ _ = u32(3.0);
+ _ = f32(i32(4u));
+ _ = vec2<f32>(5.0);
+ _ = vec3<i32>(6, 7, 8);
+ _ = mat2x2<f32>(9.0, 10.0, 11.0, 12.0);
+ _ = atan2(1.0, 2.0);
+ _ = clamp(1.0, 2.0, 3.0);
+}
+)";
+
+ auto* expect = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+fn f() {
+ var v : i32;
+}
+)";
+
+ auto got = Run<RemovePhonies>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemovePhoniesTest, SingleSideEffects) {
+ auto* src = R"(
+fn neg(a : i32) -> i32 {
+ return -(a);
+}
+
+fn add(a : i32, b : i32) -> i32 {
+ return (a + b);
+}
+
+fn f() {
+ _ = neg(1);
+ _ = add(2, 3);
+ _ = add(neg(4), neg(5));
+ _ = u32(neg(6));
+ _ = f32(add(7, 8));
+ _ = vec2<f32>(f32(neg(9)));
+ _ = vec3<i32>(1, neg(10), 3);
+ _ = mat2x2<f32>(1.0, f32(add(11, 12)), 3.0, 4.0);
+}
+)";
+
+ auto* expect = R"(
+fn neg(a : i32) -> i32 {
+ return -(a);
+}
+
+fn add(a : i32, b : i32) -> i32 {
+ return (a + b);
+}
+
+fn f() {
+ neg(1);
+ add(2, 3);
+ add(neg(4), neg(5));
+ neg(6);
+ add(7, 8);
+ neg(9);
+ neg(10);
+ add(11, 12);
+}
+)";
+
+ auto got = Run<RemovePhonies>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemovePhoniesTest, SingleSideEffectsMustUse) {
+ auto* src = R"(
+@must_use
+fn neg(a : i32) -> i32 {
+ return -(a);
+}
+
+@must_use
+fn add(a : i32, b : i32) -> i32 {
+ return (a + b);
+}
+
+fn f() {
+ let tint_phony = 42;
+ _ = neg(1);
+ _ = 100 + add(2, 3) + 200;
+ _ = add(neg(4), neg(5)) + 6;
+ _ = u32(neg(6));
+ _ = f32(add(7, 8));
+ _ = vec2<f32>(f32(neg(9)));
+ _ = vec3<i32>(1, neg(10), 3);
+ _ = mat2x2<f32>(1.0, f32(add(11, 12)), 3.0, 4.0);
+}
+)";
+
+ auto* expect = R"(
+@must_use
+fn neg(a : i32) -> i32 {
+ return -(a);
+}
+
+@must_use
+fn add(a : i32, b : i32) -> i32 {
+ return (a + b);
+}
+
+fn f() {
+ let tint_phony = 42;
+ let tint_phony_1 = neg(1);
+ let tint_phony_2 = add(2, 3);
+ let tint_phony_3 = add(neg(4), neg(5));
+ let tint_phony_4 = neg(6);
+ let tint_phony_5 = add(7, 8);
+ let tint_phony_6 = neg(9);
+ let tint_phony_7 = neg(10);
+ let tint_phony_8 = add(11, 12);
+}
+)";
+
+ auto got = Run<RemovePhonies>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemovePhoniesTest, SingleSideEffects_OutOfOrder) {
+ auto* src = R"(
+fn f() {
+ _ = neg(1);
+ _ = add(2, 3);
+ _ = add(neg(4), neg(5));
+ _ = u32(neg(6));
+ _ = f32(add(7, 8));
+ _ = vec2<f32>(f32(neg(9)));
+ _ = vec3<i32>(1, neg(10), 3);
+ _ = mat2x2<f32>(1.0, f32(add(11, 12)), 3.0, 4.0);
+}
+
+fn add(a : i32, b : i32) -> i32 {
+ return (a + b);
+}
+
+fn neg(a : i32) -> i32 {
+ return -(a);
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ neg(1);
+ add(2, 3);
+ add(neg(4), neg(5));
+ neg(6);
+ add(7, 8);
+ neg(9);
+ neg(10);
+ add(11, 12);
+}
+
+fn add(a : i32, b : i32) -> i32 {
+ return (a + b);
+}
+
+fn neg(a : i32) -> i32 {
+ return -(a);
+}
+)";
+
+ auto got = Run<RemovePhonies>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemovePhoniesTest, MultipleSideEffects) {
+ auto* src = R"(
+fn neg(a : i32) -> i32 {
+ return -(a);
+}
+
+@must_use
+fn add(a : i32, b : i32) -> i32 {
+ return (a + b);
+}
+
+fn xor(a : u32, b : u32) -> u32 {
+ return (a ^ b);
+}
+
+fn f() {
+ _ = (1 + add(2 + add(3, 4), 5)) * add(6, 7) * neg(8);
+ _ = add(9, neg(10)) + neg(11);
+ _ = xor(12u, 13u) + xor(14u, 15u);
+ _ = neg(16) / neg(17) + add(18, 19);
+}
+)";
+
+ auto* expect = R"(
+fn neg(a : i32) -> i32 {
+ return -(a);
+}
+
+@must_use
+fn add(a : i32, b : i32) -> i32 {
+ return (a + b);
+}
+
+fn xor(a : u32, b : u32) -> u32 {
+ return (a ^ b);
+}
+
+fn phony_sink(p0 : i32, p1 : i32, p2 : i32) {
+}
+
+fn phony_sink_1(p0 : i32, p1 : i32) {
+}
+
+fn phony_sink_2(p0 : u32, p1 : u32) {
+}
+
+fn f() {
+ phony_sink(add((2 + add(3, 4)), 5), add(6, 7), neg(8));
+ phony_sink_1(add(9, neg(10)), neg(11));
+ phony_sink_2(xor(12u, 13u), xor(14u, 15u));
+ phony_sink(neg(16), neg(17), add(18, 19));
+}
+)";
+
+ auto got = Run<RemovePhonies>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemovePhoniesTest, MultipleSideEffects_OutOfOrder) {
+ auto* src = R"(
+fn f() {
+ _ = (1 + add(2 + add(3, 4), 5)) * add(6, 7) * neg(8);
+ _ = add(9, neg(10)) + neg(11);
+ _ = xor(12u, 13u) + xor(14u, 15u);
+ _ = neg(16) / neg(17) + add(18, 19);
+}
+
+fn neg(a : i32) -> i32 {
+ return -(a);
+}
+
+fn add(a : i32, b : i32) -> i32 {
+ return (a + b);
+}
+
+fn xor(a : u32, b : u32) -> u32 {
+ return (a ^ b);
+}
+)";
+
+ auto* expect = R"(
+fn phony_sink(p0 : i32, p1 : i32, p2 : i32) {
+}
+
+fn phony_sink_1(p0 : i32, p1 : i32) {
+}
+
+fn phony_sink_2(p0 : u32, p1 : u32) {
+}
+
+fn f() {
+ phony_sink(add((2 + add(3, 4)), 5), add(6, 7), neg(8));
+ phony_sink_1(add(9, neg(10)), neg(11));
+ phony_sink_2(xor(12u, 13u), xor(14u, 15u));
+ phony_sink(neg(16), neg(17), add(18, 19));
+}
+
+fn neg(a : i32) -> i32 {
+ return -(a);
+}
+
+fn add(a : i32, b : i32) -> i32 {
+ return (a + b);
+}
+
+fn xor(a : u32, b : u32) -> u32 {
+ return (a ^ b);
+}
+)";
+
+ auto got = Run<RemovePhonies>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemovePhoniesTest, ForLoop) {
+ auto* src = R"(
+struct S {
+ arr : array<i32>,
+};
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+fn x() -> i32 {
+ return 0;
+}
+
+fn y() -> i32 {
+ return 0;
+}
+
+fn z() -> i32 {
+ return 0;
+}
+
+fn f() {
+ for (_ = &s.arr; ;_ = &s.arr) {
+ break;
+ }
+ for (_ = x(); ;_ = y() + z()) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ arr : array<i32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+fn x() -> i32 {
+ return 0;
+}
+
+fn y() -> i32 {
+ return 0;
+}
+
+fn z() -> i32 {
+ return 0;
+}
+
+fn phony_sink(p0 : i32, p1 : i32) {
+}
+
+fn f() {
+ for(; ; ) {
+ break;
+ }
+ for(x(); ; phony_sink(y(), z())) {
+ break;
+ }
+}
+)";
+
+ auto got = Run<RemovePhonies>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemovePhoniesTest, ForLoop_OutOfOrder) {
+ auto* src = R"(
+fn f() {
+ for (_ = &s.arr; ;_ = &s.arr) {
+ break;
+ }
+ for (_ = x(); ;_ = y() + z()) {
+ break;
+ }
+}
+
+fn x() -> i32 {
+ return 0;
+}
+
+fn y() -> i32 {
+ return 0;
+}
+
+fn z() -> i32 {
+ return 0;
+}
+
+struct S {
+ arr : array<i32>,
+};
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+)";
+
+ auto* expect = R"(
+fn phony_sink(p0 : i32, p1 : i32) {
+}
+
+fn f() {
+ for(; ; ) {
+ break;
+ }
+ for(x(); ; phony_sink(y(), z())) {
+ break;
+ }
+}
+
+fn x() -> i32 {
+ return 0;
+}
+
+fn y() -> i32 {
+ return 0;
+}
+
+fn z() -> i32 {
+ return 0;
+}
+
+struct S {
+ arr : array<i32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+)";
+
+ auto got = Run<RemovePhonies>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/remove_unreachable_statements.cc b/src/tint/lang/wgsl/ast/transform/remove_unreachable_statements.cc
new file mode 100644
index 0000000..50b9dfa
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/remove_unreachable_statements.cc
@@ -0,0 +1,63 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/remove_unreachable_statements.h"
+
+#include <memory>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/traverse_expressions.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/block_statement.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/utils/map.h"
+#include "src/tint/utils/scoped_assignment.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::RemoveUnreachableStatements);
+
+namespace tint::ast::transform {
+
+RemoveUnreachableStatements::RemoveUnreachableStatements() = default;
+
+RemoveUnreachableStatements::~RemoveUnreachableStatements() = default;
+
+Transform::ApplyResult RemoveUnreachableStatements::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ bool made_changes = false;
+ for (auto* node : src->ASTNodes().Objects()) {
+ if (auto* stmt = src->Sem().Get<sem::Statement>(node)) {
+ if (!stmt->IsReachable()) {
+ RemoveStatement(ctx, stmt->Declaration());
+ made_changes = true;
+ }
+ }
+ }
+
+ if (!made_changes) {
+ return SkipTransform;
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/remove_unreachable_statements.h b/src/tint/lang/wgsl/ast/transform/remove_unreachable_statements.h
new file mode 100644
index 0000000..c5131df
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/remove_unreachable_statements.h
@@ -0,0 +1,44 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_REMOVE_UNREACHABLE_STATEMENTS_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_REMOVE_UNREACHABLE_STATEMENTS_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// RemoveUnreachableStatements is a Transform that removes all statements
+/// marked as unreachable.
+class RemoveUnreachableStatements final
+ : public utils::Castable<RemoveUnreachableStatements, Transform> {
+ public:
+ /// Constructor
+ RemoveUnreachableStatements();
+
+ /// Destructor
+ ~RemoveUnreachableStatements() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_REMOVE_UNREACHABLE_STATEMENTS_H_
diff --git a/src/tint/lang/wgsl/ast/transform/remove_unreachable_statements_test.cc b/src/tint/lang/wgsl/ast/transform/remove_unreachable_statements_test.cc
new file mode 100644
index 0000000..cc88c61
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/remove_unreachable_statements_test.cc
@@ -0,0 +1,255 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/remove_unreachable_statements.h"
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using RemoveUnreachableStatementsTest = TransformTest;
+
+TEST_F(RemoveUnreachableStatementsTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<RemoveUnreachableStatements>(src));
+}
+
+TEST_F(RemoveUnreachableStatementsTest, ShouldRunHasNoUnreachable) {
+ auto* src = R"(
+fn f() {
+ if (true) {
+ var x = 1;
+ }
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<RemoveUnreachableStatements>(src));
+}
+
+TEST_F(RemoveUnreachableStatementsTest, ShouldRunHasUnreachable) {
+ auto* src = R"(
+fn f() {
+ return;
+ if (true) {
+ var x = 1;
+ }
+}
+)";
+
+ EXPECT_TRUE(ShouldRun<RemoveUnreachableStatements>(src));
+}
+
+TEST_F(RemoveUnreachableStatementsTest, EmptyModule) {
+ auto* src = "";
+ auto* expect = "";
+
+ auto got = Run<RemoveUnreachableStatements>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemoveUnreachableStatementsTest, Return) {
+ auto* src = R"(
+fn f() {
+ return;
+ var remove_me = 1;
+ if (true) {
+ var remove_me_too = 1;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ return;
+}
+)";
+
+ auto got = Run<RemoveUnreachableStatements>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemoveUnreachableStatementsTest, NestedReturn) {
+ auto* src = R"(
+fn f() {
+ {
+ {
+ return;
+ }
+ }
+ var remove_me = 1;
+ if (true) {
+ var remove_me_too = 1;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ {
+ {
+ return;
+ }
+ }
+}
+)";
+
+ auto got = Run<RemoveUnreachableStatements>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Discard has "demote-to-helper" semantics, and so code following a discard statement is not
+// considered unreachable.
+TEST_F(RemoveUnreachableStatementsTest, Discard) {
+ auto* src = R"(
+fn f() {
+ discard;
+ var preserve_me = 1;
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ discard;
+ var preserve_me = 1;
+}
+)";
+
+ auto got = Run<RemoveUnreachableStatements>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemoveUnreachableStatementsTest, IfReturn) {
+ auto* src = R"(
+fn f() {
+ if (true) {
+ return;
+ }
+ var preserve_me = 1;
+ if (true) {
+ var preserve_me_too = 1;
+ }
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<RemoveUnreachableStatements>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemoveUnreachableStatementsTest, IfElseReturn) {
+ auto* src = R"(
+fn f() {
+ if (true) {
+ } else {
+ return;
+ }
+ var preserve_me = 1;
+ if (true) {
+ var preserve_me_too = 1;
+ }
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<RemoveUnreachableStatements>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemoveUnreachableStatementsTest, LoopWithConditionalBreak) {
+ auto* src = R"(
+fn f() {
+ loop {
+ var a = 1;
+ if (true) {
+ break;
+ }
+
+ continuing {
+ var b = 2;
+ }
+ }
+ var preserve_me = 1;
+ if (true) {
+ var preserve_me_too = 1;
+ }
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<RemoveUnreachableStatements>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemoveUnreachableStatementsTest, LoopWithConditionalBreakInContinuing) {
+ auto* src = R"(
+fn f() {
+ loop {
+
+ continuing {
+ break if true;
+ }
+ }
+ var preserve_me = 1;
+ if (true) {
+ var preserve_me_too = 1;
+ }
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<RemoveUnreachableStatements>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RemoveUnreachableStatementsTest, SwitchCaseReturnDefaultBreak) {
+ auto* src = R"(
+fn f() {
+ switch(1) {
+ case 0: {
+ return;
+ }
+ default: {
+ break;
+ }
+ }
+ var preserve_me = 1;
+ if (true) {
+ var preserve_me_too = 1;
+ }
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<RemoveUnreachableStatements>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/renamer.cc b/src/tint/lang/wgsl/ast/transform/renamer.cc
new file mode 100644
index 0000000..34f796c
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/renamer.cc
@@ -0,0 +1,1404 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/renamer.h"
+
+#include <memory>
+#include <utility>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/builtin_enum_expression.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/type_expression.h"
+#include "src/tint/sem/value_constructor.h"
+#include "src/tint/sem/value_conversion.h"
+#include "src/tint/switch.h"
+#include "src/tint/utils/unicode.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::Renamer);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::Renamer::Data);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::Renamer::Config);
+
+namespace tint::ast::transform {
+
+namespace {
+
+// This list is used for a binary search and must be kept in sorted order.
+const char* const kReservedKeywordsGLSL[] = {
+ "abs",
+ "acos",
+ "acosh",
+ "active",
+ "all",
+ "any",
+ "asin",
+ "asinh",
+ "asm",
+ "atan",
+ "atanh",
+ "atomicAdd",
+ "atomicAnd",
+ "atomicCompSwap",
+ "atomicCounter",
+ "atomicCounterDecrement",
+ "atomicCounterIncrement",
+ "atomicExchange",
+ "atomicMax",
+ "atomicMin",
+ "atomicOr",
+ "atomicXor",
+ "atomic_uint",
+ "attribute",
+ "barrier",
+ "bitCount",
+ "bitfieldExtract",
+ "bitfieldInsert",
+ "bitfieldReverse",
+ "bool",
+ "break",
+ "buffer",
+ "bvec2",
+ "bvec3",
+ "bvec4",
+ "case",
+ "cast",
+ "ceil",
+ "centroid",
+ "clamp",
+ "class",
+ "coherent",
+ "common",
+ "const",
+ "continue",
+ "cos",
+ "cosh",
+ "cross",
+ "dFdx",
+ "dFdy",
+ "default",
+ "degrees",
+ "determinant",
+ "discard",
+ "distance",
+ "dmat2",
+ "dmat2x2",
+ "dmat2x3",
+ "dmat2x4",
+ "dmat3",
+ "dmat3x2",
+ "dmat3x3",
+ "dmat3x4",
+ "dmat4",
+ "dmat4x2",
+ "dmat4x3",
+ "dmat4x4",
+ "do",
+ "dot",
+ "double",
+ "dvec2",
+ "dvec3",
+ "dvec4",
+ "else",
+ "enum",
+ "equal",
+ "exp",
+ "exp2",
+ "extern",
+ "external",
+ "faceforward",
+ "false",
+ "filter",
+ "findLSB",
+ "findMSB",
+ "fixed",
+ "flat",
+ "float",
+ "floatBitsToInt",
+ "floatBitsToUint",
+ "floor",
+ "for",
+ "fract",
+ "frexp",
+ "fvec2",
+ "fvec3",
+ "fvec4",
+ "fwidth",
+ "gl_BaseInstance",
+ "gl_BaseVertex",
+ "gl_ClipDistance",
+ "gl_DepthRangeParameters",
+ "gl_DrawID",
+ "gl_FragCoord",
+ "gl_FragDepth",
+ "gl_FrontFacing",
+ "gl_GlobalInvocationID",
+ "gl_InstanceID",
+ "gl_LocalInvocationID",
+ "gl_LocalInvocationIndex",
+ "gl_NumSamples",
+ "gl_NumWorkGroups",
+ "gl_PerVertex",
+ "gl_PointCoord",
+ "gl_PointSize",
+ "gl_Position",
+ "gl_PrimitiveID",
+ "gl_SampleID",
+ "gl_SampleMask",
+ "gl_SampleMaskIn",
+ "gl_SamplePosition",
+ "gl_VertexID",
+ "gl_WorkGroupID",
+ "gl_WorkGroupSize",
+ "goto",
+ "greaterThan",
+ "greaterThanEqual",
+ "groupMemoryBarrier",
+ "half",
+ "highp",
+ "hvec2",
+ "hvec3",
+ "hvec4",
+ "if",
+ "iimage1D",
+ "iimage1DArray",
+ "iimage2D",
+ "iimage2DArray",
+ "iimage2DMS",
+ "iimage2DMSArray",
+ "iimage2DRect",
+ "iimage3D",
+ "iimageBuffer",
+ "iimageCube",
+ "iimageCubeArray",
+ "image1D",
+ "image1DArray",
+ "image2D",
+ "image2DArray",
+ "image2DMS",
+ "image2DMSArray",
+ "image2DRect",
+ "image3D",
+ "imageBuffer",
+ "imageCube",
+ "imageCubeArray",
+ "imageLoad",
+ "imageSize",
+ "imageStore",
+ "imulExtended",
+ "in",
+ "inline",
+ "inout",
+ "input",
+ "int",
+ "intBitsToFloat",
+ "interface",
+ "invariant",
+ "inverse",
+ "inversesqrt",
+ "isampler1D",
+ "isampler1DArray",
+ "isampler2D",
+ "isampler2DArray",
+ "isampler2DMS",
+ "isampler2DMSArray",
+ "isampler2DRect",
+ "isampler3D",
+ "isamplerBuffer",
+ "isamplerCube",
+ "isamplerCubeArray",
+ "isinf",
+ "isnan",
+ "ivec2",
+ "ivec3",
+ "ivec4",
+ "layout",
+ "ldexp",
+ "length",
+ "lessThan",
+ "lessThanEqual",
+ "log",
+ "log2",
+ "long",
+ "lowp",
+ "main",
+ "mat2",
+ "mat2x2",
+ "mat2x3",
+ "mat2x4",
+ "mat3",
+ "mat3x2",
+ "mat3x3",
+ "mat3x4",
+ "mat4",
+ "mat4x2",
+ "mat4x3",
+ "mat4x4",
+ "matrixCompMult",
+ "max",
+ "mediump",
+ "memoryBarrier",
+ "memoryBarrierAtomicCounter",
+ "memoryBarrierBuffer",
+ "memoryBarrierImage",
+ "memoryBarrierShared",
+ "min",
+ "mix",
+ "mod",
+ "modf",
+ "namespace",
+ "noinline",
+ "noperspective",
+ "normalize",
+ "not",
+ "notEqual",
+ "out",
+ "outerProduct",
+ "output",
+ "packHalf2x16",
+ "packSnorm2x16",
+ "packSnorm4x8",
+ "packUnorm2x16",
+ "packUnorm4x8",
+ "partition",
+ "patch",
+ "pow",
+ "precise",
+ "precision",
+ "public",
+ "radians",
+ "readonly",
+ "reflect",
+ "refract",
+ "resource",
+ "restrict",
+ "return",
+ "round",
+ "roundEven",
+ "sample",
+ "sampler1D",
+ "sampler1DArray",
+ "sampler1DArrayShadow",
+ "sampler1DShadow",
+ "sampler2D",
+ "sampler2DArray",
+ "sampler2DArrayShadow",
+ "sampler2DMS",
+ "sampler2DMSArray",
+ "sampler2DRect",
+ "sampler2DRectShadow",
+ "sampler2DShadow",
+ "sampler3D",
+ "sampler3DRect",
+ "samplerBuffer",
+ "samplerCube",
+ "samplerCubeArray",
+ "samplerCubeArrayShadow",
+ "samplerCubeShadow",
+ "shared",
+ "short",
+ "sign",
+ "sin",
+ "sinh",
+ "sizeof",
+ "smooth",
+ "smoothstep",
+ "sqrt",
+ "static",
+ "step",
+ "struct",
+ "subroutine",
+ "superp",
+ "switch",
+ "tan",
+ "tanh",
+ "template",
+ "texelFetch",
+ "texelFetchOffset",
+ "texture",
+ "textureGather",
+ "textureGatherOffset",
+ "textureGrad",
+ "textureGradOffset",
+ "textureLod",
+ "textureLodOffset",
+ "textureOffset",
+ "textureProj",
+ "textureProjGrad",
+ "textureProjGradOffset",
+ "textureProjLod",
+ "textureProjLodOffset",
+ "textureProjOffset",
+ "textureSize",
+ "this",
+ "transpose",
+ "true",
+ "trunc",
+ "typedef",
+ "uaddCarry",
+ "uimage1D",
+ "uimage1DArray",
+ "uimage2D",
+ "uimage2DArray",
+ "uimage2DMS",
+ "uimage2DMSArray",
+ "uimage2DRect",
+ "uimage3D",
+ "uimageBuffer",
+ "uimageCube",
+ "uimageCubeArray",
+ "uint",
+ "uintBitsToFloat",
+ "umulExtended",
+ "uniform",
+ "union",
+ "unpackHalf2x16",
+ "unpackSnorm2x16",
+ "unpackSnorm4x8",
+ "unpackUnorm2x16",
+ "unpackUnorm4x8",
+ "unsigned",
+ "usampler1D",
+ "usampler1DArray",
+ "usampler2D",
+ "usampler2DArray",
+ "usampler2DMS",
+ "usampler2DMSArray",
+ "usampler2DRect",
+ "usampler3D",
+ "usamplerBuffer",
+ "usamplerCube",
+ "usamplerCubeArray",
+ "using",
+ "usubBorrow",
+ "uvec2",
+ "uvec3",
+ "uvec4",
+ "varying",
+ "vec2",
+ "vec3",
+ "vec4",
+ "void",
+ "volatile",
+ "while",
+ "writeonly",
+};
+
+// This list is used for a binary search and must be kept in sorted order.
+const char* const kReservedKeywordsHLSL[] = {
+ "AddressU",
+ "AddressV",
+ "AddressW",
+ "AllMemoryBarrier",
+ "AllMemoryBarrierWithGroupSync",
+ "AppendStructuredBuffer",
+ "BINORMAL",
+ "BLENDINDICES",
+ "BLENDWEIGHT",
+ "BlendState",
+ "BorderColor",
+ "Buffer",
+ "ByteAddressBuffer",
+ "COLOR",
+ "CheckAccessFullyMapped",
+ "ComparisonFunc",
+ "CompileShader",
+ "ComputeShader",
+ "ConsumeStructuredBuffer",
+ "D3DCOLORtoUBYTE4",
+ "DEPTH",
+ "DepthStencilState",
+ "DepthStencilView",
+ "DeviceMemoryBarrier",
+ "DeviceMemroyBarrierWithGroupSync",
+ "DomainShader",
+ "EvaluateAttributeAtCentroid",
+ "EvaluateAttributeAtSample",
+ "EvaluateAttributeSnapped",
+ "FOG",
+ "Filter",
+ "GeometryShader",
+ "GetRenderTargetSampleCount",
+ "GetRenderTargetSamplePosition",
+ "GroupMemoryBarrier",
+ "GroupMemroyBarrierWithGroupSync",
+ "Hullshader",
+ "InputPatch",
+ "InterlockedAdd",
+ "InterlockedAnd",
+ "InterlockedCompareExchange",
+ "InterlockedCompareStore",
+ "InterlockedExchange",
+ "InterlockedMax",
+ "InterlockedMin",
+ "InterlockedOr",
+ "InterlockedXor",
+ "LineStream",
+ "MaxAnisotropy",
+ "MaxLOD",
+ "MinLOD",
+ "MipLODBias",
+ "NORMAL",
+ "NULL",
+ "Normal",
+ "OutputPatch",
+ "POSITION",
+ "POSITIONT",
+ "PSIZE",
+ "PixelShader",
+ "PointStream",
+ "Process2DQuadTessFactorsAvg",
+ "Process2DQuadTessFactorsMax",
+ "Process2DQuadTessFactorsMin",
+ "ProcessIsolineTessFactors",
+ "ProcessQuadTessFactorsAvg",
+ "ProcessQuadTessFactorsMax",
+ "ProcessQuadTessFactorsMin",
+ "ProcessTriTessFactorsAvg",
+ "ProcessTriTessFactorsMax",
+ "ProcessTriTessFactorsMin",
+ "RWBuffer",
+ "RWByteAddressBuffer",
+ "RWStructuredBuffer",
+ "RWTexture1D",
+ "RWTexture1DArray",
+ "RWTexture2D",
+ "RWTexture2DArray",
+ "RWTexture3D",
+ "RasterizerState",
+ "RenderTargetView",
+ "SV_ClipDistance",
+ "SV_Coverage",
+ "SV_CullDistance",
+ "SV_Depth",
+ "SV_DepthGreaterEqual",
+ "SV_DepthLessEqual",
+ "SV_DispatchThreadID",
+ "SV_DomainLocation",
+ "SV_GSInstanceID",
+ "SV_GroupID",
+ "SV_GroupIndex",
+ "SV_GroupThreadID",
+ "SV_InnerCoverage",
+ "SV_InsideTessFactor",
+ "SV_InstanceID",
+ "SV_IsFrontFace",
+ "SV_OutputControlPointID",
+ "SV_Position",
+ "SV_PrimitiveID",
+ "SV_RenderTargetArrayIndex",
+ "SV_SampleIndex",
+ "SV_StencilRef",
+ "SV_Target",
+ "SV_TessFactor",
+ "SV_VertexArrayIndex",
+ "SV_VertexID",
+ "Sampler",
+ "Sampler1D",
+ "Sampler2D",
+ "Sampler3D",
+ "SamplerCUBE",
+ "SamplerComparisonState",
+ "SamplerState",
+ "StructuredBuffer",
+ "TANGENT",
+ "TESSFACTOR",
+ "TEXCOORD",
+ "Texcoord",
+ "Texture",
+ "Texture1D",
+ "Texture1DArray",
+ "Texture2D",
+ "Texture2DArray",
+ "Texture2DMS",
+ "Texture2DMSArray",
+ "Texture3D",
+ "TextureCube",
+ "TextureCubeArray",
+ "TriangleStream",
+ "VFACE",
+ "VPOS",
+ "VertexShader",
+ "abort",
+ "allow_uav_condition",
+ "asdouble",
+ "asfloat",
+ "asint",
+ "asm",
+ "asm_fragment",
+ "asuint",
+ "auto",
+ "bool",
+ "bool1",
+ "bool1x1",
+ "bool1x2",
+ "bool1x3",
+ "bool1x4",
+ "bool2",
+ "bool2x1",
+ "bool2x2",
+ "bool2x3",
+ "bool2x4",
+ "bool3",
+ "bool3x1",
+ "bool3x2",
+ "bool3x3",
+ "bool3x4",
+ "bool4",
+ "bool4x1",
+ "bool4x2",
+ "bool4x3",
+ "bool4x4",
+ "branch",
+ "break",
+ "call",
+ "case",
+ "catch",
+ "cbuffer",
+ "centroid",
+ "char",
+ "class",
+ "clip",
+ "column_major",
+ "compile",
+ "compile_fragment",
+ "const",
+ "const_cast",
+ "continue",
+ "countbits",
+ "ddx",
+ "ddx_coarse",
+ "ddx_fine",
+ "ddy",
+ "ddy_coarse",
+ "ddy_fine",
+ "default",
+ "degrees",
+ "delete",
+ "discard",
+ "do",
+ "double",
+ "double1",
+ "double1x1",
+ "double1x2",
+ "double1x3",
+ "double1x4",
+ "double2",
+ "double2x1",
+ "double2x2",
+ "double2x3",
+ "double2x4",
+ "double3",
+ "double3x1",
+ "double3x2",
+ "double3x3",
+ "double3x4",
+ "double4",
+ "double4x1",
+ "double4x2",
+ "double4x3",
+ "double4x4",
+ "dst",
+ "dword",
+ "dword1",
+ "dword1x1",
+ "dword1x2",
+ "dword1x3",
+ "dword1x4",
+ "dword2",
+ "dword2x1",
+ "dword2x2",
+ "dword2x3",
+ "dword2x4",
+ "dword3",
+ "dword3x1",
+ "dword3x2",
+ "dword3x3",
+ "dword3x4",
+ "dword4",
+ "dword4x1",
+ "dword4x2",
+ "dword4x3",
+ "dword4x4",
+ "dynamic_cast",
+ "else",
+ "enum",
+ "errorf",
+ "explicit",
+ "export",
+ "extern",
+ "f16to32",
+ "f32tof16",
+ "false",
+ "fastopt",
+ "firstbithigh",
+ "firstbitlow",
+ "flatten",
+ "float",
+ "float1",
+ "float1x1",
+ "float1x2",
+ "float1x3",
+ "float1x4",
+ "float2",
+ "float2x1",
+ "float2x2",
+ "float2x3",
+ "float2x4",
+ "float3",
+ "float3x1",
+ "float3x2",
+ "float3x3",
+ "float3x4",
+ "float4",
+ "float4x1",
+ "float4x2",
+ "float4x3",
+ "float4x4",
+ "fmod",
+ "for",
+ "forcecase",
+ "frac",
+ "friend",
+ "fxgroup",
+ "goto",
+ "groupshared",
+ "half",
+ "half1",
+ "half1x1",
+ "half1x2",
+ "half1x3",
+ "half1x4",
+ "half2",
+ "half2x1",
+ "half2x2",
+ "half2x3",
+ "half2x4",
+ "half3",
+ "half3x1",
+ "half3x2",
+ "half3x3",
+ "half3x4",
+ "half4",
+ "half4x1",
+ "half4x2",
+ "half4x3",
+ "half4x4",
+ "if",
+ "in",
+ "inline",
+ "inout",
+ "int",
+ "int1",
+ "int1x1",
+ "int1x2",
+ "int1x3",
+ "int1x4",
+ "int2",
+ "int2x1",
+ "int2x2",
+ "int2x3",
+ "int2x4",
+ "int3",
+ "int3x1",
+ "int3x2",
+ "int3x3",
+ "int3x4",
+ "int4",
+ "int4x1",
+ "int4x2",
+ "int4x3",
+ "int4x4",
+ "interface",
+ "isfinite",
+ "isinf",
+ "isnan",
+ "lerp",
+ "line",
+ "lineadj",
+ "linear",
+ "lit",
+ "log10",
+ "long",
+ "loop",
+ "mad",
+ "matrix",
+ "min10float",
+ "min10float1",
+ "min10float1x1",
+ "min10float1x2",
+ "min10float1x3",
+ "min10float1x4",
+ "min10float2",
+ "min10float2x1",
+ "min10float2x2",
+ "min10float2x3",
+ "min10float2x4",
+ "min10float3",
+ "min10float3x1",
+ "min10float3x2",
+ "min10float3x3",
+ "min10float3x4",
+ "min10float4",
+ "min10float4x1",
+ "min10float4x2",
+ "min10float4x3",
+ "min10float4x4",
+ "min12int",
+ "min12int1",
+ "min12int1x1",
+ "min12int1x2",
+ "min12int1x3",
+ "min12int1x4",
+ "min12int2",
+ "min12int2x1",
+ "min12int2x2",
+ "min12int2x3",
+ "min12int2x4",
+ "min12int3",
+ "min12int3x1",
+ "min12int3x2",
+ "min12int3x3",
+ "min12int3x4",
+ "min12int4",
+ "min12int4x1",
+ "min12int4x2",
+ "min12int4x3",
+ "min12int4x4",
+ "min16float",
+ "min16float1",
+ "min16float1x1",
+ "min16float1x2",
+ "min16float1x3",
+ "min16float1x4",
+ "min16float2",
+ "min16float2x1",
+ "min16float2x2",
+ "min16float2x3",
+ "min16float2x4",
+ "min16float3",
+ "min16float3x1",
+ "min16float3x2",
+ "min16float3x3",
+ "min16float3x4",
+ "min16float4",
+ "min16float4x1",
+ "min16float4x2",
+ "min16float4x3",
+ "min16float4x4",
+ "min16int",
+ "min16int1",
+ "min16int1x1",
+ "min16int1x2",
+ "min16int1x3",
+ "min16int1x4",
+ "min16int2",
+ "min16int2x1",
+ "min16int2x2",
+ "min16int2x3",
+ "min16int2x4",
+ "min16int3",
+ "min16int3x1",
+ "min16int3x2",
+ "min16int3x3",
+ "min16int3x4",
+ "min16int4",
+ "min16int4x1",
+ "min16int4x2",
+ "min16int4x3",
+ "min16int4x4",
+ "min16uint",
+ "min16uint1",
+ "min16uint1x1",
+ "min16uint1x2",
+ "min16uint1x3",
+ "min16uint1x4",
+ "min16uint2",
+ "min16uint2x1",
+ "min16uint2x2",
+ "min16uint2x3",
+ "min16uint2x4",
+ "min16uint3",
+ "min16uint3x1",
+ "min16uint3x2",
+ "min16uint3x3",
+ "min16uint3x4",
+ "min16uint4",
+ "min16uint4x1",
+ "min16uint4x2",
+ "min16uint4x3",
+ "min16uint4x4",
+ "msad4",
+ "mul",
+ "mutable",
+ "namespace",
+ "new",
+ "nointerpolation",
+ "noise",
+ "noperspective",
+ "numthreads",
+ "operator",
+ "out",
+ "packoffset",
+ "pass",
+ "pixelfragment",
+ "pixelshader",
+ "point",
+ "precise",
+ "printf",
+ "private",
+ "protected",
+ "public",
+ "radians",
+ "rcp",
+ "refract",
+ "register",
+ "reinterpret_cast",
+ "return",
+ "row_major",
+ "rsqrt",
+ "sample",
+ "sampler",
+ "sampler1D",
+ "sampler2D",
+ "sampler3D",
+ "samplerCUBE",
+ "sampler_state",
+ "saturate",
+ "shared",
+ "short",
+ "signed",
+ "sincos",
+ "sizeof",
+ "snorm",
+ "stateblock",
+ "stateblock_state",
+ "static",
+ "static_cast",
+ "string",
+ "struct",
+ "switch",
+ "tbuffer",
+ "technique",
+ "technique10",
+ "technique11",
+ "template",
+ "tex1D",
+ "tex1Dbias",
+ "tex1Dgrad",
+ "tex1Dlod",
+ "tex1Dproj",
+ "tex2D",
+ "tex2Dbias",
+ "tex2Dgrad",
+ "tex2Dlod",
+ "tex2Dproj",
+ "tex3D",
+ "tex3Dbias",
+ "tex3Dgrad",
+ "tex3Dlod",
+ "tex3Dproj",
+ "texCUBE",
+ "texCUBEbias",
+ "texCUBEgrad",
+ "texCUBElod",
+ "texCUBEproj",
+ "texture",
+ "texture1D",
+ "texture1DArray",
+ "texture2D",
+ "texture2DArray",
+ "texture2DMS",
+ "texture2DMSArray",
+ "texture3D",
+ "textureCube",
+ "textureCubeArray",
+ "this",
+ "throw",
+ "transpose",
+ "triangle",
+ "triangleadj",
+ "true",
+ "try",
+ "typedef",
+ "typename",
+ "uint",
+ "uint1",
+ "uint1x1",
+ "uint1x2",
+ "uint1x3",
+ "uint1x4",
+ "uint2",
+ "uint2x1",
+ "uint2x2",
+ "uint2x3",
+ "uint2x4",
+ "uint3",
+ "uint3x1",
+ "uint3x2",
+ "uint3x3",
+ "uint3x4",
+ "uint4",
+ "uint4x1",
+ "uint4x2",
+ "uint4x3",
+ "uint4x4",
+ "uniform",
+ "union",
+ "unorm",
+ "unroll",
+ "unsigned",
+ "using",
+ "vector",
+ "vertexfragment",
+ "vertexshader",
+ "virtual",
+ "void",
+ "volatile",
+ "while",
+};
+
+// This list is used for a binary search and must be kept in sorted order.
+const char* const kReservedKeywordsMSL[] = {
+ "HUGE_VALF",
+ "HUGE_VALH",
+ "INFINITY",
+ "MAXFLOAT",
+ "MAXHALF",
+ "M_1_PI_F",
+ "M_1_PI_H",
+ "M_2_PI_F",
+ "M_2_PI_H",
+ "M_2_SQRTPI_F",
+ "M_2_SQRTPI_H",
+ "M_E_F",
+ "M_E_H",
+ "M_LN10_F",
+ "M_LN10_H",
+ "M_LN2_F",
+ "M_LN2_H",
+ "M_LOG10E_F",
+ "M_LOG10E_H",
+ "M_LOG2E_F",
+ "M_LOG2E_H",
+ "M_PI_2_F",
+ "M_PI_2_H",
+ "M_PI_4_F",
+ "M_PI_4_H",
+ "M_PI_F",
+ "M_PI_H",
+ "M_SQRT1_2_F",
+ "M_SQRT1_2_H",
+ "M_SQRT2_F",
+ "M_SQRT2_H",
+ "NAN",
+ "access",
+ "alignas",
+ "alignof",
+ "and",
+ "and_eq",
+ "array",
+ "array_ref",
+ "as_type",
+ "asm",
+ "atomic",
+ "atomic_bool",
+ "atomic_int",
+ "atomic_uint",
+ "auto",
+ "bitand",
+ "bitor",
+ "bool",
+ "bool2",
+ "bool3",
+ "bool4",
+ "break",
+ "buffer",
+ "case",
+ "catch",
+ "char",
+ "char16_t",
+ "char2",
+ "char3",
+ "char32_t",
+ "char4",
+ "class",
+ "compl",
+ "const",
+ "const_cast",
+ "const_reference",
+ "constant",
+ "constexpr",
+ "continue",
+ "decltype",
+ "default",
+ "delete",
+ "depth2d",
+ "depth2d_array",
+ "depth2d_ms",
+ "depth2d_ms_array",
+ "depthcube",
+ "depthcube_array",
+ "device",
+ "discard_fragment",
+ "do",
+ "double",
+ "dynamic_cast",
+ "else",
+ "enum",
+ "explicit",
+ "extern",
+ "false",
+ "final",
+ "float",
+ "float2",
+ "float2x2",
+ "float2x3",
+ "float2x4",
+ "float3",
+ "float3x2",
+ "float3x3",
+ "float3x4",
+ "float4",
+ "float4x2",
+ "float4x3",
+ "float4x4",
+ "for",
+ "fragment",
+ "friend",
+ "goto",
+ "half",
+ "half2",
+ "half2x2",
+ "half2x3",
+ "half2x4",
+ "half3",
+ "half3x2",
+ "half3x3",
+ "half3x4",
+ "half4",
+ "half4x2",
+ "half4x3",
+ "half4x4",
+ "if",
+ "imageblock",
+ "infinity",
+ "inline",
+ "int",
+ "int16_t",
+ "int2",
+ "int3",
+ "int32_t",
+ "int4",
+ "int64_t",
+ "int8_t",
+ "kernel",
+ "long",
+ "long2",
+ "long3",
+ "long4",
+ "main",
+ "matrix",
+ "metal",
+ "mutable",
+ "namespace",
+ "new",
+ "noexcept",
+ "not",
+ "not_eq",
+ "nullptr",
+ "operator",
+ "or",
+ "or_eq",
+ "override",
+ "packed_bool2",
+ "packed_bool3",
+ "packed_bool4",
+ "packed_char2",
+ "packed_char3",
+ "packed_char4",
+ "packed_float2",
+ "packed_float3",
+ "packed_float4",
+ "packed_half2",
+ "packed_half3",
+ "packed_half4",
+ "packed_int2",
+ "packed_int3",
+ "packed_int4",
+ "packed_short2",
+ "packed_short3",
+ "packed_short4",
+ "packed_uchar2",
+ "packed_uchar3",
+ "packed_uchar4",
+ "packed_uint2",
+ "packed_uint3",
+ "packed_uint4",
+ "packed_ushort2",
+ "packed_ushort3",
+ "packed_ushort4",
+ "patch_control_point",
+ "private",
+ "protected",
+ "ptrdiff_t",
+ "public",
+ "r16snorm",
+ "r16unorm",
+ "r8unorm",
+ "reference",
+ "register",
+ "reinterpret_cast",
+ "return",
+ "rg11b10f",
+ "rg16snorm",
+ "rg16unorm",
+ "rg8snorm",
+ "rg8unorm",
+ "rgb10a2",
+ "rgb9e5",
+ "rgba16snorm",
+ "rgba16unorm",
+ "rgba8snorm",
+ "rgba8unorm",
+ "sampler",
+ "short",
+ "short2",
+ "short3",
+ "short4",
+ "signed",
+ "size_t",
+ "sizeof",
+ "srgba8unorm",
+ "static",
+ "static_assert",
+ "static_cast",
+ "struct",
+ "switch",
+ "template",
+ "texture",
+ "texture1d",
+ "texture1d_array",
+ "texture2d",
+ "texture2d_array",
+ "texture2d_ms",
+ "texture2d_ms_array",
+ "texture3d",
+ "texture_buffer",
+ "texturecube",
+ "texturecube_array",
+ "this",
+ "thread",
+ "thread_local",
+ "threadgroup",
+ "threadgroup_imageblock",
+ "throw",
+ "true",
+ "try",
+ "typedef",
+ "typeid",
+ "typename",
+ "uchar",
+ "uchar2",
+ "uchar3",
+ "uchar4",
+ "uint",
+ "uint16_t",
+ "uint2",
+ "uint3",
+ "uint32_t",
+ "uint4",
+ "uint64_t",
+ "uint8_t",
+ "ulong2",
+ "ulong3",
+ "ulong4",
+ "uniform",
+ "union",
+ "unsigned",
+ "ushort",
+ "ushort2",
+ "ushort3",
+ "ushort4",
+ "using",
+ "vec",
+ "vertex",
+ "virtual",
+ "void",
+ "volatile",
+ "wchar_t",
+ "while",
+ "xor",
+ "xor_eq",
+};
+
+} // namespace
+
+Renamer::Data::Data(Remappings&& r) : remappings(std::move(r)) {}
+Renamer::Data::Data(const Data&) = default;
+Renamer::Data::~Data() = default;
+
+Renamer::Config::Config(Target t, bool pu) : target(t), preserve_unicode(pu) {}
+Renamer::Config::Config(const Config&) = default;
+Renamer::Config::~Config() = default;
+
+Renamer::Renamer() = default;
+Renamer::~Renamer() = default;
+
+Transform::ApplyResult Renamer::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap& outputs) const {
+ utils::Hashset<Symbol, 16> global_decls;
+ for (auto* decl : src->AST().TypeDecls()) {
+ global_decls.Add(decl->name->symbol);
+ }
+
+ // Identifiers that need to keep their symbols preserved.
+ utils::Hashset<const Identifier*, 16> preserved_identifiers;
+
+ for (auto* node : src->ASTNodes().Objects()) {
+ auto preserve_if_builtin_type = [&](const Identifier* ident) {
+ if (!global_decls.Contains(ident->symbol)) {
+ preserved_identifiers.Add(ident);
+ }
+ };
+
+ Switch(
+ node,
+ [&](const MemberAccessorExpression* accessor) {
+ auto* sem = src->Sem().Get(accessor)->UnwrapLoad();
+ if (sem->Is<sem::Swizzle>()) {
+ preserved_identifiers.Add(accessor->member);
+ } else if (auto* str_expr = src->Sem().GetVal(accessor->object)) {
+ if (auto* ty = str_expr->Type()->UnwrapRef()->As<type::Struct>()) {
+ if (!ty->Is<sem::Struct>()) { // Builtin structure
+ preserved_identifiers.Add(accessor->member);
+ }
+ }
+ }
+ },
+ [&](const DiagnosticAttribute* diagnostic) {
+ if (auto* category = diagnostic->control.rule_name->category) {
+ preserved_identifiers.Add(category);
+ }
+ preserved_identifiers.Add(diagnostic->control.rule_name->name);
+ },
+ [&](const DiagnosticDirective* diagnostic) {
+ if (auto* category = diagnostic->control.rule_name->category) {
+ preserved_identifiers.Add(category);
+ }
+ preserved_identifiers.Add(diagnostic->control.rule_name->name);
+ },
+ [&](const IdentifierExpression* expr) {
+ Switch(
+ src->Sem().Get(expr), //
+ [&](const sem::BuiltinEnumExpressionBase*) {
+ preserved_identifiers.Add(expr->identifier);
+ },
+ [&](const sem::TypeExpression*) {
+ preserve_if_builtin_type(expr->identifier);
+ });
+ },
+ [&](const CallExpression* call) {
+ Switch(
+ src->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>()->Target(),
+ [&](const sem::Builtin*) {
+ preserved_identifiers.Add(call->target->identifier);
+ },
+ [&](const sem::ValueConversion*) {
+ preserve_if_builtin_type(call->target->identifier);
+ },
+ [&](const sem::ValueConstructor*) {
+ preserve_if_builtin_type(call->target->identifier);
+ });
+ });
+ }
+
+ Target target = Target::kAll;
+ bool preserve_unicode = false;
+
+ if (auto* cfg = inputs.Get<Config>()) {
+ target = cfg->target;
+ preserve_unicode = cfg->preserve_unicode;
+ }
+
+ // Returns true if the symbol should be renamed based on the input configuration settings.
+ auto should_rename = [&](Symbol symbol) {
+ if (target == Target::kAll) {
+ return true;
+ }
+ auto name = symbol.Name();
+ if (!utils::utf8::IsASCII(name)) {
+ // name is non-ascii. All of the backend keywords are ascii, so rename if we're not
+ // preserving unicode symbols.
+ return !preserve_unicode;
+ }
+ switch (target) {
+ case Target::kGlslKeywords:
+ return std::binary_search(kReservedKeywordsGLSL,
+ kReservedKeywordsGLSL +
+ sizeof(kReservedKeywordsGLSL) / sizeof(const char*),
+ name) ||
+ name.compare(0, 3, "gl_") == 0 || name.find("__") != std::string::npos;
+ case Target::kHlslKeywords:
+ return std::binary_search(
+ kReservedKeywordsHLSL,
+ kReservedKeywordsHLSL + sizeof(kReservedKeywordsHLSL) / sizeof(const char*),
+ name);
+ case Target::kMslKeywords:
+ return std::binary_search(
+ kReservedKeywordsMSL,
+ kReservedKeywordsMSL + sizeof(kReservedKeywordsMSL) / sizeof(const char*),
+ name);
+ default:
+ break;
+ }
+ return true;
+ };
+
+ utils::Hashmap<Symbol, Symbol, 32> remappings;
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ false};
+
+ ctx.ReplaceAll([&](const Identifier* ident) -> const Identifier* {
+ const auto symbol = ident->symbol;
+ if (preserved_identifiers.Contains(ident) || !should_rename(symbol)) {
+ return nullptr; // Preserve symbol
+ }
+
+ // Create a replacement for this symbol, if we haven't already.
+ auto replacement = remappings.GetOrCreate(symbol, [&] { return b.Symbols().New(); });
+
+ // Reconstruct the identifier
+ if (auto* tmpl_ident = ident->As<TemplatedIdentifier>()) {
+ auto args = ctx.Clone(tmpl_ident->arguments);
+ return ctx.dst->create<TemplatedIdentifier>(ctx.Clone(ident->source), replacement,
+ std::move(args), utils::Empty);
+ }
+ return ctx.dst->create<Identifier>(ctx.Clone(ident->source), replacement);
+ });
+
+ ctx.Clone();
+
+ Data::Remappings out;
+ for (auto it : remappings) {
+ out[it.key.Name()] = it.value.Name();
+ }
+ outputs.Add<Data>(std::move(out));
+
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/renamer.h b/src/tint/lang/wgsl/ast/transform/renamer.h
new file mode 100644
index 0000000..e6f4afd
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/renamer.h
@@ -0,0 +1,96 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_RENAMER_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_RENAMER_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// Renamer is a Transform that renames all the symbols in a program.
+class Renamer final : public utils::Castable<Renamer, Transform> {
+ public:
+ /// Data is outputted by the Renamer transform.
+ /// Data holds information about shader usage and constant buffer offsets.
+ struct Data final : public utils::Castable<Data, tint::transform::Data> {
+ /// Remappings is a map of old symbol name to new symbol name
+ using Remappings = std::unordered_map<std::string, std::string>;
+
+ /// Constructor
+ /// @param remappings the symbol remappings
+ explicit Data(Remappings&& remappings);
+
+ /// Copy constructor
+ Data(const Data&);
+
+ /// Destructor
+ ~Data() override;
+
+ /// A map of old symbol name to new symbol name
+ const Remappings remappings;
+ };
+
+ /// Target is an enumerator of rename targets that can be used
+ enum class Target {
+ /// Rename every symbol.
+ kAll,
+ /// Only rename symbols that are reserved keywords in GLSL.
+ kGlslKeywords,
+ /// Only rename symbols that are reserved keywords in HLSL.
+ kHlslKeywords,
+ /// Only rename symbols that are reserved keywords in MSL.
+ kMslKeywords,
+ };
+
+ /// Optional configuration options for the transform.
+ /// If omitted, then the renamer will use Target::kAll.
+ struct Config final : public utils::Castable<Config, tint::transform::Data> {
+ /// Constructor
+ /// @param tgt the targets to rename
+ /// @param keep_unicode if false, symbols with non-ascii code-points are
+ /// renamed
+ explicit Config(Target tgt, bool keep_unicode = false);
+
+ /// Copy constructor
+ Config(const Config&);
+
+ /// Destructor
+ ~Config() override;
+
+ /// The targets to rename
+ Target const target = Target::kAll;
+
+ /// If false, symbols with non-ascii code-points are renamed.
+ bool preserve_unicode = false;
+ };
+
+ /// Constructor using a the configuration provided in the input Data
+ Renamer();
+
+ /// Destructor
+ ~Renamer() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_RENAMER_H_
diff --git a/src/tint/lang/wgsl/ast/transform/renamer_test.cc b/src/tint/lang/wgsl/ast/transform/renamer_test.cc
new file mode 100644
index 0000000..7cfd18d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/renamer_test.cc
@@ -0,0 +1,2061 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/renamer.h"
+
+#include <memory>
+#include <unordered_set>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "src/tint/builtin/builtin.h"
+#include "src/tint/builtin/texel_format.h"
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/utils/string.h"
+
+namespace tint::ast::transform {
+namespace {
+
+constexpr const char kUnicodeIdentifier[] = // "𝖎𝖉𝖊𝖓𝖙𝖎𝖋𝖎𝖊𝖗123"
+ "\xf0\x9d\x96\x8e\xf0\x9d\x96\x89\xf0\x9d\x96\x8a\xf0\x9d\x96\x93"
+ "\xf0\x9d\x96\x99\xf0\x9d\x96\x8e\xf0\x9d\x96\x8b\xf0\x9d\x96\x8e"
+ "\xf0\x9d\x96\x8a\xf0\x9d\x96\x97\x31\x32\x33";
+
+using ::testing::ContainerEq;
+
+using RenamerTest = TransformTest;
+
+TEST_F(RenamerTest, EmptyModule) {
+ auto* src = "";
+ auto* expect = "";
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<Renamer::Data>();
+
+ ASSERT_EQ(data->remappings.size(), 0u);
+}
+
+TEST_F(RenamerTest, BasicModuleVertexIndex) {
+ auto* src = R"(
+fn test(vert_idx : u32) -> u32 {
+ return vert_idx;
+}
+
+@vertex
+fn entry(@builtin(vertex_index) vert_idx : u32
+ ) -> @builtin(position) vec4<f32> {
+ _ = test(vert_idx);
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+fn tint_symbol(tint_symbol_1 : u32) -> u32 {
+ return tint_symbol_1;
+}
+
+@vertex
+fn tint_symbol_2(@builtin(vertex_index) tint_symbol_1 : u32) -> @builtin(position) vec4<f32> {
+ _ = tint_symbol(tint_symbol_1);
+ return vec4<f32>();
+}
+)";
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<Renamer::Data>();
+
+ ASSERT_NE(data, nullptr);
+ Renamer::Data::Remappings expected_remappings = {
+ {"vert_idx", "tint_symbol_1"},
+ {"test", "tint_symbol"},
+ {"entry", "tint_symbol_2"},
+ };
+ EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
+}
+
+TEST_F(RenamerTest, PreserveSwizzles) {
+ auto* src = R"(
+@vertex
+fn entry() -> @builtin(position) vec4<f32> {
+ var v : vec4<f32>;
+ var rgba : f32;
+ var xyzw : f32;
+ var z : f32;
+ return v.zyxw + v.rgab * v.z;
+}
+)";
+
+ auto* expect = R"(
+@vertex
+fn tint_symbol() -> @builtin(position) vec4<f32> {
+ var tint_symbol_1 : vec4<f32>;
+ var tint_symbol_2 : f32;
+ var tint_symbol_3 : f32;
+ var tint_symbol_4 : f32;
+ return (tint_symbol_1.zyxw + (tint_symbol_1.rgab * tint_symbol_1.z));
+}
+)";
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<Renamer::Data>();
+
+ ASSERT_NE(data, nullptr);
+ Renamer::Data::Remappings expected_remappings = {
+ {"entry", "tint_symbol"}, {"v", "tint_symbol_1"}, {"rgba", "tint_symbol_2"},
+ {"xyzw", "tint_symbol_3"}, {"z", "tint_symbol_4"},
+ };
+ EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
+}
+
+TEST_F(RenamerTest, PreserveBuiltins) {
+ auto* src = R"(
+@vertex
+fn entry() -> @builtin(position) vec4<f32> {
+ var blah : vec4<f32>;
+ return abs(blah);
+}
+)";
+
+ auto* expect = R"(
+@vertex
+fn tint_symbol() -> @builtin(position) vec4<f32> {
+ var tint_symbol_1 : vec4<f32>;
+ return abs(tint_symbol_1);
+}
+)";
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<Renamer::Data>();
+
+ ASSERT_NE(data, nullptr);
+ Renamer::Data::Remappings expected_remappings = {
+ {"entry", "tint_symbol"},
+ {"blah", "tint_symbol_1"},
+ };
+ EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
+}
+
+TEST_F(RenamerTest, PreserveBuiltinTypes) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn entry() {
+ var a = modf(1.0).whole;
+ var b = modf(1.0).fract;
+ var c = frexp(1.0).fract;
+ var d = frexp(1.0).exp;
+}
+)";
+
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn tint_symbol() {
+ var tint_symbol_1 = modf(1.0).whole;
+ var tint_symbol_2 = modf(1.0).fract;
+ var tint_symbol_3 = frexp(1.0).fract;
+ var tint_symbol_4 = frexp(1.0).exp;
+}
+)";
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<Renamer::Data>();
+
+ ASSERT_NE(data, nullptr);
+ Renamer::Data::Remappings expected_remappings = {
+ {"entry", "tint_symbol"}, {"a", "tint_symbol_1"}, {"b", "tint_symbol_2"},
+ {"c", "tint_symbol_3"}, {"d", "tint_symbol_4"},
+ };
+ EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
+}
+
+TEST_F(RenamerTest, PreserveCoreDiagnosticRuleName) {
+ auto* src = R"(
+diagnostic(off, chromium.unreachable_code);
+
+@diagnostic(off, derivative_uniformity)
+@fragment
+fn entry(@location(0) value : f32) -> @location(0) f32 {
+ if (value > 0) {
+ return dpdx(value);
+ return 0.0;
+ }
+ return 1.0;
+}
+)";
+
+ auto* expect = R"(
+diagnostic(off, chromium.unreachable_code);
+
+@diagnostic(off, derivative_uniformity) @fragment
+fn tint_symbol(@location(0) tint_symbol_1 : f32) -> @location(0) f32 {
+ if ((tint_symbol_1 > 0)) {
+ return dpdx(tint_symbol_1);
+ return 0.0;
+ }
+ return 1.0;
+}
+)";
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<Renamer::Data>();
+
+ ASSERT_NE(data, nullptr);
+ Renamer::Data::Remappings expected_remappings = {
+ {"entry", "tint_symbol"},
+ {"value", "tint_symbol_1"},
+ };
+ EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
+}
+
+TEST_F(RenamerTest, PreserveUnicode) {
+ auto src = R"(
+@fragment
+fn frag_main() {
+ var )" + std::string(kUnicodeIdentifier) +
+ R"( : i32;
+}
+)";
+
+ auto expect = src;
+
+ Transform::DataMap inputs;
+ inputs.Add<Renamer::Config>(Renamer::Target::kMslKeywords,
+ /* preserve_unicode */ true);
+ auto got = Run<Renamer>(src, inputs);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RenamerTest, PreserveUnicodeRenameAll) {
+ auto src = R"(
+@fragment
+fn frag_main() {
+ var )" + std::string(kUnicodeIdentifier) +
+ R"( : i32;
+}
+)";
+
+ auto expect = R"(
+@fragment
+fn tint_symbol() {
+ var tint_symbol_1 : i32;
+}
+)";
+
+ Transform::DataMap inputs;
+ inputs.Add<Renamer::Config>(Renamer::Target::kAll,
+ /* preserve_unicode */ true);
+ auto got = Run<Renamer>(src, inputs);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RenamerTest, AttemptSymbolCollision) {
+ auto* src = R"(
+@vertex
+fn entry() -> @builtin(position) vec4<f32> {
+ var tint_symbol : vec4<f32>;
+ var tint_symbol_2 : vec4<f32>;
+ var tint_symbol_4 : vec4<f32>;
+ return tint_symbol + tint_symbol_2 + tint_symbol_4;
+}
+)";
+
+ auto* expect = R"(
+@vertex
+fn tint_symbol() -> @builtin(position) vec4<f32> {
+ var tint_symbol_1 : vec4<f32>;
+ var tint_symbol_2 : vec4<f32>;
+ var tint_symbol_3 : vec4<f32>;
+ return ((tint_symbol_1 + tint_symbol_2) + tint_symbol_3);
+}
+)";
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<Renamer::Data>();
+
+ ASSERT_NE(data, nullptr);
+ Renamer::Data::Remappings expected_remappings = {
+ {"entry", "tint_symbol"},
+ {"tint_symbol", "tint_symbol_1"},
+ {"tint_symbol_2", "tint_symbol_2"},
+ {"tint_symbol_4", "tint_symbol_3"},
+ };
+ EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
+}
+
+TEST_F(RenamerTest, PreserveTexelFormatAndAccess) {
+ auto src = R"(
+@group(0) @binding(0) var texture : texture_storage_2d<rgba8unorm, write>;
+
+fn f() {
+ var dims = textureDimensions(texture);
+}
+)";
+
+ auto expect = R"(
+@group(0) @binding(0) var tint_symbol : texture_storage_2d<rgba8unorm, write>;
+
+fn tint_symbol_1() {
+ var tint_symbol_2 = textureDimensions(tint_symbol);
+}
+)";
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RenamerTest, PreserveAddressSpace) {
+ auto src = R"(
+var<private> p : i32;
+
+fn f() {
+ var v = p;
+}
+)";
+
+ auto expect = R"(
+var<private> tint_symbol : i32;
+
+fn tint_symbol_1() {
+ var tint_symbol_2 = tint_symbol;
+}
+)";
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+using RenamerTestGlsl = TransformTestWithParam<std::string>;
+using RenamerTestHlsl = TransformTestWithParam<std::string>;
+using RenamerTestMsl = TransformTestWithParam<std::string>;
+
+TEST_P(RenamerTestGlsl, Keywords) {
+ auto keyword = GetParam();
+
+ auto src = R"(
+@fragment
+fn frag_main() {
+ var )" + keyword +
+ R"( : i32;
+}
+)";
+
+ auto* expect = R"(
+@fragment
+fn frag_main() {
+ var tint_symbol : i32;
+}
+)";
+
+ Transform::DataMap inputs;
+ inputs.Add<Renamer::Config>(Renamer::Target::kGlslKeywords,
+ /* preserve_unicode */ false);
+ auto got = Run<Renamer>(src, inputs);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RenamerTestHlsl, Keywords) {
+ auto keyword = GetParam();
+
+ auto src = R"(
+@fragment
+fn frag_main() {
+ var )" + keyword +
+ R"( : i32;
+}
+)";
+
+ auto* expect = R"(
+@fragment
+fn frag_main() {
+ var tint_symbol : i32;
+}
+)";
+
+ Transform::DataMap inputs;
+ inputs.Add<Renamer::Config>(Renamer::Target::kHlslKeywords,
+ /* preserve_unicode */ false);
+ auto got = Run<Renamer>(src, inputs);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RenamerTestMsl, Keywords) {
+ auto keyword = GetParam();
+
+ auto src = R"(
+@fragment
+fn frag_main() {
+ var )" + keyword +
+ R"( : i32;
+}
+)";
+
+ auto* expect = R"(
+@fragment
+fn frag_main() {
+ var tint_symbol : i32;
+}
+)";
+
+ Transform::DataMap inputs;
+ inputs.Add<Renamer::Config>(Renamer::Target::kMslKeywords,
+ /* preserve_unicode */ false);
+ auto got = Run<Renamer>(src, inputs);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+INSTANTIATE_TEST_SUITE_P(RenamerTestGlsl,
+ RenamerTestGlsl,
+ testing::Values( // "active", // Also reserved in WGSL
+ // "asm", // WGSL keyword
+ "atomic_uint",
+ // "attribute", // Also reserved in WGSL
+ // "bool", // WGSL keyword
+ // "break", // WGSL keyword
+ "buffer",
+ "bvec2",
+ "bvec3",
+ "bvec4",
+ // "case", // WGSL keyword
+ // "cast", // Also reserved in WGSL
+ "centroid",
+ // "class", // Also reserved in WGSL
+ // "coherent", // Also reserved in WGSL
+ // "common", // Also reserved in WGSL
+ // "const", // WGSL keyword
+ // "continue", // WGSL keyword
+ // "default", // WGSL keyword
+ // "discard", // WGSL keyword
+ "dmat2",
+ "dmat2x2",
+ "dmat2x3",
+ "dmat2x4",
+ "dmat3",
+ "dmat3x2",
+ "dmat3x3",
+ "dmat3x4",
+ "dmat4",
+ "dmat4x2",
+ "dmat4x3",
+ "dmat4x4",
+ // "do", // WGSL keyword
+ "double",
+ "dvec2",
+ "dvec3",
+ "dvec4",
+ // "else" // WGSL keyword
+ // "enum", // WGSL keyword
+ // "extern", // Also reserved in WGSL
+ // "external", // Also reserved in WGSL
+ // "false", // WGSL keyword
+ // "filter", // Also reserved in WGSL
+ "fixed",
+ "flat",
+ "float",
+ // "for", // WGSL keyword
+ "fvec2",
+ "fvec3",
+ "fvec4",
+ "gl_BaseInstance",
+ "gl_BaseVertex",
+ "gl_ClipDistance",
+ "gl_DepthRangeParameters",
+ "gl_DrawID",
+ "gl_FragCoord",
+ "gl_FragDepth",
+ "gl_FrontFacing",
+ "gl_GlobalInvocationID",
+ "gl_InstanceID",
+ "gl_LocalInvocationID",
+ "gl_LocalInvocationIndex",
+ "gl_NumSamples",
+ "gl_NumWorkGroups",
+ "gl_PerVertex",
+ "gl_PointCoord",
+ "gl_PointSize",
+ "gl_Position",
+ "gl_PrimitiveID",
+ "gl_SampleID",
+ "gl_SampleMask",
+ "gl_SampleMaskIn",
+ "gl_SamplePosition",
+ "gl_VertexID",
+ "gl_WorkGroupID",
+ "gl_WorkGroupSize",
+ // "goto", // Also reserved in WGSL
+ "half",
+ // "highp", // Also reserved in WGSL
+ "hvec2",
+ "hvec3",
+ "hvec4",
+ // "if", // WGSL keyword
+ "iimage1D",
+ "iimage1DArray",
+ "iimage2D",
+ "iimage2DArray",
+ "iimage2DMS",
+ "iimage2DMSArray",
+ "iimage2DRect",
+ "iimage3D",
+ "iimageBuffer",
+ "iimageCube",
+ "iimageCubeArray",
+ "image1D",
+ "image1DArray",
+ "image2D",
+ "image2DArray",
+ "image2DMS",
+ "image2DMSArray",
+ "image2DRect",
+ "image3D",
+ "imageBuffer",
+ "imageCube",
+ "imageCubeArray",
+ "in",
+ // "inline", // Also reserved in WGSL
+ // "inout", // Also reserved in WGSL
+ "input",
+ "int",
+ // "interface", // Also reserved in WGSL
+ // "invariant", // Also reserved in WGSL
+ "isampler1D",
+ "isampler1DArray",
+ "isampler2D",
+ "isampler2DArray",
+ "isampler2DMS",
+ "isampler2DMSArray",
+ "isampler2DRect",
+ "isampler3D",
+ "isamplerBuffer",
+ "isamplerCube",
+ "isamplerCubeArray",
+ "ivec2",
+ "ivec3",
+ "ivec4",
+ // "layout", // Also reserved in WGSL
+ "long",
+ // "lowp", // Also reserved in WGSL
+ // "mat2x2", // WGSL keyword
+ // "mat2x3", // WGSL keyword
+ // "mat2x4", // WGSL keyword
+ // "mat2",
+ "mat3",
+ // "mat3x2", // WGSL keyword
+ // "mat3x3", // WGSL keyword
+ // "mat3x4", // WGSL keyword
+ "mat4",
+ // "mat4x2", // WGSL keyword
+ // "mat4x3", // WGSL keyword
+ // "mat4x4", // WGSL keyword
+ // "mediump", // Also reserved in WGSL
+ // "namespace", // Also reserved in WGSL
+ // "noinline", // Also reserved in WGSL
+ // "noperspective", // Also reserved in WGSL
+ "out",
+ "output",
+ // "partition", // Also reserved in WGSL
+ // "patch", // Also reserved in WGSL
+ // "precise", // Also reserved in WGSL
+ // "precision", // Also reserved in WGSL
+ // "public", // Also reserved in WGSL
+ // "readonly", // Also reserved in WGSL
+ // "resource", // Also reserved in WGSL
+ // "restrict", // Also reserved in WGSL
+ // "return", // WGSL keyword
+ "sample",
+ "sampler1D",
+ "sampler1DArray",
+ "sampler1DArrayShadow",
+ "sampler1DShadow",
+ "sampler2D",
+ "sampler2DArray",
+ "sampler2DArrayShadow",
+ "sampler2DMS",
+ "sampler2DMSArray",
+ "sampler2DRect",
+ "sampler2DRectShadow",
+ "sampler2DShadow",
+ "sampler3D",
+ "sampler3DRect",
+ "samplerBuffer",
+ "samplerCube",
+ "samplerCubeArray",
+ "samplerCubeArrayShadow",
+ "samplerCubeShadow",
+ // "shared" // Also reserved in WGSL,
+ "short",
+ // "sizeof", // Also reserved in WGSL
+ // "smooth", // Also reserved in WGSL
+ // "static", // Also reserved in WGSL
+ // "struct", // WGSL keyword
+ // "subroutine", // Also reserved in WGSL
+ "superp",
+ // "switch", // WGSL keyword
+ // "template", // Also reserved in WGSL
+ // "this", // Also reserved in WGSL
+ // "true", // WGSL keyword
+ // "typedef", // WGSL keyword
+ "uimage1D",
+ "uimage1DArray",
+ "uimage2D",
+ "uimage2DArray",
+ "uimage2DMS",
+ "uimage2DMSArray",
+ "uimage2DRect",
+ "uimage3D",
+ "uimageBuffer",
+ "uimageCube",
+ "uimageCubeArray",
+ "uint",
+ // "uniform", // WGSL keyword
+ // "union", // Also reserved in WGSL
+ "unsigned",
+ "usampler1D",
+ "usampler1DArray",
+ "usampler2D",
+ "usampler2DArray",
+ "usampler2DMS",
+ "usampler2DMSArray",
+ "usampler2DRect",
+ "usampler3D",
+ "usamplerBuffer",
+ "usamplerCube",
+ "usamplerCubeArray",
+ // "using", // WGSL keyword
+ "uvec2",
+ "uvec3",
+ "uvec4",
+ // "varying", // Also reserved in WGSL
+ // "vec2", // WGSL keyword
+ // "vec3", // WGSL keyword
+ // "vec4", // WGSL keyword
+ // "void", // WGSL keyword
+ // "volatile", // Also reserved in WGSL
+ // "while", // WGSL keyword
+ // "writeonly", // Also reserved in WGSL
+ kUnicodeIdentifier));
+
+INSTANTIATE_TEST_SUITE_P(RenamerTestHlsl,
+ RenamerTestHlsl,
+ testing::Values("AddressU",
+ "AddressV",
+ "AddressW",
+ "AllMemoryBarrier",
+ "AllMemoryBarrierWithGroupSync",
+ "AppendStructuredBuffer",
+ "BINORMAL",
+ "BLENDINDICES",
+ "BLENDWEIGHT",
+ "BlendState",
+ "BorderColor",
+ "Buffer",
+ "ByteAddressBuffer",
+ "COLOR",
+ "CheckAccessFullyMapped",
+ "ComparisonFunc",
+ // "CompileShader", // Also reserved in WGSL
+ // "ComputeShader", // Also reserved in WGSL
+ "ConsumeStructuredBuffer",
+ "D3DCOLORtoUBYTE4",
+ "DEPTH",
+ "DepthStencilState",
+ "DepthStencilView",
+ "DeviceMemoryBarrier",
+ "DeviceMemroyBarrierWithGroupSync",
+ // "DomainShader", // Also reserved in WGSL
+ "EvaluateAttributeAtCentroid",
+ "EvaluateAttributeAtSample",
+ "EvaluateAttributeSnapped",
+ "FOG",
+ "Filter",
+ // "GeometryShader", // Also reserved in WGSL
+ "GetRenderTargetSampleCount",
+ "GetRenderTargetSamplePosition",
+ "GroupMemoryBarrier",
+ "GroupMemroyBarrierWithGroupSync",
+ // "Hullshader", // Also reserved in WGSL
+ "InputPatch",
+ "InterlockedAdd",
+ "InterlockedAnd",
+ "InterlockedCompareExchange",
+ "InterlockedCompareStore",
+ "InterlockedExchange",
+ "InterlockedMax",
+ "InterlockedMin",
+ "InterlockedOr",
+ "InterlockedXor",
+ "LineStream",
+ "MaxAnisotropy",
+ "MaxLOD",
+ "MinLOD",
+ "MipLODBias",
+ "NORMAL",
+ // "NULL", // Also reserved in WGSL
+ "Normal",
+ "OutputPatch",
+ "POSITION",
+ "POSITIONT",
+ "PSIZE",
+ "PixelShader",
+ "PointStream",
+ "Process2DQuadTessFactorsAvg",
+ "Process2DQuadTessFactorsMax",
+ "Process2DQuadTessFactorsMin",
+ "ProcessIsolineTessFactors",
+ "ProcessQuadTessFactorsAvg",
+ "ProcessQuadTessFactorsMax",
+ "ProcessQuadTessFactorsMin",
+ "ProcessTriTessFactorsAvg",
+ "ProcessTriTessFactorsMax",
+ "ProcessTriTessFactorsMin",
+ "RWBuffer",
+ "RWByteAddressBuffer",
+ "RWStructuredBuffer",
+ "RWTexture1D",
+ "RWTexture1DArray",
+ "RWTexture2D",
+ "RWTexture2DArray",
+ "RWTexture3D",
+ "RasterizerState",
+ "RenderTargetView",
+ "SV_ClipDistance",
+ "SV_Coverage",
+ "SV_CullDistance",
+ "SV_Depth",
+ "SV_DepthGreaterEqual",
+ "SV_DepthLessEqual",
+ "SV_DispatchThreadID",
+ "SV_DomainLocation",
+ "SV_GSInstanceID",
+ "SV_GroupID",
+ "SV_GroupIndex",
+ "SV_GroupThreadID",
+ "SV_InnerCoverage",
+ "SV_InsideTessFactor",
+ "SV_InstanceID",
+ "SV_IsFrontFace",
+ "SV_OutputControlPointID",
+ "SV_Position",
+ "SV_PrimitiveID",
+ "SV_RenderTargetArrayIndex",
+ "SV_SampleIndex",
+ "SV_StencilRef",
+ "SV_Target",
+ "SV_TessFactor",
+ "SV_VertexArrayIndex",
+ "SV_VertexID",
+ "Sampler",
+ "Sampler1D",
+ "Sampler2D",
+ "Sampler3D",
+ "SamplerCUBE",
+ "SamplerComparisonState",
+ "SamplerState",
+ "StructuredBuffer",
+ "TANGENT",
+ "TESSFACTOR",
+ "TEXCOORD",
+ "Texcoord",
+ "Texture",
+ "Texture1D",
+ "Texture1DArray",
+ "Texture2D",
+ "Texture2DArray",
+ "Texture2DMS",
+ "Texture2DMSArray",
+ "Texture3D",
+ "TextureCube",
+ "TextureCubeArray",
+ "TriangleStream",
+ "VFACE",
+ "VPOS",
+ "VertexShader",
+ "abort",
+ // "abs", // WGSL builtin
+ // "acos", // WGSL builtin
+ // "all", // WGSL builtin
+ "allow_uav_condition",
+ // "any", // WGSL builtin
+ "asdouble",
+ "asfloat",
+ // "asin", // WGSL builtin
+ "asint",
+ // "asm", // WGSL keyword
+ // "asm_fragment", // Also reserved in WGSL
+ "asuint",
+ // "atan", // WGSL builtin
+ // "atan2", // WGSL builtin
+ // "auto", // Also reserved in WGSL
+ // "bool", // WGSL keyword
+ "bool1",
+ "bool1x1",
+ "bool1x2",
+ "bool1x3",
+ "bool1x4",
+ "bool2",
+ "bool2x1",
+ "bool2x2",
+ "bool2x3",
+ "bool2x4",
+ "bool3",
+ "bool3x1",
+ "bool3x2",
+ "bool3x3",
+ "bool3x4",
+ "bool4",
+ "bool4x1",
+ "bool4x2",
+ "bool4x3",
+ "bool4x4",
+ "branch",
+ // "break", // WGSL keyword
+ // "call", // WGSL builtin
+ // "case", // WGSL keyword
+ // "catch", // Also reserved in WGSL
+ "cbuffer",
+ // "ceil", // WGSL builtin
+ "centroid",
+ "char",
+ // "clamp", // WGSL builtin
+ // "class", // Also reserved in WGSL
+ "clip",
+ // "column_major", // Also reserved in WGSL
+ // "compile", // Also reserved in WGSL
+ // "compile_fragment", // Also reserved in WGSL
+ // "const", // WGSL keyword
+ // "const_cast", // Also reserved in WGSL
+ // "continue", // WGSL keyword
+ // "cos", // WGSL builtin
+ // "cosh", // WGSL builtin
+ "countbits",
+ // "cross", // WGSL builtin
+ "ddx",
+ "ddx_coarse",
+ "ddx_fine",
+ "ddy",
+ "ddy_coarse",
+ "ddy_fine",
+ // "default", // WGSL keyword
+ "degrees",
+ // "delete", // Also reserved in WGSL
+ // "determinant", // WGSL builtin
+ // "discard", // WGSL keyword
+ // "distance", // WGSL builtin
+ // "do", // WGSL keyword
+ // "dot", // WGSL builtin
+ "double",
+ "double1",
+ "double1x1",
+ "double1x2",
+ "double1x3",
+ "double1x4",
+ "double2",
+ "double2x1",
+ "double2x2",
+ "double2x3",
+ "double2x4",
+ "double3",
+ "double3x1",
+ "double3x2",
+ "double3x3",
+ "double3x4",
+ "double4",
+ "double4x1",
+ "double4x2",
+ "double4x3",
+ "double4x4",
+ "dst",
+ "dword",
+ "dword1",
+ "dword1x1",
+ "dword1x2",
+ "dword1x3",
+ "dword1x4",
+ "dword2",
+ "dword2x1",
+ "dword2x2",
+ "dword2x3",
+ "dword2x4",
+ "dword3",
+ "dword3x1",
+ "dword3x2",
+ "dword3x3",
+ "dword3x4",
+ "dword4",
+ "dword4x1",
+ "dword4x2",
+ "dword4x3",
+ "dword4x4",
+ // "dynamic_cast", // Also reserved in WGSL
+ // "else", // WGSL keyword
+ // "enum", // WGSL keyword
+ "errorf",
+ // "exp", // WGSL builtin
+ // "exp2", // WGSL builtin
+ // "explicit", // Also reserved in WGSL
+ // "export", // Also reserved in WGSL
+ // "extern", // Also reserved in WGSL
+ "f16to32",
+ "f32tof16",
+ // "faceforward", // WGSL builtin
+ // "false", // WGSL keyword
+ "fastopt",
+ "firstbithigh",
+ "firstbitlow",
+ "flatten",
+ "float",
+ "float1",
+ "float1x1",
+ "float1x2",
+ "float1x3",
+ "float1x4",
+ "float2",
+ "float2x1",
+ "float2x2",
+ "float2x3",
+ "float2x4",
+ "float3",
+ "float3x1",
+ "float3x2",
+ "float3x3",
+ "float3x4",
+ "float4",
+ "float4x1",
+ "float4x2",
+ "float4x3",
+ "float4x4",
+ // "floor", // WGSL builtin
+ // "fma", // WGSL builtin
+ "fmod",
+ // "for", // WGSL keyword
+ "forcecase",
+ "frac",
+ // "frexp", // WGSL builtin
+ // "friend", // Also reserved in WGSL
+ // "fwidth", // WGSL builtin
+ // "fxgroup", // Also reserved in WGSL
+ // "goto", // Also reserved in WGSL
+ // "groupshared", // Also reserved in WGSL
+ "half",
+ "half1",
+ "half1x1",
+ "half1x2",
+ "half1x3",
+ "half1x4",
+ "half2",
+ "half2x1",
+ "half2x2",
+ "half2x3",
+ "half2x4",
+ "half3",
+ "half3x1",
+ "half3x2",
+ "half3x3",
+ "half3x4",
+ "half4",
+ "half4x1",
+ "half4x2",
+ "half4x3",
+ "half4x4",
+ // "if", // WGSL keyword
+ // "in", // WGSL keyword
+ // "inline", // Also reserved in WGSL
+ // "inout", // Also reserved in WGSL
+ "int",
+ "int1",
+ "int1x1",
+ "int1x2",
+ "int1x3",
+ "int1x4",
+ "int2",
+ "int2x1",
+ "int2x2",
+ "int2x3",
+ "int2x4",
+ "int3",
+ "int3x1",
+ "int3x2",
+ "int3x3",
+ "int3x4",
+ "int4",
+ "int4x1",
+ "int4x2",
+ "int4x3",
+ "int4x4",
+ // "interface", // Also reserved in WGSL
+ "isfinite",
+ "isinf",
+ "isnan",
+ // "ldexp", // WGSL builtin
+ // "length", // WGSL builtin
+ "lerp",
+ // "line", // Also reserved in WGSL
+ // "lineadj", // Also reserved in WGSL
+ "linear",
+ "lit",
+ // "log", // WGSL builtin
+ "log10",
+ // "log2", // WGSL builtin
+ "long",
+ // "loop", // WGSL keyword
+ "mad",
+ "matrix",
+ // "max", // WGSL builtin
+ // "min", // WGSL builtin
+ "min10float",
+ "min10float1",
+ "min10float1x1",
+ "min10float1x2",
+ "min10float1x3",
+ "min10float1x4",
+ "min10float2",
+ "min10float2x1",
+ "min10float2x2",
+ "min10float2x3",
+ "min10float2x4",
+ "min10float3",
+ "min10float3x1",
+ "min10float3x2",
+ "min10float3x3",
+ "min10float3x4",
+ "min10float4",
+ "min10float4x1",
+ "min10float4x2",
+ "min10float4x3",
+ "min10float4x4",
+ "min12int",
+ "min12int1",
+ "min12int1x1",
+ "min12int1x2",
+ "min12int1x3",
+ "min12int1x4",
+ "min12int2",
+ "min12int2x1",
+ "min12int2x2",
+ "min12int2x3",
+ "min12int2x4",
+ "min12int3",
+ "min12int3x1",
+ "min12int3x2",
+ "min12int3x3",
+ "min12int3x4",
+ "min12int4",
+ "min12int4x1",
+ "min12int4x2",
+ "min12int4x3",
+ "min12int4x4",
+ "min16float",
+ "min16float1",
+ "min16float1x1",
+ "min16float1x2",
+ "min16float1x3",
+ "min16float1x4",
+ "min16float2",
+ "min16float2x1",
+ "min16float2x2",
+ "min16float2x3",
+ "min16float2x4",
+ "min16float3",
+ "min16float3x1",
+ "min16float3x2",
+ "min16float3x3",
+ "min16float3x4",
+ "min16float4",
+ "min16float4x1",
+ "min16float4x2",
+ "min16float4x3",
+ "min16float4x4",
+ "min16int",
+ "min16int1",
+ "min16int1x1",
+ "min16int1x2",
+ "min16int1x3",
+ "min16int1x4",
+ "min16int2",
+ "min16int2x1",
+ "min16int2x2",
+ "min16int2x3",
+ "min16int2x4",
+ "min16int3",
+ "min16int3x1",
+ "min16int3x2",
+ "min16int3x3",
+ "min16int3x4",
+ "min16int4",
+ "min16int4x1",
+ "min16int4x2",
+ "min16int4x3",
+ "min16int4x4",
+ "min16uint",
+ "min16uint1",
+ "min16uint1x1",
+ "min16uint1x2",
+ "min16uint1x3",
+ "min16uint1x4",
+ "min16uint2",
+ "min16uint2x1",
+ "min16uint2x2",
+ "min16uint2x3",
+ "min16uint2x4",
+ "min16uint3",
+ "min16uint3x1",
+ "min16uint3x2",
+ "min16uint3x3",
+ "min16uint3x4",
+ "min16uint4",
+ "min16uint4x1",
+ "min16uint4x2",
+ "min16uint4x3",
+ "min16uint4x4",
+ // "modf", // WGSL builtin
+ "msad4",
+ "mul",
+ // "mutable", // Also reserved in WGSL
+ // "namespace", // Also reserved in WGSL
+ // "new", // Also reserved in WGSL
+ // "nointerpolation", // Also reserved in WGSL
+ "noise",
+ // "noperspective", // Also reserved in WGSL
+ // "normalize", // WGSL builtin
+ "numthreads",
+ // "operator", // Also reserved in WGSL
+ // "out", // WGSL keyword
+ // "packoffset", // Also reserved in WGSL
+ // "pass", // Also reserved in WGSL
+ // "pixelfragment", // Also reserved in WGSL
+ "pixelshader",
+ // "point", // Also reserved in WGSL
+ // "pow", // WGSL builtin
+ // "precise", // Also reserved in WGSL
+ "printf",
+ // "private", // WGSL keyword
+ // "protected", // Also reserved in WGSL
+ // "public", // Also reserved in WGSL
+ "radians",
+ "rcp",
+ // "reflect", // WGSL builtin
+ "refract",
+ // "register", // Also reserved in WGSL
+ // "reinterpret_cast", // Also reserved in WGSL
+ // "return", // WGSL keyword
+ // "reversebits", // WGSL builtin
+ // "round", // WGSL builtin
+ "row_major",
+ "rsqrt",
+ "sample",
+ "sampler1D",
+ "sampler2D",
+ "sampler3D",
+ "samplerCUBE",
+ "sampler_state",
+ "saturate",
+ // "shared", // Also reserved in WGSL
+ "short",
+ // "sign", // WGSL builtin
+ // "signed", // Also reserved in WGSL
+ // "sin", // WGSL builtin
+ "sincos",
+ // "sinh", // WGSL builtin
+ // "sizeof", // Also reserved in WGSL
+ // "smoothstep", // WGSL builtin
+ // "snorm", // Also reserved in WGSL
+ // "sqrt", // WGSL builtin
+ "stateblock",
+ "stateblock_state",
+ // "static", // Also reserved in WGSL
+ // "static_cast", // Also reserved in WGSL
+ // "step", // WGSL builtin
+ "string",
+ // "struct", // WGSL keyword
+ // "switch", // WGSL keyword
+ // "tan", // WGSL builtin
+ // "tanh", // WGSL builtin
+ "tbuffer",
+ "technique",
+ "technique10",
+ "technique11",
+ // "template", // Also reserved in WGSL
+ "tex1D",
+ "tex1Dbias",
+ "tex1Dgrad",
+ "tex1Dlod",
+ "tex1Dproj",
+ "tex2D",
+ "tex2Dbias",
+ "tex2Dgrad",
+ "tex2Dlod",
+ "tex2Dproj",
+ "tex3D",
+ "tex3Dbias",
+ "tex3Dgrad",
+ "tex3Dlod",
+ "tex3Dproj",
+ "texCUBE",
+ "texCUBEbias",
+ "texCUBEgrad",
+ "texCUBElod",
+ "texCUBEproj",
+ "texture",
+ "texture1D",
+ "texture1DArray",
+ "texture2D",
+ "texture2DArray",
+ "texture2DMS",
+ "texture2DMSArray",
+ "texture3D",
+ "textureCube",
+ "textureCubeArray",
+ // "this", // Also reserved in WGSL
+ // "throw", // Also reserved in WGSL
+ "transpose",
+ "triangle",
+ "triangleadj",
+ // "true", // WGSL keyword
+ // "trunc", // WGSL builtin
+ // "try", // Also reserved in WGSL
+ // "typedef", // WGSL keyword
+ // "typename", // Also reserved in WGSL
+ "uint",
+ "uint1",
+ "uint1x1",
+ "uint1x2",
+ "uint1x3",
+ "uint1x4",
+ "uint2",
+ "uint2x1",
+ "uint2x2",
+ "uint2x3",
+ "uint2x4",
+ "uint3",
+ "uint3x1",
+ "uint3x2",
+ "uint3x3",
+ "uint3x4",
+ "uint4",
+ "uint4x1",
+ "uint4x2",
+ "uint4x3",
+ "uint4x4",
+ // "uniform", // WGSL keyword
+ // "union", // Also reserved in WGSL
+ // "unorm", // Also reserved in WGSL
+ "unroll",
+ "unsigned",
+ // "using", // WGSL reserved keyword
+ "vector",
+ "vertexfragment",
+ "vertexshader",
+ // "virtual", // Also reserved in WGSL
+ // "void", // WGSL keyword
+ // "volatile", // Also reserved in WGSL
+ // "while" // WGSL reserved keyword
+ kUnicodeIdentifier));
+
+INSTANTIATE_TEST_SUITE_P(
+ RenamerTestMsl,
+ RenamerTestMsl,
+ testing::Values(
+ // c++14 spec
+ // "alignas", // Also reserved in WGSL
+ // "alignof", // Also reserved in WGSL
+ "and",
+ "and_eq",
+ // "asm", // Also reserved in WGSL
+ // "auto", // Also reserved in WGSL
+ "bitand",
+ "bitor",
+ // "bool", // Also used in WGSL
+ // "break", // Also used in WGSL
+ // "case", // Also used in WGSL
+ // "catch", // Also reserved in WGSL
+ "char",
+ "char16_t",
+ "char32_t",
+ // "class", // Also reserved in WGSL
+ "compl",
+ // "const", // Also used in WGSL
+ // "const_cast", // Also reserved in WGSL
+ // "constexpr", // Also reserved in WGSL
+ // "continue", // Also used in WGSL
+ // "decltype", // Also reserved in WGSL
+ // "default", // Also used in WGSL
+ // "delete", // Also reserved in WGSL
+ // "do", // Also used in WGSL
+ "double",
+ // "dynamic_cast", // Also reserved in WGSL
+ // "else", // Also used in WGSL
+ // "enum", // Also used in WGSL
+ // "explicit", // Also reserved in WGSL
+ // "extern", // Also reserved in WGSL
+ // "false", // Also used in WGSL
+ // "final", // Also reserved in WGSL
+ "float",
+ // "for", // Also used in WGSL
+ // "friend", // Also reserved in WGSL
+ // "goto", // Also reserved in WGSL
+ // "if", // Also used in WGSL
+ // "inline", // Also reserved in WGSL
+ "int",
+ "long",
+ // "mutable", // Also reserved in WGSL
+ // "namespace", // Also reserved in WGSL
+ // "new", // Also reserved in WGSL
+ // "noexcept", // Also reserved in WGSL
+ "not",
+ "not_eq",
+ // "nullptr", // Also reserved in WGSL
+ // "operator", // Also reserved in WGSL
+ "or",
+ "or_eq",
+ // "override", // Also used in WGSL
+ // "private", // Also used in WGSL
+ // "protected", // Also reserved in WGSL
+ // "public", // Also reserved in WGSL
+ // "register", // Also reserved in WGSL
+ // "reinterpret_cast", // Also reserved in WGSL
+ // "return", // Also used in WGSL
+ "short",
+ // "signed", // Also reserved in WGSL
+ // "sizeof", // Also reserved in WGSL
+ // "static", // Also reserved in WGSL
+ // "static_assert", // Also reserved in WGSL
+ // "static_cast", // Also reserved in WGSL
+ // "struct", // Also used in WGSL
+ // "switch", // Also used in WGSL
+ // "template", // Also reserved in WGSL
+ // "this", // Also reserved in WGSL
+ // "thread_local", // Also reserved in WGSL
+ // "throw", // Also reserved in WGSL
+ // "true", // Also used in WGSL
+ // "try", // Also reserved in WGSL
+ // "typedef", // Also used in WGSL
+ // "typeid", // Also reserved in WGSL
+ // "typename", // Also reserved in WGSL
+ // "union", // Also reserved in WGSL
+ "unsigned",
+ // "using", // WGSL reserved keyword
+ // "virtual", // Also reserved in WGSL
+ // "void", // Also used in WGSL
+ // "volatile", // Also reserved in WGSL
+ "wchar_t",
+ // "while", // WGSL reserved keyword
+ "xor",
+ "xor_eq",
+
+ // Metal Spec
+ "access",
+ // "array", // Also used in WGSL
+ "array_ref",
+ "as_type",
+ // "atomic", // Also used in WGSL
+ "atomic_bool",
+ "atomic_int",
+ "atomic_uint",
+ "bool2",
+ "bool3",
+ "bool4",
+ "buffer",
+ "char2",
+ "char3",
+ "char4",
+ "const_reference",
+ "constant",
+ "depth2d",
+ "depth2d_array",
+ "depth2d_ms",
+ "depth2d_ms_array",
+ "depthcube",
+ "depthcube_array",
+ "device",
+ "discard_fragment",
+ "float2",
+ "float2x2",
+ "float2x3",
+ "float2x4",
+ "float3",
+ "float3x2",
+ "float3x3",
+ "float3x4",
+ "float4",
+ "float4x2",
+ "float4x3",
+ "float4x4",
+ "fragment",
+ "half",
+ "half2",
+ "half2x2",
+ "half2x3",
+ "half2x4",
+ "half3",
+ "half3x2",
+ "half3x3",
+ "half3x4",
+ "half4",
+ "half4x2",
+ "half4x3",
+ "half4x4",
+ "imageblock",
+ "int16_t",
+ "int2",
+ "int3",
+ "int32_t",
+ "int4",
+ "int64_t",
+ "int8_t",
+ "kernel",
+ "long2",
+ "long3",
+ "long4",
+ "main", // No functions called main
+ "matrix",
+ "metal", // The namespace
+ "packed_bool2",
+ "packed_bool3",
+ "packed_bool4",
+ "packed_char2",
+ "packed_char3",
+ "packed_char4",
+ "packed_float2",
+ "packed_float3",
+ "packed_float4",
+ "packed_half2",
+ "packed_half3",
+ "packed_half4",
+ "packed_int2",
+ "packed_int3",
+ "packed_int4",
+ "packed_short2",
+ "packed_short3",
+ "packed_short4",
+ "packed_uchar2",
+ "packed_uchar3",
+ "packed_uchar4",
+ "packed_uint2",
+ "packed_uint3",
+ "packed_uint4",
+ "packed_ushort2",
+ "packed_ushort3",
+ "packed_ushort4",
+ "patch_control_point",
+ "ptrdiff_t",
+ "r16snorm",
+ "r16unorm",
+ "r8unorm",
+ "reference",
+ "rg11b10f",
+ "rg16snorm",
+ "rg16unorm",
+ "rg8snorm",
+ "rg8unorm",
+ "rgb10a2",
+ "rgb9e5",
+ "rgba16snorm",
+ "rgba16unorm",
+ "rgba8snorm",
+ "rgba8unorm",
+ // "sampler", // Also used in WGSL
+ "short2",
+ "short3",
+ "short4",
+ "size_t",
+ "srgba8unorm",
+ "texture",
+ "texture1d",
+ "texture1d_array",
+ "texture2d",
+ "texture2d_array",
+ "texture2d_ms",
+ "texture2d_ms_array",
+ "texture3d",
+ "texture_buffer",
+ "texturecube",
+ "texturecube_array",
+ "thread",
+ "threadgroup",
+ "threadgroup_imageblock",
+ "uchar",
+ "uchar2",
+ "uchar3",
+ "uchar4",
+ "uint",
+ "uint16_t",
+ "uint2",
+ "uint3",
+ "uint32_t",
+ "uint4",
+ "uint64_t",
+ "uint8_t",
+ "ulong2",
+ "ulong3",
+ "ulong4",
+ // "uniform", // Also used in WGSL
+ "ushort",
+ "ushort2",
+ "ushort3",
+ "ushort4",
+ // "vec", // WGSL reserved keyword
+ "vertex",
+
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+ // Table 6.5. Constants for single-precision floating-point math
+ // functions
+ "MAXFLOAT",
+ "HUGE_VALF",
+ "INFINITY",
+ "infinity",
+ "NAN",
+ "M_E_F",
+ "M_LOG2E_F",
+ "M_LOG10E_F",
+ "M_LN2_F",
+ "M_LN10_F",
+ "M_PI_F",
+ "M_PI_2_F",
+ "M_PI_4_F",
+ "M_1_PI_F",
+ "M_2_PI_F",
+ "M_2_SQRTPI_F",
+ "M_SQRT2_F",
+ "M_SQRT1_2_F",
+ "MAXHALF",
+ "HUGE_VALH",
+ "M_E_H",
+ "M_LOG2E_H",
+ "M_LOG10E_H",
+ "M_LN2_H",
+ "M_LN10_H",
+ "M_PI_H",
+ "M_PI_2_H",
+ "M_PI_4_H",
+ "M_1_PI_H",
+ "M_2_PI_H",
+ "M_2_SQRTPI_H",
+ "M_SQRT2_H",
+ "M_SQRT1_2_H",
+ // "while" // WGSL reserved keyword
+ kUnicodeIdentifier));
+
+std::string ExpandBuiltinType(std::string_view name) {
+ if (name == "array") {
+ return "array<i32, 4>";
+ }
+ if (name == "atomic") {
+ return "atomic<i32>";
+ }
+ if (name == "mat2x2") {
+ return "mat2x2<f32>";
+ }
+ if (name == "mat2x2f") {
+ return "mat2x2<f32>";
+ }
+ if (name == "mat2x2h") {
+ return "mat2x2<f16>";
+ }
+ if (name == "mat2x3") {
+ return "mat2x3<f32>";
+ }
+ if (name == "mat2x3f") {
+ return "mat2x3<f32>";
+ }
+ if (name == "mat2x3h") {
+ return "mat2x3<f16>";
+ }
+ if (name == "mat2x4") {
+ return "mat2x4<f32>";
+ }
+ if (name == "mat2x4f") {
+ return "mat2x4<f32>";
+ }
+ if (name == "mat2x4h") {
+ return "mat2x4<f16>";
+ }
+ if (name == "mat3x2") {
+ return "mat3x2<f32>";
+ }
+ if (name == "mat3x2f") {
+ return "mat3x2<f32>";
+ }
+ if (name == "mat3x2h") {
+ return "mat3x2<f16>";
+ }
+ if (name == "mat3x3") {
+ return "mat3x3<f32>";
+ }
+ if (name == "mat3x3f") {
+ return "mat3x3<f32>";
+ }
+ if (name == "mat3x3h") {
+ return "mat3x3<f16>";
+ }
+ if (name == "mat3x4") {
+ return "mat3x4<f32>";
+ }
+ if (name == "mat3x4f") {
+ return "mat3x4<f32>";
+ }
+ if (name == "mat3x4h") {
+ return "mat3x4<f16>";
+ }
+ if (name == "mat4x2") {
+ return "mat4x2<f32>";
+ }
+ if (name == "mat4x2f") {
+ return "mat4x2<f32>";
+ }
+ if (name == "mat4x2h") {
+ return "mat4x2<f16>";
+ }
+ if (name == "mat4x3") {
+ return "mat4x3<f32>";
+ }
+ if (name == "mat4x3f") {
+ return "mat4x3<f32>";
+ }
+ if (name == "mat4x3h") {
+ return "mat4x3<f16>";
+ }
+ if (name == "mat4x4") {
+ return "mat4x4<f32>";
+ }
+ if (name == "mat4x4f") {
+ return "mat4x4<f32>";
+ }
+ if (name == "mat4x4h") {
+ return "mat4x4<f16>";
+ }
+ if (name == "ptr") {
+ return "ptr<function, i32>";
+ }
+ if (name == "vec2") {
+ return "vec2<f32>";
+ }
+ if (name == "vec2f") {
+ return "vec2<f32>";
+ }
+ if (name == "vec2h") {
+ return "vec2<f16>";
+ }
+ if (name == "vec2i") {
+ return "vec2<i32>";
+ }
+ if (name == "vec2u") {
+ return "vec2<u32>";
+ }
+ if (name == "vec3") {
+ return "vec3<f32>";
+ }
+ if (name == "vec3f") {
+ return "vec3<f32>";
+ }
+ if (name == "vec3h") {
+ return "vec3<f16>";
+ }
+ if (name == "vec3i") {
+ return "vec3<i32>";
+ }
+ if (name == "vec3u") {
+ return "vec3<u32>";
+ }
+ if (name == "vec4") {
+ return "vec4<f32>";
+ }
+ if (name == "vec4f") {
+ return "vec4<f32>";
+ }
+ if (name == "vec4h") {
+ return "vec4<f16>";
+ }
+ if (name == "vec4i") {
+ return "vec4<i32>";
+ }
+ if (name == "vec4u") {
+ return "vec4<u32>";
+ }
+ return std::string(name);
+}
+
+std::vector<const char*> ConstructableTypes() {
+ std::vector<const char*> out;
+ for (auto* ty : builtin::kBuiltinStrings) {
+ std::string_view type(ty);
+ if (type != "ptr" && type != "atomic" && !utils::HasPrefix(type, "sampler") &&
+ !utils::HasPrefix(type, "texture") && !utils::HasPrefix(type, "__")) {
+ out.push_back(ty);
+ }
+ }
+ return out;
+}
+
+using RenamerBuiltinTypeTest = TransformTestWithParam<const char*>;
+
+TEST_P(RenamerBuiltinTypeTest, PreserveTypeUsage) {
+ auto expand = [&](const char* source) {
+ return utils::ReplaceAll(source, "$type", ExpandBuiltinType(GetParam()));
+ };
+
+ auto src = expand(R"(
+enable f16;
+
+fn x(v : $type) -> $type {
+ const a : $type = $type();
+ let b : $type = a;
+ var c : $type = b;
+ return c;
+}
+
+struct y {
+ a : $type,
+}
+)");
+
+ auto expect = expand(R"(
+enable f16;
+
+fn tint_symbol(tint_symbol_1 : $type) -> $type {
+ const tint_symbol_2 : $type = $type();
+ let tint_symbol_3 : $type = tint_symbol_2;
+ var tint_symbol_4 : $type = tint_symbol_3;
+ return tint_symbol_4;
+}
+
+struct tint_symbol_5 {
+ tint_symbol_2 : $type,
+}
+)");
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+TEST_P(RenamerBuiltinTypeTest, PreserveTypeInitializer) {
+ auto expand = [&](const char* source) {
+ return utils::ReplaceAll(source, "$type", ExpandBuiltinType(GetParam()));
+ };
+
+ auto src = expand(R"(
+enable f16;
+
+@fragment
+fn f() {
+ var v : $type = $type();
+}
+)");
+
+ auto expect = expand(R"(
+enable f16;
+
+@fragment
+fn tint_symbol() {
+ var tint_symbol_1 : $type = $type();
+}
+)");
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RenamerBuiltinTypeTest, PreserveTypeConversion) {
+ if (std::string_view(GetParam()) == "array") {
+ return; // Cannot value convert arrays.
+ }
+
+ auto expand = [&](const char* source) {
+ return utils::ReplaceAll(source, "$type", ExpandBuiltinType(GetParam()));
+ };
+
+ auto src = expand(R"(
+enable f16;
+
+@fragment
+fn f() {
+ var v : $type = $type($type());
+}
+)");
+
+ auto expect = expand(R"(
+enable f16;
+
+@fragment
+fn tint_symbol() {
+ var tint_symbol_1 : $type = $type($type());
+}
+)");
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RenamerBuiltinTypeTest, PreserveTypeExpression) {
+ auto src = R"(
+enable f16;
+
+@fragment
+fn f() {
+ var v : array<f32, 2> = array<f32, 2>();
+}
+)";
+
+ auto expect = R"(
+enable f16;
+
+@fragment
+fn tint_symbol() {
+ var tint_symbol_1 : array<f32, 2> = array<f32, 2>();
+}
+)";
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RenamerBuiltinTypeTest, RenameShadowedByAlias) {
+ auto expand = [&](const char* source) {
+ std::string_view ty = GetParam();
+ auto out = utils::ReplaceAll(source, "$name", ty);
+ out = utils::ReplaceAll(out, "$type", ExpandBuiltinType(ty));
+ out = utils::ReplaceAll(out, "$other_type", ty == "i32" ? "u32" : "i32");
+ return out;
+ };
+
+ auto src = expand(R"(
+alias $name = $other_type;
+
+@fragment
+fn f() {
+ var v : $other_type = $name();
+}
+)");
+
+ auto expect = expand(R"(
+alias tint_symbol = $other_type;
+
+@fragment
+fn tint_symbol_1() {
+ var tint_symbol_2 : $other_type = tint_symbol();
+}
+)");
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RenamerBuiltinTypeTest, RenameShadowedByStruct) {
+ auto expand = [&](const char* source) {
+ std::string_view ty = GetParam();
+ auto out = utils::ReplaceAll(source, "$name", ty);
+ out = utils::ReplaceAll(out, "$type", ExpandBuiltinType(ty));
+ out = utils::ReplaceAll(out, "$other_type", ty == "i32" ? "u32" : "i32");
+ return out;
+ };
+
+ auto src = expand(R"(
+struct $name {
+ i : $other_type,
+}
+
+@fragment
+fn f() {
+ var a = $name();
+ var b = a.i;
+}
+)");
+
+ auto expect = expand(R"(
+struct tint_symbol {
+ tint_symbol_1 : $other_type,
+}
+
+@fragment
+fn tint_symbol_2() {
+ var tint_symbol_3 = tint_symbol();
+ var tint_symbol_4 = tint_symbol_3.tint_symbol_1;
+}
+)");
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+INSTANTIATE_TEST_SUITE_P(RenamerBuiltinTypeTest,
+ RenamerBuiltinTypeTest,
+ testing::ValuesIn(ConstructableTypes()));
+
+/// @return WGSL builtin identifier keywords
+std::vector<const char*> Identifiers() {
+ std::vector<const char*> out;
+ for (auto* ident : builtin::kBuiltinStrings) {
+ if (!utils::HasPrefix(ident, "__")) {
+ out.push_back(ident);
+ }
+ }
+ for (auto* ident : builtin::kAddressSpaceStrings) {
+ if (!utils::HasPrefix(ident, "_")) {
+ out.push_back(ident);
+ }
+ }
+ for (auto* ident : builtin::kTexelFormatStrings) {
+ out.push_back(ident);
+ }
+ for (auto* ident : builtin::kAccessStrings) {
+ out.push_back(ident);
+ }
+ return out;
+}
+
+using RenamerBuiltinIdentifierTest = TransformTestWithParam<const char*>;
+
+TEST_P(RenamerBuiltinIdentifierTest, GlobalConstName) {
+ auto expand = [&](const char* source) {
+ return utils::ReplaceAll(source, "$name", GetParam());
+ };
+
+ auto src = expand(R"(
+const $name = 42;
+
+fn f() {
+ const v = $name;
+}
+)");
+
+ auto expect = expand(R"(
+const tint_symbol = 42;
+
+fn tint_symbol_1() {
+ const tint_symbol_2 = tint_symbol;
+}
+)");
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RenamerBuiltinIdentifierTest, LocalVarName) {
+ auto expand = [&](const char* source) {
+ return utils::ReplaceAll(source, "$name", GetParam());
+ };
+
+ auto src = expand(R"(
+fn f() {
+ var $name = 42;
+}
+)");
+
+ auto expect = expand(R"(
+fn tint_symbol() {
+ var tint_symbol_1 = 42;
+}
+)");
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RenamerBuiltinIdentifierTest, FunctionName) {
+ auto expand = [&](const char* source) {
+ return utils::ReplaceAll(source, "$name", GetParam());
+ };
+
+ auto src = expand(R"(
+fn $name() {
+}
+
+fn f() {
+ $name();
+}
+)");
+
+ auto expect = expand(R"(
+fn tint_symbol() {
+}
+
+fn tint_symbol_1() {
+ tint_symbol();
+}
+)");
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RenamerBuiltinIdentifierTest, StructName) {
+ auto expand = [&](const char* source) {
+ std::string_view name = GetParam();
+ auto out = utils::ReplaceAll(source, "$name", name);
+ return utils::ReplaceAll(out, "$other_type", name == "i32" ? "u32" : "i32");
+ };
+
+ auto src = expand(R"(
+struct $name {
+ i : $other_type,
+}
+
+fn f() {
+ var x = $name();
+}
+)");
+
+ auto expect = expand(R"(
+struct tint_symbol {
+ tint_symbol_1 : $other_type,
+}
+
+fn tint_symbol_2() {
+ var tint_symbol_3 = tint_symbol();
+}
+)");
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+INSTANTIATE_TEST_SUITE_P(RenamerBuiltinIdentifierTest,
+ RenamerBuiltinIdentifierTest,
+ testing::ValuesIn(Identifiers()));
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/robustness.cc b/src/tint/lang/wgsl/ast/transform/robustness.cc
new file mode 100644
index 0000000..c5b7bbb
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/robustness.cc
@@ -0,0 +1,737 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/robustness.h"
+
+#include <algorithm>
+#include <limits>
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/block_statement.h"
+#include "src/tint/sem/builtin.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/index_accessor_expression.h"
+#include "src/tint/sem/load.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/value_expression.h"
+#include "src/tint/switch.h"
+#include "src/tint/type/reference.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::Robustness);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::Robustness::Config);
+
+namespace tint::ast::transform {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+/// PIMPL state for the transform
+struct Robustness::State {
+ /// Constructor
+ /// @param p the source program
+ /// @param c the transform config
+ State(const Program* p, Config&& c) : src(p), cfg(std::move(c)) {}
+
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
+ if (HasAction(Action::kPredicate)) {
+ AddPredicateParameters();
+ }
+
+ // Walk all the AST nodes in the module, starting with the leaf nodes.
+ // The most deeply nested expressions will come first.
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ Switch(
+ node, //
+ [&](const IndexAccessorExpression* e) {
+ // obj[idx]
+ // Array, matrix and vector indexing may require robustness transformation.
+ auto* expr = sem.Get(e)->Unwrap()->As<sem::IndexAccessorExpression>();
+ if (IsIgnoredResourceBinding(expr->Object()->RootIdentifier())) {
+ return;
+ }
+ if (cfg.disable_runtime_sized_array_index_clamping &&
+ IsIndexAccessingRuntimeSizedArray(expr)) {
+ // Ensure the index is always u32 as using a negative index is an undefined
+ // behavior in SPIRV.
+ auto* idx = CastToU32(expr->Index());
+ ctx.Replace(expr->Declaration()->index, idx);
+ return;
+ }
+ switch (ActionFor(expr)) {
+ case Action::kPredicate:
+ PredicateIndexAccessor(expr);
+ break;
+ case Action::kClamp:
+ ClampIndexAccessor(expr);
+ break;
+ case Action::kIgnore:
+ break;
+ }
+ },
+ [&](const IdentifierExpression* e) {
+ // Identifiers may resolve to pointer lets, which may be predicated.
+ // Inspect.
+ if (auto* user = sem.Get<sem::VariableUser>(e)) {
+ auto* v = user->Variable();
+ if (v->Type()->Is<type::Pointer>()) {
+ // Propagate predicate from pointer
+ if (auto pred = predicates.Get(v->Declaration()->initializer)) {
+ predicates.Add(e, *pred);
+ }
+ }
+ }
+ },
+ [&](const AccessorExpression* e) {
+ // obj.member
+ // Propagate the predication from the object to this expression.
+ if (auto pred = predicates.Get(e->object)) {
+ predicates.Add(e, *pred);
+ }
+ },
+ [&](const UnaryOpExpression* e) {
+ // Includes address-of, or indirection
+ // Propagate the predication from the inner expression to this expression.
+ if (auto pred = predicates.Get(e->expr)) {
+ predicates.Add(e, *pred);
+ }
+ },
+ [&](const AssignmentStatement* s) {
+ if (auto pred = predicates.Get(s->lhs)) {
+ // Assignment target is predicated
+ // Replace statement with condition on the predicate
+ ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
+ }
+ },
+ [&](const CompoundAssignmentStatement* s) {
+ if (auto pred = predicates.Get(s->lhs)) {
+ // Assignment expression is predicated
+ // Replace statement with condition on the predicate
+ ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
+ }
+ },
+ [&](const IncrementDecrementStatement* s) {
+ if (auto pred = predicates.Get(s->lhs)) {
+ // Assignment expression is predicated
+ // Replace statement with condition on the predicate
+ ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
+ }
+ },
+ [&](const CallExpression* e) {
+ if (auto* call = sem.Get<sem::Call>(e)) {
+ Switch(
+ call->Target(), //
+ [&](const sem::Builtin* builtin) {
+ // Calls to builtins may require robustness transformation.
+ // Inspect.
+ if (builtin->IsTexture()) {
+ switch (cfg.texture_action) {
+ case Action::kPredicate:
+ PredicateTextureBuiltin(call, builtin);
+ break;
+ case Action::kClamp:
+ ClampTextureBuiltin(call, builtin);
+ break;
+ case Action::kIgnore:
+ break;
+ }
+ } else {
+ MaybePredicateNonTextureBuiltin(call, builtin);
+ }
+ },
+ [&](const sem::Function* fn) {
+ // Calls to user function may require passing additional predicate
+ // arguments.
+ InsertPredicateArguments(call, fn);
+ });
+ }
+ });
+
+ // Check whether the node is an expression that:
+ // * Has a predicate
+ // * Is of a non-pointer or non-reference type
+ // If the above is true, then we need to predicate evaluation of this expression by
+ // replacing `expr` with `predicated_expr` and injecting the following above the
+ // expression's statement:
+ //
+ // var predicated_expr : expr_ty;
+ // if (predicate) {
+ // predicated_expr = expr;
+ // }
+ //
+ if (auto* expr = node->As<Expression>()) {
+ if (auto pred = predicates.Get(expr)) {
+ // Expression is predicated
+ auto* sem_expr = sem.GetVal(expr);
+ if (!sem_expr->Type()->IsAnyOf<type::Reference, type::Pointer>()) {
+ auto pred_load = b.Symbols().New("predicated_expr");
+ auto ty = CreateASTTypeFor(ctx, sem_expr->Type());
+ hoist.InsertBefore(sem_expr->Stmt(), b.Decl(b.Var(pred_load, ty)));
+ hoist.InsertBefore(
+ sem_expr->Stmt(),
+ b.If(*pred, b.Block(b.Assign(pred_load, ctx.Clone(expr)))));
+ ctx.Replace(expr, b.Expr(pred_load));
+
+ // The predication has been consumed for this expression.
+ // Don't predicate expressions that use this expression.
+ predicates.Remove(expr);
+ }
+ }
+ }
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+
+ private:
+ /// The source program
+ const Program* const src;
+ /// The transform's config
+ Config cfg;
+ /// The target program builder
+ ProgramBuilder b{};
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+ /// Helper for hoisting declarations
+ HoistToDeclBefore hoist{ctx};
+ /// Alias to the source program's semantic info
+ const sem::Info& sem = ctx.src->Sem();
+ /// Map of expression to predicate condition
+ utils::Hashmap<const Expression*, Symbol, 32> predicates{};
+
+ /// @return the `u32` typed expression that represents the maximum indexable value for the index
+ /// accessor @p expr, or nullptr if there is no robustness limit for this expression.
+ const Expression* DynamicLimitFor(const sem::IndexAccessorExpression* expr) {
+ auto* obj_type = expr->Object()->Type();
+ return Switch(
+ obj_type->UnwrapRef(), //
+ [&](const type::Vector* vec) -> const Expression* {
+ if (expr->Index()->ConstantValue() || expr->Index()->Is<sem::Swizzle>()) {
+ // Index and size is constant.
+ // Validation will have rejected any OOB accesses.
+ return nullptr;
+ }
+ return b.Expr(u32(vec->Width() - 1u));
+ },
+ [&](const type::Matrix* mat) -> const Expression* {
+ if (expr->Index()->ConstantValue()) {
+ // Index and size is constant.
+ // Validation will have rejected any OOB accesses.
+ return nullptr;
+ }
+ return b.Expr(u32(mat->columns() - 1u));
+ },
+ [&](const type::Array* arr) -> const Expression* {
+ if (arr->Count()->Is<type::RuntimeArrayCount>()) {
+ // Size is unknown until runtime.
+ // Must clamp, even if the index is constant.
+
+ auto* arr_ptr = b.AddressOf(ctx.Clone(expr->Object()->Declaration()));
+ return b.Sub(b.Call(builtin::Function::kArrayLength, arr_ptr), 1_u);
+ }
+ if (auto count = arr->ConstantCount()) {
+ if (expr->Index()->ConstantValue()) {
+ // Index and size is constant.
+ // Validation will have rejected any OOB accesses.
+ return nullptr;
+ }
+ return b.Expr(u32(count.value() - 1u));
+ }
+ // Note: Don't be tempted to use the array override variable as an expression here,
+ // the name might be shadowed!
+ b.Diagnostics().add_error(diag::System::Transform,
+ type::Array::kErrExpectedConstantCount);
+ return nullptr;
+ },
+ [&](Default) -> const Expression* {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled object type in robustness of array index: "
+ << obj_type->UnwrapRef()->FriendlyName();
+ return nullptr;
+ });
+ }
+
+ /// Transform the program to insert additional predicate parameters to all user functions that
+ /// have a pointer parameter type in an address space that has predicate action.
+ void AddPredicateParameters() {
+ for (auto* fn : src->AST().Functions()) {
+ for (auto* param : fn->params) {
+ auto* sem_param = sem.Get(param);
+ if (auto* ptr = sem_param->Type()->As<type::Pointer>()) {
+ if (ActionFor(ptr->AddressSpace()) == Action::kPredicate) {
+ auto name = b.Symbols().New(param->name->symbol.Name() + "_predicate");
+ ctx.InsertAfter(fn->params, param, b.Param(name, b.ty.bool_()));
+
+ // Associate the pointer parameter expressions with the predicate.
+ for (auto* user : sem_param->Users()) {
+ predicates.Add(user->Declaration(), name);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ /// Transforms call expressions to user functions, inserting additional predicate arguments
+ /// after all pointer parameters with a type in an address space that has predicate action.
+ void InsertPredicateArguments(const sem::Call* call, const sem::Function* fn) {
+ auto* expr = call->Declaration();
+ for (size_t i = 0; i < fn->Parameters().Length(); i++) {
+ auto* param = fn->Parameters()[i];
+ if (auto* ptr = param->Type()->As<type::Pointer>()) {
+ if (ActionFor(ptr->AddressSpace()) == Action::kPredicate) {
+ auto* arg = expr->args[i];
+ if (auto predicate = predicates.Get(arg)) {
+ ctx.InsertAfter(expr->args, arg, b.Expr(*predicate));
+ } else {
+ ctx.InsertAfter(expr->args, arg, b.Expr(true));
+ }
+ }
+ }
+ }
+ }
+
+ /// Applies predication to the index on an array, vector or matrix.
+ /// @param expr the index accessor expression.
+ void PredicateIndexAccessor(const sem::IndexAccessorExpression* expr) {
+ auto* obj = expr->Object()->Declaration();
+ auto* idx = expr->Index()->Declaration();
+ auto* max = DynamicLimitFor(expr);
+ if (!max) {
+ // robustness is not required
+ // Just propagate predicate from object
+ if (auto pred = predicates.Get(obj)) {
+ predicates.Add(expr->Declaration(), *pred);
+ }
+ return;
+ }
+
+ auto* stmt = expr->Stmt();
+ auto obj_pred = *predicates.GetOrZero(obj);
+
+ auto idx_let = b.Symbols().New("index");
+ auto pred = b.Symbols().New("predicate");
+
+ hoist.InsertBefore(stmt, b.Decl(b.Let(idx_let, ctx.Clone(idx))));
+ ctx.Replace(idx, b.Expr(idx_let));
+
+ auto* cond = b.LessThanEqual(b.Call<u32>(b.Expr(idx_let)), max);
+ if (obj_pred.IsValid()) {
+ cond = b.And(b.Expr(obj_pred), cond);
+ }
+ hoist.InsertBefore(stmt, b.Decl(b.Let(pred, cond)));
+
+ predicates.Add(expr->Declaration(), pred);
+ }
+
+ /// Applies bounds clamping to the index on an array, vector or matrix.
+ /// @param expr the index accessor expression.
+ void ClampIndexAccessor(const sem::IndexAccessorExpression* expr) {
+ auto* max = DynamicLimitFor(expr);
+ if (!max) {
+ return; // robustness is not required
+ }
+
+ auto* expr_sem = expr->Unwrap()->As<sem::IndexAccessorExpression>();
+ auto idx = CastToU32(expr_sem->Index());
+ auto* clamped_idx = b.Call(builtin::Function::kMin, idx, max);
+ ctx.Replace(expr->Declaration()->index, clamped_idx);
+ }
+
+ /// Applies predication to the non-texture builtin call, if required.
+ void MaybePredicateNonTextureBuiltin(const sem::Call* call, const sem::Builtin* builtin) {
+ // Gather the predications for the builtin arguments
+ const Expression* predicate = nullptr;
+ for (auto* arg : call->Declaration()->args) {
+ if (auto pred = predicates.Get(arg)) {
+ predicate = And(predicate, b.Expr(*pred));
+ }
+ }
+
+ if (predicate) {
+ if (builtin->Type() == builtin::Function::kWorkgroupUniformLoad) {
+ // https://www.w3.org/TR/WGSL/#workgroupUniformLoad-builtin:
+ // "Executes a control barrier synchronization function that affects memory and
+ // atomic operations in the workgroup address space."
+ // Because the call acts like a control barrier, we need to make sure that we still
+ // trigger a workgroup barrier if the predicate fails.
+ PredicateCall(call, predicate,
+ b.Block(b.CallStmt(b.Call(builtin::Function::kWorkgroupBarrier))));
+ } else {
+ PredicateCall(call, predicate);
+ }
+ }
+ }
+
+ /// Applies predication to texture builtins, based on whether the coordinates, array index and
+ /// level arguments are all in bounds.
+ void PredicateTextureBuiltin(const sem::Call* call, const sem::Builtin* builtin) {
+ if (!TextureBuiltinNeedsRobustness(builtin->Type())) {
+ return;
+ }
+
+ auto* expr = call->Declaration();
+ auto* stmt = call->Stmt();
+
+ // Indices of the mandatory texture and coords parameters, and the optional
+ // array and level parameters.
+ auto& signature = builtin->Signature();
+ auto texture_arg_idx = signature.IndexOf(sem::ParameterUsage::kTexture);
+ auto coords_arg_idx = signature.IndexOf(sem::ParameterUsage::kCoords);
+ auto array_arg_idx = signature.IndexOf(sem::ParameterUsage::kArrayIndex);
+ auto level_arg_idx = signature.IndexOf(sem::ParameterUsage::kLevel);
+
+ auto* texture_arg = expr->args[static_cast<size_t>(texture_arg_idx)];
+
+ // Build the builtin predicate from the arguments
+ const Expression* predicate = nullptr;
+
+ Symbol level_idx, num_levels;
+ if (level_arg_idx >= 0) {
+ auto* param = builtin->Parameters()[static_cast<size_t>(level_arg_idx)];
+ if (param->Type()->is_integer_scalar()) {
+ // let level_idx = u32(level-arg);
+ level_idx = b.Symbols().New("level_idx");
+ auto* arg = expr->args[static_cast<size_t>(level_arg_idx)];
+ hoist.InsertBefore(stmt,
+ b.Decl(b.Let(level_idx, CastToUnsigned(ctx.Clone(arg), 1u))));
+
+ // let num_levels = textureNumLevels(texture-arg);
+ num_levels = b.Symbols().New("num_levels");
+ hoist.InsertBefore(
+ stmt, b.Decl(b.Let(num_levels, b.Call(builtin::Function::kTextureNumLevels,
+ ctx.Clone(texture_arg)))));
+
+ // predicate: level_idx < num_levels
+ predicate = And(predicate, b.LessThan(level_idx, num_levels));
+
+ // Replace the level argument with `level_idx`
+ ctx.Replace(arg, b.Expr(level_idx));
+ }
+ }
+
+ Symbol coords;
+ if (coords_arg_idx >= 0) {
+ auto* param = builtin->Parameters()[static_cast<size_t>(coords_arg_idx)];
+ if (param->Type()->is_integer_scalar_or_vector()) {
+ // let coords = u32(coords-arg)
+ coords = b.Symbols().New("coords");
+ auto* arg = expr->args[static_cast<size_t>(coords_arg_idx)];
+ hoist.InsertBefore(stmt,
+ b.Decl(b.Let(coords, CastToUnsigned(b.Expr(ctx.Clone(arg)),
+ WidthOf(param->Type())))));
+
+ // predicate: all(coords < textureDimensions(texture))
+ auto* dimensions =
+ level_idx.IsValid()
+ ? b.Call(builtin::Function::kTextureDimensions, ctx.Clone(texture_arg),
+ b.Call(builtin::Function::kMin, b.Expr(level_idx),
+ b.Sub(num_levels, 1_a)))
+ : b.Call(builtin::Function::kTextureDimensions, ctx.Clone(texture_arg));
+ predicate =
+ And(predicate, b.Call(builtin::Function::kAll, b.LessThan(coords, dimensions)));
+
+ // Replace the level argument with `coord`
+ ctx.Replace(arg, b.Expr(coords));
+ }
+ }
+
+ if (array_arg_idx >= 0) {
+ // let array_idx = u32(array-arg)
+ auto* arg = expr->args[static_cast<size_t>(array_arg_idx)];
+ auto* num_layers = b.Call(builtin::Function::kTextureNumLayers, ctx.Clone(texture_arg));
+ auto array_idx = b.Symbols().New("array_idx");
+ hoist.InsertBefore(stmt, b.Decl(b.Let(array_idx, CastToUnsigned(ctx.Clone(arg), 1u))));
+
+ // predicate: array_idx < textureNumLayers(texture)
+ predicate = And(predicate, b.LessThan(array_idx, num_layers));
+
+ // Replace the array index argument with `array_idx`
+ ctx.Replace(arg, b.Expr(array_idx));
+ }
+
+ if (predicate) {
+ PredicateCall(call, predicate);
+ }
+ }
+
+ /// Applies bounds clamping to the coordinates, array index and level arguments of the texture
+ /// builtin.
+ void ClampTextureBuiltin(const sem::Call* call, const sem::Builtin* builtin) {
+ if (!TextureBuiltinNeedsRobustness(builtin->Type())) {
+ return;
+ }
+
+ auto* expr = call->Declaration();
+ auto* stmt = call->Stmt();
+
+ // Indices of the mandatory texture and coords parameters, and the optional
+ // array and level parameters.
+ auto& signature = builtin->Signature();
+ auto texture_arg_idx = signature.IndexOf(sem::ParameterUsage::kTexture);
+ auto coords_arg_idx = signature.IndexOf(sem::ParameterUsage::kCoords);
+ auto array_arg_idx = signature.IndexOf(sem::ParameterUsage::kArrayIndex);
+ auto level_arg_idx = signature.IndexOf(sem::ParameterUsage::kLevel);
+
+ auto* texture_arg = expr->args[static_cast<size_t>(texture_arg_idx)];
+
+ // If the level is provided, then we need to clamp this. As the level is used by
+ // textureDimensions() and the texture[Load|Store]() calls, we need to clamp both usages.
+ Symbol level_idx;
+ if (level_arg_idx >= 0) {
+ const auto* param = builtin->Parameters()[static_cast<size_t>(level_arg_idx)];
+ if (param->Type()->is_integer_scalar()) {
+ const auto* arg = expr->args[static_cast<size_t>(level_arg_idx)];
+ level_idx = b.Symbols().New("level_idx");
+ const auto* num_levels =
+ b.Call(builtin::Function::kTextureNumLevels, ctx.Clone(texture_arg));
+ const auto* max = b.Sub(num_levels, 1_a);
+ hoist.InsertBefore(
+ stmt, b.Decl(b.Let(level_idx, b.Call(builtin::Function::kMin,
+ b.Call<u32>(ctx.Clone(arg)), max))));
+ ctx.Replace(arg, b.Expr(level_idx));
+ }
+ }
+
+ // Clamp the coordinates argument
+ if (coords_arg_idx >= 0) {
+ const auto* param = builtin->Parameters()[static_cast<size_t>(coords_arg_idx)];
+ if (param->Type()->is_integer_scalar_or_vector()) {
+ auto* arg = expr->args[static_cast<size_t>(coords_arg_idx)];
+ const auto width = WidthOf(param->Type());
+ const auto* dimensions =
+ level_idx.IsValid()
+ ? b.Call(builtin::Function::kTextureDimensions, ctx.Clone(texture_arg),
+ level_idx)
+ : b.Call(builtin::Function::kTextureDimensions, ctx.Clone(texture_arg));
+
+ // dimensions is u32 or vecN<u32>
+ const auto* unsigned_max = b.Sub(dimensions, ScalarOrVec(b.Expr(1_a), width));
+ if (param->Type()->is_signed_integer_scalar_or_vector()) {
+ const auto* zero = ScalarOrVec(b.Expr(0_a), width);
+ const auto* signed_max = CastToSigned(unsigned_max, width);
+ ctx.Replace(
+ arg, b.Call(builtin::Function::kClamp, ctx.Clone(arg), zero, signed_max));
+ } else {
+ ctx.Replace(arg, b.Call(builtin::Function::kMin, ctx.Clone(arg), unsigned_max));
+ }
+ }
+ }
+
+ // Clamp the array_index argument, if provided
+ if (array_arg_idx >= 0) {
+ auto* param = builtin->Parameters()[static_cast<size_t>(array_arg_idx)];
+ auto* arg = expr->args[static_cast<size_t>(array_arg_idx)];
+ auto* num_layers = b.Call(builtin::Function::kTextureNumLayers, ctx.Clone(texture_arg));
+
+ const auto* unsigned_max = b.Sub(num_layers, 1_a);
+ if (param->Type()->is_signed_integer_scalar()) {
+ const auto* signed_max = CastToSigned(unsigned_max, 1u);
+ ctx.Replace(arg,
+ b.Call(builtin::Function::kClamp, ctx.Clone(arg), 0_a, signed_max));
+ } else {
+ ctx.Replace(arg, b.Call(builtin::Function::kMin, ctx.Clone(arg), unsigned_max));
+ }
+ }
+ }
+
+ /// @param type builtin type
+ /// @returns true if the given builtin is a texture function that requires predication or
+ /// clamping of arguments.
+ bool TextureBuiltinNeedsRobustness(builtin::Function type) {
+ return type == builtin::Function::kTextureLoad ||
+ type == builtin::Function::kTextureStore ||
+ type == builtin::Function::kTextureDimensions;
+ }
+
+ /// @returns a bitwise and of the two expressions, or the other expression if one is null.
+ const Expression* And(const Expression* lhs, const Expression* rhs) {
+ if (lhs && rhs) {
+ return b.And(lhs, rhs);
+ }
+ if (lhs) {
+ return lhs;
+ }
+ return rhs;
+ }
+
+ /// Transforms a call statement or expression so that the expression is predicated by @p
+ /// predicate.
+ /// @param else_stmt - the statement to execute for the predication failure
+ void PredicateCall(const sem::Call* call,
+ const Expression* predicate,
+ const BlockStatement* else_stmt = nullptr) {
+ auto* expr = call->Declaration();
+ auto* stmt = call->Stmt();
+ auto* call_stmt = stmt->Declaration()->As<CallStatement>();
+ if (call_stmt && call_stmt->expr == expr) {
+ // Wrap the statement in an if-statement with the predicate condition.
+ hoist.Replace(stmt, b.If(predicate, b.Block(ctx.Clone(stmt->Declaration())),
+ ProgramBuilder::ElseStmt(else_stmt)));
+ } else {
+ // Emit the following before the expression's statement:
+ // var predicated_value : return-type;
+ // if (predicate) {
+ // predicated_value = call(...);
+ // }
+ auto value = b.Symbols().New("predicated_value");
+ hoist.InsertBefore(stmt, b.Decl(b.Var(value, CreateASTTypeFor(ctx, call->Type()))));
+ hoist.InsertBefore(stmt, b.If(predicate, b.Block(b.Assign(value, ctx.Clone(expr))),
+ ProgramBuilder::ElseStmt(else_stmt)));
+
+ // Replace the call expression with `predicated_value`
+ ctx.Replace(expr, b.Expr(value));
+ }
+ }
+
+ /// @returns true if @p action is enabled for any address space
+ bool HasAction(Action action) const {
+ return action == cfg.function_action || //
+ action == cfg.texture_action || //
+ action == cfg.private_action || //
+ action == cfg.push_constant_action || //
+ action == cfg.storage_action || //
+ action == cfg.uniform_action || //
+ action == cfg.workgroup_action;
+ }
+
+ /// @returns the robustness action to perform for an OOB access with the expression @p expr
+ Action ActionFor(const sem::ValueExpression* expr) {
+ return Switch(
+ expr->Type(), //
+ [&](const type::Reference* t) { return ActionFor(t->AddressSpace()); },
+ [&](Default) { return cfg.value_action; });
+ }
+
+ /// @returns the robustness action to perform for an OOB access in the address space @p
+ /// address_space
+ Action ActionFor(builtin::AddressSpace address_space) {
+ switch (address_space) {
+ case builtin::AddressSpace::kFunction:
+ return cfg.function_action;
+ case builtin::AddressSpace::kHandle:
+ return cfg.texture_action;
+ case builtin::AddressSpace::kPrivate:
+ return cfg.private_action;
+ case builtin::AddressSpace::kPushConstant:
+ return cfg.push_constant_action;
+ case builtin::AddressSpace::kStorage:
+ return cfg.storage_action;
+ case builtin::AddressSpace::kUniform:
+ return cfg.uniform_action;
+ case builtin::AddressSpace::kWorkgroup:
+ return cfg.workgroup_action;
+ default:
+ break;
+ }
+ TINT_UNREACHABLE(Transform, b.Diagnostics()) << "unhandled address space" << address_space;
+ return Action::kDefault;
+ }
+
+ /// @returns the vector width of @p ty, or 1 if @p ty is not a vector
+ static uint32_t WidthOf(const type::Type* ty) {
+ if (auto* vec = ty->As<type::Vector>()) {
+ return vec->Width();
+ }
+ return 1u;
+ }
+
+ /// @returns a scalar or vector type with the element type @p scalar and width @p width
+ Type ScalarOrVecTy(Type scalar, uint32_t width) const {
+ if (width > 1) {
+ return b.ty.vec(scalar, width);
+ }
+ return scalar;
+ }
+
+ /// @returns a vector constructed with the scalar expression @p scalar if @p width > 1,
+ /// otherwise returns @p scalar.
+ const Expression* ScalarOrVec(const Expression* scalar, uint32_t width) {
+ if (width > 1) {
+ return b.Call(b.ty.vec<Infer>(width), scalar);
+ }
+ return scalar;
+ }
+
+ /// @returns @p val cast to a `vecN<i32>`, where `N` is @p width, or cast to i32 if @p width
+ /// is 1.
+ const CallExpression* CastToSigned(const Expression* val, uint32_t width) {
+ return b.Call(ScalarOrVecTy(b.ty.i32(), width), val);
+ }
+
+ /// @returns @p val cast to a `vecN<u32>`, where `N` is @p width, or cast to u32 if @p width
+ /// is 1.
+ const CallExpression* CastToUnsigned(const Expression* val, uint32_t width) {
+ return b.Call(ScalarOrVecTy(b.ty.u32(), width), val);
+ }
+
+ /// @returns true if the variable represents a resource binding that should be ignored in the
+ /// robustness check.
+ /// TODO(tint:1890): make this function work with unrestricted pointer paramters. Note that this
+ /// depends on transform::DirectVariableAccess to have been run first.
+ bool IsIgnoredResourceBinding(const sem::Variable* variable) const {
+ auto* globalVariable = utils::As<sem::GlobalVariable>(variable);
+ if (globalVariable == nullptr) {
+ return false;
+ }
+ if (!globalVariable->BindingPoint().has_value()) {
+ return false;
+ }
+ sem::BindingPoint bindingPoint = *globalVariable->BindingPoint();
+ return cfg.bindings_ignored.find(bindingPoint) != cfg.bindings_ignored.cend();
+ }
+
+ /// @returns true if expr is an IndexAccessorExpression whose object is a runtime-sized array.
+ bool IsIndexAccessingRuntimeSizedArray(const sem::IndexAccessorExpression* expr) {
+ auto* array_type = expr->Object()->Type()->UnwrapRef()->As<type::Array>();
+ return array_type != nullptr && array_type->Count()->Is<type::RuntimeArrayCount>();
+ }
+
+ /// @returns a clone of expr->Declaration() if it is an unsigned integer scalar, or
+ /// expr->Declaration() cast to u32.
+ const ast::Expression* CastToU32(const sem::ValueExpression* expr) {
+ auto* idx = ctx.Clone(expr->Declaration());
+ if (expr->Type()->is_unsigned_integer_scalar()) {
+ return idx;
+ }
+ return b.Call<u32>(idx); // u32(idx)
+ }
+};
+
+Robustness::Config::Config() = default;
+Robustness::Config::Config(const Config&) = default;
+Robustness::Config::~Config() = default;
+Robustness::Config& Robustness::Config::operator=(const Config&) = default;
+
+Robustness::Robustness() = default;
+Robustness::~Robustness() = default;
+
+Transform::ApplyResult Robustness::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap&) const {
+ Config cfg;
+ if (auto* cfg_data = inputs.Get<Config>()) {
+ cfg = *cfg_data;
+ }
+
+ return State{src, std::move(cfg)}.Run();
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/robustness.h b/src/tint/lang/wgsl/ast/transform/robustness.h
new file mode 100644
index 0000000..8f12e04
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/robustness.h
@@ -0,0 +1,107 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_ROBUSTNESS_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_ROBUSTNESS_H_
+
+#include <unordered_set>
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+#include "src/tint/sem/binding_point.h"
+
+namespace tint::ast::transform {
+
+/// This transform is responsible for ensuring that all out of bounds accesses are prevented,
+/// either by conditioning the access (predication) or through clamping of the index to keep the
+/// access in bounds.
+/// @note Robustness must come after:
+/// * PromoteSideEffectsToDecl as Robustness requires side-effecting expressions to be hoisted
+/// to their own statements.
+/// Robustness must come before:
+/// * BuiltinPolyfill as 'clamp' and binary operators may need to be polyfilled.
+/// * CanonicalizeEntryPointIO as the transform does not support the 'in' and 'out' address
+/// spaces.
+class Robustness final : public utils::Castable<Robustness, Transform> {
+ public:
+ /// Robustness action for out-of-bounds indexing.
+ enum class Action {
+ /// Do nothing to prevent the out-of-bounds action.
+ kIgnore,
+ /// Clamp the index to be within bounds.
+ kClamp,
+ /// Do not execute the read or write if the index is out-of-bounds.
+ kPredicate,
+
+ /// The default action
+ kDefault = kClamp,
+ };
+
+ /// Configuration options for the transform
+ struct Config final : public utils::Castable<Config, Data> {
+ /// Constructor
+ Config();
+
+ /// Copy constructor
+ Config(const Config&);
+
+ /// Destructor
+ ~Config() override;
+
+ /// Assignment operator
+ /// @returns this Config
+ Config& operator=(const Config&);
+
+ /// Robustness action for values
+ Action value_action = Action::kDefault;
+
+ /// Robustness action for non-sampling texture operations
+ Action texture_action = Action::kDefault;
+
+ /// Robustness action for variables in the 'function' address space
+ Action function_action = Action::kDefault;
+ /// Robustness action for variables in the 'private' address space
+ Action private_action = Action::kDefault;
+ /// Robustness action for variables in the 'push_constant' address space
+ Action push_constant_action = Action::kDefault;
+ /// Robustness action for variables in the 'storage' address space
+ Action storage_action = Action::kDefault;
+ /// Robustness action for variables in the 'uniform' address space
+ Action uniform_action = Action::kDefault;
+ /// Robustness action for variables in the 'workgroup' address space
+ Action workgroup_action = Action::kDefault;
+
+ /// Bindings that should always be applied Actions::kIgnore on
+ std::unordered_set<tint::sem::BindingPoint> bindings_ignored;
+
+ /// If we should disable index clamping on runtime-sized arrays in robustness transform
+ bool disable_runtime_sized_array_index_clamping = false;
+ };
+
+ /// Constructor
+ Robustness();
+ /// Destructor
+ ~Robustness() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_ROBUSTNESS_H_
diff --git a/src/tint/lang/wgsl/ast/transform/robustness_test.cc b/src/tint/lang/wgsl/ast/transform/robustness_test.cc
new file mode 100644
index 0000000..7bef166
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/robustness_test.cc
@@ -0,0 +1,5631 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/robustness.h"
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+
+static std::ostream& operator<<(std::ostream& out, Robustness::Action action) {
+ switch (action) {
+ case Robustness::Action::kIgnore:
+ return out << "ignore";
+ case Robustness::Action::kClamp:
+ return out << "clamp";
+ case Robustness::Action::kPredicate:
+ return out << "predicate";
+ }
+ return out << "unknown";
+}
+
+namespace {
+
+Transform::DataMap Config(Robustness::Action action,
+ bool disable_runtime_sized_array_index_clamping = false) {
+ Robustness::Config cfg;
+ cfg.value_action = action;
+ cfg.texture_action = action;
+ cfg.function_action = action;
+ cfg.private_action = action;
+ cfg.push_constant_action = action;
+ cfg.storage_action = action;
+ cfg.uniform_action = action;
+ cfg.workgroup_action = action;
+ cfg.disable_runtime_sized_array_index_clamping = disable_runtime_sized_array_index_clamping;
+ Transform::DataMap data;
+ data.Add<Robustness::Config>(cfg);
+ return data;
+}
+
+const char* Expect(Robustness::Action action,
+ const char* expect_ignore,
+ const char* expect_clamp,
+ const char* expect_predicate) {
+ switch (action) {
+ case Robustness::Action::kIgnore:
+ return expect_ignore;
+ case Robustness::Action::kClamp:
+ return expect_clamp;
+ case Robustness::Action::kPredicate:
+ return expect_predicate;
+ }
+ return "<invalid action>";
+}
+
+using RobustnessTest = TransformTestWithParam<Robustness::Action>;
+
+////////////////////////////////////////////////////////////////////////////////
+// Constant sized array
+////////////////////////////////////////////////////////////////////////////////
+
+TEST_P(RobustnessTest, Read_ConstantSizedArrayVal_IndexWithLiteral) {
+ auto* src = R"(
+fn f() {
+ var b : f32 = array<f32, 3>()[1i];
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_ConstantSizedArrayVal_IndexWithConst) {
+ auto* src = R"(
+const c : u32 = 1u;
+
+fn f() {
+ let b : f32 = array<f32, 3>()[c];
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_ConstantSizedArrayVal_IndexWithLet) {
+ auto* src = R"(
+fn f() {
+ let l : u32 = 1u;
+ let b : f32 = array<f32, 3>()[l];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+fn f() {
+ let l : u32 = 1u;
+ let b : f32 = array<f32, 3>()[min(l, 2u)];
+}
+)",
+ /* predicate */ R"(
+fn f() {
+ let l : u32 = 1u;
+ let index = l;
+ let predicate = (u32(index) <= 2u);
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = array<f32, 3>()[index];
+ }
+ let b : f32 = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_ConstantSizedArrayVal_IndexWithRuntimeArrayIndex) {
+ auto* src = R"(
+var<private> i : u32;
+
+fn f() {
+ let a = array<f32, 3>();
+ let b = array<i32, 5>();
+ var c : f32 = a[b[i]];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> i : u32;
+
+fn f() {
+ let a = array<f32, 3>();
+ let b = array<i32, 5>();
+ var c : f32 = a[min(u32(b[min(i, 4u)]), 2u)];
+}
+)",
+ /* predicate */ R"(
+var<private> i : u32;
+
+fn f() {
+ let a = array<f32, 3>();
+ let b = array<i32, 5>();
+ let index = i;
+ let predicate = (u32(index) <= 4u);
+ var predicated_expr : i32;
+ if (predicate) {
+ predicated_expr = b[index];
+ }
+ let index_1 = predicated_expr;
+ let predicate_1 = (u32(index_1) <= 2u);
+ var predicated_expr_1 : f32;
+ if (predicate_1) {
+ predicated_expr_1 = a[index_1];
+ }
+ var c : f32 = predicated_expr_1;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_ConstantSizedArrayVal_IndexWithRuntimeExpression) {
+ auto* src = R"(
+var<private> c : i32;
+
+fn f() {
+ var b : f32 = array<f32, 3>()[((c + 2) - 3)];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> c : i32;
+
+fn f() {
+ var b : f32 = array<f32, 3>()[min(u32(((c + 2) - 3)), 2u)];
+}
+)",
+ /* predicate */ R"(
+var<private> c : i32;
+
+fn f() {
+ let index = ((c + 2) - 3);
+ let predicate = (u32(index) <= 2u);
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = array<f32, 3>()[index];
+ }
+ var b : f32 = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_NestedConstantSizedArraysVal_IndexWithRuntimeExpressions) {
+ auto* src = R"(
+var<private> x : i32;
+
+var<private> y : i32;
+
+var<private> z : i32;
+
+fn f() {
+ let a = array<array<array<f32, 1>, 2>, 3>();
+ var r = a[x][y][z];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> x : i32;
+
+var<private> y : i32;
+
+var<private> z : i32;
+
+fn f() {
+ let a = array<array<array<f32, 1>, 2>, 3>();
+ var r = a[min(u32(x), 2u)][min(u32(y), 1u)][min(u32(z), 0u)];
+}
+)",
+ /* predicate */ R"(
+var<private> x : i32;
+
+var<private> y : i32;
+
+var<private> z : i32;
+
+fn f() {
+ let a = array<array<array<f32, 1>, 2>, 3>();
+ let index = x;
+ let predicate = (u32(index) <= 2u);
+ var predicated_expr : array<array<f32, 1u>, 2u>;
+ if (predicate) {
+ predicated_expr = a[index];
+ }
+ let index_1 = y;
+ let predicate_1 = (u32(index_1) <= 1u);
+ var predicated_expr_1 : array<f32, 1u>;
+ if (predicate_1) {
+ predicated_expr_1 = predicated_expr[index_1];
+ }
+ let index_2 = z;
+ let predicate_2 = (u32(index_2) <= 0u);
+ var predicated_expr_2 : f32;
+ if (predicate_2) {
+ predicated_expr_2 = predicated_expr_1[index_2];
+ }
+ var r = predicated_expr_2;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_ConstantSizedArrayVal_IndexWithOverride) {
+ auto* src = R"(
+@id(1300) override idx : i32;
+
+fn f() {
+ let a = array<f32, 4>();
+ var b : f32 = a[idx];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@id(1300) override idx : i32;
+
+fn f() {
+ let a = array<f32, 4>();
+ var b : f32 = a[min(u32(idx), 3u)];
+}
+)",
+ /* predicate */ R"(
+@id(1300) override idx : i32;
+
+fn f() {
+ let a = array<f32, 4>();
+ let index = idx;
+ let predicate = (u32(index) <= 3u);
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = a[index];
+ }
+ var b : f32 = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_ConstantSizedArrayRef_IndexWithLiteral) {
+ auto* src = R"(
+var<private> a : array<f32, 3>;
+
+fn f() {
+ var b : f32 = a[1i];
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_ConstantSizedArrayRef_IndexWithConst) {
+ auto* src = R"(
+var<private> a : array<f32, 3>;
+
+const c : u32 = 1u;
+
+fn f() {
+ let b : f32 = a[c];
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_ConstantSizedArrayRef_IndexWithLet) {
+ auto* src = R"(
+var<private> a : array<f32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ let b : f32 = a[l];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : array<f32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ let b : f32 = a[min(l, 2u)];
+}
+)",
+ /* predicate */ R"(
+var<private> a : array<f32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ let index = l;
+ let predicate = (u32(index) <= 2u);
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = a[index];
+ }
+ let b : f32 = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_ConstantSizedArrayRef_IndexWithRuntimeArrayIndex) {
+ auto* src = R"(
+var<private> a : array<f32, 3>;
+
+var<private> b : array<i32, 5>;
+
+var<private> i : u32;
+
+fn f() {
+ var c : f32 = a[b[i]];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : array<f32, 3>;
+
+var<private> b : array<i32, 5>;
+
+var<private> i : u32;
+
+fn f() {
+ var c : f32 = a[min(u32(b[min(i, 4u)]), 2u)];
+}
+)",
+ /* predicate */ R"(
+var<private> a : array<f32, 3>;
+
+var<private> b : array<i32, 5>;
+
+var<private> i : u32;
+
+fn f() {
+ let index = i;
+ let predicate = (u32(index) <= 4u);
+ var predicated_expr : i32;
+ if (predicate) {
+ predicated_expr = b[index];
+ }
+ let index_1 = predicated_expr;
+ let predicate_1 = (u32(index_1) <= 2u);
+ var predicated_expr_1 : f32;
+ if (predicate_1) {
+ predicated_expr_1 = a[index_1];
+ }
+ var c : f32 = predicated_expr_1;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_ConstantSizedArrayRef_IndexWithRuntimeExpression) {
+ auto* src = R"(
+var<private> a : array<f32, 3>;
+
+var<private> c : i32;
+
+fn f() {
+ var b : f32 = a[((c + 2) - 3)];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : array<f32, 3>;
+
+var<private> c : i32;
+
+fn f() {
+ var b : f32 = a[min(u32(((c + 2) - 3)), 2u)];
+}
+)",
+ /* predicate */ R"(
+var<private> a : array<f32, 3>;
+
+var<private> c : i32;
+
+fn f() {
+ let index = ((c + 2) - 3);
+ let predicate = (u32(index) <= 2u);
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = a[index];
+ }
+ var b : f32 = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_NestedConstantSizedArraysRef_IndexWithRuntimeExpressions) {
+ auto* src = R"(
+var<private> a : array<array<array<f32, 1>, 2>, 3>;
+
+var<private> x : i32;
+
+var<private> y : i32;
+
+var<private> z : i32;
+
+fn f() {
+ var r = a[x][y][z];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : array<array<array<f32, 1>, 2>, 3>;
+
+var<private> x : i32;
+
+var<private> y : i32;
+
+var<private> z : i32;
+
+fn f() {
+ var r = a[min(u32(x), 2u)][min(u32(y), 1u)][min(u32(z), 0u)];
+}
+)",
+ /* predicate */ R"(
+var<private> a : array<array<array<f32, 1>, 2>, 3>;
+
+var<private> x : i32;
+
+var<private> y : i32;
+
+var<private> z : i32;
+
+fn f() {
+ let index = x;
+ let predicate = (u32(index) <= 2u);
+ let index_1 = y;
+ let predicate_1 = (predicate & (u32(index_1) <= 1u));
+ let index_2 = z;
+ let predicate_2 = (predicate_1 & (u32(index_2) <= 0u));
+ var predicated_expr : f32;
+ if (predicate_2) {
+ predicated_expr = a[index][index_1][index_2];
+ }
+ var r = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_ConstantSizedArrayRef_IndexWithOverride) {
+ auto* src = R"(
+@id(1300) override idx : i32;
+
+fn f() {
+ var a : array<f32, 4>;
+ var b : f32 = a[idx];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@id(1300) override idx : i32;
+
+fn f() {
+ var a : array<f32, 4>;
+ var b : f32 = a[min(u32(idx), 3u)];
+}
+)",
+ /* predicate */ R"(
+@id(1300) override idx : i32;
+
+fn f() {
+ var a : array<f32, 4>;
+ let index = idx;
+ let predicate = (u32(index) <= 3u);
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = a[index];
+ }
+ var b : f32 = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_ConstantSizedArrayPtr_IndexWithLet) {
+ auto* src = R"(
+var<private> a : array<f32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ let p = &(a[l]);
+ let f : f32 = *(p);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : array<f32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ let p = &(a[min(l, 2u)]);
+ let f : f32 = *(p);
+}
+)",
+ /* predicate */ R"(
+var<private> a : array<f32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ let index = l;
+ let predicate = (u32(index) <= 2u);
+ let p = &(a[index]);
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = *(p);
+ }
+ let f : f32 = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_ConstantSizedArrayPtr_IndexWithRuntimeArrayIndex) {
+ auto* src = R"(
+var<private> a : array<f32, 3>;
+
+var<private> b : array<i32, 5>;
+
+var<private> i : u32;
+
+fn f() {
+ let pa = &(a);
+ let pb = &(b);
+ let p0 = &((*(pb))[i]);
+ let p1 = &(a[*(p0)]);
+ var x : f32 = *(p1);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : array<f32, 3>;
+
+var<private> b : array<i32, 5>;
+
+var<private> i : u32;
+
+fn f() {
+ let pa = &(a);
+ let pb = &(b);
+ let p0 = &((*(pb))[min(i, 4u)]);
+ let p1 = &(a[min(u32(*(p0)), 2u)]);
+ var x : f32 = *(p1);
+}
+)",
+ /* predicate */ R"(
+var<private> a : array<f32, 3>;
+
+var<private> b : array<i32, 5>;
+
+var<private> i : u32;
+
+fn f() {
+ let pa = &(a);
+ let pb = &(b);
+ let index = i;
+ let predicate = (u32(index) <= 4u);
+ let p0 = &((*(pb))[index]);
+ var predicated_expr : i32;
+ if (predicate) {
+ predicated_expr = *(p0);
+ }
+ let index_1 = predicated_expr;
+ let predicate_1 = (u32(index_1) <= 2u);
+ let p1 = &(a[index_1]);
+ var predicated_expr_1 : f32;
+ if (predicate_1) {
+ predicated_expr_1 = *(p1);
+ }
+ var x : f32 = predicated_expr_1;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_NestedConstantSizedArraysPtr_IndexWithRuntimeExpressions) {
+ auto* src = R"(
+var<private> a : array<array<array<f32, 1>, 2>, 3>;
+
+var<private> x : i32;
+
+var<private> y : i32;
+
+var<private> z : i32;
+
+fn f() {
+ let p0 = &(a);
+ let p1 = &((*(p0))[x]);
+ let p2 = &((*(p1))[y]);
+ let p3 = &((*(p2))[z]);
+ var r = *(p3);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : array<array<array<f32, 1>, 2>, 3>;
+
+var<private> x : i32;
+
+var<private> y : i32;
+
+var<private> z : i32;
+
+fn f() {
+ let p0 = &(a);
+ let p1 = &((*(p0))[min(u32(x), 2u)]);
+ let p2 = &((*(p1))[min(u32(y), 1u)]);
+ let p3 = &((*(p2))[min(u32(z), 0u)]);
+ var r = *(p3);
+}
+)",
+ /* predicate */ R"(
+var<private> a : array<array<array<f32, 1>, 2>, 3>;
+
+var<private> x : i32;
+
+var<private> y : i32;
+
+var<private> z : i32;
+
+fn f() {
+ let p0 = &(a);
+ let index = x;
+ let predicate = (u32(index) <= 2u);
+ let p1 = &((*(p0))[index]);
+ let index_1 = y;
+ let predicate_1 = (predicate & (u32(index_1) <= 1u));
+ let p2 = &((*(p1))[index_1]);
+ let index_2 = z;
+ let predicate_2 = (predicate_1 & (u32(index_2) <= 0u));
+ let p3 = &((*(p2))[index_2]);
+ var predicated_expr : f32;
+ if (predicate_2) {
+ predicated_expr = *(p3);
+ }
+ var r = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_NestedConstantSizedArrays_MixedAccess) {
+ auto* src = R"(
+var<private> a : array<array<array<f32, 1>, 2>, 3>;
+
+var<private> x : i32;
+
+const y = 1;
+
+override z : i32;
+
+fn f() {
+ let p0 = &(a);
+ let p1 = &((*(p0))[x]);
+ let p2 = &((*(p1))[y]);
+ let p3 = &((*(p2))[z]);
+ var r = *(p3);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : array<array<array<f32, 1>, 2>, 3>;
+
+var<private> x : i32;
+
+const y = 1;
+
+override z : i32;
+
+fn f() {
+ let p0 = &(a);
+ let p1 = &((*(p0))[min(u32(x), 2u)]);
+ let p2 = &((*(p1))[y]);
+ let p3 = &((*(p2))[min(u32(z), 0u)]);
+ var r = *(p3);
+}
+)",
+ /* predicate */ R"(
+var<private> a : array<array<array<f32, 1>, 2>, 3>;
+
+var<private> x : i32;
+
+const y = 1;
+
+override z : i32;
+
+fn f() {
+ let p0 = &(a);
+ let index = x;
+ let predicate = (u32(index) <= 2u);
+ let p1 = &((*(p0))[index]);
+ let p2 = &((*(p1))[y]);
+ let index_1 = z;
+ let predicate_1 = (predicate & (u32(index_1) <= 0u));
+ let p3 = &((*(p2))[index_1]);
+ var predicated_expr : f32;
+ if (predicate_1) {
+ predicated_expr = *(p3);
+ }
+ var r = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Assign_ConstantSizedArray_IndexWithLet) {
+ auto* src = R"(
+var<private> a : array<f32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ a[l] = 42.0f;
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : array<f32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ a[min(l, 2u)] = 42.0f;
+}
+)",
+ /* predicate */ R"(
+var<private> a : array<f32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ let index = l;
+ let predicate = (u32(index) <= 2u);
+ if (predicate) {
+ a[index] = 42.0f;
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Assign_ConstantSizedArrayPtr_IndexWithLet) {
+ auto* src = R"(
+var<private> a : array<f32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ let p = &(a[l]);
+ *(p) = 42.0f;
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : array<f32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ let p = &(a[min(l, 2u)]);
+ *(p) = 42.0f;
+}
+)",
+ /* predicate */ R"(
+var<private> a : array<f32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ let index = l;
+ let predicate = (u32(index) <= 2u);
+ let p = &(a[index]);
+ if (predicate) {
+ *(p) = 42.0f;
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Assign_ConstantSizedArrayPtr_IndexWithRuntimeArrayIndex) {
+ auto* src = R"(
+var<private> a : array<f32, 3>;
+
+var<private> b : array<i32, 5>;
+
+var<private> i : u32;
+
+fn f() {
+ let pa = &(a);
+ let pb = &(b);
+ let p0 = &((*(pb))[i]);
+ let p1 = &(a[*(p0)]);
+ *(p1) = 42.0f;
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : array<f32, 3>;
+
+var<private> b : array<i32, 5>;
+
+var<private> i : u32;
+
+fn f() {
+ let pa = &(a);
+ let pb = &(b);
+ let p0 = &((*(pb))[min(i, 4u)]);
+ let p1 = &(a[min(u32(*(p0)), 2u)]);
+ *(p1) = 42.0f;
+}
+)",
+ /* predicate */ R"(
+var<private> a : array<f32, 3>;
+
+var<private> b : array<i32, 5>;
+
+var<private> i : u32;
+
+fn f() {
+ let pa = &(a);
+ let pb = &(b);
+ let index = i;
+ let predicate = (u32(index) <= 4u);
+ let p0 = &((*(pb))[index]);
+ var predicated_expr : i32;
+ if (predicate) {
+ predicated_expr = *(p0);
+ }
+ let index_1 = predicated_expr;
+ let predicate_1 = (u32(index_1) <= 2u);
+ let p1 = &(a[index_1]);
+ if (predicate_1) {
+ *(p1) = 42.0f;
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Assign_NestedConstantSizedArraysPtr_IndexWithRuntimeExpressions) {
+ auto* src = R"(
+var<private> a : array<array<array<f32, 1>, 2>, 3>;
+
+var<private> x : i32;
+
+var<private> y : i32;
+
+var<private> z : i32;
+
+fn f() {
+ let p0 = &(a);
+ let p1 = &((*(p0))[x]);
+ let p2 = &((*(p1))[y]);
+ let p3 = &((*(p2))[z]);
+ *(p3) = 42.0f;
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : array<array<array<f32, 1>, 2>, 3>;
+
+var<private> x : i32;
+
+var<private> y : i32;
+
+var<private> z : i32;
+
+fn f() {
+ let p0 = &(a);
+ let p1 = &((*(p0))[min(u32(x), 2u)]);
+ let p2 = &((*(p1))[min(u32(y), 1u)]);
+ let p3 = &((*(p2))[min(u32(z), 0u)]);
+ *(p3) = 42.0f;
+}
+)",
+ /* predicate */ R"(
+var<private> a : array<array<array<f32, 1>, 2>, 3>;
+
+var<private> x : i32;
+
+var<private> y : i32;
+
+var<private> z : i32;
+
+fn f() {
+ let p0 = &(a);
+ let index = x;
+ let predicate = (u32(index) <= 2u);
+ let p1 = &((*(p0))[index]);
+ let index_1 = y;
+ let predicate_1 = (predicate & (u32(index_1) <= 1u));
+ let p2 = &((*(p1))[index_1]);
+ let index_2 = z;
+ let predicate_2 = (predicate_1 & (u32(index_2) <= 0u));
+ let p3 = &((*(p2))[index_2]);
+ if (predicate_2) {
+ *(p3) = 42.0f;
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Assign_NestedConstantSizedArrays_MixedAccess) {
+ auto* src = R"(
+var<private> a : array<array<array<f32, 1>, 2>, 3>;
+
+var<private> x : i32;
+
+const y = 1;
+
+override z : i32;
+
+fn f() {
+ let p0 = &(a);
+ let p1 = &((*(p0))[x]);
+ let p2 = &((*(p1))[y]);
+ let p3 = &((*(p2))[z]);
+ *(p3) = 42.0f;
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : array<array<array<f32, 1>, 2>, 3>;
+
+var<private> x : i32;
+
+const y = 1;
+
+override z : i32;
+
+fn f() {
+ let p0 = &(a);
+ let p1 = &((*(p0))[min(u32(x), 2u)]);
+ let p2 = &((*(p1))[y]);
+ let p3 = &((*(p2))[min(u32(z), 0u)]);
+ *(p3) = 42.0f;
+}
+)",
+ /* predicate */ R"(
+var<private> a : array<array<array<f32, 1>, 2>, 3>;
+
+var<private> x : i32;
+
+const y = 1;
+
+override z : i32;
+
+fn f() {
+ let p0 = &(a);
+ let index = x;
+ let predicate = (u32(index) <= 2u);
+ let p1 = &((*(p0))[index]);
+ let p2 = &((*(p1))[y]);
+ let index_1 = z;
+ let predicate_1 = (predicate & (u32(index_1) <= 0u));
+ let p3 = &((*(p2))[index_1]);
+ if (predicate_1) {
+ *(p3) = 42.0f;
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, CompoundAssign_ConstantSizedArray_IndexWithLet) {
+ auto* src = R"(
+var<private> a : array<f32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ a[l] += 42.0f;
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : array<f32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ a[min(l, 2u)] += 42.0f;
+}
+)",
+ /* predicate */ R"(
+var<private> a : array<f32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ let index = l;
+ let predicate = (u32(index) <= 2u);
+ if (predicate) {
+ a[index] += 42.0f;
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Increment_ConstantSizedArray_IndexWithLet) {
+ auto* src = R"(
+var<private> a : array<i32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ a[l]++;
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : array<i32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ a[min(l, 2u)]++;
+}
+)",
+ /* predicate */ R"(
+var<private> a : array<i32, 3>;
+
+fn f() {
+ let l : u32 = 1u;
+ let index = l;
+ let predicate = (u32(index) <= 2u);
+ if (predicate) {
+ a[index]++;
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Runtime sized array
+////////////////////////////////////////////////////////////////////////////////
+
+TEST_P(RobustnessTest, Read_RuntimeArray_IndexWithLiteral) {
+ auto* src = R"(
+struct S {
+ a : f32,
+ b : array<f32>,
+}
+
+@group(0) @binding(0) var<storage, read> s : S;
+
+fn f() {
+ var d : f32 = s.b[25];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+struct S {
+ a : f32,
+ b : array<f32>,
+}
+
+@group(0) @binding(0) var<storage, read> s : S;
+
+fn f() {
+ var d : f32 = s.b[min(u32(25), (arrayLength(&(s.b)) - 1u))];
+}
+)",
+ /* predicate */ R"(
+struct S {
+ a : f32,
+ b : array<f32>,
+}
+
+@group(0) @binding(0) var<storage, read> s : S;
+
+fn f() {
+ let index = 25;
+ let predicate = (u32(index) <= (arrayLength(&(s.b)) - 1u));
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = s.b[index];
+ }
+ var d : f32 = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Vector
+////////////////////////////////////////////////////////////////////////////////
+
+TEST_P(RobustnessTest, Read_Vector_IndexWithLiteral) {
+ auto* src = R"(
+var<private> a : vec3<f32>;
+
+fn f() {
+ var b : f32 = a[1i];
+}
+)";
+
+ auto* expect = src;
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_Vector_IndexWithConst) {
+ auto* src = R"(
+var<private> a : vec3<f32>;
+
+fn f() {
+ const i = 1;
+ var b : f32 = a[i];
+}
+)";
+
+ auto* expect = src;
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_Vector_IndexWithLet) {
+ auto* src = R"(
+fn f() {
+ let i = 99;
+ let v = vec4<f32>()[i];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+fn f() {
+ let i = 99;
+ let v = vec4<f32>()[min(u32(i), 3u)];
+}
+)",
+ /* predicate */ R"(
+fn f() {
+ let i = 99;
+ let index = i;
+ let predicate = (u32(index) <= 3u);
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = vec4<f32>()[index];
+ }
+ let v = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_Vector_IndexWithRuntimeExpression) {
+ auto* src = R"(
+var<private> a : vec3<f32>;
+
+var<private> c : i32;
+
+fn f() {
+ var b : f32 = a[((c + 2) - 3)];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : vec3<f32>;
+
+var<private> c : i32;
+
+fn f() {
+ var b : f32 = a[min(u32(((c + 2) - 3)), 2u)];
+}
+)",
+ /* predicate */ R"(
+var<private> a : vec3<f32>;
+
+var<private> c : i32;
+
+fn f() {
+ let index = ((c + 2) - 3);
+ let predicate = (u32(index) <= 2u);
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = a[index];
+ }
+ var b : f32 = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_Vector_SwizzleIndexWithGlobalVar) {
+ auto* src = R"(
+var<private> a : vec3<f32>;
+
+var<private> c : i32;
+
+fn f() {
+ var b : f32 = a.xy[c];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : vec3<f32>;
+
+var<private> c : i32;
+
+fn f() {
+ var b : f32 = a.xy[min(u32(c), 1u)];
+}
+)",
+ /* predicate */ R"(
+var<private> a : vec3<f32>;
+
+var<private> c : i32;
+
+fn f() {
+ let index = c;
+ let predicate = (u32(index) <= 1u);
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = a.xy[index];
+ }
+ var b : f32 = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_Vector_SwizzleIndexWithRuntimeExpression) {
+ auto* src = R"(
+var<private> a : vec3<f32>;
+
+var<private> c : i32;
+
+fn f() {
+ var b : f32 = a.xy[((c + 2) - 3)];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : vec3<f32>;
+
+var<private> c : i32;
+
+fn f() {
+ var b : f32 = a.xy[min(u32(((c + 2) - 3)), 1u)];
+}
+)",
+ /* predicate */ R"(
+var<private> a : vec3<f32>;
+
+var<private> c : i32;
+
+fn f() {
+ let index = ((c + 2) - 3);
+ let predicate = (u32(index) <= 1u);
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = a.xy[index];
+ }
+ var b : f32 = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_Vector_IndexWithOverride) {
+ auto* src = R"(
+@id(1300) override idx : i32;
+
+fn f() {
+ var a : vec3<f32>;
+ var b : f32 = a[idx];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@id(1300) override idx : i32;
+
+fn f() {
+ var a : vec3<f32>;
+ var b : f32 = a[min(u32(idx), 2u)];
+}
+)",
+ /* predicate */ R"(
+@id(1300) override idx : i32;
+
+fn f() {
+ var a : vec3<f32>;
+ let index = idx;
+ let predicate = (u32(index) <= 2u);
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = a[index];
+ }
+ var b : f32 = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Matrix
+////////////////////////////////////////////////////////////////////////////////
+
+TEST_P(RobustnessTest, Read_MatrixRef_IndexingWithLiterals) {
+ auto* src = R"(
+var<private> a : mat3x2<f32>;
+
+fn f() {
+ var b : f32 = a[2i][1i];
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_MatrixRef_IndexWithRuntimeExpressionThenLiteral) {
+ auto* src = R"(
+var<private> a : mat3x2<f32>;
+
+var<private> c : i32;
+
+fn f() {
+ var b : f32 = a[((c + 2) - 3)][1];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : mat3x2<f32>;
+
+var<private> c : i32;
+
+fn f() {
+ var b : f32 = a[min(u32(((c + 2) - 3)), 2u)][1];
+}
+)",
+ /* predicate */ R"(
+var<private> a : mat3x2<f32>;
+
+var<private> c : i32;
+
+fn f() {
+ let index = ((c + 2) - 3);
+ let predicate = (u32(index) <= 2u);
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = a[index][1];
+ }
+ var b : f32 = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_MatrixRef_IndexWithLiteralThenRuntimeExpression) {
+ auto* src = R"(
+var<private> a : mat3x2<f32>;
+
+var<private> c : i32;
+
+fn f() {
+ var b : f32 = a[1][((c + 2) - 3)];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> a : mat3x2<f32>;
+
+var<private> c : i32;
+
+fn f() {
+ var b : f32 = a[1][min(u32(((c + 2) - 3)), 1u)];
+}
+)",
+ /* predicate */ R"(
+var<private> a : mat3x2<f32>;
+
+var<private> c : i32;
+
+fn f() {
+ let index = ((c + 2) - 3);
+ let predicate = (u32(index) <= 1u);
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = a[1][index];
+ }
+ var b : f32 = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_MatrixRef_IndexWithOverrideThenLiteral) {
+ auto* src = R"(
+@id(1300) override idx : i32;
+
+fn f() {
+ var a : mat3x2<f32>;
+ var b : f32 = a[idx][1];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@id(1300) override idx : i32;
+
+fn f() {
+ var a : mat3x2<f32>;
+ var b : f32 = a[min(u32(idx), 2u)][1];
+}
+)",
+ /* predicate */ R"(
+@id(1300) override idx : i32;
+
+fn f() {
+ var a : mat3x2<f32>;
+ let index = idx;
+ let predicate = (u32(index) <= 2u);
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = a[index][1];
+ }
+ var b : f32 = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_MatrixRef_IndexWithLetThenSwizzle) {
+ auto* src = R"(
+fn f() {
+ let i = 1;
+ var m = mat3x2<f32>();
+ var v = m[i].yx;
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+fn f() {
+ let i = 1;
+ var m = mat3x2<f32>();
+ var v = m[min(u32(i), 2u)].yx;
+}
+)",
+ /* predicate */ R"(
+fn f() {
+ let i = 1;
+ var m = mat3x2<f32>();
+ let index = i;
+ let predicate = (u32(index) <= 2u);
+ var predicated_expr : vec2<f32>;
+ if (predicate) {
+ predicated_expr = m[index];
+ }
+ var v = predicated_expr.yx;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_MatrixRef_IndexWithLiteralThenOverride) {
+ auto* src = R"(
+@id(1300) override idx : i32;
+
+fn f() {
+ var a : mat3x2<f32>;
+ var b : f32 = a[1][idx];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@id(1300) override idx : i32;
+
+fn f() {
+ var a : mat3x2<f32>;
+ var b : f32 = a[1][min(u32(idx), 1u)];
+}
+)",
+ /* predicate */ R"(
+@id(1300) override idx : i32;
+
+fn f() {
+ var a : mat3x2<f32>;
+ let index = idx;
+ let predicate = (u32(index) <= 1u);
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = a[1][index];
+ }
+ var b : f32 = predicated_expr;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Assign_Matrix_IndexWithLet) {
+ auto* src = R"(
+var<private> m : mat3x4f;
+
+fn f() {
+ let c = 1;
+ m[c] = vec4f(1);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> m : mat3x4f;
+
+fn f() {
+ let c = 1;
+ m[min(u32(c), 2u)] = vec4f(1);
+}
+)",
+ /* predicate */ R"(
+var<private> m : mat3x4f;
+
+fn f() {
+ let c = 1;
+ let index = c;
+ let predicate = (u32(index) <= 2u);
+ if (predicate) {
+ m[index] = vec4f(1);
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, CompoundAssign_Matrix_IndexWithLet) {
+ auto* src = R"(
+var<private> m : mat3x4f;
+
+fn f() {
+ let c = 1;
+ m[c] += vec4f(1);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<private> m : mat3x4f;
+
+fn f() {
+ let c = 1;
+ m[min(u32(c), 2u)] += vec4f(1);
+}
+)",
+ /* predicate */ R"(
+var<private> m : mat3x4f;
+
+fn f() {
+ let c = 1;
+ let index = c;
+ let predicate = (u32(index) <= 2u);
+ if (predicate) {
+ m[index] += vec4f(1);
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Texture
+////////////////////////////////////////////////////////////////////////////////
+
+TEST_P(RobustnessTest, TextureDimensions) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+fn dimensions() {
+ let l = textureDimensions(t);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureDimensions_Level) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+fn dimensions_signed(level : i32) {
+ let l = textureDimensions(t, level);
+}
+
+fn dimensions_unsigned(level : u32) {
+ let l = textureDimensions(t, level);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ dimensions_signed(i32(non_uniform.x));
+ dimensions_unsigned(u32(non_uniform.x));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+fn dimensions_signed(level : i32) {
+ let level_idx = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureDimensions(t, level_idx);
+}
+
+fn dimensions_unsigned(level : u32) {
+ let level_idx_1 = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureDimensions(t, level_idx_1);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ dimensions_signed(i32(non_uniform.x));
+ dimensions_unsigned(u32(non_uniform.x));
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+fn dimensions_signed(level : i32) {
+ let level_idx = u32(level);
+ let num_levels = textureNumLevels(t);
+ var predicated_value : vec2<u32>;
+ if ((level_idx < num_levels)) {
+ predicated_value = textureDimensions(t, level_idx);
+ }
+ let l = predicated_value;
+}
+
+fn dimensions_unsigned(level : u32) {
+ let level_idx_1 = u32(level);
+ let num_levels_1 = textureNumLevels(t);
+ var predicated_value_1 : vec2<u32>;
+ if ((level_idx_1 < num_levels_1)) {
+ predicated_value_1 = textureDimensions(t, level_idx_1);
+ }
+ let l = predicated_value_1;
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ dimensions_signed(i32(non_uniform.x));
+ dimensions_unsigned(u32(non_uniform.x));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureGather) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn gather(coords : vec2f) {
+ let l = textureGather(0, t, s, coords);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureGather_Array) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d_array<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn gather_signed(coords : vec2f, array : i32) {
+ let l = textureGather(1, t, s, coords, array);
+}
+
+fn gather_unsigned(coords : vec2f, array : u32) {
+ let l = textureGather(1, t, s, coords, array);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ gather_signed(non_uniform.xy, i32(non_uniform.x));
+ gather_unsigned(non_uniform.xy, u32(non_uniform.x));
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureGatherCompare) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_depth_2d;
+
+@group(0) @binding(1) var s : sampler_comparison;
+
+fn gather(coords : vec2f, depth_ref : f32) {
+ let l = textureGatherCompare(t, s, coords, depth_ref);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureGatherCompare_Array) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_depth_2d_array;
+
+@group(0) @binding(1) var s : sampler_comparison;
+
+fn gather_signed(coords : vec2f, array : i32, depth_ref : f32) {
+ let l = textureGatherCompare(t, s, coords, array, depth_ref);
+}
+
+fn gather_unsigned(coords : vec2f, array : u32, depth_ref : f32) {
+ let l = textureGatherCompare(t, s, coords, array, depth_ref);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ gather_signed(non_uniform.xy, i32(non_uniform.x), non_uniform.x);
+ gather_unsigned(non_uniform.xy, u32(non_uniform.x), non_uniform.x);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureLoad_1D) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_1d<f32>;
+
+fn load_signed(coords : i32, level : i32) {
+ let l = textureLoad(t, coords, level);
+}
+
+fn load_unsigned(coords : u32, level : u32) {
+ let l = textureLoad(t, coords, level);
+}
+
+fn load_mixed(coords : i32, level : u32) {
+ let l = textureLoad(t, coords, level);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(i32(non_uniform.x), i32(non_uniform.x));
+ load_unsigned(u32(non_uniform.x), u32(non_uniform.x));
+ load_mixed(i32(non_uniform.x), u32(non_uniform.x));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_1d<f32>;
+
+fn load_signed(coords : i32, level : i32) {
+ let level_idx = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, clamp(coords, 0, i32((textureDimensions(t, level_idx) - 1))), level_idx);
+}
+
+fn load_unsigned(coords : u32, level : u32) {
+ let level_idx_1 = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, min(coords, (textureDimensions(t, level_idx_1) - 1)), level_idx_1);
+}
+
+fn load_mixed(coords : i32, level : u32) {
+ let level_idx_2 = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, clamp(coords, 0, i32((textureDimensions(t, level_idx_2) - 1))), level_idx_2);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(i32(non_uniform.x), i32(non_uniform.x));
+ load_unsigned(u32(non_uniform.x), u32(non_uniform.x));
+ load_mixed(i32(non_uniform.x), u32(non_uniform.x));
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_1d<f32>;
+
+fn load_signed(coords : i32, level : i32) {
+ let level_idx = u32(level);
+ let num_levels = textureNumLevels(t);
+ let coords_1 = u32(coords);
+ var predicated_value : vec4<f32>;
+ if (((level_idx < num_levels) & all((coords_1 < textureDimensions(t, min(level_idx, (num_levels - 1))))))) {
+ predicated_value = textureLoad(t, coords_1, level_idx);
+ }
+ let l = predicated_value;
+}
+
+fn load_unsigned(coords : u32, level : u32) {
+ let level_idx_1 = u32(level);
+ let num_levels_1 = textureNumLevels(t);
+ let coords_2 = u32(coords);
+ var predicated_value_1 : vec4<f32>;
+ if (((level_idx_1 < num_levels_1) & all((coords_2 < textureDimensions(t, min(level_idx_1, (num_levels_1 - 1))))))) {
+ predicated_value_1 = textureLoad(t, coords_2, level_idx_1);
+ }
+ let l = predicated_value_1;
+}
+
+fn load_mixed(coords : i32, level : u32) {
+ let level_idx_2 = u32(level);
+ let num_levels_2 = textureNumLevels(t);
+ let coords_3 = u32(coords);
+ var predicated_value_2 : vec4<f32>;
+ if (((level_idx_2 < num_levels_2) & all((coords_3 < textureDimensions(t, min(level_idx_2, (num_levels_2 - 1))))))) {
+ predicated_value_2 = textureLoad(t, coords_3, level_idx_2);
+ }
+ let l = predicated_value_2;
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(i32(non_uniform.x), i32(non_uniform.x));
+ load_unsigned(u32(non_uniform.x), u32(non_uniform.x));
+ load_mixed(i32(non_uniform.x), u32(non_uniform.x));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureLoad_2D) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+fn load_signed(coords : vec2i, level : i32) {
+ let l = textureLoad(t, coords, level);
+}
+
+fn load_unsigned(coords : vec2u, level : u32) {
+ let l = textureLoad(t, coords, level);
+}
+
+fn load_mixed(coords : vec2u, level : i32) {
+ let l = textureLoad(t, coords, level);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy), i32(non_uniform.x));
+ load_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x));
+ load_mixed(vec2u(non_uniform.xy), i32(non_uniform.x));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+fn load_signed(coords : vec2i, level : i32) {
+ let level_idx = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, clamp(coords, vec2(0), vec2<i32>((textureDimensions(t, level_idx) - vec2(1)))), level_idx);
+}
+
+fn load_unsigned(coords : vec2u, level : u32) {
+ let level_idx_1 = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, min(coords, (textureDimensions(t, level_idx_1) - vec2(1))), level_idx_1);
+}
+
+fn load_mixed(coords : vec2u, level : i32) {
+ let level_idx_2 = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, min(coords, (textureDimensions(t, level_idx_2) - vec2(1))), level_idx_2);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy), i32(non_uniform.x));
+ load_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x));
+ load_mixed(vec2u(non_uniform.xy), i32(non_uniform.x));
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+fn load_signed(coords : vec2i, level : i32) {
+ let level_idx = u32(level);
+ let num_levels = textureNumLevels(t);
+ let coords_1 = vec2<u32>(coords);
+ var predicated_value : vec4<f32>;
+ if (((level_idx < num_levels) & all((coords_1 < textureDimensions(t, min(level_idx, (num_levels - 1))))))) {
+ predicated_value = textureLoad(t, coords_1, level_idx);
+ }
+ let l = predicated_value;
+}
+
+fn load_unsigned(coords : vec2u, level : u32) {
+ let level_idx_1 = u32(level);
+ let num_levels_1 = textureNumLevels(t);
+ let coords_2 = vec2<u32>(coords);
+ var predicated_value_1 : vec4<f32>;
+ if (((level_idx_1 < num_levels_1) & all((coords_2 < textureDimensions(t, min(level_idx_1, (num_levels_1 - 1))))))) {
+ predicated_value_1 = textureLoad(t, coords_2, level_idx_1);
+ }
+ let l = predicated_value_1;
+}
+
+fn load_mixed(coords : vec2u, level : i32) {
+ let level_idx_2 = u32(level);
+ let num_levels_2 = textureNumLevels(t);
+ let coords_3 = vec2<u32>(coords);
+ var predicated_value_2 : vec4<f32>;
+ if (((level_idx_2 < num_levels_2) & all((coords_3 < textureDimensions(t, min(level_idx_2, (num_levels_2 - 1))))))) {
+ predicated_value_2 = textureLoad(t, coords_3, level_idx_2);
+ }
+ let l = predicated_value_2;
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy), i32(non_uniform.x));
+ load_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x));
+ load_mixed(vec2u(non_uniform.xy), i32(non_uniform.x));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureLoad_2DArray) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d_array<f32>;
+
+fn load_signed(coords : vec2i, array : i32, level : i32) {
+ let l = textureLoad(t, coords, array, level);
+}
+
+fn load_unsigned(coords : vec2u, array : u32, level : u32) {
+ let l = textureLoad(t, coords, array, level);
+}
+
+fn load_mixed(coords : vec2u, array : i32, level : u32) {
+ let l = textureLoad(t, coords, array, level);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy), i32(non_uniform.x), i32(non_uniform.x));
+ load_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x), u32(non_uniform.x));
+ load_mixed(vec2u(non_uniform.xy), i32(non_uniform.x), u32(non_uniform.x));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_2d_array<f32>;
+
+fn load_signed(coords : vec2i, array : i32, level : i32) {
+ let level_idx = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, clamp(coords, vec2(0), vec2<i32>((textureDimensions(t, level_idx) - vec2(1)))), clamp(array, 0, i32((textureNumLayers(t) - 1))), level_idx);
+}
+
+fn load_unsigned(coords : vec2u, array : u32, level : u32) {
+ let level_idx_1 = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, min(coords, (textureDimensions(t, level_idx_1) - vec2(1))), min(array, (textureNumLayers(t) - 1)), level_idx_1);
+}
+
+fn load_mixed(coords : vec2u, array : i32, level : u32) {
+ let level_idx_2 = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, min(coords, (textureDimensions(t, level_idx_2) - vec2(1))), clamp(array, 0, i32((textureNumLayers(t) - 1))), level_idx_2);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy), i32(non_uniform.x), i32(non_uniform.x));
+ load_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x), u32(non_uniform.x));
+ load_mixed(vec2u(non_uniform.xy), i32(non_uniform.x), u32(non_uniform.x));
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_2d_array<f32>;
+
+fn load_signed(coords : vec2i, array : i32, level : i32) {
+ let level_idx = u32(level);
+ let num_levels = textureNumLevels(t);
+ let coords_1 = vec2<u32>(coords);
+ let array_idx = u32(array);
+ var predicated_value : vec4<f32>;
+ if ((((level_idx < num_levels) & all((coords_1 < textureDimensions(t, min(level_idx, (num_levels - 1)))))) & (array_idx < textureNumLayers(t)))) {
+ predicated_value = textureLoad(t, coords_1, array_idx, level_idx);
+ }
+ let l = predicated_value;
+}
+
+fn load_unsigned(coords : vec2u, array : u32, level : u32) {
+ let level_idx_1 = u32(level);
+ let num_levels_1 = textureNumLevels(t);
+ let coords_2 = vec2<u32>(coords);
+ let array_idx_1 = u32(array);
+ var predicated_value_1 : vec4<f32>;
+ if ((((level_idx_1 < num_levels_1) & all((coords_2 < textureDimensions(t, min(level_idx_1, (num_levels_1 - 1)))))) & (array_idx_1 < textureNumLayers(t)))) {
+ predicated_value_1 = textureLoad(t, coords_2, array_idx_1, level_idx_1);
+ }
+ let l = predicated_value_1;
+}
+
+fn load_mixed(coords : vec2u, array : i32, level : u32) {
+ let level_idx_2 = u32(level);
+ let num_levels_2 = textureNumLevels(t);
+ let coords_3 = vec2<u32>(coords);
+ let array_idx_2 = u32(array);
+ var predicated_value_2 : vec4<f32>;
+ if ((((level_idx_2 < num_levels_2) & all((coords_3 < textureDimensions(t, min(level_idx_2, (num_levels_2 - 1)))))) & (array_idx_2 < textureNumLayers(t)))) {
+ predicated_value_2 = textureLoad(t, coords_3, array_idx_2, level_idx_2);
+ }
+ let l = predicated_value_2;
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy), i32(non_uniform.x), i32(non_uniform.x));
+ load_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x), u32(non_uniform.x));
+ load_mixed(vec2u(non_uniform.xy), i32(non_uniform.x), u32(non_uniform.x));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureLoad_3D) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_3d<f32>;
+
+fn load_signed(coords : vec3i, level : i32) {
+ let l = textureLoad(t, coords, level);
+}
+
+fn load_unsigned(coords : vec3u, level : u32) {
+ let l = textureLoad(t, coords, level);
+}
+
+fn load_mixed(coords : vec3u, level : i32) {
+ let l = textureLoad(t, coords, level);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec3i(non_uniform.xyz), i32(non_uniform.x));
+ load_unsigned(vec3u(non_uniform.xyz), u32(non_uniform.x));
+ load_mixed(vec3u(non_uniform.xyz), i32(non_uniform.x));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_3d<f32>;
+
+fn load_signed(coords : vec3i, level : i32) {
+ let level_idx = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, clamp(coords, vec3(0), vec3<i32>((textureDimensions(t, level_idx) - vec3(1)))), level_idx);
+}
+
+fn load_unsigned(coords : vec3u, level : u32) {
+ let level_idx_1 = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, min(coords, (textureDimensions(t, level_idx_1) - vec3(1))), level_idx_1);
+}
+
+fn load_mixed(coords : vec3u, level : i32) {
+ let level_idx_2 = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, min(coords, (textureDimensions(t, level_idx_2) - vec3(1))), level_idx_2);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec3i(non_uniform.xyz), i32(non_uniform.x));
+ load_unsigned(vec3u(non_uniform.xyz), u32(non_uniform.x));
+ load_mixed(vec3u(non_uniform.xyz), i32(non_uniform.x));
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_3d<f32>;
+
+fn load_signed(coords : vec3i, level : i32) {
+ let level_idx = u32(level);
+ let num_levels = textureNumLevels(t);
+ let coords_1 = vec3<u32>(coords);
+ var predicated_value : vec4<f32>;
+ if (((level_idx < num_levels) & all((coords_1 < textureDimensions(t, min(level_idx, (num_levels - 1))))))) {
+ predicated_value = textureLoad(t, coords_1, level_idx);
+ }
+ let l = predicated_value;
+}
+
+fn load_unsigned(coords : vec3u, level : u32) {
+ let level_idx_1 = u32(level);
+ let num_levels_1 = textureNumLevels(t);
+ let coords_2 = vec3<u32>(coords);
+ var predicated_value_1 : vec4<f32>;
+ if (((level_idx_1 < num_levels_1) & all((coords_2 < textureDimensions(t, min(level_idx_1, (num_levels_1 - 1))))))) {
+ predicated_value_1 = textureLoad(t, coords_2, level_idx_1);
+ }
+ let l = predicated_value_1;
+}
+
+fn load_mixed(coords : vec3u, level : i32) {
+ let level_idx_2 = u32(level);
+ let num_levels_2 = textureNumLevels(t);
+ let coords_3 = vec3<u32>(coords);
+ var predicated_value_2 : vec4<f32>;
+ if (((level_idx_2 < num_levels_2) & all((coords_3 < textureDimensions(t, min(level_idx_2, (num_levels_2 - 1))))))) {
+ predicated_value_2 = textureLoad(t, coords_3, level_idx_2);
+ }
+ let l = predicated_value_2;
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec3i(non_uniform.xyz), i32(non_uniform.x));
+ load_unsigned(vec3u(non_uniform.xyz), u32(non_uniform.x));
+ load_mixed(vec3u(non_uniform.xyz), i32(non_uniform.x));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureLoad_Multisampled2D) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_multisampled_2d<f32>;
+
+fn load_signed(coords : vec2i, sample : i32) {
+ let l = textureLoad(t, coords, sample);
+}
+
+fn load_unsigned(coords : vec2u, sample : u32) {
+ let l = textureLoad(t, coords, sample);
+}
+
+fn load_mixed(coords : vec2i, sample : u32) {
+ let l = textureLoad(t, coords, sample);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy), i32(non_uniform.x));
+ load_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x));
+ load_mixed(vec2i(non_uniform.xy), u32(non_uniform.x));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_multisampled_2d<f32>;
+
+fn load_signed(coords : vec2i, sample : i32) {
+ let l = textureLoad(t, clamp(coords, vec2(0), vec2<i32>((textureDimensions(t) - vec2(1)))), sample);
+}
+
+fn load_unsigned(coords : vec2u, sample : u32) {
+ let l = textureLoad(t, min(coords, (textureDimensions(t) - vec2(1))), sample);
+}
+
+fn load_mixed(coords : vec2i, sample : u32) {
+ let l = textureLoad(t, clamp(coords, vec2(0), vec2<i32>((textureDimensions(t) - vec2(1)))), sample);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy), i32(non_uniform.x));
+ load_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x));
+ load_mixed(vec2i(non_uniform.xy), u32(non_uniform.x));
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_multisampled_2d<f32>;
+
+fn load_signed(coords : vec2i, sample : i32) {
+ let coords_1 = vec2<u32>(coords);
+ var predicated_value : vec4<f32>;
+ if (all((coords_1 < textureDimensions(t)))) {
+ predicated_value = textureLoad(t, coords_1, sample);
+ }
+ let l = predicated_value;
+}
+
+fn load_unsigned(coords : vec2u, sample : u32) {
+ let coords_2 = vec2<u32>(coords);
+ var predicated_value_1 : vec4<f32>;
+ if (all((coords_2 < textureDimensions(t)))) {
+ predicated_value_1 = textureLoad(t, coords_2, sample);
+ }
+ let l = predicated_value_1;
+}
+
+fn load_mixed(coords : vec2i, sample : u32) {
+ let coords_3 = vec2<u32>(coords);
+ var predicated_value_2 : vec4<f32>;
+ if (all((coords_3 < textureDimensions(t)))) {
+ predicated_value_2 = textureLoad(t, coords_3, sample);
+ }
+ let l = predicated_value_2;
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy), i32(non_uniform.x));
+ load_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x));
+ load_mixed(vec2i(non_uniform.xy), u32(non_uniform.x));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureLoad_Depth2D) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_depth_2d;
+
+fn load_signed(coords : vec2i, level : i32) {
+ let l = textureLoad(t, coords, level);
+}
+
+fn load_unsigned(coords : vec2u, level : u32) {
+ let l = textureLoad(t, coords, level);
+}
+
+fn load_mixed(coords : vec2i, level : u32) {
+ let l = textureLoad(t, coords, level);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy), i32(non_uniform.x));
+ load_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x));
+ load_mixed(vec2i(non_uniform.xy), u32(non_uniform.x));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_depth_2d;
+
+fn load_signed(coords : vec2i, level : i32) {
+ let level_idx = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, clamp(coords, vec2(0), vec2<i32>((textureDimensions(t, level_idx) - vec2(1)))), level_idx);
+}
+
+fn load_unsigned(coords : vec2u, level : u32) {
+ let level_idx_1 = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, min(coords, (textureDimensions(t, level_idx_1) - vec2(1))), level_idx_1);
+}
+
+fn load_mixed(coords : vec2i, level : u32) {
+ let level_idx_2 = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, clamp(coords, vec2(0), vec2<i32>((textureDimensions(t, level_idx_2) - vec2(1)))), level_idx_2);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy), i32(non_uniform.x));
+ load_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x));
+ load_mixed(vec2i(non_uniform.xy), u32(non_uniform.x));
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_depth_2d;
+
+fn load_signed(coords : vec2i, level : i32) {
+ let level_idx = u32(level);
+ let num_levels = textureNumLevels(t);
+ let coords_1 = vec2<u32>(coords);
+ var predicated_value : f32;
+ if (((level_idx < num_levels) & all((coords_1 < textureDimensions(t, min(level_idx, (num_levels - 1))))))) {
+ predicated_value = textureLoad(t, coords_1, level_idx);
+ }
+ let l = predicated_value;
+}
+
+fn load_unsigned(coords : vec2u, level : u32) {
+ let level_idx_1 = u32(level);
+ let num_levels_1 = textureNumLevels(t);
+ let coords_2 = vec2<u32>(coords);
+ var predicated_value_1 : f32;
+ if (((level_idx_1 < num_levels_1) & all((coords_2 < textureDimensions(t, min(level_idx_1, (num_levels_1 - 1))))))) {
+ predicated_value_1 = textureLoad(t, coords_2, level_idx_1);
+ }
+ let l = predicated_value_1;
+}
+
+fn load_mixed(coords : vec2i, level : u32) {
+ let level_idx_2 = u32(level);
+ let num_levels_2 = textureNumLevels(t);
+ let coords_3 = vec2<u32>(coords);
+ var predicated_value_2 : f32;
+ if (((level_idx_2 < num_levels_2) & all((coords_3 < textureDimensions(t, min(level_idx_2, (num_levels_2 - 1))))))) {
+ predicated_value_2 = textureLoad(t, coords_3, level_idx_2);
+ }
+ let l = predicated_value_2;
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy), i32(non_uniform.x));
+ load_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x));
+ load_mixed(vec2i(non_uniform.xy), u32(non_uniform.x));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureLoad_Depth2DArray) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_depth_2d_array;
+
+fn load_signed(coords : vec2i, array : i32, level : i32) {
+ let l = textureLoad(t, coords, array, level);
+}
+
+fn load_unsigned(coords : vec2u, array : u32, level : u32) {
+ let l = textureLoad(t, coords, array, level);
+}
+
+fn load_mixed(coords : vec2u, array : i32, level : u32) {
+ let l = textureLoad(t, coords, array, level);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy), i32(non_uniform.x), i32(non_uniform.x));
+ load_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x), u32(non_uniform.x));
+ load_mixed(vec2u(non_uniform.xy), i32(non_uniform.x), u32(non_uniform.x));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_depth_2d_array;
+
+fn load_signed(coords : vec2i, array : i32, level : i32) {
+ let level_idx = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, clamp(coords, vec2(0), vec2<i32>((textureDimensions(t, level_idx) - vec2(1)))), clamp(array, 0, i32((textureNumLayers(t) - 1))), level_idx);
+}
+
+fn load_unsigned(coords : vec2u, array : u32, level : u32) {
+ let level_idx_1 = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, min(coords, (textureDimensions(t, level_idx_1) - vec2(1))), min(array, (textureNumLayers(t) - 1)), level_idx_1);
+}
+
+fn load_mixed(coords : vec2u, array : i32, level : u32) {
+ let level_idx_2 = min(u32(level), (textureNumLevels(t) - 1));
+ let l = textureLoad(t, min(coords, (textureDimensions(t, level_idx_2) - vec2(1))), clamp(array, 0, i32((textureNumLayers(t) - 1))), level_idx_2);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy), i32(non_uniform.x), i32(non_uniform.x));
+ load_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x), u32(non_uniform.x));
+ load_mixed(vec2u(non_uniform.xy), i32(non_uniform.x), u32(non_uniform.x));
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_depth_2d_array;
+
+fn load_signed(coords : vec2i, array : i32, level : i32) {
+ let level_idx = u32(level);
+ let num_levels = textureNumLevels(t);
+ let coords_1 = vec2<u32>(coords);
+ let array_idx = u32(array);
+ var predicated_value : f32;
+ if ((((level_idx < num_levels) & all((coords_1 < textureDimensions(t, min(level_idx, (num_levels - 1)))))) & (array_idx < textureNumLayers(t)))) {
+ predicated_value = textureLoad(t, coords_1, array_idx, level_idx);
+ }
+ let l = predicated_value;
+}
+
+fn load_unsigned(coords : vec2u, array : u32, level : u32) {
+ let level_idx_1 = u32(level);
+ let num_levels_1 = textureNumLevels(t);
+ let coords_2 = vec2<u32>(coords);
+ let array_idx_1 = u32(array);
+ var predicated_value_1 : f32;
+ if ((((level_idx_1 < num_levels_1) & all((coords_2 < textureDimensions(t, min(level_idx_1, (num_levels_1 - 1)))))) & (array_idx_1 < textureNumLayers(t)))) {
+ predicated_value_1 = textureLoad(t, coords_2, array_idx_1, level_idx_1);
+ }
+ let l = predicated_value_1;
+}
+
+fn load_mixed(coords : vec2u, array : i32, level : u32) {
+ let level_idx_2 = u32(level);
+ let num_levels_2 = textureNumLevels(t);
+ let coords_3 = vec2<u32>(coords);
+ let array_idx_2 = u32(array);
+ var predicated_value_2 : f32;
+ if ((((level_idx_2 < num_levels_2) & all((coords_3 < textureDimensions(t, min(level_idx_2, (num_levels_2 - 1)))))) & (array_idx_2 < textureNumLayers(t)))) {
+ predicated_value_2 = textureLoad(t, coords_3, array_idx_2, level_idx_2);
+ }
+ let l = predicated_value_2;
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy), i32(non_uniform.x), i32(non_uniform.x));
+ load_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x), u32(non_uniform.x));
+ load_mixed(vec2u(non_uniform.xy), i32(non_uniform.x), u32(non_uniform.x));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureLoad_External) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_external;
+
+fn load_signed(coords : vec2i) {
+ let l = textureLoad(t, coords);
+}
+
+fn load_unsigned(coords : vec2u) {
+ let l = textureLoad(t, coords);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy));
+ load_unsigned(vec2u(non_uniform.xy));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_external;
+
+fn load_signed(coords : vec2i) {
+ let l = textureLoad(t, clamp(coords, vec2(0), vec2<i32>((textureDimensions(t) - vec2(1)))));
+}
+
+fn load_unsigned(coords : vec2u) {
+ let l = textureLoad(t, min(coords, (textureDimensions(t) - vec2(1))));
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy));
+ load_unsigned(vec2u(non_uniform.xy));
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_external;
+
+fn load_signed(coords : vec2i) {
+ let coords_1 = vec2<u32>(coords);
+ var predicated_value : vec4<f32>;
+ if (all((coords_1 < textureDimensions(t)))) {
+ predicated_value = textureLoad(t, coords_1);
+ }
+ let l = predicated_value;
+}
+
+fn load_unsigned(coords : vec2u) {
+ let coords_2 = vec2<u32>(coords);
+ var predicated_value_1 : vec4<f32>;
+ if (all((coords_2 < textureDimensions(t)))) {
+ predicated_value_1 = textureLoad(t, coords_2);
+ }
+ let l = predicated_value_1;
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ load_signed(vec2i(non_uniform.xy));
+ load_unsigned(vec2u(non_uniform.xy));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureNumLayers) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_depth_2d_array;
+
+fn num_layers(coords : vec2f, depth_ref : f32) {
+ let l = textureNumLayers(t);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureNumLevels) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_depth_2d;
+
+fn num_levels(coords : vec2f, depth_ref : f32) {
+ let l = textureNumLevels(t);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureNumSamples) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_depth_multisampled_2d;
+
+fn num_levels(coords : vec2f, depth_ref : f32) {
+ let l = textureNumSamples(t);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSample) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn sample(coords : vec2f) {
+ let l = textureSample(t, s, coords);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSample_Offset) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn sample(coords : vec2f) {
+ const offset = vec2i(1);
+ let l = textureSample(t, s, coords, offset);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSample_ArrayIndex) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d_array<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn sample_signed(coords : vec2f, array : i32) {
+ let l = textureSample(t, s, coords, array);
+}
+
+fn sample_unsigned(coords : vec2f, array : u32) {
+ let l = textureSample(t, s, coords, array);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ sample_signed(non_uniform.xy, i32(non_uniform.x));
+ sample_unsigned(non_uniform.xy, u32(non_uniform.x));
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSampleBias) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn sample_bias(coords : vec2f, bias : f32) {
+ let l = textureSampleBias(t, s, coords, bias);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSampleBias_Offset) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn sample_bias(coords : vec2f, bias : f32) {
+ const offset = vec2i(1);
+ let l = textureSampleBias(t, s, coords, bias, offset);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSampleBias_ArrayIndex) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d_array<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn sample_bias_signed(coords : vec2f, array : i32, bias : f32) {
+ let l = textureSampleBias(t, s, coords, array, bias);
+}
+
+fn sample_bias_unsigned(coords : vec2f, array : u32, bias : f32) {
+ let l = textureSampleBias(t, s, coords, array, bias);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ sample_bias_signed(non_uniform.xy, i32(non_uniform.x), non_uniform.x);
+ sample_bias_unsigned(non_uniform.xy, u32(non_uniform.x), non_uniform.x);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSampleCompare) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_depth_2d;
+
+@group(0) @binding(1) var s : sampler_comparison;
+
+fn sample_compare(coords : vec2f, depth_ref : f32) {
+ let l = textureSampleCompare(t, s, coords, depth_ref);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSampleCompare_Offset) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_depth_2d;
+
+@group(0) @binding(1) var s : sampler_comparison;
+
+fn sample_compare(coords : vec2f, depth_ref : f32) {
+ const offset = vec2i(1);
+ let l = textureSampleCompare(t, s, coords, depth_ref, offset);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSampleCompare_ArrayIndex) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_depth_2d_array;
+
+@group(0) @binding(1) var s : sampler_comparison;
+
+fn sample_compare_signed(coords : vec2f, array : i32, depth_ref : f32) {
+ let l = textureSampleCompare(t, s, coords, array, depth_ref);
+}
+
+fn sample_compare_unsigned(coords : vec2f, array : u32, depth_ref : f32) {
+ let l = textureSampleCompare(t, s, coords, array, depth_ref);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ sample_compare_signed(non_uniform.xy, i32(non_uniform.x), non_uniform.x);
+ sample_compare_unsigned(non_uniform.xy, u32(non_uniform.x), non_uniform.x);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSampleCompareLevel) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_depth_2d;
+
+@group(0) @binding(1) var s : sampler_comparison;
+
+fn sample_compare_level(coords : vec2f, depth_ref : f32) {
+ let l = textureSampleCompareLevel(t, s, coords, depth_ref);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSampleCompareLevel_Offset) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_depth_2d;
+
+@group(0) @binding(1) var s : sampler_comparison;
+
+fn sample_compare_level(coords : vec2f, depth_ref : f32) {
+ const offset = vec2i(1);
+ let l = textureSampleCompareLevel(t, s, coords, depth_ref, offset);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSampleCompareLevel_ArrayIndex) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_depth_2d_array;
+
+@group(0) @binding(1) var s : sampler_comparison;
+
+fn sample_compare_level_signed(coords : vec2f, array : i32, depth_ref : f32) {
+ let l = textureSampleCompareLevel(t, s, coords, array, depth_ref);
+}
+
+fn sample_compare_level_unsigned(coords : vec2f, array : u32, depth_ref : f32) {
+ let l = textureSampleCompareLevel(t, s, coords, array, depth_ref);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ sample_compare_level_signed(non_uniform.xy, i32(non_uniform.x), non_uniform.x);
+ sample_compare_level_unsigned(non_uniform.xy, u32(non_uniform.x), non_uniform.x);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSampleGrad) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn sample_compare_level(coords : vec2f, ddx : vec2f, ddy : vec2f) {
+ let l = textureSampleGrad(t, s, coords, ddx, ddy);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSampleGrad_Offset) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn sample_compare_level(coords : vec2f, ddx : vec2f, ddy : vec2f) {
+ const offset = vec2i(1);
+ let l = textureSampleGrad(t, s, coords, ddx, ddy, offset);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSampleGrad_ArrayIndex) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d_array<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn sample_grad_signed(coords : vec2f, array : i32, ddx : vec2f, ddy : vec2f) {
+ let l = textureSampleGrad(t, s, coords, array, ddx, ddy);
+}
+
+fn sample_grad_unsigned(coords : vec2f, array : u32, ddx : vec2f, ddy : vec2f) {
+ let l = textureSampleGrad(t, s, coords, array, ddx, ddy);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ sample_grad_signed(non_uniform.xy, i32(non_uniform.x), non_uniform.xy, non_uniform.xy);
+ sample_grad_unsigned(non_uniform.xy, u32(non_uniform.x), non_uniform.xy, non_uniform.xy);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSampleLevel) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn sample_compare_level(coords : vec2f, level : f32) {
+ let l = textureSampleLevel(t, s, coords, level);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSampleLevel_Offset) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn sample_compare_level(coords : vec2f, level : f32) {
+ const offset = vec2i(1);
+ let l = textureSampleLevel(t, s, coords, level, offset);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSampleLevel_ArrayIndex) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d_array<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn sample_compare_level_signed(coords : vec2f, array : i32, level : f32) {
+ let l = textureSampleLevel(t, s, coords, array, level);
+}
+
+fn sample_compare_level_unsigned(coords : vec2f, array : u32, level : f32) {
+ let l = textureSampleLevel(t, s, coords, array, level);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ sample_compare_level_signed(non_uniform.xy, i32(non_uniform.x), non_uniform.x);
+ sample_compare_level_unsigned(non_uniform.xy, u32(non_uniform.x), non_uniform.x);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureSampleBaseClampToEdge) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_external;
+
+@group(0) @binding(1) var s : sampler;
+
+fn sample_base_clamp_to_edge(coords : vec2f) {
+ let l = textureSampleBaseClampToEdge(t, s, coords);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureStore_1D) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_storage_1d<rgba8sint, write>;
+
+fn store_signed(coords : i32, value : vec4i) {
+ textureStore(t, coords, value);
+}
+
+fn store_unsigned(coords : u32, value : vec4i) {
+ textureStore(t, coords, value);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ store_signed(i32(non_uniform.x), vec4i(non_uniform));
+ store_unsigned(u32(non_uniform.x), vec4i(non_uniform));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_storage_1d<rgba8sint, write>;
+
+fn store_signed(coords : i32, value : vec4i) {
+ textureStore(t, clamp(coords, 0, i32((textureDimensions(t) - 1))), value);
+}
+
+fn store_unsigned(coords : u32, value : vec4i) {
+ textureStore(t, min(coords, (textureDimensions(t) - 1)), value);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ store_signed(i32(non_uniform.x), vec4i(non_uniform));
+ store_unsigned(u32(non_uniform.x), vec4i(non_uniform));
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_storage_1d<rgba8sint, write>;
+
+fn store_signed(coords : i32, value : vec4i) {
+ let coords_1 = u32(coords);
+ if (all((coords_1 < textureDimensions(t)))) {
+ textureStore(t, coords_1, value);
+ }
+}
+
+fn store_unsigned(coords : u32, value : vec4i) {
+ let coords_2 = u32(coords);
+ if (all((coords_2 < textureDimensions(t)))) {
+ textureStore(t, coords_2, value);
+ }
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ store_signed(i32(non_uniform.x), vec4i(non_uniform));
+ store_unsigned(u32(non_uniform.x), vec4i(non_uniform));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureStore_2D) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_storage_2d<rgba8sint, write>;
+
+fn store_signed(coords : vec2i, value : vec4i) {
+ textureStore(t, coords, value);
+}
+
+fn store_unsigned(coords : vec2u, value : vec4i) {
+ textureStore(t, coords, value);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ store_signed(vec2i(non_uniform.xy), vec4i(non_uniform));
+ store_unsigned(vec2u(non_uniform.xy), vec4i(non_uniform));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_storage_2d<rgba8sint, write>;
+
+fn store_signed(coords : vec2i, value : vec4i) {
+ textureStore(t, clamp(coords, vec2(0), vec2<i32>((textureDimensions(t) - vec2(1)))), value);
+}
+
+fn store_unsigned(coords : vec2u, value : vec4i) {
+ textureStore(t, min(coords, (textureDimensions(t) - vec2(1))), value);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ store_signed(vec2i(non_uniform.xy), vec4i(non_uniform));
+ store_unsigned(vec2u(non_uniform.xy), vec4i(non_uniform));
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_storage_2d<rgba8sint, write>;
+
+fn store_signed(coords : vec2i, value : vec4i) {
+ let coords_1 = vec2<u32>(coords);
+ if (all((coords_1 < textureDimensions(t)))) {
+ textureStore(t, coords_1, value);
+ }
+}
+
+fn store_unsigned(coords : vec2u, value : vec4i) {
+ let coords_2 = vec2<u32>(coords);
+ if (all((coords_2 < textureDimensions(t)))) {
+ textureStore(t, coords_2, value);
+ }
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ store_signed(vec2i(non_uniform.xy), vec4i(non_uniform));
+ store_unsigned(vec2u(non_uniform.xy), vec4i(non_uniform));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureStore_2DArray) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_storage_2d_array<rgba8sint, write>;
+
+fn store_signed(coords : vec2i, array : i32, value : vec4i) {
+ textureStore(t, coords, array, value);
+}
+
+fn store_unsigned(coords : vec2u, array : u32, value : vec4i) {
+ textureStore(t, coords, array, value);
+}
+
+fn store_mixed(coords : vec2i, array : u32, value : vec4i) {
+ textureStore(t, coords, array, value);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ store_signed(vec2i(non_uniform.xy), i32(non_uniform.x), vec4i(non_uniform));
+ store_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x), vec4i(non_uniform));
+ store_mixed(vec2i(non_uniform.xy), u32(non_uniform.x), vec4i(non_uniform));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_storage_2d_array<rgba8sint, write>;
+
+fn store_signed(coords : vec2i, array : i32, value : vec4i) {
+ textureStore(t, clamp(coords, vec2(0), vec2<i32>((textureDimensions(t) - vec2(1)))), clamp(array, 0, i32((textureNumLayers(t) - 1))), value);
+}
+
+fn store_unsigned(coords : vec2u, array : u32, value : vec4i) {
+ textureStore(t, min(coords, (textureDimensions(t) - vec2(1))), min(array, (textureNumLayers(t) - 1)), value);
+}
+
+fn store_mixed(coords : vec2i, array : u32, value : vec4i) {
+ textureStore(t, clamp(coords, vec2(0), vec2<i32>((textureDimensions(t) - vec2(1)))), min(array, (textureNumLayers(t) - 1)), value);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ store_signed(vec2i(non_uniform.xy), i32(non_uniform.x), vec4i(non_uniform));
+ store_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x), vec4i(non_uniform));
+ store_mixed(vec2i(non_uniform.xy), u32(non_uniform.x), vec4i(non_uniform));
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_storage_2d_array<rgba8sint, write>;
+
+fn store_signed(coords : vec2i, array : i32, value : vec4i) {
+ let coords_1 = vec2<u32>(coords);
+ let array_idx = u32(array);
+ if ((all((coords_1 < textureDimensions(t))) & (array_idx < textureNumLayers(t)))) {
+ textureStore(t, coords_1, array_idx, value);
+ }
+}
+
+fn store_unsigned(coords : vec2u, array : u32, value : vec4i) {
+ let coords_2 = vec2<u32>(coords);
+ let array_idx_1 = u32(array);
+ if ((all((coords_2 < textureDimensions(t))) & (array_idx_1 < textureNumLayers(t)))) {
+ textureStore(t, coords_2, array_idx_1, value);
+ }
+}
+
+fn store_mixed(coords : vec2i, array : u32, value : vec4i) {
+ let coords_3 = vec2<u32>(coords);
+ let array_idx_2 = u32(array);
+ if ((all((coords_3 < textureDimensions(t))) & (array_idx_2 < textureNumLayers(t)))) {
+ textureStore(t, coords_3, array_idx_2, value);
+ }
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ store_signed(vec2i(non_uniform.xy), i32(non_uniform.x), vec4i(non_uniform));
+ store_unsigned(vec2u(non_uniform.xy), u32(non_uniform.x), vec4i(non_uniform));
+ store_mixed(vec2i(non_uniform.xy), u32(non_uniform.x), vec4i(non_uniform));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureStore_3D) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_storage_3d<rgba8sint, write>;
+
+fn store_signed(coords : vec3i, value : vec4i) {
+ textureStore(t, coords, value);
+}
+
+fn store_unsigned(coords : vec3u, value : vec4i) {
+ textureStore(t, coords, value);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ store_signed(vec3i(non_uniform.xyz), vec4i(non_uniform));
+ store_unsigned(vec3u(non_uniform.xyz), vec4i(non_uniform));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_storage_3d<rgba8sint, write>;
+
+fn store_signed(coords : vec3i, value : vec4i) {
+ textureStore(t, clamp(coords, vec3(0), vec3<i32>((textureDimensions(t) - vec3(1)))), value);
+}
+
+fn store_unsigned(coords : vec3u, value : vec4i) {
+ textureStore(t, min(coords, (textureDimensions(t) - vec3(1))), value);
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ store_signed(vec3i(non_uniform.xyz), vec4i(non_uniform));
+ store_unsigned(vec3u(non_uniform.xyz), vec4i(non_uniform));
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_storage_3d<rgba8sint, write>;
+
+fn store_signed(coords : vec3i, value : vec4i) {
+ let coords_1 = vec3<u32>(coords);
+ if (all((coords_1 < textureDimensions(t)))) {
+ textureStore(t, coords_1, value);
+ }
+}
+
+fn store_unsigned(coords : vec3u, value : vec4i) {
+ let coords_2 = vec3<u32>(coords);
+ if (all((coords_2 < textureDimensions(t)))) {
+ textureStore(t, coords_2, value);
+ }
+}
+
+@fragment
+fn main(@builtin(position) non_uniform : vec4f) {
+ store_signed(vec3i(non_uniform.xyz), vec4i(non_uniform));
+ store_unsigned(vec3u(non_uniform.xyz), vec4i(non_uniform));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Other
+////////////////////////////////////////////////////////////////////////////////
+TEST_P(RobustnessTest, ShadowedVariable) {
+ auto* src = R"(
+fn f() {
+ var a : array<f32, 3>;
+ var i : u32;
+ {
+ var a : array<f32, 5>;
+ var b : f32 = a[i];
+ }
+ var c : f32 = a[i];
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+fn f() {
+ var a : array<f32, 3>;
+ var i : u32;
+ {
+ var a : array<f32, 5>;
+ var b : f32 = a[min(i, 4u)];
+ }
+ var c : f32 = a[min(i, 2u)];
+}
+)",
+ /* predicate */ R"(
+fn f() {
+ var a : array<f32, 3>;
+ var i : u32;
+ {
+ var a : array<f32, 5>;
+ let index = i;
+ let predicate = (u32(index) <= 4u);
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = a[index];
+ }
+ var b : f32 = predicated_expr;
+ }
+ let index_1 = i;
+ let predicate_1 = (u32(index_1) <= 2u);
+ var predicated_expr_1 : f32;
+ if (predicate_1) {
+ predicated_expr_1 = a[index_1];
+ }
+ var c : f32 = predicated_expr_1;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+ EXPECT_EQ(expect, str(got));
+}
+
+// Check that existing use of min() and arrayLength() do not get renamed.
+TEST_P(RobustnessTest, DontRenameSymbols) {
+ auto* src = R"(
+struct S {
+ a : f32,
+ b : array<f32>,
+}
+
+@group(0) @binding(0) var<storage, read> s : S;
+
+const c : u32 = 1u;
+
+fn f() {
+ let b : f32 = s.b[c];
+ let x : i32 = min(1, 2);
+ let y : u32 = arrayLength(&(s.b));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+struct S {
+ a : f32,
+ b : array<f32>,
+}
+
+@group(0) @binding(0) var<storage, read> s : S;
+
+const c : u32 = 1u;
+
+fn f() {
+ let b : f32 = s.b[min(c, (arrayLength(&(s.b)) - 1u))];
+ let x : i32 = min(1, 2);
+ let y : u32 = arrayLength(&(s.b));
+}
+)",
+ /* predicate */ R"(
+struct S {
+ a : f32,
+ b : array<f32>,
+}
+
+@group(0) @binding(0) var<storage, read> s : S;
+
+const c : u32 = 1u;
+
+fn f() {
+ let index = c;
+ let predicate = (u32(index) <= (arrayLength(&(s.b)) - 1u));
+ var predicated_expr : f32;
+ if (predicate) {
+ predicated_expr = s.b[index];
+ }
+ let b : f32 = predicated_expr;
+ let x : i32 = min(1, 2);
+ let y : u32 = arrayLength(&(s.b));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, WorkgroupOverrideCount) {
+ auto* src = R"(
+override N = 123;
+
+var<workgroup> w : array<f32, N>;
+
+fn f() {
+ var b : f32 = w[1i];
+}
+)";
+
+ auto* expect =
+ Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */
+ R"(error: array size is an override-expression, when expected a constant-expression.
+Was the SubstituteOverride transform run?)",
+ /* predicate */
+ R"(error: array size is an override-expression, when expected a constant-expression.
+Was the SubstituteOverride transform run?)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// atomic predication
+////////////////////////////////////////////////////////////////////////////////
+
+TEST_P(RobustnessTest, AtomicLoad) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> a : array<atomic<i32>, 4>;
+
+fn f() {
+ let i = 0;
+ let r = atomicLoad(&(a[i]));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var<storage, read_write> a : array<atomic<i32>, 4>;
+
+fn f() {
+ let i = 0;
+ let r = atomicLoad(&(a[min(u32(i), 3u)]));
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var<storage, read_write> a : array<atomic<i32>, 4>;
+
+fn f() {
+ let i = 0;
+ let index = i;
+ let predicate = (u32(index) <= 3u);
+ var predicated_value : i32;
+ if (predicate) {
+ predicated_value = atomicLoad(&(a[index]));
+ }
+ let r = predicated_value;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, AtomicAnd) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> a : array<atomic<i32>, 4>;
+
+fn f() {
+ let i = 0;
+ atomicAnd(&(a[i]), 10);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var<storage, read_write> a : array<atomic<i32>, 4>;
+
+fn f() {
+ let i = 0;
+ atomicAnd(&(a[min(u32(i), 3u)]), 10);
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var<storage, read_write> a : array<atomic<i32>, 4>;
+
+fn f() {
+ let i = 0;
+ let index = i;
+ let predicate = (u32(index) <= 3u);
+ if (predicate) {
+ atomicAnd(&(a[index]), 10);
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// workgroupUniformLoad
+////////////////////////////////////////////////////////////////////////////////
+
+TEST_P(RobustnessTest, WorkgroupUniformLoad) {
+ auto* src = R"(
+var<workgroup> a : array<u32, 32>;
+
+@compute @workgroup_size(1)
+fn f() {
+ let i = 1;
+ let p = &(a[i]);
+ let v = workgroupUniformLoad(p);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+var<workgroup> a : array<u32, 32>;
+
+@compute @workgroup_size(1)
+fn f() {
+ let i = 1;
+ let p = &(a[min(u32(i), 31u)]);
+ let v = workgroupUniformLoad(p);
+}
+)",
+ /* predicate */ R"(
+var<workgroup> a : array<u32, 32>;
+
+@compute @workgroup_size(1)
+fn f() {
+ let i = 1;
+ let index = i;
+ let predicate = (u32(index) <= 31u);
+ let p = &(a[index]);
+ var predicated_value : u32;
+ if (predicate) {
+ predicated_value = workgroupUniformLoad(p);
+ } else {
+ workgroupBarrier();
+ }
+ let v = predicated_value;
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Usage in loops
+////////////////////////////////////////////////////////////////////////////////
+
+TEST_P(RobustnessTest, ArrayVal_ForLoopInit) {
+ auto* src = R"(
+fn f() {
+ let a = array<i32, 3>();
+ var v = 1;
+ for(var i = a[v]; (i < 3); i++) {
+ var in_loop = 42;
+ }
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+fn f() {
+ let a = array<i32, 3>();
+ var v = 1;
+ for(var i = a[min(u32(v), 2u)]; (i < 3); i++) {
+ var in_loop = 42;
+ }
+}
+)",
+ /* predicate */ R"(
+fn f() {
+ let a = array<i32, 3>();
+ var v = 1;
+ {
+ let index = v;
+ let predicate = (u32(index) <= 2u);
+ var predicated_expr : i32;
+ if (predicate) {
+ predicated_expr = a[index];
+ }
+ var i = predicated_expr;
+ loop {
+ if (!((i < 3))) {
+ break;
+ }
+ {
+ var in_loop = 42;
+ }
+
+ continuing {
+ i++;
+ }
+ }
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, ArrayVal_ForLoopCond) {
+ auto* src = R"(
+fn f() {
+ let a = array<i32, 3>();
+ var v = 1;
+ for(var i = 0; (i < a[v]); i++) {
+ var in_loop = 42;
+ }
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+fn f() {
+ let a = array<i32, 3>();
+ var v = 1;
+ for(var i = 0; (i < a[min(u32(v), 2u)]); i++) {
+ var in_loop = 42;
+ }
+}
+)",
+ /* predicate */ R"(
+fn f() {
+ let a = array<i32, 3>();
+ var v = 1;
+ {
+ var i = 0;
+ loop {
+ let index = v;
+ let predicate = (u32(index) <= 2u);
+ var predicated_expr : i32;
+ if (predicate) {
+ predicated_expr = a[index];
+ }
+ if (!((i < predicated_expr))) {
+ break;
+ }
+ {
+ var in_loop = 42;
+ }
+
+ continuing {
+ i++;
+ }
+ }
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, ArrayVal_ForLoopCont) {
+ auto* src = R"(
+fn f() {
+ let a = array<i32, 3>();
+ for(var i = 0; (i < 5); i = a[i]) {
+ var in_loop = 42;
+ }
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+fn f() {
+ let a = array<i32, 3>();
+ for(var i = 0; (i < 5); i = a[min(u32(i), 2u)]) {
+ var in_loop = 42;
+ }
+}
+)",
+ /* predicate */ R"(
+fn f() {
+ let a = array<i32, 3>();
+ {
+ var i = 0;
+ loop {
+ if (!((i < 5))) {
+ break;
+ }
+ {
+ var in_loop = 42;
+ }
+
+ continuing {
+ let index = i;
+ let predicate = (u32(index) <= 2u);
+ var predicated_expr : i32;
+ if (predicate) {
+ predicated_expr = a[index];
+ }
+ i = predicated_expr;
+ }
+ }
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureLoad_ForLoopInit) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<i32>;
+
+fn f() {
+ var coords = vec2(1);
+ var level = 1;
+ for(var i = textureLoad(t, coords, level).x; (i < 3); i++) {
+ var in_loop = 42;
+ }
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_2d<i32>;
+
+fn f() {
+ var coords = vec2(1);
+ var level = 1;
+ {
+ let level_idx = min(u32(level), (textureNumLevels(t) - 1));
+ var i = textureLoad(t, clamp(coords, vec2(0), vec2<i32>((textureDimensions(t, level_idx) - vec2(1)))), level_idx).x;
+ loop {
+ if (!((i < 3))) {
+ break;
+ }
+ {
+ var in_loop = 42;
+ }
+
+ continuing {
+ i++;
+ }
+ }
+ }
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_2d<i32>;
+
+fn f() {
+ var coords = vec2(1);
+ var level = 1;
+ {
+ let level_idx = u32(level);
+ let num_levels = textureNumLevels(t);
+ let coords_1 = vec2<u32>(coords);
+ var predicated_value : vec4<i32>;
+ if (((level_idx < num_levels) & all((coords_1 < textureDimensions(t, min(level_idx, (num_levels - 1))))))) {
+ predicated_value = textureLoad(t, coords_1, level_idx);
+ }
+ var i = predicated_value.x;
+ loop {
+ if (!((i < 3))) {
+ break;
+ }
+ {
+ var in_loop = 42;
+ }
+
+ continuing {
+ i++;
+ }
+ }
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureLoad_ForLoopCond) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<i32>;
+
+fn f() {
+ var coords = vec2(1);
+ var level = 1;
+ for(var i = 0; (i < textureLoad(t, coords, level).x); i++) {
+ var in_loop = 42;
+ }
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_2d<i32>;
+
+fn f() {
+ var coords = vec2(1);
+ var level = 1;
+ {
+ var i = 0;
+ loop {
+ let level_idx = min(u32(level), (textureNumLevels(t) - 1));
+ if (!((i < textureLoad(t, clamp(coords, vec2(0), vec2<i32>((textureDimensions(t, level_idx) - vec2(1)))), level_idx).x))) {
+ break;
+ }
+ {
+ var in_loop = 42;
+ }
+
+ continuing {
+ i++;
+ }
+ }
+ }
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_2d<i32>;
+
+fn f() {
+ var coords = vec2(1);
+ var level = 1;
+ {
+ var i = 0;
+ loop {
+ let level_idx = u32(level);
+ let num_levels = textureNumLevels(t);
+ let coords_1 = vec2<u32>(coords);
+ var predicated_value : vec4<i32>;
+ if (((level_idx < num_levels) & all((coords_1 < textureDimensions(t, min(level_idx, (num_levels - 1))))))) {
+ predicated_value = textureLoad(t, coords_1, level_idx);
+ }
+ if (!((i < predicated_value.x))) {
+ break;
+ }
+ {
+ var in_loop = 42;
+ }
+
+ continuing {
+ i++;
+ }
+ }
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureLoad_ForLoopCont) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<i32>;
+
+fn f() {
+ var level = 1;
+ for(var i = 0; (i < 5); i = textureLoad(t, vec2(i), i).x) {
+ var in_loop = 42;
+ }
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_2d<i32>;
+
+fn f() {
+ var level = 1;
+ {
+ var i = 0;
+ loop {
+ if (!((i < 5))) {
+ break;
+ }
+ {
+ var in_loop = 42;
+ }
+
+ continuing {
+ let level_idx = min(u32(i), (textureNumLevels(t) - 1));
+ i = textureLoad(t, clamp(vec2(i), vec2(0), vec2<i32>((textureDimensions(t, level_idx) - vec2(1)))), level_idx).x;
+ }
+ }
+ }
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_2d<i32>;
+
+fn f() {
+ var level = 1;
+ {
+ var i = 0;
+ loop {
+ if (!((i < 5))) {
+ break;
+ }
+ {
+ var in_loop = 42;
+ }
+
+ continuing {
+ let level_idx = u32(i);
+ let num_levels = textureNumLevels(t);
+ let coords = vec2<u32>(vec2(i));
+ var predicated_value : vec4<i32>;
+ if (((level_idx < num_levels) & all((coords < textureDimensions(t, min(level_idx, (num_levels - 1))))))) {
+ predicated_value = textureLoad(t, coords, level_idx);
+ }
+ i = predicated_value.x;
+ }
+ }
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureStore_ForLoopInit) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_storage_1d<rgba8sint, write>;
+
+fn f() {
+ var coords = 1;
+ var value = vec4(1);
+ var i = 0;
+ for(textureStore(t, coords, value); (i < 3); i++) {
+ var in_loop = 42;
+ }
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_storage_1d<rgba8sint, write>;
+
+fn f() {
+ var coords = 1;
+ var value = vec4(1);
+ var i = 0;
+ for(textureStore(t, clamp(coords, 0, i32((textureDimensions(t) - 1))), value); (i < 3); i++) {
+ var in_loop = 42;
+ }
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_storage_1d<rgba8sint, write>;
+
+fn f() {
+ var coords = 1;
+ var value = vec4(1);
+ var i = 0;
+ {
+ let coords_1 = u32(coords);
+ if (all((coords_1 < textureDimensions(t)))) {
+ textureStore(t, coords_1, value);
+ }
+ loop {
+ if (!((i < 3))) {
+ break;
+ }
+ {
+ var in_loop = 42;
+ }
+
+ continuing {
+ i++;
+ }
+ }
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, TextureStore_ForLoopCont) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_storage_1d<rgba8sint, write>;
+
+fn f() {
+ var level = 1;
+ var value = vec4(1);
+ for(var i = 0; (i < 3); textureStore(t, i, value)) {
+ var in_loop = 42;
+ }
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var t : texture_storage_1d<rgba8sint, write>;
+
+fn f() {
+ var level = 1;
+ var value = vec4(1);
+ for(var i = 0; (i < 3); textureStore(t, clamp(i, 0, i32((textureDimensions(t) - 1))), value)) {
+ var in_loop = 42;
+ }
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var t : texture_storage_1d<rgba8sint, write>;
+
+fn f() {
+ var level = 1;
+ var value = vec4(1);
+ {
+ var i = 0;
+ loop {
+ if (!((i < 3))) {
+ break;
+ }
+ {
+ var in_loop = 42;
+ }
+
+ continuing {
+ let coords = u32(i);
+ if (all((coords < textureDimensions(t)))) {
+ textureStore(t, coords, value);
+ }
+ }
+ }
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, AtomicXor_ForLoopInit) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> a : array<atomic<i32>, 4>;
+
+fn f() {
+ var i = 0;
+ for(atomicXor(&(a[i]), 4); (i < 3); i++) {
+ var in_loop = 42;
+ }
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var<storage, read_write> a : array<atomic<i32>, 4>;
+
+fn f() {
+ var i = 0;
+ for(atomicXor(&(a[min(u32(i), 3u)]), 4); (i < 3); i++) {
+ var in_loop = 42;
+ }
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var<storage, read_write> a : array<atomic<i32>, 4>;
+
+fn f() {
+ var i = 0;
+ {
+ let index = i;
+ let predicate = (u32(index) <= 3u);
+ if (predicate) {
+ atomicXor(&(a[index]), 4);
+ }
+ loop {
+ if (!((i < 3))) {
+ break;
+ }
+ {
+ var in_loop = 42;
+ }
+
+ continuing {
+ i++;
+ }
+ }
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, AtomicXor_ForLoopCond) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> a : array<atomic<i32>, 4>;
+
+fn f() {
+ for(var i = 0; (i < atomicXor(&(a[i]), 4)); i++) {
+ var in_loop = 42;
+ }
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var<storage, read_write> a : array<atomic<i32>, 4>;
+
+fn f() {
+ for(var i = 0; (i < atomicXor(&(a[min(u32(i), 3u)]), 4)); i++) {
+ var in_loop = 42;
+ }
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var<storage, read_write> a : array<atomic<i32>, 4>;
+
+fn f() {
+ {
+ var i = 0;
+ loop {
+ let index = i;
+ let predicate = (u32(index) <= 3u);
+ var predicated_value : i32;
+ if (predicate) {
+ predicated_value = atomicXor(&(a[index]), 4);
+ }
+ if (!((i < predicated_value))) {
+ break;
+ }
+ {
+ var in_loop = 42;
+ }
+
+ continuing {
+ i++;
+ }
+ }
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, AtomicXor_ForLoopCont) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> a : array<atomic<i32>, 4>;
+
+fn f() {
+ for(var i = 0; (i < 5); i = atomicXor(&(a[i]), 4)) {
+ var in_loop = 42;
+ }
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+@group(0) @binding(0) var<storage, read_write> a : array<atomic<i32>, 4>;
+
+fn f() {
+ for(var i = 0; (i < 5); i = atomicXor(&(a[min(u32(i), 3u)]), 4)) {
+ var in_loop = 42;
+ }
+}
+)",
+ /* predicate */ R"(
+@group(0) @binding(0) var<storage, read_write> a : array<atomic<i32>, 4>;
+
+fn f() {
+ {
+ var i = 0;
+ loop {
+ if (!((i < 5))) {
+ break;
+ }
+ {
+ var in_loop = 42;
+ }
+
+ continuing {
+ let index = i;
+ let predicate = (u32(index) <= 3u);
+ var predicated_value : i32;
+ if (predicate) {
+ predicated_value = atomicXor(&(a[index]), 4);
+ }
+ i = predicated_value;
+ }
+ }
+ }
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// User pointer parameters
+////////////////////////////////////////////////////////////////////////////////
+
+TEST_P(RobustnessTest, Read_PrivatePointerParameter_IndexWithConstant) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<private, i32>, post : i32) -> i32 {
+ return ((pre + *(p)) + post);
+}
+
+fn y(pre : i32, p : ptr<private, i32>, post : i32) -> i32 {
+ return x(pre, p, post);
+}
+
+fn z() {
+ y(1, &(a[1]), 2);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ src,
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<private, i32>, p_predicate : bool, post : i32) -> i32 {
+ var predicated_expr : i32;
+ if (p_predicate) {
+ predicated_expr = *(p);
+ }
+ return ((pre + predicated_expr) + post);
+}
+
+fn y(pre : i32, p : ptr<private, i32>, p_predicate_1 : bool, post : i32) -> i32 {
+ return x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ y(1, &(a[1]), true, 2);
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_WorkgroupPointerParameter_IndexWithConstant) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<workgroup> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<workgroup, i32>, post : i32) -> i32 {
+ return ((pre + *(p)) + post);
+}
+
+fn y(pre : i32, p : ptr<workgroup, i32>, post : i32) -> i32 {
+ return x(pre, p, post);
+}
+
+fn z() {
+ y(1, &(a[1]), 2);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ src,
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<workgroup> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<workgroup, i32>, p_predicate : bool, post : i32) -> i32 {
+ var predicated_expr : i32;
+ if (p_predicate) {
+ predicated_expr = *(p);
+ }
+ return ((pre + predicated_expr) + post);
+}
+
+fn y(pre : i32, p : ptr<workgroup, i32>, p_predicate_1 : bool, post : i32) -> i32 {
+ return x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ y(1, &(a[1]), true, 2);
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_UniformPointerParameter_IndexWithConstant) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> a : array<vec4i, 4>;
+
+fn x(pre : vec4i, p : ptr<uniform, vec4i>, post : vec4i) -> vec4i {
+ return ((pre + *(p)) + post);
+}
+
+fn y(pre : vec4i, p : ptr<uniform, vec4i>, post : vec4i) -> vec4i {
+ return x(pre, p, post);
+}
+
+fn z() {
+ y(vec4(1), &(a[1]), vec4(2));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ src,
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> a : array<vec4i, 4>;
+
+fn x(pre : vec4i, p : ptr<uniform, vec4i>, p_predicate : bool, post : vec4i) -> vec4i {
+ var predicated_expr : vec4<i32>;
+ if (p_predicate) {
+ predicated_expr = *(p);
+ }
+ return ((pre + predicated_expr) + post);
+}
+
+fn y(pre : vec4i, p : ptr<uniform, vec4i>, p_predicate_1 : bool, post : vec4i) -> vec4i {
+ return x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ y(vec4(1), &(a[1]), true, vec4(2));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_StoragePointerParameter_IndexWithConstant) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage> a : array<vec4i, 4>;
+
+fn x(pre : vec4i, p : ptr<storage, vec4i>, post : vec4i) -> vec4i {
+ return ((pre + *(p)) + post);
+}
+
+fn y(pre : vec4i, p : ptr<storage, vec4i>, post : vec4i) -> vec4i {
+ return x(pre, p, post);
+}
+
+fn z() {
+ y(vec4(1), &(a[1]), vec4(2));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ src,
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage> a : array<vec4i, 4>;
+
+fn x(pre : vec4i, p : ptr<storage, vec4i>, p_predicate : bool, post : vec4i) -> vec4i {
+ var predicated_expr : vec4<i32>;
+ if (p_predicate) {
+ predicated_expr = *(p);
+ }
+ return ((pre + predicated_expr) + post);
+}
+
+fn y(pre : vec4i, p : ptr<storage, vec4i>, p_predicate_1 : bool, post : vec4i) -> vec4i {
+ return x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ y(vec4(1), &(a[1]), true, vec4(2));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_FunctionPointerParameter_IndexWithConstant) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn x(pre : i32, p : ptr<function, i32>, post : i32) -> i32 {
+ return ((pre + *(p)) + post);
+}
+
+fn y(pre : i32, p : ptr<function, i32>, post : i32) -> i32 {
+ return x(pre, p, post);
+}
+
+fn z() {
+ var a : array<i32, 4>;
+ y(1, &(a[1]), 2);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ src,
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn x(pre : i32, p : ptr<function, i32>, p_predicate : bool, post : i32) -> i32 {
+ var predicated_expr : i32;
+ if (p_predicate) {
+ predicated_expr = *(p);
+ }
+ return ((pre + predicated_expr) + post);
+}
+
+fn y(pre : i32, p : ptr<function, i32>, p_predicate_1 : bool, post : i32) -> i32 {
+ return x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ var a : array<i32, 4>;
+ y(1, &(a[1]), true, 2);
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_PrivatePointerParameter_IndexWithLet) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<private, i32>, post : i32) -> i32 {
+ return ((pre + *(p)) + post);
+}
+
+fn y(pre : i32, p : ptr<private, i32>, post : i32) -> i32 {
+ return x(pre, p, post);
+}
+
+fn z() {
+ let i = 0;
+ y(1, &(a[i]), 2);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<private, i32>, post : i32) -> i32 {
+ return ((pre + *(p)) + post);
+}
+
+fn y(pre : i32, p : ptr<private, i32>, post : i32) -> i32 {
+ return x(pre, p, post);
+}
+
+fn z() {
+ let i = 0;
+ y(1, &(a[min(u32(i), 3u)]), 2);
+}
+)",
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<private, i32>, p_predicate : bool, post : i32) -> i32 {
+ var predicated_expr : i32;
+ if (p_predicate) {
+ predicated_expr = *(p);
+ }
+ return ((pre + predicated_expr) + post);
+}
+
+fn y(pre : i32, p : ptr<private, i32>, p_predicate_1 : bool, post : i32) -> i32 {
+ return x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ let i = 0;
+ let index = i;
+ let predicate = (u32(index) <= 3u);
+ y(1, &(a[index]), predicate, 2);
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_WorkgroupPointerParameter_IndexWithLet) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<workgroup> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<workgroup, i32>, post : i32) -> i32 {
+ return ((pre + *(p)) + post);
+}
+
+fn y(pre : i32, p : ptr<workgroup, i32>, post : i32) -> i32 {
+ return x(pre, p, post);
+}
+
+fn z() {
+ let i = 0;
+ y(1, &(a[i]), 2);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<workgroup> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<workgroup, i32>, post : i32) -> i32 {
+ return ((pre + *(p)) + post);
+}
+
+fn y(pre : i32, p : ptr<workgroup, i32>, post : i32) -> i32 {
+ return x(pre, p, post);
+}
+
+fn z() {
+ let i = 0;
+ y(1, &(a[min(u32(i), 3u)]), 2);
+}
+)",
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<workgroup> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<workgroup, i32>, p_predicate : bool, post : i32) -> i32 {
+ var predicated_expr : i32;
+ if (p_predicate) {
+ predicated_expr = *(p);
+ }
+ return ((pre + predicated_expr) + post);
+}
+
+fn y(pre : i32, p : ptr<workgroup, i32>, p_predicate_1 : bool, post : i32) -> i32 {
+ return x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ let i = 0;
+ let index = i;
+ let predicate = (u32(index) <= 3u);
+ y(1, &(a[index]), predicate, 2);
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_UniformPointerParameter_IndexWithLet) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> a : array<vec4i, 4>;
+
+fn x(pre : vec4i, p : ptr<uniform, vec4i>, post : vec4i) -> vec4i {
+ return ((pre + *(p)) + post);
+}
+
+fn y(pre : vec4i, p : ptr<uniform, vec4i>, post : vec4i) -> vec4i {
+ return x(pre, p, post);
+}
+
+fn z() {
+ let i = 0;
+ y(vec4(1), &(a[i]), vec4(2));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> a : array<vec4i, 4>;
+
+fn x(pre : vec4i, p : ptr<uniform, vec4i>, post : vec4i) -> vec4i {
+ return ((pre + *(p)) + post);
+}
+
+fn y(pre : vec4i, p : ptr<uniform, vec4i>, post : vec4i) -> vec4i {
+ return x(pre, p, post);
+}
+
+fn z() {
+ let i = 0;
+ y(vec4(1), &(a[min(u32(i), 3u)]), vec4(2));
+}
+)",
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<uniform> a : array<vec4i, 4>;
+
+fn x(pre : vec4i, p : ptr<uniform, vec4i>, p_predicate : bool, post : vec4i) -> vec4i {
+ var predicated_expr : vec4<i32>;
+ if (p_predicate) {
+ predicated_expr = *(p);
+ }
+ return ((pre + predicated_expr) + post);
+}
+
+fn y(pre : vec4i, p : ptr<uniform, vec4i>, p_predicate_1 : bool, post : vec4i) -> vec4i {
+ return x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ let i = 0;
+ let index = i;
+ let predicate = (u32(index) <= 3u);
+ y(vec4(1), &(a[index]), predicate, vec4(2));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_StoragePointerParameter_IndexWithLet) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage> a : array<vec4i, 4>;
+
+fn x(pre : vec4i, p : ptr<storage, vec4i>, post : vec4i) -> vec4i {
+ return ((pre + *(p)) + post);
+}
+
+fn y(pre : vec4i, p : ptr<storage, vec4i>, post : vec4i) -> vec4i {
+ return x(pre, p, post);
+}
+
+fn z() {
+ let i = 0;
+ y(vec4(1), &(a[i]), vec4(2));
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage> a : array<vec4i, 4>;
+
+fn x(pre : vec4i, p : ptr<storage, vec4i>, post : vec4i) -> vec4i {
+ return ((pre + *(p)) + post);
+}
+
+fn y(pre : vec4i, p : ptr<storage, vec4i>, post : vec4i) -> vec4i {
+ return x(pre, p, post);
+}
+
+fn z() {
+ let i = 0;
+ y(vec4(1), &(a[min(u32(i), 3u)]), vec4(2));
+}
+)",
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage> a : array<vec4i, 4>;
+
+fn x(pre : vec4i, p : ptr<storage, vec4i>, p_predicate : bool, post : vec4i) -> vec4i {
+ var predicated_expr : vec4<i32>;
+ if (p_predicate) {
+ predicated_expr = *(p);
+ }
+ return ((pre + predicated_expr) + post);
+}
+
+fn y(pre : vec4i, p : ptr<storage, vec4i>, p_predicate_1 : bool, post : vec4i) -> vec4i {
+ return x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ let i = 0;
+ let index = i;
+ let predicate = (u32(index) <= 3u);
+ y(vec4(1), &(a[index]), predicate, vec4(2));
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_FunctionPointerParameter_IndexWithLet) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn x(pre : i32, p : ptr<function, i32>, post : i32) -> i32 {
+ return ((pre + *(p)) + post);
+}
+
+fn y(pre : i32, p : ptr<function, i32>, post : i32) -> i32 {
+ return x(pre, p, post);
+}
+
+fn z() {
+ var a : array<i32, 4>;
+ let i = 0;
+ y(1, &(a[i]), 2);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn x(pre : i32, p : ptr<function, i32>, post : i32) -> i32 {
+ return ((pre + *(p)) + post);
+}
+
+fn y(pre : i32, p : ptr<function, i32>, post : i32) -> i32 {
+ return x(pre, p, post);
+}
+
+fn z() {
+ var a : array<i32, 4>;
+ let i = 0;
+ y(1, &(a[min(u32(i), 3u)]), 2);
+}
+)",
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn x(pre : i32, p : ptr<function, i32>, p_predicate : bool, post : i32) -> i32 {
+ var predicated_expr : i32;
+ if (p_predicate) {
+ predicated_expr = *(p);
+ }
+ return ((pre + predicated_expr) + post);
+}
+
+fn y(pre : i32, p : ptr<function, i32>, p_predicate_1 : bool, post : i32) -> i32 {
+ return x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ var a : array<i32, 4>;
+ let i = 0;
+ let index = i;
+ let predicate = (u32(index) <= 3u);
+ y(1, &(a[index]), predicate, 2);
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Write_PrivatePointerParameter_IndexWithConstant) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<private, i32>, post : i32) {
+ *(p) = (pre + post);
+}
+
+fn y(pre : i32, p : ptr<private, i32>, post : i32) {
+ x(pre, p, post);
+}
+
+fn z() {
+ y(1, &(a[1]), 2);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ src,
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<private, i32>, p_predicate : bool, post : i32) {
+ if (p_predicate) {
+ *(p) = (pre + post);
+ }
+}
+
+fn y(pre : i32, p : ptr<private, i32>, p_predicate_1 : bool, post : i32) {
+ x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ y(1, &(a[1]), true, 2);
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Write_WorkgroupPointerParameter_IndexWithConstant) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<workgroup> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<workgroup, i32>, post : i32) {
+ *(p) = (pre + post);
+}
+
+fn y(pre : i32, p : ptr<workgroup, i32>, post : i32) {
+ x(pre, p, post);
+}
+
+fn z() {
+ y(1, &(a[1]), 2);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ src,
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<workgroup> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<workgroup, i32>, p_predicate : bool, post : i32) {
+ if (p_predicate) {
+ *(p) = (pre + post);
+ }
+}
+
+fn y(pre : i32, p : ptr<workgroup, i32>, p_predicate_1 : bool, post : i32) {
+ x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ y(1, &(a[1]), true, 2);
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Write_StoragePointerParameter_IndexWithConstant) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<storage, i32, read_write>, post : i32) {
+ *(p) = (pre + post);
+}
+
+fn y(pre : i32, p : ptr<storage, i32, read_write>, post : i32) {
+ x(pre, p, post);
+}
+
+fn z() {
+ y(1, &(a[1]), 2);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ src,
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<storage, i32, read_write>, p_predicate : bool, post : i32) {
+ if (p_predicate) {
+ *(p) = (pre + post);
+ }
+}
+
+fn y(pre : i32, p : ptr<storage, i32, read_write>, p_predicate_1 : bool, post : i32) {
+ x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ y(1, &(a[1]), true, 2);
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Write_FunctionPointerParameter_IndexWithConstant) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn x(pre : i32, p : ptr<function, i32>, post : i32) {
+ *(p) = (pre + post);
+}
+
+fn y(pre : i32, p : ptr<function, i32>, post : i32) {
+ x(pre, p, post);
+}
+
+fn z() {
+ var a : array<i32, 4>;
+ y(1, &(a[1]), 2);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ src,
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn x(pre : i32, p : ptr<function, i32>, p_predicate : bool, post : i32) {
+ if (p_predicate) {
+ *(p) = (pre + post);
+ }
+}
+
+fn y(pre : i32, p : ptr<function, i32>, p_predicate_1 : bool, post : i32) {
+ x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ var a : array<i32, 4>;
+ y(1, &(a[1]), true, 2);
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Write_PrivatePointerParameter_IndexWithLet) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<private, i32>, post : i32) {
+ *(p) = (pre + post);
+}
+
+fn y(pre : i32, p : ptr<private, i32>, post : i32) {
+ x(pre, p, post);
+}
+
+fn z() {
+ let i = 0;
+ y(1, &(a[i]), 2);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<private, i32>, post : i32) {
+ *(p) = (pre + post);
+}
+
+fn y(pre : i32, p : ptr<private, i32>, post : i32) {
+ x(pre, p, post);
+}
+
+fn z() {
+ let i = 0;
+ y(1, &(a[min(u32(i), 3u)]), 2);
+}
+)",
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<private> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<private, i32>, p_predicate : bool, post : i32) {
+ if (p_predicate) {
+ *(p) = (pre + post);
+ }
+}
+
+fn y(pre : i32, p : ptr<private, i32>, p_predicate_1 : bool, post : i32) {
+ x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ let i = 0;
+ let index = i;
+ let predicate = (u32(index) <= 3u);
+ y(1, &(a[index]), predicate, 2);
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Write_WorkgroupPointerParameter_IndexWithLet) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<workgroup> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<workgroup, i32>, post : i32) {
+ *(p) = (pre + post);
+}
+
+fn y(pre : i32, p : ptr<workgroup, i32>, post : i32) {
+ x(pre, p, post);
+}
+
+fn z() {
+ let i = 0;
+ y(1, &(a[i]), 2);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<workgroup> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<workgroup, i32>, post : i32) {
+ *(p) = (pre + post);
+}
+
+fn y(pre : i32, p : ptr<workgroup, i32>, post : i32) {
+ x(pre, p, post);
+}
+
+fn z() {
+ let i = 0;
+ y(1, &(a[min(u32(i), 3u)]), 2);
+}
+)",
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+var<workgroup> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<workgroup, i32>, p_predicate : bool, post : i32) {
+ if (p_predicate) {
+ *(p) = (pre + post);
+ }
+}
+
+fn y(pre : i32, p : ptr<workgroup, i32>, p_predicate_1 : bool, post : i32) {
+ x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ let i = 0;
+ let index = i;
+ let predicate = (u32(index) <= 3u);
+ y(1, &(a[index]), predicate, 2);
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Write_StoragePointerParameter_IndexWithLet) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<storage, i32, read_write>, post : i32) {
+ *(p) = (pre + post);
+}
+
+fn y(pre : i32, p : ptr<storage, i32, read_write>, post : i32) {
+ x(pre, p, post);
+}
+
+fn z() {
+ let i = 0;
+ y(1, &(a[i]), 2);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<storage, i32, read_write>, post : i32) {
+ *(p) = (pre + post);
+}
+
+fn y(pre : i32, p : ptr<storage, i32, read_write>, post : i32) {
+ x(pre, p, post);
+}
+
+fn z() {
+ let i = 0;
+ y(1, &(a[min(u32(i), 3u)]), 2);
+}
+)",
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> a : array<i32, 4>;
+
+fn x(pre : i32, p : ptr<storage, i32, read_write>, p_predicate : bool, post : i32) {
+ if (p_predicate) {
+ *(p) = (pre + post);
+ }
+}
+
+fn y(pre : i32, p : ptr<storage, i32, read_write>, p_predicate_1 : bool, post : i32) {
+ x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ let i = 0;
+ let index = i;
+ let predicate = (u32(index) <= 3u);
+ y(1, &(a[index]), predicate, 2);
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Write_FunctionPointerParameter_IndexWithLet) {
+ auto* src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn x(pre : i32, p : ptr<function, i32>, post : i32) {
+ *(p) = (pre + post);
+}
+
+fn y(pre : i32, p : ptr<function, i32>, post : i32) {
+ x(pre, p, post);
+}
+
+fn z() {
+ var a : array<i32, 4>;
+ let i = 0;
+ y(1, &(a[i]), 2);
+}
+)";
+
+ auto* expect = Expect(GetParam(),
+ /* ignore */ src,
+ /* clamp */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn x(pre : i32, p : ptr<function, i32>, post : i32) {
+ *(p) = (pre + post);
+}
+
+fn y(pre : i32, p : ptr<function, i32>, post : i32) {
+ x(pre, p, post);
+}
+
+fn z() {
+ var a : array<i32, 4>;
+ let i = 0;
+ y(1, &(a[min(u32(i), 3u)]), 2);
+}
+)",
+ /* predicate */ R"(
+enable chromium_experimental_full_ptr_parameters;
+
+fn x(pre : i32, p : ptr<function, i32>, p_predicate : bool, post : i32) {
+ if (p_predicate) {
+ *(p) = (pre + post);
+ }
+}
+
+fn y(pre : i32, p : ptr<function, i32>, p_predicate_1 : bool, post : i32) {
+ x(pre, p, p_predicate_1, post);
+}
+
+fn z() {
+ var a : array<i32, 4>;
+ let i = 0;
+ let index = i;
+ let predicate = (u32(index) <= 3u);
+ y(1, &(a[index]), predicate, 2);
+}
+)");
+
+ auto got = Run<Robustness>(src, Config(GetParam()));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// disable_runtime_sized_array_index_clamping == true
+////////////////////////////////////////////////////////////////////////////////
+
+TEST_P(RobustnessTest, Read_disable_unsized_array_index_clamping_i32) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read> s : array<f32>;
+
+fn f() {
+ let i = 25i;
+ var d : f32 = s[i];
+}
+)";
+
+ auto* expect = R"(
+@group(0) @binding(0) var<storage, read> s : array<f32>;
+
+fn f() {
+ let i = 25i;
+ var d : f32 = s[u32(i)];
+}
+)";
+
+ auto got = Run<Robustness>(src, Config(GetParam(), true));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_disable_unsized_array_index_clamping_u32) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read> s : array<f32>;
+
+fn f() {
+ let i = 25u;
+ var d : f32 = s[i];
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam(), true));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Read_disable_unsized_array_index_clamping_abstract_int) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read> s : array<f32>;
+
+fn f() {
+ var d : f32 = s[25];
+}
+)";
+
+ auto* expect = R"(
+@group(0) @binding(0) var<storage, read> s : array<f32>;
+
+fn f() {
+ var d : f32 = s[u32(25)];
+}
+)";
+
+ auto got = Run<Robustness>(src, Config(GetParam(), true));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Assign_disable_unsized_array_index_clamping_i32) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> s : array<f32>;
+
+fn f() {
+ let i = 25i;
+ s[i] = 0.5f;
+}
+)";
+
+ auto* expect = R"(
+@group(0) @binding(0) var<storage, read_write> s : array<f32>;
+
+fn f() {
+ let i = 25i;
+ s[u32(i)] = 0.5f;
+}
+)";
+
+ auto got = Run<Robustness>(src, Config(GetParam(), true));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Assign_disable_unsized_array_index_clamping_u32) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> s : array<f32>;
+
+fn f() {
+ let i = 25u;
+ s[i] = 0.5f;
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Robustness>(src, Config(GetParam(), true));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(RobustnessTest, Assign_disable_unsized_array_index_clamping_abstract_int) {
+ auto* src = R"(
+@group(0) @binding(0) var<storage, read_write> s : array<f32>;
+
+fn f() {
+ s[25] = 0.5f;
+}
+)";
+
+ auto* expect = R"(
+@group(0) @binding(0) var<storage, read_write> s : array<f32>;
+
+fn f() {
+ s[u32(25)] = 0.5f;
+}
+)";
+
+ auto got = Run<Robustness>(src, Config(GetParam(), true));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+INSTANTIATE_TEST_SUITE_P(,
+ RobustnessTest,
+ testing::Values(Robustness::Action::kIgnore,
+ Robustness::Action::kClamp,
+ Robustness::Action::kPredicate));
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/simplify_pointers.cc b/src/tint/lang/wgsl/ast/transform/simplify_pointers.cc
new file mode 100644
index 0000000..8f37331
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/simplify_pointers.cc
@@ -0,0 +1,264 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+
+#include <memory>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/block_statement.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::SimplifyPointers);
+
+namespace tint::ast::transform {
+
+namespace {
+
+/// PointerOp describes either possible indirection or address-of action on an
+/// expression.
+struct PointerOp {
+ /// Positive: Number of times the `expr` was dereferenced (*expr)
+ /// Negative: Number of times the `expr` was 'addressed-of' (&expr)
+ /// Zero: no pointer op on `expr`
+ int indirections = 0;
+ /// The expression being operated on
+ const Expression* expr = nullptr;
+};
+
+} // namespace
+
+/// PIMPL state for the transform
+struct SimplifyPointers::State {
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+
+ /// Constructor
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
+
+ /// Traverses the expression `expr` looking for non-literal array indexing
+ /// expressions that would affect the computed address of a pointer
+ /// expression. The function-like argument `cb` is called for each found.
+ /// @param expr the expression to traverse
+ /// @param cb a function-like object with the signature
+ /// `void(const Expression*)`, which is called for each array index
+ /// expression
+ template <typename F>
+ static void CollectSavedArrayIndices(const Expression* expr, F&& cb) {
+ if (auto* a = expr->As<IndexAccessorExpression>()) {
+ CollectSavedArrayIndices(a->object, cb);
+ if (!a->index->Is<LiteralExpression>()) {
+ cb(a->index);
+ }
+ return;
+ }
+
+ if (auto* m = expr->As<MemberAccessorExpression>()) {
+ CollectSavedArrayIndices(m->object, cb);
+ return;
+ }
+
+ if (auto* u = expr->As<UnaryOpExpression>()) {
+ CollectSavedArrayIndices(u->expr, cb);
+ return;
+ }
+
+ // Note: Other Expression types can be safely ignored as they cannot be
+ // used to generate a reference or pointer.
+ // See https://gpuweb.github.io/gpuweb/wgsl/#forming-references-and-pointers
+ }
+
+ /// Reduce walks the expression chain, collapsing all address-of and
+ /// indirection ops into a PointerOp.
+ /// @param in the expression to walk
+ /// @returns the reduced PointerOp
+ PointerOp Reduce(const Expression* in) const {
+ PointerOp op{0, in};
+ while (true) {
+ if (auto* unary = op.expr->As<UnaryOpExpression>()) {
+ switch (unary->op) {
+ case UnaryOp::kIndirection:
+ op.indirections++;
+ op.expr = unary->expr;
+ continue;
+ case UnaryOp::kAddressOf:
+ op.indirections--;
+ op.expr = unary->expr;
+ continue;
+ default:
+ break;
+ }
+ }
+ if (auto* user = ctx.src->Sem().Get<sem::VariableUser>(op.expr)) {
+ auto* var = user->Variable();
+ if (var->Is<sem::LocalVariable>() && //
+ var->Declaration()->Is<Let>() && //
+ var->Type()->Is<type::Pointer>()) {
+ op.expr = var->Declaration()->initializer;
+ continue;
+ }
+ }
+ return op;
+ }
+ }
+
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
+ // A map of saved expressions to their saved variable name
+ utils::Hashmap<const Expression*, Symbol, 8> saved_vars;
+
+ bool needs_transform = false;
+ for (auto* ty : ctx.src->Types()) {
+ if (ty->Is<type::Pointer>()) {
+ // Program contains pointers which need removing.
+ needs_transform = true;
+ break;
+ }
+ }
+
+ // Find all the pointer-typed `let` declarations.
+ // Note that these must be function-scoped, as module-scoped `let`s are not
+ // permitted.
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ Switch(
+ node, //
+ [&](const VariableDeclStatement* let) {
+ if (!let->variable->Is<Let>()) {
+ return; // Not a `let` declaration. Ignore.
+ }
+
+ auto* var = ctx.src->Sem().Get(let->variable);
+ if (!var->Type()->Is<type::Pointer>()) {
+ return; // Not a pointer type. Ignore.
+ }
+
+ // We're dealing with a pointer-typed `let` declaration.
+
+ // Scan the initializer expression for array index expressions that need
+ // to be hoist to temporary "saved" variables.
+ utils::Vector<const VariableDeclStatement*, 8> saved;
+ CollectSavedArrayIndices(
+ var->Declaration()->initializer, [&](const Expression* idx_expr) {
+ // We have a sub-expression that needs to be saved.
+ // Create a new variable
+ auto saved_name = ctx.dst->Symbols().New(
+ var->Declaration()->name->symbol.Name() + "_save");
+ auto* decl =
+ ctx.dst->Decl(ctx.dst->Let(saved_name, ctx.Clone(idx_expr)));
+ saved.Push(decl);
+ // Record the substitution of `idx_expr` to the saved variable
+ // with the symbol `saved_name`. This will be used by the
+ // ReplaceAll() handler above.
+ saved_vars.Add(idx_expr, saved_name);
+ });
+
+ // Find the place to insert the saved declarations.
+ // Special care needs to be made for lets declared as the initializer
+ // part of for-loops. In this case the block will hold the for-loop
+ // statement, not the let.
+ if (!saved.IsEmpty()) {
+ auto* stmt = ctx.src->Sem().Get(let);
+ auto* block = stmt->Block();
+ // Find the statement owned by the block (either the let decl or a
+ // for-loop)
+ while (block != stmt->Parent()) {
+ stmt = stmt->Parent();
+ }
+ // Declare the stored variables just before stmt. Order here is
+ // important as order-of-operations needs to be preserved.
+ // CollectSavedArrayIndices() visits the LHS of an index accessor
+ // before the index expression.
+ for (auto* decl : saved) {
+ // Note that repeated calls to InsertBefore() with the same `before`
+ // argument will result in nodes to inserted in the order the
+ // calls are made (last call is inserted last).
+ ctx.InsertBefore(block->Declaration()->statements, stmt->Declaration(),
+ decl);
+ }
+ }
+
+ // As the original `let` declaration will be fully inlined, there's no
+ // need for the original declaration to exist. Remove it.
+ RemoveStatement(ctx, let);
+ },
+ [&](const UnaryOpExpression* op) {
+ if (op->op == UnaryOp::kAddressOf) {
+ // Transform can be skipped if no address-of operator is used, as there
+ // will be no pointers that can be inlined.
+ needs_transform = true;
+ }
+ });
+ }
+
+ if (!needs_transform) {
+ return SkipTransform;
+ }
+
+ // Register the Expression transform handler.
+ // This performs two different transformations:
+ // * Identifiers that resolve to the pointer-typed `let` declarations are
+ // replaced with the recursively inlined initializer expression for the
+ // `let` declaration.
+ // * Sub-expressions inside the pointer-typed `let` initializer expression
+ // that have been hoisted to a saved variable are replaced with the saved
+ // variable identifier.
+ ctx.ReplaceAll([&](const Expression* expr) -> const Expression* {
+ // Look to see if we need to swap this Expression with a saved variable.
+ if (auto saved_var = saved_vars.Find(expr)) {
+ return ctx.dst->Expr(*saved_var);
+ }
+
+ // Reduce the expression, folding away chains of address-of / indirections
+ auto op = Reduce(expr);
+
+ // Clone the reduced root expression
+ expr = ctx.CloneWithoutTransform(op.expr);
+
+ // And reapply the minimum number of address-of / indirections
+ for (int i = 0; i < op.indirections; i++) {
+ expr = ctx.dst->Deref(expr);
+ }
+ for (int i = 0; i > op.indirections; i--) {
+ expr = ctx.dst->AddressOf(expr);
+ }
+ return expr;
+ });
+
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+};
+
+SimplifyPointers::SimplifyPointers() = default;
+
+SimplifyPointers::~SimplifyPointers() = default;
+
+Transform::ApplyResult SimplifyPointers::Apply(const Program* src, const DataMap&, DataMap&) const {
+ return State(src).Run();
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/simplify_pointers.h b/src/tint/lang/wgsl/ast/transform/simplify_pointers.h
new file mode 100644
index 0000000..b44e4a4
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/simplify_pointers.h
@@ -0,0 +1,53 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_SIMPLIFY_POINTERS_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_SIMPLIFY_POINTERS_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// SimplifyPointers is a Transform that moves all usage of function-scope
+/// `let` statements of a pointer type into their places of usage, while also
+/// simplifying any chains of address-of or indirections operators.
+///
+/// Parameters of a pointer type are not adjusted.
+///
+/// Note: SimplifyPointers does not operate on module-scope `let`s, as these
+/// cannot be pointers: https://gpuweb.github.io/gpuweb/wgsl/#module-constants
+/// `A module-scope let-declared constant must be of constructible type.`
+///
+/// @note Depends on the following transforms to have been run first:
+/// * Unshadow
+class SimplifyPointers final : public utils::Castable<SimplifyPointers, Transform> {
+ public:
+ /// Constructor
+ SimplifyPointers();
+
+ /// Destructor
+ ~SimplifyPointers() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_SIMPLIFY_POINTERS_H_
diff --git a/src/tint/lang/wgsl/ast/transform/simplify_pointers_test.cc b/src/tint/lang/wgsl/ast/transform/simplify_pointers_test.cc
new file mode 100644
index 0000000..7801548
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/simplify_pointers_test.cc
@@ -0,0 +1,395 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using SimplifyPointersTest = TransformTest;
+
+TEST_F(SimplifyPointersTest, EmptyModule) {
+ auto* src = "";
+
+ EXPECT_FALSE(ShouldRun<SimplifyPointers>(src));
+}
+
+TEST_F(SimplifyPointersTest, FoldPointer) {
+ auto* src = R"(
+fn f() {
+ var v : i32;
+ let p : ptr<function, i32> = &v;
+ let x : i32 = *p;
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var v : i32;
+ let x : i32 = v;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SimplifyPointersTest, UnusedPointerParam) {
+ auto* src = R"(
+fn f(p : ptr<function, i32>) {
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<Unshadow, SimplifyPointers>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SimplifyPointersTest, ReferencedPointerParam) {
+ auto* src = R"(
+fn f(p : ptr<function, i32>) {
+ let x = p;
+}
+)";
+
+ auto* expect = R"(
+fn f(p : ptr<function, i32>) {
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SimplifyPointersTest, AddressOfDeref) {
+ auto* src = R"(
+fn f() {
+ var v : i32;
+ let p : ptr<function, i32> = &(v);
+ let x : ptr<function, i32> = &(*(p));
+ let y : ptr<function, i32> = &(*(&(*(p))));
+ let z : ptr<function, i32> = &(*(&(*(&(*(&(*(p))))))));
+ var a = *x;
+ var b = *y;
+ var c = *z;
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var v : i32;
+ var a = v;
+ var b = v;
+ var c = v;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SimplifyPointersTest, DerefAddressOf) {
+ auto* src = R"(
+fn f() {
+ var v : i32;
+ let x : i32 = *(&(v));
+ let y : i32 = *(&(*(&(v))));
+ let z : i32 = *(&(*(&(*(&(*(&(v))))))));
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var v : i32;
+ let x : i32 = v;
+ let y : i32 = v;
+ let z : i32 = v;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SimplifyPointersTest, ComplexChain) {
+ auto* src = R"(
+fn f() {
+ var a : array<mat4x4<f32>, 4>;
+ let ap : ptr<function, array<mat4x4<f32>, 4>> = &a;
+ let mp : ptr<function, mat4x4<f32>> = &(*ap)[3];
+ let vp : ptr<function, vec4<f32>> = &(*mp)[2];
+ let v : vec4<f32> = *vp;
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var a : array<mat4x4<f32>, 4>;
+ let v : vec4<f32> = a[3][2];
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SimplifyPointersTest, SavedVars) {
+ auto* src = R"(
+struct S {
+ i : i32,
+};
+
+fn arr() {
+ var a : array<S, 2>;
+ var i : i32 = 0;
+ var j : i32 = 0;
+ let p : ptr<function, i32> = &a[i + j].i;
+ i = 2;
+ *p = 4;
+}
+
+fn matrix() {
+ var m : mat3x3<f32>;
+ var i : i32 = 0;
+ var j : i32 = 0;
+ let p : ptr<function, vec3<f32>> = &m[i + j];
+ i = 2;
+ *p = vec3<f32>(4.0, 5.0, 6.0);
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ i : i32,
+}
+
+fn arr() {
+ var a : array<S, 2>;
+ var i : i32 = 0;
+ var j : i32 = 0;
+ let p_save = (i + j);
+ i = 2;
+ a[p_save].i = 4;
+}
+
+fn matrix() {
+ var m : mat3x3<f32>;
+ var i : i32 = 0;
+ var j : i32 = 0;
+ let p_save_1 = (i + j);
+ i = 2;
+ m[p_save_1] = vec3<f32>(4.0, 5.0, 6.0);
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SimplifyPointersTest, DontSaveLiterals) {
+ auto* src = R"(
+fn f() {
+ var arr : array<i32, 2>;
+ let p1 : ptr<function, i32> = &arr[1];
+ *p1 = 4;
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var arr : array<i32, 2>;
+ arr[1] = 4;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SimplifyPointersTest, SavedVarsChain) {
+ auto* src = R"(
+fn f() {
+ var arr : array<array<i32, 2>, 2>;
+ let i : i32 = 0;
+ let j : i32 = 1;
+ let p : ptr<function, array<i32, 2>> = &arr[i];
+ let q : ptr<function, i32> = &(*p)[j];
+ *q = 12;
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var arr : array<array<i32, 2>, 2>;
+ let i : i32 = 0;
+ let j : i32 = 1;
+ let p_save = i;
+ let q_save = j;
+ arr[p_save][q_save] = 12;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SimplifyPointersTest, ForLoopInit) {
+ auto* src = R"(
+fn foo() -> i32 {
+ return 1;
+}
+
+@fragment
+fn main() {
+ var arr = array<f32, 4>();
+ for (let a = &arr[foo()]; ;) {
+ let x = *a;
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn foo() -> i32 {
+ return 1;
+}
+
+@fragment
+fn main() {
+ var arr = array<f32, 4>();
+ let a_save = foo();
+ for(; ; ) {
+ let x = arr[a_save];
+ break;
+ }
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SimplifyPointersTest, MultiSavedVarsInSinglePtrLetExpr) {
+ auto* src = R"(
+fn x() -> i32 {
+ return 1;
+}
+
+fn y() -> i32 {
+ return 1;
+}
+
+fn z() -> i32 {
+ return 1;
+}
+
+struct Inner {
+ a : array<i32, 2>,
+};
+
+struct Outer {
+ a : array<Inner, 2>,
+};
+
+fn f() {
+ var arr : array<Outer, 2>;
+ let p : ptr<function, i32> = &arr[x()].a[y()].a[z()];
+ *p = 1;
+ *p = 2;
+}
+)";
+
+ auto* expect = R"(
+fn x() -> i32 {
+ return 1;
+}
+
+fn y() -> i32 {
+ return 1;
+}
+
+fn z() -> i32 {
+ return 1;
+}
+
+struct Inner {
+ a : array<i32, 2>,
+}
+
+struct Outer {
+ a : array<Inner, 2>,
+}
+
+fn f() {
+ var arr : array<Outer, 2>;
+ let p_save = x();
+ let p_save_1 = y();
+ let p_save_2 = z();
+ arr[p_save].a[p_save_1].a[p_save_2] = 1;
+ arr[p_save].a[p_save_1].a[p_save_2] = 2;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SimplifyPointersTest, ShadowPointer) {
+ auto* src = R"(
+var<private> a : array<i32, 2>;
+
+@compute @workgroup_size(1)
+fn main() {
+ let x = &a;
+ var a : i32 = (*x)[0];
+ {
+ var a : i32 = (*x)[1];
+ }
+}
+)";
+
+ auto* expect = R"(
+var<private> a : array<i32, 2>;
+
+@compute @workgroup_size(1)
+fn main() {
+ var a_1 : i32 = a[0];
+ {
+ var a_2 : i32 = a[1];
+ }
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/single_entry_point.cc b/src/tint/lang/wgsl/ast/transform/single_entry_point.cc
new file mode 100644
index 0000000..3cd47cc
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/single_entry_point.cc
@@ -0,0 +1,137 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/single_entry_point.h"
+
+#include <unordered_set>
+#include <utility>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::SingleEntryPoint);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::SingleEntryPoint::Config);
+
+namespace tint::ast::transform {
+
+SingleEntryPoint::SingleEntryPoint() = default;
+
+SingleEntryPoint::~SingleEntryPoint() = default;
+
+Transform::ApplyResult SingleEntryPoint::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ auto* cfg = inputs.Get<Config>();
+ if (cfg == nullptr) {
+ b.Diagnostics().add_error(diag::System::Transform,
+ "missing transform data for " + std::string(TypeInfo().name));
+ return Program(std::move(b));
+ }
+
+ // Find the target entry point.
+ const Function* entry_point = nullptr;
+ for (auto* f : src->AST().Functions()) {
+ if (!f->IsEntryPoint()) {
+ continue;
+ }
+ if (f->name->symbol.Name() == cfg->entry_point_name) {
+ entry_point = f;
+ break;
+ }
+ }
+ if (entry_point == nullptr) {
+ b.Diagnostics().add_error(diag::System::Transform,
+ "entry point '" + cfg->entry_point_name + "' not found");
+ return Program(std::move(b));
+ }
+
+ auto& sem = src->Sem();
+ auto& referenced_vars = sem.Get(entry_point)->TransitivelyReferencedGlobals();
+
+ // Clone any module-scope variables, types, and functions that are statically referenced by the
+ // target entry point.
+ for (auto* decl : src->AST().GlobalDeclarations()) {
+ Switch(
+ decl, //
+ [&](const TypeDecl* ty) {
+ // Strip aliases that reference unused override declarations.
+ if (auto* arr = sem.Get(ty)->As<type::Array>()) {
+ auto* refs = sem.TransitivelyReferencedOverrides(arr);
+ if (refs) {
+ for (auto* o : *refs) {
+ if (!referenced_vars.Contains(o)) {
+ return;
+ }
+ }
+ }
+ }
+
+ // TODO(jrprice): Strip other unused types.
+ b.AST().AddTypeDecl(ctx.Clone(ty));
+ },
+ [&](const Override* override) {
+ if (referenced_vars.Contains(sem.Get(override))) {
+ if (!HasAttribute<IdAttribute>(override->attributes)) {
+ // If the override doesn't already have an @id() attribute, add one
+ // so that its allocated ID so that it won't be affected by other
+ // stripped away overrides
+ auto* global = sem.Get(override);
+ const auto* id = b.Id(global->OverrideId());
+ ctx.InsertFront(override->attributes, id);
+ }
+ b.AST().AddGlobalVariable(ctx.Clone(override));
+ }
+ },
+ [&](const Var* var) {
+ if (referenced_vars.Contains(sem.Get<sem::GlobalVariable>(var))) {
+ b.AST().AddGlobalVariable(ctx.Clone(var));
+ }
+ },
+ [&](const Const* c) {
+ // Always keep 'const' declarations, as these can be used by attributes and array
+ // sizes, which are not tracked as transitively used by functions. They also don't
+ // typically get emitted by the backend unless they're actually used.
+ b.AST().AddGlobalVariable(ctx.Clone(c));
+ },
+ [&](const Function* func) {
+ if (sem.Get(func)->HasAncestorEntryPoint(entry_point->name->symbol)) {
+ b.AST().AddFunction(ctx.Clone(func));
+ }
+ },
+ [&](const Enable* ext) { b.AST().AddEnable(ctx.Clone(ext)); },
+ [&](const DiagnosticDirective* d) { b.AST().AddDiagnosticDirective(ctx.Clone(d)); },
+ [&](Default) {
+ TINT_UNREACHABLE(Transform, b.Diagnostics())
+ << "unhandled global declaration: " << decl->TypeInfo().name;
+ });
+ }
+
+ // Clone the entry point.
+ b.AST().AddFunction(ctx.Clone(entry_point));
+
+ return Program(std::move(b));
+}
+
+SingleEntryPoint::Config::Config(std::string entry_point) : entry_point_name(entry_point) {}
+
+SingleEntryPoint::Config::Config(const Config&) = default;
+SingleEntryPoint::Config::~Config() = default;
+SingleEntryPoint::Config& SingleEntryPoint::Config::operator=(const Config&) = default;
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/single_entry_point.h b/src/tint/lang/wgsl/ast/transform/single_entry_point.h
new file mode 100644
index 0000000..3a130c1
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/single_entry_point.h
@@ -0,0 +1,64 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_SINGLE_ENTRY_POINT_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_SINGLE_ENTRY_POINT_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// Strip all but one entry point a module.
+///
+/// All module-scope variables, types, and functions that are not used by the
+/// target entry point will also be removed.
+class SingleEntryPoint final : public utils::Castable<SingleEntryPoint, Transform> {
+ public:
+ /// Configuration options for the transform
+ struct Config final : public utils::Castable<Config, Data> {
+ /// Constructor
+ /// @param entry_point the name of the entry point to keep
+ explicit Config(std::string entry_point = "");
+
+ /// Copy constructor
+ Config(const Config&);
+
+ /// Destructor
+ ~Config() override;
+
+ /// Assignment operator
+ /// @returns this Config
+ Config& operator=(const Config&);
+
+ /// The name of the entry point to keep.
+ std::string entry_point_name;
+ };
+
+ /// Constructor
+ SingleEntryPoint();
+
+ /// Destructor
+ ~SingleEntryPoint() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_SINGLE_ENTRY_POINT_H_
diff --git a/src/tint/lang/wgsl/ast/transform/single_entry_point_test.cc b/src/tint/lang/wgsl/ast/transform/single_entry_point_test.cc
new file mode 100644
index 0000000..783a722
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/single_entry_point_test.cc
@@ -0,0 +1,635 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/single_entry_point.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using SingleEntryPointTest = TransformTest;
+
+TEST_F(SingleEntryPointTest, Error_MissingTransformData) {
+ auto* src = "";
+
+ auto* expect = "error: missing transform data for tint::ast::transform::SingleEntryPoint";
+
+ auto got = Run<SingleEntryPoint>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SingleEntryPointTest, Error_NoEntryPoints) {
+ auto* src = "";
+
+ auto* expect = "error: entry point 'main' not found";
+
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>("main");
+ auto got = Run<SingleEntryPoint>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SingleEntryPointTest, Error_InvalidEntryPoint) {
+ auto* src = R"(
+@vertex
+fn main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = "error: entry point '_' not found";
+
+ SingleEntryPoint::Config cfg("_");
+
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SingleEntryPointTest, Error_NotAnEntryPoint) {
+ auto* src = R"(
+fn foo() {}
+
+@fragment
+fn main() {}
+)";
+
+ auto* expect = "error: entry point 'foo' not found";
+
+ SingleEntryPoint::Config cfg("foo");
+
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SingleEntryPointTest, SingleEntryPoint) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn main() {
+}
+)";
+
+ SingleEntryPoint::Config cfg("main");
+
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+
+ EXPECT_EQ(src, str(got));
+}
+
+TEST_F(SingleEntryPointTest, MultipleEntryPoints) {
+ auto* src = R"(
+@vertex
+fn vert_main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+
+@fragment
+fn frag_main() {
+}
+
+@compute @workgroup_size(1)
+fn comp_main1() {
+}
+
+@compute @workgroup_size(1)
+fn comp_main2() {
+}
+)";
+
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn comp_main1() {
+}
+)";
+
+ SingleEntryPoint::Config cfg("comp_main1");
+
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SingleEntryPointTest, GlobalVariables) {
+ auto* src = R"(
+var<private> a : f32;
+
+var<private> b : f32;
+
+var<private> c : f32;
+
+var<private> d : f32;
+
+@vertex
+fn vert_main() -> @builtin(position) vec4<f32> {
+ a = 0.0;
+ return vec4<f32>();
+}
+
+@fragment
+fn frag_main() {
+ b = 0.0;
+}
+
+@compute @workgroup_size(1)
+fn comp_main1() {
+ c = 0.0;
+}
+
+@compute @workgroup_size(1)
+fn comp_main2() {
+ d = 0.0;
+}
+)";
+
+ auto* expect = R"(
+var<private> c : f32;
+
+@compute @workgroup_size(1)
+fn comp_main1() {
+ c = 0.0;
+}
+)";
+
+ SingleEntryPoint::Config cfg("comp_main1");
+
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SingleEntryPointTest, GlobalConstants) {
+ auto* src = R"(
+const a : f32 = 1.0;
+
+const b : f32 = 1.0;
+
+const c : f32 = 1.0;
+
+const d : f32 = 1.0;
+
+@vertex
+fn vert_main() -> @builtin(position) vec4<f32> {
+ let local_a : f32 = a;
+ return vec4<f32>();
+}
+
+@fragment
+fn frag_main() {
+ let local_b : f32 = b;
+}
+
+@compute @workgroup_size(1)
+fn comp_main1() {
+ let local_c : f32 = c;
+}
+
+@compute @workgroup_size(1)
+fn comp_main2() {
+ let local_d : f32 = d;
+}
+)";
+
+ auto* expect = R"(
+const a : f32 = 1.0;
+
+const b : f32 = 1.0;
+
+const c : f32 = 1.0;
+
+const d : f32 = 1.0;
+
+@compute @workgroup_size(1)
+fn comp_main1() {
+ let local_c : f32 = c;
+}
+)";
+
+ SingleEntryPoint::Config cfg("comp_main1");
+
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SingleEntryPointTest, WorkgroupSizeConstPreserved) {
+ auto* src = R"(
+const size : i32 = 1;
+
+@compute @workgroup_size(size)
+fn main() {
+}
+)";
+
+ auto* expect = src;
+
+ SingleEntryPoint::Config cfg("main");
+
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SingleEntryPointTest, OverridableConstants) {
+ auto* src = R"(
+@id(1001) override c1 : u32 = 1u;
+ override c2 : u32 = 1u;
+@id(0) override c3 : u32 = 1u;
+@id(9999) override c4 : u32 = 1u;
+
+@compute @workgroup_size(1)
+fn comp_main1() {
+ let local_d = c1;
+}
+
+@compute @workgroup_size(1)
+fn comp_main2() {
+ let local_d = c2;
+}
+
+@compute @workgroup_size(1)
+fn comp_main3() {
+ let local_d = c3;
+}
+
+@compute @workgroup_size(1)
+fn comp_main4() {
+ let local_d = c4;
+}
+
+@compute @workgroup_size(1)
+fn comp_main5() {
+ let local_d = 1u;
+}
+)";
+
+ {
+ SingleEntryPoint::Config cfg("comp_main1");
+ auto* expect = R"(
+@id(1001) override c1 : u32 = 1u;
+
+@compute @workgroup_size(1)
+fn comp_main1() {
+ let local_d = c1;
+}
+)";
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+ EXPECT_EQ(expect, str(got));
+ }
+
+ {
+ SingleEntryPoint::Config cfg("comp_main2");
+ // The decorator is replaced with the one with explicit id
+ // And should not be affected by other constants stripped away
+ auto* expect = R"(
+@id(1) override c2 : u32 = 1u;
+
+@compute @workgroup_size(1)
+fn comp_main2() {
+ let local_d = c2;
+}
+)";
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+ EXPECT_EQ(expect, str(got));
+ }
+
+ {
+ SingleEntryPoint::Config cfg("comp_main3");
+ auto* expect = R"(
+@id(0) override c3 : u32 = 1u;
+
+@compute @workgroup_size(1)
+fn comp_main3() {
+ let local_d = c3;
+}
+)";
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+ EXPECT_EQ(expect, str(got));
+ }
+
+ {
+ SingleEntryPoint::Config cfg("comp_main4");
+ auto* expect = R"(
+@id(9999) override c4 : u32 = 1u;
+
+@compute @workgroup_size(1)
+fn comp_main4() {
+ let local_d = c4;
+}
+)";
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+ EXPECT_EQ(expect, str(got));
+ }
+
+ {
+ SingleEntryPoint::Config cfg("comp_main5");
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn comp_main5() {
+ let local_d = 1u;
+}
+)";
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+ EXPECT_EQ(expect, str(got));
+ }
+}
+
+TEST_F(SingleEntryPointTest, OverridableConstants_TransitiveUses) {
+ // Make sure we do not strip away transitive uses of overridable constants.
+ auto* src = R"(
+@id(0) override c0 : u32;
+
+@id(1) override c1 : u32 = (2 * c0);
+
+@id(2) override c2 : u32;
+
+@id(3) override c3 : u32 = (2 * c2);
+
+@id(4) override c4 : u32;
+
+@id(5) override c5 : u32 = (2 * c4);
+
+alias arr_ty = array<i32, (2 * c5)>;
+
+var<workgroup> arr : arr_ty;
+
+@compute @workgroup_size(1, 1, (2 * c3))
+fn main() {
+ let local_d = c1;
+ arr[0] = 42;
+}
+)";
+
+ auto* expect = src;
+
+ SingleEntryPoint::Config cfg("main");
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SingleEntryPointTest, OverridableConstants_UnusedAliasForOverrideSizedArray) {
+ // Make sure we strip away aliases that reference unused overridable constants.
+ auto* src = R"(
+@id(0) override c0 : u32;
+
+// This is all unused by the target entry point.
+@id(1) override c1 : u32;
+alias arr_ty = array<i32, c1>;
+var<workgroup> arr : arr_ty;
+
+@compute @workgroup_size(64)
+fn unused() {
+ arr[0] = 42;
+}
+
+@compute @workgroup_size(64)
+fn main() {
+ let local_d = c0;
+}
+)";
+
+ auto* expect = R"(
+@id(0) override c0 : u32;
+
+@compute @workgroup_size(64)
+fn main() {
+ let local_d = c0;
+}
+)";
+
+ SingleEntryPoint::Config cfg("main");
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SingleEntryPointTest, CalledFunctions) {
+ auto* src = R"(
+fn inner1() {
+}
+
+fn inner2() {
+}
+
+fn inner_shared() {
+}
+
+fn outer1() {
+ inner1();
+ inner_shared();
+}
+
+fn outer2() {
+ inner2();
+ inner_shared();
+}
+
+@compute @workgroup_size(1)
+fn comp_main1() {
+ outer1();
+}
+
+@compute @workgroup_size(1)
+fn comp_main2() {
+ outer2();
+}
+)";
+
+ auto* expect = R"(
+fn inner1() {
+}
+
+fn inner_shared() {
+}
+
+fn outer1() {
+ inner1();
+ inner_shared();
+}
+
+@compute @workgroup_size(1)
+fn comp_main1() {
+ outer1();
+}
+)";
+
+ SingleEntryPoint::Config cfg("comp_main1");
+
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SingleEntryPointTest, GlobalsReferencedByCalledFunctions) {
+ auto* src = R"(
+var<private> inner1_var : f32;
+
+var<private> inner2_var : f32;
+
+var<private> inner_shared_var : f32;
+
+var<private> outer1_var : f32;
+
+var<private> outer2_var : f32;
+
+fn inner1() {
+ inner1_var = 0.0;
+}
+
+fn inner2() {
+ inner2_var = 0.0;
+}
+
+fn inner_shared() {
+ inner_shared_var = 0.0;
+}
+
+fn outer1() {
+ inner1();
+ inner_shared();
+ outer1_var = 0.0;
+}
+
+fn outer2() {
+ inner2();
+ inner_shared();
+ outer2_var = 0.0;
+}
+
+@compute @workgroup_size(1)
+fn comp_main1() {
+ outer1();
+}
+
+@compute @workgroup_size(1)
+fn comp_main2() {
+ outer2();
+}
+)";
+
+ auto* expect = R"(
+var<private> inner1_var : f32;
+
+var<private> inner_shared_var : f32;
+
+var<private> outer1_var : f32;
+
+fn inner1() {
+ inner1_var = 0.0;
+}
+
+fn inner_shared() {
+ inner_shared_var = 0.0;
+}
+
+fn outer1() {
+ inner1();
+ inner_shared();
+ outer1_var = 0.0;
+}
+
+@compute @workgroup_size(1)
+fn comp_main1() {
+ outer1();
+}
+)";
+
+ SingleEntryPoint::Config cfg("comp_main1");
+
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SingleEntryPointTest, GlobalConstUsedAsArraySize) {
+ // See crbug.com/tint/1598
+ auto* src = R"(
+const MY_SIZE = 5u;
+
+alias Arr = array<i32, MY_SIZE>;
+
+@fragment
+fn main() {
+}
+)";
+
+ auto* expect = src;
+
+ SingleEntryPoint::Config cfg("main");
+
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SingleEntryPointTest, Directives) {
+ // Make sure that directives are preserved.
+ auto* src = R"(
+enable f16;
+diagnostic(off, derivative_uniformity);
+
+@compute @workgroup_size(1)
+fn main() {
+}
+)";
+
+ SingleEntryPoint::Config cfg("main");
+
+ Transform::DataMap data;
+ data.Add<SingleEntryPoint::Config>(cfg);
+ auto got = Run<SingleEntryPoint>(src, data);
+
+ EXPECT_EQ(src, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/spirv_atomic.cc b/src/tint/lang/wgsl/ast/transform/spirv_atomic.cc
new file mode 100644
index 0000000..5eb8de8
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/spirv_atomic.cc
@@ -0,0 +1,312 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/spirv_atomic.h"
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/block_statement.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/index_accessor_expression.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/switch.h"
+#include "src/tint/type/reference.h"
+#include "src/tint/utils/map.h"
+#include "src/tint/utils/unique_vector.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::SpirvAtomic);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::SpirvAtomic::Stub);
+
+namespace tint::ast::transform {
+
+using namespace tint::number_suffixes; // NOLINT
+
+/// PIMPL state for the transform
+struct SpirvAtomic::State {
+ private:
+ /// A struct that has been forked because a subset of members were made atomic.
+ struct ForkedStruct {
+ Symbol name;
+ std::unordered_set<size_t> atomic_members;
+ };
+
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+ std::unordered_map<const type::Struct*, ForkedStruct> forked_structs;
+ std::unordered_set<const sem::Variable*> atomic_variables;
+ utils::UniqueVector<const sem::ValueExpression*, 8> atomic_expressions;
+
+ public:
+ /// Constructor
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
+
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
+ bool made_changes = false;
+
+ // Look for stub functions generated by the SPIR-V reader, which are used as placeholders
+ // for atomic builtin calls.
+ for (auto* fn : ctx.src->AST().Functions()) {
+ if (auto* stub = GetAttribute<Stub>(fn->attributes)) {
+ auto* sem = ctx.src->Sem().Get(fn);
+
+ for (auto* call : sem->CallSites()) {
+ // The first argument is always the atomic.
+ // The stub passes this by value, whereas the builtin wants a pointer.
+ // Take the address of the atomic argument.
+ auto& args = call->Declaration()->args;
+ auto out_args = ctx.Clone(args);
+ out_args[0] = b.AddressOf(out_args[0]);
+
+ // Replace all callsites of this stub to a call to the real builtin
+ if (stub->builtin == builtin::Function::kAtomicCompareExchangeWeak) {
+ // atomicCompareExchangeWeak returns a struct, so insert a call to it above
+ // the current statement, and replace the current call with the struct's
+ // `old_value` member.
+ auto* block = call->Stmt()->Block()->Declaration();
+ auto old_value = b.Symbols().New("old_value");
+ auto old_value_decl = b.Decl(b.Let(
+ old_value, b.MemberAccessor(
+ b.Call(builtin::str(stub->builtin), std::move(out_args)),
+ "old_value")));
+ ctx.InsertBefore(block->statements, call->Stmt()->Declaration(),
+ old_value_decl);
+ ctx.Replace(call->Declaration(), b.Expr(old_value));
+ } else {
+ ctx.Replace(call->Declaration(),
+ b.Call(builtin::str(stub->builtin), std::move(out_args)));
+ }
+
+ // Keep track of this expression. We'll need to modify the root identifier /
+ // structure to be atomic.
+ atomic_expressions.Add(ctx.src->Sem().GetVal(args[0]));
+ }
+
+ // Remove the stub from the output program
+ ctx.Remove(ctx.src->AST().GlobalDeclarations(), fn);
+ made_changes = true;
+ }
+ }
+
+ if (!made_changes) {
+ return SkipTransform;
+ }
+
+ // Transform all variables and structure members that were used in atomic operations as
+ // atomic types. This propagates up originating expression chains.
+ ProcessAtomicExpressions();
+
+ // If we need to change structure members, then fork them.
+ if (!forked_structs.empty()) {
+ ctx.ReplaceAll([&](const Struct* str) {
+ // Is `str` a structure we need to fork?
+ auto* str_ty = ctx.src->Sem().Get(str);
+ if (auto it = forked_structs.find(str_ty); it != forked_structs.end()) {
+ const auto& forked = it->second;
+
+ // Re-create the structure swapping in the atomic-flavoured members
+ utils::Vector<const StructMember*, 8> members;
+ members.Reserve(str->members.Length());
+ for (size_t i = 0; i < str->members.Length(); i++) {
+ auto* member = str->members[i];
+ if (forked.atomic_members.count(i)) {
+ auto type = AtomicTypeFor(ctx.src->Sem().Get(member)->Type());
+ auto name = member->name->symbol.Name();
+ members.Push(b.Member(name, type, ctx.Clone(member->attributes)));
+ } else {
+ members.Push(ctx.Clone(member));
+ }
+ }
+ b.Structure(forked.name, std::move(members));
+ }
+ return nullptr;
+ });
+ }
+
+ // Replace assignments and decls from atomic variables with atomicLoads, and assignments to
+ // atomic variables with atomicStores.
+ ReplaceLoadsAndStores();
+
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+
+ private:
+ ForkedStruct& Fork(const type::Struct* str) {
+ auto& forked = forked_structs[str];
+ if (!forked.name.IsValid()) {
+ forked.name = b.Symbols().New(str->Name().Name() + "_atomic");
+ }
+ return forked;
+ }
+
+ void ProcessAtomicExpressions() {
+ for (size_t i = 0; i < atomic_expressions.Length(); i++) {
+ Switch(
+ atomic_expressions[i]->UnwrapLoad(), //
+ [&](const sem::VariableUser* user) {
+ auto* v = user->Variable()->Declaration();
+ if (v->type && atomic_variables.emplace(user->Variable()).second) {
+ ctx.Replace(v->type.expr, b.Expr(AtomicTypeFor(user->Variable()->Type())));
+ }
+ if (auto* ctor = user->Variable()->Initializer()) {
+ atomic_expressions.Add(ctor);
+ }
+ },
+ [&](const sem::StructMemberAccess* access) {
+ // Fork the struct (the first time) and mark member(s) that need to be made
+ // atomic.
+ auto* member = access->Member();
+ Fork(member->Struct()).atomic_members.emplace(member->Index());
+ atomic_expressions.Add(access->Object());
+ },
+ [&](const sem::IndexAccessorExpression* index) {
+ atomic_expressions.Add(index->Object());
+ },
+ [&](const sem::ValueExpression* e) {
+ if (auto* unary = e->Declaration()->As<UnaryOpExpression>()) {
+ atomic_expressions.Add(ctx.src->Sem().GetVal(unary->expr));
+ }
+ });
+ }
+ }
+
+ Type AtomicTypeFor(const type::Type* ty) {
+ return Switch(
+ ty, //
+ [&](const type::I32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
+ [&](const type::U32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
+ [&](const type::Struct* str) { return b.ty(Fork(str).name); },
+ [&](const type::Array* arr) {
+ if (arr->Count()->Is<type::RuntimeArrayCount>()) {
+ return b.ty.array(AtomicTypeFor(arr->ElemType()));
+ }
+ auto count = arr->ConstantCount();
+ if (!count) {
+ ctx.dst->Diagnostics().add_error(
+ diag::System::Transform,
+ "the SpirvAtomic transform does not currently support array counts that "
+ "use override values");
+ count = 1;
+ }
+ return b.ty.array(AtomicTypeFor(arr->ElemType()), u32(count.value()));
+ },
+ [&](const type::Pointer* ptr) {
+ return b.ty.ptr(ptr->AddressSpace(), AtomicTypeFor(ptr->StoreType()),
+ ptr->Access());
+ },
+ [&](const type::Reference* ref) { return AtomicTypeFor(ref->StoreType()); },
+ [&](Default) {
+ TINT_ICE(Transform, b.Diagnostics()) << "unhandled type: " << ty->FriendlyName();
+ return Type{};
+ });
+ }
+
+ void ReplaceLoadsAndStores() {
+ // Returns true if 'e' is a reference to an atomic variable or struct member
+ auto is_ref_to_atomic_var = [&](const sem::ValueExpression* e) {
+ if (tint::Is<type::Reference>(e->Type()) && e->RootIdentifier() &&
+ (atomic_variables.count(e->RootIdentifier()) != 0)) {
+ // If it's a struct member, make sure it's one we marked as atomic
+ if (auto* ma = e->As<sem::StructMemberAccess>()) {
+ auto it = forked_structs.find(ma->Member()->Struct());
+ if (it != forked_structs.end()) {
+ auto& forked = it->second;
+ return forked.atomic_members.count(ma->Member()->Index()) != 0;
+ }
+ }
+ return true;
+ }
+ return false;
+ };
+
+ // Look for loads and stores via assignments and decls of atomic variables we've collected
+ // so far, and replace them with atomicLoad and atomicStore.
+ for (auto* atomic_var : atomic_variables) {
+ for (auto* vu : atomic_var->Users()) {
+ Switch(
+ vu->Stmt()->Declaration(),
+ [&](const AssignmentStatement* assign) {
+ auto* sem_lhs = ctx.src->Sem().GetVal(assign->lhs);
+ if (is_ref_to_atomic_var(sem_lhs)) {
+ ctx.Replace(assign, [=] {
+ auto* lhs = ctx.CloneWithoutTransform(assign->lhs);
+ auto* rhs = ctx.CloneWithoutTransform(assign->rhs);
+ auto* call = b.Call(builtin::str(builtin::Function::kAtomicStore),
+ b.AddressOf(lhs), rhs);
+ return b.CallStmt(call);
+ });
+ return;
+ }
+
+ auto sem_rhs = ctx.src->Sem().GetVal(assign->rhs);
+ if (is_ref_to_atomic_var(sem_rhs->UnwrapLoad())) {
+ ctx.Replace(assign->rhs, [=] {
+ auto* rhs = ctx.CloneWithoutTransform(assign->rhs);
+ return b.Call(builtin::str(builtin::Function::kAtomicLoad),
+ b.AddressOf(rhs));
+ });
+ return;
+ }
+ },
+ [&](const VariableDeclStatement* decl) {
+ auto* var = decl->variable;
+ if (auto* sem_init = ctx.src->Sem().GetVal(var->initializer)) {
+ if (is_ref_to_atomic_var(sem_init->UnwrapLoad())) {
+ ctx.Replace(var->initializer, [=] {
+ auto* rhs = ctx.CloneWithoutTransform(var->initializer);
+ return b.Call(builtin::str(builtin::Function::kAtomicLoad),
+ b.AddressOf(rhs));
+ });
+ return;
+ }
+ }
+ });
+ }
+ }
+ }
+};
+
+SpirvAtomic::SpirvAtomic() = default;
+SpirvAtomic::~SpirvAtomic() = default;
+
+SpirvAtomic::Stub::Stub(ProgramID pid, NodeID nid, builtin::Function b)
+ : Base(pid, nid, utils::Empty), builtin(b) {}
+SpirvAtomic::Stub::~Stub() = default;
+std::string SpirvAtomic::Stub::InternalName() const {
+ return "@internal(spirv-atomic " + std::string(builtin::str(builtin)) + ")";
+}
+
+const SpirvAtomic::Stub* SpirvAtomic::Stub::Clone(CloneContext* ctx) const {
+ return ctx->dst->ASTNodes().Create<SpirvAtomic::Stub>(ctx->dst->ID(),
+ ctx->dst->AllocateNodeID(), builtin);
+}
+
+Transform::ApplyResult SpirvAtomic::Apply(const Program* src, const DataMap&, DataMap&) const {
+ return State{src}.Run();
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/spirv_atomic.h b/src/tint/lang/wgsl/ast/transform/spirv_atomic.h
new file mode 100644
index 0000000..6fbcd66
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/spirv_atomic.h
@@ -0,0 +1,77 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_SPIRV_ATOMIC_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_SPIRV_ATOMIC_H_
+
+#include <string>
+
+#include "src/tint/builtin/function.h"
+#include "src/tint/lang/wgsl/ast/internal_attribute.h"
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+// Forward declarations
+namespace tint {
+class CloneContext;
+} // namespace tint
+
+namespace tint::ast::transform {
+
+/// SpirvAtomic is a transform that replaces calls to stub functions created by the SPIR-V reader
+/// with calls to the WGSL atomic builtin. It also makes sure to replace variable declarations that
+/// are the target of the atomic operations with an atomic declaration of the same type. For
+/// structs, it creates a copy of the original struct with atomic members.
+class SpirvAtomic final : public utils::Castable<SpirvAtomic, Transform> {
+ public:
+ /// Constructor
+ SpirvAtomic();
+ /// Destructor
+ ~SpirvAtomic() override;
+
+ /// Stub is an attribute applied to stub SPIR-V reader generated functions that need to be
+ /// translated to an atomic builtin.
+ class Stub final : public utils::Castable<Stub, InternalAttribute> {
+ public:
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param builtin the atomic builtin this stub represents
+ Stub(ProgramID pid, NodeID nid, builtin::Function builtin);
+ /// Destructor
+ ~Stub() override;
+
+ /// @return a short description of the internal attribute which will be
+ /// displayed as `@internal(<name>)`
+ std::string InternalName() const override;
+
+ /// Performs a deep clone of this object using the CloneContext `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned object
+ const Stub* Clone(CloneContext* ctx) const override;
+
+ /// The type of the intrinsic
+ const builtin::Function builtin;
+ };
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_SPIRV_ATOMIC_H_
diff --git a/src/tint/lang/wgsl/ast/transform/spirv_atomic_test.cc b/src/tint/lang/wgsl/ast/transform/spirv_atomic_test.cc
new file mode 100644
index 0000000..f03912e
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/spirv_atomic_test.cc
@@ -0,0 +1,1380 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/spirv_atomic.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/lang/wgsl/reader/parser_impl.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast::transform {
+namespace {
+
+class SpirvAtomicTest : public TransformTest {
+ public:
+ Output Run(std::string in) {
+ auto file = std::make_unique<Source::File>("test", std::move(in));
+ auto parser = reader::wgsl::ParserImpl(file.get());
+ parser.Parse();
+
+ auto& b = parser.builder();
+
+ builtin::Function two_params[] = {
+ builtin::Function::kAtomicExchange, builtin::Function::kAtomicAdd,
+ builtin::Function::kAtomicSub, builtin::Function::kAtomicMin,
+ builtin::Function::kAtomicMax, builtin::Function::kAtomicAnd,
+ builtin::Function::kAtomicOr, builtin::Function::kAtomicXor,
+ };
+ for (auto& a : two_params) {
+ b.Func(std::string{"stub_"} + builtin::str(a) + "_u32",
+ utils::Vector{
+ b.Param("p0", b.ty.u32()),
+ b.Param("p1", b.ty.u32()),
+ },
+ b.ty.u32(),
+ utils::Vector{
+ b.Return(0_u),
+ },
+ utils::Vector{
+ b.ASTNodes().Create<SpirvAtomic::Stub>(b.ID(), b.AllocateNodeID(), a),
+ });
+ b.Func(std::string{"stub_"} + builtin::str(a) + "_i32",
+ utils::Vector{
+ b.Param("p0", b.ty.i32()),
+ b.Param("p1", b.ty.i32()),
+ },
+ b.ty.i32(),
+ utils::Vector{
+ b.Return(0_i),
+ },
+ utils::Vector{
+ b.ASTNodes().Create<SpirvAtomic::Stub>(b.ID(), b.AllocateNodeID(), a),
+ });
+ }
+
+ b.Func("stub_atomicLoad_u32",
+ utils::Vector{
+ b.Param("p0", b.ty.u32()),
+ },
+ b.ty.u32(),
+ utils::Vector{
+ b.Return(0_u),
+ },
+ utils::Vector{
+ b.ASTNodes().Create<SpirvAtomic::Stub>(b.ID(), b.AllocateNodeID(),
+ builtin::Function::kAtomicLoad),
+ });
+ b.Func("stub_atomicLoad_i32",
+ utils::Vector{
+ b.Param("p0", b.ty.i32()),
+ },
+ b.ty.i32(),
+ utils::Vector{
+ b.Return(0_i),
+ },
+ utils::Vector{
+ b.ASTNodes().Create<SpirvAtomic::Stub>(b.ID(), b.AllocateNodeID(),
+ builtin::Function::kAtomicLoad),
+ });
+
+ b.Func("stub_atomicStore_u32",
+ utils::Vector{
+ b.Param("p0", b.ty.u32()),
+ b.Param("p1", b.ty.u32()),
+ },
+ b.ty.void_(), utils::Empty,
+ utils::Vector{
+ b.ASTNodes().Create<SpirvAtomic::Stub>(b.ID(), b.AllocateNodeID(),
+ builtin::Function::kAtomicStore),
+ });
+ b.Func("stub_atomicStore_i32",
+ utils::Vector{
+ b.Param("p0", b.ty.i32()),
+ b.Param("p1", b.ty.i32()),
+ },
+ b.ty.void_(), utils::Empty,
+ utils::Vector{
+ b.ASTNodes().Create<SpirvAtomic::Stub>(b.ID(), b.AllocateNodeID(),
+ builtin::Function::kAtomicStore),
+ });
+
+ b.Func("stub_atomic_compare_exchange_weak_u32",
+ utils::Vector{
+ b.Param("p0", b.ty.u32()),
+ b.Param("p1", b.ty.u32()),
+ b.Param("p2", b.ty.u32()),
+ },
+ b.ty.u32(),
+ utils::Vector{
+ b.Return(0_u),
+ },
+ utils::Vector{
+ b.ASTNodes().Create<SpirvAtomic::Stub>(
+ b.ID(), b.AllocateNodeID(), builtin::Function::kAtomicCompareExchangeWeak),
+ });
+ b.Func("stub_atomic_compare_exchange_weak_i32",
+ utils::Vector{b.Param("p0", b.ty.i32()), b.Param("p1", b.ty.i32()),
+ b.Param("p2", b.ty.i32())},
+ b.ty.i32(),
+ utils::Vector{
+ b.Return(0_i),
+ },
+ utils::Vector{
+ b.ASTNodes().Create<SpirvAtomic::Stub>(
+ b.ID(), b.AllocateNodeID(), builtin::Function::kAtomicCompareExchangeWeak),
+ });
+
+ // Keep this pointer alive after Transform() returns
+ files_.emplace_back(std::move(file));
+
+ return TransformTest::Run<SpirvAtomic>(Program(std::move(b)));
+ }
+
+ private:
+ std::vector<std::unique_ptr<Source::File>> files_;
+};
+
+TEST_F(SpirvAtomicTest, StripUnusedBuiltins) {
+ auto* src = R"(
+fn f() {
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ArrayOfU32) {
+ auto* src = R"(
+var<workgroup> wg : array<u32, 4>;
+
+fn f() {
+ stub_atomicStore_u32(wg[1], 1u);
+}
+)";
+
+ auto* expect = R"(
+var<workgroup> wg : array<atomic<u32>, 4u>;
+
+fn f() {
+ atomicStore(&(wg[1]), 1u);
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ArraysOfU32) {
+ auto* src = R"(
+var<workgroup> wg : array<array<array<u32, 1>, 2>, 3>;
+
+fn f() {
+ stub_atomicStore_u32(wg[2][1][0], 1u);
+}
+)";
+
+ auto* expect = R"(
+var<workgroup> wg : array<array<array<atomic<u32>, 1u>, 2u>, 3u>;
+
+fn f() {
+ atomicStore(&(wg[2][1][0]), 1u);
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, AliasedArraysOfU32) {
+ auto* src = R"(
+alias A0 = u32;
+
+alias A1 = array<A0, 1>;
+
+alias A2 = array<A1, 2>;
+
+alias A3 = array<A2, 3>;
+
+var<workgroup> wg : A3;
+
+fn f() {
+ stub_atomicStore_u32(wg[2][1][0], 1u);
+}
+)";
+
+ auto* expect = R"(
+alias A0 = u32;
+
+alias A1 = array<A0, 1>;
+
+alias A2 = array<A1, 2>;
+
+alias A3 = array<A2, 3>;
+
+var<workgroup> wg : array<array<array<atomic<u32>, 1u>, 2u>, 3u>;
+
+fn f() {
+ atomicStore(&(wg[2][1][0]), 1u);
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, FlatStructSingleAtomic) {
+ auto* src = R"(
+struct S {
+ a : u32,
+}
+
+var<workgroup> wg : S;
+
+fn f() {
+ stub_atomicStore_u32(wg.a, 1u);
+}
+)";
+
+ auto* expect = R"(
+struct S_atomic {
+ a : atomic<u32>,
+}
+
+struct S {
+ a : u32,
+}
+
+var<workgroup> wg : S_atomic;
+
+fn f() {
+ atomicStore(&(wg.a), 1u);
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, FlatStructMultipleAtomic) {
+ auto* src = R"(
+struct S {
+ a : u32,
+ b : i32,
+}
+
+var<workgroup> wg : S;
+
+fn f1() {
+ stub_atomicStore_u32(wg.a, 1u);
+}
+
+fn f2() {
+ stub_atomicStore_i32(wg.b, 2i);
+}
+)";
+
+ auto* expect = R"(
+struct S_atomic {
+ a : atomic<u32>,
+ b : atomic<i32>,
+}
+
+struct S {
+ a : u32,
+ b : i32,
+}
+
+var<workgroup> wg : S_atomic;
+
+fn f1() {
+ atomicStore(&(wg.a), 1u);
+}
+
+fn f2() {
+ atomicStore(&(wg.b), 2i);
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, NestedStruct) {
+ auto* src = R"(
+struct S0 {
+ a : u32,
+ b : i32,
+ c : u32,
+}
+
+struct S1 {
+ a : i32,
+ b : u32,
+ c : S0,
+}
+
+struct S2 {
+ a : i32,
+ b : S1,
+ c : u32,
+}
+
+var<workgroup> wg : S2;
+
+fn f() {
+ stub_atomicStore_u32(wg.b.c.a, 1u);
+}
+)";
+
+ auto* expect = R"(
+struct S0_atomic {
+ a : atomic<u32>,
+ b : i32,
+ c : u32,
+}
+
+struct S0 {
+ a : u32,
+ b : i32,
+ c : u32,
+}
+
+struct S1_atomic {
+ a : i32,
+ b : u32,
+ c : S0_atomic,
+}
+
+struct S1 {
+ a : i32,
+ b : u32,
+ c : S0,
+}
+
+struct S2_atomic {
+ a : i32,
+ b : S1_atomic,
+ c : u32,
+}
+
+struct S2 {
+ a : i32,
+ b : S1,
+ c : u32,
+}
+
+var<workgroup> wg : S2_atomic;
+
+fn f() {
+ atomicStore(&(wg.b.c.a), 1u);
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ArrayOfStruct) {
+ auto* src = R"(
+struct S {
+ a : u32,
+ b : i32,
+ c : u32,
+}
+
+@group(0) @binding(1) var<storage, read_write> arr : array<S>;
+
+fn f() {
+ stub_atomicStore_i32(arr[4].b, 1i);
+}
+)";
+
+ auto* expect = R"(
+struct S_atomic {
+ a : u32,
+ b : atomic<i32>,
+ c : u32,
+}
+
+struct S {
+ a : u32,
+ b : i32,
+ c : u32,
+}
+
+@group(0) @binding(1) var<storage, read_write> arr : array<S_atomic>;
+
+fn f() {
+ atomicStore(&(arr[4].b), 1i);
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, StructOfArray) {
+ auto* src = R"(
+struct S {
+ a : array<i32>,
+}
+
+@group(0) @binding(1) var<storage, read_write> s : S;
+
+fn f() {
+ stub_atomicStore_i32(s.a[4], 1i);
+}
+)";
+
+ auto* expect = R"(
+struct S_atomic {
+ a : array<atomic<i32>>,
+}
+
+struct S {
+ a : array<i32>,
+}
+
+@group(0) @binding(1) var<storage, read_write> s : S_atomic;
+
+fn f() {
+ atomicStore(&(s.a[4]), 1i);
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ViaPtrLet) {
+ auto* src = R"(
+struct S {
+ i : i32,
+}
+
+@group(0) @binding(1) var<storage, read_write> s : S;
+
+fn f() {
+ let p0 = &(s);
+ let p1 : ptr<storage, i32, read_write> = &((*(p0)).i);
+ stub_atomicStore_i32(*p1, 1i);
+}
+)";
+
+ auto* expect =
+ R"(
+struct S_atomic {
+ i : atomic<i32>,
+}
+
+struct S {
+ i : i32,
+}
+
+@group(0) @binding(1) var<storage, read_write> s : S_atomic;
+
+fn f() {
+ let p0 = &(s);
+ let p1 : ptr<storage, atomic<i32>, read_write> = &((*(p0)).i);
+ atomicStore(&(*(p1)), 1i);
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, StructIsolatedMixedUsage) {
+ auto* src = R"(
+struct S {
+ i : i32,
+}
+
+@group(0) @binding(1) var<storage, read_write> s : S;
+
+fn f() {
+ stub_atomicStore_i32(s.i, 1i);
+}
+
+fn another_usage() {
+ var s : S;
+ let x : i32 = s.i;
+ s.i = 3i;
+}
+)";
+
+ auto* expect =
+ R"(
+struct S_atomic {
+ i : atomic<i32>,
+}
+
+struct S {
+ i : i32,
+}
+
+@group(0) @binding(1) var<storage, read_write> s : S_atomic;
+
+fn f() {
+ atomicStore(&(s.i), 1i);
+}
+
+fn another_usage() {
+ var s : S;
+ let x : i32 = s.i;
+ s.i = 3i;
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, AtomicLoad) {
+ auto* src = R"(
+var<workgroup> wg_u32 : u32;
+var<workgroup> wg_i32 : i32;
+@group(0) @binding(0) var<storage, read_write> sg_u32 : u32;
+@group(0) @binding(1) var<storage, read_write> sg_i32 : i32;
+
+fn f() {
+ {let r = stub_atomicLoad_u32(wg_u32);}
+ {let r = stub_atomicLoad_i32(wg_i32);}
+ {let r = stub_atomicLoad_u32(sg_u32);}
+ {let r = stub_atomicLoad_i32(sg_i32);}
+}
+)";
+
+ auto* expect =
+ R"(
+var<workgroup> wg_u32 : atomic<u32>;
+
+var<workgroup> wg_i32 : atomic<i32>;
+
+@group(0) @binding(0) var<storage, read_write> sg_u32 : atomic<u32>;
+
+@group(0) @binding(1) var<storage, read_write> sg_i32 : atomic<i32>;
+
+fn f() {
+ {
+ let r = atomicLoad(&(wg_u32));
+ }
+ {
+ let r = atomicLoad(&(wg_i32));
+ }
+ {
+ let r = atomicLoad(&(sg_u32));
+ }
+ {
+ let r = atomicLoad(&(sg_i32));
+ }
+}
+)";
+
+ auto got = Run(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, AtomicExchange) {
+ auto* src = R"(
+var<workgroup> wg_u32 : u32;
+var<workgroup> wg_i32 : i32;
+@group(0) @binding(0) var<storage, read_write> sg_u32 : u32;
+@group(0) @binding(1) var<storage, read_write> sg_i32 : i32;
+
+fn f() {
+ {let r = stub_atomicExchange_u32(wg_u32, 123u);}
+ {let r = stub_atomicExchange_i32(wg_i32, 123i);}
+ {let r = stub_atomicExchange_u32(sg_u32, 123u);}
+ {let r = stub_atomicExchange_i32(sg_i32, 123i);}
+}
+)";
+
+ auto* expect =
+ R"(
+var<workgroup> wg_u32 : atomic<u32>;
+
+var<workgroup> wg_i32 : atomic<i32>;
+
+@group(0) @binding(0) var<storage, read_write> sg_u32 : atomic<u32>;
+
+@group(0) @binding(1) var<storage, read_write> sg_i32 : atomic<i32>;
+
+fn f() {
+ {
+ let r = atomicExchange(&(wg_u32), 123u);
+ }
+ {
+ let r = atomicExchange(&(wg_i32), 123i);
+ }
+ {
+ let r = atomicExchange(&(sg_u32), 123u);
+ }
+ {
+ let r = atomicExchange(&(sg_i32), 123i);
+ }
+}
+)";
+
+ auto got = Run(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, AtomicAdd) {
+ auto* src = R"(
+var<workgroup> wg_u32 : u32;
+var<workgroup> wg_i32 : i32;
+@group(0) @binding(0) var<storage, read_write> sg_u32 : u32;
+@group(0) @binding(1) var<storage, read_write> sg_i32 : i32;
+
+fn f() {
+ {let r = stub_atomicAdd_u32(wg_u32, 123u);}
+ {let r = stub_atomicAdd_i32(wg_i32, 123i);}
+ {let r = stub_atomicAdd_u32(sg_u32, 123u);}
+ {let r = stub_atomicAdd_i32(sg_i32, 123i);}
+}
+)";
+
+ auto* expect =
+ R"(
+var<workgroup> wg_u32 : atomic<u32>;
+
+var<workgroup> wg_i32 : atomic<i32>;
+
+@group(0) @binding(0) var<storage, read_write> sg_u32 : atomic<u32>;
+
+@group(0) @binding(1) var<storage, read_write> sg_i32 : atomic<i32>;
+
+fn f() {
+ {
+ let r = atomicAdd(&(wg_u32), 123u);
+ }
+ {
+ let r = atomicAdd(&(wg_i32), 123i);
+ }
+ {
+ let r = atomicAdd(&(sg_u32), 123u);
+ }
+ {
+ let r = atomicAdd(&(sg_i32), 123i);
+ }
+}
+)";
+
+ auto got = Run(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, AtomicSub) {
+ auto* src = R"(
+var<workgroup> wg_u32 : u32;
+var<workgroup> wg_i32 : i32;
+@group(0) @binding(0) var<storage, read_write> sg_u32 : u32;
+@group(0) @binding(1) var<storage, read_write> sg_i32 : i32;
+
+fn f() {
+ {let r = stub_atomicSub_u32(wg_u32, 123u);}
+ {let r = stub_atomicSub_i32(wg_i32, 123i);}
+ {let r = stub_atomicSub_u32(sg_u32, 123u);}
+ {let r = stub_atomicSub_i32(sg_i32, 123i);}
+}
+)";
+
+ auto* expect =
+ R"(
+var<workgroup> wg_u32 : atomic<u32>;
+
+var<workgroup> wg_i32 : atomic<i32>;
+
+@group(0) @binding(0) var<storage, read_write> sg_u32 : atomic<u32>;
+
+@group(0) @binding(1) var<storage, read_write> sg_i32 : atomic<i32>;
+
+fn f() {
+ {
+ let r = atomicSub(&(wg_u32), 123u);
+ }
+ {
+ let r = atomicSub(&(wg_i32), 123i);
+ }
+ {
+ let r = atomicSub(&(sg_u32), 123u);
+ }
+ {
+ let r = atomicSub(&(sg_i32), 123i);
+ }
+}
+)";
+
+ auto got = Run(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, AtomicMin) {
+ auto* src = R"(
+var<workgroup> wg_u32 : u32;
+var<workgroup> wg_i32 : i32;
+@group(0) @binding(0) var<storage, read_write> sg_u32 : u32;
+@group(0) @binding(1) var<storage, read_write> sg_i32 : i32;
+
+fn f() {
+ {let r = stub_atomicMin_u32(wg_u32, 123u);}
+ {let r = stub_atomicMin_i32(wg_i32, 123i);}
+ {let r = stub_atomicMin_u32(sg_u32, 123u);}
+ {let r = stub_atomicMin_i32(sg_i32, 123i);}
+}
+)";
+
+ auto* expect =
+ R"(
+var<workgroup> wg_u32 : atomic<u32>;
+
+var<workgroup> wg_i32 : atomic<i32>;
+
+@group(0) @binding(0) var<storage, read_write> sg_u32 : atomic<u32>;
+
+@group(0) @binding(1) var<storage, read_write> sg_i32 : atomic<i32>;
+
+fn f() {
+ {
+ let r = atomicMin(&(wg_u32), 123u);
+ }
+ {
+ let r = atomicMin(&(wg_i32), 123i);
+ }
+ {
+ let r = atomicMin(&(sg_u32), 123u);
+ }
+ {
+ let r = atomicMin(&(sg_i32), 123i);
+ }
+}
+)";
+
+ auto got = Run(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, AtomicMax) {
+ auto* src = R"(
+var<workgroup> wg_u32 : u32;
+var<workgroup> wg_i32 : i32;
+@group(0) @binding(0) var<storage, read_write> sg_u32 : u32;
+@group(0) @binding(1) var<storage, read_write> sg_i32 : i32;
+
+fn f() {
+ {let r = stub_atomicMax_u32(wg_u32, 123u);}
+ {let r = stub_atomicMax_i32(wg_i32, 123i);}
+ {let r = stub_atomicMax_u32(sg_u32, 123u);}
+ {let r = stub_atomicMax_i32(sg_i32, 123i);}
+}
+)";
+
+ auto* expect =
+ R"(
+var<workgroup> wg_u32 : atomic<u32>;
+
+var<workgroup> wg_i32 : atomic<i32>;
+
+@group(0) @binding(0) var<storage, read_write> sg_u32 : atomic<u32>;
+
+@group(0) @binding(1) var<storage, read_write> sg_i32 : atomic<i32>;
+
+fn f() {
+ {
+ let r = atomicMax(&(wg_u32), 123u);
+ }
+ {
+ let r = atomicMax(&(wg_i32), 123i);
+ }
+ {
+ let r = atomicMax(&(sg_u32), 123u);
+ }
+ {
+ let r = atomicMax(&(sg_i32), 123i);
+ }
+}
+)";
+
+ auto got = Run(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, AtomicAnd) {
+ auto* src = R"(
+var<workgroup> wg_u32 : u32;
+var<workgroup> wg_i32 : i32;
+@group(0) @binding(0) var<storage, read_write> sg_u32 : u32;
+@group(0) @binding(1) var<storage, read_write> sg_i32 : i32;
+
+fn f() {
+ {let r = stub_atomicAnd_u32(wg_u32, 123u);}
+ {let r = stub_atomicAnd_i32(wg_i32, 123i);}
+ {let r = stub_atomicAnd_u32(sg_u32, 123u);}
+ {let r = stub_atomicAnd_i32(sg_i32, 123i);}
+}
+)";
+
+ auto* expect =
+ R"(
+var<workgroup> wg_u32 : atomic<u32>;
+
+var<workgroup> wg_i32 : atomic<i32>;
+
+@group(0) @binding(0) var<storage, read_write> sg_u32 : atomic<u32>;
+
+@group(0) @binding(1) var<storage, read_write> sg_i32 : atomic<i32>;
+
+fn f() {
+ {
+ let r = atomicAnd(&(wg_u32), 123u);
+ }
+ {
+ let r = atomicAnd(&(wg_i32), 123i);
+ }
+ {
+ let r = atomicAnd(&(sg_u32), 123u);
+ }
+ {
+ let r = atomicAnd(&(sg_i32), 123i);
+ }
+}
+)";
+
+ auto got = Run(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, AtomicOr) {
+ auto* src = R"(
+var<workgroup> wg_u32 : u32;
+var<workgroup> wg_i32 : i32;
+@group(0) @binding(0) var<storage, read_write> sg_u32 : u32;
+@group(0) @binding(1) var<storage, read_write> sg_i32 : i32;
+
+fn f() {
+ {let r = stub_atomicOr_u32(wg_u32, 123u);}
+ {let r = stub_atomicOr_i32(wg_i32, 123i);}
+ {let r = stub_atomicOr_u32(sg_u32, 123u);}
+ {let r = stub_atomicOr_i32(sg_i32, 123i);}
+}
+)";
+
+ auto* expect =
+ R"(
+var<workgroup> wg_u32 : atomic<u32>;
+
+var<workgroup> wg_i32 : atomic<i32>;
+
+@group(0) @binding(0) var<storage, read_write> sg_u32 : atomic<u32>;
+
+@group(0) @binding(1) var<storage, read_write> sg_i32 : atomic<i32>;
+
+fn f() {
+ {
+ let r = atomicOr(&(wg_u32), 123u);
+ }
+ {
+ let r = atomicOr(&(wg_i32), 123i);
+ }
+ {
+ let r = atomicOr(&(sg_u32), 123u);
+ }
+ {
+ let r = atomicOr(&(sg_i32), 123i);
+ }
+}
+)";
+
+ auto got = Run(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, AtomicXor) {
+ auto* src = R"(
+var<workgroup> wg_u32 : u32;
+var<workgroup> wg_i32 : i32;
+@group(0) @binding(0) var<storage, read_write> sg_u32 : u32;
+@group(0) @binding(1) var<storage, read_write> sg_i32 : i32;
+
+fn f() {
+ {let r = stub_atomicXor_u32(wg_u32, 123u);}
+ {let r = stub_atomicXor_i32(wg_i32, 123i);}
+ {let r = stub_atomicXor_u32(sg_u32, 123u);}
+ {let r = stub_atomicXor_i32(sg_i32, 123i);}
+}
+)";
+
+ auto* expect =
+ R"(
+var<workgroup> wg_u32 : atomic<u32>;
+
+var<workgroup> wg_i32 : atomic<i32>;
+
+@group(0) @binding(0) var<storage, read_write> sg_u32 : atomic<u32>;
+
+@group(0) @binding(1) var<storage, read_write> sg_i32 : atomic<i32>;
+
+fn f() {
+ {
+ let r = atomicXor(&(wg_u32), 123u);
+ }
+ {
+ let r = atomicXor(&(wg_i32), 123i);
+ }
+ {
+ let r = atomicXor(&(sg_u32), 123u);
+ }
+ {
+ let r = atomicXor(&(sg_i32), 123i);
+ }
+}
+)";
+
+ auto got = Run(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, AtomicCompareExchangeWeak) {
+ auto* src = R"(
+var<workgroup> wg_u32 : u32;
+var<workgroup> wg_i32 : i32;
+@group(0) @binding(0) var<storage, read_write> sg_u32 : u32;
+@group(0) @binding(1) var<storage, read_write> sg_i32 : i32;
+
+fn f() {
+ {let r = stub_atomic_compare_exchange_weak_u32(wg_u32, 123u, 456u);}
+ {let r = stub_atomic_compare_exchange_weak_i32(wg_i32, 123i, 456i);}
+ {let r = stub_atomic_compare_exchange_weak_u32(sg_u32, 123u, 456u);}
+ {let r = stub_atomic_compare_exchange_weak_i32(sg_i32, 123i, 456i);}
+}
+)";
+
+ auto* expect =
+ R"(
+var<workgroup> wg_u32 : atomic<u32>;
+
+var<workgroup> wg_i32 : atomic<i32>;
+
+@group(0) @binding(0) var<storage, read_write> sg_u32 : atomic<u32>;
+
+@group(0) @binding(1) var<storage, read_write> sg_i32 : atomic<i32>;
+
+fn f() {
+ {
+ let old_value = atomicCompareExchangeWeak(&(wg_u32), 123u, 456u).old_value;
+ let r = old_value;
+ }
+ {
+ let old_value_2 = atomicCompareExchangeWeak(&(wg_i32), 123i, 456i).old_value;
+ let r = old_value_2;
+ }
+ {
+ let old_value_1 = atomicCompareExchangeWeak(&(sg_u32), 123u, 456u).old_value;
+ let r = old_value_1;
+ }
+ {
+ let old_value_3 = atomicCompareExchangeWeak(&(sg_i32), 123i, 456i).old_value;
+ let r = old_value_3;
+ }
+}
+)";
+
+ auto got = Run(src);
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ReplaceAssignsAndDecls_Scaler) {
+ auto* src = R"(
+var<workgroup> wg : u32;
+
+fn f() {
+ stub_atomicAdd_u32(wg, 1u);
+
+ wg = 0u;
+ let a = wg;
+ var b : u32;
+ b = wg;
+}
+)";
+
+ auto* expect = R"(
+var<workgroup> wg : atomic<u32>;
+
+fn f() {
+ atomicAdd(&(wg), 1u);
+ atomicStore(&(wg), 0u);
+ let a = atomicLoad(&(wg));
+ var b : u32;
+ b = atomicLoad(&(wg));
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ReplaceAssignsAndDecls_Struct) {
+ auto* src = R"(
+struct S {
+ a : u32,
+}
+
+var<workgroup> wg : S;
+
+fn f() {
+ stub_atomicAdd_u32(wg.a, 1u);
+
+ wg.a = 0u;
+ let a = wg.a;
+ var b : u32;
+ b = wg.a;
+}
+)";
+
+ auto* expect = R"(
+struct S_atomic {
+ a : atomic<u32>,
+}
+
+struct S {
+ a : u32,
+}
+
+var<workgroup> wg : S_atomic;
+
+fn f() {
+ atomicAdd(&(wg.a), 1u);
+ atomicStore(&(wg.a), 0u);
+ let a = atomicLoad(&(wg.a));
+ var b : u32;
+ b = atomicLoad(&(wg.a));
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ReplaceAssignsAndDecls_NestedStruct) {
+ auto* src = R"(
+struct S0 {
+ a : u32,
+}
+
+struct S1 {
+ s0 : S0
+}
+
+var<workgroup> wg : S1;
+
+fn f() {
+ stub_atomicAdd_u32(wg.s0.a, 1u);
+
+ wg.s0.a = 0u;
+ let a = wg.s0.a;
+ var b : u32;
+ b = wg.s0.a;
+}
+)";
+
+ auto* expect = R"(
+struct S0_atomic {
+ a : atomic<u32>,
+}
+
+struct S0 {
+ a : u32,
+}
+
+struct S1_atomic {
+ s0 : S0_atomic,
+}
+
+struct S1 {
+ s0 : S0,
+}
+
+var<workgroup> wg : S1_atomic;
+
+fn f() {
+ atomicAdd(&(wg.s0.a), 1u);
+ atomicStore(&(wg.s0.a), 0u);
+ let a = atomicLoad(&(wg.s0.a));
+ var b : u32;
+ b = atomicLoad(&(wg.s0.a));
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ReplaceAssignsAndDecls_StructMultipleAtomics) {
+ auto* src = R"(
+struct S {
+ a : u32,
+ b : u32,
+ c : u32,
+}
+
+var<workgroup> wg : S;
+
+fn f() {
+ stub_atomicAdd_u32(wg.a, 1u);
+ stub_atomicAdd_u32(wg.b, 1u);
+
+ wg.a = 0u;
+ let a = wg.a;
+ var b : u32;
+ b = wg.a;
+
+ wg.b = 0u;
+ let c = wg.b;
+ var d : u32;
+ d = wg.b;
+
+ wg.c = 0u;
+ let e = wg.c;
+ var f : u32;
+ f = wg.c;
+}
+)";
+
+ auto* expect = R"(
+struct S_atomic {
+ a : atomic<u32>,
+ b : atomic<u32>,
+ c : u32,
+}
+
+struct S {
+ a : u32,
+ b : u32,
+ c : u32,
+}
+
+var<workgroup> wg : S_atomic;
+
+fn f() {
+ atomicAdd(&(wg.a), 1u);
+ atomicAdd(&(wg.b), 1u);
+ atomicStore(&(wg.a), 0u);
+ let a = atomicLoad(&(wg.a));
+ var b : u32;
+ b = atomicLoad(&(wg.a));
+ atomicStore(&(wg.b), 0u);
+ let c = atomicLoad(&(wg.b));
+ var d : u32;
+ d = atomicLoad(&(wg.b));
+ wg.c = 0u;
+ let e = wg.c;
+ var f : u32;
+ f = wg.c;
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ReplaceAssignsAndDecls_ArrayOfScalar) {
+ auto* src = R"(
+var<workgroup> wg : array<u32, 4>;
+
+fn f() {
+ stub_atomicAdd_u32(wg[1], 1u);
+
+ wg[1] = 0u;
+ let a = wg[1];
+ var b : u32;
+ b = wg[1];
+}
+)";
+
+ auto* expect = R"(
+var<workgroup> wg : array<atomic<u32>, 4u>;
+
+fn f() {
+ atomicAdd(&(wg[1]), 1u);
+ atomicStore(&(wg[1]), 0u);
+ let a = atomicLoad(&(wg[1]));
+ var b : u32;
+ b = atomicLoad(&(wg[1]));
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ReplaceAssignsAndDecls_ArrayOfStruct) {
+ auto* src = R"(
+struct S {
+ a : u32,
+}
+
+var<workgroup> wg : array<S, 4>;
+
+fn f() {
+ stub_atomicAdd_u32(wg[1].a, 1u);
+
+ wg[1].a = 0u;
+ let a = wg[1].a;
+ var b : u32;
+ b = wg[1].a;
+}
+)";
+
+ auto* expect = R"(
+struct S_atomic {
+ a : atomic<u32>,
+}
+
+struct S {
+ a : u32,
+}
+
+var<workgroup> wg : array<S_atomic, 4u>;
+
+fn f() {
+ atomicAdd(&(wg[1].a), 1u);
+ atomicStore(&(wg[1].a), 0u);
+ let a = atomicLoad(&(wg[1].a));
+ var b : u32;
+ b = atomicLoad(&(wg[1].a));
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ReplaceAssignsAndDecls_StructOfArray) {
+ auto* src = R"(
+struct S {
+ a : array<u32>,
+}
+
+@group(0) @binding(1) var<storage, read_write> s : S;
+
+fn f() {
+ stub_atomicAdd_u32(s.a[4], 1u);
+
+ s.a[4] = 0u;
+ let a = s.a[4];
+ var b : u32;
+ b = s.a[4];
+}
+)";
+
+ auto* expect = R"(
+struct S_atomic {
+ a : array<atomic<u32>>,
+}
+
+struct S {
+ a : array<u32>,
+}
+
+@group(0) @binding(1) var<storage, read_write> s : S_atomic;
+
+fn f() {
+ atomicAdd(&(s.a[4]), 1u);
+ atomicStore(&(s.a[4]), 0u);
+ let a = atomicLoad(&(s.a[4]));
+ var b : u32;
+ b = atomicLoad(&(s.a[4]));
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ReplaceAssignsAndDecls_ViaPtrLet) {
+ auto* src = R"(
+struct S {
+ i : u32,
+}
+
+@group(0) @binding(1) var<storage, read_write> s : S;
+
+fn f() {
+ let p0 = &(s);
+ let p1 : ptr<storage, u32, read_write> = &((*(p0)).i);
+ stub_atomicAdd_u32(*p1, 1u);
+
+ *p1 = 0u;
+ let a = *p1;
+ var b : u32;
+ b = *p1;
+}
+)";
+
+ auto* expect = R"(
+struct S_atomic {
+ i : atomic<u32>,
+}
+
+struct S {
+ i : u32,
+}
+
+@group(0) @binding(1) var<storage, read_write> s : S_atomic;
+
+fn f() {
+ let p0 = &(s);
+ let p1 : ptr<storage, atomic<u32>, read_write> = &((*(p0)).i);
+ atomicAdd(&(*(p1)), 1u);
+ atomicStore(&(*(p1)), 0u);
+ let a = atomicLoad(&(*(p1)));
+ var b : u32;
+ b = atomicLoad(&(*(p1)));
+}
+)";
+
+ auto got = Run(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/std140.cc b/src/tint/lang/wgsl/ast/transform/std140.cc
new file mode 100644
index 0000000..3b12c15
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/std140.cc
@@ -0,0 +1,1169 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/std140.h"
+
+#include <algorithm>
+#include <string>
+#include <utility>
+#include <variant>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/index_accessor_expression.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/module.h"
+#include "src/tint/sem/struct.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
+#include "src/tint/utils/compiler_macros.h"
+#include "src/tint/utils/hashmap.h"
+#include "src/tint/utils/transform.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::Std140);
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace {
+
+/// UniformVariable is used by Std140::State::AccessIndex to indicate the root uniform variable
+struct UniformVariable {};
+
+/// Inequality operator for UniformVariable
+bool operator!=(const UniformVariable&, const UniformVariable&) {
+ return false;
+}
+
+/// DynamicIndex is used by Std140::State::AccessIndex to indicate a runtime-expression index
+struct DynamicIndex {
+ size_t slot; // The index of the expression in Std140::State::AccessChain::dynamic_indices
+};
+
+/// Inequality operator for DynamicIndex
+bool operator!=(const DynamicIndex& a, const DynamicIndex& b) {
+ return a.slot != b.slot;
+}
+
+} // namespace
+
+namespace tint::utils {
+
+/// Hasher specialization for UniformVariable
+template <>
+struct Hasher<UniformVariable> {
+ /// The hash function for the UniformVariable
+ /// @return the hash for the given UniformVariable
+ size_t operator()(const UniformVariable&) const { return 0; }
+};
+
+/// Hasher specialization for DynamicIndex
+template <>
+struct Hasher<DynamicIndex> {
+ /// The hash function for the DynamicIndex
+ /// @param d the DynamicIndex to hash
+ /// @return the hash for the given DynamicIndex
+ size_t operator()(const DynamicIndex& d) const { return utils::Hash(d.slot); }
+};
+
+} // namespace tint::utils
+
+namespace tint::ast::transform {
+
+/// PIMPL state for the transform
+struct Std140::State {
+ /// Constructor
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
+
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
+ if (!ShouldRun()) {
+ // Transform is not required
+ return SkipTransform;
+ }
+
+ // Begin by creating forked types for any type that is used as a uniform buffer, that
+ // either directly or transitively contains a matrix that needs splitting for std140 layout.
+ ForkTypes();
+
+ // Next, replace all the uniform variables to use the forked types.
+ ReplaceUniformVarTypes();
+
+ // Finally, replace all expression chains that used the authored types with those that
+ // correctly use the forked types.
+ ctx.ReplaceAll([&](const Expression* expr) -> const Expression* {
+ if (auto access = AccessChainFor(expr)) {
+ if (!access->std140_mat_idx.has_value()) {
+ // loading a std140 type, which is not a whole or partial decomposed matrix
+ return LoadWithConvert(access.value());
+ }
+ if (!access->IsMatrixSubset() || // loading a whole matrix
+ std::holds_alternative<DynamicIndex>(
+ access->indices[*access->std140_mat_idx + 1])) {
+ // Whole object or matrix is loaded, or the matrix column is indexed with a
+ // non-constant index. Build a helper function to load the expression chain.
+ return LoadMatrixWithFn(access.value());
+ }
+ // Matrix column is statically indexed. Can be emitted as an inline expression.
+ return LoadSubMatrixInline(access.value());
+ }
+ // Expression isn't an access to a std140-layout uniform buffer.
+ // Just clone.
+ return nullptr;
+ });
+
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun() const {
+ // Returns true if the type needs to be forked for std140 usage.
+ auto needs_fork = [&](const type::Type* ty) {
+ while (auto* arr = ty->As<type::Array>()) {
+ ty = arr->ElemType();
+ }
+ if (auto* mat = ty->As<type::Matrix>()) {
+ if (MatrixNeedsDecomposing(mat)) {
+ return true;
+ }
+ }
+ return false;
+ };
+
+ // Scan structures for members that need forking
+ for (auto* ty : src->Types()) {
+ if (auto* str = ty->As<type::Struct>()) {
+ if (str->UsedAs(builtin::AddressSpace::kUniform)) {
+ for (auto* member : str->Members()) {
+ if (needs_fork(member->Type())) {
+ return true;
+ }
+ }
+ }
+ }
+ }
+
+ // Scan uniform variables that have types that need forking
+ for (auto* decl : src->AST().GlobalVariables()) {
+ auto* global = src->Sem().Get(decl);
+ if (global->AddressSpace() == builtin::AddressSpace::kUniform) {
+ if (needs_fork(global->Type()->UnwrapRef())) {
+ return true;
+ }
+ }
+ }
+
+ // If we reach here, no uniform variables use a type that needs forking for std140 layout
+ return false;
+ }
+
+ private:
+ /// Swizzle describes a vector swizzle
+ using Swizzle = utils::Vector<uint32_t, 4>;
+
+ /// AccessIndex describes a single access in an access chain.
+ /// The access is one of:
+ /// UniformVariable - the root uniform variable.
+ /// u32 - a static index on a struct, array index, matrix column or vector element.
+ /// DynamicIndex - a runtime index on an array, matrix column, or vector element.
+ /// Swizzle - a static vector swizzle.
+ using AccessIndex = std::variant<UniformVariable, u32, DynamicIndex, Swizzle>;
+
+ /// A vector of AccessIndex.
+ using AccessIndices = utils::Vector<AccessIndex, 8>;
+
+ /// A key used to cache load functions for an access chain.
+ struct LoadFnKey {
+ /// The root uniform buffer variable for the access chain.
+ const sem::GlobalVariable* var;
+
+ /// The chain of accesses indices.
+ AccessIndices indices;
+
+ /// Hash function for LoadFnKey.
+ struct Hasher {
+ /// @param fn the LoadFnKey to hash
+ /// @return the hash for the given LoadFnKey
+ size_t operator()(const LoadFnKey& fn) const { return utils::Hash(fn.var, fn.indices); }
+ };
+
+ /// Equality operator
+ bool operator==(const LoadFnKey& other) const {
+ return var == other.var && indices == other.indices;
+ }
+ };
+
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+ /// Alias to the semantic info in src
+ const sem::Info& sem = src->Sem();
+ /// Alias to the symbols in src
+ const SymbolTable& sym = src->Symbols();
+
+ /// Map of load function signature, to the generated function
+ utils::Hashmap<LoadFnKey, Symbol, 8, LoadFnKey::Hasher> load_fns;
+
+ /// Map of std140-forked type to converter function name
+ utils::Hashmap<const type::Type*, Symbol, 8> conv_fns;
+
+ // Uniform variables that have been modified to use a std140 type
+ utils::Hashset<const sem::Variable*, 8> std140_uniforms;
+
+ // Map of original structure to 'std140' forked structure
+ utils::Hashmap<const type::Struct*, Symbol, 8> std140_structs;
+
+ // Map of structure member in src of a matrix type, to list of decomposed column
+ // members in ctx.dst.
+ utils::Hashmap<const type::StructMember*, utils::Vector<const StructMember*, 4>, 8>
+ std140_mat_members;
+
+ /// Describes a matrix that has been forked to a std140-structure holding the decomposed column
+ /// vectors of the matrix.
+ struct Std140Matrix {
+ /// The decomposed structure name (in ctx.dst)
+ Symbol name;
+ /// The column vector structure member names (in ctx.dst)
+ utils::Vector<Symbol, 4> columns;
+ };
+
+ // Map of matrix type in src, to decomposed column structure in ctx.dst.
+ utils::Hashmap<const type::Matrix*, Std140Matrix, 8> std140_mats;
+
+ /// AccessChain describes a chain of access expressions to uniform buffer variable.
+ struct AccessChain {
+ /// The uniform buffer variable.
+ const sem::GlobalVariable* var;
+ /// The chain of access indices, starting with the first access on #var.
+ AccessIndices indices;
+ /// The runtime-evaluated expressions. This vector is indexed by the DynamicIndex::slot
+ utils::Vector<const sem::ValueExpression*, 8> dynamic_indices;
+ /// The type of the std140-decomposed matrix being accessed.
+ /// May be nullptr if the chain does not pass through a std140-decomposed matrix.
+ const type::Matrix* std140_mat_ty = nullptr;
+ /// The index in #indices of the access that resolves to the std140-decomposed matrix.
+ /// May hold no value if the chain does not pass through a std140-decomposed matrix.
+ std::optional<size_t> std140_mat_idx;
+
+ /// @returns true if the access chain is to part of (not the whole) std140-decomposed matrix
+ bool IsMatrixSubset() const {
+ return std140_mat_idx.has_value() && (std140_mat_idx.value() + 1 != indices.Length());
+ }
+ };
+
+ /// @returns true if the given matrix needs decomposing to column vectors for std140 layout.
+ /// Std140 layout require matrix stride to be 16, otherwise decomposing is needed.
+ static bool MatrixNeedsDecomposing(const type::Matrix* mat) {
+ return mat->ColumnStride() != 16;
+ }
+
+ /// ForkTypes walks the user-declared types in dependency order, forking structures that are
+ /// used as uniform buffers which (transitively) use matrices that need std140 decomposition to
+ /// column vectors. Populates the #std140_mat_members map, #std140_structs set and #std140_mats
+ /// map (via Std140Type()).
+ void ForkTypes() {
+ // For each module scope declaration...
+ for (auto* global : src->Sem().Module()->DependencyOrderedDeclarations()) {
+ // Check to see if this is a structure used by a uniform buffer...
+ auto* str = sem.Get<sem::Struct>(global);
+ if (str && str->UsedAs(builtin::AddressSpace::kUniform)) {
+ // Should this uniform buffer be forked for std140 usage?
+ bool fork_std140 = false;
+ utils::Vector<const StructMember*, 8> members;
+ for (auto* member : str->Members()) {
+ if (auto* mat = member->Type()->As<type::Matrix>()) {
+ // Is this member a matrix that needs decomposition for std140-layout?
+ if (MatrixNeedsDecomposing(mat)) {
+ // Structure member of matrix type needs decomposition.
+ fork_std140 = true;
+ // Replace the member with column vectors.
+ const auto name_prefix = PrefixForUniqueNames(
+ str->Declaration(), member->Name(), mat->columns());
+
+ // Build a struct member for each column of the matrix
+ auto column_members = DecomposedMatrixStructMembers(
+ mat, name_prefix, member->Align(), member->Size());
+
+ // Add the member to the forked structure
+ for (auto* column_member : column_members) {
+ members.Push(column_member);
+ }
+ // Record that this matrix member was replaced with the N column
+ // members.
+ std140_mat_members.Add(member, std::move(column_members));
+
+ continue; // Next member
+ }
+ } else if (auto std140_ty = Std140Type(member->Type())) {
+ // Member is of a type that requires forking for std140-layout
+ fork_std140 = true;
+ auto attrs = ctx.Clone(member->Declaration()->attributes);
+ members.Push(b.Member(member->Name().Name(), std140_ty, std::move(attrs)));
+ continue; // Next member
+ }
+
+ // Nothing special about this member.
+ // Push the member in src to members without first cloning. We'll replace this
+ // with a cloned member once we know whether we need to fork the structure or
+ // not.
+ members.Push(member->Declaration());
+ }
+
+ // Did any of the members require forking the structure?
+ if (fork_std140) {
+ // Clone any members that have not already been cloned.
+ for (auto& member : members) {
+ if (member->program_id == src->ID()) {
+ member = ctx.Clone(member);
+ }
+ }
+ // Create a new forked structure, and insert it just under the original
+ // structure.
+ auto name = b.Symbols().New(str->Name().Name() + "_std140");
+ auto* std140 = b.create<Struct>(b.Ident(name), std::move(members),
+ ctx.Clone(str->Declaration()->attributes));
+ ctx.InsertAfter(src->AST().GlobalDeclarations(), global, std140);
+ std140_structs.Add(str, name);
+ }
+ }
+ }
+ }
+
+ /// Walks the global variables, replacing the type of those that are a uniform buffer with a
+ /// type that has been forked for std140-layout.
+ /// Populates the #std140_uniforms set.
+ void ReplaceUniformVarTypes() {
+ for (auto* global : src->AST().GlobalVariables()) {
+ if (auto* var = global->As<Var>()) {
+ auto* v = sem.Get(var);
+ if (v->AddressSpace() == builtin::AddressSpace::kUniform) {
+ if (auto std140_ty = Std140Type(v->Type()->UnwrapRef())) {
+ ctx.Replace(global->type.expr, b.Expr(std140_ty));
+ std140_uniforms.Add(v);
+ }
+ }
+ }
+ }
+ }
+
+ /// @returns a unique structure member prefix for the splitting of a matrix member into @p count
+ /// column vector members. The new members must be suffixed with a zero-based index ranging from
+ /// `[0..count)`.
+ /// @param str the structure that will hold the uniquely named member.
+ /// @param unsuffixed the common name prefix to use for the new members.
+ /// @param count the number of members that need to be created.
+ std::string PrefixForUniqueNames(const Struct* str, Symbol unsuffixed, uint32_t count) const {
+ auto prefix = unsuffixed.Name();
+ // Keep on inserting '_' between the unsuffixed name and the suffix numbers until the name
+ // is unique.
+ while (true) {
+ prefix += "_";
+
+ utils::Hashset<std::string, 4> strings;
+ for (uint32_t i = 0; i < count; i++) {
+ strings.Add(prefix + std::to_string(i));
+ }
+
+ bool unique = true;
+ for (auto* member : str->members) {
+ // The member name must be unique over the entire set of `count` suffixed names.
+ if (strings.Contains(member->name->symbol.Name())) {
+ unique = false;
+ break;
+ }
+ }
+
+ if (unique) {
+ return prefix;
+ }
+ }
+ }
+
+ /// @returns a new, forked std140 AST type for the corresponding non-forked semantic type.
+ /// If the semantic type is not split for std140-layout, then nullptr is returned.
+ /// @note will construct new std140 structures to hold decomposed matrices, populating
+ /// #std140_mats.
+ Type Std140Type(const type::Type* ty) {
+ return Switch(
+ ty, //
+ [&](const type::Struct* str) {
+ if (auto std140 = std140_structs.Find(str)) {
+ return b.ty(*std140);
+ }
+ return Type{};
+ },
+ [&](const type::Matrix* mat) {
+ if (MatrixNeedsDecomposing(mat)) {
+ auto std140_mat = std140_mats.GetOrCreate(mat, [&] {
+ auto name = b.Symbols().New("mat" + std::to_string(mat->columns()) + "x" +
+ std::to_string(mat->rows()) + "_" +
+ mat->type()->FriendlyName());
+ auto members =
+ DecomposedMatrixStructMembers(mat, "col", mat->Align(), mat->Size());
+ b.Structure(name, members);
+ return Std140Matrix{
+ name,
+ utils::Transform(members,
+ [&](auto* member) { return member->name->symbol; }),
+ };
+ });
+ return b.ty(std140_mat.name);
+ }
+ return Type{};
+ },
+ [&](const type::Array* arr) {
+ if (auto std140 = Std140Type(arr->ElemType())) {
+ utils::Vector<const Attribute*, 1> attrs;
+ if (!arr->IsStrideImplicit()) {
+ attrs.Push(b.create<StrideAttribute>(arr->Stride()));
+ }
+ auto count = arr->ConstantCount();
+ if (TINT_UNLIKELY(!count)) {
+ // Non-constant counts should not be possible:
+ // * Override-expression counts can only be applied to workgroup arrays, and
+ // this method only handles types transitively used as uniform buffers.
+ // * Runtime-sized arrays cannot be used in uniform buffers.
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unexpected non-constant array count";
+ count = 1;
+ }
+ return b.ty.array(std140, b.Expr(u32(count.value())), std::move(attrs));
+ }
+ return Type{};
+ });
+ }
+
+ /// @param mat the matrix to decompose (in src)
+ /// @param name_prefix the name prefix to apply to each of the returned column vector members.
+ /// @param align the alignment in bytes of the matrix.
+ /// @param size the size in bytes of the matrix.
+ /// @returns a vector of decomposed matrix column vectors as structure members (in ctx.dst).
+ utils::Vector<const StructMember*, 4> DecomposedMatrixStructMembers(
+ const type::Matrix* mat,
+ const std::string& name_prefix,
+ uint32_t align,
+ uint32_t size) {
+ // Replace the member with column vectors.
+ const auto num_columns = mat->columns();
+ // Build a struct member for each column of the matrix
+ utils::Vector<const StructMember*, 4> out;
+ for (uint32_t i = 0; i < num_columns; i++) {
+ utils::Vector<const Attribute*, 1> attributes;
+ if ((i == 0) && mat->Align() != align) {
+ // The matrix was @align() annotated with a larger alignment
+ // than the natural alignment for the matrix. This extra padding
+ // needs to be applied to the first column vector.
+ attributes.Push(b.MemberAlign(i32(align)));
+ }
+ if ((i == num_columns - 1) && mat->Size() != size) {
+ // The matrix was @size() annotated with a larger size than the
+ // natural size for the matrix. This extra padding needs to be
+ // applied to the last column vector.
+ attributes.Push(
+ b.MemberSize(AInt(size - mat->ColumnType()->Align() * (num_columns - 1))));
+ }
+
+ // Build the member
+ const auto col_name = name_prefix + std::to_string(i);
+ const auto col_ty = CreateASTTypeFor(ctx, mat->ColumnType());
+ const auto* col_member = b.Member(col_name, col_ty, std::move(attributes));
+ // Record the member for std140_mat_members
+ out.Push(col_member);
+ }
+ return out;
+ }
+
+ /// Walks the @p ast_expr, constructing and returning an AccessChain.
+ /// @returns an AccessChain if the expression is an access to a std140-forked uniform buffer,
+ /// otherwise returns a std::nullopt.
+ std::optional<AccessChain> AccessChainFor(const Expression* ast_expr) {
+ auto* expr = sem.GetVal(ast_expr);
+ if (!expr) {
+ return std::nullopt;
+ }
+
+ AccessChain access;
+
+ // Start by looking at the root identifier. This must be a std140-forked uniform buffer.
+ access.var = tint::As<sem::GlobalVariable>(expr->RootIdentifier());
+ if (!access.var || !std140_uniforms.Contains(access.var)) {
+ // Not at std140-forked uniform buffer access chain.
+ return std::nullopt;
+ }
+
+ // Walk from the outer-most expression, inwards towards the root identifier.
+ while (true) {
+ enum class Action { kStop, kContinue, kError };
+ Action action = Switch(
+ expr->Unwrap(), //
+ [&](const sem::VariableUser* user) {
+ if (user->Variable() == access.var) {
+ // Walked all the way to the root identifier. We're done traversing.
+ access.indices.Push(UniformVariable{});
+ return Action::kStop;
+ }
+ if (TINT_LIKELY(user->Variable()->Type()->Is<type::Pointer>())) {
+ // Found a pointer. As the root identifier is a uniform buffer variable,
+ // this must be a pointer-let. Continue traversing from the let
+ // initializer.
+ expr = user->Variable()->Initializer();
+ return Action::kContinue;
+ }
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unexpected variable found walking access chain: "
+ << user->Variable()->Declaration()->name->symbol.Name();
+ return Action::kError;
+ },
+ [&](const sem::StructMemberAccess* a) {
+ // Is this a std140 decomposed matrix?
+ if (std140_mat_members.Contains(a->Member())) {
+ // Record this on the access.
+ access.std140_mat_idx = access.indices.Length();
+ access.std140_mat_ty = expr->Type()->UnwrapRef()->As<type::Matrix>();
+ }
+ // Structure member accesses are always statically indexed
+ access.indices.Push(u32(a->Member()->Index()));
+ expr = a->Object();
+ return Action::kContinue;
+ },
+ [&](const sem::IndexAccessorExpression* a) {
+ // Array, matrix or vector index.
+ if (auto* val = a->Index()->ConstantValue()) {
+ access.indices.Push(val->ValueAs<u32>());
+ } else {
+ access.indices.Push(DynamicIndex{access.dynamic_indices.Length()});
+ access.dynamic_indices.Push(a->Index());
+ }
+ expr = a->Object();
+
+ // Is the object a std140 decomposed matrix?
+ if (auto* mat = expr->Type()->UnwrapRef()->As<type::Matrix>()) {
+ if (std140_mats.Contains(mat)) {
+ // Record this on the access.
+ access.std140_mat_idx = access.indices.Length();
+ access.std140_mat_ty = mat;
+ }
+ }
+ return Action::kContinue;
+ },
+ [&](const sem::Swizzle* s) {
+ // Vector swizzle.
+ if (s->Indices().Length() == 1) {
+ access.indices.Push(u32(s->Indices()[0]));
+ } else {
+ access.indices.Push(s->Indices());
+ }
+ expr = s->Object();
+ return Action::kContinue;
+ },
+ [&](const sem::ValueExpression* e) {
+ // Walk past indirection and address-of unary ops.
+ return Switch(e->Declaration(), //
+ [&](const UnaryOpExpression* u) {
+ switch (u->op) {
+ case UnaryOp::kAddressOf:
+ case UnaryOp::kIndirection:
+ expr = sem.GetVal(u->expr);
+ return Action::kContinue;
+ default:
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled unary op for access chain: "
+ << u->op;
+ return Action::kError;
+ }
+ });
+ },
+ [&](Default) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled expression type for access chain\n"
+ << "AST: " << expr->Declaration()->TypeInfo().name << "\n"
+ << "SEM: " << expr->TypeInfo().name;
+ return Action::kError;
+ });
+
+ switch (action) {
+ case Action::kContinue:
+ continue;
+ case Action::kStop:
+ break;
+ case Action::kError:
+ return std::nullopt;
+ }
+
+ break;
+ }
+
+ // As the access walked from RHS to LHS, the last index operation applies to the source
+ // variable. We want this the other way around, so reverse the arrays and fix indicies.
+ std::reverse(access.indices.begin(), access.indices.end());
+ std::reverse(access.dynamic_indices.begin(), access.dynamic_indices.end());
+ if (access.std140_mat_idx.has_value()) {
+ access.std140_mat_idx = access.indices.Length() - *access.std140_mat_idx - 1;
+ }
+ for (auto& index : access.indices) {
+ if (auto* dyn_idx = std::get_if<DynamicIndex>(&index)) {
+ dyn_idx->slot = access.dynamic_indices.Length() - dyn_idx->slot - 1;
+ }
+ }
+
+ return access;
+ }
+
+ /// @returns a name suffix for a std140 -> non-std140 conversion function based on the type
+ /// being converted.
+ const std::string ConvertSuffix(const type::Type* ty) {
+ return Switch(
+ ty, //
+ [&](const type::Struct* str) { return str->Name().Name(); },
+ [&](const type::Array* arr) {
+ auto count = arr->ConstantCount();
+ if (TINT_UNLIKELY(!count)) {
+ // Non-constant counts should not be possible:
+ // * Override-expression counts can only be applied to workgroup arrays, and
+ // this method only handles types transitively used as uniform buffers.
+ // * Runtime-sized arrays cannot be used in uniform buffers.
+ TINT_ICE(Transform, b.Diagnostics()) << "unexpected non-constant array count";
+ count = 1;
+ }
+ return "arr" + std::to_string(count.value()) + "_" + ConvertSuffix(arr->ElemType());
+ },
+ [&](const type::Matrix* mat) {
+ return "mat" + std::to_string(mat->columns()) + "x" + std::to_string(mat->rows()) +
+ "_" + ConvertSuffix(mat->type());
+ },
+ [&](const type::F32*) { return "f32"; }, //
+ [&](const type::F16*) { return "f16"; },
+ [&](Default) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled type for conversion name: " << ty->FriendlyName();
+ return "";
+ });
+ }
+
+ /// Generates and returns an expression that loads the value from a std140 uniform buffer,
+ /// converting the final result to a non-std140 type.
+ /// @param chain the access chain from a uniform buffer to the value to load.
+ const Expression* LoadWithConvert(const AccessChain& chain) {
+ const Expression* expr = nullptr;
+ const type::Type* ty = nullptr;
+ auto dynamic_index = [&](size_t idx) {
+ return ctx.Clone(chain.dynamic_indices[idx]->Declaration());
+ };
+ for (size_t i = 0; i < chain.indices.Length(); i++) {
+ auto [new_expr, new_ty, _] = BuildAccessExpr(expr, ty, chain, i, dynamic_index);
+ expr = new_expr;
+ ty = new_ty;
+ }
+ return Convert(ty, expr);
+ }
+
+ /// Generates and returns an expression that converts the expression @p expr of the
+ /// std140-forked type to the type @p ty. If @p expr is not a std140-forked type, then Convert()
+ /// will simply return @p expr.
+ /// @returns the converted value expression.
+ const Expression* Convert(const type::Type* ty, const Expression* expr) {
+ // Get an existing, or create a new function for converting the std140 type to ty.
+ auto fn = conv_fns.GetOrCreate(ty, [&] {
+ auto std140_ty = Std140Type(ty);
+ if (!std140_ty) {
+ // ty was not forked for std140.
+ return Symbol{};
+ }
+
+ // The converter function takes a single argument of the std140 type.
+ auto* param = b.Param("val", std140_ty);
+
+ utils::Vector<const Statement*, 3> stmts;
+
+ Switch(
+ ty, //
+ [&](const type::Struct* str) {
+ // Convert each of the structure members using either a converter function
+ // call, or by reassembling a std140 matrix from column vector members.
+ utils::Vector<const Expression*, 8> args;
+ for (auto* member : str->Members()) {
+ if (auto col_members = std140_mat_members.Find(member)) {
+ // std140 decomposed matrix. Reassemble.
+ auto mat_ty = CreateASTTypeFor(ctx, member->Type());
+ auto mat_args =
+ utils::Transform(*col_members, [&](const StructMember* m) {
+ return b.MemberAccessor(param, m->name->symbol);
+ });
+ args.Push(b.Call(mat_ty, std::move(mat_args)));
+ } else {
+ // Convert the member
+ args.Push(Convert(member->Type(),
+ b.MemberAccessor(param, member->Name().Name())));
+ }
+ }
+ stmts.Push(b.Return(b.Call(CreateASTTypeFor(ctx, ty), std::move(args))));
+ }, //
+ [&](const type::Matrix* mat) {
+ // Reassemble a std140 matrix from the structure of column vector members.
+ auto std140_mat = std140_mats.Get(mat);
+ if (TINT_LIKELY(std140_mat)) {
+ utils::Vector<const Expression*, 8> args;
+ // std140 decomposed matrix. Reassemble.
+ auto mat_ty = CreateASTTypeFor(ctx, mat);
+ auto mat_args = utils::Transform(std140_mat->columns, [&](Symbol name) {
+ return b.MemberAccessor(param, name);
+ });
+ stmts.Push(b.Return(b.Call(mat_ty, std::move(mat_args))));
+ } else {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "failed to find std140 matrix info for: " << ty->FriendlyName();
+ }
+ }, //
+ [&](const type::Array* arr) {
+ // Converting an array. Create a function var for the converted array, and
+ // loop over the input elements, converting each and assigning the result to
+ // the local array.
+ auto* var = b.Var("arr", CreateASTTypeFor(ctx, ty));
+ auto* i = b.Var("i", b.ty.u32());
+ auto* dst_el = b.IndexAccessor(var, i);
+ auto* src_el = Convert(arr->ElemType(), b.IndexAccessor(param, i));
+ auto count = arr->ConstantCount();
+ if (TINT_UNLIKELY(!count)) {
+ // Non-constant counts should not be possible:
+ // * Override-expression counts can only be applied to workgroup arrays, and
+ // this method only handles types transitively used as uniform buffers.
+ // * Runtime-sized arrays cannot be used in uniform buffers.
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unexpected non-constant array count";
+ count = 1;
+ }
+ stmts.Push(b.Decl(var));
+ stmts.Push(b.For(b.Decl(i), //
+ b.LessThan(i, u32(count.value())), //
+ b.Assign(i, b.Add(i, 1_a)), //
+ b.Block(b.Assign(dst_el, src_el))));
+ stmts.Push(b.Return(var));
+ },
+ [&](Default) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled type for conversion: " << ty->FriendlyName();
+ });
+
+ // Generate the function
+ auto ret_ty = CreateASTTypeFor(ctx, ty);
+ auto fn_sym = b.Symbols().New("conv_" + ConvertSuffix(ty));
+ b.Func(fn_sym, utils::Vector{param}, ret_ty, std::move(stmts));
+ return fn_sym;
+ });
+
+ if (!fn.IsValid()) {
+ // Not a std140 type, nothing to convert.
+ return expr;
+ }
+
+ // Call the helper
+ return b.Call(fn, utils::Vector{expr});
+ }
+
+ /// Loads a part of, or a whole std140-decomposed matrix from a uniform buffer, using a helper
+ /// function which will be generated if it hasn't been already.
+ /// @param access the access chain from the uniform buffer to either the whole matrix or part of
+ /// the matrix (column, column-swizzle, or element).
+ /// @returns the loaded value expression.
+ const Expression* LoadMatrixWithFn(const AccessChain& access) {
+ // Get an existing, or create a new function for loading the uniform buffer value.
+ // This function is keyed off the uniform buffer variable and the access chain.
+ auto fn = load_fns.GetOrCreate(LoadFnKey{access.var, access.indices}, [&] {
+ if (access.IsMatrixSubset()) {
+ // Access chain passes through the matrix, but ends either at a column vector,
+ // column swizzle, or element.
+ return BuildLoadPartialMatrixFn(access);
+ }
+ // Access is to the whole matrix.
+ return BuildLoadWholeMatrixFn(access);
+ });
+
+ // Build the arguments
+ auto args = utils::Transform(access.dynamic_indices, [&](const sem::ValueExpression* e) {
+ return b.Call<u32>(ctx.Clone(e->Declaration()));
+ });
+
+ // Call the helper
+ return b.Call(fn, std::move(args));
+ }
+
+ /// Loads a part of a std140-decomposed matrix from a uniform buffer, inline (without calling a
+ /// helper function).
+ /// @param chain the access chain from the uniform buffer to part of the matrix (column,
+ /// column-swizzle, or element).
+ /// @note The matrix column must be statically indexed to use this method.
+ /// @returns the loaded value expression.
+ const Expression* LoadSubMatrixInline(const AccessChain& chain) {
+ // Method for generating dynamic index expressions.
+ // As this is inline, we can just clone the expression.
+ auto dynamic_index = [&](size_t idx) {
+ return ctx.Clone(chain.dynamic_indices[idx]->Declaration());
+ };
+
+ const Expression* expr = nullptr;
+ const type::Type* ty = nullptr;
+
+ // Build the expression up to, but not including the matrix member
+ auto std140_mat_idx = *chain.std140_mat_idx;
+ for (size_t i = 0; i < std140_mat_idx; i++) {
+ auto [new_expr, new_ty, _] = BuildAccessExpr(expr, ty, chain, i, dynamic_index);
+ expr = new_expr;
+ ty = new_ty;
+ }
+
+ // Access is to the std140 decomposed matrix.
+ // As this is accessing only part of the matrix, we just need to pick the right column
+ // vector member.
+ auto column_idx = std::get<u32>(chain.indices[std140_mat_idx + 1]);
+ if (auto* str = tint::As<type::Struct>(ty)) {
+ // Structure member matrix. The columns are decomposed into the structure.
+ auto mat_member_idx = std::get<u32>(chain.indices[std140_mat_idx]);
+ auto* mat_member = str->Members()[mat_member_idx];
+ auto mat_columns = *std140_mat_members.Get(mat_member);
+ expr = b.MemberAccessor(expr, mat_columns[column_idx]->name->symbol);
+ ty = mat_member->Type()->As<type::Matrix>()->ColumnType();
+ } else {
+ // Non-structure-member matrix. The columns are decomposed into a new, bespoke std140
+ // structure.
+ auto [new_expr, new_ty, _] =
+ BuildAccessExpr(expr, ty, chain, std140_mat_idx, dynamic_index);
+ expr = new_expr;
+ ty = new_ty;
+ auto* mat = ty->As<type::Matrix>();
+ auto std140_mat = std140_mats.Get(ty->As<type::Matrix>());
+ expr = b.MemberAccessor(expr, std140_mat->columns[column_idx]);
+ ty = mat->ColumnType();
+ }
+
+ // Build any remaining accesses into the column
+ for (size_t i = std140_mat_idx + 2; i < chain.indices.Length(); i++) {
+ auto [new_expr, new_ty, _] = BuildAccessExpr(expr, ty, chain, i, dynamic_index);
+ expr = new_expr;
+ ty = new_ty;
+ }
+ return expr;
+ }
+
+ /// Generates a function to load part of a std140-decomposed matrix from a uniform buffer.
+ /// The generated function will have a parameter per dynamic (runtime-evaluated) index in the
+ /// access chain.
+ /// The generated function uses a WGSL switch statement to dynamically select the decomposed
+ /// matrix column.
+ /// @param chain the access chain from the uniform buffer to part of the matrix (column,
+ /// column-swizzle, or element).
+ /// @note The matrix column must be dynamically indexed to use this method.
+ /// @returns the generated function name.
+ Symbol BuildLoadPartialMatrixFn(const AccessChain& chain) {
+ // Build the dynamic index parameters
+ auto dynamic_index_params = utils::Transform(chain.dynamic_indices, [&](auto*, size_t i) {
+ return b.Param("p" + std::to_string(i), b.ty.u32());
+ });
+ // Method for generating dynamic index expressions.
+ // These are passed in as arguments to the function.
+ auto dynamic_index = [&](size_t idx) {
+ return b.Expr(dynamic_index_params[idx]->name->symbol);
+ };
+
+ // Fetch the access chain indices of the matrix access and the parameter index that
+ // holds the matrix column index.
+ auto std140_mat_idx = *chain.std140_mat_idx;
+ auto column_param_idx = std::get<DynamicIndex>(chain.indices[std140_mat_idx + 1]).slot;
+
+ // Begin building the function name. This is extended with logic in the loop below
+ // (when column_idx == 0).
+ std::string name = "load";
+
+ // The switch cases
+ utils::Vector<const CaseStatement*, 4> cases;
+
+ // The function return type.
+ const type::Type* ret_ty = nullptr;
+
+ // Build switch() cases for each column of the matrix
+ auto num_columns = chain.std140_mat_ty->columns();
+ for (uint32_t column_idx = 0; column_idx < num_columns; column_idx++) {
+ const Expression* expr = nullptr;
+ const type::Type* ty = nullptr;
+
+ // Build the expression up to, but not including the matrix
+ for (size_t i = 0; i < std140_mat_idx; i++) {
+ auto [new_expr, new_ty, access_name] =
+ BuildAccessExpr(expr, ty, chain, i, dynamic_index);
+ expr = new_expr;
+ ty = new_ty;
+ if (column_idx == 0) {
+ name += "_" + access_name;
+ }
+ }
+
+ if (auto* str = tint::As<type::Struct>(ty)) {
+ // Structure member matrix. The columns are decomposed into the structure.
+ auto mat_member_idx = std::get<u32>(chain.indices[std140_mat_idx]);
+ auto* mat_member = str->Members()[mat_member_idx];
+ if (column_idx == 0) {
+ name +=
+ "_" + mat_member->Name().Name() + "_p" + std::to_string(column_param_idx);
+ }
+ auto mat_columns = *std140_mat_members.Get(mat_member);
+ expr = b.MemberAccessor(expr, mat_columns[column_idx]->name->symbol);
+ ty = mat_member->Type()->As<type::Matrix>()->ColumnType();
+ } else {
+ // Non-structure-member matrix. The columns are decomposed into a new, bespoke
+ // std140 structure.
+ auto [new_expr, new_ty, mat_name] =
+ BuildAccessExpr(expr, ty, chain, std140_mat_idx, dynamic_index);
+ expr = new_expr;
+ ty = new_ty;
+ if (column_idx == 0) {
+ name += "_" + mat_name + "_p" + std::to_string(column_param_idx);
+ }
+ auto* mat = ty->As<type::Matrix>();
+ auto std140_mat = std140_mats.Get(ty->As<type::Matrix>());
+ expr = b.MemberAccessor(expr, std140_mat->columns[column_idx]);
+ ty = mat->ColumnType();
+ }
+
+ // Build the rest of the expression, skipping over the column index.
+ for (size_t i = std140_mat_idx + 2; i < chain.indices.Length(); i++) {
+ auto [new_expr, new_ty, access_name] =
+ BuildAccessExpr(expr, ty, chain, i, dynamic_index);
+ expr = new_expr;
+ ty = new_ty;
+ if (column_idx == 0) {
+ name += "_" + access_name;
+ }
+ }
+
+ if (column_idx == 0) {
+ ret_ty = ty;
+ }
+
+ auto* case_sel = b.CaseSelector(b.Expr(u32(column_idx)));
+ auto* case_body = b.Block(utils::Vector{b.Return(expr)});
+ cases.Push(b.Case(case_sel, case_body));
+ }
+
+ // Build the default case (required in WGSL).
+ // This just returns a zero value of the return type, as the index must be out of
+ // bounds.
+ cases.Push(b.DefaultCase(b.Block(b.Return(b.Call(CreateASTTypeFor(ctx, ret_ty))))));
+
+ auto* column_selector = dynamic_index(column_param_idx);
+ auto* stmt = b.Switch(column_selector, std::move(cases));
+
+ auto fn_sym = b.Symbols().New(name);
+ b.Func(fn_sym, std::move(dynamic_index_params), CreateASTTypeFor(ctx, ret_ty),
+ utils::Vector{stmt});
+ return fn_sym;
+ }
+
+ /// Generates a function to load a whole std140-decomposed matrix from a uniform buffer.
+ /// The generated function will have a parameter per dynamic (runtime-evaluated) index in the
+ /// access chain.
+ /// @param chain the access chain from the uniform buffer to the whole std140-decomposed
+ /// matrix.
+ /// @returns the generated function name.
+ Symbol BuildLoadWholeMatrixFn(const AccessChain& chain) {
+ // Build the dynamic index parameters
+ auto dynamic_index_params = utils::Transform(chain.dynamic_indices, [&](auto*, size_t i) {
+ return b.Param("p" + std::to_string(i), b.ty.u32());
+ });
+ // Method for generating dynamic index expressions.
+ // These are passed in as arguments to the function.
+ auto dynamic_index = [&](size_t idx) {
+ return b.Expr(dynamic_index_params[idx]->name->symbol);
+ };
+
+ const Expression* expr = nullptr;
+ const type::Type* ty = nullptr;
+ std::string name = "load";
+
+ // Build the expression up to, but not including the matrix member
+ auto std140_mat_idx = *chain.std140_mat_idx;
+ for (size_t i = 0; i < std140_mat_idx; i++) {
+ auto [new_expr, new_ty, access_name] =
+ BuildAccessExpr(expr, ty, chain, i, dynamic_index);
+ expr = new_expr;
+ ty = new_ty;
+ name += "_" + access_name;
+ }
+
+ utils::Vector<const Statement*, 2> stmts;
+
+ // Create a temporary pointer to the structure that holds the matrix columns
+ auto* let = b.Let("s", b.AddressOf(expr));
+ stmts.Push(b.Decl(let));
+
+ utils::Vector<const MemberAccessorExpression*, 4> columns;
+ if (auto* str = tint::As<type::Struct>(ty)) {
+ // Structure member matrix. The columns are decomposed into the structure.
+ auto mat_member_idx = std::get<u32>(chain.indices[std140_mat_idx]);
+ auto* mat_member = str->Members()[mat_member_idx];
+ auto mat_columns = *std140_mat_members.Get(mat_member);
+ columns = utils::Transform(mat_columns, [&](auto* column_member) {
+ return b.MemberAccessor(b.Deref(let), column_member->name->symbol);
+ });
+ ty = mat_member->Type();
+ name += "_" + mat_member->Name().Name();
+ } else {
+ // Non-structure-member matrix. The columns are decomposed into a new, bespoke
+ // std140 structure.
+ auto [new_expr, new_ty, mat_name] =
+ BuildAccessExpr(expr, ty, chain, std140_mat_idx, dynamic_index);
+ expr = new_expr;
+ auto* mat = ty->As<type::Matrix>();
+ auto std140_mat = std140_mats.Get(ty->As<type::Matrix>());
+ columns = utils::Transform(std140_mat->columns, [&](auto column_name) {
+ return b.MemberAccessor(b.Deref(let), column_name);
+ });
+ ty = mat;
+ name += "_" + mat_name;
+ }
+
+ // Reconstruct the matrix from the columns
+ expr = b.Call(CreateASTTypeFor(ctx, chain.std140_mat_ty), std::move(columns));
+
+ // Have the function return the constructed matrix
+ stmts.Push(b.Return(expr));
+
+ // Build the function
+ auto ret_ty = CreateASTTypeFor(ctx, ty);
+ auto fn_sym = b.Symbols().New(name);
+ b.Func(fn_sym, std::move(dynamic_index_params), ret_ty, std::move(stmts));
+ return fn_sym;
+ }
+
+ /// Return type of BuildAccessExpr()
+ struct ExprTypeName {
+ /// The new, post-access expression
+ const Expression* expr;
+ /// The type of #expr
+ const type::Type* type;
+ /// A name segment which can be used to build sensible names for helper functions
+ std::string name;
+ };
+
+ /// Builds a single access in an access chain.
+ /// @param lhs the expression to index using @p access
+ /// @param ty the type of the expression @p lhs
+ /// @param chain the access index to perform on @p lhs
+ /// @param dynamic_index a function that obtains the i'th dynamic index
+ /// @returns a ExprTypeName which holds the new expression, new type and a name segment which
+ /// can be used for creating helper function names.
+ ExprTypeName BuildAccessExpr(const Expression* lhs,
+ const type::Type* ty,
+ const AccessChain& chain,
+ size_t index,
+ std::function<const Expression*(size_t)> dynamic_index) {
+ auto& access = chain.indices[index];
+
+ if (std::get_if<UniformVariable>(&access)) {
+ const auto symbol = chain.var->Declaration()->name->symbol;
+ const auto* expr = b.Expr(ctx.Clone(symbol));
+ const auto name = symbol.Name();
+ ty = chain.var->Type()->UnwrapRef();
+ return {expr, ty, name};
+ }
+
+ if (auto* dyn_idx = std::get_if<DynamicIndex>(&access)) {
+ /// The access uses a dynamic (runtime-expression) index.
+ auto name = "p" + std::to_string(dyn_idx->slot);
+ return Switch(
+ ty, //
+ [&](const type::Array* arr) -> ExprTypeName {
+ auto* idx = dynamic_index(dyn_idx->slot);
+ auto* expr = b.IndexAccessor(lhs, idx);
+ return {expr, arr->ElemType(), name};
+ }, //
+ [&](const type::Matrix* mat) -> ExprTypeName {
+ auto* idx = dynamic_index(dyn_idx->slot);
+ auto* expr = b.IndexAccessor(lhs, idx);
+ return {expr, mat->ColumnType(), name};
+ }, //
+ [&](const type::Vector* vec) -> ExprTypeName {
+ auto* idx = dynamic_index(dyn_idx->slot);
+ auto* expr = b.IndexAccessor(lhs, idx);
+ return {expr, vec->type(), name};
+ }, //
+ [&](Default) -> ExprTypeName {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled type for access chain: " << ty->FriendlyName();
+ return {};
+ });
+ }
+ if (auto* swizzle = std::get_if<Swizzle>(&access)) {
+ /// The access is a vector swizzle.
+ return Switch(
+ ty, //
+ [&](const type::Vector* vec) -> ExprTypeName {
+ static const char xyzw[] = {'x', 'y', 'z', 'w'};
+ std::string rhs;
+ for (auto el : *swizzle) {
+ rhs += xyzw[el];
+ }
+ auto swizzle_ty = src->Types().Find<type::Vector>(
+ vec->type(), static_cast<uint32_t>(swizzle->Length()));
+ auto* expr = b.MemberAccessor(lhs, rhs);
+ return {expr, swizzle_ty, rhs};
+ }, //
+ [&](Default) -> ExprTypeName {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled type for access chain: " << ty->FriendlyName();
+ return {};
+ });
+ }
+ /// The access is a static index.
+ auto idx = std::get<u32>(access);
+ return Switch(
+ ty, //
+ [&](const type::Struct* str) -> ExprTypeName {
+ auto* member = str->Members()[idx];
+ auto member_name = member->Name().Name();
+ auto* expr = b.MemberAccessor(lhs, member_name);
+ ty = member->Type();
+ return {expr, ty, member_name};
+ }, //
+ [&](const type::Array* arr) -> ExprTypeName {
+ auto* expr = b.IndexAccessor(lhs, idx);
+ return {expr, arr->ElemType(), std::to_string(idx)};
+ }, //
+ [&](const type::Matrix* mat) -> ExprTypeName {
+ auto* expr = b.IndexAccessor(lhs, idx);
+ return {expr, mat->ColumnType(), std::to_string(idx)};
+ }, //
+ [&](const type::Vector* vec) -> ExprTypeName {
+ auto* expr = b.IndexAccessor(lhs, idx);
+ return {expr, vec->type(), std::to_string(idx)};
+ }, //
+ [&](Default) -> ExprTypeName {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled type for access chain: " << ty->FriendlyName();
+ return {};
+ });
+ }
+};
+
+Std140::Std140() = default;
+
+Std140::~Std140() = default;
+
+Transform::ApplyResult Std140::Apply(const Program* src, const DataMap&, DataMap&) const {
+ return State(src).Run();
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/std140.h b/src/tint/lang/wgsl/ast/transform/std140.h
new file mode 100644
index 0000000..78a7cf0
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/std140.h
@@ -0,0 +1,49 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_STD140_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_STD140_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// Std140 is a transform that forks types used in the uniform address space that contain
+/// `matNx2<f32>` matrices into `N`x`vec2<f32>` column vectors, and `matNxM<f16>` matrices into
+/// `N`x`vecM<f16>` column vectors. Types that transitively use these forked types are also forked.
+/// `var<uniform>` variables will use these forked types, and expressions loading from these
+/// variables will do appropriate conversions to the regular WGSL types. As `matNx2<f32>` and
+/// `matNxM<f16>` matrices are the only type that violate std140-layout, this transformation is
+/// sufficient to have any WGSL structure be std140-layout conformant.
+///
+/// @note This transform requires the PromoteSideEffectsToDecl transform to have been run first.
+class Std140 final : public utils::Castable<Std140, Transform> {
+ public:
+ /// Constructor
+ Std140();
+ /// Destructor
+ ~Std140() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_STD140_H_
diff --git a/src/tint/lang/wgsl/ast/transform/std140_exhaustive_test.cc b/src/tint/lang/wgsl/ast/transform/std140_exhaustive_test.cc
new file mode 100644
index 0000000..f05ee31
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/std140_exhaustive_test.cc
@@ -0,0 +1,4890 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/std140.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/utils/string.h"
+
+namespace tint::ast::transform {
+namespace {
+
+enum class MatrixType { f32, f16 };
+
+struct MatrixCase {
+ uint32_t columns;
+ uint32_t rows;
+ MatrixType type;
+
+ size_t ElementSize() const { return type == MatrixType::f16 ? 2 : 4; }
+
+ size_t ColumnVectorAlign() const { return (rows == 3 ? 4 : rows) * ElementSize(); }
+
+ bool NotStd140Compatible() const { return ColumnVectorAlign() != 16; }
+
+ // Return if this matrix type can be used as element type of an uniform buffer, i.e. the
+ // array stride is multiple of 16.
+ bool CanBeUsedAsUniformArrayElememts() const {
+ const size_t array_stride = columns * ColumnVectorAlign();
+ return (array_stride % 16 == 0);
+ }
+
+ std::string Shape() const { return std::to_string(columns) + "x" + std::to_string(rows); }
+
+ std::string ElementType() const { return type == MatrixType::f16 ? "f16" : "f32"; }
+
+ std::string Mat() const { return "mat" + Shape() + "<" + ElementType() + ">"; }
+
+ std::string ColumnVector() const {
+ return "vec" + std::to_string(rows) + "<" + (type == MatrixType::f32 ? "f32" : "f16") + ">";
+ }
+
+ std::string ColumnVectorSwizzle() const {
+ switch (rows) {
+ case 2:
+ return "yx";
+ case 3:
+ return "yzx";
+ case 4:
+ return "wzxy";
+ }
+ return "";
+ }
+
+ // For each column, replaces "${col_id_for_tmpl}" by column index in `tmpl` to get a string, and
+ // join all these strings with `seperator`. If `tmpl_for_last_column` is not empty, use it
+ // instead of `tmpl` for the last column.
+ std::string JoinTemplatedStringForEachMatrixColumn(
+ std::string tmpl,
+ std::string seperator,
+ std::string tmpl_for_last_column = "") const {
+ std::string result;
+ if (tmpl_for_last_column.size() == 0) {
+ tmpl_for_last_column = tmpl;
+ }
+ for (size_t c = 0; c < columns - 1; c++) {
+ if (c > 0) {
+ result += seperator;
+ }
+ std::string string_for_current_column =
+ utils::ReplaceAll(tmpl, "${col_id_for_tmpl}", std::to_string(c));
+ result += string_for_current_column;
+ }
+ result += seperator;
+ std::string string_for_last_column = utils::ReplaceAll(
+ tmpl_for_last_column, "${col_id_for_tmpl}", std::to_string(columns - 1));
+ result += string_for_last_column;
+ return result;
+ }
+
+ std::string ExpendedColumnVectors(uint32_t leading_space, std::string name) const {
+ std::string space(leading_space, ' ');
+ return JoinTemplatedStringForEachMatrixColumn(
+ space + name + "${col_id_for_tmpl} : " + ColumnVector() + ",", "\n");
+ }
+
+ std::string ExpendedColumnVectorsInline(std::string name, std::string seperator) const {
+ return JoinTemplatedStringForEachMatrixColumn(name + "${col_id_for_tmpl}", seperator);
+ }
+
+ std::string ExpendedColumnVectorsWithLastSize(uint32_t leading_space,
+ std::string name,
+ uint32_t last_size) const {
+ std::string space(leading_space, ' ');
+ return JoinTemplatedStringForEachMatrixColumn(
+ space + name + "${col_id_for_tmpl} : " + ColumnVector() + ",", "\n",
+ space + "@size(" + std::to_string(last_size) + ")\n" + space + name +
+ "${col_id_for_tmpl} : " + ColumnVector() + ",");
+ }
+
+ // Replace user-given fields and predefined fields in a given string `str`.
+ // First, for each pair of string in `replacement_pairs`, replace all occurrences of the first
+ // string of pair with second string. Then, replace several predefined fields with the matrix
+ // information. E.g. for a matrix mat4x3<f32>, would replace "${mat}" with "mat4x3<f32>",
+ // replace "${shape}" with "4x3", "${elem_type}" with "f32", "${col_vector_type}" with
+ // "vec3<f32>", and "${swizzle}" with "yzx".
+ std::string ReplaceFieldsInString(
+ std::string str,
+ std::initializer_list<std::pair<std::string, std::string>> replacement_pairs = {}) const {
+ for (auto& replace : replacement_pairs) {
+ str = utils::ReplaceAll(str, replace.first, replace.second);
+ }
+ str = utils::ReplaceAll(str, "${mat}", Mat());
+ str = utils::ReplaceAll(str, "${shape}", Shape());
+ str = utils::ReplaceAll(str, "${elem_type}", ElementType());
+ str = utils::ReplaceAll(str, "${col_vector_type}", ColumnVector());
+ str = utils::ReplaceAll(str, "${swizzle}", ColumnVectorSwizzle());
+ return str;
+ }
+};
+
+inline std::ostream& operator<<(std::ostream& os, const MatrixCase& c) {
+ return os << c.Mat();
+}
+
+using Std140Test_Matrix = TransformTestWithParam<MatrixCase>;
+
+TEST_P(Std140Test_Matrix, SingleStructMatUniform) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+ expect = matrix.ReplaceFieldsInString(
+ expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "m_")}});
+
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, CustomAlign) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ before : i32,
+ @align(128)
+ m : ${mat},
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ before : i32,
+ @align(128)
+ m : ${mat},
+ after : i32,
+}
+
+struct S_std140 {
+ before : i32,
+ @align(128i)
+${col_vectors}
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+ expect = matrix.ReplaceFieldsInString(
+ expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "m_")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, CustomSizeMat) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ before : i32,
+ @size(128)
+ m : ${mat},
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ uint32_t last_size =
+ 128 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+
+ expect = R"(
+enable f16;
+
+struct S {
+ before : i32,
+ @size(128)
+ m : ${mat},
+ after : i32,
+}
+
+struct S_std140 {
+ before : i32,
+${col_vectors}
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, CustomAlignAndSize) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ before : i32,
+ @align(128) @size(128)
+ m : ${mat},
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ uint32_t last_size =
+ 128 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+
+ expect = R"(
+enable f16;
+
+struct S {
+ before : i32,
+ @align(128) @size(128)
+ m : ${mat},
+ after : i32,
+}
+
+struct S_std140 {
+ before : i32,
+ @align(128i)
+${col_vectors}
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, MatrixUsageInForLoop) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ for(var i = u32(s.m[0][0]); (i < u32(s.m[i][1])); i += u32(s.m[1][i])) {
+ }
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m_p0_1(p0 : u32) -> ${elem_type} {
+ switch(p0) {
+${col_table}
+ default: {
+ return ${elem_type}();
+ }
+ }
+}
+
+fn f() {
+ for(var i = u32(s.m_0[0u]); (i < u32(load_s_m_p0_1(u32(i)))); i += u32(s.m_1[i])) {
+ }
+}
+)";
+
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return s.m_0[1u];
+ // }
+ // case 1u: {
+ // return s.m_1[1u];
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return s.m_${col_id_for_tmpl}[1u];
+ })",
+ "\n");
+ expect = matrix.ReplaceFieldsInString(
+ expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "m_")},
+ {"${col_table}", col_table}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, MatUniform_LoadMatrix) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> m : ${mat};
+
+fn f() {
+ let l = m;
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> m : mat${shape}_${elem_type};
+
+fn conv_mat${shape}_${elem_type}(val : mat${shape}_${elem_type}) -> ${mat} {
+ return ${mat}(${col_vectors_inline});
+}
+
+fn f() {
+ let l = conv_mat${shape}_${elem_type}(m);
+}
+)";
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.col", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, MatUniform_LoadColumn_ConstIndex) {
+ auto matrix = GetParam();
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : ${mat};
+
+fn f() {
+ let l = a[${cloumn_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : mat${shape}_${elem_type};
+
+fn f() {
+ let l = a.col${cloumn_index};
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t col = 0; col < matrix.columns; col++) {
+ std::string src = utils::ReplaceAll(tmpl_src, "${cloumn_index}", std::to_string(col));
+ std::string expect = utils::ReplaceAll(tmpl_expect, "${cloumn_index}", std::to_string(col));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing col " << col;
+ }
+}
+
+TEST_P(Std140Test_Matrix, MatUniform_LoadColumn_VariableIndex) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : ${mat};
+
+fn f() {
+ let I = 1;
+ let l = a[I];
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : mat${shape}_${elem_type};
+
+fn load_a_p0(p0 : u32) -> ${col_vector_type} {
+ switch(p0) {
+${col_table}
+ default: {
+ return ${col_vector_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_p0(u32(I));
+}
+)";
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return a.col0;
+ // }
+ // case 1u: {
+ // return a.col1;
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return a.col${col_id_for_tmpl};
+ })",
+ "\n");
+ expect = matrix.ReplaceFieldsInString(
+ expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_table}", col_table}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, MatUniform_LoadColumnSwizzle_ConstIndex) {
+ auto matrix = GetParam();
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : ${mat};
+
+fn f() {
+ let l = a[${cloumn_index}].${swizzle};
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : mat${shape}_${elem_type};
+
+fn f() {
+ let l = a.col${cloumn_index}.${swizzle};
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t col = 0; col < matrix.columns; col++) {
+ std::string src = utils::ReplaceAll(tmpl_src, "${cloumn_index}", std::to_string(col));
+ std::string expect = utils::ReplaceAll(tmpl_expect, "${cloumn_index}", std::to_string(col));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing col " << col;
+ }
+}
+
+TEST_P(Std140Test_Matrix, MatUniform_LoadColumnSwizzle_VariableIndex) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : ${mat};
+
+fn f() {
+ let I = 1;
+ let l = a[I].${swizzle};
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : mat${shape}_${elem_type};
+
+fn load_a_p0_${swizzle}(p0 : u32) -> ${col_vector_type} {
+ switch(p0) {
+${col_table}
+ default: {
+ return ${col_vector_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_p0_${swizzle}(u32(I));
+}
+)";
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return a.col0.${swizzle};
+ // }
+ // case 1u: {
+ // return a.col1.${swizzle};
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return a.col${col_id_for_tmpl}.${swizzle};
+ })",
+ "\n");
+ expect = matrix.ReplaceFieldsInString(
+ expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_table}", col_table}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, MatUniform_LoadScalar_ConstColumnIndex_ConstRowIndex) {
+ auto matrix = GetParam();
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : ${mat};
+
+fn f() {
+ let l = a[${col_index}][${row_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : mat${shape}_${elem_type};
+
+fn f() {
+ let l = a.col${col_index}[${row_index}u];
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t col = 0; col < matrix.columns; col++) {
+ for (uint32_t row = 0; row < matrix.rows; row++) {
+ std::string src = utils::ReplaceAll(tmpl_src, "${col_index}", std::to_string(col));
+ src = utils::ReplaceAll(src, "${row_index}", std::to_string(row));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${col_index}", std::to_string(col));
+ expect = utils::ReplaceAll(expect, "${row_index}", std::to_string(row));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing col " << col << " row " << row;
+ }
+ }
+}
+
+TEST_P(Std140Test_Matrix, MatUniform_LoadScalar_VariableColumnIndex_ConstRowIndex) {
+ auto matrix = GetParam();
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : ${mat};
+
+fn f() {
+ let I = 0;
+ let l = a[I][${row_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : mat${shape}_${elem_type};
+
+fn load_a_p0_${row_index}(p0 : u32) -> ${elem_type} {
+ switch(p0) {
+${col_table}
+ default: {
+ return ${elem_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_a_p0_${row_index}(u32(I));
+}
+)";
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return a.col0[${row_index}u];
+ // }
+ // case 1u: {
+ // return a.col1[${row_index}u];
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return a.col${col_id_for_tmpl}[${row_index}u];
+ })",
+ "\n");
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_table}", col_table}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t row = 0; row < matrix.rows; row++) {
+ std::string src = utils::ReplaceAll(tmpl_src, "${row_index}", std::to_string(row));
+ std::string expect = utils::ReplaceAll(tmpl_expect, "${row_index}", std::to_string(row));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing row " << row;
+ }
+}
+
+TEST_P(Std140Test_Matrix, MatUniform_LoadScalar_ConstColumnIndex_VariableRowIndex) {
+ auto matrix = GetParam();
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : ${mat};
+
+fn f() {
+ let I = 0;
+ let l = a[${col_index}][I];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : mat${shape}_${elem_type};
+
+fn f() {
+ let I = 0;
+ let l = a.col${col_index}[I];
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t col = 0; col < matrix.columns; col++) {
+ std::string src = utils::ReplaceAll(tmpl_src, "${col_index}", std::to_string(col));
+ std::string expect = utils::ReplaceAll(tmpl_expect, "${col_index}", std::to_string(col));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing col " << col;
+ }
+}
+
+TEST_P(Std140Test_Matrix, MatUniform_LoadScalar_VariableColumnIndex_VariableRowIndex) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : ${mat};
+
+fn f() {
+ let I = 0;
+ let l = a[I][I];
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : mat${shape}_${elem_type};
+
+fn load_a_p0_p1(p0 : u32, p1 : u32) -> ${elem_type} {
+ switch(p0) {
+${col_table}
+ default: {
+ return ${elem_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_a_p0_p1(u32(I), u32(I));
+}
+)";
+
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return a.col0[p1];
+ // }
+ // case 1u: {
+ // return a.col1[p1];
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return a.col${col_id_for_tmpl}[p1];
+ })",
+ "\n");
+ expect = matrix.ReplaceFieldsInString(
+ expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_table}", col_table}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, StructMatUniform_NameCollision) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ m_1 : i32,
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ m_1 : i32,
+ m : ${mat},
+}
+
+struct S_std140 {
+ m_1 : i32,
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+ expect = matrix.ReplaceFieldsInString(
+ expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "m__")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, StructMatUniform_LoadStruct) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s;
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn conv_S(val : S_std140) -> S {
+ return S(${mat}(${col_vectors_inline}));
+}
+
+fn f() {
+ let l = conv_S(s);
+}
+)";
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "m_")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.m_", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, StructMatUniform_LoadMatrix) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.m;
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m() -> ${mat} {
+ let s = &(s);
+ return ${mat}(${col_vectors_inline});
+}
+
+fn f() {
+ let l = load_s_m();
+}
+)";
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "m_")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("(*(s)).m_", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, StructMatUniform_LoadColumn_ConstIndex) {
+ auto matrix = GetParam();
+
+ std::string tmpl_src = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.m[${col_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let l = s.m_${col_index};
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "m_")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t col = 0; col < matrix.columns; col++) {
+ std::string src = utils::ReplaceAll(tmpl_src, "${col_index}", std::to_string(col));
+ std::string expect = utils::ReplaceAll(tmpl_expect, "${col_index}", std::to_string(col));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing col " << col;
+ }
+}
+
+TEST_P(Std140Test_Matrix, StructMatUniform_LoadColumn_VariableIndex) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[I];
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m_p0(p0 : u32) -> ${col_vector_type} {
+ switch(p0) {
+${col_table}
+ default: {
+ return ${col_vector_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_s_m_p0(u32(I));
+}
+)";
+
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return s.m_0;
+ // }
+ // case 1u: {
+ // return s.m_1;
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return s.m_${col_id_for_tmpl};
+ })",
+ "\n");
+ expect = matrix.ReplaceFieldsInString(
+ expect, {{"${col_vector_type}", matrix.ColumnVector()},
+ {"${col_vectors}", matrix.ExpendedColumnVectors(2, "m_")},
+ {"${col_table}", col_table}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, StructMatUniform_LoadScalar_ConstColumnIndex_ConstRowIndex) {
+ auto matrix = GetParam();
+
+ std::string tmpl_src = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.m[${col_index}][${row_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let l = s.m_${col_index}[${row_index}u];
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "m_")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t col = 0; col < matrix.columns; col++) {
+ for (uint32_t row = 0; row < matrix.rows; row++) {
+ std::string src = utils::ReplaceAll(tmpl_src, "${col_index}", std::to_string(col));
+ src = utils::ReplaceAll(src, "${row_index}", std::to_string(row));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${col_index}", std::to_string(col));
+ expect = utils::ReplaceAll(expect, "${row_index}", std::to_string(row));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing col " << col << " row " << row;
+ }
+ }
+}
+
+TEST_P(Std140Test_Matrix, StructMatUniform_LoadScalar_VariableColumnIndex_ConstRowIndex) {
+ auto matrix = GetParam();
+
+ std::string tmpl_src = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[I][${row_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m_p0_${row_index}(p0 : u32) -> ${elem_type} {
+ switch(p0) {
+${col_table}
+ default: {
+ return ${elem_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_s_m_p0_${row_index}(u32(I));
+}
+)";
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return s.m_0[${row_index}u];
+ // }
+ // case 1u: {
+ // return s.m_1[${row_index}u];
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return s.m_${col_id_for_tmpl}[${row_index}u];
+ })",
+ "\n");
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "m_")},
+ {"${col_table}", col_table}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t row = 0; row < matrix.rows; row++) {
+ std::string src = utils::ReplaceAll(tmpl_src, "${row_index}", std::to_string(row));
+ std::string expect = utils::ReplaceAll(tmpl_expect, "${row_index}", std::to_string(row));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing row " << row;
+ }
+}
+
+TEST_P(Std140Test_Matrix, StructMatUniform_LoadScalar_ConstColumnIndex_VariableRowIndex) {
+ auto matrix = GetParam();
+
+ std::string tmpl_src = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[${col_index}][I];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let I = 0;
+ let l = s.m_${col_index}[I];
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "m_")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t col = 0; col < matrix.columns; col++) {
+ std::string src = utils::ReplaceAll(tmpl_src, "${col_index}", std::to_string(col));
+ std::string expect = utils::ReplaceAll(tmpl_expect, "${col_index}", std::to_string(col));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing col " << col;
+ }
+}
+
+TEST_P(Std140Test_Matrix, StructMatUniform_LoadScalar_VariableColumnIndex_VariableRowIndex) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[I][I];
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m_p0_p1(p0 : u32, p1 : u32) -> ${elem_type} {
+ switch(p0) {
+${col_table}
+ default: {
+ return ${elem_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_s_m_p0_p1(u32(I), u32(I));
+}
+)";
+
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return s.m_0[p1];
+ // }
+ // case 1u: {
+ // return s.m_1[p1];
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return s.m_${col_id_for_tmpl}[p1];
+ })",
+ "\n");
+ expect = matrix.ReplaceFieldsInString(
+ expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "m_")},
+ {"${col_table}", col_table}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, ArrayStructMatUniform_LoadArray) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a;
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(${mat}(${col_vectors_inline}));
+}
+
+fn conv_arr3_S(val : array<S_std140, 3u>) -> array<S, 3u> {
+ var arr : array<S, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_S(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let l = conv_arr3_S(a);
+}
+)";
+ uint32_t last_size =
+ 64 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.m_", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, ArrayStructMatUniform_LoadStruct_ConstIndex) {
+ auto matrix = GetParam();
+
+ std::string tmpl_src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a[${array_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(${mat}(${col_vectors_inline}));
+}
+
+fn f() {
+ let l = conv_S(a[${array_index}u]);
+}
+)";
+ uint32_t last_size =
+ 64 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.m_", ", ")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t array_index = 0; array_index < 3; array_index++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${array_index}", std::to_string(array_index));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${array_index}", std::to_string(array_index));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing array element " << array_index;
+ }
+}
+
+TEST_P(Std140Test_Matrix, ArrayStructMatUniform_LoadStruct_VariableIndex) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I];
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(${mat}(${col_vectors_inline}));
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_S(a[I]);
+}
+)";
+ uint32_t last_size =
+ 64 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.m_", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, ArrayStructMatUniform_LoadMatrix_ConstArrayIndex) {
+ auto matrix = GetParam();
+
+ std::string tmpl_src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a[${array_index}].m;
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn load_a_${array_index}_m() -> ${mat} {
+ let s = &(a[${array_index}u]);
+ return ${mat}(${col_vectors_inline});
+}
+
+fn f() {
+ let l = load_a_${array_index}_m();
+}
+)";
+ uint32_t last_size =
+ 64 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("(*(s)).m_", ", ")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t array_index = 0; array_index < 3; array_index++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${array_index}", std::to_string(array_index));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${array_index}", std::to_string(array_index));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing array element " << array_index;
+ }
+}
+
+TEST_P(Std140Test_Matrix, ArrayStructMatUniform_LoadMatrix_VariableArrayIndex) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].m;
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn load_a_p0_m(p0 : u32) -> ${mat} {
+ let s = &(a[p0]);
+ return ${mat}(${col_vectors_inline});
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_p0_m(u32(I));
+}
+)";
+ uint32_t last_size =
+ 64 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("(*(s)).m_", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, ArrayStructMatUniform_LoadColumn_ConstArrayIndex_ConstColumnIndex) {
+ auto matrix = GetParam();
+
+ std::string tmpl_src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a[${array_index}].m[${cloumn_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn f() {
+ let l = a[${array_index}u].m_${cloumn_index};
+}
+)";
+ uint32_t last_size =
+ 64 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t array_index = 0; array_index < 3; array_index++) {
+ for (uint32_t col = 0; col < matrix.columns; col++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${array_index}", std::to_string(array_index));
+ src = utils::ReplaceAll(src, "${cloumn_index}", std::to_string(col));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${array_index}", std::to_string(array_index));
+ expect = utils::ReplaceAll(expect, "${cloumn_index}", std::to_string(col));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got))
+ << "accessing array element " << array_index << " col " << col;
+ }
+ }
+}
+
+TEST_P(Std140Test_Matrix, ArrayStructMatUniform_LoadColumn_VariableArrayIndex_ConstColumnIndex) {
+ auto matrix = GetParam();
+
+ std::string tmpl_src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].m[${cloumn_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].m_${cloumn_index};
+}
+)";
+ uint32_t last_size =
+ 64 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t col = 0; col < matrix.columns; col++) {
+ std::string src = utils::ReplaceAll(tmpl_src, "${cloumn_index}", std::to_string(col));
+ std::string expect = utils::ReplaceAll(tmpl_expect, "${cloumn_index}", std::to_string(col));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing col " << col;
+ }
+}
+
+TEST_P(Std140Test_Matrix, ArrayStructMatUniform_LoadColumn_ConstArrayIndex_VariableColumnIndex) {
+ auto matrix = GetParam();
+
+ std::string tmpl_src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[${array_index}].m[I];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn load_a_${array_index}_m_p0(p0 : u32) -> ${col_vector_type} {
+ switch(p0) {
+${col_table}
+ default: {
+ return ${col_vector_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_${array_index}_m_p0(u32(I));
+}
+)";
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return a[${array_index}u].m_0;
+ // }
+ // case 1u: {
+ // return a[${array_index}u].m_1;
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return a[${array_index}u].m_${col_id_for_tmpl};
+ })",
+ "\n");
+ uint32_t last_size =
+ 64 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)},
+ {"${col_table}", col_table}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t array_index = 0; array_index < 3; array_index++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${array_index}", std::to_string(array_index));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${array_index}", std::to_string(array_index));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing array element " << array_index;
+ }
+}
+
+TEST_P(Std140Test_Matrix, ArrayStructMatUniform_LoadColumn_VariableArrayIndex_VariableColumnIndex) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].m[I];
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn load_a_p0_m_p1(p0 : u32, p1 : u32) -> ${col_vector_type} {
+ switch(p1) {
+${col_table}
+ default: {
+ return ${col_vector_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_p0_m_p1(u32(I), u32(I));
+}
+)";
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return a[p0].m_0;
+ // }
+ // case 1u: {
+ // return a[p0].m_1;
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return a[p0].m_${col_id_for_tmpl};
+ })",
+ "\n");
+ uint32_t last_size =
+ 64 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)},
+ {"${col_table}", col_table}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, ArrayStructArrayStructMatUniform_Loads) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct Inner {
+ @size(64)
+ m : ${mat},
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<Outer, 4>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let K = 0;
+ let l_a : array<Outer, 4> = a;
+ let l_a_1 : Outer = a[1];
+ let l_a_I : Outer = a[I];
+ let l_a_2_a : array<Inner, 4> = a[2].a;
+ let l_a_I_a : array<Inner, 4> = a[I].a;
+ let l_a_3_a_1 : Inner = a[3].a[1];
+ let l_a_3_a_I : Inner = a[3].a[I];
+ let l_a_I_a_1 : Inner = a[I].a[1];
+ let l_a_I_a_J : Inner = a[I].a[J];
+ let l_a_0_a_2_m : ${mat} = a[0].a[2].m;
+ let l_a_0_a_I_m : ${mat} = a[0].a[I].m;
+ let l_a_I_a_2_m : ${mat} = a[I].a[2].m;
+ let l_a_I_a_J_m : ${mat} = a[I].a[J].m;
+ let l_a_1_a_3_m_0 : ${col_vector_type} = a[1].a[3].m[0];
+ let l_a_I_a_J_m_K : ${col_vector_type} = a[I].a[J].m[K];
+ let l_a_2_a_0_m_1_0 : ${elem_type} = a[2].a[0].m[1][0];
+ let l_a_I_a_J_m_K_I : ${elem_type} = a[I].a[J].m[K][I];
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct Inner {
+ @size(64)
+ m : ${mat},
+}
+
+struct Inner_std140 {
+${col_vectors}
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+struct Outer_std140 {
+ a : array<Inner_std140, 4u>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<Outer_std140, 4u>;
+
+fn conv_Inner(val : Inner_std140) -> Inner {
+ return Inner(${mat}(${col_vectors_inline_conv_Inner}));
+}
+
+fn conv_arr4_Inner(val : array<Inner_std140, 4u>) -> array<Inner, 4u> {
+ var arr : array<Inner, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_Inner(val[i]);
+ }
+ return arr;
+}
+
+fn conv_Outer(val : Outer_std140) -> Outer {
+ return Outer(conv_arr4_Inner(val.a));
+}
+
+fn conv_arr4_Outer(val : array<Outer_std140, 4u>) -> array<Outer, 4u> {
+ var arr : array<Outer, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_Outer(val[i]);
+ }
+ return arr;
+}
+
+fn load_a_0_a_2_m() -> ${mat} {
+ let s = &(a[0u].a[2u]);
+ return ${mat}(${col_vectors_inline_load_matrix});
+}
+
+fn load_a_0_a_p0_m(p0 : u32) -> ${mat} {
+ let s = &(a[0u].a[p0]);
+ return ${mat}(${col_vectors_inline_load_matrix});
+}
+
+fn load_a_p0_a_2_m(p0 : u32) -> ${mat} {
+ let s = &(a[p0].a[2u]);
+ return ${mat}(${col_vectors_inline_load_matrix});
+}
+
+fn load_a_p0_a_p1_m(p0 : u32, p1 : u32) -> ${mat} {
+ let s = &(a[p0].a[p1]);
+ return ${mat}(${col_vectors_inline_load_matrix});
+}
+
+fn load_a_p0_a_p1_m_p2(p0 : u32, p1 : u32, p2 : u32) -> ${col_vector_type} {
+ switch(p2) {
+${col_table_load_column}
+ default: {
+ return ${col_vector_type}();
+ }
+ }
+}
+
+fn load_a_p0_a_p1_m_p2_p3(p0 : u32, p1 : u32, p2 : u32, p3 : u32) -> ${elem_type} {
+ switch(p2) {
+${col_table_load_element}
+ default: {
+ return ${elem_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let K = 0;
+ let l_a : array<Outer, 4> = conv_arr4_Outer(a);
+ let l_a_1 : Outer = conv_Outer(a[1u]);
+ let l_a_I : Outer = conv_Outer(a[I]);
+ let l_a_2_a : array<Inner, 4> = conv_arr4_Inner(a[2u].a);
+ let l_a_I_a : array<Inner, 4> = conv_arr4_Inner(a[I].a);
+ let l_a_3_a_1 : Inner = conv_Inner(a[3u].a[1u]);
+ let l_a_3_a_I : Inner = conv_Inner(a[3u].a[I]);
+ let l_a_I_a_1 : Inner = conv_Inner(a[I].a[1u]);
+ let l_a_I_a_J : Inner = conv_Inner(a[I].a[J]);
+ let l_a_0_a_2_m : ${mat} = load_a_0_a_2_m();
+ let l_a_0_a_I_m : ${mat} = load_a_0_a_p0_m(u32(I));
+ let l_a_I_a_2_m : ${mat} = load_a_p0_a_2_m(u32(I));
+ let l_a_I_a_J_m : ${mat} = load_a_p0_a_p1_m(u32(I), u32(J));
+ let l_a_1_a_3_m_0 : ${col_vector_type} = a[1u].a[3u].m_0;
+ let l_a_I_a_J_m_K : ${col_vector_type} = load_a_p0_a_p1_m_p2(u32(I), u32(J), u32(K));
+ let l_a_2_a_0_m_1_0 : ${elem_type} = a[2u].a[0u].m_1[0u];
+ let l_a_I_a_J_m_K_I : ${elem_type} = load_a_p0_a_p1_m_p2_p3(u32(I), u32(J), u32(K), u32(I));
+}
+)";
+ std::string col_tableLoadColumn = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return a[p0].a[p1].m_${col_id_for_tmpl};
+ })",
+ "\n");
+ std::string col_tableLoadElement = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return a[p0].a[p1].m_${col_id_for_tmpl}[p3];
+ })",
+ "\n");
+ uint32_t last_size =
+ 64 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)},
+ {"${col_vectors_inline_conv_Inner}",
+ matrix.ExpendedColumnVectorsInline("val.m_", ", ")},
+ {"${col_vectors_inline_load_matrix}",
+ matrix.ExpendedColumnVectorsInline("(*(s)).m_", ", ")},
+ {"${col_table_load_column}", col_tableLoadColumn},
+ {"${col_table_load_element}", col_tableLoadElement}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, ArrayStructArrayStructMatUniform_LoadsViaPtrs) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct Inner {
+ @size(64)
+ m : ${mat},
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<Outer, 4>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let K = 0;
+ let p_a = &(a);
+ let p_a_3 = &((*(p_a))[3]);
+ let p_a_I = &((*(p_a))[I]);
+ let p_a_3_a = &((*(p_a_3)).a);
+ let p_a_I_a = &((*(p_a_I)).a);
+ let p_a_3_a_2 = &((*(p_a_3_a))[2]);
+ let p_a_3_a_I = &((*(p_a_3_a))[I]);
+ let p_a_I_a_2 = &((*(p_a_I_a))[2]);
+ let p_a_I_a_J = &((*(p_a_I_a))[J]);
+ let p_a_3_a_2_m = &((*(p_a_3_a_2)).m);
+ let p_a_3_a_I_m = &((*(p_a_3_a_I)).m);
+ let p_a_I_a_2_m = &((*(p_a_I_a_2)).m);
+ let p_a_I_a_J_m = &((*(p_a_I_a_J)).m);
+ let p_a_3_a_2_m_1 = &((*(p_a_3_a_2_m))[1]);
+ let p_a_I_a_J_m_K = &((*(p_a_I_a_J_m))[K]);
+ let l_a : array<Outer, 4> = *(p_a);
+ let l_a_3 : Outer = *(p_a_3);
+ let l_a_I : Outer = *(p_a_I);
+ let l_a_3_a : array<Inner, 4> = *(p_a_3_a);
+ let l_a_I_a : array<Inner, 4> = *(p_a_I_a);
+ let l_a_3_a_2 : Inner = *(p_a_3_a_2);
+ let l_a_3_a_I : Inner = *(p_a_3_a_I);
+ let l_a_I_a_2 : Inner = *(p_a_I_a_2);
+ let l_a_I_a_J : Inner = *(p_a_I_a_J);
+ let l_a_3_a_2_m : ${mat} = *(p_a_3_a_2_m);
+ let l_a_3_a_I_m : ${mat} = *(p_a_3_a_I_m);
+ let l_a_I_a_2_m : ${mat} = *(p_a_I_a_2_m);
+ let l_a_I_a_J_m : ${mat} = *(p_a_I_a_J_m);
+ let l_a_3_a_2_m_1 : ${col_vector_type} = *(p_a_3_a_2_m_1);
+ let l_a_I_a_J_m_K : ${col_vector_type} = *(p_a_I_a_J_m_K);
+ let l_a_2_a_0_m_1_0 : ${elem_type} = (*(p_a_3_a_2_m_1))[0];
+ let l_a_I_a_J_m_K_I : ${elem_type} = (*(p_a_I_a_J_m_K))[I];
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct Inner {
+ @size(64)
+ m : ${mat},
+}
+
+struct Inner_std140 {
+${col_vectors}
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+struct Outer_std140 {
+ a : array<Inner_std140, 4u>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<Outer_std140, 4u>;
+
+fn conv_Inner(val : Inner_std140) -> Inner {
+ return Inner(${mat}(${col_vectors_inline_conv_Inner}));
+}
+
+fn conv_arr4_Inner(val : array<Inner_std140, 4u>) -> array<Inner, 4u> {
+ var arr : array<Inner, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_Inner(val[i]);
+ }
+ return arr;
+}
+
+fn conv_Outer(val : Outer_std140) -> Outer {
+ return Outer(conv_arr4_Inner(val.a));
+}
+
+fn conv_arr4_Outer(val : array<Outer_std140, 4u>) -> array<Outer, 4u> {
+ var arr : array<Outer, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_Outer(val[i]);
+ }
+ return arr;
+}
+
+fn load_a_3_a_2_m() -> ${mat} {
+ let s = &(a[3u].a[2u]);
+ return ${mat}(${col_vectors_inline_load_matrix});
+}
+
+fn load_a_3_a_p0_m(p0 : u32) -> ${mat} {
+ let s = &(a[3u].a[p0]);
+ return ${mat}(${col_vectors_inline_load_matrix});
+}
+
+fn load_a_p0_a_2_m(p0 : u32) -> ${mat} {
+ let s = &(a[p0].a[2u]);
+ return ${mat}(${col_vectors_inline_load_matrix});
+}
+
+fn load_a_p0_a_p1_m(p0 : u32, p1 : u32) -> ${mat} {
+ let s = &(a[p0].a[p1]);
+ return ${mat}(${col_vectors_inline_load_matrix});
+}
+
+fn load_a_p0_a_p1_m_p2(p0 : u32, p1 : u32, p2 : u32) -> ${col_vector_type} {
+ switch(p2) {
+${col_table_load_column}
+ default: {
+ return ${col_vector_type}();
+ }
+ }
+}
+
+fn load_a_p0_a_p1_m_p2_p3(p0 : u32, p1 : u32, p2 : u32, p3 : u32) -> ${elem_type} {
+ switch(p2) {
+${col_table_load_element}
+ default: {
+ return ${elem_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let K = 0;
+ let p_a = conv_arr4_Outer(a);
+ let p_a_3 = conv_Outer(a[3u]);
+ let p_a_I = conv_Outer(a[I]);
+ let p_a_3_a = conv_arr4_Inner(a[3u].a);
+ let p_a_I_a = conv_arr4_Inner(a[I].a);
+ let p_a_3_a_2 = conv_Inner(a[3u].a[2u]);
+ let p_a_3_a_I = conv_Inner(a[3u].a[I]);
+ let p_a_I_a_2 = conv_Inner(a[I].a[2u]);
+ let p_a_I_a_J = conv_Inner(a[I].a[J]);
+ let p_a_3_a_2_m = load_a_3_a_2_m();
+ let p_a_3_a_I_m = load_a_3_a_p0_m(u32(I));
+ let p_a_I_a_2_m = load_a_p0_a_2_m(u32(I));
+ let p_a_I_a_J_m = load_a_p0_a_p1_m(u32(I), u32(J));
+ let p_a_3_a_2_m_1 = a[3u].a[2u].m_1;
+ let p_a_I_a_J_m_K = load_a_p0_a_p1_m_p2(u32(I), u32(J), u32(K));
+ let l_a : array<Outer, 4> = conv_arr4_Outer(a);
+ let l_a_3 : Outer = conv_Outer(a[3u]);
+ let l_a_I : Outer = conv_Outer(a[I]);
+ let l_a_3_a : array<Inner, 4> = conv_arr4_Inner(a[3u].a);
+ let l_a_I_a : array<Inner, 4> = conv_arr4_Inner(a[I].a);
+ let l_a_3_a_2 : Inner = conv_Inner(a[3u].a[2u]);
+ let l_a_3_a_I : Inner = conv_Inner(a[3u].a[I]);
+ let l_a_I_a_2 : Inner = conv_Inner(a[I].a[2u]);
+ let l_a_I_a_J : Inner = conv_Inner(a[I].a[J]);
+ let l_a_3_a_2_m : ${mat} = load_a_3_a_2_m();
+ let l_a_3_a_I_m : ${mat} = load_a_3_a_p0_m(u32(I));
+ let l_a_I_a_2_m : ${mat} = load_a_p0_a_2_m(u32(I));
+ let l_a_I_a_J_m : ${mat} = load_a_p0_a_p1_m(u32(I), u32(J));
+ let l_a_3_a_2_m_1 : ${col_vector_type} = a[3u].a[2u].m_1;
+ let l_a_I_a_J_m_K : ${col_vector_type} = load_a_p0_a_p1_m_p2(u32(I), u32(J), u32(K));
+ let l_a_2_a_0_m_1_0 : ${elem_type} = a[3u].a[2u].m_1[0u];
+ let l_a_I_a_J_m_K_I : ${elem_type} = load_a_p0_a_p1_m_p2_p3(u32(I), u32(J), u32(K), u32(I));
+}
+)";
+ std::string col_tableLoadColumn = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return a[p0].a[p1].m_${col_id_for_tmpl};
+ })",
+ "\n");
+ std::string col_tableLoadElement = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return a[p0].a[p1].m_${col_id_for_tmpl}[p3];
+ })",
+ "\n");
+ uint32_t last_size =
+ 64 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)},
+ {"${col_vectors_inline_conv_Inner}",
+ matrix.ExpendedColumnVectorsInline("val.m_", ", ")},
+ {"${col_vectors_inline_load_matrix}",
+ matrix.ExpendedColumnVectorsInline("(*(s)).m_", ", ")},
+ {"${col_table_load_column}", col_tableLoadColumn},
+ {"${col_table_load_element}", col_tableLoadElement}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, ArrayStructMatUniform_CopyArray_UniformToStorage) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 4>;
+
+@group(0) @binding(1) var<storage, read_write> s : array<S, 4>;
+
+fn f() {
+ s = u;
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 4u>;
+
+@group(0) @binding(1) var<storage, read_write> s : array<S, 4>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(${mat}(${col_vectors_inline}));
+}
+
+fn conv_arr4_S(val : array<S_std140, 4u>) -> array<S, 4u> {
+ var arr : array<S, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_S(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ s = conv_arr4_S(u);
+}
+)";
+ uint32_t last_size =
+ 64 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.m_", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, ArrayStructMatUniform_CopyStruct_UniformToWorkgroup) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 4>;
+
+var<workgroup> w : array<S, 4>;
+
+fn f() {
+ w[0] = u[1];
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : ${mat},
+}
+
+struct S_std140 {
+ v : vec4<i32>,
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 4u>;
+
+var<workgroup> w : array<S, 4>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(val.v, ${mat}(${col_vectors_inline}));
+}
+
+fn f() {
+ w[0] = conv_S(u[1u]);
+}
+)";
+ uint32_t last_size =
+ 64 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.m_", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, ArrayStructMatUniform_CopyMatrix_UniformToPrivate) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 3>;
+
+var<private> p : array<S, 4>;
+
+fn f() {
+ p[2].m = u[1].m;
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : ${mat},
+}
+
+struct S_std140 {
+ v : vec4<i32>,
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 3u>;
+
+var<private> p : array<S, 4>;
+
+fn load_u_1_m() -> ${mat} {
+ let s = &(u[1u]);
+ return ${mat}(${col_vectors_inline});
+}
+
+fn f() {
+ p[2].m = load_u_1_m();
+}
+)";
+ uint32_t last_size =
+ 64 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("(*(s)).m_", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, ArrayStructMatUniform_CopyColumn_UniformToStorage) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 3>;
+
+@group(0) @binding(1) var<storage, read_write> s : array<S, 4>;
+
+fn f() {
+ s[3].m[1] = u[2].m[0];
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 3u>;
+
+@group(0) @binding(1) var<storage, read_write> s : array<S, 4>;
+
+fn f() {
+ s[3].m[1] = u[2u].m_0;
+}
+)";
+ uint32_t last_size =
+ 64 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, ArrayStructMatUniform_CopyColumnSwizzle_UniformToWorkgroup) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 4>;
+
+var<workgroup> w : array<S, 4>;
+
+fn f() {
+ w[3].m[1] = u[2].m[0].${swizzle}.${swizzle};
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : ${mat},
+}
+
+struct S_std140 {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 4u>;
+
+var<workgroup> w : array<S, 4>;
+
+fn f() {
+ w[3].m[1] = u[2u].m_0.${swizzle}.${swizzle};
+}
+)";
+ uint32_t last_size =
+ 64 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_Matrix, ArrayStructMatUniform_CopyScalar_UniformToPrivate) {
+ auto matrix = GetParam();
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 3>;
+
+var<private> p : array<S, 4>;
+
+fn f() {
+ p[3].m[1].x = u[2].m[0].y;
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : ${mat},
+}
+
+struct S_std140 {
+ v : vec4<i32>,
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 3u>;
+
+var<private> p : array<S, 4>;
+
+fn f() {
+ p[3].m[1].x = u[2u].m_0[1u];
+}
+)";
+ uint32_t last_size =
+ 64 - static_cast<uint32_t>(matrix.ColumnVectorAlign() * (matrix.columns - 1));
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectorsWithLastSize(2, "m_", last_size)},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("(*(s)).m_", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+INSTANTIATE_TEST_SUITE_P(,
+ Std140Test_Matrix,
+ ::testing::ValuesIn(std::vector<MatrixCase>{
+ {2, 2, MatrixType::f32},
+ {2, 3, MatrixType::f32},
+ {2, 4, MatrixType::f32},
+ {3, 2, MatrixType::f32},
+ {3, 3, MatrixType::f32},
+ {3, 4, MatrixType::f32},
+ {4, 2, MatrixType::f32},
+ {4, 3, MatrixType::f32},
+ {4, 4, MatrixType::f32},
+ {2, 2, MatrixType::f16},
+ {2, 3, MatrixType::f16},
+ {2, 4, MatrixType::f16},
+ {3, 2, MatrixType::f16},
+ {3, 3, MatrixType::f16},
+ {3, 4, MatrixType::f16},
+ {4, 2, MatrixType::f16},
+ {4, 3, MatrixType::f16},
+ {4, 4, MatrixType::f16},
+ }));
+
+using Std140Test_MatrixArray = TransformTestWithParam<MatrixCase>;
+
+TEST_P(Std140Test_MatrixArray, ArrayMatUniform_LoadArray) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<${mat}, 3>;
+
+fn f() {
+ let l = a;
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat${shape}_${elem_type}, 3u>;
+
+fn conv_mat${shape}_${elem_type}(val : mat${shape}_${elem_type}) -> ${mat} {
+ return ${mat}(${col_vectors_inline});
+}
+
+fn conv_arr3_mat${shape}_${elem_type}(val : array<mat${shape}_${elem_type}, 3u>) -> array<${mat}, 3u> {
+ var arr : array<${mat}, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat${shape}_${elem_type}(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let l = conv_arr3_mat${shape}_${elem_type}(a);
+}
+)";
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.col", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_MatrixArray, ArrayMatUniform_LoadMatrix_ConstArrayIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<${mat}, 3>;
+
+fn f() {
+ let l = a[${array_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat${shape}_${elem_type}, 3u>;
+
+fn conv_mat${shape}_${elem_type}(val : mat${shape}_${elem_type}) -> ${mat} {
+ return ${mat}(${col_vectors_inline});
+}
+
+fn f() {
+ let l = conv_mat${shape}_${elem_type}(a[${array_index}u]);
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.col", ", ")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t array_index = 0; array_index < 3; array_index++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${array_index}", std::to_string(array_index));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${array_index}", std::to_string(array_index));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing array element " << array_index;
+ }
+}
+
+TEST_P(Std140Test_MatrixArray, ArrayMatUniform_LoadMatrix_VariableArrayIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<${mat}, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I];
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat${shape}_${elem_type}, 3u>;
+
+fn conv_mat${shape}_${elem_type}(val : mat${shape}_${elem_type}) -> ${mat} {
+ return ${mat}(${col_vectors_inline});
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_mat${shape}_${elem_type}(a[I]);
+}
+)";
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.col", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_MatrixArray, ArrayMatUniform_LoadColumn_ConstArrayIndex_ConstColumnIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<${mat}, 3>;
+
+fn f() {
+ let l = a[${array_index}][${cloumn_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat${shape}_${elem_type}, 3u>;
+
+fn f() {
+ let l = a[${array_index}u].col${cloumn_index};
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t array_index = 0; array_index < 3; array_index++) {
+ for (uint32_t col = 0; col < matrix.columns; col++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${array_index}", std::to_string(array_index));
+ src = utils::ReplaceAll(src, "${cloumn_index}", std::to_string(col));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${array_index}", std::to_string(array_index));
+ expect = utils::ReplaceAll(expect, "${cloumn_index}", std::to_string(col));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got))
+ << "accessing array element " << array_index << " col " << col;
+ }
+ }
+}
+
+TEST_P(Std140Test_MatrixArray, ArrayMatUniform_LoadColumn_VariableArrayIndex_ConstColumnIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<${mat}, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][${cloumn_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat${shape}_${elem_type}, 3u>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].col${cloumn_index};
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t col = 0; col < matrix.columns; col++) {
+ std::string src = utils::ReplaceAll(tmpl_src, "${cloumn_index}", std::to_string(col));
+ std::string expect = utils::ReplaceAll(tmpl_expect, "${cloumn_index}", std::to_string(col));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing col " << col;
+ }
+}
+
+TEST_P(Std140Test_MatrixArray, ArrayMatUniform_LoadColumn_ConstArrayIndex_VariableColumnIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<${mat}, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[${array_index}][I];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat${shape}_${elem_type}, 3u>;
+
+fn load_a_${array_index}_p0(p0 : u32) -> ${col_vector_type} {
+ switch(p0) {
+${col_table}
+ default: {
+ return ${col_vector_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_${array_index}_p0(u32(I));
+}
+)";
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return a[${array_index}u].col0;
+ // }
+ // case 1u: {
+ // return a[${array_index}u].col1;
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return a[${array_index}u].col${col_id_for_tmpl};
+ })",
+ "\n");
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_table}", col_table}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t array_index = 0; array_index < 3; array_index++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${array_index}", std::to_string(array_index));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${array_index}", std::to_string(array_index));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing array element " << array_index;
+ }
+}
+
+TEST_P(Std140Test_MatrixArray, ArrayMatUniform_LoadColumn_VariableArrayIndex_VariableColumnIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<${mat}, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][I];
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat${shape}_${elem_type}, 3u>;
+
+fn load_a_p0_p1(p0 : u32, p1 : u32) -> ${col_vector_type} {
+ switch(p1) {
+${col_table}
+ default: {
+ return ${col_vector_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_p0_p1(u32(I), u32(I));
+}
+)";
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return a[p0].col0;
+ // }
+ // case 1u: {
+ // return a[p0].col1;
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return a[p0].col${col_id_for_tmpl};
+ })",
+ "\n");
+ expect = matrix.ReplaceFieldsInString(
+ expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_table}", col_table}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_MatrixArray, StructArrayMatUniform_LoadStruct) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ a : array<${mat}, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s;
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+struct S {
+ a : array<${mat}, 3>,
+}
+
+struct S_std140 {
+ a : array<mat${shape}_${elem_type}, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn conv_mat${shape}_${elem_type}(val : mat${shape}_${elem_type}) -> ${mat} {
+ return ${mat}(${col_vectors_inline});
+}
+
+fn conv_arr3_mat${shape}_${elem_type}(val : array<mat${shape}_${elem_type}, 3u>) -> array<${mat}, 3u> {
+ var arr : array<${mat}, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat${shape}_${elem_type}(val[i]);
+ }
+ return arr;
+}
+
+fn conv_S(val : S_std140) -> S {
+ return S(conv_arr3_mat${shape}_${elem_type}(val.a));
+}
+
+fn f() {
+ let l = conv_S(s);
+}
+)";
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.col", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_MatrixArray, StructArrayMatUniform_LoadArray) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ a : array<${mat}, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.a;
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+struct S {
+ a : array<${mat}, 3>,
+}
+
+struct S_std140 {
+ a : array<mat${shape}_${elem_type}, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn conv_mat${shape}_${elem_type}(val : mat${shape}_${elem_type}) -> ${mat} {
+ return ${mat}(${col_vectors_inline});
+}
+
+fn conv_arr3_mat${shape}_${elem_type}(val : array<mat${shape}_${elem_type}, 3u>) -> array<${mat}, 3u> {
+ var arr : array<${mat}, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat${shape}_${elem_type}(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let l = conv_arr3_mat${shape}_${elem_type}(s.a);
+}
+)";
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.col", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_MatrixArray, StructArrayMatUniform_LoadMatrix_ConstArrayIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+struct S {
+ a : array<${mat}, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.a[${array_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+struct S {
+ a : array<${mat}, 3>,
+}
+
+struct S_std140 {
+ a : array<mat${shape}_${elem_type}, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn conv_mat${shape}_${elem_type}(val : mat${shape}_${elem_type}) -> ${mat} {
+ return ${mat}(${col_vectors_inline});
+}
+
+fn f() {
+ let l = conv_mat${shape}_${elem_type}(s.a[${array_index}u]);
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.col", ", ")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t array_index = 0; array_index < 3; array_index++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${array_index}", std::to_string(array_index));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${array_index}", std::to_string(array_index));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing array element " << array_index;
+ }
+}
+
+TEST_P(Std140Test_MatrixArray, StructArrayMatUniform_LoadMatrix_VariableArrayIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ a : array<${mat}, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 1;
+ let l = s.a[I];
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+struct S {
+ a : array<${mat}, 3>,
+}
+
+struct S_std140 {
+ a : array<mat${shape}_${elem_type}, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn conv_mat${shape}_${elem_type}(val : mat${shape}_${elem_type}) -> ${mat} {
+ return ${mat}(${col_vectors_inline});
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_mat${shape}_${elem_type}(s.a[I]);
+}
+)";
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.col", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_MatrixArray, StructArrayMatUniform_LoadColumn_ConstArrayIndex_ConstColumnIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+struct S {
+ a : array<${mat}, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.a[${array_index}][${cloumn_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+struct S {
+ a : array<${mat}, 3>,
+}
+
+struct S_std140 {
+ a : array<mat${shape}_${elem_type}, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let l = s.a[${array_index}u].col${cloumn_index};
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t array_index = 0; array_index < 3; array_index++) {
+ for (uint32_t col = 0; col < matrix.columns; col++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${array_index}", std::to_string(array_index));
+ src = utils::ReplaceAll(src, "${cloumn_index}", std::to_string(col));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${array_index}", std::to_string(array_index));
+ expect = utils::ReplaceAll(expect, "${cloumn_index}", std::to_string(col));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got))
+ << "accessing array element " << array_index << " col " << col;
+ }
+ }
+}
+
+TEST_P(Std140Test_MatrixArray,
+ StructArrayMatUniform_LoadColumn_VariableArrayIndex_ConstColumnIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+struct S {
+ a : array<${mat}, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 1;
+ let l = s.a[I][${cloumn_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+struct S {
+ a : array<${mat}, 3>,
+}
+
+struct S_std140 {
+ a : array<mat${shape}_${elem_type}, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let I = 1;
+ let l = s.a[I].col${cloumn_index};
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t col = 0; col < matrix.columns; col++) {
+ std::string src = utils::ReplaceAll(tmpl_src, "${cloumn_index}", std::to_string(col));
+ std::string expect = utils::ReplaceAll(tmpl_expect, "${cloumn_index}", std::to_string(col));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing col " << col;
+ }
+}
+
+TEST_P(Std140Test_MatrixArray,
+ StructArrayMatUniform_LoadColumn_ConstArrayIndex_VariableColumnIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+struct S {
+ a : array<${mat}, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 1;
+ let l = s.a[${array_index}][I];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+struct S {
+ a : array<${mat}, 3>,
+}
+
+struct S_std140 {
+ a : array<mat${shape}_${elem_type}, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_a_${array_index}_p0(p0 : u32) -> ${col_vector_type} {
+ switch(p0) {
+${col_table}
+ default: {
+ return ${col_vector_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_s_a_${array_index}_p0(u32(I));
+}
+)";
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return s.a[${array_index}u].col0;
+ // }
+ // case 1u: {
+ // return s.a[${array_index}u].col1;
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return s.a[${array_index}u].col${col_id_for_tmpl};
+ })",
+ "\n");
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_table}", col_table}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t array_index = 0; array_index < 3; array_index++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${array_index}", std::to_string(array_index));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${array_index}", std::to_string(array_index));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing array element " << array_index;
+ }
+}
+
+TEST_P(Std140Test_MatrixArray,
+ StructArrayMatUniform_LoadColumn_VariableArrayIndex_VariableColumnIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string src = R"(
+enable f16;
+
+struct S {
+ a : array<${mat}, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 1;
+ let l = s.a[I][I];
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+struct S {
+ a : array<${mat}, 3>,
+}
+
+struct S_std140 {
+ a : array<mat${shape}_${elem_type}, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_a_p0_p1(p0 : u32, p1 : u32) -> ${col_vector_type} {
+ switch(p1) {
+${col_table}
+ default: {
+ return ${col_vector_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_s_a_p0_p1(u32(I), u32(I));
+}
+)";
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return s.a[p0].col0;
+ // }
+ // case 1u: {
+ // return s.a[p0].col1;
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return s.a[p0].col${col_id_for_tmpl};
+ })",
+ "\n");
+ expect = matrix.ReplaceFieldsInString(
+ expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_table}", col_table}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_MatrixArray, ArrayArrayMatUniform_LoadArrays) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<${mat}, 3>, 4>;
+
+fn f() {
+ let l = a;
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat${shape}_${elem_type}, 3u>, 4u>;
+
+fn conv_mat${shape}_${elem_type}(val : mat${shape}_${elem_type}) -> ${mat} {
+ return ${mat}(${col_vectors_inline});
+}
+
+fn conv_arr3_mat${shape}_${elem_type}(val : array<mat${shape}_${elem_type}, 3u>) -> array<${mat}, 3u> {
+ var arr : array<${mat}, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat${shape}_${elem_type}(val[i]);
+ }
+ return arr;
+}
+
+fn conv_arr4_arr3_mat${shape}_${elem_type}(val : array<array<mat${shape}_${elem_type}, 3u>, 4u>) -> array<array<${mat}, 3u>, 4u> {
+ var arr : array<array<${mat}, 3u>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_arr3_mat${shape}_${elem_type}(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let l = conv_arr4_arr3_mat${shape}_${elem_type}(a);
+}
+)";
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.col", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_MatrixArray, ArrayArrayMatUniform_LoadArray_ConstOuterArrayIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<${mat}, 3>, 4>;
+
+fn f() {
+ let l = a[${outer_array_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat${shape}_${elem_type}, 3u>, 4u>;
+
+fn conv_mat${shape}_${elem_type}(val : mat${shape}_${elem_type}) -> ${mat} {
+ return ${mat}(${col_vectors_inline});
+}
+
+fn conv_arr3_mat${shape}_${elem_type}(val : array<mat${shape}_${elem_type}, 3u>) -> array<${mat}, 3u> {
+ var arr : array<${mat}, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat${shape}_${elem_type}(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let l = conv_arr3_mat${shape}_${elem_type}(a[${outer_array_index}u]);
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.col", ", ")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t outer = 0; outer < 4; outer++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${outer_array_index}", std::to_string(outer));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${outer_array_index}", std::to_string(outer));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing array element " << outer;
+ }
+}
+
+TEST_P(Std140Test_MatrixArray, ArrayArrayMatUniform_LoadArray_VariableOuterArrayIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<${mat}, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[I];
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat${shape}_${elem_type}, 3u>, 4u>;
+
+fn conv_mat${shape}_${elem_type}(val : mat${shape}_${elem_type}) -> ${mat} {
+ return ${mat}(${col_vectors_inline});
+}
+
+fn conv_arr3_mat${shape}_${elem_type}(val : array<mat${shape}_${elem_type}, 3u>) -> array<${mat}, 3u> {
+ var arr : array<${mat}, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat${shape}_${elem_type}(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_arr3_mat${shape}_${elem_type}(a[I]);
+}
+)";
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.col", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_MatrixArray,
+ ArrayArrayMatUniform_LoadMatrix_ConstOuterArrayIndex_ConstInnerArrayIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<${mat}, 3>, 4>;
+
+fn f() {
+ let l = a[${outer_array_index}][${inner_array_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat${shape}_${elem_type}, 3u>, 4u>;
+
+fn conv_mat${shape}_${elem_type}(val : mat${shape}_${elem_type}) -> ${mat} {
+ return ${mat}(${col_vectors_inline});
+}
+
+fn f() {
+ let l = conv_mat${shape}_${elem_type}(a[${outer_array_index}u][${inner_array_index}u]);
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.col", ", ")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t outer = 0; outer < 4; outer++) {
+ for (uint32_t inner = 0; inner < 3; inner++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${outer_array_index}", std::to_string(outer));
+ src = utils::ReplaceAll(src, "${inner_array_index}", std::to_string(inner));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${outer_array_index}", std::to_string(outer));
+ expect = utils::ReplaceAll(expect, "${inner_array_index}", std::to_string(inner));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got))
+ << "accessing array element [" << outer << "][" << inner << "]";
+ }
+ }
+}
+
+TEST_P(Std140Test_MatrixArray,
+ ArrayArrayMatUniform_LoadMatrix_ConstOuterArrayIndex_VariableInnerArrayIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<${mat}, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[${outer_array_index}][I];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat${shape}_${elem_type}, 3u>, 4u>;
+
+fn conv_mat${shape}_${elem_type}(val : mat${shape}_${elem_type}) -> ${mat} {
+ return ${mat}(${col_vectors_inline});
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_mat${shape}_${elem_type}(a[${outer_array_index}u][I]);
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.col", ", ")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t outer = 0; outer < 4; outer++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${outer_array_index}", std::to_string(outer));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${outer_array_index}", std::to_string(outer));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing array element [" << outer << "][I]";
+ }
+}
+
+TEST_P(Std140Test_MatrixArray,
+ ArrayArrayMatUniform_LoadMatrix_VariableOuterArrayIndex_ConstInnerArrayIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<${mat}, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][${inner_array_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat${shape}_${elem_type}, 3u>, 4u>;
+
+fn conv_mat${shape}_${elem_type}(val : mat${shape}_${elem_type}) -> ${mat} {
+ return ${mat}(${col_vectors_inline});
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_mat${shape}_${elem_type}(a[I][${inner_array_index}u]);
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.col", ", ")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t inner = 0; inner < 3; inner++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${inner_array_index}", std::to_string(inner));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${inner_array_index}", std::to_string(inner));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing array element [I][" << inner << "]";
+ }
+}
+
+TEST_P(Std140Test_MatrixArray,
+ ArrayArrayMatUniform_LoadMatrix_VariableOuterArrayIndex_VariableInnerArrayIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<${mat}, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][I];
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat${shape}_${elem_type}, 3u>, 4u>;
+
+fn conv_mat${shape}_${elem_type}(val : mat${shape}_${elem_type}) -> ${mat} {
+ return ${mat}(${col_vectors_inline});
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_mat${shape}_${elem_type}(a[I][I]);
+}
+)";
+ expect = matrix.ReplaceFieldsInString(
+ expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.col", ", ")}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(Std140Test_MatrixArray,
+ ArrayArrayMatUniform_LoadColumn_ConstOuterArrayIndex_ConstInnerArrayIndex_ConstColumnIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<${mat}, 3>, 4>;
+
+fn f() {
+ let l = a[${outer_array_index}][${inner_array_index}][${column_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat${shape}_${elem_type}, 3u>, 4u>;
+
+fn f() {
+ let l = a[${outer_array_index}u][${inner_array_index}u].col${column_index};
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect,
+ {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_vectors_inline}", matrix.ExpendedColumnVectorsInline("val.col", ", ")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t outer = 0; outer < 4; outer++) {
+ for (uint32_t inner = 0; inner < 3; inner++) {
+ for (uint32_t col = 0; col < matrix.columns; col++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${outer_array_index}", std::to_string(outer));
+ src = utils::ReplaceAll(src, "${inner_array_index}", std::to_string(inner));
+ src = utils::ReplaceAll(src, "${column_index}", std::to_string(col));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${outer_array_index}", std::to_string(outer));
+ expect = utils::ReplaceAll(expect, "${inner_array_index}", std::to_string(inner));
+ expect = utils::ReplaceAll(expect, "${column_index}", std::to_string(col));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got))
+ << "accessing array element [" << outer << "][" << inner << "] col " << col;
+ }
+ }
+ }
+}
+
+TEST_P(
+ Std140Test_MatrixArray,
+ ArrayArrayMatUniform_LoadColumn_ConstOuterArrayIndex_ConstInnerArrayIndex_VariableColumnIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<${mat}, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[${outer_array_index}][${inner_array_index}][I];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat${shape}_${elem_type}, 3u>, 4u>;
+
+fn load_a_${outer_array_index}_${inner_array_index}_p0(p0 : u32) -> ${col_vector_type} {
+ switch(p0) {
+${col_table}
+ default: {
+ return ${col_vector_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_${outer_array_index}_${inner_array_index}_p0(u32(I));
+}
+)";
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return a[${outer_array_index}u][${inner_array_index}u].col0;
+ // }
+ // case 1u: {
+ // return a[${outer_array_index}u][${inner_array_index}u].col1;
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return a[${outer_array_index}u][${inner_array_index}u].col${col_id_for_tmpl};
+ })",
+ "\n");
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_table}", col_table}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t outer = 0; outer < 4; outer++) {
+ for (uint32_t inner = 0; inner < 3; inner++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${outer_array_index}", std::to_string(outer));
+ src = utils::ReplaceAll(src, "${inner_array_index}", std::to_string(inner));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${outer_array_index}", std::to_string(outer));
+ expect = utils::ReplaceAll(expect, "${inner_array_index}", std::to_string(inner));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got))
+ << "accessing array element [" << outer << "][" << inner << "]";
+ }
+ }
+}
+
+TEST_P(
+ Std140Test_MatrixArray,
+ ArrayArrayMatUniform_LoadColumn_ConstOuterArrayIndex_VariableInnerArrayIndex_ConstColumnIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<${mat}, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[${outer_array_index}][I][${column_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat${shape}_${elem_type}, 3u>, 4u>;
+
+fn f() {
+ let I = 1;
+ let l = a[${outer_array_index}u][I].col${column_index};
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t outer = 0; outer < 4; outer++) {
+ for (uint32_t col = 0; col < matrix.columns; col++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${outer_array_index}", std::to_string(outer));
+ src = utils::ReplaceAll(src, "${column_index}", std::to_string(col));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${outer_array_index}", std::to_string(outer));
+ expect = utils::ReplaceAll(expect, "${column_index}", std::to_string(col));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got))
+ << "accessing array element [" << outer << "][I] col " << col;
+ }
+ }
+}
+
+TEST_P(
+ Std140Test_MatrixArray,
+ ArrayArrayMatUniform_LoadColumn_ConstOuterArrayIndex_VariableInnerArrayIndex_VariableColumnIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<${mat}, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = a[${outer_array_index}][I][J];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat${shape}_${elem_type}, 3u>, 4u>;
+
+fn load_a_${outer_array_index}_p0_p1(p0 : u32, p1 : u32) -> ${col_vector_type} {
+ switch(p1) {
+${col_table}
+ default: {
+ return ${col_vector_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = load_a_${outer_array_index}_p0_p1(u32(I), u32(J));
+}
+)";
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return a[${outer_array_index}u][p0].col0;
+ // }
+ // case 1u: {
+ // return a[${outer_array_index}u][p0].col1;
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return a[${outer_array_index}u][p0].col${col_id_for_tmpl};
+ })",
+ "\n");
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_table}", col_table}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t outer = 0; outer < 4; outer++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${outer_array_index}", std::to_string(outer));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${outer_array_index}", std::to_string(outer));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing array element [" << outer << "][I]";
+ }
+}
+
+TEST_P(
+ Std140Test_MatrixArray,
+ ArrayArrayMatUniform_LoadColumn_VariableOuterArrayIndex_ConstInnerArrayIndex_ConstColumnIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<${mat}, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][${inner_array_index}][${column_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat${shape}_${elem_type}, 3u>, 4u>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][${inner_array_index}u].col${column_index};
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t inner = 0; inner < 3; inner++) {
+ for (uint32_t col = 0; col < matrix.columns; col++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${inner_array_index}", std::to_string(inner));
+ src = utils::ReplaceAll(src, "${column_index}", std::to_string(col));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${inner_array_index}", std::to_string(inner));
+ expect = utils::ReplaceAll(expect, "${column_index}", std::to_string(col));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got))
+ << "accessing array element [I][" << inner << "] col " << col;
+ }
+ }
+}
+
+TEST_P(
+ Std140Test_MatrixArray,
+ ArrayArrayMatUniform_LoadColumn_VariableOuterArrayIndex_ConstInnerArrayIndex_VariableColumnIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<${mat}, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = a[I][${inner_array_index}][J];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat${shape}_${elem_type}, 3u>, 4u>;
+
+fn load_a_p0_${inner_array_index}_p1(p0 : u32, p1 : u32) -> ${col_vector_type} {
+ switch(p1) {
+${col_table}
+ default: {
+ return ${col_vector_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = load_a_p0_${inner_array_index}_p1(u32(I), u32(J));
+}
+)";
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return a[p0][${inner_array_index}u].col0;
+ // }
+ // case 1u: {
+ // return a[p0][${inner_array_index}u].col1;
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return a[p0][${inner_array_index}u].col${col_id_for_tmpl};
+ })",
+ "\n");
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_table}", col_table}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t inner = 0; inner < 3; inner++) {
+ std::string src =
+ utils::ReplaceAll(tmpl_src, "${inner_array_index}", std::to_string(inner));
+ std::string expect =
+ utils::ReplaceAll(tmpl_expect, "${inner_array_index}", std::to_string(inner));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing array element [I][" << inner << "]";
+ }
+}
+
+TEST_P(
+ Std140Test_MatrixArray,
+ ArrayArrayMatUniform_LoadColumn_VariableOuterArrayIndex_VariableInnerArrayIndex_ConstColumnIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string tmpl_src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<${mat}, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = a[I][J][${column_index}];
+}
+)";
+ tmpl_src = matrix.ReplaceFieldsInString(tmpl_src);
+
+ std::string tmpl_expect;
+ if (matrix.NotStd140Compatible()) {
+ tmpl_expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat${shape}_${elem_type}, 3u>, 4u>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = a[I][J].col${column_index};
+}
+)";
+ tmpl_expect = matrix.ReplaceFieldsInString(
+ tmpl_expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")}});
+ } else {
+ tmpl_expect = tmpl_src;
+ }
+
+ for (uint32_t col = 0; col < matrix.columns; col++) {
+ std::string src = utils::ReplaceAll(tmpl_src, "${column_index}", std::to_string(col));
+ std::string expect = utils::ReplaceAll(tmpl_expect, "${column_index}", std::to_string(col));
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got)) << "accessing array element [I][J] col " << col;
+ }
+}
+
+TEST_P(
+ Std140Test_MatrixArray,
+ ArrayArrayMatUniform_LoadColumn_VariableOuterArrayIndex_VariableInnerArrayIndex_VariableColumnIndex) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<${mat}, 3>, 4>;
+
+fn f() {
+ let I = 0;
+ let J = 1;
+ let K = 2;
+ let l = a[I][J][K];
+}
+)";
+ src = matrix.ReplaceFieldsInString(src);
+
+ std::string expect;
+ if (matrix.NotStd140Compatible()) {
+ expect = R"(
+enable f16;
+
+struct mat${shape}_${elem_type} {
+${col_vectors}
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat${shape}_${elem_type}, 3u>, 4u>;
+
+fn load_a_p0_p1_p2(p0 : u32, p1 : u32, p2 : u32) -> ${col_vector_type} {
+ switch(p2) {
+${col_table}
+ default: {
+ return ${col_vector_type}();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let J = 1;
+ let K = 2;
+ let l = load_a_p0_p1_p2(u32(I), u32(J), u32(K));
+}
+)";
+ // col_table is the switch cases for all column index.
+ // Example for a matrix having 2 columns:
+ // case 0u: {
+ // return a[p0][p1].col0;
+ // }
+ // case 1u: {
+ // return a[p0][p1].col1;
+ // }
+ std::string col_table = matrix.JoinTemplatedStringForEachMatrixColumn( //
+ R"( case ${col_id_for_tmpl}u: {
+ return a[p0][p1].col${col_id_for_tmpl};
+ })",
+ "\n");
+ expect = matrix.ReplaceFieldsInString(
+ expect, {{"${col_vectors}", matrix.ExpendedColumnVectors(2, "col")},
+ {"${col_table}", col_table}});
+ } else {
+ expect = src;
+ }
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+INSTANTIATE_TEST_SUITE_P(,
+ Std140Test_MatrixArray,
+ ::testing::ValuesIn(std::vector<MatrixCase>{
+ {2, 2, MatrixType::f32},
+ {2, 3, MatrixType::f32},
+ {2, 4, MatrixType::f32},
+ {3, 2, MatrixType::f32},
+ {3, 3, MatrixType::f32},
+ {3, 4, MatrixType::f32},
+ {4, 2, MatrixType::f32},
+ {4, 3, MatrixType::f32},
+ {4, 4, MatrixType::f32},
+ {2, 2, MatrixType::f16},
+ {2, 3, MatrixType::f16},
+ {2, 4, MatrixType::f16},
+ {3, 2, MatrixType::f16},
+ {3, 3, MatrixType::f16},
+ {3, 4, MatrixType::f16},
+ {4, 2, MatrixType::f16},
+ {4, 3, MatrixType::f16},
+ {4, 4, MatrixType::f16},
+ }));
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/std140_f16_test.cc b/src/tint/lang/wgsl/ast/transform/std140_f16_test.cc
new file mode 100644
index 0000000..50f5fe5
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/std140_f16_test.cc
@@ -0,0 +1,3596 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/std140.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/utils/string.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using Std140Test_F16 = TransformTest;
+
+TEST_F(Std140Test_F16, StructMatricesUniform) {
+ auto* src = R"(
+enable f16;
+
+struct S2x2F16 {
+ m : mat2x2<f16>,
+}
+struct S3x2F16 {
+ m : mat3x2<f16>,
+}
+struct S4x2F16 {
+ m : mat4x2<f16>,
+}
+struct S2x3F16 {
+ m : mat2x3<f16>,
+}
+struct S3x3F16 {
+ m : mat3x3<f16>,
+}
+struct S4x3F16 {
+ m : mat4x3<f16>,
+}
+struct S2x4F16 {
+ m : mat2x4<f16>,
+}
+struct S3x4F16 {
+ m : mat3x4<f16>,
+}
+struct S4x4F16 {
+ m : mat4x4<f16>,
+}
+
+@group(2) @binding(2) var<uniform> s2x2f16 : S2x2F16;
+@group(3) @binding(2) var<uniform> s3x2f16 : S3x2F16;
+@group(4) @binding(2) var<uniform> s4x2f16 : S4x2F16;
+@group(2) @binding(3) var<uniform> s2x3f16 : S2x3F16;
+@group(3) @binding(3) var<uniform> s3x3f16 : S3x3F16;
+@group(4) @binding(3) var<uniform> s4x3f16 : S4x3F16;
+@group(2) @binding(4) var<uniform> s2x4f16 : S2x4F16;
+@group(3) @binding(4) var<uniform> s3x4f16 : S3x4F16;
+@group(4) @binding(4) var<uniform> s4x4f16 : S4x4F16;
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S2x2F16 {
+ m : mat2x2<f16>,
+}
+
+struct S2x2F16_std140 {
+ m_0 : vec2<f16>,
+ m_1 : vec2<f16>,
+}
+
+struct S3x2F16 {
+ m : mat3x2<f16>,
+}
+
+struct S3x2F16_std140 {
+ m_0 : vec2<f16>,
+ m_1 : vec2<f16>,
+ m_2 : vec2<f16>,
+}
+
+struct S4x2F16 {
+ m : mat4x2<f16>,
+}
+
+struct S4x2F16_std140 {
+ m_0 : vec2<f16>,
+ m_1 : vec2<f16>,
+ m_2 : vec2<f16>,
+ m_3 : vec2<f16>,
+}
+
+struct S2x3F16 {
+ m : mat2x3<f16>,
+}
+
+struct S2x3F16_std140 {
+ m_0 : vec3<f16>,
+ m_1 : vec3<f16>,
+}
+
+struct S3x3F16 {
+ m : mat3x3<f16>,
+}
+
+struct S3x3F16_std140 {
+ m_0 : vec3<f16>,
+ m_1 : vec3<f16>,
+ m_2 : vec3<f16>,
+}
+
+struct S4x3F16 {
+ m : mat4x3<f16>,
+}
+
+struct S4x3F16_std140 {
+ m_0 : vec3<f16>,
+ m_1 : vec3<f16>,
+ m_2 : vec3<f16>,
+ m_3 : vec3<f16>,
+}
+
+struct S2x4F16 {
+ m : mat2x4<f16>,
+}
+
+struct S2x4F16_std140 {
+ m_0 : vec4<f16>,
+ m_1 : vec4<f16>,
+}
+
+struct S3x4F16 {
+ m : mat3x4<f16>,
+}
+
+struct S3x4F16_std140 {
+ m_0 : vec4<f16>,
+ m_1 : vec4<f16>,
+ m_2 : vec4<f16>,
+}
+
+struct S4x4F16 {
+ m : mat4x4<f16>,
+}
+
+struct S4x4F16_std140 {
+ m_0 : vec4<f16>,
+ m_1 : vec4<f16>,
+ m_2 : vec4<f16>,
+ m_3 : vec4<f16>,
+}
+
+@group(2) @binding(2) var<uniform> s2x2f16 : S2x2F16_std140;
+
+@group(3) @binding(2) var<uniform> s3x2f16 : S3x2F16_std140;
+
+@group(4) @binding(2) var<uniform> s4x2f16 : S4x2F16_std140;
+
+@group(2) @binding(3) var<uniform> s2x3f16 : S2x3F16_std140;
+
+@group(3) @binding(3) var<uniform> s3x3f16 : S3x3F16_std140;
+
+@group(4) @binding(3) var<uniform> s4x3f16 : S4x3F16_std140;
+
+@group(2) @binding(4) var<uniform> s2x4f16 : S2x4F16_std140;
+
+@group(3) @binding(4) var<uniform> s3x4f16 : S3x4F16_std140;
+
+@group(4) @binding(4) var<uniform> s4x4f16 : S4x4F16_std140;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// In the following tests we only test `mat2x3<f16>`, and set all constant column index to 1, row
+// index 0, inner array index 2, and outer array index 3. For exhaustive tests, i.e. tests on all
+// matrix shape and different valid constant index, please refer to std140_exhaustive_test.cc
+
+TEST_F(Std140Test_F16, SingleStructMatUniform_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, CustomAlign_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ before : i32,
+ @align(128)
+ m : mat2x3<f16>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ before : i32,
+ @align(128)
+ m : mat2x3<f16>,
+ after : i32,
+}
+
+struct S_std140 {
+ before : i32,
+ @align(128i)
+ m_0 : vec3<f16>,
+ m_1 : vec3<f16>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, CustomSizeMat_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ before : i32,
+ @size(128)
+ m : mat2x3<f16>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ before : i32,
+ @size(128)
+ m : mat2x3<f16>,
+ after : i32,
+}
+
+struct S_std140 {
+ before : i32,
+ m_0 : vec3<f16>,
+ @size(120)
+ m_1 : vec3<f16>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, CustomAlignAndSize_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ before : i32,
+ @align(128) @size(128)
+ m : mat2x3<f16>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ before : i32,
+ @align(128) @size(128)
+ m : mat2x3<f16>,
+ after : i32,
+}
+
+struct S_std140 {
+ before : i32,
+ @align(128i)
+ m_0 : vec3<f16>,
+ @size(120)
+ m_1 : vec3<f16>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, MatrixUsageInForLoop_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ for(var i = u32(s.m[0][0]); (i < u32(s.m[i][1])); i += u32(s.m[1][i])) {
+ }
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m_p0_1(p0 : u32) -> f16 {
+ switch(p0) {
+ case 0u: {
+ return s.m_0[1u];
+ }
+ case 1u: {
+ return s.m_1[1u];
+ }
+ default: {
+ return f16();
+ }
+ }
+}
+
+fn f() {
+ for(var i = u32(s.m_0[0u]); (i < u32(load_s_m_p0_1(u32(i)))); i += u32(s.m_1[i])) {
+ }
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, MatUniform_LoadMatrix_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> m : mat2x3<f16>;
+
+fn f() {
+ let l = m;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> m : mat2x3_f16;
+
+fn conv_mat2x3_f16(val : mat2x3_f16) -> mat2x3<f16> {
+ return mat2x3<f16>(val.col0, val.col1);
+}
+
+fn f() {
+ let l = conv_mat2x3_f16(m);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, MatUniform_LoadColumn_ConstIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : mat2x3<f16>;
+
+fn f() {
+ let l = a[1];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : mat2x3_f16;
+
+fn f() {
+ let l = a.col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, MatUniform_LoadColumn_VariableIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : mat2x3<f16>;
+
+fn f() {
+ let I = 1;
+ let l = a[I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : mat2x3_f16;
+
+fn load_a_p0(p0 : u32) -> vec3<f16> {
+ switch(p0) {
+ case 0u: {
+ return a.col0;
+ }
+ case 1u: {
+ return a.col1;
+ }
+ default: {
+ return vec3<f16>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_p0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, MatUniform_LoadColumnSwizzle_ConstIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : mat2x3<f16>;
+
+fn f() {
+ let l = a[1].yzx;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : mat2x3_f16;
+
+fn f() {
+ let l = a.col1.yzx;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, MatUniform_LoadColumnSwizzle_VariableIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : mat2x3<f16>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].yzx;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : mat2x3_f16;
+
+fn load_a_p0_yzx(p0 : u32) -> vec3<f16> {
+ switch(p0) {
+ case 0u: {
+ return a.col0.yzx;
+ }
+ case 1u: {
+ return a.col1.yzx;
+ }
+ default: {
+ return vec3<f16>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_p0_yzx(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, MatUniform_LoadScalar_ConstColumnIndex_ConstRowIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : mat2x3<f16>;
+
+fn f() {
+ let l = a[1][0];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : mat2x3_f16;
+
+fn f() {
+ let l = a.col1[0u];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, MatUniform_LoadScalar_VariableColumnIndex_ConstRowIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : mat2x3<f16>;
+
+fn f() {
+ let I = 0;
+ let l = a[I][0];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : mat2x3_f16;
+
+fn load_a_p0_0(p0 : u32) -> f16 {
+ switch(p0) {
+ case 0u: {
+ return a.col0[0u];
+ }
+ case 1u: {
+ return a.col1[0u];
+ }
+ default: {
+ return f16();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_a_p0_0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, MatUniform_LoadScalar_ConstColumnIndex_VariableRowIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : mat2x3<f16>;
+
+fn f() {
+ let I = 0;
+ let l = a[1][I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : mat2x3_f16;
+
+fn f() {
+ let I = 0;
+ let l = a.col1[I];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, MatUniform_LoadScalar_VariableColumnIndex_VariableRowIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : mat2x3<f16>;
+
+fn f() {
+ let I = 0;
+ let l = a[I][I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : mat2x3_f16;
+
+fn load_a_p0_p1(p0 : u32, p1 : u32) -> f16 {
+ switch(p0) {
+ case 0u: {
+ return a.col0[p1];
+ }
+ case 1u: {
+ return a.col1[p1];
+ }
+ default: {
+ return f16();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_a_p0_p1(u32(I), u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, StructMatUniform_NameCollision_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ m_1 : i32,
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ m_1 : i32,
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_1 : i32,
+ m__0 : vec3<f16>,
+ m__1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, StructMatUniform_LoadStruct_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn conv_S(val : S_std140) -> S {
+ return S(mat2x3<f16>(val.m_0, val.m_1));
+}
+
+fn f() {
+ let l = conv_S(s);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, StructMatUniform_LoadMatrix_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.m;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m() -> mat2x3<f16> {
+ let s = &(s);
+ return mat2x3<f16>((*(s)).m_0, (*(s)).m_1);
+}
+
+fn f() {
+ let l = load_s_m();
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, StructMatUniform_LoadColumn_ConstIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.m[1];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let l = s.m_1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, StructMatUniform_LoadColumn_VariableIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m_p0(p0 : u32) -> vec3<f16> {
+ switch(p0) {
+ case 0u: {
+ return s.m_0;
+ }
+ case 1u: {
+ return s.m_1;
+ }
+ default: {
+ return vec3<f16>();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_s_m_p0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, StructMatUniform_LoadScalar_ConstColumnIndex_ConstRowIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.m[1][0];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let l = s.m_1[0u];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, StructMatUniform_LoadScalar_VariableColumnIndex_ConstRowIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[I][0];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m_p0_0(p0 : u32) -> f16 {
+ switch(p0) {
+ case 0u: {
+ return s.m_0[0u];
+ }
+ case 1u: {
+ return s.m_1[0u];
+ }
+ default: {
+ return f16();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_s_m_p0_0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, StructMatUniform_LoadScalar_ConstColumnIndex_VariableRowIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[1][I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let I = 0;
+ let l = s.m_1[I];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, StructMatUniform_LoadScalar_VariableColumnIndex_VariableRowIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[I][I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m_p0_p1(p0 : u32, p1 : u32) -> f16 {
+ switch(p0) {
+ case 0u: {
+ return s.m_0[p1];
+ }
+ case 1u: {
+ return s.m_1[p1];
+ }
+ default: {
+ return f16();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_s_m_p0_p1(u32(I), u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayStructMatUniform_LoadArray_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ @size(56)
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(mat2x3<f16>(val.m_0, val.m_1));
+}
+
+fn conv_arr3_S(val : array<S_std140, 3u>) -> array<S, 3u> {
+ var arr : array<S, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_S(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let l = conv_arr3_S(a);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayStructMatUniform_LoadStruct_ConstIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a[2];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ @size(56)
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(mat2x3<f16>(val.m_0, val.m_1));
+}
+
+fn f() {
+ let l = conv_S(a[2u]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayStructMatUniform_LoadStruct_VariableIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ @size(56)
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(mat2x3<f16>(val.m_0, val.m_1));
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_S(a[I]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayStructMatUniform_LoadMatrix_ConstArrayIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a[2].m;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ @size(56)
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn load_a_2_m() -> mat2x3<f16> {
+ let s = &(a[2u]);
+ return mat2x3<f16>((*(s)).m_0, (*(s)).m_1);
+}
+
+fn f() {
+ let l = load_a_2_m();
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayStructMatUniform_LoadMatrix_VariableArrayIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].m;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ @size(56)
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn load_a_p0_m(p0 : u32) -> mat2x3<f16> {
+ let s = &(a[p0]);
+ return mat2x3<f16>((*(s)).m_0, (*(s)).m_1);
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_p0_m(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16,
+ ArrayStructMatUniform_LoadColumn_ConstArrayIndex_ConstColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a[2].m[1];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ @size(56)
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn f() {
+ let l = a[2u].m_1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16,
+ ArrayStructMatUniform_LoadColumn_VariableArrayIndex_ConstColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].m[1];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ @size(56)
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].m_1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16,
+ ArrayStructMatUniform_LoadColumn_ConstArrayIndex_VariableColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[2].m[I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ @size(56)
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn load_a_2_m_p0(p0 : u32) -> vec3<f16> {
+ switch(p0) {
+ case 0u: {
+ return a[2u].m_0;
+ }
+ case 1u: {
+ return a[2u].m_1;
+ }
+ default: {
+ return vec3<f16>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_2_m_p0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16,
+ ArrayStructMatUniform_LoadColumn_VariableArrayIndex_VariableColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].m[I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ @size(56)
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn load_a_p0_m_p1(p0 : u32, p1 : u32) -> vec3<f16> {
+ switch(p1) {
+ case 0u: {
+ return a[p0].m_0;
+ }
+ case 1u: {
+ return a[p0].m_1;
+ }
+ default: {
+ return vec3<f16>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_p0_m_p1(u32(I), u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayStructArrayStructMatUniform_Loads_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct Inner {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<Outer, 4>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let K = 0;
+ let l_a : array<Outer, 4> = a;
+ let l_a_1 : Outer = a[1];
+ let l_a_I : Outer = a[I];
+ let l_a_2_a : array<Inner, 4> = a[2].a;
+ let l_a_I_a : array<Inner, 4> = a[I].a;
+ let l_a_3_a_1 : Inner = a[3].a[1];
+ let l_a_3_a_I : Inner = a[3].a[I];
+ let l_a_I_a_1 : Inner = a[I].a[1];
+ let l_a_I_a_J : Inner = a[I].a[J];
+ let l_a_0_a_2_m : mat2x3<f16> = a[0].a[2].m;
+ let l_a_0_a_I_m : mat2x3<f16> = a[0].a[I].m;
+ let l_a_I_a_2_m : mat2x3<f16> = a[I].a[2].m;
+ let l_a_I_a_J_m : mat2x3<f16> = a[I].a[J].m;
+ let l_a_1_a_3_m_0 : vec3<f16> = a[1].a[3].m[0];
+ let l_a_I_a_J_m_K : vec3<f16> = a[I].a[J].m[K];
+ let l_a_2_a_0_m_1_0 : f16 = a[2].a[0].m[1][0];
+ let l_a_I_a_J_m_K_I : f16 = a[I].a[J].m[K][I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct Inner {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct Inner_std140 {
+ m_0 : vec3<f16>,
+ @size(56)
+ m_1 : vec3<f16>,
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+struct Outer_std140 {
+ a : array<Inner_std140, 4u>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<Outer_std140, 4u>;
+
+fn conv_Inner(val : Inner_std140) -> Inner {
+ return Inner(mat2x3<f16>(val.m_0, val.m_1));
+}
+
+fn conv_arr4_Inner(val : array<Inner_std140, 4u>) -> array<Inner, 4u> {
+ var arr : array<Inner, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_Inner(val[i]);
+ }
+ return arr;
+}
+
+fn conv_Outer(val : Outer_std140) -> Outer {
+ return Outer(conv_arr4_Inner(val.a));
+}
+
+fn conv_arr4_Outer(val : array<Outer_std140, 4u>) -> array<Outer, 4u> {
+ var arr : array<Outer, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_Outer(val[i]);
+ }
+ return arr;
+}
+
+fn load_a_0_a_2_m() -> mat2x3<f16> {
+ let s = &(a[0u].a[2u]);
+ return mat2x3<f16>((*(s)).m_0, (*(s)).m_1);
+}
+
+fn load_a_0_a_p0_m(p0 : u32) -> mat2x3<f16> {
+ let s = &(a[0u].a[p0]);
+ return mat2x3<f16>((*(s)).m_0, (*(s)).m_1);
+}
+
+fn load_a_p0_a_2_m(p0 : u32) -> mat2x3<f16> {
+ let s = &(a[p0].a[2u]);
+ return mat2x3<f16>((*(s)).m_0, (*(s)).m_1);
+}
+
+fn load_a_p0_a_p1_m(p0 : u32, p1 : u32) -> mat2x3<f16> {
+ let s = &(a[p0].a[p1]);
+ return mat2x3<f16>((*(s)).m_0, (*(s)).m_1);
+}
+
+fn load_a_p0_a_p1_m_p2(p0 : u32, p1 : u32, p2 : u32) -> vec3<f16> {
+ switch(p2) {
+ case 0u: {
+ return a[p0].a[p1].m_0;
+ }
+ case 1u: {
+ return a[p0].a[p1].m_1;
+ }
+ default: {
+ return vec3<f16>();
+ }
+ }
+}
+
+fn load_a_p0_a_p1_m_p2_p3(p0 : u32, p1 : u32, p2 : u32, p3 : u32) -> f16 {
+ switch(p2) {
+ case 0u: {
+ return a[p0].a[p1].m_0[p3];
+ }
+ case 1u: {
+ return a[p0].a[p1].m_1[p3];
+ }
+ default: {
+ return f16();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let K = 0;
+ let l_a : array<Outer, 4> = conv_arr4_Outer(a);
+ let l_a_1 : Outer = conv_Outer(a[1u]);
+ let l_a_I : Outer = conv_Outer(a[I]);
+ let l_a_2_a : array<Inner, 4> = conv_arr4_Inner(a[2u].a);
+ let l_a_I_a : array<Inner, 4> = conv_arr4_Inner(a[I].a);
+ let l_a_3_a_1 : Inner = conv_Inner(a[3u].a[1u]);
+ let l_a_3_a_I : Inner = conv_Inner(a[3u].a[I]);
+ let l_a_I_a_1 : Inner = conv_Inner(a[I].a[1u]);
+ let l_a_I_a_J : Inner = conv_Inner(a[I].a[J]);
+ let l_a_0_a_2_m : mat2x3<f16> = load_a_0_a_2_m();
+ let l_a_0_a_I_m : mat2x3<f16> = load_a_0_a_p0_m(u32(I));
+ let l_a_I_a_2_m : mat2x3<f16> = load_a_p0_a_2_m(u32(I));
+ let l_a_I_a_J_m : mat2x3<f16> = load_a_p0_a_p1_m(u32(I), u32(J));
+ let l_a_1_a_3_m_0 : vec3<f16> = a[1u].a[3u].m_0;
+ let l_a_I_a_J_m_K : vec3<f16> = load_a_p0_a_p1_m_p2(u32(I), u32(J), u32(K));
+ let l_a_2_a_0_m_1_0 : f16 = a[2u].a[0u].m_1[0u];
+ let l_a_I_a_J_m_K_I : f16 = load_a_p0_a_p1_m_p2_p3(u32(I), u32(J), u32(K), u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayStructArrayStructMatUniform_LoadsViaPtrs_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct Inner {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<Outer, 4>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let K = 0;
+ let p_a = &(a);
+ let p_a_3 = &((*(p_a))[3]);
+ let p_a_I = &((*(p_a))[I]);
+ let p_a_3_a = &((*(p_a_3)).a);
+ let p_a_I_a = &((*(p_a_I)).a);
+ let p_a_3_a_2 = &((*(p_a_3_a))[2]);
+ let p_a_3_a_I = &((*(p_a_3_a))[I]);
+ let p_a_I_a_2 = &((*(p_a_I_a))[2]);
+ let p_a_I_a_J = &((*(p_a_I_a))[J]);
+ let p_a_3_a_2_m = &((*(p_a_3_a_2)).m);
+ let p_a_3_a_I_m = &((*(p_a_3_a_I)).m);
+ let p_a_I_a_2_m = &((*(p_a_I_a_2)).m);
+ let p_a_I_a_J_m = &((*(p_a_I_a_J)).m);
+ let p_a_3_a_2_m_1 = &((*(p_a_3_a_2_m))[1]);
+ let p_a_I_a_J_m_K = &((*(p_a_I_a_J_m))[K]);
+ let l_a : array<Outer, 4> = *(p_a);
+ let l_a_3 : Outer = *(p_a_3);
+ let l_a_I : Outer = *(p_a_I);
+ let l_a_3_a : array<Inner, 4> = *(p_a_3_a);
+ let l_a_I_a : array<Inner, 4> = *(p_a_I_a);
+ let l_a_3_a_2 : Inner = *(p_a_3_a_2);
+ let l_a_3_a_I : Inner = *(p_a_3_a_I);
+ let l_a_I_a_2 : Inner = *(p_a_I_a_2);
+ let l_a_I_a_J : Inner = *(p_a_I_a_J);
+ let l_a_3_a_2_m : mat2x3<f16> = *(p_a_3_a_2_m);
+ let l_a_3_a_I_m : mat2x3<f16> = *(p_a_3_a_I_m);
+ let l_a_I_a_2_m : mat2x3<f16> = *(p_a_I_a_2_m);
+ let l_a_I_a_J_m : mat2x3<f16> = *(p_a_I_a_J_m);
+ let l_a_3_a_2_m_1 : vec3<f16> = *(p_a_3_a_2_m_1);
+ let l_a_I_a_J_m_K : vec3<f16> = *(p_a_I_a_J_m_K);
+ let l_a_2_a_0_m_1_0 : f16 = (*(p_a_3_a_2_m_1))[0];
+ let l_a_I_a_J_m_K_I : f16 = (*(p_a_I_a_J_m_K))[I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct Inner {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct Inner_std140 {
+ m_0 : vec3<f16>,
+ @size(56)
+ m_1 : vec3<f16>,
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+struct Outer_std140 {
+ a : array<Inner_std140, 4u>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<Outer_std140, 4u>;
+
+fn conv_Inner(val : Inner_std140) -> Inner {
+ return Inner(mat2x3<f16>(val.m_0, val.m_1));
+}
+
+fn conv_arr4_Inner(val : array<Inner_std140, 4u>) -> array<Inner, 4u> {
+ var arr : array<Inner, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_Inner(val[i]);
+ }
+ return arr;
+}
+
+fn conv_Outer(val : Outer_std140) -> Outer {
+ return Outer(conv_arr4_Inner(val.a));
+}
+
+fn conv_arr4_Outer(val : array<Outer_std140, 4u>) -> array<Outer, 4u> {
+ var arr : array<Outer, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_Outer(val[i]);
+ }
+ return arr;
+}
+
+fn load_a_3_a_2_m() -> mat2x3<f16> {
+ let s = &(a[3u].a[2u]);
+ return mat2x3<f16>((*(s)).m_0, (*(s)).m_1);
+}
+
+fn load_a_3_a_p0_m(p0 : u32) -> mat2x3<f16> {
+ let s = &(a[3u].a[p0]);
+ return mat2x3<f16>((*(s)).m_0, (*(s)).m_1);
+}
+
+fn load_a_p0_a_2_m(p0 : u32) -> mat2x3<f16> {
+ let s = &(a[p0].a[2u]);
+ return mat2x3<f16>((*(s)).m_0, (*(s)).m_1);
+}
+
+fn load_a_p0_a_p1_m(p0 : u32, p1 : u32) -> mat2x3<f16> {
+ let s = &(a[p0].a[p1]);
+ return mat2x3<f16>((*(s)).m_0, (*(s)).m_1);
+}
+
+fn load_a_p0_a_p1_m_p2(p0 : u32, p1 : u32, p2 : u32) -> vec3<f16> {
+ switch(p2) {
+ case 0u: {
+ return a[p0].a[p1].m_0;
+ }
+ case 1u: {
+ return a[p0].a[p1].m_1;
+ }
+ default: {
+ return vec3<f16>();
+ }
+ }
+}
+
+fn load_a_p0_a_p1_m_p2_p3(p0 : u32, p1 : u32, p2 : u32, p3 : u32) -> f16 {
+ switch(p2) {
+ case 0u: {
+ return a[p0].a[p1].m_0[p3];
+ }
+ case 1u: {
+ return a[p0].a[p1].m_1[p3];
+ }
+ default: {
+ return f16();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let K = 0;
+ let p_a = conv_arr4_Outer(a);
+ let p_a_3 = conv_Outer(a[3u]);
+ let p_a_I = conv_Outer(a[I]);
+ let p_a_3_a = conv_arr4_Inner(a[3u].a);
+ let p_a_I_a = conv_arr4_Inner(a[I].a);
+ let p_a_3_a_2 = conv_Inner(a[3u].a[2u]);
+ let p_a_3_a_I = conv_Inner(a[3u].a[I]);
+ let p_a_I_a_2 = conv_Inner(a[I].a[2u]);
+ let p_a_I_a_J = conv_Inner(a[I].a[J]);
+ let p_a_3_a_2_m = load_a_3_a_2_m();
+ let p_a_3_a_I_m = load_a_3_a_p0_m(u32(I));
+ let p_a_I_a_2_m = load_a_p0_a_2_m(u32(I));
+ let p_a_I_a_J_m = load_a_p0_a_p1_m(u32(I), u32(J));
+ let p_a_3_a_2_m_1 = a[3u].a[2u].m_1;
+ let p_a_I_a_J_m_K = load_a_p0_a_p1_m_p2(u32(I), u32(J), u32(K));
+ let l_a : array<Outer, 4> = conv_arr4_Outer(a);
+ let l_a_3 : Outer = conv_Outer(a[3u]);
+ let l_a_I : Outer = conv_Outer(a[I]);
+ let l_a_3_a : array<Inner, 4> = conv_arr4_Inner(a[3u].a);
+ let l_a_I_a : array<Inner, 4> = conv_arr4_Inner(a[I].a);
+ let l_a_3_a_2 : Inner = conv_Inner(a[3u].a[2u]);
+ let l_a_3_a_I : Inner = conv_Inner(a[3u].a[I]);
+ let l_a_I_a_2 : Inner = conv_Inner(a[I].a[2u]);
+ let l_a_I_a_J : Inner = conv_Inner(a[I].a[J]);
+ let l_a_3_a_2_m : mat2x3<f16> = load_a_3_a_2_m();
+ let l_a_3_a_I_m : mat2x3<f16> = load_a_3_a_p0_m(u32(I));
+ let l_a_I_a_2_m : mat2x3<f16> = load_a_p0_a_2_m(u32(I));
+ let l_a_I_a_J_m : mat2x3<f16> = load_a_p0_a_p1_m(u32(I), u32(J));
+ let l_a_3_a_2_m_1 : vec3<f16> = a[3u].a[2u].m_1;
+ let l_a_I_a_J_m_K : vec3<f16> = load_a_p0_a_p1_m_p2(u32(I), u32(J), u32(K));
+ let l_a_2_a_0_m_1_0 : f16 = a[3u].a[2u].m_1[0u];
+ let l_a_I_a_J_m_K_I : f16 = load_a_p0_a_p1_m_p2_p3(u32(I), u32(J), u32(K), u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayStructMatUniform_CopyArray_UniformToStorage_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 4>;
+
+@group(0) @binding(1) var<storage, read_write> s : array<S, 4>;
+
+fn f() {
+ s = u;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ @size(56)
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 4u>;
+
+@group(0) @binding(1) var<storage, read_write> s : array<S, 4>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(mat2x3<f16>(val.m_0, val.m_1));
+}
+
+fn conv_arr4_S(val : array<S_std140, 4u>) -> array<S, 4u> {
+ var arr : array<S, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_S(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ s = conv_arr4_S(u);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayStructMatUniform_CopyStruct_UniformToWorkgroup_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 4>;
+
+var<workgroup> w : array<S, 4>;
+
+fn f() {
+ w[0] = u[1];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ v : vec4<i32>,
+ m_0 : vec3<f16>,
+ @size(56)
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 4u>;
+
+var<workgroup> w : array<S, 4>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(val.v, mat2x3<f16>(val.m_0, val.m_1));
+}
+
+fn f() {
+ w[0] = conv_S(u[1u]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayStructMatUniform_CopyMatrix_UniformToPrivate_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 3>;
+
+var<private> p : array<S, 4>;
+
+fn f() {
+ p[2].m = u[1].m;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ v : vec4<i32>,
+ m_0 : vec3<f16>,
+ @size(56)
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 3u>;
+
+var<private> p : array<S, 4>;
+
+fn load_u_1_m() -> mat2x3<f16> {
+ let s = &(u[1u]);
+ return mat2x3<f16>((*(s)).m_0, (*(s)).m_1);
+}
+
+fn f() {
+ p[2].m = load_u_1_m();
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayStructMatUniform_CopyColumn_UniformToStorage_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 3>;
+
+@group(0) @binding(1) var<storage, read_write> s : array<S, 4>;
+
+fn f() {
+ s[3].m[1] = u[2].m[0];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ @size(56)
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 3u>;
+
+@group(0) @binding(1) var<storage, read_write> s : array<S, 4>;
+
+fn f() {
+ s[3].m[1] = u[2u].m_0;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayStructMatUniform_CopyColumnSwizzle_UniformToWorkgroup_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 4>;
+
+var<workgroup> w : array<S, 4>;
+
+fn f() {
+ w[3].m[1] = u[2].m[0].yzx.yzx;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ m_0 : vec3<f16>,
+ @size(56)
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 4u>;
+
+var<workgroup> w : array<S, 4>;
+
+fn f() {
+ w[3].m[1] = u[2u].m_0.yzx.yzx;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayStructMatUniform_CopyScalar_UniformToPrivate_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 3>;
+
+var<private> p : array<S, 4>;
+
+fn f() {
+ p[3].m[1].x = u[2].m[0].y;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : mat2x3<f16>,
+}
+
+struct S_std140 {
+ v : vec4<i32>,
+ m_0 : vec3<f16>,
+ @size(56)
+ m_1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 3u>;
+
+var<private> p : array<S, 4>;
+
+fn f() {
+ p[3].m[1].x = u[2u].m_0[1u];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayMatUniform_LoadArray_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<mat2x3<f16>, 3>;
+
+fn f() {
+ let l = a;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat2x3_f16, 3u>;
+
+fn conv_mat2x3_f16(val : mat2x3_f16) -> mat2x3<f16> {
+ return mat2x3<f16>(val.col0, val.col1);
+}
+
+fn conv_arr3_mat2x3_f16(val : array<mat2x3_f16, 3u>) -> array<mat2x3<f16>, 3u> {
+ var arr : array<mat2x3<f16>, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat2x3_f16(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let l = conv_arr3_mat2x3_f16(a);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayMatUniform_LoadMatrix_ConstArrayIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<mat2x3<f16>, 3>;
+
+fn f() {
+ let l = a[2];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat2x3_f16, 3u>;
+
+fn conv_mat2x3_f16(val : mat2x3_f16) -> mat2x3<f16> {
+ return mat2x3<f16>(val.col0, val.col1);
+}
+
+fn f() {
+ let l = conv_mat2x3_f16(a[2u]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayMatUniform_LoadMatrix_VariableArrayIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<mat2x3<f16>, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat2x3_f16, 3u>;
+
+fn conv_mat2x3_f16(val : mat2x3_f16) -> mat2x3<f16> {
+ return mat2x3<f16>(val.col0, val.col1);
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_mat2x3_f16(a[I]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayMatUniform_LoadColumn_ConstArrayIndex_ConstColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<mat2x3<f16>, 3>;
+
+fn f() {
+ let l = a[2][1];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat2x3_f16, 3u>;
+
+fn f() {
+ let l = a[2u].col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayMatUniform_LoadColumn_VariableArrayIndex_ConstColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<mat2x3<f16>, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][1];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat2x3_f16, 3u>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayMatUniform_LoadColumn_ConstArrayIndex_VariableColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<mat2x3<f16>, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[2][I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat2x3_f16, 3u>;
+
+fn load_a_2_p0(p0 : u32) -> vec3<f16> {
+ switch(p0) {
+ case 0u: {
+ return a[2u].col0;
+ }
+ case 1u: {
+ return a[2u].col1;
+ }
+ default: {
+ return vec3<f16>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_2_p0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16,
+ ArrayMatUniform_LoadColumn_VariableArrayIndex_VariableColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<mat2x3<f16>, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat2x3_f16, 3u>;
+
+fn load_a_p0_p1(p0 : u32, p1 : u32) -> vec3<f16> {
+ switch(p1) {
+ case 0u: {
+ return a[p0].col0;
+ }
+ case 1u: {
+ return a[p0].col1;
+ }
+ default: {
+ return vec3<f16>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_p0_p1(u32(I), u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, StructArrayMatUniform_LoadStruct_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ a : array<mat2x3<f16>, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+struct S {
+ a : array<mat2x3<f16>, 3>,
+}
+
+struct S_std140 {
+ a : array<mat2x3_f16, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn conv_mat2x3_f16(val : mat2x3_f16) -> mat2x3<f16> {
+ return mat2x3<f16>(val.col0, val.col1);
+}
+
+fn conv_arr3_mat2x3_f16(val : array<mat2x3_f16, 3u>) -> array<mat2x3<f16>, 3u> {
+ var arr : array<mat2x3<f16>, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat2x3_f16(val[i]);
+ }
+ return arr;
+}
+
+fn conv_S(val : S_std140) -> S {
+ return S(conv_arr3_mat2x3_f16(val.a));
+}
+
+fn f() {
+ let l = conv_S(s);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, StructArrayMatUniform_LoadArray_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ a : array<mat2x3<f16>, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.a;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+struct S {
+ a : array<mat2x3<f16>, 3>,
+}
+
+struct S_std140 {
+ a : array<mat2x3_f16, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn conv_mat2x3_f16(val : mat2x3_f16) -> mat2x3<f16> {
+ return mat2x3<f16>(val.col0, val.col1);
+}
+
+fn conv_arr3_mat2x3_f16(val : array<mat2x3_f16, 3u>) -> array<mat2x3<f16>, 3u> {
+ var arr : array<mat2x3<f16>, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat2x3_f16(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let l = conv_arr3_mat2x3_f16(s.a);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, StructArrayMatUniform_LoadMatrix_ConstArrayIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ a : array<mat2x3<f16>, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.a[2];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+struct S {
+ a : array<mat2x3<f16>, 3>,
+}
+
+struct S_std140 {
+ a : array<mat2x3_f16, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn conv_mat2x3_f16(val : mat2x3_f16) -> mat2x3<f16> {
+ return mat2x3<f16>(val.col0, val.col1);
+}
+
+fn f() {
+ let l = conv_mat2x3_f16(s.a[2u]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, StructArrayMatUniform_LoadMatrix_VariableArrayIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ a : array<mat2x3<f16>, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 1;
+ let l = s.a[I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+struct S {
+ a : array<mat2x3<f16>, 3>,
+}
+
+struct S_std140 {
+ a : array<mat2x3_f16, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn conv_mat2x3_f16(val : mat2x3_f16) -> mat2x3<f16> {
+ return mat2x3<f16>(val.col0, val.col1);
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_mat2x3_f16(s.a[I]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16,
+ StructArrayMatUniform_LoadColumn_ConstArrayIndex_ConstColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ a : array<mat2x3<f16>, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.a[2][1];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+struct S {
+ a : array<mat2x3<f16>, 3>,
+}
+
+struct S_std140 {
+ a : array<mat2x3_f16, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let l = s.a[2u].col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16,
+ StructArrayMatUniform_LoadColumn_VariableArrayIndex_ConstColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ a : array<mat2x3<f16>, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 1;
+ let l = s.a[I][1];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+struct S {
+ a : array<mat2x3<f16>, 3>,
+}
+
+struct S_std140 {
+ a : array<mat2x3_f16, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let I = 1;
+ let l = s.a[I].col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16,
+ StructArrayMatUniform_LoadColumn_ConstArrayIndex_VariableColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ a : array<mat2x3<f16>, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 1;
+ let l = s.a[2][I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+struct S {
+ a : array<mat2x3<f16>, 3>,
+}
+
+struct S_std140 {
+ a : array<mat2x3_f16, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_a_2_p0(p0 : u32) -> vec3<f16> {
+ switch(p0) {
+ case 0u: {
+ return s.a[2u].col0;
+ }
+ case 1u: {
+ return s.a[2u].col1;
+ }
+ default: {
+ return vec3<f16>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_s_a_2_p0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16,
+ StructArrayMatUniform_LoadColumn_VariableArrayIndex_VariableColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+struct S {
+ a : array<mat2x3<f16>, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 1;
+ let l = s.a[I][I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+struct S {
+ a : array<mat2x3<f16>, 3>,
+}
+
+struct S_std140 {
+ a : array<mat2x3_f16, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_a_p0_p1(p0 : u32, p1 : u32) -> vec3<f16> {
+ switch(p1) {
+ case 0u: {
+ return s.a[p0].col0;
+ }
+ case 1u: {
+ return s.a[p0].col1;
+ }
+ default: {
+ return vec3<f16>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_s_a_p0_p1(u32(I), u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayArrayMatUniform_LoadArrays_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3<f16>, 3>, 4>;
+
+fn f() {
+ let l = a;
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3_f16, 3u>, 4u>;
+
+fn conv_mat2x3_f16(val : mat2x3_f16) -> mat2x3<f16> {
+ return mat2x3<f16>(val.col0, val.col1);
+}
+
+fn conv_arr3_mat2x3_f16(val : array<mat2x3_f16, 3u>) -> array<mat2x3<f16>, 3u> {
+ var arr : array<mat2x3<f16>, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat2x3_f16(val[i]);
+ }
+ return arr;
+}
+
+fn conv_arr4_arr3_mat2x3_f16(val : array<array<mat2x3_f16, 3u>, 4u>) -> array<array<mat2x3<f16>, 3u>, 4u> {
+ var arr : array<array<mat2x3<f16>, 3u>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_arr3_mat2x3_f16(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let l = conv_arr4_arr3_mat2x3_f16(a);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayArrayMatUniform_LoadArray_ConstOuterArrayIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3<f16>, 3>, 4>;
+
+fn f() {
+ let l = a[3];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3_f16, 3u>, 4u>;
+
+fn conv_mat2x3_f16(val : mat2x3_f16) -> mat2x3<f16> {
+ return mat2x3<f16>(val.col0, val.col1);
+}
+
+fn conv_arr3_mat2x3_f16(val : array<mat2x3_f16, 3u>) -> array<mat2x3<f16>, 3u> {
+ var arr : array<mat2x3<f16>, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat2x3_f16(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let l = conv_arr3_mat2x3_f16(a[3u]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16, ArrayArrayMatUniform_LoadArray_VariableOuterArrayIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3<f16>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3_f16, 3u>, 4u>;
+
+fn conv_mat2x3_f16(val : mat2x3_f16) -> mat2x3<f16> {
+ return mat2x3<f16>(val.col0, val.col1);
+}
+
+fn conv_arr3_mat2x3_f16(val : array<mat2x3_f16, 3u>) -> array<mat2x3<f16>, 3u> {
+ var arr : array<mat2x3<f16>, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat2x3_f16(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_arr3_mat2x3_f16(a[I]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16,
+ ArrayArrayMatUniform_LoadMatrix_ConstOuterArrayIndex_ConstInnerArrayIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3<f16>, 3>, 4>;
+
+fn f() {
+ let l = a[3][2];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3_f16, 3u>, 4u>;
+
+fn conv_mat2x3_f16(val : mat2x3_f16) -> mat2x3<f16> {
+ return mat2x3<f16>(val.col0, val.col1);
+}
+
+fn f() {
+ let l = conv_mat2x3_f16(a[3u][2u]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16,
+ ArrayArrayMatUniform_LoadMatrix_ConstOuterArrayIndex_VariableInnerArrayIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3<f16>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[3][I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3_f16, 3u>, 4u>;
+
+fn conv_mat2x3_f16(val : mat2x3_f16) -> mat2x3<f16> {
+ return mat2x3<f16>(val.col0, val.col1);
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_mat2x3_f16(a[3u][I]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16,
+ ArrayArrayMatUniform_LoadMatrix_VariableOuterArrayIndex_ConstInnerArrayIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3<f16>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][2];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3_f16, 3u>, 4u>;
+
+fn conv_mat2x3_f16(val : mat2x3_f16) -> mat2x3<f16> {
+ return mat2x3<f16>(val.col0, val.col1);
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_mat2x3_f16(a[I][2u]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F16,
+ ArrayArrayMatUniform_LoadMatrix_VariableOuterArrayIndex_VariableInnerArrayIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3<f16>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3_f16, 3u>, 4u>;
+
+fn conv_mat2x3_f16(val : mat2x3_f16) -> mat2x3<f16> {
+ return mat2x3<f16>(val.col0, val.col1);
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_mat2x3_f16(a[I][I]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(
+ Std140Test_F16,
+ ArrayArrayMatUniform_LoadColumn_ConstOuterArrayIndex_ConstInnerArrayIndex_ConstColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3<f16>, 3>, 4>;
+
+fn f() {
+ let l = a[3][2][1];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3_f16, 3u>, 4u>;
+
+fn f() {
+ let l = a[3u][2u].col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(
+ Std140Test_F16,
+ ArrayArrayMatUniform_LoadColumn_ConstOuterArrayIndex_ConstInnerArrayIndex_VariableColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3<f16>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[3][2][I];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3_f16, 3u>, 4u>;
+
+fn load_a_3_2_p0(p0 : u32) -> vec3<f16> {
+ switch(p0) {
+ case 0u: {
+ return a[3u][2u].col0;
+ }
+ case 1u: {
+ return a[3u][2u].col1;
+ }
+ default: {
+ return vec3<f16>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_3_2_p0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(
+ Std140Test_F16,
+ ArrayArrayMatUniform_LoadColumn_ConstOuterArrayIndex_VariableInnerArrayIndex_ConstColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3<f16>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[3][I][1];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3_f16, 3u>, 4u>;
+
+fn f() {
+ let I = 1;
+ let l = a[3u][I].col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(
+ Std140Test_F16,
+ ArrayArrayMatUniform_LoadColumn_ConstOuterArrayIndex_VariableInnerArrayIndex_VariableColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3<f16>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = a[3][I][J];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3_f16, 3u>, 4u>;
+
+fn load_a_3_p0_p1(p0 : u32, p1 : u32) -> vec3<f16> {
+ switch(p1) {
+ case 0u: {
+ return a[3u][p0].col0;
+ }
+ case 1u: {
+ return a[3u][p0].col1;
+ }
+ default: {
+ return vec3<f16>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = load_a_3_p0_p1(u32(I), u32(J));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(
+ Std140Test_F16,
+ ArrayArrayMatUniform_LoadColumn_VariableOuterArrayIndex_ConstInnerArrayIndex_ConstColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3<f16>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][2][1];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3_f16, 3u>, 4u>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][2u].col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(
+ Std140Test_F16,
+ ArrayArrayMatUniform_LoadColumn_VariableOuterArrayIndex_ConstInnerArrayIndex_VariableColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3<f16>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = a[I][2][J];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3_f16, 3u>, 4u>;
+
+fn load_a_p0_2_p1(p0 : u32, p1 : u32) -> vec3<f16> {
+ switch(p1) {
+ case 0u: {
+ return a[p0][2u].col0;
+ }
+ case 1u: {
+ return a[p0][2u].col1;
+ }
+ default: {
+ return vec3<f16>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = load_a_p0_2_p1(u32(I), u32(J));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(
+ Std140Test_F16,
+ ArrayArrayMatUniform_LoadColumn_VariableOuterArrayIndex_VariableInnerArrayIndex_ConstColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3<f16>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = a[I][J][1];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3_f16, 3u>, 4u>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = a[I][J].col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(
+ Std140Test_F16,
+ ArrayArrayMatUniform_LoadColumn_VariableOuterArrayIndex_VariableInnerArrayIndex_VariableColumnIndex_Mat2x3F16) {
+ auto* src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3<f16>, 3>, 4>;
+
+fn f() {
+ let I = 0;
+ let J = 1;
+ let K = 2;
+ let l = a[I][J][K];
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct mat2x3_f16 {
+ col0 : vec3<f16>,
+ col1 : vec3<f16>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x3_f16, 3u>, 4u>;
+
+fn load_a_p0_p1_p2(p0 : u32, p1 : u32, p2 : u32) -> vec3<f16> {
+ switch(p2) {
+ case 0u: {
+ return a[p0][p1].col0;
+ }
+ case 1u: {
+ return a[p0][p1].col1;
+ }
+ default: {
+ return vec3<f16>();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let J = 1;
+ let K = 2;
+ let l = load_a_p0_p1_p2(u32(I), u32(J), u32(K));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/std140_f32_test.cc b/src/tint/lang/wgsl/ast/transform/std140_f32_test.cc
new file mode 100644
index 0000000..f46d33a
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/std140_f32_test.cc
@@ -0,0 +1,3359 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/std140.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/utils/string.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using Std140Test_F32 = TransformTest;
+
+TEST_F(Std140Test_F32, StructMatricesUniform) {
+ auto* src = R"(
+struct S2x2F32 {
+ m : mat2x2<f32>,
+}
+struct S3x2F32 {
+ m : mat3x2<f32>,
+}
+struct S4x2F32 {
+ m : mat4x2<f32>,
+}
+struct S2x3F32 {
+ m : mat2x3<f32>,
+}
+struct S3x3F32 {
+ m : mat3x3<f32>,
+}
+struct S4x3F32 {
+ m : mat4x3<f32>,
+}
+struct S2x4F32 {
+ m : mat2x4<f32>,
+}
+struct S3x4F32 {
+ m : mat3x4<f32>,
+}
+struct S4x4F32 {
+ m : mat4x4<f32>,
+}
+
+@group(2) @binding(2) var<uniform> s2x2f32 : S2x2F32;
+@group(3) @binding(2) var<uniform> s3x2f32 : S3x2F32;
+@group(4) @binding(2) var<uniform> s4x2f32 : S4x2F32;
+@group(2) @binding(3) var<uniform> s2x3f32 : S2x3F32;
+@group(3) @binding(3) var<uniform> s3x3f32 : S3x3F32;
+@group(4) @binding(3) var<uniform> s4x3f32 : S4x3F32;
+@group(2) @binding(4) var<uniform> s2x4f32 : S2x4F32;
+@group(3) @binding(4) var<uniform> s3x4f32 : S3x4F32;
+@group(4) @binding(4) var<uniform> s4x4f32 : S4x4F32;
+)";
+
+ auto* expect = R"(
+struct S2x2F32 {
+ m : mat2x2<f32>,
+}
+
+struct S2x2F32_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+}
+
+struct S3x2F32 {
+ m : mat3x2<f32>,
+}
+
+struct S3x2F32_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+}
+
+struct S4x2F32 {
+ m : mat4x2<f32>,
+}
+
+struct S4x2F32_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+ m_3 : vec2<f32>,
+}
+
+struct S2x3F32 {
+ m : mat2x3<f32>,
+}
+
+struct S3x3F32 {
+ m : mat3x3<f32>,
+}
+
+struct S4x3F32 {
+ m : mat4x3<f32>,
+}
+
+struct S2x4F32 {
+ m : mat2x4<f32>,
+}
+
+struct S3x4F32 {
+ m : mat3x4<f32>,
+}
+
+struct S4x4F32 {
+ m : mat4x4<f32>,
+}
+
+@group(2) @binding(2) var<uniform> s2x2f32 : S2x2F32_std140;
+
+@group(3) @binding(2) var<uniform> s3x2f32 : S3x2F32_std140;
+
+@group(4) @binding(2) var<uniform> s4x2f32 : S4x2F32_std140;
+
+@group(2) @binding(3) var<uniform> s2x3f32 : S2x3F32;
+
+@group(3) @binding(3) var<uniform> s3x3f32 : S3x3F32;
+
+@group(4) @binding(3) var<uniform> s4x3f32 : S4x3F32;
+
+@group(2) @binding(4) var<uniform> s2x4f32 : S2x4F32;
+
+@group(3) @binding(4) var<uniform> s3x4f32 : S3x4F32;
+
+@group(4) @binding(4) var<uniform> s4x4f32 : S4x4F32;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// In the following tests we only test `mat2x2<f32>` for matrix used as array element type and
+// `mat3x2<f32>` otherwise, and set all constant column index to 1, row index 0, inner array index
+// 2, and outer array index 3. For exhaustive tests, i.e. tests on all matrix shape and different
+// valid constant index, please refer to std140_exhaustive_test.cc
+
+TEST_F(Std140Test_F32, SingleStructMatUniform_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, CustomAlign_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ before : i32,
+ @align(128)
+ m : mat3x2<f32>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ auto* expect = R"(
+struct S {
+ before : i32,
+ @align(128)
+ m : mat3x2<f32>,
+ after : i32,
+}
+
+struct S_std140 {
+ before : i32,
+ @align(128i)
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, CustomSizeMat_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ before : i32,
+ @size(128)
+ m : mat3x2<f32>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ auto* expect = R"(
+struct S {
+ before : i32,
+ @size(128)
+ m : mat3x2<f32>,
+ after : i32,
+}
+
+struct S_std140 {
+ before : i32,
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(112)
+ m_2 : vec2<f32>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, CustomAlignAndSize_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ before : i32,
+ @align(128) @size(128)
+ m : mat3x2<f32>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ auto* expect = R"(
+struct S {
+ before : i32,
+ @align(128) @size(128)
+ m : mat3x2<f32>,
+ after : i32,
+}
+
+struct S_std140 {
+ before : i32,
+ @align(128i)
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(112)
+ m_2 : vec2<f32>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, MatrixUsageInForLoop_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ for(var i = u32(s.m[0][0]); (i < u32(s.m[i][1])); i += u32(s.m[1][i])) {
+ }
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m_p0_1(p0 : u32) -> f32 {
+ switch(p0) {
+ case 0u: {
+ return s.m_0[1u];
+ }
+ case 1u: {
+ return s.m_1[1u];
+ }
+ case 2u: {
+ return s.m_2[1u];
+ }
+ default: {
+ return f32();
+ }
+ }
+}
+
+fn f() {
+ for(var i = u32(s.m_0[0u]); (i < u32(load_s_m_p0_1(u32(i)))); i += u32(s.m_1[i])) {
+ }
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, MatUniform_LoadMatrix_Mat3x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> m : mat3x2<f32>;
+
+fn f() {
+ let l = m;
+}
+)";
+
+ auto* expect = R"(
+struct mat3x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+ col2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> m : mat3x2_f32;
+
+fn conv_mat3x2_f32(val : mat3x2_f32) -> mat3x2<f32> {
+ return mat3x2<f32>(val.col0, val.col1, val.col2);
+}
+
+fn f() {
+ let l = conv_mat3x2_f32(m);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, MatUniform_LoadColumn_ConstIndex_Mat3x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : mat3x2<f32>;
+
+fn f() {
+ let l = a[1];
+}
+)";
+
+ auto* expect = R"(
+struct mat3x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+ col2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : mat3x2_f32;
+
+fn f() {
+ let l = a.col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, MatUniform_LoadColumn_VariableIndex_Mat3x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : mat3x2<f32>;
+
+fn f() {
+ let I = 1;
+ let l = a[I];
+}
+)";
+
+ auto* expect = R"(
+struct mat3x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+ col2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : mat3x2_f32;
+
+fn load_a_p0(p0 : u32) -> vec2<f32> {
+ switch(p0) {
+ case 0u: {
+ return a.col0;
+ }
+ case 1u: {
+ return a.col1;
+ }
+ case 2u: {
+ return a.col2;
+ }
+ default: {
+ return vec2<f32>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_p0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, MatUniform_LoadColumnSwizzle_ConstIndex_Mat3x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : mat3x2<f32>;
+
+fn f() {
+ let l = a[1].yx;
+}
+)";
+
+ auto* expect = R"(
+struct mat3x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+ col2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : mat3x2_f32;
+
+fn f() {
+ let l = a.col1.yx;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, MatUniform_LoadColumnSwizzle_VariableIndex_Mat3x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : mat3x2<f32>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].yx;
+}
+)";
+
+ auto* expect = R"(
+struct mat3x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+ col2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : mat3x2_f32;
+
+fn load_a_p0_yx(p0 : u32) -> vec2<f32> {
+ switch(p0) {
+ case 0u: {
+ return a.col0.yx;
+ }
+ case 1u: {
+ return a.col1.yx;
+ }
+ case 2u: {
+ return a.col2.yx;
+ }
+ default: {
+ return vec2<f32>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_p0_yx(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, MatUniform_LoadScalar_ConstColumnIndex_ConstRowIndex_Mat3x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : mat3x2<f32>;
+
+fn f() {
+ let l = a[1][0];
+}
+)";
+
+ auto* expect = R"(
+struct mat3x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+ col2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : mat3x2_f32;
+
+fn f() {
+ let l = a.col1[0u];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, MatUniform_LoadScalar_VariableColumnIndex_ConstRowIndex_Mat3x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : mat3x2<f32>;
+
+fn f() {
+ let I = 0;
+ let l = a[I][0];
+}
+)";
+
+ auto* expect = R"(
+struct mat3x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+ col2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : mat3x2_f32;
+
+fn load_a_p0_0(p0 : u32) -> f32 {
+ switch(p0) {
+ case 0u: {
+ return a.col0[0u];
+ }
+ case 1u: {
+ return a.col1[0u];
+ }
+ case 2u: {
+ return a.col2[0u];
+ }
+ default: {
+ return f32();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_a_p0_0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, MatUniform_LoadScalar_ConstColumnIndex_VariableRowIndex_Mat3x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : mat3x2<f32>;
+
+fn f() {
+ let I = 0;
+ let l = a[1][I];
+}
+)";
+
+ auto* expect = R"(
+struct mat3x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+ col2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : mat3x2_f32;
+
+fn f() {
+ let I = 0;
+ let l = a.col1[I];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, MatUniform_LoadScalar_VariableColumnIndex_VariableRowIndex_Mat3x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : mat3x2<f32>;
+
+fn f() {
+ let I = 0;
+ let l = a[I][I];
+}
+)";
+
+ auto* expect = R"(
+struct mat3x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+ col2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : mat3x2_f32;
+
+fn load_a_p0_p1(p0 : u32, p1 : u32) -> f32 {
+ switch(p0) {
+ case 0u: {
+ return a.col0[p1];
+ }
+ case 1u: {
+ return a.col1[p1];
+ }
+ case 2u: {
+ return a.col2[p1];
+ }
+ default: {
+ return f32();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_a_p0_p1(u32(I), u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, StructMatUniform_NameCollision_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ m_1 : i32,
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ auto* expect = R"(
+struct S {
+ m_1 : i32,
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_1 : i32,
+ m__0 : vec2<f32>,
+ m__1 : vec2<f32>,
+ m__2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, StructMatUniform_LoadStruct_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn conv_S(val : S_std140) -> S {
+ return S(mat3x2<f32>(val.m_0, val.m_1, val.m_2));
+}
+
+fn f() {
+ let l = conv_S(s);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, StructMatUniform_LoadMatrix_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.m;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m() -> mat3x2<f32> {
+ let s = &(s);
+ return mat3x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2);
+}
+
+fn f() {
+ let l = load_s_m();
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, StructMatUniform_LoadColumn_ConstIndex_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.m[1];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let l = s.m_1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, StructMatUniform_LoadColumn_VariableIndex_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[I];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m_p0(p0 : u32) -> vec2<f32> {
+ switch(p0) {
+ case 0u: {
+ return s.m_0;
+ }
+ case 1u: {
+ return s.m_1;
+ }
+ case 2u: {
+ return s.m_2;
+ }
+ default: {
+ return vec2<f32>();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_s_m_p0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, StructMatUniform_LoadScalar_ConstColumnIndex_ConstRowIndex_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.m[1][0];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let l = s.m_1[0u];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, StructMatUniform_LoadScalar_VariableColumnIndex_ConstRowIndex_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[I][0];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m_p0_0(p0 : u32) -> f32 {
+ switch(p0) {
+ case 0u: {
+ return s.m_0[0u];
+ }
+ case 1u: {
+ return s.m_1[0u];
+ }
+ case 2u: {
+ return s.m_2[0u];
+ }
+ default: {
+ return f32();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_s_m_p0_0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, StructMatUniform_LoadScalar_ConstColumnIndex_VariableRowIndex_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[1][I];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let I = 0;
+ let l = s.m_1[I];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, StructMatUniform_LoadScalar_VariableColumnIndex_VariableRowIndex_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[I][I];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m_p0_p1(p0 : u32, p1 : u32) -> f32 {
+ switch(p0) {
+ case 0u: {
+ return s.m_0[p1];
+ }
+ case 1u: {
+ return s.m_1[p1];
+ }
+ case 2u: {
+ return s.m_2[p1];
+ }
+ default: {
+ return f32();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_s_m_p0_p1(u32(I), u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayStructMatUniform_LoadArray_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(mat3x2<f32>(val.m_0, val.m_1, val.m_2));
+}
+
+fn conv_arr3_S(val : array<S_std140, 3u>) -> array<S, 3u> {
+ var arr : array<S, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_S(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let l = conv_arr3_S(a);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayStructMatUniform_LoadStruct_ConstIndex_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a[2];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(mat3x2<f32>(val.m_0, val.m_1, val.m_2));
+}
+
+fn f() {
+ let l = conv_S(a[2u]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayStructMatUniform_LoadStruct_VariableIndex_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(mat3x2<f32>(val.m_0, val.m_1, val.m_2));
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_S(a[I]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayStructMatUniform_LoadMatrix_ConstArrayIndex_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a[2].m;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn load_a_2_m() -> mat3x2<f32> {
+ let s = &(a[2u]);
+ return mat3x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2);
+}
+
+fn f() {
+ let l = load_a_2_m();
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayStructMatUniform_LoadMatrix_VariableArrayIndex_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].m;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn load_a_p0_m(p0 : u32) -> mat3x2<f32> {
+ let s = &(a[p0]);
+ return mat3x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2);
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_p0_m(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32,
+ ArrayStructMatUniform_LoadColumn_ConstArrayIndex_ConstColumnIndex_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a[2].m[1];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn f() {
+ let l = a[2u].m_1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32,
+ ArrayStructMatUniform_LoadColumn_VariableArrayIndex_ConstColumnIndex_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].m[1];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].m_1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32,
+ ArrayStructMatUniform_LoadColumn_ConstArrayIndex_VariableColumnIndex_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[2].m[I];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn load_a_2_m_p0(p0 : u32) -> vec2<f32> {
+ switch(p0) {
+ case 0u: {
+ return a[2u].m_0;
+ }
+ case 1u: {
+ return a[2u].m_1;
+ }
+ case 2u: {
+ return a[2u].m_2;
+ }
+ default: {
+ return vec2<f32>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_2_m_p0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32,
+ ArrayStructMatUniform_LoadColumn_VariableArrayIndex_VariableColumnIndex_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].m[I];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn load_a_p0_m_p1(p0 : u32, p1 : u32) -> vec2<f32> {
+ switch(p1) {
+ case 0u: {
+ return a[p0].m_0;
+ }
+ case 1u: {
+ return a[p0].m_1;
+ }
+ case 2u: {
+ return a[p0].m_2;
+ }
+ default: {
+ return vec2<f32>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_p0_m_p1(u32(I), u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayStructArrayStructMatUniform_Loads_Mat3x2F32) {
+ auto* src = R"(
+struct Inner {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<Outer, 4>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let K = 0;
+ let l_a : array<Outer, 4> = a;
+ let l_a_1 : Outer = a[1];
+ let l_a_I : Outer = a[I];
+ let l_a_2_a : array<Inner, 4> = a[2].a;
+ let l_a_I_a : array<Inner, 4> = a[I].a;
+ let l_a_3_a_1 : Inner = a[3].a[1];
+ let l_a_3_a_I : Inner = a[3].a[I];
+ let l_a_I_a_1 : Inner = a[I].a[1];
+ let l_a_I_a_J : Inner = a[I].a[J];
+ let l_a_0_a_2_m : mat3x2<f32> = a[0].a[2].m;
+ let l_a_0_a_I_m : mat3x2<f32> = a[0].a[I].m;
+ let l_a_I_a_2_m : mat3x2<f32> = a[I].a[2].m;
+ let l_a_I_a_J_m : mat3x2<f32> = a[I].a[J].m;
+ let l_a_1_a_3_m_0 : vec2<f32> = a[1].a[3].m[0];
+ let l_a_I_a_J_m_K : vec2<f32> = a[I].a[J].m[K];
+ let l_a_2_a_0_m_1_0 : f32 = a[2].a[0].m[1][0];
+ let l_a_I_a_J_m_K_I : f32 = a[I].a[J].m[K][I];
+}
+)";
+
+ auto* expect = R"(
+struct Inner {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct Inner_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+struct Outer_std140 {
+ a : array<Inner_std140, 4u>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<Outer_std140, 4u>;
+
+fn conv_Inner(val : Inner_std140) -> Inner {
+ return Inner(mat3x2<f32>(val.m_0, val.m_1, val.m_2));
+}
+
+fn conv_arr4_Inner(val : array<Inner_std140, 4u>) -> array<Inner, 4u> {
+ var arr : array<Inner, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_Inner(val[i]);
+ }
+ return arr;
+}
+
+fn conv_Outer(val : Outer_std140) -> Outer {
+ return Outer(conv_arr4_Inner(val.a));
+}
+
+fn conv_arr4_Outer(val : array<Outer_std140, 4u>) -> array<Outer, 4u> {
+ var arr : array<Outer, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_Outer(val[i]);
+ }
+ return arr;
+}
+
+fn load_a_0_a_2_m() -> mat3x2<f32> {
+ let s = &(a[0u].a[2u]);
+ return mat3x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2);
+}
+
+fn load_a_0_a_p0_m(p0 : u32) -> mat3x2<f32> {
+ let s = &(a[0u].a[p0]);
+ return mat3x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2);
+}
+
+fn load_a_p0_a_2_m(p0 : u32) -> mat3x2<f32> {
+ let s = &(a[p0].a[2u]);
+ return mat3x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2);
+}
+
+fn load_a_p0_a_p1_m(p0 : u32, p1 : u32) -> mat3x2<f32> {
+ let s = &(a[p0].a[p1]);
+ return mat3x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2);
+}
+
+fn load_a_p0_a_p1_m_p2(p0 : u32, p1 : u32, p2 : u32) -> vec2<f32> {
+ switch(p2) {
+ case 0u: {
+ return a[p0].a[p1].m_0;
+ }
+ case 1u: {
+ return a[p0].a[p1].m_1;
+ }
+ case 2u: {
+ return a[p0].a[p1].m_2;
+ }
+ default: {
+ return vec2<f32>();
+ }
+ }
+}
+
+fn load_a_p0_a_p1_m_p2_p3(p0 : u32, p1 : u32, p2 : u32, p3 : u32) -> f32 {
+ switch(p2) {
+ case 0u: {
+ return a[p0].a[p1].m_0[p3];
+ }
+ case 1u: {
+ return a[p0].a[p1].m_1[p3];
+ }
+ case 2u: {
+ return a[p0].a[p1].m_2[p3];
+ }
+ default: {
+ return f32();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let K = 0;
+ let l_a : array<Outer, 4> = conv_arr4_Outer(a);
+ let l_a_1 : Outer = conv_Outer(a[1u]);
+ let l_a_I : Outer = conv_Outer(a[I]);
+ let l_a_2_a : array<Inner, 4> = conv_arr4_Inner(a[2u].a);
+ let l_a_I_a : array<Inner, 4> = conv_arr4_Inner(a[I].a);
+ let l_a_3_a_1 : Inner = conv_Inner(a[3u].a[1u]);
+ let l_a_3_a_I : Inner = conv_Inner(a[3u].a[I]);
+ let l_a_I_a_1 : Inner = conv_Inner(a[I].a[1u]);
+ let l_a_I_a_J : Inner = conv_Inner(a[I].a[J]);
+ let l_a_0_a_2_m : mat3x2<f32> = load_a_0_a_2_m();
+ let l_a_0_a_I_m : mat3x2<f32> = load_a_0_a_p0_m(u32(I));
+ let l_a_I_a_2_m : mat3x2<f32> = load_a_p0_a_2_m(u32(I));
+ let l_a_I_a_J_m : mat3x2<f32> = load_a_p0_a_p1_m(u32(I), u32(J));
+ let l_a_1_a_3_m_0 : vec2<f32> = a[1u].a[3u].m_0;
+ let l_a_I_a_J_m_K : vec2<f32> = load_a_p0_a_p1_m_p2(u32(I), u32(J), u32(K));
+ let l_a_2_a_0_m_1_0 : f32 = a[2u].a[0u].m_1[0u];
+ let l_a_I_a_J_m_K_I : f32 = load_a_p0_a_p1_m_p2_p3(u32(I), u32(J), u32(K), u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayStructArrayStructMatUniform_LoadsViaPtrs_Mat3x2F32) {
+ auto* src = R"(
+struct Inner {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<Outer, 4>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let K = 0;
+ let p_a = &(a);
+ let p_a_3 = &((*(p_a))[3]);
+ let p_a_I = &((*(p_a))[I]);
+ let p_a_3_a = &((*(p_a_3)).a);
+ let p_a_I_a = &((*(p_a_I)).a);
+ let p_a_3_a_2 = &((*(p_a_3_a))[2]);
+ let p_a_3_a_I = &((*(p_a_3_a))[I]);
+ let p_a_I_a_2 = &((*(p_a_I_a))[2]);
+ let p_a_I_a_J = &((*(p_a_I_a))[J]);
+ let p_a_3_a_2_m = &((*(p_a_3_a_2)).m);
+ let p_a_3_a_I_m = &((*(p_a_3_a_I)).m);
+ let p_a_I_a_2_m = &((*(p_a_I_a_2)).m);
+ let p_a_I_a_J_m = &((*(p_a_I_a_J)).m);
+ let p_a_3_a_2_m_1 = &((*(p_a_3_a_2_m))[1]);
+ let p_a_I_a_J_m_K = &((*(p_a_I_a_J_m))[K]);
+ let l_a : array<Outer, 4> = *(p_a);
+ let l_a_3 : Outer = *(p_a_3);
+ let l_a_I : Outer = *(p_a_I);
+ let l_a_3_a : array<Inner, 4> = *(p_a_3_a);
+ let l_a_I_a : array<Inner, 4> = *(p_a_I_a);
+ let l_a_3_a_2 : Inner = *(p_a_3_a_2);
+ let l_a_3_a_I : Inner = *(p_a_3_a_I);
+ let l_a_I_a_2 : Inner = *(p_a_I_a_2);
+ let l_a_I_a_J : Inner = *(p_a_I_a_J);
+ let l_a_3_a_2_m : mat3x2<f32> = *(p_a_3_a_2_m);
+ let l_a_3_a_I_m : mat3x2<f32> = *(p_a_3_a_I_m);
+ let l_a_I_a_2_m : mat3x2<f32> = *(p_a_I_a_2_m);
+ let l_a_I_a_J_m : mat3x2<f32> = *(p_a_I_a_J_m);
+ let l_a_3_a_2_m_1 : vec2<f32> = *(p_a_3_a_2_m_1);
+ let l_a_I_a_J_m_K : vec2<f32> = *(p_a_I_a_J_m_K);
+ let l_a_2_a_0_m_1_0 : f32 = (*(p_a_3_a_2_m_1))[0];
+ let l_a_I_a_J_m_K_I : f32 = (*(p_a_I_a_J_m_K))[I];
+}
+)";
+
+ auto* expect = R"(
+struct Inner {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct Inner_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+struct Outer_std140 {
+ a : array<Inner_std140, 4u>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<Outer_std140, 4u>;
+
+fn conv_Inner(val : Inner_std140) -> Inner {
+ return Inner(mat3x2<f32>(val.m_0, val.m_1, val.m_2));
+}
+
+fn conv_arr4_Inner(val : array<Inner_std140, 4u>) -> array<Inner, 4u> {
+ var arr : array<Inner, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_Inner(val[i]);
+ }
+ return arr;
+}
+
+fn conv_Outer(val : Outer_std140) -> Outer {
+ return Outer(conv_arr4_Inner(val.a));
+}
+
+fn conv_arr4_Outer(val : array<Outer_std140, 4u>) -> array<Outer, 4u> {
+ var arr : array<Outer, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_Outer(val[i]);
+ }
+ return arr;
+}
+
+fn load_a_3_a_2_m() -> mat3x2<f32> {
+ let s = &(a[3u].a[2u]);
+ return mat3x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2);
+}
+
+fn load_a_3_a_p0_m(p0 : u32) -> mat3x2<f32> {
+ let s = &(a[3u].a[p0]);
+ return mat3x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2);
+}
+
+fn load_a_p0_a_2_m(p0 : u32) -> mat3x2<f32> {
+ let s = &(a[p0].a[2u]);
+ return mat3x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2);
+}
+
+fn load_a_p0_a_p1_m(p0 : u32, p1 : u32) -> mat3x2<f32> {
+ let s = &(a[p0].a[p1]);
+ return mat3x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2);
+}
+
+fn load_a_p0_a_p1_m_p2(p0 : u32, p1 : u32, p2 : u32) -> vec2<f32> {
+ switch(p2) {
+ case 0u: {
+ return a[p0].a[p1].m_0;
+ }
+ case 1u: {
+ return a[p0].a[p1].m_1;
+ }
+ case 2u: {
+ return a[p0].a[p1].m_2;
+ }
+ default: {
+ return vec2<f32>();
+ }
+ }
+}
+
+fn load_a_p0_a_p1_m_p2_p3(p0 : u32, p1 : u32, p2 : u32, p3 : u32) -> f32 {
+ switch(p2) {
+ case 0u: {
+ return a[p0].a[p1].m_0[p3];
+ }
+ case 1u: {
+ return a[p0].a[p1].m_1[p3];
+ }
+ case 2u: {
+ return a[p0].a[p1].m_2[p3];
+ }
+ default: {
+ return f32();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let K = 0;
+ let p_a = conv_arr4_Outer(a);
+ let p_a_3 = conv_Outer(a[3u]);
+ let p_a_I = conv_Outer(a[I]);
+ let p_a_3_a = conv_arr4_Inner(a[3u].a);
+ let p_a_I_a = conv_arr4_Inner(a[I].a);
+ let p_a_3_a_2 = conv_Inner(a[3u].a[2u]);
+ let p_a_3_a_I = conv_Inner(a[3u].a[I]);
+ let p_a_I_a_2 = conv_Inner(a[I].a[2u]);
+ let p_a_I_a_J = conv_Inner(a[I].a[J]);
+ let p_a_3_a_2_m = load_a_3_a_2_m();
+ let p_a_3_a_I_m = load_a_3_a_p0_m(u32(I));
+ let p_a_I_a_2_m = load_a_p0_a_2_m(u32(I));
+ let p_a_I_a_J_m = load_a_p0_a_p1_m(u32(I), u32(J));
+ let p_a_3_a_2_m_1 = a[3u].a[2u].m_1;
+ let p_a_I_a_J_m_K = load_a_p0_a_p1_m_p2(u32(I), u32(J), u32(K));
+ let l_a : array<Outer, 4> = conv_arr4_Outer(a);
+ let l_a_3 : Outer = conv_Outer(a[3u]);
+ let l_a_I : Outer = conv_Outer(a[I]);
+ let l_a_3_a : array<Inner, 4> = conv_arr4_Inner(a[3u].a);
+ let l_a_I_a : array<Inner, 4> = conv_arr4_Inner(a[I].a);
+ let l_a_3_a_2 : Inner = conv_Inner(a[3u].a[2u]);
+ let l_a_3_a_I : Inner = conv_Inner(a[3u].a[I]);
+ let l_a_I_a_2 : Inner = conv_Inner(a[I].a[2u]);
+ let l_a_I_a_J : Inner = conv_Inner(a[I].a[J]);
+ let l_a_3_a_2_m : mat3x2<f32> = load_a_3_a_2_m();
+ let l_a_3_a_I_m : mat3x2<f32> = load_a_3_a_p0_m(u32(I));
+ let l_a_I_a_2_m : mat3x2<f32> = load_a_p0_a_2_m(u32(I));
+ let l_a_I_a_J_m : mat3x2<f32> = load_a_p0_a_p1_m(u32(I), u32(J));
+ let l_a_3_a_2_m_1 : vec2<f32> = a[3u].a[2u].m_1;
+ let l_a_I_a_J_m_K : vec2<f32> = load_a_p0_a_p1_m_p2(u32(I), u32(J), u32(K));
+ let l_a_2_a_0_m_1_0 : f32 = a[3u].a[2u].m_1[0u];
+ let l_a_I_a_J_m_K_I : f32 = load_a_p0_a_p1_m_p2_p3(u32(I), u32(J), u32(K), u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayStructMatUniform_CopyArray_UniformToStorage_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 4>;
+
+@group(0) @binding(1) var<storage, read_write> s : array<S, 4>;
+
+fn f() {
+ s = u;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 4u>;
+
+@group(0) @binding(1) var<storage, read_write> s : array<S, 4>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(mat3x2<f32>(val.m_0, val.m_1, val.m_2));
+}
+
+fn conv_arr4_S(val : array<S_std140, 4u>) -> array<S, 4u> {
+ var arr : array<S, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_S(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ s = conv_arr4_S(u);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayStructMatUniform_CopyStruct_UniformToWorkgroup_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 4>;
+
+var<workgroup> w : array<S, 4>;
+
+fn f() {
+ w[0] = u[1];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ v : vec4<i32>,
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 4u>;
+
+var<workgroup> w : array<S, 4>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(val.v, mat3x2<f32>(val.m_0, val.m_1, val.m_2));
+}
+
+fn f() {
+ w[0] = conv_S(u[1u]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayStructMatUniform_CopyMatrix_UniformToPrivate_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 3>;
+
+var<private> p : array<S, 4>;
+
+fn f() {
+ p[2].m = u[1].m;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ v : vec4<i32>,
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 3u>;
+
+var<private> p : array<S, 4>;
+
+fn load_u_1_m() -> mat3x2<f32> {
+ let s = &(u[1u]);
+ return mat3x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2);
+}
+
+fn f() {
+ p[2].m = load_u_1_m();
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayStructMatUniform_CopyColumn_UniformToStorage_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 3>;
+
+@group(0) @binding(1) var<storage, read_write> s : array<S, 4>;
+
+fn f() {
+ s[3].m[1] = u[2].m[0];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 3u>;
+
+@group(0) @binding(1) var<storage, read_write> s : array<S, 4>;
+
+fn f() {
+ s[3].m[1] = u[2u].m_0;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayStructMatUniform_CopyColumnSwizzle_UniformToWorkgroup_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 4>;
+
+var<workgroup> w : array<S, 4>;
+
+fn f() {
+ w[3].m[1] = u[2].m[0].yx.yx;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 4u>;
+
+var<workgroup> w : array<S, 4>;
+
+fn f() {
+ w[3].m[1] = u[2u].m_0.yx.yx;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayStructMatUniform_CopyScalar_UniformToPrivate_Mat3x2F32) {
+ auto* src = R"(
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 3>;
+
+var<private> p : array<S, 4>;
+
+fn f() {
+ p[3].m[1].x = u[2].m[0].y;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ v : vec4<i32>,
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 3u>;
+
+var<private> p : array<S, 4>;
+
+fn f() {
+ p[3].m[1].x = u[2u].m_0[1u];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayMatUniform_LoadArray_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<mat2x2<f32>, 3>;
+
+fn f() {
+ let l = a;
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat2x2_f32, 3u>;
+
+fn conv_mat2x2_f32(val : mat2x2_f32) -> mat2x2<f32> {
+ return mat2x2<f32>(val.col0, val.col1);
+}
+
+fn conv_arr3_mat2x2_f32(val : array<mat2x2_f32, 3u>) -> array<mat2x2<f32>, 3u> {
+ var arr : array<mat2x2<f32>, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat2x2_f32(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let l = conv_arr3_mat2x2_f32(a);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayMatUniform_LoadMatrix_ConstArrayIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<mat2x2<f32>, 3>;
+
+fn f() {
+ let l = a[2];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat2x2_f32, 3u>;
+
+fn conv_mat2x2_f32(val : mat2x2_f32) -> mat2x2<f32> {
+ return mat2x2<f32>(val.col0, val.col1);
+}
+
+fn f() {
+ let l = conv_mat2x2_f32(a[2u]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayMatUniform_LoadMatrix_VariableArrayIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<mat2x2<f32>, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat2x2_f32, 3u>;
+
+fn conv_mat2x2_f32(val : mat2x2_f32) -> mat2x2<f32> {
+ return mat2x2<f32>(val.col0, val.col1);
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_mat2x2_f32(a[I]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayMatUniform_LoadColumn_ConstArrayIndex_ConstColumnIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<mat2x2<f32>, 3>;
+
+fn f() {
+ let l = a[2][1];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat2x2_f32, 3u>;
+
+fn f() {
+ let l = a[2u].col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayMatUniform_LoadColumn_VariableArrayIndex_ConstColumnIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<mat2x2<f32>, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][1];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat2x2_f32, 3u>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayMatUniform_LoadColumn_ConstArrayIndex_VariableColumnIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<mat2x2<f32>, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[2][I];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat2x2_f32, 3u>;
+
+fn load_a_2_p0(p0 : u32) -> vec2<f32> {
+ switch(p0) {
+ case 0u: {
+ return a[2u].col0;
+ }
+ case 1u: {
+ return a[2u].col1;
+ }
+ default: {
+ return vec2<f32>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_2_p0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32,
+ ArrayMatUniform_LoadColumn_VariableArrayIndex_VariableColumnIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<mat2x2<f32>, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][I];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<mat2x2_f32, 3u>;
+
+fn load_a_p0_p1(p0 : u32, p1 : u32) -> vec2<f32> {
+ switch(p1) {
+ case 0u: {
+ return a[p0].col0;
+ }
+ case 1u: {
+ return a[p0].col1;
+ }
+ default: {
+ return vec2<f32>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_p0_p1(u32(I), u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, StructArrayMatUniform_LoadStruct_Mat2x2F32) {
+ auto* src = R"(
+struct S {
+ a : array<mat2x2<f32>, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s;
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+struct S {
+ a : array<mat2x2<f32>, 3>,
+}
+
+struct S_std140 {
+ a : array<mat2x2_f32, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn conv_mat2x2_f32(val : mat2x2_f32) -> mat2x2<f32> {
+ return mat2x2<f32>(val.col0, val.col1);
+}
+
+fn conv_arr3_mat2x2_f32(val : array<mat2x2_f32, 3u>) -> array<mat2x2<f32>, 3u> {
+ var arr : array<mat2x2<f32>, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat2x2_f32(val[i]);
+ }
+ return arr;
+}
+
+fn conv_S(val : S_std140) -> S {
+ return S(conv_arr3_mat2x2_f32(val.a));
+}
+
+fn f() {
+ let l = conv_S(s);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, StructArrayMatUniform_LoadArray_Mat2x2F32) {
+ auto* src = R"(
+struct S {
+ a : array<mat2x2<f32>, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.a;
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+struct S {
+ a : array<mat2x2<f32>, 3>,
+}
+
+struct S_std140 {
+ a : array<mat2x2_f32, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn conv_mat2x2_f32(val : mat2x2_f32) -> mat2x2<f32> {
+ return mat2x2<f32>(val.col0, val.col1);
+}
+
+fn conv_arr3_mat2x2_f32(val : array<mat2x2_f32, 3u>) -> array<mat2x2<f32>, 3u> {
+ var arr : array<mat2x2<f32>, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat2x2_f32(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let l = conv_arr3_mat2x2_f32(s.a);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, StructArrayMatUniform_LoadMatrix_ConstArrayIndex_Mat2x2F32) {
+ auto* src = R"(
+struct S {
+ a : array<mat2x2<f32>, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.a[2];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+struct S {
+ a : array<mat2x2<f32>, 3>,
+}
+
+struct S_std140 {
+ a : array<mat2x2_f32, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn conv_mat2x2_f32(val : mat2x2_f32) -> mat2x2<f32> {
+ return mat2x2<f32>(val.col0, val.col1);
+}
+
+fn f() {
+ let l = conv_mat2x2_f32(s.a[2u]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, StructArrayMatUniform_LoadMatrix_VariableArrayIndex_Mat2x2F32) {
+ auto* src = R"(
+struct S {
+ a : array<mat2x2<f32>, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 1;
+ let l = s.a[I];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+struct S {
+ a : array<mat2x2<f32>, 3>,
+}
+
+struct S_std140 {
+ a : array<mat2x2_f32, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn conv_mat2x2_f32(val : mat2x2_f32) -> mat2x2<f32> {
+ return mat2x2<f32>(val.col0, val.col1);
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_mat2x2_f32(s.a[I]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32,
+ StructArrayMatUniform_LoadColumn_ConstArrayIndex_ConstColumnIndex_Mat2x2F32) {
+ auto* src = R"(
+struct S {
+ a : array<mat2x2<f32>, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.a[2][1];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+struct S {
+ a : array<mat2x2<f32>, 3>,
+}
+
+struct S_std140 {
+ a : array<mat2x2_f32, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let l = s.a[2u].col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32,
+ StructArrayMatUniform_LoadColumn_VariableArrayIndex_ConstColumnIndex_Mat2x2F32) {
+ auto* src = R"(
+struct S {
+ a : array<mat2x2<f32>, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 1;
+ let l = s.a[I][1];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+struct S {
+ a : array<mat2x2<f32>, 3>,
+}
+
+struct S_std140 {
+ a : array<mat2x2_f32, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let I = 1;
+ let l = s.a[I].col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32,
+ StructArrayMatUniform_LoadColumn_ConstArrayIndex_VariableColumnIndex_Mat2x2F32) {
+ auto* src = R"(
+struct S {
+ a : array<mat2x2<f32>, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 1;
+ let l = s.a[2][I];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+struct S {
+ a : array<mat2x2<f32>, 3>,
+}
+
+struct S_std140 {
+ a : array<mat2x2_f32, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_a_2_p0(p0 : u32) -> vec2<f32> {
+ switch(p0) {
+ case 0u: {
+ return s.a[2u].col0;
+ }
+ case 1u: {
+ return s.a[2u].col1;
+ }
+ default: {
+ return vec2<f32>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_s_a_2_p0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32,
+ StructArrayMatUniform_LoadColumn_VariableArrayIndex_VariableColumnIndex_Mat2x2F32) {
+ auto* src = R"(
+struct S {
+ a : array<mat2x2<f32>, 3>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 1;
+ let l = s.a[I][I];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+struct S {
+ a : array<mat2x2<f32>, 3>,
+}
+
+struct S_std140 {
+ a : array<mat2x2_f32, 3u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_a_p0_p1(p0 : u32, p1 : u32) -> vec2<f32> {
+ switch(p1) {
+ case 0u: {
+ return s.a[p0].col0;
+ }
+ case 1u: {
+ return s.a[p0].col1;
+ }
+ default: {
+ return vec2<f32>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_s_a_p0_p1(u32(I), u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayArrayMatUniform_LoadArrays_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2<f32>, 3>, 4>;
+
+fn f() {
+ let l = a;
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2_f32, 3u>, 4u>;
+
+fn conv_mat2x2_f32(val : mat2x2_f32) -> mat2x2<f32> {
+ return mat2x2<f32>(val.col0, val.col1);
+}
+
+fn conv_arr3_mat2x2_f32(val : array<mat2x2_f32, 3u>) -> array<mat2x2<f32>, 3u> {
+ var arr : array<mat2x2<f32>, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat2x2_f32(val[i]);
+ }
+ return arr;
+}
+
+fn conv_arr4_arr3_mat2x2_f32(val : array<array<mat2x2_f32, 3u>, 4u>) -> array<array<mat2x2<f32>, 3u>, 4u> {
+ var arr : array<array<mat2x2<f32>, 3u>, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_arr3_mat2x2_f32(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let l = conv_arr4_arr3_mat2x2_f32(a);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayArrayMatUniform_LoadArray_ConstOuterArrayIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2<f32>, 3>, 4>;
+
+fn f() {
+ let l = a[3];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2_f32, 3u>, 4u>;
+
+fn conv_mat2x2_f32(val : mat2x2_f32) -> mat2x2<f32> {
+ return mat2x2<f32>(val.col0, val.col1);
+}
+
+fn conv_arr3_mat2x2_f32(val : array<mat2x2_f32, 3u>) -> array<mat2x2<f32>, 3u> {
+ var arr : array<mat2x2<f32>, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat2x2_f32(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let l = conv_arr3_mat2x2_f32(a[3u]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32, ArrayArrayMatUniform_LoadArray_VariableOuterArrayIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2<f32>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[I];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2_f32, 3u>, 4u>;
+
+fn conv_mat2x2_f32(val : mat2x2_f32) -> mat2x2<f32> {
+ return mat2x2<f32>(val.col0, val.col1);
+}
+
+fn conv_arr3_mat2x2_f32(val : array<mat2x2_f32, 3u>) -> array<mat2x2<f32>, 3u> {
+ var arr : array<mat2x2<f32>, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_mat2x2_f32(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_arr3_mat2x2_f32(a[I]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32,
+ ArrayArrayMatUniform_LoadMatrix_ConstOuterArrayIndex_ConstInnerArrayIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2<f32>, 3>, 4>;
+
+fn f() {
+ let l = a[3][2];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2_f32, 3u>, 4u>;
+
+fn conv_mat2x2_f32(val : mat2x2_f32) -> mat2x2<f32> {
+ return mat2x2<f32>(val.col0, val.col1);
+}
+
+fn f() {
+ let l = conv_mat2x2_f32(a[3u][2u]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32,
+ ArrayArrayMatUniform_LoadMatrix_ConstOuterArrayIndex_VariableInnerArrayIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2<f32>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[3][I];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2_f32, 3u>, 4u>;
+
+fn conv_mat2x2_f32(val : mat2x2_f32) -> mat2x2<f32> {
+ return mat2x2<f32>(val.col0, val.col1);
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_mat2x2_f32(a[3u][I]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32,
+ ArrayArrayMatUniform_LoadMatrix_VariableOuterArrayIndex_ConstInnerArrayIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2<f32>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][2];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2_f32, 3u>, 4u>;
+
+fn conv_mat2x2_f32(val : mat2x2_f32) -> mat2x2<f32> {
+ return mat2x2<f32>(val.col0, val.col1);
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_mat2x2_f32(a[I][2u]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test_F32,
+ ArrayArrayMatUniform_LoadMatrix_VariableOuterArrayIndex_VariableInnerArrayIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2<f32>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][I];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2_f32, 3u>, 4u>;
+
+fn conv_mat2x2_f32(val : mat2x2_f32) -> mat2x2<f32> {
+ return mat2x2<f32>(val.col0, val.col1);
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_mat2x2_f32(a[I][I]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(
+ Std140Test_F32,
+ ArrayArrayMatUniform_LoadColumn_ConstOuterArrayIndex_ConstInnerArrayIndex_ConstColumnIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2<f32>, 3>, 4>;
+
+fn f() {
+ let l = a[3][2][1];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2_f32, 3u>, 4u>;
+
+fn f() {
+ let l = a[3u][2u].col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(
+ Std140Test_F32,
+ ArrayArrayMatUniform_LoadColumn_ConstOuterArrayIndex_ConstInnerArrayIndex_VariableColumnIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2<f32>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[3][2][I];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2_f32, 3u>, 4u>;
+
+fn load_a_3_2_p0(p0 : u32) -> vec2<f32> {
+ switch(p0) {
+ case 0u: {
+ return a[3u][2u].col0;
+ }
+ case 1u: {
+ return a[3u][2u].col1;
+ }
+ default: {
+ return vec2<f32>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_3_2_p0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(
+ Std140Test_F32,
+ ArrayArrayMatUniform_LoadColumn_ConstOuterArrayIndex_VariableInnerArrayIndex_ConstColumnIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2<f32>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[3][I][1];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2_f32, 3u>, 4u>;
+
+fn f() {
+ let I = 1;
+ let l = a[3u][I].col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(
+ Std140Test_F32,
+ ArrayArrayMatUniform_LoadColumn_ConstOuterArrayIndex_VariableInnerArrayIndex_VariableColumnIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2<f32>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = a[3][I][J];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2_f32, 3u>, 4u>;
+
+fn load_a_3_p0_p1(p0 : u32, p1 : u32) -> vec2<f32> {
+ switch(p1) {
+ case 0u: {
+ return a[3u][p0].col0;
+ }
+ case 1u: {
+ return a[3u][p0].col1;
+ }
+ default: {
+ return vec2<f32>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = load_a_3_p0_p1(u32(I), u32(J));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(
+ Std140Test_F32,
+ ArrayArrayMatUniform_LoadColumn_VariableOuterArrayIndex_ConstInnerArrayIndex_ConstColumnIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2<f32>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][2][1];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2_f32, 3u>, 4u>;
+
+fn f() {
+ let I = 1;
+ let l = a[I][2u].col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(
+ Std140Test_F32,
+ ArrayArrayMatUniform_LoadColumn_VariableOuterArrayIndex_ConstInnerArrayIndex_VariableColumnIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2<f32>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = a[I][2][J];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2_f32, 3u>, 4u>;
+
+fn load_a_p0_2_p1(p0 : u32, p1 : u32) -> vec2<f32> {
+ switch(p1) {
+ case 0u: {
+ return a[p0][2u].col0;
+ }
+ case 1u: {
+ return a[p0][2u].col1;
+ }
+ default: {
+ return vec2<f32>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = load_a_p0_2_p1(u32(I), u32(J));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(
+ Std140Test_F32,
+ ArrayArrayMatUniform_LoadColumn_VariableOuterArrayIndex_VariableInnerArrayIndex_ConstColumnIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2<f32>, 3>, 4>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = a[I][J][1];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2_f32, 3u>, 4u>;
+
+fn f() {
+ let I = 1;
+ let J = 2;
+ let l = a[I][J].col1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(
+ Std140Test_F32,
+ ArrayArrayMatUniform_LoadColumn_VariableOuterArrayIndex_VariableInnerArrayIndex_VariableColumnIndex_Mat2x2F32) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2<f32>, 3>, 4>;
+
+fn f() {
+ let I = 0;
+ let J = 1;
+ let K = 2;
+ let l = a[I][J][K];
+}
+)";
+
+ auto* expect = R"(
+struct mat2x2_f32 {
+ col0 : vec2<f32>,
+ col1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<array<mat2x2_f32, 3u>, 4u>;
+
+fn load_a_p0_p1_p2(p0 : u32, p1 : u32, p2 : u32) -> vec2<f32> {
+ switch(p2) {
+ case 0u: {
+ return a[p0][p1].col0;
+ }
+ case 1u: {
+ return a[p0][p1].col1;
+ }
+ default: {
+ return vec2<f32>();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let J = 1;
+ let K = 2;
+ let l = load_a_p0_p1_p2(u32(I), u32(J), u32(K));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/std140_test.cc b/src/tint/lang/wgsl/ast/transform/std140_test.cc
new file mode 100644
index 0000000..5d203b7
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/std140_test.cc
@@ -0,0 +1,200 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/std140.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/utils/string.h"
+
+// This file contains the should-run tests and a trival empty module test for Std140 transform.
+// For testing transform results with clear readability, please refer to std140_f32_test.cc for f32
+// matricies and std140_f16_test.cc for f16 matricies. For exhaustive tests that run Std140
+// transform on all shape of both f32 and f16 matricies and loop on all valid literal index when
+// required, please refer to std140_exhaustive_test.cc.
+
+namespace tint::ast::transform {
+namespace {
+
+using Std140Test = TransformTest;
+
+TEST_F(Std140Test, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<Std140>(src));
+}
+
+TEST_F(Std140Test, ShouldRunStructMat2x2F32Unused) {
+ auto* src = R"(
+struct Unused {
+ m : mat2x2<f32>,
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<Std140>(src));
+}
+
+TEST_F(Std140Test, ShouldRunStructMat2x2F16Unused) {
+ auto* src = R"(
+enable f16;
+
+struct Unused {
+ m : mat2x2<f16>,
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<Std140>(src));
+}
+
+enum class MatrixType { f32, f16 };
+
+struct MatrixCase {
+ uint32_t columns;
+ uint32_t rows;
+ MatrixType type;
+
+ size_t ElementSize() const { return type == MatrixType::f16 ? 2 : 4; }
+
+ size_t ColumnVectorAlign() const { return (rows == 3 ? 4 : rows) * ElementSize(); }
+
+ bool NotStd140Compatible() const { return ColumnVectorAlign() != 16; }
+
+ // Return if this matrix type can be used as element type of an uniform buffer, i.e. the
+ // array stride is multiple of 16.
+ bool CanBeUsedAsUniformArrayElememts() const {
+ const size_t arrayStride = columns * ColumnVectorAlign();
+ return (arrayStride % 16 == 0);
+ }
+
+ std::string Shape() const { return std::to_string(columns) + "x" + std::to_string(rows); }
+
+ std::string ElementType() const { return type == MatrixType::f16 ? "f16" : "f32"; }
+
+ std::string Mat() const { return "mat" + Shape() + "<" + ElementType() + ">"; }
+
+ // Replace predefined field `${mat}` with the matrix shape. E.g. for a matrix mat4x3<f32>, would
+ // replace "${mat}" with "mat4x3<f32>".
+ std::string ReplaceMatInString(std::string str) const {
+ str = utils::ReplaceAll(str, "${mat}", Mat());
+ return str;
+ }
+};
+
+inline std::ostream& operator<<(std::ostream& os, const MatrixCase& c) {
+ return os << c.Mat();
+}
+
+using Std140TestShouldRun = TransformTestWithParam<MatrixCase>;
+
+TEST_P(Std140TestShouldRun, StructStorage) {
+ std::string src = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<storage> s : S;
+)";
+
+ src = GetParam().ReplaceMatInString(src);
+
+ EXPECT_FALSE(ShouldRun<Std140>(src));
+}
+
+TEST_P(Std140TestShouldRun, StructUniform) {
+ std::string src = R"(
+enable f16;
+
+struct S {
+ m : ${mat},
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ src = GetParam().ReplaceMatInString(src);
+
+ EXPECT_EQ(ShouldRun<Std140>(src), GetParam().NotStd140Compatible());
+}
+
+TEST_P(Std140TestShouldRun, ArrayStorage) {
+ std::string src = R"(
+enable f16;
+
+@group(0) @binding(0) var<storage> s : array<${mat}, 2>;
+)";
+
+ src = GetParam().ReplaceMatInString(src);
+
+ EXPECT_FALSE(ShouldRun<Std140>(src));
+}
+
+TEST_P(Std140TestShouldRun, ArrayUniform) {
+ auto matrix = GetParam();
+
+ if (!matrix.CanBeUsedAsUniformArrayElememts()) {
+ // This permutation is invalid, skip the test.
+ return;
+ }
+
+ std::string src = R"(
+enable f16;
+
+@group(0) @binding(0) var<uniform> s : array<${mat}, 2>;
+)";
+
+ src = GetParam().ReplaceMatInString(src);
+
+ EXPECT_EQ(ShouldRun<Std140>(src), matrix.NotStd140Compatible());
+}
+
+INSTANTIATE_TEST_SUITE_P(Std140TestShouldRun,
+ Std140TestShouldRun,
+ ::testing::ValuesIn(std::vector<MatrixCase>{
+ {2, 2, MatrixType::f32},
+ {2, 3, MatrixType::f32},
+ {2, 4, MatrixType::f32},
+ {3, 2, MatrixType::f32},
+ {3, 3, MatrixType::f32},
+ {3, 4, MatrixType::f32},
+ {4, 2, MatrixType::f32},
+ {4, 3, MatrixType::f32},
+ {4, 4, MatrixType::f32},
+ {2, 2, MatrixType::f16},
+ {2, 3, MatrixType::f16},
+ {2, 4, MatrixType::f16},
+ {3, 2, MatrixType::f16},
+ {3, 3, MatrixType::f16},
+ {3, 4, MatrixType::f16},
+ {4, 2, MatrixType::f16},
+ {4, 3, MatrixType::f16},
+ {4, 4, MatrixType::f16},
+ }));
+
+TEST_F(Std140Test, EmptyModule) {
+ auto* src = R"()";
+
+ auto* expect = src;
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/substitute_override.cc b/src/tint/lang/wgsl/ast/transform/substitute_override.cc
new file mode 100644
index 0000000..1b97062
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/substitute_override.cc
@@ -0,0 +1,130 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/substitute_override.h"
+
+#include <functional>
+#include <utility>
+
+#include "src/tint/builtin/function.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/builtin.h"
+#include "src/tint/sem/index_accessor_expression.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::SubstituteOverride);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::SubstituteOverride::Config);
+
+namespace tint::ast::transform {
+namespace {
+
+bool ShouldRun(const Program* program) {
+ for (auto* node : program->AST().GlobalVariables()) {
+ if (node->Is<Override>()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+SubstituteOverride::SubstituteOverride() = default;
+
+SubstituteOverride::~SubstituteOverride() = default;
+
+Transform::ApplyResult SubstituteOverride::Apply(const Program* src,
+ const DataMap& config,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ const auto* data = config.Get<Config>();
+ if (!data) {
+ b.Diagnostics().add_error(diag::System::Transform, "Missing override substitution data");
+ return Program(std::move(b));
+ }
+
+ if (!ShouldRun(ctx.src)) {
+ return SkipTransform;
+ }
+
+ ctx.ReplaceAll([&](const Override* w) -> const Const* {
+ auto* sem = ctx.src->Sem().Get(w);
+
+ auto source = ctx.Clone(w->source);
+ auto sym = ctx.Clone(w->name->symbol);
+ Type ty = w->type ? ctx.Clone(w->type) : Type{};
+
+ // No replacement provided, just clone the override node as a const.
+ auto iter = data->map.find(sem->OverrideId());
+ if (iter == data->map.end()) {
+ if (!w->initializer) {
+ b.Diagnostics().add_error(
+ diag::System::Transform,
+ "Initializer not provided for override, and override not overridden.");
+ return nullptr;
+ }
+ return b.Const(source, sym, ty, ctx.Clone(w->initializer));
+ }
+
+ auto value = iter->second;
+ auto* ctor = Switch(
+ sem->Type(),
+ [&](const type::Bool*) { return b.Expr(!std::equal_to<double>()(value, 0.0)); },
+ [&](const type::I32*) { return b.Expr(i32(value)); },
+ [&](const type::U32*) { return b.Expr(u32(value)); },
+ [&](const type::F32*) { return b.Expr(f32(value)); },
+ [&](const type::F16*) { return b.Expr(f16(value)); });
+
+ if (!ctor) {
+ b.Diagnostics().add_error(diag::System::Transform,
+ "Failed to create override-expression");
+ return nullptr;
+ }
+
+ return b.Const(source, sym, ty, ctor);
+ });
+
+ // Ensure that objects that are indexed with an override-expression are materialized.
+ // If the object is not materialized, and the 'override' variable is turned to a 'const', the
+ // resulting type of the index may change. See: crbug.com/tint/1697.
+ ctx.ReplaceAll([&](const IndexAccessorExpression* expr) -> const IndexAccessorExpression* {
+ if (auto* sem = src->Sem().Get(expr)) {
+ if (auto* access = sem->UnwrapMaterialize()->As<sem::IndexAccessorExpression>()) {
+ if (access->Object()->UnwrapMaterialize()->Type()->HoldsAbstract() &&
+ access->Index()->Stage() == sem::EvaluationStage::kOverride) {
+ auto* obj = b.Call(builtin::str(builtin::Function::kTintMaterialize),
+ ctx.Clone(expr->object));
+ return b.IndexAccessor(obj, ctx.Clone(expr->index));
+ }
+ }
+ }
+ return nullptr;
+ });
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+SubstituteOverride::Config::Config() = default;
+
+SubstituteOverride::Config::Config(const Config&) = default;
+
+SubstituteOverride::Config::~Config() = default;
+
+SubstituteOverride::Config& SubstituteOverride::Config::operator=(const Config&) = default;
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/substitute_override.h b/src/tint/lang/wgsl/ast/transform/substitute_override.h
new file mode 100644
index 0000000..bd918c2e
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/substitute_override.h
@@ -0,0 +1,86 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_SUBSTITUTE_OVERRIDE_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_SUBSTITUTE_OVERRIDE_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "tint/override_id.h"
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+#include "src/tint/reflection.h"
+
+namespace tint::ast::transform {
+
+/// A transform that replaces overrides with the constant values provided.
+///
+/// # Example
+/// ```
+/// override width: f32;
+/// @id(1) override height: i32 = 4;
+/// override depth = 1i;
+/// ```
+///
+/// When transformed with `width` -> 1, `1` -> 22, `depth` -> 42
+///
+/// ```
+/// const width: f32 = 1f;
+/// const height: i32 = 22i;
+/// const depth = 42i;
+/// ```
+///
+/// @see crbug.com/tint/1582
+class SubstituteOverride final : public utils::Castable<SubstituteOverride, Transform> {
+ public:
+ /// Configuration options for the transform
+ struct Config final : public utils::Castable<Config, Data> {
+ /// Constructor
+ Config();
+
+ /// Copy constructor
+ Config(const Config&);
+
+ /// Destructor
+ ~Config() override;
+
+ /// Assignment operator
+ /// @returns this Config
+ Config& operator=(const Config&);
+
+ /// The map of override identifier to the override value.
+ /// The value is always a double coming into the transform and will be
+ /// converted to the correct type through and initializer.
+ std::unordered_map<OverrideId, double> map;
+
+ /// Reflect the fields of this class so that it can be used by tint::ForeachField()
+ TINT_REFLECT(map);
+ };
+
+ /// Constructor
+ SubstituteOverride();
+
+ /// Destructor
+ ~SubstituteOverride() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_SUBSTITUTE_OVERRIDE_H_
diff --git a/src/tint/lang/wgsl/ast/transform/substitute_override_test.cc b/src/tint/lang/wgsl/ast/transform/substitute_override_test.cc
new file mode 100644
index 0000000..8b1513c
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/substitute_override_test.cc
@@ -0,0 +1,293 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/substitute_override.h"
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using SubstituteOverrideTest = TransformTest;
+
+TEST_F(SubstituteOverrideTest, Error_NoData) {
+ auto* src = R"(
+override width: i32;
+@vertex
+fn main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = "error: Missing override substitution data";
+
+ Transform::DataMap data;
+ auto got = Run<SubstituteOverride>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SubstituteOverrideTest, Error_NoOverrideValue) {
+ auto* src = R"(
+override width: i32;
+@vertex
+fn main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = "error: Initializer not provided for override, and override not overridden.";
+
+ SubstituteOverride::Config cfg;
+ Transform::DataMap data;
+ data.Add<SubstituteOverride::Config>(cfg);
+
+ auto got = Run<SubstituteOverride>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SubstituteOverrideTest, Module_NoOverrides) {
+ auto* src = R"(
+@vertex
+fn main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+@vertex
+fn main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ SubstituteOverride::Config cfg;
+
+ Transform::DataMap data;
+ data.Add<SubstituteOverride::Config>(cfg);
+ auto got = Run<SubstituteOverride>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SubstituteOverrideTest, ImplicitId) {
+ auto* src = R"(
+enable f16;
+
+override i_width: i32;
+override i_height = 1i;
+
+override f_width: f32;
+override f_height = 1.f;
+
+override h_width: f16;
+override h_height = 1.h;
+
+override b_width: bool;
+override b_height = true;
+
+override o_width = 2i;
+
+@vertex
+fn main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+const i_width : i32 = 42i;
+
+const i_height = 11i;
+
+const f_width : f32 = 22.299999237060546875f;
+
+const f_height = 12.3999996185302734375f;
+
+const h_width : f16 = 9.3984375h;
+
+const h_height = 3.3984375h;
+
+const b_width : bool = true;
+
+const b_height = false;
+
+const o_width = 2i;
+
+@vertex
+fn main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ SubstituteOverride::Config cfg;
+ cfg.map.insert({OverrideId{0}, 42.0});
+ cfg.map.insert({OverrideId{1}, 11.0});
+ cfg.map.insert({OverrideId{2}, 22.3});
+ cfg.map.insert({OverrideId{3}, 12.4});
+ cfg.map.insert({OverrideId{4}, 9.4});
+ cfg.map.insert({OverrideId{5}, 3.4});
+ cfg.map.insert({OverrideId{6}, 1.0});
+ cfg.map.insert({OverrideId{7}, 0.0});
+
+ Transform::DataMap data;
+ data.Add<SubstituteOverride::Config>(cfg);
+ auto got = Run<SubstituteOverride>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SubstituteOverrideTest, ExplicitId) {
+ auto* src = R"(
+enable f16;
+
+@id(0) override i_width: i32;
+@id(10) override i_height = 1i;
+
+@id(1) override f_width: f32;
+@id(9) override f_height = 1.f;
+
+@id(2) override h_width: f16;
+@id(8) override h_height = 1.h;
+
+@id(3) override b_width: bool;
+@id(7) override b_height = true;
+
+@id(5) override o_width = 2i;
+
+@vertex
+fn main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+const i_width : i32 = 42i;
+
+const i_height = 11i;
+
+const f_width : f32 = 22.299999237060546875f;
+
+const f_height = 12.3999996185302734375f;
+
+const h_width : f16 = 9.3984375h;
+
+const h_height = 3.3984375h;
+
+const b_width : bool = true;
+
+const b_height = false;
+
+const o_width = 13i;
+
+@vertex
+fn main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ SubstituteOverride::Config cfg;
+ cfg.map.insert({OverrideId{0}, 42.0});
+ cfg.map.insert({OverrideId{10}, 11.0});
+ cfg.map.insert({OverrideId{1}, 22.3});
+ cfg.map.insert({OverrideId{9}, 12.4});
+ cfg.map.insert({OverrideId{2}, 9.4});
+ cfg.map.insert({OverrideId{8}, 3.4});
+ cfg.map.insert({OverrideId{3}, 1.0});
+ cfg.map.insert({OverrideId{7}, 0.0});
+ cfg.map.insert({OverrideId{5}, 13});
+
+ Transform::DataMap data;
+ data.Add<SubstituteOverride::Config>(cfg);
+ auto got = Run<SubstituteOverride>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SubstituteOverrideTest, Identifier_Expression) {
+ auto* src = R"(
+override i_height = ~2i;
+
+@vertex
+fn main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+const i_height = 11i;
+
+@vertex
+fn main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ SubstituteOverride::Config cfg;
+ cfg.map.insert({OverrideId{0}, 11.0});
+
+ Transform::DataMap data;
+ data.Add<SubstituteOverride::Config>(cfg);
+ auto got = Run<SubstituteOverride>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SubstituteOverrideTest, IndexMaterialization) {
+ auto* src = R"(
+override O = 0; // Try switching to 'const'
+
+fn f() {
+ const smaller_than_any_f32 = 1e-50;
+ const large_float = 1e27;
+ // When O is an override, the outer index value is not constant, so the
+ // value is not calculated at shader-creation time, and does not error.
+ //
+ // When O is a const, and 'smaller_than_any_f32' *is not* materialized, the
+ // outer index value will evaluate to 10000, resulting in an out-of-bounds
+ // error.
+ //
+ // When O is a const, and 'smaller_than_any_f32' *is* materialized, the
+ // materialization of 'smaller_than_any_f32' to f32 will evaluate to zero,
+ // and so the outer index value will be zero, and we get no error.
+ _ = vec2(0)[i32(vec2(smaller_than_any_f32)[O]*large_float*large_float)];
+}
+)";
+
+ auto* expect = R"(
+const O = 0i;
+
+fn f() {
+ const smaller_than_any_f32 = 1e-50;
+ const large_float = 1000000000000000013287555072.0;
+ _ = _tint_materialize(vec2(0))[i32(((_tint_materialize(vec2(smaller_than_any_f32))[O] * large_float) * large_float))];
+}
+)";
+
+ SubstituteOverride::Config cfg;
+ cfg.map.insert({OverrideId{0}, 0.0});
+
+ Transform::DataMap data;
+ data.Add<SubstituteOverride::Config>(cfg);
+ auto got = Run<SubstituteOverride>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/test_helper.h b/src/tint/lang/wgsl/ast/transform/test_helper.h
new file mode 100644
index 0000000..bddfd55
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/test_helper.h
@@ -0,0 +1,168 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_TEST_HELPER_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_TEST_HELPER_H_
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+#include "src/tint/lang/wgsl/ast_writer/generator.h"
+#include "src/tint/lang/wgsl/reader/parser.h"
+#include "src/tint/transform/manager.h"
+
+namespace tint::ast::transform {
+
+/// @param program the program to get an output WGSL string from
+/// @returns the output program as a WGSL string, or an error string if the
+/// program is not valid.
+inline std::string str(const Program& program) {
+ diag::Formatter::Style style;
+ style.print_newline_at_end = false;
+
+ if (!program.IsValid()) {
+ return diag::Formatter(style).format(program.Diagnostics());
+ }
+
+ writer::wgsl::Options options;
+ auto result = writer::wgsl::Generate(&program, options);
+ if (!result.success) {
+ return "WGSL writer failed:\n" + result.error;
+ }
+
+ auto res = result.wgsl;
+ if (res.empty()) {
+ return res;
+ }
+ // The WGSL sometimes has two trailing newlines. Strip them
+ while (res.back() == '\n') {
+ res.pop_back();
+ }
+ if (res.empty()) {
+ return res;
+ }
+ return "\n" + res + "\n";
+}
+
+/// Helper class for testing transforms
+template <typename BASE>
+class TransformTestBase : public BASE {
+ public:
+ /// Transforms and returns the WGSL source `in`, transformed using
+ /// `transform`.
+ /// @param transform the transform to apply
+ /// @param in the input WGSL source
+ /// @param data the optional Transform::DataMap to pass to Transform::Run()
+ /// @return the transformed output
+ Output Run(std::string in,
+ std::unique_ptr<transform::Transform> transform,
+ const tint::transform::DataMap& data = {}) {
+ std::vector<std::unique_ptr<transform::Transform>> transforms;
+ transforms.emplace_back(std::move(transform));
+ return Run(std::move(in), std::move(transforms), data);
+ }
+
+ /// Transforms and returns the WGSL source `in`, transformed using
+ /// a transform of type `TRANSFORM`.
+ /// @param in the input WGSL source
+ /// @param data the optional Transform::DataMap to pass to Transform::Run()
+ /// @return the transformed output
+ template <typename... TRANSFORMS>
+ Output Run(std::string in, const tint::transform::DataMap& data = {}) {
+ auto file = std::make_unique<Source::File>("test", in);
+ auto program = reader::wgsl::Parse(file.get());
+
+ // Keep this pointer alive after Transform() returns
+ files_.emplace_back(std::move(file));
+
+ return Run<TRANSFORMS...>(std::move(program), data);
+ }
+
+ /// Transforms and returns program `program`, transformed using a transform of
+ /// type `TRANSFORM`.
+ /// @param program the input Program
+ /// @param data the optional Transform::DataMap to pass to Transform::Run()
+ /// @return the transformed output
+ template <typename... TRANSFORMS>
+ Output Run(Program&& program, const tint::transform::DataMap& data = {}) {
+ if (!program.IsValid()) {
+ return Output(std::move(program));
+ }
+
+ tint::transform::Manager manager;
+ tint::transform::DataMap outputs;
+ for (auto* transform_ptr : std::initializer_list<Transform*>{new TRANSFORMS()...}) {
+ manager.append(std::unique_ptr<Transform>(transform_ptr));
+ }
+ auto result = manager.Run(&program, data, outputs);
+ return {std::move(result), std::move(outputs)};
+ }
+
+ /// @param program the input program
+ /// @param data the optional Transform::DataMap to pass to Transform::Run()
+ /// @return true if the transform should be run for the given input.
+ template <typename TRANSFORM>
+ bool ShouldRun(Program&& program, const tint::transform::DataMap& data = {}) {
+ if (!program.IsValid()) {
+ ADD_FAILURE() << "ShouldRun() called with invalid program: "
+ << program.Diagnostics().str();
+ return false;
+ }
+
+ const Transform& t = TRANSFORM();
+
+ tint::transform::DataMap outputs;
+ auto result = t.Apply(&program, data, outputs);
+ if (!result) {
+ return false;
+ }
+ if (!result->IsValid()) {
+ ADD_FAILURE() << "Apply() called by ShouldRun() returned errors: "
+ << result->Diagnostics().str();
+ return true;
+ }
+ return result.has_value();
+ }
+
+ /// @param in the input WGSL source
+ /// @param data the optional Transform::DataMap to pass to Transform::Run()
+ /// @return true if the transform should be run for the given input.
+ template <typename TRANSFORM>
+ bool ShouldRun(std::string in, const tint::transform::DataMap& data = {}) {
+ auto file = std::make_unique<Source::File>("test", in);
+ auto program = reader::wgsl::Parse(file.get());
+ return ShouldRun<TRANSFORM>(std::move(program), data);
+ }
+
+ /// @param output the output of the transform
+ /// @returns the output program as a WGSL string, or an error string if the
+ /// program is not valid.
+ std::string str(const Output& output) { return transform::str(output.program); }
+
+ private:
+ std::vector<std::unique_ptr<Source::File>> files_;
+};
+
+using TransformTest = TransformTestBase<testing::Test>;
+
+template <typename T>
+using TransformTestWithParam = TransformTestBase<testing::TestWithParam<T>>;
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_TEST_HELPER_H_
diff --git a/src/tint/lang/wgsl/ast/transform/texture_1d_to_2d.cc b/src/tint/lang/wgsl/ast/transform/texture_1d_to_2d.cc
new file mode 100644
index 0000000..4a82a5c
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/texture_1d_to_2d.cc
@@ -0,0 +1,192 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/texture_1d_to_2d.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/type_expression.h"
+#include "src/tint/switch.h"
+#include "src/tint/type/texture_dimension.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::Texture1DTo2D);
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast::transform {
+
+namespace {
+
+bool ShouldRun(const Program* program) {
+ for (auto* fn : program->AST().Functions()) {
+ if (auto* sem_fn = program->Sem().Get(fn)) {
+ for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
+ const auto& signature = builtin->Signature();
+ auto texture = signature.Parameter(sem::ParameterUsage::kTexture);
+ if (texture) {
+ auto* tex = texture->Type()->As<type::Texture>();
+ if (tex->dim() == type::TextureDimension::k1d) {
+ return true;
+ }
+ }
+ }
+ }
+ }
+ for (auto* var : program->AST().GlobalVariables()) {
+ if (Switch(
+ program->Sem().Get(var)->Type()->UnwrapRef(),
+ [&](const type::SampledTexture* tex) {
+ return tex->dim() == type::TextureDimension::k1d;
+ },
+ [&](const type::StorageTexture* storage_tex) {
+ return storage_tex->dim() == type::TextureDimension::k1d;
+ })) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+/// PIMPL state for the transform
+struct Texture1DTo2D::State {
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+
+ /// Constructor
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
+
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
+ auto& sem = src->Sem();
+
+ if (!ShouldRun(ctx.src)) {
+ return SkipTransform;
+ }
+
+ auto create_var = [&](const Variable* v, Type type) -> const Variable* {
+ if (v->As<Parameter>()) {
+ return ctx.dst->Param(ctx.Clone(v->name->symbol), type, ctx.Clone(v->attributes));
+ } else {
+ return ctx.dst->Var(ctx.Clone(v->name->symbol), type, ctx.Clone(v->attributes));
+ }
+ };
+
+ ctx.ReplaceAll([&](const Variable* v) -> const Variable* {
+ const Variable* r = Switch(
+ sem.Get(v)->Type()->UnwrapRef(),
+ [&](const type::SampledTexture* tex) -> const Variable* {
+ if (tex->dim() == type::TextureDimension::k1d) {
+ auto type = ctx.dst->ty.sampled_texture(type::TextureDimension::k2d,
+ CreateASTTypeFor(ctx, tex->type()));
+ return create_var(v, type);
+ } else {
+ return nullptr;
+ }
+ },
+ [&](const type::StorageTexture* storage_tex) -> const Variable* {
+ if (storage_tex->dim() == type::TextureDimension::k1d) {
+ auto type = ctx.dst->ty.storage_texture(type::TextureDimension::k2d,
+ storage_tex->texel_format(),
+ storage_tex->access());
+ return create_var(v, type);
+ } else {
+ return nullptr;
+ }
+ },
+ [](Default) { return nullptr; });
+ return r;
+ });
+
+ ctx.ReplaceAll([&](const CallExpression* c) -> const Expression* {
+ auto* call = sem.Get(c)->UnwrapMaterialize()->As<sem::Call>();
+ if (!call) {
+ return nullptr;
+ }
+ auto* builtin = call->Target()->As<sem::Builtin>();
+ if (!builtin) {
+ return nullptr;
+ }
+ const auto& signature = builtin->Signature();
+ auto* texture = signature.Parameter(sem::ParameterUsage::kTexture);
+ if (!texture) {
+ return nullptr;
+ }
+ auto* tex = texture->Type()->As<type::Texture>();
+ if (tex->dim() != type::TextureDimension::k1d) {
+ return nullptr;
+ }
+
+ if (builtin->Type() == builtin::Function::kTextureDimensions) {
+ // If this textureDimensions() call is in a CallStatement, we can leave it
+ // unmodified since the return value will be dropped on the floor anyway.
+ if (call->Stmt()->Declaration()->Is<CallStatement>()) {
+ return nullptr;
+ }
+ auto* new_call = ctx.CloneWithoutTransform(c);
+ return ctx.dst->MemberAccessor(new_call, "x");
+ }
+
+ auto coords_index = signature.IndexOf(sem::ParameterUsage::kCoords);
+ if (coords_index == -1) {
+ return nullptr;
+ }
+
+ utils::Vector<const Expression*, 8> args;
+ int index = 0;
+ for (auto* arg : c->args) {
+ if (index == coords_index) {
+ auto* ctype = call->Arguments()[static_cast<size_t>(coords_index)]->Type();
+ auto* coords = c->args[static_cast<size_t>(coords_index)];
+
+ const LiteralExpression* half = nullptr;
+ if (ctype->is_integer_scalar()) {
+ half = ctx.dst->Expr(0_a);
+ } else {
+ half = ctx.dst->Expr(0.5_a);
+ }
+ args.Push(
+ ctx.dst->vec(CreateASTTypeFor(ctx, ctype), 2u, ctx.Clone(coords), half));
+ } else {
+ args.Push(ctx.Clone(arg));
+ }
+ index++;
+ }
+ return ctx.dst->Call(ctx.Clone(c->target), args);
+ });
+
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+};
+
+Texture1DTo2D::Texture1DTo2D() = default;
+
+Texture1DTo2D::~Texture1DTo2D() = default;
+
+Transform::ApplyResult Texture1DTo2D::Apply(const Program* src, const DataMap&, DataMap&) const {
+ return State(src).Run();
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/texture_1d_to_2d.h b/src/tint/lang/wgsl/ast/transform/texture_1d_to_2d.h
new file mode 100644
index 0000000..32122cf
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/texture_1d_to_2d.h
@@ -0,0 +1,43 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_TEXTURE_1D_TO_2D_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_TEXTURE_1D_TO_2D_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// This transform converts all 1D texture types and accesses to 2D.
+/// This is required for GLSL ES, which does not support 1D textures.
+class Texture1DTo2D final : public utils::Castable<Texture1DTo2D, Transform> {
+ public:
+ /// Constructor
+ Texture1DTo2D();
+
+ /// Destructor
+ ~Texture1DTo2D() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_TEXTURE_1D_TO_2D_H_
diff --git a/src/tint/lang/wgsl/ast/transform/texture_1d_to_2d_test.cc b/src/tint/lang/wgsl/ast/transform/texture_1d_to_2d_test.cc
new file mode 100644
index 0000000..52b17e2
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/texture_1d_to_2d_test.cc
@@ -0,0 +1,300 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/texture_1d_to_2d.h"
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using Texture1DTo2DTest = TransformTest;
+
+TEST_F(Texture1DTo2DTest, EmptyModule) {
+ auto* src = "";
+
+ Transform::DataMap data;
+ EXPECT_FALSE(ShouldRun<Texture1DTo2D>(src, data));
+}
+
+TEST_F(Texture1DTo2DTest, Global1DDecl) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_1d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+)";
+ auto* expect = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+)";
+
+ Transform::DataMap data;
+ auto got = Run<Texture1DTo2D>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Texture1DTo2DTest, Global1DDeclAndSample) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_1d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn main() -> vec4<f32> {
+ return textureSample(t, s, 0.5);
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn main() -> vec4<f32> {
+ return textureSample(t, s, vec2<f32>(0.5, 0.5));
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<Texture1DTo2D>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Texture1DTo2DTest, Global1DDeclAndLoad) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_1d<f32>;
+
+fn main() -> vec4<f32> {
+ return textureLoad(t, 1, 0);
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+fn main() -> vec4<f32> {
+ return textureLoad(t, vec2<i32>(1, 0), 0);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<Texture1DTo2D>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Texture1DTo2DTest, Global1DDeclAndTextureDimensions) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_1d<f32>;
+
+fn main() -> u32 {
+ return textureDimensions(t);
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+fn main() -> u32 {
+ return textureDimensions(t).x;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<Texture1DTo2D>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Texture1DTo2DTest, Global1DDeclAndTextureNumLevels) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_1d<f32>;
+
+fn main() -> u32 {
+ return textureNumLevels(t);
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+fn main() -> u32 {
+ return textureNumLevels(t);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<Texture1DTo2D>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Texture1DTo2DTest, Global1DDeclAndTextureDimensionsInCallStmt) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_1d<f32>;
+
+fn main() {
+ _ = textureDimensions(t);
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+fn main() {
+ _ = textureDimensions(t).x;
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<Texture1DTo2D>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Texture1DTo2DTest, GlobalStorage1DDecl) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_storage_1d<r32float, write>;
+)";
+ auto* expect = R"(
+@group(0) @binding(0) var t : texture_storage_2d<r32float, write>;
+)";
+
+ Transform::DataMap data;
+ auto got = Run<Texture1DTo2D>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Texture1DTo2DTest, Global2DDeclAndSample) {
+ auto* src = R"(
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
+fn main() -> vec4<f32> {
+ return textureSample(t, s, vec2<f32>(0.5, 1.5));
+}
+)";
+
+ Transform::DataMap data;
+ EXPECT_FALSE(ShouldRun<Texture1DTo2D>(src, data));
+}
+
+TEST_F(Texture1DTo2DTest, PrivateIntNoop) {
+ auto* src = R"(
+var<private> i : i32;
+)";
+
+ Transform::DataMap data;
+ EXPECT_FALSE(ShouldRun<Texture1DTo2D>(src, data));
+}
+
+TEST_F(Texture1DTo2DTest, GlobalMatrixNoop) {
+ auto* src = R"(
+@group(0) @binding(0) var<uniform> m : mat2x2<f32>;
+)";
+
+ Transform::DataMap data;
+ EXPECT_FALSE(ShouldRun<Texture1DTo2D>(src, data));
+}
+
+TEST_F(Texture1DTo2DTest, Texture1DFuncParam) {
+ auto* src = R"(
+@group(0) @binding(0) var tex : texture_1d<f32>;
+
+@group(0) @binding(1) var samp : sampler;
+
+fn f(t : texture_1d<f32>, s : sampler) -> vec4<f32> {
+ return textureSample(t, s, 0.7);
+}
+
+fn main() -> vec4<f32> {
+ return f(tex, samp);
+}
+)";
+ auto* expect = R"(
+@group(0) @binding(0) var tex : texture_2d<f32>;
+
+@group(0) @binding(1) var samp : sampler;
+
+fn f(t : texture_2d<f32>, s : sampler) -> vec4<f32> {
+ return textureSample(t, s, vec2<f32>(0.69999999999999995559, 0.5));
+}
+
+fn main() -> vec4<f32> {
+ return f(tex, samp);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<Texture1DTo2D>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Texture1DTo2DTest, TextureStorage1DFuncParam) {
+ auto* src = R"(
+@group(0) @binding(0) var tex : texture_storage_1d<rgba8unorm, write>;
+
+fn f(t : texture_storage_1d<rgba8unorm, write>) {
+ textureStore(t, 3, vec4<f32>(42.0, 21.0, 84.0, 10.5));
+}
+
+fn main() {
+ f(tex);
+}
+)";
+
+ auto* expect = R"(
+@group(0) @binding(0) var tex : texture_storage_2d<rgba8unorm, write>;
+
+fn f(t : texture_storage_2d<rgba8unorm, write>) {
+ textureStore(t, vec2<i32>(3, 0), vec4<f32>(42.0, 21.0, 84.0, 10.5));
+}
+
+fn main() {
+ f(tex);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<Texture1DTo2D>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Texture1DTo2DTest, TextureAndNonTextureBuiltin) {
+ auto* src = R"(
+@group(0) @binding(0) var tex : texture_1d<i32>;
+
+fn d() {
+ _ = textureLoad(tex, 1, 0);
+ let l = sin(3.0);
+}
+)";
+
+ auto* expect = R"(
+@group(0) @binding(0) var tex : texture_2d<i32>;
+
+fn d() {
+ _ = textureLoad(tex, vec2<i32>(1, 0), 0);
+ let l = sin(3.0);
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<Texture1DTo2D>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/transform.cc b/src/tint/lang/wgsl/ast/transform/transform.cc
new file mode 100644
index 0000000..5bb87d7
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/transform.cc
@@ -0,0 +1,179 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+#include <algorithm>
+#include <string>
+
+#include "src/tint/builtin/builtin.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/block_statement.h"
+#include "src/tint/sem/for_loop_statement.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/type/atomic.h"
+#include "src/tint/type/depth_multisampled_texture.h"
+#include "src/tint/type/reference.h"
+#include "src/tint/type/sampler.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::Transform);
+
+namespace tint::ast::transform {
+
+Output::Output() = default;
+Output::Output(Program&& p) : program(std::move(p)) {}
+Transform::Transform() = default;
+Transform::~Transform() = default;
+
+Output Transform::Run(const Program* src, const DataMap& data /* = {} */) const {
+ Output output;
+ if (auto program = Apply(src, data, output.data)) {
+ output.program = std::move(program.value());
+ } else {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+ ctx.Clone();
+ output.program = Program(std::move(b));
+ }
+ return output;
+}
+
+void Transform::RemoveStatement(CloneContext& ctx, const Statement* stmt) {
+ auto* sem = ctx.src->Sem().Get(stmt);
+ if (auto* block = tint::As<sem::BlockStatement>(sem->Parent())) {
+ ctx.Remove(block->Declaration()->statements, stmt);
+ return;
+ }
+ if (TINT_LIKELY(tint::Is<sem::ForLoopStatement>(sem->Parent()))) {
+ ctx.Replace(stmt, static_cast<Expression*>(nullptr));
+ return;
+ }
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "unable to remove statement from parent of type " << sem->TypeInfo().name;
+}
+
+Type Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) {
+ if (ty->Is<type::Void>()) {
+ return Type{};
+ }
+ if (ty->Is<type::I32>()) {
+ return ctx.dst->ty.i32();
+ }
+ if (ty->Is<type::U32>()) {
+ return ctx.dst->ty.u32();
+ }
+ if (ty->Is<type::F16>()) {
+ return ctx.dst->ty.f16();
+ }
+ if (ty->Is<type::F32>()) {
+ return ctx.dst->ty.f32();
+ }
+ if (ty->Is<type::Bool>()) {
+ return ctx.dst->ty.bool_();
+ }
+ if (auto* m = ty->As<type::Matrix>()) {
+ auto el = CreateASTTypeFor(ctx, m->type());
+ return ctx.dst->ty.mat(el, m->columns(), m->rows());
+ }
+ if (auto* v = ty->As<type::Vector>()) {
+ auto el = CreateASTTypeFor(ctx, v->type());
+ if (v->Packed()) {
+ TINT_ASSERT(Transform, v->Width() == 3u);
+ return ctx.dst->ty(builtin::Builtin::kPackedVec3, el);
+ } else {
+ return ctx.dst->ty.vec(el, v->Width());
+ }
+ }
+ if (auto* a = ty->As<type::Array>()) {
+ auto el = CreateASTTypeFor(ctx, a->ElemType());
+ utils::Vector<const Attribute*, 1> attrs;
+ if (!a->IsStrideImplicit()) {
+ attrs.Push(ctx.dst->create<StrideAttribute>(a->Stride()));
+ }
+ if (a->Count()->Is<type::RuntimeArrayCount>()) {
+ return ctx.dst->ty.array(el, std::move(attrs));
+ }
+ if (auto* override = a->Count()->As<sem::NamedOverrideArrayCount>()) {
+ auto* count = ctx.Clone(override->variable->Declaration());
+ return ctx.dst->ty.array(el, count, std::move(attrs));
+ }
+ if (auto* override = a->Count()->As<sem::UnnamedOverrideArrayCount>()) {
+ // If the array count is an unnamed (complex) override expression, then its not safe to
+ // redeclare this type as we'd end up with two types that would not compare equal.
+ // See crbug.com/tint/1764.
+ // Look for a type alias for this array.
+ for (auto* type_decl : ctx.src->AST().TypeDecls()) {
+ if (auto* alias = type_decl->As<Alias>()) {
+ if (ty == ctx.src->Sem().Get(alias)) {
+ // Alias found. Use the alias name to ensure types compare equal.
+ return ctx.dst->ty(ctx.Clone(alias->name->symbol));
+ }
+ }
+ }
+ // Array is not aliased. Rebuild the array.
+ auto* count = ctx.Clone(override->expr->Declaration());
+ return ctx.dst->ty.array(el, count, std::move(attrs));
+ }
+ auto count = a->ConstantCount();
+ if (TINT_UNLIKELY(!count)) {
+ TINT_ICE(Transform, ctx.dst->Diagnostics()) << type::Array::kErrExpectedConstantCount;
+ return ctx.dst->ty.array(el, u32(1), std::move(attrs));
+ }
+ return ctx.dst->ty.array(el, u32(count.value()), std::move(attrs));
+ }
+ if (auto* s = ty->As<type::Struct>()) {
+ return ctx.dst->ty(ctx.Clone(s->Name()));
+ }
+ if (auto* s = ty->As<type::Reference>()) {
+ return CreateASTTypeFor(ctx, s->StoreType());
+ }
+ if (auto* a = ty->As<type::Atomic>()) {
+ return ctx.dst->ty.atomic(CreateASTTypeFor(ctx, a->Type()));
+ }
+ if (auto* t = ty->As<type::DepthTexture>()) {
+ return ctx.dst->ty.depth_texture(t->dim());
+ }
+ if (auto* t = ty->As<type::DepthMultisampledTexture>()) {
+ return ctx.dst->ty.depth_multisampled_texture(t->dim());
+ }
+ if (ty->Is<type::ExternalTexture>()) {
+ return ctx.dst->ty.external_texture();
+ }
+ if (auto* t = ty->As<type::MultisampledTexture>()) {
+ return ctx.dst->ty.multisampled_texture(t->dim(), CreateASTTypeFor(ctx, t->type()));
+ }
+ if (auto* t = ty->As<type::SampledTexture>()) {
+ return ctx.dst->ty.sampled_texture(t->dim(), CreateASTTypeFor(ctx, t->type()));
+ }
+ if (auto* t = ty->As<type::StorageTexture>()) {
+ return ctx.dst->ty.storage_texture(t->dim(), t->texel_format(), t->access());
+ }
+ if (auto* s = ty->As<type::Sampler>()) {
+ return ctx.dst->ty.sampler(s->kind());
+ }
+ if (auto* p = ty->As<type::Pointer>()) {
+ // Note: type::Pointer always has an inferred access, but WGSL only allows an explicit
+ // access in the 'storage' address space.
+ auto address_space = p->AddressSpace();
+ auto access = address_space == builtin::AddressSpace::kStorage
+ ? p->Access()
+ : builtin::Access::kUndefined;
+ return ctx.dst->ty.ptr(address_space, CreateASTTypeFor(ctx, p->StoreType()), access);
+ }
+ TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
+ << "Unhandled type: " << ty->TypeInfo().name;
+ return Type{};
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/transform.h b/src/tint/lang/wgsl/ast/transform/transform.h
new file mode 100644
index 0000000..0b37aee
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/transform.h
@@ -0,0 +1,100 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_TRANSFORM_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_TRANSFORM_H_
+
+#include "src/tint/transform/transform.h"
+
+#include <utility>
+
+#include "src/tint/program.h"
+#include "src/tint/utils/castable.h"
+
+namespace tint::ast::transform {
+
+/// The return type of Run()
+class Output {
+ public:
+ /// Constructor
+ Output();
+
+ /// Constructor
+ /// @param program the program to move into this Output
+ explicit Output(Program&& program);
+
+ /// Constructor
+ /// @param program_ the program to move into this Output
+ /// @param data_ a variadic list of additional data unique_ptrs produced by
+ /// the transform
+ template <typename... DATA>
+ Output(Program&& program_, DATA... data_)
+ : program(std::move(program_)), data(std::forward<DATA>(data_)...) {}
+
+ /// The transformed program. May be empty on error.
+ Program program;
+
+ /// Extra output generated by the transforms.
+ tint::transform::DataMap data;
+};
+
+/// Interface for Program transforms
+class Transform : public utils::Castable<Transform, tint::transform::Transform> {
+ public:
+ /// Constructor
+ Transform();
+ /// Destructor
+ ~Transform() override;
+
+ /// Runs the transform on @p program, returning the transformation result or a clone of
+ /// @p program.
+ /// @param program the source program to transform
+ /// @param data optional extra transform-specific input data
+ /// @returns the transformation result
+ Output Run(const Program* program, const DataMap& data = {}) const;
+
+ /// The return value of Apply().
+ /// If SkipTransform (std::nullopt), then the transform is not needed to be run.
+ using ApplyResult = std::optional<Program>;
+
+ /// Value returned from Apply() to indicate that the transform does not need to be run
+ static inline constexpr std::nullopt_t SkipTransform = std::nullopt;
+
+ /// Runs the transform on `program`, return.
+ /// @param program the input program
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ /// @returns a transformed program, or std::nullopt if the transform didn't need to run.
+ virtual ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const = 0;
+
+ /// CreateASTTypeFor constructs new Type that reconstructs the semantic type `ty`.
+ /// @param ctx the clone context
+ /// @param ty the semantic type to reconstruct
+ /// @returns an Type that when resolved, will produce the semantic type `ty`.
+ static Type CreateASTTypeFor(CloneContext& ctx, const type::Type* ty);
+
+ protected:
+ /// Removes the statement `stmt` from the transformed program.
+ /// RemoveStatement handles edge cases, like statements in the initializer and
+ /// continuing of for-loops.
+ /// @param ctx the clone context
+ /// @param stmt the statement to remove when the program is cloned
+ static void RemoveStatement(CloneContext& ctx, const Statement* stmt);
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_TRANSFORM_H_
diff --git a/src/tint/lang/wgsl/ast/transform/transform_test.cc b/src/tint/lang/wgsl/ast/transform/transform_test.cc
new file mode 100644
index 0000000..3c24b35
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/transform_test.cc
@@ -0,0 +1,147 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+
+#include "src/tint/clone_context.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+#include "src/tint/program_builder.h"
+
+#include "gtest/gtest.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+
+// Inherit from Transform so we have access to protected methods
+struct CreateASTTypeForTest : public testing::Test, public Transform {
+ ApplyResult Apply(const Program*, const DataMap&, DataMap&) const override {
+ return SkipTransform;
+ }
+
+ Type create(std::function<type::Type*(ProgramBuilder&)> create_sem_type) {
+ ProgramBuilder sem_type_builder;
+ auto* sem_type = create_sem_type(sem_type_builder);
+ Program program(std::move(sem_type_builder));
+ CloneContext ctx(&ast_type_builder, &program, false);
+ return CreateASTTypeFor(ctx, sem_type);
+ }
+
+ ProgramBuilder ast_type_builder;
+};
+
+TEST_F(CreateASTTypeForTest, Basic) {
+ auto check = [&](Type ty, const char* expect) { CheckIdentifier(ty->identifier, expect); };
+
+ check(create([](ProgramBuilder& b) { return b.create<type::I32>(); }), "i32");
+ check(create([](ProgramBuilder& b) { return b.create<type::U32>(); }), "u32");
+ check(create([](ProgramBuilder& b) { return b.create<type::F32>(); }), "f32");
+ check(create([](ProgramBuilder& b) { return b.create<type::Bool>(); }), "bool");
+ EXPECT_EQ(create([](ProgramBuilder& b) { return b.create<type::Void>(); }), nullptr);
+}
+
+TEST_F(CreateASTTypeForTest, Matrix) {
+ auto mat = create([](ProgramBuilder& b) {
+ auto* column_type = b.create<type::Vector>(b.create<type::F32>(), 2u);
+ return b.create<type::Matrix>(column_type, 3u);
+ });
+
+ CheckIdentifier(mat, Template("mat3x2", "f32"));
+}
+
+TEST_F(CreateASTTypeForTest, Vector) {
+ auto vec =
+ create([](ProgramBuilder& b) { return b.create<type::Vector>(b.create<type::F32>(), 2u); });
+
+ CheckIdentifier(vec, Template("vec2", "f32"));
+}
+
+TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
+ auto arr = create([](ProgramBuilder& b) {
+ return b.create<type::Array>(b.create<type::F32>(), b.create<type::ConstantArrayCount>(2u),
+ 4u, 4u, 32u, 32u);
+ });
+
+ CheckIdentifier(arr, Template("array", "f32", 2_u));
+ auto* tmpl_attr = arr->identifier->As<TemplatedIdentifier>();
+ ASSERT_NE(tmpl_attr, nullptr);
+ EXPECT_TRUE(tmpl_attr->attributes.IsEmpty());
+}
+
+TEST_F(CreateASTTypeForTest, ArrayNonImplicitStride) {
+ auto arr = create([](ProgramBuilder& b) {
+ return b.create<type::Array>(b.create<type::F32>(), b.create<type::ConstantArrayCount>(2u),
+ 4u, 4u, 64u, 32u);
+ });
+ CheckIdentifier(arr, Template("array", "f32", 2_u));
+ auto* tmpl_attr = arr->identifier->As<TemplatedIdentifier>();
+ ASSERT_NE(tmpl_attr, nullptr);
+ ASSERT_EQ(tmpl_attr->attributes.Length(), 1u);
+ ASSERT_TRUE(tmpl_attr->attributes[0]->Is<StrideAttribute>());
+ ASSERT_EQ(tmpl_attr->attributes[0]->As<StrideAttribute>()->stride, 64u);
+}
+
+// crbug.com/tint/1764
+TEST_F(CreateASTTypeForTest, AliasedArrayWithComplexOverrideLength) {
+ // override O = 123;
+ // type A = array<i32, O*2>;
+ //
+ // var<workgroup> W : A;
+ //
+ ProgramBuilder b;
+ auto* arr_len = b.Mul("O", 2_a);
+ b.Override("O", b.Expr(123_a));
+ auto* alias = b.Alias("A", b.ty.array(b.ty.i32(), arr_len));
+
+ Program program(std::move(b));
+
+ auto* arr_ty = program.Sem().Get(alias);
+
+ CloneContext ctx(&ast_type_builder, &program, false);
+ auto ast_ty = CreateASTTypeFor(ctx, arr_ty);
+ CheckIdentifier(ast_ty, "A");
+}
+
+TEST_F(CreateASTTypeForTest, Struct) {
+ auto str = create([](ProgramBuilder& b) {
+ auto* decl = b.Structure("S", {});
+ return b.create<sem::Struct>(decl, decl->name->symbol, utils::Empty, 4u /* align */,
+ 4u /* size */, 4u /* size_no_padding */);
+ });
+
+ CheckIdentifier(str, "S");
+}
+
+TEST_F(CreateASTTypeForTest, PrivatePointer) {
+ auto ptr = create([](ProgramBuilder& b) {
+ return b.create<type::Pointer>(builtin::AddressSpace::kPrivate, b.create<type::I32>(),
+ builtin::Access::kReadWrite);
+ });
+
+ CheckIdentifier(ptr, Template("ptr", "private", "i32"));
+}
+
+TEST_F(CreateASTTypeForTest, StorageReadWritePointer) {
+ auto ptr = create([](ProgramBuilder& b) {
+ return b.create<type::Pointer>(builtin::AddressSpace::kStorage, b.create<type::I32>(),
+ builtin::Access::kReadWrite);
+ });
+
+ CheckIdentifier(ptr, Template("ptr", "storage", "i32", "read_write"));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/truncate_interstage_variables.cc b/src/tint/lang/wgsl/ast/transform/truncate_interstage_variables.cc
new file mode 100644
index 0000000..e63f471
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/truncate_interstage_variables.cc
@@ -0,0 +1,194 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/truncate_interstage_variables.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/utils/unicode.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::TruncateInterstageVariables);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::TruncateInterstageVariables::Config);
+
+namespace tint::ast::transform {
+
+namespace {
+
+struct TruncatedStructAndConverter {
+ /// The symbol of the truncated structure.
+ Symbol truncated_struct;
+ /// The symbol of the helper function that takes the original structure as a single argument and
+ /// returns the truncated structure type.
+ Symbol truncate_fn;
+};
+
+} // anonymous namespace
+
+TruncateInterstageVariables::TruncateInterstageVariables() = default;
+TruncateInterstageVariables::~TruncateInterstageVariables() = default;
+
+Transform::ApplyResult TruncateInterstageVariables::Apply(const Program* src,
+ const DataMap& config,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ const auto* data = config.Get<Config>();
+ if (data == nullptr) {
+ b.Diagnostics().add_error(
+ diag::System::Transform,
+ "missing transform data for " +
+ std::string(utils::TypeInfo::Of<TruncateInterstageVariables>().name));
+ return Program(std::move(b));
+ }
+
+ auto& sem = ctx.src->Sem();
+
+ bool should_run = false;
+
+ utils::Hashmap<const sem::Function*, Symbol, 4u> entry_point_functions_to_truncate_functions;
+ utils::Hashmap<const sem::Struct*, TruncatedStructAndConverter, 4u>
+ old_shader_io_structs_to_new_struct_and_truncate_functions;
+
+ for (auto* func_ast : ctx.src->AST().Functions()) {
+ if (!func_ast->IsEntryPoint()) {
+ continue;
+ }
+
+ if (func_ast->PipelineStage() != PipelineStage::kVertex) {
+ // Currently only vertex stage could have interstage output variables that need
+ // truncated.
+ continue;
+ }
+
+ auto* func_sem = sem.Get(func_ast);
+ auto* str = func_sem->ReturnType()->As<sem::Struct>();
+
+ // This transform is run after CanonicalizeEntryPointIO transform,
+ // So it is guaranteed that entry point inputs are already grouped in a struct.
+ if (TINT_UNLIKELY(!str)) {
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "Entrypoint function return type is non-struct.\n"
+ << "TruncateInterstageVariables transform needs to run after "
+ "CanonicalizeEntryPointIO transform.";
+ continue;
+ }
+
+ // A prepass to check if any interstage variable locations in the entry point needs
+ // truncating. If not we don't really need to handle this entry point.
+ utils::Hashset<const sem::StructMember*, 16u> omit_members;
+
+ for (auto* member : str->Members()) {
+ if (auto location = member->Attributes().location) {
+ if (!data->interstage_locations.test(location.value())) {
+ omit_members.Add(member);
+ }
+ }
+ }
+
+ if (omit_members.IsEmpty()) {
+ continue;
+ }
+
+ // Now we are sure the transform needs to be run.
+ should_run = true;
+
+ // Get or create a new truncated struct/truncate function for the interstage inputs &
+ // outputs.
+ auto entry =
+ old_shader_io_structs_to_new_struct_and_truncate_functions.GetOrCreate(str, [&] {
+ auto new_struct_sym = b.Symbols().New();
+
+ utils::Vector<const StructMember*, 20> truncated_members;
+ utils::Vector<const Expression*, 20> initializer_exprs;
+
+ for (auto* member : str->Members()) {
+ if (omit_members.Contains(member)) {
+ continue;
+ }
+
+ truncated_members.Push(ctx.Clone(member->Declaration()));
+ initializer_exprs.Push(b.MemberAccessor("io", ctx.Clone(member->Name())));
+ }
+
+ // Create the new shader io struct.
+ b.Structure(new_struct_sym, std::move(truncated_members));
+
+ // Create the mapping function to truncate the shader io.
+ auto mapping_fn_sym = b.Symbols().New("truncate_shader_output");
+ b.Func(mapping_fn_sym,
+ utils::Vector{b.Param("io", ctx.Clone(func_ast->return_type))},
+ b.ty(new_struct_sym),
+ utils::Vector{
+ b.Return(b.Call(new_struct_sym, std::move(initializer_exprs))),
+ });
+ return TruncatedStructAndConverter{new_struct_sym, mapping_fn_sym};
+ });
+
+ ctx.Replace(func_ast->return_type.expr, b.Expr(entry.truncated_struct));
+
+ entry_point_functions_to_truncate_functions.Add(func_sem, entry.truncate_fn);
+ }
+
+ if (!should_run) {
+ return SkipTransform;
+ }
+
+ // Replace return statements with new truncated shader IO struct
+ ctx.ReplaceAll([&](const ReturnStatement* return_statement) -> const ReturnStatement* {
+ auto* return_sem = sem.Get(return_statement);
+ if (auto mapping_fn_sym =
+ entry_point_functions_to_truncate_functions.Find(return_sem->Function())) {
+ return b.Return(return_statement->source,
+ b.Call(*mapping_fn_sym, ctx.Clone(return_statement->value)));
+ }
+ return nullptr;
+ });
+
+ // Remove IO attributes from old shader IO struct which is not used as entry point output
+ // anymore.
+ for (auto it : old_shader_io_structs_to_new_struct_and_truncate_functions) {
+ const Struct* struct_ty = it.key->Declaration();
+ for (auto* member : struct_ty->members) {
+ for (auto* attr : member->attributes) {
+ if (attr->IsAnyOf<BuiltinAttribute, LocationAttribute, InterpolateAttribute,
+ InvariantAttribute>()) {
+ ctx.Remove(member->attributes, attr);
+ }
+ }
+ }
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+TruncateInterstageVariables::Config::Config() = default;
+
+TruncateInterstageVariables::Config::Config(const Config&) = default;
+
+TruncateInterstageVariables::Config::~Config() = default;
+
+TruncateInterstageVariables::Config& TruncateInterstageVariables::Config::operator=(const Config&) =
+ default;
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/truncate_interstage_variables.h b/src/tint/lang/wgsl/ast/transform/truncate_interstage_variables.h
new file mode 100644
index 0000000..2d91e67
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/truncate_interstage_variables.h
@@ -0,0 +1,131 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_TRUNCATE_INTERSTAGE_VARIABLES_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_TRUNCATE_INTERSTAGE_VARIABLES_H_
+
+#include <bitset>
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+#include "src/tint/sem/binding_point.h"
+
+namespace tint::ast::transform {
+
+/// TruncateInterstageVariables is a transform that truncate interstage variables.
+/// It must be run after CanonicalizeEntryPointIO which guarantees all interstage variables of
+/// a given entry point are grouped into one shader IO struct.
+/// It replaces `original shader IO struct` with a `new wrapper struct` containing builtin IOs
+/// and user-defined IO whose locations are marked in the interstage_locations bitset from the
+/// config. The return statements of `original shader IO struct` are wrapped by a mapping function
+/// that initializes the members of `new wrapper struct` with values from `original shader IO
+/// struct`. IO attributes of members in `original shader IO struct` are removed, other attributes
+/// still preserve.
+///
+/// For example:
+///
+/// ```
+/// struct ShaderIO {
+/// @builtin(position) @invariant pos: vec4<f32>,
+/// @location(1) f_1: f32,
+/// @location(3) @align(16) f_3: f32,
+/// @location(5) @interpolate(flat) @align(16) @size(16) f_5: u32,
+/// }
+/// @vertex
+/// fn f() -> ShaderIO {
+/// var io: ShaderIO;
+/// io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+/// io.f_1 = 1.0;
+/// io.f_3 = io.f_1 + 3.0;
+/// io.f_5 = 1u;
+/// return io;
+/// }
+/// ```
+///
+/// With config.interstage_locations[3] and [5] set to true, is transformed to:
+///
+/// ```
+/// struct tint_symbol {
+/// @builtin(position) @invariant
+/// pos : vec4<f32>,
+/// @location(3) @align(16)
+/// f_3 : f32,
+/// @location(5) @interpolate(flat) @align(16) @size(16)
+/// f_5 : u32,
+/// }
+///
+/// fn truncate_shader_output(io : ShaderIO) -> tint_symbol {
+/// return tint_symbol(io.pos, io.f_3, io.f_5);
+/// }
+///
+/// struct ShaderIO {
+/// pos : vec4<f32>,
+/// f_1 : f32,
+/// @align(16)
+/// f_3 : f32,
+/// @align(16) @size(16)
+/// f_5 : u32,
+/// }
+///
+/// @vertex
+/// fn f() -> tint_symbol {
+/// var io : ShaderIO;
+/// io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+/// io.f_1 = 1.0;
+/// io.f_3 = (io.f_1 + 3.0);
+/// io.f_5 = 1u;
+/// return truncate_shader_output(io);
+/// }
+/// ```
+///
+class TruncateInterstageVariables final
+ : public utils::Castable<TruncateInterstageVariables, Transform> {
+ public:
+ /// Configuration options for the transform
+ struct Config final : public utils::Castable<Config, Data> {
+ /// Constructor
+ Config();
+
+ /// Copy constructor
+ Config(const Config&);
+
+ /// Destructor
+ ~Config() override;
+
+ /// Assignment operator
+ /// @returns this Config
+ Config& operator=(const Config&);
+
+ /// Indicate which interstage io locations are actually used by the later stage.
+ /// There can be at most 16 user defined interstage variables with locations.
+ std::bitset<16> interstage_locations;
+
+ /// Reflect the fields of this class so that it can be used by tint::ForeachField()
+ TINT_REFLECT(interstage_variables);
+ };
+
+ /// Constructor using a the configuration provided in the input Data
+ TruncateInterstageVariables();
+
+ /// Destructor
+ ~TruncateInterstageVariables() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_TRUNCATE_INTERSTAGE_VARIABLES_H_
diff --git a/src/tint/lang/wgsl/ast/transform/truncate_interstage_variables_test.cc b/src/tint/lang/wgsl/ast/transform/truncate_interstage_variables_test.cc
new file mode 100644
index 0000000..2a8c689
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/truncate_interstage_variables_test.cc
@@ -0,0 +1,598 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/truncate_interstage_variables.h"
+#include "src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using ::testing::ContainerEq;
+
+using TruncateInterstageVariablesTest = TransformTest;
+
+TEST_F(TruncateInterstageVariablesTest, ShouldRunVertex) {
+ auto* src = R"(
+struct ShaderIO {
+ @builtin(position) pos: vec4<f32>,
+ @location(0) f_0: f32,
+ @location(2) f_2: f32,
+}
+@vertex
+fn f() -> ShaderIO {
+ var io: ShaderIO;
+ io.f_0 = 1.0;
+ io.f_2 = io.f_2 + 3.0;
+ return io;
+}
+)";
+
+ {
+ auto* expect =
+ "error: missing transform data for tint::ast::transform::TruncateInterstageVariables";
+ auto got = Run<TruncateInterstageVariables>(src);
+ EXPECT_EQ(expect, str(got));
+ }
+
+ {
+ // Empty interstage_locations: truncate all interstage variables, should run
+ TruncateInterstageVariables::Config cfg;
+ Transform::DataMap data;
+ data.Add<TruncateInterstageVariables::Config>(cfg);
+ EXPECT_TRUE(ShouldRun<TruncateInterstageVariables>(src, data));
+ }
+
+ {
+ // All existing interstage_locations are marked: should not run
+ TruncateInterstageVariables::Config cfg;
+ cfg.interstage_locations[0] = true;
+ cfg.interstage_locations[2] = true;
+ Transform::DataMap data;
+ data.Add<TruncateInterstageVariables::Config>(cfg);
+ EXPECT_FALSE(ShouldRun<TruncateInterstageVariables>(src, data));
+ }
+
+ {
+ // Partial interstage_locations are marked: should run
+ TruncateInterstageVariables::Config cfg;
+ cfg.interstage_locations[2] = true;
+ Transform::DataMap data;
+ data.Add<TruncateInterstageVariables::Config>(cfg);
+ EXPECT_TRUE(ShouldRun<TruncateInterstageVariables>(src, data));
+ }
+}
+
+TEST_F(TruncateInterstageVariablesTest, ShouldRunFragment) {
+ auto* src = R"(
+struct ShaderIO {
+ @location(0) f_0: f32,
+ @location(2) f_2: f32,
+}
+@fragment
+fn f(io: ShaderIO) -> @location(1) vec4<f32> {
+ return vec4<f32>(io.f_0, io.f_2, 0.0, 1.0);
+}
+)";
+
+ TruncateInterstageVariables::Config cfg;
+ cfg.interstage_locations[2] = true;
+
+ Transform::DataMap data;
+ data.Add<TruncateInterstageVariables::Config>(cfg);
+
+ EXPECT_FALSE(ShouldRun<TruncateInterstageVariables>(src, data));
+}
+
+// Test that this transform should run after canoicalize entry point io, where shader io is already
+// grouped into a struct.
+TEST_F(TruncateInterstageVariablesTest, ShouldRunAfterCanonicalizeEntryPointIO) {
+ auto* src = R"(
+@vertex
+fn f() -> @builtin(position) vec4<f32> {
+ return vec4<f32>(1.0, 1.0, 1.0, 1.0);
+}
+)";
+
+ TruncateInterstageVariables::Config cfg;
+ cfg.interstage_locations[0] = true;
+ Transform::DataMap data;
+ data.Add<TruncateInterstageVariables::Config>(cfg);
+
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<CanonicalizeEntryPointIO>(src, data);
+
+ // Inevitably entry point can write only one variable if not using struct
+ // So the truncate won't run.
+ EXPECT_FALSE(ShouldRun<TruncateInterstageVariables>(str(got), data));
+}
+
+TEST_F(TruncateInterstageVariablesTest, BasicVertexTrimLocationInMid) {
+ auto* src = R"(
+struct ShaderIO {
+ @builtin(position) pos: vec4<f32>,
+ @location(1) f_1: f32,
+ @location(3) f_3: f32,
+}
+@vertex
+fn f() -> ShaderIO {
+ var io: ShaderIO;
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_1 = 1.0;
+ io.f_3 = io.f_1 + 3.0;
+ return io;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @builtin(position)
+ pos : vec4<f32>,
+ @location(1)
+ f_1 : f32,
+}
+
+fn truncate_shader_output(io : ShaderIO) -> tint_symbol {
+ return tint_symbol(io.pos, io.f_1);
+}
+
+struct ShaderIO {
+ pos : vec4<f32>,
+ f_1 : f32,
+ f_3 : f32,
+}
+
+@vertex
+fn f() -> tint_symbol {
+ var io : ShaderIO;
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_1 = 1.0;
+ io.f_3 = (io.f_1 + 3.0);
+ return truncate_shader_output(io);
+}
+)";
+
+ TruncateInterstageVariables::Config cfg;
+ // fragment has input at @location(1)
+ cfg.interstage_locations[1] = true;
+
+ Transform::DataMap data;
+ data.Add<TruncateInterstageVariables::Config>(cfg);
+
+ auto got = Run<TruncateInterstageVariables>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TruncateInterstageVariablesTest, BasicVertexTrimLocationAtEnd) {
+ auto* src = R"(
+struct ShaderIO {
+ @builtin(position) pos: vec4<f32>,
+ @location(1) f_1: f32,
+ @location(3) f_3: f32,
+}
+@vertex
+fn f() -> ShaderIO {
+ var io: ShaderIO;
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_1 = 1.0;
+ io.f_3 = io.f_1 + 3.0;
+ return io;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @builtin(position)
+ pos : vec4<f32>,
+ @location(3)
+ f_3 : f32,
+}
+
+fn truncate_shader_output(io : ShaderIO) -> tint_symbol {
+ return tint_symbol(io.pos, io.f_3);
+}
+
+struct ShaderIO {
+ pos : vec4<f32>,
+ f_1 : f32,
+ f_3 : f32,
+}
+
+@vertex
+fn f() -> tint_symbol {
+ var io : ShaderIO;
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_1 = 1.0;
+ io.f_3 = (io.f_1 + 3.0);
+ return truncate_shader_output(io);
+}
+)";
+
+ TruncateInterstageVariables::Config cfg;
+ // fragment has input at @location(3)
+ cfg.interstage_locations[3] = true;
+
+ Transform::DataMap data;
+ data.Add<TruncateInterstageVariables::Config>(cfg);
+
+ auto got = Run<TruncateInterstageVariables>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TruncateInterstageVariablesTest, TruncateAllLocations) {
+ auto* src = R"(
+struct ShaderIO {
+ @builtin(position) pos: vec4<f32>,
+ @location(1) f_1: f32,
+ @location(3) f_3: f32,
+}
+@vertex
+fn f() -> ShaderIO {
+ var io: ShaderIO;
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_1 = 1.0;
+ io.f_3 = io.f_1 + 3.0;
+ return io;
+}
+)";
+
+ {
+ auto* expect = R"(
+struct tint_symbol {
+ @builtin(position)
+ pos : vec4<f32>,
+}
+
+fn truncate_shader_output(io : ShaderIO) -> tint_symbol {
+ return tint_symbol(io.pos);
+}
+
+struct ShaderIO {
+ pos : vec4<f32>,
+ f_1 : f32,
+ f_3 : f32,
+}
+
+@vertex
+fn f() -> tint_symbol {
+ var io : ShaderIO;
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_1 = 1.0;
+ io.f_3 = (io.f_1 + 3.0);
+ return truncate_shader_output(io);
+}
+)";
+
+ TruncateInterstageVariables::Config cfg;
+ Transform::DataMap data;
+ data.Add<TruncateInterstageVariables::Config>(cfg);
+
+ auto got = Run<TruncateInterstageVariables>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+ }
+}
+
+// Test that the transform only removes IO attributes and preserve other attributes in the old
+// Shader IO struct.
+TEST_F(TruncateInterstageVariablesTest, RemoveIOAttributes) {
+ auto* src = R"(
+struct ShaderIO {
+ @builtin(position) @invariant pos: vec4<f32>,
+ @location(1) f_1: f32,
+ @location(3) @align(16) f_3: f32,
+ @location(5) @interpolate(flat) @align(16) @size(16) f_5: u32,
+}
+@vertex
+fn f() -> ShaderIO {
+ var io: ShaderIO;
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_1 = 1.0;
+ io.f_3 = io.f_1 + 3.0;
+ io.f_5 = 1u;
+ return io;
+}
+)";
+
+ {
+ auto* expect = R"(
+struct tint_symbol {
+ @builtin(position) @invariant
+ pos : vec4<f32>,
+ @location(3) @align(16)
+ f_3 : f32,
+ @location(5) @interpolate(flat) @align(16) @size(16)
+ f_5 : u32,
+}
+
+fn truncate_shader_output(io : ShaderIO) -> tint_symbol {
+ return tint_symbol(io.pos, io.f_3, io.f_5);
+}
+
+struct ShaderIO {
+ pos : vec4<f32>,
+ f_1 : f32,
+ @align(16)
+ f_3 : f32,
+ @align(16) @size(16)
+ f_5 : u32,
+}
+
+@vertex
+fn f() -> tint_symbol {
+ var io : ShaderIO;
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_1 = 1.0;
+ io.f_3 = (io.f_1 + 3.0);
+ io.f_5 = 1u;
+ return truncate_shader_output(io);
+}
+)";
+
+ TruncateInterstageVariables::Config cfg;
+ // Missing @location[1] intentionally to make sure the transform run.
+ cfg.interstage_locations[3] = true;
+ cfg.interstage_locations[5] = true;
+
+ Transform::DataMap data;
+ data.Add<TruncateInterstageVariables::Config>(cfg);
+
+ auto got = Run<TruncateInterstageVariables>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+ }
+}
+
+TEST_F(TruncateInterstageVariablesTest, MultipleEntryPointsSharingStruct) {
+ auto* src = R"(
+struct ShaderIO {
+ @builtin(position) pos: vec4<f32>,
+ @location(1) f_1: f32,
+ @location(3) f_3: f32,
+ @location(5) f_5: f32,
+}
+
+@vertex
+fn f1() -> ShaderIO {
+ var io: ShaderIO;
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_1 = 1.0;
+ io.f_3 = 1.0;
+ return io;
+}
+
+@vertex
+fn f2() -> ShaderIO {
+ var io: ShaderIO;
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_5 = 2.0;
+ return io;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @builtin(position)
+ pos : vec4<f32>,
+ @location(3)
+ f_3 : f32,
+}
+
+fn truncate_shader_output(io : ShaderIO) -> tint_symbol {
+ return tint_symbol(io.pos, io.f_3);
+}
+
+struct ShaderIO {
+ pos : vec4<f32>,
+ f_1 : f32,
+ f_3 : f32,
+ f_5 : f32,
+}
+
+@vertex
+fn f1() -> tint_symbol {
+ var io : ShaderIO;
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_1 = 1.0;
+ io.f_3 = 1.0;
+ return truncate_shader_output(io);
+}
+
+@vertex
+fn f2() -> tint_symbol {
+ var io : ShaderIO;
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_5 = 2.0;
+ return truncate_shader_output(io);
+}
+)";
+
+ TruncateInterstageVariables::Config cfg;
+ // fragment has input at @location(3)
+ cfg.interstage_locations[3] = true;
+
+ Transform::DataMap data;
+ data.Add<TruncateInterstageVariables::Config>(cfg);
+
+ auto got = Run<TruncateInterstageVariables>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TruncateInterstageVariablesTest, MultipleEntryPoints) {
+ auto* src = R"(
+struct ShaderIO1 {
+ @builtin(position) pos: vec4<f32>,
+ @location(1) f_1: f32,
+ @location(3) f_3: f32,
+ @location(5) f_5: f32,
+}
+
+@vertex
+fn f1() -> ShaderIO1 {
+ var io: ShaderIO1;
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_1 = 1.0;
+ io.f_3 = 1.0;
+ return io;
+}
+
+struct ShaderIO2 {
+ @builtin(position) pos: vec4<f32>,
+ @location(2) f_2: f32,
+}
+
+@vertex
+fn f2() -> ShaderIO2 {
+ var io: ShaderIO2;
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_2 = 2.0;
+ return io;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @builtin(position)
+ pos : vec4<f32>,
+ @location(3)
+ f_3 : f32,
+}
+
+fn truncate_shader_output(io : ShaderIO1) -> tint_symbol {
+ return tint_symbol(io.pos, io.f_3);
+}
+
+struct tint_symbol_1 {
+ @builtin(position)
+ pos : vec4<f32>,
+}
+
+fn truncate_shader_output_1(io : ShaderIO2) -> tint_symbol_1 {
+ return tint_symbol_1(io.pos);
+}
+
+struct ShaderIO1 {
+ pos : vec4<f32>,
+ f_1 : f32,
+ f_3 : f32,
+ f_5 : f32,
+}
+
+@vertex
+fn f1() -> tint_symbol {
+ var io : ShaderIO1;
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_1 = 1.0;
+ io.f_3 = 1.0;
+ return truncate_shader_output(io);
+}
+
+struct ShaderIO2 {
+ pos : vec4<f32>,
+ f_2 : f32,
+}
+
+@vertex
+fn f2() -> tint_symbol_1 {
+ var io : ShaderIO2;
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_2 = 2.0;
+ return truncate_shader_output_1(io);
+}
+)";
+
+ TruncateInterstageVariables::Config cfg;
+ // fragment has input at @location(3)
+ cfg.interstage_locations[3] = true;
+
+ Transform::DataMap data;
+ data.Add<TruncateInterstageVariables::Config>(cfg);
+
+ auto got = Run<TruncateInterstageVariables>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TruncateInterstageVariablesTest, MultipleReturnStatements) {
+ auto* src = R"(
+struct ShaderIO {
+ @builtin(position) pos: vec4<f32>,
+ @location(1) f_1: f32,
+ @location(3) f_3: f32,
+}
+@vertex
+fn f(@builtin(vertex_index) vid: u32) -> ShaderIO {
+ var io: ShaderIO;
+ if (vid > 10u) {
+ io.f_1 = 2.0;
+ return io;
+ }
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_1 = 1.0;
+ io.f_3 = io.f_1 + 3.0;
+ return io;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @builtin(position)
+ pos : vec4<f32>,
+ @location(3)
+ f_3 : f32,
+}
+
+fn truncate_shader_output(io : ShaderIO) -> tint_symbol {
+ return tint_symbol(io.pos, io.f_3);
+}
+
+struct ShaderIO {
+ pos : vec4<f32>,
+ f_1 : f32,
+ f_3 : f32,
+}
+
+@vertex
+fn f(@builtin(vertex_index) vid : u32) -> tint_symbol {
+ var io : ShaderIO;
+ if ((vid > 10u)) {
+ io.f_1 = 2.0;
+ return truncate_shader_output(io);
+ }
+ io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ io.f_1 = 1.0;
+ io.f_3 = (io.f_1 + 3.0);
+ return truncate_shader_output(io);
+}
+)";
+
+ TruncateInterstageVariables::Config cfg;
+ // fragment has input at @location(3)
+ cfg.interstage_locations[3] = true;
+
+ Transform::DataMap data;
+ data.Add<TruncateInterstageVariables::Config>(cfg);
+
+ auto got = Run<TruncateInterstageVariables>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/unshadow.cc b/src/tint/lang/wgsl/ast/transform/unshadow.cc
new file mode 100644
index 0000000..bf804d1
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/unshadow.cc
@@ -0,0 +1,130 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
+
+#include <memory>
+#include <unordered_map>
+#include <utility>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/block_statement.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::Unshadow);
+
+namespace tint::ast::transform {
+
+/// PIMPL state for the transform
+struct Unshadow::State {
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+
+ /// Constructor
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
+
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ Transform::ApplyResult Run() {
+ auto& sem = src->Sem();
+
+ // Maps a variable to its new name.
+ utils::Hashmap<const sem::Variable*, Symbol, 8> renamed_to;
+
+ auto rename = [&](const sem::Variable* v) -> const Variable* {
+ auto* decl = v->Declaration();
+ auto name = decl->name->symbol.Name();
+ auto symbol = b.Symbols().New(name);
+ renamed_to.Add(v, symbol);
+
+ auto source = ctx.Clone(decl->source);
+ auto type = decl->type ? ctx.Clone(decl->type) : Type{};
+ auto* initializer = ctx.Clone(decl->initializer);
+ auto attributes = ctx.Clone(decl->attributes);
+ return Switch(
+ decl, //
+ [&](const Var* var) {
+ return b.Var(source, symbol, type, var->declared_address_space,
+ var->declared_access, initializer, attributes);
+ },
+ [&](const Let*) { return b.Let(source, symbol, type, initializer, attributes); },
+ [&](const Const*) {
+ return b.Const(source, symbol, type, initializer, attributes);
+ },
+ [&](const Parameter*) { //
+ return b.Param(source, symbol, type, attributes);
+ },
+ [&](Default) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unexpected variable type: " << decl->TypeInfo().name;
+ return nullptr;
+ });
+ };
+
+ bool made_changes = false;
+
+ for (auto* node : ctx.src->SemNodes().Objects()) {
+ Switch(
+ node, //
+ [&](const sem::LocalVariable* local) {
+ if (local->Shadows()) {
+ ctx.Replace(local->Declaration(), [&, local] { return rename(local); });
+ made_changes = true;
+ }
+ },
+ [&](const sem::Parameter* param) {
+ if (param->Shadows()) {
+ ctx.Replace(param->Declaration(), [&, param] { return rename(param); });
+ made_changes = true;
+ }
+ });
+ }
+
+ if (!made_changes) {
+ return SkipTransform;
+ }
+
+ ctx.ReplaceAll([&](const IdentifierExpression* ident) -> const IdentifierExpression* {
+ if (auto* sem_ident = sem.GetVal(ident)) {
+ if (auto* user = sem_ident->Unwrap()->As<sem::VariableUser>()) {
+ if (auto renamed = renamed_to.Find(user->Variable())) {
+ return b.Expr(*renamed);
+ }
+ }
+ }
+ return nullptr;
+ });
+
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+};
+
+Unshadow::Unshadow() = default;
+
+Unshadow::~Unshadow() = default;
+
+Transform::ApplyResult Unshadow::Apply(const Program* src, const DataMap&, DataMap&) const {
+ return State(src).Run();
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/unshadow.h b/src/tint/lang/wgsl/ast/transform/unshadow.h
new file mode 100644
index 0000000..966140d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/unshadow.h
@@ -0,0 +1,42 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_UNSHADOW_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_UNSHADOW_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// Unshadow is a Transform that renames any variables that shadow another variable.
+class Unshadow final : public utils::Castable<Unshadow, Transform> {
+ public:
+ /// Constructor
+ Unshadow();
+
+ /// Destructor
+ ~Unshadow() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_UNSHADOW_H_
diff --git a/src/tint/lang/wgsl/ast/transform/unshadow_test.cc b/src/tint/lang/wgsl/ast/transform/unshadow_test.cc
new file mode 100644
index 0000000..a87f6b7
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/unshadow_test.cc
@@ -0,0 +1,803 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using UnshadowTest = TransformTest;
+
+TEST_F(UnshadowTest, EmptyModule) {
+ auto* src = "";
+
+ EXPECT_FALSE(ShouldRun<Unshadow>(src));
+}
+
+TEST_F(UnshadowTest, Noop) {
+ auto* src = R"(
+var<private> a : i32;
+
+const b : i32 = 1;
+
+fn F(c : i32) {
+ var d : i32;
+ let e : i32 = 1;
+ const f : i32 = 2;
+ {
+ var g : i32;
+ let h : i32 = 1;
+ const i : i32 = 2;
+ }
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<Unshadow>(src));
+}
+
+TEST_F(UnshadowTest, LocalShadowsAlias) {
+ auto* src = R"(
+alias a = i32;
+
+fn X() {
+ var a = false;
+}
+
+fn Y() {
+ let a = true;
+}
+
+fn Z() {
+ const a = true;
+}
+)";
+
+ auto* expect = R"(
+alias a = i32;
+
+fn X() {
+ var a_1 = false;
+}
+
+fn Y() {
+ let a_2 = true;
+}
+
+fn Z() {
+ const a_3 = true;
+}
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, LocalShadowsAlias_OutOfOrder) {
+ auto* src = R"(
+fn X() {
+ var a = false;
+}
+
+fn Y() {
+ let a = true;
+}
+
+fn Z() {
+ const a = true;
+}
+
+alias a = i32;
+)";
+
+ auto* expect = R"(
+fn X() {
+ var a_1 = false;
+}
+
+fn Y() {
+ let a_2 = true;
+}
+
+fn Z() {
+ const a_3 = true;
+}
+
+alias a = i32;
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, LocalShadowsStruct) {
+ auto* src = R"(
+struct a {
+ m : i32,
+};
+
+fn X() {
+ var a = true;
+}
+
+fn Y() {
+ let a = false;
+}
+
+fn Z() {
+ const a = false;
+}
+)";
+
+ auto* expect = R"(
+struct a {
+ m : i32,
+}
+
+fn X() {
+ var a_1 = true;
+}
+
+fn Y() {
+ let a_2 = false;
+}
+
+fn Z() {
+ const a_3 = false;
+}
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, LocalShadowsStruct_OutOfOrder) {
+ auto* src = R"(
+fn X() {
+ var a = true;
+}
+
+fn Y() {
+ let a = false;
+}
+
+fn Z() {
+ const a = false;
+}
+
+struct a {
+ m : i32,
+};
+
+)";
+
+ auto* expect = R"(
+fn X() {
+ var a_1 = true;
+}
+
+fn Y() {
+ let a_2 = false;
+}
+
+fn Z() {
+ const a_3 = false;
+}
+
+struct a {
+ m : i32,
+}
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, LocalShadowsFunction) {
+ auto* src = R"(
+fn a() {
+ var a = true;
+ var b = false;
+ var c = true;
+}
+
+fn b() {
+ let a = true;
+ let b = false;
+ let c = true;
+}
+
+fn c() {
+ const a = true;
+ const b = false;
+ const c = true;
+}
+)";
+
+ auto* expect = R"(
+fn a() {
+ var a_1 = true;
+ var b_1 = false;
+ var c_1 = true;
+}
+
+fn b() {
+ let a_2 = true;
+ let b_2 = false;
+ let c_2 = true;
+}
+
+fn c() {
+ const a_3 = true;
+ const b_3 = false;
+ const c_3 = true;
+}
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, LocalShadowsFunction_OutOfOrder) {
+ auto* src = R"(
+fn b() {
+ let a = true;
+ let b = false;
+ let c = true;
+}
+
+fn a() {
+ var a = true;
+ var b = false;
+ var c = true;
+}
+
+fn c() {
+ const a = true;
+ const b = false;
+ const c = true;
+}
+)";
+
+ auto* expect = R"(
+fn b() {
+ let a_1 = true;
+ let b_1 = false;
+ let c_1 = true;
+}
+
+fn a() {
+ var a_2 = true;
+ var b_2 = false;
+ var c_2 = true;
+}
+
+fn c() {
+ const a_3 = true;
+ const b_3 = false;
+ const c_3 = true;
+}
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, LocalShadowsGlobalVar) {
+ auto* src = R"(
+var<private> a : i32;
+
+fn X() {
+ var a = (a == 123);
+}
+
+fn Y() {
+ let a = (a == 321);
+}
+
+fn Z() {
+ const a = 321;
+}
+)";
+
+ auto* expect = R"(
+var<private> a : i32;
+
+fn X() {
+ var a_1 = (a == 123);
+}
+
+fn Y() {
+ let a_2 = (a == 321);
+}
+
+fn Z() {
+ const a_3 = 321;
+}
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, LocalShadowsGlobalVar_OutOfOrder) {
+ auto* src = R"(
+fn X() {
+ var a = (a == 123);
+}
+
+fn Y() {
+ let a = (a == 321);
+}
+
+fn Z() {
+ const a = 321;
+}
+
+var<private> a : i32;
+)";
+
+ auto* expect = R"(
+fn X() {
+ var a_1 = (a == 123);
+}
+
+fn Y() {
+ let a_2 = (a == 321);
+}
+
+fn Z() {
+ const a_3 = 321;
+}
+
+var<private> a : i32;
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, LocalShadowsGlobalConst) {
+ auto* src = R"(
+const a : i32 = 1;
+
+fn X() {
+ var a = (a == 123);
+}
+
+fn Y() {
+ let a = (a == 321);
+}
+
+fn Z() {
+ const a = 321;
+}
+)";
+
+ auto* expect = R"(
+const a : i32 = 1;
+
+fn X() {
+ var a_1 = (a == 123);
+}
+
+fn Y() {
+ let a_2 = (a == 321);
+}
+
+fn Z() {
+ const a_3 = 321;
+}
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, LocalShadowsGlobalConst_OutOfOrder) {
+ auto* src = R"(
+fn X() {
+ var a = (a == 123);
+}
+
+fn Y() {
+ let a = (a == 321);
+}
+
+fn Z() {
+ const a = a;
+}
+
+const a : i32 = 1;
+)";
+
+ auto* expect = R"(
+fn X() {
+ var a_1 = (a == 123);
+}
+
+fn Y() {
+ let a_2 = (a == 321);
+}
+
+fn Z() {
+ const a_3 = a;
+}
+
+const a : i32 = 1;
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, LocalShadowsLocalVar) {
+ auto* src = R"(
+fn X() {
+ var a : i32;
+ {
+ var a = (a == 123);
+ }
+ {
+ let a = (a == 321);
+ }
+ {
+ const a = 321;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn X() {
+ var a : i32;
+ {
+ var a_1 = (a == 123);
+ }
+ {
+ let a_2 = (a == 321);
+ }
+ {
+ const a_3 = 321;
+ }
+}
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, LocalShadowsLocalLet) {
+ auto* src = R"(
+fn X() {
+ let a = 1;
+ {
+ var a = (a == 123);
+ }
+ {
+ let a = (a == 321);
+ }
+ {
+ const a = 321;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn X() {
+ let a = 1;
+ {
+ var a_1 = (a == 123);
+ }
+ {
+ let a_2 = (a == 321);
+ }
+ {
+ const a_3 = 321;
+ }
+}
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, LocalShadowsLocalConst) {
+ auto* src = R"(
+fn X() {
+ const a = 1;
+ {
+ var a = (a == 123);
+ }
+ {
+ let a = (a == 321);
+ }
+ {
+ const a = a;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn X() {
+ const a = 1;
+ {
+ var a_1 = (a == 123);
+ }
+ {
+ let a_2 = (a == 321);
+ }
+ {
+ const a_3 = a;
+ }
+}
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, LocalShadowsParam) {
+ auto* src = R"(
+fn F(a : i32) {
+ {
+ var a = (a == 123);
+ }
+ {
+ let a = (a == 321);
+ }
+ {
+ const a = 321;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn F(a : i32) {
+ {
+ var a_1 = (a == 123);
+ }
+ {
+ let a_2 = (a == 321);
+ }
+ {
+ const a_3 = 321;
+ }
+}
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, ParamShadowsFunction) {
+ auto* src = R"(
+fn a(a : i32) {
+ {
+ var a = (a == 123);
+ }
+ {
+ let a = (a == 321);
+ }
+ {
+ const a = 321;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(a_1 : i32) {
+ {
+ var a_2 = (a_1 == 123);
+ }
+ {
+ let a_3 = (a_1 == 321);
+ }
+ {
+ const a_4 = 321;
+ }
+}
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, ParamShadowsGlobalVar) {
+ auto* src = R"(
+var<private> a : i32;
+
+fn F(a : bool) {
+}
+)";
+
+ auto* expect = R"(
+var<private> a : i32;
+
+fn F(a_1 : bool) {
+}
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, ParamShadowsGlobalConst) {
+ auto* src = R"(
+const a : i32 = 1;
+
+fn F(a : bool) {
+}
+)";
+
+ auto* expect = R"(
+const a : i32 = 1;
+
+fn F(a_1 : bool) {
+}
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, ParamShadowsGlobalConst_OutOfOrder) {
+ auto* src = R"(
+fn F(a : bool) {
+}
+
+const a : i32 = 1;
+)";
+
+ auto* expect = R"(
+fn F(a_1 : bool) {
+}
+
+const a : i32 = 1;
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, ParamShadowsAlias) {
+ auto* src = R"(
+alias a = i32;
+
+fn F(a : a) {
+ {
+ var a = (a == 123);
+ }
+ {
+ let a = (a == 321);
+ }
+}
+)";
+
+ auto* expect = R"(
+alias a = i32;
+
+fn F(a_1 : a) {
+ {
+ var a_2 = (a_1 == 123);
+ }
+ {
+ let a_3 = (a_1 == 321);
+ }
+}
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, ParamShadowsAlias_OutOfOrder) {
+ auto* src = R"(
+fn F(a : a) {
+ {
+ var a = (a == 123);
+ }
+ {
+ let a = (a == 321);
+ }
+}
+
+alias a = i32;
+)";
+
+ auto* expect = R"(
+fn F(a_1 : a) {
+ {
+ var a_2 = (a_1 == 123);
+ }
+ {
+ let a_3 = (a_1 == 321);
+ }
+}
+
+alias a = i32;
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, RenamedVarHasUsers) {
+ auto* src = R"(
+fn F() {
+ var a : bool;
+ {
+ var a : i32;
+ var b = a + 1;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn F() {
+ var a : bool;
+ {
+ var a_1 : i32;
+ var b = (a_1 + 1);
+ }
+}
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(UnshadowTest, RenamedAbstractConstHasUsers) {
+ auto* src = R"(
+fn v() {
+ const v = 1;
+ let x = v;
+}
+)";
+
+ auto* expect = R"(
+fn v() {
+ const v_1 = 1;
+ let x = v_1;
+}
+)";
+
+ auto got = Run<Unshadow>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/utils/get_insertion_point.cc b/src/tint/lang/wgsl/ast/transform/utils/get_insertion_point.cc
new file mode 100644
index 0000000..32ba34b
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/utils/get_insertion_point.cc
@@ -0,0 +1,58 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/utils/get_insertion_point.h"
+#include "src/tint/debug.h"
+#include "src/tint/diagnostic/diagnostic.h"
+#include "src/tint/sem/for_loop_statement.h"
+#include "src/tint/switch.h"
+
+namespace tint::ast::transform::utils {
+
+InsertionPoint GetInsertionPoint(CloneContext& ctx, const Statement* stmt) {
+ auto& sem = ctx.src->Sem();
+ auto& diag = ctx.dst->Diagnostics();
+ using RetType = std::pair<const sem::BlockStatement*, const Statement*>;
+
+ if (auto* sem_stmt = sem.Get(stmt)) {
+ auto* parent = sem_stmt->Parent();
+ return Switch(
+ parent,
+ [&](const sem::BlockStatement* block) -> RetType {
+ // Common case, can insert in the current block above/below the input
+ // statement.
+ return {block, stmt};
+ },
+ [&](const sem::ForLoopStatement* fl) -> RetType {
+ // `stmt` is either the for loop initializer or the continuing
+ // statement of a for-loop.
+ if (fl->Declaration()->initializer == stmt) {
+ // For loop init, can insert above the for loop itself.
+ return {fl->Block(), fl->Declaration()};
+ }
+
+ // Cannot insert before or after continuing statement of a for-loop
+ return {};
+ },
+ [&](Default) -> RetType {
+ TINT_ICE(Transform, diag) << "expected parent of statement to be "
+ "either a block or for loop";
+ return {};
+ });
+ }
+
+ return {};
+}
+
+} // namespace tint::ast::transform::utils
diff --git a/src/tint/lang/wgsl/ast/transform/utils/get_insertion_point.h b/src/tint/lang/wgsl/ast/transform/utils/get_insertion_point.h
new file mode 100644
index 0000000..687e155
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/utils/get_insertion_point.h
@@ -0,0 +1,39 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_UTILS_GET_INSERTION_POINT_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_UTILS_GET_INSERTION_POINT_H_
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/block_statement.h"
+
+namespace tint::ast::transform::utils {
+
+/// InsertionPoint is a pair of the block (`first`) within which, and the
+/// statement (`second`) before or after which to insert.
+using InsertionPoint = std::pair<const sem::BlockStatement*, const Statement*>;
+
+/// For the input statement, returns the block and statement within that
+/// block to insert before/after. If `stmt` is a for-loop continue statement,
+/// the function returns {nullptr, nullptr} as we cannot insert before/after it.
+/// @param ctx the clone context
+/// @param stmt the statement to insert before or after
+/// @return the insertion point
+InsertionPoint GetInsertionPoint(CloneContext& ctx, const Statement* stmt);
+
+} // namespace tint::ast::transform::utils
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_UTILS_GET_INSERTION_POINT_H_
diff --git a/src/tint/lang/wgsl/ast/transform/utils/get_insertion_point_test.cc b/src/tint/lang/wgsl/ast/transform/utils/get_insertion_point_test.cc
new file mode 100644
index 0000000..8ec261e
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/utils/get_insertion_point_test.cc
@@ -0,0 +1,96 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <utility>
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/debug.h"
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/lang/wgsl/ast/transform/utils/get_insertion_point.h"
+#include "src/tint/program_builder.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast::transform {
+namespace {
+
+using GetInsertionPointTest = ::testing::Test;
+
+TEST_F(GetInsertionPointTest, Block) {
+ // fn f() {
+ // var a = 1i;
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Expr(1_i);
+ auto* var = b.Decl(b.Var("a", expr));
+ auto* block = b.Block(var);
+ b.Func("f", tint::utils::Empty, b.ty.void_(), tint::utils::Vector{block});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ // Can insert in block containing the variable, above or below the input
+ // statement.
+ auto ip = utils::GetInsertionPoint(ctx, var);
+ ASSERT_EQ(ip.first->Declaration(), block);
+ ASSERT_EQ(ip.second, var);
+}
+
+TEST_F(GetInsertionPointTest, ForLoopInit) {
+ // fn f() {
+ // for(var a = 1i; true; ) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Expr(1_i);
+ auto* var = b.Decl(b.Var("a", expr));
+ auto* fl = b.For(var, b.Expr(true), nullptr, b.Block());
+ auto* func_block = b.Block(fl);
+ b.Func("f", tint::utils::Empty, b.ty.void_(), tint::utils::Vector{func_block});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ // Can insert in block containing for-loop above the for-loop itself.
+ auto ip = utils::GetInsertionPoint(ctx, var);
+ ASSERT_EQ(ip.first->Declaration(), func_block);
+ ASSERT_EQ(ip.second, fl);
+}
+
+TEST_F(GetInsertionPointTest, ForLoopCont_Invalid) {
+ // fn f() {
+ // for(; true; var a = 1i) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Expr(1_i);
+ auto* var = b.Decl(b.Var("a", expr));
+ auto* s = b.For({}, b.Expr(true), var, b.Block());
+ b.Func("f", tint::utils::Empty, b.ty.void_(), tint::utils::Vector{s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ // Can't insert before/after for loop continue statement (would ned to be
+ // converted to loop).
+ auto ip = utils::GetInsertionPoint(ctx, var);
+ ASSERT_EQ(ip.first, nullptr);
+ ASSERT_EQ(ip.second, nullptr);
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before.cc b/src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before.cc
new file mode 100644
index 0000000..799c790
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before.cc
@@ -0,0 +1,442 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/block_statement.h"
+#include "src/tint/sem/for_loop_statement.h"
+#include "src/tint/sem/if_statement.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/sem/while_statement.h"
+#include "src/tint/type/reference.h"
+#include "src/tint/utils/hashmap.h"
+#include "src/tint/utils/reverse.h"
+#include "src/tint/utils/transform.h"
+
+namespace tint::ast::transform {
+
+/// Private implementation of HoistToDeclBefore transform
+struct HoistToDeclBefore::State {
+ /// Constructor
+ /// @param ctx_in the clone context
+ explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
+
+ /// @copydoc HoistToDeclBefore::Add()
+ bool Add(const sem::ValueExpression* before_expr,
+ const Expression* expr,
+ VariableKind kind,
+ const char* decl_name) {
+ auto name = b.Symbols().New(decl_name);
+
+ switch (kind) {
+ case VariableKind::kLet: {
+ auto* ty = ctx.src->Sem().GetVal(expr)->Type();
+ TINT_ASSERT(Transform, !ty->HoldsAbstract());
+ auto builder = [this, expr, name, ty] {
+ return b.Decl(b.Let(name, Transform::CreateASTTypeFor(ctx, ty),
+ ctx.CloneWithoutTransform(expr)));
+ };
+ if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) {
+ return false;
+ }
+ break;
+ }
+
+ case VariableKind::kVar: {
+ auto* ty = ctx.src->Sem().GetVal(expr)->Type();
+ TINT_ASSERT(Transform, !ty->HoldsAbstract());
+ auto builder = [this, expr, name, ty] {
+ return b.Decl(b.Var(name, Transform::CreateASTTypeFor(ctx, ty),
+ ctx.CloneWithoutTransform(expr)));
+ };
+ if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) {
+ return false;
+ }
+ break;
+ }
+
+ case VariableKind::kConst: {
+ auto builder = [this, expr, name] {
+ return b.Decl(b.Const(name, ctx.CloneWithoutTransform(expr)));
+ };
+ if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) {
+ return false;
+ }
+ break;
+ }
+ }
+
+ // Replace the source expression with a reference to the hoisted declaration.
+ ctx.Replace(expr, b.Expr(name));
+ return true;
+ }
+
+ /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const Statement*)
+ bool InsertBefore(const sem::Statement* before_stmt, const Statement* stmt) {
+ if (stmt) {
+ auto builder = [stmt] { return stmt; };
+ return InsertBeforeImpl(before_stmt, std::move(builder));
+ }
+ return InsertBeforeImpl(before_stmt, Decompose{});
+ }
+
+ /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const StmtBuilder&)
+ bool InsertBefore(const sem::Statement* before_stmt, const StmtBuilder& builder) {
+ return InsertBeforeImpl(before_stmt, std::move(builder));
+ }
+
+ /// @copydoc HoistToDeclBefore::Replace(const sem::Statement* what, const Statement* with)
+ bool Replace(const sem::Statement* what, const Statement* with) {
+ auto builder = [with] { return with; };
+ return Replace(what, std::move(builder));
+ }
+
+ /// @copydoc HoistToDeclBefore::Replace(const sem::Statement* what, const StmtBuilder& with)
+ bool Replace(const sem::Statement* what, const StmtBuilder& with) {
+ if (!InsertBeforeImpl(what, Decompose{})) {
+ return false;
+ }
+ ctx.Replace(what->Declaration(), with);
+ return true;
+ }
+
+ /// @copydoc HoistToDeclBefore::Prepare()
+ bool Prepare(const sem::ValueExpression* before_expr) {
+ return InsertBefore(before_expr->Stmt(), nullptr);
+ }
+
+ private:
+ CloneContext& ctx;
+ ProgramBuilder& b;
+
+ /// Holds information about a for-loop that needs to be decomposed into a
+ /// loop, so that declaration statements can be inserted before the
+ /// condition expression or continuing statement.
+ struct LoopInfo {
+ utils::Vector<StmtBuilder, 8> init_decls;
+ utils::Vector<StmtBuilder, 8> cond_decls;
+ utils::Vector<StmtBuilder, 8> cont_decls;
+ };
+
+ /// Info for each else-if that needs decomposing
+ struct ElseIfInfo {
+ /// Decls to insert before condition
+ utils::Vector<StmtBuilder, 8> cond_decls;
+ };
+
+ /// For-loops that need to be decomposed to loops.
+ utils::Hashmap<const sem::ForLoopStatement*, LoopInfo, 4> for_loops;
+
+ /// Whiles that need to be decomposed to loops.
+ utils::Hashmap<const sem::WhileStatement*, LoopInfo, 4> while_loops;
+
+ /// 'else if' statements that need to be decomposed to 'else {if}'
+ utils::Hashmap<const IfStatement*, ElseIfInfo, 4> else_ifs;
+
+ template <size_t N>
+ static auto Build(const utils::Vector<StmtBuilder, N>& builders) {
+ return utils::Transform(builders, [&](auto& builder) { return builder(); });
+ }
+
+ /// @returns a new LoopInfo reference for the given @p for_loop.
+ /// @note if this is the first call to this method, then RegisterForLoopTransform() is
+ /// automatically called.
+ /// @warning the returned reference is invalid if this is called a second time, or the
+ /// #for_loops map is mutated.
+ auto ForLoop(const sem::ForLoopStatement* for_loop) {
+ if (for_loops.IsEmpty()) {
+ RegisterForLoopTransform();
+ }
+ return for_loops.GetOrZero(for_loop);
+ }
+
+ /// @returns a new LoopInfo reference for the given @p while_loop.
+ /// @note if this is the first call to this method, then RegisterWhileLoopTransform() is
+ /// automatically called.
+ /// @warning the returned reference is invalid if this is called a second time, or the
+ /// #for_loops map is mutated.
+ auto WhileLoop(const sem::WhileStatement* while_loop) {
+ if (while_loops.IsEmpty()) {
+ RegisterWhileLoopTransform();
+ }
+ return while_loops.GetOrZero(while_loop);
+ }
+
+ /// @returns a new ElseIfInfo reference for the given @p else_if.
+ /// @note if this is the first call to this method, then RegisterElseIfTransform() is
+ /// automatically called.
+ /// @warning the returned reference is invalid if this is called a second time, or the
+ /// #else_ifs map is mutated.
+ auto ElseIf(const IfStatement* else_if) {
+ if (else_ifs.IsEmpty()) {
+ RegisterElseIfTransform();
+ }
+ return else_ifs.GetOrZero(else_if);
+ }
+
+ /// Registers the handler for transforming for-loops based on the content of the #for_loops map.
+ void RegisterForLoopTransform() const {
+ ctx.ReplaceAll([&](const ForLoopStatement* stmt) -> const Statement* {
+ auto& sem = ctx.src->Sem();
+
+ if (auto* fl = sem.Get(stmt)) {
+ if (auto info = for_loops.Find(fl)) {
+ auto* for_loop = fl->Declaration();
+ // For-loop needs to be decomposed to a loop.
+ // Build the loop body's statements.
+ // Start with any let declarations for the conditional expression.
+ auto body_stmts = Build(info->cond_decls);
+ // If the for-loop has a condition, emit this next as:
+ // if (!cond) { break; }
+ if (auto* cond = for_loop->condition) {
+ // !condition
+ auto* not_cond =
+ b.create<UnaryOpExpression>(UnaryOp::kNot, ctx.Clone(cond));
+ // { break; }
+ auto* break_body = b.Block(b.create<BreakStatement>());
+ // if (!condition) { break; }
+ body_stmts.Push(b.If(not_cond, break_body));
+ }
+ // Next emit the for-loop body
+ body_stmts.Push(ctx.Clone(for_loop->body));
+
+ // Create the continuing block if there was one.
+ const BlockStatement* continuing = nullptr;
+ if (auto* cont = for_loop->continuing) {
+ // Continuing block starts with any let declarations used by
+ // the continuing.
+ auto cont_stmts = Build(info->cont_decls);
+ cont_stmts.Push(ctx.Clone(cont));
+ continuing = b.Block(cont_stmts);
+ }
+
+ auto* body = b.Block(body_stmts);
+ auto* loop = b.Loop(body, continuing);
+
+ // If the loop has no initializer statements, then we're done.
+ // Otherwise, wrap loop with another block, prefixed with the initializer
+ // statements
+ if (!info->init_decls.IsEmpty() || for_loop->initializer) {
+ auto stmts = Build(info->init_decls);
+ if (auto* init = for_loop->initializer) {
+ stmts.Push(ctx.Clone(init));
+ }
+ stmts.Push(loop);
+ return b.Block(std::move(stmts));
+ }
+ return loop;
+ }
+ }
+ return nullptr;
+ });
+ }
+
+ /// Registers the handler for transforming while-loops based on the content of the #while_loops
+ /// map.
+ void RegisterWhileLoopTransform() const {
+ // At least one while needs to be transformed into a loop.
+ ctx.ReplaceAll([&](const WhileStatement* stmt) -> const Statement* {
+ auto& sem = ctx.src->Sem();
+
+ if (auto* w = sem.Get(stmt)) {
+ if (auto info = while_loops.Find(w)) {
+ auto* while_loop = w->Declaration();
+ // While needs to be decomposed to a loop.
+ // Build the loop body's statements.
+ // Start with any let declarations for the conditional
+ // expression.
+ auto body_stmts = utils::Transform(info->cond_decls,
+ [&](auto& builder) { return builder(); });
+ // Emit the condition as:
+ // if (!cond) { break; }
+ auto* cond = while_loop->condition;
+ // !condition
+ auto* not_cond = b.Not(ctx.Clone(cond));
+ // { break; }
+ auto* break_body = b.Block(b.Break());
+ // if (!condition) { break; }
+ body_stmts.Push(b.If(not_cond, break_body));
+
+ // Next emit the body
+ body_stmts.Push(ctx.Clone(while_loop->body));
+
+ const BlockStatement* continuing = nullptr;
+
+ auto* body = b.Block(body_stmts);
+ auto* loop = b.Loop(body, continuing);
+ return loop;
+ }
+ }
+ return nullptr;
+ });
+ }
+
+ /// Registers the handler for transforming if-statements based on the content of the #else_ifs
+ /// map.
+ void RegisterElseIfTransform() const {
+ // Decompose 'else-if' statements into 'else { if }' blocks.
+ ctx.ReplaceAll([&](const IfStatement* stmt) -> const Statement* {
+ if (auto info = else_ifs.Find(stmt)) {
+ // Build the else block's body statements, starting with let decls for the
+ // conditional expression.
+ auto body_stmts = Build(info->cond_decls);
+
+ // Move the 'else-if' into the new `else` block as a plain 'if'.
+ auto* cond = ctx.Clone(stmt->condition);
+ auto* body = ctx.Clone(stmt->body);
+ auto* new_if = b.If(cond, body, b.Else(ctx.Clone(stmt->else_statement)));
+ body_stmts.Push(new_if);
+
+ // Replace the 'else-if' with the new 'else' block.
+ return b.Block(body_stmts);
+ }
+ return nullptr;
+ });
+ }
+
+ /// A type used to signal to InsertBeforeImpl that no insertion should take place - instead flow
+ /// control statements should just be decomposed.
+ struct Decompose {};
+
+ template <typename BUILDER>
+ bool InsertBeforeImpl(const sem::Statement* before_stmt, BUILDER&& builder) {
+ (void)builder; // Avoid 'unused parameter' warning due to 'if constexpr'
+
+ auto* ip = before_stmt->Declaration();
+
+ auto* else_if = before_stmt->As<sem::IfStatement>();
+ if (else_if && else_if->Parent()->Is<sem::IfStatement>()) {
+ // Insertion point is an 'else if' condition.
+ // Need to convert 'else if' to 'else { if }'.
+ auto else_if_info = ElseIf(else_if->Declaration());
+
+ // Index the map to decompose this else if, even if `stmt` is nullptr.
+ auto& decls = else_if_info->cond_decls;
+ if constexpr (!std::is_same_v<BUILDER, Decompose>) {
+ decls.Push(std::forward<BUILDER>(builder));
+ }
+ return true;
+ }
+
+ if (auto* fl = before_stmt->As<sem::ForLoopStatement>()) {
+ // Insertion point is a for-loop condition.
+ // For-loop needs to be decomposed to a loop.
+
+ // Index the map to decompose this for-loop, even if `stmt` is nullptr.
+ auto& decls = ForLoop(fl)->cond_decls;
+ if constexpr (!std::is_same_v<BUILDER, Decompose>) {
+ decls.Push(std::forward<BUILDER>(builder));
+ }
+ return true;
+ }
+
+ if (auto* w = before_stmt->As<sem::WhileStatement>()) {
+ // Insertion point is a while condition.
+ // While needs to be decomposed to a loop.
+
+ // Index the map to decompose this while, even if `stmt` is nullptr.
+ auto& decls = WhileLoop(w)->cond_decls;
+ if constexpr (!std::is_same_v<BUILDER, Decompose>) {
+ decls.Push(std::forward<BUILDER>(builder));
+ }
+ return true;
+ }
+
+ auto* parent = before_stmt->Parent(); // The statement's parent
+ if (auto* block = parent->As<sem::BlockStatement>()) {
+ // Insert point sits in a block. Simple case.
+ // Insert the stmt before the parent statement.
+ if constexpr (!std::is_same_v<BUILDER, Decompose>) {
+ ctx.InsertBefore(block->Declaration()->statements, ip,
+ std::forward<BUILDER>(builder));
+ }
+ return true;
+ }
+
+ auto* fl = parent->As<sem::ForLoopStatement>();
+ if (TINT_LIKELY(fl)) {
+ // Insertion point is a for-loop initializer or continuing statement.
+ // These require special care.
+ if (fl->Declaration()->initializer == ip) {
+ // Insertion point is a for-loop initializer.
+ // For-loop needs to be decomposed to a loop.
+
+ // Index the map to decompose this for-loop, even if `stmt` is nullptr.
+ auto& decls = ForLoop(fl)->init_decls;
+ if constexpr (!std::is_same_v<BUILDER, Decompose>) {
+ decls.Push(std::forward<BUILDER>(builder));
+ }
+
+ return true;
+ }
+
+ if (TINT_LIKELY(fl->Declaration()->continuing == ip)) {
+ // Insertion point is a for-loop continuing statement.
+ // For-loop needs to be decomposed to a loop.
+
+ // Index the map to decompose this for-loop, even if `stmt` is nullptr.
+ auto& decls = ForLoop(fl)->cont_decls;
+ if constexpr (!std::is_same_v<BUILDER, Decompose>) {
+ decls.Push(std::forward<BUILDER>(builder));
+ }
+
+ return true;
+ }
+
+ TINT_ICE(Transform, b.Diagnostics()) << "unhandled use of expression in for-loop";
+ return false;
+ }
+
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled expression parent statement type: " << parent->TypeInfo().name;
+ return false;
+ }
+};
+
+HoistToDeclBefore::HoistToDeclBefore(CloneContext& ctx) : state_(std::make_unique<State>(ctx)) {}
+
+HoistToDeclBefore::~HoistToDeclBefore() {}
+
+bool HoistToDeclBefore::Add(const sem::ValueExpression* before_expr,
+ const Expression* expr,
+ VariableKind kind,
+ const char* decl_name) {
+ return state_->Add(before_expr, expr, kind, decl_name);
+}
+
+bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt, const Statement* stmt) {
+ return state_->InsertBefore(before_stmt, stmt);
+}
+
+bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt,
+ const StmtBuilder& builder) {
+ return state_->InsertBefore(before_stmt, builder);
+}
+
+bool HoistToDeclBefore::Replace(const sem::Statement* what, const Statement* with) {
+ return state_->Replace(what, with);
+}
+
+bool HoistToDeclBefore::Replace(const sem::Statement* what, const StmtBuilder& with) {
+ return state_->Replace(what, with);
+}
+
+bool HoistToDeclBefore::Prepare(const sem::ValueExpression* before_expr) {
+ return state_->Prepare(before_expr);
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before.h b/src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before.h
new file mode 100644
index 0000000..7c55042
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before.h
@@ -0,0 +1,107 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_UTILS_HOIST_TO_DECL_BEFORE_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_UTILS_HOIST_TO_DECL_BEFORE_H_
+
+#include <functional>
+#include <memory>
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+#include "src/tint/sem/value_expression.h"
+
+namespace tint::ast::transform {
+
+/// Utility class that can be used to hoist expressions before other
+/// expressions, possibly converting 'for-loop's to 'loop's and 'else-if's to
+// 'else {if}'s.
+class HoistToDeclBefore {
+ public:
+ /// Constructor
+ /// @param ctx the clone context
+ explicit HoistToDeclBefore(CloneContext& ctx);
+
+ /// Destructor
+ ~HoistToDeclBefore();
+
+ /// StmtBuilder is a builder of an AST statement
+ using StmtBuilder = std::function<const Statement*()>;
+
+ /// VariableKind is either a var, let or const
+ enum class VariableKind {
+ kVar,
+ kLet,
+ kConst,
+ };
+
+ /// Hoists @p expr to a `let` or `var` with optional `decl_name`, inserting it
+ /// before @p before_expr.
+ /// @param before_expr expression to insert `expr` before
+ /// @param expr expression to hoist
+ /// @param kind variable kind to hoist to
+ /// @param decl_name optional name to use for the variable/constant name
+ /// @return true on success
+ bool Add(const sem::ValueExpression* before_expr,
+ const Expression* expr,
+ VariableKind kind,
+ const char* decl_name = "");
+
+ /// Inserts @p stmt before @p before_stmt, possibly converting 'for-loop's to 'loop's if
+ /// necessary.
+ /// @warning If the container of @p before_stmt is cloned multiple times, then the resolver will
+ /// ICE as the same statement cannot be shared.
+ /// @param before_stmt statement to insert @p stmt before
+ /// @param stmt statement to insert
+ /// @return true on success
+ bool InsertBefore(const sem::Statement* before_stmt, const Statement* stmt);
+
+ /// Inserts the returned statement of @p builder before @p before_stmt, possibly converting
+ /// 'for-loop's to 'loop's if necessary.
+ /// @note If the container of @p before_stmt is cloned multiple times, then @p builder will be
+ /// called for each clone.
+ /// @param before_stmt the preceding statement that the statement of @p builder will be inserted
+ /// before
+ /// @param builder the statement builder used to create the new statement
+ /// @return true on success
+ bool InsertBefore(const sem::Statement* before_stmt, const StmtBuilder& builder);
+
+ /// Replaces the statement @p what with the statement @p stmt, possibly converting 'for-loop's
+ /// to 'loop's if necessary.
+ /// @param what the statement to replace
+ /// @param with the replacement statement
+ /// @return true on success
+ bool Replace(const sem::Statement* what, const Statement* with);
+
+ /// Replaces the statement @p what with the statement returned by @p stmt, possibly converting
+ /// 'for-loop's to 'loop's if necessary.
+ /// @param what the statement to replace
+ /// @param with the replacement statement builder
+ /// @return true on success
+ bool Replace(const sem::Statement* what, const StmtBuilder& with);
+
+ /// Use to signal that we plan on hoisting a decl before `before_expr`. This
+ /// will convert 'for-loop's to 'loop's and 'else-if's to 'else {if}'s if
+ /// needed.
+ /// @param before_expr expression we would hoist a decl before
+ /// @return true on success
+ bool Prepare(const sem::ValueExpression* before_expr);
+
+ private:
+ struct State;
+ std::unique_ptr<State> state_;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_UTILS_HOIST_TO_DECL_BEFORE_H_
diff --git a/src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before_test.cc b/src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before_test.cc
new file mode 100644
index 0000000..f2c1bbc
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before_test.cc
@@ -0,0 +1,1140 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <utility>
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/if_statement.h"
+#include "src/tint/sem/index_accessor_expression.h"
+#include "src/tint/sem/statement.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast::transform {
+namespace {
+
+using HoistToDeclBeforeTest = ::testing::Test;
+
+TEST_F(HoistToDeclBeforeTest, VarInit) {
+ // fn f() {
+ // var a = 1;
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Expr(1_i);
+ auto* var = b.Decl(b.Var("a", expr));
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, HoistToDeclBefore::VariableKind::kLet);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn f() {
+ let tint_symbol : i32 = 1i;
+ var a = tint_symbol;
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, ForLoopInit) {
+ // fn f() {
+ // for(var a = 1i; true; ) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Expr(1_i);
+ auto* s = b.For(b.Decl(b.Var("a", expr)), b.Expr(true), nullptr, b.Block());
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, HoistToDeclBefore::VariableKind::kVar);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn f() {
+ {
+ var tint_symbol : i32 = 1i;
+ var a = tint_symbol;
+ loop {
+ if (!(true)) {
+ break;
+ }
+ {
+ }
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, ForLoopCond) {
+ // fn f() {
+ // const a = true;
+ // for(; a; ) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* var = b.Decl(b.Const("a", b.Expr(true)));
+ auto* expr = b.Expr("a");
+ auto* s = b.For(nullptr, expr, nullptr, b.Block());
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().GetVal(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, HoistToDeclBefore::VariableKind::kConst);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn f() {
+ const a = true;
+ loop {
+ const tint_symbol = a;
+ if (!(tint_symbol)) {
+ break;
+ }
+ {
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, ForLoopCont) {
+ // fn f() {
+ // for(; true; var a = 1i) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Expr(1_i);
+ auto* s = b.For(nullptr, b.Expr(true), b.Decl(b.Var("a", expr)), b.Block());
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, HoistToDeclBefore::VariableKind::kLet);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn f() {
+ loop {
+ if (!(true)) {
+ break;
+ }
+ {
+ }
+
+ continuing {
+ let tint_symbol : i32 = 1i;
+ var a = tint_symbol;
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, WhileCond) {
+ // fn f() {
+ // var a : bool;
+ // while(a) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* var = b.Decl(b.Var("a", b.ty.bool_()));
+ auto* expr = b.Expr("a");
+ auto* s = b.While(expr, b.Block());
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().GetVal(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, HoistToDeclBefore::VariableKind::kVar);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn f() {
+ var a : bool;
+ loop {
+ var tint_symbol : bool = a;
+ if (!(tint_symbol)) {
+ break;
+ }
+ {
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, ElseIf) {
+ // fn f() {
+ // const a = true;
+ // if (true) {
+ // } else if (a) {
+ // } else {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* var = b.Decl(b.Const("a", b.Expr(true)));
+ auto* expr = b.Expr("a");
+ auto* s = b.If(b.Expr(true), b.Block(), //
+ b.Else(b.If(expr, b.Block(), //
+ b.Else(b.Block()))));
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().GetVal(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, HoistToDeclBefore::VariableKind::kConst);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn f() {
+ const a = true;
+ if (true) {
+ } else {
+ const tint_symbol = a;
+ if (tint_symbol) {
+ } else {
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, Array1D) {
+ // fn f() {
+ // var a : array<i32, 10>;
+ // var b = a[0];
+ // }
+ ProgramBuilder b;
+ auto* var1 = b.Decl(b.Var("a", b.ty.array<i32, 10>()));
+ auto* expr = b.IndexAccessor("a", 0_i);
+ auto* var2 = b.Decl(b.Var("b", expr));
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var1, var2});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, HoistToDeclBefore::VariableKind::kLet);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn f() {
+ var a : array<i32, 10u>;
+ let tint_symbol : i32 = a[0i];
+ var b = tint_symbol;
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, Array2D) {
+ // fn f() {
+ // var a : array<array<i32, 10>, 10>;
+ // var b = a[0][0];
+ // }
+ ProgramBuilder b;
+
+ auto* var1 = b.Decl(b.Var("a", b.ty.array(b.ty.array<i32, 10>(), 10_i)));
+ auto* expr = b.IndexAccessor(b.IndexAccessor("a", 0_i), 0_i);
+ auto* var2 = b.Decl(b.Var("b", expr));
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var1, var2});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, HoistToDeclBefore::VariableKind::kVar);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn f() {
+ var a : array<array<i32, 10u>, 10i>;
+ var tint_symbol : i32 = a[0i][0i];
+ var b = tint_symbol;
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, Prepare_ForLoopCond) {
+ // fn f() {
+ // var a : bool;
+ // for(; a; ) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* var = b.Decl(b.Var("a", b.ty.bool_()));
+ auto* expr = b.Expr("a");
+ auto* s = b.For(nullptr, expr, nullptr, b.Block());
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().GetVal(expr);
+ hoistToDeclBefore.Prepare(sem_expr);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn f() {
+ var a : bool;
+ loop {
+ if (!(a)) {
+ break;
+ }
+ {
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, Prepare_ForLoopCont) {
+ // fn f() {
+ // for(; true; var a = 1i) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Expr(1_i);
+ auto* s = b.For(nullptr, b.Expr(true), b.Decl(b.Var("a", expr)), b.Block());
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Prepare(sem_expr);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn f() {
+ loop {
+ if (!(true)) {
+ break;
+ }
+ {
+ }
+
+ continuing {
+ var a = 1i;
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, Prepare_ElseIf) {
+ // fn f() {
+ // var a : bool;
+ // if (true) {
+ // } else if (a) {
+ // } else {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* var = b.Decl(b.Var("a", b.ty.bool_()));
+ auto* expr = b.Expr("a");
+ auto* s = b.If(b.Expr(true), b.Block(), //
+ b.Else(b.If(expr, b.Block(), //
+ b.Else(b.Block()))));
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().GetVal(expr);
+ hoistToDeclBefore.Prepare(sem_expr);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn f() {
+ var a : bool;
+ if (true) {
+ } else {
+ if (a) {
+ } else {
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, InsertBefore_Block) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // var a = 1i;
+ // }
+ ProgramBuilder b;
+ b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
+ auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* before_stmt = ctx.src->Sem().Get(var);
+ auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+ hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ foo();
+ var a = 1i;
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, InsertBefore_Block_Function) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // var a = 1i;
+ // }
+ ProgramBuilder b;
+ b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
+ auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* before_stmt = ctx.src->Sem().Get(var);
+ hoistToDeclBefore.InsertBefore(before_stmt,
+ [&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); });
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ foo();
+ var a = 1i;
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopInit) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // for(var a = 1i; true;) {
+ // }
+ // }
+ ProgramBuilder b;
+ b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
+ auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
+ auto* s = b.For(var, b.Expr(true), nullptr, b.Block());
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* before_stmt = ctx.src->Sem().Get(var);
+ auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+ hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ {
+ foo();
+ var a = 1i;
+ loop {
+ if (!(true)) {
+ break;
+ }
+ {
+ }
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopInit_Function) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // for(var a = 1i; true;) {
+ // }
+ // }
+ ProgramBuilder b;
+ b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
+ auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
+ auto* s = b.For(var, b.Expr(true), nullptr, b.Block());
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* before_stmt = ctx.src->Sem().Get(var);
+ hoistToDeclBefore.InsertBefore(before_stmt,
+ [&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); });
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ {
+ foo();
+ var a = 1i;
+ loop {
+ if (!(true)) {
+ break;
+ }
+ {
+ }
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopCont) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // var a = 1i;
+ // for(; true; a+=1i) {
+ // }
+ // }
+ ProgramBuilder b;
+ b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
+ auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
+ auto* cont = b.CompoundAssign("a", b.Expr(1_i), BinaryOp::kAdd);
+ auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* before_stmt = ctx.src->Sem().Get(cont->As<Statement>());
+ auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+ hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ var a = 1i;
+ loop {
+ if (!(true)) {
+ break;
+ }
+ {
+ }
+
+ continuing {
+ foo();
+ a += 1i;
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopCont_Function) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // var a = 1i;
+ // for(; true; a+=1i) {
+ // }
+ // }
+ ProgramBuilder b;
+ b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
+ auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
+ auto* cont = b.CompoundAssign("a", b.Expr(1_i), BinaryOp::kAdd);
+ auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* before_stmt = ctx.src->Sem().Get(cont->As<Statement>());
+ hoistToDeclBefore.InsertBefore(before_stmt,
+ [&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); });
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ var a = 1i;
+ loop {
+ if (!(true)) {
+ break;
+ }
+ {
+ }
+
+ continuing {
+ foo();
+ a += 1i;
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, InsertBefore_ElseIf) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // var a : bool;
+ // if (true) {
+ // } else if (a) {
+ // } else {
+ // }
+ // }
+ ProgramBuilder b;
+ b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
+ auto* var = b.Decl(b.Var("a", b.ty.bool_()));
+ auto* elseif = b.If(b.Expr("a"), b.Block(), b.Else(b.Block()));
+ auto* s = b.If(b.Expr(true), b.Block(), //
+ b.Else(elseif));
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* before_stmt = ctx.src->Sem().Get(elseif);
+ auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+ hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ var a : bool;
+ if (true) {
+ } else {
+ foo();
+ if (a) {
+ } else {
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, InsertBefore_ElseIf_Function) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // var a : bool;
+ // if (true) {
+ // } else if (a) {
+ // } else {
+ // }
+ // }
+ ProgramBuilder b;
+ b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
+ auto* var = b.Decl(b.Var("a", b.ty.bool_()));
+ auto* elseif = b.If(b.Expr("a"), b.Block(), b.Else(b.Block()));
+ auto* s = b.If(b.Expr(true), b.Block(), //
+ b.Else(elseif));
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* before_stmt = ctx.src->Sem().Get(elseif);
+ hoistToDeclBefore.InsertBefore(before_stmt,
+ [&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); });
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ var a : bool;
+ if (true) {
+ } else {
+ foo();
+ if (a) {
+ } else {
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, AbstractArray_ToLet) {
+ // fn f() {
+ // var a : array<f32, 1> = array(1);
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Call(b.ty("array"), b.Expr(1_a));
+ auto* var = b.Decl(b.Var("a", b.ty.array(b.ty.f32(), 1_a), expr));
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, HoistToDeclBefore::VariableKind::kLet);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn f() {
+ let tint_symbol : array<f32, 1u> = array(1);
+ var a : array<f32, 1> = tint_symbol;
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, AbstractArray_ToVar) {
+ // fn f() {
+ // var a : array<f32, 1> = array(1);
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Call(b.ty("array"), b.Expr(1_a));
+ auto* var = b.Decl(b.Var("a", b.ty.array(b.ty.f32(), 1_a), expr));
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, HoistToDeclBefore::VariableKind::kVar);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn f() {
+ var tint_symbol : array<f32, 1u> = array(1);
+ var a : array<f32, 1> = tint_symbol;
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, Replace_Block) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // var a = 1i;
+ // }
+ ProgramBuilder b;
+ b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
+ auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* target_stmt = ctx.src->Sem().Get(var);
+ auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+ hoistToDeclBefore.Replace(target_stmt, new_stmt);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ foo();
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, Replace_Block_Function) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // var a = 1i;
+ // }
+ ProgramBuilder b;
+ b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
+ auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* target_stmt = ctx.src->Sem().Get(var);
+ hoistToDeclBefore.Replace(target_stmt, [&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); });
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ foo();
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, Replace_ForLoopInit) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // for(var a = 1i; true;) {
+ // }
+ // }
+ ProgramBuilder b;
+ b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
+ auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
+ auto* s = b.For(var, b.Expr(true), nullptr, b.Block());
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* target_stmt = ctx.src->Sem().Get(var);
+ auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+ hoistToDeclBefore.Replace(target_stmt, new_stmt);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ {
+ foo();
+ loop {
+ if (!(true)) {
+ break;
+ }
+ {
+ }
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, Replace_ForLoopInit_Function) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // for(var a = 1i; true;) {
+ // }
+ // }
+ ProgramBuilder b;
+ b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
+ auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
+ auto* s = b.For(var, b.Expr(true), nullptr, b.Block());
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* target_stmt = ctx.src->Sem().Get(var);
+ hoistToDeclBefore.Replace(target_stmt, [&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); });
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ {
+ foo();
+ loop {
+ if (!(true)) {
+ break;
+ }
+ {
+ }
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, Replace_ForLoopCont) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // var a = 1i;
+ // for(; true; a+=1i) {
+ // }
+ // }
+ ProgramBuilder b;
+ b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
+ auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
+ auto* cont = b.CompoundAssign("a", b.Expr(1_i), BinaryOp::kAdd);
+ auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* target_stmt = ctx.src->Sem().Get(cont->As<Statement>());
+ auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+ hoistToDeclBefore.Replace(target_stmt, new_stmt);
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ var a = 1i;
+ loop {
+ if (!(true)) {
+ break;
+ }
+ {
+ }
+
+ continuing {
+ foo();
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, Replace_ForLoopCont_Function) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // var a = 1i;
+ // for(; true; a+=1i) {
+ // }
+ // }
+ ProgramBuilder b;
+ b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
+ auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
+ auto* cont = b.CompoundAssign("a", b.Expr(1_i), BinaryOp::kAdd);
+ auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
+ b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* target_stmt = ctx.src->Sem().Get(cont->As<Statement>());
+ hoistToDeclBefore.Replace(target_stmt, [&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); });
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ var a = 1i;
+ loop {
+ if (!(true)) {
+ break;
+ }
+ {
+ }
+
+ continuing {
+ foo();
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/var_for_dynamic_index.cc b/src/tint/lang/wgsl/ast/transform/var_for_dynamic_index.cc
new file mode 100644
index 0000000..d3302d9
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/var_for_dynamic_index.cc
@@ -0,0 +1,80 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/var_for_dynamic_index.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/utils/hoist_to_decl_before.h"
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::VarForDynamicIndex);
+
+namespace tint::ast::transform {
+
+VarForDynamicIndex::VarForDynamicIndex() = default;
+
+VarForDynamicIndex::~VarForDynamicIndex() = default;
+
+Transform::ApplyResult VarForDynamicIndex::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ HoistToDeclBefore hoist_to_decl_before(ctx);
+
+ // Extracts array and matrix values that are dynamically indexed to a
+ // temporary `var` local that is then indexed.
+ auto dynamic_index_to_var = [&](const IndexAccessorExpression* access_expr) {
+ auto* index_expr = access_expr->index;
+ auto* object_expr = access_expr->object;
+ auto& sem = src->Sem();
+
+ if (sem.GetVal(index_expr)->ConstantValue()) {
+ // Index expression resolves to a compile time value.
+ // As this isn't a dynamic index, we can ignore this.
+ return true;
+ }
+
+ auto* indexed = sem.GetVal(object_expr);
+ if (!indexed->Type()->IsAnyOf<type::Array, type::Matrix>()) {
+ // We only care about array and matrices.
+ return true;
+ }
+
+ // TODO(bclayton): group multiple accesses in the same object.
+ // e.g. arr[i] + arr[i+1] // Don't create two vars for this
+ return hoist_to_decl_before.Add(indexed, object_expr, HoistToDeclBefore::VariableKind::kVar,
+ "var_for_index");
+ };
+
+ bool index_accessor_found = false;
+ for (auto* node : src->ASTNodes().Objects()) {
+ if (auto* access_expr = node->As<IndexAccessorExpression>()) {
+ if (!dynamic_index_to_var(access_expr)) {
+ return Program(std::move(b));
+ }
+ index_accessor_found = true;
+ }
+ }
+ if (!index_accessor_found) {
+ return SkipTransform;
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/var_for_dynamic_index.h b/src/tint/lang/wgsl/ast/transform/var_for_dynamic_index.h
new file mode 100644
index 0000000..3d66520
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/var_for_dynamic_index.h
@@ -0,0 +1,42 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_VAR_FOR_DYNAMIC_INDEX_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_VAR_FOR_DYNAMIC_INDEX_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// A transform that extracts array and matrix values that are dynamically
+/// indexed to a temporary `var` local before performing the index. This
+/// transform is used by the SPIR-V writer as there is no SPIR-V instruction
+/// that can dynamically index a non-pointer composite.
+class VarForDynamicIndex final : public utils::Castable<VarForDynamicIndex, Transform> {
+ public:
+ /// Constructor
+ VarForDynamicIndex();
+
+ /// Destructor
+ ~VarForDynamicIndex() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_VAR_FOR_DYNAMIC_INDEX_H_
diff --git a/src/tint/lang/wgsl/ast/transform/var_for_dynamic_index_test.cc b/src/tint/lang/wgsl/ast/transform/var_for_dynamic_index_test.cc
new file mode 100644
index 0000000..1a79f73
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/var_for_dynamic_index_test.cc
@@ -0,0 +1,563 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/var_for_dynamic_index.h"
+#include "src/tint/lang/wgsl/ast/transform/for_loop_to_loop.h"
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using VarForDynamicIndexTest = TransformTest;
+
+TEST_F(VarForDynamicIndexTest, EmptyModule) {
+ auto* src = "";
+ auto* expect = "";
+
+ auto got = Run<ForLoopToLoop, VarForDynamicIndex>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, ArrayIndexDynamic) {
+ auto* src = R"(
+fn f() {
+ var i : i32;
+ let p = array<i32, 4>(1, 2, 3, 4);
+ let x = p[i];
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i : i32;
+ let p = array<i32, 4>(1, 2, 3, 4);
+ var var_for_index : array<i32, 4u> = p;
+ let x = var_for_index[i];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, MatrixIndexDynamic) {
+ auto* src = R"(
+fn f() {
+ var i : i32;
+ let p = mat2x2(1.0, 2.0, 3.0, 4.0);
+ let x = p[i];
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i : i32;
+ let p = mat2x2(1.0, 2.0, 3.0, 4.0);
+ var var_for_index : mat2x2<f32> = p;
+ let x = var_for_index[i];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, ArrayIndexDynamicChain) {
+ auto* src = R"(
+fn f() {
+ var i : i32;
+ var j : i32;
+ let p = array<array<i32, 2>, 2>(array<i32, 2>(1, 2), array<i32, 2>(3, 4));
+ let x = p[i][j];
+}
+)";
+
+ // TODO(bclayton): Optimize this case:
+ // This output is not as efficient as it could be.
+ // We only actually need to hoist the inner-most array to a `var`
+ // (`var_for_index`), as later indexing operations will be working with
+ // references, not values.
+ auto* expect = R"(
+fn f() {
+ var i : i32;
+ var j : i32;
+ let p = array<array<i32, 2>, 2>(array<i32, 2>(1, 2), array<i32, 2>(3, 4));
+ var var_for_index : array<array<i32, 2u>, 2u> = p;
+ var var_for_index_1 : array<i32, 2u> = var_for_index[i];
+ let x = var_for_index_1[j];
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, ArrayIndexInForLoopInit) {
+ auto* src = R"(
+fn f() {
+ var i : i32;
+ let p = array<array<i32, 2>, 2>(array<i32, 2>(1, 2), array<i32, 2>(3, 4));
+ for(let x = p[i]; ; ) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i : i32;
+ let p = array<array<i32, 2>, 2>(array<i32, 2>(1, 2), array<i32, 2>(3, 4));
+ {
+ var var_for_index : array<array<i32, 2u>, 2u> = p;
+ let x = var_for_index[i];
+ loop {
+ {
+ break;
+ }
+ }
+ }
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, MatrixIndexInForLoopInit) {
+ auto* src = R"(
+fn f() {
+ var i : i32;
+ let p = mat2x2(1.0, 2.0, 3.0, 4.0);
+ for(let x = p[i]; ; ) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i : i32;
+ let p = mat2x2(1.0, 2.0, 3.0, 4.0);
+ {
+ var var_for_index : mat2x2<f32> = p;
+ let x = var_for_index[i];
+ loop {
+ {
+ break;
+ }
+ }
+ }
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, ArrayIndexInForLoopCond) {
+ auto* src = R"(
+fn f() {
+ var i : i32;
+ let p = array<i32, 2>(1, 2);
+ for(; p[i] < 3; ) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i : i32;
+ let p = array<i32, 2>(1, 2);
+ loop {
+ var var_for_index : array<i32, 2u> = p;
+ if (!((var_for_index[i] < 3))) {
+ break;
+ }
+ {
+ break;
+ }
+ }
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, MatrixIndexInForLoopCond) {
+ auto* src = R"(
+fn f() {
+ var i : i32;
+ let p = mat2x2(1.0, 2.0, 3.0, 4.0);
+ for(; p[i].x < 3.0; ) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i : i32;
+ let p = mat2x2(1.0, 2.0, 3.0, 4.0);
+ loop {
+ var var_for_index : mat2x2<f32> = p;
+ if (!((var_for_index[i].x < 3.0))) {
+ break;
+ }
+ {
+ break;
+ }
+ }
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, MatrixIndexInForLoopCondWithNestedIndex) {
+ auto* src = R"(
+fn f() {
+ var i : i32;
+ let p = mat2x2(1.0, 2.0, 3.0, 4.0);
+ for(; p[i].x < 3.0; ) {
+ if (p[i].x < 1.0) {
+ var marker = 1;
+ }
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i : i32;
+ let p = mat2x2(1.0, 2.0, 3.0, 4.0);
+ loop {
+ var var_for_index : mat2x2<f32> = p;
+ if (!((var_for_index[i].x < 3.0))) {
+ break;
+ }
+ {
+ var var_for_index_1 : mat2x2<f32> = p;
+ if ((var_for_index_1[i].x < 1.0)) {
+ var marker = 1;
+ }
+ break;
+ }
+ }
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, ArrayIndexInElseIf) {
+ auto* src = R"(
+fn f() {
+ var i : i32;
+ let p = array<i32, 2>(1, 2);
+ if (false) {
+ var marker = 0;
+ } else if (p[i] < 3) {
+ var marker = 1;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i : i32;
+ let p = array<i32, 2>(1, 2);
+ if (false) {
+ var marker = 0;
+ } else {
+ var var_for_index : array<i32, 2u> = p;
+ if ((var_for_index[i] < 3)) {
+ var marker = 1;
+ }
+ }
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, ArrayIndexInElseIfChain) {
+ auto* src = R"(
+fn f() {
+ var i : i32;
+ let p = array<i32, 2>(1, 2);
+ if (true) {
+ var marker = 0;
+ } else if (true) {
+ var marker = 1;
+ } else if (p[i] < 3) {
+ var marker = 2;
+ } else if (p[i] < 4) {
+ var marker = 3;
+ } else if (true) {
+ var marker = 4;
+ } else {
+ var marker = 5;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i : i32;
+ let p = array<i32, 2>(1, 2);
+ if (true) {
+ var marker = 0;
+ } else if (true) {
+ var marker = 1;
+ } else {
+ var var_for_index : array<i32, 2u> = p;
+ if ((var_for_index[i] < 3)) {
+ var marker = 2;
+ } else {
+ var var_for_index_1 : array<i32, 2u> = p;
+ if ((var_for_index_1[i] < 4)) {
+ var marker = 3;
+ } else if (true) {
+ var marker = 4;
+ } else {
+ var marker = 5;
+ }
+ }
+ }
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, MatrixIndexInElseIf) {
+ auto* src = R"(
+fn f() {
+ var i : i32;
+ let p = mat2x2(1.0, 2.0, 3.0, 4.0);
+ if (false) {
+ var marker_if = 1;
+ } else if (p[i].x < 3.0) {
+ var marker_else_if = 1;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i : i32;
+ let p = mat2x2(1.0, 2.0, 3.0, 4.0);
+ if (false) {
+ var marker_if = 1;
+ } else {
+ var var_for_index : mat2x2<f32> = p;
+ if ((var_for_index[i].x < 3.0)) {
+ var marker_else_if = 1;
+ }
+ }
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, MatrixIndexInElseIfChain) {
+ auto* src = R"(
+fn f() {
+ var i : i32;
+ let p = mat2x2(1.0, 2.0, 3.0, 4.0);
+ if (true) {
+ var marker = 0;
+ } else if (true) {
+ var marker = 1;
+ } else if (p[i].x < 3.0) {
+ var marker = 2;
+ } else if (p[i].y < 3.0) {
+ var marker = 3;
+ } else if (true) {
+ var marker = 4;
+ } else {
+ var marker = 5;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i : i32;
+ let p = mat2x2(1.0, 2.0, 3.0, 4.0);
+ if (true) {
+ var marker = 0;
+ } else if (true) {
+ var marker = 1;
+ } else {
+ var var_for_index : mat2x2<f32> = p;
+ if ((var_for_index[i].x < 3.0)) {
+ var marker = 2;
+ } else {
+ var var_for_index_1 : mat2x2<f32> = p;
+ if ((var_for_index_1[i].y < 3.0)) {
+ var marker = 3;
+ } else if (true) {
+ var marker = 4;
+ } else {
+ var marker = 5;
+ }
+ }
+ }
+}
+)";
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, ArrayIndexLiteral) {
+ auto* src = R"(
+fn f() {
+ let p = array<i32, 4>(1, 2, 3, 4);
+ let x = p[1];
+}
+)";
+
+ auto* expect = src;
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, MatrixIndexLiteral) {
+ auto* src = R"(
+fn f() {
+ let p = mat2x2(1.0, 2.0, 3.0, 4.0);
+ let x = p[1];
+}
+)";
+
+ auto* expect = src;
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, ArrayIndexConstantLet) {
+ auto* src = R"(
+fn f() {
+ let p = array<i32, 4>(1, 2, 3, 4);
+ let c = 1;
+ var var_for_index = p;
+ let x = var_for_index[c];
+}
+)";
+
+ auto* expect = src;
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, MatrixIndexConstantLet) {
+ auto* src = R"(
+fn f() {
+ let p = mat2x2(1.0, 2.0, 3.0, 4.0);
+ let c = 1;
+ var var_for_index = p;
+ let x = var_for_index[c];
+}
+)";
+
+ auto* expect = src;
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, ArrayIndexLiteralChain) {
+ auto* src = R"(
+fn f() {
+ let a = array<i32, 2>(1, 2);
+ let b = array<i32, 2>(3, 4);
+ let p = array<array<i32, 2>, 2>(a, b);
+ let x = p[0][1];
+}
+)";
+
+ auto* expect = src;
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VarForDynamicIndexTest, MatrixIndexLiteralChain) {
+ auto* src = R"(
+fn f() {
+ let p = mat2x2(1.0, 2.0, 3.0, 4.0);
+ let x = p[0][1];
+}
+)";
+
+ auto* expect = src;
+
+ Transform::DataMap data;
+ auto got = Run<VarForDynamicIndex>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/vectorize_matrix_conversions.cc b/src/tint/lang/wgsl/ast/transform/vectorize_matrix_conversions.cc
new file mode 100644
index 0000000..511cfbc
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/vectorize_matrix_conversions.cc
@@ -0,0 +1,149 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/vectorize_matrix_conversions.h"
+
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/value_conversion.h"
+#include "src/tint/sem/value_expression.h"
+#include "src/tint/type/abstract_numeric.h"
+#include "src/tint/utils/hash.h"
+#include "src/tint/utils/map.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::VectorizeMatrixConversions);
+
+namespace tint::ast::transform {
+
+namespace {
+
+bool ShouldRun(const Program* program) {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (auto* sem = program->Sem().GetVal(node)) {
+ if (auto* call = sem->UnwrapMaterialize()->As<sem::Call>()) {
+ if (call->Target()->Is<sem::ValueConversion>() &&
+ call->Type()->Is<type::Matrix>()) {
+ auto& args = call->Arguments();
+ if (args.Length() == 1 && args[0]->Type()->UnwrapRef()->is_float_matrix()) {
+ return true;
+ }
+ }
+ }
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+VectorizeMatrixConversions::VectorizeMatrixConversions() = default;
+
+VectorizeMatrixConversions::~VectorizeMatrixConversions() = default;
+
+Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ using HelperFunctionKey =
+ utils::UnorderedKeyWrapper<std::tuple<const type::Matrix*, const type::Matrix*>>;
+
+ std::unordered_map<HelperFunctionKey, Symbol> matrix_convs;
+
+ ctx.ReplaceAll([&](const CallExpression* expr) -> const CallExpression* {
+ auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
+ auto* ty_conv = call->Target()->As<sem::ValueConversion>();
+ if (!ty_conv) {
+ return nullptr;
+ }
+ auto* dst_type = call->Type()->As<type::Matrix>();
+ if (!dst_type) {
+ return nullptr;
+ }
+
+ auto& args = call->Arguments();
+ if (args.Length() != 1) {
+ return nullptr;
+ }
+
+ auto& matrix = args[0];
+
+ auto* src_type = matrix->Type()->UnwrapRef()->As<type::Matrix>();
+ if (!src_type) {
+ return nullptr;
+ }
+
+ // The source and destination type of a matrix conversion must have a same shape.
+ if (TINT_UNLIKELY(!(src_type->rows() == dst_type->rows() &&
+ src_type->columns() == dst_type->columns()))) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "source and destination matrix has different shape in matrix conversion";
+ return nullptr;
+ }
+
+ auto build_vectorized_conversion_expression = [&](auto&& src_expression_builder) {
+ utils::Vector<const Expression*, 4> columns;
+ for (uint32_t c = 0; c < dst_type->columns(); c++) {
+ auto* src_matrix_expr = src_expression_builder();
+ auto* src_column_expr = b.IndexAccessor(src_matrix_expr, b.Expr(tint::AInt(c)));
+ columns.Push(
+ b.Call(CreateASTTypeFor(ctx, dst_type->ColumnType()), src_column_expr));
+ }
+ return b.Call(CreateASTTypeFor(ctx, dst_type), columns);
+ };
+
+ // Replace the matrix conversion to column vector conversions and a matrix construction.
+ if (!matrix->HasSideEffects()) {
+ // Simply use the argument's declaration if it has no side effects.
+ return build_vectorized_conversion_expression([&] { //
+ return ctx.Clone(matrix->Declaration());
+ });
+ } else {
+ // If has side effects, use a helper function.
+ auto fn =
+ utils::GetOrCreate(matrix_convs, HelperFunctionKey{{src_type, dst_type}}, [&] {
+ auto name = b.Symbols().New(
+ "convert_mat" + std::to_string(src_type->columns()) + "x" +
+ std::to_string(src_type->rows()) + "_" + src_type->type()->FriendlyName() +
+ "_" + dst_type->type()->FriendlyName());
+ b.Func(name,
+ utils::Vector{
+ b.Param("value", CreateASTTypeFor(ctx, src_type)),
+ },
+ CreateASTTypeFor(ctx, dst_type),
+ utils::Vector{
+ b.Return(build_vectorized_conversion_expression([&] { //
+ return b.Expr("value");
+ })),
+ });
+ return name;
+ });
+ return b.Call(fn, ctx.Clone(args[0]->Declaration()));
+ }
+ });
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/vectorize_matrix_conversions.h b/src/tint/lang/wgsl/ast/transform/vectorize_matrix_conversions.h
new file mode 100644
index 0000000..b61777a
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/vectorize_matrix_conversions.h
@@ -0,0 +1,40 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_VECTORIZE_MATRIX_CONVERSIONS_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_VECTORIZE_MATRIX_CONVERSIONS_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// A transform that converts matrix conversions (between f32 and f16 matrices) to the vector form.
+class VectorizeMatrixConversions final
+ : public utils::Castable<VectorizeMatrixConversions, Transform> {
+ public:
+ /// Constructor
+ VectorizeMatrixConversions();
+
+ /// Destructor
+ ~VectorizeMatrixConversions() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_VECTORIZE_MATRIX_CONVERSIONS_H_
diff --git a/src/tint/lang/wgsl/ast/transform/vectorize_matrix_conversions_test.cc b/src/tint/lang/wgsl/ast/transform/vectorize_matrix_conversions_test.cc
new file mode 100644
index 0000000..c876461
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/vectorize_matrix_conversions_test.cc
@@ -0,0 +1,411 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/vectorize_matrix_conversions.h"
+
+#include <string>
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/utils/string.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using VectorizeMatrixConversionsTest = TransformTestWithParam<std::pair<uint32_t, uint32_t>>;
+
+TEST_F(VectorizeMatrixConversionsTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<VectorizeMatrixConversions>(src));
+}
+
+// Test that VectorizeMatrixConversions transforms the matRxC<f32> to matRxC<f16> conversion as
+// expected.
+//
+// Example input:
+//
+// enable f16;
+//
+// @fragment
+// fn main() {
+// let m = mat3x2<f32>(vec2<f32>(0.0, 1.0), vec2<f32>(2.0, 3.0), vec2<f32>(4.0, 5.0));
+// let n : mat3x2<f16> = mat3x2<f16>(m);
+// }
+//
+// Example output:
+//
+// enable f16;
+//
+// @fragment
+// fn main() {
+// let m = mat3x2<f32>(vec2<f32>(0.0, 1.0), vec2<f32>(2.0, 3.0), vec2<f32>(4.0, 5.0));
+// let n : mat3x2<f16> = mat3x2<f16>(vec2<f16>(m[0]), vec2<f16>(m[1]), vec2<f16>(m[2]));
+// }
+TEST_P(VectorizeMatrixConversionsTest, Conversion_F32ToF16) {
+ uint32_t cols = GetParam().first;
+ uint32_t rows = GetParam().second;
+ std::string src_mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
+ std::string src_vec_type = "vec" + std::to_string(rows) + "<f32>";
+ std::string dst_mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f16>";
+ std::string dst_vec_type = "vec" + std::to_string(rows) + "<f16>";
+ std::string vector_values;
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ vector_values += ", ";
+ }
+ vector_values += src_vec_type + "(";
+ for (uint32_t r = 0; r < rows; r++) {
+ if (r > 0) {
+ vector_values += ", ";
+ }
+ auto value = std::to_string(c * rows + r) + ".0";
+ vector_values += value;
+ }
+ vector_values += ")";
+ }
+
+ std::string vectorized_args = "";
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ vectorized_args += ", ";
+ }
+ vectorized_args += dst_vec_type + "(m[" + std::to_string(c) + "])";
+ }
+
+ std::string tmpl = R"(
+enable f16;
+
+@fragment
+fn main() {
+ let m = ${src_mat_type}(${values});
+ let n : ${dst_mat_type} = ${dst_mat_type}(${args});
+}
+)";
+ tmpl = utils::ReplaceAll(tmpl, "${src_mat_type}", src_mat_type);
+ tmpl = utils::ReplaceAll(tmpl, "${dst_mat_type}", dst_mat_type);
+ tmpl = utils::ReplaceAll(tmpl, "${values}", vector_values);
+ auto src = utils::ReplaceAll(tmpl, "${args}", "m");
+ auto expect = utils::ReplaceAll(tmpl, "${args}", vectorized_args);
+
+ EXPECT_TRUE(ShouldRun<VectorizeMatrixConversions>(src));
+
+ auto got = Run<VectorizeMatrixConversions>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that VectorizeMatrixConversions transforms the matRxC<f32> to matRxC<f16> conversion as
+// expected.
+//
+// Example input:
+//
+// enable f16;
+//
+// @fragment
+// fn main() {
+// let m = mat3x2<f16>(vec2<f16>(0.0, 1.0), vec2<f16>(2.0, 3.0), vec2<f16>(4.0, 5.0));
+// let n : mat3x2<f32> = mat3x2<f32>(m);
+// }
+//
+// Example output:
+//
+// enable f16;
+//
+// @fragment
+// fn main() {
+// let m = mat3x2<f16>(vec2<f16>(0.0, 1.0), vec2<f16>(2.0, 3.0), vec2<f16>(4.0, 5.0));
+// let n : mat3x2<f32> = mat3x2<f32>(vec2<f32>(m[0]), vec2<f32>(m[1]), vec2<f32>(m[2]));
+// }
+TEST_P(VectorizeMatrixConversionsTest, Conversion_F16ToF32) {
+ uint32_t cols = GetParam().first;
+ uint32_t rows = GetParam().second;
+ std::string src_mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f16>";
+ std::string src_vec_type = "vec" + std::to_string(rows) + "<f16>";
+ std::string dst_mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
+ std::string dst_vec_type = "vec" + std::to_string(rows) + "<f32>";
+ std::string vector_values;
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ vector_values += ", ";
+ }
+ vector_values += src_vec_type + "(";
+ for (uint32_t r = 0; r < rows; r++) {
+ if (r > 0) {
+ vector_values += ", ";
+ }
+ auto value = std::to_string(c * rows + r) + ".0";
+ vector_values += value;
+ }
+ vector_values += ")";
+ }
+
+ std::string vectorized_args = "";
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ vectorized_args += ", ";
+ }
+ vectorized_args += dst_vec_type + "(m[" + std::to_string(c) + "])";
+ }
+
+ std::string tmpl = R"(
+enable f16;
+
+@fragment
+fn main() {
+ let m = ${src_mat_type}(${values});
+ let n : ${dst_mat_type} = ${dst_mat_type}(${args});
+}
+)";
+ tmpl = utils::ReplaceAll(tmpl, "${src_mat_type}", src_mat_type);
+ tmpl = utils::ReplaceAll(tmpl, "${dst_mat_type}", dst_mat_type);
+ tmpl = utils::ReplaceAll(tmpl, "${values}", vector_values);
+ auto src = utils::ReplaceAll(tmpl, "${args}", "m");
+ auto expect = utils::ReplaceAll(tmpl, "${args}", vectorized_args);
+
+ EXPECT_TRUE(ShouldRun<VectorizeMatrixConversions>(src));
+
+ auto got = Run<VectorizeMatrixConversions>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that VectorizeMatrixConversions transform generates help functions for conversions of which
+// input expression has side effect.
+//
+// Example input:
+//
+// enable f16;
+//
+// var<private> i : i32 = 0;
+//
+// fn mat_f32() -> mat2x2<f32> {
+// i = (i + 1);
+// return mat2x2<f32>(vec2<f32>(f32(i), f32(i)), vec2<f32>(f32(i), f32(i)));
+// }
+//
+// fn mat_f16() -> mat2x2<f16> {
+// i = (i + 1);
+// return mat2x2<f16>(vec2<f16>(f16(i), f16(i)), vec2<f16>(f16(i), f16(i)));
+// }
+//
+// @fragment
+// fn main() {
+// let m32 : mat2x2<f32> = mat2x2<f32>(mat_f16());
+// let m16 : mat2x2<f16> = mat2x2<f16>(mat_f32());
+// }
+//
+// Example output:
+//
+// enable f16;
+//
+// var<private> i : i32 = 0;
+//
+// fn mat_f32() -> mat2x2<f32> {
+// i = (i + 1);
+// return mat2x2<f32>(vec2<f32>(f32(i), f32(i)), vec2<f32>(f32(i), f32(i)));
+// }
+//
+// fn mat_f16() -> mat2x2<f16> {
+// i = (i + 1);
+// return mat2x2<f16>(vec2<f16>(f16(i), f16(i)), vec2<f16>(f16(i), f16(i)));
+// }
+//
+// fn convert_mat2x2_f16_f32(value : mat2x2<f16>) -> mat2x2<f32> {
+// return mat2x2<f32>(vec2<f32>(value[0]), vec2<f32>(value[1]));
+// }
+//
+// fn convert_mat2x2_f32_f16(value : mat2x2<f32>) -> mat2x2<f16> {
+// return mat2x2<f16>(vec2<f16>(value[0]), vec2<f16>(value[1]));
+// }
+//
+// @fragment
+// fn main() {
+// let m32 : mat2x2<f32> = convert_mat2x2_f16_f32(mat_f16());
+// let m16 : mat2x2<f16> = convert_mat2x2_f32_f16(mat_f32());
+// }
+TEST_P(VectorizeMatrixConversionsTest, Conversion_WithSideEffect) {
+ uint32_t cols = GetParam().first;
+ uint32_t rows = GetParam().second;
+ std::string mat_shape = "mat" + std::to_string(cols) + "x" + std::to_string(rows);
+ std::string f32_mat_type = mat_shape + "<f32>";
+ std::string f32_vec_type = "vec" + std::to_string(rows) + "<f32>";
+ std::string f16_mat_type = mat_shape + "<f16>";
+ std::string f16_vec_type = "vec" + std::to_string(rows) + "<f16>";
+ std::string f32_vector_values;
+ std::string f16_vector_values;
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ f32_vector_values += ", ";
+ f16_vector_values += ", ";
+ }
+ f32_vector_values += f32_vec_type + "(";
+ f16_vector_values += f16_vec_type + "(";
+ for (uint32_t r = 0; r < rows; r++) {
+ if (r > 0) {
+ f32_vector_values += ", ";
+ f16_vector_values += ", ";
+ }
+ f32_vector_values += "f32(i)";
+ f16_vector_values += "f16(i)";
+ }
+ f32_vector_values += ")";
+ f16_vector_values += ")";
+ }
+
+ std::string f32_vectorized_args = "";
+ std::string f16_vectorized_args = "";
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ f32_vectorized_args += ", ";
+ f16_vectorized_args += ", ";
+ }
+ f32_vectorized_args += f32_vec_type + "(value[" + std::to_string(c) + "])";
+ f16_vectorized_args += f16_vec_type + "(value[" + std::to_string(c) + "])";
+ }
+
+ std::string tmpl = R"(
+enable f16;
+
+var<private> i : i32 = 0;
+
+fn mat_f32() -> ${f32_mat_type} {
+ i = (i + 1);
+ return ${f32_mat_type}(${f32_values});
+}
+
+fn mat_f16() -> ${f16_mat_type} {
+ i = (i + 1);
+ return ${f16_mat_type}(${f16_values});
+}
+${helper_function}
+@fragment
+fn main() {
+ let m32 : ${f32_mat_type} = ${f32_matrix_conversion};
+ let m16 : ${f16_mat_type} = ${f16_matrix_conversion};
+}
+)";
+ tmpl = utils::ReplaceAll(tmpl, "${f32_values}", f32_vector_values);
+ tmpl = utils::ReplaceAll(tmpl, "${f16_values}", f16_vector_values);
+ auto src = utils::ReplaceAll(tmpl, "${f32_matrix_conversion}", "${f32_mat_type}(mat_f16())");
+ src = utils::ReplaceAll(src, "${f16_matrix_conversion}", "${f16_mat_type}(mat_f32())");
+ src = utils::ReplaceAll(src, "${helper_function}", "");
+ src = utils::ReplaceAll(src, "${f32_mat_type}", f32_mat_type);
+ src = utils::ReplaceAll(src, "${f16_mat_type}", f16_mat_type);
+
+ auto helper_function = std::string(R"(
+fn convert_${mat_shape}_f16_f32(value : ${f16_mat_type}) -> ${f32_mat_type} {
+ return ${f32_mat_type}(${f32_vectorized_args});
+}
+
+fn convert_${mat_shape}_f32_f16(value : ${f32_mat_type}) -> ${f16_mat_type} {
+ return ${f16_mat_type}(${f16_vectorized_args});
+}
+)");
+ auto expect = utils::ReplaceAll(tmpl, "${helper_function}", helper_function);
+ expect = utils::ReplaceAll(expect, "${f32_mat_type}", f32_mat_type);
+ expect = utils::ReplaceAll(expect, "${f16_mat_type}", f16_mat_type);
+ expect = utils::ReplaceAll(expect, "${f32_matrix_conversion}",
+ "convert_${mat_shape}_f16_f32(mat_f16())");
+ expect = utils::ReplaceAll(expect, "${f16_matrix_conversion}",
+ "convert_${mat_shape}_f32_f16(mat_f32())");
+ expect = utils::ReplaceAll(expect, "${mat_shape}", mat_shape);
+ expect = utils::ReplaceAll(expect, "${f32_vectorized_args}", f32_vectorized_args);
+ expect = utils::ReplaceAll(expect, "${f16_vectorized_args}", f16_vectorized_args);
+
+ EXPECT_TRUE(ShouldRun<VectorizeMatrixConversions>(src));
+
+ auto got = Run<VectorizeMatrixConversions>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that VectorizeMatrixConversions transform will not run for matrix initializer.
+TEST_P(VectorizeMatrixConversionsTest, NonConversion_InitializerFromVectors) {
+ uint32_t cols = GetParam().first;
+ uint32_t rows = GetParam().second;
+ std::string mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
+ std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
+ std::string columns;
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ columns += ", ";
+ }
+ columns += vec_type + "()";
+ }
+
+ std::string tmpl = R"(
+@fragment
+fn main() {
+ let m = ${matrix}(${columns});
+}
+)";
+ tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
+ auto src = utils::ReplaceAll(tmpl, "${columns}", columns);
+ auto expect = src;
+
+ EXPECT_FALSE(ShouldRun<VectorizeMatrixConversions>(src));
+
+ auto got = Run<VectorizeMatrixConversions>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that VectorizeMatrixConversions transform will not run for identity matrix initializer,
+// which also take a single matrix as input.
+TEST_P(VectorizeMatrixConversionsTest, NonConversion_IdentityInitializer) {
+ uint32_t cols = GetParam().first;
+ uint32_t rows = GetParam().second;
+ std::string mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
+ std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
+ std::string columns;
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ columns += ", ";
+ }
+ columns += vec_type + "()";
+ }
+
+ std::string tmpl = R"(
+@fragment
+fn main() {
+ let m = ${matrix}(${columns});
+ let n : ${matrix} = ${matrix}(m);
+}
+)";
+ tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
+ auto src = utils::ReplaceAll(tmpl, "${columns}", columns);
+ auto expect = src;
+
+ EXPECT_FALSE(ShouldRun<VectorizeMatrixConversions>(src));
+
+ auto got = Run<VectorizeMatrixConversions>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+INSTANTIATE_TEST_SUITE_P(VectorizeMatrixConversionsTest,
+ VectorizeMatrixConversionsTest,
+ testing::Values(std::make_pair(2, 2),
+ std::make_pair(2, 3),
+ std::make_pair(2, 4),
+ std::make_pair(3, 2),
+ std::make_pair(3, 3),
+ std::make_pair(3, 4),
+ std::make_pair(4, 2),
+ std::make_pair(4, 3),
+ std::make_pair(4, 4)));
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers.cc b/src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers.cc
new file mode 100644
index 0000000..763e3da
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers.cc
@@ -0,0 +1,146 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers.h"
+
+#include <unordered_map>
+#include <utility>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/value_constructor.h"
+#include "src/tint/sem/value_expression.h"
+#include "src/tint/type/abstract_numeric.h"
+#include "src/tint/utils/map.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::VectorizeScalarMatrixInitializers);
+
+namespace tint::ast::transform {
+namespace {
+
+bool ShouldRun(const Program* program) {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (auto* call = program->Sem().Get<sem::Call>(node)) {
+ if (call->Target()->Is<sem::ValueConstructor>() && call->Type()->Is<type::Matrix>()) {
+ auto& args = call->Arguments();
+ if (!args.IsEmpty() && args[0]->Type()->UnwrapRef()->Is<type::Scalar>()) {
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+VectorizeScalarMatrixInitializers::VectorizeScalarMatrixInitializers() = default;
+
+VectorizeScalarMatrixInitializers::~VectorizeScalarMatrixInitializers() = default;
+
+Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ std::unordered_map<const type::Matrix*, Symbol> scalar_inits;
+
+ ctx.ReplaceAll([&](const CallExpression* expr) -> const CallExpression* {
+ auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
+ auto* ty_init = call->Target()->As<sem::ValueConstructor>();
+ if (!ty_init) {
+ return nullptr;
+ }
+ auto* mat_type = call->Type()->As<type::Matrix>();
+ if (!mat_type) {
+ return nullptr;
+ }
+
+ auto& args = call->Arguments();
+ if (args.IsEmpty()) {
+ return nullptr;
+ }
+
+ // If the argument type is a matrix, then this is an identity / conversion initializer.
+ // If the argument type is a vector, then we're already column vectors.
+ // If the argument type is abstract, then we're const-expression and there's no need to
+ // adjust this, as it'll be constant folded by the backend.
+ if (args[0]
+ ->Type()
+ ->UnwrapRef()
+ ->IsAnyOf<type::Matrix, type::Vector, type::AbstractNumeric>()) {
+ return nullptr;
+ }
+
+ // Constructs a matrix using vector columns, with the elements constructed using the
+ // 'element(uint32_t c, uint32_t r)' callback.
+ auto build_mat = [&](auto&& element) {
+ utils::Vector<const Expression*, 4> columns;
+ for (uint32_t c = 0; c < mat_type->columns(); c++) {
+ utils::Vector<const Expression*, 4> row_values;
+ for (uint32_t r = 0; r < mat_type->rows(); r++) {
+ row_values.Push(element(c, r));
+ }
+
+ // Construct the column vector.
+ columns.Push(b.vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(),
+ std::move(row_values)));
+ }
+ return b.Call(CreateASTTypeFor(ctx, mat_type), columns);
+ };
+
+ if (args.Length() == 1) {
+ // Generate a helper function for constructing the matrix.
+ // This is done to ensure that the single argument value is only evaluated once, and
+ // with the correct expression evaluation order.
+ auto fn = utils::GetOrCreate(scalar_inits, mat_type, [&] {
+ auto name = b.Symbols().New("build_mat" + std::to_string(mat_type->columns()) +
+ "x" + std::to_string(mat_type->rows()));
+ b.Func(name,
+ utils::Vector{
+ // Single scalar parameter
+ b.Param("value", CreateASTTypeFor(ctx, mat_type->type())),
+ },
+ CreateASTTypeFor(ctx, mat_type),
+ utils::Vector{
+ b.Return(build_mat([&](uint32_t, uint32_t) { //
+ return b.Expr("value");
+ })),
+ });
+ return name;
+ });
+ return b.Call(fn, ctx.Clone(args[0]->Declaration()));
+ }
+
+ if (TINT_LIKELY(args.Length() == mat_type->columns() * mat_type->rows())) {
+ return build_mat([&](uint32_t c, uint32_t r) {
+ return ctx.Clone(args[c * mat_type->rows() + r]->Declaration());
+ });
+ }
+
+ TINT_ICE(Transform, b.Diagnostics())
+ << "matrix initializer has unexpected number of arguments";
+ return nullptr;
+ });
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers.h b/src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers.h
new file mode 100644
index 0000000..519975d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers.h
@@ -0,0 +1,40 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_VECTORIZE_SCALAR_MATRIX_INITIALIZERS_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_VECTORIZE_SCALAR_MATRIX_INITIALIZERS_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// A transform that converts scalar matrix initializers to the vector form.
+class VectorizeScalarMatrixInitializers final
+ : public utils::Castable<VectorizeScalarMatrixInitializers, Transform> {
+ public:
+ /// Constructor
+ VectorizeScalarMatrixInitializers();
+
+ /// Destructor
+ ~VectorizeScalarMatrixInitializers() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_VECTORIZE_SCALAR_MATRIX_INITIALIZERS_H_
diff --git a/src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers_test.cc b/src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers_test.cc
new file mode 100644
index 0000000..26a8670
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers_test.cc
@@ -0,0 +1,162 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers.h"
+
+#include <string>
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+#include "src/tint/utils/string.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using VectorizeScalarMatrixInitializersTest = TransformTestWithParam<std::pair<uint32_t, uint32_t>>;
+
+TEST_F(VectorizeScalarMatrixInitializersTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<VectorizeScalarMatrixInitializers>(src));
+}
+
+TEST_P(VectorizeScalarMatrixInitializersTest, MultipleScalars) {
+ uint32_t cols = GetParam().first;
+ uint32_t rows = GetParam().second;
+ std::string mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
+ std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
+ std::string scalar_values;
+ std::string vector_values;
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ vector_values += ", ";
+ scalar_values += ", ";
+ }
+ vector_values += vec_type + "(";
+ for (uint32_t r = 0; r < rows; r++) {
+ if (r > 0) {
+ scalar_values += ", ";
+ vector_values += ", ";
+ }
+ auto value = std::to_string(c * rows + r) + ".0";
+ scalar_values += value;
+ vector_values += value;
+ }
+ vector_values += ")";
+ }
+
+ std::string tmpl = R"(
+@fragment
+fn main() {
+ let m = ${matrix}(${values});
+}
+)";
+ tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
+ auto src = utils::ReplaceAll(tmpl, "${values}", scalar_values);
+ auto expect = utils::ReplaceAll(tmpl, "${values}", vector_values);
+
+ EXPECT_TRUE(ShouldRun<VectorizeScalarMatrixInitializers>(src));
+
+ auto got = Run<VectorizeScalarMatrixInitializers>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(VectorizeScalarMatrixInitializersTest, MultipleScalarsReference) {
+ uint32_t cols = GetParam().first;
+ uint32_t rows = GetParam().second;
+ std::string mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
+ std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
+ std::string scalar_values;
+ std::string vector_values;
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ vector_values += ", ";
+ scalar_values += ", ";
+ }
+ vector_values += vec_type + "(";
+ for (uint32_t r = 0; r < rows; r++) {
+ if (r > 0) {
+ scalar_values += ", ";
+ vector_values += ", ";
+ }
+ auto value = "v[" + std::to_string((c * rows + r) % 4) + "]";
+ scalar_values += value;
+ vector_values += value;
+ }
+ vector_values += ")";
+ }
+
+ std::string tmpl = R"(
+@fragment
+fn main() {
+ let v = vec4<f32>(1.0, 2.0, 3.0, 8.0);
+ let m = ${matrix}(${values});
+}
+)";
+ tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
+ auto src = utils::ReplaceAll(tmpl, "${values}", scalar_values);
+ auto expect = utils::ReplaceAll(tmpl, "${values}", vector_values);
+
+ EXPECT_TRUE(ShouldRun<VectorizeScalarMatrixInitializers>(src));
+
+ auto got = Run<VectorizeScalarMatrixInitializers>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(VectorizeScalarMatrixInitializersTest, NonScalarInitializers) {
+ uint32_t cols = GetParam().first;
+ uint32_t rows = GetParam().second;
+ std::string mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
+ std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
+ std::string columns;
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ columns += ", ";
+ }
+ columns += vec_type + "()";
+ }
+
+ std::string tmpl = R"(
+@fragment
+fn main() {
+ let m = ${matrix}(${columns});
+}
+)";
+ tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
+ auto src = utils::ReplaceAll(tmpl, "${columns}", columns);
+ auto expect = src;
+
+ EXPECT_FALSE(ShouldRun<VectorizeScalarMatrixInitializers>(src));
+
+ auto got = Run<VectorizeScalarMatrixInitializers>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+INSTANTIATE_TEST_SUITE_P(VectorizeScalarMatrixInitializersTest,
+ VectorizeScalarMatrixInitializersTest,
+ testing::Values(std::make_pair(2, 2),
+ std::make_pair(2, 3),
+ std::make_pair(2, 4),
+ std::make_pair(3, 2),
+ std::make_pair(3, 3),
+ std::make_pair(3, 4),
+ std::make_pair(4, 2),
+ std::make_pair(4, 3),
+ std::make_pair(4, 4)));
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/vertex_pulling.cc b/src/tint/lang/wgsl/ast/transform/vertex_pulling.cc
new file mode 100644
index 0000000..3857d16
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/vertex_pulling.cc
@@ -0,0 +1,983 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/vertex_pulling.h"
+
+#include <algorithm>
+#include <utility>
+
+#include "src/tint/builtin/builtin_value.h"
+#include "src/tint/lang/wgsl/ast/assignment_statement.h"
+#include "src/tint/lang/wgsl/ast/bitcast_expression.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
+#include "src/tint/utils/compiler_macros.h"
+#include "src/tint/utils/map.h"
+#include "src/tint/utils/math.h"
+#include "src/tint/utils/string_stream.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::VertexPulling);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::VertexPulling::Config);
+
+namespace tint::ast::transform {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+namespace {
+
+/// The base WGSL type of a component.
+/// The format type is either this type or a vector of this type.
+enum class BaseWGSLType {
+ kInvalid,
+ kU32,
+ kI32,
+ kF32,
+ kF16,
+};
+
+/// The data type of a vertex format.
+/// The format type is either this type or a vector of this type.
+enum class VertexDataType {
+ kInvalid,
+ kUInt, // unsigned int
+ kSInt, // signed int
+ kFloat, // unsigned normalized, signed normalized, and float
+};
+
+/// Writes the VertexFormat to the stream.
+/// @param out the stream to write to
+/// @param format the VertexFormat to write
+/// @returns out so calls can be chained
+utils::StringStream& operator<<(utils::StringStream& out, VertexFormat format) {
+ switch (format) {
+ case VertexFormat::kUint8x2:
+ return out << "uint8x2";
+ case VertexFormat::kUint8x4:
+ return out << "uint8x4";
+ case VertexFormat::kSint8x2:
+ return out << "sint8x2";
+ case VertexFormat::kSint8x4:
+ return out << "sint8x4";
+ case VertexFormat::kUnorm8x2:
+ return out << "unorm8x2";
+ case VertexFormat::kUnorm8x4:
+ return out << "unorm8x4";
+ case VertexFormat::kSnorm8x2:
+ return out << "snorm8x2";
+ case VertexFormat::kSnorm8x4:
+ return out << "snorm8x4";
+ case VertexFormat::kUint16x2:
+ return out << "uint16x2";
+ case VertexFormat::kUint16x4:
+ return out << "uint16x4";
+ case VertexFormat::kSint16x2:
+ return out << "sint16x2";
+ case VertexFormat::kSint16x4:
+ return out << "sint16x4";
+ case VertexFormat::kUnorm16x2:
+ return out << "unorm16x2";
+ case VertexFormat::kUnorm16x4:
+ return out << "unorm16x4";
+ case VertexFormat::kSnorm16x2:
+ return out << "snorm16x2";
+ case VertexFormat::kSnorm16x4:
+ return out << "snorm16x4";
+ case VertexFormat::kFloat16x2:
+ return out << "float16x2";
+ case VertexFormat::kFloat16x4:
+ return out << "float16x4";
+ case VertexFormat::kFloat32:
+ return out << "float32";
+ case VertexFormat::kFloat32x2:
+ return out << "float32x2";
+ case VertexFormat::kFloat32x3:
+ return out << "float32x3";
+ case VertexFormat::kFloat32x4:
+ return out << "float32x4";
+ case VertexFormat::kUint32:
+ return out << "uint32";
+ case VertexFormat::kUint32x2:
+ return out << "uint32x2";
+ case VertexFormat::kUint32x3:
+ return out << "uint32x3";
+ case VertexFormat::kUint32x4:
+ return out << "uint32x4";
+ case VertexFormat::kSint32:
+ return out << "sint32";
+ case VertexFormat::kSint32x2:
+ return out << "sint32x2";
+ case VertexFormat::kSint32x3:
+ return out << "sint32x3";
+ case VertexFormat::kSint32x4:
+ return out << "sint32x4";
+ }
+ return out << "<unknown>";
+}
+
+/// Type information of a vertex input attribute.
+struct AttributeWGSLType {
+ BaseWGSLType base_type;
+ uint32_t width; // 1 for scalar, 2+ for a vector
+};
+
+/// Type information of a vertex format.
+struct VertexFormatType {
+ VertexDataType base_type;
+ uint32_t width; // 1 for scalar, 2+ for a vector
+};
+
+// Check if base types match between the WGSL variable and the vertex format
+bool IsTypeCompatible(AttributeWGSLType wgslType, VertexFormatType vertexFormatType) {
+ switch (wgslType.base_type) {
+ case BaseWGSLType::kF32:
+ case BaseWGSLType::kF16:
+ return (vertexFormatType.base_type == VertexDataType::kFloat);
+ case BaseWGSLType::kU32:
+ return (vertexFormatType.base_type == VertexDataType::kUInt);
+ case BaseWGSLType::kI32:
+ return (vertexFormatType.base_type == VertexDataType::kSInt);
+ default:
+ return false;
+ }
+}
+
+AttributeWGSLType WGSLTypeOf(const type::Type* ty) {
+ return Switch(
+ ty,
+ [](const type::I32*) -> AttributeWGSLType {
+ return {BaseWGSLType::kI32, 1};
+ },
+ [](const type::U32*) -> AttributeWGSLType {
+ return {BaseWGSLType::kU32, 1};
+ },
+ [](const type::F32*) -> AttributeWGSLType {
+ return {BaseWGSLType::kF32, 1};
+ },
+ [](const type::F16*) -> AttributeWGSLType {
+ return {BaseWGSLType::kF16, 1};
+ },
+ [](const type::Vector* vec) -> AttributeWGSLType {
+ return {WGSLTypeOf(vec->type()).base_type, vec->Width()};
+ },
+ [](Default) -> AttributeWGSLType {
+ return {BaseWGSLType::kInvalid, 0};
+ });
+}
+
+VertexFormatType VertexFormatTypeOf(VertexFormat format) {
+ switch (format) {
+ case VertexFormat::kUint32:
+ return {VertexDataType::kUInt, 1};
+ case VertexFormat::kUint8x2:
+ case VertexFormat::kUint16x2:
+ case VertexFormat::kUint32x2:
+ return {VertexDataType::kUInt, 2};
+ case VertexFormat::kUint32x3:
+ return {VertexDataType::kUInt, 3};
+ case VertexFormat::kUint8x4:
+ case VertexFormat::kUint16x4:
+ case VertexFormat::kUint32x4:
+ return {VertexDataType::kUInt, 4};
+ case VertexFormat::kSint32:
+ return {VertexDataType::kSInt, 1};
+ case VertexFormat::kSint8x2:
+ case VertexFormat::kSint16x2:
+ case VertexFormat::kSint32x2:
+ return {VertexDataType::kSInt, 2};
+ case VertexFormat::kSint32x3:
+ return {VertexDataType::kSInt, 3};
+ case VertexFormat::kSint8x4:
+ case VertexFormat::kSint16x4:
+ case VertexFormat::kSint32x4:
+ return {VertexDataType::kSInt, 4};
+ case VertexFormat::kFloat32:
+ return {VertexDataType::kFloat, 1};
+ case VertexFormat::kUnorm8x2:
+ case VertexFormat::kSnorm8x2:
+ case VertexFormat::kUnorm16x2:
+ case VertexFormat::kSnorm16x2:
+ case VertexFormat::kFloat16x2:
+ case VertexFormat::kFloat32x2:
+ return {VertexDataType::kFloat, 2};
+ case VertexFormat::kFloat32x3:
+ return {VertexDataType::kFloat, 3};
+ case VertexFormat::kUnorm8x4:
+ case VertexFormat::kSnorm8x4:
+ case VertexFormat::kUnorm16x4:
+ case VertexFormat::kSnorm16x4:
+ case VertexFormat::kFloat16x4:
+ case VertexFormat::kFloat32x4:
+ return {VertexDataType::kFloat, 4};
+ }
+ return {VertexDataType::kInvalid, 0};
+}
+
+} // namespace
+
+/// PIMPL state for the transform
+struct VertexPulling::State {
+ /// Constructor
+ /// @param program the source program
+ /// @param c the VertexPulling config
+ State(const Program* program, const VertexPulling::Config& c) : src(program), cfg(c) {}
+
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
+ // Find entry point
+ const Function* func = nullptr;
+ for (auto* fn : src->AST().Functions()) {
+ if (fn->PipelineStage() == PipelineStage::kVertex) {
+ if (func != nullptr) {
+ b.Diagnostics().add_error(
+ diag::System::Transform,
+ "VertexPulling found more than one vertex entry point");
+ return Program(std::move(b));
+ }
+ func = fn;
+ }
+ }
+ if (func == nullptr) {
+ b.Diagnostics().add_error(diag::System::Transform,
+ "Vertex stage entry point not found");
+ return Program(std::move(b));
+ }
+
+ AddVertexStorageBuffers();
+ Process(func);
+
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+
+ private:
+ /// LocationReplacement describes an Variable replacement for a location input.
+ struct LocationReplacement {
+ /// The variable to replace in the source Program
+ Variable* from;
+ /// The replacement to use in the target ProgramBuilder
+ Variable* to;
+ };
+
+ /// LocationInfo describes an input location
+ struct LocationInfo {
+ /// A builder that builds the expression that resolves to the (transformed) input location
+ std::function<const Expression*()> expr;
+ /// The store type of the location variable
+ const type::Type* type;
+ };
+
+ /// The source program
+ const Program* const src;
+ /// The transform config
+ VertexPulling::Config const cfg;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+ std::unordered_map<uint32_t, LocationInfo> location_info;
+ std::function<const Expression*()> vertex_index_expr = nullptr;
+ std::function<const Expression*()> instance_index_expr = nullptr;
+ Symbol pulling_position_name;
+ Symbol struct_buffer_name;
+ std::unordered_map<uint32_t, Symbol> vertex_buffer_names;
+ utils::Vector<const Parameter*, 8> new_function_parameters;
+
+ /// Generate the vertex buffer binding name
+ /// @param index index to append to buffer name
+ Symbol GetVertexBufferName(uint32_t index) {
+ return utils::GetOrCreate(vertex_buffer_names, index, [&] {
+ static const char kVertexBufferNamePrefix[] = "tint_pulling_vertex_buffer_";
+ return b.Symbols().New(kVertexBufferNamePrefix + std::to_string(index));
+ });
+ }
+
+ /// Lazily generates the structure buffer symbol
+ Symbol GetStructBufferName() {
+ if (!struct_buffer_name.IsValid()) {
+ static const char kStructBufferName[] = "tint_vertex_data";
+ struct_buffer_name = b.Symbols().New(kStructBufferName);
+ }
+ return struct_buffer_name;
+ }
+
+ /// Adds storage buffer decorated variables for the vertex buffers
+ void AddVertexStorageBuffers() {
+ // Creating the struct type
+ static const char kStructName[] = "TintVertexData";
+ auto* struct_type = b.Structure(b.Symbols().New(kStructName),
+ utils::Vector{
+ b.Member(GetStructBufferName(), b.ty.array<u32>()),
+ });
+ for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
+ // The decorated variable with struct type
+ b.GlobalVar(GetVertexBufferName(i), b.ty.Of(struct_type),
+ builtin::AddressSpace::kStorage, builtin::Access::kRead, b.Binding(AInt(i)),
+ b.Group(AInt(cfg.pulling_group)));
+ }
+ }
+
+ /// Creates and returns the assignment to the variables from the buffers
+ const BlockStatement* CreateVertexPullingPreamble() {
+ // Assign by looking at the vertex descriptor to find attributes with
+ // matching location.
+
+ utils::Vector<const Statement*, 8> stmts;
+
+ for (uint32_t buffer_idx = 0; buffer_idx < cfg.vertex_state.size(); ++buffer_idx) {
+ const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[buffer_idx];
+
+ if ((buffer_layout.array_stride & 3) != 0) {
+ b.Diagnostics().add_error(
+ diag::System::Transform,
+ "WebGPU requires that vertex stride must be a multiple of 4 bytes, "
+ "but VertexPulling array stride for buffer " +
+ std::to_string(buffer_idx) + " was " +
+ std::to_string(buffer_layout.array_stride) + " bytes");
+ return nullptr;
+ }
+
+ auto* index_expr = buffer_layout.step_mode == VertexStepMode::kVertex
+ ? vertex_index_expr()
+ : instance_index_expr();
+
+ // buffer_array_base is the base array offset for all the vertex
+ // attributes. These are units of uint (4 bytes).
+ auto buffer_array_base =
+ b.Symbols().New("buffer_array_base_" + std::to_string(buffer_idx));
+
+ auto* attribute_offset = index_expr;
+ if (buffer_layout.array_stride != 4) {
+ attribute_offset = b.Mul(index_expr, u32(buffer_layout.array_stride / 4u));
+ }
+
+ // let pulling_offset_n = <attribute_offset>
+ stmts.Push(b.Decl(b.Let(buffer_array_base, attribute_offset)));
+
+ for (const VertexAttributeDescriptor& attribute_desc : buffer_layout.attributes) {
+ auto it = location_info.find(attribute_desc.shader_location);
+ if (it == location_info.end()) {
+ continue;
+ }
+ auto& var = it->second;
+
+ // Data type of the target WGSL variable
+ auto var_dt = WGSLTypeOf(var.type);
+ // Data type of the vertex stream attribute
+ auto fmt_dt = VertexFormatTypeOf(attribute_desc.format);
+
+ // Base types must match between the vertex stream and the WGSL variable
+ if (!IsTypeCompatible(var_dt, fmt_dt)) {
+ utils::StringStream err;
+ err << "VertexAttributeDescriptor for location "
+ << std::to_string(attribute_desc.shader_location) << " has format "
+ << attribute_desc.format << " but shader expects "
+ << var.type->FriendlyName();
+ b.Diagnostics().add_error(diag::System::Transform, err.str());
+ return nullptr;
+ }
+
+ // Load the attribute value according to vertex format and convert the element type
+ // of result to match target WGSL variable. The result of `Fetch` should be of WGSL
+ // types `f32`, `i32`, `u32`, and their vectors, while WGSL variable can be of
+ // `f16`.
+ auto* fetch = Fetch(buffer_array_base, attribute_desc.offset, buffer_idx,
+ attribute_desc.format);
+ // Convert the fetched scalar/vector if WGSL variable is of `f16` types
+ if (var_dt.base_type == BaseWGSLType::kF16) {
+ // The type of the same element number of base type of target WGSL variable
+ Type loaded_data_target_type;
+ if (fmt_dt.width == 1) {
+ loaded_data_target_type = b.ty.f16();
+ } else {
+ loaded_data_target_type = b.ty.vec(b.ty.f16(), fmt_dt.width);
+ }
+
+ fetch = b.Call(loaded_data_target_type, fetch);
+ }
+
+ // The attribute value may not be of the desired vector width. If it is not, we'll
+ // need to either reduce the width with a swizzle, or append 0's and / or a 1.
+ auto* value = fetch;
+ if (var_dt.width < fmt_dt.width) {
+ // WGSL variable vector width is smaller than the loaded vector width
+ switch (var_dt.width) {
+ case 1:
+ value = b.MemberAccessor(fetch, "x");
+ break;
+ case 2:
+ value = b.MemberAccessor(fetch, "xy");
+ break;
+ case 3:
+ value = b.MemberAccessor(fetch, "xyz");
+ break;
+ default:
+ TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.width;
+ return nullptr;
+ }
+ } else if (var_dt.width > fmt_dt.width) {
+ // WGSL variable vector width is wider than the loaded vector width, do padding.
+
+ // The components of result vector variable, initialized with type-converted
+ // loaded data vector.
+ utils::Vector<const Expression*, 8> values{fetch};
+
+ // Add padding elements. The result must be of vector types of signed/unsigned
+ // integer or float, so use the abstract integer or abstract float value to do
+ // padding.
+ for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
+ if (var_dt.base_type == BaseWGSLType::kI32 ||
+ var_dt.base_type == BaseWGSLType::kU32) {
+ values.Push(b.Expr((i == 3) ? 1_a : 0_a));
+ } else {
+ values.Push(b.Expr((i == 3) ? 1.0_a : 0.0_a));
+ }
+ }
+
+ value = b.Call(CreateASTTypeFor(ctx, var.type), values);
+ }
+
+ // Assign the value to the WGSL variable
+ stmts.Push(b.Assign(var.expr(), value));
+ }
+ }
+
+ if (stmts.IsEmpty()) {
+ return nullptr;
+ }
+
+ return b.Block(std::move(stmts));
+ }
+
+ /// Generates an expression reading a specific vertex format from a buffer. Any vertex format of
+ /// signed normailized, unsigned normailized, or float will result in `f32` or `vecN<f32>` WGSL
+ /// type.
+ /// @param array_base the symbol of the variable holding the base array offset
+ /// of the vertex array (each index is 4-bytes).
+ /// @param offset the byte offset of the data from `buffer_base`
+ /// @param buffer the index of the vertex buffer
+ /// @param format the vertex format to read
+ const Expression* Fetch(Symbol array_base,
+ uint32_t offset,
+ uint32_t buffer,
+ VertexFormat format) {
+ // Returns a u32 loaded from buffer_base + offset.
+ auto load_u32 = [&] {
+ return LoadPrimitive(array_base, offset, buffer, VertexFormat::kUint32);
+ };
+
+ // Returns a i32 loaded from buffer_base + offset.
+ auto load_i32 = [&] { return b.Bitcast<i32>(load_u32()); };
+
+ // Returns a u32 loaded from buffer_base + offset + 4.
+ auto load_next_u32 = [&] {
+ return LoadPrimitive(array_base, offset + 4, buffer, VertexFormat::kUint32);
+ };
+
+ // Returns a i32 loaded from buffer_base + offset + 4.
+ auto load_next_i32 = [&] { return b.Bitcast<i32>(load_next_u32()); };
+
+ // Returns a u16 loaded from offset, packed in the high 16 bits of a u32.
+ // The low 16 bits are 0.
+ // `min_alignment` must be a power of two.
+ // `offset` must be `min_alignment` bytes aligned.
+ auto load_u16_h = [&] {
+ auto low_u32_offset = offset & ~3u;
+ auto* low_u32 =
+ LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32);
+ switch (offset & 3) {
+ case 0:
+ return b.Shl(low_u32, 16_u);
+ case 1:
+ return b.And(b.Shl(low_u32, 8_u), 0xffff0000_u);
+ case 2:
+ return b.And(low_u32, 0xffff0000_u);
+ default: { // 3:
+ auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
+ VertexFormat::kUint32);
+ auto* shr = b.Shr(low_u32, 8_u);
+ auto* shl = b.Shl(high_u32, 24_u);
+ return b.And(b.Or(shl, shr), 0xffff0000_u);
+ }
+ }
+ };
+
+ // Returns a u16 loaded from offset, packed in the low 16 bits of a u32.
+ // The high 16 bits are 0.
+ auto load_u16_l = [&] {
+ auto low_u32_offset = offset & ~3u;
+ auto* low_u32 =
+ LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32);
+ switch (offset & 3) {
+ case 0:
+ return b.And(low_u32, 0xffff_u);
+ case 1:
+ return b.And(b.Shr(low_u32, 8_u), 0xffff_u);
+ case 2:
+ return b.Shr(low_u32, 16_u);
+ default: { // 3:
+ auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
+ VertexFormat::kUint32);
+ auto* shr = b.Shr(low_u32, 24_u);
+ auto* shl = b.Shl(high_u32, 8_u);
+ return b.And(b.Or(shl, shr), 0xffff_u);
+ }
+ }
+ };
+
+ // Returns a i16 loaded from offset, packed in the high 16 bits of a u32.
+ // The low 16 bits are 0.
+ auto load_i16_h = [&] { return b.Bitcast<i32>(load_u16_h()); };
+
+ // Assumptions are made that alignment must be at least as large as the size
+ // of a single component.
+ switch (format) {
+ // Basic primitives
+ case VertexFormat::kUint32:
+ case VertexFormat::kSint32:
+ case VertexFormat::kFloat32:
+ return LoadPrimitive(array_base, offset, buffer, format);
+
+ // Vectors of basic primitives
+ case VertexFormat::kUint32x2:
+ return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 2);
+ case VertexFormat::kUint32x3:
+ return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 3);
+ case VertexFormat::kUint32x4:
+ return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 4);
+ case VertexFormat::kSint32x2:
+ return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 2);
+ case VertexFormat::kSint32x3:
+ return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 3);
+ case VertexFormat::kSint32x4:
+ return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 4);
+ case VertexFormat::kFloat32x2:
+ return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32,
+ 2);
+ case VertexFormat::kFloat32x3:
+ return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32,
+ 3);
+ case VertexFormat::kFloat32x4:
+ return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32,
+ 4);
+
+ case VertexFormat::kUint8x2: {
+ // yyxx0000, yyxx0000
+ auto* u16s = b.Call<vec2<u32>>(load_u16_h());
+ // xx000000, yyxx0000
+ auto* shl = b.Shl(u16s, b.Call<vec2<u32>>(8_u, 0_u));
+ // 000000xx, 000000yy
+ return b.Shr(shl, b.Call<vec2<u32>>(24_u));
+ }
+ case VertexFormat::kUint8x4: {
+ // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
+ auto* u32s = b.Call<vec4<u32>>(load_u32());
+ // xx000000, yyxx0000, zzyyxx00, wwzzyyxx
+ auto* shl = b.Shl(u32s, b.Call<vec4<u32>>(24_u, 16_u, 8_u, 0_u));
+ // 000000xx, 000000yy, 000000zz, 000000ww
+ return b.Shr(shl, b.Call<vec4<u32>>(24_u));
+ }
+ case VertexFormat::kUint16x2: {
+ // yyyyxxxx, yyyyxxxx
+ auto* u32s = b.Call<vec2<u32>>(load_u32());
+ // xxxx0000, yyyyxxxx
+ auto* shl = b.Shl(u32s, b.Call<vec2<u32>>(16_u, 0_u));
+ // 0000xxxx, 0000yyyy
+ return b.Shr(shl, b.Call<vec2<u32>>(16_u));
+ }
+ case VertexFormat::kUint16x4: {
+ // yyyyxxxx, wwwwzzzz
+ auto* u32s = b.Call<vec2<u32>>(load_u32(), load_next_u32());
+ // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
+ auto* xxyy = b.MemberAccessor(u32s, "xxyy");
+ // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
+ auto* shl = b.Shl(xxyy, b.Call<vec4<u32>>(16_u, 0_u, 16_u, 0_u));
+ // 0000xxxx, 0000yyyy, 0000zzzz, 0000wwww
+ return b.Shr(shl, b.Call<vec4<u32>>(16_u));
+ }
+ case VertexFormat::kSint8x2: {
+ // yyxx0000, yyxx0000
+ auto* i16s = b.Call<vec2<i32>>(load_i16_h());
+ // xx000000, yyxx0000
+ auto* shl = b.Shl(i16s, b.Call<vec2<u32>>(8_u, 0_u));
+ // ssssssxx, ssssssyy
+ return b.Shr(shl, b.Call<vec2<u32>>(24_u));
+ }
+ case VertexFormat::kSint8x4: {
+ // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
+ auto* i32s = b.Call<vec4<i32>>(load_i32());
+ // xx000000, yyxx0000, zzyyxx00, wwzzyyxx
+ auto* shl = b.Shl(i32s, b.Call<vec4<u32>>(24_u, 16_u, 8_u, 0_u));
+ // ssssssxx, ssssssyy, sssssszz, ssssssww
+ return b.Shr(shl, b.Call<vec4<u32>>(24_u));
+ }
+ case VertexFormat::kSint16x2: {
+ // yyyyxxxx, yyyyxxxx
+ auto* i32s = b.Call<vec2<i32>>(load_i32());
+ // xxxx0000, yyyyxxxx
+ auto* shl = b.Shl(i32s, b.Call<vec2<u32>>(16_u, 0_u));
+ // ssssxxxx, ssssyyyy
+ return b.Shr(shl, b.Call<vec2<u32>>(16_u));
+ }
+ case VertexFormat::kSint16x4: {
+ // yyyyxxxx, wwwwzzzz
+ auto* i32s = b.Call<vec2<i32>>(load_i32(), load_next_i32());
+ // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
+ auto* xxyy = b.MemberAccessor(i32s, "xxyy");
+ // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
+ auto* shl = b.Shl(xxyy, b.Call<vec4<u32>>(16_u, 0_u, 16_u, 0_u));
+ // ssssxxxx, ssssyyyy, sssszzzz, sssswwww
+ return b.Shr(shl, b.Call<vec4<u32>>(16_u));
+ }
+ case VertexFormat::kUnorm8x2:
+ return b.MemberAccessor(b.Call("unpack4x8unorm", load_u16_l()), "xy");
+ case VertexFormat::kSnorm8x2:
+ return b.MemberAccessor(b.Call("unpack4x8snorm", load_u16_l()), "xy");
+ case VertexFormat::kUnorm8x4:
+ return b.Call("unpack4x8unorm", load_u32());
+ case VertexFormat::kSnorm8x4:
+ return b.Call("unpack4x8snorm", load_u32());
+ case VertexFormat::kUnorm16x2:
+ return b.Call("unpack2x16unorm", load_u32());
+ case VertexFormat::kSnorm16x2:
+ return b.Call("unpack2x16snorm", load_u32());
+ case VertexFormat::kFloat16x2:
+ return b.Call("unpack2x16float", load_u32());
+ case VertexFormat::kUnorm16x4:
+ return b.Call<vec4<f32>>(b.Call("unpack2x16unorm", load_u32()),
+ b.Call("unpack2x16unorm", load_next_u32()));
+ case VertexFormat::kSnorm16x4:
+ return b.Call<vec4<f32>>(b.Call("unpack2x16snorm", load_u32()),
+ b.Call("unpack2x16snorm", load_next_u32()));
+ case VertexFormat::kFloat16x4:
+ return b.Call<vec4<f32>>(b.Call("unpack2x16float", load_u32()),
+ b.Call("unpack2x16float", load_next_u32()));
+ }
+
+ TINT_UNREACHABLE(Transform, b.Diagnostics()) << "format " << static_cast<int>(format);
+ return nullptr;
+ }
+
+ /// Generates an expression reading an aligned basic type (u32, i32, f32) from
+ /// a vertex buffer.
+ /// @param array_base the symbol of the variable holding the base array offset
+ /// of the vertex array (each index is 4-bytes).
+ /// @param offset the byte offset of the data from `buffer_base`
+ /// @param buffer the index of the vertex buffer
+ /// @param format VertexFormat::kUint32, VertexFormat::kSint32 or
+ /// VertexFormat::kFloat32
+ const Expression* LoadPrimitive(Symbol array_base,
+ uint32_t offset,
+ uint32_t buffer,
+ VertexFormat format) {
+ const Expression* u = nullptr;
+ if ((offset & 3) == 0) {
+ // Aligned load.
+
+ const ast ::Expression* index = nullptr;
+ if (offset > 0) {
+ index = b.Add(array_base, u32(offset / 4));
+ } else {
+ index = b.Expr(array_base);
+ }
+ u = b.IndexAccessor(
+ b.MemberAccessor(GetVertexBufferName(buffer), GetStructBufferName()), index);
+
+ } else {
+ // Unaligned load
+ uint32_t offset_aligned = offset & ~3u;
+ auto* low = LoadPrimitive(array_base, offset_aligned, buffer, VertexFormat::kUint32);
+ auto* high =
+ LoadPrimitive(array_base, offset_aligned + 4u, buffer, VertexFormat::kUint32);
+
+ uint32_t shift = 8u * (offset & 3u);
+
+ auto* low_shr = b.Shr(low, u32(shift));
+ auto* high_shl = b.Shl(high, u32(32u - shift));
+ u = b.Or(low_shr, high_shl);
+ }
+
+ switch (format) {
+ case VertexFormat::kUint32:
+ return u;
+ case VertexFormat::kSint32:
+ return b.Bitcast(b.ty.i32(), u);
+ case VertexFormat::kFloat32:
+ return b.Bitcast(b.ty.f32(), u);
+ default:
+ break;
+ }
+ TINT_UNREACHABLE(Transform, b.Diagnostics())
+ << "invalid format for LoadPrimitive" << static_cast<int>(format);
+ return nullptr;
+ }
+
+ /// Generates an expression reading a vec2/3/4 from a vertex buffer.
+ /// @param array_base the symbol of the variable holding the base array offset
+ /// of the vertex array (each index is 4-bytes).
+ /// @param offset the byte offset of the data from `buffer_base`
+ /// @param buffer the index of the vertex buffer
+ /// @param element_stride stride between elements, in bytes
+ /// @param base_type underlying AST type
+ /// @param base_format underlying vertex format
+ /// @param count how many elements the vector has
+ const Expression* LoadVec(Symbol array_base,
+ uint32_t offset,
+ uint32_t buffer,
+ uint32_t element_stride,
+ Type base_type,
+ VertexFormat base_format,
+ uint32_t count) {
+ utils::Vector<const Expression*, 8> expr_list;
+ for (uint32_t i = 0; i < count; ++i) {
+ // Offset read position by element_stride for each component
+ uint32_t primitive_offset = offset + element_stride * i;
+ expr_list.Push(LoadPrimitive(array_base, primitive_offset, buffer, base_format));
+ }
+
+ return b.Call(b.ty.vec(base_type, count), std::move(expr_list));
+ }
+
+ /// Process a non-struct entry point parameter.
+ /// Generate function-scope variables for location parameters, and record
+ /// vertex_index and instance_index builtins if present.
+ /// @param func the entry point function
+ /// @param param the parameter to process
+ void ProcessNonStructParameter(const Function* func, const Parameter* param) {
+ if (HasAttribute<LocationAttribute>(param->attributes)) {
+ // Create a function-scope variable to replace the parameter.
+ auto func_var_sym = ctx.Clone(param->name->symbol);
+ auto func_var_type = ctx.Clone(param->type);
+ auto* func_var = b.Var(func_var_sym, func_var_type);
+ ctx.InsertFront(func->body->statements, b.Decl(func_var));
+ // Capture mapping from location to the new variable.
+ LocationInfo info;
+ info.expr = [this, func_var] { return b.Expr(func_var); };
+
+ auto* sem = src->Sem().Get<sem::Parameter>(param);
+ info.type = sem->Type();
+
+ if (TINT_UNLIKELY(!sem->Location().has_value())) {
+ TINT_ICE(Transform, b.Diagnostics()) << "Location missing value";
+ return;
+ }
+ location_info[sem->Location().value()] = info;
+ } else {
+ auto* builtin_attr = GetAttribute<BuiltinAttribute>(param->attributes);
+ if (TINT_UNLIKELY(!builtin_attr)) {
+ TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
+ return;
+ }
+ auto builtin = src->Sem().Get(builtin_attr)->Value();
+ // Check for existing vertex_index and instance_index builtins.
+ if (builtin == builtin::BuiltinValue::kVertexIndex) {
+ vertex_index_expr = [this, param] {
+ return b.Expr(ctx.Clone(param->name->symbol));
+ };
+ } else if (builtin == builtin::BuiltinValue::kInstanceIndex) {
+ instance_index_expr = [this, param] {
+ return b.Expr(ctx.Clone(param->name->symbol));
+ };
+ }
+ new_function_parameters.Push(ctx.Clone(param));
+ }
+ }
+
+ /// Process a struct entry point parameter.
+ /// If the struct has members with location attributes, push the parameter to
+ /// a function-scope variable and create a new struct parameter without those
+ /// attributes. Record expressions for members that are vertex_index and
+ /// instance_index builtins.
+ /// @param func the entry point function
+ /// @param param the parameter to process
+ /// @param struct_ty the structure type
+ void ProcessStructParameter(const Function* func,
+ const Parameter* param,
+ const Struct* struct_ty) {
+ auto param_sym = ctx.Clone(param->name->symbol);
+
+ // Process the struct members.
+ bool has_locations = false;
+ utils::Vector<const StructMember*, 8> members_to_clone;
+ for (auto* member : struct_ty->members) {
+ auto member_sym = ctx.Clone(member->name->symbol);
+ std::function<const Expression*()> member_expr = [this, param_sym, member_sym] {
+ return b.MemberAccessor(param_sym, member_sym);
+ };
+
+ if (HasAttribute<LocationAttribute>(member->attributes)) {
+ // Capture mapping from location to struct member.
+ LocationInfo info;
+ info.expr = member_expr;
+
+ auto* sem = src->Sem().Get(member);
+ info.type = sem->Type();
+
+ TINT_ASSERT(Transform, sem->Attributes().location.has_value());
+ location_info[sem->Attributes().location.value()] = info;
+ has_locations = true;
+ } else {
+ auto* builtin_attr = GetAttribute<BuiltinAttribute>(member->attributes);
+ if (TINT_UNLIKELY(!builtin_attr)) {
+ TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
+ return;
+ }
+ auto builtin = src->Sem().Get(builtin_attr)->Value();
+ // Check for existing vertex_index and instance_index builtins.
+ if (builtin == builtin::BuiltinValue::kVertexIndex) {
+ vertex_index_expr = member_expr;
+ } else if (builtin == builtin::BuiltinValue::kInstanceIndex) {
+ instance_index_expr = member_expr;
+ }
+ members_to_clone.Push(member);
+ }
+ }
+
+ if (!has_locations) {
+ // Nothing to do.
+ new_function_parameters.Push(ctx.Clone(param));
+ return;
+ }
+
+ // Create a function-scope variable to replace the parameter.
+ auto* func_var = b.Var(param_sym, ctx.Clone(param->type));
+ ctx.InsertFront(func->body->statements, b.Decl(func_var));
+
+ if (!members_to_clone.IsEmpty()) {
+ // Create a new struct without the location attributes.
+ utils::Vector<const StructMember*, 8> new_members;
+ for (auto* member : members_to_clone) {
+ auto member_name = ctx.Clone(member->name);
+ auto member_type = ctx.Clone(member->type);
+ auto member_attrs = ctx.Clone(member->attributes);
+ new_members.Push(b.Member(member_name, member_type, std::move(member_attrs)));
+ }
+ auto* new_struct = b.Structure(b.Sym(), new_members);
+
+ // Create a new function parameter with this struct.
+ auto* new_param = b.Param(b.Sym(), b.ty.Of(new_struct));
+ new_function_parameters.Push(new_param);
+
+ // Copy values from the new parameter to the function-scope variable.
+ for (auto* member : members_to_clone) {
+ auto member_name = ctx.Clone(member->name->symbol);
+ ctx.InsertFront(func->body->statements,
+ b.Assign(b.MemberAccessor(func_var, member_name),
+ b.MemberAccessor(new_param, member_name)));
+ }
+ }
+ }
+
+ /// Process an entry point function.
+ /// @param func the entry point function
+ void Process(const Function* func) {
+ if (func->body->Empty()) {
+ return;
+ }
+
+ // Process entry point parameters.
+ for (auto* param : func->params) {
+ auto* sem = src->Sem().Get(param);
+ if (auto* str = sem->Type()->As<sem::Struct>()) {
+ ProcessStructParameter(func, param, str->Declaration());
+ } else {
+ ProcessNonStructParameter(func, param);
+ }
+ }
+
+ // Insert new parameters for vertex_index and instance_index if needed.
+ if (!vertex_index_expr) {
+ for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
+ if (layout.step_mode == VertexStepMode::kVertex) {
+ auto name = b.Symbols().New("tint_pulling_vertex_index");
+ new_function_parameters.Push(
+ b.Param(name, b.ty.u32(),
+ utils::Vector{b.Builtin(builtin::BuiltinValue::kVertexIndex)}));
+ vertex_index_expr = [this, name] { return b.Expr(name); };
+ break;
+ }
+ }
+ }
+ if (!instance_index_expr) {
+ for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
+ if (layout.step_mode == VertexStepMode::kInstance) {
+ auto name = b.Symbols().New("tint_pulling_instance_index");
+ new_function_parameters.Push(
+ b.Param(name, b.ty.u32(),
+ utils::Vector{b.Builtin(builtin::BuiltinValue::kInstanceIndex)}));
+ instance_index_expr = [this, name] { return b.Expr(name); };
+ break;
+ }
+ }
+ }
+
+ // Generate vertex pulling preamble.
+ if (auto* block = CreateVertexPullingPreamble()) {
+ ctx.InsertFront(func->body->statements, block);
+ }
+
+ // Rewrite the function header with the new parameters.
+ auto func_sym = ctx.Clone(func->name->symbol);
+ auto ret_type = ctx.Clone(func->return_type);
+ auto* body = ctx.Clone(func->body);
+ auto attrs = ctx.Clone(func->attributes);
+ auto ret_attrs = ctx.Clone(func->return_type_attributes);
+ auto* new_func =
+ b.create<Function>(func->source, b.Ident(func_sym), new_function_parameters, ret_type,
+ body, std::move(attrs), std::move(ret_attrs));
+ ctx.Replace(func, new_func);
+ }
+};
+
+VertexPulling::VertexPulling() = default;
+VertexPulling::~VertexPulling() = default;
+
+Transform::ApplyResult VertexPulling::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap&) const {
+ auto cfg = cfg_;
+ if (auto* cfg_data = inputs.Get<Config>()) {
+ cfg = *cfg_data;
+ }
+
+ return State{src, cfg}.Run();
+}
+
+VertexPulling::Config::Config() = default;
+VertexPulling::Config::Config(const Config&) = default;
+VertexPulling::Config::~Config() = default;
+VertexPulling::Config& VertexPulling::Config::operator=(const Config&) = default;
+
+VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;
+
+VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
+ uint32_t in_array_stride,
+ VertexStepMode in_step_mode,
+ std::vector<VertexAttributeDescriptor> in_attributes)
+ : array_stride(in_array_stride),
+ step_mode(in_step_mode),
+ attributes(std::move(in_attributes)) {}
+
+VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
+ const VertexBufferLayoutDescriptor& other) = default;
+
+VertexBufferLayoutDescriptor& VertexBufferLayoutDescriptor::operator=(
+ const VertexBufferLayoutDescriptor& other) = default;
+
+VertexBufferLayoutDescriptor::~VertexBufferLayoutDescriptor() = default;
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/vertex_pulling.h b/src/tint/lang/wgsl/ast/transform/vertex_pulling.h
new file mode 100644
index 0000000..b10ea62
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/vertex_pulling.h
@@ -0,0 +1,187 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_VERTEX_PULLING_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_VERTEX_PULLING_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+#include "src/tint/reflection.h"
+
+namespace tint::ast::transform {
+
+/// Describes the format of data in a vertex buffer
+enum class VertexFormat {
+ kUint8x2, // uint8x2
+ kUint8x4, // uint8x4
+ kSint8x2, // sint8x2
+ kSint8x4, // sint8x4
+ kUnorm8x2, // unorm8x2
+ kUnorm8x4, // unorm8x4
+ kSnorm8x2, // snorm8x2
+ kSnorm8x4, // snorm8x4
+ kUint16x2, // uint16x2
+ kUint16x4, // uint16x4
+ kSint16x2, // sint16x2
+ kSint16x4, // sint16x4
+ kUnorm16x2, // unorm16x2
+ kUnorm16x4, // unorm16x4
+ kSnorm16x2, // snorm16x2
+ kSnorm16x4, // snorm16x4
+ kFloat16x2, // float16x2
+ kFloat16x4, // float16x4
+ kFloat32, // float32
+ kFloat32x2, // float32x2
+ kFloat32x3, // float32x3
+ kFloat32x4, // float32x4
+ kUint32, // uint32
+ kUint32x2, // uint32x2
+ kUint32x3, // uint32x3
+ kUint32x4, // uint32x4
+ kSint32, // sint32
+ kSint32x2, // sint32x2
+ kSint32x3, // sint32x3
+ kSint32x4, // sint32x4
+
+ kLastEntry = kSint32x4,
+};
+
+/// Describes if a vertex attributes increments with vertex index or instance
+/// index
+enum class VertexStepMode { kVertex, kInstance, kLastEntry = kInstance };
+
+/// Describes a vertex attribute within a buffer
+struct VertexAttributeDescriptor {
+ /// The format of the attribute
+ VertexFormat format;
+ /// The byte offset of the attribute in the buffer
+ uint32_t offset;
+ /// The shader location used for the attribute
+ uint32_t shader_location;
+
+ /// Reflect the fields of this class so that it can be used by tint::ForeachField()
+ TINT_REFLECT(format, offset, shader_location);
+};
+
+/// Describes a buffer containing multiple vertex attributes
+struct VertexBufferLayoutDescriptor {
+ /// Constructor
+ VertexBufferLayoutDescriptor();
+ /// Constructor
+ /// @param in_array_stride the array stride of the in buffer
+ /// @param in_step_mode the step mode of the in buffer
+ /// @param in_attributes the in attributes
+ VertexBufferLayoutDescriptor(uint32_t in_array_stride,
+ VertexStepMode in_step_mode,
+ std::vector<VertexAttributeDescriptor> in_attributes);
+ /// Copy constructor
+ /// @param other the struct to copy
+ VertexBufferLayoutDescriptor(const VertexBufferLayoutDescriptor& other);
+
+ /// Assignment operator
+ /// @param other the struct to copy
+ /// @returns this struct
+ VertexBufferLayoutDescriptor& operator=(const VertexBufferLayoutDescriptor& other);
+
+ ~VertexBufferLayoutDescriptor();
+
+ /// The array stride used in the in buffer
+ uint32_t array_stride = 0u;
+ /// The input step mode used
+ VertexStepMode step_mode = VertexStepMode::kVertex;
+ /// The vertex attributes
+ std::vector<VertexAttributeDescriptor> attributes;
+
+ /// Reflect the fields of this class so that it can be used by tint::ForeachField()
+ TINT_REFLECT(array_stride, step_mode, attributes);
+};
+
+/// Describes vertex state, which consists of many buffers containing vertex
+/// attributes
+using VertexStateDescriptor = std::vector<VertexBufferLayoutDescriptor>;
+
+/// Converts a program to use vertex pulling
+///
+/// Variables which accept vertex input are var<in> with a location attribute.
+/// This transform will convert those to be assigned from storage buffers
+/// instead. The intention is to allow vertex input to rely on a storage buffer
+/// clamping pass for out of bounds reads. We bind the storage buffers as arrays
+/// of u32, so any read to byte position `p` will actually need to read position
+/// `p / 4`, since `sizeof(u32) == 4`.
+///
+/// `VertexFormat` represents the input type of the attribute. This isn't
+/// related to the type of the variable in the shader. For example,
+/// `VertexFormat::kVec2F16` tells us that the buffer will contain `f16`
+/// elements, to be read as vec2. In the shader, a user would make a `vec2<f32>`
+/// to be able to use them. The conversion between `f16` and `f32` will need to
+/// be handled by us (using unpack functions).
+///
+/// To be clear, there won't be types such as `f16` or `u8` anywhere in WGSL
+/// code, but these are types that the data may arrive as. We need to convert
+/// these smaller types into the base types such as `f32` and `u32` for the
+/// shader to use.
+///
+/// The SingleEntryPoint transform must have run before VertexPulling.
+class VertexPulling final : public utils::Castable<VertexPulling, Transform> {
+ public:
+ /// Configuration options for the transform
+ struct Config final : public utils::Castable<Config, Data> {
+ /// Constructor
+ Config();
+
+ /// Copy constructor
+ Config(const Config&);
+
+ /// Destructor
+ ~Config() override;
+
+ /// Assignment operator
+ /// @returns this Config
+ Config& operator=(const Config&);
+
+ /// The vertex state descriptor, containing info about attributes
+ VertexStateDescriptor vertex_state;
+
+ /// The "group" we will put all our vertex buffers into (as storage buffers)
+ /// Default to 4 as it is past the limits of user-accessible groups
+ uint32_t pulling_group = 4u;
+
+ /// Reflect the fields of this class so that it can be used by tint::ForeachField()
+ TINT_REFLECT(vertex_state, pulling_group);
+ };
+
+ /// Constructor
+ VertexPulling();
+
+ /// Destructor
+ ~VertexPulling() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+
+ Config cfg_;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_VERTEX_PULLING_H_
diff --git a/src/tint/lang/wgsl/ast/transform/vertex_pulling_test.cc b/src/tint/lang/wgsl/ast/transform/vertex_pulling_test.cc
new file mode 100644
index 0000000..16e1a1f
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/vertex_pulling_test.cc
@@ -0,0 +1,2326 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/vertex_pulling.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using VertexPullingTest = TransformTest;
+
+TEST_F(VertexPullingTest, Error_NoEntryPoint) {
+ auto* src = "";
+
+ auto* expect = "error: Vertex stage entry point not found";
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>();
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, Error_MultipleEntryPoint) {
+ auto* src = R"(
+@vertex
+fn main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+@vertex
+fn main2() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = "error: VertexPulling found more than one vertex entry point";
+
+ VertexPulling::Config cfg;
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, Error_EntryPointWrongStage) {
+ auto* src = R"(
+@fragment
+fn main() {}
+)";
+
+ auto* expect = "error: Vertex stage entry point not found";
+
+ VertexPulling::Config cfg;
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, Error_BadStride) {
+ auto* src = R"(
+@vertex
+fn main(@location(0) var_a : f32) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(var_a, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect =
+ "error: WebGPU requires that vertex stride must be a multiple of 4 "
+ "bytes, but VertexPulling array stride for buffer 0 was 15 bytes";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{15, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, BasicModule) {
+ auto* src = R"(
+@vertex
+fn main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@vertex
+fn main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ VertexPulling::Config cfg;
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, OneAttribute) {
+ auto* src = R"(
+@vertex
+fn main(@location(0) var_a : f32) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(var_a, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var var_a : f32;
+ {
+ let buffer_array_base_0 = tint_pulling_vertex_index;
+ var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[buffer_array_base_0]);
+ }
+ return vec4<f32>(var_a, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, OneInstancedAttribute) {
+ auto* src = R"(
+@vertex
+fn main(@location(0) var_a : f32) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(var_a, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(instance_index) tint_pulling_instance_index : u32) -> @builtin(position) vec4<f32> {
+ var var_a : f32;
+ {
+ let buffer_array_base_0 = tint_pulling_instance_index;
+ var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[buffer_array_base_0]);
+ }
+ return vec4<f32>(var_a, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{4, VertexStepMode::kInstance, {{VertexFormat::kFloat32, 0, 0}}}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) {
+ auto* src = R"(
+@vertex
+fn main(@location(0) var_a : f32) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(var_a, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(5) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var var_a : f32;
+ {
+ let buffer_array_base_0 = tint_pulling_vertex_index;
+ var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[buffer_array_base_0]);
+ }
+ return vec4<f32>(var_a, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}};
+ cfg.pulling_group = 5;
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, OneAttribute_Struct) {
+ auto* src = R"(
+struct Inputs {
+ @location(0) var_a : f32,
+};
+
+@vertex
+fn main(inputs : Inputs) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(inputs.var_a, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+struct Inputs {
+ @location(0)
+ var_a : f32,
+}
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var inputs : Inputs;
+ {
+ let buffer_array_base_0 = tint_pulling_vertex_index;
+ inputs.var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[buffer_array_base_0]);
+ }
+ return vec4<f32>(inputs.var_a, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// We expect the transform to use an existing builtin variables if it finds them
+TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) {
+ auto* src = R"(
+@vertex
+fn main(@location(0) var_a : f32,
+ @location(1) var_b : f32,
+ @builtin(vertex_index) custom_vertex_index : u32,
+ @builtin(instance_index) custom_instance_index : u32
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(var_a, var_b, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@binding(1) @group(4) var<storage, read> tint_pulling_vertex_buffer_1 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) custom_vertex_index : u32, @builtin(instance_index) custom_instance_index : u32) -> @builtin(position) vec4<f32> {
+ var var_a : f32;
+ var var_b : f32;
+ {
+ let buffer_array_base_0 = custom_vertex_index;
+ var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[buffer_array_base_0]);
+ let buffer_array_base_1 = custom_instance_index;
+ var_b = bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[buffer_array_base_1]);
+ }
+ return vec4<f32>(var_a, var_b, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {
+ 4,
+ VertexStepMode::kVertex,
+ {{VertexFormat::kFloat32, 0, 0}},
+ },
+ {
+ 4,
+ VertexStepMode::kInstance,
+ {{VertexFormat::kFloat32, 0, 1}},
+ },
+ }};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_Struct) {
+ auto* src = R"(
+struct Inputs {
+ @location(0) var_a : f32,
+ @location(1) var_b : f32,
+ @builtin(vertex_index) custom_vertex_index : u32,
+ @builtin(instance_index) custom_instance_index : u32,
+};
+
+@vertex
+fn main(inputs : Inputs) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@binding(1) @group(4) var<storage, read> tint_pulling_vertex_buffer_1 : TintVertexData;
+
+struct tint_symbol {
+ @builtin(vertex_index)
+ custom_vertex_index : u32,
+ @builtin(instance_index)
+ custom_instance_index : u32,
+}
+
+struct Inputs {
+ @location(0)
+ var_a : f32,
+ @location(1)
+ var_b : f32,
+ @builtin(vertex_index)
+ custom_vertex_index : u32,
+ @builtin(instance_index)
+ custom_instance_index : u32,
+}
+
+@vertex
+fn main(tint_symbol_1 : tint_symbol) -> @builtin(position) vec4<f32> {
+ var inputs : Inputs;
+ inputs.custom_vertex_index = tint_symbol_1.custom_vertex_index;
+ inputs.custom_instance_index = tint_symbol_1.custom_instance_index;
+ {
+ let buffer_array_base_0 = inputs.custom_vertex_index;
+ inputs.var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[buffer_array_base_0]);
+ let buffer_array_base_1 = inputs.custom_instance_index;
+ inputs.var_b = bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[buffer_array_base_1]);
+ }
+ return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {
+ 4,
+ VertexStepMode::kVertex,
+ {{VertexFormat::kFloat32, 0, 0}},
+ },
+ {
+ 4,
+ VertexStepMode::kInstance,
+ {{VertexFormat::kFloat32, 0, 1}},
+ },
+ }};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_Struct_OutOfOrder) {
+ auto* src = R"(
+@vertex
+fn main(inputs : Inputs) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
+}
+
+struct Inputs {
+ @location(0) var_a : f32,
+ @location(1) var_b : f32,
+ @builtin(vertex_index) custom_vertex_index : u32,
+ @builtin(instance_index) custom_instance_index : u32,
+};
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@binding(1) @group(4) var<storage, read> tint_pulling_vertex_buffer_1 : TintVertexData;
+
+struct tint_symbol {
+ @builtin(vertex_index)
+ custom_vertex_index : u32,
+ @builtin(instance_index)
+ custom_instance_index : u32,
+}
+
+@vertex
+fn main(tint_symbol_1 : tint_symbol) -> @builtin(position) vec4<f32> {
+ var inputs : Inputs;
+ inputs.custom_vertex_index = tint_symbol_1.custom_vertex_index;
+ inputs.custom_instance_index = tint_symbol_1.custom_instance_index;
+ {
+ let buffer_array_base_0 = inputs.custom_vertex_index;
+ inputs.var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[buffer_array_base_0]);
+ let buffer_array_base_1 = inputs.custom_instance_index;
+ inputs.var_b = bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[buffer_array_base_1]);
+ }
+ return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
+}
+
+struct Inputs {
+ @location(0)
+ var_a : f32,
+ @location(1)
+ var_b : f32,
+ @builtin(vertex_index)
+ custom_vertex_index : u32,
+ @builtin(instance_index)
+ custom_instance_index : u32,
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {
+ 4,
+ VertexStepMode::kVertex,
+ {{VertexFormat::kFloat32, 0, 0}},
+ },
+ {
+ 4,
+ VertexStepMode::kInstance,
+ {{VertexFormat::kFloat32, 0, 1}},
+ },
+ }};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_SeparateStruct) {
+ auto* src = R"(
+struct Inputs {
+ @location(0) var_a : f32,
+ @location(1) var_b : f32,
+};
+
+struct Indices {
+ @builtin(vertex_index) custom_vertex_index : u32,
+ @builtin(instance_index) custom_instance_index : u32,
+};
+
+@vertex
+fn main(inputs : Inputs, indices : Indices) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@binding(1) @group(4) var<storage, read> tint_pulling_vertex_buffer_1 : TintVertexData;
+
+struct Inputs {
+ @location(0)
+ var_a : f32,
+ @location(1)
+ var_b : f32,
+}
+
+struct Indices {
+ @builtin(vertex_index)
+ custom_vertex_index : u32,
+ @builtin(instance_index)
+ custom_instance_index : u32,
+}
+
+@vertex
+fn main(indices : Indices) -> @builtin(position) vec4<f32> {
+ var inputs : Inputs;
+ {
+ let buffer_array_base_0 = indices.custom_vertex_index;
+ inputs.var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[buffer_array_base_0]);
+ let buffer_array_base_1 = indices.custom_instance_index;
+ inputs.var_b = bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[buffer_array_base_1]);
+ }
+ return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {
+ 4,
+ VertexStepMode::kVertex,
+ {{VertexFormat::kFloat32, 0, 0}},
+ },
+ {
+ 4,
+ VertexStepMode::kInstance,
+ {{VertexFormat::kFloat32, 0, 1}},
+ },
+ }};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_SeparateStruct_OutOfOrder) {
+ auto* src = R"(
+@vertex
+fn main(inputs : Inputs, indices : Indices) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
+}
+
+struct Inputs {
+ @location(0) var_a : f32,
+ @location(1) var_b : f32,
+};
+
+struct Indices {
+ @builtin(vertex_index) custom_vertex_index : u32,
+ @builtin(instance_index) custom_instance_index : u32,
+};
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@binding(1) @group(4) var<storage, read> tint_pulling_vertex_buffer_1 : TintVertexData;
+
+@vertex
+fn main(indices : Indices) -> @builtin(position) vec4<f32> {
+ var inputs : Inputs;
+ {
+ let buffer_array_base_0 = indices.custom_vertex_index;
+ inputs.var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[buffer_array_base_0]);
+ let buffer_array_base_1 = indices.custom_instance_index;
+ inputs.var_b = bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[buffer_array_base_1]);
+ }
+ return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
+}
+
+struct Inputs {
+ @location(0)
+ var_a : f32,
+ @location(1)
+ var_b : f32,
+}
+
+struct Indices {
+ @builtin(vertex_index)
+ custom_vertex_index : u32,
+ @builtin(instance_index)
+ custom_instance_index : u32,
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {
+ 4,
+ VertexStepMode::kVertex,
+ {{VertexFormat::kFloat32, 0, 0}},
+ },
+ {
+ 4,
+ VertexStepMode::kInstance,
+ {{VertexFormat::kFloat32, 0, 1}},
+ },
+ }};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, TwoAttributesSameBuffer) {
+ auto* src = R"(
+@vertex
+fn main(@location(0) var_a : f32,
+ @location(1) var_b : vec4<f32>) -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var var_a : f32;
+ var var_b : vec4<f32>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 4u);
+ var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[buffer_array_base_0]);
+ var_b = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[buffer_array_base_0]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 1u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 2u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 3u)]));
+ }
+ return vec4<f32>();
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{16,
+ VertexStepMode::kVertex,
+ {{VertexFormat::kFloat32, 0, 0}, {VertexFormat::kFloat32x4, 0, 1}}}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FloatVectorAttributes_F32) {
+ auto* src = R"(
+@vertex
+fn main(@location(0) var_a : vec2<f32>,
+ @location(1) var_b : vec3<f32>,
+ @location(2) var_c : vec4<f32>
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@binding(1) @group(4) var<storage, read> tint_pulling_vertex_buffer_1 : TintVertexData;
+
+@binding(2) @group(4) var<storage, read> tint_pulling_vertex_buffer_2 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var var_a : vec2<f32>;
+ var var_b : vec3<f32>;
+ var var_c : vec4<f32>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 2u);
+ var_a = vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[buffer_array_base_0]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 1u)]));
+ let buffer_array_base_1 = (tint_pulling_vertex_index * 3u);
+ var_b = vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[buffer_array_base_1]), bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[(buffer_array_base_1 + 1u)]), bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[(buffer_array_base_1 + 2u)]));
+ let buffer_array_base_2 = (tint_pulling_vertex_index * 4u);
+ var_c = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[buffer_array_base_2]), bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[(buffer_array_base_2 + 1u)]), bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[(buffer_array_base_2 + 2u)]), bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[(buffer_array_base_2 + 3u)]));
+ }
+ return vec4<f32>();
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {8, VertexStepMode::kVertex, {{VertexFormat::kFloat32x2, 0, 0}}},
+ {12, VertexStepMode::kVertex, {{VertexFormat::kFloat32x3, 0, 1}}},
+ {16, VertexStepMode::kVertex, {{VertexFormat::kFloat32x4, 0, 2}}},
+ }};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FloatVectorAttributes_F16) {
+ auto* src = R"(
+enable f16;
+
+@vertex
+fn main(@location(0) var_a : vec2<f16>,
+ @location(1) var_b : vec3<f16>,
+ @location(2) var_c : vec4<f16>
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@binding(1) @group(4) var<storage, read> tint_pulling_vertex_buffer_1 : TintVertexData;
+
+@binding(2) @group(4) var<storage, read> tint_pulling_vertex_buffer_2 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var var_a : vec2<f16>;
+ var var_b : vec3<f16>;
+ var var_c : vec4<f16>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 2u);
+ var_a = vec2<f16>(vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[buffer_array_base_0]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 1u)])));
+ let buffer_array_base_1 = (tint_pulling_vertex_index * 3u);
+ var_b = vec3<f16>(vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[buffer_array_base_1]), bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[(buffer_array_base_1 + 1u)]), bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[(buffer_array_base_1 + 2u)])));
+ let buffer_array_base_2 = (tint_pulling_vertex_index * 4u);
+ var_c = vec4<f16>(vec4<f32>(unpack2x16float(tint_pulling_vertex_buffer_2.tint_vertex_data[buffer_array_base_2]), unpack2x16float(tint_pulling_vertex_buffer_2.tint_vertex_data[(buffer_array_base_2 + 1u)])));
+ }
+ return vec4<f32>();
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {8, VertexStepMode::kVertex, {{VertexFormat::kFloat32x2, 0, 0}}},
+ {12, VertexStepMode::kVertex, {{VertexFormat::kFloat32x3, 0, 1}}},
+ {16, VertexStepMode::kVertex, {{VertexFormat::kFloat16x4, 0, 2}}},
+ }};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, AttemptSymbolCollision) {
+ auto* src = R"(
+@vertex
+fn main(@location(0) var_a : f32,
+ @location(1) var_b : vec4<f32>) -> @builtin(position) vec4<f32> {
+ var tint_pulling_vertex_index : i32;
+ var tint_pulling_vertex_buffer_0 : i32;
+ var tint_vertex_data : i32;
+ var tint_pulling_pos : i32;
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data_1 : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0_1 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index_1 : u32) -> @builtin(position) vec4<f32> {
+ var var_a : f32;
+ var var_b : vec4<f32>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index_1 * 4u);
+ var_a = bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[buffer_array_base_0]);
+ var_b = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[buffer_array_base_0]), bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[(buffer_array_base_0 + 1u)]), bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[(buffer_array_base_0 + 2u)]), bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[(buffer_array_base_0 + 3u)]));
+ }
+ var tint_pulling_vertex_index : i32;
+ var tint_pulling_vertex_buffer_0 : i32;
+ var tint_vertex_data : i32;
+ var tint_pulling_pos : i32;
+ return vec4<f32>();
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{16,
+ VertexStepMode::kVertex,
+ {{VertexFormat::kFloat32, 0, 0}, {VertexFormat::kFloat32x4, 0, 1}}}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, std::move(data));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FormatsAligned_SInt) {
+ auto* src = R"(
+@vertex
+fn main(
+ @location(0) sint8x2 : vec2<i32>,
+ @location(1) sint8x4 : vec4<i32>,
+ @location(2) sint16x2 : vec2<i32>,
+ @location(3) sint16x4 : vec4<i32>,
+ @location(4) sint32 : i32,
+ @location(5) sint32x2 : vec2<i32>,
+ @location(6) sint32x3 : vec3<i32>,
+ @location(7) sint32x4 : vec4<i32>
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var sint8x2 : vec2<i32>;
+ var sint8x4 : vec4<i32>;
+ var sint16x2 : vec2<i32>;
+ var sint16x4 : vec4<i32>;
+ var sint32 : i32;
+ var sint32x2 : vec2<i32>;
+ var sint32x3 : vec3<i32>;
+ var sint32x4 : vec4<i32>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ sint8x2 = ((vec2<i32>(bitcast<i32>((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 16u))) << vec2<u32>(8u, 0u)) >> vec2<u32>(24u));
+ sint8x4 = ((vec4<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])) << vec4<u32>(24u, 16u, 8u, 0u)) >> vec4<u32>(24u));
+ sint16x2 = ((vec2<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])) << vec2<u32>(16u, 0u)) >> vec2<u32>(16u));
+ sint16x4 = ((vec2<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])).xxyy << vec4<u32>(16u, 0u, 16u, 0u)) >> vec4<u32>(16u));
+ sint32 = bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]);
+ sint32x2 = vec2<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]));
+ sint32x3 = vec3<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]));
+ sint32x4 = vec4<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)]));
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kSint8x2, 64, 0},
+ {VertexFormat::kSint8x4, 64, 1},
+ {VertexFormat::kSint16x2, 64, 2},
+ {VertexFormat::kSint16x4, 64, 3},
+ {VertexFormat::kSint32, 64, 4},
+ {VertexFormat::kSint32x2, 64, 5},
+ {VertexFormat::kSint32x3, 64, 6},
+ {VertexFormat::kSint32x4, 64, 7},
+ }}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FormatsAligned_UInt) {
+ auto* src = R"(
+@vertex
+fn main(
+ @location(0) uint8x2 : vec2<u32>,
+ @location(1) uint8x4 : vec4<u32>,
+ @location(2) uint16x2 : vec2<u32>,
+ @location(3) uint16x4 : vec4<u32>,
+ @location(4) uint32 : u32,
+ @location(5) uint32x2 : vec2<u32>,
+ @location(6) uint32x3 : vec3<u32>,
+ @location(7) uint32x4 : vec4<u32>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var uint8x2 : vec2<u32>;
+ var uint8x4 : vec4<u32>;
+ var uint16x2 : vec2<u32>;
+ var uint16x4 : vec4<u32>;
+ var uint32 : u32;
+ var uint32x2 : vec2<u32>;
+ var uint32x3 : vec3<u32>;
+ var uint32x4 : vec4<u32>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ uint8x2 = ((vec2<u32>((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 16u)) << vec2<u32>(8u, 0u)) >> vec2<u32>(24u));
+ uint8x4 = ((vec4<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]) << vec4<u32>(24u, 16u, 8u, 0u)) >> vec4<u32>(24u));
+ uint16x2 = ((vec2<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]) << vec2<u32>(16u, 0u)) >> vec2<u32>(16u));
+ uint16x4 = ((vec2<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]).xxyy << vec4<u32>(16u, 0u, 16u, 0u)) >> vec4<u32>(16u));
+ uint32 = tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)];
+ uint32x2 = vec2<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]);
+ uint32x3 = vec3<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]);
+ uint32x4 = vec4<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)]);
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUint8x2, 64, 0},
+ {VertexFormat::kUint8x4, 64, 1},
+ {VertexFormat::kUint16x2, 64, 2},
+ {VertexFormat::kUint16x4, 64, 3},
+ {VertexFormat::kUint32, 64, 4},
+ {VertexFormat::kUint32x2, 64, 5},
+ {VertexFormat::kUint32x3, 64, 6},
+ {VertexFormat::kUint32x4, 64, 7},
+ }}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FormatsAligned_Float_F32) {
+ auto* src = R"(
+@vertex
+fn main(
+ @location(0) unorm8x2 : vec2<f32>,
+ @location(1) unorm8x4 : vec4<f32>,
+ @location(2) snorm8x2 : vec2<f32>,
+ @location(3) snorm8x4 : vec4<f32>,
+ @location(4) unorm16x2 : vec2<f32>,
+ @location(5) unorm16x4 : vec4<f32>,
+ @location(6) snorm16x2 : vec2<f32>,
+ @location(7) snorm16x4 : vec4<f32>,
+ @location(8) float16x2 : vec2<f32>,
+ @location(9) float16x4 : vec4<f32>,
+ @location(10) float32 : f32,
+ @location(11) float32x2 : vec2<f32>,
+ @location(12) float32x3 : vec3<f32>,
+ @location(13) float32x4 : vec4<f32>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var unorm8x2 : vec2<f32>;
+ var unorm8x4 : vec4<f32>;
+ var snorm8x2 : vec2<f32>;
+ var snorm8x4 : vec4<f32>;
+ var unorm16x2 : vec2<f32>;
+ var unorm16x4 : vec4<f32>;
+ var snorm16x2 : vec2<f32>;
+ var snorm16x4 : vec4<f32>;
+ var float16x2 : vec2<f32>;
+ var float16x4 : vec4<f32>;
+ var float32 : f32;
+ var float32x2 : vec2<f32>;
+ var float32x3 : vec3<f32>;
+ var float32x4 : vec4<f32>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ unorm8x2 = unpack4x8unorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy;
+ unorm8x4 = unpack4x8unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]);
+ snorm8x2 = unpack4x8snorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy;
+ snorm8x4 = unpack4x8snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]);
+ unorm16x2 = unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]);
+ unorm16x4 = vec4<f32>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]));
+ snorm16x2 = unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]);
+ snorm16x4 = vec4<f32>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]));
+ float16x2 = unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]);
+ float16x4 = vec4<f32>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]));
+ float32 = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]);
+ float32x2 = vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]));
+ float32x3 = vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]));
+ float32x4 = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)]));
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUnorm8x2, 64, 0},
+ {VertexFormat::kUnorm8x4, 64, 1},
+ {VertexFormat::kSnorm8x2, 64, 2},
+ {VertexFormat::kSnorm8x4, 64, 3},
+ {VertexFormat::kUnorm16x2, 64, 4},
+ {VertexFormat::kUnorm16x4, 64, 5},
+ {VertexFormat::kSnorm16x2, 64, 6},
+ {VertexFormat::kSnorm16x4, 64, 7},
+ {VertexFormat::kFloat16x2, 64, 8},
+ {VertexFormat::kFloat16x4, 64, 9},
+ {VertexFormat::kFloat32, 64, 10},
+ {VertexFormat::kFloat32x2, 64, 11},
+ {VertexFormat::kFloat32x3, 64, 12},
+ {VertexFormat::kFloat32x4, 64, 13},
+ }}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FormatsAligned_Float_F16) {
+ auto* src = R"(
+enable f16;
+
+@vertex
+fn main(
+ @location(0) unorm8x2 : vec2<f16>,
+ @location(1) unorm8x4 : vec4<f16>,
+ @location(2) snorm8x2 : vec2<f16>,
+ @location(3) snorm8x4 : vec4<f16>,
+ @location(4) unorm16x2 : vec2<f16>,
+ @location(5) unorm16x4 : vec4<f16>,
+ @location(6) snorm16x2 : vec2<f16>,
+ @location(7) snorm16x4 : vec4<f16>,
+ @location(8) float16x2 : vec2<f16>,
+ @location(9) float16x4 : vec4<f16>,
+ @location(10) float32 : f16,
+ @location(11) float32x2 : vec2<f16>,
+ @location(12) float32x3 : vec3<f16>,
+ @location(13) float32x4 : vec4<f16>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var unorm8x2 : vec2<f16>;
+ var unorm8x4 : vec4<f16>;
+ var snorm8x2 : vec2<f16>;
+ var snorm8x4 : vec4<f16>;
+ var unorm16x2 : vec2<f16>;
+ var unorm16x4 : vec4<f16>;
+ var snorm16x2 : vec2<f16>;
+ var snorm16x4 : vec4<f16>;
+ var float16x2 : vec2<f16>;
+ var float16x4 : vec4<f16>;
+ var float32 : f16;
+ var float32x2 : vec2<f16>;
+ var float32x3 : vec3<f16>;
+ var float32x4 : vec4<f16>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ unorm8x2 = vec2<f16>(unpack4x8unorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy);
+ unorm8x4 = vec4<f16>(unpack4x8unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]));
+ snorm8x2 = vec2<f16>(unpack4x8snorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy);
+ snorm8x4 = vec4<f16>(unpack4x8snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]));
+ unorm16x2 = vec2<f16>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]));
+ unorm16x4 = vec4<f16>(vec4<f32>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])));
+ snorm16x2 = vec2<f16>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]));
+ snorm16x4 = vec4<f16>(vec4<f32>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])));
+ float16x2 = vec2<f16>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]));
+ float16x4 = vec4<f16>(vec4<f32>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])));
+ float32 = f16(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]));
+ float32x2 = vec2<f16>(vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])));
+ float32x3 = vec3<f16>(vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)])));
+ float32x4 = vec4<f16>(vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)])));
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUnorm8x2, 64, 0},
+ {VertexFormat::kUnorm8x4, 64, 1},
+ {VertexFormat::kSnorm8x2, 64, 2},
+ {VertexFormat::kSnorm8x4, 64, 3},
+ {VertexFormat::kUnorm16x2, 64, 4},
+ {VertexFormat::kUnorm16x4, 64, 5},
+ {VertexFormat::kSnorm16x2, 64, 6},
+ {VertexFormat::kSnorm16x4, 64, 7},
+ {VertexFormat::kFloat16x2, 64, 8},
+ {VertexFormat::kFloat16x4, 64, 9},
+ {VertexFormat::kFloat32, 64, 10},
+ {VertexFormat::kFloat32x2, 64, 11},
+ {VertexFormat::kFloat32x3, 64, 12},
+ {VertexFormat::kFloat32x4, 64, 13},
+ }}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FormatsUnaligned_SInt) {
+ auto* src = R"(
+@vertex
+fn main(
+ @location(0) sint8x2 : vec2<i32>,
+ @location(1) sint8x4 : vec4<i32>,
+ @location(2) sint16x2 : vec2<i32>,
+ @location(3) sint16x4 : vec4<i32>,
+ @location(4) sint32 : i32,
+ @location(5) sint32x2 : vec2<i32>,
+ @location(6) sint32x3 : vec3<i32>,
+ @location(7) sint32x4 : vec4<i32>
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var sint8x2 : vec2<i32>;
+ var sint8x4 : vec4<i32>;
+ var sint16x2 : vec2<i32>;
+ var sint16x4 : vec4<i32>;
+ var sint32 : i32;
+ var sint32x2 : vec2<i32>;
+ var sint32x3 : vec3<i32>;
+ var sint32x4 : vec4<i32>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ sint8x2 = ((vec2<i32>(bitcast<i32>((((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 8u)) & 4294901760u))) << vec2<u32>(8u, 0u)) >> vec2<u32>(24u));
+ sint8x4 = ((vec4<i32>(bitcast<i32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u)))) << vec4<u32>(24u, 16u, 8u, 0u)) >> vec4<u32>(24u));
+ sint16x2 = ((vec2<i32>(bitcast<i32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u)))) << vec2<u32>(16u, 0u)) >> vec2<u32>(16u));
+ sint16x4 = ((vec2<i32>(bitcast<i32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), bitcast<i32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u)))).xxyy << vec4<u32>(16u, 0u, 16u, 0u)) >> vec4<u32>(16u));
+ sint32 = bitcast<i32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u)));
+ sint32x2 = vec2<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]));
+ sint32x3 = vec3<i32>(bitcast<i32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), bitcast<i32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u))), bitcast<i32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)] << 8u))));
+ sint32x4 = vec4<i32>(bitcast<i32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), bitcast<i32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u))), bitcast<i32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)] << 8u))), bitcast<i32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)] << 8u))));
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kSint8x2, 63, 0},
+ {VertexFormat::kSint8x4, 63, 1},
+ {VertexFormat::kSint16x2, 63, 2},
+ {VertexFormat::kSint16x4, 63, 3},
+ {VertexFormat::kSint32, 63, 4},
+ {VertexFormat::kSint32x2, 64, 5},
+ {VertexFormat::kSint32x3, 63, 6},
+ {VertexFormat::kSint32x4, 63, 7},
+ }}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FormatsUnaligned_UInt) {
+ auto* src = R"(
+@vertex
+fn main(
+ @location(0) uint8x2 : vec2<u32>,
+ @location(1) uint8x4 : vec4<u32>,
+ @location(2) uint16x2 : vec2<u32>,
+ @location(3) uint16x4 : vec4<u32>,
+ @location(4) uint32 : u32,
+ @location(5) uint32x2 : vec2<u32>,
+ @location(6) uint32x3 : vec3<u32>,
+ @location(7) uint32x4 : vec4<u32>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var uint8x2 : vec2<u32>;
+ var uint8x4 : vec4<u32>;
+ var uint16x2 : vec2<u32>;
+ var uint16x4 : vec4<u32>;
+ var uint32 : u32;
+ var uint32x2 : vec2<u32>;
+ var uint32x3 : vec3<u32>;
+ var uint32x4 : vec4<u32>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ uint8x2 = ((vec2<u32>((((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 8u)) & 4294901760u)) << vec2<u32>(8u, 0u)) >> vec2<u32>(24u));
+ uint8x4 = ((vec4<u32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))) << vec4<u32>(24u, 16u, 8u, 0u)) >> vec4<u32>(24u));
+ uint16x2 = ((vec2<u32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))) << vec2<u32>(16u, 0u)) >> vec2<u32>(16u));
+ uint16x4 = ((vec2<u32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u)), ((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u))).xxyy << vec4<u32>(16u, 0u, 16u, 0u)) >> vec4<u32>(16u));
+ uint32 = ((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u));
+ uint32x2 = vec2<u32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u)), ((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u)));
+ uint32x3 = vec3<u32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u)), ((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u)), ((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)] << 8u)));
+ uint32x4 = vec4<u32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u)), ((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u)), ((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)] << 8u)), ((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)] << 8u)));
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUint8x2, 63, 0},
+ {VertexFormat::kUint8x4, 63, 1},
+ {VertexFormat::kUint16x2, 63, 2},
+ {VertexFormat::kUint16x4, 63, 3},
+ {VertexFormat::kUint32, 63, 4},
+ {VertexFormat::kUint32x2, 63, 5},
+ {VertexFormat::kUint32x3, 63, 6},
+ {VertexFormat::kUint32x4, 63, 7},
+ }}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FormatsUnaligned_Float_F32) {
+ auto* src = R"(
+@vertex
+fn main(
+ @location(0) unorm8x2 : vec2<f32>,
+ @location(1) unorm8x4 : vec4<f32>,
+ @location(2) snorm8x2 : vec2<f32>,
+ @location(3) snorm8x4 : vec4<f32>,
+ @location(4) unorm16x2 : vec2<f32>,
+ @location(5) unorm16x4 : vec4<f32>,
+ @location(6) snorm16x2 : vec2<f32>,
+ @location(7) snorm16x4 : vec4<f32>,
+ @location(8) float16x2 : vec2<f32>,
+ @location(9) float16x4 : vec4<f32>,
+ @location(10) float32 : f32,
+ @location(11) float32x2 : vec2<f32>,
+ @location(12) float32x3 : vec3<f32>,
+ @location(13) float32x4 : vec4<f32>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var unorm8x2 : vec2<f32>;
+ var unorm8x4 : vec4<f32>;
+ var snorm8x2 : vec2<f32>;
+ var snorm8x4 : vec4<f32>;
+ var unorm16x2 : vec2<f32>;
+ var unorm16x4 : vec4<f32>;
+ var snorm16x2 : vec2<f32>;
+ var snorm16x4 : vec4<f32>;
+ var float16x2 : vec2<f32>;
+ var float16x4 : vec4<f32>;
+ var float32 : f32;
+ var float32x2 : vec2<f32>;
+ var float32x3 : vec3<f32>;
+ var float32x4 : vec4<f32>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ unorm8x2 = unpack4x8unorm((((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u)) & 65535u)).xy;
+ unorm8x4 = unpack4x8unorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u)));
+ snorm8x2 = unpack4x8snorm((((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u)) & 65535u)).xy;
+ snorm8x4 = unpack4x8snorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u)));
+ unorm16x2 = unpack2x16unorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u)));
+ unorm16x4 = vec4<f32>(unpack2x16unorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), unpack2x16unorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u))));
+ snorm16x2 = unpack2x16snorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u)));
+ snorm16x4 = vec4<f32>(unpack2x16snorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), unpack2x16snorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u))));
+ float16x2 = unpack2x16float(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u)));
+ float16x4 = vec4<f32>(unpack2x16float(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), unpack2x16float(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u))));
+ float32 = bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u)));
+ float32x2 = vec2<f32>(bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u))));
+ float32x3 = vec3<f32>(bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)] << 8u))));
+ float32x4 = vec4<f32>(bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)] << 8u))));
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUnorm8x2, 63, 0},
+ {VertexFormat::kUnorm8x4, 63, 1},
+ {VertexFormat::kSnorm8x2, 63, 2},
+ {VertexFormat::kSnorm8x4, 63, 3},
+ {VertexFormat::kUnorm16x2, 63, 4},
+ {VertexFormat::kUnorm16x4, 63, 5},
+ {VertexFormat::kSnorm16x2, 63, 6},
+ {VertexFormat::kSnorm16x4, 63, 7},
+ {VertexFormat::kFloat16x2, 63, 8},
+ {VertexFormat::kFloat16x4, 63, 9},
+ {VertexFormat::kFloat32, 63, 10},
+ {VertexFormat::kFloat32x2, 63, 11},
+ {VertexFormat::kFloat32x3, 63, 12},
+ {VertexFormat::kFloat32x4, 63, 13},
+ }}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FormatsUnaligned_Float_F16) {
+ auto* src = R"(
+enable f16;
+
+@vertex
+fn main(
+ @location(0) unorm8x2 : vec2<f16>,
+ @location(1) unorm8x4 : vec4<f16>,
+ @location(2) snorm8x2 : vec2<f16>,
+ @location(3) snorm8x4 : vec4<f16>,
+ @location(4) unorm16x2 : vec2<f16>,
+ @location(5) unorm16x4 : vec4<f16>,
+ @location(6) snorm16x2 : vec2<f16>,
+ @location(7) snorm16x4 : vec4<f16>,
+ @location(8) float16x2 : vec2<f16>,
+ @location(9) float16x4 : vec4<f16>,
+ @location(10) float32 : f16,
+ @location(11) float32x2 : vec2<f16>,
+ @location(12) float32x3 : vec3<f16>,
+ @location(13) float32x4 : vec4<f16>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var unorm8x2 : vec2<f16>;
+ var unorm8x4 : vec4<f16>;
+ var snorm8x2 : vec2<f16>;
+ var snorm8x4 : vec4<f16>;
+ var unorm16x2 : vec2<f16>;
+ var unorm16x4 : vec4<f16>;
+ var snorm16x2 : vec2<f16>;
+ var snorm16x4 : vec4<f16>;
+ var float16x2 : vec2<f16>;
+ var float16x4 : vec4<f16>;
+ var float32 : f16;
+ var float32x2 : vec2<f16>;
+ var float32x3 : vec3<f16>;
+ var float32x4 : vec4<f16>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ unorm8x2 = vec2<f16>(unpack4x8unorm((((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u)) & 65535u)).xy);
+ unorm8x4 = vec4<f16>(unpack4x8unorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))));
+ snorm8x2 = vec2<f16>(unpack4x8snorm((((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u)) & 65535u)).xy);
+ snorm8x4 = vec4<f16>(unpack4x8snorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))));
+ unorm16x2 = vec2<f16>(unpack2x16unorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))));
+ unorm16x4 = vec4<f16>(vec4<f32>(unpack2x16unorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), unpack2x16unorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u)))));
+ snorm16x2 = vec2<f16>(unpack2x16snorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))));
+ snorm16x4 = vec4<f16>(vec4<f32>(unpack2x16snorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), unpack2x16snorm(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u)))));
+ float16x2 = vec2<f16>(unpack2x16float(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))));
+ float16x4 = vec4<f16>(vec4<f32>(unpack2x16float(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), unpack2x16float(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u)))));
+ float32 = f16(bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))));
+ float32x2 = vec2<f16>(vec2<f32>(bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u)))));
+ float32x3 = vec3<f16>(vec3<f32>(bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)] << 8u)))));
+ float32x4 = vec4<f16>(vec4<f32>(bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 15u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)] << 8u))), bitcast<f32>(((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)] >> 24u) | (tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)] << 8u)))));
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUnorm8x2, 63, 0},
+ {VertexFormat::kUnorm8x4, 63, 1},
+ {VertexFormat::kSnorm8x2, 63, 2},
+ {VertexFormat::kSnorm8x4, 63, 3},
+ {VertexFormat::kUnorm16x2, 63, 4},
+ {VertexFormat::kUnorm16x4, 63, 5},
+ {VertexFormat::kSnorm16x2, 63, 6},
+ {VertexFormat::kSnorm16x4, 63, 7},
+ {VertexFormat::kFloat16x2, 63, 8},
+ {VertexFormat::kFloat16x4, 63, 9},
+ {VertexFormat::kFloat32, 63, 10},
+ {VertexFormat::kFloat32x2, 63, 11},
+ {VertexFormat::kFloat32x3, 63, 12},
+ {VertexFormat::kFloat32x4, 63, 13},
+ }}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FormatsWithVectorsResized_Padding_SInt) {
+ auto* src = R"(
+@vertex
+fn main(
+ @location(0) vec3_sint8x2 : vec3<i32>,
+ @location(1) vec4_sint8x2 : vec4<i32>,
+ @location(2) vec3_sint16x2 : vec3<i32>,
+ @location(3) vec4_sint16x2 : vec4<i32>,
+ @location(4) vec2_sint32 : vec2<i32>,
+ @location(5) vec3_sint32 : vec3<i32>,
+ @location(6) vec4_sint32 : vec4<i32>,
+ @location(7) vec3_sint32x2 : vec3<i32>,
+ @location(8) vec4_sint32x2 : vec4<i32>,
+ @location(9) vec4_sint32x3 : vec4<i32>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var vec3_sint8x2 : vec3<i32>;
+ var vec4_sint8x2 : vec4<i32>;
+ var vec3_sint16x2 : vec3<i32>;
+ var vec4_sint16x2 : vec4<i32>;
+ var vec2_sint32 : vec2<i32>;
+ var vec3_sint32 : vec3<i32>;
+ var vec4_sint32 : vec4<i32>;
+ var vec3_sint32x2 : vec3<i32>;
+ var vec4_sint32x2 : vec4<i32>;
+ var vec4_sint32x3 : vec4<i32>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ vec3_sint8x2 = vec3<i32>(((vec2<i32>(bitcast<i32>((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 16u))) << vec2<u32>(8u, 0u)) >> vec2<u32>(24u)), 0);
+ vec4_sint8x2 = vec4<i32>(((vec2<i32>(bitcast<i32>((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 16u))) << vec2<u32>(8u, 0u)) >> vec2<u32>(24u)), 0, 1);
+ vec3_sint16x2 = vec3<i32>(((vec2<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])) << vec2<u32>(16u, 0u)) >> vec2<u32>(16u)), 0);
+ vec4_sint16x2 = vec4<i32>(((vec2<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])) << vec2<u32>(16u, 0u)) >> vec2<u32>(16u)), 0, 1);
+ vec2_sint32 = vec2<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), 0);
+ vec3_sint32 = vec3<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), 0, 0);
+ vec4_sint32 = vec4<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), 0, 0, 1);
+ vec3_sint32x2 = vec3<i32>(vec2<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])), 0);
+ vec4_sint32x2 = vec4<i32>(vec2<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])), 0, 1);
+ vec4_sint32x3 = vec4<i32>(vec3<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)])), 1);
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kSint8x2, 64, 0},
+ {VertexFormat::kSint8x2, 64, 1},
+ {VertexFormat::kSint16x2, 64, 2},
+ {VertexFormat::kSint16x2, 64, 3},
+ {VertexFormat::kSint32, 64, 4},
+ {VertexFormat::kSint32, 64, 5},
+ {VertexFormat::kSint32, 64, 6},
+ {VertexFormat::kSint32x2, 64, 7},
+ {VertexFormat::kSint32x2, 64, 8},
+ {VertexFormat::kSint32x3, 64, 9},
+ }}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FormatsWithVectorsResized_Padding_UInt) {
+ auto* src = R"(
+@vertex
+fn main(
+ @location(0) vec3_uint8x2 : vec3<u32>,
+ @location(1) vec4_uint8x2 : vec4<u32>,
+ @location(2) vec3_uint16x2 : vec3<u32>,
+ @location(3) vec4_uint16x2 : vec4<u32>,
+ @location(4) vec2_uint32 : vec2<u32>,
+ @location(5) vec3_uint32 : vec3<u32>,
+ @location(6) vec4_uint32 : vec4<u32>,
+ @location(7) vec3_uint32x2 : vec3<u32>,
+ @location(8) vec4_uint32x2 : vec4<u32>,
+ @location(9) vec4_uint32x3 : vec4<u32>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var vec3_uint8x2 : vec3<u32>;
+ var vec4_uint8x2 : vec4<u32>;
+ var vec3_uint16x2 : vec3<u32>;
+ var vec4_uint16x2 : vec4<u32>;
+ var vec2_uint32 : vec2<u32>;
+ var vec3_uint32 : vec3<u32>;
+ var vec4_uint32 : vec4<u32>;
+ var vec3_uint32x2 : vec3<u32>;
+ var vec4_uint32x2 : vec4<u32>;
+ var vec4_uint32x3 : vec4<u32>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ vec3_uint8x2 = vec3<u32>(((vec2<u32>((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 16u)) << vec2<u32>(8u, 0u)) >> vec2<u32>(24u)), 0);
+ vec4_uint8x2 = vec4<u32>(((vec2<u32>((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 16u)) << vec2<u32>(8u, 0u)) >> vec2<u32>(24u)), 0, 1);
+ vec3_uint16x2 = vec3<u32>(((vec2<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]) << vec2<u32>(16u, 0u)) >> vec2<u32>(16u)), 0);
+ vec4_uint16x2 = vec4<u32>(((vec2<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]) << vec2<u32>(16u, 0u)) >> vec2<u32>(16u)), 0, 1);
+ vec2_uint32 = vec2<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], 0);
+ vec3_uint32 = vec3<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], 0, 0);
+ vec4_uint32 = vec4<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], 0, 0, 1);
+ vec3_uint32x2 = vec3<u32>(vec2<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), 0);
+ vec4_uint32x2 = vec4<u32>(vec2<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), 0, 1);
+ vec4_uint32x3 = vec4<u32>(vec3<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]), 1);
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUint8x2, 64, 0},
+ {VertexFormat::kUint8x2, 64, 1},
+ {VertexFormat::kUint16x2, 64, 2},
+ {VertexFormat::kUint16x2, 64, 3},
+ {VertexFormat::kUint32, 64, 4},
+ {VertexFormat::kUint32, 64, 5},
+ {VertexFormat::kUint32, 64, 6},
+ {VertexFormat::kUint32x2, 64, 7},
+ {VertexFormat::kUint32x2, 64, 8},
+ {VertexFormat::kUint32x3, 64, 9},
+ }}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FormatsWithVectorsResized_Padding_Float_F32) {
+ auto* src = R"(
+@vertex
+fn main(
+ @location(0) vec3_unorm8x2 : vec3<f32>,
+ @location(1) vec4_unorm8x2 : vec4<f32>,
+ @location(2) vec3_snorm8x2 : vec3<f32>,
+ @location(3) vec4_snorm8x2 : vec4<f32>,
+ @location(4) vec3_unorm16x2 : vec3<f32>,
+ @location(5) vec4_unorm16x2 : vec4<f32>,
+ @location(6) vec3_snorm16x2 : vec3<f32>,
+ @location(7) vec4_snorm16x2 : vec4<f32>,
+ @location(8) vec3_float16x2 : vec3<f32>,
+ @location(9) vec4_float16x2 : vec4<f32>,
+ @location(10) vec2_float32 : vec2<f32>,
+ @location(11) vec3_float32 : vec3<f32>,
+ @location(12) vec4_float32 : vec4<f32>,
+ @location(13) vec3_float32x2 : vec3<f32>,
+ @location(14) vec4_float32x2 : vec4<f32>,
+ @location(15) vec4_float32x3 : vec4<f32>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var vec3_unorm8x2 : vec3<f32>;
+ var vec4_unorm8x2 : vec4<f32>;
+ var vec3_snorm8x2 : vec3<f32>;
+ var vec4_snorm8x2 : vec4<f32>;
+ var vec3_unorm16x2 : vec3<f32>;
+ var vec4_unorm16x2 : vec4<f32>;
+ var vec3_snorm16x2 : vec3<f32>;
+ var vec4_snorm16x2 : vec4<f32>;
+ var vec3_float16x2 : vec3<f32>;
+ var vec4_float16x2 : vec4<f32>;
+ var vec2_float32 : vec2<f32>;
+ var vec3_float32 : vec3<f32>;
+ var vec4_float32 : vec4<f32>;
+ var vec3_float32x2 : vec3<f32>;
+ var vec4_float32x2 : vec4<f32>;
+ var vec4_float32x3 : vec4<f32>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ vec3_unorm8x2 = vec3<f32>(unpack4x8unorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy, 0.0);
+ vec4_unorm8x2 = vec4<f32>(unpack4x8unorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy, 0.0, 1.0);
+ vec3_snorm8x2 = vec3<f32>(unpack4x8snorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy, 0.0);
+ vec4_snorm8x2 = vec4<f32>(unpack4x8snorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy, 0.0, 1.0);
+ vec3_unorm16x2 = vec3<f32>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), 0.0);
+ vec4_unorm16x2 = vec4<f32>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), 0.0, 1.0);
+ vec3_snorm16x2 = vec3<f32>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), 0.0);
+ vec4_snorm16x2 = vec4<f32>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), 0.0, 1.0);
+ vec3_float16x2 = vec3<f32>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), 0.0);
+ vec4_float16x2 = vec4<f32>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), 0.0, 1.0);
+ vec2_float32 = vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), 0.0);
+ vec3_float32 = vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), 0.0, 0.0);
+ vec4_float32 = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), 0.0, 0.0, 1.0);
+ vec3_float32x2 = vec3<f32>(vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])), 0.0);
+ vec4_float32x2 = vec4<f32>(vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])), 0.0, 1.0);
+ vec4_float32x3 = vec4<f32>(vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)])), 1.0);
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUnorm8x2, 64, 0},
+ {VertexFormat::kUnorm8x2, 64, 1},
+ {VertexFormat::kSnorm8x2, 64, 2},
+ {VertexFormat::kSnorm8x2, 64, 3},
+ {VertexFormat::kUnorm16x2, 64, 4},
+ {VertexFormat::kUnorm16x2, 64, 5},
+ {VertexFormat::kSnorm16x2, 64, 6},
+ {VertexFormat::kSnorm16x2, 64, 7},
+ {VertexFormat::kFloat16x2, 64, 8},
+ {VertexFormat::kFloat16x2, 64, 9},
+ {VertexFormat::kFloat32, 64, 10},
+ {VertexFormat::kFloat32, 64, 11},
+ {VertexFormat::kFloat32, 64, 12},
+ {VertexFormat::kFloat32x2, 64, 13},
+ {VertexFormat::kFloat32x2, 64, 14},
+ {VertexFormat::kFloat32x3, 64, 15},
+ }}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FormatsWithVectorsResized_Padding_Float_F16) {
+ auto* src = R"(
+enable f16;
+
+@vertex
+fn main(
+ @location(0) vec3_unorm8x2 : vec3<f16>,
+ @location(1) vec4_unorm8x2 : vec4<f16>,
+ @location(2) vec3_snorm8x2 : vec3<f16>,
+ @location(3) vec4_snorm8x2 : vec4<f16>,
+ @location(4) vec3_unorm16x2 : vec3<f16>,
+ @location(5) vec4_unorm16x2 : vec4<f16>,
+ @location(6) vec3_snorm16x2 : vec3<f16>,
+ @location(7) vec4_snorm16x2 : vec4<f16>,
+ @location(8) vec3_float16x2 : vec3<f16>,
+ @location(9) vec4_float16x2 : vec4<f16>,
+ @location(10) vec2_float32 : vec2<f16>,
+ @location(11) vec3_float32 : vec3<f16>,
+ @location(12) vec4_float32 : vec4<f16>,
+ @location(13) vec3_float32x2 : vec3<f16>,
+ @location(14) vec4_float32x2 : vec4<f16>,
+ @location(15) vec4_float32x3 : vec4<f16>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var vec3_unorm8x2 : vec3<f16>;
+ var vec4_unorm8x2 : vec4<f16>;
+ var vec3_snorm8x2 : vec3<f16>;
+ var vec4_snorm8x2 : vec4<f16>;
+ var vec3_unorm16x2 : vec3<f16>;
+ var vec4_unorm16x2 : vec4<f16>;
+ var vec3_snorm16x2 : vec3<f16>;
+ var vec4_snorm16x2 : vec4<f16>;
+ var vec3_float16x2 : vec3<f16>;
+ var vec4_float16x2 : vec4<f16>;
+ var vec2_float32 : vec2<f16>;
+ var vec3_float32 : vec3<f16>;
+ var vec4_float32 : vec4<f16>;
+ var vec3_float32x2 : vec3<f16>;
+ var vec4_float32x2 : vec4<f16>;
+ var vec4_float32x3 : vec4<f16>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ vec3_unorm8x2 = vec3<f16>(vec2<f16>(unpack4x8unorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy), 0.0);
+ vec4_unorm8x2 = vec4<f16>(vec2<f16>(unpack4x8unorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy), 0.0, 1.0);
+ vec3_snorm8x2 = vec3<f16>(vec2<f16>(unpack4x8snorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy), 0.0);
+ vec4_snorm8x2 = vec4<f16>(vec2<f16>(unpack4x8snorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy), 0.0, 1.0);
+ vec3_unorm16x2 = vec3<f16>(vec2<f16>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0);
+ vec4_unorm16x2 = vec4<f16>(vec2<f16>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0, 1.0);
+ vec3_snorm16x2 = vec3<f16>(vec2<f16>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0);
+ vec4_snorm16x2 = vec4<f16>(vec2<f16>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0, 1.0);
+ vec3_float16x2 = vec3<f16>(vec2<f16>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0);
+ vec4_float16x2 = vec4<f16>(vec2<f16>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0, 1.0);
+ vec2_float32 = vec2<f16>(f16(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0);
+ vec3_float32 = vec3<f16>(f16(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0, 0.0);
+ vec4_float32 = vec4<f16>(f16(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])), 0.0, 0.0, 1.0);
+ vec3_float32x2 = vec3<f16>(vec2<f16>(vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))), 0.0);
+ vec4_float32x2 = vec4<f16>(vec2<f16>(vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))), 0.0, 1.0);
+ vec4_float32x3 = vec4<f16>(vec3<f16>(vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]))), 1.0);
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUnorm8x2, 64, 0},
+ {VertexFormat::kUnorm8x2, 64, 1},
+ {VertexFormat::kSnorm8x2, 64, 2},
+ {VertexFormat::kSnorm8x2, 64, 3},
+ {VertexFormat::kUnorm16x2, 64, 4},
+ {VertexFormat::kUnorm16x2, 64, 5},
+ {VertexFormat::kSnorm16x2, 64, 6},
+ {VertexFormat::kSnorm16x2, 64, 7},
+ {VertexFormat::kFloat16x2, 64, 8},
+ {VertexFormat::kFloat16x2, 64, 9},
+ {VertexFormat::kFloat32, 64, 10},
+ {VertexFormat::kFloat32, 64, 11},
+ {VertexFormat::kFloat32, 64, 12},
+ {VertexFormat::kFloat32x2, 64, 13},
+ {VertexFormat::kFloat32x2, 64, 14},
+ {VertexFormat::kFloat32x3, 64, 15},
+ }}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FormatsWithVectorsResized_Shrinking_SInt) {
+ auto* src = R"(
+@vertex
+fn main(
+ @location(0) sclr_sint8x2 : i32 ,
+ @location(1) sclr_sint8x4 : i32 ,
+ @location(2) vec2_sint8x4 : vec2<i32>,
+ @location(3) vec3_sint8x4 : vec3<i32>,
+ @location(4) sclr_sint16x2 : i32 ,
+ @location(5) sclr_sint16x4 : i32 ,
+ @location(6) vec2_sint16x4 : vec2<i32>,
+ @location(7) vec3_sint16x4 : vec3<i32>,
+ @location(8) sclr_sint32x2 : i32 ,
+ @location(9) sclr_sint32x3 : i32 ,
+ @location(10) vec2_sint32x3 : vec2<i32>,
+ @location(11) sclr_sint32x4 : i32 ,
+ @location(12) vec2_sint32x4 : vec2<i32>,
+ @location(13) vec3_sint32x4 : vec3<i32>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var sclr_sint8x2 : i32;
+ var sclr_sint8x4 : i32;
+ var vec2_sint8x4 : vec2<i32>;
+ var vec3_sint8x4 : vec3<i32>;
+ var sclr_sint16x2 : i32;
+ var sclr_sint16x4 : i32;
+ var vec2_sint16x4 : vec2<i32>;
+ var vec3_sint16x4 : vec3<i32>;
+ var sclr_sint32x2 : i32;
+ var sclr_sint32x3 : i32;
+ var vec2_sint32x3 : vec2<i32>;
+ var sclr_sint32x4 : i32;
+ var vec2_sint32x4 : vec2<i32>;
+ var vec3_sint32x4 : vec3<i32>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ sclr_sint8x2 = (((vec2<i32>(bitcast<i32>((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 16u))) << vec2<u32>(8u, 0u)) >> vec2<u32>(24u))).x;
+ sclr_sint8x4 = (((vec4<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])) << vec4<u32>(24u, 16u, 8u, 0u)) >> vec4<u32>(24u))).x;
+ vec2_sint8x4 = (((vec4<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])) << vec4<u32>(24u, 16u, 8u, 0u)) >> vec4<u32>(24u))).xy;
+ vec3_sint8x4 = (((vec4<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])) << vec4<u32>(24u, 16u, 8u, 0u)) >> vec4<u32>(24u))).xyz;
+ sclr_sint16x2 = (((vec2<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])) << vec2<u32>(16u, 0u)) >> vec2<u32>(16u))).x;
+ sclr_sint16x4 = (((vec2<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])).xxyy << vec4<u32>(16u, 0u, 16u, 0u)) >> vec4<u32>(16u))).x;
+ vec2_sint16x4 = (((vec2<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])).xxyy << vec4<u32>(16u, 0u, 16u, 0u)) >> vec4<u32>(16u))).xy;
+ vec3_sint16x4 = (((vec2<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])).xxyy << vec4<u32>(16u, 0u, 16u, 0u)) >> vec4<u32>(16u))).xyz;
+ sclr_sint32x2 = vec2<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])).x;
+ sclr_sint32x3 = vec3<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)])).x;
+ vec2_sint32x3 = vec3<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)])).xy;
+ sclr_sint32x4 = vec4<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)])).x;
+ vec2_sint32x4 = vec4<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)])).xy;
+ vec3_sint32x4 = vec4<i32>(bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]), bitcast<i32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)])).xyz;
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kSint8x2, 64, 0},
+ {VertexFormat::kSint8x4, 64, 1},
+ {VertexFormat::kSint8x4, 64, 2},
+ {VertexFormat::kSint8x4, 64, 3},
+ {VertexFormat::kSint16x2, 64, 4},
+ {VertexFormat::kSint16x4, 64, 5},
+ {VertexFormat::kSint16x4, 64, 6},
+ {VertexFormat::kSint16x4, 64, 7},
+ {VertexFormat::kSint32x2, 64, 8},
+ {VertexFormat::kSint32x3, 64, 9},
+ {VertexFormat::kSint32x3, 64, 10},
+ {VertexFormat::kSint32x4, 64, 11},
+ {VertexFormat::kSint32x4, 64, 12},
+ {VertexFormat::kSint32x4, 64, 13},
+ }}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FormatsWithVectorsResized_Shrinking_UInt) {
+ auto* src = R"(
+@vertex
+fn main(
+ @location(0) sclr_uint8x2 : u32 ,
+ @location(1) sclr_uint8x4 : u32 ,
+ @location(2) vec2_uint8x4 : vec2<u32>,
+ @location(3) vec3_uint8x4 : vec3<u32>,
+ @location(4) sclr_uint16x2 : u32 ,
+ @location(5) sclr_uint16x4 : u32 ,
+ @location(6) vec2_uint16x4 : vec2<u32>,
+ @location(7) vec3_uint16x4 : vec3<u32>,
+ @location(8) sclr_uint32x2 : u32 ,
+ @location(9) sclr_uint32x3 : u32 ,
+ @location(10) vec2_uint32x3 : vec2<u32>,
+ @location(11) sclr_uint32x4 : u32 ,
+ @location(12) vec2_uint32x4 : vec2<u32>,
+ @location(13) vec3_uint32x4 : vec3<u32>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var sclr_uint8x2 : u32;
+ var sclr_uint8x4 : u32;
+ var vec2_uint8x4 : vec2<u32>;
+ var vec3_uint8x4 : vec3<u32>;
+ var sclr_uint16x2 : u32;
+ var sclr_uint16x4 : u32;
+ var vec2_uint16x4 : vec2<u32>;
+ var vec3_uint16x4 : vec3<u32>;
+ var sclr_uint32x2 : u32;
+ var sclr_uint32x3 : u32;
+ var vec2_uint32x3 : vec2<u32>;
+ var sclr_uint32x4 : u32;
+ var vec2_uint32x4 : vec2<u32>;
+ var vec3_uint32x4 : vec3<u32>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ sclr_uint8x2 = (((vec2<u32>((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] << 16u)) << vec2<u32>(8u, 0u)) >> vec2<u32>(24u))).x;
+ sclr_uint8x4 = (((vec4<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]) << vec4<u32>(24u, 16u, 8u, 0u)) >> vec4<u32>(24u))).x;
+ vec2_uint8x4 = (((vec4<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]) << vec4<u32>(24u, 16u, 8u, 0u)) >> vec4<u32>(24u))).xy;
+ vec3_uint8x4 = (((vec4<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]) << vec4<u32>(24u, 16u, 8u, 0u)) >> vec4<u32>(24u))).xyz;
+ sclr_uint16x2 = (((vec2<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]) << vec2<u32>(16u, 0u)) >> vec2<u32>(16u))).x;
+ sclr_uint16x4 = (((vec2<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]).xxyy << vec4<u32>(16u, 0u, 16u, 0u)) >> vec4<u32>(16u))).x;
+ vec2_uint16x4 = (((vec2<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]).xxyy << vec4<u32>(16u, 0u, 16u, 0u)) >> vec4<u32>(16u))).xy;
+ vec3_uint16x4 = (((vec2<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]).xxyy << vec4<u32>(16u, 0u, 16u, 0u)) >> vec4<u32>(16u))).xyz;
+ sclr_uint32x2 = vec2<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]).x;
+ sclr_uint32x3 = vec3<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]).x;
+ vec2_uint32x3 = vec3<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]).xy;
+ sclr_uint32x4 = vec4<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)]).x;
+ vec2_uint32x4 = vec4<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)]).xy;
+ vec3_uint32x4 = vec4<u32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)], tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)]).xyz;
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUint8x2, 64, 0},
+ {VertexFormat::kUint8x4, 64, 1},
+ {VertexFormat::kUint8x4, 64, 2},
+ {VertexFormat::kUint8x4, 64, 3},
+ {VertexFormat::kUint16x2, 64, 4},
+ {VertexFormat::kUint16x4, 64, 5},
+ {VertexFormat::kUint16x4, 64, 6},
+ {VertexFormat::kUint16x4, 64, 7},
+ {VertexFormat::kUint32x2, 64, 8},
+ {VertexFormat::kUint32x3, 64, 9},
+ {VertexFormat::kUint32x3, 64, 10},
+ {VertexFormat::kUint32x4, 64, 11},
+ {VertexFormat::kUint32x4, 64, 12},
+ {VertexFormat::kUint32x4, 64, 13},
+ }}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FormatsWithVectorsResized_Shrinking_Float_F32) {
+ auto* src = R"(
+@vertex
+fn main(
+ @location(0) sclr_unorm8x2 : f32 ,
+ @location(1) sclr_unorm8x4 : f32 ,
+ @location(2) vec2_unorm8x4 : vec2<f32>,
+ @location(3) vec3_unorm8x4 : vec3<f32>,
+ @location(4) sclr_snorm8x2 : f32 ,
+ @location(5) sclr_snorm8x4 : f32 ,
+ @location(6) vec2_snorm8x4 : vec2<f32>,
+ @location(7) vec3_snorm8x4 : vec3<f32>,
+ @location(8) sclr_unorm16x2 : f32 ,
+ @location(9) sclr_unorm16x4 : f32 ,
+ @location(10) vec2_unorm16x4 : vec2<f32>,
+ @location(11) vec3_unorm16x4 : vec3<f32>,
+ @location(12) sclr_snorm16x2 : f32 ,
+ @location(13) sclr_snorm16x4 : f32 ,
+ @location(14) vec2_snorm16x4 : vec2<f32>,
+ @location(15) vec3_snorm16x4 : vec3<f32>,
+ @location(16) sclr_float16x2 : f32 ,
+ @location(17) sclr_float16x4 : f32 ,
+ @location(18) vec2_float16x4 : vec2<f32>,
+ @location(19) vec3_float16x4 : vec3<f32>,
+ @location(20) sclr_float32x2 : f32 ,
+ @location(21) sclr_float32x3 : f32 ,
+ @location(22) vec2_float32x3 : vec2<f32>,
+ @location(23) sclr_float32x4 : f32 ,
+ @location(24) vec2_float32x4 : vec2<f32>,
+ @location(25) vec3_float32x4 : vec3<f32>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var sclr_unorm8x2 : f32;
+ var sclr_unorm8x4 : f32;
+ var vec2_unorm8x4 : vec2<f32>;
+ var vec3_unorm8x4 : vec3<f32>;
+ var sclr_snorm8x2 : f32;
+ var sclr_snorm8x4 : f32;
+ var vec2_snorm8x4 : vec2<f32>;
+ var vec3_snorm8x4 : vec3<f32>;
+ var sclr_unorm16x2 : f32;
+ var sclr_unorm16x4 : f32;
+ var vec2_unorm16x4 : vec2<f32>;
+ var vec3_unorm16x4 : vec3<f32>;
+ var sclr_snorm16x2 : f32;
+ var sclr_snorm16x4 : f32;
+ var vec2_snorm16x4 : vec2<f32>;
+ var vec3_snorm16x4 : vec3<f32>;
+ var sclr_float16x2 : f32;
+ var sclr_float16x4 : f32;
+ var vec2_float16x4 : vec2<f32>;
+ var vec3_float16x4 : vec3<f32>;
+ var sclr_float32x2 : f32;
+ var sclr_float32x3 : f32;
+ var vec2_float32x3 : vec2<f32>;
+ var sclr_float32x4 : f32;
+ var vec2_float32x4 : vec2<f32>;
+ var vec3_float32x4 : vec3<f32>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ sclr_unorm8x2 = unpack4x8unorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy.x;
+ sclr_unorm8x4 = unpack4x8unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]).x;
+ vec2_unorm8x4 = unpack4x8unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]).xy;
+ vec3_unorm8x4 = unpack4x8unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]).xyz;
+ sclr_snorm8x2 = unpack4x8snorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy.x;
+ sclr_snorm8x4 = unpack4x8snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]).x;
+ vec2_snorm8x4 = unpack4x8snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]).xy;
+ vec3_snorm8x4 = unpack4x8snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]).xyz;
+ sclr_unorm16x2 = unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]).x;
+ sclr_unorm16x4 = vec4<f32>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])).x;
+ vec2_unorm16x4 = vec4<f32>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])).xy;
+ vec3_unorm16x4 = vec4<f32>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])).xyz;
+ sclr_snorm16x2 = unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]).x;
+ sclr_snorm16x4 = vec4<f32>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])).x;
+ vec2_snorm16x4 = vec4<f32>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])).xy;
+ vec3_snorm16x4 = vec4<f32>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])).xyz;
+ sclr_float16x2 = unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]).x;
+ sclr_float16x4 = vec4<f32>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])).x;
+ vec2_float16x4 = vec4<f32>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])).xy;
+ vec3_float16x4 = vec4<f32>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])).xyz;
+ sclr_float32x2 = vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)])).x;
+ sclr_float32x3 = vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)])).x;
+ vec2_float32x3 = vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)])).xy;
+ sclr_float32x4 = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)])).x;
+ vec2_float32x4 = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)])).xy;
+ vec3_float32x4 = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)])).xyz;
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {
+ {{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUnorm8x2, 64, 0}, {VertexFormat::kUnorm8x4, 64, 1},
+ {VertexFormat::kUnorm8x4, 64, 2}, {VertexFormat::kUnorm8x4, 64, 3},
+ {VertexFormat::kSnorm8x2, 64, 4}, {VertexFormat::kSnorm8x4, 64, 5},
+ {VertexFormat::kSnorm8x4, 64, 6}, {VertexFormat::kSnorm8x4, 64, 7},
+ {VertexFormat::kUnorm16x2, 64, 8}, {VertexFormat::kUnorm16x4, 64, 9},
+ {VertexFormat::kUnorm16x4, 64, 10}, {VertexFormat::kUnorm16x4, 64, 11},
+ {VertexFormat::kSnorm16x2, 64, 12}, {VertexFormat::kSnorm16x4, 64, 13},
+ {VertexFormat::kSnorm16x4, 64, 14}, {VertexFormat::kSnorm16x4, 64, 15},
+ {VertexFormat::kFloat16x2, 64, 16}, {VertexFormat::kFloat16x4, 64, 17},
+ {VertexFormat::kFloat16x4, 64, 18}, {VertexFormat::kFloat16x4, 64, 19},
+ {VertexFormat::kFloat32x2, 64, 20}, {VertexFormat::kFloat32x3, 64, 21},
+ {VertexFormat::kFloat32x3, 64, 22}, {VertexFormat::kFloat32x4, 64, 23},
+ {VertexFormat::kFloat32x4, 64, 24}, {VertexFormat::kFloat32x4, 64, 25},
+ }}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FormatsWithVectorsResized_Shrinking_Float_F16) {
+ auto* src = R"(
+enable f16;
+
+@vertex
+fn main(
+ @location(0) sclr_unorm8x2 : f16 ,
+ @location(1) sclr_unorm8x4 : f16 ,
+ @location(2) vec2_unorm8x4 : vec2<f16>,
+ @location(3) vec3_unorm8x4 : vec3<f16>,
+ @location(4) sclr_snorm8x2 : f16 ,
+ @location(5) sclr_snorm8x4 : f16 ,
+ @location(6) vec2_snorm8x4 : vec2<f16>,
+ @location(7) vec3_snorm8x4 : vec3<f16>,
+ @location(8) sclr_unorm16x2 : f16 ,
+ @location(9) sclr_unorm16x4 : f16 ,
+ @location(10) vec2_unorm16x4 : vec2<f16>,
+ @location(11) vec3_unorm16x4 : vec3<f16>,
+ @location(12) sclr_snorm16x2 : f16 ,
+ @location(13) sclr_snorm16x4 : f16 ,
+ @location(14) vec2_snorm16x4 : vec2<f16>,
+ @location(15) vec3_snorm16x4 : vec3<f16>,
+ @location(16) sclr_float16x2 : f16 ,
+ @location(17) sclr_float16x4 : f16 ,
+ @location(18) vec2_float16x4 : vec2<f16>,
+ @location(19) vec3_float16x4 : vec3<f16>,
+ @location(20) sclr_float32x2 : f16 ,
+ @location(21) sclr_float32x3 : f16 ,
+ @location(22) vec2_float32x3 : vec2<f16>,
+ @location(23) sclr_float32x4 : f16 ,
+ @location(24) vec2_float32x4 : vec2<f16>,
+ @location(25) vec3_float32x4 : vec3<f16>,
+ ) -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+struct TintVertexData {
+ tint_vertex_data : array<u32>,
+}
+
+@binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
+
+@vertex
+fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ var sclr_unorm8x2 : f16;
+ var sclr_unorm8x4 : f16;
+ var vec2_unorm8x4 : vec2<f16>;
+ var vec3_unorm8x4 : vec3<f16>;
+ var sclr_snorm8x2 : f16;
+ var sclr_snorm8x4 : f16;
+ var vec2_snorm8x4 : vec2<f16>;
+ var vec3_snorm8x4 : vec3<f16>;
+ var sclr_unorm16x2 : f16;
+ var sclr_unorm16x4 : f16;
+ var vec2_unorm16x4 : vec2<f16>;
+ var vec3_unorm16x4 : vec3<f16>;
+ var sclr_snorm16x2 : f16;
+ var sclr_snorm16x4 : f16;
+ var vec2_snorm16x4 : vec2<f16>;
+ var vec3_snorm16x4 : vec3<f16>;
+ var sclr_float16x2 : f16;
+ var sclr_float16x4 : f16;
+ var vec2_float16x4 : vec2<f16>;
+ var vec3_float16x4 : vec3<f16>;
+ var sclr_float32x2 : f16;
+ var sclr_float32x3 : f16;
+ var vec2_float32x3 : vec2<f16>;
+ var sclr_float32x4 : f16;
+ var vec2_float32x4 : vec2<f16>;
+ var vec3_float32x4 : vec3<f16>;
+ {
+ let buffer_array_base_0 = (tint_pulling_vertex_index * 64u);
+ sclr_unorm8x2 = vec2<f16>(unpack4x8unorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy).x;
+ sclr_unorm8x4 = vec4<f16>(unpack4x8unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).x;
+ vec2_unorm8x4 = vec4<f16>(unpack4x8unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).xy;
+ vec3_unorm8x4 = vec4<f16>(unpack4x8unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).xyz;
+ sclr_snorm8x2 = vec2<f16>(unpack4x8snorm((tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)] & 65535u)).xy).x;
+ sclr_snorm8x4 = vec4<f16>(unpack4x8snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).x;
+ vec2_snorm8x4 = vec4<f16>(unpack4x8snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).xy;
+ vec3_snorm8x4 = vec4<f16>(unpack4x8snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).xyz;
+ sclr_unorm16x2 = vec2<f16>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).x;
+ sclr_unorm16x4 = vec4<f16>(vec4<f32>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).x;
+ vec2_unorm16x4 = vec4<f16>(vec4<f32>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).xy;
+ vec3_unorm16x4 = vec4<f16>(vec4<f32>(unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16unorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).xyz;
+ sclr_snorm16x2 = vec2<f16>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).x;
+ sclr_snorm16x4 = vec4<f16>(vec4<f32>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).x;
+ vec2_snorm16x4 = vec4<f16>(vec4<f32>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).xy;
+ vec3_snorm16x4 = vec4<f16>(vec4<f32>(unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16snorm(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).xyz;
+ sclr_float16x2 = vec2<f16>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)])).x;
+ sclr_float16x4 = vec4<f16>(vec4<f32>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).x;
+ vec2_float16x4 = vec4<f16>(vec4<f32>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).xy;
+ vec3_float16x4 = vec4<f16>(vec4<f32>(unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), unpack2x16float(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).xyz;
+ sclr_float32x2 = vec2<f16>(vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]))).x;
+ sclr_float32x3 = vec3<f16>(vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]))).x;
+ vec2_float32x3 = vec3<f16>(vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]))).xy;
+ sclr_float32x4 = vec4<f16>(vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)]))).x;
+ vec2_float32x4 = vec4<f16>(vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)]))).xy;
+ vec3_float32x4 = vec4<f16>(vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 16u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 17u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 18u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(buffer_array_base_0 + 19u)]))).xyz;
+ }
+ return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {
+ {{256,
+ VertexStepMode::kVertex,
+ {
+ {VertexFormat::kUnorm8x2, 64, 0}, {VertexFormat::kUnorm8x4, 64, 1},
+ {VertexFormat::kUnorm8x4, 64, 2}, {VertexFormat::kUnorm8x4, 64, 3},
+ {VertexFormat::kSnorm8x2, 64, 4}, {VertexFormat::kSnorm8x4, 64, 5},
+ {VertexFormat::kSnorm8x4, 64, 6}, {VertexFormat::kSnorm8x4, 64, 7},
+ {VertexFormat::kUnorm16x2, 64, 8}, {VertexFormat::kUnorm16x4, 64, 9},
+ {VertexFormat::kUnorm16x4, 64, 10}, {VertexFormat::kUnorm16x4, 64, 11},
+ {VertexFormat::kSnorm16x2, 64, 12}, {VertexFormat::kSnorm16x4, 64, 13},
+ {VertexFormat::kSnorm16x4, 64, 14}, {VertexFormat::kSnorm16x4, 64, 15},
+ {VertexFormat::kFloat16x2, 64, 16}, {VertexFormat::kFloat16x4, 64, 17},
+ {VertexFormat::kFloat16x4, 64, 18}, {VertexFormat::kFloat16x4, 64, 19},
+ {VertexFormat::kFloat32x2, 64, 20}, {VertexFormat::kFloat32x3, 64, 21},
+ {VertexFormat::kFloat32x3, 64, 22}, {VertexFormat::kFloat32x4, 64, 23},
+ {VertexFormat::kFloat32x4, 64, 24}, {VertexFormat::kFloat32x4, 64, 25},
+ }}}};
+
+ Transform::DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/while_to_loop.cc b/src/tint/lang/wgsl/ast/transform/while_to_loop.cc
new file mode 100644
index 0000000..8f85172
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/while_to_loop.cc
@@ -0,0 +1,79 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/while_to_loop.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/break_statement.h"
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::WhileToLoop);
+
+namespace tint::ast::transform {
+namespace {
+
+bool ShouldRun(const Program* program) {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (node->Is<WhileStatement>()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+WhileToLoop::WhileToLoop() = default;
+
+WhileToLoop::~WhileToLoop() = default;
+
+Transform::ApplyResult WhileToLoop::Apply(const Program* src, const DataMap&, DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ ctx.ReplaceAll([&](const WhileStatement* w) -> const Statement* {
+ utils::Vector<const Statement*, 16> stmts;
+ auto* cond = w->condition;
+
+ // !condition
+ auto* not_cond = b.Not(ctx.Clone(cond));
+
+ // { break; }
+ auto* break_body = b.Block(b.Break());
+
+ // if (!condition) { break; }
+ stmts.Push(b.If(not_cond, break_body));
+
+ for (auto* stmt : w->body->statements) {
+ stmts.Push(ctx.Clone(stmt));
+ }
+
+ const BlockStatement* continuing = nullptr;
+
+ auto* body = b.Block(stmts);
+ auto* loop = b.Loop(body, continuing);
+
+ return loop;
+ });
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/while_to_loop.h b/src/tint/lang/wgsl/ast/transform/while_to_loop.h
new file mode 100644
index 0000000..8171912
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/while_to_loop.h
@@ -0,0 +1,40 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_WHILE_TO_LOOP_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_WHILE_TO_LOOP_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// WhileToLoop is a Transform that converts a while statement into a loop
+/// statement. This is required by the SPIR-V writer.
+class WhileToLoop final : public utils::Castable<WhileToLoop, Transform> {
+ public:
+ /// Constructor
+ WhileToLoop();
+
+ /// Destructor
+ ~WhileToLoop() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_WHILE_TO_LOOP_H_
diff --git a/src/tint/lang/wgsl/ast/transform/while_to_loop_test.cc b/src/tint/lang/wgsl/ast/transform/while_to_loop_test.cc
new file mode 100644
index 0000000..c8b5e46
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/while_to_loop_test.cc
@@ -0,0 +1,129 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/while_to_loop.h"
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using WhileToLoopTest = TransformTest;
+
+TEST_F(WhileToLoopTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<WhileToLoop>(src));
+}
+
+TEST_F(WhileToLoopTest, ShouldRunHasWhile) {
+ auto* src = R"(
+fn f() {
+ while (true) {
+ break;
+ }
+}
+)";
+
+ EXPECT_TRUE(ShouldRun<WhileToLoop>(src));
+}
+
+TEST_F(WhileToLoopTest, EmptyModule) {
+ auto* src = "";
+ auto* expect = src;
+
+ auto got = Run<WhileToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test an empty for loop.
+TEST_F(WhileToLoopTest, Empty) {
+ auto* src = R"(
+fn f() {
+ while (true) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ loop {
+ if (!(true)) {
+ break;
+ }
+ break;
+ }
+}
+)";
+
+ auto got = Run<WhileToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test a for loop with non-empty body.
+TEST_F(WhileToLoopTest, Body) {
+ auto* src = R"(
+fn f() {
+ while (true) {
+ discard;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ loop {
+ if (!(true)) {
+ break;
+ }
+ discard;
+ }
+}
+)";
+
+ auto got = Run<WhileToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test a loop with a break condition
+TEST_F(WhileToLoopTest, BreakCondition) {
+ auto* src = R"(
+fn f() {
+ while (0 == 1) {
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ loop {
+ if (!((0 == 1))) {
+ break;
+ }
+ }
+}
+)";
+
+ auto got = Run<WhileToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.cc b/src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.cc
new file mode 100644
index 0000000..a6edc6d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.cc
@@ -0,0 +1,483 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.h"
+
+#include <algorithm>
+#include <map>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "src/tint/builtin/builtin_value.h"
+#include "src/tint/lang/wgsl/ast/workgroup_attribute.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/type/atomic.h"
+#include "src/tint/utils/map.h"
+#include "src/tint/utils/unique_vector.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::ZeroInitWorkgroupMemory);
+
+namespace tint::ast::transform {
+namespace {
+
+bool ShouldRun(const Program* program) {
+ for (auto* global : program->AST().GlobalVariables()) {
+ if (auto* var = global->As<Var>()) {
+ auto* v = program->Sem().Get(var);
+ if (v->AddressSpace() == builtin::AddressSpace::kWorkgroup) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+using StatementList = utils::Vector<const Statement*, 8>;
+
+/// PIMPL state for the transform
+struct ZeroInitWorkgroupMemory::State {
+ /// The clone context
+ CloneContext& ctx;
+
+ /// An alias to *ctx.dst
+ ProgramBuilder& b = *ctx.dst;
+
+ /// The constant size of the workgroup. If 0, then #workgroup_size_expr should
+ /// be used instead.
+ uint32_t workgroup_size_const = 0;
+ /// The size of the workgroup as an expression generator. Use if
+ /// #workgroup_size_const is 0.
+ std::function<const ast::Expression*()> workgroup_size_expr;
+
+ /// ArrayIndex represents a function on the local invocation index, of
+ /// the form: `array_index = (local_invocation_index % modulo) / division`
+ struct ArrayIndex {
+ /// The RHS of the modulus part of the expression
+ uint32_t modulo = 1;
+ /// The RHS of the division part of the expression
+ uint32_t division = 1;
+
+ /// Equality operator
+ /// @param i the ArrayIndex to compare to this ArrayIndex
+ /// @returns true if `i` and this ArrayIndex are equal
+ bool operator==(const ArrayIndex& i) const {
+ return modulo == i.modulo && division == i.division;
+ }
+
+ /// Hash function for the ArrayIndex type
+ struct Hasher {
+ /// @param i the ArrayIndex to calculate a hash for
+ /// @returns the hash value for the ArrayIndex `i`
+ size_t operator()(const ArrayIndex& i) const {
+ return utils::Hash(i.modulo, i.division);
+ }
+ };
+ };
+
+ /// A list of unique ArrayIndex
+ using ArrayIndices = utils::UniqueVector<ArrayIndex, 4, ArrayIndex::Hasher>;
+
+ /// Expression holds information about an expression that is being built for a
+ /// statement will zero workgroup values.
+ struct Expression {
+ /// The AST expression node
+ const ast::Expression* expr = nullptr;
+ /// The number of iterations required to zero the value
+ uint32_t num_iterations = 0;
+ /// All array indices used by this expression
+ ArrayIndices array_indices;
+
+ /// @returns true if the expr is not null (null usually indicates a failure)
+ operator bool() const { return expr != nullptr; }
+ };
+
+ /// Statement holds information about a statement that will zero workgroup
+ /// values.
+ struct Statement {
+ /// The AST statement node
+ const ast::Statement* stmt;
+ /// The number of iterations required to zero the value
+ uint32_t num_iterations;
+ /// All array indices used by this statement
+ ArrayIndices array_indices;
+ };
+
+ /// All statements that zero workgroup memory
+ std::vector<Statement> statements;
+
+ /// A map of ArrayIndex to the name reserved for the `let` declaration of that
+ /// index.
+ std::unordered_map<ArrayIndex, Symbol, ArrayIndex::Hasher> array_index_names;
+
+ /// Constructor
+ /// @param c the CloneContext used for the transform
+ explicit State(CloneContext& c) : ctx(c) {}
+
+ /// Run inserts the workgroup memory zero-initialization logic at the top of
+ /// the given function
+ /// @param fn a compute shader entry point function
+ void Run(const Function* fn) {
+ auto& sem = ctx.src->Sem();
+
+ CalculateWorkgroupSize(GetAttribute<WorkgroupAttribute>(fn->attributes));
+
+ // Generate a list of statements to zero initialize each of the
+ // workgroup storage variables used by `fn`. This will populate #statements.
+ auto* func = sem.Get(fn);
+ for (auto* var : func->TransitivelyReferencedGlobals()) {
+ if (var->AddressSpace() == builtin::AddressSpace::kWorkgroup) {
+ auto get_expr = [&](uint32_t num_values) {
+ auto var_name = ctx.Clone(var->Declaration()->name->symbol);
+ return Expression{b.Expr(var_name), num_values, ArrayIndices{}};
+ };
+ if (!BuildZeroingStatements(var->Type()->UnwrapRef(), get_expr)) {
+ return;
+ }
+ }
+ }
+
+ if (statements.empty()) {
+ return; // No workgroup variables to initialize.
+ }
+
+ // Scan the entry point for an existing local_invocation_index builtin
+ // parameter
+ std::function<const ast::Expression*()> local_index;
+ for (auto* param : fn->params) {
+ if (auto* builtin_attr = GetAttribute<BuiltinAttribute>(param->attributes)) {
+ auto builtin = sem.Get(builtin_attr)->Value();
+ if (builtin == builtin::BuiltinValue::kLocalInvocationIndex) {
+ local_index = [=] { return b.Expr(ctx.Clone(param->name->symbol)); };
+ break;
+ }
+ }
+
+ if (auto* str = sem.Get(param)->Type()->As<type::Struct>()) {
+ for (auto* member : str->Members()) {
+ if (member->Attributes().builtin ==
+ builtin::BuiltinValue::kLocalInvocationIndex) {
+ local_index = [=] {
+ auto* param_expr = b.Expr(ctx.Clone(param->name->symbol));
+ auto member_name = ctx.Clone(member->Name());
+ return b.MemberAccessor(param_expr, member_name);
+ };
+ break;
+ }
+ }
+ }
+ }
+ if (!local_index) {
+ // No existing local index parameter. Append one to the entry point.
+ auto param_name = b.Symbols().New("local_invocation_index");
+ auto* local_invocation_index = b.Builtin(builtin::BuiltinValue::kLocalInvocationIndex);
+ auto* param = b.Param(param_name, b.ty.u32(), utils::Vector{local_invocation_index});
+ ctx.InsertBack(fn->params, param);
+ local_index = [=] { return b.Expr(param->name->symbol); };
+ }
+
+ // Take the zeroing statements and bin them by the number of iterations
+ // required to zero the workgroup data. We then emit these in blocks,
+ // possibly wrapped in if-statements or for-loops.
+ std::unordered_map<uint32_t, std::vector<Statement>> stmts_by_num_iterations;
+ std::vector<uint32_t> num_sorted_iterations;
+ for (auto& s : statements) {
+ auto& stmts = stmts_by_num_iterations[s.num_iterations];
+ if (stmts.empty()) {
+ num_sorted_iterations.emplace_back(s.num_iterations);
+ }
+ stmts.emplace_back(s);
+ }
+ std::sort(num_sorted_iterations.begin(), num_sorted_iterations.end());
+
+ // Loop over the statements, grouped by num_iterations.
+ for (auto num_iterations : num_sorted_iterations) {
+ auto& stmts = stmts_by_num_iterations[num_iterations];
+
+ // Gather all the array indices used by all the statements in the block.
+ ArrayIndices array_indices;
+ for (auto& s : stmts) {
+ for (auto& idx : s.array_indices) {
+ array_indices.Add(idx);
+ }
+ }
+
+ // Determine the block type used to emit these statements.
+
+ if (workgroup_size_const == 0 || num_iterations > workgroup_size_const) {
+ // Either the workgroup size is dynamic, or smaller than num_iterations.
+ // In either case, we need to generate a for loop to ensure we
+ // initialize all the array elements.
+ //
+ // for (var idx : u32 = local_index;
+ // idx < num_iterations;
+ // idx += workgroup_size) {
+ // ...
+ // }
+ auto idx = b.Symbols().New("idx");
+ auto* init = b.Decl(b.Var(idx, b.ty.u32(), local_index()));
+ auto* cond = b.create<BinaryExpression>(BinaryOp::kLessThan, b.Expr(idx),
+ b.Expr(u32(num_iterations)));
+ auto* cont = b.Assign(
+ idx, b.Add(idx, workgroup_size_const ? b.Expr(u32(workgroup_size_const))
+ : workgroup_size_expr()));
+
+ auto block =
+ DeclareArrayIndices(num_iterations, array_indices, [&] { return b.Expr(idx); });
+ for (auto& s : stmts) {
+ block.Push(s.stmt);
+ }
+ auto* for_loop = b.For(init, cond, cont, b.Block(block));
+ ctx.InsertFront(fn->body->statements, for_loop);
+ } else if (num_iterations < workgroup_size_const) {
+ // Workgroup size is a known constant, but is greater than
+ // num_iterations. Emit an if statement:
+ //
+ // if (local_index < num_iterations) {
+ // ...
+ // }
+ auto* cond = b.create<BinaryExpression>(BinaryOp::kLessThan, local_index(),
+ b.Expr(u32(num_iterations)));
+ auto block = DeclareArrayIndices(num_iterations, array_indices,
+ [&] { return b.Expr(local_index()); });
+ for (auto& s : stmts) {
+ block.Push(s.stmt);
+ }
+ auto* if_stmt = b.If(cond, b.Block(block));
+ ctx.InsertFront(fn->body->statements, if_stmt);
+ } else {
+ // Workgroup size exactly equals num_iterations.
+ // No need for any conditionals. Just emit a basic block:
+ //
+ // {
+ // ...
+ // }
+ auto block = DeclareArrayIndices(num_iterations, array_indices,
+ [&] { return b.Expr(local_index()); });
+ for (auto& s : stmts) {
+ block.Push(s.stmt);
+ }
+ ctx.InsertFront(fn->body->statements, b.Block(block));
+ }
+ }
+
+ // Append a single workgroup barrier after the zero initialization.
+ ctx.InsertFront(fn->body->statements, b.CallStmt(b.Call("workgroupBarrier")));
+ }
+
+ /// BuildZeroingExpr is a function that builds a sub-expression used to zero
+ /// workgroup values. `num_values` is the number of elements that the
+ /// expression will be used to zero. Returns the expression.
+ using BuildZeroingExpr = std::function<Expression(uint32_t num_values)>;
+
+ /// BuildZeroingStatements() generates the statements required to zero
+ /// initialize the workgroup storage expression of type `ty`.
+ /// @param ty the expression type
+ /// @param get_expr a function that builds the AST nodes for the expression.
+ /// @returns true on success, false on failure
+ [[nodiscard]] bool BuildZeroingStatements(const type::Type* ty,
+ const BuildZeroingExpr& get_expr) {
+ if (CanTriviallyZero(ty)) {
+ auto var = get_expr(1u);
+ if (!var) {
+ return false;
+ }
+ auto* zero_init = b.Call(CreateASTTypeFor(ctx, ty));
+ statements.emplace_back(
+ Statement{b.Assign(var.expr, zero_init), var.num_iterations, var.array_indices});
+ return true;
+ }
+
+ if (auto* atomic = ty->As<type::Atomic>()) {
+ auto* zero_init = b.Call(CreateASTTypeFor(ctx, atomic->Type()));
+ auto expr = get_expr(1u);
+ if (!expr) {
+ return false;
+ }
+ auto* store = b.Call("atomicStore", b.AddressOf(expr.expr), zero_init);
+ statements.emplace_back(
+ Statement{b.CallStmt(store), expr.num_iterations, expr.array_indices});
+ return true;
+ }
+
+ if (auto* str = ty->As<type::Struct>()) {
+ for (auto* member : str->Members()) {
+ auto name = ctx.Clone(member->Name());
+ auto get_member = [&](uint32_t num_values) {
+ auto s = get_expr(num_values);
+ if (!s) {
+ return Expression{}; // error
+ }
+ return Expression{b.MemberAccessor(s.expr, name), s.num_iterations,
+ s.array_indices};
+ };
+ if (!BuildZeroingStatements(member->Type(), get_member)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ if (auto* arr = ty->As<type::Array>()) {
+ auto get_el = [&](uint32_t num_values) {
+ // num_values is the number of values to zero for the element type.
+ // The number of iterations required to zero the array and its elements is:
+ // `num_values * arr->Count()`
+ // The index for this array is:
+ // `(idx % modulo) / division`
+ auto count = arr->ConstantCount();
+ if (!count) {
+ ctx.dst->Diagnostics().add_error(diag::System::Transform,
+ type::Array::kErrExpectedConstantCount);
+ return Expression{}; // error
+ }
+ auto modulo = num_values * count.value();
+ auto division = num_values;
+ auto a = get_expr(modulo);
+ if (!a) {
+ return Expression{}; // error
+ }
+ auto array_indices = a.array_indices;
+ array_indices.Add(ArrayIndex{modulo, division});
+ auto index = utils::GetOrCreate(array_index_names, ArrayIndex{modulo, division},
+ [&] { return b.Symbols().New("i"); });
+ return Expression{b.IndexAccessor(a.expr, index), a.num_iterations, array_indices};
+ };
+ return BuildZeroingStatements(arr->ElemType(), get_el);
+ }
+
+ TINT_UNREACHABLE(Transform, b.Diagnostics())
+ << "could not zero workgroup type: " << ty->FriendlyName();
+ return false;
+ }
+
+ /// DeclareArrayIndices returns a list of statements that contain the `let`
+ /// declarations for all of the ArrayIndices.
+ /// @param num_iterations the number of iterations for the block
+ /// @param array_indices the list of array indices to generate `let`
+ /// declarations for
+ /// @param iteration a function that returns the index of the current
+ /// iteration.
+ /// @returns the list of `let` statements that declare the array indices
+ StatementList DeclareArrayIndices(uint32_t num_iterations,
+ const ArrayIndices& array_indices,
+ const std::function<const ast::Expression*()>& iteration) {
+ StatementList stmts;
+ std::map<Symbol, ArrayIndex> indices_by_name;
+ for (auto index : array_indices) {
+ auto name = array_index_names.at(index);
+ auto* mod = (num_iterations > index.modulo)
+ ? b.create<BinaryExpression>(BinaryOp::kModulo, iteration(),
+ b.Expr(u32(index.modulo)))
+ : iteration();
+ auto* div = (index.division != 1u) ? b.Div(mod, u32(index.division)) : mod;
+ auto* decl = b.Decl(b.Let(name, b.ty.u32(), div));
+ stmts.Push(decl);
+ }
+ return stmts;
+ }
+
+ /// CalculateWorkgroupSize initializes the members #workgroup_size_const and
+ /// #workgroup_size_expr with the linear workgroup size.
+ /// @param attr the workgroup attribute applied to the entry point function
+ void CalculateWorkgroupSize(const WorkgroupAttribute* attr) {
+ bool is_signed = false;
+ workgroup_size_const = 1u;
+ workgroup_size_expr = nullptr;
+ for (auto* expr : attr->Values()) {
+ if (!expr) {
+ continue;
+ }
+ auto* sem = ctx.src->Sem().GetVal(expr);
+ if (auto* c = sem->ConstantValue()) {
+ workgroup_size_const *= c->ValueAs<AInt>();
+ continue;
+ }
+ // Constant value could not be found. Build expression instead.
+ workgroup_size_expr = [this, expr, size = workgroup_size_expr] {
+ auto* e = ctx.Clone(expr);
+ if (ctx.src->TypeOf(expr)->UnwrapRef()->Is<type::I32>()) {
+ e = b.Call<u32>(e);
+ }
+ return size ? b.Mul(size(), e) : e;
+ };
+ }
+ if (workgroup_size_expr) {
+ if (workgroup_size_const != 1) {
+ // Fold workgroup_size_const in to workgroup_size_expr
+ workgroup_size_expr = [this, is_signed, const_size = workgroup_size_const,
+ expr_size = workgroup_size_expr] {
+ return is_signed ? b.Mul(expr_size(), i32(const_size))
+ : b.Mul(expr_size(), u32(const_size));
+ };
+ }
+ // Indicate that workgroup_size_expr should be used instead of the
+ // constant.
+ workgroup_size_const = 0;
+ }
+ }
+
+ /// @returns true if a variable with store type `ty` can be efficiently zeroed
+ /// by assignment of a value constructor without operands. If CanTriviallyZero() returns false,
+ /// then the type needs to be initialized by decomposing the initialization into multiple
+ /// sub-initializations.
+ /// @param ty the type to inspect
+ bool CanTriviallyZero(const type::Type* ty) {
+ if (ty->Is<type::Atomic>()) {
+ return false;
+ }
+ if (auto* str = ty->As<type::Struct>()) {
+ for (auto* member : str->Members()) {
+ if (!CanTriviallyZero(member->Type())) {
+ return false;
+ }
+ }
+ }
+ if (ty->Is<type::Array>()) {
+ return false;
+ }
+ // True for all other storable types
+ return true;
+ }
+};
+
+ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default;
+
+ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default;
+
+Transform::ApplyResult ZeroInitWorkgroupMemory::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ for (auto* fn : src->AST().Functions()) {
+ if (fn->PipelineStage() == PipelineStage::kCompute) {
+ State{ctx}.Run(fn);
+ }
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.h b/src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.h
new file mode 100644
index 0000000..9d69b9e
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.h
@@ -0,0 +1,44 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_ZERO_INIT_WORKGROUP_MEMORY_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_ZERO_INIT_WORKGROUP_MEMORY_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// ZeroInitWorkgroupMemory is a transform that injects code at the top of entry
+/// points to zero-initialize workgroup memory used by that entry point (and all
+/// transitive functions called by that entry point)
+class ZeroInitWorkgroupMemory final : public utils::Castable<ZeroInitWorkgroupMemory, Transform> {
+ public:
+ /// Constructor
+ ZeroInitWorkgroupMemory();
+
+ /// Destructor
+ ~ZeroInitWorkgroupMemory() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_ZERO_INIT_WORKGROUP_MEMORY_H_
diff --git a/src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory_test.cc b/src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory_test.cc
new file mode 100644
index 0000000..54a9854
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory_test.cc
@@ -0,0 +1,1390 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using ZeroInitWorkgroupMemoryTest = TransformTest;
+
+TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<ZeroInitWorkgroupMemory>(src));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunHasNoWorkgroupVars) {
+ auto* src = R"(
+var<private> v : i32;
+)";
+
+ EXPECT_FALSE(ShouldRun<ZeroInitWorkgroupMemory>(src));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunHasWorkgroupVars) {
+ auto* src = R"(
+var<workgroup> a : i32;
+)";
+
+ EXPECT_TRUE(ShouldRun<ZeroInitWorkgroupMemory>(src));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, EmptyModule) {
+ auto* src = "";
+ auto* expect = src;
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, NoWorkgroupVars) {
+ auto* src = R"(
+var<private> v : i32;
+
+fn f() {
+ v = 1;
+}
+)";
+ auto* expect = src;
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, UnreferencedWorkgroupVars) {
+ auto* src = R"(
+var<workgroup> a : i32;
+
+var<workgroup> b : i32;
+
+var<workgroup> c : i32;
+
+fn unreferenced() {
+ b = c;
+}
+
+@compute @workgroup_size(1)
+fn f() {
+}
+)";
+ auto* expect = src;
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, UnreferencedWorkgroupVars_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn f() {
+}
+
+fn unreferenced() {
+ b = c;
+}
+
+var<workgroup> a : i32;
+
+var<workgroup> b : i32;
+
+var<workgroup> c : i32;
+)";
+ auto* expect = src;
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_ExistingLocalIndex) {
+ auto* src = R"(
+var<workgroup> v : i32;
+
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ _ = v; // Initialization should be inserted above this statement
+}
+)";
+ auto* expect = R"(
+var<workgroup> v : i32;
+
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ {
+ v = i32();
+ }
+ workgroupBarrier();
+ _ = v;
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_ExistingLocalIndex_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ _ = v; // Initialization should be inserted above this statement
+}
+
+var<workgroup> v : i32;
+)";
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ {
+ v = i32();
+ }
+ workgroupBarrier();
+ _ = v;
+}
+
+var<workgroup> v : i32;
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_ExistingLocalIndexInStruct) {
+ auto* src = R"(
+var<workgroup> v : i32;
+
+struct Params {
+ @builtin(local_invocation_index) local_idx : u32,
+};
+
+@compute @workgroup_size(1)
+fn f(params : Params) {
+ _ = v; // Initialization should be inserted above this statement
+}
+)";
+ auto* expect = R"(
+var<workgroup> v : i32;
+
+struct Params {
+ @builtin(local_invocation_index)
+ local_idx : u32,
+}
+
+@compute @workgroup_size(1)
+fn f(params : Params) {
+ {
+ v = i32();
+ }
+ workgroupBarrier();
+ _ = v;
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_ExistingLocalIndexInStruct_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn f(params : Params) {
+ _ = v; // Initialization should be inserted above this statement
+}
+
+struct Params {
+ @builtin(local_invocation_index) local_idx : u32,
+};
+
+var<workgroup> v : i32;
+)";
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn f(params : Params) {
+ {
+ v = i32();
+ }
+ workgroupBarrier();
+ _ = v;
+}
+
+struct Params {
+ @builtin(local_invocation_index)
+ local_idx : u32,
+}
+
+var<workgroup> v : i32;
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_InjectedLocalIndex) {
+ auto* src = R"(
+var<workgroup> v : i32;
+
+@compute @workgroup_size(1)
+fn f() {
+ _ = v; // Initialization should be inserted above this statement
+}
+)";
+ auto* expect = R"(
+var<workgroup> v : i32;
+
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
+ {
+ v = i32();
+ }
+ workgroupBarrier();
+ _ = v;
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_InjectedLocalIndex_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn f() {
+ _ = v; // Initialization should be inserted above this statement
+}
+
+var<workgroup> v : i32;
+)";
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
+ {
+ v = i32();
+ }
+ workgroupBarrier();
+ _ = v;
+}
+
+var<workgroup> v : i32;
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_ExistingLocalIndex_Size1) {
+ auto* src = R"(
+struct S {
+ x : i32,
+ y : array<i32, 8>,
+};
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ _ = a; // Initialization should be inserted above this statement
+ _ = b;
+ _ = c;
+}
+)";
+ auto* expect = R"(
+struct S {
+ x : i32,
+ y : array<i32, 8>,
+}
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ {
+ a = i32();
+ b.x = i32();
+ }
+ for(var idx : u32 = local_idx; (idx < 8u); idx = (idx + 1u)) {
+ let i : u32 = idx;
+ b.y[i] = i32();
+ }
+ for(var idx_1 : u32 = local_idx; (idx_1 < 32u); idx_1 = (idx_1 + 1u)) {
+ let i_1 : u32 = idx_1;
+ c[i_1].x = i32();
+ }
+ for(var idx_2 : u32 = local_idx; (idx_2 < 256u); idx_2 = (idx_2 + 1u)) {
+ let i_2 : u32 = (idx_2 / 8u);
+ let i : u32 = (idx_2 % 8u);
+ c[i_2].y[i] = i32();
+ }
+ workgroupBarrier();
+ _ = a;
+ _ = b;
+ _ = c;
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_ExistingLocalIndex_Size1_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ _ = a; // Initialization should be inserted above this statement
+ _ = b;
+ _ = c;
+}
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+struct S {
+ x : i32,
+ y : array<i32, 8>,
+};
+)";
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ {
+ a = i32();
+ b.x = i32();
+ }
+ for(var idx : u32 = local_idx; (idx < 8u); idx = (idx + 1u)) {
+ let i : u32 = idx;
+ b.y[i] = i32();
+ }
+ for(var idx_1 : u32 = local_idx; (idx_1 < 32u); idx_1 = (idx_1 + 1u)) {
+ let i_1 : u32 = idx_1;
+ c[i_1].x = i32();
+ }
+ for(var idx_2 : u32 = local_idx; (idx_2 < 256u); idx_2 = (idx_2 + 1u)) {
+ let i_2 : u32 = (idx_2 / 8u);
+ let i : u32 = (idx_2 % 8u);
+ c[i_2].y[i] = i32();
+ }
+ workgroupBarrier();
+ _ = a;
+ _ = b;
+ _ = c;
+}
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+struct S {
+ x : i32,
+ y : array<i32, 8>,
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_ExistingLocalIndex_Size_2_3) {
+ auto* src = R"(
+struct S {
+ x : i32,
+ y : array<i32, 8>,
+};
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+@compute @workgroup_size(2, 3)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ _ = a; // Initialization should be inserted above this statement
+ _ = b;
+ _ = c;
+}
+)";
+ auto* expect = R"(
+struct S {
+ x : i32,
+ y : array<i32, 8>,
+}
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+@compute @workgroup_size(2, 3)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ if ((local_idx < 1u)) {
+ a = i32();
+ b.x = i32();
+ }
+ for(var idx : u32 = local_idx; (idx < 8u); idx = (idx + 6u)) {
+ let i : u32 = idx;
+ b.y[i] = i32();
+ }
+ for(var idx_1 : u32 = local_idx; (idx_1 < 32u); idx_1 = (idx_1 + 6u)) {
+ let i_1 : u32 = idx_1;
+ c[i_1].x = i32();
+ }
+ for(var idx_2 : u32 = local_idx; (idx_2 < 256u); idx_2 = (idx_2 + 6u)) {
+ let i_2 : u32 = (idx_2 / 8u);
+ let i : u32 = (idx_2 % 8u);
+ c[i_2].y[i] = i32();
+ }
+ workgroupBarrier();
+ _ = a;
+ _ = b;
+ _ = c;
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_ExistingLocalIndex_Size_2_3_X) {
+ auto* src = R"(
+struct S {
+ x : i32,
+ y : array<i32, 8>,
+};
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+@id(1) override X : i32;
+
+@compute @workgroup_size(2, 3, X)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ _ = a; // Initialization should be inserted above this statement
+ _ = b;
+ _ = c;
+}
+)";
+ auto* expect =
+ R"(
+struct S {
+ x : i32,
+ y : array<i32, 8>,
+}
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+@id(1) override X : i32;
+
+@compute @workgroup_size(2, 3, X)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ for(var idx : u32 = local_idx; (idx < 1u); idx = (idx + (u32(X) * 6u))) {
+ a = i32();
+ b.x = i32();
+ }
+ for(var idx_1 : u32 = local_idx; (idx_1 < 8u); idx_1 = (idx_1 + (u32(X) * 6u))) {
+ let i : u32 = idx_1;
+ b.y[i] = i32();
+ }
+ for(var idx_2 : u32 = local_idx; (idx_2 < 32u); idx_2 = (idx_2 + (u32(X) * 6u))) {
+ let i_1 : u32 = idx_2;
+ c[i_1].x = i32();
+ }
+ for(var idx_3 : u32 = local_idx; (idx_3 < 256u); idx_3 = (idx_3 + (u32(X) * 6u))) {
+ let i_2 : u32 = (idx_3 / 8u);
+ let i : u32 = (idx_3 % 8u);
+ c[i_2].y[i] = i32();
+ }
+ workgroupBarrier();
+ _ = a;
+ _ = b;
+ _ = c;
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_ExistingLocalIndex_Size_5u_X_10u) {
+ auto* src = R"(
+struct S {
+ x : array<array<i32, 8>, 10>,
+ y : array<i32, 8>,
+ z : array<array<array<i32, 8>, 10>, 20>,
+};
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+@id(1) override X : u32;
+
+@compute @workgroup_size(5u, X, 10u)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ _ = a; // Initialization should be inserted above this statement
+ _ = b;
+ _ = c;
+}
+)";
+ auto* expect =
+ R"(
+struct S {
+ x : array<array<i32, 8>, 10>,
+ y : array<i32, 8>,
+ z : array<array<array<i32, 8>, 10>, 20>,
+}
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+@id(1) override X : u32;
+
+@compute @workgroup_size(5u, X, 10u)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ for(var idx : u32 = local_idx; (idx < 1u); idx = (idx + (X * 50u))) {
+ a = i32();
+ }
+ for(var idx_1 : u32 = local_idx; (idx_1 < 8u); idx_1 = (idx_1 + (X * 50u))) {
+ let i_1 : u32 = idx_1;
+ b.y[i_1] = i32();
+ }
+ for(var idx_2 : u32 = local_idx; (idx_2 < 80u); idx_2 = (idx_2 + (X * 50u))) {
+ let i : u32 = (idx_2 / 8u);
+ let i_1 : u32 = (idx_2 % 8u);
+ b.x[i][i_1] = i32();
+ }
+ for(var idx_3 : u32 = local_idx; (idx_3 < 256u); idx_3 = (idx_3 + (X * 50u))) {
+ let i_4 : u32 = (idx_3 / 8u);
+ let i_1 : u32 = (idx_3 % 8u);
+ c[i_4].y[i_1] = i32();
+ }
+ for(var idx_4 : u32 = local_idx; (idx_4 < 1600u); idx_4 = (idx_4 + (X * 50u))) {
+ let i_2 : u32 = (idx_4 / 80u);
+ let i : u32 = ((idx_4 % 80u) / 8u);
+ let i_1 : u32 = (idx_4 % 8u);
+ b.z[i_2][i][i_1] = i32();
+ }
+ for(var idx_5 : u32 = local_idx; (idx_5 < 2560u); idx_5 = (idx_5 + (X * 50u))) {
+ let i_3 : u32 = (idx_5 / 80u);
+ let i : u32 = ((idx_5 % 80u) / 8u);
+ let i_1 : u32 = (idx_5 % 8u);
+ c[i_3].x[i][i_1] = i32();
+ }
+ for(var idx_6 : u32 = local_idx; (idx_6 < 51200u); idx_6 = (idx_6 + (X * 50u))) {
+ let i_5 : u32 = (idx_6 / 1600u);
+ let i_2 : u32 = ((idx_6 % 1600u) / 80u);
+ let i : u32 = ((idx_6 % 80u) / 8u);
+ let i_1 : u32 = (idx_6 % 8u);
+ c[i_5].z[i_2][i][i_1] = i32();
+ }
+ workgroupBarrier();
+ _ = a;
+ _ = b;
+ _ = c;
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_InjectedLocalIndex) {
+ auto* src = R"(
+struct S {
+ x : i32,
+ y : array<i32, 8>,
+};
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_id) local_invocation_id : vec3<u32>) {
+ _ = a; // Initialization should be inserted above this statement
+ _ = b;
+ _ = c;
+}
+)";
+ auto* expect = R"(
+struct S {
+ x : i32,
+ y : array<i32, 8>,
+}
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_id) local_invocation_id : vec3<u32>, @builtin(local_invocation_index) local_invocation_index : u32) {
+ {
+ a = i32();
+ b.x = i32();
+ }
+ for(var idx : u32 = local_invocation_index; (idx < 8u); idx = (idx + 1u)) {
+ let i : u32 = idx;
+ b.y[i] = i32();
+ }
+ for(var idx_1 : u32 = local_invocation_index; (idx_1 < 32u); idx_1 = (idx_1 + 1u)) {
+ let i_1 : u32 = idx_1;
+ c[i_1].x = i32();
+ }
+ for(var idx_2 : u32 = local_invocation_index; (idx_2 < 256u); idx_2 = (idx_2 + 1u)) {
+ let i_2 : u32 = (idx_2 / 8u);
+ let i : u32 = (idx_2 % 8u);
+ c[i_2].y[i] = i32();
+ }
+ workgroupBarrier();
+ _ = a;
+ _ = b;
+ _ = c;
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_InjectedLocalIndex_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_id) local_invocation_id : vec3<u32>) {
+ _ = a; // Initialization should be inserted above this statement
+ _ = b;
+ _ = c;
+}
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+struct S {
+ x : i32,
+ y : array<i32, 8>,
+};
+)";
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_id) local_invocation_id : vec3<u32>, @builtin(local_invocation_index) local_invocation_index : u32) {
+ {
+ a = i32();
+ b.x = i32();
+ }
+ for(var idx : u32 = local_invocation_index; (idx < 8u); idx = (idx + 1u)) {
+ let i : u32 = idx;
+ b.y[i] = i32();
+ }
+ for(var idx_1 : u32 = local_invocation_index; (idx_1 < 32u); idx_1 = (idx_1 + 1u)) {
+ let i_1 : u32 = idx_1;
+ c[i_1].x = i32();
+ }
+ for(var idx_2 : u32 = local_invocation_index; (idx_2 < 256u); idx_2 = (idx_2 + 1u)) {
+ let i_2 : u32 = (idx_2 / 8u);
+ let i : u32 = (idx_2 % 8u);
+ c[i_2].y[i] = i32();
+ }
+ workgroupBarrier();
+ _ = a;
+ _ = b;
+ _ = c;
+}
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+struct S {
+ x : i32,
+ y : array<i32, 8>,
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_MultipleEntryPoints) {
+ auto* src = R"(
+struct S {
+ x : i32,
+ y : array<i32, 8>,
+};
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+@compute @workgroup_size(1)
+fn f1() {
+ _ = a; // Initialization should be inserted above this statement
+ _ = c;
+}
+
+@compute @workgroup_size(1, 2, 3)
+fn f2(@builtin(local_invocation_id) local_invocation_id : vec3<u32>) {
+ _ = b; // Initialization should be inserted above this statement
+}
+
+@compute @workgroup_size(4, 5, 6)
+fn f3() {
+ _ = c; // Initialization should be inserted above this statement
+ _ = a;
+}
+)";
+ auto* expect = R"(
+struct S {
+ x : i32,
+ y : array<i32, 8>,
+}
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+@compute @workgroup_size(1)
+fn f1(@builtin(local_invocation_index) local_invocation_index : u32) {
+ {
+ a = i32();
+ }
+ for(var idx : u32 = local_invocation_index; (idx < 32u); idx = (idx + 1u)) {
+ let i : u32 = idx;
+ c[i].x = i32();
+ }
+ for(var idx_1 : u32 = local_invocation_index; (idx_1 < 256u); idx_1 = (idx_1 + 1u)) {
+ let i_1 : u32 = (idx_1 / 8u);
+ let i_2 : u32 = (idx_1 % 8u);
+ c[i_1].y[i_2] = i32();
+ }
+ workgroupBarrier();
+ _ = a;
+ _ = c;
+}
+
+@compute @workgroup_size(1, 2, 3)
+fn f2(@builtin(local_invocation_id) local_invocation_id : vec3<u32>, @builtin(local_invocation_index) local_invocation_index_1 : u32) {
+ if ((local_invocation_index_1 < 1u)) {
+ b.x = i32();
+ }
+ for(var idx_2 : u32 = local_invocation_index_1; (idx_2 < 8u); idx_2 = (idx_2 + 6u)) {
+ let i_3 : u32 = idx_2;
+ b.y[i_3] = i32();
+ }
+ workgroupBarrier();
+ _ = b;
+}
+
+@compute @workgroup_size(4, 5, 6)
+fn f3(@builtin(local_invocation_index) local_invocation_index_2 : u32) {
+ if ((local_invocation_index_2 < 1u)) {
+ a = i32();
+ }
+ if ((local_invocation_index_2 < 32u)) {
+ let i_4 : u32 = local_invocation_index_2;
+ c[i_4].x = i32();
+ }
+ for(var idx_3 : u32 = local_invocation_index_2; (idx_3 < 256u); idx_3 = (idx_3 + 120u)) {
+ let i_5 : u32 = (idx_3 / 8u);
+ let i_6 : u32 = (idx_3 % 8u);
+ c[i_5].y[i_6] = i32();
+ }
+ workgroupBarrier();
+ _ = c;
+ _ = a;
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_MultipleEntryPoints_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn f1() {
+ _ = a; // Initialization should be inserted above this statement
+ _ = c;
+}
+
+@compute @workgroup_size(1, 2, 3)
+fn f2(@builtin(local_invocation_id) local_invocation_id : vec3<u32>) {
+ _ = b; // Initialization should be inserted above this statement
+}
+
+@compute @workgroup_size(4, 5, 6)
+fn f3() {
+ _ = c; // Initialization should be inserted above this statement
+ _ = a;
+}
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+struct S {
+ x : i32,
+ y : array<i32, 8>,
+};
+)";
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn f1(@builtin(local_invocation_index) local_invocation_index : u32) {
+ {
+ a = i32();
+ }
+ for(var idx : u32 = local_invocation_index; (idx < 32u); idx = (idx + 1u)) {
+ let i : u32 = idx;
+ c[i].x = i32();
+ }
+ for(var idx_1 : u32 = local_invocation_index; (idx_1 < 256u); idx_1 = (idx_1 + 1u)) {
+ let i_1 : u32 = (idx_1 / 8u);
+ let i_2 : u32 = (idx_1 % 8u);
+ c[i_1].y[i_2] = i32();
+ }
+ workgroupBarrier();
+ _ = a;
+ _ = c;
+}
+
+@compute @workgroup_size(1, 2, 3)
+fn f2(@builtin(local_invocation_id) local_invocation_id : vec3<u32>, @builtin(local_invocation_index) local_invocation_index_1 : u32) {
+ if ((local_invocation_index_1 < 1u)) {
+ b.x = i32();
+ }
+ for(var idx_2 : u32 = local_invocation_index_1; (idx_2 < 8u); idx_2 = (idx_2 + 6u)) {
+ let i_3 : u32 = idx_2;
+ b.y[i_3] = i32();
+ }
+ workgroupBarrier();
+ _ = b;
+}
+
+@compute @workgroup_size(4, 5, 6)
+fn f3(@builtin(local_invocation_index) local_invocation_index_2 : u32) {
+ if ((local_invocation_index_2 < 1u)) {
+ a = i32();
+ }
+ if ((local_invocation_index_2 < 32u)) {
+ let i_4 : u32 = local_invocation_index_2;
+ c[i_4].x = i32();
+ }
+ for(var idx_3 : u32 = local_invocation_index_2; (idx_3 < 256u); idx_3 = (idx_3 + 120u)) {
+ let i_5 : u32 = (idx_3 / 8u);
+ let i_6 : u32 = (idx_3 % 8u);
+ c[i_5].y[i_6] = i32();
+ }
+ workgroupBarrier();
+ _ = c;
+ _ = a;
+}
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+struct S {
+ x : i32,
+ y : array<i32, 8>,
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, TransitiveUsage) {
+ auto* src = R"(
+var<workgroup> v : i32;
+
+fn use_v() {
+ _ = v;
+}
+
+fn call_use_v() {
+ use_v();
+}
+
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ call_use_v(); // Initialization should be inserted above this statement
+}
+)";
+ auto* expect = R"(
+var<workgroup> v : i32;
+
+fn use_v() {
+ _ = v;
+}
+
+fn call_use_v() {
+ use_v();
+}
+
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ {
+ v = i32();
+ }
+ workgroupBarrier();
+ call_use_v();
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, TransitiveUsage_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ call_use_v(); // Initialization should be inserted above this statement
+}
+
+fn call_use_v() {
+ use_v();
+}
+
+fn use_v() {
+ _ = v;
+}
+
+var<workgroup> v : i32;
+)";
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_idx : u32) {
+ {
+ v = i32();
+ }
+ workgroupBarrier();
+ call_use_v();
+}
+
+fn call_use_v() {
+ use_v();
+}
+
+fn use_v() {
+ _ = v;
+}
+
+var<workgroup> v : i32;
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupAtomics) {
+ auto* src = R"(
+var<workgroup> i : atomic<i32>;
+var<workgroup> u : atomic<u32>;
+
+@compute @workgroup_size(1)
+fn f() {
+ atomicLoad(&(i)); // Initialization should be inserted above this statement
+ atomicLoad(&(u));
+}
+)";
+ auto* expect = R"(
+var<workgroup> i : atomic<i32>;
+
+var<workgroup> u : atomic<u32>;
+
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
+ {
+ atomicStore(&(i), i32());
+ atomicStore(&(u), u32());
+ }
+ workgroupBarrier();
+ atomicLoad(&(i));
+ atomicLoad(&(u));
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupAtomics_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn f() {
+ atomicLoad(&(i)); // Initialization should be inserted above this statement
+ atomicLoad(&(u));
+}
+
+var<workgroup> i : atomic<i32>;
+var<workgroup> u : atomic<u32>;
+)";
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
+ {
+ atomicStore(&(i), i32());
+ atomicStore(&(u), u32());
+ }
+ workgroupBarrier();
+ atomicLoad(&(i));
+ atomicLoad(&(u));
+}
+
+var<workgroup> i : atomic<i32>;
+
+var<workgroup> u : atomic<u32>;
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupStructOfAtomics) {
+ auto* src = R"(
+struct S {
+ a : i32,
+ i : atomic<i32>,
+ b : f32,
+ u : atomic<u32>,
+ c : u32,
+};
+
+var<workgroup> w : S;
+
+@compute @workgroup_size(1)
+fn f() {
+ _ = w.a; // Initialization should be inserted above this statement
+}
+)";
+ auto* expect = R"(
+struct S {
+ a : i32,
+ i : atomic<i32>,
+ b : f32,
+ u : atomic<u32>,
+ c : u32,
+}
+
+var<workgroup> w : S;
+
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
+ {
+ w.a = i32();
+ atomicStore(&(w.i), i32());
+ w.b = f32();
+ atomicStore(&(w.u), u32());
+ w.c = u32();
+ }
+ workgroupBarrier();
+ _ = w.a;
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupStructOfAtomics_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn f() {
+ _ = w.a; // Initialization should be inserted above this statement
+}
+
+var<workgroup> w : S;
+
+struct S {
+ a : i32,
+ i : atomic<i32>,
+ b : f32,
+ u : atomic<u32>,
+ c : u32,
+};
+)";
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
+ {
+ w.a = i32();
+ atomicStore(&(w.i), i32());
+ w.b = f32();
+ atomicStore(&(w.u), u32());
+ w.c = u32();
+ }
+ workgroupBarrier();
+ _ = w.a;
+}
+
+var<workgroup> w : S;
+
+struct S {
+ a : i32,
+ i : atomic<i32>,
+ b : f32,
+ u : atomic<u32>,
+ c : u32,
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfAtomics) {
+ auto* src = R"(
+var<workgroup> w : array<atomic<u32>, 4>;
+
+@compute @workgroup_size(1)
+fn f() {
+ atomicLoad(&w[0]); // Initialization should be inserted above this statement
+}
+)";
+ auto* expect = R"(
+var<workgroup> w : array<atomic<u32>, 4>;
+
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
+ for(var idx : u32 = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
+ let i : u32 = idx;
+ atomicStore(&(w[i]), u32());
+ }
+ workgroupBarrier();
+ atomicLoad(&(w[0]));
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfAtomics_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn f() {
+ atomicLoad(&w[0]); // Initialization should be inserted above this statement
+}
+
+var<workgroup> w : array<atomic<u32>, 4>;
+)";
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
+ for(var idx : u32 = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
+ let i : u32 = idx;
+ atomicStore(&(w[i]), u32());
+ }
+ workgroupBarrier();
+ atomicLoad(&(w[0]));
+}
+
+var<workgroup> w : array<atomic<u32>, 4>;
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfStructOfAtomics) {
+ auto* src = R"(
+struct S {
+ a : i32,
+ i : atomic<i32>,
+ b : f32,
+ u : atomic<u32>,
+ c : u32,
+};
+
+var<workgroup> w : array<S, 4>;
+
+@compute @workgroup_size(1)
+fn f() {
+ _ = w[0].a; // Initialization should be inserted above this statement
+}
+)";
+ auto* expect = R"(
+struct S {
+ a : i32,
+ i : atomic<i32>,
+ b : f32,
+ u : atomic<u32>,
+ c : u32,
+}
+
+var<workgroup> w : array<S, 4>;
+
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
+ for(var idx : u32 = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
+ let i_1 : u32 = idx;
+ w[i_1].a = i32();
+ atomicStore(&(w[i_1].i), i32());
+ w[i_1].b = f32();
+ atomicStore(&(w[i_1].u), u32());
+ w[i_1].c = u32();
+ }
+ workgroupBarrier();
+ _ = w[0].a;
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfStructOfAtomics_OutOfOrder) {
+ auto* src = R"(
+@compute @workgroup_size(1)
+fn f() {
+ _ = w[0].a; // Initialization should be inserted above this statement
+}
+
+var<workgroup> w : array<S, 4>;
+
+struct S {
+ a : i32,
+ i : atomic<i32>,
+ b : f32,
+ u : atomic<u32>,
+ c : u32,
+};
+)";
+ auto* expect = R"(
+@compute @workgroup_size(1)
+fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
+ for(var idx : u32 = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
+ let i_1 : u32 = idx;
+ w[i_1].a = i32();
+ atomicStore(&(w[i_1].i), i32());
+ w[i_1].b = f32();
+ atomicStore(&(w[i_1].u), u32());
+ w[i_1].c = u32();
+ }
+ workgroupBarrier();
+ _ = w[0].a;
+}
+
+var<workgroup> w : array<S, 4>;
+
+struct S {
+ a : i32,
+ i : atomic<i32>,
+ b : f32,
+ u : atomic<u32>,
+ c : u32,
+}
+)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, ArrayWithOverrideCount) {
+ auto* src =
+ R"(override O = 123;
+alias A = array<i32, O*2>;
+
+var<workgroup> W : A;
+
+@compute @workgroup_size(1)
+fn main() {
+ let p : ptr<workgroup, A> = &W;
+ (*p)[0] = 42;
+}
+)";
+
+ auto* expect =
+ R"(error: array size is an override-expression, when expected a constant-expression.
+Was the SubstituteOverride transform run?)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/traverse_expressions.h b/src/tint/lang/wgsl/ast/traverse_expressions.h
new file mode 100644
index 0000000..0ee10bb
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/traverse_expressions.h
@@ -0,0 +1,164 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRAVERSE_EXPRESSIONS_H_
+#define SRC_TINT_LANG_WGSL_AST_TRAVERSE_EXPRESSIONS_H_
+
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/binary_expression.h"
+#include "src/tint/lang/wgsl/ast/bitcast_expression.h"
+#include "src/tint/lang/wgsl/ast/call_expression.h"
+#include "src/tint/lang/wgsl/ast/index_accessor_expression.h"
+#include "src/tint/lang/wgsl/ast/literal_expression.h"
+#include "src/tint/lang/wgsl/ast/member_accessor_expression.h"
+#include "src/tint/lang/wgsl/ast/phony_expression.h"
+#include "src/tint/lang/wgsl/ast/unary_op_expression.h"
+#include "src/tint/switch.h"
+#include "src/tint/utils/compiler_macros.h"
+#include "src/tint/utils/reverse.h"
+#include "src/tint/utils/vector.h"
+
+namespace tint::ast {
+
+/// The action to perform after calling the TraverseExpressions() callback
+/// function.
+enum class TraverseAction {
+ /// Stop traversal immediately.
+ Stop,
+ /// Descend into this expression.
+ Descend,
+ /// Do not descend into this expression.
+ Skip,
+};
+
+/// The order TraverseExpressions() will traverse expressions
+enum class TraverseOrder {
+ /// Expressions will be traversed from left to right
+ LeftToRight,
+ /// Expressions will be traversed from right to left
+ RightToLeft,
+};
+
+/// TraverseExpressions performs a depth-first traversal of the expression nodes
+/// from `root`, calling `callback` for each of the visited expressions that
+/// match the predicate parameter type, in pre-ordering (root first).
+/// @param root the root expression node
+/// @param diags the diagnostics used for error messages
+/// @param callback the callback function. Must be of the signature:
+/// `TraverseAction(const T* expr)` or `TraverseAction(const T* expr, size_t depth)` where T
+/// is an Expression type.
+/// @return true on success, false on error
+template <TraverseOrder ORDER = TraverseOrder::LeftToRight, typename CALLBACK>
+bool TraverseExpressions(const Expression* root, diag::List& diags, CALLBACK&& callback) {
+ using EXPR_TYPE = std::remove_pointer_t<utils::traits::ParameterType<CALLBACK, 0>>;
+ constexpr static bool kHasDepthArg =
+ utils::traits::SignatureOfT<CALLBACK>::parameter_count == 2;
+
+ struct Pending {
+ const Expression* expr;
+ size_t depth;
+ };
+
+ utils::Vector<Pending, 32> to_visit{{root, 0}};
+
+ auto push_single = [&](const Expression* expr, size_t depth) { to_visit.Push({expr, depth}); };
+ auto push_pair = [&](const Expression* left, const Expression* right, size_t depth) {
+ if (ORDER == TraverseOrder::LeftToRight) {
+ to_visit.Push({right, depth});
+ to_visit.Push({left, depth});
+ } else {
+ to_visit.Push({left, depth});
+ to_visit.Push({right, depth});
+ }
+ };
+ auto push_list = [&](utils::VectorRef<const Expression*> exprs, size_t depth) {
+ if (ORDER == TraverseOrder::LeftToRight) {
+ for (auto* expr : utils::Reverse(exprs)) {
+ to_visit.Push({expr, depth});
+ }
+ } else {
+ for (auto* expr : exprs) {
+ to_visit.Push({expr, depth});
+ }
+ }
+ };
+
+ while (!to_visit.IsEmpty()) {
+ auto p = to_visit.Pop();
+ const Expression* expr = p.expr;
+
+ if (auto* filtered = expr->template As<EXPR_TYPE>()) {
+ TraverseAction result;
+ if constexpr (kHasDepthArg) {
+ result = callback(filtered, p.depth);
+ } else {
+ result = callback(filtered);
+ }
+
+ switch (result) {
+ case TraverseAction::Stop:
+ return true;
+ case TraverseAction::Skip:
+ continue;
+ case TraverseAction::Descend:
+ break;
+ }
+ }
+
+ bool ok = Switch(
+ expr,
+ [&](const IndexAccessorExpression* idx) {
+ push_pair(idx->object, idx->index, p.depth + 1);
+ return true;
+ },
+ [&](const BinaryExpression* bin_op) {
+ push_pair(bin_op->lhs, bin_op->rhs, p.depth + 1);
+ return true;
+ },
+ [&](const BitcastExpression* bitcast) {
+ push_single(bitcast->expr, p.depth + 1);
+ return true;
+ },
+ [&](const CallExpression* call) {
+ push_list(call->args, p.depth + 1);
+ return true;
+ },
+ [&](const MemberAccessorExpression* member) {
+ push_single(member->object, p.depth + 1);
+ return true;
+ },
+ [&](const UnaryOpExpression* unary) {
+ push_single(unary->expr, p.depth + 1);
+ return true;
+ },
+ [&](Default) {
+ if (TINT_LIKELY((expr->IsAnyOf<LiteralExpression, IdentifierExpression,
+ PhonyExpression>()))) {
+ return true; // Leaf expression
+ }
+ TINT_ICE(AST, diags)
+ << "unhandled expression type: " << (expr ? expr->TypeInfo().name : "<null>");
+ return false;
+ });
+ if (!ok) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRAVERSE_EXPRESSIONS_H_
diff --git a/src/tint/lang/wgsl/ast/traverse_expressions_test.cc b/src/tint/lang/wgsl/ast/traverse_expressions_test.cc
new file mode 100644
index 0000000..1a0607b
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/traverse_expressions_test.cc
@@ -0,0 +1,248 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/traverse_expressions.h"
+#include "gmock/gmock.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+using ::testing::ElementsAre;
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast {
+namespace {
+
+using TraverseExpressionsTest = TestHelper;
+
+TEST_F(TraverseExpressionsTest, DescendIndexAccessor) {
+ std::vector<const Expression*> e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
+ std::vector<const Expression*> i = {IndexAccessor(e[0], e[1]), IndexAccessor(e[2], e[3])};
+ auto* root = IndexAccessor(i[0], i[1]);
+ {
+ std::vector<const Expression*> l2r;
+ TraverseExpressions<TraverseOrder::LeftToRight>(root, Diagnostics(),
+ [&](const Expression* expr) {
+ l2r.push_back(expr);
+ return TraverseAction::Descend;
+ });
+ EXPECT_THAT(l2r, ElementsAre(root, i[0], e[0], e[1], i[1], e[2], e[3]));
+ }
+ {
+ std::vector<const Expression*> r2l;
+ TraverseExpressions<TraverseOrder::RightToLeft>(root, Diagnostics(),
+ [&](const Expression* expr) {
+ r2l.push_back(expr);
+ return TraverseAction::Descend;
+ });
+ EXPECT_THAT(r2l, ElementsAre(root, i[1], e[3], e[2], i[0], e[1], e[0]));
+ }
+}
+
+TEST_F(TraverseExpressionsTest, DescendBinaryExpression) {
+ std::vector<const Expression*> e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
+ std::vector<const Expression*> i = {Add(e[0], e[1]), Sub(e[2], e[3])};
+ auto* root = Mul(i[0], i[1]);
+ {
+ std::vector<const Expression*> l2r;
+ TraverseExpressions<TraverseOrder::LeftToRight>(root, Diagnostics(),
+ [&](const Expression* expr) {
+ l2r.push_back(expr);
+ return TraverseAction::Descend;
+ });
+ EXPECT_THAT(l2r, ElementsAre(root, i[0], e[0], e[1], i[1], e[2], e[3]));
+ }
+ {
+ std::vector<const Expression*> r2l;
+ TraverseExpressions<TraverseOrder::RightToLeft>(root, Diagnostics(),
+ [&](const Expression* expr) {
+ r2l.push_back(expr);
+ return TraverseAction::Descend;
+ });
+ EXPECT_THAT(r2l, ElementsAre(root, i[1], e[3], e[2], i[0], e[1], e[0]));
+ }
+}
+
+TEST_F(TraverseExpressionsTest, Depth) {
+ std::vector<const Expression*> e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
+ std::vector<const Expression*> i = {Add(e[0], e[1]), Sub(e[2], e[3])};
+ auto* root = Mul(i[0], i[1]);
+
+ size_t j = 0;
+ size_t depths[] = {0, 1, 2, 2, 1, 2, 2};
+ {
+ TraverseExpressions<TraverseOrder::LeftToRight>( //
+ root, Diagnostics(), [&](const Expression* expr, size_t depth) {
+ (void)expr;
+ EXPECT_THAT(depth, depths[j++]);
+ return TraverseAction::Descend;
+ });
+ }
+}
+
+TEST_F(TraverseExpressionsTest, DescendBitcastExpression) {
+ auto* e = Expr(1_i);
+ auto* b0 = Bitcast<i32>(e);
+ auto* b1 = Bitcast<i32>(b0);
+ auto* b2 = Bitcast<i32>(b1);
+ auto* root = Bitcast<i32>(b2);
+ {
+ utils::Vector<const Expression*, 8> l2r;
+ TraverseExpressions<TraverseOrder::LeftToRight>(root, Diagnostics(),
+ [&](const Expression* expr) {
+ l2r.Push(expr);
+ return TraverseAction::Descend;
+ });
+ EXPECT_THAT(l2r, ElementsAre(root, b2, b1, b0, e));
+ }
+ {
+ utils::Vector<const Expression*, 8> r2l;
+ TraverseExpressions<TraverseOrder::RightToLeft>(root, Diagnostics(),
+ [&](const Expression* expr) {
+ r2l.Push(expr);
+ return TraverseAction::Descend;
+ });
+ EXPECT_THAT(r2l, ElementsAre(root, b2, b1, b0, e));
+ }
+}
+
+TEST_F(TraverseExpressionsTest, DescendCallExpression) {
+ utils::Vector e{Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
+ utils::Vector c{Call("a", e[0], e[1]), Call("b", e[2], e[3])};
+ auto* root = Call("c", c[0], c[1]);
+ {
+ utils::Vector<const Expression*, 8> l2r;
+ TraverseExpressions<TraverseOrder::LeftToRight>(root, Diagnostics(),
+ [&](const Expression* expr) {
+ l2r.Push(expr);
+ return TraverseAction::Descend;
+ });
+ EXPECT_THAT(l2r, ElementsAre(root, c[0], e[0], e[1], c[1], e[2], e[3]));
+ }
+ {
+ utils::Vector<const Expression*, 8> r2l;
+ TraverseExpressions<TraverseOrder::RightToLeft>(root, Diagnostics(),
+ [&](const Expression* expr) {
+ r2l.Push(expr);
+ return TraverseAction::Descend;
+ });
+ EXPECT_THAT(r2l, ElementsAre(root, c[1], e[3], e[2], c[0], e[1], e[0]));
+ }
+}
+
+TEST_F(TraverseExpressionsTest, DescendMemberAccessorExpression) {
+ auto* e = Expr(1_i);
+ auto* m = MemberAccessor(e, "a");
+ auto* root = MemberAccessor(m, "b");
+ {
+ std::vector<const Expression*> l2r;
+ TraverseExpressions<TraverseOrder::LeftToRight>(root, Diagnostics(),
+ [&](const Expression* expr) {
+ l2r.push_back(expr);
+ return TraverseAction::Descend;
+ });
+ EXPECT_THAT(l2r, ElementsAre(root, m, e));
+ }
+ {
+ std::vector<const Expression*> r2l;
+ TraverseExpressions<TraverseOrder::RightToLeft>(root, Diagnostics(),
+ [&](const Expression* expr) {
+ r2l.push_back(expr);
+ return TraverseAction::Descend;
+ });
+ EXPECT_THAT(r2l, ElementsAre(root, m, e));
+ }
+}
+
+TEST_F(TraverseExpressionsTest, DescendMemberIndexExpression) {
+ auto* a = Expr("a");
+ auto* b = Expr("b");
+ auto* c = IndexAccessor(a, b);
+ auto* d = Expr("d");
+ auto* e = Expr("e");
+ auto* f = IndexAccessor(d, e);
+ auto* root = IndexAccessor(c, f);
+ {
+ std::vector<const Expression*> l2r;
+ TraverseExpressions<TraverseOrder::LeftToRight>(root, Diagnostics(),
+ [&](const Expression* expr) {
+ l2r.push_back(expr);
+ return TraverseAction::Descend;
+ });
+ EXPECT_THAT(l2r, ElementsAre(root, c, a, b, f, d, e));
+ }
+ {
+ std::vector<const Expression*> r2l;
+ TraverseExpressions<TraverseOrder::RightToLeft>(root, Diagnostics(),
+ [&](const Expression* expr) {
+ r2l.push_back(expr);
+ return TraverseAction::Descend;
+ });
+ EXPECT_THAT(r2l, ElementsAre(root, f, e, d, c, b, a));
+ }
+}
+
+TEST_F(TraverseExpressionsTest, DescendUnaryExpression) {
+ auto* e = Expr(1_i);
+ auto* u0 = AddressOf(e);
+ auto* u1 = Deref(u0);
+ auto* u2 = AddressOf(u1);
+ auto* root = Deref(u2);
+ {
+ std::vector<const Expression*> l2r;
+ TraverseExpressions<TraverseOrder::LeftToRight>(root, Diagnostics(),
+ [&](const Expression* expr) {
+ l2r.push_back(expr);
+ return TraverseAction::Descend;
+ });
+ EXPECT_THAT(l2r, ElementsAre(root, u2, u1, u0, e));
+ }
+ {
+ std::vector<const Expression*> r2l;
+ TraverseExpressions<TraverseOrder::RightToLeft>(root, Diagnostics(),
+ [&](const Expression* expr) {
+ r2l.push_back(expr);
+ return TraverseAction::Descend;
+ });
+ EXPECT_THAT(r2l, ElementsAre(root, u2, u1, u0, e));
+ }
+}
+
+TEST_F(TraverseExpressionsTest, Skip) {
+ std::vector<const Expression*> e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
+ std::vector<const Expression*> i = {IndexAccessor(e[0], e[1]), IndexAccessor(e[2], e[3])};
+ auto* root = IndexAccessor(i[0], i[1]);
+ std::vector<const Expression*> order;
+ TraverseExpressions<TraverseOrder::LeftToRight>(
+ root, Diagnostics(), [&](const Expression* expr) {
+ order.push_back(expr);
+ return expr == i[0] ? TraverseAction::Skip : TraverseAction::Descend;
+ });
+ EXPECT_THAT(order, ElementsAre(root, i[0], i[1], e[2], e[3]));
+}
+
+TEST_F(TraverseExpressionsTest, Stop) {
+ std::vector<const Expression*> e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
+ std::vector<const Expression*> i = {IndexAccessor(e[0], e[1]), IndexAccessor(e[2], e[3])};
+ auto* root = IndexAccessor(i[0], i[1]);
+ std::vector<const Expression*> order;
+ TraverseExpressions<TraverseOrder::LeftToRight>(
+ root, Diagnostics(), [&](const Expression* expr) {
+ order.push_back(expr);
+ return expr == i[0] ? TraverseAction::Stop : TraverseAction::Descend;
+ });
+ EXPECT_THAT(order, ElementsAre(root, i[0]));
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/type.cc b/src/tint/lang/wgsl/ast/type.cc
new file mode 100644
index 0000000..69d28eb
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/type.cc
@@ -0,0 +1,24 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/type.h"
+#include "src/tint/lang/wgsl/ast/identifier_expression.h"
+
+namespace tint {
+
+ProgramID ProgramIDOf(ast::Type type) {
+ return ProgramIDOf(type.expr);
+}
+
+} // namespace tint
diff --git a/src/tint/lang/wgsl/ast/type.h b/src/tint/lang/wgsl/ast/type.h
new file mode 100644
index 0000000..4015081
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/type.h
@@ -0,0 +1,52 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TYPE_H_
+#define SRC_TINT_LANG_WGSL_AST_TYPE_H_
+
+#include "src/tint/program_id.h"
+
+// Forward declarations
+namespace tint::ast {
+class IdentifierExpression;
+} // namespace tint::ast
+
+namespace tint::ast {
+
+/// Type is a thin wrapper around an IdentifierExpression, to help statically disambiguate known
+/// type expressions from other expressions.
+struct Type {
+ /// The type expression
+ const IdentifierExpression* expr = nullptr;
+
+ /// Indirection operator for accessing the type's expression
+ /// @return #expr
+ const IdentifierExpression* operator->() const { return expr; }
+
+ /// Implicit conversion operator to the type's expression
+ /// @return #expr
+ operator const IdentifierExpression*() const { return expr; }
+};
+
+} // namespace tint::ast
+
+namespace tint {
+
+/// @param type an AST type
+/// @returns the ProgramID of the given AST type.
+ProgramID ProgramIDOf(ast::Type type);
+
+} // namespace tint
+
+#endif // SRC_TINT_LANG_WGSL_AST_TYPE_H_
diff --git a/src/tint/lang/wgsl/ast/type_decl.cc b/src/tint/lang/wgsl/ast/type_decl.cc
new file mode 100644
index 0000000..5bd9d26
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/type_decl.cc
@@ -0,0 +1,33 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/type_decl.h"
+
+#include "src/tint/lang/wgsl/ast/templated_identifier.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::TypeDecl);
+
+namespace tint::ast {
+
+TypeDecl::TypeDecl(ProgramID pid, NodeID nid, const Source& src, const Identifier* n)
+ : Base(pid, nid, src), name(n) {
+ TINT_ASSERT(AST, name);
+ if (name) {
+ TINT_ASSERT(AST, !name->Is<TemplatedIdentifier>());
+ }
+}
+
+TypeDecl::~TypeDecl() = default;
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/type_decl.h b/src/tint/lang/wgsl/ast/type_decl.h
new file mode 100644
index 0000000..1414370
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/type_decl.h
@@ -0,0 +1,46 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TYPE_DECL_H_
+#define SRC_TINT_LANG_WGSL_AST_TYPE_DECL_H_
+
+#include "src/tint/lang/wgsl/ast/node.h"
+
+// Forward declarations
+namespace tint::ast {
+class Identifier;
+} // namespace tint::ast
+
+namespace tint::ast {
+
+/// The base class for type declarations.
+class TypeDecl : public utils::Castable<TypeDecl, Node> {
+ public:
+ /// Create a new struct statement
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node for the import statement
+ /// @param name The name of the type
+ TypeDecl(ProgramID pid, NodeID nid, const Source& src, const Identifier* name);
+
+ /// Destructor
+ ~TypeDecl() override;
+
+ /// The name of the type declaration
+ const Identifier* const name;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_TYPE_DECL_H_
diff --git a/src/tint/lang/wgsl/ast/unary_op.cc b/src/tint/lang/wgsl/ast/unary_op.cc
new file mode 100644
index 0000000..288e19c
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/unary_op.cc
@@ -0,0 +1,45 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/unary_op.h"
+
+namespace tint::ast {
+
+utils::StringStream& operator<<(utils::StringStream& out, UnaryOp mod) {
+ switch (mod) {
+ case UnaryOp::kAddressOf: {
+ out << "address-of";
+ break;
+ }
+ case UnaryOp::kComplement: {
+ out << "complement";
+ break;
+ }
+ case UnaryOp::kIndirection: {
+ out << "indirection";
+ break;
+ }
+ case UnaryOp::kNegation: {
+ out << "negation";
+ break;
+ }
+ case UnaryOp::kNot: {
+ out << "not";
+ break;
+ }
+ }
+ return out;
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/unary_op.h b/src/tint/lang/wgsl/ast/unary_op.h
new file mode 100644
index 0000000..1e817fc
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/unary_op.h
@@ -0,0 +1,38 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_UNARY_OP_H_
+#define SRC_TINT_LANG_WGSL_AST_UNARY_OP_H_
+
+#include "src/tint/utils/string_stream.h"
+
+namespace tint::ast {
+
+/// The unary op
+enum class UnaryOp {
+ kAddressOf, // &EXPR
+ kComplement, // ~EXPR
+ kIndirection, // *EXPR
+ kNegation, // -EXPR
+ kNot, // !EXPR
+};
+
+/// @param out the stream to write to
+/// @param mod the UnaryOp
+/// @return the stream so calls can be chained
+utils::StringStream& operator<<(utils::StringStream& out, UnaryOp mod);
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_UNARY_OP_H_
diff --git a/src/tint/lang/wgsl/ast/unary_op_expression.cc b/src/tint/lang/wgsl/ast/unary_op_expression.cc
new file mode 100644
index 0000000..2f795c4
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/unary_op_expression.cc
@@ -0,0 +1,42 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/unary_op_expression.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::UnaryOpExpression);
+
+namespace tint::ast {
+
+UnaryOpExpression::UnaryOpExpression(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ UnaryOp o,
+ const Expression* e)
+ : Base(pid, nid, src), op(o), expr(e) {
+ TINT_ASSERT(AST, expr);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, expr, program_id);
+}
+
+UnaryOpExpression::~UnaryOpExpression() = default;
+
+const UnaryOpExpression* UnaryOpExpression::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* e = ctx->Clone(expr);
+ return ctx->dst->create<UnaryOpExpression>(src, op, e);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/unary_op_expression.h b/src/tint/lang/wgsl/ast/unary_op_expression.h
new file mode 100644
index 0000000..22d70f5
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/unary_op_expression.h
@@ -0,0 +1,56 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_UNARY_OP_EXPRESSION_H_
+#define SRC_TINT_LANG_WGSL_AST_UNARY_OP_EXPRESSION_H_
+
+#include "src/tint/lang/wgsl/ast/expression.h"
+#include "src/tint/lang/wgsl/ast/unary_op.h"
+
+namespace tint::ast {
+
+/// A unary op expression
+class UnaryOpExpression final : public utils::Castable<UnaryOpExpression, Expression> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the unary op expression source
+ /// @param op the op
+ /// @param expr the expr
+ UnaryOpExpression(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ UnaryOp op,
+ const Expression* expr);
+
+ /// Destructor
+ ~UnaryOpExpression() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const UnaryOpExpression* Clone(CloneContext* ctx) const override;
+
+ /// The op
+ const UnaryOp op;
+
+ /// The expression
+ const Expression* const expr;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_UNARY_OP_EXPRESSION_H_
diff --git a/src/tint/lang/wgsl/ast/unary_op_expression_test.cc b/src/tint/lang/wgsl/ast/unary_op_expression_test.cc
new file mode 100644
index 0000000..82b9764
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/unary_op_expression_test.cc
@@ -0,0 +1,67 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/unary_op_expression.h"
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using UnaryOpExpressionTest = TestHelper;
+
+TEST_F(UnaryOpExpressionTest, Creation) {
+ auto* ident = Expr("ident");
+
+ auto* u = create<UnaryOpExpression>(UnaryOp::kNot, ident);
+ EXPECT_EQ(u->op, UnaryOp::kNot);
+ EXPECT_EQ(u->expr, ident);
+}
+
+TEST_F(UnaryOpExpressionTest, Creation_WithSource) {
+ auto* ident = Expr("ident");
+ auto* u = create<UnaryOpExpression>(Source{Source::Location{20, 2}}, UnaryOp::kNot, ident);
+ auto src = u->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(UnaryOpExpressionTest, IsUnaryOp) {
+ auto* ident = Expr("ident");
+ auto* u = create<UnaryOpExpression>(UnaryOp::kNot, ident);
+ EXPECT_TRUE(u->Is<UnaryOpExpression>());
+}
+
+TEST_F(UnaryOpExpressionTest, Assert_Null_Expression) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<UnaryOpExpression>(UnaryOp::kNot, nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(UnaryOpExpressionTest, Assert_DifferentProgramID_Expression) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<UnaryOpExpression>(UnaryOp::kNot, b2.Expr(true));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/var.cc b/src/tint/lang/wgsl/ast/var.cc
new file mode 100644
index 0000000..6f3faca
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/var.cc
@@ -0,0 +1,53 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/var.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::Var);
+
+namespace tint::ast {
+
+Var::Var(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Identifier* n,
+ Type ty,
+ const Expression* address_space,
+ const Expression* access,
+ const Expression* init,
+ utils::VectorRef<const Attribute*> attrs)
+ : Base(pid, nid, src, n, ty, init, std::move(attrs)),
+ declared_address_space(address_space),
+ declared_access(access) {}
+
+Var::~Var() = default;
+
+const char* Var::Kind() const {
+ return "var";
+}
+
+const Var* Var::Clone(CloneContext* ctx) const {
+ auto src = ctx->Clone(source);
+ auto* n = ctx->Clone(name);
+ auto ty = ctx->Clone(type);
+ auto* address_space = ctx->Clone(declared_address_space);
+ auto* access = ctx->Clone(declared_access);
+ auto* init = ctx->Clone(initializer);
+ auto attrs = ctx->Clone(attributes);
+ return ctx->dst->create<Var>(src, n, ty, address_space, access, init, std::move(attrs));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/var.h b/src/tint/lang/wgsl/ast/var.h
new file mode 100644
index 0000000..d7a6770
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/var.h
@@ -0,0 +1,88 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_VAR_H_
+#define SRC_TINT_LANG_WGSL_AST_VAR_H_
+
+#include <utility>
+#include <vector>
+
+#include "src/tint/lang/wgsl/ast/variable.h"
+
+namespace tint::ast {
+
+/// A "var" declaration is a name for typed storage.
+///
+/// Examples:
+///
+/// ```
+/// // Declared outside a function, i.e. at module scope, requires
+/// // a address space.
+/// var<workgroup> width : i32; // no initializer
+/// var<private> height : i32 = 3; // with initializer
+///
+/// // A variable declared inside a function doesn't take a address space,
+/// // and maps to SPIR-V Function storage.
+/// var computed_depth : i32;
+/// var area : i32 = compute_area(width, height);
+/// ```
+///
+/// @see https://www.w3.org/TR/WGSL/#var-decls
+class Var final : public utils::Castable<Var, Variable> {
+ public:
+ /// Create a 'var' variable
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the variable source
+ /// @param name the variable name
+ /// @param type the declared variable type
+ /// @param declared_address_space the declared address space
+ /// @param declared_access the declared access control
+ /// @param initializer the initializer expression
+ /// @param attributes the variable attributes
+ Var(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ const Identifier* name,
+ Type type,
+ const Expression* declared_address_space,
+ const Expression* declared_access,
+ const Expression* initializer,
+ utils::VectorRef<const Attribute*> attributes);
+
+ /// Destructor
+ ~Var() override;
+
+ /// @returns "var"
+ const char* Kind() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const Var* Clone(CloneContext* ctx) const override;
+
+ /// The declared address space
+ const Expression* const declared_address_space = nullptr;
+
+ /// The declared access control
+ const Expression* const declared_access = nullptr;
+};
+
+/// A list of `var` declarations
+using VarList = std::vector<const Var*>;
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_VAR_H_
diff --git a/src/tint/lang/wgsl/ast/variable.cc b/src/tint/lang/wgsl/ast/variable.cc
new file mode 100644
index 0000000..d0eb80d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/variable.cc
@@ -0,0 +1,41 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/variable.h"
+#include "src/tint/lang/wgsl/ast/binding_attribute.h"
+#include "src/tint/lang/wgsl/ast/group_attribute.h"
+#include "src/tint/lang/wgsl/ast/templated_identifier.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::Variable);
+
+namespace tint::ast {
+
+Variable::Variable(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Identifier* n,
+ Type ty,
+ const Expression* init,
+ utils::VectorRef<const Attribute*> attrs)
+ : Base(pid, nid, src), name(n), type(ty), initializer(init), attributes(std::move(attrs)) {
+ TINT_ASSERT(AST, name);
+ if (name) {
+ TINT_ASSERT(AST, !name->Is<TemplatedIdentifier>());
+ }
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, initializer, program_id);
+}
+
+Variable::~Variable() = default;
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/variable.h b/src/tint/lang/wgsl/ast/variable.h
new file mode 100644
index 0000000..84393c52
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/variable.h
@@ -0,0 +1,91 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_VARIABLE_H_
+#define SRC_TINT_LANG_WGSL_AST_VARIABLE_H_
+
+#include <utility>
+#include <vector>
+
+#include "src/tint/builtin/access.h"
+#include "src/tint/builtin/address_space.h"
+#include "src/tint/lang/wgsl/ast/attribute.h"
+#include "src/tint/lang/wgsl/ast/binding_attribute.h"
+#include "src/tint/lang/wgsl/ast/expression.h"
+#include "src/tint/lang/wgsl/ast/group_attribute.h"
+#include "src/tint/lang/wgsl/ast/type.h"
+
+// Forward declarations
+namespace tint::ast {
+class Identifier;
+class LocationAttribute;
+} // namespace tint::ast
+
+namespace tint::ast {
+
+/// Variable is the base class for Var, Let, Const, Override and Parameter.
+///
+/// An instance of this class represents one of five constructs in WGSL: "var" declaration, "let"
+/// declaration, "override" declaration, "const" declaration, or formal parameter to a function.
+///
+/// @see https://www.w3.org/TR/WGSL/#value-decls
+class Variable : public utils::Castable<Variable, Node> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the variable source
+ /// @param name The struct member name
+ /// @param type the declared variable type
+ /// @param initializer the initializer expression
+ /// @param attributes the variable attributes
+ Variable(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Identifier* name,
+ Type type,
+ const Expression* initializer,
+ utils::VectorRef<const Attribute*> attributes);
+
+ /// Destructor
+ ~Variable() override;
+
+ /// @returns true if the variable has both group and binding attributes
+ bool HasBindingPoint() const {
+ return HasAttribute<BindingAttribute>(attributes) &&
+ HasAttribute<GroupAttribute>(attributes);
+ }
+
+ /// @returns the kind of the variable, which can be used in diagnostics
+ /// e.g. "var", "let", "const", etc
+ virtual const char* Kind() const = 0;
+
+ /// The variable name
+ const Identifier* const name;
+
+ /// The declared variable type. This is null if the type is inferred, e.g.:
+ /// let f = 1.0;
+ /// var i = 1;
+ const Type type;
+
+ /// The initializer expression or nullptr if none set
+ const Expression* const initializer;
+
+ /// The attributes attached to this variable
+ const utils::Vector<const Attribute*, 2> attributes;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_VARIABLE_H_
diff --git a/src/tint/lang/wgsl/ast/variable_decl_statement.cc b/src/tint/lang/wgsl/ast/variable_decl_statement.cc
new file mode 100644
index 0000000..d83b87b
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/variable_decl_statement.cc
@@ -0,0 +1,41 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::VariableDeclStatement);
+
+namespace tint::ast {
+
+VariableDeclStatement::VariableDeclStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Variable* var)
+ : Base(pid, nid, src), variable(var) {
+ TINT_ASSERT(AST, variable);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, variable, program_id);
+}
+
+VariableDeclStatement::~VariableDeclStatement() = default;
+
+const VariableDeclStatement* VariableDeclStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* var = ctx->Clone(variable);
+ return ctx->dst->create<VariableDeclStatement>(src, var);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/variable_decl_statement.h b/src/tint/lang/wgsl/ast/variable_decl_statement.h
new file mode 100644
index 0000000..e9b796d
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/variable_decl_statement.h
@@ -0,0 +1,51 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_VARIABLE_DECL_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_VARIABLE_DECL_STATEMENT_H_
+
+#include "src/tint/lang/wgsl/ast/statement.h"
+#include "src/tint/lang/wgsl/ast/variable.h"
+
+namespace tint::ast {
+
+/// A variable declaration statement
+class VariableDeclStatement final : public utils::Castable<VariableDeclStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the variable statement source
+ /// @param variable the variable
+ VariableDeclStatement(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ const Variable* variable);
+
+ /// Destructor
+ ~VariableDeclStatement() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const VariableDeclStatement* Clone(CloneContext* ctx) const override;
+
+ /// The variable
+ const Variable* const variable;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_VARIABLE_DECL_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/variable_decl_statement_test.cc b/src/tint/lang/wgsl/ast/variable_decl_statement_test.cc
new file mode 100644
index 0000000..df8a033
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/variable_decl_statement_test.cc
@@ -0,0 +1,68 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+namespace tint::ast {
+namespace {
+
+using VariableDeclStatementTest = TestHelper;
+
+TEST_F(VariableDeclStatementTest, Creation) {
+ auto* var = Var("a", ty.f32());
+
+ auto* stmt = create<VariableDeclStatement>(var);
+ EXPECT_EQ(stmt->variable, var);
+}
+
+TEST_F(VariableDeclStatementTest, Creation_WithSource) {
+ auto* var = Var("a", ty.f32());
+
+ auto* stmt = create<VariableDeclStatement>(Source{Source::Location{20, 2}}, var);
+ auto src = stmt->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(VariableDeclStatementTest, IsVariableDecl) {
+ auto* var = Var("a", ty.f32());
+
+ auto* stmt = create<VariableDeclStatement>(var);
+ EXPECT_TRUE(stmt->Is<VariableDeclStatement>());
+}
+
+TEST_F(VariableDeclStatementTest, Assert_Null_Variable) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<VariableDeclStatement>(nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(VariableDeclStatementTest, Assert_DifferentProgramID_Variable) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<VariableDeclStatement>(b2.Var("a", b2.ty.f32()));
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/variable_test.cc b/src/tint/lang/wgsl/ast/variable_test.cc
new file mode 100644
index 0000000..08ca16c
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/variable_test.cc
@@ -0,0 +1,132 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest-spi.h"
+
+#include "src/tint/builtin/builtin_value.h"
+#include "src/tint/lang/wgsl/ast/id_attribute.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast {
+namespace {
+
+using VariableTest = TestHelper;
+
+TEST_F(VariableTest, Creation) {
+ auto* v = Var("my_var", ty.i32(), builtin::AddressSpace::kFunction);
+
+ CheckIdentifier(v->name, "my_var");
+ CheckIdentifier(v->declared_address_space, "function");
+ EXPECT_EQ(v->declared_access, nullptr);
+ CheckIdentifier(v->type, "i32");
+ EXPECT_EQ(v->source.range.begin.line, 0u);
+ EXPECT_EQ(v->source.range.begin.column, 0u);
+ EXPECT_EQ(v->source.range.end.line, 0u);
+ EXPECT_EQ(v->source.range.end.column, 0u);
+}
+
+TEST_F(VariableTest, CreationWithSource) {
+ auto* v = Var(Source{Source::Range{Source::Location{27, 4}, Source::Location{27, 5}}}, "i",
+ ty.f32(), builtin::AddressSpace::kPrivate, utils::Empty);
+
+ CheckIdentifier(v->name, "i");
+ CheckIdentifier(v->declared_address_space, "private");
+ CheckIdentifier(v->type, "f32");
+ EXPECT_EQ(v->source.range.begin.line, 27u);
+ EXPECT_EQ(v->source.range.begin.column, 4u);
+ EXPECT_EQ(v->source.range.end.line, 27u);
+ EXPECT_EQ(v->source.range.end.column, 5u);
+}
+
+TEST_F(VariableTest, CreationEmpty) {
+ auto* v = Var(Source{Source::Range{Source::Location{27, 4}, Source::Location{27, 7}}}, "a_var",
+ ty.i32(), builtin::AddressSpace::kWorkgroup, utils::Empty);
+
+ CheckIdentifier(v->name, "a_var");
+ CheckIdentifier(v->declared_address_space, "workgroup");
+ CheckIdentifier(v->type, "i32");
+ EXPECT_EQ(v->source.range.begin.line, 27u);
+ EXPECT_EQ(v->source.range.begin.column, 4u);
+ EXPECT_EQ(v->source.range.end.line, 27u);
+ EXPECT_EQ(v->source.range.end.column, 7u);
+}
+
+TEST_F(VariableTest, Assert_Null_Name) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.Var(static_cast<Identifier*>(nullptr), b.ty.i32());
+ },
+ "internal compiler error");
+}
+
+TEST_F(VariableTest, Assert_DifferentProgramID_Symbol) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Var(b2.Sym("x"), b1.ty.f32());
+ },
+ "internal compiler error");
+}
+
+TEST_F(VariableTest, Assert_DifferentProgramID_Initializer) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Var("x", b1.ty.f32(), b2.Expr(1.2_f));
+ },
+ "internal compiler error");
+}
+
+TEST_F(VariableTest, WithAttributes) {
+ auto* var = Var("my_var", ty.i32(), builtin::AddressSpace::kFunction, Location(1_u),
+ Builtin(builtin::BuiltinValue::kPosition), Id(1200_u));
+
+ auto& attributes = var->attributes;
+ EXPECT_TRUE(ast::HasAttribute<LocationAttribute>(attributes));
+ EXPECT_TRUE(ast::HasAttribute<BuiltinAttribute>(attributes));
+ EXPECT_TRUE(ast::HasAttribute<IdAttribute>(attributes));
+
+ auto* location = GetAttribute<LocationAttribute>(attributes);
+ ASSERT_NE(nullptr, location);
+ ASSERT_NE(nullptr, location->expr);
+ EXPECT_TRUE(location->expr->Is<IntLiteralExpression>());
+}
+
+TEST_F(VariableTest, HasBindingPoint_BothProvided) {
+ auto* var = Var("my_var", ty.i32(), builtin::AddressSpace::kFunction, Binding(2_a), Group(1_a));
+ EXPECT_TRUE(var->HasBindingPoint());
+}
+
+TEST_F(VariableTest, HasBindingPoint_NeitherProvided) {
+ auto* var = Var("my_var", ty.i32(), builtin::AddressSpace::kFunction, utils::Empty);
+ EXPECT_FALSE(var->HasBindingPoint());
+}
+
+TEST_F(VariableTest, HasBindingPoint_MissingGroupAttribute) {
+ auto* var = Var("my_var", ty.i32(), builtin::AddressSpace::kFunction, Binding(2_a));
+ EXPECT_FALSE(var->HasBindingPoint());
+}
+
+TEST_F(VariableTest, HasBindingPoint_MissingBindingAttribute) {
+ auto* var = Var("my_var", ty.i32(), builtin::AddressSpace::kFunction, Group(1_a));
+ EXPECT_FALSE(var->HasBindingPoint());
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/while_statement.cc b/src/tint/lang/wgsl/ast/while_statement.cc
new file mode 100644
index 0000000..dec91ac
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/while_statement.cc
@@ -0,0 +1,55 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/while_statement.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::WhileStatement);
+
+namespace tint::ast {
+
+WhileStatement::WhileStatement(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* cond,
+ const BlockStatement* b,
+ utils::VectorRef<const ast::Attribute*> attrs)
+ : Base(pid, nid, src), condition(cond), body(b), attributes(std::move(attrs)) {
+ TINT_ASSERT(AST, cond);
+ TINT_ASSERT(AST, body);
+
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, condition, program_id);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, body, program_id);
+ for (auto* attr : attributes) {
+ TINT_ASSERT(AST, attr);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, attr, program_id);
+ }
+}
+
+WhileStatement::~WhileStatement() = default;
+
+const WhileStatement* WhileStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+
+ auto* cond = ctx->Clone(condition);
+ auto* b = ctx->Clone(body);
+ auto attrs = ctx->Clone(attributes);
+ return ctx->dst->create<WhileStatement>(src, cond, b, std::move(attrs));
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/while_statement.h b/src/tint/lang/wgsl/ast/while_statement.h
new file mode 100644
index 0000000..9d824c7
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/while_statement.h
@@ -0,0 +1,62 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_WHILE_STATEMENT_H_
+#define SRC_TINT_LANG_WGSL_AST_WHILE_STATEMENT_H_
+
+#include "src/tint/lang/wgsl/ast/block_statement.h"
+
+namespace tint::ast {
+
+class Expression;
+
+/// A while loop statement
+class WhileStatement final : public utils::Castable<WhileStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param source the for loop statement source
+ /// @param condition the optional loop condition expression
+ /// @param body the loop body
+ /// @param attributes the while statement attributes
+ WhileStatement(ProgramID pid,
+ NodeID nid,
+ const Source& source,
+ const Expression* condition,
+ const BlockStatement* body,
+ utils::VectorRef<const ast::Attribute*> attributes);
+
+ /// Destructor
+ ~WhileStatement() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const WhileStatement* Clone(CloneContext* ctx) const override;
+
+ /// The condition expression
+ const Expression* const condition;
+
+ /// The loop body block
+ const BlockStatement* const body;
+
+ /// The attribute list
+ const utils::Vector<const Attribute*, 1> attributes;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_WHILE_STATEMENT_H_
diff --git a/src/tint/lang/wgsl/ast/while_statement_test.cc b/src/tint/lang/wgsl/ast/while_statement_test.cc
new file mode 100644
index 0000000..f22edc9
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/while_statement_test.cc
@@ -0,0 +1,96 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/lang/wgsl/ast/binary_expression.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast {
+namespace {
+
+using WhileStatementTest = TestHelper;
+
+TEST_F(WhileStatementTest, Creation) {
+ auto* cond = create<BinaryExpression>(BinaryOp::kLessThan, Expr("i"), Expr(5_u));
+ auto* body = Block(Return());
+ auto* l = While(cond, body);
+
+ EXPECT_EQ(l->condition, cond);
+ EXPECT_EQ(l->body, body);
+}
+
+TEST_F(WhileStatementTest, Creation_WithSource) {
+ auto* cond = create<BinaryExpression>(BinaryOp::kLessThan, Expr("i"), Expr(5_u));
+ auto* body = Block(Return());
+ auto* l = While(Source{{20u, 2u}}, cond, body);
+ auto src = l->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(WhileStatementTest, Creation_WithAttributes) {
+ auto* attr1 = DiagnosticAttribute(builtin::DiagnosticSeverity::kOff, "foo");
+ auto* attr2 = DiagnosticAttribute(builtin::DiagnosticSeverity::kOff, "bar");
+ auto* cond = create<BinaryExpression>(BinaryOp::kLessThan, Expr("i"), Expr(5_u));
+ auto* body = Block(Return());
+ auto* l = While(cond, body, utils::Vector{attr1, attr2});
+
+ EXPECT_THAT(l->attributes, testing::ElementsAre(attr1, attr2));
+}
+
+TEST_F(WhileStatementTest, Assert_Null_Cond) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ auto* body = b.Block();
+ b.While(nullptr, body);
+ },
+ "internal compiler error");
+}
+
+TEST_F(WhileStatementTest, Assert_Null_Body) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ auto* cond = b.create<BinaryExpression>(BinaryOp::kLessThan, b.Expr("i"), b.Expr(5_u));
+ b.While(cond, nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(WhileStatementTest, Assert_DifferentProgramID_Condition) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.While(b2.Expr(true), b1.Block());
+ },
+ "internal compiler error");
+}
+
+TEST_F(WhileStatementTest, Assert_DifferentProgramID_Body) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.While(b1.Expr(true), b2.Block());
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/workgroup_attribute.cc b/src/tint/lang/wgsl/ast/workgroup_attribute.cc
new file mode 100644
index 0000000..936f8e4
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/workgroup_attribute.cc
@@ -0,0 +1,48 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/workgroup_attribute.h"
+
+#include <string>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::WorkgroupAttribute);
+
+namespace tint::ast {
+
+WorkgroupAttribute::WorkgroupAttribute(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* x_,
+ const Expression* y_,
+ const Expression* z_)
+ : Base(pid, nid, src), x(x_), y(y_), z(z_) {}
+
+WorkgroupAttribute::~WorkgroupAttribute() = default;
+
+std::string WorkgroupAttribute::Name() const {
+ return "workgroup_size";
+}
+
+const WorkgroupAttribute* WorkgroupAttribute::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto* x_ = ctx->Clone(x);
+ auto* y_ = ctx->Clone(y);
+ auto* z_ = ctx->Clone(z);
+ return ctx->dst->create<WorkgroupAttribute>(src, x_, y_, z_);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast/workgroup_attribute.h b/src/tint/lang/wgsl/ast/workgroup_attribute.h
new file mode 100644
index 0000000..593a6bc
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/workgroup_attribute.h
@@ -0,0 +1,71 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_WORKGROUP_ATTRIBUTE_H_
+#define SRC_TINT_LANG_WGSL_AST_WORKGROUP_ATTRIBUTE_H_
+
+#include <array>
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/attribute.h"
+
+// Forward declarations
+namespace tint::ast {
+class Expression;
+} // namespace tint::ast
+
+namespace tint::ast {
+
+/// A workgroup attribute
+class WorkgroupAttribute final : public utils::Castable<WorkgroupAttribute, Attribute> {
+ public:
+ /// constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param x the workgroup x dimension expression
+ /// @param y the optional workgroup y dimension expression
+ /// @param z the optional workgroup z dimension expression
+ WorkgroupAttribute(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const Expression* x,
+ const Expression* y = nullptr,
+ const Expression* z = nullptr);
+
+ ~WorkgroupAttribute() override;
+
+ /// @returns the workgroup dimensions
+ std::array<const Expression*, 3> Values() const { return {x, y, z}; }
+
+ /// @returns the WGSL name for the attribute
+ std::string Name() const override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const WorkgroupAttribute* Clone(CloneContext* ctx) const override;
+
+ /// The workgroup x dimension.
+ const Expression* const x;
+ /// The optional workgroup y dimension. May be null.
+ const Expression* const y = nullptr;
+ /// The optional workgroup z dimension. May be null.
+ const Expression* const z = nullptr;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_LANG_WGSL_AST_WORKGROUP_ATTRIBUTE_H_
diff --git a/src/tint/lang/wgsl/ast/workgroup_attribute_test.cc b/src/tint/lang/wgsl/ast/workgroup_attribute_test.cc
new file mode 100644
index 0000000..0dd8729
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/workgroup_attribute_test.cc
@@ -0,0 +1,80 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/workgroup_attribute.h"
+
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast {
+namespace {
+
+using WorkgroupAttributeTest = TestHelper;
+
+TEST_F(WorkgroupAttributeTest, Creation_1param) {
+ auto* d = WorkgroupSize(2_i);
+ auto values = d->Values();
+
+ ASSERT_TRUE(values[0]->Is<IntLiteralExpression>());
+ EXPECT_EQ(values[0]->As<IntLiteralExpression>()->value, 2);
+
+ EXPECT_EQ(values[1], nullptr);
+ EXPECT_EQ(values[2], nullptr);
+}
+TEST_F(WorkgroupAttributeTest, Creation_2param) {
+ auto* d = WorkgroupSize(2_i, 4_i);
+ auto values = d->Values();
+
+ ASSERT_TRUE(values[0]->Is<IntLiteralExpression>());
+ EXPECT_EQ(values[0]->As<IntLiteralExpression>()->value, 2);
+
+ ASSERT_TRUE(values[1]->Is<IntLiteralExpression>());
+ EXPECT_EQ(values[1]->As<IntLiteralExpression>()->value, 4);
+
+ EXPECT_EQ(values[2], nullptr);
+}
+
+TEST_F(WorkgroupAttributeTest, Creation_3param) {
+ auto* d = WorkgroupSize(2_i, 4_i, 6_i);
+ auto values = d->Values();
+
+ ASSERT_TRUE(values[0]->Is<IntLiteralExpression>());
+ EXPECT_EQ(values[0]->As<IntLiteralExpression>()->value, 2);
+
+ ASSERT_TRUE(values[1]->Is<IntLiteralExpression>());
+ EXPECT_EQ(values[1]->As<IntLiteralExpression>()->value, 4);
+
+ ASSERT_TRUE(values[2]->Is<IntLiteralExpression>());
+ EXPECT_EQ(values[2]->As<IntLiteralExpression>()->value, 6);
+}
+
+TEST_F(WorkgroupAttributeTest, Creation_WithIdentifier) {
+ auto* d = WorkgroupSize(2_i, 4_i, "depth");
+ auto values = d->Values();
+
+ ASSERT_TRUE(values[0]->Is<IntLiteralExpression>());
+ EXPECT_EQ(values[0]->As<IntLiteralExpression>()->value, 2);
+
+ ASSERT_TRUE(values[1]->Is<IntLiteralExpression>());
+ EXPECT_EQ(values[1]->As<IntLiteralExpression>()->value, 4);
+
+ auto* z_ident = As<IdentifierExpression>(values[2]);
+ ASSERT_TRUE(z_ident);
+ EXPECT_EQ(z_ident->identifier->symbol.Name(), "depth");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/lang/wgsl/ast_writer/generator_impl.cc b/src/tint/lang/wgsl/ast_writer/generator_impl.cc
index 130dec9..a046225b 100644
--- a/src/tint/lang/wgsl/ast_writer/generator_impl.cc
+++ b/src/tint/lang/wgsl/ast_writer/generator_impl.cc
@@ -16,22 +16,22 @@
#include <algorithm>
-#include "src/tint/ast/alias.h"
-#include "src/tint/ast/bool_literal_expression.h"
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/float_literal_expression.h"
-#include "src/tint/ast/id_attribute.h"
-#include "src/tint/ast/internal_attribute.h"
-#include "src/tint/ast/interpolate_attribute.h"
-#include "src/tint/ast/invariant_attribute.h"
-#include "src/tint/ast/module.h"
-#include "src/tint/ast/stage_attribute.h"
-#include "src/tint/ast/stride_attribute.h"
-#include "src/tint/ast/struct_member_align_attribute.h"
-#include "src/tint/ast/struct_member_offset_attribute.h"
-#include "src/tint/ast/struct_member_size_attribute.h"
-#include "src/tint/ast/variable_decl_statement.h"
-#include "src/tint/ast/workgroup_attribute.h"
+#include "src/tint/lang/wgsl/ast/alias.h"
+#include "src/tint/lang/wgsl/ast/bool_literal_expression.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/float_literal_expression.h"
+#include "src/tint/lang/wgsl/ast/id_attribute.h"
+#include "src/tint/lang/wgsl/ast/internal_attribute.h"
+#include "src/tint/lang/wgsl/ast/interpolate_attribute.h"
+#include "src/tint/lang/wgsl/ast/invariant_attribute.h"
+#include "src/tint/lang/wgsl/ast/module.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+#include "src/tint/lang/wgsl/ast/stride_attribute.h"
+#include "src/tint/lang/wgsl/ast/struct_member_align_attribute.h"
+#include "src/tint/lang/wgsl/ast/struct_member_offset_attribute.h"
+#include "src/tint/lang/wgsl/ast/struct_member_size_attribute.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
+#include "src/tint/lang/wgsl/ast/workgroup_attribute.h"
#include "src/tint/sem/struct.h"
#include "src/tint/sem/switch_statement.h"
#include "src/tint/switch.h"
diff --git a/src/tint/lang/wgsl/ast_writer/generator_impl.h b/src/tint/lang/wgsl/ast_writer/generator_impl.h
index a393412..fe98186 100644
--- a/src/tint/lang/wgsl/ast_writer/generator_impl.h
+++ b/src/tint/lang/wgsl/ast_writer/generator_impl.h
@@ -17,22 +17,22 @@
#include <string>
-#include "src/tint/ast/assignment_statement.h"
-#include "src/tint/ast/binary_expression.h"
-#include "src/tint/ast/bitcast_expression.h"
-#include "src/tint/ast/break_if_statement.h"
-#include "src/tint/ast/break_statement.h"
-#include "src/tint/ast/compound_assignment_statement.h"
-#include "src/tint/ast/continue_statement.h"
-#include "src/tint/ast/discard_statement.h"
-#include "src/tint/ast/for_loop_statement.h"
-#include "src/tint/ast/if_statement.h"
-#include "src/tint/ast/index_accessor_expression.h"
-#include "src/tint/ast/loop_statement.h"
-#include "src/tint/ast/member_accessor_expression.h"
-#include "src/tint/ast/return_statement.h"
-#include "src/tint/ast/switch_statement.h"
-#include "src/tint/ast/unary_op_expression.h"
+#include "src/tint/lang/wgsl/ast/assignment_statement.h"
+#include "src/tint/lang/wgsl/ast/binary_expression.h"
+#include "src/tint/lang/wgsl/ast/bitcast_expression.h"
+#include "src/tint/lang/wgsl/ast/break_if_statement.h"
+#include "src/tint/lang/wgsl/ast/break_statement.h"
+#include "src/tint/lang/wgsl/ast/compound_assignment_statement.h"
+#include "src/tint/lang/wgsl/ast/continue_statement.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/for_loop_statement.h"
+#include "src/tint/lang/wgsl/ast/if_statement.h"
+#include "src/tint/lang/wgsl/ast/index_accessor_expression.h"
+#include "src/tint/lang/wgsl/ast/loop_statement.h"
+#include "src/tint/lang/wgsl/ast/member_accessor_expression.h"
+#include "src/tint/lang/wgsl/ast/return_statement.h"
+#include "src/tint/lang/wgsl/ast/switch_statement.h"
+#include "src/tint/lang/wgsl/ast/unary_op_expression.h"
#include "src/tint/program.h"
#include "src/tint/sem/struct.h"
#include "src/tint/utils/string_stream.h"
diff --git a/src/tint/lang/wgsl/ast_writer/generator_impl_call_test.cc b/src/tint/lang/wgsl/ast_writer/generator_impl_call_test.cc
index 21faae0..2471962 100644
--- a/src/tint/lang/wgsl/ast_writer/generator_impl_call_test.cc
+++ b/src/tint/lang/wgsl/ast_writer/generator_impl_call_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
#include "src/tint/lang/wgsl/ast_writer/test_helper.h"
#include "src/tint/utils/string_stream.h"
diff --git a/src/tint/lang/wgsl/ast_writer/generator_impl_function_test.cc b/src/tint/lang/wgsl/ast_writer/generator_impl_function_test.cc
index 3121c31..200e60c 100644
--- a/src/tint/lang/wgsl/ast_writer/generator_impl_function_test.cc
+++ b/src/tint/lang/wgsl/ast_writer/generator_impl_function_test.cc
@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/stage_attribute.h"
-#include "src/tint/ast/variable_decl_statement.h"
-#include "src/tint/ast/workgroup_attribute.h"
#include "src/tint/builtin/builtin_value.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
+#include "src/tint/lang/wgsl/ast/workgroup_attribute.h"
#include "src/tint/lang/wgsl/ast_writer/test_helper.h"
#include "gmock/gmock.h"
diff --git a/src/tint/lang/wgsl/ast_writer/generator_impl_global_decl_test.cc b/src/tint/lang/wgsl/ast_writer/generator_impl_global_decl_test.cc
index dd20231..75161a2 100644
--- a/src/tint/lang/wgsl/ast_writer/generator_impl_global_decl_test.cc
+++ b/src/tint/lang/wgsl/ast_writer/generator_impl_global_decl_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/stage_attribute.h"
-#include "src/tint/ast/variable_decl_statement.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
#include "src/tint/lang/wgsl/ast_writer/test_helper.h"
#include "src/tint/type/sampled_texture.h"
#include "src/tint/type/texture_dimension.h"
diff --git a/src/tint/lang/wgsl/ast_writer/generator_impl_variable_decl_statement_test.cc b/src/tint/lang/wgsl/ast_writer/generator_impl_variable_decl_statement_test.cc
index bc66f1e..a681d10 100644
--- a/src/tint/lang/wgsl/ast_writer/generator_impl_variable_decl_statement_test.cc
+++ b/src/tint/lang/wgsl/ast_writer/generator_impl_variable_decl_statement_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/variable_decl_statement.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
#include "src/tint/lang/wgsl/ast_writer/test_helper.h"
#include "gmock/gmock.h"
diff --git a/src/tint/lang/wgsl/reader/parser_impl.cc b/src/tint/lang/wgsl/reader/parser_impl.cc
index 00d8c91..5138e61 100644
--- a/src/tint/lang/wgsl/reader/parser_impl.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl.cc
@@ -16,25 +16,25 @@
#include <limits>
-#include "src/tint/ast/assignment_statement.h"
-#include "src/tint/ast/bitcast_expression.h"
-#include "src/tint/ast/break_if_statement.h"
-#include "src/tint/ast/break_statement.h"
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/continue_statement.h"
-#include "src/tint/ast/discard_statement.h"
-#include "src/tint/ast/id_attribute.h"
-#include "src/tint/ast/if_statement.h"
-#include "src/tint/ast/increment_decrement_statement.h"
-#include "src/tint/ast/invariant_attribute.h"
-#include "src/tint/ast/loop_statement.h"
-#include "src/tint/ast/return_statement.h"
-#include "src/tint/ast/stage_attribute.h"
-#include "src/tint/ast/switch_statement.h"
-#include "src/tint/ast/unary_op_expression.h"
-#include "src/tint/ast/variable_decl_statement.h"
-#include "src/tint/ast/workgroup_attribute.h"
#include "src/tint/builtin/attribute.h"
+#include "src/tint/lang/wgsl/ast/assignment_statement.h"
+#include "src/tint/lang/wgsl/ast/bitcast_expression.h"
+#include "src/tint/lang/wgsl/ast/break_if_statement.h"
+#include "src/tint/lang/wgsl/ast/break_statement.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/continue_statement.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/id_attribute.h"
+#include "src/tint/lang/wgsl/ast/if_statement.h"
+#include "src/tint/lang/wgsl/ast/increment_decrement_statement.h"
+#include "src/tint/lang/wgsl/ast/invariant_attribute.h"
+#include "src/tint/lang/wgsl/ast/loop_statement.h"
+#include "src/tint/lang/wgsl/ast/return_statement.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+#include "src/tint/lang/wgsl/ast/switch_statement.h"
+#include "src/tint/lang/wgsl/ast/unary_op_expression.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
+#include "src/tint/lang/wgsl/ast/workgroup_attribute.h"
#include "src/tint/lang/wgsl/reader/classify_template_args.h"
#include "src/tint/lang/wgsl/reader/lexer.h"
#include "src/tint/type/depth_texture.h"
diff --git a/src/tint/lang/wgsl/reader/parser_impl_break_stmt_test.cc b/src/tint/lang/wgsl/reader/parser_impl_break_stmt_test.cc
index 1282434..ff3c4bd 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_break_stmt_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_break_stmt_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/break_statement.h"
+#include "src/tint/lang/wgsl/ast/break_statement.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_call_stmt_test.cc b/src/tint/lang/wgsl/reader/parser_impl_call_stmt_test.cc
index e35c6b5..e3058d6 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_call_stmt_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_call_stmt_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_compound_stmt_test.cc b/src/tint/lang/wgsl/reader/parser_impl_compound_stmt_test.cc
index 21f7baa..6d984c11 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_compound_stmt_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_compound_stmt_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_continue_stmt_test.cc b/src/tint/lang/wgsl/reader/parser_impl_continue_stmt_test.cc
index 1c3b2ac..ea03a99 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_continue_stmt_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_continue_stmt_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/continue_statement.h"
+#include "src/tint/lang/wgsl/ast/continue_statement.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_continuing_stmt_test.cc b/src/tint/lang/wgsl/reader/parser_impl_continuing_stmt_test.cc
index fd8ceb0..de24790 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_continuing_stmt_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_continuing_stmt_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_diagnostic_attribute_test.cc b/src/tint/lang/wgsl/reader/parser_impl_diagnostic_attribute_test.cc
index 06c7a28..1af4402 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_diagnostic_attribute_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_diagnostic_attribute_test.cc
@@ -14,8 +14,8 @@
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
-#include "src/tint/ast/diagnostic_control.h"
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/diagnostic_control.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
namespace tint::reader::wgsl {
namespace {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_diagnostic_control_test.cc b/src/tint/lang/wgsl/reader/parser_impl_diagnostic_control_test.cc
index d5ece9a..09ec3cb 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_diagnostic_control_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_diagnostic_control_test.cc
@@ -14,8 +14,8 @@
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
-#include "src/tint/ast/diagnostic_control.h"
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/diagnostic_control.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
namespace tint::reader::wgsl {
namespace {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_diagnostic_directive_test.cc b/src/tint/lang/wgsl/reader/parser_impl_diagnostic_directive_test.cc
index 4e4c566..c5d95e0 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_diagnostic_directive_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_diagnostic_directive_test.cc
@@ -14,8 +14,8 @@
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
-#include "src/tint/ast/diagnostic_control.h"
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/diagnostic_control.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
namespace tint::reader::wgsl {
namespace {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_enable_directive_test.cc b/src/tint/lang/wgsl/reader/parser_impl_enable_directive_test.cc
index 2039f90..675e70b 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_enable_directive_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_enable_directive_test.cc
@@ -14,7 +14,7 @@
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
-#include "src/tint/ast/enable.h"
+#include "src/tint/lang/wgsl/ast/enable.h"
namespace tint::reader::wgsl {
namespace {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_for_stmt_test.cc b/src/tint/lang/wgsl/reader/parser_impl_for_stmt_test.cc
index a4b1b5a..ecd06e3 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_for_stmt_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_for_stmt_test.cc
@@ -14,7 +14,7 @@
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
-#include "src/tint/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
namespace tint::reader::wgsl {
namespace {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_function_attribute_list_test.cc b/src/tint/lang/wgsl/reader/parser_impl_function_attribute_list_test.cc
index c647c89..5ac213e 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_function_attribute_list_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_function_attribute_list_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/workgroup_attribute.h"
+#include "src/tint/lang/wgsl/ast/workgroup_attribute.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_function_attribute_test.cc b/src/tint/lang/wgsl/reader/parser_impl_function_attribute_test.cc
index c2fb866..3e54dc1 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_function_attribute_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_function_attribute_test.cc
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/stage_attribute.h"
-#include "src/tint/ast/test_helper.h"
-#include "src/tint/ast/workgroup_attribute.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/workgroup_attribute.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_function_decl_test.cc b/src/tint/lang/wgsl/reader/parser_impl_function_decl_test.cc
index b84bb17..2aca45c 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_function_decl_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_function_decl_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/test_helper.h"
-#include "src/tint/ast/workgroup_attribute.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/workgroup_attribute.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
#include "src/tint/utils/string.h"
diff --git a/src/tint/lang/wgsl/reader/parser_impl_function_header_test.cc b/src/tint/lang/wgsl/reader/parser_impl_function_header_test.cc
index 8fbd9a4..e043e68 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_function_header_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_function_header_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_global_constant_decl_test.cc b/src/tint/lang/wgsl/reader/parser_impl_global_constant_decl_test.cc
index fe1d800..696058f 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_global_constant_decl_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_global_constant_decl_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/id_attribute.h"
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/id_attribute.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_global_decl_test.cc b/src/tint/lang/wgsl/reader/parser_impl_global_decl_test.cc
index 2111a75..0e19d15 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_global_decl_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_global_decl_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_global_variable_decl_test.cc b/src/tint/lang/wgsl/reader/parser_impl_global_variable_decl_test.cc
index e71212b..45cb31f 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_global_variable_decl_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_global_variable_decl_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_loop_stmt_test.cc b/src/tint/lang/wgsl/reader/parser_impl_loop_stmt_test.cc
index e605a0b..839dc5c 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_loop_stmt_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_loop_stmt_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_param_list_test.cc b/src/tint/lang/wgsl/reader/parser_impl_param_list_test.cc
index 001f74b..2318a13 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_param_list_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_param_list_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_primary_expression_test.cc b/src/tint/lang/wgsl/reader/parser_impl_primary_expression_test.cc
index 5a3ec5a..a7862a3 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_primary_expression_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_primary_expression_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/bitcast_expression.h"
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/bitcast_expression.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_singular_expression_test.cc b/src/tint/lang/wgsl/reader/parser_impl_singular_expression_test.cc
index c242bab..18c8f3a 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_singular_expression_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_singular_expression_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_statement_test.cc b/src/tint/lang/wgsl/reader/parser_impl_statement_test.cc
index 25ebc27..d0ff3e9 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_statement_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_statement_test.cc
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/break_statement.h"
-#include "src/tint/ast/continue_statement.h"
-#include "src/tint/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/break_statement.h"
+#include "src/tint/lang/wgsl/ast/continue_statement.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_statements_test.cc b/src/tint/lang/wgsl/reader/parser_impl_statements_test.cc
index 800337c..385aa9c 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_statements_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_statements_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_struct_attribute_decl_test.cc b/src/tint/lang/wgsl/reader/parser_impl_struct_attribute_decl_test.cc
index f8addb4..a33f632 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_struct_attribute_decl_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_struct_attribute_decl_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/invariant_attribute.h"
+#include "src/tint/lang/wgsl/ast/invariant_attribute.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_struct_body_decl_test.cc b/src/tint/lang/wgsl/reader/parser_impl_struct_body_decl_test.cc
index 977904d..ba36f9a 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_struct_body_decl_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_struct_body_decl_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_struct_member_test.cc b/src/tint/lang/wgsl/reader/parser_impl_struct_member_test.cc
index 4dce83a..6f4da06 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_struct_member_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_struct_member_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_type_alias_test.cc b/src/tint/lang/wgsl/reader/parser_impl_type_alias_test.cc
index ad8874e..c8b8595 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_type_alias_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_type_alias_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_type_decl_test.cc b/src/tint/lang/wgsl/reader/parser_impl_type_decl_test.cc
index ffcbed6..02c43f2 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_type_decl_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_type_decl_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/alias.h"
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/alias.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
#include "src/tint/type/sampled_texture.h"
diff --git a/src/tint/lang/wgsl/reader/parser_impl_unary_expression_test.cc b/src/tint/lang/wgsl/reader/parser_impl_unary_expression_test.cc
index 7517a37..0f5201f 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_unary_expression_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_unary_expression_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/unary_op_expression.h"
+#include "src/tint/lang/wgsl/ast/unary_op_expression.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_variable_attribute_list_test.cc b/src/tint/lang/wgsl/reader/parser_impl_variable_attribute_list_test.cc
index 1ea68c7..a888506 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_variable_attribute_list_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_variable_attribute_list_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_variable_attribute_test.cc b/src/tint/lang/wgsl/reader/parser_impl_variable_attribute_test.cc
index d18209e..6d59ed6 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_variable_attribute_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_variable_attribute_test.cc
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/test_helper.h"
#include "src/tint/builtin/builtin_value.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_variable_decl_test.cc b/src/tint/lang/wgsl/reader/parser_impl_variable_decl_test.cc
index c722334..8448452 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_variable_decl_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_variable_decl_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_variable_ident_decl_test.cc b/src/tint/lang/wgsl/reader/parser_impl_variable_ident_decl_test.cc
index 2b44124..0891ca4 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_variable_ident_decl_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_variable_ident_decl_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_variable_qualifier_test.cc b/src/tint/lang/wgsl/reader/parser_impl_variable_qualifier_test.cc
index 21e52eb..ccfc78c 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_variable_qualifier_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_variable_qualifier_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_variable_stmt_test.cc b/src/tint/lang/wgsl/reader/parser_impl_variable_stmt_test.cc
index 3217d99..06b1f74 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_variable_stmt_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_variable_stmt_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ast/test_helper.h"
+#include "src/tint/lang/wgsl/ast/test_helper.h"
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
namespace tint::reader::wgsl {
diff --git a/src/tint/lang/wgsl/reader/parser_impl_while_stmt_test.cc b/src/tint/lang/wgsl/reader/parser_impl_while_stmt_test.cc
index 7a859ae..f38669a 100644
--- a/src/tint/lang/wgsl/reader/parser_impl_while_stmt_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_impl_while_stmt_test.cc
@@ -14,7 +14,7 @@
#include "src/tint/lang/wgsl/reader/parser_impl_test_helper.h"
-#include "src/tint/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
namespace tint::reader::wgsl {
namespace {
diff --git a/src/tint/lang/wgsl/reader/parser_test.cc b/src/tint/lang/wgsl/reader/parser_test.cc
index 4cb45c8..12c37f9 100644
--- a/src/tint/lang/wgsl/reader/parser_test.cc
+++ b/src/tint/lang/wgsl/reader/parser_test.cc
@@ -16,7 +16,7 @@
#include "gtest/gtest.h"
-#include "src/tint/ast/module.h"
+#include "src/tint/lang/wgsl/ast/module.h"
namespace tint::reader::wgsl {
namespace {
diff --git a/src/tint/lang/wgsl/syntax_tree_writer/generator_impl.cc b/src/tint/lang/wgsl/syntax_tree_writer/generator_impl.cc
index bc23479..a861e66 100644
--- a/src/tint/lang/wgsl/syntax_tree_writer/generator_impl.cc
+++ b/src/tint/lang/wgsl/syntax_tree_writer/generator_impl.cc
@@ -16,22 +16,22 @@
#include <algorithm>
-#include "src/tint/ast/alias.h"
-#include "src/tint/ast/bool_literal_expression.h"
-#include "src/tint/ast/call_statement.h"
-#include "src/tint/ast/float_literal_expression.h"
-#include "src/tint/ast/id_attribute.h"
-#include "src/tint/ast/internal_attribute.h"
-#include "src/tint/ast/interpolate_attribute.h"
-#include "src/tint/ast/invariant_attribute.h"
-#include "src/tint/ast/module.h"
-#include "src/tint/ast/stage_attribute.h"
-#include "src/tint/ast/stride_attribute.h"
-#include "src/tint/ast/struct_member_align_attribute.h"
-#include "src/tint/ast/struct_member_offset_attribute.h"
-#include "src/tint/ast/struct_member_size_attribute.h"
-#include "src/tint/ast/variable_decl_statement.h"
-#include "src/tint/ast/workgroup_attribute.h"
+#include "src/tint/lang/wgsl/ast/alias.h"
+#include "src/tint/lang/wgsl/ast/bool_literal_expression.h"
+#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/ast/float_literal_expression.h"
+#include "src/tint/lang/wgsl/ast/id_attribute.h"
+#include "src/tint/lang/wgsl/ast/internal_attribute.h"
+#include "src/tint/lang/wgsl/ast/interpolate_attribute.h"
+#include "src/tint/lang/wgsl/ast/invariant_attribute.h"
+#include "src/tint/lang/wgsl/ast/module.h"
+#include "src/tint/lang/wgsl/ast/stage_attribute.h"
+#include "src/tint/lang/wgsl/ast/stride_attribute.h"
+#include "src/tint/lang/wgsl/ast/struct_member_align_attribute.h"
+#include "src/tint/lang/wgsl/ast/struct_member_offset_attribute.h"
+#include "src/tint/lang/wgsl/ast/struct_member_size_attribute.h"
+#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
+#include "src/tint/lang/wgsl/ast/workgroup_attribute.h"
#include "src/tint/sem/struct.h"
#include "src/tint/sem/switch_statement.h"
#include "src/tint/switch.h"
diff --git a/src/tint/lang/wgsl/syntax_tree_writer/generator_impl.h b/src/tint/lang/wgsl/syntax_tree_writer/generator_impl.h
index 26914fe..871ae7f 100644
--- a/src/tint/lang/wgsl/syntax_tree_writer/generator_impl.h
+++ b/src/tint/lang/wgsl/syntax_tree_writer/generator_impl.h
@@ -17,22 +17,22 @@
#include <string>
-#include "src/tint/ast/assignment_statement.h"
-#include "src/tint/ast/binary_expression.h"
-#include "src/tint/ast/bitcast_expression.h"
-#include "src/tint/ast/break_if_statement.h"
-#include "src/tint/ast/break_statement.h"
-#include "src/tint/ast/compound_assignment_statement.h"
-#include "src/tint/ast/continue_statement.h"
-#include "src/tint/ast/discard_statement.h"
-#include "src/tint/ast/for_loop_statement.h"
-#include "src/tint/ast/if_statement.h"
-#include "src/tint/ast/index_accessor_expression.h"
-#include "src/tint/ast/loop_statement.h"
-#include "src/tint/ast/member_accessor_expression.h"
-#include "src/tint/ast/return_statement.h"
-#include "src/tint/ast/switch_statement.h"
-#include "src/tint/ast/unary_op_expression.h"
+#include "src/tint/lang/wgsl/ast/assignment_statement.h"
+#include "src/tint/lang/wgsl/ast/binary_expression.h"
+#include "src/tint/lang/wgsl/ast/bitcast_expression.h"
+#include "src/tint/lang/wgsl/ast/break_if_statement.h"
+#include "src/tint/lang/wgsl/ast/break_statement.h"
+#include "src/tint/lang/wgsl/ast/compound_assignment_statement.h"
+#include "src/tint/lang/wgsl/ast/continue_statement.h"
+#include "src/tint/lang/wgsl/ast/discard_statement.h"
+#include "src/tint/lang/wgsl/ast/for_loop_statement.h"
+#include "src/tint/lang/wgsl/ast/if_statement.h"
+#include "src/tint/lang/wgsl/ast/index_accessor_expression.h"
+#include "src/tint/lang/wgsl/ast/loop_statement.h"
+#include "src/tint/lang/wgsl/ast/member_accessor_expression.h"
+#include "src/tint/lang/wgsl/ast/return_statement.h"
+#include "src/tint/lang/wgsl/ast/switch_statement.h"
+#include "src/tint/lang/wgsl/ast/unary_op_expression.h"
#include "src/tint/program.h"
#include "src/tint/sem/struct.h"
#include "src/tint/utils/string_stream.h"