Add missing locks to indirect draw validation & timestamp query.

Also added new multithreaded tests to verify them.

Bug: dawn:1662
Change-Id: I58ebe265edf58e0c4eb5d9337d3441a6bb972ed4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/126781
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Quyen Le <lehoangquyen@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/dawn/native/BlitBufferToDepthStencil.cpp b/src/dawn/native/BlitBufferToDepthStencil.cpp
index 7ec11c6..f4db65e 100644
--- a/src/dawn/native/BlitBufferToDepthStencil.cpp
+++ b/src/dawn/native/BlitBufferToDepthStencil.cpp
@@ -223,6 +223,7 @@
                                  TextureBase* dataTexture,
                                  const TextureCopy& dst,
                                  const Extent3D& copyExtent) {
+    ASSERT(device->IsLockedByCurrentThreadIfNeeded());
     ASSERT(dst.texture->GetFormat().format == wgpu::TextureFormat::Depth16Unorm);
     ASSERT(dataTexture->GetFormat().format == wgpu::TextureFormat::RG8Uint);
 
@@ -297,7 +298,7 @@
         RenderPassDescriptor rpDesc = {};
         rpDesc.depthStencilAttachment = &dsAttachment;
 
-        Ref<RenderPassEncoder> pass = AcquireRef(commandEncoder->APIBeginRenderPass(&rpDesc));
+        Ref<RenderPassEncoder> pass = commandEncoder->BeginRenderPass(&rpDesc);
         // Bind the resources.
         pass->APISetBindGroup(0, bindGroup.Get());
         // Discard all fragments outside the copy region.
@@ -307,7 +308,7 @@
         pass->APISetPipeline(pipeline.Get());
         pass->APIDraw(3, 1, 0, 0);
 
-        pass->APIEnd();
+        pass->End();
     }
     return {};
 }
@@ -317,6 +318,7 @@
                            TextureBase* dataTexture,
                            const TextureCopy& dst,
                            const Extent3D& copyExtent) {
+    ASSERT(device->IsLockedByCurrentThreadIfNeeded());
     const Format& format = dst.texture->GetFormat();
     ASSERT(dst.aspect == Aspect::Stencil);
 
@@ -415,7 +417,7 @@
         RenderPassDescriptor rpDesc = {};
         rpDesc.depthStencilAttachment = &dsAttachment;
 
-        Ref<RenderPassEncoder> pass = AcquireRef(commandEncoder->APIBeginRenderPass(&rpDesc));
+        Ref<RenderPassEncoder> pass = commandEncoder->BeginRenderPass(&rpDesc);
         // Bind the resources.
         pass->APISetBindGroup(0, bindGroup.Get());
         // Discard all fragments outside the copy region.
@@ -438,7 +440,7 @@
             // since WebGPU doesn't have push constants.
             pass->APIDraw(3, 1, 0, 1u << bit);
         }
-        pass->APIEnd();
+        pass->End();
     }
     return {};
 }
diff --git a/src/dawn/native/BlitDepthToDepth.cpp b/src/dawn/native/BlitDepthToDepth.cpp
index 792690d..8293099 100644
--- a/src/dawn/native/BlitDepthToDepth.cpp
+++ b/src/dawn/native/BlitDepthToDepth.cpp
@@ -97,6 +97,7 @@
                             const TextureCopy& src,
                             const TextureCopy& dst,
                             const Extent3D& copyExtent) {
+    ASSERT(device->IsLockedByCurrentThreadIfNeeded());
     // ASSERT that the texture have depth and are not multisampled.
     ASSERT(src.texture->GetFormat().HasDepth());
     ASSERT(dst.texture->GetFormat().HasDepth());
@@ -216,11 +217,11 @@
         rpDesc.depthStencilAttachment = &dsAttachment;
 
         // Draw to perform the blit.
-        Ref<RenderPassEncoder> pass = AcquireRef(commandEncoder->APIBeginRenderPass(&rpDesc));
+        Ref<RenderPassEncoder> pass = commandEncoder->BeginRenderPass(&rpDesc);
         pass->APISetBindGroup(0, bindGroup.Get());
         pass->APISetPipeline(pipeline.Get());
         pass->APIDraw(3, 1, 0, 0);
-        pass->APIEnd();
+        pass->End();
     }
 
     return {};
