tint/exe: Add --allow-non-uniform-derivatives flag

When used with the SPIR-V reader, this will insert a module-scope
diagnostic directive to suppress uniformity violations for derivative
operations.

Bug: tint:1809
Change-Id: I2305265231ccffad49461d194669ba598484e8e0
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/117740
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/cmd/main.cc b/src/tint/cmd/main.cc
index a30a133..fa8dbc0 100644
--- a/src/tint/cmd/main.cc
+++ b/src/tint/cmd/main.cc
@@ -103,6 +103,10 @@
 
     bool rename_all = false;
 
+#if TINT_BUILD_SPV_READER
+    tint::reader::spirv::Options spirv_reader_options;
+#endif
+
     std::vector<std::string> transforms;
 
     std::string fxc_path;
@@ -135,6 +139,9 @@
   --transform <name list>   -- Runs transforms, name list is comma separated
                                Available transforms:
 ${transforms} --parse-only              -- Stop after parsing the input
+  --allow-non-uniform-derivatives  -- When using SPIR-V input, allow non-uniform derivatives by
+                               inserting a module-scope directive to suppress any uniformity
+                               violations that may be produced.
   --disable-workgroup-init  -- Disable workgroup memory zero initialization.
   --demangle                -- Preserve original source names. Demangle them.
                                Affects AST dumping, and text-based output languages.
@@ -443,6 +450,13 @@
             opts->transforms = split_on_comma(args[i]);
         } else if (arg == "--parse-only") {
             opts->parse_only = true;
+        } else if (arg == "--allow-non-uniform-derivatives") {
+#if TINT_BUILD_SPV_READER
+            opts->spirv_reader_options.allow_non_uniform_derivatives = true;
+#else
+            std::cerr << "Tint not built with the SPIR-V reader enabled" << std::endl;
+            return false;
+#endif
         } else if (arg == "--disable-workgroup-init") {
             opts->disable_workgroup_init = true;
         } else if (arg == "--demangle") {
@@ -1285,7 +1299,8 @@
             if (!ReadFile<uint32_t>(options.input_filename, &data)) {
                 return 1;
             }
-            program = std::make_unique<tint::Program>(tint::reader::spirv::Parse(data));
+            program = std::make_unique<tint::Program>(
+                tint::reader::spirv::Parse(data, options.spirv_reader_options));
             break;
 #else
             std::cerr << "Tint not built with the SPIR-V reader enabled" << std::endl;
@@ -1309,7 +1324,8 @@
                                 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS)) {
                 return 1;
             }
-            program = std::make_unique<tint::Program>(tint::reader::spirv::Parse(data));
+            program = std::make_unique<tint::Program>(
+                tint::reader::spirv::Parse(data, options.spirv_reader_options));
             break;
 #else
             std::cerr << "Tint not built with the SPIR-V reader enabled" << std::endl;
