[dawn][wire] Adds Dawn wire invalid extension struct and handling.

- Fixes serialization side by early-ing out when we detect
  invalid chain structure.
- Fixes deserialization by only deserializing allowed types.
- With these fixes, the stack overflows should no longer be
  possible during serialization nor deserialization, so we can
  add back extensibility for PipelineLayoutStorageAttachment.

Bug: 383824627
Change-Id: I8d4858b2ab747a53a5c6721f5d15160cc36bf3d0
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/226175
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
diff --git a/generator/dawn_json_generator.py b/generator/dawn_json_generator.py
index 502412c..f1cbdb8 100644
--- a/generator/dawn_json_generator.py
+++ b/generator/dawn_json_generator.py
@@ -523,6 +523,9 @@
             if member.type.category == 'structure':
                 max_dependent_depth = max(max_dependent_depth,
                                           compute_depth(member.type) + 1)
+        for extension in struct.extensions:
+            max_dependent_depth = max(max_dependent_depth,
+                                      compute_depth(extension) + 1)
 
         struct.subdag_depth = max_dependent_depth
         struct.visited = True
diff --git a/generator/templates/dawn/native/ChainUtils.cpp b/generator/templates/dawn/native/ChainUtils.cpp
index 2a33d5b..1ad7858 100644
--- a/generator/templates/dawn/native/ChainUtils.cpp
+++ b/generator/templates/dawn/native/ChainUtils.cpp
@@ -152,10 +152,20 @@
                                           result.mBitset,
                                           next,
                                           &duplicate)) {
-                        return DAWN_VALIDATION_ERROR(
-                            "Unexpected chained struct of type %s found on %s chain.",
-                            next->sType, "{{T}}"
-                        );
+                        if (next->sType == wgpu::SType::DawnInjectedInvalidSType) {
+                            // TODO(crbug.com/399470698): Need to reinterpret cast to base C type
+                            // for now because in/out typing are differentiated in C++ bindings.
+                            auto* ext = reinterpret_cast<const WGPUDawnInjectedInvalidSType*>(next);
+                            return DAWN_VALIDATION_ERROR(
+                                "Unexpected chained struct of type %s found on %s chain.",
+                                wgpu::SType(ext->invalidSType), "{{T}}"
+                            );
+                        } else {
+                            return DAWN_VALIDATION_ERROR(
+                                "Unexpected chained struct of type %s found on %s chain.",
+                                next->sType, "{{T}}"
+                            );
+                        }
                     }
                     break;
                 }
diff --git a/generator/templates/dawn/wire/WireCmd.cpp b/generator/templates/dawn/wire/WireCmd.cpp
index dbaf7b6..16b5d90 100644
--- a/generator/templates/dawn/wire/WireCmd.cpp
+++ b/generator/templates/dawn/wire/WireCmd.cpp
@@ -174,8 +174,24 @@
 
         //* Gather how much space will be needed for the extension chain.
         {% if record.extensible %}
