Add Metal vertex pulling behind a flag
Implements vertex pulling on the Metal backend, hidden behind a flag
until ready for use (we are missing support for more complicated vertex
input types).
Bug: dawn:480
Change-Id: I38028b80673693ebf21309ad5336561fb99f40dc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/26522
Commit-Queue: Idan Raiter <idanr@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 40ed372..4fdd708 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -295,6 +295,82 @@
<< " binding " << static_cast<uint32_t>(binding);
return ostream.str();
}
+
+#ifdef DAWN_ENABLE_WGSL
+ tint::ast::transform::VertexFormat ToTintVertexFormat(wgpu::VertexFormat format) {
+ switch (format) {
+ case wgpu::VertexFormat::UChar2:
+ return tint::ast::transform::VertexFormat::kVec2U8;
+ case wgpu::VertexFormat::UChar4:
+ return tint::ast::transform::VertexFormat::kVec4U8;
+ case wgpu::VertexFormat::Char2:
+ return tint::ast::transform::VertexFormat::kVec2I8;
+ case wgpu::VertexFormat::Char4:
+ return tint::ast::transform::VertexFormat::kVec4I8;
+ case wgpu::VertexFormat::UChar2Norm:
+ return tint::ast::transform::VertexFormat::kVec2U8Norm;
+ case wgpu::VertexFormat::UChar4Norm:
+ return tint::ast::transform::VertexFormat::kVec4U8Norm;
+ case wgpu::VertexFormat::Char2Norm:
+ return tint::ast::transform::VertexFormat::kVec2I8Norm;
+ case wgpu::VertexFormat::Char4Norm:
+ return tint::ast::transform::VertexFormat::kVec4I8Norm;
+ case wgpu::VertexFormat::UShort2:
+ return tint::ast::transform::VertexFormat::kVec2U16;
+ case wgpu::VertexFormat::UShort4:
+ return tint::ast::transform::VertexFormat::kVec4U16;
+ case wgpu::VertexFormat::Short2:
+ return tint::ast::transform::VertexFormat::kVec2I16;
+ case wgpu::VertexFormat::Short4:
+ return tint::ast::transform::VertexFormat::kVec4I16;
+ case wgpu::VertexFormat::UShort2Norm:
+ return tint::ast::transform::VertexFormat::kVec2U16Norm;
+ case wgpu::VertexFormat::UShort4Norm:
+ return tint::ast::transform::VertexFormat::kVec4U16Norm;
+ case wgpu::VertexFormat::Short2Norm:
+ return tint::ast::transform::VertexFormat::kVec2I16Norm;
+ case wgpu::VertexFormat::Short4Norm:
+ return tint::ast::transform::VertexFormat::kVec4I16Norm;
+ case wgpu::VertexFormat::Half2:
+ return tint::ast::transform::VertexFormat::kVec2F16;
+ case wgpu::VertexFormat::Half4:
+ return tint::ast::transform::VertexFormat::kVec4F16;
+ case wgpu::VertexFormat::Float:
+ return tint::ast::transform::VertexFormat::kF32;
+ case wgpu::VertexFormat::Float2:
+ return tint::ast::transform::VertexFormat::kVec2F32;
+ case wgpu::VertexFormat::Float3:
+ return tint::ast::transform::VertexFormat::kVec3F32;
+ case wgpu::VertexFormat::Float4:
+ return tint::ast::transform::VertexFormat::kVec4F32;
+ case wgpu::VertexFormat::UInt:
+ return tint::ast::transform::VertexFormat::kU32;
+ case wgpu::VertexFormat::UInt2:
+ return tint::ast::transform::VertexFormat::kVec2U32;
+ case wgpu::VertexFormat::UInt3:
+ return tint::ast::transform::VertexFormat::kVec3U32;
+ case wgpu::VertexFormat::UInt4:
+ return tint::ast::transform::VertexFormat::kVec4U32;
+ case wgpu::VertexFormat::Int:
+ return tint::ast::transform::VertexFormat::kI32;
+ case wgpu::VertexFormat::Int2:
+ return tint::ast::transform::VertexFormat::kVec2I32;
+ case wgpu::VertexFormat::Int3:
+ return tint::ast::transform::VertexFormat::kVec3I32;
+ case wgpu::VertexFormat::Int4:
+ return tint::ast::transform::VertexFormat::kVec4I32;
+ }
+ }
+
+ tint::ast::transform::InputStepMode ToTintInputStepMode(wgpu::InputStepMode mode) {
+ switch (mode) {
+ case wgpu::InputStepMode::Vertex:
+ return tint::ast::transform::InputStepMode::kVertex;
+ case wgpu::InputStepMode::Instance:
+ return tint::ast::transform::InputStepMode::kInstance;
+ }
+ }
+#endif
} // anonymous namespace
MaybeError ValidateSpirv(DeviceBase*, const uint32_t* code, uint32_t codeSize) {
@@ -400,6 +476,75 @@
std::vector<uint32_t> spirv = generator.result();
return std::move(spirv);
}
+
+ ResultOrError<std::vector<uint32_t>> ConvertWGSLToSPIRVWithPulling(
+ const char* source,
+ const VertexStateDescriptor& vertexState,
+ const std::string& entryPoint,
+ uint32_t pullingBufferBindingSet) {
+ std::ostringstream errorStream;
+ errorStream << "Tint WGSL->SPIR-V failure:" << std::endl;
+
+ tint::Context context;
+ tint::reader::wgsl::Parser parser(&context, source);
+
+ // TODO: This is a duplicate parse with ValidateWGSL, need to store
+ // state between calls to avoid this.
+ if (!parser.Parse()) {
+ errorStream << "Parser: " << parser.error() << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ tint::ast::Module module = parser.module();
+ if (!module.IsValid()) {
+ errorStream << "Invalid module generated..." << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ tint::ast::transform::VertexPullingTransform transform(&context, &module);
+ auto state = std::make_unique<tint::ast::transform::VertexStateDescriptor>();
+ for (uint32_t i = 0; i < vertexState.vertexBufferCount; ++i) {
+ auto& vertexBuffer = vertexState.vertexBuffers[i];
+ tint::ast::transform::VertexBufferLayoutDescriptor layout;
+ layout.array_stride = vertexBuffer.arrayStride;
+ layout.step_mode = ToTintInputStepMode(vertexBuffer.stepMode);
+
+ for (uint32_t j = 0; j < vertexBuffer.attributeCount; ++j) {
+ auto& attribute = vertexBuffer.attributes[j];
+ tint::ast::transform::VertexAttributeDescriptor attr;
+ attr.format = ToTintVertexFormat(attribute.format);
+ attr.offset = attribute.offset;
+ attr.shader_location = attribute.shaderLocation;
+
+ layout.attributes.push_back(std::move(attr));
+ }
+
+ state->vertex_buffers.push_back(std::move(layout));
+ }
+ transform.SetVertexState(std::move(state));
+ transform.SetEntryPoint(entryPoint);
+ transform.SetPullingBufferBindingSet(pullingBufferBindingSet);
+
+ if (!transform.Run()) {
+ errorStream << "Vertex pulling transform: " << transform.GetError();
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ tint::TypeDeterminer type_determiner(&context, &module);
+ if (!type_determiner.Determine()) {
+ errorStream << "Type Determination: " << type_determiner.error();
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ tint::writer::spirv::Generator generator(std::move(module));
+ if (!generator.Generate()) {
+ errorStream << "Generator: " << generator.error() << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ std::vector<uint32_t> spirv = generator.result();
+ return std::move(spirv);
+ }
#endif // DAWN_ENABLE_WGSL
MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
@@ -1094,10 +1239,22 @@
return mSpirv;
}
+#ifdef DAWN_ENABLE_WGSL
+ ResultOrError<std::vector<uint32_t>> ShaderModuleBase::GeneratePullingSpirv(
+ const VertexStateDescriptor& vertexState,
+ const std::string& entryPoint,
+ uint32_t pullingBufferBindingSet) const {
+ return ConvertWGSLToSPIRVWithPulling(mWgsl.c_str(), vertexState, entryPoint,
+ pullingBufferBindingSet);
+ }
+#endif
+
shaderc_spvc::CompileOptions ShaderModuleBase::GetCompileOptions() const {
shaderc_spvc::CompileOptions options;
options.SetValidate(GetDevice()->IsValidationEnabled());
options.SetRobustBufferAccessPass(GetDevice()->IsRobustnessEnabled());
+ options.SetSourceEnvironment(shaderc_target_env_vulkan, shaderc_env_version_vulkan_1_1);
+ options.SetTargetEnvironment(shaderc_target_env_vulkan, shaderc_env_version_vulkan_1_1);
return options;
}
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index 4e771aa..336551e 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -91,6 +91,13 @@
shaderc_spvc::Context* GetContext();
const std::vector<uint32_t>& GetSpirv() const;
+#ifdef DAWN_ENABLE_WGSL
+ ResultOrError<std::vector<uint32_t>> GeneratePullingSpirv(
+ const VertexStateDescriptor& vertexState,
+ const std::string& entryPoint,
+ uint32_t pullingBufferBindingSet) const;
+#endif
+
protected:
static MaybeError CheckSpvcSuccess(shaderc_spvc_status status, const char* error_msg);
shaderc_spvc::CompileOptions GetCompileOptions() const;
diff --git a/src/dawn_native/Toggles.cpp b/src/dawn_native/Toggles.cpp
index 7f40ebd..d54dc09 100644
--- a/src/dawn_native/Toggles.cpp
+++ b/src/dawn_native/Toggles.cpp
@@ -138,7 +138,11 @@
"Clear buffers on their first use. This is a temporary toggle only for the "
"development of buffer lazy initialization and will be removed after buffer lazy "
"initialization is completely implemented.",
- "https://crbug.com/dawn/414"}}}};
+ "https://crbug.com/dawn/414"}},
+ {Toggle::MetalEnableVertexPulling,
+ {"metal_enable_vertex_pulling",
+ "Uses vertex pulling to protect out-of-bounds reads on Metal",
+ "https://crbug.com/dawn/480"}}}};
} // anonymous namespace
diff --git a/src/dawn_native/Toggles.h b/src/dawn_native/Toggles.h
index f9d66a2..5186cb5 100644
--- a/src/dawn_native/Toggles.h
+++ b/src/dawn_native/Toggles.h
@@ -44,6 +44,7 @@
UseDXC,
DisableRobustness,
LazyClearBufferOnFirstUse,
+ MetalEnableVertexPulling,
EnumCount,
InvalidEnum = EnumCount,
diff --git a/src/dawn_native/metal/CommandBufferMTL.mm b/src/dawn_native/metal/CommandBufferMTL.mm
index 7390b92..351090aa 100644
--- a/src/dawn_native/metal/CommandBufferMTL.mm
+++ b/src/dawn_native/metal/CommandBufferMTL.mm
@@ -263,7 +263,9 @@
// MSL code generated by SPIRV-Cross expects.
PerStage<std::array<uint32_t, kGenericMetalBufferSlots>> data;
- void Apply(id<MTLRenderCommandEncoder> render, RenderPipeline* pipeline) {
+ void Apply(id<MTLRenderCommandEncoder> render,
+ RenderPipeline* pipeline,
+ bool enableVertexPulling) {
wgpu::ShaderStage stagesToApply =
dirtyStages & pipeline->GetStagesRequiringStorageBufferLength();
@@ -274,6 +276,11 @@
if (stagesToApply & wgpu::ShaderStage::Vertex) {
uint32_t bufferCount = ToBackend(pipeline->GetLayout())
->GetBufferBindingCount(SingleShaderStage::Vertex);
+
+ if (enableVertexPulling) {
+ bufferCount += pipeline->GetVertexStateDescriptor()->vertexBufferCount;
+ }
+
[render setVertexBytes:data[SingleShaderStage::Vertex].data()
length:sizeof(uint32_t) * bufferCount
atIndex:kBufferLengthBufferSlot];
@@ -483,10 +490,17 @@
// all the relevant state.
class VertexBufferTracker {
public:
+ explicit VertexBufferTracker(StorageBufferLengthTracker* lengthTracker)
+ : mLengthTracker(lengthTracker) {
+ }
+
void OnSetVertexBuffer(uint32_t slot, Buffer* buffer, uint64_t offset) {
mVertexBuffers[slot] = buffer->GetMTLBuffer();
mVertexBufferOffsets[slot] = offset;
+ ASSERT(buffer->GetSize() < std::numeric_limits<uint32_t>::max());
+ mVertexBufferBindingSizes[slot] = static_cast<uint32_t>(buffer->GetSize() - offset);
+
// Use 64 bit masks and make sure there are no shift UB
static_assert(kMaxVertexBuffers <= 8 * sizeof(unsigned long long) - 1, "");
mDirtyVertexBuffers |= 1ull << slot;
@@ -499,13 +513,22 @@
mDirtyVertexBuffers |= pipeline->GetVertexBufferSlotsUsed();
}
- void Apply(id<MTLRenderCommandEncoder> encoder, RenderPipeline* pipeline) {
+ void Apply(id<MTLRenderCommandEncoder> encoder,
+ RenderPipeline* pipeline,
+ bool enableVertexPulling) {
std::bitset<kMaxVertexBuffers> vertexBuffersToApply =
mDirtyVertexBuffers & pipeline->GetVertexBufferSlotsUsed();
for (uint32_t dawnIndex : IterateBitSet(vertexBuffersToApply)) {
uint32_t metalIndex = pipeline->GetMtlVertexBufferIndex(dawnIndex);
+ if (enableVertexPulling) {
+ // Insert lengths for vertex buffers bound as storage buffers
+ mLengthTracker->data[SingleShaderStage::Vertex][metalIndex] =
+ mVertexBufferBindingSizes[dawnIndex];
+ mLengthTracker->dirtyStages |= wgpu::ShaderStage::Vertex;
+ }
+
[encoder setVertexBuffers:&mVertexBuffers[dawnIndex]
offsets:&mVertexBufferOffsets[dawnIndex]
withRange:NSMakeRange(metalIndex, 1)];
@@ -519,6 +542,9 @@
std::bitset<kMaxVertexBuffers> mDirtyVertexBuffers;
std::array<id<MTLBuffer>, kMaxVertexBuffers> mVertexBuffers;
std::array<NSUInteger, kMaxVertexBuffers> mVertexBufferOffsets;
+ std::array<uint32_t, kMaxVertexBuffers> mVertexBufferBindingSizes;
+
+ StorageBufferLengthTracker* mLengthTracker;
};
} // anonymous namespace
@@ -949,11 +975,12 @@
MTLRenderPassDescriptor* mtlRenderPass,
uint32_t width,
uint32_t height) {
+ bool enableVertexPulling = GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling);
RenderPipeline* lastPipeline = nullptr;
id<MTLBuffer> indexBuffer = nil;
uint32_t indexBufferBaseOffset = 0;
- VertexBufferTracker vertexBuffers;
StorageBufferLengthTracker storageBufferLengths = {};
+ VertexBufferTracker vertexBuffers(&storageBufferLengths);
BindGroupTracker bindGroups(&storageBufferLengths);
id<MTLRenderCommandEncoder> encoder = commandContext->BeginRender(mtlRenderPass);
@@ -963,9 +990,9 @@
case Command::Draw: {
DrawCmd* draw = iter->NextCommand<DrawCmd>();
- vertexBuffers.Apply(encoder, lastPipeline);
+ vertexBuffers.Apply(encoder, lastPipeline, enableVertexPulling);
bindGroups.Apply(encoder);
- storageBufferLengths.Apply(encoder, lastPipeline);
+ storageBufferLengths.Apply(encoder, lastPipeline, enableVertexPulling);
// The instance count must be non-zero, otherwise no-op
if (draw->instanceCount != 0) {
@@ -991,9 +1018,9 @@
size_t formatSize =
IndexFormatSize(lastPipeline->GetVertexStateDescriptor()->indexFormat);
- vertexBuffers.Apply(encoder, lastPipeline);
+ vertexBuffers.Apply(encoder, lastPipeline, enableVertexPulling);
bindGroups.Apply(encoder);
- storageBufferLengths.Apply(encoder, lastPipeline);
+ storageBufferLengths.Apply(encoder, lastPipeline, enableVertexPulling);
// The index and instance count must be non-zero, otherwise no-op
if (draw->indexCount != 0 && draw->instanceCount != 0) {
@@ -1025,9 +1052,9 @@
case Command::DrawIndirect: {
DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>();
- vertexBuffers.Apply(encoder, lastPipeline);
+ vertexBuffers.Apply(encoder, lastPipeline, enableVertexPulling);
bindGroups.Apply(encoder);
- storageBufferLengths.Apply(encoder, lastPipeline);
+ storageBufferLengths.Apply(encoder, lastPipeline, enableVertexPulling);
Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
id<MTLBuffer> indirectBuffer = buffer->GetMTLBuffer();
@@ -1040,9 +1067,9 @@
case Command::DrawIndexedIndirect: {
DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>();
- vertexBuffers.Apply(encoder, lastPipeline);
+ vertexBuffers.Apply(encoder, lastPipeline, enableVertexPulling);
bindGroups.Apply(encoder);
- storageBufferLengths.Apply(encoder, lastPipeline);
+ storageBufferLengths.Apply(encoder, lastPipeline, enableVertexPulling);
Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
id<MTLBuffer> indirectBuffer = buffer->GetMTLBuffer();
diff --git a/src/dawn_native/metal/DeviceMTL.mm b/src/dawn_native/metal/DeviceMTL.mm
index a09f9270..0867639 100644
--- a/src/dawn_native/metal/DeviceMTL.mm
+++ b/src/dawn_native/metal/DeviceMTL.mm
@@ -61,6 +61,11 @@
MaybeError Device::Initialize() {
InitTogglesFromDriver();
+
+ if (!IsRobustnessEnabled() || !IsToggleEnabled(Toggle::UseSpvc)) {
+ ForceSetToggle(Toggle::MetalEnableVertexPulling, false);
+ }
+
mCommandQueue = [mMtlDevice newCommandQueue];
return DeviceBase::Initialize(new Queue(this));
diff --git a/src/dawn_native/metal/RenderPipelineMTL.mm b/src/dawn_native/metal/RenderPipelineMTL.mm
index 25c7869..1e9efe1 100644
--- a/src/dawn_native/metal/RenderPipelineMTL.mm
+++ b/src/dawn_native/metal/RenderPipelineMTL.mm
@@ -329,11 +329,24 @@
MTLRenderPipelineDescriptor* descriptorMTL = [MTLRenderPipelineDescriptor new];
+ // TODO: MakeVertexDesc should be const in the future, so we don't need to call it here when
+ // vertex pulling is enabled
+ MTLVertexDescriptor* vertexDesc = MakeVertexDesc();
+ descriptorMTL.vertexDescriptor = vertexDesc;
+ [vertexDesc release];
+
+ if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
+ // Calling MakeVertexDesc first is important since it sets indices for packed bindings
+ MTLVertexDescriptor* emptyVertexDesc = [MTLVertexDescriptor new];
+ descriptorMTL.vertexDescriptor = emptyVertexDesc;
+ [emptyVertexDesc release];
+ }
+
ShaderModule* vertexModule = ToBackend(descriptor->vertexStage.module);
const char* vertexEntryPoint = descriptor->vertexStage.entryPoint;
ShaderModule::MetalFunctionData vertexData;
DAWN_TRY(vertexModule->GetFunction(vertexEntryPoint, SingleShaderStage::Vertex,
- ToBackend(GetLayout()), &vertexData));
+ ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF, this));
descriptorMTL.vertexFunction = vertexData.function;
if (vertexData.needsStorageBufferLength) {
@@ -377,11 +390,6 @@
}
descriptorMTL.inputPrimitiveTopology = MTLInputPrimitiveTopology(GetPrimitiveTopology());
-
- MTLVertexDescriptor* vertexDesc = MakeVertexDesc();
- descriptorMTL.vertexDescriptor = vertexDesc;
- [vertexDesc release];
-
descriptorMTL.sampleCount = GetSampleCount();
descriptorMTL.alphaToCoverageEnabled = descriptor->alphaToCoverageEnabled;
diff --git a/src/dawn_native/metal/ShaderModuleMTL.h b/src/dawn_native/metal/ShaderModuleMTL.h
index 7460727..9193b25 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.h
+++ b/src/dawn_native/metal/ShaderModuleMTL.h
@@ -29,6 +29,7 @@
class Device;
class PipelineLayout;
+ class RenderPipeline;
class ShaderModule final : public ShaderModuleBase {
public:
@@ -47,7 +48,8 @@
SingleShaderStage functionStage,
const PipelineLayout* layout,
MetalFunctionData* out,
- uint32_t sampleMask = 0xFFFFFFFF);
+ uint32_t sampleMask = 0xFFFFFFFF,
+ const RenderPipeline* renderPipeline = nullptr);
private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index e49cc95..208612e 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -17,6 +17,7 @@
#include "dawn_native/BindGroupLayout.h"
#include "dawn_native/metal/DeviceMTL.h"
#include "dawn_native/metal/PipelineLayoutMTL.h"
+#include "dawn_native/metal/RenderPipelineMTL.h"
#include <spirv_msl.hpp>
@@ -92,10 +93,24 @@
SingleShaderStage functionStage,
const PipelineLayout* layout,
ShaderModule::MetalFunctionData* out,
- uint32_t sampleMask) {
+ uint32_t sampleMask,
+ const RenderPipeline* renderPipeline) {
ASSERT(!IsError());
ASSERT(out);
- const std::vector<uint32_t>& spirv = GetSpirv();
+ const std::vector<uint32_t>* spirv = &GetSpirv();
+
+#ifdef DAWN_ENABLE_WGSL
+ // Use set 4 since it is bigger than what users can access currently
+ static const uint32_t kPullingBufferBindingSet = 4;
+ std::vector<uint32_t> pullingSpirv;
+ if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
+ functionStage == SingleShaderStage::Vertex) {
+ DAWN_TRY_ASSIGN(pullingSpirv,
+ GeneratePullingSpirv(*renderPipeline->GetVertexStateDescriptor(),
+ functionName, kPullingBufferBindingSet));
+ spirv = &pullingSpirv;
+ }
+#endif
std::unique_ptr<spirv_cross::CompilerMSL> compilerImpl;
spirv_cross::CompilerMSL* compiler;
@@ -103,7 +118,7 @@
// Initializing the compiler is needed every call, because this method uses reflection
// to mutate the compiler's IR.
DAWN_TRY(
- CheckSpvcSuccess(mSpvcContext.InitializeForMsl(spirv.data(), spirv.size(),
+ CheckSpvcSuccess(mSpvcContext.InitializeForMsl(spirv->data(), spirv->size(),
GetMSLCompileOptions(sampleMask)),
"Unable to initialize instance of spvc"));
DAWN_TRY(CheckSpvcSuccess(mSpvcContext.GetCompiler(reinterpret_cast<void**>(&compiler)),
@@ -126,7 +141,7 @@
options_msl.additional_fixed_sample_mask = sampleMask;
- compilerImpl = std::make_unique<spirv_cross::CompilerMSL>(spirv);
+ compilerImpl = std::make_unique<spirv_cross::CompilerMSL>(*spirv);
compiler = compilerImpl.get();
compiler->set_msl_options(options_msl);
}
@@ -172,6 +187,22 @@
}
}
+ // Add vertex buffers bound as storage buffers
+ if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
+ functionStage == SingleShaderStage::Vertex) {
+ for (uint32_t dawnIndex : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
+ uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(dawnIndex);
+
+ shaderc_spvc_msl_resource_binding mslBinding;
+ mslBinding.stage = ToSpvcExecutionModel(SingleShaderStage::Vertex);
+ mslBinding.desc_set = kPullingBufferBindingSet;
+ mslBinding.binding = dawnIndex;
+ mslBinding.msl_buffer = metalIndex;
+ DAWN_TRY(CheckSpvcSuccess(mSpvcContext.AddMSLResourceBinding(mslBinding),
+ "Unable to add MSL Resource Binding"));
+ }
+ }
+
{
if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {
shaderc_spvc_execution_model executionModel = ToSpvcExecutionModel(functionStage);
@@ -245,6 +276,11 @@
out->needsStorageBufferLength = compiler->needs_buffer_size_buffer();
}
+ if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
+ functionStage == SingleShaderStage::Vertex && GetUsedVertexAttributes().any()) {
+ out->needsStorageBufferLength = true;
+ }
+
return {};
}
diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn
index 956f3d4..7817d8a 100644
--- a/src/tests/BUILD.gn
+++ b/src/tests/BUILD.gn
@@ -337,6 +337,10 @@
frameworks = [ "IOSurface.framework" ]
}
+ if (dawn_enable_wgsl) {
+ sources += [ "end2end/VertexBufferRobustnessTests.cpp" ]
+ }
+
if (dawn_enable_opengl) {
assert(dawn_supports_glfw_for_windowing)
}
diff --git a/src/tests/end2end/VertexBufferRobustnessTests.cpp b/src/tests/end2end/VertexBufferRobustnessTests.cpp
new file mode 100644
index 0000000..59d090b
--- /dev/null
+++ b/src/tests/end2end/VertexBufferRobustnessTests.cpp
@@ -0,0 +1,231 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "common/Assert.h"
+#include "common/Constants.h"
+#include "common/Math.h"
+#include "tests/DawnTest.h"
+#include "utils/ComboRenderPipelineDescriptor.h"
+#include "utils/WGPUHelpers.h"
+
+// Vertex buffer robustness tests that clamping is applied on vertex attributes. This would happen
+// on backends where vertex pulling is enabled, such as Metal.
+
+class VertexBufferRobustnessTest : public DawnTest {
+ protected:
+ void SetUp() override {
+ DawnTest::SetUp();
+ // SPVC must be used currently, since we rely on the robustness pass in it
+ DAWN_SKIP_TEST_IF(!IsSpvcBeingUsed());
+ }
+
+ // Creates a vertex module that tests an expression with given attributes. If successful, the
+ // point drawn would be moved out of the viewport. On failure, the point is kept inside the
+ // viewport.
+ wgpu::ShaderModule CreateVertexModule(const std::string& attributes,
+ const std::string& successExpression) {
+ return utils::CreateShaderModuleFromWGSL(device, (R"(
+ entry_point vertex as "main" = vtx_main;
+
+ )" + attributes + R"(
+ [[builtin position]] var<out> Position : vec4<f32>;
+
+ fn vtx_main() -> void {
+ if ()" + successExpression + R"() {
+ # Success case, move the vertex out of the viewport
+ Position = vec4<f32>(-10.0, 0.0, 0.0, 1.0);
+ } else {
+ # Failure case, move the vertex inside the viewport
+ Position = vec4<f32>(0.0, 0.0, 0.0, 1.0);
+ }
+ return;
+ }
+ )")
+ .c_str());
+ }
+
+ // Runs the test, a true |expectation| meaning success
+ void DoTest(const std::string& attributes,
+ const std::string& successExpression,
+ utils::ComboVertexStateDescriptor vertexState,
+ wgpu::Buffer vertexBuffer,
+ uint64_t bufferOffset,
+ bool expectation) {
+ wgpu::ShaderModule vsModule = CreateVertexModule(attributes, successExpression);
+ wgpu::ShaderModule fsModule = utils::CreateShaderModuleFromWGSL(device, R"(
+ entry_point fragment as "main" = frag_main;
+ [[location 0]] var<out> outColor : vec4<f32>;
+ fn frag_main() -> void {
+ outColor = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ return;
+ }
+ )");
+
+ utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, 1, 1);
+
+ utils::ComboRenderPipelineDescriptor descriptor(device);
+ descriptor.vertexStage.module = vsModule;
+ descriptor.cFragmentStage.module = fsModule;
+ descriptor.primitiveTopology = wgpu::PrimitiveTopology::PointList;
+ descriptor.cVertexState = std::move(vertexState);
+ descriptor.cColorStates[0].format = renderPass.colorFormat;
+ renderPass.renderPassInfo.cColorAttachments[0].clearColor = {0, 0, 0, 1};
+
+ wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&descriptor);
+
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+ wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+ pass.SetPipeline(pipeline);
+ pass.SetVertexBuffer(0, vertexBuffer, bufferOffset);
+ pass.Draw(1000);
+ pass.EndPass();
+
+ wgpu::CommandBuffer commands = encoder.Finish();
+ queue.Submit(1, &commands);
+
+ RGBA8 noOutput(0, 0, 0, 255);
+ RGBA8 someOutput(255, 255, 255, 255);
+ EXPECT_PIXEL_RGBA8_EQ(expectation ? noOutput : someOutput, renderPass.color, 0, 0);
+ }
+};
+
+TEST_P(VertexBufferRobustnessTest, DetectInvalidValues) {
+ utils::ComboVertexStateDescriptor vertexState;
+ vertexState.vertexBufferCount = 1;
+ vertexState.cVertexBuffers[0].arrayStride = sizeof(float);
+ vertexState.cVertexBuffers[0].attributeCount = 1;
+ vertexState.cAttributes[0].format = wgpu::VertexFormat::Float;
+ vertexState.cAttributes[0].offset = 0;
+ vertexState.cAttributes[0].shaderLocation = 0;
+
+ // Bind at an offset of 0, so we see 111.0, leading to failure
+ float kVertices[] = {111.0, 473.0, 473.0};
+ wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
+ wgpu::BufferUsage::Vertex);
+
+ DoTest("[[location 0]] var<in> a : f32;", "a == 473.0", std::move(vertexState), vertexBuffer, 0,
+ false);
+}
+
+TEST_P(VertexBufferRobustnessTest, FloatClamp) {
+ utils::ComboVertexStateDescriptor vertexState;
+ vertexState.vertexBufferCount = 1;
+ vertexState.cVertexBuffers[0].arrayStride = sizeof(float);
+ vertexState.cVertexBuffers[0].attributeCount = 1;
+ vertexState.cAttributes[0].format = wgpu::VertexFormat::Float;
+ vertexState.cAttributes[0].offset = 0;
+ vertexState.cAttributes[0].shaderLocation = 0;
+
+ // Bind at an offset of 4, so we clamp to only values containing 473.0
+ float kVertices[] = {111.0, 473.0, 473.0};
+ wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
+ wgpu::BufferUsage::Vertex);
+
+ DoTest("[[location 0]] var<in> a : f32;", "a == 473.0", std::move(vertexState), vertexBuffer, 4,
+ true);
+}
+
+TEST_P(VertexBufferRobustnessTest, IntClamp) {
+ utils::ComboVertexStateDescriptor vertexState;
+ vertexState.vertexBufferCount = 1;
+ vertexState.cVertexBuffers[0].arrayStride = sizeof(int32_t);
+ vertexState.cVertexBuffers[0].attributeCount = 1;
+ vertexState.cAttributes[0].format = wgpu::VertexFormat::Int;
+ vertexState.cAttributes[0].offset = 0;
+ vertexState.cAttributes[0].shaderLocation = 0;
+
+ // Bind at an offset of 4, so we clamp to only values containing 473
+ int32_t kVertices[] = {111, 473, 473};
+ wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
+ wgpu::BufferUsage::Vertex);
+
+ DoTest("[[location 0]] var<in> a : i32;", "a == 473", std::move(vertexState), vertexBuffer, 4,
+ true);
+}
+
+TEST_P(VertexBufferRobustnessTest, UIntClamp) {
+ utils::ComboVertexStateDescriptor vertexState;
+ vertexState.vertexBufferCount = 1;
+ vertexState.cVertexBuffers[0].arrayStride = sizeof(uint32_t);
+ vertexState.cVertexBuffers[0].attributeCount = 1;
+ vertexState.cAttributes[0].format = wgpu::VertexFormat::UInt;
+ vertexState.cAttributes[0].offset = 0;
+ vertexState.cAttributes[0].shaderLocation = 0;
+
+ // Bind at an offset of 4, so we clamp to only values containing 473
+ uint32_t kVertices[] = {111, 473, 473};
+ wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
+ wgpu::BufferUsage::Vertex);
+
+ DoTest("[[location 0]] var<in> a : u32;", "a == 473", std::move(vertexState), vertexBuffer, 4,
+ true);
+}
+
+TEST_P(VertexBufferRobustnessTest, Float2Clamp) {
+ utils::ComboVertexStateDescriptor vertexState;
+ vertexState.vertexBufferCount = 1;
+ vertexState.cVertexBuffers[0].arrayStride = sizeof(float) * 2;
+ vertexState.cVertexBuffers[0].attributeCount = 1;
+ vertexState.cAttributes[0].format = wgpu::VertexFormat::Float2;
+ vertexState.cAttributes[0].offset = 0;
+ vertexState.cAttributes[0].shaderLocation = 0;
+
+ // Bind at an offset of 8, so we clamp to only values containing 473.0
+ float kVertices[] = {111.0, 111.0, 473.0, 473.0};
+ wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
+ wgpu::BufferUsage::Vertex);
+
+ DoTest("[[location 0]] var<in> a : vec2<f32>;", "a[0] == 473.0 && a[1] == 473.0",
+ std::move(vertexState), vertexBuffer, 8, true);
+}
+
+TEST_P(VertexBufferRobustnessTest, Float3Clamp) {
+ utils::ComboVertexStateDescriptor vertexState;
+ vertexState.vertexBufferCount = 1;
+ vertexState.cVertexBuffers[0].arrayStride = sizeof(float) * 3;
+ vertexState.cVertexBuffers[0].attributeCount = 1;
+ vertexState.cAttributes[0].format = wgpu::VertexFormat::Float3;
+ vertexState.cAttributes[0].offset = 0;
+ vertexState.cAttributes[0].shaderLocation = 0;
+
+ // Bind at an offset of 12, so we clamp to only values containing 473.0
+ float kVertices[] = {111.0, 111.0, 111.0, 473.0, 473.0, 473.0};
+ wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
+ wgpu::BufferUsage::Vertex);
+
+ DoTest("[[location 0]] var<in> a : vec3<f32>;",
+ "a[0] == 473.0 && a[1] == 473.0 && a[2] == 473.0", std::move(vertexState), vertexBuffer,
+ 12, true);
+}
+
+TEST_P(VertexBufferRobustnessTest, Float4Clamp) {
+ utils::ComboVertexStateDescriptor vertexState;
+ vertexState.vertexBufferCount = 1;
+ vertexState.cVertexBuffers[0].arrayStride = sizeof(float) * 4;
+ vertexState.cVertexBuffers[0].attributeCount = 1;
+ vertexState.cAttributes[0].format = wgpu::VertexFormat::Float4;
+ vertexState.cAttributes[0].offset = 0;
+ vertexState.cAttributes[0].shaderLocation = 0;
+
+ // Bind at an offset of 16, so we clamp to only values containing 473.0
+ float kVertices[] = {111.0, 111.0, 111.0, 111.0, 473.0, 473.0, 473.0, 473.0};
+ wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
+ wgpu::BufferUsage::Vertex);
+
+ DoTest("[[location 0]] var<in> a : vec4<f32>;",
+ "a[0] == 473.0 && a[1] == 473.0 && a[2] == 473.0 && a[3] == 473.0",
+ std::move(vertexState), vertexBuffer, 16, true);
+}
+
+DAWN_INSTANTIATE_TEST(VertexBufferRobustnessTest, MetalBackend({"metal_enable_vertex_pulling"}));
diff --git a/src/utils/WGPUHelpers.cpp b/src/utils/WGPUHelpers.cpp
index b9a51bb..94f34c8 100644
--- a/src/utils/WGPUHelpers.cpp
+++ b/src/utils/WGPUHelpers.cpp
@@ -144,6 +144,14 @@
return CreateShaderModuleFromResult(device, result);
}
+ wgpu::ShaderModule CreateShaderModuleFromWGSL(const wgpu::Device& device, const char* source) {
+ wgpu::ShaderModuleWGSLDescriptor wgslDesc;
+ wgslDesc.source = source;
+ wgpu::ShaderModuleDescriptor descriptor;
+ descriptor.nextInChain = &wgslDesc;
+ return device.CreateShaderModule(&descriptor);
+ }
+
std::vector<uint32_t> CompileGLSLToSpirv(SingleShaderStage stage, const char* source) {
shaderc_shader_kind kind = ShadercShaderKind(stage);
diff --git a/src/utils/WGPUHelpers.h b/src/utils/WGPUHelpers.h
index 03e618c..d0a66fd 100644
--- a/src/utils/WGPUHelpers.h
+++ b/src/utils/WGPUHelpers.h
@@ -34,6 +34,8 @@
SingleShaderStage stage,
const char* source);
wgpu::ShaderModule CreateShaderModuleFromASM(const wgpu::Device& device, const char* source);
+ wgpu::ShaderModule CreateShaderModuleFromWGSL(const wgpu::Device& device, const char* source);
+
std::vector<uint32_t> CompileGLSLToSpirv(SingleShaderStage stage, const char* source);
wgpu::Buffer CreateBufferFromData(const wgpu::Device& device,