dawn: Allow internal extensions when creating shader module.

Bug: 42240662
Change-Id: I64fecb2b019337586fa89a8897def1e8a70b2e52
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/191781
Commit-Queue: Quyen Le <lehoangquyen@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp
index 713a6eb..e0ac3fd 100644
--- a/src/dawn/native/Device.cpp
+++ b/src/dawn/native/Device.cpp
@@ -1161,11 +1161,13 @@
 
 ResultOrError<Ref<ShaderModuleBase>> DeviceBase::GetOrCreateShaderModule(
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+    const std::vector<tint::wgsl::Extension>& internalExtensions,
     ShaderModuleParseResult* parseResult,
     std::unique_ptr<OwnedCompilationMessages>* compilationMessages) {
     DAWN_ASSERT(parseResult != nullptr);
 
-    ShaderModuleBase blueprint(this, descriptor, ApiObjectBase::kUntrackedByDevice);
+    ShaderModuleBase blueprint(this, descriptor, internalExtensions,
+                               ApiObjectBase::kUntrackedByDevice);
 
     const size_t blueprintHash = blueprint.ComputeContentHash();
     blueprint.SetContentHash(blueprintHash);
@@ -1178,13 +1180,14 @@
                 // lookup in the cache without validating and parsing. We need the parsed module
                 // now.
                 DAWN_ASSERT(!IsValidationEnabled());
-                DAWN_TRY(
-                    ValidateAndParseShaderModule(this, descriptor, parseResult, unownedMessages));
+                DAWN_TRY(ValidateAndParseShaderModule(this, descriptor, internalExtensions,
+                                                      parseResult, unownedMessages));
             }
 
             auto resultOrError = [&]() -> ResultOrError<Ref<ShaderModuleBase>> {
                 SCOPED_DAWN_HISTOGRAM_TIMER_MICROS(GetPlatform(), "CreateShaderModuleUS");
-                return CreateShaderModuleImpl(descriptor, parseResult, unownedMessages);
+                return CreateShaderModuleImpl(descriptor, internalExtensions, parseResult,
+                                              unownedMessages);
             }();
             DAWN_HISTOGRAM_BOOLEAN(GetPlatform(), "CreateShaderModuleSuccess",
                                    resultOrError.IsSuccess());
@@ -1528,7 +1531,8 @@
 
     std::unique_ptr<OwnedCompilationMessages> compilationMessages(
         std::make_unique<OwnedCompilationMessages>());
-    auto resultOrError = CreateShaderModule(descriptor, &compilationMessages);
+    auto resultOrError =
+        CreateShaderModule(descriptor, /*internalExtensions=*/{}, &compilationMessages);
     if (resultOrError.IsSuccess()) {
         Ref<ShaderModuleBase> result = resultOrError.AcquireSuccess();
         EmitCompilationLog(result.Get());
@@ -1827,10 +1831,6 @@
     return mWGSLAllowedFeatures;
 }
 
-void DeviceBase::EnableAdditionalWGSLExtension(tint::wgsl::Extension extension) {
-    mWGSLAllowedFeatures.extensions.insert(extension);
-}
-
 bool DeviceBase::IsValidationEnabled() const {
     return !IsToggleEnabled(Toggle::SkipValidation);
 }
@@ -2237,6 +2237,7 @@
 
 ResultOrError<Ref<ShaderModuleBase>> DeviceBase::CreateShaderModule(
     const ShaderModuleDescriptor* descriptor,
+    const std::vector<tint::wgsl::Extension>& internalExtensions,
     std::unique_ptr<OwnedCompilationMessages>* compilationMessages) {
     DAWN_TRY(ValidateIsAlive());
 
@@ -2250,14 +2251,14 @@
         DAWN_TRY_ASSIGN_CONTEXT(unpacked, ValidateAndUnpack(descriptor),
                                 "validating and unpacking %s", descriptor);
         DAWN_TRY_CONTEXT(ValidateAndParseShaderModule(
-                             this, unpacked, &parseResult,
+                             this, unpacked, internalExtensions, &parseResult,
                              compilationMessages ? compilationMessages->get() : nullptr),
                          "validating %s", descriptor);
     } else {
         unpacked = Unpack(descriptor);
     }
 