-            if (record.nextInChain != nullptr) {
-                result += GetChainedStructExtraRequiredSize(record.nextInChain);
+            const WGPUChainedStruct* next = record.nextInChain;
+            while (next != nullptr) {
+                switch (next->sType) {
+                    {% for extension in record.extensions if extension.name.CamelCase() not in client_side_structures %}
+                        {% set CType = as_cType(extension.name) %}
+                        case {{as_cEnum(types["s type"].name, extension.name)}}: {
+                            const auto& typedStruct = *reinterpret_cast<{{CType}} const *>(next);
+                            result += WireAlignSizeof<{{CType}}Transfer>();
+                            result += {{CType}}GetExtraRequiredSize(typedStruct);
+                            break;
+                        }
+                    {% endfor %}
+                    default: {
+                        result += WireAlignSizeof<WGPUDawnInjectedInvalidSTypeTransfer>();
+                        break;
+                    }
+                }
+                next = next->next;
             }
         {% endif %}
         //* Gather space needed for pointer members.
@@ -230,11 +246,36 @@
         {% endif %}
 
         {% if record.extensible %}
-            if (record.nextInChain != nullptr) {
+            const WGPUChainedStruct* next = record.nextInChain;
+            transfer->hasNextInChain = false;
+            while (next != nullptr) {
                 transfer->hasNextInChain = true;
-                WIRE_TRY(SerializeChainedStruct(record.nextInChain, buffer, provider));
-            } else {
-                transfer->hasNextInChain = false;
+                switch (next->sType) {
+                    {% for extension in record.extensions if extension.name.CamelCase() not in client_side_structures %}
+                        {% set CType = as_cType(extension.name) %}
+                        case {{as_cEnum(types["s type"].name, extension.name)}}: {
+                            {{CType}}Transfer* chainTransfer;
+                            WIRE_TRY(buffer->Next(&chainTransfer));
+                            chainTransfer->chain.sType = next->sType;
+                            chainTransfer->chain.hasNext = next->next != nullptr;
+
+                            WIRE_TRY({{CType}}Serialize(*reinterpret_cast<{{CType}} const*>(next), chainTransfer, buffer, provider));
+                            break;
+                        }
+                    {% endfor %}
+                    default: {
+                        // Invalid enum. Serialize just the invalid sType for validation purposes.
+                        dawn::WarningLog() << "Unknown sType " << next->sType << " discarded.";
+
+                        WGPUDawnInjectedInvalidSTypeTransfer* chainTransfer;
+                        WIRE_TRY(buffer->Next(&chainTransfer));
+                        chainTransfer->chain.sType = WGPUSType_DawnInjectedInvalidSType;
+                        chainTransfer->chain.hasNext = next->next != nullptr;
+                        chainTransfer->invalidSType = next->sType;
+                        break;
+                    }
+                }
+                next = next->next;
             }
         {% endif %}
         {% if record.chained %}
@@ -305,8 +346,7 @@
         [[maybe_unused]] DeserializeAllocator* allocator
         {%- if record.may_have_dawn_object -%}
             , const ObjectIdResolver& resolver
-        {%- endif -%}
-    ) {
+        {%- endif -%}) {
         {% if is_cmd %}
             DAWN_ASSERT(transfer->commandId == {{Return}}WireCmd::{{name}});
         {% endif %}
@@ -315,10 +355,43 @@
         {% endif %}
 
         {% if record.extensible %}
-            record->nextInChain = nullptr;
-            if (transfer->hasNextInChain) {
-                WIRE_TRY(DeserializeChainedStruct(&record->nextInChain, deserializeBuffer, allocator, resolver));
+            WGPUChainedStruct** outChainNext = &record->nextInChain;
+            bool hasNext = transfer->hasNextInChain;
+            while (hasNext) {
+                const volatile WGPUChainedStructTransfer* header;
+                WIRE_TRY(deserializeBuffer->Peek(&header));
+                WGPUSType sType = header->sType;
+                hasNext = header->hasNext;
+
+                switch (sType) {
+                    //* All extensible types need to be able to handle deserializing the invalid
+                    //* sType struct.
+                    {% set extensions = record.extensions + [types['dawn injected invalid s type']] %}
+                    {% for extension in extensions if extension.name.CamelCase() not in client_side_structures %}
+                        {% set CType = as_cType(extension.name) %}
+                        case {{as_cEnum(types["s type"].name, extension.name)}}: {
+                            const volatile {{CType}}Transfer* chainTransfer;
+                            WIRE_TRY(deserializeBuffer->Read(&chainTransfer));
+
+                            {{CType}}* typedOutStruct;
+                            WIRE_TRY(GetSpace(allocator, 1u, &typedOutStruct));
+                            typedOutStruct->chain.sType = sType;
+                            typedOutStruct->chain.next = nullptr;
+                            WIRE_TRY({{CType}}Deserialize(typedOutStruct, chainTransfer,
+                                                          deserializeBuffer, allocator, resolver));
+                            *outChainNext = &typedOutStruct->chain;
+                            outChainNext = &typedOutStruct->chain.next;
+                            break;
+                        }
+                    {% endfor %}
+                    default: {
+                        //* For invalid sTypes, it's a fatal error since this implies a compromised
+                        //* or corrupt client.
+                        return WireResult::FatalError;
+                    }
+                }
             }
+            *outChainNext = nullptr;
         {% endif %}
         {% if record.chained %}
             //* Should be set by the root descriptor's call to DeserializeChainedStruct.
@@ -505,14 +578,8 @@
     bool hasNext;
 };
 
-size_t GetChainedStructExtraRequiredSize(WGPUChainedStruct* chainedStruct);
-[[nodiscard]] WireResult SerializeChainedStruct(WGPUChainedStruct* chainedStruct,
-                                                SerializeBuffer* buffer,
-                                                const ObjectIdProvider& provider);
-WireResult DeserializeChainedStruct(WGPUChainedStruct** outChainNext,
-                                    DeserializeBuffer* deserializeBuffer,
-                                    DeserializeAllocator* allocator,
-                                    const ObjectIdResolver& resolver);
+//* Structs that need special handling for [de]serialization code generation.
+{% set SpecialSerializeStructs = ["string view", "dawn injected invalid s type"] %}
 
 // Manually define serialization and deserialization for WGPUStringView because
 // it has a special encoding where:
@@ -599,10 +666,15 @@
     return WireResult::Success;
 }
 
+//* Force generation of de[serialization] methods for WGPUDawnInjectedInvalidSType early.
+{% set type = types["dawn injected invalid s type"] %}
+{%- set name = as_cType(type.name) -%}
+{{write_record_serialization_helpers(type, name, type.members, is_cmd=False)}}
+
 //* Output structure [de]serialization first because it is used by commands.
 {% for type in by_category["structure"] %}
     {%- set name = as_cType(type.name) -%}
-    {% if type.name.CamelCase() not in client_side_structures and name != "WGPUStringView" -%}
+    {% if type.name.CamelCase() not in client_side_structures and type.name.get() not in SpecialSerializeStructs -%}
         {{write_record_serialization_helpers(type, name, type.members, is_cmd=False)}}
     {% endif %}
 {% endfor %}
@@ -618,138 +690,6 @@
     {% do sTypes.append(sType) %}
 {% endfor %}
 
-size_t GetChainedStructExtraRequiredSize(WGPUChainedStruct* chainedStruct) {
-    DAWN_ASSERT(chainedStruct != nullptr);
-    size_t result = 0;
-    while (chainedStruct != nullptr) {
-        uint32_t sType_as_uint;
-        std::memcpy(&sType_as_uint, &(chainedStruct->sType), sizeof(uint32_t));
-        switch (sType_as_uint) {
-            {% for sType in sTypes %}
-                case {{as_cEnum(types["s type"].name, sType.name)}}: {
-                    const auto& typedStruct = *reinterpret_cast<{{as_cType(sType.name)}} const *>(chainedStruct);
-                    result += WireAlignSizeof<{{as_cType(sType.name)}}Transfer>();
-                    result += {{as_cType(sType.name)}}GetExtraRequiredSize(typedStruct);
-                    chainedStruct = typedStruct.chain.next;
-                    break;
-                }
-            {% endfor %}
-            default:
-                // Invalid enum. Reserve space just for the transfer header (sType and hasNext).
-                result += WireAlignSizeof<WGPUChainedStructTransfer>();
-                chainedStruct = chainedStruct->next;
-                break;
-        }
-    }
-    return result;
-}
-
-[[nodiscard]] WireResult SerializeChainedStruct(WGPUChainedStruct* chainedStruct,
-                                                SerializeBuffer* buffer,
-                                                const ObjectIdProvider& provider) {
-    DAWN_ASSERT(chainedStruct != nullptr);
-    DAWN_ASSERT(buffer != nullptr);
-    do {
-        uint32_t sType_as_uint;
-        std::memcpy(&sType_as_uint, &(chainedStruct->sType), sizeof(uint32_t));
-        switch (sType_as_uint) {
-            {% for sType in sTypes %}
-                {% set CType = as_cType(sType.name) %}
-                case {{as_cEnum(types["s type"].name, sType.name)}}: {
-                    {{CType}}Transfer* transfer;
-                    WIRE_TRY(buffer->Next(&transfer));
-                    transfer->chain.sType = chainedStruct->sType;
-                    transfer->chain.hasNext = chainedStruct->next != nullptr;
-
-                    WIRE_TRY({{CType}}Serialize(*reinterpret_cast<{{CType}} const*>(chainedStruct), transfer, buffer
-                        {%- if types[sType.name.get()].may_have_dawn_object -%}
-                        , provider
-                        {%- endif -%}
-                    ));
-
-                    chainedStruct = chainedStruct->next;
-                } break;
-            {% endfor %}
-            default: {
-                // Invalid enum. Serialize just the transfer header with Invalid as the sType.
-                // TODO(crbug.com/dawn/369): Unknown sTypes are silently discarded.
-                if (sType_as_uint != 0u) {
-                    dawn::WarningLog() << "Unknown sType " << sType_as_uint << " discarded.";
-                }
-
-                WGPUChainedStructTransfer* transfer;
-                WIRE_TRY(buffer->Next(&transfer));
-                transfer->sType = WGPUSType(0);
-                transfer->hasNext = chainedStruct->next != nullptr;
-
-                // Still move on in case there are valid structs after this.
-                chainedStruct = chainedStruct->next;
-                break;
-            }
-        }
-    } while (chainedStruct != nullptr);
-    return WireResult::Success;
-}
-
-WireResult DeserializeChainedStruct(WGPUChainedStruct** outChainNext,
-                                    DeserializeBuffer* deserializeBuffer,
-                                    DeserializeAllocator* allocator,
-                                    const ObjectIdResolver& resolver) {
-    bool hasNext;
-    do {
-        const volatile WGPUChainedStructTransfer* header;
-        WIRE_TRY(deserializeBuffer->Peek(&header));
-        WGPUSType sType = header->sType;
-        switch (sType) {
-            {% for sType in sTypes %}
-                {% set CType = as_cType(sType.name) %}
-                case {{as_cEnum(types["s type"].name, sType.name)}}: {
-                    const volatile {{CType}}Transfer* transfer;
-                    WIRE_TRY(deserializeBuffer->Read(&transfer));
-
-                    {{CType}}* outStruct;
-                    WIRE_TRY(GetSpace(allocator, 1u, &outStruct));
-                    outStruct->chain.sType = sType;
-                    outStruct->chain.next = nullptr;
-
-                    *outChainNext = &outStruct->chain;
-                    outChainNext = &outStruct->chain.next;
-
-                    WIRE_TRY({{CType}}Deserialize(outStruct, transfer, deserializeBuffer, allocator
-                        {%- if types[sType.name.get()].may_have_dawn_object -%}
-                            , resolver
-                        {%- endif -%}
-                    ));
-
-                    hasNext = transfer->chain.hasNext;
-                } break;
-            {% endfor %}
-            default: {
-                // Invalid enum. Deserialize just the transfer header with Invalid as the sType.
-                // TODO(crbug.com/dawn/369): Unknown sTypes are silently discarded.
-                if (sType != WGPUSType(0)) {
-                    dawn::WarningLog() << "Unknown sType " << sType << " discarded.";
-                }
-
-                const volatile WGPUChainedStructTransfer* transfer;
-                WIRE_TRY(deserializeBuffer->Read(&transfer));
-
-                WGPUChainedStruct* outStruct;
-                WIRE_TRY(GetSpace(allocator, 1u, &outStruct));
-                outStruct->sType = WGPUSType(0);
-                outStruct->next = nullptr;
-
-                // Still move on in case there are valid structs after this.
-                *outChainNext = outStruct;
-                outChainNext = &outStruct->next;
-                hasNext = transfer->hasNext;
-                break;
-            }
-        }
-    } while (hasNext);
-    return WireResult::Success;
-}
-
 //* Output [de]serialization helpers for commands
 {% for command in cmd_records["command"] %}
     {%- set name = command.name.CamelCase() -%}
diff --git a/src/dawn/dawn.json b/src/dawn/dawn.json
index c618176..01b809b 100644
--- a/src/dawn/dawn.json
+++ b/src/dawn/dawn.json
@@ -2508,6 +2508,15 @@
             {"name": "enable testing", "type": "bool", "default": "false"}
         ]
     },
