Implement dawn:MatchVariant to visit std::variant
This patch implments the template dawn::MatchVariant as a convenient
way to call std::visit on a std::variant object.
Bug: dawn:527
Fixed: dawn:2370
Change-Id: I06204888d3e78308edb5142c7cad517d15fecb27
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/171740
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/dawn/common/BUILD.gn b/src/dawn/common/BUILD.gn
index 48f594b..aa6eecd 100644
--- a/src/dawn/common/BUILD.gn
+++ b/src/dawn/common/BUILD.gn
@@ -260,6 +260,7 @@
"LinkedList.h",
"Log.cpp",
"Log.h",
+ "MatchVariant.h",
"Math.cpp",
"Math.h",
"Mutex.cpp",
diff --git a/src/dawn/common/CMakeLists.txt b/src/dawn/common/CMakeLists.txt
index 740d5e3..b630f6b 100644
--- a/src/dawn/common/CMakeLists.txt
+++ b/src/dawn/common/CMakeLists.txt
@@ -67,6 +67,7 @@
"LinkedList.h"
"Log.cpp"
"Log.h"
+ "MatchVariant.h"
"Math.cpp"
"Math.h"
"Mutex.cpp"
diff --git a/src/dawn/common/MatchVariant.h b/src/dawn/common/MatchVariant.h
new file mode 100644
index 0000000..e691893
--- /dev/null
+++ b/src/dawn/common/MatchVariant.h
@@ -0,0 +1,82 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef SRC_DAWN_COMMON_MATCHVARIANT_H_
+#define SRC_DAWN_COMMON_MATCHVARIANT_H_
+
+#include <variant>
+
+namespace dawn {
+
+// This is the `Overloaded` template in chromium/src/base/functional/Overloaded.h.
+// std::visit() needs to be called with a functor object, such as
+//
+// struct Visitor {
+// std::string operator()(const PackageA& source) {
+// return "PackageA";
+// }
+//
+// std::string operator()(const PackageB& source) {
+// return "PackageB";
+// }
+// };
+//
+// std::variant<PackageA, PackageB> var = PackageA();
+// return std::visit(Visitor(), var);
+//
+// `Overloaded` enables the above code to be written as:
+//
+// std::visit(
+// Overloaded{
+// [](const PackageA& pack) { return "PackageA"; },
+// [](const PackageB& pack) { return "PackageB"; },
+// }, var);
+//
+// Note: Overloads must be implemented for all the variant options. Otherwise, there will be a
+// compilation error.
+//
+// This struct inherits operator() method from all its base classes. Introduces operator() method
+// from all its base classes into its definition.
+template <typename... Callables>
+struct Overloaded : Callables... {
+ using Callables::operator()...;
+};
+
+// Uses template argument deduction so that the `Overloaded` struct can be used without specifying
+// its template argument. This allows anonymous lambdas passed into the `Overloaded` constructor.
+template <typename... Callables>
+Overloaded(Callables...) -> Overloaded<Callables...>;
+
+// With this template we can simplify the call of std::visit(Overloaded{...}, variant).
+template <typename Variant, typename... Callables>
+auto MatchVariant(const Variant& v, Callables... args) {
+ return std::visit(Overloaded{args...}, v);
+}
+
+} // namespace dawn
+
+#endif
diff --git a/src/dawn/native/Format.cpp b/src/dawn/native/Format.cpp
index 756245a..6ce2515 100644
--- a/src/dawn/native/Format.cpp
+++ b/src/dawn/native/Format.cpp
@@ -30,6 +30,7 @@
#include <bitset>
#include <utility>
+#include "dawn/common/MatchVariant.h"
#include "dawn/common/TypedInteger.h"
#include "dawn/native/Device.h"
#include "dawn/native/EnumMaskIterator.h"
@@ -596,29 +597,16 @@
return table;
}
-namespace {
-
-template <class... Ts>
-struct overloaded : Ts... {
- using Ts::operator()...;
-};
-template <class... Ts>
-overloaded(Ts...) -> overloaded<Ts...>;
-
-} // anonymous namespace
-
absl::FormatConvertResult<absl::FormatConversionCharSet::kString> AbslFormatConvert(
const UnsupportedReason& value,
const absl::FormatConversionSpec& spec,
absl::FormatSink* s) {
- std::visit(
- overloaded{
- [](const std::monostate&) { DAWN_UNREACHABLE(); },
- [s](const RequiresFeature& requiresFeature) {
- s->Append(absl::StrFormat("requires feature %s", requiresFeature.feature));
- },
- [s](const CompatibilityMode&) { s->Append("not supported in compatibility mode"); }},
- value);
+ MatchVariant(
+ value, [](const std::monostate&) { DAWN_UNREACHABLE(); },
+ [s](const RequiresFeature& requiresFeature) {
+ s->Append(absl::StrFormat("requires feature %s", requiresFeature.feature));
+ },
+ [s](const CompatibilityMode&) { s->Append("not supported in compatibility mode"); });
return {true};
}
diff --git a/src/dawn/native/PipelineLayout.cpp b/src/dawn/native/PipelineLayout.cpp
index f2bf900..f17c259 100644
--- a/src/dawn/native/PipelineLayout.cpp
+++ b/src/dawn/native/PipelineLayout.cpp
@@ -34,6 +34,7 @@
#include "dawn/common/Assert.h"
#include "dawn/common/BitSetIterator.h"
#include "dawn/common/Enumerator.h"
+#include "dawn/common/MatchVariant.h"
#include "dawn/common/Numeric.h"
#include "dawn/common/Range.h"
#include "dawn/common/ityp_stack_vec.h"
@@ -233,58 +234,56 @@
-> BindGroupLayoutEntry {
BindGroupLayoutEntry entry = {};
- // TODO(dawn:2370): implement a helper in dawn/utils to simplify the call of std::visit.
- std::visit(
- [&](const auto& bindingInfo) {
- using T = std::decay_t<decltype(bindingInfo)>;
-
- if constexpr (std::is_same_v<T, BufferBindingInfo>) {
- entry.buffer.type = bindingInfo.type;
- entry.buffer.minBindingSize = bindingInfo.minBindingSize;
- } else if constexpr (std::is_same_v<T, SamplerBindingInfo>) {
- if (bindingInfo.isComparison) {
- entry.sampler.type = wgpu::SamplerBindingType::Comparison;
- } else {
- entry.sampler.type = wgpu::SamplerBindingType::Filtering;
- }
- } else if constexpr (std::is_same_v<T, SampledTextureBindingInfo>) {
- switch (bindingInfo.compatibleSampleTypes) {
- case SampleTypeBit::Depth:
- entry.texture.sampleType = wgpu::TextureSampleType::Depth;
- break;
- case SampleTypeBit::Sint:
- entry.texture.sampleType = wgpu::TextureSampleType::Sint;
- break;
- case SampleTypeBit::Uint:
- entry.texture.sampleType = wgpu::TextureSampleType::Uint;
- break;
- case SampleTypeBit::Float:
- case SampleTypeBit::UnfilterableFloat:
- case SampleTypeBit::None:
- DAWN_UNREACHABLE();
- break;
- default:
- if (bindingInfo.compatibleSampleTypes ==
- (SampleTypeBit::Float | SampleTypeBit::UnfilterableFloat)) {
- // Default to UnfilterableFloat. It will be promoted to Float if it
- // is used with a sampler.
- entry.texture.sampleType =
- wgpu::TextureSampleType::UnfilterableFloat;
- } else {
- DAWN_UNREACHABLE();
- }
- }
- entry.texture.viewDimension = bindingInfo.viewDimension;
- entry.texture.multisampled = bindingInfo.multisampled;
- } else if constexpr (std::is_same_v<T, StorageTextureBindingInfo>) {
- entry.storageTexture.access = bindingInfo.access;
- entry.storageTexture.format = bindingInfo.format;
- entry.storageTexture.viewDimension = bindingInfo.viewDimension;
- } else if constexpr (std::is_same_v<T, ExternalTextureBindingInfo>) {
- entry.nextInChain = externalTextureBindingEntry;
+ MatchVariant(
+ shaderBinding.bindingInfo,
+ [&](const BufferBindingInfo& bindingInfo) {
+ entry.buffer.type = bindingInfo.type;
+ entry.buffer.minBindingSize = bindingInfo.minBindingSize;
+ },
+ [&](const SamplerBindingInfo& bindingInfo) {
+ if (bindingInfo.isComparison) {
+ entry.sampler.type = wgpu::SamplerBindingType::Comparison;
+ } else {
+ entry.sampler.type = wgpu::SamplerBindingType::Filtering;
}
},
- shaderBinding.bindingInfo);
+ [&](const SampledTextureBindingInfo& bindingInfo) {
+ switch (bindingInfo.compatibleSampleTypes) {
+ case SampleTypeBit::Depth:
+ entry.texture.sampleType = wgpu::TextureSampleType::Depth;
+ break;
+ case SampleTypeBit::Sint:
+ entry.texture.sampleType = wgpu::TextureSampleType::Sint;
+ break;
+ case SampleTypeBit::Uint:
+ entry.texture.sampleType = wgpu::TextureSampleType::Uint;
+ break;
+ case SampleTypeBit::Float:
+ case SampleTypeBit::UnfilterableFloat:
+ case SampleTypeBit::None:
+ DAWN_UNREACHABLE();
+ break;
+ default:
+ if (bindingInfo.compatibleSampleTypes ==
+ (SampleTypeBit::Float | SampleTypeBit::UnfilterableFloat)) {
+ // Default to UnfilterableFloat. It will be promoted to Float
+ // if it is used with a sampler.
+ entry.texture.sampleType = wgpu::TextureSampleType::UnfilterableFloat;
+ } else {
+ DAWN_UNREACHABLE();
+ }
+ }
+ entry.texture.viewDimension = bindingInfo.viewDimension;
+ entry.texture.multisampled = bindingInfo.multisampled;
+ },
+ [&](const StorageTextureBindingInfo& bindingInfo) {
+ entry.storageTexture.access = bindingInfo.access;
+ entry.storageTexture.format = bindingInfo.format;
+ entry.storageTexture.viewDimension = bindingInfo.viewDimension;
+ },
+ [&](const ExternalTextureBindingInfo&) {
+ entry.nextInChain = externalTextureBindingEntry;
+ });
return entry;
};
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index 371d7f5..9cc8d3c 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -32,6 +32,7 @@
#include "dawn/common/BitSetIterator.h"
#include "dawn/common/Constants.h"
+#include "dawn/common/MatchVariant.h"
#include "dawn/native/BindGroupLayoutInternal.h"
#include "dawn/native/ChainUtils.h"
#include "dawn/native/CompilationMessages.h"
@@ -422,29 +423,12 @@
}
BindingInfoType GetShaderBindingType(const ShaderBindingInfo& shaderInfo) {
- return std::visit(
- [](const auto& bindingInfo) -> BindingInfoType {
- using T = std::decay_t<decltype(bindingInfo)>;
-
- if constexpr (std::is_same_v<T, BufferBindingInfo>) {
- return BindingInfoType::Buffer;
- }
- if constexpr (std::is_same_v<T, StorageTextureBindingInfo>) {
- return BindingInfoType::StorageTexture;
- }
- if constexpr (std::is_same_v<T, SampledTextureBindingInfo>) {
- return BindingInfoType::Texture;
- }
- if constexpr (std::is_same_v<T, SamplerBindingInfo>) {
- return BindingInfoType::Sampler;
- }
- if constexpr (std::is_same_v<T, ExternalTextureBindingInfo>) {
- return BindingInfoType::ExternalTexture;
- }
- DAWN_UNREACHABLE();
- return BindingInfoType::Buffer;
- },
- shaderInfo.bindingInfo);
+ return MatchVariant(
+ shaderInfo.bindingInfo, [&](const BufferBindingInfo&) { return BindingInfoType::Buffer; },
+ [&](const StorageTextureBindingInfo&) { return BindingInfoType::StorageTexture; },
+ [&](const SampledTextureBindingInfo&) { return BindingInfoType::Texture; },
+ [&](const SamplerBindingInfo&) { return BindingInfoType::Sampler; },
+ [&](const ExternalTextureBindingInfo&) { return BindingInfoType::ExternalTexture; });
}
MaybeError ValidateCompatibilityOfSingleBindingWithLayout(const DeviceBase* device,
@@ -494,97 +478,98 @@
"Entry point's stage (%s) is not in the binding visibility in the layout (%s).",
StageBit(entryPointStage), layoutInfo.visibility);
- // TODO(dawn:2370): implement a helper in dawn/utils to simplify the call of std::visit.
- return std::visit(
- [&](const auto& bindingInfo) -> MaybeError {
- using T = std::decay_t<decltype(bindingInfo)>;
+ return MatchVariant(
+ shaderInfo.bindingInfo,
+ [&](const SampledTextureBindingInfo& bindingInfo) -> MaybeError {
+ DAWN_INVALID_IF(
+ layoutInfo.texture.multisampled != bindingInfo.multisampled,
+ "Binding multisampled flag (%u) doesn't match the layout's multisampled "
+ "flag (%u)",
+ layoutInfo.texture.multisampled, bindingInfo.multisampled);
- if constexpr (std::is_same_v<T, SampledTextureBindingInfo>) {
- DAWN_INVALID_IF(
- layoutInfo.texture.multisampled != bindingInfo.multisampled,
- "Binding multisampled flag (%u) doesn't match the layout's multisampled "
- "flag (%u)",
- layoutInfo.texture.multisampled, bindingInfo.multisampled);
-
- // TODO(dawn:563): Provide info about the sample types.
- SampleTypeBit requiredType;
- if (layoutInfo.texture.sampleType == kInternalResolveAttachmentSampleType) {
- // If the layout's texture's sample type is
- // kInternalResolveAttachmentSampleType, then the shader's compatible sample
- // types must contain float.
- requiredType = SampleTypeBit::UnfilterableFloat;
- } else {
- requiredType = SampleTypeToSampleTypeBit(layoutInfo.texture.sampleType);
- }
-
- DAWN_INVALID_IF(!(bindingInfo.compatibleSampleTypes & requiredType),
- "The sample type in the shader is not compatible with the "
- "sample type of the layout.");
-
- DAWN_INVALID_IF(
- layoutInfo.texture.viewDimension != bindingInfo.viewDimension,
- "The shader's binding dimension (%s) doesn't match the shader's binding "
- "dimension (%s).",
- layoutInfo.texture.viewDimension, bindingInfo.viewDimension);
- } else if constexpr (std::is_same_v<T, StorageTextureBindingInfo>) {
- DAWN_ASSERT(layoutInfo.storageTexture.format != wgpu::TextureFormat::Undefined);
- DAWN_ASSERT(bindingInfo.format != wgpu::TextureFormat::Undefined);
-
- DAWN_INVALID_IF(
- !IsShaderCompatibleWithPipelineLayoutOnStorageTextureAccess(layoutInfo,
- bindingInfo),
- "The layout's binding access (%s) isn't compatible with the shader's "
- "binding access (%s).",
- layoutInfo.storageTexture.access, bindingInfo.access);
-
- DAWN_INVALID_IF(
- layoutInfo.storageTexture.format != bindingInfo.format,
- "The layout's binding format (%s) doesn't match the shader's binding "
- "format (%s).",
- layoutInfo.storageTexture.format, bindingInfo.format);
-
- DAWN_INVALID_IF(
- layoutInfo.storageTexture.viewDimension != bindingInfo.viewDimension,
- "The layout's binding dimension (%s) doesn't match the "
- "shader's binding dimension (%s).",
- layoutInfo.storageTexture.viewDimension, bindingInfo.viewDimension);
- } else if constexpr (std::is_same_v<T, BufferBindingInfo>) {
- // Binding mismatch between shader and bind group is invalid. For example, a
- // writable binding in the shader with a readonly storage buffer in the bind
- // group layout is invalid. For internal usage with internal shaders, a storage
- // binding in the shader with an internal storage buffer in the bind group
- // layout is also valid.
- bool validBindingConversion =
- (layoutInfo.buffer.type == kInternalStorageBufferBinding &&
- bindingInfo.type == wgpu::BufferBindingType::Storage);
-
- DAWN_INVALID_IF(
- layoutInfo.buffer.type != bindingInfo.type && !validBindingConversion,
- "The buffer type in the shader (%s) is not compatible with the type in the "
- "layout (%s).",
- bindingInfo.type, layoutInfo.buffer.type);
-
- DAWN_INVALID_IF(layoutInfo.buffer.minBindingSize != 0 &&
- bindingInfo.minBindingSize > layoutInfo.buffer.minBindingSize,
- "The shader uses more bytes of the buffer (%u) than the layout's "
- "minBindingSize (%u).",
- bindingInfo.minBindingSize, layoutInfo.buffer.minBindingSize);
- } else if constexpr (std::is_same_v<T, SamplerBindingInfo>) {
- DAWN_INVALID_IF(
- (layoutInfo.sampler.type == wgpu::SamplerBindingType::Comparison) !=
- bindingInfo.isComparison,
- "The sampler type in the shader (comparison: %u) doesn't match the type in "
- "the layout (comparison: %u).",
- bindingInfo.isComparison,
- layoutInfo.sampler.type == wgpu::SamplerBindingType::Comparison);
- } else if constexpr (std::is_same_v<T, ExternalTextureBindingInfo>) {
- DAWN_UNREACHABLE();
+ // TODO(dawn:563): Provide info about the sample types.
+ SampleTypeBit requiredType;
+ if (layoutInfo.texture.sampleType == kInternalResolveAttachmentSampleType) {
+ // If the layout's texture's sample type is
+ // kInternalResolveAttachmentSampleType, then the shader's compatible sample
+ // types must contain float.
+ requiredType = SampleTypeBit::UnfilterableFloat;
+ } else {
+ requiredType = SampleTypeToSampleTypeBit(layoutInfo.texture.sampleType);
}
+ DAWN_INVALID_IF(!(bindingInfo.compatibleSampleTypes & requiredType),
+ "The sample type in the shader is not compatible with the "
+ "sample type of the layout.");
+
+ DAWN_INVALID_IF(
+ layoutInfo.texture.viewDimension != bindingInfo.viewDimension,
+ "The shader's binding dimension (%s) doesn't match the shader's binding "
+ "dimension (%s).",
+ layoutInfo.texture.viewDimension, bindingInfo.viewDimension);
return {};
},
- shaderInfo.bindingInfo);
+ [&](const StorageTextureBindingInfo& bindingInfo) -> MaybeError {
+ DAWN_ASSERT(layoutInfo.storageTexture.format != wgpu::TextureFormat::Undefined);
+ DAWN_ASSERT(bindingInfo.format != wgpu::TextureFormat::Undefined);
+
+ DAWN_INVALID_IF(!IsShaderCompatibleWithPipelineLayoutOnStorageTextureAccess(
+ layoutInfo, bindingInfo),
+ "The layout's binding access (%s) isn't compatible with the shader's "
+ "binding access (%s).",
+ layoutInfo.storageTexture.access, bindingInfo.access);
+
+ DAWN_INVALID_IF(layoutInfo.storageTexture.format != bindingInfo.format,
+ "The layout's binding format (%s) doesn't match the shader's binding "
+ "format (%s).",
+ layoutInfo.storageTexture.format, bindingInfo.format);
+
+ DAWN_INVALID_IF(layoutInfo.storageTexture.viewDimension != bindingInfo.viewDimension,
+ "The layout's binding dimension (%s) doesn't match the "
+ "shader's binding dimension (%s).",
+ layoutInfo.storageTexture.viewDimension, bindingInfo.viewDimension);
+ return {};
+ },
+ [&](const BufferBindingInfo& bindingInfo) -> MaybeError {
+ // Binding mismatch between shader and bind group is invalid. For example, a
+ // writable binding in the shader with a readonly storage buffer in the bind
+ // group layout is invalid. For internal usage with internal shaders, a storage
+ // binding in the shader with an internal storage buffer in the bind group
+ // layout is also valid.
+ bool validBindingConversion =
+ (layoutInfo.buffer.type == kInternalStorageBufferBinding &&
+ bindingInfo.type == wgpu::BufferBindingType::Storage);
+
+ DAWN_INVALID_IF(
+ layoutInfo.buffer.type != bindingInfo.type && !validBindingConversion,
+ "The buffer type in the shader (%s) is not compatible with the type in the "
+ "layout (%s).",
+ bindingInfo.type, layoutInfo.buffer.type);
+
+ DAWN_INVALID_IF(layoutInfo.buffer.minBindingSize != 0 &&
+ bindingInfo.minBindingSize > layoutInfo.buffer.minBindingSize,
+ "The shader uses more bytes of the buffer (%u) than the layout's "
+ "minBindingSize (%u).",
+ bindingInfo.minBindingSize, layoutInfo.buffer.minBindingSize);
+ return {};
+ },
+ [&](const SamplerBindingInfo& bindingInfo) -> MaybeError {
+ DAWN_INVALID_IF(
+ (layoutInfo.sampler.type == wgpu::SamplerBindingType::Comparison) !=
+ bindingInfo.isComparison,
+ "The sampler type in the shader (comparison: %u) doesn't match the type in "
+ "the layout (comparison: %u).",
+ bindingInfo.isComparison,
+ layoutInfo.sampler.type == wgpu::SamplerBindingType::Comparison);
+ return {};
+ },
+
+ [&](const ExternalTextureBindingInfo&) -> MaybeError {
+ DAWN_UNREACHABLE();
+ return {};
+ });
}
+
MaybeError ValidateCompatibilityWithBindGroupLayout(DeviceBase* device,
BindGroupIndex group,
const EntryPointMetadata& entryPoint,
diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm
index fff5b27..52d818f 100644
--- a/src/dawn/native/metal/ShaderModuleMTL.mm
+++ b/src/dawn/native/metal/ShaderModuleMTL.mm
@@ -27,6 +27,7 @@
#include "dawn/native/metal/ShaderModuleMTL.h"
+#include "dawn/common/MatchVariant.h"
#include "dawn/native/BindGroupLayout.h"
#include "dawn/native/CacheRequest.h"
#include "dawn/native/Serializable.h"
@@ -147,65 +148,62 @@
tint::BindingPoint dstBindingPoint{0, shaderIndex};
- // TODO(dawn:2370): implement a helper in dawn/utils to simplify the call of std::visit.
- std::visit(
- [&](const auto& bindingInfo) {
- using T = std::decay_t<decltype(bindingInfo)>;
-
- if constexpr (std::is_same_v<T, BufferBindingInfo>) {
- switch (bindingInfo.type) {
- case wgpu::BufferBindingType::Uniform:
- bindings.uniform.emplace(
- srcBindingPoint,
- tint::msl::writer::binding::Uniform{dstBindingPoint.binding});
- break;
- case kInternalStorageBufferBinding:
- case wgpu::BufferBindingType::Storage:
- case wgpu::BufferBindingType::ReadOnlyStorage:
- bindings.storage.emplace(
- srcBindingPoint,
- tint::msl::writer::binding::Storage{dstBindingPoint.binding});
- // Use the ShaderIndex as the indices for the buffer size lookups in
- // the array length uniform transform. This is used to compute the
- // size of variable length arrays in storage buffers.
- arrayLengthFromUniform.bindpoint_to_size_index.emplace(
- dstBindingPoint, dstBindingPoint.binding);
- break;
- case wgpu::BufferBindingType::Undefined:
- DAWN_UNREACHABLE();
- break;
- }
- } else if constexpr (std::is_same_v<T, SamplerBindingInfo>) {
- bindings.sampler.emplace(
- srcBindingPoint,
- tint::msl::writer::binding::Sampler{dstBindingPoint.binding});
- } else if constexpr (std::is_same_v<T, SampledTextureBindingInfo>) {
- bindings.texture.emplace(
- srcBindingPoint,
- tint::msl::writer::binding::Texture{dstBindingPoint.binding});
- } else if constexpr (std::is_same_v<T, StorageTextureBindingInfo>) {
- bindings.storage_texture.emplace(
- srcBindingPoint,
- tint::msl::writer::binding::StorageTexture{dstBindingPoint.binding});
- } else if constexpr (std::is_same_v<T, ExternalTextureBindingInfo>) {
- const auto& etBindingMap = bgl->GetExternalTextureBindingExpansionMap();
- const auto& expansion = etBindingMap.find(binding);
- DAWN_ASSERT(expansion != etBindingMap.end());
-
- const auto& bindingExpansion = expansion->second;
- tint::msl::writer::binding::BindingInfo plane0{
- static_cast<uint32_t>(shaderIndex)};
- tint::msl::writer::binding::BindingInfo plane1{
- bindingIndexInfo[bgl->GetBindingIndex(bindingExpansion.plane1)]};
- tint::msl::writer::binding::BindingInfo metadata{
- bindingIndexInfo[bgl->GetBindingIndex(bindingExpansion.params)]};
-
- bindings.external_texture.emplace(
- srcBindingPoint,
- tint::msl::writer::binding::ExternalTexture{metadata, plane0, plane1});
+ MatchVariant(
+ shaderBindingInfo.bindingInfo,
+ [&](const BufferBindingInfo& bindingInfo) {
+ switch (bindingInfo.type) {
+ case wgpu::BufferBindingType::Uniform:
+ bindings.uniform.emplace(
+ srcBindingPoint,
+ tint::msl::writer::binding::Uniform{dstBindingPoint.binding});
+ break;
+ case kInternalStorageBufferBinding:
+ case wgpu::BufferBindingType::Storage:
+ case wgpu::BufferBindingType::ReadOnlyStorage:
+ bindings.storage.emplace(
+ srcBindingPoint,
+ tint::msl::writer::binding::Storage{dstBindingPoint.binding});
+ // Use the ShaderIndex as the indices for the buffer size lookups in
+ // the array length uniform transform. This is used to compute the
+ // size of variable length arrays in storage buffers.
+ arrayLengthFromUniform.bindpoint_to_size_index.emplace(
+ dstBindingPoint, dstBindingPoint.binding);
+ break;
+ case wgpu::BufferBindingType::Undefined:
+ DAWN_UNREACHABLE();
+ break;
}
},
- shaderBindingInfo.bindingInfo);
+ [&](const SamplerBindingInfo& bindingInfo) {
+ bindings.sampler.emplace(srcBindingPoint, tint::msl::writer::binding::Sampler{
+ dstBindingPoint.binding});
+ },
+ [&](const SampledTextureBindingInfo& bindingInfo) {
+ bindings.texture.emplace(srcBindingPoint, tint::msl::writer::binding::Texture{
+ dstBindingPoint.binding});
+ },
+ [&](const StorageTextureBindingInfo& bindingInfo) {
+ bindings.storage_texture.emplace(
+ srcBindingPoint,
+ tint::msl::writer::binding::StorageTexture{dstBindingPoint.binding});
+ },
+ [&](const ExternalTextureBindingInfo& bindingInfo) {
+ const auto& etBindingMap = bgl->GetExternalTextureBindingExpansionMap();
+ const auto& expansion = etBindingMap.find(binding);
+ DAWN_ASSERT(expansion != etBindingMap.end());
+
+ const auto& bindingExpansion = expansion->second;
+ tint::msl::writer::binding::BindingInfo plane0{
+ static_cast<uint32_t>(shaderIndex)};
+ tint::msl::writer::binding::BindingInfo plane1{
+ bindingIndexInfo[bgl->GetBindingIndex(bindingExpansion.plane1)]};
+ tint::msl::writer::binding::BindingInfo metadata{
+ bindingIndexInfo[bgl->GetBindingIndex(bindingExpansion.params)]};
+
+ bindings.external_texture.emplace(
+ srcBindingPoint,
+ tint::msl::writer::binding::ExternalTexture{metadata, plane0, plane1});
+ });
}
}
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp
index 6ad0516..0796340 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -32,6 +32,7 @@
#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "dawn/common/MatchVariant.h"
#include "dawn/native/CacheRequest.h"
#include "dawn/native/PhysicalDevice.h"
#include "dawn/native/Serializable.h"
@@ -250,65 +251,64 @@
tint::BindingPoint dstBindingPoint{
static_cast<uint32_t>(group), static_cast<uint32_t>(bgl->GetBindingIndex(binding))};
- // TODO(dawn:2370): implement a helper in dawn/utils to simplify the call of std::visit.
- std::visit(
- [&](const auto& bindingInfo) {
- using T = std::decay_t<decltype(bindingInfo)>;
-
- if constexpr (std::is_same_v<T, BufferBindingInfo>) {
- switch (bindingInfo.type) {
- case wgpu::BufferBindingType::Uniform:
- bindings.uniform.emplace(
- srcBindingPoint,
- tint::spirv::writer::binding::Uniform{dstBindingPoint.group,
- dstBindingPoint.binding});
- break;
- case kInternalStorageBufferBinding:
- case wgpu::BufferBindingType::Storage:
- case wgpu::BufferBindingType::ReadOnlyStorage:
- bindings.storage.emplace(
- srcBindingPoint,
- tint::spirv::writer::binding::Storage{dstBindingPoint.group,
- dstBindingPoint.binding});
- break;
- case wgpu::BufferBindingType::Undefined:
- DAWN_UNREACHABLE();
- break;
- }
- } else if constexpr (std::is_same_v<T, SamplerBindingInfo>) {
- bindings.sampler.emplace(
- srcBindingPoint, tint::spirv::writer::binding::Sampler{
- dstBindingPoint.group, dstBindingPoint.binding});
- } else if constexpr (std::is_same_v<T, SampledTextureBindingInfo>) {
- bindings.texture.emplace(
- srcBindingPoint, tint::spirv::writer::binding::Texture{
- dstBindingPoint.group, dstBindingPoint.binding});
- } else if constexpr (std::is_same_v<T, StorageTextureBindingInfo>) {
- bindings.storage_texture.emplace(
- srcBindingPoint, tint::spirv::writer::binding::StorageTexture{
- dstBindingPoint.group, dstBindingPoint.binding});
- } else if constexpr (std::is_same_v<T, ExternalTextureBindingInfo>) {
- const auto& bindingMap = bgl->GetExternalTextureBindingExpansionMap();
- const auto& expansion = bindingMap.find(binding);
- DAWN_ASSERT(expansion != bindingMap.end());
-
- const auto& bindingExpansion = expansion->second;
- tint::spirv::writer::binding::BindingInfo plane0{
- static_cast<uint32_t>(group),
- static_cast<uint32_t>(bgl->GetBindingIndex(bindingExpansion.plane0))};
- tint::spirv::writer::binding::BindingInfo plane1{
- static_cast<uint32_t>(group),
- static_cast<uint32_t>(bgl->GetBindingIndex(bindingExpansion.plane1))};
- tint::spirv::writer::binding::BindingInfo metadata{
- static_cast<uint32_t>(group),
- static_cast<uint32_t>(bgl->GetBindingIndex(bindingExpansion.params))};
-
- bindings.external_texture.emplace(
- srcBindingPoint, tint::spirv::writer::binding::ExternalTexture{
- metadata, plane0, plane1});
+ MatchVariant(
+ shaderBindingInfo.bindingInfo,
+ [&](const BufferBindingInfo& bindingInfo) {
+ switch (bindingInfo.type) {
+ case wgpu::BufferBindingType::Uniform:
+ bindings.uniform.emplace(
+ srcBindingPoint,
+ tint::spirv::writer::binding::Uniform{dstBindingPoint.group,
+ dstBindingPoint.binding});
+ break;
+ case kInternalStorageBufferBinding:
+ case wgpu::BufferBindingType::Storage:
+ case wgpu::BufferBindingType::ReadOnlyStorage:
+ bindings.storage.emplace(
+ srcBindingPoint,
+ tint::spirv::writer::binding::Storage{dstBindingPoint.group,
+ dstBindingPoint.binding});
+ break;
+ case wgpu::BufferBindingType::Undefined:
+ DAWN_UNREACHABLE();
+ break;
}
},
- shaderBindingInfo.bindingInfo);
+ [&](const SamplerBindingInfo& bindingInfo) {
+ bindings.sampler.emplace(srcBindingPoint,
+ tint::spirv::writer::binding::Sampler{
+ dstBindingPoint.group, dstBindingPoint.binding});
+ },
+ [&](const SampledTextureBindingInfo& bindingInfo) {
+ bindings.texture.emplace(srcBindingPoint,
+ tint::spirv::writer::binding::Texture{
+ dstBindingPoint.group, dstBindingPoint.binding});
+ },
+ [&](const StorageTextureBindingInfo& bindingInfo) {
+ bindings.storage_texture.emplace(
+ srcBindingPoint, tint::spirv::writer::binding::StorageTexture{
+ dstBindingPoint.group, dstBindingPoint.binding});
+ },
+ [&](const ExternalTextureBindingInfo& bindingInfo) {
+ const auto& bindingMap = bgl->GetExternalTextureBindingExpansionMap();
+ const auto& expansion = bindingMap.find(binding);
+ DAWN_ASSERT(expansion != bindingMap.end());
+
+ const auto& bindingExpansion = expansion->second;
+ tint::spirv::writer::binding::BindingInfo plane0{
+ static_cast<uint32_t>(group),
+ static_cast<uint32_t>(bgl->GetBindingIndex(bindingExpansion.plane0))};
+ tint::spirv::writer::binding::BindingInfo plane1{
+ static_cast<uint32_t>(group),
+ static_cast<uint32_t>(bgl->GetBindingIndex(bindingExpansion.plane1))};
+ tint::spirv::writer::binding::BindingInfo metadata{
+ static_cast<uint32_t>(group),
+ static_cast<uint32_t>(bgl->GetBindingIndex(bindingExpansion.params))};
+
+ bindings.external_texture.emplace(
+ srcBindingPoint,
+ tint::spirv::writer::binding::ExternalTexture{metadata, plane0, plane1});
+ });
}
}