Add helper functions to iterate over ChainedStructs

This CL adds two helpers for more ergonomic processing of
ChainedStructs.

1. FindInChain(): Iterates through the chain and automatically
   casts the ChainedStruct into the appropriate child type before
   returning.
2. ValidateSTypes(): Verifies that the chain only contains structs
   with sTypes from a pre-defined set. This also allows the caller
   to specify one-of constraints.
3. ValidateSingleSType(): Verifies that the chain contains a
   single struct with a specific sType or is an empty chain. This
   is a common case of |ValidateSTypes()| and is separated out as
   a fast-path.

Change-Id: I938df0bf2a9b1800b1105fb7f80fbde20bef8ec8
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/47680
Commit-Queue: Brian Ho <hob@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/docs/codegen.md b/docs/codegen.md
index 9cdf40b..d79c0e9 100644
--- a/docs/codegen.md
+++ b/docs/codegen.md
@@ -18,6 +18,7 @@
  - validation helper functions for dawn_native
  - the definition of dawn_native's proc table
  - dawn_native's internal version of the webgpu.h types
+ - utilities for working with dawn_native's chained structs
  - a lot of dawn_wire parts, see below
 
 Internally `dawn.json` is a dictionary from the "canonical name" of things to their definition. The "canonical name" is a space-separated (mostly) lower-case version of the name that's parsed into a `Name` Python object. Then that name can be turned into various casings with `.CamelCase()` `.SNAKE_CASE()`, etc. When `dawn.json` things reference each other, it is always via these "canonical names".
diff --git a/generator/dawn_json_generator.py b/generator/dawn_json_generator.py
index 84a2b99..858be52 100644
--- a/generator/dawn_json_generator.py
+++ b/generator/dawn_json_generator.py
@@ -765,6 +765,14 @@
             renders.append(
                 FileRender('dawn_native/ProcTable.cpp',
                            'src/dawn_native/ProcTable.cpp', frontend_params))
+            renders.append(
+                FileRender('dawn_native/ChainUtils.h',
+                           'src/dawn_native/ChainUtils_autogen.h',
+                           frontend_params))
+            renders.append(
+                FileRender('dawn_native/ChainUtils.cpp',
+                           'src/dawn_native/ChainUtils_autogen.cpp',
+                           frontend_params))
 
         if 'dawn_wire' in targets:
             additional_params = compute_wire_params(api_params, wire_json)
