Remove global device lock from CreateShaderModule/Pipeline
This allows multithreaded creation of these objects without
acquiring the global device lock. Note that creating objects
while the device is being destroyed is still not thread-safe.
Tested on Linux TSAN to be data-race free.
Bug: dawn:1662
Change-Id: Ia715fbccad2392f6643919426ef7e4fc5c24166a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/175206
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/dawn/dawn.json b/src/dawn/dawn.json
index 27cd49c..7dbbaa4 100644
--- a/src/dawn/dawn.json
+++ b/src/dawn/dawn.json
@@ -1197,6 +1197,7 @@
},
{
"name": "create compute pipeline",
+ "no autolock": true,
"returns": "compute pipeline",
"args": [
{"name": "descriptor", "type": "compute pipeline descriptor", "annotation": "const*"}
@@ -1276,6 +1277,7 @@
},
{
"name": "create render pipeline",
+ "no autolock": true,
"returns": "render pipeline",
"args": [
{"name": "descriptor", "type": "render pipeline descriptor", "annotation": "const*"}
@@ -1290,6 +1292,7 @@
},
{
"name": "create shader module",
+ "no autolock": true,
"returns": "shader module",
"args": [
{"name": "descriptor", "type": "shader module descriptor", "annotation": "const*"}
diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp
index 304786b..d64df1d 100644
--- a/src/dawn/native/Device.cpp
+++ b/src/dawn/native/Device.cpp
@@ -965,14 +965,12 @@
Ref<ComputePipelineBase> DeviceBase::AddOrGetCachedComputePipeline(
Ref<ComputePipelineBase> computePipeline) {
- DAWN_ASSERT(IsLockedByCurrentThreadIfNeeded());
auto [pipeline, _] = mCaches->computePipelines.Insert(computePipeline.Get());
return std::move(pipeline);
}
Ref<RenderPipelineBase> DeviceBase::AddOrGetCachedRenderPipeline(
Ref<RenderPipelineBase> renderPipeline) {
- DAWN_ASSERT(IsLockedByCurrentThreadIfNeeded());
auto [pipeline, _] = mCaches->renderPipelines.Insert(renderPipeline.Get());
return std::move(pipeline);
}
@@ -1066,7 +1064,7 @@
ResultOrError<Ref<ShaderModuleBase>> DeviceBase::GetOrCreateShaderModule(
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
ShaderModuleParseResult* parseResult,
- OwnedCompilationMessages* compilationMessages) {
+ std::unique_ptr<OwnedCompilationMessages>* compilationMessages) {
DAWN_ASSERT(parseResult != nullptr);
ShaderModuleBase blueprint(this, descriptor, ApiObjectBase::kUntrackedByDevice);
@@ -1076,18 +1074,19 @@
return GetOrCreate(
mCaches->shaderModules, &blueprint, [&]() -> ResultOrError<Ref<ShaderModuleBase>> {
+ auto* unownedMessages = compilationMessages ? compilationMessages->get() : nullptr;
if (!parseResult->HasParsedShader()) {
// We skip the parse on creation if validation isn't enabled which let's us quickly
// lookup in the cache without validating and parsing. We need the parsed module
// now.
DAWN_ASSERT(!IsValidationEnabled());
- DAWN_TRY(ValidateAndParseShaderModule(this, descriptor, parseResult,
- compilationMessages));
+ DAWN_TRY(
+ ValidateAndParseShaderModule(this, descriptor, parseResult, unownedMessages));
}
auto resultOrError = [&]() -> ResultOrError<Ref<ShaderModuleBase>> {
SCOPED_DAWN_HISTOGRAM_TIMER_MICROS(GetPlatform(), "CreateShaderModuleUS");
- return CreateShaderModuleImpl(descriptor, parseResult, compilationMessages);
+ return CreateShaderModuleImpl(descriptor, parseResult, unownedMessages);
}();
DAWN_HISTOGRAM_BOOLEAN(GetPlatform(), "CreateShaderModuleSuccess",
resultOrError.IsSuccess());
@@ -1095,6 +1094,11 @@
Ref<ShaderModuleBase> result;
DAWN_TRY_ASSIGN(result, std::move(resultOrError));
result->SetContentHash(blueprintHash);
+ // Inject compilation messages now, as another thread may get a cache hit and query them
+ // immediately after insert into the cache.
+ if (compilationMessages) {
+ result->InjectCompilationMessages(std::move(*compilationMessages));
+ }
return result;
});
}
@@ -1171,6 +1175,7 @@
}
// Error case continues below, and will acquire the device lock for
// thread-safe error handling.
+ // TODO(dawn:1662): Make error handling thread-safe.
}
auto deviceLock(GetScopedLock());
@@ -1199,8 +1204,16 @@
TRACE_EVENT1(GetPlatform(), General, "DeviceBase::APICreateComputePipeline", "label",
utils::GetLabelForTrace(descriptor->label));
+ auto resultOrError = CreateComputePipeline(descriptor);
+ if (resultOrError.IsSuccess()) {
+ return ReturnToAPI(resultOrError.AcquireSuccess());
+ }
+
+ // Acquire the device lock for error handling.
+ // TODO(dawn:1662): Make error handling thread-safe.
+ auto deviceLock(GetScopedLock());
Ref<ComputePipelineBase> result;
- if (ConsumedError(CreateComputePipeline(descriptor), &result, InternalErrorType::Internal,
+ if (ConsumedError(std::move(resultOrError), &result, InternalErrorType::Internal,
"calling %s.CreateComputePipeline(%s).", this, descriptor)) {
result = ComputePipelineBase::MakeError(this, descriptor ? descriptor->label : nullptr);
}
@@ -1322,8 +1335,16 @@
TRACE_EVENT1(GetPlatform(), General, "DeviceBase::APICreateRenderPipeline", "label",
utils::GetLabelForTrace(descriptor->label));
+ auto resultOrError = CreateRenderPipeline(descriptor);
+ if (resultOrError.IsSuccess()) {
+ return ReturnToAPI(resultOrError.AcquireSuccess());
+ }
+
+ // Acquire the device lock for error handling.
+ // TODO(dawn:1662): Make error handling thread-safe.
+ auto deviceLock(GetScopedLock());
Ref<RenderPipelineBase> result;
- if (ConsumedError(CreateRenderPipeline(descriptor), &result, InternalErrorType::Internal,
+ if (ConsumedError(std::move(resultOrError), &result, InternalErrorType::Internal,
"calling %s.CreateRenderPipeline(%s).", this, descriptor)) {
result = RenderPipelineBase::MakeError(this, descriptor ? descriptor->label : nullptr);
}
@@ -1333,20 +1354,28 @@
TRACE_EVENT1(GetPlatform(), General, "DeviceBase::APICreateShaderModule", "label",
utils::GetLabelForTrace(descriptor->label));
- Ref<ShaderModuleBase> result;
std::unique_ptr<OwnedCompilationMessages> compilationMessages(
std::make_unique<OwnedCompilationMessages>());
- if (ConsumedError(CreateShaderModule(descriptor, compilationMessages.get()), &result,
- "calling %s.CreateShaderModule(%s).", this, descriptor)) {
+ auto resultOrError = CreateShaderModule(descriptor, &compilationMessages);
+ if (resultOrError.IsSuccess()) {
+ Ref<ShaderModuleBase> result = resultOrError.AcquireSuccess();
+ EmitCompilationLog(result.Get());
+ return ReturnToAPI(std::move(result));
+ }
+
+ // Acquire the device lock for error handling.
+ auto deviceLock(GetScopedLock());
+ Ref<ShaderModuleBase> result;
+ if (ConsumedError(std::move(resultOrError), &result, "calling %s.CreateShaderModule(%s).", this,
+ descriptor)) {
DAWN_ASSERT(result == nullptr);
result = ShaderModuleBase::MakeError(this, descriptor ? descriptor->label : nullptr);
+ // Emit Tint errors and warnings for the error shader module.
+ // Also move the compilation messages to the shader module so the application can later
+ // retrieve it with GetCompilationInfo.
+ result->InjectCompilationMessages(std::move(compilationMessages));
}
- // Emit Tint errors and warnings after all operations are finished even if any of them is a
- // failure and result in an error shader module. Also move the compilation messages to the
- // shader module so the application can later retrieve it with GetCompilationInfo.
- result->InjectCompilationMessages(std::move(compilationMessages));
EmitCompilationLog(result.Get());
-
return ReturnToAPI(std::move(result));
}
@@ -1663,12 +1692,15 @@
// Limit the number of compilation error emitted to avoid spamming the devtools console hard.
constexpr uint32_t kCompilationLogSpamLimit = 20;
- if (mEmittedCompilationLogCount > kCompilationLogSpamLimit) {
+ if (mEmittedCompilationLogCount.load(std::memory_order_acquire) > kCompilationLogSpamLimit) {
return;
}
- mEmittedCompilationLogCount++;
- if (mEmittedCompilationLogCount == kCompilationLogSpamLimit) {
+ if (mEmittedCompilationLogCount.fetch_add(1, std::memory_order_acq_rel) ==
+ kCompilationLogSpamLimit - 1) {
+ // Note: if there are multiple threads emitting logs, this may not actually be the exact
+ // last message. This is probably not a huge problem since this message will be emitted
+ // somewhere near the end.
return EmitLog(WGPULoggingType_Warning,
"Reached the WGSL compilation log warning limit. To see all the compilation "
"logs, query them directly on the ShaderModule objects.");
@@ -2035,7 +2067,7 @@
ResultOrError<Ref<ShaderModuleBase>> DeviceBase::CreateShaderModule(
const ShaderModuleDescriptor* descriptor,
- OwnedCompilationMessages* compilationMessages) {
+ std::unique_ptr<OwnedCompilationMessages>* compilationMessages) {
DAWN_TRY(ValidateIsAlive());
// CreateShaderModule can be called from inside dawn_native. If that's the case handle the
@@ -2047,9 +2079,10 @@
if (IsValidationEnabled()) {
DAWN_TRY_ASSIGN_CONTEXT(unpacked, ValidateAndUnpack(descriptor),
"validating and unpacking %s", descriptor);
- DAWN_TRY_CONTEXT(
- ValidateAndParseShaderModule(this, unpacked, &parseResult, compilationMessages),
- "validating %s", descriptor);
+ DAWN_TRY_CONTEXT(ValidateAndParseShaderModule(
+ this, unpacked, &parseResult,
+ compilationMessages ? compilationMessages->get() : nullptr),
+ "validating %s", descriptor);
} else {
unpacked = Unpack(descriptor);
}
diff --git a/src/dawn/native/Device.h b/src/dawn/native/Device.h
index 986dabd..95c8dcd 100644
--- a/src/dawn/native/Device.h
+++ b/src/dawn/native/Device.h
@@ -225,7 +225,7 @@
ResultOrError<Ref<ShaderModuleBase>> GetOrCreateShaderModule(
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
ShaderModuleParseResult* parseResult,
- OwnedCompilationMessages* compilationMessages);
+ std::unique_ptr<OwnedCompilationMessages>* compilationMessages);
Ref<AttachmentState> GetOrCreateAttachmentState(AttachmentState* blueprint);
Ref<AttachmentState> GetOrCreateAttachmentState(
@@ -265,7 +265,7 @@
ResultOrError<Ref<SamplerBase>> CreateSampler(const SamplerDescriptor* descriptor = nullptr);
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModule(
const ShaderModuleDescriptor* descriptor,
- OwnedCompilationMessages* compilationMessages = nullptr);
+ std::unique_ptr<OwnedCompilationMessages>* compilationMessages = nullptr);
ResultOrError<Ref<SwapChainBase>> CreateSwapChain(Surface* surface,
const SwapChainDescriptor* descriptor);
ResultOrError<Ref<TextureBase>> CreateTexture(const TextureDescriptor* rawDescriptor);
@@ -611,7 +611,7 @@
struct DeprecationWarnings;
std::unique_ptr<DeprecationWarnings> mDeprecationWarnings;
- uint32_t mEmittedCompilationLogCount = 0;
+ std::atomic<uint32_t> mEmittedCompilationLogCount = 0;
absl::flat_hash_set<std::string> mWarnings;
diff --git a/src/dawn/tests/end2end/MultithreadTests.cpp b/src/dawn/tests/end2end/MultithreadTests.cpp
index d1ddeef..cd73a02 100644
--- a/src/dawn/tests/end2end/MultithreadTests.cpp
+++ b/src/dawn/tests/end2end/MultithreadTests.cpp
@@ -236,6 +236,45 @@
});
}
+// Test CreateShaderModule on multiple threads. Cache hits should share compilation warnings.
+TEST_P(MultithreadTests, CreateShaderModuleInParallel) {
+ constexpr uint32_t kCacheHitFactor = 4; // 4 threads will create the same shader module.
+
+ std::vector<std::string> shaderSources(10);
+ std::vector<wgpu::ShaderModule> shaderModules(shaderSources.size() * kCacheHitFactor);
+
+ std::string shader = R"(@fragment
+ fn main(@location(0) x : f32) {
+ return;
+ return;
+ };)";
+ for (uint32_t i = 0; i < shaderSources.size(); ++i) {
+ // Insert newlines to make the shaders unique.
+ shader = "\n" + shader;
+ shaderSources[i] = shader;
+ }
+
+ // Create shader modules in parallel.
+ utils::RunInParallel(static_cast<uint32_t>(shaderModules.size()), [&](uint32_t index) {
+ uint32_t sourceIndex = index / kCacheHitFactor;
+ shaderModules[index] =
+ utils::CreateShaderModule(device, shaderSources[sourceIndex].c_str());
+ });
+
+ // Check that the compilation info is correct for every shader module.
+ for (uint32_t index = 0; index < shaderModules.size(); ++index) {
+ uint32_t sourceIndex = index / kCacheHitFactor;
+ shaderModules[index].GetCompilationInfo(
+ [](WGPUCompilationInfoRequestStatus, const WGPUCompilationInfo* info, void* userdata) {
+ for (size_t i = 0; i < info->messageCount; ++i) {
+ EXPECT_THAT(info->messages[i].message, testing::HasSubstr("unreachable"));
+ EXPECT_EQ(info->messages[i].lineNum, 5u + *static_cast<uint32_t*>(userdata));
+ }
+ },
+ &sourceIndex);
+ }
+}
+
// Test CreateComputePipelineAsync on multiple threads.
TEST_P(MultithreadTests, CreateComputePipelineAsyncInParallel) {
// TODO(crbug.com/dawn/1766): TSAN reported race conditions in NVIDIA's vk driver.
@@ -324,6 +363,72 @@
}
}
+// Test CreateComputePipeline on multiple threads.
+TEST_P(MultithreadTests, CreateComputePipelineInParallel) {
+ // TODO(crbug.com/dawn/1766): TSAN reported race conditions in NVIDIA's vk driver.
+ DAWN_SUPPRESS_TEST_IF(IsVulkan() && IsNvidia() && IsTsan());
+
+ std::vector<wgpu::ComputePipeline> pipelines(10);
+ std::vector<std::string> shaderSources(pipelines.size());
+ std::vector<uint32_t> expectedValues(shaderSources.size());
+
+ for (uint32_t i = 0; i < pipelines.size(); ++i) {
+ expectedValues[i] = i + 1;
+
+ std::ostringstream ss;
+ ss << R"(
+ struct SSBO {
+ value : u32
+ }
+ @group(0) @binding(0) var<storage, read_write> ssbo : SSBO;
+
+ @compute @workgroup_size(1) fn main() {
+ ssbo.value =
+ )";
+ ss << expectedValues[i];
+ ss << ";}";
+
+ shaderSources[i] = ss.str();
+ }
+
+ // Create pipelines in parallel
+ utils::RunInParallel(static_cast<uint32_t>(pipelines.size()), [&](uint32_t index) {
+ wgpu::ComputePipelineDescriptor csDesc;
+ csDesc.compute.module = utils::CreateShaderModule(device, shaderSources[index].c_str());
+ pipelines[index] = device.CreateComputePipeline(&csDesc);
+ });
+
+ // Verify pipelines' executions
+ for (uint32_t i = 0; i < pipelines.size(); ++i) {
+ wgpu::Buffer ssbo =
+ CreateBuffer(sizeof(uint32_t), wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc);
+
+ wgpu::CommandBuffer commands;
+ {
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+
+ ASSERT_NE(nullptr, pipelines[i].Get());
+ wgpu::BindGroup bindGroup =
+ utils::MakeBindGroup(device, pipelines[i].GetBindGroupLayout(0),
+ {
+ {0, ssbo, 0, sizeof(uint32_t)},
+ });
+ pass.SetBindGroup(0, bindGroup);
+ pass.SetPipeline(pipelines[i]);
+
+ pass.DispatchWorkgroups(1);
+ pass.End();
+
+ commands = encoder.Finish();
+ }
+
+ queue.Submit(1, &commands);
+
+ EXPECT_BUFFER_U32_EQ(expectedValues[i], ssbo, 0);
+ }
+}
+
// Test CreateRenderPipelineAsync on multiple threads.
TEST_P(MultithreadTests, CreateRenderPipelineAsyncInParallel) {
// TODO(crbug.com/dawn/1766): TSAN reported race conditions in NVIDIA's vk driver.
@@ -426,6 +531,87 @@
}
}
+// Test CreateRenderPipeline on multiple threads.
+TEST_P(MultithreadTests, CreateRenderPipelineInParallel) {
+ // TODO(crbug.com/dawn/1766): TSAN reported race conditions in NVIDIA's vk driver.
+ DAWN_SUPPRESS_TEST_IF(IsVulkan() && IsNvidia() && IsTsan());
+
+ constexpr uint32_t kNumThreads = 10;
+ constexpr wgpu::TextureFormat kRenderAttachmentFormat = wgpu::TextureFormat::RGBA8Unorm;
+ constexpr uint8_t kColorStep = 250 / kNumThreads;
+
+ std::vector<wgpu::RenderPipeline> pipelines(kNumThreads);
+ std::vector<std::string> fragmentShaderSources(kNumThreads);
+ std::vector<utils::RGBA8> minExpectedValues(kNumThreads);
+ std::vector<utils::RGBA8> maxExpectedValues(kNumThreads);
+
+ for (uint32_t i = 0; i < kNumThreads; ++i) {
+ // Due to floating point precision, we need to use min & max values to compare the
+ // expectations.
+ auto expectedGreen = kColorStep * i;
+ minExpectedValues[i] =
+ utils::RGBA8(0, expectedGreen == 0 ? 0 : (expectedGreen - 2), 0, 255);
+ maxExpectedValues[i] =
+ utils::RGBA8(0, expectedGreen == 255 ? 255 : (expectedGreen + 2), 0, 255);
+
+ std::ostringstream ss;
+ ss << R"(
+ @fragment fn main() -> @location(0) vec4f {
+ return vec4f(0.0,
+ )";
+ ss << expectedGreen / 255.0;
+ ss << ", 0.0, 1.0);}";
+
+ fragmentShaderSources[i] = ss.str();
+ }
+
+ // Create pipelines in parallel
+ utils::RunInParallel(kNumThreads, [&](uint32_t index) {
+ utils::ComboRenderPipelineDescriptor renderPipelineDescriptor;
+ wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
+ @vertex fn main() -> @builtin(position) vec4f {
+ return vec4f(0.0, 0.0, 0.0, 1.0);
+ })");
+ wgpu::ShaderModule fsModule =
+ utils::CreateShaderModule(device, fragmentShaderSources[index].c_str());
+ renderPipelineDescriptor.vertex.module = vsModule;
+ renderPipelineDescriptor.cFragment.module = fsModule;
+ renderPipelineDescriptor.cTargets[0].format = kRenderAttachmentFormat;
+ renderPipelineDescriptor.primitive.topology = wgpu::PrimitiveTopology::PointList;
+
+ pipelines[index] = device.CreateRenderPipeline(&renderPipelineDescriptor);
+ });
+
+ // Verify pipelines' executions
+ for (uint32_t i = 0; i < pipelines.size(); ++i) {
+ wgpu::Texture outputTexture =
+ CreateTexture(1, 1, kRenderAttachmentFormat,
+ wgpu::TextureUsage::RenderAttachment | wgpu::TextureUsage::CopySrc);
+
+ utils::ComboRenderPassDescriptor renderPassDescriptor({outputTexture.CreateView()});
+ renderPassDescriptor.cColorAttachments[0].loadOp = wgpu::LoadOp::Clear;
+ renderPassDescriptor.cColorAttachments[0].clearValue = {1.f, 0.f, 0.f, 1.f};
+
+ wgpu::CommandBuffer commands;
+ {
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+ wgpu::RenderPassEncoder renderPassEncoder =
+ encoder.BeginRenderPass(&renderPassDescriptor);
+
+ ASSERT_NE(nullptr, pipelines[i].Get());
+
+ renderPassEncoder.SetPipeline(pipelines[i]);
+ renderPassEncoder.Draw(1);
+ renderPassEncoder.End();
+ commands = encoder.Finish();
+ }
+
+ queue.Submit(1, &commands);
+
+ EXPECT_PIXEL_RGBA8_BETWEEN(minExpectedValues[i], maxExpectedValues[i], outputTexture, 0, 0);
+ }
+}
+
class MultithreadCachingTests : public MultithreadTests {
protected:
wgpu::ShaderModule CreateComputeShaderModule() const {