Deprecate ShaderModuleDescriptor.code in favor of chained descriptor

This also adds the definition of the WGSL sub descriptor but forbids
using it for now.

Bug: dawn:22
Change-Id: I0514eec95bbcda28911547d6bda4d5257b62432b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/19865
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/dawn.json b/dawn.json
index 29b99c7..5f49e04 100644
--- a/dawn.json
+++ b/dawn.json
@@ -1273,10 +1273,25 @@
         "extensible": true,
         "members": [
             {"name": "label", "type": "char", "annotation": "const*", "length": "strlen", "optional": true},
+            {"name": "code size", "type": "uint32_t",  "default": 0},
+            {"name": "code", "type": "uint32_t", "annotation": "const*", "length": "code size", "optional": true}
+        ]
+    },
+    "shader module SPIRV descriptor": {
+        "category": "structure",
+        "chained": true,
+        "members": [
             {"name": "code size", "type": "uint32_t"},
             {"name": "code", "type": "uint32_t", "annotation": "const*", "length": "code size"}
         ]
     },
+    "shader module WGSL descriptor": {
+        "category": "structure",
+        "chained": true,
+        "members": [
+            {"name": "source", "type": "char", "annotation": "const*", "length": "strlen"}
+        ]
+    },
     "shader stage": {
         "category": "bitmask",
         "values": [
@@ -1390,8 +1405,10 @@
             {"value": 2, "name": "surface descriptor from windows HWND"},
             {"value": 3, "name": "surface descriptor from xlib"},
             {"value": 4, "name": "surface descriptor from HTML canvas id"},
-            {"value": 5, "name": "sampler descriptor dummy anisotropic filtering"},
-            {"value": 6, "name": "render pipeline descriptor dummy extension"}
+            {"value": 5, "name": "shader module SPIRV descriptor"},
+            {"value": 6, "name": "shader module WGSL descriptor"},
+            {"value": 7, "name": "sampler descriptor dummy anisotropic filtering"},
+            {"value": 8, "name": "render pipeline descriptor dummy extension"}
         ]
     },
     "texture": {
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 713dd95..b72ce35 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -282,12 +282,7 @@
         }
     }  // anonymous namespace
 
-    MaybeError ValidateShaderModuleDescriptor(DeviceBase*,
-                                              const ShaderModuleDescriptor* descriptor) {
-        if (descriptor->nextInChain != nullptr) {
-            return DAWN_VALIDATION_ERROR("nextInChain must be nullptr");
-        }
-
+    MaybeError ValidateSpirv(DeviceBase*, const uint32_t* code, uint32_t codeSize) {
         spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1);
 
         std::ostringstream errorStream;
@@ -314,17 +309,68 @@
             }
         });
 
-        if (!spirvTools.Validate(descriptor->code, descriptor->codeSize)) {
+        if (!spirvTools.Validate(code, codeSize)) {
             return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
         }
 
         return {};
+    }
+
+    MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
+                                              const ShaderModuleDescriptor* descriptor) {
+        if (descriptor->codeSize != 0) {
+            if (descriptor->nextInChain != nullptr) {
+                return DAWN_VALIDATION_ERROR("Cannot set both code/codeSize and nextInChain");
+            }
+
+            device->EmitDeprecationWarning(
+                "ShaderModuleDescriptor::code/codeSize is deprecated, chain "
+                "ShaderModuleSPIRVDescriptor instead.");
+            return ValidateSpirv(device, descriptor->code, descriptor->codeSize);
+        }
+
+        // For now only a single SPIRV or WGSL subdescriptor is allowed.
+        const ChainedStruct* chainedDescriptor = descriptor->nextInChain;
+        if (chainedDescriptor->nextInChain != nullptr) {
+            return DAWN_VALIDATION_ERROR("chained nextInChain must be nullptr");
+        }
+
+        switch (chainedDescriptor->sType) {
+            case wgpu::SType::ShaderModuleSPIRVDescriptor: {
+                const ShaderModuleSPIRVDescriptor* spirvDesc =
+                    static_cast<const ShaderModuleSPIRVDescriptor*>(chainedDescriptor);
+                DAWN_TRY(ValidateSpirv(device, spirvDesc->code, spirvDesc->codeSize));
+                break;
+            }
+
+            case wgpu::SType::ShaderModuleWGSLDescriptor: {
+                return DAWN_VALIDATION_ERROR("WGSL not supported (yet)");
+                break;
+            }
+
+            default:
+                return DAWN_VALIDATION_ERROR("Unsupported sType");
+        }
+
+        return {};
     }  // namespace
 
     // ShaderModuleBase
 
     ShaderModuleBase::ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor)
