Implement inter-stage variable matching rules - Part II

This patch implements the inter-stage variable matching rules on
the interpolation attributes ('interpolation type' and
'interpolation sampling'). WebGPU SPEC requires that the interpolation
attributes must match between vertex outputs and fragment inputs with
the same location assignment within the same pipeline.

BUG=dawn:802
TEST=dawn_unittests

Change-Id: Ied38d68f73868c30b0392954683963a801e3f3aa
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/60160
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp
index 7eff93d..9b18105 100644
--- a/src/dawn_native/RenderPipeline.cpp
+++ b/src/dawn_native/RenderPipeline.cpp
@@ -351,6 +351,13 @@
                 if (vertexOutputInfo.componentCount != fragmentInputInfo.componentCount) {
                     return DAWN_VALIDATION_ERROR(generateErrorString("componentCount", i));
                 }
+                if (vertexOutputInfo.interpolationType != fragmentInputInfo.interpolationType) {
+                    return DAWN_VALIDATION_ERROR(generateErrorString("interpolation type", i));
+                }
+                if (vertexOutputInfo.interpolationSampling !=
+                    fragmentInputInfo.interpolationSampling) {
+                    return DAWN_VALIDATION_ERROR(generateErrorString("interpolation sampling", i));
+                }
             }
 
             return {};
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 2fe2f5c..79a9514 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -370,6 +370,38 @@
             }
         }
 
+        ResultOrError<InterpolationType> TintInterpolationTypeToInterpolationType(
+            tint::inspector::InterpolationType type) {
+            switch (type) {
+                case tint::inspector::InterpolationType::kPerspective:
+                    return InterpolationType::Perspective;
+                case tint::inspector::InterpolationType::kLinear:
+                    return InterpolationType::Linear;
+                case tint::inspector::InterpolationType::kFlat:
+                    return InterpolationType::Flat;
+                case tint::inspector::InterpolationType::kUnknown:
+                    return DAWN_VALIDATION_ERROR(
+                        "Attempted to convert 'Unknown' interpolation type from Tint");
+            }
+        }
+
+        ResultOrError<InterpolationSampling> TintInterpolationSamplingToInterpolationSamplingType(
+            tint::inspector::InterpolationSampling type) {
+            switch (type) {
+                case tint::inspector::InterpolationSampling::kNone:
+                    return InterpolationSampling::None;
+                case tint::inspector::InterpolationSampling::kCenter:
+                    return InterpolationSampling::Center;
+                case tint::inspector::InterpolationSampling::kCentroid:
+                    return InterpolationSampling::Centroid;
+                case tint::inspector::InterpolationSampling::kSample:
+                    return InterpolationSampling::Sample;
+                case tint::inspector::InterpolationSampling::kUnknown:
+                    return DAWN_VALIDATION_ERROR(
+                        "Attempted to convert 'Unknown' interpolation sampling type from Tint");
+            }
+        }
+
         MaybeError ValidateSpirv(const uint32_t* code, uint32_t codeSize) {
             spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1);
 
@@ -1041,6 +1073,13 @@
                         DAWN_TRY_ASSIGN(metadata->interStageVariables[location].componentCount,
                                         TintCompositionTypeToInterStageComponentCount(
                                             output_var.composition_type));
+                        DAWN_TRY_ASSIGN(metadata->interStageVariables[location].interpolationType,
+                                        TintInterpolationTypeToInterpolationType(
+                                            output_var.interpolation_type));
+                        DAWN_TRY_ASSIGN(
+                            metadata->interStageVariables[location].interpolationSampling,
+                            TintInterpolationSamplingToInterpolationSamplingType(
+                                output_var.interpolation_sampling));
                     }
                 }
 
@@ -1063,6 +1102,13 @@
                         DAWN_TRY_ASSIGN(metadata->interStageVariables[location].componentCount,
                                         TintCompositionTypeToInterStageComponentCount(
                                             input_var.composition_type));
