[tint][ir][val] Check type of input_attachment_index

Adds in a new capability to allow SPIRV backend to opt-out of this
check, so that spirv.image can also be used.

Fixes: 414333405
Change-Id: Icfc88f3408a0bbee7f4816bb475e406b1c1e3272
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/240055
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
diff --git a/src/tint/lang/core/ir/transform/std140.h b/src/tint/lang/core/ir/transform/std140.h
index 31f718a..c63d4e7 100644
--- a/src/tint/lang/core/ir/transform/std140.h
+++ b/src/tint/lang/core/ir/transform/std140.h
@@ -41,6 +41,7 @@
 /// The capabilities that the transform can support.
 const core::ir::Capabilities kStd140Capabilities{
     core::ir::Capability::kAllowHandleVarsWithoutBindings,
+    core::ir::Capability::kAllowAnyInputAttachmentIndexType,
 };
 
 /// Std140 is a transform that rewrites matrix types in the uniform address space to conform to
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 3c8fcc1..c759052 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -2604,10 +2604,17 @@
         }
     }
 
-    // Check that non-handle variables don't have @input_attachment_index set
-    if (var->InputAttachmentIndex().has_value() && mv->AddressSpace() != AddressSpace::kHandle) {
-        AddError(var) << "'@input_attachment_index' is not valid for non-handle var";
-        return;
+    if (var->InputAttachmentIndex().has_value()) {
+        if (mv->AddressSpace() != AddressSpace::kHandle) {
+            AddError(var) << "'@input_attachment_index' is not valid for non-handle var";
+            return;
+        }
+        if (!capabilities_.Contains(Capability::kAllowAnyInputAttachmentIndexType) &&
+            !mv->UnwrapPtrOrRef()->Is<core::type::InputAttachment>()) {
+            AddError(var)
+                << "'@input_attachment_index' is only valid for 'input_attachment' type var";
+            return;
+        }
     }
 
     if (var->Block() == mod_.root_block) {
diff --git a/src/tint/lang/core/ir/validator.h b/src/tint/lang/core/ir/validator.h
index 2ef66dc..1ec3087 100644
--- a/src/tint/lang/core/ir/validator.h
+++ b/src/tint/lang/core/ir/validator.h
@@ -64,6 +64,9 @@
     kAllowPhonyInstructions,
     /// Allows lets to have any type, used by MSL backend for module scoped vars
     kAllowAnyLetType,
+    /// Allows input_attachment_index to be associated with any type, used by
+    /// SPIRV backend for spirv.image.
+    kAllowAnyInputAttachmentIndexType,
 };
 
 /// Capabilities is a set of Capability
diff --git a/src/tint/lang/core/ir/validator_value_test.cc b/src/tint/lang/core/ir/validator_value_test.cc
index 1b3abb2..d16b491 100644
--- a/src/tint/lang/core/ir/validator_value_test.cc
+++ b/src/tint/lang/core/ir/validator_value_test.cc
@@ -609,6 +609,38 @@
 )")) << res.Failure();
 }
 