-        : CachedObject(device), mSpirv(descriptor->code, descriptor->code + descriptor->codeSize) {
+        : CachedObject(device) {
+        // Extract the correct SPIRV from the descriptor.
+        if (descriptor->codeSize != 0) {
+            mSpirv.assign(descriptor->code, descriptor->code + descriptor->codeSize);
+        } else {
+            ASSERT(descriptor->nextInChain != nullptr);
+            ASSERT(descriptor->nextInChain->sType == wgpu::SType::ShaderModuleSPIRVDescriptor);
+
+            const ShaderModuleSPIRVDescriptor* spirvDesc =
+                static_cast<const ShaderModuleSPIRVDescriptor*>(descriptor->nextInChain);
+            mSpirv.assign(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
+        }
+
         mFragmentOutputFormatBaseTypes.fill(Format::Other);
         if (GetDevice()->IsToggleEnabled(Toggle::UseSpvcParser)) {
             mSpvcContext.SetUseSpvcParser(true);
diff --git a/src/tests/end2end/DeprecatedAPITests.cpp b/src/tests/end2end/DeprecatedAPITests.cpp
index 23fc1b9..17d5e14 100644
--- a/src/tests/end2end/DeprecatedAPITests.cpp
+++ b/src/tests/end2end/DeprecatedAPITests.cpp
@@ -19,6 +19,7 @@
 
 #include "tests/DawnTest.h"
 
+#include "utils/ComboRenderPipelineDescriptor.h"
 #include "utils/WGPUHelpers.h"
 
 class DeprecationTests : public DawnTest {
@@ -307,6 +308,70 @@
     EXPECT_DEPRECATION_WARNING(ASSERT_DEVICE_ERROR(device.CreateBindGroup(&bgDesc)));
 }
 
+// Tests for ShaderModuleDescriptor.code/codeSize -> ShaderModuleSPIRVDescriptor
+
+static const char kEmptyShader[] = R"(#version 450
+void main() {
+})";
+
+// That creating a ShaderModule without the chained descriptor gives a warning.
+TEST_P(DeprecationTests, ShaderModuleNoSubDescriptorIsDeprecated) {
+    std::vector<uint32_t> spirv =
+        CompileGLSLToSpirv(utils::SingleShaderStage::Compute, kEmptyShader);
+
+    wgpu::ShaderModuleDescriptor descriptor = {
+        .codeSize = static_cast<uint32_t>(spirv.size()),
+        .code = spirv.data(),
+    };
+    EXPECT_DEPRECATION_WARNING(device.CreateShaderModule(&descriptor));
+}
+
+// That creating a ShaderModule with both inline code and the chained descriptor is an error.
+TEST_P(DeprecationTests, ShaderModuleBothInlinedAndChainedIsInvalid) {
+    std::vector<uint32_t> spirv =
+        CompileGLSLToSpirv(utils::SingleShaderStage::Compute, kEmptyShader);
+
+    wgpu::ShaderModuleSPIRVDescriptor spirvDesc;
+    spirvDesc.codeSize = static_cast<uint32_t>(spirv.size());
+    spirvDesc.code = spirv.data();
+
+    wgpu::ShaderModuleDescriptor descriptor = {
+        .nextInChain = &spirvDesc,
+        .codeSize = static_cast<uint32_t>(spirv.size()),
+        .code = spirv.data(),
+    };
+    ASSERT_DEVICE_ERROR(device.CreateShaderModule(&descriptor));
+}
+
+// That creating a ShaderModule with both inline code still does correct state tracking
+TEST_P(DeprecationTests, ShaderModuleInlinedCodeStateTracking) {
+    std::vector<uint32_t> spirv =
+        CompileGLSLToSpirv(utils::SingleShaderStage::Compute, kEmptyShader);
+
+    wgpu::ShaderModuleDescriptor descriptor = {
+        .codeSize = static_cast<uint32_t>(spirv.size()),
+        .code = spirv.data(),
+    };
+    wgpu::ShaderModule module;
+    EXPECT_DEPRECATION_WARNING(module = device.CreateShaderModule(&descriptor));
+
+    // Creating a compute pipeline works, because it is a compute module.
+    wgpu::ComputePipelineDescriptor computePipelineDesc = {
+        .computeStage =
+            {
+                .module = module,
+                .entryPoint = "main",
+            },
+    };
+    device.CreateComputePipeline(&computePipelineDesc);
+
+    utils::ComboRenderPipelineDescriptor renderPipelineDesc(device);
+    renderPipelineDesc.vertexStage.module =
+        utils::CreateShaderModule(device, utils::SingleShaderStage::Vertex, kEmptyShader);
+    renderPipelineDesc.cFragmentStage.module = module;
+    ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&renderPipelineDesc));
+}
+
 DAWN_INSTANTIATE_TEST(DeprecationTests,
                       D3D12Backend(),
                       MetalBackend(),
diff --git a/src/tests/unittests/wire/WireArgumentTests.cpp b/src/tests/unittests/wire/WireArgumentTests.cpp
index fe2b141..23d895d 100644
--- a/src/tests/unittests/wire/WireArgumentTests.cpp
+++ b/src/tests/unittests/wire/WireArgumentTests.cpp
@@ -96,7 +96,6 @@
 TEST_F(WireArgumentTests, CStringArgument) {
     // Create shader module
     WGPUShaderModuleDescriptor vertexDescriptor = {};
-    vertexDescriptor.codeSize = 0;
     WGPUShaderModule vsModule = wgpuDeviceCreateShaderModule(device, &vertexDescriptor);
     WGPUShaderModule apiVsModule = api.GetNewShaderModule();
     EXPECT_CALL(api, DeviceCreateShaderModule(apiDevice, _)).WillOnce(Return(apiVsModule));
diff --git a/src/tests/unittests/wire/WireOptionalTests.cpp b/src/tests/unittests/wire/WireOptionalTests.cpp
index ca80994..503d91d 100644
--- a/src/tests/unittests/wire/WireOptionalTests.cpp
+++ b/src/tests/unittests/wire/WireOptionalTests.cpp
@@ -66,7 +66,6 @@
 TEST_F(WireOptionalTests, OptionalStructPointer) {
     // Create shader module
     WGPUShaderModuleDescriptor vertexDescriptor = {};
-    vertexDescriptor.codeSize = 0;
     WGPUShaderModule vsModule = wgpuDeviceCreateShaderModule(device, &vertexDescriptor);
     WGPUShaderModule apiVsModule = api.GetNewShaderModule();
     EXPECT_CALL(api, DeviceCreateShaderModule(apiDevice, _)).WillOnce(Return(apiVsModule));
diff --git a/src/utils/WGPUHelpers.cpp b/src/utils/WGPUHelpers.cpp
index 9eb3912..adc270b 100644
--- a/src/utils/WGPUHelpers.cpp
+++ b/src/utils/WGPUHelpers.cpp
@@ -51,9 +51,13 @@
             ptrdiff_t resultSize = resultEnd - resultBegin;
             // SetSource takes data as uint32_t*.
 
+            wgpu::ShaderModuleSPIRVDescriptor spirvDesc;
+            spirvDesc.codeSize = static_cast<uint32_t>(resultSize);
+            spirvDesc.code = result.cbegin();
+
             wgpu::ShaderModuleDescriptor descriptor;
-            descriptor.codeSize = static_cast<uint32_t>(resultSize);
-            descriptor.code = result.cbegin();
+            descriptor.nextInChain = &spirvDesc;
+
             return device.CreateShaderModule(&descriptor);
         }
 
@@ -113,6 +117,18 @@
         return CreateShaderModuleFromResult(device, result);
     }
 
+    std::vector<uint32_t> CompileGLSLToSpirv(SingleShaderStage stage, const char* source) {
+        shaderc_shader_kind kind = ShadercShaderKind(stage);
+
+        shaderc::Compiler compiler;
+        auto result = compiler.CompileGlslToSpv(source, strlen(source), kind, "myshader?");
+        if (result.GetCompilationStatus() != shaderc_compilation_status_success) {
+            dawn::ErrorLog() << result.GetErrorMessage();
+            return {};
+        }
+        return {result.cbegin(), result.cend()};
+    }
+
     wgpu::Buffer CreateBufferFromData(const wgpu::Device& device,
                                       const void* data,
                                       uint64_t size,
diff --git a/src/utils/WGPUHelpers.h b/src/utils/WGPUHelpers.h
index 8e2f963..eafc45c 100644
--- a/src/utils/WGPUHelpers.h
+++ b/src/utils/WGPUHelpers.h
@@ -32,6 +32,7 @@
                                           SingleShaderStage stage,
                                           const char* source);
     wgpu::ShaderModule CreateShaderModuleFromASM(const wgpu::Device& device, const char* source);
+    std::vector<uint32_t> CompileGLSLToSpirv(SingleShaderStage stage, const char* source);
 
     wgpu::Buffer CreateBufferFromData(const wgpu::Device& device,
                                       const void* data,