Implement 'baseVertex' in drawIndexed() on D3D12, Metal and Vulkan

This patch adds the support of the parameter 'baseVertex' of drawIndexed
on D3D12, Metal and Vulkan back-ends.

BUG=dawn:51
TEST=dawn_end2end_tests

Change-Id: Ibd25884ad2abceaaed744d74c4ee6b0ae6b3fa1b
Reviewed-on: https://dawn-review.googlesource.com/c/3221
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/BUILD.gn b/BUILD.gn
index 195204c..885fa0b 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -884,7 +884,7 @@
     "src/tests/end2end/ComputeCopyStorageBufferTests.cpp",
     "src/tests/end2end/CopyTests.cpp",
     "src/tests/end2end/DepthStencilStateTests.cpp",
-    "src/tests/end2end/DrawElementsTests.cpp",
+    "src/tests/end2end/DrawIndexedTests.cpp",
     "src/tests/end2end/FenceTests.cpp",
     "src/tests/end2end/IndexFormatTests.cpp",
     "src/tests/end2end/InputStateTests.cpp",
diff --git a/dawn.json b/dawn.json
index 017b60b..ade8a303 100644
--- a/dawn.json
+++ b/dawn.json
@@ -839,6 +839,7 @@
                     {"name": "index count", "type": "uint32_t"},
                     {"name": "instance count", "type": "uint32_t"},
                     {"name": "first index", "type": "uint32_t"},
+                    {"name": "base vertex", "type": "uint32_t"},
                     {"name": "first instance", "type": "uint32_t"}
                 ]
             },
diff --git a/examples/CppHelloTriangle.cpp b/examples/CppHelloTriangle.cpp
index 03f629c..9c2cb20 100644
--- a/examples/CppHelloTriangle.cpp
+++ b/examples/CppHelloTriangle.cpp
@@ -164,7 +164,7 @@
         pass.SetBindGroup(0, bindGroup);
         pass.SetVertexBuffers(0, 1, &vertexBuffer, vertexBufferOffsets);
         pass.SetIndexBuffer(indexBuffer, 0);
-        pass.DrawIndexed(3, 1, 0, 0);
+        pass.DrawIndexed(3, 1, 0, 0, 0);
         pass.EndPass();
     }
 
diff --git a/examples/CubeReflection.cpp b/examples/CubeReflection.cpp
index c0d66fa..5dd6e91 100644
--- a/examples/CubeReflection.cpp
+++ b/examples/CubeReflection.cpp
@@ -278,18 +278,18 @@
         pass.SetBindGroup(0, bindGroup[0]);
         pass.SetVertexBuffers(0, 1, &vertexBuffer, vertexBufferOffsets);
         pass.SetIndexBuffer(indexBuffer, 0);
-        pass.DrawIndexed(36, 1, 0, 0);
+        pass.DrawIndexed(36, 1, 0, 0, 0);
 
         pass.SetStencilReference(0x1);
         pass.SetRenderPipeline(planePipeline);
         pass.SetBindGroup(0, bindGroup[0]);
         pass.SetVertexBuffers(0, 1, &planeBuffer, vertexBufferOffsets);
-        pass.DrawIndexed(6, 1, 0, 0);
+        pass.DrawIndexed(6, 1, 0, 0, 0);
 
         pass.SetRenderPipeline(reflectionPipeline);
         pass.SetVertexBuffers(0, 1, &vertexBuffer, vertexBufferOffsets);
         pass.SetBindGroup(0, bindGroup[1]);
-        pass.DrawIndexed(36, 1, 0, 0);
+        pass.DrawIndexed(36, 1, 0, 0, 0);
 
         pass.EndPass();
     }
diff --git a/examples/glTFViewer/glTFViewer.cpp b/examples/glTFViewer/glTFViewer.cpp
index 61dd6a5..5ba0597 100644
--- a/examples/glTFViewer/glTFViewer.cpp
+++ b/examples/glTFViewer/glTFViewer.cpp
@@ -537,7 +537,7 @@
                 }
                 const auto& oIndicesBuffer = buffers.at(iIndices.bufferView);
                 pass.SetIndexBuffer(oIndicesBuffer, static_cast<uint32_t>(iIndices.byteOffset));
