[fuzzers] Add checks that bad SPIRV isn't getting through
BUG=tint:963
Change-Id: I3cac636c194a36581f372ee22acad36d5e94eb07
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/57500
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
Kokoro: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/fuzzers/tint_common_fuzzer.cc b/fuzzers/tint_common_fuzzer.cc
index 7469150..c663648 100644
--- a/fuzzers/tint_common_fuzzer.cc
+++ b/fuzzers/tint_common_fuzzer.cc
@@ -16,10 +16,15 @@
#include <cstring>
#include <memory>
+#include <sstream>
#include <string>
#include <utility>
#include <vector>
+#if TINT_BUILD_SPV_READER
+#include "spirv-tools/libspirv.hpp"
+#endif // TINT_BUILD_SPV_READER
+
#include "src/ast/module.h"
#include "src/diagnostic/formatter.h"
#include "src/program.h"
@@ -29,21 +34,19 @@
namespace {
-[[noreturn]] void TintInternalCompilerErrorReporter(
- const tint::diag::List& diagnostics) {
+[[noreturn]] void FatalError(const tint::diag::List& diags,
+ std::string msg = "") {
auto printer = tint::diag::Printer::create(stderr, true);
- tint::diag::Formatter{}.format(diagnostics, printer.get());
+ if (msg.size()) {
+ printer->write((msg + "\n").c_str(), {diag::Color::kRed, true});
+ }
+ tint::diag::Formatter().format(diags, printer.get());
__builtin_trap();
}
-[[noreturn]] void ValidityErrorReporter(const tint::diag::List& diags) {
- auto printer = tint::diag::Printer::create(stderr, true);
- printer->write(
- "Fuzzing detected valid input program being transformed into an invalid "
- "output progam\n",
- {diag::Color::kRed, true});
- tint::diag::Formatter().format(diags, printer.get());
- __builtin_trap();
+[[noreturn]] void TintInternalCompilerErrorReporter(
+ const tint::diag::List& diagnostics) {
+ FatalError(diagnostics);
}
transform::VertexAttributeDescriptor ExtractVertexAttributeDescriptor(
@@ -66,6 +69,26 @@
return desc;
}
+bool SPIRVToolsValidationCheck(const tint::Program& program,
+ std::vector<uint32_t> spirv) {
+ spvtools::SpirvTools tools(SPV_ENV_VULKAN_1_1);
+ const tint::diag::List& diags = program.Diagnostics();
+ tools.SetMessageConsumer([diags](spv_message_level_t, const char*,
+ const spv_position_t& pos, const char* msg) {
+ std::stringstream out;
+ out << "Unexpected spirv-val error:\n"
+ << (pos.line + 1) << ":" << (pos.column + 1) << ": " << msg
+ << std::endl;
+
+ auto printer = tint::diag::Printer::create(stderr, true);
+ printer->write(out.str(), {diag::Color::kYellow, false});
+ tint::diag::Formatter().format(diags, printer.get());
+ });
+
+ return tools.Validate(spirv.data(), spirv.size(),
+ spvtools::ValidatorOptions());
+}
+
} // namespace
Reader::Reader(const uint8_t* data, size_t size) : data_(data), size_(size) {}
@@ -162,6 +185,13 @@
std::unique_ptr<Source::File> file;
#endif // TINT_BUILD_WGSL_READER
+#if TINT_BUILD_SPV_READER
+ size_t u32_size = size / sizeof(uint32_t);
+ const uint32_t* u32_data = reinterpret_cast<const uint32_t*>(data);
+ std::vector<uint32_t> spirv_input(u32_data, u32_data + u32_size);
+
+#endif // TINT_BUILD_SPV_READER
+
switch (input_) {
#if TINT_BUILD_WGSL_READER
case InputFormat::kWGSL: {
@@ -173,16 +203,12 @@
#endif // TINT_BUILD_WGSL_READER
#if TINT_BUILD_SPV_READER
case InputFormat::kSpv: {
- size_t sizeInU32 = size / sizeof(uint32_t);
- const uint32_t* u32Data = reinterpret_cast<const uint32_t*>(data);
- std::vector<uint32_t> input(u32Data, u32Data + sizeInU32);
-
- if (input.size() != 0) {
- program = reader::spirv::Parse(input);
+ if (spirv_input.size() != 0) {
+ program = reader::spirv::Parse(spirv_input);
}
break;
}
-#endif // TINT_BUILD_WGSL_READER
+#endif // TINT_BUILD_SPV_READER
default:
return 0;
}
@@ -196,6 +222,14 @@
return 0;
}
+#if TINT_BUILD_SPV_READER
+ if (input_ == InputFormat::kSpv &&
+ !SPIRVToolsValidationCheck(program, spirv_input)) {
+ FatalError(program.Diagnostics(),
+ "Fuzzing detected invalid input spirv not being caught by Tint");
+ }
+#endif // TINT_BUILD_SPV_READER
+
if (inspector_enabled_) {
inspector::Inspector inspector(&program);
@@ -276,7 +310,9 @@
for (auto diag : out.program.Diagnostics()) {
if (diag.severity > diag::Severity::Error ||
diag.system != diag::System::Transform) {
- ValidityErrorReporter(out.program.Diagnostics());
+ FatalError(out.program.Diagnostics(),
+ "Fuzzing detected valid input program being transformed "
+ "into an invalid output program");
}
}
}
@@ -314,6 +350,16 @@
errors_ = writer_->error();
return 0;
}
+
+#if TINT_BUILD_SPV_WRITER
+ if (output_ == OutputFormat::kSpv &&
+ !SPIRVToolsValidationCheck(
+ program,
+ static_cast<writer::spirv::Generator*>(writer_.get())->result())) {
+ FatalError(program.Diagnostics(),
+ "Fuzzing detected invalid spirv being emitted by Tint");
+ }
+#endif // TINT_BUILD_SPV_WRITER
}
return 0;