diff --git a/src/tint/reader/spirv/parser.cc b/src/tint/reader/spirv/parser.cc
index ac43b9e..41e6df3 100644
--- a/src/tint/reader/spirv/parser.cc
+++ b/src/tint/reader/spirv/parser.cc
@@ -27,7 +27,7 @@
 
 namespace tint::reader::spirv {
 
-Program Parse(const std::vector<uint32_t>& input) {
+Program Parse(const std::vector<uint32_t>& input, const Options& options) {
     ParserImpl parser(input);
     bool parsed = parser.Parse();
 
@@ -38,6 +38,13 @@
         return Program(std::move(builder));
     }
 
+    if (options.allow_non_uniform_derivatives) {
+        // Suppress errors regarding non-uniform derivative operations if requested, by adding a
+        // diagnostic directive to the module.
+        builder.DiagnosticDirective(ast::DiagnosticSeverity::kOff,
+                                    builder.Expr("derivative_uniformity"));
+    }
+
     // The SPIR-V parser can construct disjoint AST nodes, which is invalid for
     // the Resolver. Clone the Program to clean these up.
     builder.SetResolveOnBuild(false);
diff --git a/src/tint/reader/spirv/parser.h b/src/tint/reader/spirv/parser.h
index 3641e08..78c80c0d 100644
--- a/src/tint/reader/spirv/parser.h
+++ b/src/tint/reader/spirv/parser.h
@@ -21,13 +21,20 @@
 
 namespace tint::reader::spirv {
 
+/// Options that control how the SPIR-V parser should behave.
+struct Options {
+    /// Set to `true` to allow calls to derivative builtins in non-uniform control flow.
+    bool allow_non_uniform_derivatives = false;
+};
+
 /// Parses the SPIR-V source data, returning the parsed program.
 /// If the source data fails to parse then the returned
 /// `program.Diagnostics.contains_errors()` will be true, and the
 /// `program.Diagnostics()` will describe the error.
 /// @param input the source data
+/// @param options the parser options
 /// @returns the parsed program
-Program Parse(const std::vector<uint32_t>& input);
+Program Parse(const std::vector<uint32_t>& input, const Options& options = {});
 
 }  // namespace tint::reader::spirv
 
diff --git a/src/tint/reader/spirv/parser_test.cc b/src/tint/reader/spirv/parser_test.cc
index 35cb5da..3f5e370 100644
--- a/src/tint/reader/spirv/parser_test.cc
+++ b/src/tint/reader/spirv/parser_test.cc
@@ -14,7 +14,9 @@
 
 #include "src/tint/reader/spirv/parser.h"
 
+#include "gmock/gmock.h"
 #include "gtest/gtest.h"
+#include "src/tint/reader/spirv/spirv_tools_helpers_test.h"
 
 namespace tint::reader::spirv {
 namespace {
@@ -29,6 +31,54 @@
     EXPECT_EQ(errs, "error: line:0: Invalid SPIR-V magic number.\n");
 }
 
+constexpr auto kShaderWithNonUniformDerivative = R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %foo "foo" %x
+               OpExecutionMode %foo OriginUpperLeft
+               OpDecorate %x Location 0
+      %float = OpTypeFloat 32
+%_ptr_Input_float = OpTypePointer Input %float
+          %x = OpVariable %_ptr_Input_float Input
+       %void = OpTypeVoid
+    %float_0 = OpConstantNull %float
+       %bool = OpTypeBool
+  %func_type = OpTypeFunction %void
+        %foo = OpFunction %void None %func_type
+  %foo_start = OpLabel
+    %x_value = OpLoad %float %x
+  %condition = OpFOrdGreaterThan %bool %x_value %float_0
+               OpSelectionMerge %merge None
+               OpBranchConditional %condition %true_branch %merge
+%true_branch = OpLabel
+     %result = OpDPdx %float %x_value
+               OpBranch %merge
+      %merge = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+
+TEST_F(ParserTest, AllowNonUniformDerivatives_False) {
+    auto spv = test::Assemble(kShaderWithNonUniformDerivative);
+    Options options;
+    options.allow_non_uniform_derivatives = false;
+    auto program = Parse(spv, options);
+    auto errs = diag::Formatter().format(program.Diagnostics());
+    // TODO(jrprice): This will become EXPECT_FALSE.
+    EXPECT_TRUE(program.IsValid()) << errs;
+    EXPECT_THAT(errs, ::testing::HasSubstr("'dpdx' must only be called from uniform control flow"));
+}
+
+TEST_F(ParserTest, AllowNonUniformDerivatives_True) {
+    auto spv = test::Assemble(kShaderWithNonUniformDerivative);
+    Options options;
+    options.allow_non_uniform_derivatives = true;
+    auto program = Parse(spv, options);
+    auto errs = diag::Formatter().format(program.Diagnostics());
+    EXPECT_TRUE(program.IsValid()) << errs;
+    EXPECT_EQ(program.Diagnostics().count(), 0u) << errs;
+}
+
 // TODO(dneto): uint32 vec, valid SPIR-V
 // TODO(dneto): uint32 vec, invalid SPIR-V