Implement render pipeline vertex format base type validation.
Bug: dawn:1008
Change-Id: I04d1ff1d46c1106147a8c50415c989db5789cbfc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/59031
Auto-Submit: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Jiawei Shao <jiawei.shao@intel.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp
index 3f59535..7392f1d 100644
--- a/src/dawn_native/RenderPipeline.cpp
+++ b/src/dawn_native/RenderPipeline.cpp
@@ -28,16 +28,19 @@
// Helper functions
namespace {
- MaybeError ValidateVertexAttribute(DeviceBase* device,
- const VertexAttribute* attribute,
- uint64_t vertexBufferStride,
- std::bitset<kMaxVertexAttributes>* attributesSetMask) {
+ MaybeError ValidateVertexAttribute(
+ DeviceBase* device,
+ const VertexAttribute* attribute,
+ const EntryPointMetadata& metadata,
+ uint64_t vertexBufferStride,
+ ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes>* attributesSetMask) {
DAWN_TRY(ValidateVertexFormat(attribute->format));
const VertexFormatInfo& formatInfo = GetVertexFormatInfo(attribute->format);
if (attribute->shaderLocation >= kMaxVertexAttributes) {
return DAWN_VALIDATION_ERROR("Setting attribute out of bounds");
}
+ VertexAttributeLocation location(static_cast<uint8_t>(attribute->shaderLocation));
// No underflow is possible because the max vertex format size is smaller than
// kMaxVertexBufferArrayStride.
@@ -59,18 +62,25 @@
"Attribute offset needs to be a multiple of the size format's components");
}
- if ((*attributesSetMask)[attribute->shaderLocation]) {
+ if (metadata.usedVertexInputs[location] &&
+ formatInfo.baseType != metadata.vertexInputBaseTypes[location]) {
+ return DAWN_VALIDATION_ERROR(
+ "Attribute base type must match the base type in the shader.");
+ }
+
+ if ((*attributesSetMask)[location]) {
return DAWN_VALIDATION_ERROR("Setting already set attribute");
}
- attributesSetMask->set(attribute->shaderLocation);
+ attributesSetMask->set(location);
return {};
}
MaybeError ValidateVertexBufferLayout(
DeviceBase* device,
const VertexBufferLayout* buffer,
- std::bitset<kMaxVertexAttributes>* attributesSetMask) {
+ const EntryPointMetadata& metadata,
+ ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes>* attributesSetMask) {
DAWN_TRY(ValidateInputStepMode(buffer->stepMode));
if (buffer->arrayStride > kMaxVertexBufferArrayStride) {
return DAWN_VALIDATION_ERROR("Setting arrayStride out of bounds");
@@ -82,7 +92,7 @@
}
for (uint32_t i = 0; i < buffer->attributeCount; ++i) {
- DAWN_TRY(ValidateVertexAttribute(device, &buffer->attributes[i],
+ DAWN_TRY(ValidateVertexAttribute(device, &buffer->attributes[i], metadata,
buffer->arrayStride, attributesSetMask));
}
@@ -100,10 +110,15 @@
return DAWN_VALIDATION_ERROR("Vertex buffer count exceeds maximum");
}
- std::bitset<kMaxVertexAttributes> attributesSetMask;
+ DAWN_TRY(ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint,
+ layout, SingleShaderStage::Vertex));
+ const EntryPointMetadata& vertexMetadata =
+ descriptor->module->GetEntryPoint(descriptor->entryPoint);
+
+ ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes> attributesSetMask;
uint32_t totalAttributesNum = 0;
for (uint32_t i = 0; i < descriptor->bufferCount; ++i) {
- DAWN_TRY(ValidateVertexBufferLayout(device, &descriptor->buffers[i],
+ DAWN_TRY(ValidateVertexBufferLayout(device, &descriptor->buffers[i], vertexMetadata,
&attributesSetMask));
totalAttributesNum += descriptor->buffers[i].attributeCount;
}
@@ -114,11 +129,7 @@
// attribute number never exceed kMaxVertexAttributes.
ASSERT(totalAttributesNum <= kMaxVertexAttributes);
- DAWN_TRY(ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint,
- layout, SingleShaderStage::Vertex));
- const EntryPointMetadata& vertexMetadata =
- descriptor->module->GetEntryPoint(descriptor->entryPoint);
- if (!IsSubset(vertexMetadata.usedVertexAttributes, attributesSetMask)) {
+ if (!IsSubset(vertexMetadata.usedVertexInputs, attributesSetMask)) {
return DAWN_VALIDATION_ERROR(
"Pipeline vertex stage uses vertex buffers not in the vertex state");
}
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 9bd3b7e..55b771d 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -295,6 +295,21 @@
}
}
+ ResultOrError<VertexFormatBaseType> TintComponentTypeToVertexFormatBaseType(
+ tint::inspector::ComponentType type) {
+ switch (type) {
+ case tint::inspector::ComponentType::kFloat:
+ return VertexFormatBaseType::Float;
+ case tint::inspector::ComponentType::kSInt:
+ return VertexFormatBaseType::Sint;
+ case tint::inspector::ComponentType::kUInt:
+ return VertexFormatBaseType::Uint;
+ case tint::inspector::ComponentType::kUnknown:
+ return DAWN_VALIDATION_ERROR(
+ "Attempted to convert 'Unknown' component type from Tint");
+ }
+ }
+
ResultOrError<wgpu::BufferBindingType> TintResourceTypeToBufferBindingType(
tint::inspector::ResourceBinding::ResourceType resource_type) {
switch (resource_type) {
@@ -811,13 +826,19 @@
return DAWN_VALIDATION_ERROR(
"Unable to find Location decoration for Vertex input");
}
- uint32_t location = compiler.get_decoration(attrib.id, spv::DecorationLocation);
+ uint32_t unsanitizedLocation =
+ compiler.get_decoration(attrib.id, spv::DecorationLocation);
- if (location >= kMaxVertexAttributes) {
+ if (unsanitizedLocation >= kMaxVertexAttributes) {
return DAWN_VALIDATION_ERROR("Attribute location over limits in the SPIRV");
}
+ VertexAttributeLocation location(static_cast<uint8_t>(unsanitizedLocation));
- metadata->usedVertexAttributes.set(location);
+ spirv_cross::SPIRType::BaseType inputBaseType =
+ compiler.get_type(attrib.base_type_id).basetype;
+ metadata->vertexInputBaseTypes[location] =
+ SpirvBaseTypeToVertexFormatBaseType(inputBaseType);
+ metadata->usedVertexInputs.set(location);
}
// Without a location qualifier on vertex outputs, spirv_cross::CompilerMSL gives
@@ -846,6 +867,7 @@
}
uint32_t unsanitizedAttachment =
compiler.get_decoration(fragmentOutput.id, spv::DecorationLocation);
+
if (unsanitizedAttachment >= kMaxColorAttachments) {
return DAWN_VALIDATION_ERROR(
"Fragment output index must be less than max number of color "
@@ -958,13 +980,17 @@
return DAWN_VALIDATION_ERROR(
"Need Location decoration on Vertex input");
}
- uint32_t location = input_var.location_decoration;
- if (DAWN_UNLIKELY(location >= kMaxVertexAttributes)) {
+ uint32_t unsanitizedLocation = input_var.location_decoration;
+ if (DAWN_UNLIKELY(unsanitizedLocation >= kMaxVertexAttributes)) {
std::stringstream ss;
- ss << "Attribute location (" << location << ") over limits";
+ ss << "Attribute location (" << unsanitizedLocation << ") over limits";
return DAWN_VALIDATION_ERROR(ss.str());
}
- metadata->usedVertexAttributes.set(location);
+ VertexAttributeLocation location(static_cast<uint8_t>(unsanitizedLocation));
+ DAWN_TRY_ASSIGN(
+ metadata->vertexInputBaseTypes[location],
+ TintComponentTypeToVertexFormatBaseType(input_var.component_type));
+ metadata->usedVertexInputs.set(location);
}
for (const auto& output_var : entryPoint.output_variables) {
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index 2717042..da948e3 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -25,6 +25,7 @@
#include "dawn_native/Forward.h"
#include "dawn_native/IntegerTypes.h"
#include "dawn_native/PerStage.h"
+#include "dawn_native/VertexFormat.h"
#include "dawn_native/dawn_platform.h"
#include <bitset>
@@ -147,7 +148,9 @@
std::vector<SamplerTexturePair> samplerTexturePairs;
// The set of vertex attributes this entryPoint uses.
- std::bitset<kMaxVertexAttributes> usedVertexAttributes;
+ ityp::array<VertexAttributeLocation, VertexFormatBaseType, kMaxVertexAttributes>
+ vertexInputBaseTypes;
+ ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes> usedVertexInputs;
// An array to record the basic types (float, int and uint) of the fragment shader outputs.
ityp::array<ColorAttachmentIndex, wgpu::TextureComponentType, kMaxColorAttachments>
diff --git a/src/dawn_native/SpirvUtils.cpp b/src/dawn_native/SpirvUtils.cpp
index 9472508..01749de 100644
--- a/src/dawn_native/SpirvUtils.cpp
+++ b/src/dawn_native/SpirvUtils.cpp
@@ -161,4 +161,18 @@
}
}
+ VertexFormatBaseType SpirvBaseTypeToVertexFormatBaseType(
+ spirv_cross::SPIRType::BaseType spirvBaseType) {
+ switch (spirvBaseType) {
+ case spirv_cross::SPIRType::Float:
+ return VertexFormatBaseType::Float;
+ case spirv_cross::SPIRType::Int:
+ return VertexFormatBaseType::Sint;
+ case spirv_cross::SPIRType::UInt:
+ return VertexFormatBaseType::Uint;
+ default:
+ UNREACHABLE();
+ }
+ }
+
} // namespace dawn_native
diff --git a/src/dawn_native/SpirvUtils.h b/src/dawn_native/SpirvUtils.h
index 158b165..ff356df 100644
--- a/src/dawn_native/SpirvUtils.h
+++ b/src/dawn_native/SpirvUtils.h
@@ -20,6 +20,7 @@
#include "dawn_native/Format.h"
#include "dawn_native/PerStage.h"
+#include "dawn_native/VertexFormat.h"
#include "dawn_native/dawn_platform.h"
#include <spirv_cross.hpp>
@@ -41,6 +42,10 @@
spirv_cross::SPIRType::BaseType spirvBaseType);
SampleTypeBit SpirvBaseTypeToSampleTypeBit(spirv_cross::SPIRType::BaseType spirvBaseType);
+ // Returns the VertexFormatBaseType corresponding to the SPIRV base type.
+ VertexFormatBaseType SpirvBaseTypeToVertexFormatBaseType(
+ spirv_cross::SPIRType::BaseType spirvBaseType);
+
} // namespace dawn_native
#endif // DAWNNATIVE_SPIRV_UTILS_H_
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index d6510e1..03652c6 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -356,7 +356,7 @@
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]);
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
- GetEntryPoint(entryPointName).usedVertexAttributes.any()) {
+ GetEntryPoint(entryPointName).usedVertexInputs.any()) {
out->needsStorageBufferLength = true;
}
diff --git a/src/tests/unittests/validation/VertexStateValidationTests.cpp b/src/tests/unittests/validation/VertexStateValidationTests.cpp
index 9ebea1e..974dacb 100644
--- a/src/tests/unittests/validation/VertexStateValidationTests.cpp
+++ b/src/tests/unittests/validation/VertexStateValidationTests.cpp
@@ -306,7 +306,7 @@
state.cAttributes[0].offset = 2;
CreatePipeline(true, state, kDummyVertexShader);
- state.cAttributes[0].format = wgpu::VertexFormat::Uint8x2;
+ state.cAttributes[0].format = wgpu::VertexFormat::Unorm8x2;
state.cAttributes[0].offset = 1;
CreatePipeline(true, state, kDummyVertexShader);
@@ -338,3 +338,80 @@
state.cAttributes[0].format = wgpu::VertexFormat::Float32x4;
CreatePipeline(false, state, kDummyVertexShader);
}
+
+// Check that the vertex format base type must match the shader's variable base type.
+TEST_F(VertexStateTest, BaseTypeMatching) {
+ auto DoTest = [&](wgpu::VertexFormat format, std::string shaderType, bool success) {
+ utils::ComboVertexStateDescriptor state;
+ state.vertexBufferCount = 1;
+ state.cVertexBuffers[0].arrayStride = 16;
+ state.cVertexBuffers[0].attributeCount = 1;
+ state.cAttributes[0].format = format;
+
+ std::string shader = "[[stage(vertex)]] fn main([[location(0)]] attrib : " + shaderType +
+ R"() -> [[builtin(position)]] vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 0.0);
+ })";
+
+ CreatePipeline(success, state, shader.c_str());
+ };
+
+ // Test that a float format is compatible only with f32 base type.
+ DoTest(wgpu::VertexFormat::Float32, "f32", true);
+ DoTest(wgpu::VertexFormat::Float32, "i32", false);
+ DoTest(wgpu::VertexFormat::Float32, "u32", false);
+
+ // Test that an unorm format is compatible only with f32.
+ DoTest(wgpu::VertexFormat::Unorm16x2, "f32", true);
+ DoTest(wgpu::VertexFormat::Unorm16x2, "i32", false);
+ DoTest(wgpu::VertexFormat::Unorm16x2, "u32", false);
+
+ // Test that an snorm format is compatible only with f32.
+ DoTest(wgpu::VertexFormat::Snorm16x4, "f32", true);
+ DoTest(wgpu::VertexFormat::Snorm16x4, "i32", false);
+ DoTest(wgpu::VertexFormat::Snorm16x4, "u32", false);
+
+ // Test that an uint format is compatible only with u32.
+ DoTest(wgpu::VertexFormat::Uint32x3, "f32", false);
+ DoTest(wgpu::VertexFormat::Uint32x3, "i32", false);
+ DoTest(wgpu::VertexFormat::Uint32x3, "u32", true);
+
+ // Test that an sint format is compatible only with u32.
+ DoTest(wgpu::VertexFormat::Sint8x4, "f32", false);
+ DoTest(wgpu::VertexFormat::Sint8x4, "i32", true);
+ DoTest(wgpu::VertexFormat::Sint8x4, "u32", false);
+
+ // Test that formats are compatible with any width of vectors.
+ DoTest(wgpu::VertexFormat::Float32, "f32", true);
+ DoTest(wgpu::VertexFormat::Float32, "vec2<f32>", true);
+ DoTest(wgpu::VertexFormat::Float32, "vec3<f32>", true);
+ DoTest(wgpu::VertexFormat::Float32, "vec4<f32>", true);
+
+ DoTest(wgpu::VertexFormat::Float32x4, "f32", true);
+ DoTest(wgpu::VertexFormat::Float32x4, "vec2<f32>", true);
+ DoTest(wgpu::VertexFormat::Float32x4, "vec3<f32>", true);
+ DoTest(wgpu::VertexFormat::Float32x4, "vec4<f32>", true);
+}
+
+// Check that we only check base type compatibility for vertex inputs the shader uses.
+TEST_F(VertexStateTest, BaseTypeMatchingForInexistentInput) {
+ auto DoTest = [&](wgpu::VertexFormat format) {
+ utils::ComboVertexStateDescriptor state;
+ state.vertexBufferCount = 1;
+ state.cVertexBuffers[0].arrayStride = 16;
+ state.cVertexBuffers[0].attributeCount = 1;
+ state.cAttributes[0].format = format;
+
+ std::string shader = R"([[stage(vertex)]] fn main() -> [[builtin(position)]] vec4<f32> {
+ return vec4<f32>(0.0, 0.0, 0.0, 0.0);
+ })";
+
+ CreatePipeline(true, state, shader.c_str());
+ };
+
+ DoTest(wgpu::VertexFormat::Float32);
+ DoTest(wgpu::VertexFormat::Unorm16x2);
+ DoTest(wgpu::VertexFormat::Snorm16x4);
+ DoTest(wgpu::VertexFormat::Uint8x4);
+ DoTest(wgpu::VertexFormat::Sint32x2);
+}