[tint][ir][fuzz] Prevent encoding a binary that won't decode

Adds plumbing to get errors out of IR binary encoding that are not
ICEs. This is then used in the roundtrip fuzzer to reject inputs that
won't decode correctly due to internal limits.

Fixes: 375220551
Change-Id: I775c150f867124b3e30cc2161645134e88b2c625
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/212314
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/cmd/fuzz/ir/as/main.cc b/src/tint/cmd/fuzz/ir/as/main.cc
index b81db3d..b0e6156 100644
--- a/src/tint/cmd/fuzz/ir/as/main.cc
+++ b/src/tint/cmd/fuzz/ir/as/main.cc
@@ -199,7 +199,11 @@
     tint::cmd::fuzz::ir::pb::Root fuzz_pb;
     {
         auto ir_pb = tint::core::ir::binary::EncodeToProto(module.Get());
-        fuzz_pb.set_allocated_module(ir_pb.release());
+        if (ir_pb != tint::Success) {
+            std::cerr << " Failed to encode IR to proto: " << ir_pb.Failure() << "\n";
+            return tint::Failure();
+        }
+        fuzz_pb.set_allocated_module(ir_pb.Get().release());
     }
 
     return std::move(fuzz_pb);
diff --git a/src/tint/lang/core/ir/binary/encode.cc b/src/tint/lang/core/ir/binary/encode.cc
index 561fb09..8471356 100644
--- a/src/tint/lang/core/ir/binary/encode.cc
+++ b/src/tint/lang/core/ir/binary/encode.cc
@@ -83,6 +83,7 @@
 #include "src/tint/lang/core/type/storage_texture.h"
 #include "src/tint/lang/core/type/u32.h"
 #include "src/tint/lang/core/type/void.h"
+#include "src/tint/utils/constants/internal_limits.h"
 #include "src/tint/utils/macros/compiler.h"
 #include "src/tint/utils/rtti/switch.h"
 
@@ -95,13 +96,16 @@
 struct Encoder {
     const Module& mod_in_;
     pb::Module& mod_out_;
+
     Hashmap<const core::ir::Function*, uint32_t, 32> functions_{};
     Hashmap<const core::ir::Block*, uint32_t, 32> blocks_{};
     Hashmap<const core::type::Type*, uint32_t, 32> types_{};
     Hashmap<const core::ir::Value*, uint32_t, 32> values_{};
     Hashmap<const core::constant::Value*, uint32_t, 32> constant_values_{};
 
-    void Encode() {
+    diag::List diags_{};
+
+    Result<SuccessType> Encode() {
         // Encode all user-declared structures first. This is to ensure that the IR disassembly
         // (which prints structure types first) does not reorder after encoding and decoding.
         for (auto* ty : mod_in_.Types()) {
@@ -119,8 +123,16 @@
             PopulateFunction(fns_out[i], mod_in_.functions[i]);
         }
         mod_out_.set_root_block(Block(mod_in_.root_block));
+
+        if (diags_.ContainsErrors()) {
+            return Failure{std::move(diags_)};
+        }
+        return Success;
     }
 
+    /// Adds a new error to the diagnostics and returns a reference to it
+    diag::Diagnostic& Error() { return diags_.AddError(Source{}); }
+
     ////////////////////////////////////////////////////////////////////////////
     // Functions
     ////////////////////////////////////////////////////////////////////////////
@@ -477,7 +489,13 @@
         array_out.set_stride(array_in->Stride());
         tint::Switch(
             array_in->Count(),  //
-            [&](const core::type::ConstantArrayCount* c) { array_out.set_count(c->value); },
+            [&](const core::type::ConstantArrayCount* c) {
+                array_out.set_count(c->value);
+                if (c->value >= internal_limits::kMaxArrayElementCount) {
+                    Error() << "array count (" << c->value << ") must be less than "
+                            << internal_limits::kMaxArrayElementCount;
+                }
+            },
             [&](const core::type::RuntimeArrayCount*) { array_out.set_count(0); },
             TINT_ICE_ON_NO_MATCH);
     }
@@ -647,6 +665,10 @@
     void ConstantValueSplat(pb::ConstantValueSplat& splat_out,
                             const core::constant::Splat* splat_in) {
         splat_out.set_type(Type(splat_in->type));
+        if (DAWN_UNLIKELY(splat_in->count > internal_limits::kMaxArrayConstructorElements)) {
+            Error() << "array constructor has excessive number of elements (>"
+                    << internal_limits::kMaxArrayConstructorElements << ")";
+        }
         splat_out.set_elements(ConstantValue(splat_in->el));
         splat_out.set_count(static_cast<uint32_t>(splat_in->count));
     }
@@ -1220,23 +1242,29 @@
 
 }  // namespace
 