-    return GetOrCreateShaderModule(unpacked, &parseResult, compilationMessages);
+    return GetOrCreateShaderModule(unpacked, internalExtensions, &parseResult, compilationMessages);
 }
 
 ResultOrError<Ref<SwapChainBase>> DeviceBase::CreateSwapChain(
diff --git a/src/dawn/native/Device.h b/src/dawn/native/Device.h
index 2d395a4..3a9189f 100644
--- a/src/dawn/native/Device.h
+++ b/src/dawn/native/Device.h
@@ -185,6 +185,7 @@
 
     ResultOrError<Ref<ShaderModuleBase>> GetOrCreateShaderModule(
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+        const std::vector<tint::wgsl::Extension>& internalExtensions,
         ShaderModuleParseResult* parseResult,
         std::unique_ptr<OwnedCompilationMessages>* compilationMessages);
 
@@ -228,6 +229,7 @@
     ResultOrError<Ref<SamplerBase>> CreateSampler(const SamplerDescriptor* descriptor = nullptr);
     ResultOrError<Ref<ShaderModuleBase>> CreateShaderModule(
         const ShaderModuleDescriptor* descriptor,
+        const std::vector<tint::wgsl::Extension>& internalExtensions = {},
         std::unique_ptr<OwnedCompilationMessages>* compilationMessages = nullptr);
     // Deprecated: this was the way to create a SwapChain when it was explicitly manipulated by the
     // end user.
@@ -469,8 +471,6 @@
     void DestroyObjects();
     void Destroy();
 
-    void EnableAdditionalWGSLExtension(tint::wgsl::Extension extension);
-
     virtual MaybeError GetAHardwareBufferPropertiesImpl(
         void* handle,
         AHardwareBufferProperties* properties) const {
@@ -501,6 +501,7 @@
         const SamplerDescriptor* descriptor) = 0;
     virtual ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+        const std::vector<tint::wgsl::Extension>& internalExtensions,
         ShaderModuleParseResult* parseResult,
         OwnedCompilationMessages* compilationMessages) = 0;
     // Note that previousSwapChain may be nullptr, or come from a different backend.
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index 123633c..6ba5f1f 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -348,9 +348,12 @@
 
 ResultOrError<tint::Program> ParseWGSL(const tint::Source::File* file,
                                        const tint::wgsl::AllowedFeatures& allowedFeatures,
+                                       const std::vector<tint::wgsl::Extension>& internalExtensions,
                                        OwnedCompilationMessages* outMessages) {
     tint::wgsl::reader::Options options;
     options.allowed_features = allowedFeatures;
+    options.allowed_features.extensions.insert(internalExtensions.begin(),
+                                               internalExtensions.end());
     tint::Program program = tint::wgsl::reader::Parse(file, options);
     if (outMessages != nullptr) {
         DAWN_TRY(outMessages->AddMessages(program.Diagnostics()));
@@ -1071,10 +1074,12 @@
     return tintProgram != nullptr;
 }
 
-MaybeError ValidateAndParseShaderModule(DeviceBase* device,
-                                        const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-                                        ShaderModuleParseResult* parseResult,
-                                        OwnedCompilationMessages* outMessages) {
+MaybeError ValidateAndParseShaderModule(
+    DeviceBase* device,
+    const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+    const std::vector<tint::wgsl::Extension>& internalExtensions,
+    ShaderModuleParseResult* parseResult,
+    OwnedCompilationMessages* outMessages) {
     DAWN_ASSERT(parseResult != nullptr);
 
     wgpu::SType moduleType;
@@ -1148,8 +1153,8 @@
     }
 
     tint::Program program;
-    DAWN_TRY_ASSIGN(program,
-                    ParseWGSL(tintFile.get(), device->GetWGSLAllowedFeatures(), outMessages));
+    DAWN_TRY_ASSIGN(program, ParseWGSL(tintFile.get(), device->GetWGSLAllowedFeatures(),
+                                       internalExtensions, outMessages));
 
     parseResult->tintProgram = AcquireRef(new TintProgram(std::move(program), std::move(tintFile)));
 
@@ -1305,8 +1310,11 @@
 
 ShaderModuleBase::ShaderModuleBase(DeviceBase* device,
                                    const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+                                   std::vector<tint::wgsl::Extension> internalExtensions,
                                    ApiObjectBase::UntrackedByDeviceTag tag)
-    : Base(device, descriptor->label), mType(Type::Undefined) {
+    : Base(device, descriptor->label),
+      mType(Type::Undefined),
+      mInternalExtensions(std::move(internalExtensions)) {
     if (auto* spirvDesc = descriptor.Get<ShaderModuleSPIRVDescriptor>()) {
         mType = Type::Spirv;
         mOriginalSpirv.assign(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
@@ -1323,8 +1331,9 @@
 }
 
 ShaderModuleBase::ShaderModuleBase(DeviceBase* device,
-                                   const UnpackedPtr<ShaderModuleDescriptor>& descriptor)
-    : ShaderModuleBase(device, descriptor, kUntrackedByDevice) {
+                                   const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+                                   std::vector<tint::wgsl::Extension> internalExtensions)
+    : ShaderModuleBase(device, descriptor, std::move(internalExtensions), kUntrackedByDevice) {
     GetObjectTrackingList()->Track(this);
 }
 
@@ -1419,7 +1428,8 @@
         }
 
         ShaderModuleParseResult parseResult;
-        ValidateAndParseShaderModule(GetDevice(), Unpack(&descriptor), &parseResult,
+        ValidateAndParseShaderModule(GetDevice(), Unpack(&descriptor), mInternalExtensions,
+                                     &parseResult,
                                      /*compilationMessages=*/nullptr)
             .AcquireSuccess();
         DAWN_ASSERT(parseResult.tintProgram != nullptr);
diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h
index 339a56b..18ffe10 100644
--- a/src/dawn/native/ShaderModule.h
+++ b/src/dawn/native/ShaderModule.h
@@ -130,10 +130,12 @@
     std::string name;
 };
 
-MaybeError ValidateAndParseShaderModule(DeviceBase* device,
-                                        const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-                                        ShaderModuleParseResult* parseResult,
-                                        OwnedCompilationMessages* outMessages);
+MaybeError ValidateAndParseShaderModule(
+    DeviceBase* device,
+    const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+    const std::vector<tint::wgsl::Extension>& internalExtensions,
+    ShaderModuleParseResult* parseResult,
+    OwnedCompilationMessages* outMessages);
 MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
                                                    const EntryPointMetadata& entryPoint,
                                                    const PipelineLayoutBase* layout);
@@ -284,8 +286,11 @@
     using Base = RefCountedWithExternalCountBase<ApiObjectBase>;
     ShaderModuleBase(DeviceBase* device,
                      const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+                     std::vector<tint::wgsl::Extension> internalExtensions,
                      ApiObjectBase::UntrackedByDeviceTag tag);
-    ShaderModuleBase(DeviceBase* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor);
+    ShaderModuleBase(DeviceBase* device,
+                     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+                     std::vector<tint::wgsl::Extension> internalExtensions);
     ~ShaderModuleBase() override;
 
     static Ref<ShaderModuleBase> MakeError(DeviceBase* device, const char* label);
@@ -361,6 +366,8 @@
     MutexProtected<TintData> mTintData;
 
     std::unique_ptr<OwnedCompilationMessages> mCompilationMessages;
+
+    const std::vector<tint::wgsl::Extension> mInternalExtensions;
 };
 
 }  // namespace dawn::native
diff --git a/src/dawn/native/d3d11/DeviceD3D11.cpp b/src/dawn/native/d3d11/DeviceD3D11.cpp
index 2430bb6..757fad9 100644
--- a/src/dawn/native/d3d11/DeviceD3D11.cpp
+++ b/src/dawn/native/d3d11/DeviceD3D11.cpp
@@ -211,9 +211,11 @@
 
 ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+    const std::vector<tint::wgsl::Extension>& internalExtensions,
     ShaderModuleParseResult* parseResult,
     OwnedCompilationMessages* compilationMessages) {
-    return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
+    return ShaderModule::Create(this, descriptor, internalExtensions, parseResult,
+                                compilationMessages);
 }
 
 ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(Surface* surface,
diff --git a/src/dawn/native/d3d11/DeviceD3D11.h b/src/dawn/native/d3d11/DeviceD3D11.h
index 332dcdb..de72f45 100644
--- a/src/dawn/native/d3d11/DeviceD3D11.h
+++ b/src/dawn/native/d3d11/DeviceD3D11.h
@@ -112,6 +112,7 @@
     ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
     ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+        const std::vector<tint::wgsl::Extension>& internalExtensions,
         ShaderModuleParseResult* parseResult,
         OwnedCompilationMessages* compilationMessages) override;
     ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