diff --git a/generator/templates/dawn_native/ChainUtils.cpp b/generator/templates/dawn_native/ChainUtils.cpp
new file mode 100644
index 0000000..2a42db2
--- /dev/null
+++ b/generator/templates/dawn_native/ChainUtils.cpp
@@ -0,0 +1,61 @@
+// Copyright 2021 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn_native/ChainUtils_autogen.h"
+
+#include <unordered_set>
+
+namespace dawn_native {
+
+{% for value in types["s type"].values %}
+    {% if value.valid %}
+        void FindInChain(const ChainedStruct* chain, const {{as_cppEnum(value.name)}}** out) {
+            for (; chain; chain = chain->nextInChain) {
+                if (chain->sType == wgpu::SType::{{as_cppEnum(value.name)}}) {
+                    *out = static_cast<const {{as_cppEnum(value.name)}}*>(chain);
+                    break;
+                }
+            }
+        }
+    {% endif %}
+{% endfor %}
+
+MaybeError ValidateSTypes(const ChainedStruct* chain,
+                          std::vector<std::vector<wgpu::SType>> oneOfConstraints) {
+    std::unordered_set<wgpu::SType> allSTypes;
+    for (; chain; chain = chain->nextInChain) {
+        if (allSTypes.find(chain->sType) != allSTypes.end()) {
+            return DAWN_VALIDATION_ERROR("Chain cannot have duplicate sTypes");
+        }
+        allSTypes.insert(chain->sType);
+    }
+    for (const auto& oneOfConstraint : oneOfConstraints) {
+        bool satisfied = false;
+        for (wgpu::SType oneOfSType : oneOfConstraint) {
+            if (allSTypes.find(oneOfSType) != allSTypes.end()) {
+                if (satisfied) {
+                    return DAWN_VALIDATION_ERROR("Unsupported sType combination");
+                }
+                satisfied = true;
+                allSTypes.erase(oneOfSType);
+            }
+        }
+    }
+    if (!allSTypes.empty()) {
+        return DAWN_VALIDATION_ERROR("Unsupported sType");
+    }
+    return {};
+}
+
+}  // namespace dawn_native
diff --git a/generator/templates/dawn_native/ChainUtils.h b/generator/templates/dawn_native/ChainUtils.h
new file mode 100644
index 0000000..ce46591
--- /dev/null
+++ b/generator/templates/dawn_native/ChainUtils.h
@@ -0,0 +1,81 @@
+// Copyright 2021 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef DAWNNATIVE_CHAIN_UTILS_H_
+#define DAWNNATIVE_CHAIN_UTILS_H_
+
+#include "dawn_native/dawn_platform.h"
+#include "dawn_native/Error.h"
+
+namespace dawn_native {
+    {% for value in types["s type"].values %}
+        {% if value.valid %}
+            void FindInChain(const ChainedStruct* chain, const {{as_cppEnum(value.name)}}** out);
+        {% endif %}
+    {% endfor %}
+
+    // Verifies that |chain| only contains ChainedStructs of types enumerated in
+    // |oneOfConstraints| and contains no duplicate sTypes. Each vector in
+    // |oneOfConstraints| defines a set of sTypes that cannot coexist in the same chain.
+    // For example:
+    //   ValidateSTypes(chain, { { ShaderModuleSPIRVDescriptor, ShaderModuleWGSLDescriptor } }))
+    //   ValidateSTypes(chain, { { Extension1 }, { Extension2 } })
+    MaybeError ValidateSTypes(const ChainedStruct* chain,
+                              std::vector<std::vector<wgpu::SType>> oneOfConstraints);
+
+    template <typename T>
+    MaybeError ValidateSingleSTypeInner(const ChainedStruct* chain, T sType) {
+        if (chain->sType != sType) {
+            return DAWN_VALIDATION_ERROR("Unsupported sType");
+        }
+        return {};
+    }
+
+    template <typename T, typename... Args>
+    MaybeError ValidateSingleSTypeInner(const ChainedStruct* chain, T sType, Args... sTypes) {
+        if (chain->sType == sType) {
+            return {};
+        }
+        return ValidateSingleSTypeInner(chain, sTypes...);
+    }
+
+    // Verifies that |chain| contains a single ChainedStruct of type |sType| or no ChainedStructs
+    // at all.
+    template <typename T>
+    MaybeError ValidateSingleSType(const ChainedStruct* chain, T sType) {
+        if (chain == nullptr) {
+            return {};
+        }
+        if (chain->nextInChain != nullptr) {
+            return DAWN_VALIDATION_ERROR("Chain can only contain a single chained struct");
+        }
+        return ValidateSingleSTypeInner(chain, sType);
+    }
+
+    // Verifies that |chain| contains a single ChainedStruct with a type enumerated in the
+    // parameter pack or no ChainedStructs at all.
+    template <typename T, typename... Args>
+    MaybeError ValidateSingleSType(const ChainedStruct* chain, T sType, Args... sTypes) {
+        if (chain == nullptr) {
+            return {};
+        }
+        if (chain->nextInChain != nullptr) {
+            return DAWN_VALIDATION_ERROR("Chain can only contain a single chained struct");
+        }
+        return ValidateSingleSTypeInner(chain, sType, sTypes...);
+    }
+
+}  // namespace dawn_native
+
+#endif  // DAWNNATIVE_CHAIN_UTILS_H_
diff --git a/src/dawn_native/BUILD.gn b/src/dawn_native/BUILD.gn
index 803ee30..2a5a59b 100644
--- a/src/dawn_native/BUILD.gn
+++ b/src/dawn_native/BUILD.gn
@@ -91,6 +91,8 @@
 dawn_json_generator("dawn_native_utils_gen") {
   target = "dawn_native_utils"
   outputs = [
+    "src/dawn_native/ChainUtils_autogen.h",
+    "src/dawn_native/ChainUtils_autogen.cpp",
     "src/dawn_native/ProcTable.cpp",
     "src/dawn_native/wgpu_structs_autogen.h",
     "src/dawn_native/wgpu_structs_autogen.cpp",
diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp
index c09bdfb..c32034d 100644
--- a/src/dawn_native/RenderPipeline.cpp
+++ b/src/dawn_native/RenderPipeline.cpp
@@ -16,6 +16,7 @@
 
 #include "common/BitSetIterator.h"
 #include "common/VertexFormatUtils.h"
+#include "dawn_native/ChainUtils_autogen.h"
 #include "dawn_native/Commands.h"
 #include "dawn_native/Device.h"
 #include "dawn_native/ObjectContentHasher.h"
@@ -133,16 +134,13 @@
 
         MaybeError ValidatePrimitiveState(const DeviceBase* device,
                                           const PrimitiveState* descriptor) {
-            const ChainedStruct* chained = descriptor->nextInChain;
-            if (chained != nullptr) {
-                if (chained->sType != wgpu::SType::PrimitiveDepthClampingState) {
-                    return DAWN_VALIDATION_ERROR("Unsupported sType");
-                }
-                if (!device->IsExtensionEnabled(Extension::DepthClamping)) {
-                    return DAWN_VALIDATION_ERROR("The depth clamping feature is not supported");
-                }
+            DAWN_TRY(ValidateSingleSType(descriptor->nextInChain,
+                wgpu::SType::PrimitiveDepthClampingState));
+            const PrimitiveDepthClampingState* clampInfo = nullptr;
+            FindInChain(descriptor->nextInChain, &clampInfo);
+            if (clampInfo && !device->IsExtensionEnabled(Extension::DepthClamping)) {
+                return DAWN_VALIDATION_ERROR("The depth clamping feature is not supported");
             }
-
             DAWN_TRY(ValidatePrimitiveTopology(descriptor->topology));
             DAWN_TRY(ValidateIndexFormat(descriptor->stripIndexFormat));
             DAWN_TRY(ValidateFrontFace(descriptor->frontFace));
@@ -426,11 +424,10 @@
         }
 
         mPrimitive = descriptor->primitive;
-        const ChainedStruct* chained = mPrimitive.nextInChain;
-        if (chained != nullptr) {
-            ASSERT(chained->sType == wgpu::SType::PrimitiveDepthClampingState);
-            const auto* clampState = static_cast<const PrimitiveDepthClampingState*>(chained);
-            mClampDepth = clampState->clampDepth;
+        const PrimitiveDepthClampingState* clampInfo = nullptr;
+        FindInChain(mPrimitive.nextInChain, &clampInfo);
+        if (clampInfo) {
+            mClampDepth = clampInfo->clampDepth;
         }
         mMultisample = descriptor->multisample;
 
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index a9cceac..5e5762f 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -17,6 +17,7 @@
 #include "common/HashUtils.h"
 #include "common/VertexFormatUtils.h"
 #include "dawn_native/BindGroupLayout.h"
+#include "dawn_native/ChainUtils_autogen.h"
 #include "dawn_native/CompilationMessages.h"
 #include "dawn_native/Device.h"
 #include "dawn_native/ObjectContentHasher.h"
@@ -1069,65 +1070,56 @@
             return DAWN_VALIDATION_ERROR("Shader module descriptor missing chained descriptor");
         }
         // For now only a single SPIRV or WGSL subdescriptor is allowed.
-        if (chainedDescriptor->nextInChain != nullptr) {
-            return DAWN_VALIDATION_ERROR(
-                "Shader module descriptor chained nextInChain must be nullptr");
-        }
+        DAWN_TRY(ValidateSingleSType(chainedDescriptor,
+            wgpu::SType::ShaderModuleSPIRVDescriptor,
+            wgpu::SType::ShaderModuleWGSLDescriptor));
 
         OwnedCompilationMessages* outMessages = parseResult->compilationMessages.get();
 
         ScopedTintICEHandler scopedICEHandler(device);
 
-        switch (chainedDescriptor->sType) {
-            case wgpu::SType::ShaderModuleSPIRVDescriptor: {
-                const auto* spirvDesc =
-                    static_cast<const ShaderModuleSPIRVDescriptor*>(chainedDescriptor);
-                std::vector<uint32_t> spirv(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
-                if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
-                    tint::Program program;
-                    DAWN_TRY_ASSIGN(program, ParseSPIRV(spirv, outMessages));
-                    parseResult->tintProgram = std::make_unique<tint::Program>(std::move(program));
-                } else {
-                    if (device->IsValidationEnabled()) {
-                        DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
-                    }
-                    parseResult->spirv = std::move(spirv);
-                }
-                break;
-            }
+        const ShaderModuleSPIRVDescriptor* spirvDesc = nullptr;
+        FindInChain(chainedDescriptor, &spirvDesc);
+        const ShaderModuleWGSLDescriptor* wgslDesc = nullptr;
+        FindInChain(chainedDescriptor, &wgslDesc);
 
-            case wgpu::SType::ShaderModuleWGSLDescriptor: {
-                const auto* wgslDesc =
-                    static_cast<const ShaderModuleWGSLDescriptor*>(chainedDescriptor);
-
-                auto tintSource = std::make_unique<TintSource>("", wgslDesc->source);
-
+        if (spirvDesc) {
+            std::vector<uint32_t> spirv(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
+            if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
                 tint::Program program;
-                DAWN_TRY_ASSIGN(program, ParseWGSL(&tintSource->file, outMessages));
-
-                if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
-                    parseResult->tintProgram = std::make_unique<tint::Program>(std::move(program));
-                    parseResult->tintSource = std::move(tintSource);
-                } else {
-                    tint::transform::Manager transformManager;
-                    transformManager.Add<tint::transform::EmitVertexPointSize>();
-                    transformManager.Add<tint::transform::Spirv>();
-
-                    tint::transform::DataMap transformInputs;
-
-                    DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, &program,
-                                                           transformInputs, nullptr, outMessages));
-
-                    std::vector<uint32_t> spirv;
-                    DAWN_TRY_ASSIGN(spirv, ModuleToSPIRV(&program));
+                DAWN_TRY_ASSIGN(program, ParseSPIRV(spirv, outMessages));
+                parseResult->tintProgram = std::make_unique<tint::Program>(std::move(program));
+            } else {
+                if (device->IsValidationEnabled()) {
                     DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
-
-                    parseResult->spirv = std::move(spirv);
                 }
-                break;
+                parseResult->spirv = std::move(spirv);
             }
-            default:
-                return DAWN_VALIDATION_ERROR("Unsupported sType");
+        } else if (wgslDesc) {
+            auto tintSource = std::make_unique<TintSource>("", wgslDesc->source);
+
+            tint::Program program;
+            DAWN_TRY_ASSIGN(program, ParseWGSL(&tintSource->file, outMessages));
+
+            if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
+                parseResult->tintProgram = std::make_unique<tint::Program>(std::move(program));
+                parseResult->tintSource = std::move(tintSource);
+            } else {
+                tint::transform::Manager transformManager;
+                transformManager.Add<tint::transform::EmitVertexPointSize>();
+                transformManager.Add<tint::transform::Spirv>();
+
+                tint::transform::DataMap transformInputs;
+
+                DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, &program,
+                                                       transformInputs, nullptr, outMessages));
+
+                std::vector<uint32_t> spirv;
+                DAWN_TRY_ASSIGN(spirv, ModuleToSPIRV(&program));
+                DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
+
+                parseResult->spirv = std::move(spirv);
+            }
         }
 
         return {};
