Validate SPIR-V code when creating ShaderModules
This integrates spirv-val in dawn_native so that regular and
WebGPU-specific validation of shaders is done.
Also adds tests to check OpUndef is correctly rejected so we know
WebGPU-specific validation is working.
Change-Id: If49d276c98bca8cd3c6c1a420903fe34923a2942
diff --git a/BUILD.gn b/BUILD.gn
index fad9992..dc6641e 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -314,6 +314,7 @@
deps = [
":dawn_common",
":libdawn_native_utils_gen",
+ "${dawn_spirv_tools_dir}:spvtools_val",
"third_party:spirv_cross",
]
@@ -772,6 +773,7 @@
"src/tests/unittests/validation/PushConstantsValidationTests.cpp",
"src/tests/unittests/validation/RenderPassDescriptorValidationTests.cpp",
"src/tests/unittests/validation/RenderPipelineValidationTests.cpp",
+ "src/tests/unittests/validation/ShaderModuleValidationTests.cpp",
"src/tests/unittests/validation/ValidationTest.cpp",
"src/tests/unittests/validation/ValidationTest.h",
"src/tests/unittests/validation/VertexBufferValidationTests.cpp",
diff --git a/src/dawn_native/CMakeLists.txt b/src/dawn_native/CMakeLists.txt
index a24999d..031728a 100644
--- a/src/dawn_native/CMakeLists.txt
+++ b/src/dawn_native/CMakeLists.txt
@@ -31,8 +31,8 @@
)
set(DAWN_NATIVE_SOURCES)
-set(DAWN_NATIVE_DEPS dawn_common spirv_cross dawn_native_utils_autogen)
-set(DAWN_NATIVE_INCLUDE_DIRS ${SPIRV_CROSS_INCLUDE_DIR})
+set(DAWN_NATIVE_DEPS dawn_common spirv_cross dawn_native_utils_autogen SPIRV-Tools)
+set(DAWN_NATIVE_INCLUDE_DIRS ${SPIRV_CROSS_INCLUDE_DIR} ${SPIRV_TOOLS_INCLUDE_DIR})
################################################################################
# OpenGL Backend
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index f7f7a24..faf9551 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -20,13 +20,44 @@
#include "dawn_native/PipelineLayout.h"
#include <spirv-cross/spirv_cross.hpp>
+#include <spirv-tools/libspirv.hpp>
namespace dawn_native {
MaybeError ValidateShaderModuleDescriptor(DeviceBase*,
const ShaderModuleDescriptor* descriptor) {
DAWN_TRY_ASSERT(descriptor->nextInChain == nullptr, "nextInChain must be nullptr");
- // TODO(cwallez@chromium.org): Use spirv-val to check the module is well-formed
+
+ spvtools::SpirvTools spirvTools(SPV_ENV_WEBGPU_0);
+
+ std::ostringstream errorStream;
+ errorStream << "SPIRV Validation failure:" << std::endl;
+
+ spirvTools.SetMessageConsumer([&errorStream](spv_message_level_t level, const char*,
+ const spv_position_t& position,
+ const char* message) {
+ switch (level) {
+ case SPV_MSG_FATAL:
+ case SPV_MSG_INTERNAL_ERROR:
+ case SPV_MSG_ERROR:
+ errorStream << "error: line " << position.index << ": " << message << std::endl;
+ break;
+ case SPV_MSG_WARNING:
+ errorStream << "warning: line " << position.index << ": " << message
+ << std::endl;
+ break;
+ case SPV_MSG_INFO:
+ errorStream << "info: line " << position.index << ": " << message << std::endl;
+ break;
+ default:
+ break;
+ }
+ });
+
+ if (!spirvTools.Validate(descriptor->code, descriptor->codeSize)) {
+ DAWN_RETURN_ERROR(errorStream.str().c_str());
+ }
+
return {};
}
diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt
index 27828e2..68c573b 100644
--- a/src/tests/CMakeLists.txt
+++ b/src/tests/CMakeLists.txt
@@ -53,6 +53,7 @@
${VALIDATION_TESTS_DIR}/PushConstantsValidationTests.cpp
${VALIDATION_TESTS_DIR}/RenderPassDescriptorValidationTests.cpp
${VALIDATION_TESTS_DIR}/RenderPipelineValidationTests.cpp
+ ${VALIDATION_TESTS_DIR}/ShaderModuleValidationTests.cpp
${VALIDATION_TESTS_DIR}/VertexBufferValidationTests.cpp
${VALIDATION_TESTS_DIR}/ValidationTest.cpp
${VALIDATION_TESTS_DIR}/ValidationTest.h
diff --git a/src/tests/unittests/validation/ShaderModuleValidationTests.cpp b/src/tests/unittests/validation/ShaderModuleValidationTests.cpp
new file mode 100644
index 0000000..5bf4a216
--- /dev/null
+++ b/src/tests/unittests/validation/ShaderModuleValidationTests.cpp
@@ -0,0 +1,89 @@
+// Copyright 2018 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tests/unittests/validation/ValidationTest.h"
+
+#include "utils/DawnHelpers.h"
+
+class ShaderModuleValidationTest : public ValidationTest {
+};
+
+// Test case with a simpler shader that should successfully be created
+TEST_F(ShaderModuleValidationTest, CreationSuccess) {
+ const char* shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main" %fragColor
+ OpExecutionMode %main OriginUpperLeft
+ OpSource GLSL 450
+ OpSourceExtension "GL_GOOGLE_cpp_style_line_directive"
+ OpSourceExtension "GL_GOOGLE_include_directive"
+ OpName %main "main"
+ OpName %fragColor "fragColor"
+ OpDecorate %fragColor Location 0
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v4float = OpTypeVector %float 4
+ %_ptr_Output_v4float = OpTypePointer Output %v4float
+ %fragColor = OpVariable %_ptr_Output_v4float Output
+ %float_1 = OpConstant %float 1
+ %float_0 = OpConstant %float 0
+ %12 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1
+ %main = OpFunction %void None %3
+ %5 = OpLabel
+ OpStore %fragColor %12
+ OpReturn
+ OpFunctionEnd)";
+
+ utils::CreateShaderModuleFromASM(device, shader);
+}
+
+// Test case with a shader with OpUndef to test WebGPU-specific validation
+TEST_F(ShaderModuleValidationTest, OpUndef) {
+ const char* shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main" %fragColor
+ OpExecutionMode %main OriginUpperLeft
+ OpSource GLSL 450
+ OpSourceExtension "GL_GOOGLE_cpp_style_line_directive"
+ OpSourceExtension "GL_GOOGLE_include_directive"
+ OpName %main "main"
+ OpName %fragColor "fragColor"
+ OpDecorate %fragColor Location 0
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v4float = OpTypeVector %float 4
+ %_ptr_Output_v4float = OpTypePointer Output %v4float
+ %fragColor = OpVariable %_ptr_Output_v4float Output
+ %float_1 = OpConstant %float 1
+ %float_0 = OpConstant %float 0
+ %12 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1
+ %main = OpFunction %void None %3
+ %5 = OpLabel
+ %6 = OpUndef %v4float
+ OpStore %fragColor %12
+ OpReturn
+ OpFunctionEnd)";
+
+ // Notice "%6 = OpUndef %v4float" above
+ ASSERT_DEVICE_ERROR(utils::CreateShaderModuleFromASM(device, shader));
+
+ std::string error = GetLastDeviceErrorMessage();
+ ASSERT_NE(error.find("OpUndef"), std::string::npos);
+}
diff --git a/src/tests/unittests/validation/ValidationTest.cpp b/src/tests/unittests/validation/ValidationTest.cpp
index f511b5f..85ba237 100644
--- a/src/tests/unittests/validation/ValidationTest.cpp
+++ b/src/tests/unittests/validation/ValidationTest.cpp
@@ -69,6 +69,9 @@
mExpectError = false;
return mError;
}
+std::string ValidationTest::GetLastDeviceErrorMessage() const {
+ return mDeviceErrorMessage;
+}
dawn::RenderPassDescriptor ValidationTest::CreateSimpleRenderPass() {
dawn::TextureDescriptor descriptor;
@@ -91,14 +94,16 @@
}
void ValidationTest::OnDeviceError(const char* message, dawnCallbackUserdata userdata) {
+ auto self = reinterpret_cast<ValidationTest*>(static_cast<uintptr_t>(userdata));
+ self->mDeviceErrorMessage = message;
+
// Skip this one specific error that is raised when a builder is used after it got an error
// this is important because we don't want to wrap all creation tests in ASSERT_DEVICE_ERROR.
// Yes the error message is misleading.
- if (std::string(message) == "Builder cannot be used after GetResult") {
+ if (self->mDeviceErrorMessage == "Builder cannot be used after GetResult") {
return;
}
- auto self = reinterpret_cast<ValidationTest*>(static_cast<uintptr_t>(userdata));
ASSERT_TRUE(self->mExpectError) << "Got unexpected device error: " << message;
ASSERT_FALSE(self->mError) << "Got two errors in expect block";
self->mError = true;
diff --git a/src/tests/unittests/validation/ValidationTest.h b/src/tests/unittests/validation/ValidationTest.h
index 2189b10..c6bd908 100644
--- a/src/tests/unittests/validation/ValidationTest.h
+++ b/src/tests/unittests/validation/ValidationTest.h
@@ -47,6 +47,7 @@
void StartExpectDeviceError();
bool EndExpectDeviceError();
+ std::string GetLastDeviceErrorMessage() const;
dawn::RenderPassDescriptor CreateSimpleRenderPass();
@@ -66,6 +67,7 @@
private:
static void OnDeviceError(const char* message, dawnCallbackUserdata userdata);
+ std::string mDeviceErrorMessage;
bool mExpectError = false;
bool mError = false;
diff --git a/src/utils/DawnHelpers.cpp b/src/utils/DawnHelpers.cpp
index 01daa9b..0871f26 100644
--- a/src/utils/DawnHelpers.cpp
+++ b/src/utils/DawnHelpers.cpp
@@ -25,46 +25,53 @@
namespace utils {
+ namespace {
+
+ shaderc_shader_kind ShadercShaderKind(dawn::ShaderStage stage) {
+ switch (stage) {
+ case dawn::ShaderStage::Vertex:
+ return shaderc_glsl_vertex_shader;
+ case dawn::ShaderStage::Fragment:
+ return shaderc_glsl_fragment_shader;
+ case dawn::ShaderStage::Compute:
+ return shaderc_glsl_compute_shader;
+ default:
+ UNREACHABLE();
+ }
+ }
+
+ dawn::ShaderModule CreateShaderModuleFromResult(
+ const dawn::Device& device,
+ const shaderc::SpvCompilationResult& result) {
+ // result.cend and result.cbegin return pointers to uint32_t.
+ const uint32_t* resultBegin = result.cbegin();
+ const uint32_t* resultEnd = result.cend();
+ // So this size is in units of sizeof(uint32_t).
+ ptrdiff_t resultSize = resultEnd - resultBegin;
+ // SetSource takes data as uint32_t*.
+
+ dawn::ShaderModuleDescriptor descriptor;
+ descriptor.codeSize = static_cast<uint32_t>(resultSize);
+ descriptor.code = result.cbegin();
+ return device.CreateShaderModule(&descriptor);
+ }
+
+ } // anonymous namespace
+
dawn::ShaderModule CreateShaderModule(const dawn::Device& device,
dawn::ShaderStage stage,
const char* source) {
+ shaderc_shader_kind kind = ShadercShaderKind(stage);
+
shaderc::Compiler compiler;
- shaderc::CompileOptions options;
-
- shaderc_shader_kind kind;
- switch (stage) {
- case dawn::ShaderStage::Vertex:
- kind = shaderc_glsl_vertex_shader;
- break;
- case dawn::ShaderStage::Fragment:
- kind = shaderc_glsl_fragment_shader;
- break;
- case dawn::ShaderStage::Compute:
- kind = shaderc_glsl_compute_shader;
- break;
- default:
- UNREACHABLE();
- }
-
- auto result = compiler.CompileGlslToSpv(source, strlen(source), kind, "myshader?", options);
+ auto result = compiler.CompileGlslToSpv(source, strlen(source), kind, "myshader?");
if (result.GetCompilationStatus() != shaderc_compilation_status_success) {
std::cerr << result.GetErrorMessage();
return {};
}
-
- // result.cend and result.cbegin return pointers to uint32_t.
- const uint32_t* resultBegin = result.cbegin();
- const uint32_t* resultEnd = result.cend();
- // So this size is in units of sizeof(uint32_t).
- ptrdiff_t resultSize = resultEnd - resultBegin;
- // SetSource takes data as uint32_t*.
-
- dawn::ShaderModuleDescriptor descriptor;
- descriptor.codeSize = static_cast<uint32_t>(resultSize);
- descriptor.code = result.cbegin();
-
#ifdef DUMP_SPIRV_ASSEMBLY
{
+ shaderc::CompileOptions options;
auto resultAsm = compiler.CompileGlslToSpvAssembly(source, strlen(source), kind,
"myshader?", options);
size_t sizeAsm = (resultAsm.cend() - resultAsm.cbegin());
@@ -91,7 +98,18 @@
printf("SPIRV JS ARRAY DUMP END\n");
#endif
- return device.CreateShaderModule(&descriptor);
+ return CreateShaderModuleFromResult(device, result);
+ }
+
+ dawn::ShaderModule CreateShaderModuleFromASM(const dawn::Device& device, const char* source) {
+ shaderc::Compiler compiler;
+ shaderc::SpvCompilationResult result = compiler.AssembleToSpv(source, strlen(source));
+ if (result.GetCompilationStatus() != shaderc_compilation_status_success) {
+ std::cerr << result.GetErrorMessage();
+ return {};
+ }
+
+ return CreateShaderModuleFromResult(device, result);
}
dawn::Buffer CreateBufferFromData(const dawn::Device& device,
diff --git a/src/utils/DawnHelpers.h b/src/utils/DawnHelpers.h
index 88698ce..03a3e30 100644
--- a/src/utils/DawnHelpers.h
+++ b/src/utils/DawnHelpers.h
@@ -21,6 +21,8 @@
dawn::ShaderModule CreateShaderModule(const dawn::Device& device,
dawn::ShaderStage stage,
const char* source);
+ dawn::ShaderModule CreateShaderModuleFromASM(const dawn::Device& device, const char* source);
+
dawn::Buffer CreateBufferFromData(const dawn::Device& device,
const void* data,
uint32_t size,
diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt
index 4d70444..61e614c 100644
--- a/third_party/CMakeLists.txt
+++ b/third_party/CMakeLists.txt
@@ -41,6 +41,9 @@
target_include_directories(glad SYSTEM PUBLIC ${GLAD_INCLUDE_DIR})
DawnExternalTarget("third_party" glad)
+# SPIRV-Tools
+set(SPIRV_TOOLS_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/spirv-tools/include PARENT_SCOPE)
+
# ShaderC
# Prevent SPIRV-Tools from using Werror as it has a warning on MSVC
set(SPIRV_WERROR OFF CACHE BOOL "" FORCE)