Implement render pipeline vertex format base type validation.

Bug: dawn:1008

Change-Id: I04d1ff1d46c1106147a8c50415c989db5789cbfc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/59031
Auto-Submit: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Jiawei Shao <jiawei.shao@intel.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp
index 3f59535..7392f1d 100644
--- a/src/dawn_native/RenderPipeline.cpp
+++ b/src/dawn_native/RenderPipeline.cpp
@@ -28,16 +28,19 @@
     // Helper functions
     namespace {
 
-        MaybeError ValidateVertexAttribute(DeviceBase* device,
-                                           const VertexAttribute* attribute,
-                                           uint64_t vertexBufferStride,
-                                           std::bitset<kMaxVertexAttributes>* attributesSetMask) {
+        MaybeError ValidateVertexAttribute(
+            DeviceBase* device,
+            const VertexAttribute* attribute,
+            const EntryPointMetadata& metadata,
+            uint64_t vertexBufferStride,
+            ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes>* attributesSetMask) {
             DAWN_TRY(ValidateVertexFormat(attribute->format));
             const VertexFormatInfo& formatInfo = GetVertexFormatInfo(attribute->format);
 
             if (attribute->shaderLocation >= kMaxVertexAttributes) {
                 return DAWN_VALIDATION_ERROR("Setting attribute out of bounds");
             }
+            VertexAttributeLocation location(static_cast<uint8_t>(attribute->shaderLocation));
 
             // No underflow is possible because the max vertex format size is smaller than
             // kMaxVertexBufferArrayStride.
@@ -59,18 +62,25 @@
                     "Attribute offset needs to be a multiple of the size format's components");
             }
 
-            if ((*attributesSetMask)[attribute->shaderLocation]) {
+            if (metadata.usedVertexInputs[location] &&
+                formatInfo.baseType != metadata.vertexInputBaseTypes[location]) {
+                return DAWN_VALIDATION_ERROR(
+                    "Attribute base type must match the base type in the shader.");
+            }
+
+            if ((*attributesSetMask)[location]) {
                 return DAWN_VALIDATION_ERROR("Setting already set attribute");
             }
 
-            attributesSetMask->set(attribute->shaderLocation);
+            attributesSetMask->set(location);
             return {};
         }
 
         MaybeError ValidateVertexBufferLayout(
             DeviceBase* device,
             const VertexBufferLayout* buffer,
-            std::bitset<kMaxVertexAttributes>* attributesSetMask) {
+            const EntryPointMetadata& metadata,
+            ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes>* attributesSetMask) {
             DAWN_TRY(ValidateInputStepMode(buffer->stepMode));
             if (buffer->arrayStride > kMaxVertexBufferArrayStride) {
                 return DAWN_VALIDATION_ERROR("Setting arrayStride out of bounds");
@@ -82,7 +92,7 @@
             }
 
             for (uint32_t i = 0; i < buffer->attributeCount; ++i) {
-                DAWN_TRY(ValidateVertexAttribute(device, &buffer->attributes[i],
+                DAWN_TRY(ValidateVertexAttribute(device, &buffer->attributes[i], metadata,
                                                  buffer->arrayStride, attributesSetMask));
             }
 
@@ -100,10 +110,15 @@
                 return DAWN_VALIDATION_ERROR("Vertex buffer count exceeds maximum");
             }
 
-            std::bitset<kMaxVertexAttributes> attributesSetMask;
+            DAWN_TRY(ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint,
+                                               layout, SingleShaderStage::Vertex));
+            const EntryPointMetadata& vertexMetadata =
+                descriptor->module->GetEntryPoint(descriptor->entryPoint);
+
+            ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes> attributesSetMask;
             uint32_t totalAttributesNum = 0;
             for (uint32_t i = 0; i < descriptor->bufferCount; ++i) {
-                DAWN_TRY(ValidateVertexBufferLayout(device, &descriptor->buffers[i],
+                DAWN_TRY(ValidateVertexBufferLayout(device, &descriptor->buffers[i], vertexMetadata,
                                                     &attributesSetMask));
                 totalAttributesNum += descriptor->buffers[i].attributeCount;
             }
@@ -114,11 +129,7 @@
             // attribute number never exceed kMaxVertexAttributes.
             ASSERT(totalAttributesNum <= kMaxVertexAttributes);
 
-            DAWN_TRY(ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint,
-                                               layout, SingleShaderStage::Vertex));
-            const EntryPointMetadata& vertexMetadata =
-                descriptor->module->GetEntryPoint(descriptor->entryPoint);
-            if (!IsSubset(vertexMetadata.usedVertexAttributes, attributesSetMask)) {
+            if (!IsSubset(vertexMetadata.usedVertexInputs, attributesSetMask)) {
                 return DAWN_VALIDATION_ERROR(
                     "Pipeline vertex stage uses vertex buffers not in the vertex state");
             }
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 9bd3b7e..55b771d 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -295,6 +295,21 @@
             }
         }
 
+        ResultOrError<VertexFormatBaseType> TintComponentTypeToVertexFormatBaseType(
+            tint::inspector::ComponentType type) {
+            switch (type) {
+                case tint::inspector::ComponentType::kFloat:
+                    return VertexFormatBaseType::Float;
+                case tint::inspector::ComponentType::kSInt:
+                    return VertexFormatBaseType::Sint;
+                case tint::inspector::ComponentType::kUInt:
+                    return VertexFormatBaseType::Uint;
+                case tint::inspector::ComponentType::kUnknown:
+                    return DAWN_VALIDATION_ERROR(
+                        "Attempted to convert 'Unknown' component type from Tint");
+            }
+        }
+
         ResultOrError<wgpu::BufferBindingType> TintResourceTypeToBufferBindingType(
             tint::inspector::ResourceBinding::ResourceType resource_type) {
             switch (resource_type) {
@@ -811,13 +826,19 @@
                         return DAWN_VALIDATION_ERROR(
                             "Unable to find Location decoration for Vertex input");
                     }
-                    uint32_t location = compiler.get_decoration(attrib.id, spv::DecorationLocation);
+                    uint32_t unsanitizedLocation =
+                        compiler.get_decoration(attrib.id, spv::DecorationLocation);
 
-                    if (location >= kMaxVertexAttributes) {
+                    if (unsanitizedLocation >= kMaxVertexAttributes) {
                         return DAWN_VALIDATION_ERROR("Attribute location over limits in the SPIRV");
                     }
+                    VertexAttributeLocation location(static_cast<uint8_t>(unsanitizedLocation));
 
-                    metadata->usedVertexAttributes.set(location);
+                    spirv_cross::SPIRType::BaseType inputBaseType =
+                        compiler.get_type(attrib.base_type_id).basetype;
+                    metadata->vertexInputBaseTypes[location] =
+                        SpirvBaseTypeToVertexFormatBaseType(inputBaseType);
+                    metadata->usedVertexInputs.set(location);
                 }
 
                 // Without a location qualifier on vertex outputs, spirv_cross::CompilerMSL gives
@@ -846,6 +867,7 @@
                     }
                     uint32_t unsanitizedAttachment =
                         compiler.get_decoration(fragmentOutput.id, spv::DecorationLocation);
+
                     if (unsanitizedAttachment >= kMaxColorAttachments) {
                         return DAWN_VALIDATION_ERROR(
                             "Fragment output index must be less than max number of color "
@@ -958,13 +980,17 @@
                             return DAWN_VALIDATION_ERROR(
                                 "Need Location decoration on Vertex input");
                         }
-                        uint32_t location = input_var.location_decoration;
-                        if (DAWN_UNLIKELY(location >= kMaxVertexAttributes)) {
+                        uint32_t unsanitizedLocation = input_var.location_decoration;
+                        if (DAWN_UNLIKELY(unsanitizedLocation >= kMaxVertexAttributes)) {
                             std::stringstream ss;
-                            ss << "Attribute location (" << location << ") over limits";
+                            ss << "Attribute location (" << unsanitizedLocation << ") over limits";
                             return DAWN_VALIDATION_ERROR(ss.str());
                         }
-                        metadata->usedVertexAttributes.set(location);
+                        VertexAttributeLocation location(static_cast<uint8_t>(unsanitizedLocation));
+                        DAWN_TRY_ASSIGN(
+                            metadata->vertexInputBaseTypes[location],
+                            TintComponentTypeToVertexFormatBaseType(input_var.component_type));
+                        metadata->usedVertexInputs.set(location);
                     }
 
                     for (const auto& output_var : entryPoint.output_variables) {
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index 2717042..da948e3 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -25,6 +25,7 @@
 #include "dawn_native/Forward.h"
 #include "dawn_native/IntegerTypes.h"
 #include "dawn_native/PerStage.h"
+#include "dawn_native/VertexFormat.h"
 #include "dawn_native/dawn_platform.h"
 
 #include <bitset>
@@ -147,7 +148,9 @@
         std::vector<SamplerTexturePair> samplerTexturePairs;
 
         // The set of vertex attributes this entryPoint uses.
-        std::bitset<kMaxVertexAttributes> usedVertexAttributes;
+        ityp::array<VertexAttributeLocation, VertexFormatBaseType, kMaxVertexAttributes>
+            vertexInputBaseTypes;
+        ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes> usedVertexInputs;
 
         // An array to record the basic types (float, int and uint) of the fragment shader outputs.
         ityp::array<ColorAttachmentIndex, wgpu::TextureComponentType, kMaxColorAttachments>
diff --git a/src/dawn_native/SpirvUtils.cpp b/src/dawn_native/SpirvUtils.cpp
index 9472508..01749de 100644
--- a/src/dawn_native/SpirvUtils.cpp
+++ b/src/dawn_native/SpirvUtils.cpp
@@ -161,4 +161,18 @@
         }
     }
 
+    VertexFormatBaseType SpirvBaseTypeToVertexFormatBaseType(
+        spirv_cross::SPIRType::BaseType spirvBaseType) {
+        switch (spirvBaseType) {
+            case spirv_cross::SPIRType::Float:
+                return VertexFormatBaseType::Float;
+            case spirv_cross::SPIRType::Int:
+                return VertexFormatBaseType::Sint;
+            case spirv_cross::SPIRType::UInt:
+                return VertexFormatBaseType::Uint;
+            default:
+                UNREACHABLE();
+        }
+    }
+
 }  // namespace dawn_native
diff --git a/src/dawn_native/SpirvUtils.h b/src/dawn_native/SpirvUtils.h
index 158b165..ff356df 100644
--- a/src/dawn_native/SpirvUtils.h
+++ b/src/dawn_native/SpirvUtils.h
@@ -20,6 +20,7 @@
 
 #include "dawn_native/Format.h"
 #include "dawn_native/PerStage.h"
+#include "dawn_native/VertexFormat.h"
 #include "dawn_native/dawn_platform.h"
 
 #include <spirv_cross.hpp>
@@ -41,6 +42,10 @@
         spirv_cross::SPIRType::BaseType spirvBaseType);
     SampleTypeBit SpirvBaseTypeToSampleTypeBit(spirv_cross::SPIRType::BaseType spirvBaseType);
 
