Resolver: Check that every AST node is reached once

AST nodes must not be shared. Diamonds in the AST will cause all sorts
of exciting, non trivial bugs.

All AST nodes must be reached by the Resolver. There are two common
reasons why they may not be:
(a) They were constructed and not attached to the AST. Several
    transforms scan the full list of constructed AST nodes to find nodes
    of a given type. Having detached nodes will likely cause bugs in
    these transforms. Detached nodes is also just a waste of memory.
(b) They are attached to the AST, but the resolver did not traverse
    them. Having the resolver skip over parts of the AST will fail to
    catch validation issues, and will leave semantic gaps, likely
    breaking downstream logic.

Bug: tint:469
Change-Id: I143b84fd830699f874d2936146f0e93197db610c
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47778
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/resolver/decoration_validation_test.cc b/src/resolver/decoration_validation_test.cc
index 6b74e1e..ea372e5 100644
--- a/src/resolver/decoration_validation_test.cc
+++ b/src/resolver/decoration_validation_test.cc
@@ -420,6 +420,20 @@
         Params{ty_mat3x3<f32>, (default_mat3x3.align - 1) * 7, false},
         Params{ty_mat4x4<f32>, (default_mat4x4.align - 1) * 7, false}));
 
+TEST_F(ArrayStrideTest, MultipleDecorations) {
+  auto* arr = create<type::Array>(ty.i32(), 4,
+                                  ast::DecorationList{
+                                      create<ast::StrideDecoration>(4),
+                                      create<ast::StrideDecoration>(4),
+                                  });
+
+  Global(Source{{12, 34}}, "myarray", arr, ast::StorageClass::kInput);
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: array must have at most one [[stride]] decoration");
+}
+
 }  // namespace
 }  // namespace ArrayStrideTests
 }  // namespace resolver
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 1bbccd6..556b555 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -17,6 +17,7 @@
 #include <algorithm>
 #include <utility>
 
+#include "src/ast/access_decoration.h"
 #include "src/ast/assignment_statement.h"
 #include "src/ast/bitcast_expression.h"
 #include "src/ast/break_statement.h"