diff --git a/src/dawn/native/CommandEncoder.cpp b/src/dawn/native/CommandEncoder.cpp
index 8850157..f74e0b3 100644
--- a/src/dawn/native/CommandEncoder.cpp
+++ b/src/dawn/native/CommandEncoder.cpp
@@ -1560,6 +1560,11 @@
             // Encode internal compute pipeline for timestamp query
             if (querySet->GetQueryType() == wgpu::QueryType::Timestamp &&
                 !GetDevice()->IsToggleEnabled(Toggle::DisableTimestampQueryConversion)) {
+                // The below function might create new resources. Need to lock the Device.
+                // TODO(crbug.com/dawn/1618): In future, all temp resources should be created at
+                // Command Submit time, so the locking would be removed from here at that point.
+                auto deviceLock(GetDevice()->GetScopedLock());
+
                 DAWN_TRY(EncodeTimestampsToNanosecondsConversion(
                     this, querySet, firstQuery, queryCount, destination, destinationOffset));
             }
diff --git a/src/dawn/native/CopyTextureForBrowserHelper.cpp b/src/dawn/native/CopyTextureForBrowserHelper.cpp
index f22ce31..0013ba8 100644
--- a/src/dawn/native/CopyTextureForBrowserHelper.cpp
+++ b/src/dawn/native/CopyTextureForBrowserHelper.cpp
@@ -583,7 +583,7 @@
     passEncoder->APISetViewport(destination->origin.x, destination->origin.y, copySize->width,
                                 copySize->height, 0.0, 1.0);
     passEncoder->APIDraw(3);
-    passEncoder->APIEnd();
+    passEncoder->End();
 
     // Finsh encoding.
     Ref<CommandBufferBase> commandBuffer;
diff --git a/src/dawn/native/EncodingContext.cpp b/src/dawn/native/EncodingContext.cpp
index 0682a5a..eb95466 100644
--- a/src/dawn/native/EncodingContext.cpp
+++ b/src/dawn/native/EncodingContext.cpp
@@ -133,9 +133,20 @@
         // Note: If encoding validation commands fails, no commands should be in mPendingCommands,
         //       so swap back the renderCommands to ensure that they are not leaked.
         CommandAllocator renderCommands = std::move(mPendingCommands);
-        DAWN_TRY_WITH_CLEANUP(EncodeIndirectDrawValidationCommands(
-                                  mDevice, commandEncoder, &usageTracker, &indirectDrawMetadata),
-                              { mPendingCommands = std::move(renderCommands); });
+
+        // The below function might create new resources. Device must already be locked via
+        // renderpassEncoder's APIEnd().
+        // TODO(crbug.com/dawn/1618): In future, all temp resources should be created at
+        // Command Submit time, so the locking would be removed from here at that point.
+        {
+            ASSERT(mDevice->IsLockedByCurrentThreadIfNeeded());
+
+            DAWN_TRY_WITH_CLEANUP(
+                EncodeIndirectDrawValidationCommands(mDevice, commandEncoder, &usageTracker,
+                                                     &indirectDrawMetadata),
+                { mPendingCommands = std::move(renderCommands); });
+        }
+
         CommitCommands(std::move(mPendingCommands));
         CommitCommands(std::move(renderCommands));
     }
diff --git a/src/dawn/native/IndirectDrawValidationEncoder.cpp b/src/dawn/native/IndirectDrawValidationEncoder.cpp
index a94cfa5..95851b8 100644
--- a/src/dawn/native/IndirectDrawValidationEncoder.cpp
+++ b/src/dawn/native/IndirectDrawValidationEncoder.cpp
@@ -212,8 +212,6 @@
         computePipelineDescriptor.compute.module = store->renderValidationShader.Get();
         computePipelineDescriptor.compute.entryPoint = "main";
 
