Remove global device lock from CreateShaderModule/Pipeline

This allows multithreaded creation of these objects without
acquiring the global device lock. Note that creating objects
while the device is being destroyed is still not thread-safe.

Tested on Linux TSAN to be data-race free.

Bug: dawn:1662
Change-Id: Ia715fbccad2392f6643919426ef7e4fc5c24166a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/175206
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/dawn/dawn.json b/src/dawn/dawn.json
index 27cd49c..7dbbaa4 100644
--- a/src/dawn/dawn.json
+++ b/src/dawn/dawn.json
@@ -1197,6 +1197,7 @@
             },
             {
                 "name": "create compute pipeline",
+                "no autolock": true,
                 "returns": "compute pipeline",
                 "args": [
                     {"name": "descriptor", "type": "compute pipeline descriptor", "annotation": "const*"}
@@ -1276,6 +1277,7 @@
             },
             {
                 "name": "create render pipeline",
+                "no autolock": true,
                 "returns": "render pipeline",
                 "args": [
                     {"name": "descriptor", "type": "render pipeline descriptor", "annotation": "const*"}
@@ -1290,6 +1292,7 @@
             },
             {
                 "name": "create shader module",
+                "no autolock": true,
                 "returns": "shader module",
                 "args": [
                     {"name": "descriptor", "type": "shader module descriptor", "annotation": "const*"}
diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp
index 304786b..d64df1d 100644
--- a/src/dawn/native/Device.cpp
+++ b/src/dawn/native/Device.cpp
@@ -965,14 +965,12 @@
 
 Ref<ComputePipelineBase> DeviceBase::AddOrGetCachedComputePipeline(
     Ref<ComputePipelineBase> computePipeline) {
-    DAWN_ASSERT(IsLockedByCurrentThreadIfNeeded());
     auto [pipeline, _] = mCaches->computePipelines.Insert(computePipeline.Get());
     return std::move(pipeline);
 }
 
 Ref<RenderPipelineBase> DeviceBase::AddOrGetCachedRenderPipeline(
     Ref<RenderPipelineBase> renderPipeline) {
-    DAWN_ASSERT(IsLockedByCurrentThreadIfNeeded());
     auto [pipeline, _] = mCaches->renderPipelines.Insert(renderPipeline.Get());
     return std::move(pipeline);
 }
@@ -1066,7 +1064,7 @@
 ResultOrError<Ref<ShaderModuleBase>> DeviceBase::GetOrCreateShaderModule(
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
     ShaderModuleParseResult* parseResult,
-    OwnedCompilationMessages* compilationMessages) {
+    std::unique_ptr<OwnedCompilationMessages>* compilationMessages) {
     DAWN_ASSERT(parseResult != nullptr);
 
     ShaderModuleBase blueprint(this, descriptor, ApiObjectBase::kUntrackedByDevice);
@@ -1076,18 +1074,19 @@
 
     return GetOrCreate(
         mCaches->shaderModules, &blueprint, [&]() -> ResultOrError<Ref<ShaderModuleBase>> {
+            auto* unownedMessages = compilationMessages ? compilationMessages->get() : nullptr;
             if (!parseResult->HasParsedShader()) {
                 // We skip the parse on creation if validation isn't enabled which let's us quickly
                 // lookup in the cache without validating and parsing. We need the parsed module
                 // now.
                 DAWN_ASSERT(!IsValidationEnabled());
-                DAWN_TRY(ValidateAndParseShaderModule(this, descriptor, parseResult,
-                                                      compilationMessages));
+                DAWN_TRY(
+                    ValidateAndParseShaderModule(this, descriptor, parseResult, unownedMessages));
             }
 
             auto resultOrError = [&]() -> ResultOrError<Ref<ShaderModuleBase>> {
                 SCOPED_DAWN_HISTOGRAM_TIMER_MICROS(GetPlatform(), "CreateShaderModuleUS");
-                return CreateShaderModuleImpl(descriptor, parseResult, compilationMessages);
+                return CreateShaderModuleImpl(descriptor, parseResult, unownedMessages);
             }();
             DAWN_HISTOGRAM_BOOLEAN(GetPlatform(), "CreateShaderModuleSuccess",
                                    resultOrError.IsSuccess());
@@ -1095,6 +1094,11 @@
             Ref<ShaderModuleBase> result;
             DAWN_TRY_ASSIGN(result, std::move(resultOrError));
             result->SetContentHash(blueprintHash);
+            // Inject compilation messages now, as another thread may get a cache hit and query them
+            // immediately after insert into the cache.
+            if (compilationMessages) {
+                result->InjectCompilationMessages(std::move(*compilationMessages));
+            }
             return result;
         });
 }
@@ -1171,6 +1175,7 @@
         }
         // Error case continues below, and will acquire the device lock for
         // thread-safe error handling.
+        // TODO(dawn:1662): Make error handling thread-safe.
     }
 
     auto deviceLock(GetScopedLock());
@@ -1199,8 +1204,16 @@
     TRACE_EVENT1(GetPlatform(), General, "DeviceBase::APICreateComputePipeline", "label",
                  utils::GetLabelForTrace(descriptor->label));
 
+    auto resultOrError = CreateComputePipeline(descriptor);
+    if (resultOrError.IsSuccess()) {
+        return ReturnToAPI(resultOrError.AcquireSuccess());
+    }
+
+    // Acquire the device lock for error handling.
+    // TODO(dawn:1662): Make error handling thread-safe.
+    auto deviceLock(GetScopedLock());
     Ref<ComputePipelineBase> result;
-    if (ConsumedError(CreateComputePipeline(descriptor), &result, InternalErrorType::Internal,
+    if (ConsumedError(std::move(resultOrError), &result, InternalErrorType::Internal,
                       "calling %s.CreateComputePipeline(%s).", this, descriptor)) {
         result = ComputePipelineBase::MakeError(this, descriptor ? descriptor->label : nullptr);
     }
@@ -1322,8 +1335,16 @@
     TRACE_EVENT1(GetPlatform(), General, "DeviceBase::APICreateRenderPipeline", "label",
                  utils::GetLabelForTrace(descriptor->label));
 
+    auto resultOrError = CreateRenderPipeline(descriptor);
+    if (resultOrError.IsSuccess()) {
+        return ReturnToAPI(resultOrError.AcquireSuccess());
+    }
+
+    // Acquire the device lock for error handling.
+    // TODO(dawn:1662): Make error handling thread-safe.
+    auto deviceLock(GetScopedLock());
     Ref<RenderPipelineBase> result;
-    if (ConsumedError(CreateRenderPipeline(descriptor), &result, InternalErrorType::Internal,
+    if (ConsumedError(std::move(resultOrError), &result, InternalErrorType::Internal,
                       "calling %s.CreateRenderPipeline(%s).", this, descriptor)) {
         result = RenderPipelineBase::MakeError(this, descriptor ? descriptor->label : nullptr);
     }
@@ -1333,20 +1354,28 @@
     TRACE_EVENT1(GetPlatform(), General, "DeviceBase::APICreateShaderModule", "label",
                  utils::GetLabelForTrace(descriptor->label));
 
-    Ref<ShaderModuleBase> result;
     std::unique_ptr<OwnedCompilationMessages> compilationMessages(
         std::make_unique<OwnedCompilationMessages>());
-    if (ConsumedError(CreateShaderModule(descriptor, compilationMessages.get()), &result,
-                      "calling %s.CreateShaderModule(%s).", this, descriptor)) {
+    auto resultOrError = CreateShaderModule(descriptor, &compilationMessages);
+    if (resultOrError.IsSuccess()) {
+        Ref<ShaderModuleBase> result = resultOrError.AcquireSuccess();
+        EmitCompilationLog(result.Get());
+        return ReturnToAPI(std::move(result));
+    }
+
+    // Acquire the device lock for error handling.
+    auto deviceLock(GetScopedLock());
+    Ref<ShaderModuleBase> result;
+    if (ConsumedError(std::move(resultOrError), &result, "calling %s.CreateShaderModule(%s).", this,
+                      descriptor)) {
         DAWN_ASSERT(result == nullptr);
         result = ShaderModuleBase::MakeError(this, descriptor ? descriptor->label : nullptr);
+        // Emit Tint errors and warnings for the error shader module.
+        // Also move the compilation messages to the shader module so the application can later
+        // retrieve it with GetCompilationInfo.
+        result->InjectCompilationMessages(std::move(compilationMessages));
     }
-    // Emit Tint errors and warnings after all operations are finished even if any of them is a
-    // failure and result in an error shader module. Also move the compilation messages to the
-    // shader module so the application can later retrieve it with GetCompilationInfo.
-    result->InjectCompilationMessages(std::move(compilationMessages));
     EmitCompilationLog(result.Get());
-
     return ReturnToAPI(std::move(result));
 }
 
@@ -1663,12 +1692,15 @@
 
     // Limit the number of compilation error emitted to avoid spamming the devtools console hard.
     constexpr uint32_t kCompilationLogSpamLimit = 20;
-    if (mEmittedCompilationLogCount > kCompilationLogSpamLimit) {
+    if (mEmittedCompilationLogCount.load(std::memory_order_acquire) > kCompilationLogSpamLimit) {
         return;
     }
 
-    mEmittedCompilationLogCount++;
-    if (mEmittedCompilationLogCount == kCompilationLogSpamLimit) {
+    if (mEmittedCompilationLogCount.fetch_add(1, std::memory_order_acq_rel) ==
+        kCompilationLogSpamLimit - 1) {
+        // Note: if there are multiple threads emitting logs, this may not actually be the exact
+        // last message. This is probably not a huge problem since this message will be emitted
+        // somewhere near the end.
         return EmitLog(WGPULoggingType_Warning,
                        "Reached the WGSL compilation log warning limit. To see all the compilation "
                        "logs, query them directly on the ShaderModule objects.");
@@ -2035,7 +2067,7 @@
 
 ResultOrError<Ref<ShaderModuleBase>> DeviceBase::CreateShaderModule(
     const ShaderModuleDescriptor* descriptor,
-    OwnedCompilationMessages* compilationMessages) {
+    std::unique_ptr<OwnedCompilationMessages>* compilationMessages) {
     DAWN_TRY(ValidateIsAlive());
 
     // CreateShaderModule can be called from inside dawn_native. If that's the case handle the
@@ -2047,9 +2079,10 @@
     if (IsValidationEnabled()) {
         DAWN_TRY_ASSIGN_CONTEXT(unpacked, ValidateAndUnpack(descriptor),
                                 "validating and unpacking %s", descriptor);
-        DAWN_TRY_CONTEXT(
-            ValidateAndParseShaderModule(this, unpacked, &parseResult, compilationMessages),
-            "validating %s", descriptor);
+        DAWN_TRY_CONTEXT(ValidateAndParseShaderModule(
+                             this, unpacked, &parseResult,
+                             compilationMessages ? compilationMessages->get() : nullptr),
+                         "validating %s", descriptor);
     } else {
         unpacked = Unpack(descriptor);
     }
diff --git a/src/dawn/native/Device.h b/src/dawn/native/Device.h
index 986dabd..95c8dcd 100644
--- a/src/dawn/native/Device.h
+++ b/src/dawn/native/Device.h
@@ -225,7 +225,7 @@
     ResultOrError<Ref<ShaderModuleBase>> GetOrCreateShaderModule(
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
         ShaderModuleParseResult* parseResult,
-        OwnedCompilationMessages* compilationMessages);
+        std::unique_ptr<OwnedCompilationMessages>* compilationMessages);
 
     Ref<AttachmentState> GetOrCreateAttachmentState(AttachmentState* blueprint);
     Ref<AttachmentState> GetOrCreateAttachmentState(
@@ -265,7 +265,7 @@
     ResultOrError<Ref<SamplerBase>> CreateSampler(const SamplerDescriptor* descriptor = nullptr);
     ResultOrError<Ref<ShaderModuleBase>> CreateShaderModule(
         const ShaderModuleDescriptor* descriptor,
-        OwnedCompilationMessages* compilationMessages = nullptr);
+        std::unique_ptr<OwnedCompilationMessages>* compilationMessages = nullptr);
     ResultOrError<Ref<SwapChainBase>> CreateSwapChain(Surface* surface,
                                                       const SwapChainDescriptor* descriptor);
     ResultOrError<Ref<TextureBase>> CreateTexture(const TextureDescriptor* rawDescriptor);
@@ -611,7 +611,7 @@
 
     struct DeprecationWarnings;
     std::unique_ptr<DeprecationWarnings> mDeprecationWarnings;
-    uint32_t mEmittedCompilationLogCount = 0;
+    std::atomic<uint32_t> mEmittedCompilationLogCount = 0;
 
     absl::flat_hash_set<std::string> mWarnings;
 
diff --git a/src/dawn/tests/end2end/MultithreadTests.cpp b/src/dawn/tests/end2end/MultithreadTests.cpp
index d1ddeef..cd73a02 100644
--- a/src/dawn/tests/end2end/MultithreadTests.cpp
+++ b/src/dawn/tests/end2end/MultithreadTests.cpp
@@ -236,6 +236,45 @@
     });
 }
 
+// Test CreateShaderModule on multiple threads. Cache hits should share compilation warnings.
+TEST_P(MultithreadTests, CreateShaderModuleInParallel) {
+    constexpr uint32_t kCacheHitFactor = 4;  // 4 threads will create the same shader module.
+
+    std::vector<std::string> shaderSources(10);
+    std::vector<wgpu::ShaderModule> shaderModules(shaderSources.size() * kCacheHitFactor);
+
+    std::string shader = R"(@fragment
+    fn main(@location(0) x : f32) {
+        return;
+        return;
+    };)";
+    for (uint32_t i = 0; i < shaderSources.size(); ++i) {
+        // Insert newlines to make the shaders unique.
+        shader = "\n" + shader;
+        shaderSources[i] = shader;
+    }
+
+    // Create shader modules in parallel.
+    utils::RunInParallel(static_cast<uint32_t>(shaderModules.size()), [&](uint32_t index) {
+        uint32_t sourceIndex = index / kCacheHitFactor;
+        shaderModules[index] =
+            utils::CreateShaderModule(device, shaderSources[sourceIndex].c_str());
+    });
+
+    // Check that the compilation info is correct for every shader module.
+    for (uint32_t index = 0; index < shaderModules.size(); ++index) {
+        uint32_t sourceIndex = index / kCacheHitFactor;
+        shaderModules[index].GetCompilationInfo(
+            [](WGPUCompilationInfoRequestStatus, const WGPUCompilationInfo* info, void* userdata) {
+                for (size_t i = 0; i < info->messageCount; ++i) {
+                    EXPECT_THAT(info->messages[i].message, testing::HasSubstr("unreachable"));
+                    EXPECT_EQ(info->messages[i].lineNum, 5u + *static_cast<uint32_t*>(userdata));
+                }
+            },
+            &sourceIndex);
+    }
+}
+
 // Test CreateComputePipelineAsync on multiple threads.
 TEST_P(MultithreadTests, CreateComputePipelineAsyncInParallel) {
     // TODO(crbug.com/dawn/1766): TSAN reported race conditions in NVIDIA's vk driver.
@@ -324,6 +363,72 @@
     }
 }
 
+// Test CreateComputePipeline on multiple threads.
+TEST_P(MultithreadTests, CreateComputePipelineInParallel) {
+    // TODO(crbug.com/dawn/1766): TSAN reported race conditions in NVIDIA's vk driver.
+    DAWN_SUPPRESS_TEST_IF(IsVulkan() && IsNvidia() && IsTsan());
+
+    std::vector<wgpu::ComputePipeline> pipelines(10);
+    std::vector<std::string> shaderSources(pipelines.size());
+    std::vector<uint32_t> expectedValues(shaderSources.size());
+
+    for (uint32_t i = 0; i < pipelines.size(); ++i) {
+        expectedValues[i] = i + 1;
+
+        std::ostringstream ss;
+        ss << R"(
+        struct SSBO {
+            value : u32
+        }
+        @group(0) @binding(0) var<storage, read_write> ssbo : SSBO;
+
+        @compute @workgroup_size(1) fn main() {
+            ssbo.value =
+        )";
+        ss << expectedValues[i];
+        ss << ";}";
+
+        shaderSources[i] = ss.str();
+    }
+
+    // Create pipelines in parallel
+    utils::RunInParallel(static_cast<uint32_t>(pipelines.size()), [&](uint32_t index) {
+        wgpu::ComputePipelineDescriptor csDesc;
+        csDesc.compute.module = utils::CreateShaderModule(device, shaderSources[index].c_str());
+        pipelines[index] = device.CreateComputePipeline(&csDesc);
+    });
+
+    // Verify pipelines' executions
+    for (uint32_t i = 0; i < pipelines.size(); ++i) {
+        wgpu::Buffer ssbo =
+            CreateBuffer(sizeof(uint32_t), wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc);
+
+        wgpu::CommandBuffer commands;
+        {
+            wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+            wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+
+            ASSERT_NE(nullptr, pipelines[i].Get());
+            wgpu::BindGroup bindGroup =
+                utils::MakeBindGroup(device, pipelines[i].GetBindGroupLayout(0),
+                                     {
+                                         {0, ssbo, 0, sizeof(uint32_t)},
+                                     });
+            pass.SetBindGroup(0, bindGroup);
+            pass.SetPipeline(pipelines[i]);
+
+            pass.DispatchWorkgroups(1);
+            pass.End();
+
+            commands = encoder.Finish();
+        }
+
+        queue.Submit(1, &commands);
+
+        EXPECT_BUFFER_U32_EQ(expectedValues[i], ssbo, 0);
+    }
+}
+
 // Test CreateRenderPipelineAsync on multiple threads.
 TEST_P(MultithreadTests, CreateRenderPipelineAsyncInParallel) {
     // TODO(crbug.com/dawn/1766): TSAN reported race conditions in NVIDIA's vk driver.
@@ -426,6 +531,87 @@
     }
 }
 
+// Test CreateRenderPipeline on multiple threads.
+TEST_P(MultithreadTests, CreateRenderPipelineInParallel) {
+    // TODO(crbug.com/dawn/1766): TSAN reported race conditions in NVIDIA's vk driver.
+    DAWN_SUPPRESS_TEST_IF(IsVulkan() && IsNvidia() && IsTsan());
+
+    constexpr uint32_t kNumThreads = 10;
+    constexpr wgpu::TextureFormat kRenderAttachmentFormat = wgpu::TextureFormat::RGBA8Unorm;
+    constexpr uint8_t kColorStep = 250 / kNumThreads;
+
+    std::vector<wgpu::RenderPipeline> pipelines(kNumThreads);
+    std::vector<std::string> fragmentShaderSources(kNumThreads);
+    std::vector<utils::RGBA8> minExpectedValues(kNumThreads);
+    std::vector<utils::RGBA8> maxExpectedValues(kNumThreads);
+
+    for (uint32_t i = 0; i < kNumThreads; ++i) {
+        // Due to floating point precision, we need to use min & max values to compare the
+        // expectations.
+        auto expectedGreen = kColorStep * i;
+        minExpectedValues[i] =
+            utils::RGBA8(0, expectedGreen == 0 ? 0 : (expectedGreen - 2), 0, 255);
+        maxExpectedValues[i] =
+            utils::RGBA8(0, expectedGreen == 255 ? 255 : (expectedGreen + 2), 0, 255);
+
+        std::ostringstream ss;
+        ss << R"(
+        @fragment fn main() -> @location(0) vec4f {
+            return vec4f(0.0,
+        )";
+        ss << expectedGreen / 255.0;
+        ss << ", 0.0, 1.0);}";
+
+        fragmentShaderSources[i] = ss.str();
+    }
+
+    // Create pipelines in parallel
+    utils::RunInParallel(kNumThreads, [&](uint32_t index) {
+        utils::ComboRenderPipelineDescriptor renderPipelineDescriptor;
+        wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
+        @vertex fn main() -> @builtin(position) vec4f {
+            return vec4f(0.0, 0.0, 0.0, 1.0);
+        })");
+        wgpu::ShaderModule fsModule =
+            utils::CreateShaderModule(device, fragmentShaderSources[index].c_str());
+        renderPipelineDescriptor.vertex.module = vsModule;
+        renderPipelineDescriptor.cFragment.module = fsModule;
+        renderPipelineDescriptor.cTargets[0].format = kRenderAttachmentFormat;
+        renderPipelineDescriptor.primitive.topology = wgpu::PrimitiveTopology::PointList;
+
+        pipelines[index] = device.CreateRenderPipeline(&renderPipelineDescriptor);
+    });
+
+    // Verify pipelines' executions
+    for (uint32_t i = 0; i < pipelines.size(); ++i) {
+        wgpu::Texture outputTexture =
+            CreateTexture(1, 1, kRenderAttachmentFormat,
+                          wgpu::TextureUsage::RenderAttachment | wgpu::TextureUsage::CopySrc);
+
+        utils::ComboRenderPassDescriptor renderPassDescriptor({outputTexture.CreateView()});
+        renderPassDescriptor.cColorAttachments[0].loadOp = wgpu::LoadOp::Clear;
+        renderPassDescriptor.cColorAttachments[0].clearValue = {1.f, 0.f, 0.f, 1.f};
+
+        wgpu::CommandBuffer commands;
+        {
+            wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+            wgpu::RenderPassEncoder renderPassEncoder =
+                encoder.BeginRenderPass(&renderPassDescriptor);
+
+            ASSERT_NE(nullptr, pipelines[i].Get());
+
+            renderPassEncoder.SetPipeline(pipelines[i]);
+            renderPassEncoder.Draw(1);
+            renderPassEncoder.End();
+            commands = encoder.Finish();
+        }
+
+        queue.Submit(1, &commands);
+
+        EXPECT_PIXEL_RGBA8_BETWEEN(minExpectedValues[i], maxExpectedValues[i], outputTexture, 0, 0);
+    }
+}
+
 class MultithreadCachingTests : public MultithreadTests {
   protected:
     wgpu::ShaderModule CreateComputeShaderModule() const {