[tint][ir] Validate value scoping

Errors if a value is used before it has been declared.

Note: this doesn't currently validate value declarations used in a
continuing block are all made before the first 'ir::Continue' statement.

Change-Id: Id2419135ff168e72dc012289839cdd1ca7f460e0
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/186463
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/core/ir/transform/bgra8unorm_polyfill_test.cc b/src/tint/lang/core/ir/transform/bgra8unorm_polyfill_test.cc
index 1e4b39e..7595c2c 100644
--- a/src/tint/lang/core/ir/transform/bgra8unorm_polyfill_test.cc
+++ b/src/tint/lang/core/ir/transform/bgra8unorm_polyfill_test.cc
@@ -525,7 +525,7 @@
     auto* coords = b.FunctionParam("coords", ty.vec2<u32>());
     auto* index = b.FunctionParam("index", ty.u32());
     auto* value = b.FunctionParam("value", ty.vec4<f32>());
-    func->SetParams({value, coords});
+    func->SetParams({value, coords, index, value});
     b.Append(func->Block(), [&] {
         auto* load = b.Load(var->Result(0));
         b.Call(ty.void_(), core::BuiltinFn::kTextureStore, load, coords, index, value);
@@ -537,10 +537,10 @@
   %texture:ptr<handle, texture_storage_2d_array<bgra8unorm, write>, read> = var @binding_point(1, 2)
 }
 
-%foo = func(%value:vec4<f32>, %coords:vec2<u32>):void {
+%foo = func(%value:vec4<f32>, %coords:vec2<u32>, %index:u32%value:vec4<f32>):void {
   $B2: {
-    %5:texture_storage_2d_array<bgra8unorm, write> = load %texture
-    %6:void = textureStore %5, %coords, %index, %value
+    %6:texture_storage_2d_array<bgra8unorm, write> = load %texture
+    %7:void = textureStore %6, %coords, %index, %value
     ret
   }
 }
@@ -550,11 +550,11 @@
   %texture:ptr<handle, texture_storage_2d_array<rgba8unorm, write>, read> = var @binding_point(1, 2)
 }
 
-%foo = func(%value:vec4<f32>, %coords:vec2<u32>):void {
+%foo = func(%value:vec4<f32>, %coords:vec2<u32>, %index:u32%value:vec4<f32>):void {
   $B2: {
-    %5:texture_storage_2d_array<rgba8unorm, write> = load %texture
-    %6:vec4<f32> = swizzle %value, zyxw
-    %7:void = textureStore %5, %coords, %index, %6
+    %6:texture_storage_2d_array<rgba8unorm, write> = load %texture
+    %7:vec4<f32> = swizzle %value, zyxw
+    %8:void = textureStore %6, %coords, %index, %7
     ret
   }
 }
diff --git a/src/tint/lang/core/ir/transform/combine_access_instructions_test.cc b/src/tint/lang/core/ir/transform/combine_access_instructions_test.cc
index 2d07019..404b94b 100644
--- a/src/tint/lang/core/ir/transform/combine_access_instructions_test.cc
+++ b/src/tint/lang/core/ir/transform/combine_access_instructions_test.cc
@@ -677,6 +677,7 @@
     auto* func = b.Function("foo", ty.f32());
     auto* indices = b.FunctionParam("indices", ty.array<u32, 4>());
     auto* values = b.FunctionParam("values", ty.array<f32, 4>());
+    func->SetParams({indices, values});
     b.Append(func->Block(), [&] {
         auto* access_index = b.Access(ty.u32(), indices, 1_u);
         auto* access_value = b.Access(ty.f32(), values, access_index);
@@ -684,11 +685,11 @@
     });
 
     auto* src = R"(
-%foo = func():f32 {
+%foo = func(%indices:array<u32, 4>, %values:array<f32, 4>):f32 {
   $B1: {
-    %2:u32 = access %indices, 1u
-    %4:f32 = access %values, %2
-    ret %4
+    %4:u32 = access %indices, 1u
+    %5:f32 = access %values, %4
+    ret %5
   }
 }
 )";
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 21d8213..5c12eff 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -201,11 +201,15 @@
     /// @param func the function
     diag::Diagnostic& AddNote(const Function* func);
 
-    /// Adds a note to @p inst for operand @p idx and highlights the operand in the
-    /// disassembly
+    /// Adds a note to @p inst for operand @p idx and highlights the operand in the disassembly
     /// @param inst the instruction
     /// @param idx the operand index
-    diag::Diagnostic& AddNote(const Instruction* inst, size_t idx);
+    diag::Diagnostic& AddOperandNote(const Instruction* inst, size_t idx);
+
+    /// Adds a note to @p inst for result @p idx and highlights the result in the disassembly
+    /// @param inst the instruction
+    /// @param idx the result index
+    diag::Diagnostic& AddResultNote(const Instruction* inst, size_t idx);
 
     /// Adds a note to @p blk and highlights the block in the disassembly
     /// @param blk the block
@@ -215,6 +219,11 @@
     /// @param src the source lines to highlight
     diag::Diagnostic& AddNote(Source src = {});
 
+    /// Adds a note to the diagnostics highlighting where the value was declared, if it has a source
+    /// location.
+    /// @param value the value
+    void AddDeclarationNote(const Value* value);
+
     /// @param v the value to get the name for
     /// @returns the name for the given value
     StyledText NameOf(const Value* v);
@@ -352,12 +361,28 @@
     void QueueInstructions(const Instruction* inst);
 
     /// Begins validation of the block @p blk, and its instructions.
+    /// BeginBlock() pushes a new scope for values.
     /// Must be paired with a call to EndBlock().
     void BeginBlock(const Block* blk);
 
-    /// Ends validation of the block opened with BeginBlock().
+    /// Ends validation of the block opened with BeginBlock() and closes the block's scope for
+    /// values.
     void EndBlock();
 
+    /// ScopeStack holds a stack of values that are currently in scope
+    struct ScopeStack {
+        void Push() { stack_.Push({}); }
+        void Pop() { stack_.Pop(); }
+        void Add(const Value* value) { stack_.Back().Add(value); }
+        bool Contains(const Value* value) {
+            return stack_.Any([&](auto& v) { return v.Contains(value); });
+        }
+        bool IsEmpty() const { return stack_.IsEmpty(); }
+
+      private:
+        Vector<Hashset<const Value*, 8>, 4> stack_;
+    };
+
     const Module& mod_;
     Capabilities capabilities_;
     std::optional<ir::Disassembly> disassembly_;  // Use Disassembly()
@@ -366,6 +391,7 @@
     Hashset<const Instruction*, 4> visited_instructions_;
     Vector<const ControlInstruction*, 8> control_stack_;
     Vector<const Block*, 8> block_stack_;
+    ScopeStack scope_stack_;
     Vector<std::function<void()>, 16> tasks_;
 };
 
@@ -382,7 +408,10 @@
 }
 
 Result<SuccessType> Validator::Run() {
+    scope_stack_.Push();
     TINT_DEFER({
+        scope_stack_.Pop();
+        TINT_ASSERT(scope_stack_.IsEmpty());
         TINT_ASSERT(tasks_.IsEmpty());
         TINT_ASSERT(control_stack_.IsEmpty());
         TINT_ASSERT(block_stack_.IsEmpty());
@@ -394,6 +423,7 @@
             AddError(func) << "function " << NameOf(func.Get())
                            << " added to module multiple times";
         }
+        scope_stack_.Add(func);
     }
 
     for (auto& func : mod_.functions) {
@@ -479,12 +509,18 @@
     return AddNote(src);
 }
 
-diag::Diagnostic& Validator::AddNote(const Instruction* inst, size_t idx) {
+diag::Diagnostic& Validator::AddOperandNote(const Instruction* inst, size_t idx) {
     auto src =
         Disassembly().OperandSource(Disassembly::IndexedValue{inst, static_cast<uint32_t>(idx)});
     return AddNote(src);
 }
 
+diag::Diagnostic& Validator::AddResultNote(const Instruction* inst, size_t idx) {
+    auto src =
+        Disassembly().ResultSource(Disassembly::IndexedValue{inst, static_cast<uint32_t>(idx)});
+    return AddNote(src);
+}
+
 diag::Diagnostic& Validator::AddNote(const Block* blk) {
     auto src = Disassembly().BlockSource(blk);
     return AddNote(src);
@@ -508,6 +544,35 @@
     return diag;
 }
 
+void Validator::AddDeclarationNote(const Value* value) {
+    tint::Switch(
+        value,  //
+        [&](const InstructionResult* res) {
+            if (auto* inst = res->Instruction()) {
+                auto results = inst->Results();
+                for (size_t i = 0; i < results.Length(); i++) {
+                    if (results[i] == value) {
+                        AddResultNote(res->Instruction(), i) << NameOf(value) << " declared here";
+                        return;
+                    }
+                }
+            }
+        },
+        [&](const FunctionParam* param) {
+            auto src = Disassembly().FunctionParamSource(param);
+            if (src.file) {
+                AddNote(src) << NameOf(value) << " declared here";
+            }
+        },
+        [&](const BlockParam* param) {
+            auto src = Disassembly().BlockParamSource(param);
+            if (src.file) {
+                AddNote(src) << NameOf(value) << " declared here";
+            }
+        },
+        [&](const Function* fn) { AddNote(fn) << NameOf(value) << " declared here"; });
+}
+
 StyledText Validator::NameOf(const Value* value) {
     return Disassembly().NameOf(value);
 }
@@ -546,6 +611,10 @@
 }
 
 void Validator::CheckFunction(const Function* func) {
+    // Scope holds the parameters and block
+    scope_stack_.Push();
+    TINT_DEFER(scope_stack_.Pop());
+
     for (auto* param : func->Params()) {
         if (!param->Alive()) {
             AddError(param) << "destroyed parameter found in function parameter list";
@@ -564,6 +633,8 @@
         if (HoldsType<type::Reference>(param->Type())) {
             AddError(param) << "references are not permitted as parameter types";
         }
+
+        scope_stack_.Add(param);
     }
     if (HoldsType<type::Reference>(func->ReturnType())) {
         AddError(func) << "references are not permitted as return types";
@@ -585,6 +656,7 @@
 }
 
 void Validator::BeginBlock(const Block* blk) {
+    scope_stack_.Push();
     block_stack_.Push(blk);
 
     if (auto* mb = blk->As<MultiInBlock>()) {
@@ -601,6 +673,7 @@
                 AddNote(param->Block()) << "parent block declared here";
                 return;
             }
+            scope_stack_.Add(param);
         }
     }
 
@@ -628,6 +701,7 @@
 }
 
 void Validator::EndBlock() {
+    scope_stack_.Pop();
     block_stack_.Pop();
 }
 
@@ -682,6 +756,9 @@
             AddError(inst, i) << "operand missing usage";
         } else if (auto fn = op->As<Function>(); fn && !all_functions_.Contains(fn)) {
             AddError(inst, i) << NameOf(op) << " is not part of the module";
+        } else if (!op->Is<Constant>() && !scope_stack_.Contains(op)) {
+            AddError(inst, i) << NameOf(op) << " is not in scope";
+            AddDeclarationNote(op);
         }
 
         if (!capabilities_.Contains(Capability::kAllowRefTypes)) {
@@ -709,6 +786,10 @@
         [&](const Unary* u) { CheckUnary(u); },                            //
         [&](const Var* var) { CheckVar(var); },                            //
         [&](const Default) { AddError(inst) << "missing validation"; });