-                pass.DrawIndexed(static_cast<uint32_t>(iIndices.count), 1, 0, 0);
+                pass.DrawIndexed(static_cast<uint32_t>(iIndices.count), 1, 0, 0, 0);
             } else {
                 // DrawArrays
                 pass.Draw(vertexCount, 1, 0, 0);
diff --git a/src/dawn_native/Commands.h b/src/dawn_native/Commands.h
index 62a3bf6..37b2297 100644
--- a/src/dawn_native/Commands.h
+++ b/src/dawn_native/Commands.h
@@ -103,6 +103,7 @@
         uint32_t indexCount;
         uint32_t instanceCount;
         uint32_t firstIndex;
+        uint32_t baseVertex;
         uint32_t firstInstance;
     };
 
diff --git a/src/dawn_native/RenderPassEncoder.cpp b/src/dawn_native/RenderPassEncoder.cpp
index 2a8912d..e61e8b5 100644
--- a/src/dawn_native/RenderPassEncoder.cpp
+++ b/src/dawn_native/RenderPassEncoder.cpp
@@ -48,6 +48,7 @@
     void RenderPassEncoderBase::DrawIndexed(uint32_t indexCount,
                                             uint32_t instanceCount,
                                             uint32_t firstIndex,
+                                            uint32_t baseVertex,
                                             uint32_t firstInstance) {
         if (mTopLevelBuilder->ConsumedError(ValidateCanRecordCommands())) {
             return;
@@ -58,6 +59,7 @@
         draw->indexCount = indexCount;
         draw->instanceCount = instanceCount;
         draw->firstIndex = firstIndex;
+        draw->baseVertex = baseVertex;
         draw->firstInstance = firstInstance;
     }
 
diff --git a/src/dawn_native/RenderPassEncoder.h b/src/dawn_native/RenderPassEncoder.h
index d8b618f..806cdc8 100644
--- a/src/dawn_native/RenderPassEncoder.h
+++ b/src/dawn_native/RenderPassEncoder.h
@@ -37,6 +37,7 @@
         void DrawIndexed(uint32_t vertexCount,
                          uint32_t instanceCount,
                          uint32_t firstIndex,
+                         uint32_t baseVertex,
                          uint32_t firstInstance);
 
         void SetRenderPipeline(RenderPipelineBase* pipeline);
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.cpp b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
index 37bce73..5db2375 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.cpp
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
@@ -550,7 +550,8 @@
                 case Command::DrawIndexed: {
                     DrawIndexedCmd* draw = mCommands.NextCommand<DrawIndexedCmd>();
                     commandList->DrawIndexedInstanced(draw->indexCount, draw->instanceCount,
-                                                      draw->firstIndex, 0, draw->firstInstance);
+                                                      draw->firstIndex, draw->baseVertex,
+                                                      draw->firstInstance);
                 } break;
 
                 case Command::SetRenderPipeline: {
diff --git a/src/dawn_native/metal/CommandBufferMTL.mm b/src/dawn_native/metal/CommandBufferMTL.mm
index 7726fcc..1b2bb8e 100644
--- a/src/dawn_native/metal/CommandBufferMTL.mm
+++ b/src/dawn_native/metal/CommandBufferMTL.mm
@@ -424,7 +424,7 @@
                                   indexBuffer:indexBuffer
                             indexBufferOffset:indexBufferBaseOffset + draw->firstIndex * formatSize
                                 instanceCount:draw->instanceCount
-                                   baseVertex:0
+                                   baseVertex:draw->baseVertex
                                  baseInstance:draw->firstInstance];
                 } break;
 
diff --git a/src/dawn_native/vulkan/CommandBufferVk.cpp b/src/dawn_native/vulkan/CommandBufferVk.cpp
index ab165b0..74e7964 100644
--- a/src/dawn_native/vulkan/CommandBufferVk.cpp
+++ b/src/dawn_native/vulkan/CommandBufferVk.cpp
@@ -332,9 +332,9 @@
                     DrawIndexedCmd* draw = mCommands.NextCommand<DrawIndexedCmd>();
 
                     descriptorSets.Flush(device, commands, VK_PIPELINE_BIND_POINT_GRAPHICS);
