[tint] validate insertBits count+offset when other args are runtime-eval

count+offset must not exceed the bit width of the other arguments.

Bug: chromium:351372334
Change-Id: I41749e9700f247544c3e396c8362a395c3eac708
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/202554
Reviewed-by: James Price <jrprice@google.com>
Auto-Submit: David Neto <dneto@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/dawn/tests/end2end/ShaderBuiltinPartialConstArgsErrorTests.cpp b/src/dawn/tests/end2end/ShaderBuiltinPartialConstArgsErrorTests.cpp
index 9f12a70..33bd6bc 100644
--- a/src/dawn/tests/end2end/ShaderBuiltinPartialConstArgsErrorTests.cpp
+++ b/src/dawn/tests/end2end/ShaderBuiltinPartialConstArgsErrorTests.cpp
@@ -83,8 +83,12 @@
     return o;
 }
 
+///////
+// clamp, smoothstep
+///////
+
 template <class Params>
-struct BuiltinPartialConstArgsErrorBase : public DawnTestWithParams<Params> {
+struct BuiltinPartialConstLowHighBase : public DawnTestWithParams<Params> {
     using DawnTestWithParams<Params>::GetParam;
     Phase mLowPhase = Phase::kConst;
     Phase mHighPhase = Phase::kConst;
@@ -95,15 +99,15 @@
 using LowPhase = Phase;
 using HighPhase = Phase;
 using Scalar = bool;
-DAWN_TEST_PARAM_STRUCT(BuiltinPartialConstArgsErrorTestParams,
+DAWN_TEST_PARAM_STRUCT(BuiltinPartialConstLowHighParams,
                        Builtin,
                        LowPhase,
                        HighPhase,
                        Compare,
                        Scalar);
 
-class ShaderBuiltinPartialConstArgsErrorTest
-    : public BuiltinPartialConstArgsErrorBase<BuiltinPartialConstArgsErrorTestParams> {
+class ShaderBuiltinPartialConstLowHighTest
+    : public BuiltinPartialConstLowHighBase<BuiltinPartialConstLowHighParams> {
   protected:
     std::string Shader() {
         const auto builtin = GetParam().mBuiltin;
@@ -159,7 +163,7 @@
     }
 };
 
-TEST_P(ShaderBuiltinPartialConstArgsErrorTest, All) {
+TEST_P(ShaderBuiltinPartialConstLowHighTest, All) {
     const auto builtin = GetParam().mBuiltin;
     const auto lowPhase = GetParam().mLowPhase;
     const auto highPhase = GetParam().mHighPhase;
@@ -186,8 +190,7 @@
     }
 }
 
-// DawnTestBase::CreateDeviceImpl always enables allow_unsafe_apis toggle.
-DAWN_INSTANTIATE_TEST_P(ShaderBuiltinPartialConstArgsErrorTest,
+DAWN_INSTANTIATE_TEST_P(ShaderBuiltinPartialConstLowHighTest,
                         {D3D11Backend(), D3D12Backend(), MetalBackend(), NullBackend(),
                          OpenGLBackend(), OpenGLESBackend(), VulkanBackend()},
                         {"clamp", "smoothstep"},                             // mBuiltin
@@ -196,5 +199,98 @@
                         {Compare::kLess, Compare::kEqual, Compare::kMore},   // mCompare
                         {true, false});                                      // Scalar (else Vector)
 
+///////
+// insertBits
+///////
+
+template <class Params>
+struct BuiltinPartialConstOffsetCountBase : public DawnTestWithParams<Params> {
+    using DawnTestWithParams<Params>::GetParam;
+    Phase mOffsetPhase = Phase::kConst;
+    Phase mCountPhase = Phase::kConst;
+};
+
+using Builtin = std::string;
+using OffsetPhase = Phase;
+using CountPhase = Phase;
+using OffsetTooBig = bool;
+using CountTooBig = bool;
+DAWN_TEST_PARAM_STRUCT(BuiltinPartialConstOffsetCountParams,
+                       Builtin,
+                       OffsetPhase,
+                       CountPhase,
+                       OffsetTooBig,
+                       CountTooBig);
+
+class ShaderBuiltinPartialConstOffsetCountTest
+    : public BuiltinPartialConstOffsetCountBase<BuiltinPartialConstOffsetCountParams> {
+  protected:
+    std::string Shader() {
+        // Assume 32-bit integers.
+        const auto builtin = GetParam().mBuiltin;
+        const int offset_val = 16 + (GetParam().mOffsetTooBig ? 1 : 0);
+        const int count_val = 16 + (GetParam().mCountTooBig ? 1 : 0);
+
+        std::stringstream code;
+        auto module_var = [&](std::string ident, Phase p, float value) {
+            if (p != Phase::kRuntime) {
+                code << p << " " << ident << ": u32 = " << value << ";\n";
+            }
+        };
+        auto function_var = [&](std::string ident, Phase p, int value) {
+            if (p == Phase::kRuntime) {
+                code << "  var " << ident << ": u32 = " << value << ";\n";
+            }
+        };
+        module_var("offset", GetParam().mOffsetPhase, offset_val);
+        module_var("count", GetParam().mCountPhase, count_val);
+        code << "@compute @workgroup_size(1) fn main() {\n";
+        function_var("offset", GetParam().mOffsetPhase, offset_val);
+        function_var("count", GetParam().mCountPhase, count_val);
+        code << "  var s: u32 = 0;\n";
+        code << " _ = insertBits(s,s,offset,count);\n";
+        code << "}";
+        return code.str();
+    }
+};
+
+TEST_P(ShaderBuiltinPartialConstOffsetCountTest, All) {
+    const auto builtin = GetParam().mBuiltin;
+    const auto offsetPhase = GetParam().mOffsetPhase;
+    const auto countPhase = GetParam().mCountPhase;
+    const auto offsetTooBig = GetParam().mOffsetTooBig;
+    const auto countTooBig = GetParam().mCountTooBig;
+    const auto wgsl = Shader();
+
+    const bool expect_create_shader_error = (offsetTooBig || countTooBig)    //
+                                            && offsetPhase == Phase::kConst  //
+                                            && countPhase == Phase::kConst;
+
+    if (expect_create_shader_error) {
+        ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, wgsl));
+    } else {
+        const bool expect_pipeline_error = (offsetTooBig || countTooBig)      //
+                                           && offsetPhase != Phase::kRuntime  //
+                                           && countPhase != Phase::kRuntime;
+        auto shader = utils::CreateShaderModule(device, wgsl);
+        wgpu::ComputePipelineDescriptor desc;
+        desc.compute.module = shader;
+        if (expect_pipeline_error) {
+            ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc));
+        } else {
+            device.CreateComputePipeline(&desc);
+        }
+    }
+}
+
+DAWN_INSTANTIATE_TEST_P(ShaderBuiltinPartialConstOffsetCountTest,
+                        {D3D11Backend(), D3D12Backend(), MetalBackend(), NullBackend(),
+                         OpenGLBackend(), OpenGLESBackend(), VulkanBackend()},
+                        {"insertBits"},                                      // mBuiltin
+                        {Phase::kConst, Phase::kOverride, Phase::kRuntime},  // mOffsetPhase
+                        {Phase::kConst, Phase::kOverride, Phase::kRuntime},  // mCountPhase
+                        {true, false},                                       // mOffsetTooBig
+                        {true, false});                                      // mCountTooBig
+
 }  // anonymous namespace
 }  // namespace dawn