diff --git a/src/dawn/native/d3d11/ShaderModuleD3D11.cpp b/src/dawn/native/d3d11/ShaderModuleD3D11.cpp
index c5469c5..58b214e 100644
--- a/src/dawn/native/d3d11/ShaderModuleD3D11.cpp
+++ b/src/dawn/native/d3d11/ShaderModuleD3D11.cpp
@@ -56,15 +56,18 @@
 ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
     Device* device,
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+    const std::vector<tint::wgsl::Extension>& internalExtensions,
     ShaderModuleParseResult* parseResult,
     OwnedCompilationMessages* compilationMessages) {
-    Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
+    Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
     DAWN_TRY(module->Initialize(parseResult, compilationMessages));
     return module;
 }
 
-ShaderModule::ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor)
-    : ShaderModuleBase(device, descriptor) {}
+ShaderModule::ShaderModule(Device* device,
+                           const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+                           std::vector<tint::wgsl::Extension> internalExtensions)
+    : ShaderModuleBase(device, descriptor, std::move(internalExtensions)) {}
 
 MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
                                     OwnedCompilationMessages* compilationMessages) {
diff --git a/src/dawn/native/d3d11/ShaderModuleD3D11.h b/src/dawn/native/d3d11/ShaderModuleD3D11.h
index b436aee..a99f963 100644
--- a/src/dawn/native/d3d11/ShaderModuleD3D11.h
+++ b/src/dawn/native/d3d11/ShaderModuleD3D11.h
@@ -30,6 +30,7 @@
 
 #include <optional>
 #include <string>
+#include <vector>
 
 #include "dawn/native/Blob.h"
 #include "dawn/native/Serializable.h"
@@ -51,6 +52,7 @@
     static ResultOrError<Ref<ShaderModule>> Create(
         Device* device,
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+        const std::vector<tint::wgsl::Extension>& internalExtensions,
         ShaderModuleParseResult* parseResult,
         OwnedCompilationMessages* compilationMessages);
 
@@ -64,7 +66,9 @@
         const std::optional<tint::hlsl::writer::PixelLocalOptions>& pixelLocalOptions = {});
 
   private:
