Clean up the ScopeStack interface

There's no need for the ScopeStack to include 'global' information. This
is easily obtainable from the element type.
Replace the get-by-reference, with a simpler return value.

Change-Id: Ic6f4c0f656a2019417d68ffb3fe85ba8343ad15e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/68403
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index fb1b25d..6836f12 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -656,7 +656,7 @@
   if (!info) {
     return false;
   }
-  variable_stack_.set_global(var->symbol, info);
+  variable_stack_.Set(var->symbol, info);
 
   if (!var->is_const && info->storage_class == ast::StorageClass::kNone) {
     AddError("global variables must have a storage class", var->source);
@@ -1769,7 +1769,7 @@
 
   TINT_SCOPED_ASSIGNMENT(current_function_, info);
 
-  variable_stack_.push_scope();
+  variable_stack_.Push();
   uint32_t parameter_index = 0;
   std::unordered_map<Symbol, Source> parameter_names;
   for (auto* param : func->params) {
@@ -1798,7 +1798,7 @@
       return false;
     }
 
-    variable_stack_.set(param->symbol, param_info);
+    variable_stack_.Set(param->symbol, param_info);
     info->parameters.emplace_back(param_info);
 
     if (!ApplyStorageClassUsageToType(param->declared_storage_class,
@@ -1887,7 +1887,7 @@
       return false;
     }
   }
-  variable_stack_.pop_scope();
+  variable_stack_.Pop();
 
   for (auto* deco : func->decorations) {
     Mark(deco);
@@ -1952,9 +1952,8 @@
 
       if (auto* ident = expr->As<ast::IdentifierExpression>()) {
         // We have an identifier of a module-scope constant.
-        VariableInfo* var = nullptr;
-        if (!variable_stack_.get(ident->symbol, &var) ||
-            !(var->declaration->is_const)) {
+        VariableInfo* var = variable_stack_.Get(ident->symbol);
+        if (!var || !(var->declaration->is_const)) {
           AddError(kErrBadType, expr->source);
           return false;
         }
@@ -2635,8 +2634,8 @@
     if (param->declaration->type->Is<ast::Pointer>()) {
       auto is_valid = false;
       if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) {
-        VariableInfo* var;
-        if (!variable_stack_.get(ident_expr->symbol, &var)) {
+        VariableInfo* var = variable_stack_.Get(ident_expr->symbol);
+        if (!var) {
           TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
           return false;
         }
@@ -2647,8 +2646,8 @@
         if (unary->op == ast::UnaryOp::kAddressOf) {
           if (auto* ident_unary =
                   unary->expr->As<ast::IdentifierExpression>()) {
-            VariableInfo* var;
-            if (!variable_stack_.get(ident_unary->symbol, &var)) {
+            VariableInfo* var = variable_stack_.Get(ident_unary->symbol);
+            if (!var) {
               TINT_ICE(Resolver, diagnostics_)
                   << "failed to resolve identifier";
               return false;
@@ -2987,8 +2986,7 @@
 
 bool Resolver::Identifier(const ast::IdentifierExpression* expr) {
   auto symbol = expr->symbol;
-  VariableInfo* var;
-  if (variable_stack_.get(symbol, &var)) {
+  if (VariableInfo* var = variable_stack_.Get(symbol)) {
     SetExprInfo(expr, var->type, var->type_name);
 
     var->users.push_back(expr);
@@ -3485,7 +3483,7 @@
     }
   }
 
-  variable_stack_.set(var->symbol, info);
+  variable_stack_.Set(var->symbol, info);
   if (current_block_) {  // Not all statements are inside a block
     current_block_->AddDecl(var);
   }
@@ -3909,9 +3907,8 @@
 
     if (auto* ident = count_expr->As<ast::IdentifierExpression>()) {
       // Make sure the identifier is a non-overridable module-scope constant.
-      VariableInfo* var = nullptr;
-      bool is_global = false;
-      if (!variable_stack_.get(ident->symbol, &var, &is_global) || !is_global ||
+      VariableInfo* var = variable_stack_.Get(ident->symbol);
+      if (!var || var->kind != VariableKind::kGlobal ||
           !var->declaration->is_const) {
         AddError("array size identifier must be a module-scope constant",
                  size_source);
@@ -4489,8 +4486,7 @@
   auto const* lhs_type = TypeOf(a->lhs);
 
   if (auto* ident = a->lhs->As<ast::IdentifierExpression>()) {
-    VariableInfo* var;
-    if (variable_stack_.get(ident->symbol, &var)) {
+    if (VariableInfo* var = variable_stack_.Get(ident->symbol)) {
       if (var->kind == VariableKind::kParameter) {
         AddError("cannot assign to function parameter", a->lhs->source);
         AddNote("'" + builder_->Symbols().NameFor(ident->symbol) +
@@ -4542,10 +4538,8 @@
                                              const Source& source,
                                              bool check_global_scope_only) {
   if (check_global_scope_only) {
-    bool is_global = false;
-    VariableInfo* var;
-    if (variable_stack_.get(sym, &var, &is_global)) {
-      if (is_global) {
+    if (VariableInfo* var = variable_stack_.Get(sym)) {
+      if (var->kind == VariableKind::kGlobal) {
         AddError("redefinition of '" + builder_->Symbols().NameFor(sym) + "'",
                  source);
         AddNote("previous definition is here", var->declaration->source);
@@ -4560,8 +4554,7 @@
       return false;
     }
   } else {
-    VariableInfo* var;
-    if (variable_stack_.get(sym, &var)) {
+    if (VariableInfo* var = variable_stack_.Get(sym)) {
       AddError("redefinition of '" + builder_->Symbols().NameFor(sym) + "'",
                source);
       AddNote("previous definition is here", var->declaration->source);
@@ -4635,10 +4628,10 @@
   current_statement_ = stmt;
   current_compound_statement_ = stmt;
   current_block_ = stmt->As<sem::BlockStatement>();
-  variable_stack_.push_scope();
+  variable_stack_.Push();
 
   TINT_DEFER({
-    TINT_DEFER(variable_stack_.pop_scope());
+    TINT_DEFER(variable_stack_.Pop());
     current_block_ = prev_current_block;
     current_compound_statement_ = prev_current_compound_statement;
     current_statement_ = prev_current_statement;
diff --git a/src/scope_stack.h b/src/scope_stack.h
index a1619c4..f8ddadc 100644
--- a/src/scope_stack.h
+++ b/src/scope_stack.h
@@ -36,60 +36,33 @@
   ~ScopeStack() = default;
 
   /// Push a new scope on to the stack
-  void push_scope() { stack_.push_back({}); }
+  void Push() { stack_.push_back({}); }
 
   /// Pop the scope off the top of the stack
-  void pop_scope() {
+  void Pop() {
     if (stack_.size() > 1) {
       stack_.pop_back();
     }
   }
 
-  /// Set a global variable in the stack
+  /// Assigns the value into the top most scope of the stack
   /// @param symbol the symbol of the variable
   /// @param val the value
-  void set_global(const Symbol& symbol, T val) { stack_[0][symbol] = val; }
+  void Set(const Symbol& symbol, T val) { stack_.back()[symbol] = val; }
 
-  /// Sets variable into the top most scope of the stack
-  /// @param symbol the symbol of the variable
-  /// @param val the value
-  void set(const Symbol& symbol, T val) { stack_.back()[symbol] = val; }
-
-  /// Checks for the given `symbol` in the stack
+  /// Retrieves a value from the stack
   /// @param symbol the symbol to look for
-  /// @returns true if the stack contains `symbol`
-  bool has(const Symbol& symbol) const { return get(symbol, nullptr); }
-
-  /// Retrieves a given variable from the stack
-  /// @param symbol the symbol to look for
-  /// @param ret where to place the value
-  /// @returns true if the symbol was successfully found, false otherwise
-  bool get(const Symbol& symbol, T* ret) const {
-    return get(symbol, ret, nullptr);
-  }
-
-  /// Retrieves a given variable from the stack
-  /// @param symbol the symbol to look for
-  /// @param ret where to place the value
-  /// @param is_global set true if the symbol references a global variable
-  /// otherwise unchanged
-  /// @returns true if the symbol was successfully found, false otherwise
-  bool get(const Symbol& symbol, T* ret, bool* is_global) const {
+  /// @returns the variable, or the zero initializer if the value was not found
+  T Get(const Symbol& symbol) const {
     for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) {
       auto& map = *iter;
       auto val = map.find(symbol);
-
       if (val != map.end()) {
-        if (ret) {
-          *ret = val->second;
-        }
-        if (is_global && iter == stack_.rend() - 1) {
-          *is_global = true;
-        }
-        return true;
+        return val->second;
       }
     }
-    return false;
+
+    return T{};
   }
 
  private:
diff --git a/src/scope_stack_test.cc b/src/scope_stack_test.cc
index 0006347..f96f0d1 100644
--- a/src/scope_stack_test.cc
+++ b/src/scope_stack_test.cc
@@ -21,78 +21,32 @@
 
 class ScopeStackTest : public ProgramBuilder, public testing::Test {};
 
-TEST_F(ScopeStackTest, Global) {
+TEST_F(ScopeStackTest, Get) {
   ScopeStack<uint32_t> s;
-  Symbol sym(1, ID());
-  s.set_global(sym, 5);
+  Symbol a(1, ID());
+  Symbol b(3, ID());
+  s.Push();
+  s.Set(a, 5u);
+  s.Set(b, 10u);
 
-  uint32_t val = 0;
-  EXPECT_TRUE(s.get(sym, &val));
-  EXPECT_EQ(val, 5u);
-}
+  EXPECT_EQ(s.Get(a), 5u);
+  EXPECT_EQ(s.Get(b), 10u);
 
-TEST_F(ScopeStackTest, Global_SetWithPointer) {
-  auto* v = Var("my_var", ty.f32(), ast::StorageClass::kNone);
-  ScopeStack<const ast::Variable*> s;
-  s.set_global(v->symbol, v);
+  s.Push();
 
-  const ast::Variable* v2 = nullptr;
-  EXPECT_TRUE(s.get(v->symbol, &v2));
-  EXPECT_EQ(v2->symbol, v->symbol);
-}
+  s.Set(a, 15u);
+  EXPECT_EQ(s.Get(a), 15u);
+  EXPECT_EQ(s.Get(b), 10u);
 
-TEST_F(ScopeStackTest, Global_CanNotPop) {
-  ScopeStack<uint32_t> s;
-  Symbol sym(1, ID());
-  s.set_global(sym, 5);
-  s.pop_scope();
-
-  uint32_t val = 0;
-  EXPECT_TRUE(s.get(sym, &val));
-  EXPECT_EQ(val, 5u);
-}
-
-TEST_F(ScopeStackTest, Scope) {
-  ScopeStack<uint32_t> s;
-  Symbol sym(1, ID());
-  s.push_scope();
-  s.set(sym, 5);
-
-  uint32_t val = 0;
-  EXPECT_TRUE(s.get(sym, &val));
-  EXPECT_EQ(val, 5u);
+  s.Pop();
+  EXPECT_EQ(s.Get(a), 5u);
+  EXPECT_EQ(s.Get(b), 10u);
 }
 
 TEST_F(ScopeStackTest, Get_MissingSymbol) {
   ScopeStack<uint32_t> s;
   Symbol sym(1, ID());
-  uint32_t ret = 0;
-  EXPECT_FALSE(s.get(sym, &ret));
-  EXPECT_EQ(ret, 0u);
-}
-
-TEST_F(ScopeStackTest, Has) {
-  ScopeStack<uint32_t> s;
-  Symbol sym(1, ID());
-  Symbol sym2(2, ID());
-  s.set_global(sym2, 3);
-  s.push_scope();
-  s.set(sym, 5);
-
-  EXPECT_TRUE(s.has(sym));
-  EXPECT_TRUE(s.has(sym2));
-}
-
-TEST_F(ScopeStackTest, ReturnsScopeBeforeGlobalFirst) {
-  ScopeStack<uint32_t> s;
-  Symbol sym(1, ID());
-  s.set_global(sym, 3);
-  s.push_scope();
-  s.set(sym, 5);
-
-  uint32_t ret;
-  EXPECT_TRUE(s.get(sym, &ret));
-  EXPECT_EQ(ret, 5u);
+  EXPECT_EQ(s.Get(sym), 0u);
 }
 
 }  // namespace
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 1ca479f..71dc541 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -46,6 +46,7 @@
 #include "src/transform/simplify.h"
 #include "src/transform/vectorize_scalar_matrix_constructors.h"
 #include "src/transform/zero_init_workgroup_memory.h"
+#include "src/utils/defer.h"
 #include "src/utils/get_or_create.h"
 #include "src/writer/append_vector.h"
 
@@ -468,8 +469,8 @@
       continue;
     }
 
-    uint32_t var_id;
-    if (!scope_stack_.get(var->Declaration()->symbol, &var_id)) {
+    uint32_t var_id = scope_stack_.Get(var->Declaration()->symbol);
+    if (var_id == 0) {
       error_ = "unable to find ID for global variable: " +
                builder_.Symbols().NameFor(var->Declaration()->symbol);
       return false;
@@ -613,7 +614,8 @@
     return false;
   }
 
-  scope_stack_.push_scope();
+  scope_stack_.Push();
+  TINT_DEFER(scope_stack_.Pop());
 
   auto definition_inst = Instruction{
       spv::Op::OpFunction,
@@ -636,7 +638,7 @@
     params.push_back(Instruction{spv::Op::OpFunctionParameter,
                                  {Operand::Int(param_type_id), param_op}});
 
-    scope_stack_.set(param->Declaration()->symbol, param_id);
+    scope_stack_.Set(param->Declaration()->symbol, param_id);
   }
 
   push_function(Function{definition_inst, result_op(), std::move(params)});
@@ -656,8 +658,6 @@
     }
   }
 
-  scope_stack_.pop_scope();
-
   func_symbol_to_id_[func_ast->symbol] = func_id;
 
   return true;
@@ -706,7 +706,7 @@
       error_ = "missing constructor for constant";
       return false;
     }
-    scope_stack_.set(var->symbol, init_id);
+    scope_stack_.Set(var->symbol, init_id);
     spirv_id_to_variable_[init_id] = var;
     return true;
   }
@@ -740,7 +740,7 @@
     }
   }
 
-  scope_stack_.set(var->symbol, var_id);
+  scope_stack_.Set(var->symbol, var_id);
   spirv_id_to_variable_[var_id] = var;
 
   return true;
@@ -803,7 +803,7 @@
                {Operand::Int(init_id),
                 Operand::String(builder_.Symbols().NameFor(var->symbol))});
 
-    scope_stack_.set_global(var->symbol, init_id);
+    scope_stack_.Set(var->symbol, init_id);
     spirv_id_to_variable_[init_id] = var;
     return true;
   }
@@ -898,7 +898,7 @@
     }
   }
 
-  scope_stack_.set_global(var->symbol, var_id);
+  scope_stack_.Set(var->symbol, var_id);
   spirv_id_to_variable_[var_id] = var;
   return true;
 }
@@ -1173,14 +1173,12 @@
 
 uint32_t Builder::GenerateIdentifierExpression(
     const ast::IdentifierExpression* expr) {
-  uint32_t val = 0;
-  if (scope_stack_.get(expr->symbol, &val)) {
-    return val;
+  uint32_t val = scope_stack_.Get(expr->symbol);
+  if (val == 0) {
+    error_ = "unable to find variable with identifier: " +
+             builder_.Symbols().NameFor(expr->symbol);
   }
-
-  error_ = "unable to find variable with identifier: " +
-           builder_.Symbols().NameFor(expr->symbol);
-  return 0;
+  return val;
 }
 
 uint32_t Builder::GenerateLoadIfNeeded(const sem::Type* type, uint32_t id) {
@@ -2231,10 +2229,9 @@
 }
 
 bool Builder::GenerateBlockStatement(const ast::BlockStatement* stmt) {
-  scope_stack_.push_scope();
-  auto result = GenerateBlockStatementWithoutScoping(stmt);
-  scope_stack_.pop_scope();
-  return result;
+  scope_stack_.Push();
+  TINT_DEFER(scope_stack_.Pop());
+  return GenerateBlockStatementWithoutScoping(stmt);
 }
 
 bool Builder::GenerateBlockStatementWithoutScoping(
@@ -3736,34 +3733,35 @@
 
   // We need variables from the body to be visible in the continuing block, so
   // manage scope outside of GenerateBlockStatement.
-  scope_stack_.push_scope();
+  {
+    scope_stack_.Push();
+    TINT_DEFER(scope_stack_.Pop());
 
-  if (!GenerateBlockStatementWithoutScoping(stmt->body)) {
-    return false;
-  }
-
-  // We only branch if the last element of the body didn't already branch.
-  if (!LastIsTerminator(stmt->body)) {
-    if (!push_function_inst(spv::Op::OpBranch,
-                            {Operand::Int(continue_block_id)})) {
+    if (!GenerateBlockStatementWithoutScoping(stmt->body)) {
       return false;
     }
-  }
 
-  if (!GenerateLabel(continue_block_id)) {
-    return false;
-  }
-  if (stmt->continuing && !stmt->continuing->Empty()) {
-    continuing_stack_.emplace_back(stmt->continuing->Last(), loop_header_id,
-                                   merge_block_id);
-    if (!GenerateBlockStatementWithoutScoping(stmt->continuing)) {
+    // We only branch if the last element of the body didn't already branch.
+    if (!LastIsTerminator(stmt->body)) {
+      if (!push_function_inst(spv::Op::OpBranch,
+                              {Operand::Int(continue_block_id)})) {
+        return false;
+      }
+    }
+
+    if (!GenerateLabel(continue_block_id)) {
       return false;
     }
-    continuing_stack_.pop_back();
+    if (stmt->continuing && !stmt->continuing->Empty()) {
+      continuing_stack_.emplace_back(stmt->continuing->Last(), loop_header_id,
+                                     merge_block_id);
+      if (!GenerateBlockStatementWithoutScoping(stmt->continuing)) {
+        return false;
+      }
+      continuing_stack_.pop_back();
+    }
   }
 
-  scope_stack_.pop_scope();
-
   // Generate the backedge.
   TINT_ASSERT(Writer, !backedge_stack_.empty());
   const Backedge& backedge = backedge_stack_.back();