[tint][ir][fuzz] Prevent fuzzer from crashing on malformed accesses - Adds support for variable length operands to checking logic Fixes: 353259704 Change-Id: I22dd9d03419bff9ae551b284f0d72ea1be13a754 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/198560 Reviewed-by: dan sinclair <dsinclair@chromium.org> Commit-Queue: Ryan Harrison <rharrison@chromium.org> Auto-Submit: Ryan Harrison <rharrison@chromium.org>
diff --git a/src/tint/lang/core/ir/access.h b/src/tint/lang/core/ir/access.h index 55a332c..0df6ee4 100644 --- a/src/tint/lang/core/ir/access.h +++ b/src/tint/lang/core/ir/access.h
@@ -44,6 +44,12 @@ /// The base offset in Operands() for the access indices static constexpr size_t kIndicesOperandOffset = 1; + /// The fixed number of results returned by this instruction + static constexpr size_t kNumResults = 1; + + /// The minimum number of operands used by this instruction + static constexpr size_t kMinNumOperands = 1; + /// Constructor (no results, no operands) Access();
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc index d3b7ad0..e8b1694 100644 --- a/src/tint/lang/core/ir/validator.cc +++ b/src/tint/lang/core/ir/validator.cc
@@ -316,6 +316,16 @@ /// @returns true if the operand is not null bool CheckOperandNotNull(const ir::Instruction* inst, size_t idx); + /// Checks the number of operands provided to @p inst and that none of them are null. + /// @param inst the instruction + /// @param min_count the minimum number of operands to expect + /// @param max_count the maximum number of operands to expect, if not set, than only the minimum + /// number is checked. + /// @returns true if the number of operands is in the expected range and none are null + bool CheckOperands(const ir::Instruction* inst, + size_t min_count, + std::optional<size_t> max_count); + /// Checks the number of operands for @p inst are exactly equal to @p count and that none of /// them are null. /// @param inst the instruction @@ -323,6 +333,19 @@ /// @returns true if the operands count is as expected and none are null bool CheckOperands(const ir::Instruction* inst, size_t count); + /// Checks the number of results for @p inst are exactly equal to @p num_results and the number + /// of operands is correctly. Both results and operands are confirmed to be non-null. + /// @param inst the instruction + /// @param num_results expected number of results for the instruction + /// @param min_operands the minimum number of operands to expect + /// @param max_operands the maximum number of operands to expect, if not set, than only the + /// minimum number is checked. + /// @returns true if the result and operand counts are as expected and none are null + bool CheckResultsAndOperandRange(const ir::Instruction* inst, + size_t num_results, + size_t min_operands, + std::optional<size_t> max_operands); + /// Checks the number of results and operands for @p inst are exactly equal to num_results /// and num_operands, respectively, and that none of them are null. /// @param inst the instruction @@ -802,6 +825,35 @@ return true; } +bool Validator::CheckOperands(const ir::Instruction* inst, + size_t min_count, + std::optional<size_t> max_count) { + if (TINT_UNLIKELY(inst->Operands().Length() < min_count)) { + if (max_count.has_value()) { + AddError(inst) << "expected between " << min_count << " and " << max_count.value() + << " operands, got " << inst->Operands().Length(); + } else { + AddError(inst) << "expected at least " << min_count << " operands, got " + << inst->Operands().Length(); + } + return false; + } + + if (TINT_UNLIKELY(max_count.has_value() && inst->Operands().Length() > max_count.value())) { + AddError(inst) << "expected between " << min_count << " and " << max_count.value() + << " operands, got " << inst->Operands().Length(); + return false; + } + + bool passed = true; + for (size_t i = 0; i < inst->Operands().Length(); i++) { + if (TINT_UNLIKELY(!CheckOperandNotNull(inst, i))) { + passed = false; + } + } + return passed; +} + bool Validator::CheckOperands(const ir::Instruction* inst, size_t count) { if (TINT_UNLIKELY(inst->Operands().Length() != count)) { AddError(inst) << "expected exactly " << count << " operands, got " @@ -818,6 +870,16 @@ return passed; } +bool Validator::CheckResultsAndOperandRange(const ir::Instruction* inst, + size_t num_results, + size_t min_operands, + std::optional<size_t> max_operands = {}) { + // Intentionally avoiding short-circuiting here + bool results_passed = CheckResults(inst, num_results); + bool operands_passed = CheckOperands(inst, min_operands, max_operands); + return results_passed && operands_passed; +} + bool Validator::CheckResultsAndOperands(const ir::Instruction* inst, size_t num_results, size_t num_operands) { @@ -1200,8 +1262,7 @@ } void Validator::CheckAccess(const Access* a) { - if (!a->Object()) { - AddError(a, Access::kObjectOperandOffset) << "null object"; + if (!CheckResultsAndOperandRange(a, Access::kNumResults, Access::kMinNumOperands)) { return; } @@ -1670,7 +1731,7 @@ } void Validator::CheckLoad(const Load* l) { - if (TINT_UNLIKELY(!CheckResultsAndOperands(l, Load::kNumResults, Load::kNumOperands))) { + if (!CheckResultsAndOperands(l, Load::kNumResults, Load::kNumOperands)) { return; } @@ -1690,7 +1751,7 @@ } void Validator::CheckStore(const Store* s) { - if (TINT_UNLIKELY(!CheckResultsAndOperands(s, Store::kNumResults, Store::kNumOperands))) { + if (!CheckResultsAndOperands(s, Store::kNumResults, Store::kNumOperands)) { return; }
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc index dc96213..8aa1830 100644 --- a/src/tint/lang/core/ir/validator_test.cc +++ b/src/tint/lang/core/ir/validator_test.cc
@@ -849,22 +849,23 @@ )"); } -TEST_F(IR_ValidatorTest, Access_NoObject) { +TEST_F(IR_ValidatorTest, Access_NoOperands) { auto* f = b.Function("my_func", ty.void_()); auto* obj = b.FunctionParam(ty.vec3<f32>()); f->SetParams({obj}); b.Append(f->Block(), [&] { - b.Access(ty.f32(), nullptr); + auto* access = b.Access(ty.f32(), obj, 0_i); + access->ClearOperands(); b.Return(f); }); auto res = ir::Validate(mod); ASSERT_NE(res, Success); EXPECT_EQ(res.Failure().reason.Str(), - R"(:3:21 error: access: null object - %3:f32 = access undef - ^^^^^ + R"(:3:14 error: access: expected at least 1 operands, got 0 + %3:f32 = access + ^^^^^^ :2:3 note: in block $B1: { @@ -873,7 +874,98 @@ note: # Disassembly %my_func = func(%2:vec3<f32>):void { $B1: { - %3:f32 = access undef + %3:f32 = access + ret + } +} +)"); +} + +TEST_F(IR_ValidatorTest, Access_NoResults) { + auto* f = b.Function("my_func", ty.void_()); + auto* obj = b.FunctionParam(ty.vec3<f32>()); + f->SetParams({obj}); + + b.Append(f->Block(), [&] { + auto* access = b.Access(ty.f32(), obj, 0_i); + access->ClearResults(); + b.Return(f); + }); + + auto res = ir::Validate(mod); + ASSERT_NE(res, Success); + EXPECT_EQ(res.Failure().reason.Str(), + R"(:3:13 error: access: expected exactly 1 results, got 0 + undef = access %2, 0i + ^^^^^^ + +:2:3 note: in block + $B1: { + ^^^ + +note: # Disassembly +%my_func = func(%2:vec3<f32>):void { + $B1: { + undef = access %2, 0i + ret + } +} +)"); +} + +TEST_F(IR_ValidatorTest, Access_NullObject) { + auto* f = b.Function("my_func", ty.void_()); + b.Append(f->Block(), [&] { + b.Access(ty.f32(), nullptr); + b.Return(f); + }); + + auto res = ir::Validate(mod); + ASSERT_NE(res, Success); + EXPECT_EQ(res.Failure().reason.Str(), + R"(:3:21 error: access: operand is undefined + %2:f32 = access undef + ^^^^^ + +:2:3 note: in block + $B1: { + ^^^ + +note: # Disassembly +%my_func = func():void { + $B1: { + %2:f32 = access undef + ret + } +} +)"); +} + +TEST_F(IR_ValidatorTest, Access_NullIndex) { + auto* f = b.Function("my_func", ty.void_()); + auto* obj = b.FunctionParam(ty.vec3<f32>()); + f->SetParams({obj}); + + b.Append(f->Block(), [&] { + b.Access(ty.f32(), obj, nullptr); + b.Return(f); + }); + + auto res = ir::Validate(mod); + ASSERT_NE(res, Success); + EXPECT_EQ(res.Failure().reason.Str(), + R"(:3:25 error: access: operand is undefined + %3:f32 = access %2, undef + ^^^^^ + +:2:3 note: in block + $B1: { + ^^^ + +note: # Disassembly +%my_func = func(%2:vec3<f32>):void { + $B1: { + %3:f32 = access %2, undef ret } }