-        // This will create new resource so we have to lock the device.
-        auto deviceLock(device->GetScopedLock());
         DAWN_TRY_ASSIGN(store->renderValidationPipeline,
                         device->CreateComputePipeline(&computePipelineDescriptor));
     }
@@ -241,6 +239,7 @@
                                                 CommandEncoder* commandEncoder,
                                                 RenderPassResourceUsageTracker* usageTracker,
                                                 IndirectDrawMetadata* indirectDrawMetadata) {
+    ASSERT(device->IsLockedByCurrentThreadIfNeeded());
     // Since encoding validation commands may create new objects, verify that the device is alive.
     // TODO(dawn:1199): This check is obsolete if device loss causes device.destroy().
     //   - This function only happens within the context of a TryEncode which would catch the
diff --git a/src/dawn/native/QueryHelper.cpp b/src/dawn/native/QueryHelper.cpp
index 4b7cce6..925e3be 100644
--- a/src/dawn/native/QueryHelper.cpp
+++ b/src/dawn/native/QueryHelper.cpp
@@ -187,6 +187,7 @@
                                                 BufferBase* availability,
                                                 BufferBase* params) {
     DeviceBase* device = encoder->GetDevice();
+    ASSERT(device->IsLockedByCurrentThreadIfNeeded());
 
     ComputePipelineBase* pipeline;
     DAWN_TRY_ASSIGN(pipeline, GetOrCreateTimestampComputePipeline(device));
diff --git a/src/dawn/native/RenderPassEncoder.cpp b/src/dawn/native/RenderPassEncoder.cpp
index 5b3e409..aaba03f 100644
--- a/src/dawn/native/RenderPassEncoder.cpp
+++ b/src/dawn/native/RenderPassEncoder.cpp
@@ -136,6 +136,14 @@
 }
 
 void RenderPassEncoder::APIEnd() {
+    // The encoding context might create additional resources, so we need to lock the device.
+    auto deviceLock(GetDevice()->GetScopedLock());
+    End();
+}
+
+void RenderPassEncoder::End() {
+    ASSERT(GetDevice()->IsLockedByCurrentThreadIfNeeded());
+
     if (mEnded && IsValidationEnabled()) {
         GetDevice()->HandleError(DAWN_VALIDATION_ERROR("%s was already ended.", this));
         return;
diff --git a/src/dawn/native/RenderPassEncoder.h b/src/dawn/native/RenderPassEncoder.h
index 6328852..592869d 100644
--- a/src/dawn/native/RenderPassEncoder.h
+++ b/src/dawn/native/RenderPassEncoder.h
@@ -44,6 +44,8 @@
 
     ObjectType GetType() const override;
 
+    // NOTE: this will lock the device internally. To avoid deadlock when the device is already
+    // locked, use End() instead.
     void APIEnd();
     void APIEndPass();  // TODO(dawn:1286): Remove after deprecation period.
 
@@ -63,6 +65,10 @@
 
     void APIWriteTimestamp(QuerySetBase* querySet, uint32_t queryIndex);
 
+    // Internal code that already locked the device should call this method instead of
+    // APIEnd() to avoid the device being locked again.
+    void End();
+
   protected:
     RenderPassEncoder(DeviceBase* device,
                       const RenderPassDescriptor* descriptor,
diff --git a/src/dawn/tests/end2end/MultithreadTests.cpp b/src/dawn/tests/end2end/MultithreadTests.cpp
index fdd0fae..9f88a6a 100644
--- a/src/dawn/tests/end2end/MultithreadTests.cpp
+++ b/src/dawn/tests/end2end/MultithreadTests.cpp
@@ -20,11 +20,20 @@
 
 #include "dawn/common/Constants.h"
 #include "dawn/common/Math.h"
+#include "dawn/common/Mutex.h"
 #include "dawn/tests/DawnTest.h"
+#include "dawn/utils/ComboRenderPipelineDescriptor.h"
 #include "dawn/utils/TestUtils.h"
 #include "dawn/utils/TextureUtils.h"
 #include "dawn/utils/WGPUHelpers.h"
 
+#define LOCKED_CMD(CMD)                   \
+    do {                                  \
+        dawn::Mutex::AutoLock lk(&mutex); \
+        CMD;                              \
+    } while (0)
+
+namespace {
 class MultithreadTests : public DawnTest {
   protected:
     std::vector<wgpu::FeatureName> GetRequiredFeatures() override {
@@ -78,6 +87,8 @@
             thread->join();
         }
     }
+
+    dawn::Mutex mutex;
 };
 
 class MultithreadEncodingTests : public MultithreadTests {};
@@ -171,6 +182,212 @@
     }
 }
 
+class MultithreadDrawIndexedIndirectTests : public MultithreadTests {
+  protected:
+    void SetUp() override {
+        MultithreadTests::SetUp();
+
+        wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
+            @vertex
+            fn main(@location(0) pos : vec4f) -> @builtin(position) vec4f {
+                return pos;
+            })");
+
+        wgpu::ShaderModule fsModule = utils::CreateShaderModule(device, R"(
+            @fragment fn main() -> @location(0) vec4f {
+                return vec4f(0.0, 1.0, 0.0, 1.0);
+            })");
+
+        utils::ComboRenderPipelineDescriptor descriptor;
+        descriptor.vertex.module = vsModule;
+        descriptor.cFragment.module = fsModule;
+        descriptor.primitive.topology = wgpu::PrimitiveTopology::TriangleStrip;
+        descriptor.primitive.stripIndexFormat = wgpu::IndexFormat::Uint32;
+        descriptor.vertex.bufferCount = 1;
+        descriptor.cBuffers[0].arrayStride = 4 * sizeof(float);
+        descriptor.cBuffers[0].attributeCount = 1;
+        descriptor.cAttributes[0].format = wgpu::VertexFormat::Float32x4;
+        descriptor.cTargets[0].format = utils::BasicRenderPass::kDefaultColorFormat;
+
+        pipeline = device.CreateRenderPipeline(&descriptor);
+
+        vertexBuffer = utils::CreateBufferFromData<float>(
+            device, wgpu::BufferUsage::Vertex,
+            {// First quad: the first 3 vertices represent the bottom left triangle
+             -1.0f, 1.0f, 0.0f, 1.0f, 1.0f, -1.0f, 0.0f, 1.0f, -1.0f, -1.0f, 0.0f, 1.0f, 1.0f, 1.0f,
+             0.0f, 1.0f,
+
+             // Second quad: the first 3 vertices represent the top right triangle
+             -1.0f, 1.0f, 0.0f, 1.0f, 1.0f, -1.0f, 0.0f, 1.0f, 1.0f, 1.0f, 0.0f, 1.0f, -1.0f, -1.0f,
+             0.0f, 1.0f});
+    }
+
+    void Test(std::initializer_list<uint32_t> bufferList,
+              uint64_t indexOffset,
+              uint64_t indirectOffset,
+              utils::RGBA8 bottomLeftExpected,
+              utils::RGBA8 topRightExpected) {
+        utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(
+            device, kRTSize, kRTSize, utils::BasicRenderPass::kDefaultColorFormat);
+        wgpu::Buffer indexBuffer =
+            CreateIndexBuffer({0, 1, 2, 0, 3, 1,
+                               // The indices below are added to test negatve baseVertex
+                               0 + 4, 1 + 4, 2 + 4, 0 + 4, 3 + 4, 1 + 4});
+        TestDraw(
+            renderPass, bottomLeftExpected, topRightExpected,
+            EncodeDrawCommands(bufferList, indexBuffer, indexOffset, indirectOffset, renderPass));
+    }
+
+  private:
+    wgpu::Buffer CreateIndirectBuffer(std::initializer_list<uint32_t> indirectParamList) {
+        return utils::CreateBufferFromData<uint32_t>(
+            device, wgpu::BufferUsage::Indirect | wgpu::BufferUsage::Storage, indirectParamList);
+    }
+
+    wgpu::Buffer CreateIndexBuffer(std::initializer_list<uint32_t> indexList) {
+        return utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::Index, indexList);
+    }
+
+    wgpu::CommandBuffer EncodeDrawCommands(std::initializer_list<uint32_t> bufferList,
+                                           wgpu::Buffer indexBuffer,
+                                           uint64_t indexOffset,
+                                           uint64_t indirectOffset,
+                                           const utils::BasicRenderPass& renderPass) {
+        wgpu::Buffer indirectBuffer = CreateIndirectBuffer(bufferList);
+
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        {
+            wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+            pass.SetPipeline(pipeline);
+            pass.SetVertexBuffer(0, vertexBuffer);
+            pass.SetIndexBuffer(indexBuffer, wgpu::IndexFormat::Uint32, indexOffset);
+            pass.DrawIndexedIndirect(indirectBuffer, indirectOffset);
+            pass.End();
+        }
+
+        return encoder.Finish();
+    }
+
+    void TestDraw(const utils::BasicRenderPass& renderPass,
+                  utils::RGBA8 bottomLeftExpected,
+                  utils::RGBA8 topRightExpected,
+                  wgpu::CommandBuffer commands) {
+        queue.Submit(1, &commands);
+
+        LOCKED_CMD(EXPECT_PIXEL_RGBA8_EQ(bottomLeftExpected, renderPass.color, 1, 3));
+        LOCKED_CMD(EXPECT_PIXEL_RGBA8_EQ(topRightExpected, renderPass.color, 3, 1));
+    }
+
+    wgpu::RenderPipeline pipeline;
+    wgpu::Buffer vertexBuffer;
+    static constexpr uint32_t kRTSize = 4;
+};
+
+// Test indirect draws with offsets on multiple threads.
+TEST_P(MultithreadDrawIndexedIndirectTests, IndirectOffsetInParallel) {
+    // TODO(crbug.com/dawn/789): Test is failing after a roll on SwANGLE on Windows only.
+    DAWN_SUPPRESS_TEST_IF(IsANGLE() && IsWindows());
+
+    // TODO(crbug.com/dawn/1292): Some Intel OpenGL drivers don't seem to like
+    // the offsets that Tint/GLSL produces.
+    DAWN_SUPPRESS_TEST_IF(IsIntel() && IsOpenGL() && IsLinux());
+
+    utils::RGBA8 filled(0, 255, 0, 255);
+    utils::RGBA8 notFilled(0, 0, 0, 0);
+
+    RunInParallel(10, [=](uint32_t) {
+        // Test an offset draw call, with indirect buffer containing 2 calls:
+        // 1) first 3 indices of the second quad (top right triangle)
+        // 2) last 3 indices of the second quad
+
+        // Test #1 (no offset)
+        Test({3, 1, 0, 4, 0, 3, 1, 3, 4, 0}, 0, 0, notFilled, filled);
+
+        // Offset to draw #2
+        Test({3, 1, 0, 4, 0, 3, 1, 3, 4, 0}, 0, 5 * sizeof(uint32_t), filled, notFilled);
+    });
+}
+
+class TimestampExpectation : public detail::Expectation {
+  public:
+    ~TimestampExpectation() override = default;
+
+    // Expect the timestamp results are greater than 0.
+    testing::AssertionResult Check(const void* data, size_t size) override {
+        ASSERT(size % sizeof(uint64_t) == 0);
+        const uint64_t* timestamps = static_cast<const uint64_t*>(data);
+        for (size_t i = 0; i < size / sizeof(uint64_t); i++) {
+            if (timestamps[i] == 0) {
+                return testing::AssertionFailure()
+                       << "Expected data[" << i << "] to be greater than 0." << std::endl;
+            }
+        }
+
+        return testing::AssertionSuccess();
+    }
+};
+
+class MultithreadTimestampQueryTests : public MultithreadTests {
+  protected:
+    void SetUp() override {
+        MultithreadTests::SetUp();
+
+        // Skip all tests if timestamp feature is not supported
+        DAWN_TEST_UNSUPPORTED_IF(!SupportsFeatures({wgpu::FeatureName::TimestampQuery}));
+    }
+
+    std::vector<wgpu::FeatureName> GetRequiredFeatures() override {
+        std::vector<wgpu::FeatureName> requiredFeatures = MultithreadTests::GetRequiredFeatures();
+        if (SupportsFeatures({wgpu::FeatureName::TimestampQuery})) {
+            requiredFeatures.push_back(wgpu::FeatureName::TimestampQuery);
+        }
+        return requiredFeatures;
+    }
+
+    wgpu::QuerySet CreateQuerySetForTimestamp(uint32_t queryCount) {
+        wgpu::QuerySetDescriptor descriptor;
+        descriptor.count = queryCount;
+        descriptor.type = wgpu::QueryType::Timestamp;
+        return device.CreateQuerySet(&descriptor);
+    }
+
+    wgpu::Buffer CreateResolveBuffer(uint64_t size) {
+        return CreateBuffer(size, /*usage=*/wgpu::BufferUsage::QueryResolve |
+                                      wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst);
+    }
+};
+
+// Test resolving timestamp queries on multiple threads. ResolveQuerySet() will create temp
+// resources internally so we need to make sure they are thread safe.
+TEST_P(MultithreadTimestampQueryTests, ResolveQuerySets_InParallel) {
+    constexpr uint32_t kQueryCount = 2;
+    constexpr uint32_t kNumThreads = 10;
+
+    std::vector<wgpu::QuerySet> querySets(kNumThreads);
+    std::vector<wgpu::Buffer> destinations(kNumThreads);
+
+    for (size_t i = 0; i < kNumThreads; ++i) {
+        querySets[i] = CreateQuerySetForTimestamp(kQueryCount);
+        destinations[i] = CreateResolveBuffer(kQueryCount * sizeof(uint64_t));
+    }
+
+    RunInParallel(kNumThreads, [&](uint32_t index) {
+        const auto& querySet = querySets[index];
+        const auto& destination = destinations[index];
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        encoder.WriteTimestamp(querySet, 0);
+        encoder.WriteTimestamp(querySet, 1);
+        encoder.ResolveQuerySet(querySet, 0, kQueryCount, destination, 0);
+        wgpu::CommandBuffer commands = encoder.Finish();
+        queue.Submit(1, &commands);
+
+        LOCKED_CMD(EXPECT_BUFFER(destination, 0, kQueryCount * sizeof(uint64_t),
+                                 new TimestampExpectation));
+    });
+}
+
+}  // namespace
+
 DAWN_INSTANTIATE_TEST(MultithreadEncodingTests,
                       D3D11Backend(),
                       D3D12Backend(),
@@ -178,3 +395,19 @@
                       OpenGLBackend(),
                       OpenGLESBackend(),
                       VulkanBackend());
+
+DAWN_INSTANTIATE_TEST(MultithreadDrawIndexedIndirectTests,
+                      D3D11Backend(),
+                      D3D12Backend(),
+                      MetalBackend(),
+                      OpenGLBackend(),
+                      OpenGLESBackend(),
+                      VulkanBackend());
+
+DAWN_INSTANTIATE_TEST(MultithreadTimestampQueryTests,
+                      D3D11Backend(),
+                      D3D12Backend(),
+                      MetalBackend(),
+                      OpenGLBackend(),
+                      OpenGLESBackend(),
+                      VulkanBackend());