reader/spirv: Return early if MakeExpression() fails
Bug: tint:762
Change-Id: If17e666cf459117c52ece73cb20ea8f70ed1fcd5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49531
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 1955a27..5e206bb 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -2353,7 +2353,9 @@
const auto condition_id =
block_info.basic_block->terminator()->GetSingleWordInOperand(0);
auto* cond = MakeExpression(condition_id).expr;
-
+ if (!cond) {
+ return false;
+ }
// Generate the code for the condition.
auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
@@ -2484,7 +2486,9 @@
const auto selector_id = branch->GetSingleWordInOperand(0);
// Generate the code for the selector.
auto selector = MakeExpression(selector_id);
-
+ if (!selector.expr) {
+ return false;
+ }
// First, push the statement block for the entire switch.
auto* swch = AddStatementBuilder<SwitchStatementBuilder>(selector.expr);
@@ -2620,6 +2624,9 @@
return true;
case SpvOpReturnValue: {
auto value = MakeExpression(terminator.GetSingleWordInOperand(0));
+ if (!value.expr) {
+ return false;
+ }
AddStatement(create<ast::ReturnStatement>(Source{}, value.expr));
}
return true;
@@ -2663,6 +2670,9 @@
auto* const true_info = GetBlockInfo(true_dest);
auto* const false_info = GetBlockInfo(false_dest);
auto* cond = MakeExpression(terminator.GetSingleWordInOperand(0)).expr;
+ if (!cond) {
+ return false;
+ }
// We have two distinct destinations. But we only get here if this
// is a normal terminator; in particular the source block is *not* the
@@ -2931,6 +2941,9 @@
for (auto assignment : block_info.phi_assignments) {
const auto var_name = GetDefInfo(assignment.phi_id)->phi_var;
auto expr = MakeExpression(assignment.value);
+ if (!expr.expr) {
+ return false;
+ }
AddStatement(create<ast::AssignmentStatement>(
Source{},
create<ast::IdentifierExpression>(
@@ -3047,7 +3060,18 @@
auto ptr_id = inst.GetSingleWordInOperand(0);
const auto value_id = inst.GetSingleWordInOperand(1);
- auto rhs = MakeExpression(value_id);
+ const auto ptr_type_id = def_use_mgr_->GetDef(ptr_id)->type_id();
+ const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
+ if (ptr_type_id == builtin_position_info.pointer_type_id) {
+ return Fail()
+ << "storing to the whole per-vertex structure is not supported: "
+ << inst.PrettyPrint();
+ }
+
+ TypedExpression rhs = MakeExpression(value_id);
+ if (!rhs.expr) {
+ return false;
+ }
// Handle exceptional cases
switch (GetSkipReason(ptr_id)) {
@@ -3079,16 +3103,12 @@
break;
}
- const auto ptr_type_id = def_use_mgr_->GetDef(ptr_id)->type_id();
- const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
- if (ptr_type_id == builtin_position_info.pointer_type_id) {
- return Fail()
- << "storing to the whole per-vertex structure is not supported: "
- << inst.PrettyPrint();
- }
-
// Handle an ordinary store as an assignment.
auto lhs = MakeExpression(ptr_id);
+ if (!lhs.expr) {
+ return false;
+ }
+
AddStatement(
create<ast::AssignmentStatement>(Source{}, lhs.expr, rhs.expr));
return success();
@@ -3144,6 +3164,10 @@
break;
}
auto expr = MakeExpression(ptr_id);
+ if (!expr.expr) {
+ return false;
+ }
+
// The load result type is the pointee type of its operand.
TINT_ASSERT(expr.type->Is<sem::Pointer>());
expr.type = expr.type->As<sem::Pointer>()->type();
@@ -3161,6 +3185,9 @@
return true;
}
auto expr = MakeExpression(value_id);
+ if (!expr.type || !expr.expr) {
+ return false;
+ }
expr.type = RemapStorageClass(expr.type, result_id);
return EmitConstDefOrWriteToHoistedVar(inst, expr);
}
@@ -3231,6 +3258,9 @@
const spvtools::opt::Instruction& inst,
uint32_t operand_index) {
auto expr = this->MakeExpression(inst.GetSingleWordInOperand(operand_index));
+ if (!expr.expr) {
+ return {};
+ }
return parser_impl_.RectifyOperandSignedness(inst, std::move(expr));
}
@@ -3823,13 +3853,21 @@
for (uint32_t i = 2; i < inst.NumInOperands(); ++i) {
const auto index = inst.GetSingleWordInOperand(i);
if (index < vec0_len) {
+ auto expr = MakeExpression(vec0_id);
+ if (!expr.expr) {
+ return {};
+ }
values.emplace_back(create<ast::MemberAccessorExpression>(
- source, MakeExpression(vec0_id).expr, Swizzle(index)));
+ source, expr.expr, Swizzle(index)));
} else if (index < vec0_len + vec1_len) {
const auto sub_index = index - vec0_len;
TINT_ASSERT(sub_index < kMaxVectorLen);
+ auto expr = MakeExpression(vec1_id);
+ if (!expr.expr) {
+ return {};
+ }
values.emplace_back(create<ast::MemberAccessorExpression>(
- source, MakeExpression(vec1_id).expr, Swizzle(sub_index)));
+ source, expr.expr, Swizzle(sub_index)));
} else if (index == 0xFFFFFFFF) {
// By rule, this maps to OpUndef. Instead, make it zero.
values.emplace_back(parser_impl_.MakeNullValue(result_type->type()));
@@ -4954,12 +4992,17 @@
if (field_name.empty()) {
Fail() << "struct index out of bounds for array length: "
<< inst.PrettyPrint();
+ return {};
}
auto* member_ident = create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register(field_name));
+ auto member_expr = MakeExpression(struct_ptr_id);
+ if (!member_expr.expr) {
+ return {};
+ }
auto* member_access = create<ast::MemberAccessorExpression>(
- Source{}, MakeExpression(struct_ptr_id).expr, member_ident);
+ Source{}, member_expr.expr, member_ident);
// Generate the intrinsic function call.
std::string call_ident_str = "arrayLength";