Compat: Make vertex_index/instance_index take an attribute

Bug: dawn:2042
Change-Id: I5441648bb1387e959b181f5171ba1a46f5d6f73a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/150703
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Gregg Tavares <gman@chromium.org>
diff --git a/src/dawn/native/RenderPipeline.cpp b/src/dawn/native/RenderPipeline.cpp
index bd75e8d..014ecdc 100644
--- a/src/dawn/native/RenderPipeline.cpp
+++ b/src/dawn/native/RenderPipeline.cpp
@@ -151,6 +151,18 @@
         totalAttributesNum += descriptor->buffers[i].attributeCount;
     }
 
+    if (device->IsCompatibilityMode() &&
+        (vertexMetadata.usesVertexIndex || vertexMetadata.usesInstanceIndex)) {
+        uint32_t totalEffectiveAttributesNum = totalAttributesNum +
+                                               (vertexMetadata.usesVertexIndex ? 1 : 0) +
+                                               (vertexMetadata.usesInstanceIndex ? 1 : 0);
+        DAWN_INVALID_IF(totalEffectiveAttributesNum > limits.v1.maxVertexAttributes,
+                        "Attribute count (%u) exceeds the maximum number of attributes (%u) as "
+                        "@builtin(vertex_index) and @builtin(instance_index) each use an attribute "
+                        "in compatibility mode.",
+                        totalEffectiveAttributesNum, limits.v1.maxVertexAttributes);
+    }
+
     // Every vertex attribute has a member called shaderLocation, and there are some
     // requirements for shaderLocation: 1) >=0, 2) values are different across different
     // attributes, 3) can't exceed kMaxVertexAttributes. So it can ensure that total
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index 3ec62fd..77b6f5d 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -634,6 +634,9 @@
         DelayedInvalidIf(totalInterStageShaderComponents > maxInterStageShaderComponents,
                          "Total vertex output components count (%u) exceeds the maximum (%u).",
                          totalInterStageShaderComponents, maxInterStageShaderComponents);
+
+        metadata->usesVertexIndex = entryPoint.vertex_index_used;
+        metadata->usesInstanceIndex = entryPoint.instance_index_used;
     }
 
     if (metadata->stage == SingleShaderStage::Fragment) {
diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h
index 995c426..6f5ff7e 100644
--- a/src/dawn/native/ShaderModule.h
+++ b/src/dawn/native/ShaderModule.h
@@ -247,6 +247,8 @@
 
     bool usesNumWorkgroups = false;
     bool usesFragDepth = false;
+    bool usesVertexIndex = false;
+    bool usesInstanceIndex = false;
     // Used at render pipeline validation.
     bool usesSampleMaskOutput = false;
 };
diff --git a/src/dawn/tests/unittests/validation/CompatValidationTests.cpp b/src/dawn/tests/unittests/validation/CompatValidationTests.cpp
index 100ee8d..2075ae9 100644
--- a/src/dawn/tests/unittests/validation/CompatValidationTests.cpp
+++ b/src/dawn/tests/unittests/validation/CompatValidationTests.cpp
@@ -16,6 +16,8 @@
 #include <string>
 #include <vector>
 
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
 #include "dawn/tests/unittests/validation/ValidationTest.h"
 #include "dawn/utils/ComboRenderPipelineDescriptor.h"
 #include "dawn/utils/WGPUHelpers.h"
@@ -995,5 +997,95 @@
         ASSERT_DEVICE_ERROR(encoder.Finish(), testing::HasSubstr("cannot be used"));
     }
 }
