| // Copyright 2024 The Dawn & Tint Authors |
| // |
| // Redistribution and use in source and binary forms, with or without |
| // modification, are permitted provided that the following conditions are met: |
| // |
| // 1. Redistributions of source code must retain the above copyright notice, this |
| // list of conditions and the following disclaimer. |
| // |
| // 2. Redistributions in binary form must reproduce the above copyright notice, |
| // this list of conditions and the following disclaimer in the documentation |
| // and/or other materials provided with the distribution. |
| // |
| // 3. Neither the name of the copyright holder nor the names of its |
| // contributors may be used to endorse or promote products derived from |
| // this software without specific prior written permission. |
| // |
| // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE |
| // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL |
| // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR |
| // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER |
| // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, |
| // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| |
| #include <string> |
| #include <vector> |
| |
| #include "dawn/tests/DawnTest.h" |
| #include "dawn/utils/ComboRenderPipelineDescriptor.h" |
| #include "dawn/utils/WGPUHelpers.h" |
| |
| namespace dawn { |
| namespace { |
| |
| enum class Phase { |
| kConst, |
| kOverride, |
| kRuntime, |
| }; |
| |
| std::ostream& operator<<(std::ostream& o, Phase p) { |
| switch (p) { |
| case Phase::kConst: |
| o << "const"; |
| break; |
| case Phase::kOverride: |
| o << "override"; |
| break; |
| case Phase::kRuntime: |
| o << "runtime"; |
| break; |
| default: |
| DAWN_UNREACHABLE(); |
| break; |
| } |
| return o; |
| } |
| |
| enum class Compare : int { |
| kLess = -1, |
| kEqual = 0, |
| kMore = 1, |
| }; |
| |
| std::ostream& operator<<(std::ostream& o, Compare c) { |
| switch (c) { |
| case Compare::kLess: |
| o << "less"; |
| break; |
| case Compare::kEqual: |
| o << "equal"; |
| break; |
| case Compare::kMore: |
| o << "more"; |
| break; |
| default: |
| DAWN_UNREACHABLE(); |
| break; |
| } |
| return o; |
| } |
| |
| /////// |
| // clamp, smoothstep |
| /////// |
| |
| template <class Params> |
| struct BuiltinPartialConstLowHighBase : public DawnTestWithParams<Params> { |
| using DawnTestWithParams<Params>::GetParam; |
| Phase mLowPhase = Phase::kConst; |
| Phase mHighPhase = Phase::kConst; |
| Compare mCompare = Compare::kLess; |
| }; |
| |
| using Builtin = std::string; |
| using LowPhase = Phase; |
| using HighPhase = Phase; |
| using Scalar = bool; |
| DAWN_TEST_PARAM_STRUCT(BuiltinPartialConstLowHighParams, |
| Builtin, |
| LowPhase, |
| HighPhase, |
| Compare, |
| Scalar); |
| |
| class ShaderBuiltinPartialConstLowHighTest |
| : public BuiltinPartialConstLowHighBase<BuiltinPartialConstLowHighParams> { |
| protected: |
| std::string Shader() { |
| const auto builtin = GetParam().mBuiltin; |
| const float high_val = 10; // stay away from divide by zero for smoothstep |
| const float low_val = high_val + static_cast<float>(GetParam().mCompare); |
| |
| std::stringstream code; |
| auto module_var = [&](std::string ident, Phase p, float value) { |
| if (p != Phase::kRuntime) { |
| code << p << " " << ident << ": f32 = " << value << ";\n"; |
| } |
| }; |
| auto function_var = [&](std::string ident, Phase p, int value) { |
| if (p == Phase::kRuntime) { |
| code << " var " << ident << ": f32 = " << value << ";\n"; |
| } |
| }; |
| module_var("low", GetParam().mLowPhase, low_val); |
| module_var("high", GetParam().mHighPhase, high_val); |
| code << "@compute @workgroup_size(1) fn main() {\n"; |
| function_var("low", GetParam().mLowPhase, low_val); |
| function_var("high", GetParam().mHighPhase, high_val); |
| code << " var s: f32 = 0;\n"; |
| if (GetParam().mScalar) { |
| if (builtin == "clamp") { |
| code << " _ = clamp(s,low,high);\n"; |
| } |
| if (builtin == "smoothstep") { |
| code << " _ = smoothstep(low,high,s);\n"; |
| } |
| } else { |
| if (builtin == "clamp") { |
| code << " _ = clamp(vec3(s),vec3(0,low,0),vec3(1,high,1));\n"; |
| } |
| if (builtin == "smoothstep") { |
| code << " _ = smoothstep(vec3(0,low,0),vec3(1,high,1),vec3(s));\n"; |
| } |
| } |
| code << "}"; |
| return code.str(); |
| } |
| |
| bool BadCaseForBuiltin() { |
| if (GetParam().mBuiltin == "smoothstep") { |
| // The more case is bad because low can't be more than high. |
| // The equal case generates a divide by zero. |
| return GetParam().mCompare != Compare::kLess; |
| } |
| if (GetParam().mBuiltin == "clamp") { |
| return GetParam().mCompare == Compare::kMore; |
| } |
| DAWN_UNREACHABLE(); |
| } |
| }; |
| |
| TEST_P(ShaderBuiltinPartialConstLowHighTest, All) { |
| const auto builtin = GetParam().mBuiltin; |
| const auto lowPhase = GetParam().mLowPhase; |
| const auto highPhase = GetParam().mHighPhase; |
| const auto wgsl = Shader(); |
| |
| const bool expect_create_shader_error = BadCaseForBuiltin() // |
| && lowPhase == Phase::kConst // |
| && highPhase == Phase::kConst; |
| |
| if (expect_create_shader_error) { |
| ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, wgsl)); |
| } else { |
| const bool expect_pipeline_error = BadCaseForBuiltin() // |
| && lowPhase != Phase::kRuntime // |
| && highPhase != Phase::kRuntime; |
| auto shader = utils::CreateShaderModule(device, wgsl); |
| wgpu::ComputePipelineDescriptor desc; |
| desc.compute.module = shader; |
| if (expect_pipeline_error) { |
| ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc)); |
| } else { |
| device.CreateComputePipeline(&desc); |
| } |
| } |
| } |
| |
| DAWN_INSTANTIATE_TEST_P(ShaderBuiltinPartialConstLowHighTest, |
| {D3D11Backend(), D3D12Backend(), MetalBackend(), NullBackend(), |
| OpenGLBackend(), OpenGLESBackend(), VulkanBackend()}, |
| {"clamp", "smoothstep"}, // mBuiltin |
| {Phase::kConst, Phase::kOverride, Phase::kRuntime}, // mLowPhase |
| {Phase::kConst, Phase::kOverride, Phase::kRuntime}, // mHighPhase |
| {Compare::kLess, Compare::kEqual, Compare::kMore}, // mCompare |
| {true, false}); // Scalar (else Vector) |
| |
| /////// |
| // insertBits, extractBits |
| /////// |
| |
| template <class Params> |
| struct BuiltinPartialConstOffsetCountBase : public DawnTestWithParams<Params> { |
| using DawnTestWithParams<Params>::GetParam; |
| Phase mOffsetPhase = Phase::kConst; |
| Phase mCountPhase = Phase::kConst; |
| }; |
| |
| using Builtin = std::string; |
| using OffsetPhase = Phase; |
| using CountPhase = Phase; |
| using OffsetTooBig = bool; |
| using CountTooBig = bool; |
| DAWN_TEST_PARAM_STRUCT(BuiltinPartialConstOffsetCountParams, |
| Builtin, |
| OffsetPhase, |
| CountPhase, |
| OffsetTooBig, |
| CountTooBig); |
| |
| class ShaderBuiltinPartialConstOffsetCountTest |
| : public BuiltinPartialConstOffsetCountBase<BuiltinPartialConstOffsetCountParams> { |
| protected: |
| std::string Shader() { |
| // Assume 32-bit integers. |
| const auto builtin = GetParam().mBuiltin; |
| const int offset_val = 16 + (GetParam().mOffsetTooBig ? 1 : 0); |
| const int count_val = 16 + (GetParam().mCountTooBig ? 1 : 0); |
| |
| std::stringstream code; |
| auto module_var = [&](std::string ident, Phase p, float value) { |
| if (p != Phase::kRuntime) { |
| code << p << " " << ident << ": u32 = " << value << ";\n"; |
| } |
| }; |
| auto function_var = [&](std::string ident, Phase p, int value) { |
| if (p == Phase::kRuntime) { |
| code << " var " << ident << ": u32 = " << value << ";\n"; |
| } |
| }; |
| module_var("offset", GetParam().mOffsetPhase, offset_val); |
| module_var("count", GetParam().mCountPhase, count_val); |
| code << "@compute @workgroup_size(1) fn main() {\n"; |
| function_var("offset", GetParam().mOffsetPhase, offset_val); |
| function_var("count", GetParam().mCountPhase, count_val); |
| code << " var s: u32 = 0;\n"; |
| if (builtin == "insertBits") { |
| code << " _ = insertBits(s,s,offset,count);\n"; |
| } |
| if (builtin == "extractBits") { |
| code << " _ = extractBits(s,offset,count);\n"; |
| } |
| code << "}"; |
| return code.str(); |
| } |
| }; |
| |
| TEST_P(ShaderBuiltinPartialConstOffsetCountTest, All) { |
| const auto builtin = GetParam().mBuiltin; |
| const auto offsetPhase = GetParam().mOffsetPhase; |
| const auto countPhase = GetParam().mCountPhase; |
| const auto offsetTooBig = GetParam().mOffsetTooBig; |
| const auto countTooBig = GetParam().mCountTooBig; |
| const auto wgsl = Shader(); |
| |
| const bool expect_create_shader_error = (offsetTooBig || countTooBig) // |
| && offsetPhase == Phase::kConst // |
| && countPhase == Phase::kConst; |
| |
| if (expect_create_shader_error) { |
| ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, wgsl)); |
| } else { |
| const bool expect_pipeline_error = (offsetTooBig || countTooBig) // |
| && offsetPhase != Phase::kRuntime // |
| && countPhase != Phase::kRuntime; |
| auto shader = utils::CreateShaderModule(device, wgsl); |
| wgpu::ComputePipelineDescriptor desc; |
| desc.compute.module = shader; |
| if (expect_pipeline_error) { |
| ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc)); |
| } else { |
| device.CreateComputePipeline(&desc); |
| } |
| } |
| } |
| |
| DAWN_INSTANTIATE_TEST_P(ShaderBuiltinPartialConstOffsetCountTest, |
| {D3D11Backend(), D3D12Backend(), MetalBackend(), NullBackend(), |
| OpenGLBackend(), OpenGLESBackend(), VulkanBackend()}, |
| {"insertBits", "extractBits"}, // mBuiltin |
| {Phase::kConst, Phase::kOverride, Phase::kRuntime}, // mOffsetPhase |
| {Phase::kConst, Phase::kOverride, Phase::kRuntime}, // mCountPhase |
| {true, false}, // mOffsetTooBig |
| {true, false}); // mCountTooBig |
| |
| /////// |
| // ldexp |
| // If the first parameter is not const, then it must be concrete. So we don't |
| // have to check the abstract float case. |
| /////// |
| |
| enum class FloatType { |
| f32, |
| f16, |
| }; |
| std::ostream& operator<<(std::ostream& o, FloatType ty) { |
| switch (ty) { |
| case FloatType::f16: |
| o << "f16"; |
| break; |
| case FloatType::f32: |
| o << "f32"; |
| break; |
| default: |
| DAWN_UNREACHABLE(); |
| break; |
| } |
| return o; |
| } |
| |
| template <class Params> |
| class BuiltinPartialConstExponentBase : public DawnTestWithParams<Params> { |
| public: |
| using DawnTestWithParams<Params>::GetParam; |
| using DawnTestWithParams<Params>::SupportsFeatures; |
| Phase mExponentPhase = Phase::kConst; |
| |
| protected: |
| std::vector<wgpu::FeatureName> GetRequiredFeatures() override { |
| // Always require related features if available. |
| std::vector<wgpu::FeatureName> requiredFeatures; |
| if (SupportsFeatures({wgpu::FeatureName::ShaderF16})) { |
| mShaderF16Supported = true; |
| requiredFeatures.push_back(wgpu::FeatureName::ShaderF16); |
| } |
| return requiredFeatures; |
| } |
| |
| bool mShaderF16Supported = false; |
| }; |
| |
| using ExponentPhase = Phase; |
| using ExponentTooBig = bool; |
| using Scalar = bool; |
| DAWN_TEST_PARAM_STRUCT(BuiltinPartialConstExponentParams, |
| FloatType, |
| ExponentPhase, |
| ExponentTooBig, |
| Scalar); |
| |
| class ShaderBuiltinPartialConstExponentTest |
| : public BuiltinPartialConstExponentBase<BuiltinPartialConstExponentParams> { |
| protected: |
| static int bias(FloatType ft) { |
| switch (ft) { |
| case FloatType::f32: |
| return 127; |
| case FloatType::f16: |
| return 15; |
| } |
| DAWN_UNREACHABLE(); |
| } |
| |
| static std::string suffix(FloatType ft) { |
| switch (ft) { |
| case FloatType::f32: |
| return "f"; |
| case FloatType::f16: |
| return "h"; |
| } |
| DAWN_UNREACHABLE(); |
| } |
| |
| std::string Shader() { |
| const FloatType ty = GetParam().mFloatType; |
| |
| const bool too_big = GetParam().mExponentTooBig; |
| const int exponent_val = bias(ty) + 1 + (too_big ? 1 : 0); |
| |
| std::stringstream code; |
| if (ty == FloatType::f16) { |
| code << "enable f16;\n"; |
| } |
| auto module_var = [&](std::string ident, Phase p, float value) { |
| if (p != Phase::kRuntime) { |
| code << p << " " << ident << ": i32 = " << value << ";\n"; |
| } |
| }; |
| auto function_var = [&](std::string ident, Phase p, int value) { |
| if (p == Phase::kRuntime) { |
| code << " var " << ident << ": i32 = " << value << ";\n"; |
| } |
| }; |
| module_var("exponent", GetParam().mExponentPhase, exponent_val); |
| code << "@compute @workgroup_size(1) fn main() {\n"; |
| function_var("exponent", GetParam().mExponentPhase, exponent_val); |
| code << " var x: " << ty << " = 0;\n"; |
| if (GetParam().mScalar) { |
| code << " _ = ldexp(x,exponent);\n"; |
| } else { |
| code << " _ = ldexp(vec2(x,x),vec2(0,exponent));\n"; |
| } |
| code << "}"; |
| return code.str(); |
| } |
| }; |
| |
| TEST_P(ShaderBuiltinPartialConstExponentTest, All) { |
| if (GetParam().mFloatType == FloatType::f16) { |
| DAWN_TEST_UNSUPPORTED_IF(!mShaderF16Supported); |
| } |
| const auto exponentPhase = GetParam().mExponentPhase; |
| const auto exponentTooBig = GetParam().mExponentTooBig; |
| const auto wgsl = Shader(); |
| |
| const bool expect_create_shader_error = exponentTooBig && exponentPhase == Phase::kConst; |
| |
| if (expect_create_shader_error) { |
| ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, wgsl)); |
| } else { |
| const bool expect_pipeline_error = exponentTooBig && exponentPhase != Phase::kRuntime; |
| |
| auto shader = utils::CreateShaderModule(device, wgsl); |
| wgpu::ComputePipelineDescriptor desc; |
| desc.compute.module = shader; |
| if (expect_pipeline_error) { |
| ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc)); |
| } else { |
| device.CreateComputePipeline(&desc); |
| } |
| } |
| } |
| |
| DAWN_INSTANTIATE_TEST_P(ShaderBuiltinPartialConstExponentTest, |
| {D3D12Backend(), MetalBackend(), VulkanBackend()}, |
| {FloatType::f16, FloatType::f32}, // mFloatType |
| {Phase::kConst, Phase::kOverride, Phase::kRuntime}, // mCountPhase |
| {true, false}, // mExponentTooBig |
| {true, false}); // Scalar (or Vector) |
| |
| } // anonymous namespace |
| } // namespace dawn |