diff --git a/src/tint/lang/core/constant/eval.cc b/src/tint/lang/core/constant/eval.cc
index 8db34d5..ded2d07 100644
--- a/src/tint/lang/core/constant/eval.cc
+++ b/src/tint/lang/core/constant/eval.cc
@@ -2627,7 +2627,7 @@
                 constexpr UT w = sizeof(UT) * 8;
                 if (o > w || c > w || (o + c) > w) {
                     AddError(source)
-                        << "'offset + 'count' must be less than or equal to the bit width of 'e'";
+                        << "'offset' + 'count' must be less than or equal to the bit width of 'e'";
                     if (use_runtime_semantics_) {
                         o = std::min(o, w);
                         c = std::min(c, w - o);
@@ -2896,7 +2896,7 @@
                 constexpr UT w = sizeof(UT) * 8;
                 if (o > w || c > w || (o + c) > w) {
                     AddError(source)
-                        << "'offset + 'count' must be less than or equal to the bit width of 'e'";
+                        << "'offset' + 'count' must be less than or equal to the bit width of 'e'";
                     if (use_runtime_semantics_) {
                         o = std::min(o, w);
                         c = std::min(c, w - o);
diff --git a/src/tint/lang/core/constant/eval_binary_op_test.cc b/src/tint/lang/core/constant/eval_binary_op_test.cc
index 812d633..ef9f163 100644
--- a/src/tint/lang/core/constant/eval_binary_op_test.cc
+++ b/src/tint/lang/core/constant/eval_binary_op_test.cc
@@ -2207,7 +2207,7 @@
 
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(),
-              "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'");
+              "12:34 error: 'offset' + 'count' must be less than or equal to the bit width of 'e'");
 }
 
 TEST_F(ConstEvalTest, ShortCircuit_And_Error_BuiltinCall) {
@@ -2255,7 +2255,7 @@
 
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(),
-              "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'");
+              "12:34 error: 'offset' + 'count' must be less than or equal to the bit width of 'e'");
 }
 
 TEST_F(ConstEvalTest, ShortCircuit_Or_Error_BuiltinCall) {
diff --git a/src/tint/lang/core/constant/eval_builtin_test.cc b/src/tint/lang/core/constant/eval_builtin_test.cc
index 4eb8152..0eaf4d2 100644
--- a/src/tint/lang/core/constant/eval_builtin_test.cc
+++ b/src/tint/lang/core/constant/eval_builtin_test.cc
@@ -1374,7 +1374,7 @@
     };
 
     const char* error_msg =
