spirv-reader: Use GenerateExpressionWithLoadIfNeeded more
* Rename GenerateNonReferenceExpression to
GenerateExpressionWithLoadIfNeeded.
This version takes an ast::Expression
* Add a variant that takes a sem::Expression, because the sem
expression already knows the resolved type, and so we can save
a lookup.
* Replace most uses of GenerateExpression ... GenerateLoadIfNeeded
with a call to one of the above.
This is a non-functional change.
Followup to the fix in tint:1343.
Bug: tint:1343
Change-Id: If19a1bc7670edd2badc1533861d8b42f0825c7b8
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/72720
Auto-Submit: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index a0eba5a..04334a2 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -408,15 +408,10 @@
if (lhs_id == 0) {
return false;
}
- auto rhs_id = GenerateExpression(assign->rhs);
+ auto rhs_id = GenerateExpressionWithLoadIfNeeded(assign->rhs);
if (rhs_id == 0) {
return false;
}
-
- // If the thing we're assigning is a reference then we must load it first.
- auto* type = TypeOf(assign->rhs);
- rhs_id = GenerateLoadIfNeeded(type, rhs_id);
-
return GenerateStore(lhs_id, rhs_id);
}
}
@@ -706,14 +701,10 @@
bool Builder::GenerateFunctionVariable(const ast::Variable* var) {
uint32_t init_id = 0;
if (var->constructor) {
- init_id = GenerateExpression(var->constructor);
+ init_id = GenerateExpressionWithLoadIfNeeded(var->constructor);
if (init_id == 0) {
return false;
}
- auto* type = TypeOf(var->constructor);
- if (type->Is<sem::Reference>()) {
- init_id = GenerateLoadIfNeeded(type, init_id);
- }
}
if (var->is_const) {
@@ -914,12 +905,10 @@
bool Builder::GenerateIndexAccessor(const ast::IndexAccessorExpression* expr,
AccessorInfo* info) {
- auto idx_id = GenerateExpression(expr->index);
+ auto idx_id = GenerateExpressionWithLoadIfNeeded(expr->index);
if (idx_id == 0) {
return 0;
}
- auto* type = TypeOf(expr->index);
- idx_id = GenerateLoadIfNeeded(type, idx_id);
// If the source is a reference, we access chain into it.
// In the future, pointers may support access-chaining.
@@ -1183,8 +1172,19 @@
return val;
}
-uint32_t Builder::GenerateNonReferenceExpression(const ast::Expression* expr) {
+uint32_t Builder::GenerateExpressionWithLoadIfNeeded(
+ const sem::Expression* expr) {
+ // The semantic node directly knows both the AST node and the resolved type.
+ if (const auto id = GenerateExpression(expr->Declaration())) {
+ return GenerateLoadIfNeeded(expr->Type(), id);
+ }
+ return 0;
+}
+
+uint32_t Builder::GenerateExpressionWithLoadIfNeeded(
+ const ast::Expression* expr) {
if (const auto id = GenerateExpression(expr)) {
+ // Perform a lookup to get the resolved type.
return GenerateLoadIfNeeded(TypeOf(expr), id);
}
return 0;
@@ -1212,11 +1212,6 @@
auto result = result_op();
auto result_id = result.to_i();
- auto val_id = GenerateExpression(expr->expr);
- if (val_id == 0) {
- return 0;
- }
-
spv::Op op = spv::Op::OpNop;
switch (expr->op) {
case ast::UnaryOp::kComplement:
@@ -1237,10 +1232,13 @@
// Address-of converts a reference to a pointer, and dereference converts
// a pointer to a reference. These are the same thing in SPIR-V, so this
// is a no-op.
- return val_id;
+ return GenerateExpression(expr->expr);
}
- val_id = GenerateLoadIfNeeded(TypeOf(expr->expr), val_id);
+ auto val_id = GenerateExpressionWithLoadIfNeeded(expr->expr);
+ if (val_id == 0) {
+ return 0;
+ }
auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (type_id == 0) {
@@ -1380,11 +1378,7 @@
OperandList ops;
for (auto* e : args) {
uint32_t id = 0;
- id = GenerateExpression(e->Declaration());
- if (id == 0) {
- return 0;
- }
- id = GenerateLoadIfNeeded(e->Type(), id);
+ id = GenerateExpressionWithLoadIfNeeded(e);
if (id == 0) {
return 0;
}
@@ -1532,11 +1526,10 @@
return 0;
}
- auto val_id = GenerateExpression(from_expr);
+ auto val_id = GenerateExpressionWithLoadIfNeeded(from_expr);
if (val_id == 0) {
return 0;
}
- val_id = GenerateLoadIfNeeded(TypeOf(from_expr), val_id);
auto* from_type = TypeOf(from_expr)->UnwrapRef();
@@ -1804,11 +1797,10 @@
uint32_t Builder::GenerateShortCircuitBinaryExpression(
const ast::BinaryExpression* expr) {
- auto lhs_id = GenerateExpression(expr->lhs);
+ auto lhs_id = GenerateExpressionWithLoadIfNeeded(expr->lhs);
if (lhs_id == 0) {
return false;
}
- lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs), lhs_id);
// Get the ID of the basic block where control flow will diverge. It's the
// last basic block generated for the left-hand-side of the operator.
@@ -1848,11 +1840,10 @@
if (!GenerateLabel(block_id)) {
return 0;
}
- auto rhs_id = GenerateExpression(expr->rhs);
+ auto rhs_id = GenerateExpressionWithLoadIfNeeded(expr->rhs);
if (rhs_id == 0) {
return 0;
}
- rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs), rhs_id);
// Get the block ID of the last basic block generated for the right-hand-side
// expression. That block will be an immediate predecessor to the merge block.
@@ -1971,17 +1962,15 @@
return GenerateShortCircuitBinaryExpression(expr);
}
- auto lhs_id = GenerateExpression(expr->lhs);
+ auto lhs_id = GenerateExpressionWithLoadIfNeeded(expr->lhs);
if (lhs_id == 0) {
return 0;
}
- lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs), lhs_id);
- auto rhs_id = GenerateExpression(expr->rhs);
+ auto rhs_id = GenerateExpressionWithLoadIfNeeded(expr->rhs);
if (rhs_id == 0) {
return 0;
}
- rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs), rhs_id);
auto result = result_op();
auto result_id = result.to_i();
@@ -2258,11 +2247,7 @@
size_t arg_idx = 0;
for (auto* arg : expr->args) {
- auto id = GenerateExpression(arg);
- if (id == 0) {
- return 0;
- }
- id = GenerateLoadIfNeeded(TypeOf(arg), id);
+ auto id = GenerateExpressionWithLoadIfNeeded(arg);
if (id == 0) {
return 0;
}
@@ -2715,12 +2700,7 @@
// Generates the given expression, returning the operand ID
auto gen = [&](const sem::Expression* expr) {
- auto val_id = GenerateExpression(expr->Declaration());
- if (val_id == 0) {
- return Operand::Int(0);
- }
- val_id = GenerateLoadIfNeeded(expr->Type(), val_id);
-
+ const auto val_id = GenerateExpressionWithLoadIfNeeded(expr);
return Operand::Int(val_id);
};
@@ -3218,11 +3198,7 @@
uint32_t value_id = 0;
if (call->Arguments().size() > 1) {
- value_id = GenerateExpression(call->Arguments().back()->Declaration());
- if (value_id == 0) {
- return false;
- }
- value_id = GenerateLoadIfNeeded(call->Arguments().back()->Type(), value_id);
+ value_id = GenerateExpressionWithLoadIfNeeded(call->Arguments().back());
if (value_id == 0) {
return false;
}
@@ -3458,11 +3434,10 @@
return 0;
}
- auto val_id = GenerateExpression(expr->expr);
+ auto val_id = GenerateExpressionWithLoadIfNeeded(expr->expr);
if (val_id == 0) {
return 0;
}
- val_id = GenerateLoadIfNeeded(TypeOf(expr->expr), val_id);
// Bitcast does not allow same types, just emit a CopyObject
auto* to_type = TypeOf(expr)->UnwrapRef();
@@ -3489,11 +3464,10 @@
const ast::BlockStatement* true_body,
size_t cur_else_idx,
const ast::ElseStatementList& else_stmts) {
- auto cond_id = GenerateExpression(cond);
+ auto cond_id = GenerateExpressionWithLoadIfNeeded(cond);
if (cond_id == 0) {
return false;
}
- cond_id = GenerateLoadIfNeeded(TypeOf(cond), cond_id);
auto merge_block = result_op();
auto merge_block_id = merge_block.to_i();
@@ -3585,7 +3559,7 @@
if (is_just_a_break(stmt->body) && stmt->else_statements.empty()) {
// It's a break-if.
TINT_ASSERT(Writer, !backedge_stack_.empty());
- const auto cond_id = GenerateNonReferenceExpression(stmt->condition);
+ const auto cond_id = GenerateExpressionWithLoadIfNeeded(stmt->condition);
if (!cond_id) {
return false;
}
@@ -3600,7 +3574,8 @@
is_just_a_break(es.back()->body)) {
// It's a break-unless.
TINT_ASSERT(Writer, !backedge_stack_.empty());
- const auto cond_id = GenerateNonReferenceExpression(stmt->condition);
+ const auto cond_id =
+ GenerateExpressionWithLoadIfNeeded(stmt->condition);
if (!cond_id) {
return false;
}
@@ -3626,11 +3601,10 @@
merge_stack_.push_back(merge_block_id);
- auto cond_id = GenerateExpression(stmt->condition);
+ auto cond_id = GenerateExpressionWithLoadIfNeeded(stmt->condition);
if (cond_id == 0) {
return false;
}
- cond_id = GenerateLoadIfNeeded(TypeOf(stmt->condition), cond_id);
auto default_block = result_op();
auto default_block_id = default_block.to_i();
@@ -3724,11 +3698,10 @@
bool Builder::GenerateReturnStatement(const ast::ReturnStatement* stmt) {
if (stmt->value) {
- auto val_id = GenerateExpression(stmt->value);
+ auto val_id = GenerateExpressionWithLoadIfNeeded(stmt->value);
if (val_id == 0) {
return false;
}
- val_id = GenerateLoadIfNeeded(TypeOf(stmt->value), val_id);
if (!push_function_inst(spv::Op::OpReturnValue, {Operand::Int(val_id)})) {
return false;
}
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index 1b43ee4..d2b5237 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -458,9 +458,16 @@
/// type, then return the SPIR-V ID for the expression. Otherwise implement
/// the WGSL Load Rule: generate an OpLoad and return the ID of the result.
/// Returns 0 if the expression could not be generated.
- /// @param expr the expression to be generate
+ /// @param expr the semantic expression node to be generated
/// @returns the the ID of the expression, or loaded expression
- uint32_t GenerateNonReferenceExpression(const ast::Expression* expr);
+ uint32_t GenerateExpressionWithLoadIfNeeded(const sem::Expression* expr);
+ /// Generates an expression. If the WGSL expression does not have reference
+ /// type, then return the SPIR-V ID for the expression. Otherwise implement
+ /// the WGSL Load Rule: generate an OpLoad and return the ID of the result.
+ /// Returns 0 if the expression could not be generated.
+ /// @param expr the AST expression to be generated
+ /// @returns the the ID of the expression, or loaded expression
+ uint32_t GenerateExpressionWithLoadIfNeeded(const ast::Expression* expr);
/// Generates an OpLoad on the given ID if it has reference type in WGSL,
/// othewrise return the ID itself.
/// @param type the type of the expression