Use LOAD_LIBRARY_SEARCH_SYSTEM32 for system DLL loads To avoid DLL search order attacks, always load system DLLs like d3d11.dll, d3d12.dll by passing LOAD_LIBRARY_SEARCH_SYSTEM32 to LoadLibraryExW. Change-Id: I0af9f26bb070ea745ec97c23578047543961084f Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/229734 Reviewed-by: Corentin Wallez <cwallez@chromium.org> Commit-Queue: Rafael Cintron <rafael.cintron@microsoft.com> Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/dawn/common/DynamicLib.cpp b/src/dawn/common/DynamicLib.cpp index 7d823db..d700aa0 100644 --- a/src/dawn/common/DynamicLib.cpp +++ b/src/dawn/common/DynamicLib.cpp
@@ -61,6 +61,18 @@ return mHandle != nullptr; } +#if DAWN_PLATFORM_IS(WINDOWS) && !DAWN_PLATFORM_IS(WINUWP) +bool DynamicLib::OpenSystemLibrary(std::wstring_view filename, std::string* error) { + // Force LOAD_LIBRARY_SEARCH_SYSTEM32 for system libraries to avoid DLL search path + // attacks. + mHandle = ::LoadLibraryExW(filename.data(), nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32); + if (mHandle == nullptr && error != nullptr) { + *error = "Windows Error: " + std::to_string(GetLastError()); + } + return mHandle != nullptr; +} +#endif + bool DynamicLib::Open(const std::string& filename, std::string* error) { #if DAWN_PLATFORM_IS(WINDOWS) #if DAWN_PLATFORM_IS(WINUWP)
diff --git a/src/dawn/common/DynamicLib.h b/src/dawn/common/DynamicLib.h index 7197221..8f89eb2 100644 --- a/src/dawn/common/DynamicLib.h +++ b/src/dawn/common/DynamicLib.h
@@ -32,6 +32,7 @@ #include <type_traits> #include "dawn/common/Assert.h" +#include "dawn/common/Platform.h" namespace dawn { @@ -48,6 +49,9 @@ bool Valid() const; +#if DAWN_PLATFORM_IS(WINDOWS) && !DAWN_PLATFORM_IS(WINUWP) + bool OpenSystemLibrary(std::wstring_view filename, std::string* error = nullptr); +#endif bool Open(const std::string& filename, std::string* error = nullptr); void Close();
diff --git a/src/dawn/native/d3d/PlatformFunctions.cpp b/src/dawn/native/d3d/PlatformFunctions.cpp index f63bbeb..5efe68e 100644 --- a/src/dawn/native/d3d/PlatformFunctions.cpp +++ b/src/dawn/native/d3d/PlatformFunctions.cpp
@@ -80,7 +80,7 @@ createDxgiFactory2 = &CreateDXGIFactory2; #else std::string error; - if (!mDXGILib.Open("dxgi.dll", &error) || + if (!mDXGILib.OpenSystemLibrary(L"dxgi.dll", &error) || !mDXGILib.GetProc(&dxgiGetDebugInterface1, "DXGIGetDebugInterface1", &error) || !mDXGILib.GetProc(&createDxgiFactory2, "CreateDXGIFactory2", &error)) { return DAWN_INTERNAL_ERROR(error.c_str());
diff --git a/src/dawn/native/d3d11/PlatformFunctionsD3D11.cpp b/src/dawn/native/d3d11/PlatformFunctionsD3D11.cpp index c917f93..d843218 100644 --- a/src/dawn/native/d3d11/PlatformFunctionsD3D11.cpp +++ b/src/dawn/native/d3d11/PlatformFunctionsD3D11.cpp
@@ -46,7 +46,7 @@ d3d11CreateDevice = &D3D11CreateDevice; #else std::string error; - if (!mD3D11Lib.Open("d3d11.dll", &error) || + if (!mD3D11Lib.OpenSystemLibrary(L"d3d11.dll", &error) || !mD3D11Lib.GetProc(&d3d11CreateDevice, "D3D11CreateDevice", &error)) { return DAWN_INTERNAL_ERROR(error.c_str()); }
diff --git a/src/dawn/native/d3d12/PlatformFunctionsD3D12.cpp b/src/dawn/native/d3d12/PlatformFunctionsD3D12.cpp index 1cf1c60..c84c8e4 100644 --- a/src/dawn/native/d3d12/PlatformFunctionsD3D12.cpp +++ b/src/dawn/native/d3d12/PlatformFunctionsD3D12.cpp
@@ -131,7 +131,7 @@ d3d12CreateVersionedRootSignatureDeserializer = &D3D12CreateVersionedRootSignatureDeserializer; #else std::string error; - if (!mD3D12Lib.Open("d3d12.dll", &error) || + if (!mD3D12Lib.OpenSystemLibrary(L"d3d12.dll", &error) || !mD3D12Lib.GetProc(&d3d12CreateDevice, "D3D12CreateDevice", &error) || !mD3D12Lib.GetProc(&d3d12GetDebugInterface, "D3D12GetDebugInterface", &error) || !mD3D12Lib.GetProc(&d3d12SerializeRootSignature, "D3D12SerializeRootSignature", &error) || @@ -153,7 +153,7 @@ d3d11on12CreateDevice = &D3D11On12CreateDevice; #else std::string error; - if (!mD3D11Lib.Open("d3d11.dll", &error) || + if (!mD3D11Lib.OpenSystemLibrary(L"d3d11.dll", &error) || !mD3D11Lib.GetProc(&d3d11on12CreateDevice, "D3D11On12CreateDevice", &error)) { return DAWN_INTERNAL_ERROR(error.c_str()); }