-        "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'";
+        "12:34 error: 'offset' + 'count' must be less than or equal to the bit width of 'e'";
     ConcatInto(  //
         r, std::vector<Case>{
                E({T(1), T(1), UT(33), UT(0)}, error_msg),         //
@@ -1582,7 +1582,7 @@
     };
 
     const char* error_msg =
-        "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'";
+        "12:34 error: 'offset' + 'count' must be less than or equal to the bit width of 'e'";
     ConcatInto(  //
         r, std::vector<Case>{
                E({T(1), UT(33), UT(0)}, error_msg),
diff --git a/src/tint/lang/core/constant/eval_runtime_semantics_test.cc b/src/tint/lang/core/constant/eval_runtime_semantics_test.cc
index ab244d5..be30634 100644
--- a/src/tint/lang/core/constant/eval_runtime_semantics_test.cc
+++ b/src/tint/lang/core/constant/eval_runtime_semantics_test.cc
@@ -372,7 +372,7 @@
     ASSERT_EQ(result, Success);
     EXPECT_EQ(result.Get()->ValueAs<i32>(), 0x12);
     EXPECT_EQ(error(),
-              R"(warning: 'offset + 'count' must be less than or equal to the bit width of 'e')");
+              R"(warning: 'offset' + 'count' must be less than or equal to the bit width of 'e')");
 }
 
 TEST_F(ConstEvalRuntimeSemanticsTest, ExtractBits_U32_TooManyBits) {
@@ -383,7 +383,7 @@
     ASSERT_EQ(result, Success);
     EXPECT_EQ(result.Get()->ValueAs<u32>(), 0x12);
     EXPECT_EQ(error(),
-              R"(warning: 'offset + 'count' must be less than or equal to the bit width of 'e')");
+              R"(warning: 'offset' + 'count' must be less than or equal to the bit width of 'e')");
 }
 
 TEST_F(ConstEvalRuntimeSemanticsTest, InsertBits_I32_TooManyBits) {
@@ -395,7 +395,7 @@
     ASSERT_EQ(result, Success);
     EXPECT_EQ(result.Get()->ValueAs<i32>(), 0x12345678);
     EXPECT_EQ(error(),
-              R"(warning: 'offset + 'count' must be less than or equal to the bit width of 'e')");
+              R"(warning: 'offset' + 'count' must be less than or equal to the bit width of 'e')");
 }
 
 TEST_F(ConstEvalRuntimeSemanticsTest, InsertBits_U32_TooManyBits) {
@@ -407,7 +407,7 @@
     ASSERT_EQ(result, Success);
     EXPECT_EQ(result.Get()->ValueAs<u32>(), 0x12345678);
     EXPECT_EQ(error(),
-              R"(warning: 'offset + 'count' must be less than or equal to the bit width of 'e')");
+              R"(warning: 'offset' + 'count' must be less than or equal to the bit width of 'e')");
 }
 
 TEST_F(ConstEvalRuntimeSemanticsTest, InverseSqrt_F32_OutOfRange) {
diff --git a/src/tint/lang/wgsl/resolver/builtins_validation_test.cc b/src/tint/lang/wgsl/resolver/builtins_validation_test.cc
index 2c3075f..25d4dc1 100644
--- a/src/tint/lang/wgsl/resolver/builtins_validation_test.cc
+++ b/src/tint/lang/wgsl/resolver/builtins_validation_test.cc
@@ -1709,5 +1709,119 @@
                          SmoothstepPartialConst,
                          ::testing::ValuesIn(smoothstepCases()));
 