+    // Returns the VertexFormatBaseType corresponding to the SPIRV base type.
+    VertexFormatBaseType SpirvBaseTypeToVertexFormatBaseType(
+        spirv_cross::SPIRType::BaseType spirvBaseType);
+
 }  // namespace dawn_native
 
 #endif  // DAWNNATIVE_SPIRV_UTILS_H_
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index d6510e1..03652c6 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -356,7 +356,7 @@
         out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]);
 
         if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
-            GetEntryPoint(entryPointName).usedVertexAttributes.any()) {
+            GetEntryPoint(entryPointName).usedVertexInputs.any()) {
             out->needsStorageBufferLength = true;
         }
 
diff --git a/src/tests/unittests/validation/VertexStateValidationTests.cpp b/src/tests/unittests/validation/VertexStateValidationTests.cpp
index 9ebea1e..974dacb 100644
--- a/src/tests/unittests/validation/VertexStateValidationTests.cpp
+++ b/src/tests/unittests/validation/VertexStateValidationTests.cpp
@@ -306,7 +306,7 @@
     state.cAttributes[0].offset = 2;
     CreatePipeline(true, state, kDummyVertexShader);
 
-    state.cAttributes[0].format = wgpu::VertexFormat::Uint8x2;
+    state.cAttributes[0].format = wgpu::VertexFormat::Unorm8x2;
     state.cAttributes[0].offset = 1;
     CreatePipeline(true, state, kDummyVertexShader);
 
