tint: fix insertBits edge case

If count is highest and offset is non-zero, or vice-versa, we'd overflow
the count + offset > bit-width check. This CL fixes this case.

Also folded in error tests into extractBits and insertBits.

Bug: tint:1581
Bug: chromium:53440
Change-Id: Id1e9e737b8076e8075da5992a41d18b6b7c8afd4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110482
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 8dc4728..a7f5d7d 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -2092,19 +2092,19 @@
             NumberUT in_offset = args[2]->As<NumberUT>();
             NumberUT in_count = args[3]->As<NumberUT>();
 
-            constexpr UT w = sizeof(UT) * 8;
-            if ((in_offset + in_count) > w) {
-                AddError("'offset + 'count' must be less than or equal to the bit width of 'e'",
-                         source);
-                return utils::Failure;
-            }
-
             // Cast all to unsigned
             UT e = static_cast<UT>(in_e);
             UT newbits = static_cast<UT>(in_newbits);
             UT o = static_cast<UT>(in_offset);
             UT c = static_cast<UT>(in_count);
 
+            constexpr UT w = sizeof(UT) * 8;
+            if (o > w || c > w || (o + c) > w) {
+                AddError("'offset + 'count' must be less than or equal to the bit width of 'e'",
+                         source);
+                return utils::Failure;
+            }
+
             NumberT result;
             if (c == UT{0}) {
                 // The result is e if c is 0
diff --git a/src/tint/resolver/const_eval_builtin_test.cc b/src/tint/resolver/const_eval_builtin_test.cc
index 69a30a4..cc950ab 100644
--- a/src/tint/resolver/const_eval_builtin_test.cc
+++ b/src/tint/resolver/const_eval_builtin_test.cc
@@ -1150,6 +1150,26 @@
               T(0b1010'0101'1010'0101'1010'0111'1111'1101))),
     };
 
+    const char* error_msg =
+        "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'";
+    ConcatInto(  //
+        r, std::vector<Case>{
+               E({T(1), T(1), UT(33), UT(0)}, error_msg),         //
+               E({T(1), T(1), UT(34), UT(0)}, error_msg),         //
+               E({T(1), T(1), UT(1000), UT(0)}, error_msg),       //
+               E({T(1), T(1), UT::Highest(), UT()}, error_msg),   //
+               E({T(1), T(1), UT(0), UT(33)}, error_msg),         //
+               E({T(1), T(1), UT(0), UT(34)}, error_msg),         //
+               E({T(1), T(1), UT(0), UT(1000)}, error_msg),       //
+               E({T(1), T(1), UT(0), UT::Highest()}, error_msg),  //
+               E({T(1), T(1), UT(33), UT(33)}, error_msg),        //
+               E({T(1), T(1), UT(34), UT(34)}, error_msg),        //
+               E({T(1), T(1), UT(1000), UT(1000)}, error_msg),    //
+               E({T(1), T(1), UT::Highest(), UT(1)}, error_msg),
+               E({T(1), T(1), UT(1), UT::Highest()}, error_msg),
+               E({T(1), T(1), UT::Highest(), u32::Highest()}, error_msg),
+           });
+
     return r;
 }
 INSTANTIATE_TEST_SUITE_P(  //
@@ -1253,6 +1273,26 @@
               set_msbs_if_signed(T(0b11010001)))),
     };
 
+    const char* error_msg =
+        "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'";
+    ConcatInto(  //
+        r, std::vector<Case>{
+               E({T(1), UT(33), UT(0)}, error_msg),
+               E({T(1), UT(34), UT(0)}, error_msg),
+               E({T(1), UT(1000), UT(0)}, error_msg),
+               E({T(1), UT::Highest(), UT(0)}, error_msg),
+               E({T(1), UT(0), UT(33)}, error_msg),
+               E({T(1), UT(0), UT(34)}, error_msg),
+               E({T(1), UT(0), UT(1000)}, error_msg),
+               E({T(1), UT(0), UT::Highest()}, error_msg),
+               E({T(1), UT(33), UT(33)}, error_msg),
+               E({T(1), UT(34), UT(34)}, error_msg),
+               E({T(1), UT(1000), UT(1000)}, error_msg),
+               E({T(1), UT::Highest(), UT(1)}, error_msg),
+               E({T(1), UT(1), UT::Highest()}, error_msg),
+               E({T(1), UT::Highest(), UT::Highest()}, error_msg),
+           });
+
     return r;
 }
 INSTANTIATE_TEST_SUITE_P(  //
@@ -1262,35 +1302,6 @@
                      testing::ValuesIn(Concat(ExtractBitsCases<i32>(),  //
                                               ExtractBitsCases<u32>()))));
 
-using ResolverConstEvalBuiltinTest_ExtractBits_InvalidOffsetAndCount =
-    ResolverTestWithParam<std::tuple<size_t, size_t>>;
-TEST_P(ResolverConstEvalBuiltinTest_ExtractBits_InvalidOffsetAndCount, Test) {
-    auto& p = GetParam();
-    auto* expr = Call(Source{{12, 34}}, sem::str(sem::BuiltinType::kExtractBits), Expr(1_u),
-                      Expr(u32(std::get<0>(p))), Expr(u32(std::get<1>(p))));
-    GlobalConst("C", expr);
-    EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(r()->error(),
-              "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'");
-}
-INSTANTIATE_TEST_SUITE_P(ExtractBits,
-                         ResolverConstEvalBuiltinTest_ExtractBits_InvalidOffsetAndCount,
-                         testing::Values(                         //
-                             std::make_tuple(33, 0),              //
-                             std::make_tuple(34, 0),              //
-                             std::make_tuple(1000, 0),            //
-                             std::make_tuple(u32::Highest(), 0),  //
-                             std::make_tuple(0, 33),              //
-                             std::make_tuple(0, 34),              //
-                             std::make_tuple(0, 1000),            //
-                             std::make_tuple(0, u32::Highest()),  //
-                             std::make_tuple(33, 33),             //
-                             std::make_tuple(34, 34),             //
-                             std::make_tuple(1000, 1000),         //
-                             std::make_tuple(u32::Highest(), 1),  //
-                             std::make_tuple(1, u32::Highest()),  //
-                             std::make_tuple(u32::Highest(), u32::Highest())));
-
 template <typename T>
 std::vector<Case> MaxCases() {
     return {