[tint][wgsl] validate ldexp const exponent < bias + 1 when other arg non-const

Bug: chromium:360155641
Change-Id: Iad0f580c56f1e4f09592f30be599f24c704fa00b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/202794
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/dawn/tests/end2end/ShaderBuiltinPartialConstArgsErrorTests.cpp b/src/dawn/tests/end2end/ShaderBuiltinPartialConstArgsErrorTests.cpp
index 7036da5..cd2c005 100644
--- a/src/dawn/tests/end2end/ShaderBuiltinPartialConstArgsErrorTests.cpp
+++ b/src/dawn/tests/end2end/ShaderBuiltinPartialConstArgsErrorTests.cpp
@@ -200,7 +200,7 @@
                         {true, false});                                      // Scalar (else Vector)
 
 ///////
-// insertBits
+// insertBits, extractBits
 ///////
 
 template <class Params>
@@ -297,5 +297,150 @@
                         {true, false},                                       // mOffsetTooBig
                         {true, false});                                      // mCountTooBig
 
+///////
+// ldexp
+// If the first parameter is not const, then it must be concrete. So we don't
+// have to check the abstract float case.
+///////
+
+enum class FloatType {
+    f32,
+    f16,
+};
+std::ostream& operator<<(std::ostream& o, FloatType ty) {
+    switch (ty) {
+        case FloatType::f16:
+            o << "f16";
+            break;
+        case FloatType::f32:
+            o << "f32";
+            break;
+        default:
+            DAWN_UNREACHABLE();
+            break;
+    }
+    return o;
+}
+
+template <class Params>
+class BuiltinPartialConstExponentBase : public DawnTestWithParams<Params> {
+  public:
+    using DawnTestWithParams<Params>::GetParam;
+    using DawnTestWithParams<Params>::SupportsFeatures;
+    Phase mExponentPhase = Phase::kConst;
+
+  protected:
+    std::vector<wgpu::FeatureName> GetRequiredFeatures() override {
+        // Always require related features if available.
+        std::vector<wgpu::FeatureName> requiredFeatures;
+        if (SupportsFeatures({wgpu::FeatureName::ShaderF16})) {
+            mShaderF16Supported = true;
+            requiredFeatures.push_back(wgpu::FeatureName::ShaderF16);
+        }
+        return requiredFeatures;
+    }
+
+    bool mShaderF16Supported = false;
+};
+
+using ExponentPhase = Phase;
+using ExponentTooBig = bool;
+using Scalar = bool;
+DAWN_TEST_PARAM_STRUCT(BuiltinPartialConstExponentParams,
+                       FloatType,
+                       ExponentPhase,
+                       ExponentTooBig,
+                       Scalar);
+
+class ShaderBuiltinPartialConstExponentTest
+    : public BuiltinPartialConstExponentBase<BuiltinPartialConstExponentParams> {
+  protected:
+    static int bias(FloatType ft) {
+        switch (ft) {
+            case FloatType::f32:
+                return 127;
+            case FloatType::f16:
+                return 15;
+        }
+        DAWN_UNREACHABLE();
+    }
+
+    static std::string suffix(FloatType ft) {
+        switch (ft) {
+            case FloatType::f32:
+                return "f";
+            case FloatType::f16:
+                return "h";
+        }
+        DAWN_UNREACHABLE();
+    }
+
+    std::string Shader() {
+        const FloatType ty = GetParam().mFloatType;
+
+        const bool too_big = GetParam().mExponentTooBig;
+        const int exponent_val = bias(ty) + 1 + (too_big ? 1 : 0);
+
+        std::stringstream code;
+        if (ty == FloatType::f16) {
+            code << "enable f16;\n";
+        }
+        auto module_var = [&](std::string ident, Phase p, float value) {
+            if (p != Phase::kRuntime) {
+                code << p << " " << ident << ": i32 = " << value << ";\n";
+            }
+        };
+        auto function_var = [&](std::string ident, Phase p, int value) {
+            if (p == Phase::kRuntime) {
+                code << "  var " << ident << ": i32 = " << value << ";\n";
+            }
+        };
+        module_var("exponent", GetParam().mExponentPhase, exponent_val);
+        code << "@compute @workgroup_size(1) fn main() {\n";
+        function_var("exponent", GetParam().mExponentPhase, exponent_val);
+        code << "  var x: " << ty << " = 0;\n";
+        if (GetParam().mScalar) {
+            code << "  _ = ldexp(x,exponent);\n";
+        } else {
+            code << "  _ = ldexp(vec2(x,x),vec2(0,exponent));\n";
+        }
+        code << "}";
+        return code.str();
+    }
+};
+
+TEST_P(ShaderBuiltinPartialConstExponentTest, All) {
+    if (GetParam().mFloatType == FloatType::f16) {
+        DAWN_TEST_UNSUPPORTED_IF(!mShaderF16Supported);
+    }
+    const auto exponentPhase = GetParam().mExponentPhase;
+    const auto exponentTooBig = GetParam().mExponentTooBig;
+    const auto wgsl = Shader();
+
+    const bool expect_create_shader_error = exponentTooBig && exponentPhase == Phase::kConst;
+
+    if (expect_create_shader_error) {
+        ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, wgsl));
+    } else {
+        const bool expect_pipeline_error = exponentTooBig && exponentPhase != Phase::kRuntime;
+
+        auto shader = utils::CreateShaderModule(device, wgsl);
+        wgpu::ComputePipelineDescriptor desc;
+        desc.compute.module = shader;
+        if (expect_pipeline_error) {
+            ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
+        } else {
+            device.CreateComputePipeline(&desc);
+        }
+    }
+}
+
+DAWN_INSTANTIATE_TEST_P(ShaderBuiltinPartialConstExponentTest,
+                        {D3D12Backend(), MetalBackend(), VulkanBackend()},
+                        {FloatType::f16, FloatType::f32},                    // mFloatType
+                        {Phase::kConst, Phase::kOverride, Phase::kRuntime},  // mCountPhase
+                        {true, false},                                       // mExponentTooBig
+                        {true, false});                                      // Scalar (or Vector)
+
 }  // anonymous namespace
 }  // namespace dawn
