Make the validation on inter-stage shader variables match latest WebGPU SPEC

This patch updates the validations on the inter-stage shader variables to
match the latest WebGPU SPEC (in chapter "validating-inter-stage-interfaces").

With this patch the below validation tests in WebGPU CTS will pass:
- render_pipeline,inter_stage:max_shader_variable_location:*
- render_pipeline,inter_stage:max_components_count,*

Fixed: dawn:1448
Test: dawn_unittests
Change-Id: I3e4d98f03ec18e5d1642a4d7ecd3eed1b7ae04d0
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/102104
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/dawn/common/Constants.h b/src/dawn/common/Constants.h
index a5d20ab..4ee5423 100644
--- a/src/dawn/common/Constants.h
+++ b/src/dawn/common/Constants.h
@@ -25,7 +25,7 @@
 static constexpr uint8_t kMaxColorAttachments = 8u;
 static constexpr uint32_t kTextureBytesPerRowAlignment = 256u;
 static constexpr uint32_t kMaxInterStageShaderComponents = 60u;
-static constexpr uint32_t kMaxInterStageShaderVariables = kMaxInterStageShaderComponents / 4;
+static constexpr uint32_t kMaxInterStageShaderVariables = 16u;
 
 // Per stage limits
 static constexpr uint32_t kMaxSampledTexturesPerShaderStage = 16;
diff --git a/src/dawn/native/RenderPipeline.cpp b/src/dawn/native/RenderPipeline.cpp
index 553896f..033231b 100644
--- a/src/dawn/native/RenderPipeline.cpp
+++ b/src/dawn/native/RenderPipeline.cpp
@@ -115,12 +115,15 @@
 
 MaybeError ValidateVertexState(DeviceBase* device,
                                const VertexState* descriptor,
-                               const PipelineLayoutBase* layout) {
+                               const PipelineLayoutBase* layout,
+                               wgpu::PrimitiveTopology primitiveTopology) {
     DAWN_INVALID_IF(descriptor->nextInChain != nullptr, "nextInChain must be nullptr.");
 
-    DAWN_INVALID_IF(descriptor->bufferCount > kMaxVertexBuffers,
+    const CombinedLimits& limits = device->GetLimits();
+
+    DAWN_INVALID_IF(descriptor->bufferCount > limits.v1.maxVertexBuffers,
                     "Vertex buffer count (%u) exceeds the maximum number of vertex buffers (%u).",
-                    descriptor->bufferCount, kMaxVertexBuffers);
+                    descriptor->bufferCount, limits.v1.maxVertexBuffers);
 
     DAWN_TRY_CONTEXT(ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint,
                                                descriptor->constantCount, descriptor->constants,
@@ -129,6 +132,15 @@
                      descriptor->entryPoint);
     const EntryPointMetadata& vertexMetadata =
         descriptor->module->GetEntryPoint(descriptor->entryPoint);
+    if (primitiveTopology == wgpu::PrimitiveTopology::PointList) {
+        DAWN_INVALID_IF(
+            vertexMetadata.totalInterStageShaderComponents + 1 >
+                limits.v1.maxInterStageShaderComponents,
+            "Total vertex output components count (%u) exceeds the maximum (%u) when primitive "
+            "topology is %s as another component is implicitly used for the point size.",
+            vertexMetadata.totalInterStageShaderComponents,
+            limits.v1.maxInterStageShaderComponents - 1, primitiveTopology);
+    }
 
     ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes> attributesSetMask;
     uint32_t totalAttributesNum = 0;
@@ -433,7 +445,8 @@
         DAWN_TRY(device->ValidateObject(descriptor->layout));
     }
 
