Vulkan: Add `emitPointSize` into `TransformedShaderModuleCacheKey`
This patch adds `emitPointSize` into `TransformedShaderModuleCacheKey`
because having `PointSize` or not should generate different SPIRV
code.
Fixed: chromium:341282611
Test: dawn_end2end_tests
Change-Id: I1395c4b51648e1c37776937a4beebc7b38565300
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/191770
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp
index 1e06207..5fbdd3c 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -78,13 +78,16 @@
if (maxSubgroupSizeForFullSubgroups != other.maxSubgroupSizeForFullSubgroups) {
return false;
}
+ if (emitPointSize != other.emitPointSize) {
+ return false;
+ }
return true;
}
size_t TransformedShaderModuleCacheKeyHashFunc::operator()(
const TransformedShaderModuleCacheKey& key) const {
size_t hash = 0;
- HashCombine(&hash, key.layoutPtr, key.entryPoint);
+ HashCombine(&hash, key.layoutPtr, key.entryPoint, key.emitPointSize);
for (const auto& entry : key.constants) {
HashCombine(&hash, entry.first, entry.second);
}
@@ -218,10 +221,13 @@
ScopedTintICEHandler scopedICEHandler(GetDevice());
- // Check to see if we have the handle and spirv cached already.
+ // Check to see if we have the handle and spirv cached already
+ // TODO(chromium:345359083): Improve the computation of the cache key. For example, it isn't
+ // ideal to use `reinterpret_cast<uintptr_t>(layout)` as the layout may be freed and
+ // reallocated during the runtime.
auto cacheKey = TransformedShaderModuleCacheKey{
reinterpret_cast<uintptr_t>(layout), programmableStage.entryPoint.c_str(),
- programmableStage.constants, maxSubgroupSizeForFullSubgroups};
+ programmableStage.constants, maxSubgroupSizeForFullSubgroups, emitPointSize};
auto handleAndSpirv = mTransformedShaderModuleCache->Find(cacheKey);
if (handleAndSpirv.has_value()) {
return std::move(*handleAndSpirv);
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.h b/src/dawn/native/vulkan/ShaderModuleVk.h
index f87c682..37fe631 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.h
+++ b/src/dawn/native/vulkan/ShaderModuleVk.h
@@ -51,6 +51,7 @@
std::string entryPoint;
PipelineConstantEntries constants;
std::optional<uint32_t> maxSubgroupSizeForFullSubgroups;
+ bool emitPointSize;
bool operator==(const TransformedShaderModuleCacheKey& other) const;
};
diff --git a/src/dawn/tests/end2end/ShaderTests.cpp b/src/dawn/tests/end2end/ShaderTests.cpp
index 84c6dc4..87986da 100644
--- a/src/dawn/tests/end2end/ShaderTests.cpp
+++ b/src/dawn/tests/end2end/ShaderTests.cpp
@@ -2509,6 +2509,54 @@
EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), output, 0, expected.size());
}
+// A regression test for chromium:341282611. Test that Vulkan shader module cache should take the
+// primitive type into account because for `PointList` we should generate `PointSize` in the SPIRV
+// of the vertex shader.
+TEST_P(ShaderTests, SameShaderModuleToRenderPointAndNonPoint) {
+ std::string shader = R"(
+@vertex
+fn vs_main() -> @builtin(position) vec4f {
+ return vec4f(0.0, 0.0, 0.0, 1.0);
+}
+@fragment
+fn fs_main() -> @location(0) vec4f {
+ return vec4f(0.0, 0.0, 0.0, 1.0);
+}
+)";
+ wgpu::PipelineLayoutDescriptor layoutDesc = {};
+ layoutDesc.bindGroupLayoutCount = 0;
+ wgpu::PipelineLayout layout = device.CreatePipelineLayout(&layoutDesc);
+
+ wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, shader.c_str());
+ utils::ComboRenderPipelineDescriptor rpDesc;
+ rpDesc.vertex.module = shaderModule;
+ rpDesc.vertex.entryPoint = "vs_main";
+ rpDesc.layout = layout;
+ rpDesc.cFragment.module = shaderModule;
+ rpDesc.cFragment.entryPoint = "fs_main";
+
+ utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, 64, 64);
+ rpDesc.cTargets[0].format = renderPass.colorFormat;
+
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+ wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+ {
+ rpDesc.primitive.topology = wgpu::PrimitiveTopology::TriangleList;
+ wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&rpDesc);
+ pass.SetPipeline(pipeline);
+ pass.Draw(3);
+ }
+ {
+ rpDesc.primitive.topology = wgpu::PrimitiveTopology::PointList;
+ wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&rpDesc);
+ pass.SetPipeline(pipeline);
+ pass.Draw(1);
+ }
+ pass.End();
+ wgpu::CommandBuffer commandBuffer = encoder.Finish();
+ queue.Submit(1, &commandBuffer);
+}
+
DAWN_INSTANTIATE_TEST(ShaderTests,
D3D11Backend(),
D3D12Backend(),