-    ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor);
+    ShaderModule(Device* device,
+                 const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+                 std::vector<tint::wgsl::Extension> internalExtensions);
     ~ShaderModule() override = default;
     MaybeError Initialize(ShaderModuleParseResult* parseResult,
                           OwnedCompilationMessages* compilationMessages);
diff --git a/src/dawn/native/d3d12/DeviceD3D12.cpp b/src/dawn/native/d3d12/DeviceD3D12.cpp
index a46aad7..04613dd 100644
--- a/src/dawn/native/d3d12/DeviceD3D12.cpp
+++ b/src/dawn/native/d3d12/DeviceD3D12.cpp
@@ -395,9 +395,11 @@
 }
 ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+    const std::vector<tint::wgsl::Extension>& internalExtensions,
     ShaderModuleParseResult* parseResult,
     OwnedCompilationMessages* compilationMessages) {
-    return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
+    return ShaderModule::Create(this, descriptor, internalExtensions, parseResult,
+                                compilationMessages);
 }
 ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(Surface* surface,
                                                               SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/d3d12/DeviceD3D12.h b/src/dawn/native/d3d12/DeviceD3D12.h
index a46bf8f..bced302 100644
--- a/src/dawn/native/d3d12/DeviceD3D12.h
+++ b/src/dawn/native/d3d12/DeviceD3D12.h
@@ -191,6 +191,7 @@
     ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
     ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+        const std::vector<tint::wgsl::Extension>& internalExtensions,
         ShaderModuleParseResult* parseResult,
         OwnedCompilationMessages* compilationMessages) override;
     ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
index c5b7401..099f258 100644
--- a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
@@ -106,15 +106,18 @@
 ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
     Device* device,
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+    const std::vector<tint::wgsl::Extension>& internalExtensions,
     ShaderModuleParseResult* parseResult,
     OwnedCompilationMessages* compilationMessages) {
-    Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
+    Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
     DAWN_TRY(module->Initialize(parseResult, compilationMessages));
     return module;
 }
 
-ShaderModule::ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor)
-    : ShaderModuleBase(device, descriptor) {}
+ShaderModule::ShaderModule(Device* device,
+                           const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+                           std::vector<tint::wgsl::Extension> internalExtensions)
+    : ShaderModuleBase(device, descriptor, std::move(internalExtensions)) {}
 
 MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
                                     OwnedCompilationMessages* compilationMessages) {
diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.h b/src/dawn/native/d3d12/ShaderModuleD3D12.h
index f98050a..f5e143d 100644
--- a/src/dawn/native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn/native/d3d12/ShaderModuleD3D12.h
@@ -30,6 +30,7 @@
 
 #include <optional>
 #include <string>
+#include <vector>
 
 #include "dawn/native/Blob.h"
 #include "dawn/native/Serializable.h"
@@ -51,6 +52,7 @@
     static ResultOrError<Ref<ShaderModule>> Create(
         Device* device,
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+        const std::vector<tint::wgsl::Extension>& internalExtensions,
         ShaderModuleParseResult* parseResult,
         OwnedCompilationMessages* compilationMessages);
 
@@ -64,7 +66,9 @@
         std::optional<uint32_t> maxSubgroupSizeForFullSubgroups = std::nullopt);
 
   private:
-    ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor);
+    ShaderModule(Device* device,
+                 const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+                 std::vector<tint::wgsl::Extension> internalExtensions);
     ~ShaderModule() override = default;
     MaybeError Initialize(ShaderModuleParseResult* parseResult,
                           OwnedCompilationMessages* compilationMessages);
diff --git a/src/dawn/native/metal/DeviceMTL.h b/src/dawn/native/metal/DeviceMTL.h
index 44e8f8f..4ac2200 100644
--- a/src/dawn/native/metal/DeviceMTL.h
+++ b/src/dawn/native/metal/DeviceMTL.h
@@ -110,6 +110,7 @@
     ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
     ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+        const std::vector<tint::wgsl::Extension>& internalExtensions,
         ShaderModuleParseResult* parseResult,
         OwnedCompilationMessages* compilationMessages) override;
     ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
diff --git a/src/dawn/native/metal/DeviceMTL.mm b/src/dawn/native/metal/DeviceMTL.mm
index 2a58e4e..73c810e 100644
--- a/src/dawn/native/metal/DeviceMTL.mm
+++ b/src/dawn/native/metal/DeviceMTL.mm
@@ -225,9 +225,11 @@
 }
 ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+    const std::vector<tint::wgsl::Extension>& internalExtensions,
     ShaderModuleParseResult* parseResult,
     OwnedCompilationMessages* compilationMessages) {
-    return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
+    return ShaderModule::Create(this, descriptor, internalExtensions, parseResult,
+                                compilationMessages);
 }
 ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(Surface* surface,
                                                               SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/metal/ShaderModuleMTL.h b/src/dawn/native/metal/ShaderModuleMTL.h
index c04327f..9f34270 100644
--- a/src/dawn/native/metal/ShaderModuleMTL.h
+++ b/src/dawn/native/metal/ShaderModuleMTL.h
@@ -53,6 +53,7 @@
     static ResultOrError<Ref<ShaderModule>> Create(
         Device* device,
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+        const std::vector<tint::wgsl::Extension>& internalExtensions,
         ShaderModuleParseResult* parseResult,
         OwnedCompilationMessages* compilationMessages);
 
@@ -73,7 +74,9 @@
         std::optional<uint32_t> maxSubgroupSizeForFullSubgroups = std::nullopt);
 
   private:
-    ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor);
+    ShaderModule(Device* device,
+                 const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+                 std::vector<tint::wgsl::Extension> internalExtensions);
     ~ShaderModule() override;
     MaybeError Initialize(ShaderModuleParseResult* parseResult,
                           OwnedCompilationMessages* compilationMessages);
diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm
index be8d15a..70c4282 100644
--- a/src/dawn/native/metal/ShaderModuleMTL.mm
+++ b/src/dawn/native/metal/ShaderModuleMTL.mm
@@ -92,15 +92,18 @@
 ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
     Device* device,
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+    const std::vector<tint::wgsl::Extension>& internalExtensions,
     ShaderModuleParseResult* parseResult,
     OwnedCompilationMessages* compilationMessages) {
-    Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
+    Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
     DAWN_TRY(module->Initialize(parseResult, compilationMessages));
     return module;
 }
 
-ShaderModule::ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor)
-    : ShaderModuleBase(device, descriptor) {}
+ShaderModule::ShaderModule(Device* device,
+                           const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+                           std::vector<tint::wgsl::Extension> internalExtensions)
+    : ShaderModuleBase(device, descriptor, std::move(internalExtensions)) {}
 
 ShaderModule::~ShaderModule() = default;
 
diff --git a/src/dawn/native/null/DeviceNull.cpp b/src/dawn/native/null/DeviceNull.cpp
index 6636f88..d8a1e2f 100644
--- a/src/dawn/native/null/DeviceNull.cpp
+++ b/src/dawn/native/null/DeviceNull.cpp
@@ -227,9 +227,10 @@
 }
 ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+    const std::vector<tint::wgsl::Extension>& internalExtensions,
     ShaderModuleParseResult* parseResult,
     OwnedCompilationMessages* compilationMessages) {
-    Ref<ShaderModule> module = AcquireRef(new ShaderModule(this, descriptor));
+    Ref<ShaderModule> module = AcquireRef(new ShaderModule(this, descriptor, internalExtensions));
     DAWN_TRY(module->Initialize(parseResult, compilationMessages));
     return module;
 }