+
+class CompatMaxVertexAttributesTest : public CompatValidationTest {
+  protected:
+    void TestMaxVertexAttributes(bool usesVertexIndex, bool usesInstanceIndex) {
+        wgpu::SupportedLimits limits;
+        device.GetLimits(&limits);
+
+        uint32_t maxAttributes = limits.limits.maxVertexAttributes;
+        uint32_t numAttributesUsedByBuiltins =
+            (usesVertexIndex ? 1 : 0) + (usesInstanceIndex ? 1 : 0);
+
+        TestAttributes(maxAttributes - numAttributesUsedByBuiltins, usesVertexIndex,
+                       usesInstanceIndex, true);
+        if (usesVertexIndex || usesInstanceIndex) {
+            TestAttributes(maxAttributes - numAttributesUsedByBuiltins + 1, usesVertexIndex,
+                           usesInstanceIndex, false);
+        }
+    }
+
+    void TestAttributes(uint32_t numAttributes,
+                        bool usesVertexIndex,
+                        bool usesInstanceIndex,
+                        bool expectSuccess) {
+        std::vector<std::string> inputs;
+        std::vector<std::string> outputs;
+
+        utils::ComboRenderPipelineDescriptor descriptor;
+        descriptor.layout = {};
+        descriptor.vertex.entryPoint = "vs";
+        descriptor.vertex.bufferCount = 1;
+        descriptor.cFragment.entryPoint = "fs";
+        descriptor.cBuffers[0].arrayStride = 16;
+        descriptor.cBuffers[0].attributeCount = numAttributes;
+
+        for (uint32_t i = 0; i < numAttributes; ++i) {
+            inputs.push_back(absl::StrFormat("@location(%u) v%u: vec4f", i, i));
+            outputs.push_back(absl::StrFormat("v%u", i));
+            descriptor.cAttributes[i].format = wgpu::VertexFormat::Float32x4;
+            descriptor.cAttributes[i].shaderLocation = i;
+        }
+
+        if (usesVertexIndex) {
+            inputs.push_back("@builtin(vertex_index) vNdx: u32");
+            outputs.push_back("vec4f(f32(vNdx))");
+        }
+
+        if (usesInstanceIndex) {
+            inputs.push_back("@builtin(instance_index) iNdx: u32");
+            outputs.push_back("vec4f(f32(iNdx))");
+        }
+
+        auto wgsl = absl::StrFormat(R"(
+            @fragment fn fs() -> @location(0) vec4f {
+                return vec4f(1);
+            }
+            @vertex fn vs(%s) -> @builtin(position) vec4f {
+                return %s;
+            }
+            )",
+                                    absl::StrJoin(inputs, ", "), absl::StrJoin(outputs, " + "));
+
+        wgpu::ShaderModule module = utils::CreateShaderModule(device, wgsl.c_str());
+        descriptor.vertex.module = module;
+        descriptor.cFragment.module = module;
+
+        if (expectSuccess) {
+            device.CreateRenderPipeline(&descriptor);
+        } else {
+            ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor),
+                                testing::HasSubstr("compat"));
+        }
+    }
+};
+
+TEST_F(CompatMaxVertexAttributesTest, CanUseMaxVertexAttributes) {
+    TestMaxVertexAttributes(false, false);
+}
+
+TEST_F(CompatMaxVertexAttributesTest, VertexIndexTakesAnAttribute) {
+    TestMaxVertexAttributes(true, false);
+}
+
+TEST_F(CompatMaxVertexAttributesTest, InstanceIndexTakesAnAttribute) {
+    TestMaxVertexAttributes(false, true);
+}
+
+TEST_F(CompatMaxVertexAttributesTest, VertexAndInstanceIndexEachTakeAnAttribute) {
+    TestMaxVertexAttributes(true, true);
+}
+
 }  // anonymous namespace
 }  // namespace dawn
diff --git a/src/tint/lang/wgsl/inspector/entry_point.h b/src/tint/lang/wgsl/inspector/entry_point.h
index d379867..8c3161d 100644
--- a/src/tint/lang/wgsl/inspector/entry_point.h
+++ b/src/tint/lang/wgsl/inspector/entry_point.h
@@ -163,6 +163,10 @@
     bool num_workgroups_used = false;
     /// Does the entry point use the frag_depth builtin
     bool frag_depth_used = false;
+    /// Does the entry point use the vertex_index builtin
+    bool vertex_index_used = false;
+    /// Does the entry point use the instance_index builtin
+    bool instance_index_used = false;
 };
 
 }  // namespace tint::inspector
diff --git a/src/tint/lang/wgsl/inspector/inspector.cc b/src/tint/lang/wgsl/inspector/inspector.cc
index 64811df..8a6b0d1 100644
--- a/src/tint/lang/wgsl/inspector/inspector.cc
+++ b/src/tint/lang/wgsl/inspector/inspector.cc
@@ -178,6 +178,10 @@
             core::BuiltinValue::kSampleMask, param->Type(), param->Declaration()->attributes);
         entry_point.num_workgroups_used |= ContainsBuiltin(
             core::BuiltinValue::kNumWorkgroups, param->Type(), param->Declaration()->attributes);
+        entry_point.vertex_index_used |= ContainsBuiltin(
+            core::BuiltinValue::kVertexIndex, param->Type(), param->Declaration()->attributes);
+        entry_point.instance_index_used |= ContainsBuiltin(
+            core::BuiltinValue::kInstanceIndex, param->Type(), param->Declaration()->attributes);
     }
 
     if (!sem->ReturnType()->Is<core::type::Void>()) {