+
+    for (auto* result : results) {
+        scope_stack_.Add(result);
+    }
 }
 
 void Validator::CheckVar(const Var* var) {
@@ -820,7 +901,7 @@
             return AddError(a, i + Access::kIndicesOperandOffset);
         };
         auto note = [&]() -> diag::Diagnostic& {
-            return AddNote(a, i + Access::kIndicesOperandOffset);
+            return AddOperandNote(a, i + Access::kIndicesOperandOffset);
         };
 
         auto* index = a->Indices()[i];
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 7c03112..864e24d 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -1392,17 +1392,16 @@
 TEST_F(IR_ValidatorTest, Var_Init_WrongType) {
     auto* f = b.Function("my_func", ty.void_());
 
-    auto sb = b.Append(f->Block());
-    auto* v = sb.Var(ty.ptr<function, f32>());
-    sb.Return(f);
-
-    auto* result = sb.InstructionResult(ty.i32());
-    v->SetInitializer(result);
+    b.Append(f->Block(), [&] {
+        auto* v = b.Var<function, f32>();
+        v->SetInitializer(b.Constant(1_i));
+        b.Return(f);
+    });
 
     auto res = ir::Validate(mod);
     ASSERT_NE(res, Success);
     EXPECT_EQ(res.Failure().reason.Str(), R"(:3:41 error: var: initializer has incorrect type
-    %2:ptr<function, f32, read_write> = var, %3
+    %2:ptr<function, f32, read_write> = var, 1i
                                         ^^^
 
 :2:3 note: in block
@@ -1412,7 +1411,7 @@
 note: # Disassembly
 %my_func = func():void {
   $B1: {
-    %2:ptr<function, f32, read_write> = var, %3
+    %2:ptr<function, f32, read_write> = var, 1i
     ret
   }
 }
@@ -3723,6 +3722,42 @@
 )");
 }
 