-std::unique_ptr<pb::Module> EncodeToProto(const Module& mod_in) {
+Result<std::unique_ptr<pb::Module>> EncodeToProto(const Module& mod_in) {
     GOOGLE_PROTOBUF_VERIFY_VERSION;
 
     pb::Module mod_out;
-    Encoder{mod_in, mod_out}.Encode();
+    auto res = Encoder{mod_in, mod_out}.Encode();
+    if (res != Success) {
+        return res.Failure();
+    }
 
     return std::make_unique<pb::Module>(mod_out);
 }
 
 Result<Vector<std::byte, 0>> EncodeToBinary(const Module& mod_in) {
     auto mod_out = EncodeToProto(mod_in);
+    if (mod_out != Success) {
+        return mod_out.Failure();
+    }
 
     Vector<std::byte, 0> buffer;
-    size_t len = mod_out->ByteSizeLong();
+    size_t len = mod_out.Get()->ByteSizeLong();
     buffer.Resize(len);
     if (len > 0) {
-        if (!mod_out->SerializeToArray(&buffer[0], static_cast<int>(len))) {
+        if (!mod_out.Get()->SerializeToArray(&buffer[0], static_cast<int>(len))) {
             return Failure{"failed to serialize protobuf"};
         }
     }
diff --git a/src/tint/lang/core/ir/binary/encode.h b/src/tint/lang/core/ir/binary/encode.h
index 3920542..82e94f6 100644
--- a/src/tint/lang/core/ir/binary/encode.h
+++ b/src/tint/lang/core/ir/binary/encode.h
@@ -46,7 +46,7 @@
 namespace tint::core::ir::binary {
 
 // Encode the module into a proto representation.
-std::unique_ptr<pb::Module> EncodeToProto(const Module& module);
+Result<std::unique_ptr<pb::Module>> EncodeToProto(const Module& module);
 
 // Encode the module into a binary representation.
 Result<Vector<std::byte, 0>> EncodeToBinary(const Module& module);
diff --git a/src/tint/lang/core/ir/binary/roundtrip_fuzz.cc b/src/tint/lang/core/ir/binary/roundtrip_fuzz.cc
index 5622e09..17cfe54 100644
--- a/src/tint/lang/core/ir/binary/roundtrip_fuzz.cc
+++ b/src/tint/lang/core/ir/binary/roundtrip_fuzz.cc
@@ -29,6 +29,7 @@
 #include "src/tint/lang/core/ir/binary/decode.h"
 #include "src/tint/lang/core/ir/binary/encode.h"
 #include "src/tint/lang/core/ir/disassembler.h"
+#include "src/tint/lang/core/ir/validator.h"
 
 namespace tint::core::ir::binary {
 namespace {
@@ -36,7 +37,11 @@
 void IRBinaryRoundtripFuzzer(core::ir::Module& module) {
     auto encoded = EncodeToBinary(module);
     if (encoded != Success) {
-        TINT_ICE() << "Encode() failed\n" << encoded.Failure();
+        // Failing to encode, not ICE'ing, indicates that an internal limit to the IR binary
+        // encoding/decoding logic was hit. Due to differences between the AST and IR
+        // implementations, there exist corner cases where these internal limits are hit for IR,
+        // but not AST.
+        return;
     }
 
     auto decoded = Decode(encoded->Slice());