diff --git a/src/dawn/native/null/DeviceNull.h b/src/dawn/native/null/DeviceNull.h
index a9190ca..954ba68 100644
--- a/src/dawn/native/null/DeviceNull.h
+++ b/src/dawn/native/null/DeviceNull.h
@@ -160,6 +160,7 @@
     ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
     ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+        const std::vector<tint::wgsl::Extension>& internalExtensions,
         ShaderModuleParseResult* parseResult,
         OwnedCompilationMessages* compilationMessages) override;
     ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
diff --git a/src/dawn/native/opengl/DeviceGL.cpp b/src/dawn/native/opengl/DeviceGL.cpp
index e9842fc..c8b3db8 100644
--- a/src/dawn/native/opengl/DeviceGL.cpp
+++ b/src/dawn/native/opengl/DeviceGL.cpp
@@ -257,9 +257,11 @@
 }
 ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+    const std::vector<tint::wgsl::Extension>& internalExtensions,
     ShaderModuleParseResult* parseResult,
     OwnedCompilationMessages* compilationMessages) {
-    return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
+    return ShaderModule::Create(this, descriptor, internalExtensions, parseResult,
+                                compilationMessages);
 }
 ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(Surface* surface,
                                                               SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/opengl/DeviceGL.h b/src/dawn/native/opengl/DeviceGL.h
index 1e943fe..f4418f6 100644
--- a/src/dawn/native/opengl/DeviceGL.h
+++ b/src/dawn/native/opengl/DeviceGL.h
@@ -29,6 +29,7 @@
 #define SRC_DAWN_NATIVE_OPENGL_DEVICEGL_H_
 
 #include <memory>
+#include <vector>
 
 #include "dawn/native/dawn_platform.h"
 
@@ -121,6 +122,7 @@
     ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
     ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+        const std::vector<tint::wgsl::Extension>& internalExtensions,
         ShaderModuleParseResult* parseResult,
         OwnedCompilationMessages* compilationMessages) override;
     ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
diff --git a/src/dawn/native/opengl/ShaderModuleGL.cpp b/src/dawn/native/opengl/ShaderModuleGL.cpp
index fa031a2..9a8bce8 100644
--- a/src/dawn/native/opengl/ShaderModuleGL.cpp
+++ b/src/dawn/native/opengl/ShaderModuleGL.cpp
@@ -146,15 +146,18 @@
 ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
     Device* device,
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+    const std::vector<tint::wgsl::Extension>& internalExtensions,
     ShaderModuleParseResult* parseResult,
     OwnedCompilationMessages* compilationMessages) {
-    Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
+    Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
     DAWN_TRY(module->Initialize(parseResult, compilationMessages));
     return module;
 }
 
-ShaderModule::ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor)
-    : ShaderModuleBase(device, descriptor) {}
+ShaderModule::ShaderModule(Device* device,
+                           const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+                           std::vector<tint::wgsl::Extension> internalExtensions)
+    : ShaderModuleBase(device, descriptor, std::move(internalExtensions)) {}
 
 MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
                                     OwnedCompilationMessages* compilationMessages) {
diff --git a/src/dawn/native/opengl/ShaderModuleGL.h b/src/dawn/native/opengl/ShaderModuleGL.h
index 2c04be3..133eb44 100644
--- a/src/dawn/native/opengl/ShaderModuleGL.h
+++ b/src/dawn/native/opengl/ShaderModuleGL.h
@@ -83,6 +83,7 @@
     static ResultOrError<Ref<ShaderModule>> Create(
         Device* device,
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+        const std::vector<tint::wgsl::Extension>& internalExtensions,
         ShaderModuleParseResult* parseResult,
         OwnedCompilationMessages* compilationMessages);
 
@@ -99,7 +100,9 @@
                                         BindingPointToFunctionAndOffset* bindingPointToData) const;
 
   private:
-    ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor);
+    ShaderModule(Device* device,
+                 const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+                 std::vector<tint::wgsl::Extension> internalExtensions);
     ~ShaderModule() override = default;
     MaybeError Initialize(ShaderModuleParseResult* parseResult,
                           OwnedCompilationMessages* compilationMessages);