+TEST_F(IR_ValidatorTest, Scoping_UseBeforeDecl) {
+    auto* f = b.Function("my_func", ty.void_());
+
+    auto* y = b.Add<i32>(2_i, 3_i);
+    auto* x = b.Add<i32>(y, 1_i);
+
+    f->Block()->Append(x);
+    f->Block()->Append(y);
+    f->Block()->Append(b.Return(f));
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(res.Failure().reason.Str(),
+              R"(:3:18 error: binary: %3 is not in scope
+    %2:i32 = add %3, 1i
+                 ^^
+
+:2:3 note: in block
+  $B1: {
+  ^^^
+
+:4:5 note: %3 declared here
+    %3:i32 = add 2i, 3i
+    ^^^^^^
+
+note: # Disassembly
+%my_func = func():void {
+  $B1: {
+    %2:i32 = add %3, 1i
+    %3:i32 = add 2i, 3i
+    ret
+  }
+}
+)");
+}
+
 template <typename T>
 static const type::Type* TypeBuilder(type::Manager& m) {
     return m.Get<T>();
diff --git a/src/tint/lang/spirv/reader/lower/vector_element_pointer_test.cc b/src/tint/lang/spirv/reader/lower/vector_element_pointer_test.cc
index ade391c..bd30d05 100644
--- a/src/tint/lang/spirv/reader/lower/vector_element_pointer_test.cc
+++ b/src/tint/lang/spirv/reader/lower/vector_element_pointer_test.cc
@@ -42,16 +42,17 @@
 TEST_F(SpirvReader_VectorElementPointerTest, NonPointerAccess) {
     auto* vec = b.FunctionParam("vec", ty.vec4<u32>());
     auto* foo = b.Function("foo", ty.u32());
+    foo->SetParams({vec});
     b.Append(foo->Block(), [&] {
         auto* access = b.Access<u32>(vec, 2_u);
         b.Return(foo, access);
     });
 
     auto* src = R"(
-%foo = func():u32 {
+%foo = func(%vec:vec4<u32>):u32 {
   $B1: {
-    %2:u32 = access %vec, 2u
-    ret %2
+    %3:u32 = access %vec, 2u
+    ret %3
   }
 }
 )";
diff --git a/src/tint/lang/spirv/writer/raise/merge_return_test.cc b/src/tint/lang/spirv/writer/raise/merge_return_test.cc
index 75998f5..40d44b0 100644
--- a/src/tint/lang/spirv/writer/raise/merge_return_test.cc
+++ b/src/tint/lang/spirv/writer/raise/merge_return_test.cc
@@ -67,7 +67,7 @@
     auto* in = b.FunctionParam(ty.i32());
     auto* cond = b.FunctionParam(ty.bool_());
     auto* func = b.Function("foo", ty.i32());
-    func->SetParams({in});
+    func->SetParams({in, cond});
 
     b.Append(func->Block(), [&] {
         auto* ifelse = b.If(cond);
@@ -78,9 +78,9 @@
         b.Return(func, ifelse->Result(0));
     });
     auto* src = R"(
-%foo = func(%2:i32):i32 {
+%foo = func(%2:i32, %3:bool):i32 {
   $B1: {
-    %3:i32 = if %4 [t: $B2, f: $B3] {  # if_1
+    %4:i32 = if %3 [t: $B2, f: $B3] {  # if_1
       $B2: {  # true
         %5:i32 = add %2, 1i
         exit_if %5  # if_1
@@ -90,7 +90,7 @@
         exit_if %6  # if_1
       }
     }
-    ret %3
+    ret %4
   }
 }
 )";
@@ -107,7 +107,7 @@
     auto* in = b.FunctionParam(ty.i32());
     auto* cond = b.FunctionParam(ty.bool_());
     auto* func = b.Function("foo", ty.i32());
-    func->SetParams({in});
+    func->SetParams({in, cond});
 
     b.Append(func->Block(), [&] {
         auto* swtch = b.Switch(in);
@@ -125,7 +125,7 @@
     });
 
     auto* src = R"(
-%foo = func(%2:i32):i32 {
+%foo = func(%2:i32, %3:bool):i32 {
   $B1: {
     switch %2 [c: (default, $B2)] {  # switch_1
       $B2: {  # case
@@ -137,7 +137,7 @@
         exit_loop  # loop_1
       }
     }
-    %3:i32 = if %4 [t: $B4, f: $B5] {  # if_1
+    %4:i32 = if %3 [t: $B4, f: $B5] {  # if_1
       $B4: {  # true
         %5:i32 = add %2, 1i
         exit_if %5  # if_1
@@ -147,7 +147,7 @@
         exit_if %6  # if_1
       }
     }
-    ret %3
+    ret %4
   }
 }
 )";