-    DAWN_TRY_CONTEXT(ValidateVertexState(device, &descriptor->vertex, descriptor->layout),
+    DAWN_TRY_CONTEXT(ValidateVertexState(device, &descriptor->vertex, descriptor->layout,
+                                         descriptor->primitive.topology),
                      "validating vertex state.");
 
     DAWN_TRY_CONTEXT(ValidatePrimitiveState(device, &descriptor->primitive),
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index f244a50..2eeb6b2 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -510,8 +510,6 @@
     const DeviceBase* device,
     tint::inspector::Inspector* inspector,
     const tint::inspector::EntryPoint& entryPoint) {
-    constexpr uint32_t kMaxInterStageShaderLocation = kMaxInterStageShaderVariables - 1;
-
     std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
 
     // Returns the invalid argument, and if it is true additionally store the formatted
@@ -572,13 +570,17 @@
         metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
     }
 
+    const CombinedLimits& limits = device->GetLimits();
+    const uint32_t maxVertexAttributes = limits.v1.maxVertexAttributes;
+    const uint32_t maxInterStageShaderVariables = limits.v1.maxInterStageShaderVariables;
+    const uint32_t maxInterStageShaderComponents = limits.v1.maxInterStageShaderComponents;
     if (metadata->stage == SingleShaderStage::Vertex) {
         for (const auto& inputVar : entryPoint.input_variables) {
             uint32_t unsanitizedLocation = inputVar.location_decoration;
-            if (DelayedInvalidIf(unsanitizedLocation >= kMaxVertexAttributes,
+            if (DelayedInvalidIf(unsanitizedLocation >= maxVertexAttributes,
                                  "Vertex input variable \"%s\" has a location (%u) that "
                                  "exceeds the maximum (%u)",
-                                 inputVar.name, unsanitizedLocation, kMaxVertexAttributes)) {
+                                 inputVar.name, unsanitizedLocation, maxVertexAttributes)) {
                 continue;
             }
 
@@ -588,9 +590,7 @@
             metadata->usedVertexInputs.set(location);
         }
 
-        // [[position]] must be declared in a vertex shader but is not exposed as an
-        // output variable by Tint so we directly add its components to the total.
-        uint32_t totalInterStageShaderComponents = 4;
+        uint32_t totalInterStageShaderComponents = 0;
         for (const auto& outputVar : entryPoint.output_variables) {
             EntryPointMetadata::InterStageVariableInfo variable;
             DAWN_TRY_ASSIGN(variable.baseType,
@@ -605,10 +605,10 @@
             totalInterStageShaderComponents += variable.componentCount;
 
             uint32_t location = outputVar.location_decoration;
-            if (DelayedInvalidIf(location > kMaxInterStageShaderLocation,
+            if (DelayedInvalidIf(location >= maxInterStageShaderVariables,
                                  "Vertex output variable \"%s\" has a location (%u) that "
-                                 "exceeds the maximum (%u).",
-                                 outputVar.name, location, kMaxInterStageShaderLocation)) {
+                                 "is greater than or equal to (%u).",
+                                 outputVar.name, location, maxInterStageShaderVariables)) {
                 continue;
             }
 
@@ -616,9 +616,10 @@
             metadata->interStageVariables[location] = variable;
         }
 
-        DelayedInvalidIf(totalInterStageShaderComponents > kMaxInterStageShaderComponents,
+        metadata->totalInterStageShaderComponents = totalInterStageShaderComponents;
+        DelayedInvalidIf(totalInterStageShaderComponents > maxInterStageShaderComponents,
                          "Total vertex output components count (%u) exceeds the maximum (%u).",
-                         totalInterStageShaderComponents, kMaxInterStageShaderComponents);
+                         totalInterStageShaderComponents, maxInterStageShaderComponents);
     }
 
     if (metadata->stage == SingleShaderStage::Fragment) {
@@ -637,10 +638,10 @@
             totalInterStageShaderComponents += variable.componentCount;
 
             uint32_t location = inputVar.location_decoration;
-            if (DelayedInvalidIf(location > kMaxInterStageShaderLocation,
+            if (DelayedInvalidIf(location >= maxInterStageShaderVariables,
                                  "Fragment input variable \"%s\" has a location (%u) that "
-                                 "exceeds the maximum (%u).",
-                                 inputVar.name, location, kMaxInterStageShaderLocation)) {
+                                 "is greater than or equal to (%u).",
+                                 inputVar.name, location, maxInterStageShaderVariables)) {
                 continue;
             }
 
@@ -658,15 +659,13 @@
         if (entryPoint.sample_index_used) {
             totalInterStageShaderComponents += 1;
         }
-        if (entryPoint.input_position_used) {
-            totalInterStageShaderComponents += 4;
-        }
 
-        DelayedInvalidIf(totalInterStageShaderComponents > kMaxInterStageShaderComponents,
+        metadata->totalInterStageShaderComponents = totalInterStageShaderComponents;
+        DelayedInvalidIf(totalInterStageShaderComponents > maxInterStageShaderComponents,
                          "Total fragment input components count (%u) exceeds the maximum (%u).",
-                         totalInterStageShaderComponents, kMaxInterStageShaderComponents);
+                         totalInterStageShaderComponents, maxInterStageShaderComponents);
 
-        uint32_t maxColorAttachments = device->GetLimits().v1.maxColorAttachments;
+        uint32_t maxColorAttachments = limits.v1.maxColorAttachments;
         for (const auto& outputVar : entryPoint.output_variables) {
             EntryPointMetadata::FragmentOutputVariableInfo variable;
             DAWN_TRY_ASSIGN(variable.baseType,
diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h
index b04e9ed..0c829f5 100644
--- a/src/dawn/native/ShaderModule.h
+++ b/src/dawn/native/ShaderModule.h
@@ -210,6 +210,7 @@
     // inputs and outputs in one shader stage.
     std::bitset<kMaxInterStageShaderVariables> usedInterStageVariables;
     std::array<InterStageVariableInfo, kMaxInterStageShaderVariables> interStageVariables;
+    uint32_t totalInterStageShaderComponents;
 
     // The shader stage for this binding.
     SingleShaderStage stage;
diff --git a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp
index 1520bb8..d7bf350 100644
--- a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp
+++ b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp
@@ -288,21 +288,21 @@
         }
     };
 
-    constexpr uint32_t kMaxInterShaderIOLocation = kMaxInterStageShaderComponents / 4 - 1;
+    // It is allowed to create a shader module with the maximum active vertex output location ==
+    // (kMaxInterStageShaderVariables - 1);
+    CheckTestPipeline(true, kMaxInterStageShaderVariables - 1, wgpu::ShaderStage::Vertex);
 
-    // It is allowed to create a shader module with the maximum active vertex output location == 14;
-    CheckTestPipeline(true, kMaxInterShaderIOLocation, wgpu::ShaderStage::Vertex);
-
-    // It isn't allowed to create a shader module with the maximum active vertex output location >
-    // 14;
-    CheckTestPipeline(false, kMaxInterShaderIOLocation + 1, wgpu::ShaderStage::Vertex);
+    // It isn't allowed to create a shader module with the maximum active vertex output location ==
+    // kMaxInterStageShaderVariables;
+    CheckTestPipeline(false, kMaxInterStageShaderVariables, wgpu::ShaderStage::Vertex);
 
     // It is allowed to create a shader module with the maximum active fragment input location ==
-    // 14;
-    CheckTestPipeline(true, kMaxInterShaderIOLocation, wgpu::ShaderStage::Fragment);
+    // (kMaxInterStageShaderVariables - 1);
+    CheckTestPipeline(true, kMaxInterStageShaderVariables - 1, wgpu::ShaderStage::Fragment);
 
-    // It is allowed to create a shader module with the maximum active vertex output location > 14;
-    CheckTestPipeline(false, kMaxInterShaderIOLocation + 1, wgpu::ShaderStage::Fragment);
+    // It isn't allowed to create a shader module with the maximum active vertex output location ==
+    // kMaxInterStageShaderVariables;
+    CheckTestPipeline(false, kMaxInterStageShaderVariables, wgpu::ShaderStage::Fragment);
 }
 
 // Validate the maximum number of total inter-stage user-defined variable component count and
@@ -311,7 +311,8 @@
     auto CheckTestPipeline = [&](bool success,
                                  uint32_t totalUserDefinedInterStageShaderComponentCount,
                                  wgpu::ShaderStage failingShaderStage,
-                                 const char* extraBuiltInDeclarations = "") {
+                                 const char* extraBuiltInDeclarations = "",
+                                 bool usePointListAsPrimitiveType = false) {
         // Build the ShaderIO struct containing totalUserDefinedInterStageShaderComponentCount
         // components. Components are added in two parts, a bunch of vec4s, then one additional
         // variable for the remaining components.
@@ -347,11 +348,20 @@
         // string "failingVertex" or "failingFragment" in the error message.
         utils::ComboRenderPipelineDescriptor pDesc;
         pDesc.cTargets[0].format = wgpu::TextureFormat::RGBA8Unorm;
+        if (usePointListAsPrimitiveType) {
+            pDesc.primitive.topology = wgpu::PrimitiveTopology::PointList;
+        } else {
+            pDesc.primitive.topology = wgpu::PrimitiveTopology::TriangleList;
+        }
 
         const char* errorMatcher = nullptr;
         switch (failingShaderStage) {
             case wgpu::ShaderStage::Vertex: {
-                errorMatcher = "failingVertex";
+                if (usePointListAsPrimitiveType) {
+                    errorMatcher = "PointList";
+                } else {
+                    errorMatcher = "failingVertex";
+                }
                 pDesc.vertex.entryPoint = "failingVertex";
                 pDesc.vertex.module = utils::CreateShaderModule(device, (ioStruct + R"(
                     @vertex fn failingVertex() -> ShaderIO {
@@ -408,20 +418,28 @@
         CheckTestPipeline(false, kMaxInterStageShaderComponents + 1, wgpu::ShaderStage::Fragment);
     }
 
-    // @builtin(position) should be counted into the maximum inter-stage component count.
-    // Note that in vertex shader we always have @position so we don't need to specify it
-    // again in the parameter "builtInDeclarations" of generateShaderForTest().
+    // Verify the total user-defined vertex output component count must be less than
+    // kMaxInterStageShaderComponents.
     {
-        CheckTestPipeline(true, kMaxInterStageShaderComponents - 4, wgpu::ShaderStage::Vertex);
-        CheckTestPipeline(false, kMaxInterStageShaderComponents - 3, wgpu::ShaderStage::Vertex);
+        CheckTestPipeline(true, kMaxInterStageShaderComponents, wgpu::ShaderStage::Vertex);
+        CheckTestPipeline(false, kMaxInterStageShaderComponents + 1, wgpu::ShaderStage::Vertex);
     }
 
-    // @builtin(position) in fragment shaders should be counted into the maximum inter-stage
+    // Verify the total user-defined vertex output component count must be less than
+    // (kMaxInterStageShaderComponents - 1) when the primitive topology is PointList.
+    {
+        constexpr bool kUsePointListAsPrimitiveTopology = true;
+        const char* kExtraBuiltins = "";
+        CheckTestPipeline(true, kMaxInterStageShaderComponents - 1, wgpu::ShaderStage::Vertex,
+                          kExtraBuiltins, kUsePointListAsPrimitiveTopology);
+        CheckTestPipeline(false, kMaxInterStageShaderComponents, wgpu::ShaderStage::Vertex,
+                          kExtraBuiltins, kUsePointListAsPrimitiveTopology);
+    }
+
+    // @builtin(position) in fragment shaders shouldn't be counted into the maximum inter-stage
     // component count.
     {
-        CheckTestPipeline(true, kMaxInterStageShaderComponents - 4, wgpu::ShaderStage::Fragment,
-                          "@builtin(position) fragCoord : vec4<f32>,");
-        CheckTestPipeline(false, kMaxInterStageShaderComponents - 3, wgpu::ShaderStage::Fragment,
+        CheckTestPipeline(true, kMaxInterStageShaderComponents, wgpu::ShaderStage::Fragment,
                           "@builtin(position) fragCoord : vec4<f32>,");
     }
 
diff --git a/webgpu-cts/expectations.txt b/webgpu-cts/expectations.txt
index 7830d43..bba7c8d 100644
--- a/webgpu-cts/expectations.txt
+++ b/webgpu-cts/expectations.txt
@@ -542,16 +542,6 @@
 ################################################################################
 # API validation failures
 ################################################################################
-crbug.com/dawn/0000 webgpu:api,validation,render_pipeline,inter_stage:max_components_count,input:isAsync=false;numScalarDelta=-3;useExtraBuiltinInputs=true [ Failure ]
-crbug.com/dawn/0000 webgpu:api,validation,render_pipeline,inter_stage:max_components_count,input:isAsync=false;numScalarDelta=0;useExtraBuiltinInputs=false [ Failure ]
-crbug.com/dawn/0000 webgpu:api,validation,render_pipeline,inter_stage:max_components_count,input:isAsync=true;numScalarDelta=-3;useExtraBuiltinInputs=true [ Failure ]
-crbug.com/dawn/0000 webgpu:api,validation,render_pipeline,inter_stage:max_components_count,input:isAsync=true;numScalarDelta=0;useExtraBuiltinInputs=false [ Failure ]
-crbug.com/dawn/0000 webgpu:api,validation,render_pipeline,inter_stage:max_components_count,output:isAsync=false;numScalarDelta=-1;topology="point-list" [ Failure ]
-crbug.com/dawn/0000 webgpu:api,validation,render_pipeline,inter_stage:max_components_count,output:isAsync=false;numScalarDelta=0;topology="triangle-list" [ Failure ]
-crbug.com/dawn/0000 webgpu:api,validation,render_pipeline,inter_stage:max_components_count,output:isAsync=true;numScalarDelta=-1;topology="point-list" [ Failure ]
-crbug.com/dawn/0000 webgpu:api,validation,render_pipeline,inter_stage:max_components_count,output:isAsync=true;numScalarDelta=0;topology="triangle-list" [ Failure ]
-crbug.com/dawn/0000 webgpu:api,validation,render_pipeline,inter_stage:max_shader_variable_location:isAsync=false;locationDelta=-1 [ Failure ]
-crbug.com/dawn/0000 webgpu:api,validation,render_pipeline,inter_stage:max_shader_variable_location:isAsync=true;locationDelta=-1 [ Failure ]
 crbug.com/dawn/0000 webgpu:api,validation,shader_module,entry_point:compute:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cu0000" [ Failure ]
 crbug.com/dawn/0000 webgpu:api,validation,shader_module,entry_point:compute:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cu0000a" [ Failure ]
 crbug.com/dawn/0000 webgpu:api,validation,shader_module,entry_point:compute:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cu0000" [ Failure ]