[tint] validate insertBits count+offset when other args are runtime-eval
count+offset must not exceed the bit width of the other arguments.
Bug: chromium:351372334
Change-Id: I41749e9700f247544c3e396c8362a395c3eac708
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/202554
Reviewed-by: James Price <jrprice@google.com>
Auto-Submit: David Neto <dneto@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/dawn/tests/end2end/ShaderBuiltinPartialConstArgsErrorTests.cpp b/src/dawn/tests/end2end/ShaderBuiltinPartialConstArgsErrorTests.cpp
index 9f12a70..33bd6bc 100644
--- a/src/dawn/tests/end2end/ShaderBuiltinPartialConstArgsErrorTests.cpp
+++ b/src/dawn/tests/end2end/ShaderBuiltinPartialConstArgsErrorTests.cpp
@@ -83,8 +83,12 @@
return o;
}
+///////
+// clamp, smoothstep
+///////
+
template <class Params>
-struct BuiltinPartialConstArgsErrorBase : public DawnTestWithParams<Params> {
+struct BuiltinPartialConstLowHighBase : public DawnTestWithParams<Params> {
using DawnTestWithParams<Params>::GetParam;
Phase mLowPhase = Phase::kConst;
Phase mHighPhase = Phase::kConst;
@@ -95,15 +99,15 @@
using LowPhase = Phase;
using HighPhase = Phase;
using Scalar = bool;
-DAWN_TEST_PARAM_STRUCT(BuiltinPartialConstArgsErrorTestParams,
+DAWN_TEST_PARAM_STRUCT(BuiltinPartialConstLowHighParams,
Builtin,
LowPhase,
HighPhase,
Compare,
Scalar);
-class ShaderBuiltinPartialConstArgsErrorTest
- : public BuiltinPartialConstArgsErrorBase<BuiltinPartialConstArgsErrorTestParams> {
+class ShaderBuiltinPartialConstLowHighTest
+ : public BuiltinPartialConstLowHighBase<BuiltinPartialConstLowHighParams> {
protected:
std::string Shader() {
const auto builtin = GetParam().mBuiltin;
@@ -159,7 +163,7 @@
}
};
-TEST_P(ShaderBuiltinPartialConstArgsErrorTest, All) {
+TEST_P(ShaderBuiltinPartialConstLowHighTest, All) {
const auto builtin = GetParam().mBuiltin;
const auto lowPhase = GetParam().mLowPhase;
const auto highPhase = GetParam().mHighPhase;
@@ -186,8 +190,7 @@
}
}
-// DawnTestBase::CreateDeviceImpl always enables allow_unsafe_apis toggle.
-DAWN_INSTANTIATE_TEST_P(ShaderBuiltinPartialConstArgsErrorTest,
+DAWN_INSTANTIATE_TEST_P(ShaderBuiltinPartialConstLowHighTest,
{D3D11Backend(), D3D12Backend(), MetalBackend(), NullBackend(),
OpenGLBackend(), OpenGLESBackend(), VulkanBackend()},
{"clamp", "smoothstep"}, // mBuiltin
@@ -196,5 +199,98 @@
{Compare::kLess, Compare::kEqual, Compare::kMore}, // mCompare
{true, false}); // Scalar (else Vector)
+///////
+// insertBits
+///////
+
+template <class Params>
+struct BuiltinPartialConstOffsetCountBase : public DawnTestWithParams<Params> {
+ using DawnTestWithParams<Params>::GetParam;
+ Phase mOffsetPhase = Phase::kConst;
+ Phase mCountPhase = Phase::kConst;
+};
+
+using Builtin = std::string;
+using OffsetPhase = Phase;
+using CountPhase = Phase;
+using OffsetTooBig = bool;
+using CountTooBig = bool;
+DAWN_TEST_PARAM_STRUCT(BuiltinPartialConstOffsetCountParams,
+ Builtin,
+ OffsetPhase,
+ CountPhase,
+ OffsetTooBig,
+ CountTooBig);
+
+class ShaderBuiltinPartialConstOffsetCountTest
+ : public BuiltinPartialConstOffsetCountBase<BuiltinPartialConstOffsetCountParams> {
+ protected:
+ std::string Shader() {
+ // Assume 32-bit integers.
+ const auto builtin = GetParam().mBuiltin;
+ const int offset_val = 16 + (GetParam().mOffsetTooBig ? 1 : 0);
+ const int count_val = 16 + (GetParam().mCountTooBig ? 1 : 0);
+
+ std::stringstream code;
+ auto module_var = [&](std::string ident, Phase p, float value) {
+ if (p != Phase::kRuntime) {
+ code << p << " " << ident << ": u32 = " << value << ";\n";
+ }
+ };
+ auto function_var = [&](std::string ident, Phase p, int value) {
+ if (p == Phase::kRuntime) {
+ code << " var " << ident << ": u32 = " << value << ";\n";
+ }
+ };
+ module_var("offset", GetParam().mOffsetPhase, offset_val);
+ module_var("count", GetParam().mCountPhase, count_val);
+ code << "@compute @workgroup_size(1) fn main() {\n";
+ function_var("offset", GetParam().mOffsetPhase, offset_val);
+ function_var("count", GetParam().mCountPhase, count_val);
+ code << " var s: u32 = 0;\n";
+ code << " _ = insertBits(s,s,offset,count);\n";
+ code << "}";
+ return code.str();
+ }
+};
+
+TEST_P(ShaderBuiltinPartialConstOffsetCountTest, All) {
+ const auto builtin = GetParam().mBuiltin;
+ const auto offsetPhase = GetParam().mOffsetPhase;
+ const auto countPhase = GetParam().mCountPhase;
+ const auto offsetTooBig = GetParam().mOffsetTooBig;
+ const auto countTooBig = GetParam().mCountTooBig;
+ const auto wgsl = Shader();
+
+ const bool expect_create_shader_error = (offsetTooBig || countTooBig) //
+ && offsetPhase == Phase::kConst //
+ && countPhase == Phase::kConst;
+
+ if (expect_create_shader_error) {
+ ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, wgsl));
+ } else {
+ const bool expect_pipeline_error = (offsetTooBig || countTooBig) //
+ && offsetPhase != Phase::kRuntime //
+ && countPhase != Phase::kRuntime;
+ auto shader = utils::CreateShaderModule(device, wgsl);
+ wgpu::ComputePipelineDescriptor desc;
+ desc.compute.module = shader;
+ if (expect_pipeline_error) {
+ ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
+ } else {
+ device.CreateComputePipeline(&desc);
+ }
+ }
+}
+
+DAWN_INSTANTIATE_TEST_P(ShaderBuiltinPartialConstOffsetCountTest,
+ {D3D11Backend(), D3D12Backend(), MetalBackend(), NullBackend(),
+ OpenGLBackend(), OpenGLESBackend(), VulkanBackend()},
+ {"insertBits"}, // mBuiltin
+ {Phase::kConst, Phase::kOverride, Phase::kRuntime}, // mOffsetPhase
+ {Phase::kConst, Phase::kOverride, Phase::kRuntime}, // mCountPhase
+ {true, false}, // mOffsetTooBig
+ {true, false}); // mCountTooBig
+
} // anonymous namespace
} // namespace dawn
diff --git a/src/tint/lang/core/constant/eval.cc b/src/tint/lang/core/constant/eval.cc
index 8db34d5..ded2d07 100644
--- a/src/tint/lang/core/constant/eval.cc
+++ b/src/tint/lang/core/constant/eval.cc
@@ -2627,7 +2627,7 @@
constexpr UT w = sizeof(UT) * 8;
if (o > w || c > w || (o + c) > w) {
AddError(source)
- << "'offset + 'count' must be less than or equal to the bit width of 'e'";
+ << "'offset' + 'count' must be less than or equal to the bit width of 'e'";
if (use_runtime_semantics_) {
o = std::min(o, w);
c = std::min(c, w - o);
@@ -2896,7 +2896,7 @@
constexpr UT w = sizeof(UT) * 8;
if (o > w || c > w || (o + c) > w) {
AddError(source)
- << "'offset + 'count' must be less than or equal to the bit width of 'e'";
+ << "'offset' + 'count' must be less than or equal to the bit width of 'e'";
if (use_runtime_semantics_) {
o = std::min(o, w);
c = std::min(c, w - o);
diff --git a/src/tint/lang/core/constant/eval_binary_op_test.cc b/src/tint/lang/core/constant/eval_binary_op_test.cc
index 812d633..ef9f163 100644
--- a/src/tint/lang/core/constant/eval_binary_op_test.cc
+++ b/src/tint/lang/core/constant/eval_binary_op_test.cc
@@ -2207,7 +2207,7 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'");
+ "12:34 error: 'offset' + 'count' must be less than or equal to the bit width of 'e'");
}
TEST_F(ConstEvalTest, ShortCircuit_And_Error_BuiltinCall) {
@@ -2255,7 +2255,7 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'");
+ "12:34 error: 'offset' + 'count' must be less than or equal to the bit width of 'e'");
}
TEST_F(ConstEvalTest, ShortCircuit_Or_Error_BuiltinCall) {
diff --git a/src/tint/lang/core/constant/eval_builtin_test.cc b/src/tint/lang/core/constant/eval_builtin_test.cc
index 4eb8152..0eaf4d2 100644
--- a/src/tint/lang/core/constant/eval_builtin_test.cc
+++ b/src/tint/lang/core/constant/eval_builtin_test.cc
@@ -1374,7 +1374,7 @@
};
const char* error_msg =
- "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'";
+ "12:34 error: 'offset' + 'count' must be less than or equal to the bit width of 'e'";
ConcatInto( //
r, std::vector<Case>{
E({T(1), T(1), UT(33), UT(0)}, error_msg), //
@@ -1582,7 +1582,7 @@
};
const char* error_msg =
- "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'";
+ "12:34 error: 'offset' + 'count' must be less than or equal to the bit width of 'e'";
ConcatInto( //
r, std::vector<Case>{
E({T(1), UT(33), UT(0)}, error_msg),
diff --git a/src/tint/lang/core/constant/eval_runtime_semantics_test.cc b/src/tint/lang/core/constant/eval_runtime_semantics_test.cc
index ab244d5..be30634 100644
--- a/src/tint/lang/core/constant/eval_runtime_semantics_test.cc
+++ b/src/tint/lang/core/constant/eval_runtime_semantics_test.cc
@@ -372,7 +372,7 @@
ASSERT_EQ(result, Success);
EXPECT_EQ(result.Get()->ValueAs<i32>(), 0x12);
EXPECT_EQ(error(),
- R"(warning: 'offset + 'count' must be less than or equal to the bit width of 'e')");
+ R"(warning: 'offset' + 'count' must be less than or equal to the bit width of 'e')");
}
TEST_F(ConstEvalRuntimeSemanticsTest, ExtractBits_U32_TooManyBits) {
@@ -383,7 +383,7 @@
ASSERT_EQ(result, Success);
EXPECT_EQ(result.Get()->ValueAs<u32>(), 0x12);
EXPECT_EQ(error(),
- R"(warning: 'offset + 'count' must be less than or equal to the bit width of 'e')");
+ R"(warning: 'offset' + 'count' must be less than or equal to the bit width of 'e')");
}
TEST_F(ConstEvalRuntimeSemanticsTest, InsertBits_I32_TooManyBits) {
@@ -395,7 +395,7 @@
ASSERT_EQ(result, Success);
EXPECT_EQ(result.Get()->ValueAs<i32>(), 0x12345678);
EXPECT_EQ(error(),
- R"(warning: 'offset + 'count' must be less than or equal to the bit width of 'e')");
+ R"(warning: 'offset' + 'count' must be less than or equal to the bit width of 'e')");
}
TEST_F(ConstEvalRuntimeSemanticsTest, InsertBits_U32_TooManyBits) {
@@ -407,7 +407,7 @@
ASSERT_EQ(result, Success);
EXPECT_EQ(result.Get()->ValueAs<u32>(), 0x12345678);
EXPECT_EQ(error(),
- R"(warning: 'offset + 'count' must be less than or equal to the bit width of 'e')");
+ R"(warning: 'offset' + 'count' must be less than or equal to the bit width of 'e')");
}
TEST_F(ConstEvalRuntimeSemanticsTest, InverseSqrt_F32_OutOfRange) {
diff --git a/src/tint/lang/wgsl/resolver/builtins_validation_test.cc b/src/tint/lang/wgsl/resolver/builtins_validation_test.cc
index 2c3075f..25d4dc1 100644
--- a/src/tint/lang/wgsl/resolver/builtins_validation_test.cc
+++ b/src/tint/lang/wgsl/resolver/builtins_validation_test.cc
@@ -1709,5 +1709,119 @@
SmoothstepPartialConst,
::testing::ValuesIn(smoothstepCases()));
+// We'll construct cases like this:
+// fn foo() {
+// var e: ETYPE;
+// _ = insertBits(e, e, COUNT, OFFSET);
+// }
+
+struct InsertBitsPartialConstCase {
+ builder::ast_type_func_ptr eType;
+ ExprMaker makeOffset;
+ ExprMaker makeCount;
+ bool expectPass = true;
+ int width = 32;
+};
+
+using InsertBitsPartialConst =
+ ResolverBuiltinsValidationTestWithParams<std::tuple<InsertBitsPartialConstCase, bool, bool>>;
+
+TEST_P(InsertBitsPartialConst, Scalar) {
+ auto [params, firstConst, secondConst] = GetParam();
+ auto eTy = params.eType(*this);
+ const ast::Expression* offset = params.makeOffset(this);
+ const ast::Expression* count = params.makeCount(this);
+ const ast::Variable* offsetDecl;
+ if (firstConst) {
+ offsetDecl = Const("offset", offset);
+ } else {
+ offsetDecl = Var("offset", offset);
+ }
+ const ast::Variable* countDecl;
+ if (secondConst) {
+ countDecl = Const("count", count);
+ } else {
+ countDecl = Var("count", count);
+ }
+ WrapInFunction(Var("e", eTy), offsetDecl, countDecl,
+ Ignore(Call(Source{{12, 34}}, "insertBits", "e", "e", "offset", "count")));
+
+ const auto expectPass = params.expectPass || !(firstConst && secondConst);
+
+ if (expectPass) {
+ EXPECT_TRUE(r()->Resolve());
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ const std::string expect =
+ "12:34 error: 'offset' + 'count' must be less than or equal to the bit width of 'e'";
+ EXPECT_EQ(r()->error(), expect);
+ }
+}
+
+TEST_P(InsertBitsPartialConst, Vector) {
+ auto [params, firstConst, secondConst] = GetParam();
+ auto eTy = params.eType(*this);
+ const ast::Expression* offset = params.makeOffset(this);
+ const ast::Expression* count = params.makeCount(this);
+ const ast::Variable* offsetDecl;
+ if (firstConst) {
+ offsetDecl = Const("offset", offset);
+ } else {
+ offsetDecl = Var("offset", offset);
+ }
+ const ast::Variable* countDecl;
+ if (secondConst) {
+ countDecl = Const("count", count);
+ } else {
+ countDecl = Var("count", count);
+ }
+ WrapInFunction(Var("e", eTy), offsetDecl, countDecl, //
+ Ignore(Call(Source{{12, 34}}, "insertBits", //
+ Call(Ident("vec3"), "e", "e", "e"), //
+ Call(Ident("vec3"), "e", "e", "e"), //
+ "offset", "count")));
+
+ const auto expectPass = params.expectPass || !(firstConst && secondConst);
+
+ if (expectPass) {
+ EXPECT_TRUE(r()->Resolve());
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ const std::string expect =
+ "12:34 error: 'offset' + 'count' must be less than or equal to the bit width of 'e'";
+ EXPECT_EQ(r()->error(), expect);
+ }
+}
+
+std::vector<InsertBitsPartialConstCase> insertBitsCases() {
+ return std::vector<InsertBitsPartialConstCase>{
+ // Simple passing cases.
+ {DataType<u32>::AST, Mk(0_a), Mk(0_a), true},
+ {DataType<i32>::AST, Mk(0_a), Mk(0_a), true},
+ {DataType<u32>::AST, Mk(16_a), Mk(16_a), true},
+ {DataType<i32>::AST, Mk(16_a), Mk(16_a), true},
+ {DataType<u32>::AST, Mk(32_a), Mk(0_a), true},
+ {DataType<i32>::AST, Mk(32_a), Mk(0_a), true},
+ {DataType<u32>::AST, Mk(0_a), Mk(0_a), true},
+ {DataType<i32>::AST, Mk(0_a), Mk(0_a), true},
+ {DataType<u32>::AST, Mk(32_a), Mk(0_u), true},
+ {DataType<i32>::AST, Mk(32_u), Mk(0_a), true},
+
+ // AInt AInt
+ {DataType<u32>::AST, Mk(0_a), Mk(33_a), false},
+ {DataType<u32>::AST, Mk(16_a), Mk(17_a), false},
+ {DataType<u32>::AST, Mk(33_a), Mk(0_a), false},
+
+ {DataType<i32>::AST, Mk(0_a), Mk(33_u), false}, // Aint u32
+ {DataType<i32>::AST, Mk(16_u), Mk(17_a), false}, // u32 AInt
+ };
+}
+
+INSTANTIATE_TEST_SUITE_P(InsertBits,
+ InsertBitsPartialConst,
+ ::testing::Combine(::testing::ValuesIn(insertBitsCases()),
+ ::testing::ValuesIn({true}),
+ ::testing::ValuesIn({true})));
+
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 954de2d..79069eb 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -2533,6 +2533,19 @@
}
}
} break;
+ case wgsl::BuiltinFn::kInsertBits: {
+ auto* offsetConst = ConvertConstArgument(args, target, 2);
+ auto* countConst = ConvertConstArgument(args, target, 3);
+ if (offsetConst && countConst) {
+ auto* zero = const_eval_.Zero(call->Type(), {}, Source{}).Get();
+ auto fakeArgs = Vector{zero, zero, offsetConst, countConst};
+ auto res =
+ const_eval_.insertBits(call->Type(), fakeArgs, call->Declaration()->source);
+ if (res != Success) {
+ return nullptr;
+ }
+ }
+ } break;
case wgsl::BuiltinFn::kSmoothstep: {
auto* lowConst = ConvertConstArgument(args, target, 0);
auto* highConst = ConvertConstArgument(args, target, 1);