diff --git a/src/dawn/native/utils/WGPUHelpers.cpp b/src/dawn/native/utils/WGPUHelpers.cpp
index 06aedae..b1606ac 100644
--- a/src/dawn/native/utils/WGPUHelpers.cpp
+++ b/src/dawn/native/utils/WGPUHelpers.cpp
@@ -47,12 +47,15 @@
 
 namespace dawn::native::utils {
 
-ResultOrError<Ref<ShaderModuleBase>> CreateShaderModule(DeviceBase* device, const char* source) {
+ResultOrError<Ref<ShaderModuleBase>> CreateShaderModule(
+    DeviceBase* device,
+    const char* source,
+    const std::vector<tint::wgsl::Extension>& internalExtensions) {
     ShaderModuleWGSLDescriptor wgslDesc;
     wgslDesc.code = source;
     ShaderModuleDescriptor descriptor;
     descriptor.nextInChain = &wgslDesc;
-    return device->CreateShaderModule(&descriptor);
+    return device->CreateShaderModule(&descriptor, internalExtensions);
 }
 
 ResultOrError<Ref<BufferBase>> CreateBufferFromData(DeviceBase* device,
diff --git a/src/dawn/native/utils/WGPUHelpers.h b/src/dawn/native/utils/WGPUHelpers.h
index 5790df7..ae18e3f 100644
--- a/src/dawn/native/utils/WGPUHelpers.h
+++ b/src/dawn/native/utils/WGPUHelpers.h
@@ -37,9 +37,16 @@
 #include "dawn/native/UsageValidationMode.h"
 #include "dawn/native/dawn_platform.h"
 
+namespace tint::wgsl {
+enum class Extension : uint8_t;
+}
+
 namespace dawn::native::utils {
 
-ResultOrError<Ref<ShaderModuleBase>> CreateShaderModule(DeviceBase* device, const char* source);
+ResultOrError<Ref<ShaderModuleBase>> CreateShaderModule(
+    DeviceBase* device,
+    const char* source,
+    const std::vector<tint::wgsl::Extension>& internalExtensions = {});
 
 ResultOrError<Ref<BufferBase>> CreateBufferFromData(DeviceBase* device,
                                                     wgpu::BufferUsage usage,
diff --git a/src/dawn/native/vulkan/DeviceVk.cpp b/src/dawn/native/vulkan/DeviceVk.cpp
index 55879ea..0e504c2 100644
--- a/src/dawn/native/vulkan/DeviceVk.cpp
+++ b/src/dawn/native/vulkan/DeviceVk.cpp
@@ -168,15 +168,7 @@
     Ref<Queue> queue;
     DAWN_TRY_ASSIGN(queue, Queue::Create(this, &descriptor->defaultQueue, mMainQueueFamily));
 
-    DAWN_TRY(DeviceBase::Initialize(std::move(queue)));
-
-    if (HasFeature(Feature::DawnLoadResolveTexture)) {
-        // TODO(42240662): Add a way to add additional extensions when compiling specific shader
-        // modules only.
-        EnableAdditionalWGSLExtension(tint::wgsl::Extension::kChromiumInternalInputAttachments);
-    }
-
-    return {};
+    return DeviceBase::Initialize(std::move(queue));
 }
 
 Device::~Device() {
@@ -220,9 +212,11 @@
 }
 ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+    const std::vector<tint::wgsl::Extension>& internalExtensions,
     ShaderModuleParseResult* parseResult,
     OwnedCompilationMessages* compilationMessages) {
-    return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
+    return ShaderModule::Create(this, descriptor, internalExtensions, parseResult,
+                                compilationMessages);
 }
 ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(Surface* surface,
                                                               SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/vulkan/DeviceVk.h b/src/dawn/native/vulkan/DeviceVk.h
index dddbd4a..ba12063 100644
--- a/src/dawn/native/vulkan/DeviceVk.h
+++ b/src/dawn/native/vulkan/DeviceVk.h
@@ -143,6 +143,7 @@
     ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
     ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+        const std::vector<tint::wgsl::Extension>& internalExtensions,
         ShaderModuleParseResult* parseResult,
         OwnedCompilationMessages* compilationMessages) override;
     ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
diff --git a/src/dawn/native/vulkan/ResolveTextureLoadingUtilsVk.cpp b/src/dawn/native/vulkan/ResolveTextureLoadingUtilsVk.cpp
index 3f00cad..6f4cf7c 100644
--- a/src/dawn/native/vulkan/ResolveTextureLoadingUtilsVk.cpp
+++ b/src/dawn/native/vulkan/ResolveTextureLoadingUtilsVk.cpp
@@ -116,7 +116,9 @@
     // fragment shader's source will depend on pipeline key.
     std::string fsCode = GenerateFS(pipelineKey);
     Ref<ShaderModuleBase> fshaderModule;
-    DAWN_TRY_ASSIGN(fshaderModule, utils::CreateShaderModule(device, fsCode.c_str()));
+    DAWN_TRY_ASSIGN(fshaderModule, utils::CreateShaderModule(
+                                       device, fsCode.c_str(),
+                                       {tint::wgsl::Extension::kChromiumInternalInputAttachments}));
 
     FragmentState fragmentState = {};
     fragmentState.module = fshaderModule.Get();
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp
index 5fbdd3c..2aa94bd 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -165,15 +165,18 @@
 ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
     Device* device,
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+    const std::vector<tint::wgsl::Extension>& internalExtensions,
     ShaderModuleParseResult* parseResult,
     OwnedCompilationMessages* compilationMessages) {
-    Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
+    Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
     DAWN_TRY(module->Initialize(parseResult, compilationMessages));
     return module;
 }
 
-ShaderModule::ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor)
-    : ShaderModuleBase(device, descriptor),
+ShaderModule::ShaderModule(Device* device,
+                           const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+                           std::vector<tint::wgsl::Extension> internalExtensions)
+    : ShaderModuleBase(device, descriptor, std::move(internalExtensions)),
       mTransformedShaderModuleCache(
           std::make_unique<ConcurrentTransformedShaderModuleCache>(device)) {}
 
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.h b/src/dawn/native/vulkan/ShaderModuleVk.h
index 37fe631..aad21e5 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.h
+++ b/src/dawn/native/vulkan/ShaderModuleVk.h
@@ -33,6 +33,7 @@
 #include <optional>
 #include <string>
 #include <utility>
+#include <vector>
 
 #include "dawn/common/HashUtils.h"
 #include "dawn/common/vulkan_platform.h"
@@ -76,6 +77,7 @@
     static ResultOrError<Ref<ShaderModule>> Create(
         Device* device,
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+        const std::vector<tint::wgsl::Extension>& internalExtensions,
         ShaderModuleParseResult* parseResult,
         OwnedCompilationMessages* compilationMessages);
 
@@ -88,7 +90,9 @@
         std::optional<uint32_t> maxSubgroupSizeForFullSubgroups);
 
   private:
-    ShaderModule(Device* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor);
+    ShaderModule(Device* device,
+                 const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+                 std::vector<tint::wgsl::Extension> internalExtensions);
     ~ShaderModule() override;
     MaybeError Initialize(ShaderModuleParseResult* parseResult,
                           OwnedCompilationMessages* compilationMessages);
diff --git a/src/dawn/tests/unittests/native/mocks/DeviceMock.h b/src/dawn/tests/unittests/native/mocks/DeviceMock.h
index ec7b899..0a3703c 100644
--- a/src/dawn/tests/unittests/native/mocks/DeviceMock.h
+++ b/src/dawn/tests/unittests/native/mocks/DeviceMock.h
@@ -29,6 +29,7 @@
 #define SRC_DAWN_TESTS_UNITTESTS_NATIVE_MOCKS_DEVICEMOCK_H_
 
 #include <memory>
+#include <vector>
 
 #include "dawn/native/Device.h"
 #include "dawn/native/Instance.h"
@@ -119,6 +120,7 @@
     MOCK_METHOD(ResultOrError<Ref<ShaderModuleBase>>,
                 CreateShaderModuleImpl,
                 (const UnpackedPtr<ShaderModuleDescriptor>&,
+                 const std::vector<tint::wgsl::Extension>&,
                  ShaderModuleParseResult*,
                  OwnedCompilationMessages*),
                 (override));
diff --git a/src/dawn/tests/unittests/native/mocks/ShaderModuleMock.cpp b/src/dawn/tests/unittests/native/mocks/ShaderModuleMock.cpp
index 1f13df0..2066b0e 100644
--- a/src/dawn/tests/unittests/native/mocks/ShaderModuleMock.cpp
+++ b/src/dawn/tests/unittests/native/mocks/ShaderModuleMock.cpp
@@ -35,7 +35,7 @@
 
 ShaderModuleMock::ShaderModuleMock(DeviceMock* device,
                                    const UnpackedPtr<ShaderModuleDescriptor>& descriptor)
-    : ShaderModuleBase(device, descriptor) {
+    : ShaderModuleBase(device, descriptor, {}) {
     ON_CALL(*this, DestroyImpl).WillByDefault([this] { this->ShaderModuleBase::DestroyImpl(); });
 
     SetContentHash(ComputeContentHash());
@@ -48,7 +48,7 @@
     DeviceMock* device,
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor) {
     ShaderModuleParseResult parseResult;
-    ValidateAndParseShaderModule(device, descriptor, &parseResult, nullptr).AcquireSuccess();
+    ValidateAndParseShaderModule(device, descriptor, {}, &parseResult, nullptr).AcquireSuccess();
 
     Ref<ShaderModuleMock> shaderModule =
         AcquireRef(new NiceMock<ShaderModuleMock>(device, descriptor));