@@ -338,3 +338,80 @@
     state.cAttributes[0].format = wgpu::VertexFormat::Float32x4;
     CreatePipeline(false, state, kDummyVertexShader);
 }
+
+// Check that the vertex format base type must match the shader's variable base type.
+TEST_F(VertexStateTest, BaseTypeMatching) {
+    auto DoTest = [&](wgpu::VertexFormat format, std::string shaderType, bool success) {
+        utils::ComboVertexStateDescriptor state;
+        state.vertexBufferCount = 1;
+        state.cVertexBuffers[0].arrayStride = 16;
+        state.cVertexBuffers[0].attributeCount = 1;
+        state.cAttributes[0].format = format;
+
+        std::string shader = "[[stage(vertex)]] fn main([[location(0)]] attrib : " + shaderType +
+                             R"() -> [[builtin(position)]] vec4<f32> {
+            return vec4<f32>(0.0, 0.0, 0.0, 0.0);
+        })";
+
+        CreatePipeline(success, state, shader.c_str());
+    };
+
+    // Test that a float format is compatible only with f32 base type.
+    DoTest(wgpu::VertexFormat::Float32, "f32", true);
+    DoTest(wgpu::VertexFormat::Float32, "i32", false);
+    DoTest(wgpu::VertexFormat::Float32, "u32", false);
+
+    // Test that an unorm format is compatible only with f32.
+    DoTest(wgpu::VertexFormat::Unorm16x2, "f32", true);
+    DoTest(wgpu::VertexFormat::Unorm16x2, "i32", false);
+    DoTest(wgpu::VertexFormat::Unorm16x2, "u32", false);
+
+    // Test that an snorm format is compatible only with f32.
+    DoTest(wgpu::VertexFormat::Snorm16x4, "f32", true);
+    DoTest(wgpu::VertexFormat::Snorm16x4, "i32", false);
+    DoTest(wgpu::VertexFormat::Snorm16x4, "u32", false);
+
+    // Test that an uint format is compatible only with u32.
+    DoTest(wgpu::VertexFormat::Uint32x3, "f32", false);
+    DoTest(wgpu::VertexFormat::Uint32x3, "i32", false);
+    DoTest(wgpu::VertexFormat::Uint32x3, "u32", true);
+
+    // Test that an sint format is compatible only with u32.
+    DoTest(wgpu::VertexFormat::Sint8x4, "f32", false);
+    DoTest(wgpu::VertexFormat::Sint8x4, "i32", true);
+    DoTest(wgpu::VertexFormat::Sint8x4, "u32", false);
+
+    // Test that formats are compatible with any width of vectors.
+    DoTest(wgpu::VertexFormat::Float32, "f32", true);
+    DoTest(wgpu::VertexFormat::Float32, "vec2<f32>", true);
+    DoTest(wgpu::VertexFormat::Float32, "vec3<f32>", true);
+    DoTest(wgpu::VertexFormat::Float32, "vec4<f32>", true);
+
+    DoTest(wgpu::VertexFormat::Float32x4, "f32", true);
+    DoTest(wgpu::VertexFormat::Float32x4, "vec2<f32>", true);
+    DoTest(wgpu::VertexFormat::Float32x4, "vec3<f32>", true);
+    DoTest(wgpu::VertexFormat::Float32x4, "vec4<f32>", true);
+}
+
+// Check that we only check base type compatibility for vertex inputs the shader uses.
+TEST_F(VertexStateTest, BaseTypeMatchingForInexistentInput) {
+    auto DoTest = [&](wgpu::VertexFormat format) {
+        utils::ComboVertexStateDescriptor state;
+        state.vertexBufferCount = 1;
+        state.cVertexBuffers[0].arrayStride = 16;
+        state.cVertexBuffers[0].attributeCount = 1;
+        state.cAttributes[0].format = format;
+
+        std::string shader = R"([[stage(vertex)]] fn main() -> [[builtin(position)]] vec4<f32> {
+            return vec4<f32>(0.0, 0.0, 0.0, 0.0);
+        })";
+
+        CreatePipeline(true, state, shader.c_str());
+    };
+
+    DoTest(wgpu::VertexFormat::Float32);
+    DoTest(wgpu::VertexFormat::Unorm16x2);
+    DoTest(wgpu::VertexFormat::Snorm16x4);
+    DoTest(wgpu::VertexFormat::Uint8x4);
+    DoTest(wgpu::VertexFormat::Sint32x2);
+}