+    "dawn injected invalid s type": {
+        "category": "structure",
+        "tags": ["dawn"],
+        "chained": "in",
+        "chain roots": [],
+        "members": [
+            {"name": "invalid s type", "type": "s type"}
+        ]
+    },
     "vertex attribute": {
         "category": "structure",
         "extensible": "in",
@@ -2629,6 +2638,7 @@
     "pipeline layout storage attachment": {
         "category": "structure",
         "tags": ["dawn"],
+        "extensible": "in",
         "members": [
             {"name": "offset", "type": "uint64_t", "default": 0},
             {"name": "format", "type": "texture format"}
@@ -3806,7 +3816,8 @@
             {"value": 59, "name": "dawn texel copy buffer row alignment limits", "tags": ["dawn"]},
             {"value": 60, "name": "adapter properties subgroup matrix configs", "tags": ["dawn"]},
             {"value": 61, "name": "shared fence EGL sync descriptor", "tags": ["dawn", "native"]},
-            {"value": 62, "name": "shared fence EGL sync export info", "tags": ["dawn", "native"]}
+            {"value": 62, "name": "shared fence EGL sync export info", "tags": ["dawn", "native"]},
+            {"value": 63, "name": "dawn injected invalid s type", "tags": ["dawn"]}
         ]
     },
     "texture": {
diff --git a/src/dawn/tests/unittests/ChainUtilsTests.cpp b/src/dawn/tests/unittests/ChainUtilsTests.cpp
index 518e166..3cca776 100644
--- a/src/dawn/tests/unittests/ChainUtilsTests.cpp
+++ b/src/dawn/tests/unittests/ChainUtilsTests.cpp
@@ -86,6 +86,29 @@
     }
 }
 