diff --git a/src/tint/lang/wgsl/resolver/builtins_validation_test.cc b/src/tint/lang/wgsl/resolver/builtins_validation_test.cc
index 61024f2..7552048 100644
--- a/src/tint/lang/wgsl/resolver/builtins_validation_test.cc
+++ b/src/tint/lang/wgsl/resolver/builtins_validation_test.cc
@@ -28,6 +28,7 @@
 #include <functional>
 #include "src/tint/lang/core/builtin_value.h"
 #include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/extension.h"
 #include "src/tint/lang/wgsl/resolver/resolver_helper_test.h"
 #include "src/tint/utils/text/string_stream.h"
 
@@ -2003,5 +2004,88 @@
                                             ::testing::ValuesIn({true}),
                                             ::testing::ValuesIn({true})));
 
+// We'll construct cases like this:
+// fn foo() {
+//   var x: XTYPE;
+//   _ = ldexp(x, EXPONENT);
+// }
+
+struct LdexpPartialConstCase {
+    builder::ast_type_func_ptr eType;
+    ExprMaker makeExponent;
+    bool expectPass = true;
+    int highestAllowed = 128;
+};
+
+using LdexpPartialConst = ResolverBuiltinsValidationTestWithParams<LdexpPartialConstCase>;
+
+TEST_P(LdexpPartialConst, Scalar) {
+    auto params = GetParam();
+    auto xTy = params.eType(*this);
+    const ast::Expression* exponent = params.makeExponent(this);
+
+    Enable(wgsl::Extension::kF16);
+    WrapInFunction(Var("x", xTy), Ignore(Call(Source{{12, 34}}, "ldexp", "x", exponent)));
+
+    if (params.expectPass) {
+        EXPECT_TRUE(r()->Resolve());
+    } else {
+        EXPECT_FALSE(r()->Resolve());
+        const std::string expect = "12:34 error: e2 must be less than or equal to " +
+                                   std::to_string(params.highestAllowed);
+        EXPECT_EQ(r()->error(), expect);
+    }
+}
+
+TEST_P(LdexpPartialConst, Vector) {
+    auto params = GetParam();
+    auto xTy = params.eType(*this);
+    const ast::Expression* exponent = params.makeExponent(this);
+
+    Enable(wgsl::Extension::kF16);
+    WrapInFunction(Var("x", xTy),                                   //
+                   Ignore(Call(Source{{12, 34}}, "ldexp",           //
+                               Call(Ident("vec3"), "x", "x", "x"),  //
+                               Call(Ident("vec3"), Expr(0_a), exponent, Expr(1_a)))));
+
+    if (params.expectPass) {
+        EXPECT_TRUE(r()->Resolve());
+    } else {
+        EXPECT_FALSE(r()->Resolve());
+        const std::string expect = "12:34 error: e2 must be less than or equal to " +
+                                   std::to_string(params.highestAllowed);
+        EXPECT_EQ(r()->error(), expect);
+    }
+}
+
+std::vector<LdexpPartialConstCase> ldexpCases() {
+    // Abstract Float cases don't apply here, because if the first parameter
+    // is abstract float, then it must be const already.  So that case is
+    // already checked by the full const-eval rules.
+    return std::vector<LdexpPartialConstCase>{
+        // Simple passing cases.
+        {DataType<f32>::AST, Mk(128_a), true},
+        {DataType<f32>::AST, Mk(128_i), true},
+        {DataType<f32>::AST, Mk(-5000_a), true},
+        {DataType<f32>::AST, Mk(-5000_i), true},
+        {DataType<f16>::AST, Mk(16_a), true},
+        {DataType<f16>::AST, Mk(16_i), true},
+        {DataType<f16>::AST, Mk(-5000_a), true},
+        {DataType<f16>::AST, Mk(-5000_i), true},
+
+        // Failing cases
+        {DataType<f32>::AST, Mk(129_a), false, 128},
+        {DataType<f32>::AST, Mk(129_i), false, 128},
+        {DataType<f32>::AST, Mk(5000_a), false, 128},
+        {DataType<f32>::AST, Mk(5000_i), false, 128},
+        {DataType<f16>::AST, Mk(17_a), false, 16},
+        {DataType<f16>::AST, Mk(17_i), false, 16},
+        {DataType<f16>::AST, Mk(5000_a), false, 16},
+        {DataType<f16>::AST, Mk(5000_i), false, 16},
+    };
+}
+
+INSTANTIATE_TEST_SUITE_P(Ldexp, LdexpPartialConst, ::testing::ValuesIn(ldexpCases()));
+
 }  // namespace
 }  // namespace tint::resolver
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 67a054d..c44661f 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -2533,6 +2533,16 @@
                 }
             }
         } break;
