Implement primitive topology in OpenGL, Metal, and D3D12 backends
diff --git a/src/backend/d3d12/CommandBufferD3D12.cpp b/src/backend/d3d12/CommandBufferD3D12.cpp
index 7f5e6a8..c0dac98 100644
--- a/src/backend/d3d12/CommandBufferD3D12.cpp
+++ b/src/backend/d3d12/CommandBufferD3D12.cpp
@@ -399,7 +399,6 @@
{
DrawArraysCmd* draw = commands.NextCommand<DrawArraysCmd>();
- commandList->IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
commandList->DrawInstanced(
draw->vertexCount,
draw->instanceCount,
@@ -413,7 +412,6 @@
{
DrawElementsCmd* draw = commands.NextCommand<DrawElementsCmd>();
- commandList->IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
commandList->DrawIndexedInstanced(
draw->indexCount,
draw->instanceCount,
@@ -464,6 +462,7 @@
commandList->SetGraphicsRootSignature(layout->GetRootSignature().Get());
commandList->SetPipelineState(pipeline->GetPipelineState().Get());
+ commandList->IASetPrimitiveTopology(pipeline->GetD3D12PrimitiveTopology());
bindingTracker.SetInheritedBindGroups(commandList, lastLayout, layout);
diff --git a/src/backend/d3d12/RenderPipelineD3D12.cpp b/src/backend/d3d12/RenderPipelineD3D12.cpp
index 5dfa2fd..60ef3d1 100644
--- a/src/backend/d3d12/RenderPipelineD3D12.cpp
+++ b/src/backend/d3d12/RenderPipelineD3D12.cpp
@@ -25,8 +25,27 @@
namespace backend {
namespace d3d12 {
+ namespace {
+ D3D12_PRIMITIVE_TOPOLOGY D3D12PrimitiveTopology(nxt::PrimitiveTopology primitiveTopology) {
+ switch (primitiveTopology) {
+ case nxt::PrimitiveTopology::Point:
+ return D3D_PRIMITIVE_TOPOLOGY_POINTLIST;
+ case nxt::PrimitiveTopology::Line:
+ return D3D_PRIMITIVE_TOPOLOGY_LINELIST;
+ case nxt::PrimitiveTopology::LineStrip:
+ return D3D_PRIMITIVE_TOPOLOGY_LINESTRIP;
+ case nxt::PrimitiveTopology::Triangle:
+ return D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST;
+ case nxt::PrimitiveTopology::TriangleStrip:
+ return D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP;
+ default:
+ UNREACHABLE();
+ }
+ }
+ }
+
RenderPipeline::RenderPipeline(RenderPipelineBuilder* builder)
- : RenderPipelineBase(builder) {
+ : RenderPipelineBase(builder), d3d12PrimitiveTopology(D3D12PrimitiveTopology(GetPrimitiveTopology())) {
uint32_t compileFlags = 0;
#if defined(_DEBUG)
// Enable better shader debugging with the graphics debugging tools.
@@ -132,6 +151,10 @@
ASSERT_SUCCESS(device->GetD3D12Device()->CreateGraphicsPipelineState(&descriptor, IID_PPV_ARGS(&pipelineState)));
}
+ D3D12_PRIMITIVE_TOPOLOGY RenderPipeline::GetD3D12PrimitiveTopology() const {
+ return d3d12PrimitiveTopology;
+ }
+
ComPtr<ID3D12PipelineState> RenderPipeline::GetPipelineState() {
return pipelineState;
}
diff --git a/src/backend/d3d12/RenderPipelineD3D12.h b/src/backend/d3d12/RenderPipelineD3D12.h
index 30b0dc3..81cdf7c 100644
--- a/src/backend/d3d12/RenderPipelineD3D12.h
+++ b/src/backend/d3d12/RenderPipelineD3D12.h
@@ -26,9 +26,11 @@
public:
RenderPipeline(RenderPipelineBuilder* builder);
+ D3D12_PRIMITIVE_TOPOLOGY GetD3D12PrimitiveTopology() const;
ComPtr<ID3D12PipelineState> GetPipelineState();
private:
+ D3D12_PRIMITIVE_TOPOLOGY d3d12PrimitiveTopology;
ComPtr<ID3D12PipelineState> pipelineState;
};
diff --git a/src/backend/metal/CommandBufferMTL.mm b/src/backend/metal/CommandBufferMTL.mm
index ddac071..25203cc 100644
--- a/src/backend/metal/CommandBufferMTL.mm
+++ b/src/backend/metal/CommandBufferMTL.mm
@@ -279,7 +279,7 @@
ASSERT(encoders.render);
[encoders.render
- drawPrimitives:MTLPrimitiveTypeTriangle
+ drawPrimitives:lastRenderPipeline->GetMTLPrimitiveTopology()
vertexStart:draw->firstVertex
vertexCount:draw->vertexCount
instanceCount:draw->instanceCount
@@ -293,7 +293,7 @@
ASSERT(encoders.render);
[encoders.render
- drawIndexedPrimitives:MTLPrimitiveTypeTriangle
+ drawIndexedPrimitives:lastRenderPipeline->GetMTLPrimitiveTopology()
indexCount:draw->indexCount
indexType:indexType
indexBuffer:indexBuffer
diff --git a/src/backend/metal/RenderPipelineMTL.h b/src/backend/metal/RenderPipelineMTL.h
index d2ba1f7..33f4830 100644
--- a/src/backend/metal/RenderPipelineMTL.h
+++ b/src/backend/metal/RenderPipelineMTL.h
@@ -27,9 +27,12 @@
RenderPipeline(RenderPipelineBuilder* builder);
~RenderPipeline();
+ MTLPrimitiveType GetMTLPrimitiveTopology() const;
+
void Encode(id<MTLRenderCommandEncoder> encoder);
private:
+ MTLPrimitiveType mtlPrimitiveTopology;
id<MTLRenderPipelineState> mtlRenderPipelineState = nil;
};
diff --git a/src/backend/metal/RenderPipelineMTL.mm b/src/backend/metal/RenderPipelineMTL.mm
index 15fce68..a2e2f82 100644
--- a/src/backend/metal/RenderPipelineMTL.mm
+++ b/src/backend/metal/RenderPipelineMTL.mm
@@ -23,8 +23,25 @@
namespace backend {
namespace metal {
+ namespace {
+ MTLPrimitiveType MTLPrimitiveTopology(nxt::PrimitiveTopology primitiveTopology) {
+ switch (primitiveTopology) {
+ case nxt::PrimitiveTopology::Point:
+ return MTLPrimitiveTypePoint;
+ case nxt::PrimitiveTopology::Line:
+ return MTLPrimitiveTypeLine;
+ case nxt::PrimitiveTopology::LineStrip:
+ return MTLPrimitiveTypeLineStrip;
+ case nxt::PrimitiveTopology::Triangle:
+ return MTLPrimitiveTypeTriangle;
+ case nxt::PrimitiveTopology::TriangleStrip:
+ return MTLPrimitiveTypeTriangleStrip;
+ }
+ }
+ }
+
RenderPipeline::RenderPipeline(RenderPipelineBuilder* builder)
- : RenderPipelineBase(builder) {
+ : RenderPipelineBase(builder), mtlPrimitiveTopology(MTLPrimitiveTopology(GetPrimitiveTopology())) {
auto mtlDevice = ToBackend(builder->GetDevice())->GetMTLDevice();
@@ -73,6 +90,10 @@
[mtlRenderPipelineState release];
}
+ MTLPrimitiveType RenderPipeline::GetMTLPrimitiveTopology() const {
+ return mtlPrimitiveTopology;
+ }
+
void RenderPipeline::Encode(id<MTLRenderCommandEncoder> encoder) {
[encoder setRenderPipelineState:mtlRenderPipelineState];
}
diff --git a/src/backend/opengl/CommandBufferGL.cpp b/src/backend/opengl/CommandBufferGL.cpp
index 48be807..9241b73 100644
--- a/src/backend/opengl/CommandBufferGL.cpp
+++ b/src/backend/opengl/CommandBufferGL.cpp
@@ -236,11 +236,11 @@
{
DrawArraysCmd* draw = commands.NextCommand<DrawArraysCmd>();
if (draw->firstInstance > 0) {
- glDrawArraysInstancedBaseInstance(GL_TRIANGLES,
+ glDrawArraysInstancedBaseInstance(lastRenderPipeline->GetGLPrimitiveTopology(),
draw->firstVertex, draw->vertexCount, draw->instanceCount, draw->firstInstance);
} else {
// This branch is only needed on OpenGL < 4.2
- glDrawArraysInstanced(GL_TRIANGLES,
+ glDrawArraysInstanced(lastRenderPipeline->GetGLPrimitiveTopology(),
draw->firstVertex, draw->vertexCount, draw->instanceCount);
}
}
@@ -253,13 +253,13 @@
GLenum formatType = IndexFormatType(indexBufferFormat);
if (draw->firstInstance > 0) {
- glDrawElementsInstancedBaseInstance(GL_TRIANGLES,
+ glDrawElementsInstancedBaseInstance(lastRenderPipeline->GetGLPrimitiveTopology(),
draw->indexCount, formatType,
reinterpret_cast<void*>(draw->firstIndex * formatSize + indexBufferOffset),
draw->instanceCount, draw->firstInstance);
} else {
// This branch is only needed on OpenGL < 4.2
- glDrawElementsInstanced(GL_TRIANGLES,
+ glDrawElementsInstanced(lastRenderPipeline->GetGLPrimitiveTopology(),
draw->indexCount, formatType,
reinterpret_cast<void*>(draw->firstIndex * formatSize + indexBufferOffset),
draw->instanceCount);
diff --git a/src/backend/opengl/RenderPipelineGL.cpp b/src/backend/opengl/RenderPipelineGL.cpp
index 1492b4f..8992a84 100644
--- a/src/backend/opengl/RenderPipelineGL.cpp
+++ b/src/backend/opengl/RenderPipelineGL.cpp
@@ -21,8 +21,32 @@
namespace backend {
namespace opengl {
+ namespace {
+ GLenum GLPrimitiveTopology(nxt::PrimitiveTopology primitiveTopology) {
+ switch (primitiveTopology) {
+ case nxt::PrimitiveTopology::Point:
+ return GL_POINTS;
+ case nxt::PrimitiveTopology::Line:
+ return GL_LINES;
+ case nxt::PrimitiveTopology::LineStrip:
+ return GL_LINE_STRIP;
+ case nxt::PrimitiveTopology::Triangle:
+ return GL_TRIANGLES;
+ case nxt::PrimitiveTopology::TriangleStrip:
+ return GL_TRIANGLE_STRIP;
+ default:
+ UNREACHABLE();
+ }
+ }
+ }
+
RenderPipeline::RenderPipeline(RenderPipelineBuilder* builder)
- : RenderPipelineBase(builder), PipelineGL(this, builder) {
+ : RenderPipelineBase(builder), PipelineGL(this, builder),
+ glPrimitiveTopology(GLPrimitiveTopology(GetPrimitiveTopology())) {
+ }
+
+ GLenum RenderPipeline::GetGLPrimitiveTopology() const {
+ return glPrimitiveTopology;
}
void RenderPipeline::ApplyNow(PersistentPipelineState &persistentPipelineState) {
diff --git a/src/backend/opengl/RenderPipelineGL.h b/src/backend/opengl/RenderPipelineGL.h
index 3c5edde..e16ce25 100644
--- a/src/backend/opengl/RenderPipelineGL.h
+++ b/src/backend/opengl/RenderPipelineGL.h
@@ -32,7 +32,12 @@
public:
RenderPipeline(RenderPipelineBuilder* builder);
+ GLenum GetGLPrimitiveTopology() const;
+
void ApplyNow(PersistentPipelineState &persistentPipelineState);
+
+ private:
+ GLenum glPrimitiveTopology;
};
}