Vulkan: Add `emitPointSize` into `TransformedShaderModuleCacheKey`

This patch adds `emitPointSize` into `TransformedShaderModuleCacheKey`
because having `PointSize` or not should generate different SPIRV
code.

Fixed: chromium:341282611
Test: dawn_end2end_tests
Change-Id: I1395c4b51648e1c37776937a4beebc7b38565300
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/191770
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp
index 1e06207..5fbdd3c 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -78,13 +78,16 @@
     if (maxSubgroupSizeForFullSubgroups != other.maxSubgroupSizeForFullSubgroups) {
         return false;
     }
+    if (emitPointSize != other.emitPointSize) {
+        return false;
+    }
     return true;
 }
 
 size_t TransformedShaderModuleCacheKeyHashFunc::operator()(
     const TransformedShaderModuleCacheKey& key) const {
     size_t hash = 0;
-    HashCombine(&hash, key.layoutPtr, key.entryPoint);
+    HashCombine(&hash, key.layoutPtr, key.entryPoint, key.emitPointSize);
     for (const auto& entry : key.constants) {
         HashCombine(&hash, entry.first, entry.second);
     }
@@ -218,10 +221,13 @@
 
     ScopedTintICEHandler scopedICEHandler(GetDevice());
 
-    // Check to see if we have the handle and spirv cached already.
+    // Check to see if we have the handle and spirv cached already
+    // TODO(chromium:345359083): Improve the computation of the cache key. For example, it isn't
+    // ideal to use `reinterpret_cast<uintptr_t>(layout)` as the layout may be freed and
+    // reallocated during the runtime.
     auto cacheKey = TransformedShaderModuleCacheKey{
         reinterpret_cast<uintptr_t>(layout), programmableStage.entryPoint.c_str(),
-        programmableStage.constants, maxSubgroupSizeForFullSubgroups};
+        programmableStage.constants, maxSubgroupSizeForFullSubgroups, emitPointSize};
     auto handleAndSpirv = mTransformedShaderModuleCache->Find(cacheKey);
     if (handleAndSpirv.has_value()) {
         return std::move(*handleAndSpirv);
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.h b/src/dawn/native/vulkan/ShaderModuleVk.h
index f87c682..37fe631 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.h
+++ b/src/dawn/native/vulkan/ShaderModuleVk.h
@@ -51,6 +51,7 @@
     std::string entryPoint;
     PipelineConstantEntries constants;
     std::optional<uint32_t> maxSubgroupSizeForFullSubgroups;
+    bool emitPointSize;
 
     bool operator==(const TransformedShaderModuleCacheKey& other) const;
 };
diff --git a/src/dawn/tests/end2end/ShaderTests.cpp b/src/dawn/tests/end2end/ShaderTests.cpp
index 84c6dc4..87986da 100644
--- a/src/dawn/tests/end2end/ShaderTests.cpp
+++ b/src/dawn/tests/end2end/ShaderTests.cpp
@@ -2509,6 +2509,54 @@
     EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), output, 0, expected.size());
 }
 
+// A regression test for chromium:341282611. Test that Vulkan shader module cache should take the
+// primitive type into account because for `PointList` we should generate `PointSize` in the SPIRV
+// of the vertex shader.
+TEST_P(ShaderTests, SameShaderModuleToRenderPointAndNonPoint) {
+    std::string shader = R"(
+@vertex
+fn vs_main() -> @builtin(position) vec4f {
+    return vec4f(0.0, 0.0, 0.0, 1.0);
+}
+@fragment
+fn fs_main() -> @location(0) vec4f {
+    return vec4f(0.0, 0.0, 0.0, 1.0);
+}
+)";
+    wgpu::PipelineLayoutDescriptor layoutDesc = {};
+    layoutDesc.bindGroupLayoutCount = 0;
+    wgpu::PipelineLayout layout = device.CreatePipelineLayout(&layoutDesc);
+
+    wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, shader.c_str());
+    utils::ComboRenderPipelineDescriptor rpDesc;
+    rpDesc.vertex.module = shaderModule;
+    rpDesc.vertex.entryPoint = "vs_main";
+    rpDesc.layout = layout;
+    rpDesc.cFragment.module = shaderModule;
+    rpDesc.cFragment.entryPoint = "fs_main";
+
+    utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, 64, 64);
+    rpDesc.cTargets[0].format = renderPass.colorFormat;
+
+    wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+    wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+    {
+        rpDesc.primitive.topology = wgpu::PrimitiveTopology::TriangleList;
+        wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&rpDesc);
+        pass.SetPipeline(pipeline);
+        pass.Draw(3);
+    }
+    {
+        rpDesc.primitive.topology = wgpu::PrimitiveTopology::PointList;
+        wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&rpDesc);
+        pass.SetPipeline(pipeline);
+        pass.Draw(1);
+    }
+    pass.End();
+    wgpu::CommandBuffer commandBuffer = encoder.Finish();
+    queue.Submit(1, &commandBuffer);
+}
+
 DAWN_INSTANTIATE_TEST(ShaderTests,
                       D3D11Backend(),
                       D3D12Backend(),