[ir] Add Robustness transform
Inserts clamping for access, load_vector_element, and
store_vector_element instructions. Handles config options for
disabling clamping for each individual address space.
Clamping for texture builtins will be added in a future CL, as will
handling for the config options that disable clamping in a more
fine-grained manner.
Bug: tint:1718
Change-Id: I2d9cadeb1a53a0dc70efa80f15fe46ecfe1cf191
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/150621
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/core/ir/transform/BUILD.bazel b/src/tint/lang/core/ir/transform/BUILD.bazel
index f950fe6..7216985 100644
--- a/src/tint/lang/core/ir/transform/BUILD.bazel
+++ b/src/tint/lang/core/ir/transform/BUILD.bazel
@@ -34,6 +34,7 @@
"builtin_polyfill.cc",
"demote_to_helper.cc",
"multiplanar_external_texture.cc",
+ "robustness.cc",
"shader_io.cc",
"std140.cc",
],
@@ -46,6 +47,7 @@
"builtin_polyfill.h",
"demote_to_helper.h",
"multiplanar_external_texture.h",
+ "robustness.h",
"shader_io.h",
"std140.h",
],
@@ -88,6 +90,7 @@
"demote_to_helper_test.cc",
"helper_test.h",
"multiplanar_external_texture_test.cc",
+ "robustness_test.cc",
"std140_test.cc",
],
deps = [
diff --git a/src/tint/lang/core/ir/transform/BUILD.cmake b/src/tint/lang/core/ir/transform/BUILD.cmake
index efa9256..f474b48 100644
--- a/src/tint/lang/core/ir/transform/BUILD.cmake
+++ b/src/tint/lang/core/ir/transform/BUILD.cmake
@@ -42,6 +42,8 @@
lang/core/ir/transform/demote_to_helper.h
lang/core/ir/transform/multiplanar_external_texture.cc
lang/core/ir/transform/multiplanar_external_texture.h
+ lang/core/ir/transform/robustness.cc
+ lang/core/ir/transform/robustness.h
lang/core/ir/transform/shader_io.cc
lang/core/ir/transform/shader_io.h
lang/core/ir/transform/std140.cc
@@ -86,6 +88,7 @@
lang/core/ir/transform/demote_to_helper_test.cc
lang/core/ir/transform/helper_test.h
lang/core/ir/transform/multiplanar_external_texture_test.cc
+ lang/core/ir/transform/robustness_test.cc
lang/core/ir/transform/std140_test.cc
)
diff --git a/src/tint/lang/core/ir/transform/BUILD.gn b/src/tint/lang/core/ir/transform/BUILD.gn
index ef5d690..82e462b 100644
--- a/src/tint/lang/core/ir/transform/BUILD.gn
+++ b/src/tint/lang/core/ir/transform/BUILD.gn
@@ -47,6 +47,8 @@
"demote_to_helper.h",
"multiplanar_external_texture.cc",
"multiplanar_external_texture.h",
+ "robustness.cc",
+ "robustness.h",
"shader_io.cc",
"shader_io.h",
"std140.cc",
@@ -89,6 +91,7 @@
"demote_to_helper_test.cc",
"helper_test.h",
"multiplanar_external_texture_test.cc",
+ "robustness_test.cc",
"std140_test.cc",
]
deps = [
diff --git a/src/tint/lang/core/ir/transform/robustness.cc b/src/tint/lang/core/ir/transform/robustness.cc
new file mode 100644
index 0000000..43c25ff3
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/robustness.cc
@@ -0,0 +1,237 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/core/ir/transform/robustness.h"
+
+#include <algorithm>
+#include <utility>
+
+#include "src/tint/lang/core/ir/builder.h"
+#include "src/tint/lang/core/ir/module.h"
+#include "src/tint/lang/core/ir/validator.h"
+
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
+
+namespace tint::core::ir::transform {
+
+namespace {
+
+/// PIMPL state for the transform.
+struct State {
+ /// The robustness config.
+ const RobustnessConfig& config;
+
+ /// The IR module.
+ Module* ir = nullptr;
+
+ /// The IR builder.
+ Builder b{*ir};
+
+ /// The type manager.
+ core::type::Manager& ty{ir->Types()};
+
+ /// Process the module.
+ void Process() {
+ // Find the access instructions that may need to be clamped.
+ Vector<ir::Access*, 64> accesses;
+ Vector<ir::LoadVectorElement*, 64> vector_loads;
+ Vector<ir::StoreVectorElement*, 64> vector_stores;
+ for (auto* inst : ir->instructions.Objects()) {
+ if (inst->Alive()) {
+ tint::Switch(
+ inst, //
+ [&](ir::Access* access) {
+ // Check if accesses into this object should be clamped.
+ auto* ptr = access->Object()->Type()->As<type::Pointer>();
+ if (ptr) {
+ if (ShouldClamp(ptr->AddressSpace())) {
+ accesses.Push(access);
+ }
+ } else {
+ if (config.clamp_value) {
+ accesses.Push(access);
+ }
+ }
+ },
+ [&](ir::LoadVectorElement* lve) {
+ // Check if loads from this address space should be clamped.
+ auto* ptr = lve->From()->Type()->As<type::Pointer>();
+ if (ShouldClamp(ptr->AddressSpace())) {
+ vector_loads.Push(lve);
+ }
+ },
+ [&](ir::StoreVectorElement* sve) {
+ // Check if stores to this address space should be clamped.
+ auto* ptr = sve->To()->Type()->As<type::Pointer>();
+ if (ShouldClamp(ptr->AddressSpace())) {
+ vector_stores.Push(sve);
+ }
+ });
+ }
+ }
+
+ // Clamp access indices.
+ for (auto* access : accesses) {
+ b.InsertBefore(access, [&] { //
+ ClampAccessIndices(access);
+ });
+ }
+
+ // Clamp load-vector-element instructions.
+ for (auto* lve : vector_loads) {
+ auto* vec = lve->From()->Type()->UnwrapPtr()->As<type::Vector>();
+ b.InsertBefore(lve, [&] { //
+ ClampOperand(lve, LoadVectorElement::kIndexOperandOffset,
+ b.Constant(u32(vec->Width() - 1u)));
+ });
+ }
+
+ // Clamp store-vector-element instructions.
+ for (auto* sve : vector_stores) {
+ auto* vec = sve->To()->Type()->UnwrapPtr()->As<type::Vector>();
+ b.InsertBefore(sve, [&] { //
+ ClampOperand(sve, StoreVectorElement::kIndexOperandOffset,
+ b.Constant(u32(vec->Width() - 1u)));
+ });
+ }
+
+ // TODO(jrprice): Handle texture builtins.
+ // TODO(jrprice): Handle config.bindings_ignored.
+ // TODO(jrprice): Handle config.disable_runtime_sized_array_index_clamping.
+ }
+
+ /// Check if clamping should be applied to a particular address space.
+ /// @param addrspace the address space to check
+ /// @returns true if pointer accesses in @p param addrspace should be clamped
+ bool ShouldClamp(AddressSpace addrspace) {
+ switch (addrspace) {
+ case AddressSpace::kFunction:
+ return config.clamp_function;
+ case AddressSpace::kPrivate:
+ return config.clamp_private;
+ case AddressSpace::kPushConstant:
+ return config.clamp_push_constant;
+ case AddressSpace::kStorage:
+ return config.clamp_storage;
+ case AddressSpace::kUniform:
+ return config.clamp_uniform;
+ case AddressSpace::kWorkgroup:
+ return config.clamp_workgroup;
+ case AddressSpace::kUndefined:
+ case AddressSpace::kPixelLocal:
+ case AddressSpace::kHandle:
+ case AddressSpace::kIn:
+ case AddressSpace::kOut:
+ return false;
+ }
+ return false;
+ }
+
+ /// Clamp operand @p op_idx of @p inst to ensure it is within @p limit.
+ /// @param inst the instruction
+ /// @param op_idx the index of the operand that should be clamped
+ /// @param limit the limit to clamp to
+ void ClampOperand(ir::Instruction* inst, size_t op_idx, ir::Value* limit) {
+ auto* idx = inst->Operands()[op_idx];
+ auto* const_idx = idx->As<ir::Constant>();
+ auto* const_limit = limit->As<ir::Constant>();
+
+ ir::Value* clamped_idx = nullptr;
+ if (const_idx && const_limit) {
+ // Generate a new constant index that is clamped to the limit.
+ clamped_idx = b.Constant(u32(std::min(const_idx->Value()->ValueAs<uint32_t>(),
+ const_limit->Value()->ValueAs<uint32_t>())));
+ } else {
+ // Convert the index to u32 if needed.
+ if (idx->Type()->is_signed_integer_scalar()) {
+ idx = b.Convert(ty.u32(), idx)->Result();
+ }
+
+ // Clamp it to the dynamic limit.
+ clamped_idx = b.Call(ty.u32(), core::Function::kMin, idx, limit)->Result();
+ }
+
+ // Replace the index operand with the clamped version.
+ inst->SetOperand(op_idx, clamped_idx);
+ }
+
+ /// Clamp the indices of an access instruction to ensure they are within the limits of the types
+ /// that they are indexing into.
+ /// @param access the access instruction
+ void ClampAccessIndices(ir::Access* access) {
+ auto* type = access->Object()->Type()->UnwrapPtr();
+ auto indices = access->Indices();
+ for (size_t i = 0; i < indices.Length(); i++) {
+ auto* idx = indices[i];
+ auto* const_idx = idx->As<ir::Constant>();
+
+ // Determine the limit of the type being indexed into.
+ auto limit = tint::Switch(
+ type, //
+ [&](const type::Vector* vec) -> ir::Value* {
+ return b.Constant(u32(vec->Width() - 1u));
+ },
+ [&](const type::Matrix* mat) -> ir::Value* {
+ return b.Constant(u32(mat->columns() - 1u));
+ },
+ [&](const type::Array* arr) -> ir::Value* {
+ if (arr->ConstantCount()) {
+ return b.Constant(u32(arr->ConstantCount().value() - 1u));
+ }
+ TINT_ASSERT_OR_RETURN_VALUE(arr->Count()->Is<type::RuntimeArrayCount>(),
+ nullptr);
+
+ auto* object = access->Object();
+ if (i > 0) {
+ // Generate a pointer to the runtime-sized array if it isn't the base of
+ // this access instruction.
+ auto* base_ptr = object->Type()->As<type::Pointer>();
+ TINT_ASSERT_OR_RETURN_VALUE(base_ptr != nullptr, nullptr);
+ TINT_ASSERT_OR_RETURN_VALUE(i == 1, nullptr);
+ auto* arr_ptr = ty.ptr(base_ptr->AddressSpace(), arr, base_ptr->Access());
+ object = b.Access(arr_ptr, object, indices[0])->Result();
+ }
+
+ // Use the `arrayLength` builtin to get the limit of a runtime-sized array.
+ auto* length = b.Call(ty.u32(), core::Function::kArrayLength, object);
+ return b.Subtract(ty.u32(), length, b.Constant(1_u))->Result();
+ });
+
+ // If there's a dynamic limit that needs enforced, clamp the index operand.
+ if (limit) {
+ ClampOperand(access, ir::Access::kIndicesOperandOffset + i, limit);
+ }
+
+ // Get the type that this index produces.
+ type = const_idx ? type->Element(const_idx->Value()->ValueAs<u32>())
+ : type->Elements().type;
+ }
+ }
+};
+
+} // namespace
+
+Result<SuccessType, std::string> Robustness(Module* ir, const RobustnessConfig& config) {
+ auto result = ValidateAndDumpIfNeeded(*ir, "Robustness transform");
+ if (!result) {
+ return result;
+ }
+
+ State{config, ir}.Process();
+
+ return Success;
+}
+
+} // namespace tint::core::ir::transform
diff --git a/src/tint/lang/core/ir/transform/robustness.h b/src/tint/lang/core/ir/transform/robustness.h
new file mode 100644
index 0000000..82d0452
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/robustness.h
@@ -0,0 +1,67 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_CORE_IR_TRANSFORM_ROBUSTNESS_H_
+#define SRC_TINT_LANG_CORE_IR_TRANSFORM_ROBUSTNESS_H_
+
+#include <string>
+#include <unordered_set>
+
+#include "src/tint/api/common/binding_point.h"
+#include "src/tint/utils/result/result.h"
+
+// Forward declarations.
+namespace tint::core::ir {
+class Module;
+}
+
+namespace tint::core::ir::transform {
+
+/// Configuration options that control when to clamp accesses.
+struct RobustnessConfig {
+ /// Should non-pointer accesses be clamped?
+ bool clamp_value = true;
+
+ /// Should texture accesses be clamped?
+ bool clamp_texture = true;
+
+ /// Should accesses to pointers with the 'function' address space be clamped?
+ bool clamp_function = true;
+ /// Should accesses to pointers with the 'private' address space be clamped?
+ bool clamp_private = true;
+ /// Should accesses to pointers with the 'push_constant' address space be clamped?
+ bool clamp_push_constant = true;
+ /// Should accesses to pointers with the 'storage' address space be clamped?
+ bool clamp_storage = true;
+ /// Should accesses to pointers with the 'uniform' address space be clamped?
+ bool clamp_uniform = true;
+ /// Should accesses to pointers with the 'workgroup' address space be clamped?
+ bool clamp_workgroup = true;
+
+ /// Bindings that should always be ignored.
+ std::unordered_set<tint::BindingPoint> bindings_ignored;
+
+ /// Should the transform skip index clamping on runtime-sized arrays?
+ bool disable_runtime_sized_array_index_clamping = false;
+};
+
+/// Robustness is a transform that prevents out-of-bounds memory accesses.
+/// @param module the module to transform
+/// @param config the robustness configuration
+/// @returns an error string on failure
+Result<SuccessType, std::string> Robustness(Module* module, const RobustnessConfig& config);
+
+} // namespace tint::core::ir::transform
+
+#endif // SRC_TINT_LANG_CORE_IR_TRANSFORM_ROBUSTNESS_H_
diff --git a/src/tint/lang/core/ir/transform/robustness_test.cc b/src/tint/lang/core/ir/transform/robustness_test.cc
new file mode 100644
index 0000000..6bda7fd
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/robustness_test.cc
@@ -0,0 +1,1874 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/core/ir/transform/robustness.h"
+
+#include <utility>
+
+#include "src/tint/lang/core/ir/transform/helper_test.h"
+#include "src/tint/lang/core/type/array.h"
+#include "src/tint/lang/core/type/matrix.h"
+#include "src/tint/lang/core/type/pointer.h"
+#include "src/tint/lang/core/type/struct.h"
+#include "src/tint/lang/core/type/vector.h"
+
+namespace tint::core::ir::transform {
+namespace {
+
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
+
+using IR_RobustnessTest = TransformTestWithParam<bool>;
+
+////////////////////////////////////////////////////////////////
+// These tests use the function address space.
+// Test clamping of vectors, matrices, and fixed-size arrays.
+// Test indices that are const, const-via-let, and dynamic.
+// Test signed vs unsigned indices.
+////////////////////////////////////////////////////////////////
+
+TEST_P(IR_RobustnessTest, VectorLoad_ConstIndex) {
+ auto* func = b.Function("foo", ty.u32());
+ b.Append(func->Block(), [&] {
+ auto* vec = b.Var("vec", ty.ptr(function, ty.vec4<u32>()));
+ auto* load = b.LoadVectorElement(vec, b.Constant(5_u));
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%foo = func():u32 -> %b1 {
+ %b1 = block {
+ %vec:ptr<function, vec4<u32>, read_write> = var
+ %3:u32 = load_vector_element %vec, 5u
+ ret %3
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func():u32 -> %b1 {
+ %b1 = block {
+ %vec:ptr<function, vec4<u32>, read_write> = var
+ %3:u32 = load_vector_element %vec, 3u
+ ret %3
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, VectorLoad_ConstIndexViaLet) {
+ auto* func = b.Function("foo", ty.u32());
+ b.Append(func->Block(), [&] {
+ auto* vec = b.Var("vec", ty.ptr(function, ty.vec4<u32>()));
+ auto* idx = b.Let("idx", b.Constant(5_u));
+ auto* load = b.LoadVectorElement(vec, idx);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%foo = func():u32 -> %b1 {
+ %b1 = block {
+ %vec:ptr<function, vec4<u32>, read_write> = var
+ %idx:u32 = let 5u
+ %4:u32 = load_vector_element %vec, %idx
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func():u32 -> %b1 {
+ %b1 = block {
+ %vec:ptr<function, vec4<u32>, read_write> = var
+ %idx:u32 = let 5u
+ %4:u32 = min %idx, 3u
+ %5:u32 = load_vector_element %vec, %4
+ ret %5
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, VectorLoad_DynamicIndex) {
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* vec = b.Var("vec", ty.ptr(function, ty.vec4<u32>()));
+ auto* load = b.LoadVectorElement(vec, idx);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%foo = func(%idx:u32):u32 -> %b1 {
+ %b1 = block {
+ %vec:ptr<function, vec4<u32>, read_write> = var
+ %4:u32 = load_vector_element %vec, %idx
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%idx:u32):u32 -> %b1 {
+ %b1 = block {
+ %vec:ptr<function, vec4<u32>, read_write> = var
+ %4:u32 = min %idx, 3u
+ %5:u32 = load_vector_element %vec, %4
+ ret %5
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, VectorLoad_DynamicIndex_Signed) {
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.i32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* vec = b.Var("vec", ty.ptr(function, ty.vec4<u32>()));
+ auto* load = b.LoadVectorElement(vec, idx);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%foo = func(%idx:i32):u32 -> %b1 {
+ %b1 = block {
+ %vec:ptr<function, vec4<u32>, read_write> = var
+ %4:u32 = load_vector_element %vec, %idx
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%idx:i32):u32 -> %b1 {
+ %b1 = block {
+ %vec:ptr<function, vec4<u32>, read_write> = var
+ %4:u32 = convert %idx
+ %5:u32 = min %4, 3u
+ %6:u32 = load_vector_element %vec, %5
+ ret %6
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, VectorStore_ConstIndex) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* vec = b.Var("vec", ty.ptr(function, ty.vec4<u32>()));
+ b.StoreVectorElement(vec, b.Constant(5_u), b.Constant(0_u));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = func():void -> %b1 {
+ %b1 = block {
+ %vec:ptr<function, vec4<u32>, read_write> = var
+ store_vector_element %vec, 5u, 0u
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func():void -> %b1 {
+ %b1 = block {
+ %vec:ptr<function, vec4<u32>, read_write> = var
+ store_vector_element %vec, 3u, 0u
+ ret
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, VectorStore_ConstIndexViaLet) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* vec = b.Var("vec", ty.ptr(function, ty.vec4<u32>()));
+ auto* idx = b.Let("idx", b.Constant(5_u));
+ b.StoreVectorElement(vec, idx, b.Constant(0_u));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = func():void -> %b1 {
+ %b1 = block {
+ %vec:ptr<function, vec4<u32>, read_write> = var
+ %idx:u32 = let 5u
+ store_vector_element %vec, %idx, 0u
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func():void -> %b1 {
+ %b1 = block {
+ %vec:ptr<function, vec4<u32>, read_write> = var
+ %idx:u32 = let 5u
+ %4:u32 = min %idx, 3u
+ store_vector_element %vec, %4, 0u
+ ret
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, VectorStore_DynamicIndex) {
+ auto* func = b.Function("foo", ty.void_());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* vec = b.Var("vec", ty.ptr(function, ty.vec4<u32>()));
+ b.StoreVectorElement(vec, idx, b.Constant(0_u));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = func(%idx:u32):void -> %b1 {
+ %b1 = block {
+ %vec:ptr<function, vec4<u32>, read_write> = var
+ store_vector_element %vec, %idx, 0u
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%idx:u32):void -> %b1 {
+ %b1 = block {
+ %vec:ptr<function, vec4<u32>, read_write> = var
+ %4:u32 = min %idx, 3u
+ store_vector_element %vec, %4, 0u
+ ret
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, VectorStore_DynamicIndex_Signed) {
+ auto* func = b.Function("foo", ty.void_());
+ auto* idx = b.FunctionParam("idx", ty.i32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* vec = b.Var("vec", ty.ptr(function, ty.vec4<u32>()));
+ b.StoreVectorElement(vec, idx, b.Constant(0_u));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = func(%idx:i32):void -> %b1 {
+ %b1 = block {
+ %vec:ptr<function, vec4<u32>, read_write> = var
+ store_vector_element %vec, %idx, 0u
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%idx:i32):void -> %b1 {
+ %b1 = block {
+ %vec:ptr<function, vec4<u32>, read_write> = var
+ %4:u32 = convert %idx
+ %5:u32 = min %4, 3u
+ store_vector_element %vec, %5, 0u
+ ret
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, Matrix_ConstIndex) {
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ b.Append(func->Block(), [&] {
+ auto* mat = b.Var("mat", ty.ptr(function, ty.mat4x4<f32>()));
+ auto* access = b.Access(ty.ptr(function, ty.vec4<f32>()), mat, b.Constant(2_u));
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%foo = func():vec4<f32> -> %b1 {
+ %b1 = block {
+ %mat:ptr<function, mat4x4<f32>, read_write> = var
+ %3:ptr<function, vec4<f32>, read_write> = access %mat, 2u
+ %4:vec4<f32> = load %3
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_P(IR_RobustnessTest, Matrix_ConstIndexViaLet) {
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ b.Append(func->Block(), [&] {
+ auto* mat = b.Var("mat", ty.ptr(function, ty.mat4x4<f32>()));
+ auto* idx = b.Let("idx", b.Constant(2_u));
+ auto* access = b.Access(ty.ptr(function, ty.vec4<f32>()), mat, idx);
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%foo = func():vec4<f32> -> %b1 {
+ %b1 = block {
+ %mat:ptr<function, mat4x4<f32>, read_write> = var
+ %idx:u32 = let 2u
+ %4:ptr<function, vec4<f32>, read_write> = access %mat, %idx
+ %5:vec4<f32> = load %4
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func():vec4<f32> -> %b1 {
+ %b1 = block {
+ %mat:ptr<function, mat4x4<f32>, read_write> = var
+ %idx:u32 = let 2u
+ %4:u32 = min %idx, 3u
+ %5:ptr<function, vec4<f32>, read_write> = access %mat, %4
+ %6:vec4<f32> = load %5
+ ret %6
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, Matrix_DynamicIndex) {
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* mat = b.Var("mat", ty.ptr(function, ty.mat4x4<f32>()));
+ auto* access = b.Access(ty.ptr(function, ty.vec4<f32>()), mat, idx);
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%foo = func(%idx:u32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %mat:ptr<function, mat4x4<f32>, read_write> = var
+ %4:ptr<function, vec4<f32>, read_write> = access %mat, %idx
+ %5:vec4<f32> = load %4
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%idx:u32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %mat:ptr<function, mat4x4<f32>, read_write> = var
+ %4:u32 = min %idx, 3u
+ %5:ptr<function, vec4<f32>, read_write> = access %mat, %4
+ %6:vec4<f32> = load %5
+ ret %6
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, Matrix_DynamicIndex_Signed) {
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ auto* idx = b.FunctionParam("idx", ty.i32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* mat = b.Var("mat", ty.ptr(function, ty.mat4x4<f32>()));
+ auto* access = b.Access(ty.ptr(function, ty.vec4<f32>()), mat, idx);
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%foo = func(%idx:i32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %mat:ptr<function, mat4x4<f32>, read_write> = var
+ %4:ptr<function, vec4<f32>, read_write> = access %mat, %idx
+ %5:vec4<f32> = load %4
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%idx:i32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %mat:ptr<function, mat4x4<f32>, read_write> = var
+ %4:u32 = convert %idx
+ %5:u32 = min %4, 3u
+ %6:ptr<function, vec4<f32>, read_write> = access %mat, %5
+ %7:vec4<f32> = load %6
+ ret %7
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, Array_ConstSize_ConstIndex) {
+ auto* func = b.Function("foo", ty.u32());
+ b.Append(func->Block(), [&] {
+ auto* arr = b.Var("arr", ty.ptr(function, ty.array<u32, 4>()));
+ auto* access = b.Access(ty.ptr<function, u32>(), arr, b.Constant(2_u));
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%foo = func():u32 -> %b1 {
+ %b1 = block {
+ %arr:ptr<function, array<u32, 4>, read_write> = var
+ %3:ptr<function, u32, read_write> = access %arr, 2u
+ %4:u32 = load %3
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_P(IR_RobustnessTest, Array_ConstSize_ConstIndexViaLet) {
+ auto* func = b.Function("foo", ty.u32());
+ b.Append(func->Block(), [&] {
+ auto* arr = b.Var("arr", ty.ptr(function, ty.array<u32, 4>()));
+ auto* idx = b.Let("idx", b.Constant(2_u));
+ auto* access = b.Access(ty.ptr<function, u32>(), arr, idx);
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%foo = func():u32 -> %b1 {
+ %b1 = block {
+ %arr:ptr<function, array<u32, 4>, read_write> = var
+ %idx:u32 = let 2u
+ %4:ptr<function, u32, read_write> = access %arr, %idx
+ %5:u32 = load %4
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func():u32 -> %b1 {
+ %b1 = block {
+ %arr:ptr<function, array<u32, 4>, read_write> = var
+ %idx:u32 = let 2u
+ %4:u32 = min %idx, 3u
+ %5:ptr<function, u32, read_write> = access %arr, %4
+ %6:u32 = load %5
+ ret %6
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, Array_ConstSize_DynamicIndex) {
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* arr = b.Var("arr", ty.ptr(function, ty.array<u32, 4>()));
+ auto* access = b.Access(ty.ptr<function, u32>(), arr, idx);
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%foo = func(%idx:u32):u32 -> %b1 {
+ %b1 = block {
+ %arr:ptr<function, array<u32, 4>, read_write> = var
+ %4:ptr<function, u32, read_write> = access %arr, %idx
+ %5:u32 = load %4
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%idx:u32):u32 -> %b1 {
+ %b1 = block {
+ %arr:ptr<function, array<u32, 4>, read_write> = var
+ %4:u32 = min %idx, 3u
+ %5:ptr<function, u32, read_write> = access %arr, %4
+ %6:u32 = load %5
+ ret %6
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, Array_ConstSize_DynamicIndex_Signed) {
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.i32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* arr = b.Var("arr", ty.ptr(function, ty.array<u32, 4>()));
+ auto* access = b.Access(ty.ptr<function, u32>(), arr, idx);
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%foo = func(%idx:i32):u32 -> %b1 {
+ %b1 = block {
+ %arr:ptr<function, array<u32, 4>, read_write> = var
+ %4:ptr<function, u32, read_write> = access %arr, %idx
+ %5:u32 = load %4
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%idx:i32):u32 -> %b1 {
+ %b1 = block {
+ %arr:ptr<function, array<u32, 4>, read_write> = var
+ %4:u32 = convert %idx
+ %5:u32 = min %4, 3u
+ %6:ptr<function, u32, read_write> = access %arr, %5
+ %7:u32 = load %6
+ ret %7
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, NestedArrays) {
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx1 = b.FunctionParam("idx1", ty.u32());
+ auto* idx2 = b.FunctionParam("idx2", ty.u32());
+ auto* idx3 = b.FunctionParam("idx3", ty.u32());
+ auto* idx4 = b.FunctionParam("idx4", ty.u32());
+ func->SetParams({idx1, idx2, idx3, idx4});
+ b.Append(func->Block(), [&] {
+ auto* arr = b.Var(
+ "arr", ty.ptr(function, ty.array(ty.array(ty.array(ty.array(ty.u32(), 4), 5), 6), 7)));
+ auto* access = b.Access(ty.ptr<function, u32>(), arr, idx1, idx2, idx3, idx4);
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%foo = func(%idx1:u32, %idx2:u32, %idx3:u32, %idx4:u32):u32 -> %b1 {
+ %b1 = block {
+ %arr:ptr<function, array<array<array<array<u32, 4>, 5>, 6>, 7>, read_write> = var
+ %7:ptr<function, u32, read_write> = access %arr, %idx1, %idx2, %idx3, %idx4
+ %8:u32 = load %7
+ ret %8
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%idx1:u32, %idx2:u32, %idx3:u32, %idx4:u32):u32 -> %b1 {
+ %b1 = block {
+ %arr:ptr<function, array<array<array<array<u32, 4>, 5>, 6>, 7>, read_write> = var
+ %7:u32 = min %idx1, 6u
+ %8:u32 = min %idx2, 5u
+ %9:u32 = min %idx3, 4u
+ %10:u32 = min %idx4, 3u
+ %11:ptr<function, u32, read_write> = access %arr, %7, %8, %9, %10
+ %12:u32 = load %11
+ ret %12
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, NestedMixedTypes) {
+ auto* structure = ty.Struct(mod.symbols.Register("structure"),
+ {
+ {mod.symbols.Register("arr"), ty.array(ty.mat3x4<f32>(), 4)},
+ });
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ auto* idx1 = b.FunctionParam("idx1", ty.u32());
+ auto* idx2 = b.FunctionParam("idx2", ty.u32());
+ auto* idx3 = b.FunctionParam("idx3", ty.u32());
+ func->SetParams({idx1, idx2, idx3});
+ b.Append(func->Block(), [&] {
+ auto* arr = b.Var("arr", ty.ptr(function, ty.array(structure, 8)));
+ auto* access =
+ b.Access(ty.ptr<function, vec4<f32>>(), arr, idx1, b.Constant(0_u), idx2, idx3);
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+structure = struct @align(16) {
+ arr:array<mat3x4<f32>, 4> @offset(0)
+}
+
+%foo = func(%idx1:u32, %idx2:u32, %idx3:u32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %arr:ptr<function, array<structure, 8>, read_write> = var
+ %6:ptr<function, vec4<f32>, read_write> = access %arr, %idx1, 0u, %idx2, %idx3
+ %7:vec4<f32> = load %6
+ ret %7
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+structure = struct @align(16) {
+ arr:array<mat3x4<f32>, 4> @offset(0)
+}
+
+%foo = func(%idx1:u32, %idx2:u32, %idx3:u32):vec4<f32> -> %b1 {
+ %b1 = block {
+ %arr:ptr<function, array<structure, 8>, read_write> = var
+ %6:u32 = min %idx1, 7u
+ %7:u32 = min %idx2, 3u
+ %8:u32 = min %idx3, 2u
+ %9:ptr<function, vec4<f32>, read_write> = access %arr, %6, 0u, %7, %8
+ %10:vec4<f32> = load %9
+ ret %10
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_function = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+////////////////////////////////////////////////////////////////
+// Test the clamp toggles for every other address space.
+////////////////////////////////////////////////////////////////
+
+TEST_P(IR_RobustnessTest, Private_LoadVectorElement) {
+ auto* vec = b.Var("vec", ty.ptr(private_, ty.vec4<u32>()));
+ b.RootBlock()->Append(vec);
+
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* load = b.LoadVectorElement(vec, idx);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %vec:ptr<private, vec4<u32>, read_write> = var
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:u32 = load_vector_element %vec, %idx
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %vec:ptr<private, vec4<u32>, read_write> = var
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:u32 = min %idx, 3u
+ %5:u32 = load_vector_element %vec, %4
+ ret %5
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_private = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, Private_StoreVectorElement) {
+ auto* vec = b.Var("vec", ty.ptr(private_, ty.vec4<u32>()));
+ b.RootBlock()->Append(vec);
+
+ auto* func = b.Function("foo", ty.void_());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ b.StoreVectorElement(vec, idx, b.Constant(0_u));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %vec:ptr<private, vec4<u32>, read_write> = var
+}
+
+%foo = func(%idx:u32):void -> %b2 {
+ %b2 = block {
+ store_vector_element %vec, %idx, 0u
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %vec:ptr<private, vec4<u32>, read_write> = var
+}
+
+%foo = func(%idx:u32):void -> %b2 {
+ %b2 = block {
+ %4:u32 = min %idx, 3u
+ store_vector_element %vec, %4, 0u
+ ret
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_private = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, Private_Access) {
+ auto* arr = b.Var("arr", ty.ptr(private_, ty.array<u32, 4>()));
+ b.RootBlock()->Append(arr);
+
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* access = b.Access(ty.ptr<private_, u32>(), arr, idx);
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %arr:ptr<private, array<u32, 4>, read_write> = var
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:ptr<private, u32, read_write> = access %arr, %idx
+ %5:u32 = load %4
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %arr:ptr<private, array<u32, 4>, read_write> = var
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:u32 = min %idx, 3u
+ %5:ptr<private, u32, read_write> = access %arr, %4
+ %6:u32 = load %5
+ ret %6
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_private = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, PushConstant_LoadVectorElement) {
+ auto* vec = b.Var("vec", ty.ptr(push_constant, ty.vec4<u32>()));
+ b.RootBlock()->Append(vec);
+
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* load = b.LoadVectorElement(vec, idx);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %vec:ptr<push_constant, vec4<u32>, read_write> = var
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:u32 = load_vector_element %vec, %idx
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %vec:ptr<push_constant, vec4<u32>, read_write> = var
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:u32 = min %idx, 3u
+ %5:u32 = load_vector_element %vec, %4
+ ret %5
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_push_constant = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, PushConstant_StoreVectorElement) {
+ auto* vec = b.Var("vec", ty.ptr(push_constant, ty.vec4<u32>()));
+ b.RootBlock()->Append(vec);
+
+ auto* func = b.Function("foo", ty.void_());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ b.StoreVectorElement(vec, idx, b.Constant(0_u));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %vec:ptr<push_constant, vec4<u32>, read_write> = var
+}
+
+%foo = func(%idx:u32):void -> %b2 {
+ %b2 = block {
+ store_vector_element %vec, %idx, 0u
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %vec:ptr<push_constant, vec4<u32>, read_write> = var
+}
+
+%foo = func(%idx:u32):void -> %b2 {
+ %b2 = block {
+ %4:u32 = min %idx, 3u
+ store_vector_element %vec, %4, 0u
+ ret
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_push_constant = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, PushConstant_Access) {
+ auto* arr = b.Var("arr", ty.ptr(push_constant, ty.array<u32, 4>()));
+ b.RootBlock()->Append(arr);
+
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* access = b.Access(ty.ptr<push_constant, u32>(), arr, idx);
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %arr:ptr<push_constant, array<u32, 4>, read_write> = var
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:ptr<push_constant, u32, read_write> = access %arr, %idx
+ %5:u32 = load %4
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %arr:ptr<push_constant, array<u32, 4>, read_write> = var
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:u32 = min %idx, 3u
+ %5:ptr<push_constant, u32, read_write> = access %arr, %4
+ %6:u32 = load %5
+ ret %6
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_push_constant = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, Storage_LoadVectorElement) {
+ auto* vec = b.Var("vec", ty.ptr(storage, ty.vec4<u32>()));
+ vec->SetBindingPoint(0, 0);
+ b.RootBlock()->Append(vec);
+
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* load = b.LoadVectorElement(vec, idx);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %vec:ptr<storage, vec4<u32>, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:u32 = load_vector_element %vec, %idx
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %vec:ptr<storage, vec4<u32>, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:u32 = min %idx, 3u
+ %5:u32 = load_vector_element %vec, %4
+ ret %5
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_storage = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, Storage_StoreVectorElement) {
+ auto* vec = b.Var("vec", ty.ptr(storage, ty.vec4<u32>()));
+ vec->SetBindingPoint(0, 0);
+ b.RootBlock()->Append(vec);
+
+ auto* func = b.Function("foo", ty.void_());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ b.StoreVectorElement(vec, idx, b.Constant(0_u));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %vec:ptr<storage, vec4<u32>, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func(%idx:u32):void -> %b2 {
+ %b2 = block {
+ store_vector_element %vec, %idx, 0u
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %vec:ptr<storage, vec4<u32>, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func(%idx:u32):void -> %b2 {
+ %b2 = block {
+ %4:u32 = min %idx, 3u
+ store_vector_element %vec, %4, 0u
+ ret
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_storage = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, Storage_Access) {
+ auto* arr = b.Var("arr", ty.ptr(storage, ty.array<u32, 4>()));
+ arr->SetBindingPoint(0, 0);
+ b.RootBlock()->Append(arr);
+
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* access = b.Access(ty.ptr<storage, u32>(), arr, idx);
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %arr:ptr<storage, array<u32, 4>, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:ptr<storage, u32, read_write> = access %arr, %idx
+ %5:u32 = load %4
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %arr:ptr<storage, array<u32, 4>, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:u32 = min %idx, 3u
+ %5:ptr<storage, u32, read_write> = access %arr, %4
+ %6:u32 = load %5
+ ret %6
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_storage = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, Unifom_LoadVectorElement) {
+ auto* vec = b.Var("vec", ty.ptr(uniform, ty.vec4<u32>()));
+ vec->SetBindingPoint(0, 0);
+ b.RootBlock()->Append(vec);
+
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* load = b.LoadVectorElement(vec, idx);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %vec:ptr<uniform, vec4<u32>, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:u32 = load_vector_element %vec, %idx
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %vec:ptr<uniform, vec4<u32>, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:u32 = min %idx, 3u
+ %5:u32 = load_vector_element %vec, %4
+ ret %5
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_uniform = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, Unifom_StoreVectorElement) {
+ auto* vec = b.Var("vec", ty.ptr(uniform, ty.vec4<u32>()));
+ vec->SetBindingPoint(0, 0);
+ b.RootBlock()->Append(vec);
+
+ auto* func = b.Function("foo", ty.void_());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ b.StoreVectorElement(vec, idx, b.Constant(0_u));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %vec:ptr<uniform, vec4<u32>, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func(%idx:u32):void -> %b2 {
+ %b2 = block {
+ store_vector_element %vec, %idx, 0u
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %vec:ptr<uniform, vec4<u32>, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func(%idx:u32):void -> %b2 {
+ %b2 = block {
+ %4:u32 = min %idx, 3u
+ store_vector_element %vec, %4, 0u
+ ret
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_uniform = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, Unifom_Access) {
+ auto* arr = b.Var("arr", ty.ptr(uniform, ty.array<u32, 4>()));
+ arr->SetBindingPoint(0, 0);
+ b.RootBlock()->Append(arr);
+
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* access = b.Access(ty.ptr<uniform, u32>(), arr, idx);
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %arr:ptr<uniform, array<u32, 4>, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:ptr<uniform, u32, read_write> = access %arr, %idx
+ %5:u32 = load %4
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %arr:ptr<uniform, array<u32, 4>, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:u32 = min %idx, 3u
+ %5:ptr<uniform, u32, read_write> = access %arr, %4
+ %6:u32 = load %5
+ ret %6
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_uniform = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, Workgroup_LoadVectorElement) {
+ auto* vec = b.Var("vec", ty.ptr(workgroup, ty.vec4<u32>()));
+ b.RootBlock()->Append(vec);
+
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* load = b.LoadVectorElement(vec, idx);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %vec:ptr<workgroup, vec4<u32>, read_write> = var
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:u32 = load_vector_element %vec, %idx
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %vec:ptr<workgroup, vec4<u32>, read_write> = var
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:u32 = min %idx, 3u
+ %5:u32 = load_vector_element %vec, %4
+ ret %5
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_workgroup = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, Workgroup_StoreVectorElement) {
+ auto* vec = b.Var("vec", ty.ptr(workgroup, ty.vec4<u32>()));
+ b.RootBlock()->Append(vec);
+
+ auto* func = b.Function("foo", ty.void_());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ b.StoreVectorElement(vec, idx, b.Constant(0_u));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %vec:ptr<workgroup, vec4<u32>, read_write> = var
+}
+
+%foo = func(%idx:u32):void -> %b2 {
+ %b2 = block {
+ store_vector_element %vec, %idx, 0u
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %vec:ptr<workgroup, vec4<u32>, read_write> = var
+}
+
+%foo = func(%idx:u32):void -> %b2 {
+ %b2 = block {
+ %4:u32 = min %idx, 3u
+ store_vector_element %vec, %4, 0u
+ ret
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_workgroup = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, Workgroup_Access) {
+ auto* arr = b.Var("arr", ty.ptr(workgroup, ty.array<u32, 4>()));
+ b.RootBlock()->Append(arr);
+
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* access = b.Access(ty.ptr<workgroup, u32>(), arr, idx);
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %arr:ptr<workgroup, array<u32, 4>, read_write> = var
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:ptr<workgroup, u32, read_write> = access %arr, %idx
+ %5:u32 = load %4
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %arr:ptr<workgroup, array<u32, 4>, read_write> = var
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:u32 = min %idx, 3u
+ %5:ptr<workgroup, u32, read_write> = access %arr, %4
+ %6:u32 = load %5
+ ret %6
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_workgroup = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+////////////////////////////////////////////////////////////////
+// Test clamping non-pointer values.
+////////////////////////////////////////////////////////////////
+
+TEST_P(IR_RobustnessTest, ConstantVector_DynamicIndex) {
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* vec = mod.constant_values.Composite(ty.vec4<u32>(), Vector{
+ mod.constant_values.Get(1_u),
+ mod.constant_values.Get(2_u),
+ mod.constant_values.Get(3_u),
+ mod.constant_values.Get(4_u),
+ });
+ auto* element = b.Access(ty.u32(), b.Constant(vec), idx);
+ b.Return(func, element);
+ });
+
+ auto* src = R"(
+%foo = func(%idx:u32):u32 -> %b1 {
+ %b1 = block {
+ %3:u32 = access vec4<u32>(1u, 2u, 3u, 4u), %idx
+ ret %3
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%idx:u32):u32 -> %b1 {
+ %b1 = block {
+ %3:u32 = min %idx, 3u
+ %4:u32 = access vec4<u32>(1u, 2u, 3u, 4u), %3
+ ret %4
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_value = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, ConstantArray_DynamicIndex) {
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* arr =
+ mod.constant_values.Composite(ty.array<u32, 4>(), Vector{
+ mod.constant_values.Get(1_u),
+ mod.constant_values.Get(2_u),
+ mod.constant_values.Get(3_u),
+ mod.constant_values.Get(4_u),
+ });
+ auto* element = b.Access(ty.u32(), b.Constant(arr), idx);
+ b.Return(func, element);
+ });
+
+ auto* src = R"(
+%foo = func(%idx:u32):u32 -> %b1 {
+ %b1 = block {
+ %3:u32 = access array<u32, 4>(1u, 2u, 3u, 4u), %idx
+ ret %3
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%idx:u32):u32 -> %b1 {
+ %b1 = block {
+ %3:u32 = min %idx, 3u
+ %4:u32 = access array<u32, 4>(1u, 2u, 3u, 4u), %3
+ ret %4
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_value = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, ParamValueArray_DynamicIndex) {
+ auto* func = b.Function("foo", ty.u32());
+ auto* arr = b.FunctionParam("arr", ty.array<u32, 4>());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({arr, idx});
+ b.Append(func->Block(), [&] {
+ auto* element = b.Access(ty.u32(), arr, idx);
+ b.Return(func, element);
+ });
+
+ auto* src = R"(
+%foo = func(%arr:array<u32, 4>, %idx:u32):u32 -> %b1 {
+ %b1 = block {
+ %4:u32 = access %arr, %idx
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%arr:array<u32, 4>, %idx:u32):u32 -> %b1 {
+ %b1 = block {
+ %4:u32 = min %idx, 3u
+ %5:u32 = access %arr, %4
+ ret %5
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_value = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+INSTANTIATE_TEST_SUITE_P(, IR_RobustnessTest, testing::Values(false, true));
+
+////////////////////////////////////////////////////////////////
+// Test clamping non-pointer arrays.
+////////////////////////////////////////////////////////////////
+
+TEST_P(IR_RobustnessTest, RuntimeSizedArray_ConstIndex) {
+ auto* arr = b.Var("arr", ty.ptr(storage, ty.array<u32>()));
+ arr->SetBindingPoint(0, 0);
+ b.RootBlock()->Append(arr);
+
+ auto* func = b.Function("foo", ty.u32());
+ b.Append(func->Block(), [&] {
+ auto* access = b.Access(ty.ptr<storage, u32>(), arr, b.Constant(42_u));
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %arr:ptr<storage, array<u32>, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func():u32 -> %b2 {
+ %b2 = block {
+ %3:ptr<storage, u32, read_write> = access %arr, 42u
+ %4:u32 = load %3
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %arr:ptr<storage, array<u32>, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func():u32 -> %b2 {
+ %b2 = block {
+ %3:u32 = arrayLength %arr
+ %4:u32 = sub %3, 1u
+ %5:u32 = min 42u, %4
+ %6:ptr<storage, u32, read_write> = access %arr, %5
+ %7:u32 = load %6
+ ret %7
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_storage = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, RuntimeSizedArray_DynamicIndex) {
+ auto* arr = b.Var("arr", ty.ptr(storage, ty.array<u32>()));
+ arr->SetBindingPoint(0, 0);
+ b.RootBlock()->Append(arr);
+
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* access = b.Access(ty.ptr<storage, u32>(), arr, idx);
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %arr:ptr<storage, array<u32>, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:ptr<storage, u32, read_write> = access %arr, %idx
+ %5:u32 = load %4
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %arr:ptr<storage, array<u32>, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:u32 = arrayLength %arr
+ %5:u32 = sub %4, 1u
+ %6:u32 = min %idx, %5
+ %7:ptr<storage, u32, read_write> = access %arr, %6
+ %8:u32 = load %7
+ ret %8
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_storage = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, RuntimeSizedArray_InStruct_ConstIndex) {
+ auto* structure = ty.Struct(mod.symbols.Register("structure"),
+ {
+ {mod.symbols.Register("arr"), ty.array<u32>()},
+ });
+
+ auto* buffer = b.Var("buffer", ty.ptr(storage, structure));
+ buffer->SetBindingPoint(0, 0);
+ b.RootBlock()->Append(buffer);
+
+ auto* func = b.Function("foo", ty.u32());
+ b.Append(func->Block(), [&] {
+ auto* access = b.Access(ty.ptr<storage, u32>(), buffer, b.Constant(0_u), b.Constant(42_u));
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+structure = struct @align(4) {
+ arr:array<u32> @offset(0)
+}
+
+%b1 = block { # root
+ %buffer:ptr<storage, structure, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func():u32 -> %b2 {
+ %b2 = block {
+ %3:ptr<storage, u32, read_write> = access %buffer, 0u, 42u
+ %4:u32 = load %3
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+structure = struct @align(4) {
+ arr:array<u32> @offset(0)
+}
+
+%b1 = block { # root
+ %buffer:ptr<storage, structure, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func():u32 -> %b2 {
+ %b2 = block {
+ %3:ptr<storage, array<u32>, read_write> = access %buffer, 0u
+ %4:u32 = arrayLength %3
+ %5:u32 = sub %4, 1u
+ %6:u32 = min 42u, %5
+ %7:ptr<storage, u32, read_write> = access %buffer, 0u, %6
+ %8:u32 = load %7
+ ret %8
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_storage = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+TEST_P(IR_RobustnessTest, RuntimeSizedArray_InStruct_DynamicIndex) {
+ auto* structure = ty.Struct(mod.symbols.Register("structure"),
+ {
+ {mod.symbols.Register("arr"), ty.array<u32>()},
+ });
+
+ auto* buffer = b.Var("buffer", ty.ptr(storage, structure));
+ buffer->SetBindingPoint(0, 0);
+ b.RootBlock()->Append(buffer);
+
+ auto* func = b.Function("foo", ty.u32());
+ auto* idx = b.FunctionParam("idx", ty.u32());
+ func->SetParams({idx});
+ b.Append(func->Block(), [&] {
+ auto* access = b.Access(ty.ptr<storage, u32>(), buffer, b.Constant(0_u), idx);
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+structure = struct @align(4) {
+ arr:array<u32> @offset(0)
+}
+
+%b1 = block { # root
+ %buffer:ptr<storage, structure, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:ptr<storage, u32, read_write> = access %buffer, 0u, %idx
+ %5:u32 = load %4
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+structure = struct @align(4) {
+ arr:array<u32> @offset(0)
+}
+
+%b1 = block { # root
+ %buffer:ptr<storage, structure, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func(%idx:u32):u32 -> %b2 {
+ %b2 = block {
+ %4:ptr<storage, array<u32>, read_write> = access %buffer, 0u
+ %5:u32 = arrayLength %4
+ %6:u32 = sub %5, 1u
+ %7:u32 = min %idx, %6
+ %8:ptr<storage, u32, read_write> = access %buffer, 0u, %7
+ %9:u32 = load %8
+ ret %9
+ }
+}
+)";
+
+ RobustnessConfig cfg;
+ cfg.clamp_storage = GetParam();
+ Run(Robustness, cfg);
+
+ EXPECT_EQ(GetParam() ? expect : src, str());
+}
+
+} // namespace
+} // namespace tint::core::ir::transform
diff --git a/src/tint/lang/spirv/writer/access_test.cc b/src/tint/lang/spirv/writer/access_test.cc
index cbfa1bb..ffdd185 100644
--- a/src/tint/lang/spirv/writer/access_test.cc
+++ b/src/tint/lang/spirv/writer/access_test.cc
@@ -59,7 +59,11 @@
});
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%result = OpAccessChain %_ptr_Function_int %arr %idx");
+ EXPECT_INST(R"(
+ %13 = OpBitcast %uint %idx
+ %14 = OpExtInst %uint %15 UMin %13 %uint_3
+ %result = OpAccessChain %_ptr_Function_int %arr %14
+)");
}
TEST_F(SpirvWriterTest, Access_Matrix_Value_ConstantIndex) {
@@ -110,9 +114,15 @@
});
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%result_vector = OpAccessChain %_ptr_Function_v2float %mat %idx");
- EXPECT_INST("%15 = OpAccessChain %_ptr_Function_float %result_vector %idx");
- EXPECT_INST("%result_scalar = OpLoad %float %15");
+ EXPECT_INST(R"(
+ %14 = OpBitcast %uint %idx
+ %15 = OpExtInst %uint %16 UMin %14 %uint_1
+%result_vector = OpAccessChain %_ptr_Function_v2float %mat %15
+ %20 = OpBitcast %uint %idx
+ %21 = OpExtInst %uint %16 UMin %20 %uint_1
+ %22 = OpAccessChain %_ptr_Function_float %result_vector %21
+%result_scalar = OpLoad %float %22
+)");
}
TEST_F(SpirvWriterTest, Access_Vector_Value_ConstantIndex) {
@@ -141,7 +151,11 @@
});
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%result = OpVectorExtractDynamic %int %vec %idx");
+ EXPECT_INST(R"(
+ %10 = OpBitcast %uint %idx
+ %11 = OpExtInst %uint %12 UMin %10 %uint_3
+ %result = OpVectorExtractDynamic %int %vec %11
+)");
}
TEST_F(SpirvWriterTest, Access_NestedVector_Value_DynamicIndex) {
@@ -156,8 +170,12 @@
});
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%14 = OpCompositeExtract %v4int %arr 1 2");
- EXPECT_INST("%result = OpVectorExtractDynamic %int %14 %idx");
+ EXPECT_INST(R"(
+ %13 = OpBitcast %uint %idx
+ %14 = OpExtInst %uint %15 UMin %13 %uint_3
+ %18 = OpCompositeExtract %v4int %arr 1 2
+ %result = OpVectorExtractDynamic %int %18 %14
+)");
}
TEST_F(SpirvWriterTest, Access_Struct_Value_ConstantIndex) {
@@ -229,8 +247,12 @@
});
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%11 = OpAccessChain %_ptr_Function_int %vec %idx");
- EXPECT_INST("%result = OpLoad %int %11");
+ EXPECT_INST(R"(
+ %12 = OpBitcast %uint %idx
+ %13 = OpExtInst %uint %14 UMin %12 %uint_3
+ %16 = OpAccessChain %_ptr_Function_int %vec %13
+ %result = OpLoad %int %16
+)");
}
TEST_F(SpirvWriterTest, StoreVectorElement_ConstantIndex) {
@@ -257,8 +279,12 @@
});
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%11 = OpAccessChain %_ptr_Function_int %vec %idx");
- EXPECT_INST("OpStore %11 %int_42");
+ EXPECT_INST(R"(
+ %12 = OpBitcast %uint %idx
+ %13 = OpExtInst %uint %14 UMin %12 %uint_3
+ %16 = OpAccessChain %_ptr_Function_int %vec %13
+ OpStore %16 %int_42
+)");
}
} // namespace
diff --git a/src/tint/lang/spirv/writer/raise/raise.cc b/src/tint/lang/spirv/writer/raise/raise.cc
index f429783..71edb73 100644
--- a/src/tint/lang/spirv/writer/raise/raise.cc
+++ b/src/tint/lang/spirv/writer/raise/raise.cc
@@ -23,6 +23,7 @@
#include "src/tint/lang/core/ir/transform/builtin_polyfill.h"
#include "src/tint/lang/core/ir/transform/demote_to_helper.h"
#include "src/tint/lang/core/ir/transform/multiplanar_external_texture.h"
+#include "src/tint/lang/core/ir/transform/robustness.h"
#include "src/tint/lang/core/ir/transform/std140.h"
#include "src/tint/lang/spirv/writer/raise/builtin_polyfill.h"
#include "src/tint/lang/spirv/writer/raise/expand_implicit_splats.h"
@@ -56,6 +57,16 @@
core_polyfills.texture_sample_base_clamp_to_edge_2d_f32 = true;
RUN_TRANSFORM(core::ir::transform::BuiltinPolyfill, module, core_polyfills);
+ if (!options.disable_robustness) {
+ core::ir::transform::RobustnessConfig config;
+ if (options.disable_image_robustness) {
+ config.clamp_texture = false;
+ }
+ config.disable_runtime_sized_array_index_clamping =
+ options.disable_runtime_sized_array_index_clamping;
+ RUN_TRANSFORM(core::ir::transform::Robustness, module, config);
+ }
+
RUN_TRANSFORM(core::ir::transform::MultiplanarExternalTexture, module,
options.external_texture_options);