[tint][wgsl] validate ldexp const exponent < bias + 1 when other arg non-const
Bug: chromium:360155641
Change-Id: Iad0f580c56f1e4f09592f30be599f24c704fa00b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/202794
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/dawn/tests/end2end/ShaderBuiltinPartialConstArgsErrorTests.cpp b/src/dawn/tests/end2end/ShaderBuiltinPartialConstArgsErrorTests.cpp
index 7036da5..cd2c005 100644
--- a/src/dawn/tests/end2end/ShaderBuiltinPartialConstArgsErrorTests.cpp
+++ b/src/dawn/tests/end2end/ShaderBuiltinPartialConstArgsErrorTests.cpp
@@ -200,7 +200,7 @@
{true, false}); // Scalar (else Vector)
///////
-// insertBits
+// insertBits, extractBits
///////
template <class Params>
@@ -297,5 +297,150 @@
{true, false}, // mOffsetTooBig
{true, false}); // mCountTooBig
+///////
+// ldexp
+// If the first parameter is not const, then it must be concrete. So we don't
+// have to check the abstract float case.
+///////
+
+enum class FloatType {
+ f32,
+ f16,
+};
+std::ostream& operator<<(std::ostream& o, FloatType ty) {
+ switch (ty) {
+ case FloatType::f16:
+ o << "f16";
+ break;
+ case FloatType::f32:
+ o << "f32";
+ break;
+ default:
+ DAWN_UNREACHABLE();
+ break;
+ }
+ return o;
+}
+
+template <class Params>
+class BuiltinPartialConstExponentBase : public DawnTestWithParams<Params> {
+ public:
+ using DawnTestWithParams<Params>::GetParam;
+ using DawnTestWithParams<Params>::SupportsFeatures;
+ Phase mExponentPhase = Phase::kConst;
+
+ protected:
+ std::vector<wgpu::FeatureName> GetRequiredFeatures() override {
+ // Always require related features if available.
+ std::vector<wgpu::FeatureName> requiredFeatures;
+ if (SupportsFeatures({wgpu::FeatureName::ShaderF16})) {
+ mShaderF16Supported = true;
+ requiredFeatures.push_back(wgpu::FeatureName::ShaderF16);
+ }
+ return requiredFeatures;
+ }
+
+ bool mShaderF16Supported = false;
+};
+
+using ExponentPhase = Phase;
+using ExponentTooBig = bool;
+using Scalar = bool;
+DAWN_TEST_PARAM_STRUCT(BuiltinPartialConstExponentParams,
+ FloatType,
+ ExponentPhase,
+ ExponentTooBig,
+ Scalar);
+
+class ShaderBuiltinPartialConstExponentTest
+ : public BuiltinPartialConstExponentBase<BuiltinPartialConstExponentParams> {
+ protected:
+ static int bias(FloatType ft) {
+ switch (ft) {
+ case FloatType::f32:
+ return 127;
+ case FloatType::f16:
+ return 15;
+ }
+ DAWN_UNREACHABLE();
+ }
+
+ static std::string suffix(FloatType ft) {
+ switch (ft) {
+ case FloatType::f32:
+ return "f";
+ case FloatType::f16:
+ return "h";
+ }
+ DAWN_UNREACHABLE();
+ }
+
+ std::string Shader() {
+ const FloatType ty = GetParam().mFloatType;
+
+ const bool too_big = GetParam().mExponentTooBig;
+ const int exponent_val = bias(ty) + 1 + (too_big ? 1 : 0);
+
+ std::stringstream code;
+ if (ty == FloatType::f16) {
+ code << "enable f16;\n";
+ }
+ auto module_var = [&](std::string ident, Phase p, float value) {
+ if (p != Phase::kRuntime) {
+ code << p << " " << ident << ": i32 = " << value << ";\n";
+ }
+ };
+ auto function_var = [&](std::string ident, Phase p, int value) {
+ if (p == Phase::kRuntime) {
+ code << " var " << ident << ": i32 = " << value << ";\n";
+ }
+ };
+ module_var("exponent", GetParam().mExponentPhase, exponent_val);
+ code << "@compute @workgroup_size(1) fn main() {\n";
+ function_var("exponent", GetParam().mExponentPhase, exponent_val);
+ code << " var x: " << ty << " = 0;\n";
+ if (GetParam().mScalar) {
+ code << " _ = ldexp(x,exponent);\n";
+ } else {
+ code << " _ = ldexp(vec2(x,x),vec2(0,exponent));\n";
+ }
+ code << "}";
+ return code.str();
+ }
+};
+
+TEST_P(ShaderBuiltinPartialConstExponentTest, All) {
+ if (GetParam().mFloatType == FloatType::f16) {
+ DAWN_TEST_UNSUPPORTED_IF(!mShaderF16Supported);
+ }
+ const auto exponentPhase = GetParam().mExponentPhase;
+ const auto exponentTooBig = GetParam().mExponentTooBig;
+ const auto wgsl = Shader();
+
+ const bool expect_create_shader_error = exponentTooBig && exponentPhase == Phase::kConst;
+
+ if (expect_create_shader_error) {
+ ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, wgsl));
+ } else {
+ const bool expect_pipeline_error = exponentTooBig && exponentPhase != Phase::kRuntime;
+
+ auto shader = utils::CreateShaderModule(device, wgsl);
+ wgpu::ComputePipelineDescriptor desc;
+ desc.compute.module = shader;
+ if (expect_pipeline_error) {
+ ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
+ } else {
+ device.CreateComputePipeline(&desc);
+ }
+ }
+}
+
+DAWN_INSTANTIATE_TEST_P(ShaderBuiltinPartialConstExponentTest,
+ {D3D12Backend(), MetalBackend(), VulkanBackend()},
+ {FloatType::f16, FloatType::f32}, // mFloatType
+ {Phase::kConst, Phase::kOverride, Phase::kRuntime}, // mCountPhase
+ {true, false}, // mExponentTooBig
+ {true, false}); // Scalar (or Vector)
+
} // anonymous namespace
} // namespace dawn
diff --git a/src/tint/lang/wgsl/resolver/builtins_validation_test.cc b/src/tint/lang/wgsl/resolver/builtins_validation_test.cc
index 61024f2..7552048 100644
--- a/src/tint/lang/wgsl/resolver/builtins_validation_test.cc
+++ b/src/tint/lang/wgsl/resolver/builtins_validation_test.cc
@@ -28,6 +28,7 @@
#include <functional>
#include "src/tint/lang/core/builtin_value.h"
#include "src/tint/lang/wgsl/ast/call_statement.h"
+#include "src/tint/lang/wgsl/extension.h"
#include "src/tint/lang/wgsl/resolver/resolver_helper_test.h"
#include "src/tint/utils/text/string_stream.h"
@@ -2003,5 +2004,88 @@
::testing::ValuesIn({true}),
::testing::ValuesIn({true})));
+// We'll construct cases like this:
+// fn foo() {
+// var x: XTYPE;
+// _ = ldexp(x, EXPONENT);
+// }
+
+struct LdexpPartialConstCase {
+ builder::ast_type_func_ptr eType;
+ ExprMaker makeExponent;
+ bool expectPass = true;
+ int highestAllowed = 128;
+};
+
+using LdexpPartialConst = ResolverBuiltinsValidationTestWithParams<LdexpPartialConstCase>;
+
+TEST_P(LdexpPartialConst, Scalar) {
+ auto params = GetParam();
+ auto xTy = params.eType(*this);
+ const ast::Expression* exponent = params.makeExponent(this);
+
+ Enable(wgsl::Extension::kF16);
+ WrapInFunction(Var("x", xTy), Ignore(Call(Source{{12, 34}}, "ldexp", "x", exponent)));
+
+ if (params.expectPass) {
+ EXPECT_TRUE(r()->Resolve());
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ const std::string expect = "12:34 error: e2 must be less than or equal to " +
+ std::to_string(params.highestAllowed);
+ EXPECT_EQ(r()->error(), expect);
+ }
+}
+
+TEST_P(LdexpPartialConst, Vector) {
+ auto params = GetParam();
+ auto xTy = params.eType(*this);
+ const ast::Expression* exponent = params.makeExponent(this);
+
+ Enable(wgsl::Extension::kF16);
+ WrapInFunction(Var("x", xTy), //
+ Ignore(Call(Source{{12, 34}}, "ldexp", //
+ Call(Ident("vec3"), "x", "x", "x"), //
+ Call(Ident("vec3"), Expr(0_a), exponent, Expr(1_a)))));
+
+ if (params.expectPass) {
+ EXPECT_TRUE(r()->Resolve());
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ const std::string expect = "12:34 error: e2 must be less than or equal to " +
+ std::to_string(params.highestAllowed);
+ EXPECT_EQ(r()->error(), expect);
+ }
+}
+
+std::vector<LdexpPartialConstCase> ldexpCases() {
+ // Abstract Float cases don't apply here, because if the first parameter
+ // is abstract float, then it must be const already. So that case is
+ // already checked by the full const-eval rules.
+ return std::vector<LdexpPartialConstCase>{
+ // Simple passing cases.
+ {DataType<f32>::AST, Mk(128_a), true},
+ {DataType<f32>::AST, Mk(128_i), true},
+ {DataType<f32>::AST, Mk(-5000_a), true},
+ {DataType<f32>::AST, Mk(-5000_i), true},
+ {DataType<f16>::AST, Mk(16_a), true},
+ {DataType<f16>::AST, Mk(16_i), true},
+ {DataType<f16>::AST, Mk(-5000_a), true},
+ {DataType<f16>::AST, Mk(-5000_i), true},
+
+ // Failing cases
+ {DataType<f32>::AST, Mk(129_a), false, 128},
+ {DataType<f32>::AST, Mk(129_i), false, 128},
+ {DataType<f32>::AST, Mk(5000_a), false, 128},
+ {DataType<f32>::AST, Mk(5000_i), false, 128},
+ {DataType<f16>::AST, Mk(17_a), false, 16},
+ {DataType<f16>::AST, Mk(17_i), false, 16},
+ {DataType<f16>::AST, Mk(5000_a), false, 16},
+ {DataType<f16>::AST, Mk(5000_i), false, 16},
+ };
+}
+
+INSTANTIATE_TEST_SUITE_P(Ldexp, LdexpPartialConst, ::testing::ValuesIn(ldexpCases()));
+
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 67a054d..c44661f 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -2533,6 +2533,16 @@
}
}
} break;
+ case wgsl::BuiltinFn::kLdexp:
+ if (auto* exponentConst = ConvertConstArgument(args, target, 1)) {
+ auto* zero = const_eval_.Zero(call->Type(), {}, Source{}).Get();
+ auto fakeArgs = Vector{zero, exponentConst};
+ auto res = const_eval_.ldexp(call->Type(), fakeArgs, call->Declaration()->source);
+ if (res != Success) {
+ return nullptr;
+ }
+ }
+ break;
case wgsl::BuiltinFn::kExtractBits: {
auto* offsetConst = ConvertConstArgument(args, target, 1);
auto* countConst = ConvertConstArgument(args, target, 2);
diff --git a/test/tint/vk-gl-cts/graphicsfuzz/cov-inst-combine-compares-ldexp/0-opt.spvasm b/test/tint/vk-gl-cts/graphicsfuzz/cov-inst-combine-compares-ldexp/0-opt.spvasm
index f80464d..b55a00a 100644
--- a/test/tint/vk-gl-cts/graphicsfuzz/cov-inst-combine-compares-ldexp/0-opt.spvasm
+++ b/test/tint/vk-gl-cts/graphicsfuzz/cov-inst-combine-compares-ldexp/0-opt.spvasm
@@ -35,7 +35,7 @@
%int = OpTypeInt 32 1
%int_0 = OpConstant %int 0
%_ptr_Uniform_float = OpTypePointer Uniform %float
- %int_10000 = OpConstant %int 10000
+ %int_100 = OpConstant %int 100
%bool = OpTypeBool
%v4float = OpTypeVector %float 4
%_ptr_Output_v4float = OpTypePointer Output %v4float
@@ -51,7 +51,7 @@
%27 = OpLabel
%28 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %int_0
%29 = OpLoad %float %28
- %30 = OpExtInst %float %1 Ldexp %29 %int_10000
+ %30 = OpExtInst %float %1 Ldexp %29 %int_100
%31 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %int_0
%32 = OpLoad %float %31
%33 = OpFOrdEqual %bool %30 %32
diff --git a/test/tint/vk-gl-cts/graphicsfuzz/cov-inst-combine-compares-ldexp/0-opt.wgsl b/test/tint/vk-gl-cts/graphicsfuzz/cov-inst-combine-compares-ldexp/0-opt.wgsl
index 49babb2..c3faec3 100644
--- a/test/tint/vk-gl-cts/graphicsfuzz/cov-inst-combine-compares-ldexp/0-opt.wgsl
+++ b/test/tint/vk-gl-cts/graphicsfuzz/cov-inst-combine-compares-ldexp/0-opt.wgsl
@@ -29,7 +29,7 @@
fn main_1() {
let x_29 : f32 = x_5.x_GLF_uniform_float_values[0].el;
let x_32 : f32 = x_5.x_GLF_uniform_float_values[0].el;
- if ((ldexp(x_29, 10000) == x_32)) {
+ if ((ldexp(x_29, 100) == x_32)) {
let x_38 : i32 = x_7.x_GLF_uniform_int_values[1].el;
let x_41 : i32 = x_7.x_GLF_uniform_int_values[0].el;
let x_44 : i32 = x_7.x_GLF_uniform_int_values[0].el;