@@ -1216,23 +1208,18 @@
     ShaderModuleBase::ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor)
         : CachedObject(device), mType(Type::Undefined) {
         ASSERT(descriptor->nextInChain != nullptr);
-        switch (descriptor->nextInChain->sType) {
-            case wgpu::SType::ShaderModuleSPIRVDescriptor: {
-                mType = Type::Spirv;
-                const auto* spirvDesc =
-                    static_cast<const ShaderModuleSPIRVDescriptor*>(descriptor->nextInChain);
-                mOriginalSpirv.assign(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
-                break;
-            }
-            case wgpu::SType::ShaderModuleWGSLDescriptor: {
-                mType = Type::Wgsl;
-                const auto* wgslDesc =
-                    static_cast<const ShaderModuleWGSLDescriptor*>(descriptor->nextInChain);
-                mWgsl = std::string(wgslDesc->source);
-                break;
-            }
-            default:
-                UNREACHABLE();
+        const ShaderModuleSPIRVDescriptor* spirvDesc = nullptr;
+        FindInChain(descriptor->nextInChain, &spirvDesc);
+        const ShaderModuleWGSLDescriptor* wgslDesc = nullptr;
+        FindInChain(descriptor->nextInChain, &wgslDesc);
+        ASSERT(spirvDesc || wgslDesc);
+
+        if (spirvDesc) {
+            mType = Type::Spirv;
+            mOriginalSpirv.assign(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
+        } else if (wgslDesc) {
+            mType = Type::Wgsl;
+            mWgsl = std::string(wgslDesc->source);
         }
     }
 
diff --git a/src/dawn_native/Surface.cpp b/src/dawn_native/Surface.cpp
index 4afe05e..9b317bc 100644
--- a/src/dawn_native/Surface.cpp
+++ b/src/dawn_native/Surface.cpp
@@ -15,6 +15,7 @@
 #include "dawn_native/Surface.h"
 
 #include "common/Platform.h"
+#include "dawn_native/ChainUtils_autogen.h"
 #include "dawn_native/Instance.h"
 #include "dawn_native/SwapChain.h"
 
@@ -34,75 +35,60 @@
 
     MaybeError ValidateSurfaceDescriptor(const InstanceBase* instance,
                                          const SurfaceDescriptor* descriptor) {
-        // TODO(cwallez@chromium.org): Have some type of helper to iterate over all the chained
-        // structures.
         if (descriptor->nextInChain == nullptr) {
             return DAWN_VALIDATION_ERROR("Surface cannot be created with just the base descriptor");
         }
 
-        const ChainedStruct* chainedDescriptor = descriptor->nextInChain;
-        if (chainedDescriptor->nextInChain != nullptr) {
-            return DAWN_VALIDATION_ERROR("Cannot specify two windows for a single surface");
-        }
+        DAWN_TRY(ValidateSingleSType(descriptor->nextInChain,
+            wgpu::SType::SurfaceDescriptorFromMetalLayer,
+            wgpu::SType::SurfaceDescriptorFromWindowsHWND,
+            wgpu::SType::SurfaceDescriptorFromXlib));
 
-        switch (chainedDescriptor->sType) {
 #if defined(DAWN_ENABLE_BACKEND_METAL)
-            case wgpu::SType::SurfaceDescriptorFromMetalLayer: {
-                const SurfaceDescriptorFromMetalLayer* metalDesc =
-                    static_cast<const SurfaceDescriptorFromMetalLayer*>(chainedDescriptor);
-
-                // Check that the layer is a CAMetalLayer (or a derived class).
-                if (!InheritsFromCAMetalLayer(metalDesc->layer)) {
-                    return DAWN_VALIDATION_ERROR("layer must be a CAMetalLayer");
-                }
-                break;
-            }
+        const SurfaceDescriptorFromMetalLayer* metalDesc = nullptr;
+        FindInChain(descriptor->nextInChain, &metalDesc);
+        if (!metalDesc) {
+            return DAWN_VALIDATION_ERROR("Unsupported sType");
+        }
+        // Check that the layer is a CAMetalLayer (or a derived class).
+        if (!InheritsFromCAMetalLayer(metalDesc->layer)) {
+            return DAWN_VALIDATION_ERROR("layer must be a CAMetalLayer");
+        }
 #endif  // defined(DAWN_ENABLE_BACKEND_METAL)
 
 #if defined(DAWN_PLATFORM_WINDOWS)
-            case wgpu::SType::SurfaceDescriptorFromWindowsHWND: {
-                const SurfaceDescriptorFromWindowsHWND* hwndDesc =
-                    static_cast<const SurfaceDescriptorFromWindowsHWND*>(chainedDescriptor);
-
-                // It is not possible to validate an HINSTANCE.
-
-                // Validate the hwnd using the windows.h IsWindow function.
-                if (IsWindow(static_cast<HWND>(hwndDesc->hwnd)) == 0) {
-                    return DAWN_VALIDATION_ERROR("Invalid HWND");
-                }
-                break;
-            }
+        const SurfaceDescriptorFromWindowsHWND* hwndDesc = nullptr;
+        FindInChain(descriptor->nextInChain, &hwndDesc);
+        if (!hwndDesc) {
+            return DAWN_VALIDATION_ERROR("Unsupported sType");
+        }
+        // Validate the hwnd using the windows.h IsWindow function.
+        if (IsWindow(static_cast<HWND>(hwndDesc->hwnd)) == 0) {
+            return DAWN_VALIDATION_ERROR("Invalid HWND");
+        }
 #endif  // defined(DAWN_PLATFORM_WINDOWS)
 
 #if defined(DAWN_USE_X11)
-            case wgpu::SType::SurfaceDescriptorFromXlib: {
-                const SurfaceDescriptorFromXlib* xDesc =
-                    static_cast<const SurfaceDescriptorFromXlib*>(chainedDescriptor);
-
-                // It is not possible to validate an X Display.
-
-                // Check the validity of the window by calling a getter function on the window that
-                // returns a status code. If the window is bad the call return a status of zero. We
-                // need to set a temporary X11 error handler while doing this because the default
-                // X11 error handler exits the program on any error.
-                XErrorHandler oldErrorHandler =
-                    XSetErrorHandler([](Display*, XErrorEvent*) { return 0; });
-                XWindowAttributes attributes;
-                int status = XGetWindowAttributes(reinterpret_cast<Display*>(xDesc->display),
-                                                  xDesc->window, &attributes);
-                XSetErrorHandler(oldErrorHandler);
-
-                if (status == 0) {
-                    return DAWN_VALIDATION_ERROR("Invalid X Window");
-                }
-                break;
-            }
-#endif  // defined(DAWN_USE_X11)
-
-            case wgpu::SType::SurfaceDescriptorFromCanvasHTMLSelector:
-            default:
-                return DAWN_VALIDATION_ERROR("Unsupported sType");
+        const SurfaceDescriptorFromXlib* xDesc = nullptr;
+        FindInChain(descriptor->nextInChain, &xDesc);
+        if (!xDesc) {
+            return DAWN_VALIDATION_ERROR("Unsupported sType");
         }
+        // Check the validity of the window by calling a getter function on the window that
+        // returns a status code. If the window is bad the call return a status of zero. We
+        // need to set a temporary X11 error handler while doing this because the default
+        // X11 error handler exits the program on any error.
+        XErrorHandler oldErrorHandler =
+            XSetErrorHandler([](Display*, XErrorEvent*) { return 0; });
+        XWindowAttributes attributes;
+        int status = XGetWindowAttributes(reinterpret_cast<Display*>(xDesc->display),
+                                          xDesc->window, &attributes);
+        XSetErrorHandler(oldErrorHandler);
+
+        if (status == 0) {
+            return DAWN_VALIDATION_ERROR("Invalid X Window");
+        }
+#endif  // defined(DAWN_USE_X11)
 
         return {};
     }
@@ -110,37 +96,24 @@
     Surface::Surface(InstanceBase* instance, const SurfaceDescriptor* descriptor)
         : mInstance(instance) {
         ASSERT(descriptor->nextInChain != nullptr);
-        const ChainedStruct* chainedDescriptor = descriptor->nextInChain;
-
-        switch (chainedDescriptor->sType) {
-            case wgpu::SType::SurfaceDescriptorFromMetalLayer: {
-                const SurfaceDescriptorFromMetalLayer* metalDesc =
-                    static_cast<const SurfaceDescriptorFromMetalLayer*>(chainedDescriptor);
-                mType = Type::MetalLayer;
-                mMetalLayer = metalDesc->layer;
-                break;
-            }
-
-            case wgpu::SType::SurfaceDescriptorFromWindowsHWND: {
-                const SurfaceDescriptorFromWindowsHWND* hwndDesc =
-                    static_cast<const SurfaceDescriptorFromWindowsHWND*>(chainedDescriptor);
-                mType = Type::WindowsHWND;
-                mHInstance = hwndDesc->hinstance;
-                mHWND = hwndDesc->hwnd;
-                break;
-            }
-
-            case wgpu::SType::SurfaceDescriptorFromXlib: {
-                const SurfaceDescriptorFromXlib* xDesc =
-                    static_cast<const SurfaceDescriptorFromXlib*>(chainedDescriptor);
-                mType = Type::Xlib;
-                mXDisplay = xDesc->display;
-                mXWindow = xDesc->window;
-                break;
-            }
-
-            default:
-                UNREACHABLE();
+        const SurfaceDescriptorFromMetalLayer* metalDesc = nullptr;
+        const SurfaceDescriptorFromWindowsHWND* hwndDesc = nullptr;
+        const SurfaceDescriptorFromXlib* xDesc = nullptr;
+        FindInChain(descriptor->nextInChain, &metalDesc);
+        FindInChain(descriptor->nextInChain, &hwndDesc);
+        FindInChain(descriptor->nextInChain, &xDesc);
+        ASSERT(metalDesc || hwndDesc || xDesc);
+        if (metalDesc) {
+            mType = Type::MetalLayer;
+            mMetalLayer = metalDesc->layer;
+        } else if (hwndDesc) {
+            mType = Type::WindowsHWND;
+            mHInstance = hwndDesc->hinstance;
+            mHWND = hwndDesc->hwnd;
+        } else if (xDesc) {
+            mType = Type::Xlib;
+            mXDisplay = xDesc->display;
+            mXWindow = xDesc->window;
         }
     }
 
diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn
index faf4b22..284cc41c 100644
--- a/src/tests/BUILD.gn
+++ b/src/tests/BUILD.gn
@@ -156,6 +156,7 @@
     "unittests/BitSetIteratorTests.cpp",
     "unittests/BuddyAllocatorTests.cpp",
     "unittests/BuddyMemoryAllocatorTests.cpp",
+    "unittests/ChainUtilsTests.cpp",
     "unittests/CommandAllocatorTests.cpp",
     "unittests/EnumClassBitmasksTests.cpp",
     "unittests/EnumMaskIteratorTests.cpp",
diff --git a/src/tests/unittests/ChainUtilsTests.cpp b/src/tests/unittests/ChainUtilsTests.cpp
new file mode 100644
index 0000000..2d43729
--- /dev/null
+++ b/src/tests/unittests/ChainUtilsTests.cpp
@@ -0,0 +1,181 @@
+// Copyright 2021 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <gtest/gtest.h>
+
+#include "dawn_native/ChainUtils_autogen.h"
+#include "dawn_native/dawn_platform.h"
+
+// Checks that we cannot find any structs in an empty chain
+TEST(ChainUtilsTests, FindEmptyChain) {
+    const dawn_native::PrimitiveDepthClampingState* info = nullptr;
+    dawn_native::FindInChain(nullptr, &info);
+
+    ASSERT_EQ(nullptr, info);
+}
+
+// Checks that searching a chain for a present struct returns that struct
+TEST(ChainUtilsTests, FindPresentInChain) {
+    dawn_native::PrimitiveDepthClampingState chain1;
+    dawn_native::ShaderModuleSPIRVDescriptor chain2;
+    chain1.nextInChain = &chain2;
+    const dawn_native::PrimitiveDepthClampingState* info1 = nullptr;
+    const dawn_native::ShaderModuleSPIRVDescriptor* info2 = nullptr;
+    dawn_native::FindInChain(&chain1, &info1);
+    dawn_native::FindInChain(&chain1, &info2);
+
+    ASSERT_NE(nullptr, info1);
+    ASSERT_NE(nullptr, info2);
+}
+
+// Checks that searching a chain for a struct that doesn't exist returns a nullptr
+TEST(ChainUtilsTests, FindMissingInChain) {
+    dawn_native::PrimitiveDepthClampingState chain1;
+    dawn_native::ShaderModuleSPIRVDescriptor chain2;
+    chain1.nextInChain = &chain2;
+    const dawn_native::SurfaceDescriptorFromMetalLayer* info = nullptr;
+    dawn_native::FindInChain(&chain1, &info);
+
+    ASSERT_EQ(nullptr, info);
+}
+
+// Checks that validation rejects chains with duplicate STypes
+TEST(ChainUtilsTests, ValidateDuplicateSTypes) {
+    dawn_native::PrimitiveDepthClampingState chain1;
+    dawn_native::ShaderModuleSPIRVDescriptor chain2;
+    dawn_native::PrimitiveDepthClampingState chain3;
+    chain1.nextInChain = &chain2;
+    chain2.nextInChain = &chain3;
+
+    dawn_native::MaybeError result = dawn_native::ValidateSTypes(&chain1, {});
+    ASSERT_TRUE(result.IsError());
+    result.AcquireError();
+}
+
+// Checks that validation rejects chains that contain unspecified STypes
+TEST(ChainUtilsTests, ValidateUnspecifiedSTypes) {
+    dawn_native::PrimitiveDepthClampingState chain1;
+    dawn_native::ShaderModuleSPIRVDescriptor chain2;
+    dawn_native::ShaderModuleWGSLDescriptor chain3;
+    chain1.nextInChain = &chain2;
+    chain2.nextInChain = &chain3;
+
+    dawn_native::MaybeError result = dawn_native::ValidateSTypes(&chain1, {
+        {wgpu::SType::PrimitiveDepthClampingState},
+        {wgpu::SType::ShaderModuleSPIRVDescriptor},
+    });
+    ASSERT_TRUE(result.IsError());
+    result.AcquireError();
+}
+
+// Checks that validation rejects chains that contain multiple STypes from the same oneof
+// constraint.
+TEST(ChainUtilsTests, ValidateOneOfFailure) {
+    dawn_native::PrimitiveDepthClampingState chain1;
+    dawn_native::ShaderModuleSPIRVDescriptor chain2;
+    dawn_native::ShaderModuleWGSLDescriptor chain3;
+    chain1.nextInChain = &chain2;
+    chain2.nextInChain = &chain3;
+
+    dawn_native::MaybeError result = dawn_native::ValidateSTypes(&chain1,
+        {{wgpu::SType::ShaderModuleSPIRVDescriptor, wgpu::SType::ShaderModuleWGSLDescriptor}});
+    ASSERT_TRUE(result.IsError());
+    result.AcquireError();
+}
+
+// Checks that validation accepts chains that match the constraints.
+TEST(ChainUtilsTests, ValidateSuccess) {
+    dawn_native::PrimitiveDepthClampingState chain1;
+    dawn_native::ShaderModuleSPIRVDescriptor chain2;
+    chain1.nextInChain = &chain2;
+
+    dawn_native::MaybeError result = dawn_native::ValidateSTypes(&chain1, {
+        {wgpu::SType::ShaderModuleSPIRVDescriptor, wgpu::SType::ShaderModuleWGSLDescriptor},
+        {wgpu::SType::PrimitiveDepthClampingState},
+        {wgpu::SType::SurfaceDescriptorFromMetalLayer},
+    });
+    ASSERT_TRUE(result.IsSuccess());
+}
+
+// Checks that validation always passes on empty chains.
+TEST(ChainUtilsTests, ValidateEmptyChain) {
+    dawn_native::MaybeError result = dawn_native::ValidateSTypes(nullptr, {
+        {wgpu::SType::ShaderModuleSPIRVDescriptor},
+        {wgpu::SType::PrimitiveDepthClampingState},
+    });
+    ASSERT_TRUE(result.IsSuccess());
+
+    result = dawn_native::ValidateSTypes(nullptr, {});
+    ASSERT_TRUE(result.IsSuccess());
+}
+
+// Checks that singleton validation always passes on empty chains.
+TEST(ChainUtilsTests, ValidateSingleEmptyChain) {
+    dawn_native::MaybeError result = dawn_native::ValidateSingleSType(nullptr,
+        wgpu::SType::ShaderModuleSPIRVDescriptor);
+    ASSERT_TRUE(result.IsSuccess());
+
+    result = dawn_native::ValidateSingleSType(nullptr,
+        wgpu::SType::ShaderModuleSPIRVDescriptor, wgpu::SType::PrimitiveDepthClampingState);
+    ASSERT_TRUE(result.IsSuccess());
+}
+
+// Checks that singleton validation always fails on chains with multiple children.
+TEST(ChainUtilsTests, ValidateSingleMultiChain) {
+    dawn_native::PrimitiveDepthClampingState chain1;
+    dawn_native::ShaderModuleSPIRVDescriptor chain2;
+    chain1.nextInChain = &chain2;
+
+    dawn_native::MaybeError result = dawn_native::ValidateSingleSType(&chain1,
+        wgpu::SType::PrimitiveDepthClampingState);
+    ASSERT_TRUE(result.IsError());
+    result.AcquireError();
+
+    result = dawn_native::ValidateSingleSType(&chain1,
+        wgpu::SType::PrimitiveDepthClampingState, wgpu::SType::ShaderModuleSPIRVDescriptor);
+    ASSERT_TRUE(result.IsError());
+    result.AcquireError();
+}
+
+// Checks that singleton validation passes when the oneof constraint is met.
+TEST(ChainUtilsTests, ValidateSingleSatisfied) {
+    dawn_native::ShaderModuleWGSLDescriptor chain1;
+
+    dawn_native::MaybeError result = dawn_native::ValidateSingleSType(&chain1,
+        wgpu::SType::ShaderModuleWGSLDescriptor);
+    ASSERT_TRUE(result.IsSuccess());
+
+    result = dawn_native::ValidateSingleSType(&chain1,
+        wgpu::SType::ShaderModuleSPIRVDescriptor, wgpu::SType::ShaderModuleWGSLDescriptor);
+    ASSERT_TRUE(result.IsSuccess());
+
+    result = dawn_native::ValidateSingleSType(&chain1,
+        wgpu::SType::ShaderModuleWGSLDescriptor, wgpu::SType::ShaderModuleSPIRVDescriptor);
+    ASSERT_TRUE(result.IsSuccess());
+}
+
+// Checks that singleton validation passes when the oneof constraint is not met.
+TEST(ChainUtilsTests, ValidateSingleUnsatisfied) {
+    dawn_native::PrimitiveDepthClampingState chain1;
+
+    dawn_native::MaybeError result = dawn_native::ValidateSingleSType(&chain1,
+        wgpu::SType::ShaderModuleWGSLDescriptor);
+    ASSERT_TRUE(result.IsError());
+    result.AcquireError();
+
+    result = dawn_native::ValidateSingleSType(&chain1,
+        wgpu::SType::ShaderModuleSPIRVDescriptor, wgpu::SType::ShaderModuleWGSLDescriptor);
+    ASSERT_TRUE(result.IsError());
+    result.AcquireError();
+}