+// We'll construct cases like this:
+// fn foo() {
+//   var e: ETYPE;
+//   _ = insertBits(e, e, COUNT, OFFSET);
+// }
+
+struct InsertBitsPartialConstCase {
+    builder::ast_type_func_ptr eType;
+    ExprMaker makeOffset;
+    ExprMaker makeCount;
+    bool expectPass = true;
+    int width = 32;
+};
+
+using InsertBitsPartialConst =
+    ResolverBuiltinsValidationTestWithParams<std::tuple<InsertBitsPartialConstCase, bool, bool>>;
+
+TEST_P(InsertBitsPartialConst, Scalar) {
+    auto [params, firstConst, secondConst] = GetParam();
+    auto eTy = params.eType(*this);
+    const ast::Expression* offset = params.makeOffset(this);
+    const ast::Expression* count = params.makeCount(this);
+    const ast::Variable* offsetDecl;
+    if (firstConst) {
+        offsetDecl = Const("offset", offset);
+    } else {
+        offsetDecl = Var("offset", offset);
+    }
+    const ast::Variable* countDecl;
+    if (secondConst) {
+        countDecl = Const("count", count);
+    } else {
+        countDecl = Var("count", count);
+    }
+    WrapInFunction(Var("e", eTy), offsetDecl, countDecl,
+                   Ignore(Call(Source{{12, 34}}, "insertBits", "e", "e", "offset", "count")));
+
+    const auto expectPass = params.expectPass || !(firstConst && secondConst);
+
+    if (expectPass) {
+        EXPECT_TRUE(r()->Resolve());
+    } else {
+        EXPECT_FALSE(r()->Resolve());
+        const std::string expect =
+            "12:34 error: 'offset' + 'count' must be less than or equal to the bit width of 'e'";
+        EXPECT_EQ(r()->error(), expect);
+    }
+}
+
+TEST_P(InsertBitsPartialConst, Vector) {
+    auto [params, firstConst, secondConst] = GetParam();
+    auto eTy = params.eType(*this);
+    const ast::Expression* offset = params.makeOffset(this);
+    const ast::Expression* count = params.makeCount(this);
+    const ast::Variable* offsetDecl;
+    if (firstConst) {
+        offsetDecl = Const("offset", offset);
+    } else {
+        offsetDecl = Var("offset", offset);
+    }
+    const ast::Variable* countDecl;
+    if (secondConst) {
+        countDecl = Const("count", count);
+    } else {
+        countDecl = Var("count", count);
+    }
+    WrapInFunction(Var("e", eTy), offsetDecl, countDecl,            //
+                   Ignore(Call(Source{{12, 34}}, "insertBits",      //
+                               Call(Ident("vec3"), "e", "e", "e"),  //
+                               Call(Ident("vec3"), "e", "e", "e"),  //
+                               "offset", "count")));
+
+    const auto expectPass = params.expectPass || !(firstConst && secondConst);
+
+    if (expectPass) {
+        EXPECT_TRUE(r()->Resolve());
+    } else {
+        EXPECT_FALSE(r()->Resolve());
+        const std::string expect =
+            "12:34 error: 'offset' + 'count' must be less than or equal to the bit width of 'e'";
+        EXPECT_EQ(r()->error(), expect);
+    }
+}
+
+std::vector<InsertBitsPartialConstCase> insertBitsCases() {
+    return std::vector<InsertBitsPartialConstCase>{
+        // Simple passing cases.
+        {DataType<u32>::AST, Mk(0_a), Mk(0_a), true},
+        {DataType<i32>::AST, Mk(0_a), Mk(0_a), true},
+        {DataType<u32>::AST, Mk(16_a), Mk(16_a), true},
+        {DataType<i32>::AST, Mk(16_a), Mk(16_a), true},
+        {DataType<u32>::AST, Mk(32_a), Mk(0_a), true},
+        {DataType<i32>::AST, Mk(32_a), Mk(0_a), true},
+        {DataType<u32>::AST, Mk(0_a), Mk(0_a), true},
+        {DataType<i32>::AST, Mk(0_a), Mk(0_a), true},
+        {DataType<u32>::AST, Mk(32_a), Mk(0_u), true},
+        {DataType<i32>::AST, Mk(32_u), Mk(0_a), true},
+
+        //  AInt AInt
+        {DataType<u32>::AST, Mk(0_a), Mk(33_a), false},
+        {DataType<u32>::AST, Mk(16_a), Mk(17_a), false},
+        {DataType<u32>::AST, Mk(33_a), Mk(0_a), false},
+
+        {DataType<i32>::AST, Mk(0_a), Mk(33_u), false},   // Aint u32
+        {DataType<i32>::AST, Mk(16_u), Mk(17_a), false},  // u32 AInt
+    };
+}
+
+INSTANTIATE_TEST_SUITE_P(InsertBits,
+                         InsertBitsPartialConst,
+                         ::testing::Combine(::testing::ValuesIn(insertBitsCases()),
+                                            ::testing::ValuesIn({true}),
+                                            ::testing::ValuesIn({true})));
+
 }  // namespace
 }  // namespace tint::resolver
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 954de2d..79069eb 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -2533,6 +2533,19 @@
                 }
             }
         } break;
+        case wgsl::BuiltinFn::kInsertBits: {
+            auto* offsetConst = ConvertConstArgument(args, target, 2);
+            auto* countConst = ConvertConstArgument(args, target, 3);
+            if (offsetConst && countConst) {
+                auto* zero = const_eval_.Zero(call->Type(), {}, Source{}).Get();
+                auto fakeArgs = Vector{zero, zero, offsetConst, countConst};
+                auto res =
+                    const_eval_.insertBits(call->Type(), fakeArgs, call->Declaration()->source);
+                if (res != Success) {
+                    return nullptr;
+                }
+            }
+        } break;
         case wgsl::BuiltinFn::kSmoothstep: {
             auto* lowConst = ConvertConstArgument(args, target, 0);
             auto* highConst = ConvertConstArgument(args, target, 1);