Clean up the ScopeStack interface
There's no need for the ScopeStack to include 'global' information. This
is easily obtainable from the element type.
Replace the get-by-reference, with a simpler return value.
Change-Id: Ic6f4c0f656a2019417d68ffb3fe85ba8343ad15e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/68403
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index fb1b25d..6836f12 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -656,7 +656,7 @@
if (!info) {
return false;
}
- variable_stack_.set_global(var->symbol, info);
+ variable_stack_.Set(var->symbol, info);
if (!var->is_const && info->storage_class == ast::StorageClass::kNone) {
AddError("global variables must have a storage class", var->source);
@@ -1769,7 +1769,7 @@
TINT_SCOPED_ASSIGNMENT(current_function_, info);
- variable_stack_.push_scope();
+ variable_stack_.Push();
uint32_t parameter_index = 0;
std::unordered_map<Symbol, Source> parameter_names;
for (auto* param : func->params) {
@@ -1798,7 +1798,7 @@
return false;
}
- variable_stack_.set(param->symbol, param_info);
+ variable_stack_.Set(param->symbol, param_info);
info->parameters.emplace_back(param_info);
if (!ApplyStorageClassUsageToType(param->declared_storage_class,
@@ -1887,7 +1887,7 @@
return false;
}
}
- variable_stack_.pop_scope();
+ variable_stack_.Pop();
for (auto* deco : func->decorations) {
Mark(deco);
@@ -1952,9 +1952,8 @@
if (auto* ident = expr->As<ast::IdentifierExpression>()) {
// We have an identifier of a module-scope constant.
- VariableInfo* var = nullptr;
- if (!variable_stack_.get(ident->symbol, &var) ||
- !(var->declaration->is_const)) {
+ VariableInfo* var = variable_stack_.Get(ident->symbol);
+ if (!var || !(var->declaration->is_const)) {
AddError(kErrBadType, expr->source);
return false;
}
@@ -2635,8 +2634,8 @@
if (param->declaration->type->Is<ast::Pointer>()) {
auto is_valid = false;
if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) {
- VariableInfo* var;
- if (!variable_stack_.get(ident_expr->symbol, &var)) {
+ VariableInfo* var = variable_stack_.Get(ident_expr->symbol);
+ if (!var) {
TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
return false;
}
@@ -2647,8 +2646,8 @@
if (unary->op == ast::UnaryOp::kAddressOf) {
if (auto* ident_unary =
unary->expr->As<ast::IdentifierExpression>()) {
- VariableInfo* var;
- if (!variable_stack_.get(ident_unary->symbol, &var)) {
+ VariableInfo* var = variable_stack_.Get(ident_unary->symbol);
+ if (!var) {
TINT_ICE(Resolver, diagnostics_)
<< "failed to resolve identifier";
return false;
@@ -2987,8 +2986,7 @@
bool Resolver::Identifier(const ast::IdentifierExpression* expr) {
auto symbol = expr->symbol;
- VariableInfo* var;
- if (variable_stack_.get(symbol, &var)) {
+ if (VariableInfo* var = variable_stack_.Get(symbol)) {
SetExprInfo(expr, var->type, var->type_name);
var->users.push_back(expr);
@@ -3485,7 +3483,7 @@
}
}
- variable_stack_.set(var->symbol, info);
+ variable_stack_.Set(var->symbol, info);
if (current_block_) { // Not all statements are inside a block
current_block_->AddDecl(var);
}
@@ -3909,9 +3907,8 @@
if (auto* ident = count_expr->As<ast::IdentifierExpression>()) {
// Make sure the identifier is a non-overridable module-scope constant.
- VariableInfo* var = nullptr;
- bool is_global = false;
- if (!variable_stack_.get(ident->symbol, &var, &is_global) || !is_global ||
+ VariableInfo* var = variable_stack_.Get(ident->symbol);
+ if (!var || var->kind != VariableKind::kGlobal ||
!var->declaration->is_const) {
AddError("array size identifier must be a module-scope constant",
size_source);
@@ -4489,8 +4486,7 @@
auto const* lhs_type = TypeOf(a->lhs);
if (auto* ident = a->lhs->As<ast::IdentifierExpression>()) {
- VariableInfo* var;
- if (variable_stack_.get(ident->symbol, &var)) {
+ if (VariableInfo* var = variable_stack_.Get(ident->symbol)) {
if (var->kind == VariableKind::kParameter) {
AddError("cannot assign to function parameter", a->lhs->source);
AddNote("'" + builder_->Symbols().NameFor(ident->symbol) +
@@ -4542,10 +4538,8 @@
const Source& source,
bool check_global_scope_only) {
if (check_global_scope_only) {
- bool is_global = false;
- VariableInfo* var;
- if (variable_stack_.get(sym, &var, &is_global)) {
- if (is_global) {
+ if (VariableInfo* var = variable_stack_.Get(sym)) {
+ if (var->kind == VariableKind::kGlobal) {
AddError("redefinition of '" + builder_->Symbols().NameFor(sym) + "'",
source);
AddNote("previous definition is here", var->declaration->source);
@@ -4560,8 +4554,7 @@
return false;
}
} else {
- VariableInfo* var;
- if (variable_stack_.get(sym, &var)) {
+ if (VariableInfo* var = variable_stack_.Get(sym)) {
AddError("redefinition of '" + builder_->Symbols().NameFor(sym) + "'",
source);
AddNote("previous definition is here", var->declaration->source);
@@ -4635,10 +4628,10 @@
current_statement_ = stmt;
current_compound_statement_ = stmt;
current_block_ = stmt->As<sem::BlockStatement>();
- variable_stack_.push_scope();
+ variable_stack_.Push();
TINT_DEFER({
- TINT_DEFER(variable_stack_.pop_scope());
+ TINT_DEFER(variable_stack_.Pop());
current_block_ = prev_current_block;
current_compound_statement_ = prev_current_compound_statement;
current_statement_ = prev_current_statement;
diff --git a/src/scope_stack.h b/src/scope_stack.h
index a1619c4..f8ddadc 100644
--- a/src/scope_stack.h
+++ b/src/scope_stack.h
@@ -36,60 +36,33 @@
~ScopeStack() = default;
/// Push a new scope on to the stack
- void push_scope() { stack_.push_back({}); }
+ void Push() { stack_.push_back({}); }
/// Pop the scope off the top of the stack
- void pop_scope() {
+ void Pop() {
if (stack_.size() > 1) {
stack_.pop_back();
}
}
- /// Set a global variable in the stack
+ /// Assigns the value into the top most scope of the stack
/// @param symbol the symbol of the variable
/// @param val the value
- void set_global(const Symbol& symbol, T val) { stack_[0][symbol] = val; }
+ void Set(const Symbol& symbol, T val) { stack_.back()[symbol] = val; }
- /// Sets variable into the top most scope of the stack
- /// @param symbol the symbol of the variable
- /// @param val the value
- void set(const Symbol& symbol, T val) { stack_.back()[symbol] = val; }
-
- /// Checks for the given `symbol` in the stack
+ /// Retrieves a value from the stack
/// @param symbol the symbol to look for
- /// @returns true if the stack contains `symbol`
- bool has(const Symbol& symbol) const { return get(symbol, nullptr); }
-
- /// Retrieves a given variable from the stack
- /// @param symbol the symbol to look for
- /// @param ret where to place the value
- /// @returns true if the symbol was successfully found, false otherwise
- bool get(const Symbol& symbol, T* ret) const {
- return get(symbol, ret, nullptr);
- }
-
- /// Retrieves a given variable from the stack
- /// @param symbol the symbol to look for
- /// @param ret where to place the value
- /// @param is_global set true if the symbol references a global variable
- /// otherwise unchanged
- /// @returns true if the symbol was successfully found, false otherwise
- bool get(const Symbol& symbol, T* ret, bool* is_global) const {
+ /// @returns the variable, or the zero initializer if the value was not found
+ T Get(const Symbol& symbol) const {
for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) {
auto& map = *iter;
auto val = map.find(symbol);
-
if (val != map.end()) {
- if (ret) {
- *ret = val->second;
- }
- if (is_global && iter == stack_.rend() - 1) {
- *is_global = true;
- }
- return true;
+ return val->second;
}
}
- return false;
+
+ return T{};
}
private:
diff --git a/src/scope_stack_test.cc b/src/scope_stack_test.cc
index 0006347..f96f0d1 100644
--- a/src/scope_stack_test.cc
+++ b/src/scope_stack_test.cc
@@ -21,78 +21,32 @@
class ScopeStackTest : public ProgramBuilder, public testing::Test {};
-TEST_F(ScopeStackTest, Global) {
+TEST_F(ScopeStackTest, Get) {
ScopeStack<uint32_t> s;
- Symbol sym(1, ID());
- s.set_global(sym, 5);
+ Symbol a(1, ID());
+ Symbol b(3, ID());
+ s.Push();
+ s.Set(a, 5u);
+ s.Set(b, 10u);
- uint32_t val = 0;
- EXPECT_TRUE(s.get(sym, &val));
- EXPECT_EQ(val, 5u);
-}
+ EXPECT_EQ(s.Get(a), 5u);
+ EXPECT_EQ(s.Get(b), 10u);
-TEST_F(ScopeStackTest, Global_SetWithPointer) {
- auto* v = Var("my_var", ty.f32(), ast::StorageClass::kNone);
- ScopeStack<const ast::Variable*> s;
- s.set_global(v->symbol, v);
+ s.Push();
- const ast::Variable* v2 = nullptr;
- EXPECT_TRUE(s.get(v->symbol, &v2));
- EXPECT_EQ(v2->symbol, v->symbol);
-}
+ s.Set(a, 15u);
+ EXPECT_EQ(s.Get(a), 15u);
+ EXPECT_EQ(s.Get(b), 10u);
-TEST_F(ScopeStackTest, Global_CanNotPop) {
- ScopeStack<uint32_t> s;
- Symbol sym(1, ID());
- s.set_global(sym, 5);
- s.pop_scope();
-
- uint32_t val = 0;
- EXPECT_TRUE(s.get(sym, &val));
- EXPECT_EQ(val, 5u);
-}
-
-TEST_F(ScopeStackTest, Scope) {
- ScopeStack<uint32_t> s;
- Symbol sym(1, ID());
- s.push_scope();
- s.set(sym, 5);
-
- uint32_t val = 0;
- EXPECT_TRUE(s.get(sym, &val));
- EXPECT_EQ(val, 5u);
+ s.Pop();
+ EXPECT_EQ(s.Get(a), 5u);
+ EXPECT_EQ(s.Get(b), 10u);
}
TEST_F(ScopeStackTest, Get_MissingSymbol) {
ScopeStack<uint32_t> s;
Symbol sym(1, ID());
- uint32_t ret = 0;
- EXPECT_FALSE(s.get(sym, &ret));
- EXPECT_EQ(ret, 0u);
-}
-
-TEST_F(ScopeStackTest, Has) {
- ScopeStack<uint32_t> s;
- Symbol sym(1, ID());
- Symbol sym2(2, ID());
- s.set_global(sym2, 3);
- s.push_scope();
- s.set(sym, 5);
-
- EXPECT_TRUE(s.has(sym));
- EXPECT_TRUE(s.has(sym2));
-}
-
-TEST_F(ScopeStackTest, ReturnsScopeBeforeGlobalFirst) {
- ScopeStack<uint32_t> s;
- Symbol sym(1, ID());
- s.set_global(sym, 3);
- s.push_scope();
- s.set(sym, 5);
-
- uint32_t ret;
- EXPECT_TRUE(s.get(sym, &ret));
- EXPECT_EQ(ret, 5u);
+ EXPECT_EQ(s.Get(sym), 0u);
}
} // namespace
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 1ca479f..71dc541 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -46,6 +46,7 @@
#include "src/transform/simplify.h"
#include "src/transform/vectorize_scalar_matrix_constructors.h"
#include "src/transform/zero_init_workgroup_memory.h"
+#include "src/utils/defer.h"
#include "src/utils/get_or_create.h"
#include "src/writer/append_vector.h"
@@ -468,8 +469,8 @@
continue;
}
- uint32_t var_id;
- if (!scope_stack_.get(var->Declaration()->symbol, &var_id)) {
+ uint32_t var_id = scope_stack_.Get(var->Declaration()->symbol);
+ if (var_id == 0) {
error_ = "unable to find ID for global variable: " +
builder_.Symbols().NameFor(var->Declaration()->symbol);
return false;
@@ -613,7 +614,8 @@
return false;
}
- scope_stack_.push_scope();
+ scope_stack_.Push();
+ TINT_DEFER(scope_stack_.Pop());
auto definition_inst = Instruction{
spv::Op::OpFunction,
@@ -636,7 +638,7 @@
params.push_back(Instruction{spv::Op::OpFunctionParameter,
{Operand::Int(param_type_id), param_op}});
- scope_stack_.set(param->Declaration()->symbol, param_id);
+ scope_stack_.Set(param->Declaration()->symbol, param_id);
}
push_function(Function{definition_inst, result_op(), std::move(params)});
@@ -656,8 +658,6 @@
}
}
- scope_stack_.pop_scope();
-
func_symbol_to_id_[func_ast->symbol] = func_id;
return true;
@@ -706,7 +706,7 @@
error_ = "missing constructor for constant";
return false;
}
- scope_stack_.set(var->symbol, init_id);
+ scope_stack_.Set(var->symbol, init_id);
spirv_id_to_variable_[init_id] = var;
return true;
}
@@ -740,7 +740,7 @@
}
}
- scope_stack_.set(var->symbol, var_id);
+ scope_stack_.Set(var->symbol, var_id);
spirv_id_to_variable_[var_id] = var;
return true;
@@ -803,7 +803,7 @@
{Operand::Int(init_id),
Operand::String(builder_.Symbols().NameFor(var->symbol))});
- scope_stack_.set_global(var->symbol, init_id);
+ scope_stack_.Set(var->symbol, init_id);
spirv_id_to_variable_[init_id] = var;
return true;
}
@@ -898,7 +898,7 @@
}
}
- scope_stack_.set_global(var->symbol, var_id);
+ scope_stack_.Set(var->symbol, var_id);
spirv_id_to_variable_[var_id] = var;
return true;
}
@@ -1173,14 +1173,12 @@
uint32_t Builder::GenerateIdentifierExpression(
const ast::IdentifierExpression* expr) {
- uint32_t val = 0;
- if (scope_stack_.get(expr->symbol, &val)) {
- return val;
+ uint32_t val = scope_stack_.Get(expr->symbol);
+ if (val == 0) {
+ error_ = "unable to find variable with identifier: " +
+ builder_.Symbols().NameFor(expr->symbol);
}
-
- error_ = "unable to find variable with identifier: " +
- builder_.Symbols().NameFor(expr->symbol);
- return 0;
+ return val;
}
uint32_t Builder::GenerateLoadIfNeeded(const sem::Type* type, uint32_t id) {
@@ -2231,10 +2229,9 @@
}
bool Builder::GenerateBlockStatement(const ast::BlockStatement* stmt) {
- scope_stack_.push_scope();
- auto result = GenerateBlockStatementWithoutScoping(stmt);
- scope_stack_.pop_scope();
- return result;
+ scope_stack_.Push();
+ TINT_DEFER(scope_stack_.Pop());
+ return GenerateBlockStatementWithoutScoping(stmt);
}
bool Builder::GenerateBlockStatementWithoutScoping(
@@ -3736,34 +3733,35 @@
// We need variables from the body to be visible in the continuing block, so
// manage scope outside of GenerateBlockStatement.
- scope_stack_.push_scope();
+ {
+ scope_stack_.Push();
+ TINT_DEFER(scope_stack_.Pop());
- if (!GenerateBlockStatementWithoutScoping(stmt->body)) {
- return false;
- }
-
- // We only branch if the last element of the body didn't already branch.
- if (!LastIsTerminator(stmt->body)) {
- if (!push_function_inst(spv::Op::OpBranch,
- {Operand::Int(continue_block_id)})) {
+ if (!GenerateBlockStatementWithoutScoping(stmt->body)) {
return false;
}
- }
- if (!GenerateLabel(continue_block_id)) {
- return false;
- }
- if (stmt->continuing && !stmt->continuing->Empty()) {
- continuing_stack_.emplace_back(stmt->continuing->Last(), loop_header_id,
- merge_block_id);
- if (!GenerateBlockStatementWithoutScoping(stmt->continuing)) {
+ // We only branch if the last element of the body didn't already branch.
+ if (!LastIsTerminator(stmt->body)) {
+ if (!push_function_inst(spv::Op::OpBranch,
+ {Operand::Int(continue_block_id)})) {
+ return false;
+ }
+ }
+
+ if (!GenerateLabel(continue_block_id)) {
return false;
}
- continuing_stack_.pop_back();
+ if (stmt->continuing && !stmt->continuing->Empty()) {
+ continuing_stack_.emplace_back(stmt->continuing->Last(), loop_header_id,
+ merge_block_id);
+ if (!GenerateBlockStatementWithoutScoping(stmt->continuing)) {
+ return false;
+ }
+ continuing_stack_.pop_back();
+ }
}
- scope_stack_.pop_scope();
-
// Generate the backedge.
TINT_ASSERT(Writer, !backedge_stack_.empty());
const Backedge& backedge = backedge_stack_.back();