dawn: Allow internal extensions when creating shader module.
Bug: 42240662
Change-Id: I64fecb2b019337586fa89a8897def1e8a70b2e52
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/191781
Commit-Queue: Quyen Le <lehoangquyen@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp
index 713a6eb..e0ac3fd 100644
--- a/src/dawn/native/Device.cpp
+++ b/src/dawn/native/Device.cpp
@@ -1161,11 +1161,13 @@
ResultOrError<Ref<ShaderModuleBase>> DeviceBase::GetOrCreateShaderModule(
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
std::unique_ptr<OwnedCompilationMessages>* compilationMessages) {
DAWN_ASSERT(parseResult != nullptr);
- ShaderModuleBase blueprint(this, descriptor, ApiObjectBase::kUntrackedByDevice);
+ ShaderModuleBase blueprint(this, descriptor, internalExtensions,
+ ApiObjectBase::kUntrackedByDevice);
const size_t blueprintHash = blueprint.ComputeContentHash();
blueprint.SetContentHash(blueprintHash);
@@ -1178,13 +1180,14 @@
// lookup in the cache without validating and parsing. We need the parsed module
// now.
DAWN_ASSERT(!IsValidationEnabled());
- DAWN_TRY(
- ValidateAndParseShaderModule(this, descriptor, parseResult, unownedMessages));
+ DAWN_TRY(ValidateAndParseShaderModule(this, descriptor, internalExtensions,
+ parseResult, unownedMessages));
}
auto resultOrError = [&]() -> ResultOrError<Ref<ShaderModuleBase>> {
SCOPED_DAWN_HISTOGRAM_TIMER_MICROS(GetPlatform(), "CreateShaderModuleUS");
- return CreateShaderModuleImpl(descriptor, parseResult, unownedMessages);
+ return CreateShaderModuleImpl(descriptor, internalExtensions, parseResult,
+ unownedMessages);
}();
DAWN_HISTOGRAM_BOOLEAN(GetPlatform(), "CreateShaderModuleSuccess",
resultOrError.IsSuccess());
@@ -1528,7 +1531,8 @@
std::unique_ptr<OwnedCompilationMessages> compilationMessages(
std::make_unique<OwnedCompilationMessages>());
- auto resultOrError = CreateShaderModule(descriptor, &compilationMessages);
+ auto resultOrError =
+ CreateShaderModule(descriptor, /*internalExtensions=*/{}, &compilationMessages);
if (resultOrError.IsSuccess()) {
Ref<ShaderModuleBase> result = resultOrError.AcquireSuccess();
EmitCompilationLog(result.Get());
@@ -1827,10 +1831,6 @@
return mWGSLAllowedFeatures;
}
-void DeviceBase::EnableAdditionalWGSLExtension(tint::wgsl::Extension extension) {
- mWGSLAllowedFeatures.extensions.insert(extension);
-}
-
bool DeviceBase::IsValidationEnabled() const {
return !IsToggleEnabled(Toggle::SkipValidation);
}
@@ -2237,6 +2237,7 @@
ResultOrError<Ref<ShaderModuleBase>> DeviceBase::CreateShaderModule(
const ShaderModuleDescriptor* descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
std::unique_ptr<OwnedCompilationMessages>* compilationMessages) {
DAWN_TRY(ValidateIsAlive());
@@ -2250,14 +2251,14 @@
DAWN_TRY_ASSIGN_CONTEXT(unpacked, ValidateAndUnpack(descriptor),
"validating and unpacking %s", descriptor);
DAWN_TRY_CONTEXT(ValidateAndParseShaderModule(
- this, unpacked, &parseResult,
+ this, unpacked, internalExtensions, &parseResult,
compilationMessages ? compilationMessages->get() : nullptr),
"validating %s", descriptor);
} else {
unpacked = Unpack(descriptor);
}
- return GetOrCreateShaderModule(unpacked, &parseResult, compilationMessages);
+ return GetOrCreateShaderModule(unpacked, internalExtensions, &parseResult, compilationMessages);
}
ResultOrError<Ref<SwapChainBase>> DeviceBase::CreateSwapChain(
diff --git a/src/dawn/native/Device.h b/src/dawn/native/Device.h
index 2d395a4..3a9189f 100644
--- a/src/dawn/native/Device.h
+++ b/src/dawn/native/Device.h
@@ -185,6 +185,7 @@
ResultOrError<Ref<ShaderModuleBase>> GetOrCreateShaderModule(
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
std::unique_ptr<OwnedCompilationMessages>* compilationMessages);
@@ -228,6 +229,7 @@
ResultOrError<Ref<SamplerBase>> CreateSampler(const SamplerDescriptor* descriptor = nullptr);
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModule(
const ShaderModuleDescriptor* descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions = {},
std::unique_ptr<OwnedCompilationMessages>* compilationMessages = nullptr);
// Deprecated: this was the way to create a SwapChain when it was explicitly manipulated by the
// end user.
@@ -469,8 +471,6 @@
void DestroyObjects();
void Destroy();
- void EnableAdditionalWGSLExtension(tint::wgsl::Extension extension);
-
virtual MaybeError GetAHardwareBufferPropertiesImpl(
void* handle,
AHardwareBufferProperties* properties) const {
@@ -501,6 +501,7 @@
const SamplerDescriptor* descriptor) = 0;
virtual ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) = 0;
// Note that previousSwapChain may be nullptr, or come from a different backend.
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index 123633c..6ba5f1f 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -348,9 +348,12 @@
ResultOrError<tint::Program> ParseWGSL(const tint::Source::File* file,
const tint::wgsl::AllowedFeatures& allowedFeatures,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
OwnedCompilationMessages* outMessages) {
tint::wgsl::reader::Options options;
options.allowed_features = allowedFeatures;
+ options.allowed_features.extensions.insert(internalExtensions.begin(),
+ internalExtensions.end());
tint::Program program = tint::wgsl::reader::Parse(file, options);
if (outMessages != nullptr) {
DAWN_TRY(outMessages->AddMessages(program.Diagnostics()));
@@ -1071,10 +1074,12 @@
return tintProgram != nullptr;
}
-MaybeError ValidateAndParseShaderModule(DeviceBase* device,
- const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
- ShaderModuleParseResult* parseResult,
- OwnedCompilationMessages* outMessages) {
+MaybeError ValidateAndParseShaderModule(
+ DeviceBase* device,
+ const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
+ ShaderModuleParseResult* parseResult,
+ OwnedCompilationMessages* outMessages) {
DAWN_ASSERT(parseResult != nullptr);
wgpu::SType moduleType;
@@ -1148,8 +1153,8 @@
}
tint::Program program;
- DAWN_TRY_ASSIGN(program,
- ParseWGSL(tintFile.get(), device->GetWGSLAllowedFeatures(), outMessages));
+ DAWN_TRY_ASSIGN(program, ParseWGSL(tintFile.get(), device->GetWGSLAllowedFeatures(),
+ internalExtensions, outMessages));
parseResult->tintProgram = AcquireRef(new TintProgram(std::move(program), std::move(tintFile)));
@@ -1305,8 +1310,11 @@
ShaderModuleBase::ShaderModuleBase(DeviceBase* device,
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ std::vector<tint::wgsl::Extension> internalExtensions,
ApiObjectBase::UntrackedByDeviceTag tag)
- : Base(device, descriptor->label), mType(Type::Undefined) {
+ : Base(device, descriptor->label),
+ mType(Type::Undefined),
+ mInternalExtensions(std::move(internalExtensions)) {
if (auto* spirvDesc = descriptor.Get<ShaderModuleSPIRVDescriptor>()) {
mType = Type::Spirv;
mOriginalSpirv.assign(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
@@ -1323,8 +1331,9 @@
}
ShaderModuleBase::ShaderModuleBase(DeviceBase* device,
- const UnpackedPtr<ShaderModuleDescriptor>& descriptor)
- : ShaderModuleBase(device, descriptor, kUntrackedByDevice) {
+ const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ std::vector<tint::wgsl::Extension> internalExtensions)
+ : ShaderModuleBase(device, descriptor, std::move(internalExtensions), kUntrackedByDevice) {
GetObjectTrackingList()->Track(this);
}
@@ -1419,7 +1428,8 @@
}
ShaderModuleParseResult parseResult;
- ValidateAndParseShaderModule(GetDevice(), Unpack(&descriptor), &parseResult,
+ ValidateAndParseShaderModule(GetDevice(), Unpack(&descriptor), mInternalExtensions,
+ &parseResult,
/*compilationMessages=*/nullptr)
.AcquireSuccess();
DAWN_ASSERT(parseResult.tintProgram != nullptr);
diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h
index 339a56b..18ffe10 100644
--- a/src/dawn/native/ShaderModule.h
+++ b/src/dawn/native/ShaderModule.h
@@ -130,10 +130,12 @@
std::string name;
};
-MaybeError ValidateAndParseShaderModule(DeviceBase* device,
- const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
- ShaderModuleParseResult* parseResult,
- OwnedCompilationMessages* outMessages);
+MaybeError ValidateAndParseShaderModule(
+ DeviceBase* device,
+ const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
+ ShaderModuleParseResult* parseResult,
+ OwnedCompilationMessages* outMessages);
MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
const EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout);
@@ -284,8 +286,11 @@
using Base = RefCountedWithExternalCountBase<ApiObjectBase>;
ShaderModuleBase(DeviceBase* device,
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ std::vector<tint::wgsl::Extension> internalExtensions,
ApiObjectBase::UntrackedByDeviceTag tag);
- ShaderModuleBase(DeviceBase* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor);
+ ShaderModuleBase(DeviceBase* device,
+ const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ std::vector<tint::wgsl::Extension> internalExtensions);
~ShaderModuleBase() override;
static Ref<ShaderModuleBase> MakeError(DeviceBase* device, const char* label);
@@ -361,6 +366,8 @@
MutexProtected<TintData> mTintData;
std::unique_ptr<OwnedCompilationMessages> mCompilationMessages;
+
+ const std::vector<tint::wgsl::Extension> mInternalExtensions;
};
} // namespace dawn::native
diff --git a/src/dawn/native/d3d11/DeviceD3D11.cpp b/src/dawn/native/d3d11/DeviceD3D11.cpp
index 2430bb6..757fad9 100644
--- a/src/dawn/native/d3d11/DeviceD3D11.cpp
+++ b/src/dawn/native/d3d11/DeviceD3D11.cpp
@@ -211,9 +211,11 @@
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
- return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
+ return ShaderModule::Create(this, descriptor, internalExtensions, parseResult,
+ compilationMessages);
}
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(Surface* surface,
diff --git a/src/dawn/native/d3d11/DeviceD3D11.h b/src/dawn/native/d3d11/DeviceD3D11.h
index 332dcdb..de72f45 100644
--- a/src/dawn/native/d3d11/DeviceD3D11.h
+++ b/src/dawn/native/d3d11/DeviceD3D11.h
@@ -112,6 +112,7 @@
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) override;
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
diff --git a/src/dawn/native/d3d11/ShaderModuleD3D11.cpp b/src/dawn/native/d3d11/ShaderModuleD3D11.cpp
index c5469c5..58b214e 100644
--- a/src/dawn/native/d3d11/ShaderModuleD3D11.cpp
+++ b/src/dawn/native/d3d11/ShaderModuleD3D11.cpp
@@ -56,15 +56,18 @@
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
Device* device,
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
- Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
+ Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
DAWN_TRY(module->Initialize(parseResult, compilationMessages));
return module;
}
-ShaderModule::ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor)
- : ShaderModuleBase(device, descriptor) {}
+ShaderModule::ShaderModule(Device* device,
+ const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ std::vector<tint::wgsl::Extension> internalExtensions)
+ : ShaderModuleBase(device, descriptor, std::move(internalExtensions)) {}
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
diff --git a/src/dawn/native/d3d11/ShaderModuleD3D11.h b/src/dawn/native/d3d11/ShaderModuleD3D11.h
index b436aee..a99f963 100644
--- a/src/dawn/native/d3d11/ShaderModuleD3D11.h
+++ b/src/dawn/native/d3d11/ShaderModuleD3D11.h
@@ -30,6 +30,7 @@
#include <optional>
#include <string>
+#include <vector>
#include "dawn/native/Blob.h"
#include "dawn/native/Serializable.h"
@@ -51,6 +52,7 @@
static ResultOrError<Ref<ShaderModule>> Create(
Device* device,
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
@@ -64,7 +66,9 @@
const std::optional<tint::hlsl::writer::PixelLocalOptions>& pixelLocalOptions = {});
private:
- ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor);
+ ShaderModule(Device* device,
+ const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ std::vector<tint::wgsl::Extension> internalExtensions);
~ShaderModule() override = default;
MaybeError Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
diff --git a/src/dawn/native/d3d12/DeviceD3D12.cpp b/src/dawn/native/d3d12/DeviceD3D12.cpp
index a46aad7..04613dd 100644
--- a/src/dawn/native/d3d12/DeviceD3D12.cpp
+++ b/src/dawn/native/d3d12/DeviceD3D12.cpp
@@ -395,9 +395,11 @@
}
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
- return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
+ return ShaderModule::Create(this, descriptor, internalExtensions, parseResult,
+ compilationMessages);
}
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(Surface* surface,
SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/d3d12/DeviceD3D12.h b/src/dawn/native/d3d12/DeviceD3D12.h
index a46bf8f..bced302 100644
--- a/src/dawn/native/d3d12/DeviceD3D12.h
+++ b/src/dawn/native/d3d12/DeviceD3D12.h
@@ -191,6 +191,7 @@
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) override;
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
index c5b7401..099f258 100644
--- a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
@@ -106,15 +106,18 @@
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
Device* device,
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
- Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
+ Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
DAWN_TRY(module->Initialize(parseResult, compilationMessages));
return module;
}
-ShaderModule::ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor)
- : ShaderModuleBase(device, descriptor) {}
+ShaderModule::ShaderModule(Device* device,
+ const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ std::vector<tint::wgsl::Extension> internalExtensions)
+ : ShaderModuleBase(device, descriptor, std::move(internalExtensions)) {}
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.h b/src/dawn/native/d3d12/ShaderModuleD3D12.h
index f98050a..f5e143d 100644
--- a/src/dawn/native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn/native/d3d12/ShaderModuleD3D12.h
@@ -30,6 +30,7 @@
#include <optional>
#include <string>
+#include <vector>
#include "dawn/native/Blob.h"
#include "dawn/native/Serializable.h"
@@ -51,6 +52,7 @@
static ResultOrError<Ref<ShaderModule>> Create(
Device* device,
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
@@ -64,7 +66,9 @@
std::optional<uint32_t> maxSubgroupSizeForFullSubgroups = std::nullopt);
private:
- ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor);
+ ShaderModule(Device* device,
+ const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ std::vector<tint::wgsl::Extension> internalExtensions);
~ShaderModule() override = default;
MaybeError Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
diff --git a/src/dawn/native/metal/DeviceMTL.h b/src/dawn/native/metal/DeviceMTL.h
index 44e8f8f..4ac2200 100644
--- a/src/dawn/native/metal/DeviceMTL.h
+++ b/src/dawn/native/metal/DeviceMTL.h
@@ -110,6 +110,7 @@
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) override;
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
diff --git a/src/dawn/native/metal/DeviceMTL.mm b/src/dawn/native/metal/DeviceMTL.mm
index 2a58e4e..73c810e 100644
--- a/src/dawn/native/metal/DeviceMTL.mm
+++ b/src/dawn/native/metal/DeviceMTL.mm
@@ -225,9 +225,11 @@
}
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
- return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
+ return ShaderModule::Create(this, descriptor, internalExtensions, parseResult,
+ compilationMessages);
}
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(Surface* surface,
SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/metal/ShaderModuleMTL.h b/src/dawn/native/metal/ShaderModuleMTL.h
index c04327f..9f34270 100644
--- a/src/dawn/native/metal/ShaderModuleMTL.h
+++ b/src/dawn/native/metal/ShaderModuleMTL.h
@@ -53,6 +53,7 @@
static ResultOrError<Ref<ShaderModule>> Create(
Device* device,
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
@@ -73,7 +74,9 @@
std::optional<uint32_t> maxSubgroupSizeForFullSubgroups = std::nullopt);
private:
- ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor);
+ ShaderModule(Device* device,
+ const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ std::vector<tint::wgsl::Extension> internalExtensions);
~ShaderModule() override;
MaybeError Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm
index be8d15a..70c4282 100644
--- a/src/dawn/native/metal/ShaderModuleMTL.mm
+++ b/src/dawn/native/metal/ShaderModuleMTL.mm
@@ -92,15 +92,18 @@
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
Device* device,
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
- Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
+ Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
DAWN_TRY(module->Initialize(parseResult, compilationMessages));
return module;
}
-ShaderModule::ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor)
- : ShaderModuleBase(device, descriptor) {}
+ShaderModule::ShaderModule(Device* device,
+ const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ std::vector<tint::wgsl::Extension> internalExtensions)
+ : ShaderModuleBase(device, descriptor, std::move(internalExtensions)) {}
ShaderModule::~ShaderModule() = default;
diff --git a/src/dawn/native/null/DeviceNull.cpp b/src/dawn/native/null/DeviceNull.cpp
index 6636f88..d8a1e2f 100644
--- a/src/dawn/native/null/DeviceNull.cpp
+++ b/src/dawn/native/null/DeviceNull.cpp
@@ -227,9 +227,10 @@
}
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
- Ref<ShaderModule> module = AcquireRef(new ShaderModule(this, descriptor));
+ Ref<ShaderModule> module = AcquireRef(new ShaderModule(this, descriptor, internalExtensions));
DAWN_TRY(module->Initialize(parseResult, compilationMessages));
return module;
}
diff --git a/src/dawn/native/null/DeviceNull.h b/src/dawn/native/null/DeviceNull.h
index a9190ca..954ba68 100644
--- a/src/dawn/native/null/DeviceNull.h
+++ b/src/dawn/native/null/DeviceNull.h
@@ -160,6 +160,7 @@
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) override;
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
diff --git a/src/dawn/native/opengl/DeviceGL.cpp b/src/dawn/native/opengl/DeviceGL.cpp
index e9842fc..c8b3db8 100644
--- a/src/dawn/native/opengl/DeviceGL.cpp
+++ b/src/dawn/native/opengl/DeviceGL.cpp
@@ -257,9 +257,11 @@
}
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
- return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
+ return ShaderModule::Create(this, descriptor, internalExtensions, parseResult,
+ compilationMessages);
}
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(Surface* surface,
SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/opengl/DeviceGL.h b/src/dawn/native/opengl/DeviceGL.h
index 1e943fe..f4418f6 100644
--- a/src/dawn/native/opengl/DeviceGL.h
+++ b/src/dawn/native/opengl/DeviceGL.h
@@ -29,6 +29,7 @@
#define SRC_DAWN_NATIVE_OPENGL_DEVICEGL_H_
#include <memory>
+#include <vector>
#include "dawn/native/dawn_platform.h"
@@ -121,6 +122,7 @@
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) override;
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
diff --git a/src/dawn/native/opengl/ShaderModuleGL.cpp b/src/dawn/native/opengl/ShaderModuleGL.cpp
index fa031a2..9a8bce8 100644
--- a/src/dawn/native/opengl/ShaderModuleGL.cpp
+++ b/src/dawn/native/opengl/ShaderModuleGL.cpp
@@ -146,15 +146,18 @@
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
Device* device,
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
- Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
+ Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
DAWN_TRY(module->Initialize(parseResult, compilationMessages));
return module;
}
-ShaderModule::ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor)
- : ShaderModuleBase(device, descriptor) {}
+ShaderModule::ShaderModule(Device* device,
+ const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ std::vector<tint::wgsl::Extension> internalExtensions)
+ : ShaderModuleBase(device, descriptor, std::move(internalExtensions)) {}
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
diff --git a/src/dawn/native/opengl/ShaderModuleGL.h b/src/dawn/native/opengl/ShaderModuleGL.h
index 2c04be3..133eb44 100644
--- a/src/dawn/native/opengl/ShaderModuleGL.h
+++ b/src/dawn/native/opengl/ShaderModuleGL.h
@@ -83,6 +83,7 @@
static ResultOrError<Ref<ShaderModule>> Create(
Device* device,
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
@@ -99,7 +100,9 @@
BindingPointToFunctionAndOffset* bindingPointToData) const;
private:
- ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor);
+ ShaderModule(Device* device,
+ const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ std::vector<tint::wgsl::Extension> internalExtensions);
~ShaderModule() override = default;
MaybeError Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
diff --git a/src/dawn/native/utils/WGPUHelpers.cpp b/src/dawn/native/utils/WGPUHelpers.cpp
index 06aedae..b1606ac 100644
--- a/src/dawn/native/utils/WGPUHelpers.cpp
+++ b/src/dawn/native/utils/WGPUHelpers.cpp
@@ -47,12 +47,15 @@
namespace dawn::native::utils {
-ResultOrError<Ref<ShaderModuleBase>> CreateShaderModule(DeviceBase* device, const char* source) {
+ResultOrError<Ref<ShaderModuleBase>> CreateShaderModule(
+ DeviceBase* device,
+ const char* source,
+ const std::vector<tint::wgsl::Extension>& internalExtensions) {
ShaderModuleWGSLDescriptor wgslDesc;
wgslDesc.code = source;
ShaderModuleDescriptor descriptor;
descriptor.nextInChain = &wgslDesc;
- return device->CreateShaderModule(&descriptor);
+ return device->CreateShaderModule(&descriptor, internalExtensions);
}
ResultOrError<Ref<BufferBase>> CreateBufferFromData(DeviceBase* device,
diff --git a/src/dawn/native/utils/WGPUHelpers.h b/src/dawn/native/utils/WGPUHelpers.h
index 5790df7..ae18e3f 100644
--- a/src/dawn/native/utils/WGPUHelpers.h
+++ b/src/dawn/native/utils/WGPUHelpers.h
@@ -37,9 +37,16 @@
#include "dawn/native/UsageValidationMode.h"
#include "dawn/native/dawn_platform.h"
+namespace tint::wgsl {
+enum class Extension : uint8_t;
+}
+
namespace dawn::native::utils {
-ResultOrError<Ref<ShaderModuleBase>> CreateShaderModule(DeviceBase* device, const char* source);
+ResultOrError<Ref<ShaderModuleBase>> CreateShaderModule(
+ DeviceBase* device,
+ const char* source,
+ const std::vector<tint::wgsl::Extension>& internalExtensions = {});
ResultOrError<Ref<BufferBase>> CreateBufferFromData(DeviceBase* device,
wgpu::BufferUsage usage,
diff --git a/src/dawn/native/vulkan/DeviceVk.cpp b/src/dawn/native/vulkan/DeviceVk.cpp
index 55879ea..0e504c2 100644
--- a/src/dawn/native/vulkan/DeviceVk.cpp
+++ b/src/dawn/native/vulkan/DeviceVk.cpp
@@ -168,15 +168,7 @@
Ref<Queue> queue;
DAWN_TRY_ASSIGN(queue, Queue::Create(this, &descriptor->defaultQueue, mMainQueueFamily));
- DAWN_TRY(DeviceBase::Initialize(std::move(queue)));
-
- if (HasFeature(Feature::DawnLoadResolveTexture)) {
- // TODO(42240662): Add a way to add additional extensions when compiling specific shader
- // modules only.
- EnableAdditionalWGSLExtension(tint::wgsl::Extension::kChromiumInternalInputAttachments);
- }
-
- return {};
+ return DeviceBase::Initialize(std::move(queue));
}
Device::~Device() {
@@ -220,9 +212,11 @@
}
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
- return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
+ return ShaderModule::Create(this, descriptor, internalExtensions, parseResult,
+ compilationMessages);
}
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(Surface* surface,
SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/vulkan/DeviceVk.h b/src/dawn/native/vulkan/DeviceVk.h
index dddbd4a..ba12063 100644
--- a/src/dawn/native/vulkan/DeviceVk.h
+++ b/src/dawn/native/vulkan/DeviceVk.h
@@ -143,6 +143,7 @@
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) override;
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
diff --git a/src/dawn/native/vulkan/ResolveTextureLoadingUtilsVk.cpp b/src/dawn/native/vulkan/ResolveTextureLoadingUtilsVk.cpp
index 3f00cad..6f4cf7c 100644
--- a/src/dawn/native/vulkan/ResolveTextureLoadingUtilsVk.cpp
+++ b/src/dawn/native/vulkan/ResolveTextureLoadingUtilsVk.cpp
@@ -116,7 +116,9 @@
// fragment shader's source will depend on pipeline key.
std::string fsCode = GenerateFS(pipelineKey);
Ref<ShaderModuleBase> fshaderModule;
- DAWN_TRY_ASSIGN(fshaderModule, utils::CreateShaderModule(device, fsCode.c_str()));
+ DAWN_TRY_ASSIGN(fshaderModule, utils::CreateShaderModule(
+ device, fsCode.c_str(),
+ {tint::wgsl::Extension::kChromiumInternalInputAttachments}));
FragmentState fragmentState = {};
fragmentState.module = fshaderModule.Get();
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp
index 5fbdd3c..2aa94bd 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -165,15 +165,18 @@
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
Device* device,
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
- Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
+ Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
DAWN_TRY(module->Initialize(parseResult, compilationMessages));
return module;
}
-ShaderModule::ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor)
- : ShaderModuleBase(device, descriptor),
+ShaderModule::ShaderModule(Device* device,
+ const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ std::vector<tint::wgsl::Extension> internalExtensions)
+ : ShaderModuleBase(device, descriptor, std::move(internalExtensions)),
mTransformedShaderModuleCache(
std::make_unique<ConcurrentTransformedShaderModuleCache>(device)) {}
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.h b/src/dawn/native/vulkan/ShaderModuleVk.h
index 37fe631..aad21e5 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.h
+++ b/src/dawn/native/vulkan/ShaderModuleVk.h
@@ -33,6 +33,7 @@
#include <optional>
#include <string>
#include <utility>
+#include <vector>
#include "dawn/common/HashUtils.h"
#include "dawn/common/vulkan_platform.h"
@@ -76,6 +77,7 @@
static ResultOrError<Ref<ShaderModule>> Create(
Device* device,
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
@@ -88,7 +90,9 @@
std::optional<uint32_t> maxSubgroupSizeForFullSubgroups);
private:
- ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor);
+ ShaderModule(Device* device,
+ const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ std::vector<tint::wgsl::Extension> internalExtensions);
~ShaderModule() override;
MaybeError Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
diff --git a/src/dawn/tests/unittests/native/mocks/DeviceMock.h b/src/dawn/tests/unittests/native/mocks/DeviceMock.h
index ec7b899..0a3703c 100644
--- a/src/dawn/tests/unittests/native/mocks/DeviceMock.h
+++ b/src/dawn/tests/unittests/native/mocks/DeviceMock.h
@@ -29,6 +29,7 @@
#define SRC_DAWN_TESTS_UNITTESTS_NATIVE_MOCKS_DEVICEMOCK_H_
#include <memory>
+#include <vector>
#include "dawn/native/Device.h"
#include "dawn/native/Instance.h"
@@ -119,6 +120,7 @@
MOCK_METHOD(ResultOrError<Ref<ShaderModuleBase>>,
CreateShaderModuleImpl,
(const UnpackedPtr<ShaderModuleDescriptor>&,
+ const std::vector<tint::wgsl::Extension>&,
ShaderModuleParseResult*,
OwnedCompilationMessages*),
(override));
diff --git a/src/dawn/tests/unittests/native/mocks/ShaderModuleMock.cpp b/src/dawn/tests/unittests/native/mocks/ShaderModuleMock.cpp
index 1f13df0..2066b0e 100644
--- a/src/dawn/tests/unittests/native/mocks/ShaderModuleMock.cpp
+++ b/src/dawn/tests/unittests/native/mocks/ShaderModuleMock.cpp
@@ -35,7 +35,7 @@
ShaderModuleMock::ShaderModuleMock(DeviceMock* device,
const UnpackedPtr<ShaderModuleDescriptor>& descriptor)
- : ShaderModuleBase(device, descriptor) {
+ : ShaderModuleBase(device, descriptor, {}) {
ON_CALL(*this, DestroyImpl).WillByDefault([this] { this->ShaderModuleBase::DestroyImpl(); });
SetContentHash(ComputeContentHash());
@@ -48,7 +48,7 @@
DeviceMock* device,
const UnpackedPtr<ShaderModuleDescriptor>& descriptor) {
ShaderModuleParseResult parseResult;
- ValidateAndParseShaderModule(device, descriptor, &parseResult, nullptr).AcquireSuccess();
+ ValidateAndParseShaderModule(device, descriptor, {}, &parseResult, nullptr).AcquireSuccess();
Ref<ShaderModuleMock> shaderModule =
AcquireRef(new NiceMock<ShaderModuleMock>(device, descriptor));