Add flag to disable derivative_uniformity for SPIR-V ingestion

Add a chained struct for Dawn-specific options for SPIR-V ingestion to
contain this new flag.

Bug: tint:1890
Change-Id: I1332ff20c91f29a84c21550a37f11bc7d9c956ce
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/118421
Reviewed-by: Austin Eng <enga@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/dawn.json b/dawn.json
index de13f0d..03cf47d 100644
--- a/dawn.json
+++ b/dawn.json
@@ -2426,6 +2426,15 @@
             {"name": "code", "type": "char", "annotation": "const*", "length": "strlen", "tags": ["upstream"]}
         ]
     },
+    "dawn shader module SPIRV options descriptor": {
+        "category": "structure",
+        "chained": "in",
+        "chain roots": ["shader module descriptor"],
+        "tags": ["dawn"],
+        "members": [
+            {"name": "allow non uniform derivatives", "type": "bool", "default": "false"}
+        ]
+    },
     "shader stage": {
         "category": "bitmask",
         "values": [
@@ -2619,7 +2628,8 @@
             {"value": 1005, "name": "dawn cache device descriptor", "tags": ["dawn", "native"]},
             {"value": 1006, "name": "dawn adapter properties power preference", "tags": ["dawn", "native"]},
             {"value": 1007, "name": "dawn buffer descriptor error info from wire client", "tags": ["dawn"]},
-            {"value": 1008, "name": "dawn toggles descriptor", "tags": ["dawn", "native"]}
+            {"value": 1008, "name": "dawn toggles descriptor", "tags": ["dawn", "native"]},
+            {"value": 1009, "name": "dawn shader module SPIRV options descriptor", "tags": ["dawn"]}
         ]
     },
     "texture": {
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index c2ea13e..7db19a2 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -315,8 +315,13 @@
 
 #if TINT_BUILD_SPV_READER
 ResultOrError<tint::Program> ParseSPIRV(const std::vector<uint32_t>& spirv,
-                                        OwnedCompilationMessages* outMessages) {
-    tint::Program program = tint::reader::spirv::Parse(spirv);
+                                        OwnedCompilationMessages* outMessages,
+                                        const DawnShaderModuleSPIRVOptionsDescriptor* optionsDesc) {
+    tint::reader::spirv::Options options;
+    if (optionsDesc) {
+        options.allow_non_uniform_derivatives = optionsDesc->allowNonUniformDerivatives;
+    }
+    tint::Program program = tint::reader::spirv::Parse(spirv, options);
     if (outMessages != nullptr) {
         DAWN_TRY(outMessages->AddMessages(program.Diagnostics()));
     }
@@ -905,10 +910,13 @@
     DAWN_INVALID_IF(chainedDescriptor == nullptr,
                     "Shader module descriptor missing chained descriptor");
 
-// For now only a single WGSL (or SPIRV, if enabled) subdescriptor is allowed.
+// A WGSL (or SPIR-V, if enabled) subdescriptor is required, and a Dawn-specific SPIR-V options
+// descriptor is allowed when using SPIR-V.
 #if TINT_BUILD_SPV_READER
-    DAWN_TRY(ValidateSingleSType(chainedDescriptor, wgpu::SType::ShaderModuleSPIRVDescriptor,
-                                 wgpu::SType::ShaderModuleWGSLDescriptor));
+    DAWN_TRY(ValidateSTypes(
+        chainedDescriptor,
+        {{wgpu::SType::ShaderModuleSPIRVDescriptor, wgpu::SType::ShaderModuleWGSLDescriptor},
+         {wgpu::SType::DawnShaderModuleSPIRVOptionsDescriptor}}));
 #else
     DAWN_TRY(ValidateSingleSType(chainedDescriptor, wgpu::SType::ShaderModuleWGSLDescriptor));
 #endif
@@ -918,10 +926,19 @@
     const ShaderModuleWGSLDescriptor* wgslDesc = nullptr;
     FindInChain(chainedDescriptor, &wgslDesc);
 
+    const DawnShaderModuleSPIRVOptionsDescriptor* spirvOptions = nullptr;
+    FindInChain(chainedDescriptor, &spirvOptions);
+
+    DAWN_INVALID_IF(wgslDesc != nullptr && spirvOptions != nullptr,
+                    "SPIR-V options descriptor not valid with WGSL descriptor");
+
 #if TINT_BUILD_SPV_READER
     const ShaderModuleSPIRVDescriptor* spirvDesc = nullptr;
     FindInChain(chainedDescriptor, &spirvDesc);
 
+    DAWN_INVALID_IF(spirvOptions != nullptr && spirvDesc == nullptr,
+                    "SPIR-V options descriptor can only be used with SPIR-V input");
+
     // We have a temporary toggle to force the SPIRV ingestion to go through a WGSL
     // intermediate step. It is done by switching the spirvDesc for a wgslDesc below.
     ShaderModuleWGSLDescriptor newWgslDesc;
@@ -930,7 +947,7 @@
 #if TINT_BUILD_WGSL_WRITER
         std::vector<uint32_t> spirv(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
         tint::Program program;
-        DAWN_TRY_ASSIGN(program, ParseSPIRV(spirv, outMessages));
+        DAWN_TRY_ASSIGN(program, ParseSPIRV(spirv, outMessages, spirvOptions));
 
         tint::writer::wgsl::Options options;
         auto result = tint::writer::wgsl::Generate(&program, options);
@@ -953,7 +970,7 @@
 
         std::vector<uint32_t> spirv(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
         tint::Program program;
-        DAWN_TRY_ASSIGN(program, ParseSPIRV(spirv, outMessages));
+        DAWN_TRY_ASSIGN(program, ParseSPIRV(spirv, outMessages, spirvOptions));
         parseResult->tintProgram = std::make_unique<tint::Program>(std::move(program));
 
         return {};
diff --git a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp
index 0f8e4f7..9efb35f 100644
--- a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp
+++ b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp
@@ -137,6 +137,56 @@
 
     ASSERT_DEVICE_ERROR(utils::CreateShaderModuleFromASM(device, shader));
 }
+
+const char* kShaderWithNonUniformDerivative = R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %foo "foo" %x
+               OpExecutionMode %foo OriginUpperLeft
+               OpDecorate %x Location 0
+      %float = OpTypeFloat 32
+%_ptr_Input_float = OpTypePointer Input %float
+          %x = OpVariable %_ptr_Input_float Input
+       %void = OpTypeVoid
+    %float_0 = OpConstantNull %float
+       %bool = OpTypeBool
+  %func_type = OpTypeFunction %void
+        %foo = OpFunction %void None %func_type
+  %foo_start = OpLabel
+    %x_value = OpLoad %float %x
+  %condition = OpFOrdGreaterThan %bool %x_value %float_0
+               OpSelectionMerge %merge None
+               OpBranchConditional %condition %true_branch %merge
+%true_branch = OpLabel
+     %result = OpDPdx %float %x_value
+               OpBranch %merge
+      %merge = OpLabel
+               OpReturn
+               OpFunctionEnd)";
+
+// Test that creating a module with a SPIR-V shader that has a uniformity violation fails when no
+// SPIR-V options descriptor is used.
+TEST_F(ShaderModuleValidationTest, NonUniformDerivatives_NoOptions) {
+    ASSERT_DEVICE_ERROR(utils::CreateShaderModuleFromASM(device, kShaderWithNonUniformDerivative));
+}
+
+// Test that creating a module with a SPIR-V shader that has a uniformity violation fails when
+// passing a SPIR-V options descriptor with the `allowNonUniformDerivatives` flag set to `false`.
+TEST_F(ShaderModuleValidationTest, NonUniformDerivatives_FlagSetToFalse) {
+    wgpu::DawnShaderModuleSPIRVOptionsDescriptor spirv_options_desc = {};
+    spirv_options_desc.allowNonUniformDerivatives = false;
+    ASSERT_DEVICE_ERROR(utils::CreateShaderModuleFromASM(device, kShaderWithNonUniformDerivative,
+                                                         &spirv_options_desc));
+}
+
+// Test that creating a module with a SPIR-V shader that has a uniformity violation succeeds when
+// passing a SPIR-V options descriptor with the `allowNonUniformDerivatives` flag set to `true`.
+TEST_F(ShaderModuleValidationTest, NonUniformDerivatives_FlagSetToTrue) {
+    wgpu::DawnShaderModuleSPIRVOptionsDescriptor spirv_options_desc = {};
+    spirv_options_desc.allowNonUniformDerivatives = true;
+    utils::CreateShaderModuleFromASM(device, kShaderWithNonUniformDerivative, &spirv_options_desc);
+}
+
 #endif  // TINT_BUILD_SPV_READER
 
 // Test that it is invalid to create a shader module with no chained descriptor. (It must be
@@ -146,6 +196,47 @@
     ASSERT_DEVICE_ERROR(device.CreateShaderModule(&desc));
 }
 
+// Test that it is invalid to create a shader module that uses both the WGSL descriptor and the
+// SPIRV descriptor.
+TEST_F(ShaderModuleValidationTest, MultipleChainedDescriptor_WgslAndSpirv) {
+    uint32_t code = 42;
+    wgpu::ShaderModuleDescriptor desc = {};
+    wgpu::ShaderModuleSPIRVDescriptor spirv_desc = {};
+    spirv_desc.code = &code;
+    spirv_desc.codeSize = 1;
+    wgpu::ShaderModuleWGSLDescriptor wgsl_desc = {};
+    wgsl_desc.source = "";
+    wgsl_desc.nextInChain = &spirv_desc;
+    desc.nextInChain = &wgsl_desc;
+    ASSERT_DEVICE_ERROR(device.CreateShaderModule(&desc),
+                        testing::HasSubstr("is part of a group of exclusive sTypes"));
+}
+
+// Test that it is invalid to create a shader module that uses both the WGSL descriptor and the
+// Dawn SPIRV options descriptor.
+TEST_F(ShaderModuleValidationTest, MultipleChainedDescriptor_WgslAndDawnSpirvOptions) {
+    wgpu::ShaderModuleDescriptor desc = {};
+    wgpu::DawnShaderModuleSPIRVOptionsDescriptor spirv_options_desc = {};
+    wgpu::ShaderModuleWGSLDescriptor wgsl_desc = {};
+    wgsl_desc.nextInChain = &spirv_options_desc;
+    wgsl_desc.source = "";
+    desc.nextInChain = &wgsl_desc;
+    ASSERT_DEVICE_ERROR(
+        device.CreateShaderModule(&desc),
+        testing::HasSubstr("SPIR-V options descriptor not valid with WGSL descriptor"));
+}
+
+// Test that it is invalid to create a shader module that only uses the Dawn SPIRV options
+// descriptor without the SPIRV descriptor.
+TEST_F(ShaderModuleValidationTest, OnlySpirvOptionsDescriptor) {
+    wgpu::ShaderModuleDescriptor desc = {};
+    wgpu::DawnShaderModuleSPIRVOptionsDescriptor spirv_options_desc = {};
+    desc.nextInChain = &spirv_options_desc;
+    ASSERT_DEVICE_ERROR(
+        device.CreateShaderModule(&desc),
+        testing::HasSubstr("SPIR-V options descriptor can only be used with SPIR-V input"));
+}
+
 // Tests that shader module compilation messages can be queried.
 TEST_F(ShaderModuleValidationTest, GetCompilationMessages) {
     // This test works assuming ShaderModule is backed by a dawn::native::ShaderModuleBase, which
diff --git a/src/dawn/utils/WGPUHelpers.cpp b/src/dawn/utils/WGPUHelpers.cpp
index d6b1a4a..e517945 100644
--- a/src/dawn/utils/WGPUHelpers.cpp
+++ b/src/dawn/utils/WGPUHelpers.cpp
@@ -41,7 +41,10 @@
 
 namespace utils {
 #if TINT_BUILD_SPV_READER
-wgpu::ShaderModule CreateShaderModuleFromASM(const wgpu::Device& device, const char* source) {
+wgpu::ShaderModule CreateShaderModuleFromASM(
+    const wgpu::Device& device,
+    const char* source,
+    wgpu::DawnShaderModuleSPIRVOptionsDescriptor* spirv_options) {
     // Use SPIRV-Tools's C API to assemble the SPIR-V assembly text to binary. Because the types
     // aren't RAII, we don't return directly on success and instead always go through the code
     // path that destroys the SPIRV-Tools objects.
@@ -59,6 +62,7 @@
         wgpu::ShaderModuleSPIRVDescriptor spirvDesc;
         spirvDesc.codeSize = static_cast<uint32_t>(spirv->wordCount);
         spirvDesc.code = spirv->code;
+        spirvDesc.nextInChain = spirv_options;
 
         wgpu::ShaderModuleDescriptor descriptor;
         descriptor.nextInChain = &spirvDesc;
diff --git a/src/dawn/utils/WGPUHelpers.h b/src/dawn/utils/WGPUHelpers.h
index d56dc59..f05e323 100644
--- a/src/dawn/utils/WGPUHelpers.h
+++ b/src/dawn/utils/WGPUHelpers.h
@@ -28,7 +28,10 @@
 enum Expectation { Success, Failure };
 
 #if TINT_BUILD_SPV_READER
-wgpu::ShaderModule CreateShaderModuleFromASM(const wgpu::Device& device, const char* source);
+wgpu::ShaderModule CreateShaderModuleFromASM(
+    const wgpu::Device& device,
+    const char* source,
+    wgpu::DawnShaderModuleSPIRVOptionsDescriptor* spirv_options = nullptr);
 #endif
 wgpu::ShaderModule CreateShaderModule(const wgpu::Device& device, const char* source);