+                        DAWN_TRY_ASSIGN(
+                            metadata->interStageVariables[location].interpolationType,
+                            TintInterpolationTypeToInterpolationType(input_var.interpolation_type));
+                        DAWN_TRY_ASSIGN(
+                            metadata->interStageVariables[location].interpolationSampling,
+                            TintInterpolationSamplingToInterpolationSamplingType(
+                                input_var.interpolation_sampling));
                     }
 
                     for (const auto& output_var : entryPoint.output_variables) {
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index 2e50153..613033a 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -60,6 +60,19 @@
         Float,
     };
 
+    enum class InterpolationType {
+        Perspective,
+        Linear,
+        Flat,
+    };
+
+    enum class InterpolationSampling {
+        None,
+        Center,
+        Centroid,
+        Sample,
+    };
+
     using PipelineLayoutEntryPointPair = std::pair<PipelineLayoutBase*, std::string>;
     struct PipelineLayoutEntryPointPairHashFunc {
         size_t operator()(const PipelineLayoutEntryPointPair& pair) const;
@@ -169,6 +182,8 @@
         struct InterStageVariableInfo {
             InterStageComponentType baseType;
             uint32_t componentCount;
+            InterpolationType interpolationType;
+            InterpolationSampling interpolationSampling;
         };
         // Now that we only support vertex and fragment stages, there can't be both inter-stage
         // inputs and outputs in one shader stage.
diff --git a/src/tests/unittests/validation/RenderPipelineValidationTests.cpp b/src/tests/unittests/validation/RenderPipelineValidationTests.cpp
index 628a37a..d6e60ed 100644
--- a/src/tests/unittests/validation/RenderPipelineValidationTests.cpp
+++ b/src/tests/unittests/validation/RenderPipelineValidationTests.cpp
@@ -1087,3 +1087,142 @@
         }
     }
 }