+        case wgsl::BuiltinFn::kLdexp:
+            if (auto* exponentConst = ConvertConstArgument(args, target, 1)) {
+                auto* zero = const_eval_.Zero(call->Type(), {}, Source{}).Get();
+                auto fakeArgs = Vector{zero, exponentConst};
+                auto res = const_eval_.ldexp(call->Type(), fakeArgs, call->Declaration()->source);
+                if (res != Success) {
+                    return nullptr;
+                }
+            }
+            break;
         case wgsl::BuiltinFn::kExtractBits: {
             auto* offsetConst = ConvertConstArgument(args, target, 1);
             auto* countConst = ConvertConstArgument(args, target, 2);
diff --git a/test/tint/vk-gl-cts/graphicsfuzz/cov-inst-combine-compares-ldexp/0-opt.spvasm b/test/tint/vk-gl-cts/graphicsfuzz/cov-inst-combine-compares-ldexp/0-opt.spvasm
index f80464d..b55a00a 100644
--- a/test/tint/vk-gl-cts/graphicsfuzz/cov-inst-combine-compares-ldexp/0-opt.spvasm
+++ b/test/tint/vk-gl-cts/graphicsfuzz/cov-inst-combine-compares-ldexp/0-opt.spvasm
@@ -35,7 +35,7 @@
         %int = OpTypeInt 32 1
       %int_0 = OpConstant %int 0
 %_ptr_Uniform_float = OpTypePointer Uniform %float
-  %int_10000 = OpConstant %int 10000
+  %int_100 = OpConstant %int 100
        %bool = OpTypeBool
     %v4float = OpTypeVector %float 4
 %_ptr_Output_v4float = OpTypePointer Output %v4float
@@ -51,7 +51,7 @@
          %27 = OpLabel
          %28 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %int_0
          %29 = OpLoad %float %28
-         %30 = OpExtInst %float %1 Ldexp %29 %int_10000
+         %30 = OpExtInst %float %1 Ldexp %29 %int_100
          %31 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %int_0
          %32 = OpLoad %float %31
          %33 = OpFOrdEqual %bool %30 %32
diff --git a/test/tint/vk-gl-cts/graphicsfuzz/cov-inst-combine-compares-ldexp/0-opt.wgsl b/test/tint/vk-gl-cts/graphicsfuzz/cov-inst-combine-compares-ldexp/0-opt.wgsl
index 49babb2..c3faec3 100644
--- a/test/tint/vk-gl-cts/graphicsfuzz/cov-inst-combine-compares-ldexp/0-opt.wgsl
+++ b/test/tint/vk-gl-cts/graphicsfuzz/cov-inst-combine-compares-ldexp/0-opt.wgsl
@@ -29,7 +29,7 @@
 fn main_1() {
   let x_29 : f32 = x_5.x_GLF_uniform_float_values[0].el;
   let x_32 : f32 = x_5.x_GLF_uniform_float_values[0].el;
-  if ((ldexp(x_29, 10000) == x_32)) {
+  if ((ldexp(x_29, 100) == x_32)) {
     let x_38 : i32 = x_7.x_GLF_uniform_int_values[1].el;
     let x_41 : i32 = x_7.x_GLF_uniform_int_values[0].el;
     let x_44 : i32 = x_7.x_GLF_uniform_int_values[0].el;