@@ -178,20 +179,22 @@
 }
 
 bool Resolver::ResolveInternal() {
+  Mark(&builder_->AST());
+
   // Process everything else in the order they appear in the module. This is
   // necessary for validation of use-before-declaration.
   for (auto* decl : builder_->AST().GlobalDeclarations()) {
-    if (decl->Is<type::Type>()) {
-      if (auto* str = decl->As<type::Struct>()) {
-        if (!Structure(str)) {
-          return false;
-        }
+    if (auto* ty = decl->As<type::Type>()) {
+      if (!Type(ty)) {
+        return false;
       }
     } else if (auto* func = decl->As<ast::Function>()) {
+      Mark(func);
       if (!Function(func)) {
         return false;
       }
     } else if (auto* var = decl->As<ast::Variable>()) {
+      Mark(var);
       if (!GlobalVariable(var)) {
         return false;
       }
@@ -202,6 +205,34 @@
     }
   }
 
+  for (auto* node : builder_->ASTNodes().Objects()) {
+    if (marked_.count(node) == 0) {
+      if (node->Is<ast::AccessDecoration>()) {
+        // These are generated by the WGSL parser, used to build
+        // type::AccessControls and then leaked.
+        // Once we introduce AST types, this should be fixed.
+        continue;
+      }
+      TINT_ICE(diagnostics_) << "AST node '" << node->TypeInfo().name
+                             << "' was not reached by the resolver\n"
+                             << "At: " << node->source();
+    }
+  }
+
+  return true;
+}
+
+bool Resolver::Type(type::Type* ty) {
+  ty = ty->UnwrapAliasIfNeeded();
+  if (auto* str = ty->As<type::Struct>()) {
+    if (!Structure(str)) {
+      return false;
+    }
+  } else if (auto* arr = ty->As<type::Array>()) {
+    if (!Array(arr, Source{})) {
+      return false;
+    }
+  }
   return true;
 }
 
@@ -275,6 +306,7 @@
   }
 
   for (auto* deco : var->decorations()) {
+    Mark(deco);
     if (!(deco->Is<ast::BindingDecoration>() ||
           deco->Is<ast::BuiltinDecoration>() ||
           deco->Is<ast::ConstantIdDecoration>() ||
@@ -287,6 +319,7 @@
   }
 
   if (var->has_constructor()) {
+    Mark(var->constructor());
     if (!Expression(var->constructor())) {
       return false;
     }
@@ -570,10 +603,17 @@
 
   variable_stack_.push_scope();
   for (auto* param : func->params()) {
+    Mark(param);
     auto* param_info = Variable(param);
     if (!param_info) {
       return false;
     }
+
+    // TODO(amaiorano): Validate parameter decorations
+    for (auto* deco : param->decorations()) {
+      Mark(deco);
+    }
+
     variable_stack_.set(param->symbol(), param_info);
     func_info->parameters.emplace_back(param_info);
 
@@ -642,12 +682,20 @@
   }
 
   if (func->body()) {
+    Mark(func->body());
     if (!BlockStatement(func->body())) {
       return false;
     }
   }
   variable_stack_.pop_scope();
 
+  for (auto* deco : func->decorations()) {
+    Mark(deco);
+  }
+  for (auto* deco : func->return_type_decorations()) {
+    Mark(deco);
+  }
+
   if (!ValidateFunction(func)) {
     return false;
   }
@@ -668,6 +716,7 @@
 
 bool Resolver::Statements(const ast::StatementList& stmts) {
   for (auto* stmt : stmts) {
+    Mark(stmt);
     if (!Statement(stmt)) {
       return false;
     }
@@ -698,6 +747,7 @@
     return true;
   }
   if (auto* c = stmt->As<ast::CallStatement>()) {
+    Mark(c->expr());
     return Expression(c->expr());
   }
   if (auto* c = stmt->As<ast::CaseStatement>()) {
@@ -732,11 +782,15 @@
     // these would make their BlockInfo siblings as in the AST, but we want the
     // body BlockInfo to parent the continuing BlockInfo for semantics and
     // validation. Also, we need to set their types differently.
+    Mark(l->body());
     return BlockScope(l->body(), BlockInfo::Type::kLoop, [&] {
       if (!Statements(l->body()->list())) {
         return false;
       }
 
+      if (l->continuing()) {  // has_continuing() also checks for empty()
+        Mark(l->continuing());
+      }
       if (l->has_continuing()) {
         if (!BlockScope(l->continuing(), BlockInfo::Type::kLoopContinuing,
                         [&] { return Statements(l->continuing()->list()); })) {
@@ -764,11 +818,16 @@
 }
 
 bool Resolver::CaseStatement(ast::CaseStatement* stmt) {
+  Mark(stmt->body());
+  for (auto* sel : stmt->selectors()) {
+    Mark(sel);
+  }
   return BlockScope(stmt->body(), BlockInfo::Type::kSwitchCase,
                     [&] { return Statements(stmt->body()->list()); });
 }
 
 bool Resolver::IfStatement(ast::IfStatement* stmt) {
+  Mark(stmt->condition());
   if (!Expression(stmt->condition())) {
     return false;
   }
@@ -781,11 +840,13 @@
     return false;
   }
 
+  Mark(stmt->body());
   if (!BlockStatement(stmt->body())) {
     return false;
   }
 
   for (auto* else_stmt : stmt->else_statements()) {
+    Mark(else_stmt);
     // Else statements are a bit unusual - they're owned by the if-statement,
     // not a BlockStatement.
     constexpr ast::BlockStatement* no_block_statement = nullptr;
@@ -793,9 +854,13 @@
         builder_->create<sem::Statement>(else_stmt, no_block_statement);
     builder_->Sem().Add(else_stmt, sem_else_stmt);
     ScopedAssignment<sem::Statement*> sa(current_statement_, sem_else_stmt);
-    if (!Expression(else_stmt->condition())) {
-      return false;
+    if (auto* cond = else_stmt->condition()) {
+      Mark(cond);
+      if (!Expression(cond)) {
+        return false;
+      }
     }
+    Mark(else_stmt->body());
     if (!BlockStatement(else_stmt->body())) {
       return false;
     }
@@ -805,6 +870,7 @@
 
 bool Resolver::Expressions(const ast::ExpressionList& list) {
   for (auto* expr : list) {
+    Mark(expr);
     if (!Expression(expr)) {
       return false;
     }
@@ -813,11 +879,6 @@
 }
 
 bool Resolver::Expression(ast::Expression* expr) {
-  // This is blindly called above, so in some cases the expression won't exist.
-  if (!expr) {
-    return true;
-  }
-
   if (TypeOf(expr)) {
     return true;  // Already resolved
   }
@@ -853,9 +914,11 @@
 }
 
 bool Resolver::ArrayAccessor(ast::ArrayAccessorExpression* expr) {
+  Mark(expr->array());
   if (!Expression(expr->array())) {
     return false;
   }
+  Mark(expr->idx_expr());
   if (!Expression(expr->idx_expr())) {
     return false;
   }
@@ -892,6 +955,7 @@
 }
 
 bool Resolver::Bitcast(ast::BitcastExpression* expr) {
+  Mark(expr->expr());
   if (!Expression(expr->expr())) {
     return false;
   }
@@ -907,6 +971,7 @@
   // The expression has to be an identifier as you can't store function pointers
   // but, if it isn't we'll just use the normal result determination to be on
   // the safe side.
+  Mark(call->func());
   auto* ident = call->func()->As<ast::IdentifierExpression>();
   if (!ident) {
     diagnostics_.add_error("call target is not an identifier", call->source());
@@ -993,6 +1058,7 @@
 bool Resolver::Constructor(ast::ConstructorExpression* expr) {
   if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
     for (auto* value : type_ctor->values()) {
+      Mark(value);
       if (!Expression(value)) {
         return false;
       }
@@ -1003,13 +1069,14 @@
     // obey the constructor type rules laid out in
     // https://gpuweb.github.io/gpuweb/wgsl.html#type-constructor-expr.
     if (auto* vec_type = type_ctor->type()->As<type::Vector>()) {
-      return VectorConstructor(vec_type, type_ctor->values());
+      return ValidateVectorConstructor(vec_type, type_ctor->values());
     }
     if (auto* mat_type = type_ctor->type()->As<type::Matrix>()) {
-      return MatrixConstructor(mat_type, type_ctor->values());
+      return ValidateMatrixConstructor(mat_type, type_ctor->values());
     }
     // TODO(crbug.com/tint/634): Validate array constructor
   } else if (auto* scalar_ctor = expr->As<ast::ScalarConstructorExpression>()) {
+    Mark(scalar_ctor->literal());
     SetType(expr, scalar_ctor->literal()->type());
   } else {
     TINT_ICE(diagnostics_) << "unexpected constructor expression type";
@@ -1017,8 +1084,8 @@
   return true;
 }
 
-bool Resolver::VectorConstructor(const type::Vector* vec_type,
-                                 const ast::ExpressionList& values) {
+bool Resolver::ValidateVectorConstructor(const type::Vector* vec_type,
+                                         const ast::ExpressionList& values) {
   type::Type* elem_type = vec_type->type()->UnwrapAll();
   size_t value_cardinality_sum = 0;
   for (auto* value : values) {
@@ -1085,8 +1152,8 @@
   return true;
 }
 
-bool Resolver::MatrixConstructor(const type::Matrix* matrix_type,
-                                 const ast::ExpressionList& values) {
+bool Resolver::ValidateMatrixConstructor(const type::Matrix* matrix_type,
+                                         const ast::ExpressionList& values) {
   // Zero Value expression
   if (values.empty()) {
     return true;
@@ -1198,6 +1265,7 @@
 }
 
 bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
+  Mark(expr->structure());
   if (!Expression(expr->structure())) {
     return false;
   }
@@ -1209,8 +1277,9 @@
   std::vector<uint32_t> swizzle;
 
   if (auto* ty = data_type->As<type::Struct>()) {
-    auto* str = Structure(ty);
+    Mark(expr->member());
     auto symbol = expr->member()->symbol();
+    auto* str = Structure(ty);
 
     const sem::StructMember* member = nullptr;
     for (auto* m : str->members) {
@@ -1236,6 +1305,7 @@
     builder_->Sem().Add(expr, builder_->create<sem::StructMemberAccess>(
                                   expr, ret, current_statement_, member));
   } else if (auto* vec = data_type->As<type::Vector>()) {
+    Mark(expr->member());
     std::string str = builder_->Symbols().NameFor(expr->member()->symbol());
     auto size = str.size();
     swizzle.reserve(str.size());
@@ -1481,6 +1551,8 @@
 }
 
 bool Resolver::Binary(ast::BinaryExpression* expr) {
+  Mark(expr->lhs());
+  Mark(expr->rhs());
   if (!Expression(expr->lhs()) || !Expression(expr->rhs())) {
     return false;
   }
@@ -1557,6 +1629,8 @@
 }
 
 bool Resolver::UnaryOp(ast::UnaryOpExpression* expr) {
+  Mark(expr->expr());
+
   // Result type matches the parameter type.
   if (!Expression(expr->expr())) {
     return false;
@@ -1569,6 +1643,8 @@
 
 bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
   ast::Variable* var = stmt->variable();
+  Mark(var);
+
   type::Type* type = var->declared_type();
 
   bool is_global = false;
@@ -1582,6 +1658,7 @@
   }
 
   if (auto* ctor = stmt->variable()->constructor()) {
+    Mark(ctor);
     if (!Expression(ctor)) {
       return false;
     }
@@ -1602,11 +1679,16 @@
     }
   }
 
+  for (auto* deco : var->decorations()) {
+    // TODO(bclayton): Validate decorations
+    Mark(deco);
+  }
+
   auto* info = Variable(var, type);
   if (!info) {
     return false;
   }
-  // TODO(amaiorano): Remove this and fix tests. We're overriding the semantic
+  // TODO(bclayton): Remove this and fix tests. We're overriding the semantic
   // type stored in info->type here with a possibly non-canonicalized type.
   info->type = type;
   variable_stack_.set(var->symbol(), info);
@@ -1860,9 +1942,16 @@
   };
 
   // Look for explicit stride via [[stride(n)]] decoration
+  uint32_t explicit_stride = 0;
   for (auto* deco : arr->decorations()) {
+    Mark(deco);
     if (auto* stride = deco->As<ast::StrideDecoration>()) {
-      auto explicit_stride = stride->stride();
+      if (explicit_stride) {
+        diagnostics_.add_error(
+            "array must have at most one [[stride]] decoration", source);
+        return nullptr;
+      }
+      explicit_stride = stride->stride();
       bool is_valid_stride = (explicit_stride >= el_size) &&
                              (explicit_stride >= el_align) &&
                              (explicit_stride % el_align == 0);
@@ -1878,10 +1967,11 @@
             source);
         return nullptr;
       }
-
-      return create_semantic(explicit_stride);
     }
   }
+  if (explicit_stride) {
+    return create_semantic(explicit_stride);
+  }
 
   // Calculate implicit stride
   auto implicit_stride = utils::RoundUp(el_align, el_size);
@@ -1950,6 +2040,11 @@
     return info_it->second;
   }
 
+  Mark(str->impl());
+  for (auto* deco : str->impl()->decorations()) {
+    Mark(deco);
+  }
+
   if (!ValidateStructure(str)) {
     return nullptr;
   }
@@ -1972,6 +2067,8 @@
   uint32_t struct_align = 1;
 
   for (auto* member : str->impl()->members()) {
+    Mark(member);
+
     // First check the member type is legal
     if (!IsStorable(member->type())) {
       builder_->Diagnostics().add_error(
@@ -1991,6 +2088,7 @@
     bool has_align_deco = false;
     bool has_size_deco = false;
     for (auto* deco : member->decorations()) {
+      Mark(deco);
       if (auto* o = deco->As<ast::StructMemberOffsetDecoration>()) {
         // Offset decorations are not part of the WGSL spec, but are emitted by
         // the SPIR-V reader.
@@ -2077,11 +2175,15 @@
 bool Resolver::Return(ast::ReturnStatement* ret) {
   current_function_->return_statements.push_back(ret);
 
-  auto result = Expression(ret->value());
+  if (auto* value = ret->value()) {
+    Mark(value);
 
-  // Validate after processing the return value expression so that its type is
-  // available for validation
-  return result && ValidateReturn(ret);
+    // Validate after processing the return value expression so that its type is
+    // available for validation
+    return Expression(value) && ValidateReturn(ret);
+  }
+
+  return true;
 }
 
 bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) {
@@ -2155,10 +2257,12 @@
 }
 
 bool Resolver::Switch(ast::SwitchStatement* s) {
+  Mark(s->condition());
   if (!Expression(s->condition())) {
     return false;
   }
   for (auto* case_stmt : s->body()) {
+    Mark(case_stmt);
     if (!CaseStatement(case_stmt)) {
       return false;
     }
@@ -2231,6 +2335,9 @@
 }
 
 bool Resolver::Assignment(ast::AssignmentStatement* a) {
+  Mark(a->lhs());
+  Mark(a->rhs());
+
   if (!Expression(a->lhs()) || !Expression(a->rhs())) {
     return false;
   }
@@ -2330,6 +2437,19 @@
                             [&] { return make_canonical(type); });
 }
 
+void Resolver::Mark(ast::Node* node) {
+  if (node == nullptr) {
+    TINT_ICE(diagnostics_) << "Resolver::Mark() called with nullptr";
+  }
+  if (marked_.emplace(node).second) {
+    return;
+  }
+  TINT_ICE(diagnostics_)
+      << "AST node '" << node->TypeInfo().name
+      << "' was encountered twice in the same AST of a Program\n"
+      << "At: " << node->source();
+}
+
 Resolver::VariableInfo::VariableInfo(ast::Variable* decl, type::Type* ctype)
     : declaration(decl),
       type(ctype),
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index ae3d823..7bc2940 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -210,43 +210,45 @@
   // AST and Type traversal methods
   // Each return true on success, false on failure.
   bool ArrayAccessor(ast::ArrayAccessorExpression*);
+  bool Assignment(ast::AssignmentStatement* a);
   bool Binary(ast::BinaryExpression*);
   bool Bitcast(ast::BitcastExpression*);
   bool BlockStatement(const ast::BlockStatement*);
   bool Call(ast::CallExpression*);
   bool CaseStatement(ast::CaseStatement*);
   bool Constructor(ast::ConstructorExpression*);
-  bool VectorConstructor(const type::Vector* vec_type,
-                         const ast::ExpressionList& values);
-  bool MatrixConstructor(const type::Matrix* matrix_type,
-                         const ast::ExpressionList& values);
   bool Expression(ast::Expression*);
   bool Expressions(const ast::ExpressionList&);
   bool Function(ast::Function*);
+  bool GlobalVariable(ast::Variable* var);
   bool Identifier(ast::IdentifierExpression*);
   bool IfStatement(ast::IfStatement*);
   bool IntrinsicCall(ast::CallExpression*, sem::IntrinsicType);
   bool MemberAccessor(ast::MemberAccessorExpression*);
+  bool Parameter(ast::Variable* param);
+  bool Return(ast::ReturnStatement* ret);
   bool Statement(ast::Statement*);
   bool Statements(const ast::StatementList&);
+  bool Switch(ast::SwitchStatement* s);
+  bool Type(type::Type* ty);
   bool UnaryOp(ast::UnaryOpExpression*);
   bool VariableDeclStatement(const ast::VariableDeclStatement*);
-  bool Return(ast::ReturnStatement* ret);
-  bool Switch(ast::SwitchStatement* s);
-  bool Assignment(ast::AssignmentStatement* a);
-  bool GlobalVariable(ast::Variable* var);
 
   // AST and Type validation methods
   // Each return true on success, false on failure.
-  bool ValidateBinary(ast::BinaryExpression* expr);
-  bool ValidateVariable(const ast::Variable* param);
-  bool ValidateParameter(const ast::Variable* param);
-  bool ValidateFunction(const ast::Function* func);
-  bool ValidateEntryPoint(const ast::Function* func);
-  bool ValidateStructure(const type::Struct* st);
-  bool ValidateReturn(const ast::ReturnStatement* ret);
-  bool ValidateSwitch(const ast::SwitchStatement* s);
   bool ValidateAssignment(const ast::AssignmentStatement* a);
+  bool ValidateBinary(ast::BinaryExpression* expr);
+  bool ValidateEntryPoint(const ast::Function* func);
+  bool ValidateFunction(const ast::Function* func);
+  bool ValidateMatrixConstructor(const type::Matrix* matrix_type,
+                                 const ast::ExpressionList& values);
+  bool ValidateParameter(const ast::Variable* param);
+  bool ValidateReturn(const ast::ReturnStatement* ret);
+  bool ValidateStructure(const type::Struct* st);
+  bool ValidateSwitch(const ast::SwitchStatement* s);
+  bool ValidateVariable(const ast::Variable* param);
+  bool ValidateVectorConstructor(const type::Vector* vec_type,
+                                 const ast::ExpressionList& values);
 
   /// @returns the semantic information for the array `arr`, building it if it
   /// hasn't been constructed already. If an error is raised, nullptr is
@@ -312,6 +314,11 @@
   /// @return pretty string representation
   std::string VectorPretty(uint32_t size, type::Type* element_type);
 
+  /// Mark records that the given AST node has been visited, and asserts that
+  /// the given node has not already been seen. Diamonds in the AST are illegal.
+  /// @param node the AST node.
+  void Mark(ast::Node* node);
+
   ProgramBuilder* const builder_;
   std::unique_ptr<IntrinsicTable> const intrinsic_table_;
   diag::List diagnostics_;
@@ -324,6 +331,7 @@
   std::unordered_map<ast::Expression*, ExpressionInfo> expr_info_;
   std::unordered_map<type::Struct*, StructInfo*> struct_info_;
   std::unordered_map<type::Type*, type::Type*> type_to_canonical_;
+  std::unordered_set<ast::Node*> marked_;
   FunctionInfo* current_function_ = nullptr;
   sem::Statement* current_statement_ = nullptr;
   BlockAllocator<VariableInfo> variable_infos_;
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index f54ffb1..d9f8ac7 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -17,6 +17,7 @@
 #include <tuple>
 
 #include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
 #include "src/ast/assignment_statement.h"
 #include "src/ast/bitcast_expression.h"
 #include "src/ast/break_statement.h"
@@ -1619,6 +1620,31 @@
   ASSERT_TRUE(r()->Resolve()) << r()->error();
 }
 
+TEST_F(ResolverTest, ASTNodeNotReached) {
+  EXPECT_FATAL_FAILURE(
+      {
+        ProgramBuilder builder;
+        builder.Expr("1");
+        Resolver(&builder).Resolve();
+      },
+      "internal compiler error: AST node 'tint::ast::IdentifierExpression' was "
+      "not reached by the resolver");
+}
+
+TEST_F(ResolverTest, ASTNodeReachedTwice) {
+  EXPECT_FATAL_FAILURE(
+      {
+        ProgramBuilder builder;
+        auto* expr = builder.Expr("1");
+        auto* usesExprTwice = builder.Add(expr, expr);
+        builder.Global("g", builder.ty.i32(), ast::StorageClass::kPrivate,
+                       usesExprTwice);
+        Resolver(&builder).Resolve();
+      },
+      "internal compiler error: AST node 'tint::ast::IdentifierExpression' was "
+      "encountered twice in the same AST of a Program");
+}
+
 }  // namespace
 }  // namespace resolver
 }  // namespace tint
diff --git a/src/source.cc b/src/source.cc
index f4a7236..c9e5076 100644
--- a/src/source.cc
+++ b/src/source.cc
@@ -14,6 +14,7 @@
 
 #include "src/source.h"
 
+#include <algorithm>
 #include <sstream>
 #include <utility>
 
@@ -37,4 +38,57 @@
 
 Source::File::~File() = default;
 
+std::ostream& operator<<(std::ostream& out, const Source& source) {
+  auto rng = source.range;
+
+  if (!source.file_path.empty()) {
+    out << source.file_path << ":";
+  }
+  if (rng.begin.line) {
+    out << rng.begin.line << ":";
+    if (rng.begin.column) {
+      out << rng.begin.column;
+    }
+
+    if (source.file_content) {
+      out << std::endl << std::endl;
+
+      auto repeat = [&](char c, size_t n) {
+        while (n--) {
+          out << c;
+        }
+      };
+
+      for (size_t line = rng.begin.line; line <= rng.end.line; line++) {
+        if (line < source.file_content->lines.size() + 1) {
+          auto len = source.file_content->lines[line - 1].size();
+
+          out << source.file_content->lines[line - 1];
+
+          out << std::endl;
+
+          if (line == rng.begin.line && line == rng.end.line) {
+            // Single line
+            repeat(' ', rng.begin.column - 1);
+            repeat('^', std::max<size_t>(rng.end.column - rng.begin.column, 1));
+          } else if (line == rng.begin.line) {
+            // Start of multi-line
+            repeat(' ', rng.begin.column - 1);
+            repeat('^', len - (rng.begin.column - 1));
+          } else if (line == rng.end.line) {
+            // End of multi-line
+            repeat('^', rng.end.column - 1);
+          } else {
+            // Middle of multi-line
+            repeat('^', len);
+          }
+
+          out << std::endl;
+        }
+      }
+    }
+  }
+  return out;
+}
+
 }  // namespace tint
diff --git a/src/source.h b/src/source.h
index 416a30e..07146da 100644
--- a/src/source.h
+++ b/src/source.h
@@ -147,6 +147,12 @@
   const FileContent* file_content = nullptr;
 };
 
+/// Writes the Source to the std::ostream.
+/// @param out the std::ostream to write to
+/// @param source the source to write
+/// @returns out so calls can be chained
+std::ostream& operator<<(std::ostream& out, const Source& source);
+
 /// Writes the Source::FileContent to the std::ostream.
 /// @param out the std::ostream to write to
 /// @param content the file content to write