[fuzzers] Add checks that bad SPIRV isn't getting through

BUG=tint:963

Change-Id: I3cac636c194a36581f372ee22acad36d5e94eb07
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/57500
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
Kokoro: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/fuzzers/tint_common_fuzzer.cc b/fuzzers/tint_common_fuzzer.cc
index 7469150..c663648 100644
--- a/fuzzers/tint_common_fuzzer.cc
+++ b/fuzzers/tint_common_fuzzer.cc
@@ -16,10 +16,15 @@
 
 #include <cstring>
 #include <memory>
+#include <sstream>
 #include <string>
 #include <utility>
 #include <vector>
 
+#if TINT_BUILD_SPV_READER
+#include "spirv-tools/libspirv.hpp"
+#endif  // TINT_BUILD_SPV_READER
+
 #include "src/ast/module.h"
 #include "src/diagnostic/formatter.h"
 #include "src/program.h"
@@ -29,21 +34,19 @@
 
 namespace {
 
-[[noreturn]] void TintInternalCompilerErrorReporter(
-    const tint::diag::List& diagnostics) {
+[[noreturn]] void FatalError(const tint::diag::List& diags,
+                             std::string msg = "") {
   auto printer = tint::diag::Printer::create(stderr, true);
-  tint::diag::Formatter{}.format(diagnostics, printer.get());
+  if (msg.size()) {
+    printer->write((msg + "\n").c_str(), {diag::Color::kRed, true});
+  }
+  tint::diag::Formatter().format(diags, printer.get());
   __builtin_trap();
 }
 
-[[noreturn]] void ValidityErrorReporter(const tint::diag::List& diags) {
-  auto printer = tint::diag::Printer::create(stderr, true);
-  printer->write(
-      "Fuzzing detected valid input program being transformed into an invalid "
-      "output progam\n",
-      {diag::Color::kRed, true});
-  tint::diag::Formatter().format(diags, printer.get());
-  __builtin_trap();
+[[noreturn]] void TintInternalCompilerErrorReporter(
+    const tint::diag::List& diagnostics) {
+  FatalError(diagnostics);
 }
 
 transform::VertexAttributeDescriptor ExtractVertexAttributeDescriptor(
@@ -66,6 +69,26 @@
   return desc;
 }
 
+bool SPIRVToolsValidationCheck(const tint::Program& program,
+                               std::vector<uint32_t> spirv) {
+  spvtools::SpirvTools tools(SPV_ENV_VULKAN_1_1);
+  const tint::diag::List& diags = program.Diagnostics();
+  tools.SetMessageConsumer([diags](spv_message_level_t, const char*,
+                                   const spv_position_t& pos, const char* msg) {
+    std::stringstream out;
+    out << "Unexpected spirv-val error:\n"
+        << (pos.line + 1) << ":" << (pos.column + 1) << ": " << msg
+        << std::endl;
+
+    auto printer = tint::diag::Printer::create(stderr, true);
+    printer->write(out.str(), {diag::Color::kYellow, false});
+    tint::diag::Formatter().format(diags, printer.get());
+  });
+
+  return tools.Validate(spirv.data(), spirv.size(),
+                        spvtools::ValidatorOptions());
+}
+
 }  // namespace
 
 Reader::Reader(const uint8_t* data, size_t size) : data_(data), size_(size) {}
@@ -162,6 +185,13 @@
   std::unique_ptr<Source::File> file;
 #endif  // TINT_BUILD_WGSL_READER
 
+#if TINT_BUILD_SPV_READER
+  size_t u32_size = size / sizeof(uint32_t);
+  const uint32_t* u32_data = reinterpret_cast<const uint32_t*>(data);
+  std::vector<uint32_t> spirv_input(u32_data, u32_data + u32_size);
+
+#endif  // TINT_BUILD_SPV_READER
+
   switch (input_) {
 #if TINT_BUILD_WGSL_READER
     case InputFormat::kWGSL: {
@@ -173,16 +203,12 @@
 #endif  // TINT_BUILD_WGSL_READER
 #if TINT_BUILD_SPV_READER
     case InputFormat::kSpv: {
-      size_t sizeInU32 = size / sizeof(uint32_t);
-      const uint32_t* u32Data = reinterpret_cast<const uint32_t*>(data);
-      std::vector<uint32_t> input(u32Data, u32Data + sizeInU32);
-
-      if (input.size() != 0) {
-        program = reader::spirv::Parse(input);
+      if (spirv_input.size() != 0) {
+        program = reader::spirv::Parse(spirv_input);
       }
       break;
     }
-#endif  // TINT_BUILD_WGSL_READER
+#endif  // TINT_BUILD_SPV_READER
     default:
       return 0;
   }
@@ -196,6 +222,14 @@
     return 0;
   }
 
+#if TINT_BUILD_SPV_READER
+  if (input_ == InputFormat::kSpv &&
+      !SPIRVToolsValidationCheck(program, spirv_input)) {
+    FatalError(program.Diagnostics(),
+               "Fuzzing detected invalid input spirv not being caught by Tint");
+  }
+#endif  // TINT_BUILD_SPV_READER
+
   if (inspector_enabled_) {
     inspector::Inspector inspector(&program);
 
@@ -276,7 +310,9 @@
       for (auto diag : out.program.Diagnostics()) {
         if (diag.severity > diag::Severity::Error ||
             diag.system != diag::System::Transform) {
-          ValidityErrorReporter(out.program.Diagnostics());
+          FatalError(out.program.Diagnostics(),
+                     "Fuzzing detected valid input program being transformed "
+                     "into an invalid output program");
         }
       }
     }
@@ -314,6 +350,16 @@
       errors_ = writer_->error();
       return 0;
     }
+
+#if TINT_BUILD_SPV_WRITER
+    if (output_ == OutputFormat::kSpv &&
+        !SPIRVToolsValidationCheck(
+            program,
+            static_cast<writer::spirv::Generator*>(writer_.get())->result())) {
+      FatalError(program.Diagnostics(),
+                 "Fuzzing detected invalid spirv being emitted by Tint");
+    }
+#endif  // TINT_BUILD_SPV_WRITER
   }
 
   return 0;