+// Inject invalid chain extensions cause an error.
+TEST(ChainUtilsTests, ValidateAndUnpackInjected) {
+    {
+        // TextureViewDescriptor (as of when this test was written) does not have any valid chains
+        // in the JSON nor via additional extensions.
+        TextureViewDescriptor desc;
+        DawnInjectedInvalidSType chain;
+        chain.invalidSType = wgpu::SType::ShaderSourceWGSL;
+        desc.nextInChain = &chain;
+        EXPECT_THAT(ValidateAndUnpack(&desc).AcquireError()->GetFormattedMessage(),
+                    HasSubstr("ShaderSourceWGSL"));
+    }
+    {
+        // InstanceDescriptor has at least 1 valid chain extension.
+        InstanceDescriptor desc;
+        DawnInjectedInvalidSType chain;
+        chain.invalidSType = wgpu::SType::ShaderSourceWGSL;
+        desc.nextInChain = &chain;
+        EXPECT_THAT(ValidateAndUnpack(&desc).AcquireError()->GetFormattedMessage(),
+                    HasSubstr("ShaderSourceWGSL"));
+    }
+}
+
 // Nominal unpacking valid descriptors should return the expected descriptors in the unpacked type.
 TEST(ChainUtilsTests, ValidateAndUnpack) {
     // DawnTogglesDescriptor is a valid extension for InstanceDescriptor.
diff --git a/src/dawn/tests/unittests/wire/WireExtensionTests.cpp b/src/dawn/tests/unittests/wire/WireExtensionTests.cpp
index f5f2717..8ff8f0d 100644
--- a/src/dawn/tests/unittests/wire/WireExtensionTests.cpp
+++ b/src/dawn/tests/unittests/wire/WireExtensionTests.cpp
@@ -71,8 +71,8 @@
 TEST_F(WireExtensionTests, MultipleChainedStructs) {
     wgpu::ShaderModuleDescriptor shaderModuleDesc = {};
 
-    wgpu::ShaderSourceWGSL clientExt2 = {};
-    clientExt2.code = {"/* comment 2 */", WGPU_STRLEN};
+    wgpu::ShaderModuleCompilationOptions clientExt2 = {};
+    clientExt2.strictMath = true;
 
     wgpu::ShaderSourceWGSL clientExt1 = {};
     clientExt1.code = {"/* comment 1 */", WGPU_STRLEN};
@@ -91,11 +91,10 @@
                 EXPECT_EQ(0, memcmp(ext1->code.data, clientExt1.code.data, ext1->code.length));
                 EXPECT_EQ(ext1->code.length, strlen(clientExt1.code.data));
 
-                const auto* ext2 = reinterpret_cast<const WGPUShaderSourceWGSL*>(ext1->chain.next);
-                EXPECT_EQ(ext2->chain.sType, WGPUSType_ShaderSourceWGSL);
-                EXPECT_NE(ext2->code.length, WGPU_STRLEN) << "The wire should decay WGPU_STRLEN";
-                EXPECT_EQ(0, memcmp(ext2->code.data, clientExt2.code.data, ext2->code.length));
-                EXPECT_EQ(ext2->code.length, strlen(clientExt2.code.data));
+                const auto* ext2 =
+                    reinterpret_cast<const WGPUShaderModuleCompilationOptions*>(ext1->chain.next);
+                EXPECT_EQ(ext2->chain.sType, WGPUSType_ShaderModuleCompilationOptions);
+                EXPECT_NE(ext2->strictMath, 0u);
                 EXPECT_EQ(ext2->chain.next, nullptr);
 
                 return apiShaderModule;
@@ -111,12 +110,10 @@
     EXPECT_CALL(api, DeviceCreateShaderModule(apiDevice, _))
         .WillOnce(
             Invoke([&](Unused, const WGPUShaderModuleDescriptor* serverDesc) -> WGPUShaderModule {
-                const auto* ext2 =
-                    reinterpret_cast<const WGPUShaderSourceWGSL*>(serverDesc->nextInChain);
-                EXPECT_EQ(ext2->chain.sType, WGPUSType_ShaderSourceWGSL);
-                EXPECT_NE(ext2->code.length, WGPU_STRLEN) << "The wire should decay WGPU_STRLEN";
-                EXPECT_EQ(0, memcmp(ext2->code.data, clientExt2.code.data, ext2->code.length));
-                EXPECT_EQ(ext2->code.length, strlen(clientExt2.code.data));
+                const auto* ext2 = reinterpret_cast<const WGPUShaderModuleCompilationOptions*>(
+                    serverDesc->nextInChain);
+                EXPECT_EQ(ext2->chain.sType, WGPUSType_ShaderModuleCompilationOptions);
+                EXPECT_NE(ext2->strictMath, 0u);
 
                 const auto* ext1 = reinterpret_cast<const WGPUShaderSourceWGSL*>(ext2->chain.next);
                 EXPECT_EQ(ext1->chain.sType, WGPUSType_ShaderSourceWGSL);
@@ -133,17 +130,20 @@
 // Test that a chained struct with Invalid sType passes through as Invalid.
 TEST_F(WireExtensionTests, InvalidSType) {
     wgpu::ShaderModuleDescriptor shaderModuleDesc = {};
-    wgpu::ShaderSourceWGSL clientExt = {};
+
+    wgpu::DawnWireWGSLControl clientExt = {};
     shaderModuleDesc.nextInChain = &clientExt;
-    clientExt.sType = wgpu::SType(0);
 
     WGPUShaderModule apiShaderModule = api.GetNewShaderModule();
     wgpu::ShaderModule shaderModule = device.CreateShaderModule(&shaderModuleDesc);
     EXPECT_CALL(api, DeviceCreateShaderModule(apiDevice, _))
         .WillOnce(
             Invoke([&](Unused, const WGPUShaderModuleDescriptor* serverDesc) -> WGPUShaderModule {
-                EXPECT_EQ(serverDesc->nextInChain->sType, WGPUSType(0));
-                EXPECT_EQ(serverDesc->nextInChain->next, nullptr);
+                const auto* ext =
+                    reinterpret_cast<const WGPUDawnInjectedInvalidSType*>(serverDesc->nextInChain);
+                EXPECT_EQ(ext->chain.sType, WGPUSType_DawnInjectedInvalidSType);
+                EXPECT_EQ(ext->chain.next, nullptr);
+                EXPECT_EQ(ext->invalidSType, WGPUSType_DawnWireWGSLControl);
 
                 return apiShaderModule;
             }));
@@ -153,17 +153,19 @@
 // Test that a chained struct with unknown sType passes through as Invalid.
 TEST_F(WireExtensionTests, UnknownSType) {
     wgpu::ShaderModuleDescriptor shaderModuleDesc = {};
-    wgpu::ShaderSourceWGSL clientExt = {};
+    wgpu::ChainedStruct clientExt = {};
     shaderModuleDesc.nextInChain = &clientExt;
-    clientExt.sType = static_cast<wgpu::SType>(-1);
 
     WGPUShaderModule apiShaderModule = api.GetNewShaderModule();
     wgpu::ShaderModule shaderModule = device.CreateShaderModule(&shaderModuleDesc);
     EXPECT_CALL(api, DeviceCreateShaderModule(apiDevice, _))
         .WillOnce(
             Invoke([&](Unused, const WGPUShaderModuleDescriptor* serverDesc) -> WGPUShaderModule {
-                EXPECT_EQ(serverDesc->nextInChain->sType, WGPUSType(0));
-                EXPECT_EQ(serverDesc->nextInChain->next, nullptr);
+                const auto* ext =
+                    reinterpret_cast<const WGPUDawnInjectedInvalidSType*>(serverDesc->nextInChain);
+                EXPECT_EQ(ext->chain.sType, WGPUSType_DawnInjectedInvalidSType);
+                EXPECT_EQ(ext->chain.next, nullptr);
+                EXPECT_EQ(ext->invalidSType, WGPUSType(0));
 
                 return apiShaderModule;
             }));
@@ -173,55 +175,56 @@
 // Test that if both an invalid and valid stype are passed on the chain, only the invalid
 // sType passes through as Invalid.
 TEST_F(WireExtensionTests, ValidAndInvalidSTypeInChain) {
-    WGPUShaderModuleDescriptor shaderModuleDesc = {};
+    wgpu::ShaderModuleDescriptor shaderModuleDesc = {};
 
-    WGPUShaderSourceWGSL clientExt2 = {};
-    clientExt2.chain.sType = WGPUSType(0);
-    clientExt2.chain.next = nullptr;
-
-    WGPUShaderSourceWGSL clientExt1 = {};
-    clientExt1.chain.sType = WGPUSType_ShaderSourceWGSL;
-    clientExt1.chain.next = &clientExt2.chain;
+    wgpu::DawnWireWGSLControl clientExt2 = {};
+    wgpu::ShaderSourceWGSL clientExt1 = {};
     clientExt1.code = {"/* comment 1 */", WGPU_STRLEN};
-    shaderModuleDesc.nextInChain = &clientExt1.chain;
+    clientExt1.nextInChain = &clientExt2;
+    shaderModuleDesc.nextInChain = &clientExt1;
 
     WGPUShaderModule apiShaderModule = api.GetNewShaderModule();
-    wgpu::ShaderModule shaderModule1 = wgpuDeviceCreateShaderModule(cDevice, &shaderModuleDesc);
+    wgpu::ShaderModule shaderModule1 = device.CreateShaderModule(&shaderModuleDesc);
     EXPECT_CALL(api, DeviceCreateShaderModule(apiDevice, _))
         .WillOnce(
             Invoke([&](Unused, const WGPUShaderModuleDescriptor* serverDesc) -> WGPUShaderModule {
-                const auto* ext =
+                const auto* ext1 =
                     reinterpret_cast<const WGPUShaderSourceWGSL*>(serverDesc->nextInChain);
-                EXPECT_EQ(ext->chain.sType, clientExt1.chain.sType);
-                EXPECT_NE(ext->code.length, WGPU_STRLEN) << "The wire should decay WGPU_STRLEN";
-                EXPECT_EQ(0, memcmp(ext->code.data, clientExt1.code.data, ext->code.length));
-                EXPECT_EQ(ext->code.length, strlen(clientExt1.code.data));
+                EXPECT_EQ(ext1->chain.sType, WGPUSType_ShaderSourceWGSL);
+                EXPECT_NE(ext1->code.length, WGPU_STRLEN) << "The wire should decay WGPU_STRLEN";
+                EXPECT_EQ(0, memcmp(ext1->code.data, clientExt1.code.data, ext1->code.length));
+                EXPECT_EQ(ext1->code.length, strlen(clientExt1.code.data));
 
-                EXPECT_EQ(ext->chain.next->sType, WGPUSType(0));
-                EXPECT_EQ(ext->chain.next->next, nullptr);
+                const auto* ext2 =
+                    reinterpret_cast<const WGPUDawnInjectedInvalidSType*>(ext1->chain.next);
+                EXPECT_EQ(ext2->chain.sType, WGPUSType_DawnInjectedInvalidSType);
+                EXPECT_EQ(ext2->chain.next, nullptr);
+                EXPECT_EQ(ext2->invalidSType, WGPUSType_DawnWireWGSLControl);
 
                 return apiShaderModule;
             }));
     FlushClient();
 
     // Swap the order of the chained structs.
-    shaderModuleDesc.nextInChain = &clientExt2.chain;
-    clientExt2.chain.next = &clientExt1.chain;
-    clientExt1.chain.next = nullptr;
+    shaderModuleDesc.nextInChain = &clientExt2;
+    clientExt2.nextInChain = &clientExt1;
+    clientExt1.nextInChain = nullptr;
 
-    wgpu::ShaderModule shaderModule2 = wgpuDeviceCreateShaderModule(cDevice, &shaderModuleDesc);
+    wgpu::ShaderModule shaderModule2 = device.CreateShaderModule(&shaderModuleDesc);
     EXPECT_CALL(api, DeviceCreateShaderModule(apiDevice, _))
         .WillOnce(
             Invoke([&](Unused, const WGPUShaderModuleDescriptor* serverDesc) -> WGPUShaderModule {
-                EXPECT_EQ(serverDesc->nextInChain->sType, WGPUSType(0));
+                const auto* ext2 =
+                    reinterpret_cast<const WGPUDawnInjectedInvalidSType*>(serverDesc->nextInChain);
+                EXPECT_EQ(ext2->chain.sType, WGPUSType_DawnInjectedInvalidSType);
+                EXPECT_EQ(ext2->invalidSType, WGPUSType_DawnWireWGSLControl);
 
-                const auto* ext =
-                    reinterpret_cast<const WGPUShaderSourceWGSL*>(serverDesc->nextInChain->next);
-                EXPECT_EQ(ext->chain.sType, clientExt1.chain.sType);
-                EXPECT_NE(ext->code.length, WGPU_STRLEN) << "The wire should decay WGPU_STRLEN";
-                EXPECT_EQ(0, memcmp(ext->code.data, clientExt1.code.data, ext->code.length));
-                EXPECT_EQ(ext->code.length, strlen(clientExt1.code.data));
-                EXPECT_EQ(ext->chain.next, nullptr);
+                const auto* ext1 = reinterpret_cast<const WGPUShaderSourceWGSL*>(ext2->chain.next);
+                EXPECT_EQ(ext1->chain.sType, WGPUSType_ShaderSourceWGSL);
+                EXPECT_EQ(ext1->chain.next, nullptr);
+                EXPECT_NE(ext1->code.length, WGPU_STRLEN) << "The wire should decay WGPU_STRLEN";
+                EXPECT_EQ(0, memcmp(ext1->code.data, clientExt1.code.data, ext1->code.length));
+                EXPECT_EQ(ext1->code.length, strlen(clientExt1.code.data));
 
                 return apiShaderModule;
             }));