-                    uint32_t vertexOffset = 0;
                     device->fn.CmdDrawIndexed(commands, draw->indexCount, draw->instanceCount,
-                                              draw->firstIndex, vertexOffset, draw->firstInstance);
+                                              draw->firstIndex, draw->baseVertex,
+                                              draw->firstInstance);
                 } break;
 
                 case Command::SetBindGroup: {
diff --git a/src/tests/end2end/DrawElementsTests.cpp b/src/tests/end2end/DrawIndexedTests.cpp
similarity index 70%
rename from src/tests/end2end/DrawElementsTests.cpp
rename to src/tests/end2end/DrawIndexedTests.cpp
index 906ad7e..e618713 100644
--- a/src/tests/end2end/DrawElementsTests.cpp
+++ b/src/tests/end2end/DrawIndexedTests.cpp
@@ -19,7 +19,7 @@
 
 constexpr uint32_t kRTSize = 4;
 
-class DrawElementsTest : public DawnTest {
+class DrawIndexedTest : public DawnTest {
     protected:
         void SetUp() override {
             DawnTest::SetUp();
@@ -59,10 +59,17 @@
             pipeline = device.CreateRenderPipeline(&descriptor);
 
             vertexBuffer = utils::CreateBufferFromData<float>(device, dawn::BufferUsageBit::Vertex, {
+                // First quad: the first 3 vertices represent the bottom left triangle
                 -1.0f, -1.0f, 0.0f, 1.0f,
                  1.0f,  1.0f, 0.0f, 1.0f,
                 -1.0f,  1.0f, 0.0f, 1.0f,
-                 1.0f, -1.0f, 0.0f, 1.0f
+                 1.0f, -1.0f, 0.0f, 1.0f,
+
+                 // Second quad: the first 3 vertices represent the top right triangle
+                -1.0f, -1.0f, 0.0f, 1.0f,
+                 1.0f,  1.0f, 0.0f, 1.0f,
+                 1.0f, -1.0f, 0.0f, 1.0f,
+                -1.0f,  1.0f, 0.0f, 1.0f
             });
             indexBuffer = utils::CreateBufferFromData<uint32_t>(device, dawn::BufferUsageBit::Index, {
                 0, 1, 2, 0, 3, 1
@@ -75,7 +82,8 @@
         dawn::Buffer indexBuffer;
 
         void Test(uint32_t indexCount, uint32_t instanceCount, uint32_t firstIndex,
-                  uint32_t firstInstance, RGBA8 bottomLeftExpected, RGBA8 topRightExpected) {
+                  uint32_t baseVertex, uint32_t firstInstance, RGBA8 bottomLeftExpected,
+                  RGBA8 topRightExpected) {
             uint32_t zeroOffset = 0;
             dawn::CommandBufferBuilder builder = device.CreateCommandBufferBuilder();
             {
@@ -83,7 +91,7 @@
                 pass.SetRenderPipeline(pipeline);
                 pass.SetVertexBuffers(0, 1, &vertexBuffer, &zeroOffset);
                 pass.SetIndexBuffer(indexBuffer, 0);
-                pass.DrawIndexed(indexCount, instanceCount, firstIndex, firstInstance);
+                pass.DrawIndexed(indexCount, instanceCount, firstIndex, baseVertex, firstInstance);
                 pass.EndPass();
             }
 
@@ -95,20 +103,34 @@
         }
 };
 
-// The most basic DrawElements triangle draw.
-TEST_P(DrawElementsTest, Uint32) {
+// The most basic DrawIndexed triangle draw.
+TEST_P(DrawIndexedTest, Uint32) {
 
     RGBA8 filled(0, 255, 0, 255);
     RGBA8 notFilled(0, 0, 0, 0);
 
     // Test a draw with no indices.
-    Test(0, 0, 0, 0, notFilled, notFilled);
-    // Test a draw with only the first 3 indices (bottom left triangle)
-    Test(3, 1, 0, 0, filled, notFilled);
-    // Test a draw with only the last 3 indices (top right triangle)
-    Test(3, 1, 3, 0, notFilled, filled);
+    Test(0, 0, 0, 0, 0, notFilled, notFilled);
+    // Test a draw with only the first 3 indices of the first quad (bottom left triangle)
+    Test(3, 1, 0, 0, 0, filled, notFilled);
+    // Test a draw with only the last 3 indices of the first quad (top right triangle)
+    Test(3, 1, 3, 0, 0, notFilled, filled);
     // Test a draw with all 6 indices (both triangles).
-    Test(6, 1, 0, 0, filled, filled);
+    Test(6, 1, 0, 0, 0, filled, filled);
 }
 
-DAWN_INSTANTIATE_TEST(DrawElementsTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend)
+// Test the parameter 'baseVertex' of DrawIndexed() works.
+TEST_P(DrawIndexedTest, BaseVertex) {
+    // TODO(jiawei.shao@intel.com): enable 'baseVertex' on OpenGL back-ends
+    DAWN_SKIP_TEST_IF(IsOpenGL());
+
+    RGBA8 filled(0, 255, 0, 255);
+    RGBA8 notFilled(0, 0, 0, 0);
+
+    // Test a draw with only the first 3 indices of the second quad (top right triangle)
+    Test(3, 1, 0, 4, 0, notFilled, filled);
+    // Test a draw with only the last 3 indices of the second quad (bottom left triangle)
+    Test(3, 1, 3, 4, 0, filled, notFilled);
+}
+
+DAWN_INSTANTIATE_TEST(DrawIndexedTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend)
diff --git a/src/tests/end2end/IndexFormatTests.cpp b/src/tests/end2end/IndexFormatTests.cpp
index 4fd9a96..980b0da 100644
--- a/src/tests/end2end/IndexFormatTests.cpp
+++ b/src/tests/end2end/IndexFormatTests.cpp
@@ -87,7 +87,7 @@
         pass.SetRenderPipeline(pipeline);
         pass.SetVertexBuffers(0, 1, &vertexBuffer, &zeroOffset);
         pass.SetIndexBuffer(indexBuffer, 0);
-        pass.DrawIndexed(3, 1, 0, 0);
+        pass.DrawIndexed(3, 1, 0, 0, 0);
         pass.EndPass();
     }
 
@@ -118,7 +118,7 @@
         pass.SetRenderPipeline(pipeline);
         pass.SetVertexBuffers(0, 1, &vertexBuffer, &zeroOffset);
         pass.SetIndexBuffer(indexBuffer, 0);
-        pass.DrawIndexed(3, 1, 0, 0);
+        pass.DrawIndexed(3, 1, 0, 0, 0);
         pass.EndPass();
     }
 
@@ -162,7 +162,7 @@
         pass.SetRenderPipeline(pipeline);
         pass.SetVertexBuffers(0, 1, &vertexBuffer, &zeroOffset);
         pass.SetIndexBuffer(indexBuffer, 0);
-        pass.DrawIndexed(7, 1, 0, 0);
+        pass.DrawIndexed(7, 1, 0, 0, 0);
         pass.EndPass();
     }
 
@@ -196,7 +196,7 @@
         pass.SetRenderPipeline(pipeline);
         pass.SetVertexBuffers(0, 1, &vertexBuffer, &zeroOffset);
         pass.SetIndexBuffer(indexBuffer, 0);
-        pass.DrawIndexed(7, 1, 0, 0);
+        pass.DrawIndexed(7, 1, 0, 0, 0);
         pass.EndPass();
     }
 
@@ -236,7 +236,7 @@
         pass.SetVertexBuffers(0, 1, &vertexBuffer, &zeroOffset);
         pass.SetIndexBuffer(indexBuffer, 0);
         pass.SetRenderPipeline(pipeline32);
-        pass.DrawIndexed(3, 1, 0, 0);
+        pass.DrawIndexed(3, 1, 0, 0, 0);
         pass.EndPass();
     }
 
@@ -270,7 +270,7 @@
         pass.SetIndexBuffer(indexBuffer, 0);
         pass.SetRenderPipeline(pipeline);
         pass.SetVertexBuffers(0, 1, &vertexBuffer, &zeroOffset);
-        pass.DrawIndexed(3, 1, 0, 0);
+        pass.DrawIndexed(3, 1, 0, 0, 0);
         pass.EndPass();
     }