+TEST_F(IR_ValidatorTest, Var_InputAttachementIndex_NonHandle) {
+    auto* v = b.Var(ty.ptr(AddressSpace::kPrivate, ty.f32(), read_write));
+    v->SetInputAttachmentIndex(0);
+    mod.root_block->Append(v);
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_THAT(res.Failure().reason,
+                testing::HasSubstr(
+                    R"(:2:38 error: var: '@input_attachment_index' is not valid for non-handle var
+  %1:ptr<private, f32, read_write> = var undef @input_attachment_index(0)
+                                     ^^^
+)")) << res.Failure();
+}
+
+TEST_F(IR_ValidatorTest, Var_InputAttachementIndex_WrongType) {
+    auto* v = b.Var(ty.ptr(AddressSpace::kHandle, ty.f32(), read_write));
+    v->SetBindingPoint(0, 0);
+    v->SetInputAttachmentIndex(0);
+    mod.root_block->Append(v);
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_THAT(
+        res.Failure().reason,
+        testing::HasSubstr(
+            R"(:2:37 error: var: '@input_attachment_index' is only valid for 'input_attachment' type var
+  %1:ptr<handle, f32, read_write> = var undef @binding_point(0, 0) @input_attachment_index(0)
+                                    ^^^
+)")) << res.Failure();
+}
+
 TEST_F(IR_ValidatorTest, Let_NullResult) {
     auto* v = mod.CreateInstruction<ir::Let>(nullptr, b.Constant(1_i));
 
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index e8534b9..2013fe3 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -299,7 +299,7 @@
 
     /// Builds the SPIR-V from the IR
     Result<SuccessType> Generate() {
-        auto valid = core::ir::ValidateAndDumpIfNeeded(ir_, "spirv.Printer");
+        auto valid = core::ir::ValidateAndDumpIfNeeded(ir_, "spirv.Printer", kPrinterCapabilities);
         if (valid != Success) {
             return valid.Failure();
         }
diff --git a/src/tint/lang/spirv/writer/printer/printer.h b/src/tint/lang/spirv/writer/printer/printer.h
index 8c6cbfc..6fc55d0 100644
--- a/src/tint/lang/spirv/writer/printer/printer.h
+++ b/src/tint/lang/spirv/writer/printer/printer.h
@@ -28,6 +28,7 @@
 #ifndef SRC_TINT_LANG_SPIRV_WRITER_PRINTER_PRINTER_H_
 #define SRC_TINT_LANG_SPIRV_WRITER_PRINTER_PRINTER_H_
 
+#include "src/tint/lang/core/ir/validator.h"
 #include "src/tint/lang/spirv/writer/common/options.h"
 #include "src/tint/lang/spirv/writer/common/output.h"
 #include "src/tint/utils/result.h"
@@ -39,6 +40,11 @@
 
 namespace tint::spirv::writer {
 
+// The capabilities that might be needed due to raising.
+const core::ir::Capabilities kPrinterCapabilities{
+    core::ir::Capability::kAllowAnyInputAttachmentIndexType,
+};
+
 /// @returns the generated SPIR-V instructions on success, or failure
 /// @param module the Tint IR module to generate
 /// @param options the printer options
diff --git a/src/tint/lang/spirv/writer/raise/expand_implicit_splats.cc b/src/tint/lang/spirv/writer/raise/expand_implicit_splats.cc
index 2ec5a5f..a45c01c 100644
--- a/src/tint/lang/spirv/writer/raise/expand_implicit_splats.cc
+++ b/src/tint/lang/spirv/writer/raise/expand_implicit_splats.cc
@@ -135,7 +135,8 @@
 }  // namespace
 
 Result<SuccessType> ExpandImplicitSplats(core::ir::Module& ir) {
-    auto result = ValidateAndDumpIfNeeded(ir, "spirv.ExpandImplicitSplats");
+    auto result = ValidateAndDumpIfNeeded(ir, "spirv.ExpandImplicitSplats",
+                                          kExpandImplicitSplatsCapabilities);
     if (result != Success) {
         return result.Failure();
     }
diff --git a/src/tint/lang/spirv/writer/raise/expand_implicit_splats.h b/src/tint/lang/spirv/writer/raise/expand_implicit_splats.h
index ccc67d8..8d387c3 100644
--- a/src/tint/lang/spirv/writer/raise/expand_implicit_splats.h
+++ b/src/tint/lang/spirv/writer/raise/expand_implicit_splats.h
@@ -28,6 +28,7 @@
 #ifndef SRC_TINT_LANG_SPIRV_WRITER_RAISE_EXPAND_IMPLICIT_SPLATS_H_
 #define SRC_TINT_LANG_SPIRV_WRITER_RAISE_EXPAND_IMPLICIT_SPLATS_H_
 
+#include "src/tint/lang/core/ir/validator.h"
 #include "src/tint/utils/result.h"
 
 // Forward declarations.
@@ -37,6 +38,11 @@
 
 namespace tint::spirv::writer::raise {
 
+/// The capabilities that the transform can support.
+const core::ir::Capabilities kExpandImplicitSplatsCapabilities{
+    core::ir::Capability::kAllowAnyInputAttachmentIndexType,
+};
+
 /// ExpandImplicitSplats is a transform that expands implicit vector splat operands in construct
 /// instructions and binary instructions where not supported by SPIR-V.
 /// @param module the module to transform
diff --git a/src/tint/lang/spirv/writer/raise/fork_explicit_layout_types.cc b/src/tint/lang/spirv/writer/raise/fork_explicit_layout_types.cc
index cb8588d..da23c2d 100644
--- a/src/tint/lang/spirv/writer/raise/fork_explicit_layout_types.cc
+++ b/src/tint/lang/spirv/writer/raise/fork_explicit_layout_types.cc
@@ -357,7 +357,8 @@
 }  // namespace
 
 Result<SuccessType> ForkExplicitLayoutTypes(core::ir::Module& ir) {
-    auto result = ValidateAndDumpIfNeeded(ir, "spirv.ForkExplicitLayoutTypes");
+    auto result = ValidateAndDumpIfNeeded(ir, "spirv.ForkExplicitLayoutTypes",
+                                          kForkExplicitLayoutTypesCapabilities);
     if (result != Success) {
         return result;
     }
diff --git a/src/tint/lang/spirv/writer/raise/fork_explicit_layout_types.h b/src/tint/lang/spirv/writer/raise/fork_explicit_layout_types.h
index 678bfa3..5a4742e 100644
--- a/src/tint/lang/spirv/writer/raise/fork_explicit_layout_types.h
+++ b/src/tint/lang/spirv/writer/raise/fork_explicit_layout_types.h
@@ -28,6 +28,7 @@
 #ifndef SRC_TINT_LANG_SPIRV_WRITER_RAISE_FORK_EXPLICIT_LAYOUT_TYPES_H_
 #define SRC_TINT_LANG_SPIRV_WRITER_RAISE_FORK_EXPLICIT_LAYOUT_TYPES_H_
 
+#include "src/tint/lang/core/ir/validator.h"
 #include "src/tint/utils/result.h"
 
 // Forward declarations.
@@ -37,6 +38,11 @@
 
 namespace tint::spirv::writer::raise {
 
+/// The capabilities that the transform can support.
+const core::ir::Capabilities kForkExplicitLayoutTypesCapabilities{
+    core::ir::Capability::kAllowAnyInputAttachmentIndexType,
+};
+
 /// ForkExplicitLayoutTypes is a transform that forks array and structures types that are shared
 /// between address spaces that require explicit layout in SPIR-V and those that cannot have them.
 ///
diff --git a/src/tint/lang/spirv/writer/raise/handle_matrix_arithmetic.cc b/src/tint/lang/spirv/writer/raise/handle_matrix_arithmetic.cc
index 2a8c327..d485361 100644
--- a/src/tint/lang/spirv/writer/raise/handle_matrix_arithmetic.cc
+++ b/src/tint/lang/spirv/writer/raise/handle_matrix_arithmetic.cc
@@ -163,7 +163,8 @@
 }  // namespace
 
 Result<SuccessType> HandleMatrixArithmetic(core::ir::Module& ir) {
-    auto result = ValidateAndDumpIfNeeded(ir, "spirv.HandleMatrixArithmetic");
+    auto result = ValidateAndDumpIfNeeded(ir, "spirv.HandleMatrixArithmetic",
+                                          kHandleMatrixArithmeticCapabilities);
     if (result != Success) {
         return result.Failure();
     }
diff --git a/src/tint/lang/spirv/writer/raise/handle_matrix_arithmetic.h b/src/tint/lang/spirv/writer/raise/handle_matrix_arithmetic.h
index dc0ab72..d1fa97f 100644
--- a/src/tint/lang/spirv/writer/raise/handle_matrix_arithmetic.h
+++ b/src/tint/lang/spirv/writer/raise/handle_matrix_arithmetic.h
@@ -28,6 +28,7 @@
 #ifndef SRC_TINT_LANG_SPIRV_WRITER_RAISE_HANDLE_MATRIX_ARITHMETIC_H_
 #define SRC_TINT_LANG_SPIRV_WRITER_RAISE_HANDLE_MATRIX_ARITHMETIC_H_
 
+#include "src/tint/lang/core/ir/validator.h"
 #include "src/tint/utils/result.h"
 
 // Forward declarations.
@@ -37,6 +38,11 @@
 
 namespace tint::spirv::writer::raise {
 
+// The capabilities that the transform can support.
+const core::ir::Capabilities kHandleMatrixArithmeticCapabilities{
+    core::ir::Capability::kAllowAnyInputAttachmentIndexType,
+};
+
 /// HandleMatrixArithmetic is a transform that converts arithmetic instruction that use matrix into
 /// SPIR-V intrinsics or polyfills.
 /// @param module the module to transform
diff --git a/src/tint/lang/spirv/writer/raise/merge_return.cc b/src/tint/lang/spirv/writer/raise/merge_return.cc
index 20a9076..84e2720 100644
--- a/src/tint/lang/spirv/writer/raise/merge_return.cc
+++ b/src/tint/lang/spirv/writer/raise/merge_return.cc
@@ -43,6 +43,11 @@
 
 namespace {
 
+// The capabilities that the transform can support.
+const core::ir::Capabilities kMergeReturnCapabilities{
+    core::ir::Capability::kAllowAnyInputAttachmentIndexType,
+};
+
 /// PIMPL state for the transform, for a single function.
 struct State {
     /// The IR module.
@@ -322,7 +327,7 @@
 }  // namespace
 
 Result<SuccessType> MergeReturn(core::ir::Module& ir) {
-    auto result = ValidateAndDumpIfNeeded(ir, "spirv.MergeReturn");
+    auto result = ValidateAndDumpIfNeeded(ir, "spirv.MergeReturn", kMergeReturnCapabilities);
     if (result != Success) {
         return result;
     }
diff --git a/src/tint/lang/spirv/writer/raise/raise.cc b/src/tint/lang/spirv/writer/raise/raise.cc
index 0fc09ba..0e4177f 100644
--- a/src/tint/lang/spirv/writer/raise/raise.cc
+++ b/src/tint/lang/spirv/writer/raise/raise.cc
@@ -166,7 +166,9 @@
     }
 
     RUN_TRANSFORM(raise::BuiltinPolyfill, module, options.use_vulkan_memory_model);
+
     RUN_TRANSFORM(raise::ExpandImplicitSplats, module);
+    // kAllowAnyInputAttachmentIndexType required after ExpandImplicitSplats
     RUN_TRANSFORM(raise::HandleMatrixArithmetic, module);
     RUN_TRANSFORM(raise::MergeReturn, module);
     RUN_TRANSFORM(raise::RemoveUnreachableInLoopContinuing, module);
diff --git a/src/tint/lang/spirv/writer/raise/remove_unreachable_in_loop_continuing.cc b/src/tint/lang/spirv/writer/raise/remove_unreachable_in_loop_continuing.cc
index 5f534b4..0848748 100644
--- a/src/tint/lang/spirv/writer/raise/remove_unreachable_in_loop_continuing.cc
+++ b/src/tint/lang/spirv/writer/raise/remove_unreachable_in_loop_continuing.cc
@@ -94,7 +94,8 @@
 }  // namespace
 
 Result<SuccessType> RemoveUnreachableInLoopContinuing(core::ir::Module& ir) {
-    auto result = ValidateAndDumpIfNeeded(ir, "spirv.RemoveUnreachableInLoopContinuing");
+    auto result = ValidateAndDumpIfNeeded(ir, "spirv.RemoveUnreachableInLoopContinuing",
+                                          kRemoveUnreachableInLoopContinuingCapabilities);
     if (result != Success) {
         return result;
     }
diff --git a/src/tint/lang/spirv/writer/raise/remove_unreachable_in_loop_continuing.h b/src/tint/lang/spirv/writer/raise/remove_unreachable_in_loop_continuing.h
index 31742fd..5c04733 100644
--- a/src/tint/lang/spirv/writer/raise/remove_unreachable_in_loop_continuing.h
+++ b/src/tint/lang/spirv/writer/raise/remove_unreachable_in_loop_continuing.h
@@ -28,6 +28,7 @@
 #ifndef SRC_TINT_LANG_SPIRV_WRITER_RAISE_REMOVE_UNREACHABLE_IN_LOOP_CONTINUING_H_
 #define SRC_TINT_LANG_SPIRV_WRITER_RAISE_REMOVE_UNREACHABLE_IN_LOOP_CONTINUING_H_
 
+#include "src/tint/lang/core/ir/validator.h"
 #include "src/tint/utils/result.h"
 
 // Forward declarations.
@@ -37,6 +38,10 @@
 
 namespace tint::spirv::writer::raise {
 
+// The capabilities that the transform can support.
+const core::ir::Capabilities kRemoveUnreachableInLoopContinuingCapabilities{
+    core::ir::Capability::kAllowAnyInputAttachmentIndexType,
+};
 /// RemoveUnreachableInLoopContinuing is a transform that replaces unreachable statements that are
 /// nested inside a loop continuing block, as SPIR-V's structured control flow rules prohibit this.
 /// @param module the module to transform
diff --git a/src/tint/lang/spirv/writer/raise/shader_io.cc b/src/tint/lang/spirv/writer/raise/shader_io.cc
index d69a346..4f5865a 100644
--- a/src/tint/lang/spirv/writer/raise/shader_io.cc
+++ b/src/tint/lang/spirv/writer/raise/shader_io.cc
@@ -207,7 +207,7 @@
 }  // namespace
 
 Result<SuccessType> ShaderIO(core::ir::Module& ir, const ShaderIOConfig& config) {
-    auto result = ValidateAndDumpIfNeeded(ir, "spirv.ShaderIO");
+    auto result = ValidateAndDumpIfNeeded(ir, "spirv.ShaderIO", kShaderIOCapabilities);
     if (result != Success) {
         return result;
     }
diff --git a/src/tint/lang/spirv/writer/raise/shader_io.h b/src/tint/lang/spirv/writer/raise/shader_io.h
index e8c42f4..9ffa76c 100644
--- a/src/tint/lang/spirv/writer/raise/shader_io.h
+++ b/src/tint/lang/spirv/writer/raise/shader_io.h
@@ -29,6 +29,7 @@
 #define SRC_TINT_LANG_SPIRV_WRITER_RAISE_SHADER_IO_H_
 
 #include "src/tint/lang/core/ir/transform/prepare_push_constants.h"
+#include "src/tint/lang/core/ir/validator.h"
 #include "src/tint/lang/spirv/writer/common/options.h"
 #include "src/tint/utils/result.h"
 
@@ -39,6 +40,11 @@
 
 namespace tint::spirv::writer::raise {
 
+/// The capabilities that the transform can support.
+const core::ir::Capabilities kShaderIOCapabilities{
+    core::ir::Capability::kAllowAnyInputAttachmentIndexType,
+};
+
 /// ShaderIOConfig describes the set of configuration options for the ShaderIO transform.
 struct ShaderIOConfig {
     /// push constant layout information
diff --git a/src/tint/lang/spirv/writer/raise/var_for_dynamic_index.cc b/src/tint/lang/spirv/writer/raise/var_for_dynamic_index.cc
index ef3e231..cc5a6a7 100644
--- a/src/tint/lang/spirv/writer/raise/var_for_dynamic_index.cc
+++ b/src/tint/lang/spirv/writer/raise/var_for_dynamic_index.cc
@@ -249,7 +249,7 @@
 }  // namespace
 
 Result<SuccessType> VarForDynamicIndex(core::ir::Module& ir) {
-    auto result = ValidateAndDumpIfNeeded(ir, "spirv.VarForDynamicIndex");
+    auto result = ValidateAndDumpIfNeeded(ir, "spirv.VarForDynamicIndex", kVarForDynamicIndex);
     if (result != Success) {
         return result;
     }
diff --git a/src/tint/lang/spirv/writer/raise/var_for_dynamic_index.h b/src/tint/lang/spirv/writer/raise/var_for_dynamic_index.h
index e3e1b8c..5f9363e 100644
--- a/src/tint/lang/spirv/writer/raise/var_for_dynamic_index.h
+++ b/src/tint/lang/spirv/writer/raise/var_for_dynamic_index.h
@@ -28,6 +28,7 @@
 #ifndef SRC_TINT_LANG_SPIRV_WRITER_RAISE_VAR_FOR_DYNAMIC_INDEX_H_
 #define SRC_TINT_LANG_SPIRV_WRITER_RAISE_VAR_FOR_DYNAMIC_INDEX_H_
 
+#include "src/tint/lang/core/ir/validator.h"
 #include "src/tint/utils/result.h"
 
 // Forward declarations.
@@ -37,6 +38,11 @@
 
 namespace tint::spirv::writer::raise {
 
+/// The capabilities that the transform can support.
+const core::ir::Capabilities kVarForDynamicIndex{
+    core::ir::Capability::kAllowAnyInputAttachmentIndexType,
+};
+
 /// VarForDynamicIndex is a transform that copies array and matrix values that are dynamically
 /// indexed to a temporary local `var` before performing the index. This transform is used by the
 /// SPIR-V writer as there is no SPIR-V instruction that can dynamically index a non-pointer
diff --git a/src/tint/lang/spirv/writer/writer_fuzz.cc b/src/tint/lang/spirv/writer/writer_fuzz.cc
index 94fa09a..5a90dec 100644
--- a/src/tint/lang/spirv/writer/writer_fuzz.cc
+++ b/src/tint/lang/spirv/writer/writer_fuzz.cc
@@ -62,4 +62,6 @@
 }  // namespace
 }  // namespace tint::spirv::writer
 
-TINT_IR_MODULE_FUZZER(tint::spirv::writer::IRFuzzer, tint::core::ir::Capabilities{});
+TINT_IR_MODULE_FUZZER(tint::spirv::writer::IRFuzzer,
+                      tint::core::ir::Capabilities{
+                          tint::core::ir::Capability::kAllowAnyInputAttachmentIndexType});