+
+// Tests that creating render pipeline should fail when the interpolation attribute of a vertex
+// stage output variable doesn't match the type of the fragment stage input variable at the same
+// location.
+TEST_F(InterStageVariableMatchingValidationTest, DifferentInterpolationAttributeAtSameLocation) {
+    enum class InterpolationType : uint8_t {
+        None = 0,
+        Perspective,
+        Linear,
+        Flat,
+        Count,
+    };
+    enum class InterpolationSampling : uint8_t {
+        None = 0,
+        Center,
+        Centroid,
+        Sample,
+        Count,
+    };
+    constexpr std::array<const char*, static_cast<size_t>(InterpolationType::Count)>
+        kInterpolationTypeString = {{"", "perspective", "linear", "flat"}};
+    constexpr std::array<const char*, static_cast<size_t>(InterpolationSampling::Count)>
+        kInterpolationSamplingString = {{"", "center", "centroid", "sample"}};
+
+    struct InterpolationAttribute {
+        InterpolationType interpolationType;
+        InterpolationSampling interpolationSampling;
+    };
+
+    // Interpolation sampling is not used with flat interpolation.
+    constexpr std::array<InterpolationAttribute, 10> validInterpolationAttributes = {{
+        {InterpolationType::None, InterpolationSampling::None},
+        {InterpolationType::Flat, InterpolationSampling::None},
+        {InterpolationType::Linear, InterpolationSampling::None},
+        {InterpolationType::Linear, InterpolationSampling::Center},
+        {InterpolationType::Linear, InterpolationSampling::Centroid},
+        {InterpolationType::Linear, InterpolationSampling::Sample},
+        {InterpolationType::Perspective, InterpolationSampling::None},
+        {InterpolationType::Perspective, InterpolationSampling::Center},
+        {InterpolationType::Perspective, InterpolationSampling::Centroid},
+        {InterpolationType::Perspective, InterpolationSampling::Sample},
+    }};
+
+    std::vector<wgpu::ShaderModule> vertexModules(validInterpolationAttributes.size());
+    std::vector<wgpu::ShaderModule> fragmentModules(validInterpolationAttributes.size());
+    for (uint32_t i = 0; i < validInterpolationAttributes.size(); ++i) {
+        std::string interfaceDeclaration;
+        {
+            const auto& interpolationAttribute = validInterpolationAttributes[i];
+            std::ostringstream sstream;
+            sstream << "struct A { [[location(0)";
+            if (interpolationAttribute.interpolationType != InterpolationType::None) {
+                sstream << ", interpolate("
+                        << kInterpolationTypeString[static_cast<uint8_t>(
+                               interpolationAttribute.interpolationType)];
+                if (interpolationAttribute.interpolationSampling != InterpolationSampling::None) {
+                    sstream << ", "
+                            << kInterpolationSamplingString[static_cast<uint8_t>(
+                                   interpolationAttribute.interpolationSampling)];
+                }
+                sstream << ")";
+            }
+            sstream << " ]] a : vec4<f32>;" << std::endl;
+            interfaceDeclaration = sstream.str();
+        }
+        {
+            std::ostringstream vertexStream;
+            vertexStream << interfaceDeclaration << R"(
+                    [[builtin(position)]] pos: vec4<f32>;
+                };
+                [[stage(vertex)]] fn main() -> A {
+                    var vertexOut: A;
+                    vertexOut.pos = vec4<f32>(0.0, 0.0, 0.0, 1.0);
+                    return vertexOut;
+                })";
+            vertexModules[i] = utils::CreateShaderModule(device, vertexStream.str().c_str());
+        }
+        {
+            std::ostringstream fragmentStream;
+            fragmentStream << interfaceDeclaration << R"(
+                };
+                [[stage(fragment)]] fn main(fragmentIn: A) -> [[location(0)]] vec4<f32> {
+                    return fragmentIn.a;
+                })";
+            fragmentModules[i] = utils::CreateShaderModule(device, fragmentStream.str().c_str());
+        }
+    }
+
+    auto GetAppliedInterpolationAttribute = [](const InterpolationAttribute& attribute) {
+        InterpolationAttribute appliedAttribute = {attribute.interpolationType,
+                                                   attribute.interpolationSampling};
+        switch (attribute.interpolationType) {
+            // If the interpolation attribute is not specified, then
+            // [[interpolate(perspective, center)]] or [[interpolate(perspective)]] is assumed.
+            case InterpolationType::None:
+                appliedAttribute.interpolationType = InterpolationType::Perspective;
+                appliedAttribute.interpolationSampling = InterpolationSampling::Center;
+                break;
+
+            // If the interpolation type is perspective or linear, and the interpolation
+            // sampling is not specified, then 'center' is assumed.
+            case InterpolationType::Perspective:
+            case InterpolationType::Linear:
+                if (appliedAttribute.interpolationSampling == InterpolationSampling::None) {
+                    appliedAttribute.interpolationSampling = InterpolationSampling::Center;
+                }
+                break;
+
+            case InterpolationType::Flat:
+                break;
+            default:
+                UNREACHABLE();
+        }
+        return appliedAttribute;
+    };
+
+    auto InterpolationAttributeMatch = [GetAppliedInterpolationAttribute](
+                                           const InterpolationAttribute& attribute1,
+                                           const InterpolationAttribute& attribute2) {
+        InterpolationAttribute appliedAttribute1 = GetAppliedInterpolationAttribute(attribute1);
+        InterpolationAttribute appliedAttribute2 = GetAppliedInterpolationAttribute(attribute2);
+
+        return appliedAttribute1.interpolationType == appliedAttribute2.interpolationType &&
+               appliedAttribute1.interpolationSampling == appliedAttribute2.interpolationSampling;
+    };
+
+    for (uint32_t vertexModuleIndex = 0; vertexModuleIndex < validInterpolationAttributes.size();
+         ++vertexModuleIndex) {
+        wgpu::ShaderModule vertexModule = vertexModules[vertexModuleIndex];
+        for (uint32_t fragmentModuleIndex = 0;
+             fragmentModuleIndex < validInterpolationAttributes.size(); ++fragmentModuleIndex) {
+            wgpu::ShaderModule fragmentModule = fragmentModules[fragmentModuleIndex];
+            bool shouldSuccess =
+                InterpolationAttributeMatch(validInterpolationAttributes[vertexModuleIndex],
+                                            validInterpolationAttributes[fragmentModuleIndex]);
+            CheckCreatingRenderPipeline(vertexModule, fragmentModule, shouldSuccess);
+        }
+    }
+}