Move TypeManager from tint::Context to ast::Module

Bug: tint:307
Bug: tint:337
Change-Id: I726cdf89182813ba6f468f8ac35e5d44b22e1e1f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/33666
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/ast/builder.cc b/src/ast/builder.cc
index b4aacca..871f4ad 100644
--- a/src/ast/builder.cc
+++ b/src/ast/builder.cc
@@ -26,7 +26,7 @@
       tm_(tm) {}
 
 Builder::Builder(tint::Context* c, tint::ast::Module* m)
-    : ctx(c), mod(m), ty(&c->type_mgr()) {}
+    : ctx(c), mod(m), ty(&m->type_mgr()) {}
 Builder::~Builder() = default;
 
 ast::Variable* Builder::Var(const std::string& name,
diff --git a/src/ast/module.h b/src/ast/module.h
index 6ec3c78..db8f353 100644
--- a/src/ast/module.h
+++ b/src/ast/module.h
@@ -22,6 +22,7 @@
 
 #include "src/ast/function.h"
 #include "src/ast/type/alias_type.h"
+#include "src/ast/type_manager.h"
 #include "src/ast/variable.h"
 
 namespace tint {
@@ -77,6 +78,9 @@
   /// @returns a string representation of the module
   std::string to_str() const;
 
+  /// @returns the Type Manager
+  ast::TypeManager& type_mgr() { return type_mgr_; }
+
   /// Creates a new `ast::Node` owned by the Module. When the Module is
   /// destructed, the `ast::Node` will also be destructed.
   /// @param args the arguments to pass to the type constructor
@@ -99,6 +103,7 @@
   std::vector<type::Type*> constructed_types_;
   FunctionList functions_;
   std::vector<std::unique_ptr<ast::Node>> ast_nodes_;
+  ast::TypeManager type_mgr_;
 };
 
 }  // namespace ast
diff --git a/src/ast/type/storage_texture_type_test.cc b/src/ast/type/storage_texture_type_test.cc
index 4e5d8ac..4447f30 100644
--- a/src/ast/type/storage_texture_type_test.cc
+++ b/src/ast/type/storage_texture_type_test.cc
@@ -79,10 +79,10 @@
 
 TEST_F(StorageTextureTypeTest, F32Type) {
   Context ctx;
-  ast::type::Type* s = ctx.type_mgr().Get(std::make_unique<StorageTextureType>(
+  ast::Module mod;
+  ast::type::Type* s = mod.type_mgr().Get(std::make_unique<StorageTextureType>(
       TextureDimension::k2dArray, AccessControl::kReadOnly,
       ImageFormat::kRgba32Float));
-  ast::Module mod;
   TypeDeterminer td(&ctx, &mod);
 
   ASSERT_TRUE(td.Determine()) << td.error();
@@ -93,10 +93,10 @@
 
 TEST_F(StorageTextureTypeTest, U32Type) {
   Context ctx;
-  ast::type::Type* s = ctx.type_mgr().Get(std::make_unique<StorageTextureType>(
+  ast::Module mod;
+  ast::type::Type* s = mod.type_mgr().Get(std::make_unique<StorageTextureType>(
       TextureDimension::k2dArray, AccessControl::kReadOnly,
       ImageFormat::kRgba8Unorm));
-  ast::Module mod;
   TypeDeterminer td(&ctx, &mod);
 
   ASSERT_TRUE(td.Determine()) << td.error();
@@ -107,10 +107,10 @@
 
 TEST_F(StorageTextureTypeTest, I32Type) {
   Context ctx;
-  ast::type::Type* s = ctx.type_mgr().Get(std::make_unique<StorageTextureType>(
+  ast::Module mod;
+  ast::type::Type* s = mod.type_mgr().Get(std::make_unique<StorageTextureType>(
       TextureDimension::k2dArray, AccessControl::kReadOnly,
       ImageFormat::kRgba32Sint));
-  ast::Module mod;
   TypeDeterminer td(&ctx, &mod);
 
   ASSERT_TRUE(td.Determine()) << td.error();
diff --git a/src/ast/type_manager.cc b/src/ast/type_manager.cc
index 6166287..bd841ef 100644
--- a/src/ast/type_manager.cc
+++ b/src/ast/type_manager.cc
@@ -20,7 +20,7 @@
 namespace ast {
 
 TypeManager::TypeManager() = default;
-
+TypeManager::TypeManager(TypeManager&&) = default;
 TypeManager::~TypeManager() = default;
 
 void TypeManager::Reset() {
diff --git a/src/ast/type_manager.h b/src/ast/type_manager.h
index 8f4eb99..c9ca90f 100644
--- a/src/ast/type_manager.h
+++ b/src/ast/type_manager.h
@@ -29,6 +29,8 @@
 class TypeManager {
  public:
   TypeManager();
+  /// Move constructor
+  TypeManager(TypeManager&&);
   ~TypeManager();
 
   /// Clears all registered types.
diff --git a/src/context.cc b/src/context.cc
index fff1052..88e69f9 100644
--- a/src/context.cc
+++ b/src/context.cc
@@ -27,8 +27,4 @@
 
 Context::~Context() = default;
 
-void Context::Reset() {
-  type_mgr_.Reset();
-}
-
 }  // namespace tint
diff --git a/src/context.h b/src/context.h
index 072ab86..d21fc5b 100644
--- a/src/context.h
+++ b/src/context.h
@@ -22,7 +22,6 @@
 #include <utility>
 #include <vector>
 
-#include "src/ast/type_manager.h"
 #include "src/namer.h"
 
 namespace tint {
@@ -42,17 +41,11 @@
   explicit Context(std::unique_ptr<Namer> namer);
   /// Destructor
   ~Context();
-  /// Resets the state of this context.
-  void Reset();
-
-  /// @returns the Type Manager
-  ast::TypeManager& type_mgr() { return type_mgr_; }
 
   /// @returns the namer object
   Namer* namer() const { return namer_.get(); }
 
  private:
-  ast::TypeManager type_mgr_;
   std::unique_ptr<Namer> namer_;
 };
 
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 51010cb..a30df6f 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -3253,7 +3253,7 @@
     const auto* ast_ptr_type = type->AsPointer();
     const auto sc = GetStorageClassForPointerValue(result_id);
     if (ast_ptr_type->storage_class() != sc) {
-      return parser_impl_.context().type_mgr().Get(
+      return parser_impl_.get_module().type_mgr().Get(
           std::make_unique<ast::type::PointerType>(ast_ptr_type->type(), sc));
     }
   }
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index c0bf2c7..c58a7b3 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -196,7 +196,8 @@
     : Reader(ctx),
       spv_binary_(spv_binary),
       fail_stream_(&success_, &errors_),
-      bool_type_(ctx->type_mgr().Get(std::make_unique<ast::type::BoolType>())),
+      bool_type_(
+          ast_module_.type_mgr().Get(std::make_unique<ast::type::BoolType>())),
       namer_(fail_stream_),
       enum_converter_(fail_stream_),
       tools_context_(kInputEnv) {
@@ -285,7 +286,8 @@
 
   switch (spirv_type->kind()) {
     case spvtools::opt::analysis::Type::kVoid:
-      return save(ctx_.type_mgr().Get(std::make_unique<ast::type::VoidType>()));
+      return save(
+          ast_module_.type_mgr().Get(std::make_unique<ast::type::VoidType>()));
     case spvtools::opt::analysis::Type::kBool:
       return save(bool_type_);
     case spvtools::opt::analysis::Type::kInteger:
@@ -315,7 +317,8 @@
     case spvtools::opt::analysis::Type::kImage:
       // Fake it for sampler and texture types.  These are handled in an
       // entirely different way.
-      return save(ctx_.type_mgr().Get(std::make_unique<ast::type::VoidType>()));
+      return save(
+          ast_module_.type_mgr().Get(std::make_unique<ast::type::VoidType>()));
     default:
       break;
   }
@@ -649,9 +652,9 @@
     const spvtools::opt::analysis::Integer* int_ty) {
   if (int_ty->width() == 32) {
     auto* signed_ty =
-        ctx_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+        ast_module_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
     auto* unsigned_ty =
-        ctx_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
+        ast_module_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
     signed_type_for_[unsigned_ty] = signed_ty;
     unsigned_type_for_[signed_ty] = unsigned_ty;
     return int_ty->IsSigned() ? signed_ty : unsigned_ty;
@@ -663,7 +666,7 @@
 ast::type::Type* ParserImpl::ConvertType(
     const spvtools::opt::analysis::Float* float_ty) {
   if (float_ty->width() == 32) {
-    return ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>());
+    return ast_module_.type_mgr().Get(std::make_unique<ast::type::F32Type>());
   }
   Fail() << "unhandled float width: " << float_ty->width();
   return nullptr;
@@ -676,18 +679,18 @@
   if (ast_elem_ty == nullptr) {
     return nullptr;
   }
-  auto* this_ty = ctx_.type_mgr().Get(
+  auto* this_ty = ast_module_.type_mgr().Get(
       std::make_unique<ast::type::VectorType>(ast_elem_ty, num_elem));
   // Generate the opposite-signedness vector type, if this type is integral.
   if (unsigned_type_for_.count(ast_elem_ty)) {
     auto* other_ty =
-        ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+        ast_module_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
             unsigned_type_for_[ast_elem_ty], num_elem));
     signed_type_for_[other_ty] = this_ty;
     unsigned_type_for_[this_ty] = other_ty;
   } else if (signed_type_for_.count(ast_elem_ty)) {
     auto* other_ty =
-        ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+        ast_module_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
             signed_type_for_[ast_elem_ty], num_elem));
     unsigned_type_for_[other_ty] = this_ty;
     signed_type_for_[this_ty] = other_ty;
@@ -705,7 +708,7 @@
   if (ast_scalar_ty == nullptr) {
     return nullptr;
   }
-  return ctx_.type_mgr().Get(std::make_unique<ast::type::MatrixType>(
+  return ast_module_.type_mgr().Get(std::make_unique<ast::type::MatrixType>(
       ast_scalar_ty, num_rows, num_columns));
 }
 
@@ -719,7 +722,7 @@
   if (!ApplyArrayDecorations(rtarr_ty, ast_type.get())) {
     return nullptr;
   }
-  return ctx_.type_mgr().Get(std::move(ast_type));
+  return ast_module_.type_mgr().Get(std::move(ast_type));
 }
 
 ast::type::Type* ParserImpl::ConvertType(
@@ -764,7 +767,7 @@
   if (remap_buffer_block_type_.count(elem_type_id)) {
     remap_buffer_block_type_.insert(type_mgr_->GetId(arr_ty));
   }
-  return ctx_.type_mgr().Get(std::move(ast_type));
+  return ast_module_.type_mgr().Get(std::move(ast_type));
 }
 
 bool ParserImpl::ApplyArrayDecorations(
@@ -892,7 +895,7 @@
   auto ast_struct_type = std::make_unique<ast::type::StructType>(
       namer_.GetName(type_id), ast_struct);
 
-  auto* result = ctx_.type_mgr().Get(std::move(ast_struct_type));
+  auto* result = ast_module_.type_mgr().Get(std::move(ast_struct_type));
   id_to_type_[type_id] = result;
   if (num_non_writable_members == members.size()) {
     read_only_struct_types_.insert(result);
@@ -932,7 +935,7 @@
     ast_storage_class = ast::StorageClass::kStorageBuffer;
     remap_buffer_block_type_.insert(type_id);
   }
-  return ctx_.type_mgr().Get(
+  return ast_module_.type_mgr().Get(
       std::make_unique<ast::type::PointerType>(ast_elem_ty, ast_storage_class));
 }
 
@@ -1062,7 +1065,7 @@
     return;
   }
   const auto name = namer_.GetName(type_id);
-  auto* ast_alias_type = ctx_.type_mgr()
+  auto* ast_alias_type = ast_module_.type_mgr()
                              .Get(std::make_unique<ast::type::AliasType>(
                                  name, ast_underlying_type))
                              ->AsAlias();
@@ -1166,7 +1169,7 @@
     auto access = read_only_struct_types_.count(type)
                       ? ast::AccessControl::kReadOnly
                       : ast::AccessControl::kReadWrite;
-    type = ctx_.type_mgr().Get(
+    type = ast_module_.type_mgr().Get(
         std::make_unique<ast::type::AccessControlType>(access, type));
   }
 
@@ -1361,7 +1364,7 @@
     const auto* mat_ty = type->AsMatrix();
     // Matrix components are columns
     auto* column_ty =
-        ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+        ast_module_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
             mat_ty->type(), mat_ty->rows()));
     ast::ExpressionList ast_components;
     for (size_t i = 0; i < mat_ty->columns(); ++i) {
@@ -1443,13 +1446,14 @@
   if (other == nullptr) {
     Fail() << "no type provided";
   }
-  auto* i32 = ctx_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+  auto* i32 =
+      ast_module_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
   if (other->IsF32() || other->IsU32() || other->IsI32()) {
     return i32;
   }
   auto* vec_ty = other->AsVector();
   if (vec_ty) {
-    return ctx_.type_mgr().Get(
+    return ast_module_.type_mgr().Get(
         std::make_unique<ast::type::VectorType>(i32, vec_ty->size()));
   }
   Fail() << "required numeric scalar or vector, but got " << other->type_name();
@@ -1462,13 +1466,14 @@
     Fail() << "no type provided";
     return nullptr;
   }
-  auto* u32 = ctx_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
+  auto* u32 =
+      ast_module_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
   if (other->IsF32() || other->IsU32() || other->IsI32()) {
     return u32;
   }
   auto* vec_ty = other->AsVector();
   if (vec_ty) {
-    return ctx_.type_mgr().Get(
+    return ast_module_.type_mgr().Get(
         std::make_unique<ast::type::VectorType>(u32, vec_ty->size()));
   }
   Fail() << "required numeric scalar or vector, but got " << other->type_name();
@@ -1628,7 +1633,7 @@
   ast::type::Type* ast_store_type = nullptr;
   if (usage.IsSampler()) {
     ast_store_type =
-        ctx_.type_mgr().Get(std::make_unique<ast::type::SamplerType>(
+        ast_module_.type_mgr().Get(std::make_unique<ast::type::SamplerType>(
             usage.IsComparisonSampler()
                 ? ast::type::SamplerKind::kComparisonSampler
                 : ast::type::SamplerKind::kSampler));
@@ -1684,16 +1689,16 @@
       // OpImage variable with an OpImage*Dref* instruction.  In WGSL we must
       // treat that as a depth texture.
       if (image_type->depth() || usage.IsDepthTexture()) {
-        ast_store_type = ctx_.type_mgr().Get(
+        ast_store_type = ast_module_.type_mgr().Get(
             std::make_unique<ast::type::DepthTextureType>(dim));
       } else if (image_type->is_multisampled()) {
         // Multisampled textures are never depth textures.
-        ast_store_type = ctx_.type_mgr().Get(
+        ast_store_type = ast_module_.type_mgr().Get(
             std::make_unique<ast::type::MultisampledTextureType>(
                 dim, ast_sampled_component_type));
       } else {
-        ast_store_type =
-            ctx_.type_mgr().Get(std::make_unique<ast::type::SampledTextureType>(
+        ast_store_type = ast_module_.type_mgr().Get(
+            std::make_unique<ast::type::SampledTextureType>(
                 dim, ast_sampled_component_type));
       }
     } else {
@@ -1726,7 +1731,7 @@
       if (format == ast::type::ImageFormat::kNone) {
         return nullptr;
       }
-      ast_store_type = ctx_.type_mgr().Get(
+      ast_store_type = ast_module_.type_mgr().Get(
           std::make_unique<ast::type::StorageTextureType>(dim, access, format));
     }
   } else {
@@ -1736,7 +1741,7 @@
     return nullptr;
   }
   // Form the pointer type.
-  return ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
+  return ast_module_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
       ast_store_type, ast::StorageClass::kUniformConstant));
 }
 
diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc
index 0b8fb24..9ab7068 100644
--- a/src/reader/wgsl/parser_impl.cc
+++ b/src/reader/wgsl/parser_impl.cc
@@ -162,8 +162,8 @@
 
 }  // namespace
 
-ParserImpl::ParserImpl(Context* ctx, Source::File const* file)
-    : ctx_(*ctx), lexer_(std::make_unique<Lexer>(file)) {}
+ParserImpl::ParserImpl(Context*, Source::File const* file)
+    : lexer_(std::make_unique<Lexer>(file)) {}
 
 ParserImpl::~ParserImpl() = default;
 
@@ -308,7 +308,7 @@
       if (!expect("struct declaration", Token::Type::kSemicolon))
         return Failure::kErrored;
 
-      auto* type = ctx_.type_mgr().Get(std::move(str.value));
+      auto* type = module_.type_mgr().Get(std::move(str.value));
       register_constructed(type->AsStruct()->name(), type);
       module_.AddConstructedType(type);
       return true;
@@ -462,8 +462,9 @@
     if (subtype.errored)
       return Failure::kErrored;
 
-    return ctx_.type_mgr().Get(std::make_unique<ast::type::SampledTextureType>(
-        dim.value, subtype.value));
+    return module_.type_mgr().Get(
+        std::make_unique<ast::type::SampledTextureType>(dim.value,
+                                                        subtype.value));
   }
 
   auto ms_dim = multisampled_texture_type();
@@ -474,7 +475,7 @@
     if (subtype.errored)
       return Failure::kErrored;
 
-    return ctx_.type_mgr().Get(
+    return module_.type_mgr().Get(
         std::make_unique<ast::type::MultisampledTextureType>(ms_dim.value,
                                                              subtype.value));
   }
@@ -489,8 +490,9 @@
     if (format.errored)
       return Failure::kErrored;
 
-    return ctx_.type_mgr().Get(std::make_unique<ast::type::StorageTextureType>(
-        storage->first, storage->second, format.value));
+    return module_.type_mgr().Get(
+        std::make_unique<ast::type::StorageTextureType>(
+            storage->first, storage->second, format.value));
   }
 
   return Failure::kNoMatch;
@@ -501,11 +503,11 @@
 //  | SAMPLER_COMPARISON
 Maybe<ast::type::Type*> ParserImpl::sampler_type() {
   if (match(Token::Type::kSampler))
-    return ctx_.type_mgr().Get(std::make_unique<ast::type::SamplerType>(
+    return module_.type_mgr().Get(std::make_unique<ast::type::SamplerType>(
         ast::type::SamplerKind::kSampler));
 
   if (match(Token::Type::kComparisonSampler))
-    return ctx_.type_mgr().Get(std::make_unique<ast::type::SamplerType>(
+    return module_.type_mgr().Get(std::make_unique<ast::type::SamplerType>(
         ast::type::SamplerKind::kComparisonSampler));
 
   return Failure::kNoMatch;
@@ -634,19 +636,19 @@
 //  | TEXTURE_DEPTH_CUBE_ARRAY
 Maybe<ast::type::Type*> ParserImpl::depth_texture_type() {
   if (match(Token::Type::kTextureDepth2d))
-    return ctx_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>(
+    return module_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>(
         ast::type::TextureDimension::k2d));
 
   if (match(Token::Type::kTextureDepth2dArray))
-    return ctx_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>(
+    return module_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>(
         ast::type::TextureDimension::k2dArray));
 
   if (match(Token::Type::kTextureDepthCube))
-    return ctx_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>(
+    return module_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>(
         ast::type::TextureDimension::kCube));
 
   if (match(Token::Type::kTextureDepthCubeArray))
-    return ctx_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>(
+    return module_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>(
         ast::type::TextureDimension::kCubeArray));
 
   return Failure::kNoMatch;
@@ -832,7 +834,7 @@
   for (auto* deco : access_decos) {
     // If we have an access control decoration then we take it and wrap our
     // type up with that decoration
-    ty = ctx_.type_mgr().Get(std::make_unique<ast::type::AccessControlType>(
+    ty = module_.type_mgr().Get(std::make_unique<ast::type::AccessControlType>(
         deco->AsAccess()->value(), ty));
   }
 
@@ -892,7 +894,7 @@
   if (!type.matched)
     return add_error(peek(), "invalid type alias");
 
-  auto* alias = ctx_.type_mgr().Get(
+  auto* alias = module_.type_mgr().Get(
       std::make_unique<ast::type::AliasType>(name.value, type.value));
   register_constructed(name.value, alias);
 
@@ -951,16 +953,16 @@
   }
 
   if (match(Token::Type::kBool))
-    return ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
+    return module_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
 
   if (match(Token::Type::kF32))
-    return ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>());
+    return module_.type_mgr().Get(std::make_unique<ast::type::F32Type>());
 
   if (match(Token::Type::kI32))
-    return ctx_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+    return module_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
 
   if (match(Token::Type::kU32))
-    return ctx_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
+    return module_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
 
   if (t.IsVec2() || t.IsVec3() || t.IsVec4()) {
     next();  // Consume the peek
@@ -1018,7 +1020,7 @@
     if (subtype.errored)
       return Failure::kErrored;
 
-    return ctx_.type_mgr().Get(
+    return module_.type_mgr().Get(
         std::make_unique<ast::type::PointerType>(subtype.value, sc.value));
   });
 }
@@ -1036,7 +1038,7 @@
   if (subtype.errored)
     return Failure::kErrored;
 
-  return ctx_.type_mgr().Get(
+  return module_.type_mgr().Get(
       std::make_unique<ast::type::VectorType>(subtype.value, count));
 }
 
@@ -1059,7 +1061,7 @@
 
     auto ty = std::make_unique<ast::type::ArrayType>(subtype.value, size);
     ty->set_decorations(std::move(decos));
-    return ctx_.type_mgr().Get(std::move(ty));
+    return module_.type_mgr().Get(std::move(ty));
   });
 }
 
@@ -1083,7 +1085,7 @@
   if (subtype.errored)
     return Failure::kErrored;
 
-  return ctx_.type_mgr().Get(
+  return module_.type_mgr().Get(
       std::make_unique<ast::type::MatrixType>(subtype.value, rows, columns));
 }
 
@@ -1252,7 +1254,7 @@
 //   | VOID
 Maybe<ast::type::Type*> ParserImpl::function_type_decl() {
   if (match(Token::Type::kVoid))
-    return ctx_.type_mgr().Get(std::make_unique<ast::type::VoidType>());
+    return module_.type_mgr().Get(std::make_unique<ast::type::VoidType>());
 
   return type_decl();
 }
@@ -2611,23 +2613,25 @@
 Maybe<ast::Literal*> ParserImpl::const_literal() {
   auto t = peek();
   if (match(Token::Type::kTrue)) {
-    auto* type = ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
+    auto* type =
+        module_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
     return create<ast::BoolLiteral>(type, true);
   }
   if (match(Token::Type::kFalse)) {
-    auto* type = ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
+    auto* type =
+        module_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
     return create<ast::BoolLiteral>(type, false);
   }
   if (match(Token::Type::kSintLiteral)) {
-    auto* type = ctx_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+    auto* type = module_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
     return create<ast::SintLiteral>(type, t.to_i32());
   }
   if (match(Token::Type::kUintLiteral)) {
-    auto* type = ctx_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
+    auto* type = module_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
     return create<ast::UintLiteral>(type, t.to_u32());
   }
   if (match(Token::Type::kFloatLiteral)) {
-    auto* type = ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>());
+    auto* type = module_.type_mgr().Get(std::make_unique<ast::type::F32Type>());
     return create<ast::FloatLiteral>(type, t.to_f32());
   }
   return Failure::kNoMatch;
diff --git a/src/reader/wgsl/parser_impl.h b/src/reader/wgsl/parser_impl.h
index ab476c9..9b6a841 100644
--- a/src/reader/wgsl/parser_impl.h
+++ b/src/reader/wgsl/parser_impl.h
@@ -251,6 +251,9 @@
   /// @returns the module. The module in the parser will be reset after this.
   ast::Module module() { return std::move(module_); }
 
+  /// @returns a pointer to the module, without resetting it.
+  ast::Module& get_module() { return module_; }
+
   /// @returns the next token
   Token next();
   /// @returns the next token without advancing
@@ -768,7 +771,6 @@
     return module_.create<T>(std::forward<ARGS>(args)...);
   }
 
-  Context& ctx_;
   diag::List diags_;
   std::unique_ptr<Lexer> lexer_;
   std::deque<Token> token_queue_;
diff --git a/src/reader/wgsl/parser_impl_function_type_decl_test.cc b/src/reader/wgsl/parser_impl_function_type_decl_test.cc
index 58bf1e4..ad4d362 100644
--- a/src/reader/wgsl/parser_impl_function_type_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_function_type_decl_test.cc
@@ -27,9 +27,11 @@
 namespace {
 
 TEST_F(ParserImplTest, FunctionTypeDecl_Void) {
-  auto* v = tm()->Get(std::make_unique<ast::type::VoidType>());
-
   auto p = parser("void");
+
+  auto& mod = p->get_module();
+  auto* v = mod.type_mgr().Get(std::make_unique<ast::type::VoidType>());
+
   auto e = p->function_type_decl();
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
@@ -38,10 +40,13 @@
 }
 
 TEST_F(ParserImplTest, FunctionTypeDecl_Type) {
-  auto* f32 = tm()->Get(std::make_unique<ast::type::F32Type>());
-  auto* vec2 = tm()->Get(std::make_unique<ast::type::VectorType>(f32, 2));
-
   auto p = parser("vec2<f32>");
+
+  auto& mod = p->get_module();
+  auto* f32 = mod.type_mgr().Get(std::make_unique<ast::type::F32Type>());
+  auto* vec2 =
+      mod.type_mgr().Get(std::make_unique<ast::type::VectorType>(f32, 2));
+
   auto e = p->function_type_decl();
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
diff --git a/src/reader/wgsl/parser_impl_global_decl_test.cc b/src/reader/wgsl/parser_impl_global_decl_test.cc
index 906efe5..956aaf8 100644
--- a/src/reader/wgsl/parser_impl_global_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_global_decl_test.cc
@@ -34,7 +34,7 @@
   p->expect_global_decl();
   ASSERT_FALSE(p->has_error()) << p->error();
 
-  auto m = p->module();
+  auto& m = p->get_module();
   ASSERT_EQ(m.global_variables().size(), 1u);
 
   auto* v = m.global_variables()[0];
@@ -60,7 +60,7 @@
   p->expect_global_decl();
   ASSERT_FALSE(p->has_error()) << p->error();
 
-  auto m = p->module();
+  auto& m = p->get_module();
   ASSERT_EQ(m.global_variables().size(), 1u);
 
   auto* v = m.global_variables()[0];
@@ -86,7 +86,7 @@
   p->expect_global_decl();
   ASSERT_FALSE(p->has_error()) << p->error();
 
-  auto m = p->module();
+  auto& m = p->get_module();
   ASSERT_EQ(m.constructed_types().size(), 1u);
   ASSERT_TRUE(m.constructed_types()[0]->IsAlias());
   EXPECT_EQ(m.constructed_types()[0]->AsAlias()->name(), "A");
@@ -101,7 +101,7 @@
   p->expect_global_decl();
   ASSERT_FALSE(p->has_error()) << p->error();
 
-  auto m = p->module();
+  auto& m = p->get_module();
   ASSERT_EQ(m.constructed_types().size(), 2u);
   ASSERT_TRUE(m.constructed_types()[0]->IsStruct());
   auto* str = m.constructed_types()[0]->AsStruct();
@@ -132,7 +132,7 @@
   p->expect_global_decl();
   ASSERT_FALSE(p->has_error()) << p->error();
 
-  auto m = p->module();
+  auto& m = p->get_module();
   ASSERT_EQ(m.functions().size(), 1u);
   EXPECT_EQ(m.functions()[0]->name(), "main");
 }
@@ -142,7 +142,7 @@
   p->expect_global_decl();
   ASSERT_FALSE(p->has_error()) << p->error();
 
-  auto m = p->module();
+  auto& m = p->get_module();
   ASSERT_EQ(m.functions().size(), 1u);
   EXPECT_EQ(m.functions()[0]->name(), "main");
 }
@@ -159,7 +159,7 @@
   p->expect_global_decl();
   ASSERT_FALSE(p->has_error()) << p->error();
 
-  auto m = p->module();
+  auto& m = p->get_module();
   ASSERT_EQ(m.constructed_types().size(), 1u);
 
   auto* t = m.constructed_types()[0];
@@ -174,10 +174,11 @@
 TEST_F(ParserImplTest, GlobalDecl_Struct_WithStride) {
   auto p =
       parser("struct A { [[offset(0)]] data: [[stride(4)]] array<f32>; };");
+
   p->expect_global_decl();
   ASSERT_FALSE(p->has_error()) << p->error();
 
-  auto m = p->module();
+  auto& m = p->get_module();
   ASSERT_EQ(m.constructed_types().size(), 1u);
 
   auto* t = m.constructed_types()[0];
@@ -201,7 +202,7 @@
   p->expect_global_decl();
   ASSERT_FALSE(p->has_error()) << p->error();
 
-  auto m = p->module();
+  auto& m = p->get_module();
   ASSERT_EQ(m.constructed_types().size(), 1u);
 
   auto* t = m.constructed_types()[0];
diff --git a/src/reader/wgsl/parser_impl_param_list_test.cc b/src/reader/wgsl/parser_impl_param_list_test.cc
index b1b2bf5..2aa8737 100644
--- a/src/reader/wgsl/parser_impl_param_list_test.cc
+++ b/src/reader/wgsl/parser_impl_param_list_test.cc
@@ -28,9 +28,11 @@
 namespace {
 
 TEST_F(ParserImplTest, ParamList_Single) {
-  auto* i32 = tm()->Get(std::make_unique<ast::type::I32Type>());
-
   auto p = parser("a : i32");
+
+  auto& mod = p->get_module();
+  auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+
   auto e = p->expect_param_list();
   ASSERT_FALSE(p->has_error()) << p->error();
   ASSERT_FALSE(e.errored);
@@ -47,11 +49,14 @@
 }
 
 TEST_F(ParserImplTest, ParamList_Multiple) {
-  auto* i32 = tm()->Get(std::make_unique<ast::type::I32Type>());
-  auto* f32 = tm()->Get(std::make_unique<ast::type::F32Type>());
-  auto* vec2 = tm()->Get(std::make_unique<ast::type::VectorType>(f32, 2));
-
   auto p = parser("a : i32, b: f32, c: vec2<f32>");
+
+  auto& mod = p->get_module();
+  auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+  auto* f32 = mod.type_mgr().Get(std::make_unique<ast::type::F32Type>());
+  auto* vec2 =
+      mod.type_mgr().Get(std::make_unique<ast::type::VectorType>(f32, 2));
+
   auto e = p->expect_param_list();
   ASSERT_FALSE(p->has_error()) << p->error();
   ASSERT_FALSE(e.errored);
diff --git a/src/reader/wgsl/parser_impl_primary_expression_test.cc b/src/reader/wgsl/parser_impl_primary_expression_test.cc
index 1a60ddc..de6d4d2 100644
--- a/src/reader/wgsl/parser_impl_primary_expression_test.cc
+++ b/src/reader/wgsl/parser_impl_primary_expression_test.cc
@@ -190,9 +190,11 @@
 }
 
 TEST_F(ParserImplTest, PrimaryExpression_Cast) {
-  auto* f32_type = tm()->Get(std::make_unique<ast::type::F32Type>());
-
   auto p = parser("f32(1)");
+
+  auto& mod = p->get_module();
+  auto* f32 = mod.type_mgr().Get(std::make_unique<ast::type::F32Type>());
+
   auto e = p->primary_expression();
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
@@ -202,7 +204,7 @@
   ASSERT_TRUE(e->AsConstructor()->IsTypeConstructor());
 
   auto* c = e->AsConstructor()->AsTypeConstructor();
-  ASSERT_EQ(c->type(), f32_type);
+  ASSERT_EQ(c->type(), f32);
   ASSERT_EQ(c->values().size(), 1u);
 
   ASSERT_TRUE(c->values()[0]->IsConstructor());
@@ -210,9 +212,11 @@
 }
 
 TEST_F(ParserImplTest, PrimaryExpression_Bitcast) {
-  auto* f32_type = tm()->Get(std::make_unique<ast::type::F32Type>());
-
   auto p = parser("bitcast<f32>(1)");
+
+  auto& mod = p->get_module();
+  auto* f32 = mod.type_mgr().Get(std::make_unique<ast::type::F32Type>());
+
   auto e = p->primary_expression();
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
@@ -221,7 +225,7 @@
   ASSERT_TRUE(e->IsBitcast());
 
   auto* c = e->AsBitcast();
-  ASSERT_EQ(c->type(), f32_type);
+  ASSERT_EQ(c->type(), f32);
 
   ASSERT_TRUE(c->expr()->IsConstructor());
   ASSERT_TRUE(c->expr()->AsConstructor()->IsScalarConstructor());
diff --git a/src/reader/wgsl/parser_impl_struct_body_decl_test.cc b/src/reader/wgsl/parser_impl_struct_body_decl_test.cc
index 49a2437..0d44e4c 100644
--- a/src/reader/wgsl/parser_impl_struct_body_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_struct_body_decl_test.cc
@@ -23,9 +23,11 @@
 namespace {
 
 TEST_F(ParserImplTest, StructBodyDecl_Parses) {
-  auto* i32 = tm()->Get(std::make_unique<ast::type::I32Type>());
-
   auto p = parser("{a : i32;}");
+
+  auto& mod = p->get_module();
+  auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+
   auto m = p->expect_struct_body_decl();
   ASSERT_FALSE(p->has_error());
   ASSERT_FALSE(m.errored);
diff --git a/src/reader/wgsl/parser_impl_struct_member_test.cc b/src/reader/wgsl/parser_impl_struct_member_test.cc
index b46b456..45d733f 100644
--- a/src/reader/wgsl/parser_impl_struct_member_test.cc
+++ b/src/reader/wgsl/parser_impl_struct_member_test.cc
@@ -24,9 +24,11 @@
 namespace {
 
 TEST_F(ParserImplTest, StructMember_Parses) {
-  auto* i32 = tm()->Get(std::make_unique<ast::type::I32Type>());
-
   auto p = parser("a : i32;");
+
+  auto& mod = p->get_module();
+  auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+
   auto decos = p->decoration_list();
   EXPECT_FALSE(decos.errored);
   EXPECT_FALSE(decos.matched);
@@ -48,9 +50,11 @@
 }
 
 TEST_F(ParserImplTest, StructMember_ParsesWithDecoration) {
-  auto* i32 = tm()->Get(std::make_unique<ast::type::I32Type>());
-
   auto p = parser("[[offset(2)]] a : i32;");
+
+  auto& mod = p->get_module();
+  auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+
   auto decos = p->decoration_list();
   EXPECT_FALSE(decos.errored);
   EXPECT_TRUE(decos.matched);
@@ -74,10 +78,12 @@
 }
 
 TEST_F(ParserImplTest, StructMember_ParsesWithMultipleDecorations) {
-  auto* i32 = tm()->Get(std::make_unique<ast::type::I32Type>());
-
   auto p = parser(R"([[offset(2)]]
 [[offset(4)]] a : i32;)");
+
+  auto& mod = p->get_module();
+  auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+
   auto decos = p->decoration_list();
   EXPECT_FALSE(decos.errored);
   EXPECT_TRUE(decos.matched);
diff --git a/src/reader/wgsl/parser_impl_test.cc b/src/reader/wgsl/parser_impl_test.cc
index 829b203..475e3c2 100644
--- a/src/reader/wgsl/parser_impl_test.cc
+++ b/src/reader/wgsl/parser_impl_test.cc
@@ -39,7 +39,7 @@
 )");
   ASSERT_TRUE(p->Parse()) << p->error();
 
-  auto m = p->module();
+  auto& m = p->get_module();
   ASSERT_EQ(1u, m.functions().size());
   ASSERT_EQ(1u, m.global_variables().size());
 }
diff --git a/src/reader/wgsl/parser_impl_test_helper.h b/src/reader/wgsl/parser_impl_test_helper.h
index db44780..73d2900 100644
--- a/src/reader/wgsl/parser_impl_test_helper.h
+++ b/src/reader/wgsl/parser_impl_test_helper.h
@@ -45,9 +45,6 @@
     return impl;
   }
 
-  /// @returns the type manager
-  ast::TypeManager* tm() { return &(ctx_.type_mgr()); }
-
  private:
   std::vector<std::unique_ptr<Source::File>> files_;
   Context ctx_;
@@ -71,9 +68,6 @@
     return impl;
   }
 
-  /// @returns the type manager
-  ast::TypeManager* tm() { return &(ctx_.type_mgr()); }
-
  private:
   std::vector<std::unique_ptr<Source::File>> files_;
   Context ctx_;
diff --git a/src/reader/wgsl/parser_impl_type_alias_test.cc b/src/reader/wgsl/parser_impl_type_alias_test.cc
index 000c115..0431ccf 100644
--- a/src/reader/wgsl/parser_impl_type_alias_test.cc
+++ b/src/reader/wgsl/parser_impl_type_alias_test.cc
@@ -26,9 +26,11 @@
 namespace {
 
 TEST_F(ParserImplTest, TypeDecl_ParsesType) {
-  auto* i32 = tm()->Get(std::make_unique<ast::type::I32Type>());
-
   auto p = parser("type a = i32");
+
+  auto& mod = p->get_module();
+  auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+
   auto t = p->type_alias();
   EXPECT_FALSE(p->has_error());
   EXPECT_FALSE(t.errored);
diff --git a/src/reader/wgsl/parser_impl_type_decl_test.cc b/src/reader/wgsl/parser_impl_type_decl_test.cc
index 32895dc..b141031 100644
--- a/src/reader/wgsl/parser_impl_type_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_type_decl_test.cc
@@ -46,10 +46,11 @@
 TEST_F(ParserImplTest, TypeDecl_Identifier) {
   auto p = parser("A");
 
-  auto* int_type = tm()->Get(std::make_unique<ast::type::I32Type>());
-  // Pre-register to make sure that it's the same type.
+  auto& mod = p->get_module();
+
+  auto* int_type = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
   auto* alias_type =
-      tm()->Get(std::make_unique<ast::type::AliasType>("A", int_type));
+      mod.type_mgr().Get(std::make_unique<ast::type::AliasType>("A", int_type));
 
   p->register_constructed("A", alias_type);
 
@@ -79,7 +80,8 @@
 TEST_F(ParserImplTest, TypeDecl_Bool) {
   auto p = parser("bool");
 
-  auto* bool_type = tm()->Get(std::make_unique<ast::type::BoolType>());
+  auto& mod = p->get_module();
+  auto* bool_type = mod.type_mgr().Get(std::make_unique<ast::type::BoolType>());
 
   auto t = p->type_decl();
   EXPECT_TRUE(t.matched);
@@ -92,7 +94,8 @@
 TEST_F(ParserImplTest, TypeDecl_F32) {
   auto p = parser("f32");
 
-  auto* float_type = tm()->Get(std::make_unique<ast::type::F32Type>());
+  auto& mod = p->get_module();
+  auto* float_type = mod.type_mgr().Get(std::make_unique<ast::type::F32Type>());
 
   auto t = p->type_decl();
   EXPECT_TRUE(t.matched);
@@ -105,7 +108,8 @@
 TEST_F(ParserImplTest, TypeDecl_I32) {
   auto p = parser("i32");
 
-  auto* int_type = tm()->Get(std::make_unique<ast::type::I32Type>());
+  auto& mod = p->get_module();
+  auto* int_type = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
 
   auto t = p->type_decl();
   EXPECT_TRUE(t.matched);
@@ -118,7 +122,8 @@
 TEST_F(ParserImplTest, TypeDecl_U32) {
   auto p = parser("u32");
 
-  auto* uint_type = tm()->Get(std::make_unique<ast::type::U32Type>());
+  auto& mod = p->get_module();
+  auto* uint_type = mod.type_mgr().Get(std::make_unique<ast::type::U32Type>());
 
   auto t = p->type_decl();
   EXPECT_TRUE(t.matched);
@@ -734,7 +739,8 @@
 TEST_F(ParserImplTest, TypeDecl_Sampler) {
   auto p = parser("sampler");
 
-  auto* type = tm()->Get(std::make_unique<ast::type::SamplerType>(
+  auto& mod = p->get_module();
+  auto* type = mod.type_mgr().Get(std::make_unique<ast::type::SamplerType>(
       ast::type::SamplerKind::kSampler));
 
   auto t = p->type_decl();
@@ -749,9 +755,11 @@
 TEST_F(ParserImplTest, TypeDecl_Texture_Old) {
   auto p = parser("texture_sampled_cube<f32>");
 
+  auto& mod = p->get_module();
   ast::type::F32Type f32;
-  auto* type = tm()->Get(std::make_unique<ast::type::SampledTextureType>(
-      ast::type::TextureDimension::kCube, &f32));
+  auto* type =
+      mod.type_mgr().Get(std::make_unique<ast::type::SampledTextureType>(
+          ast::type::TextureDimension::kCube, &f32));
 
   auto t = p->type_decl();
   EXPECT_TRUE(t.matched);
@@ -767,8 +775,10 @@
   auto p = parser("texture_cube<f32>");
 
   ast::type::F32Type f32;
-  auto* type = tm()->Get(std::make_unique<ast::type::SampledTextureType>(
-      ast::type::TextureDimension::kCube, &f32));
+  auto& mod = p->get_module();
+  auto* type =
+      mod.type_mgr().Get(std::make_unique<ast::type::SampledTextureType>(
+          ast::type::TextureDimension::kCube, &f32));
 
   auto t = p->type_decl();
   EXPECT_TRUE(t.matched);
diff --git a/src/transform/bound_array_accessors_transform.cc b/src/transform/bound_array_accessors_transform.cc
index f305c5e..b4b404e 100644
--- a/src/transform/bound_array_accessors_transform.cc
+++ b/src/transform/bound_array_accessors_transform.cc
@@ -237,7 +237,7 @@
       return false;
     }
   } else {
-    auto* u32 = ctx_->type_mgr().Get(std::make_unique<ast::type::U32Type>());
+    auto* u32 = mod_->type_mgr().Get(std::make_unique<ast::type::U32Type>());
 
     ast::ExpressionList cast_expr;
     cast_expr.push_back(expr->idx_expr());
diff --git a/src/transform/vertex_pulling_transform.cc b/src/transform/vertex_pulling_transform.cc
index d099d8f..c941df9 100644
--- a/src/transform/vertex_pulling_transform.cc
+++ b/src/transform/vertex_pulling_transform.cc
@@ -222,7 +222,7 @@
   ary_decos.push_back(create<ast::StrideDecoration>(4u, Source{}));
   internal_array->set_decorations(std::move(ary_decos));
 
-  auto* internal_array_type = ctx_->type_mgr().Get(std::move(internal_array));
+  auto* internal_array_type = mod_->type_mgr().Get(std::move(internal_array));
 
   // Creating the struct type
   ast::StructMemberList members;
@@ -236,7 +236,7 @@
   decos.push_back(create<ast::StructBlockDecoration>(Source{}));
 
   auto* struct_type =
-      ctx_->type_mgr().Get(std::make_unique<ast::type::StructType>(
+      mod_->type_mgr().Get(std::make_unique<ast::type::StructType>(
           kStructName,
           create<ast::Struct>(std::move(decos), std::move(members))));
 
@@ -411,21 +411,21 @@
   }
 
   return create<ast::TypeConstructorExpression>(
-      ctx_->type_mgr().Get(
+      mod_->type_mgr().Get(
           std::make_unique<ast::type::VectorType>(base_type, count)),
       std::move(expr_list));
 }
 
 ast::type::Type* VertexPullingTransform::GetU32Type() {
-  return ctx_->type_mgr().Get(std::make_unique<ast::type::U32Type>());
+  return mod_->type_mgr().Get(std::make_unique<ast::type::U32Type>());
 }
 
 ast::type::Type* VertexPullingTransform::GetI32Type() {
-  return ctx_->type_mgr().Get(std::make_unique<ast::type::I32Type>());
+  return mod_->type_mgr().Get(std::make_unique<ast::type::I32Type>());
 }
 
 ast::type::Type* VertexPullingTransform::GetF32Type() {
-  return ctx_->type_mgr().Get(std::make_unique<ast::type::F32Type>());
+  return mod_->type_mgr().Get(std::make_unique<ast::type::F32Type>());
 }
 
 VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;
diff --git a/src/transform/vertex_pulling_transform_test.cc b/src/transform/vertex_pulling_transform_test.cc
index 91b632b..2b31f1f 100644
--- a/src/transform/vertex_pulling_transform_test.cc
+++ b/src/transform/vertex_pulling_transform_test.cc
@@ -48,7 +48,7 @@
   void InitBasicModule() {
     auto* func = create<ast::Function>(
         "main", ast::VariableList{},
-        ctx_.type_mgr().Get(std::make_unique<ast::type::VoidType>()),
+        mod_->type_mgr().Get(std::make_unique<ast::type::VoidType>()),
         create<ast::BlockStatement>());
     func->add_decoration(
         create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}));
@@ -81,7 +81,6 @@
     mod_->AddGlobalVariable(var);
   }
 
-  Context* ctx() { return &ctx_; }
   ast::Module* mod() { return mod_.get(); }
   Manager* manager() { return manager_.get(); }
   VertexPullingTransform* transform() { return transform_; }
@@ -128,7 +127,7 @@
 TEST_F(VertexPullingTransformTest, Error_EntryPointWrongStage) {
   auto* func = create<ast::Function>(
       "main", ast::VariableList{},
-      ctx()->type_mgr().Get(std::make_unique<ast::type::VoidType>()),
+      mod()->type_mgr().Get(std::make_unique<ast::type::VoidType>()),
       create<ast::BlockStatement>());
   func->add_decoration(
       create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}));
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 6208fef..6f95f2f 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -83,7 +83,7 @@
 }
 
 bool TypeDeterminer::Determine() {
-  for (auto& iter : ctx_.type_mgr().types()) {
+  for (auto& iter : mod_->type_mgr().types()) {
     auto& type = iter.second;
     if (!type->IsTexture() || !type->AsTexture()->IsStorage()) {
       continue;
@@ -339,7 +339,7 @@
     ret = parent_type->AsVector()->type();
   } else if (parent_type->IsMatrix()) {
     auto* m = parent_type->AsMatrix();
-    ret = ctx_.type_mgr().Get(
+    ret = mod_->type_mgr().Get(
         std::make_unique<ast::type::VectorType>(m->type(), m->rows()));
   } else {
     set_error(expr->source(), "invalid parent type (" +
@@ -350,14 +350,14 @@
 
   // If we're extracting from a pointer, we return a pointer.
   if (res->IsPointer()) {
-    ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
+    ret = mod_->type_mgr().Get(std::make_unique<ast::type::PointerType>(
         ret, res->AsPointer()->storage_class()));
   } else if (parent_type->IsArray() &&
              !parent_type->AsArray()->type()->is_scalar()) {
     // If we extract a non-scalar from an array then we also get a pointer. We
     // will generate a Function storage class variable to store this
     // into.
-    ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
+    ret = mod_->type_mgr().Get(std::make_unique<ast::type::PointerType>(
         ret, ast::StorageClass::kFunction));
   }
   expr->set_result_type(ret);
@@ -523,12 +523,12 @@
   if (ident->intrinsic() == ast::Intrinsic::kAny ||
       ident->intrinsic() == ast::Intrinsic::kAll) {
     expr->func()->set_result_type(
-        ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>()));
+        mod_->type_mgr().Get(std::make_unique<ast::type::BoolType>()));
     return true;
   }
   if (ident->intrinsic() == ast::Intrinsic::kArrayLength) {
     expr->func()->set_result_type(
-        ctx_.type_mgr().Get(std::make_unique<ast::type::U32Type>()));
+        mod_->type_mgr().Get(std::make_unique<ast::type::U32Type>()));
     return true;
   }
   if (ast::intrinsic::IsFloatClassificationIntrinsic(ident->intrinsic())) {
@@ -539,12 +539,12 @@
     }
 
     auto* bool_type =
-        ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
+        mod_->type_mgr().Get(std::make_unique<ast::type::BoolType>());
 
     auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded();
     if (param_type->IsVector()) {
       expr->func()->set_result_type(
-          ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+          mod_->type_mgr().Get(std::make_unique<ast::type::VectorType>(
               bool_type, param_type->AsVector()->size())));
     } else {
       expr->func()->set_result_type(bool_type);
@@ -667,7 +667,7 @@
 
     if (texture->IsDepth()) {
       expr->func()->set_result_type(
-          ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>()));
+          mod_->type_mgr().Get(std::make_unique<ast::type::F32Type>()));
       return true;
     }
 
@@ -689,12 +689,12 @@
       return false;
     }
     expr->func()->set_result_type(
-        ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(type, 4)));
+        mod_->type_mgr().Get(std::make_unique<ast::type::VectorType>(type, 4)));
     return true;
   }
   if (ident->intrinsic() == ast::Intrinsic::kDot) {
     expr->func()->set_result_type(
-        ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>()));
+        mod_->type_mgr().Get(std::make_unique<ast::type::F32Type>()));
     return true;
   }
   if (ident->intrinsic() == ast::Intrinsic::kOuterProduct) {
@@ -712,8 +712,8 @@
     }
 
     expr->func()->set_result_type(
-        ctx_.type_mgr().Get(std::make_unique<ast::type::MatrixType>(
-            ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>()),
+        mod_->type_mgr().Get(std::make_unique<ast::type::MatrixType>(
+            mod_->type_mgr().Get(std::make_unique<ast::type::F32Type>()),
             param0_type->AsVector()->size(), param1_type->AsVector()->size())));
     return true;
   }
@@ -862,7 +862,7 @@
       expr->set_result_type(var->type());
     } else {
       expr->set_result_type(
-          ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
+          mod_->type_mgr().Get(std::make_unique<ast::type::PointerType>(
               var->type(), var->storage_class())));
     }
 
@@ -1055,7 +1055,7 @@
 
     // If we're extracting from a pointer, we return a pointer.
     if (res->IsPointer()) {
-      ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
+      ret = mod_->type_mgr().Get(std::make_unique<ast::type::PointerType>(
           ret, res->AsPointer()->storage_class()));
     }
   } else if (data_type->IsVector()) {
@@ -1067,14 +1067,14 @@
       ret = vec->type();
       // If we're extracting from a pointer, we return a pointer.
       if (res->IsPointer()) {
-        ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
+        ret = mod_->type_mgr().Get(std::make_unique<ast::type::PointerType>(
             ret, res->AsPointer()->storage_class()));
       }
     } else {
       // The vector will have a number of components equal to the length of the
       // swizzle. This assumes the validator will check that the swizzle
       // is correct.
-      ret = ctx_.type_mgr().Get(
+      ret = mod_->type_mgr().Get(
           std::make_unique<ast::type::VectorType>(vec->type(), size));
     }
   } else {
@@ -1107,11 +1107,11 @@
       expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() ||
       expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
     auto* bool_type =
-        ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
+        mod_->type_mgr().Get(std::make_unique<ast::type::BoolType>());
     auto* param_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded();
     if (param_type->IsVector()) {
       expr->set_result_type(
-          ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+          mod_->type_mgr().Get(std::make_unique<ast::type::VectorType>(
               bool_type, param_type->AsVector()->size())));
     } else {
       expr->set_result_type(bool_type);
@@ -1126,18 +1126,18 @@
     // checks having been done.
     if (lhs_type->IsMatrix() && rhs_type->IsMatrix()) {
       expr->set_result_type(
-          ctx_.type_mgr().Get(std::make_unique<ast::type::MatrixType>(
+          mod_->type_mgr().Get(std::make_unique<ast::type::MatrixType>(
               lhs_type->AsMatrix()->type(), lhs_type->AsMatrix()->rows(),
               rhs_type->AsMatrix()->columns())));
 
     } else if (lhs_type->IsMatrix() && rhs_type->IsVector()) {
       auto* mat = lhs_type->AsMatrix();
-      expr->set_result_type(ctx_.type_mgr().Get(
+      expr->set_result_type(mod_->type_mgr().Get(
           std::make_unique<ast::type::VectorType>(mat->type(), mat->rows())));
     } else if (lhs_type->IsVector() && rhs_type->IsMatrix()) {
       auto* mat = rhs_type->AsMatrix();
       expr->set_result_type(
-          ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+          mod_->type_mgr().Get(std::make_unique<ast::type::VectorType>(
               mat->type(), mat->columns())));
     } else if (lhs_type->IsMatrix()) {
       // matrix * scalar
@@ -1198,7 +1198,7 @@
     case ast::type::ImageFormat::kRgba16Uint:
     case ast::type::ImageFormat::kRgba32Uint: {
       tex->set_type(
-          ctx_.type_mgr().Get(std::make_unique<ast::type::U32Type>()));
+          mod_->type_mgr().Get(std::make_unique<ast::type::U32Type>()));
       return true;
     }
 
@@ -1215,7 +1215,7 @@
     case ast::type::ImageFormat::kRgba16Sint:
     case ast::type::ImageFormat::kRgba32Sint: {
       tex->set_type(
-          ctx_.type_mgr().Get(std::make_unique<ast::type::I32Type>()));
+          mod_->type_mgr().Get(std::make_unique<ast::type::I32Type>()));
       return true;
     }
 
@@ -1227,7 +1227,7 @@
     case ast::type::ImageFormat::kRgba16Float:
     case ast::type::ImageFormat::kRgba32Float: {
       tex->set_type(
-          ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>()));
+          mod_->type_mgr().Get(std::make_unique<ast::type::F32Type>()));
       return true;
     }
 
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 66094a2..81ec3ef 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -1787,7 +1787,7 @@
   auto coords_type = get_coords_type(dim, &i32);
 
   ast::type::Type* texture_type =
-      ctx->type_mgr().Get(std::make_unique<ast::type::StorageTextureType>(
+      mod->type_mgr().Get(std::make_unique<ast::type::StorageTextureType>(
           dim, ast::AccessControl::kReadOnly, format));
 
   ast::ExpressionList call_params;
@@ -4549,13 +4549,13 @@
   switch (param.texture_kind) {
     case ast::intrinsic::test::TextureKind::kRegular:
       Var("texture", ast::StorageClass::kNone,
-          ctx->type_mgr().Get<ast::type::SampledTextureType>(
+          mod->type_mgr().Get<ast::type::SampledTextureType>(
               param.texture_dimension, datatype));
       break;
 
     case ast::intrinsic::test::TextureKind::kDepth:
       Var("texture", ast::StorageClass::kNone,
-          ctx->type_mgr().Get<ast::type::DepthTextureType>(
+          mod->type_mgr().Get<ast::type::DepthTextureType>(
               param.texture_dimension));
       break;
   }
diff --git a/src/writer/hlsl/generator_impl_intrinsic_texture_test.cc b/src/writer/hlsl/generator_impl_intrinsic_texture_test.cc
index bc22da4..94dc341 100644
--- a/src/writer/hlsl/generator_impl_intrinsic_texture_test.cc
+++ b/src/writer/hlsl/generator_impl_intrinsic_texture_test.cc
@@ -183,13 +183,13 @@
   switch (param.texture_kind) {
     case ast::intrinsic::test::TextureKind::kRegular:
       Var("texture", ast::StorageClass::kNone,
-          ctx->type_mgr().Get<ast::type::SampledTextureType>(
+          mod->type_mgr().Get<ast::type::SampledTextureType>(
               param.texture_dimension, datatype));
       break;
 
     case ast::intrinsic::test::TextureKind::kDepth:
       Var("texture", ast::StorageClass::kNone,
-          ctx->type_mgr().Get<ast::type::DepthTextureType>(
+          mod->type_mgr().Get<ast::type::DepthTextureType>(
               param.texture_dimension));
       break;
   }
diff --git a/src/writer/spirv/builder_intrinsic_texture_test.cc b/src/writer/spirv/builder_intrinsic_texture_test.cc
index a677191..d11c80e 100644
--- a/src/writer/spirv/builder_intrinsic_texture_test.cc
+++ b/src/writer/spirv/builder_intrinsic_texture_test.cc
@@ -1618,13 +1618,13 @@
   switch (param.texture_kind) {
     case ast::intrinsic::test::TextureKind::kRegular:
       tex = Var("texture", ast::StorageClass::kNone,
-                ctx->type_mgr().Get<ast::type::SampledTextureType>(
+                mod->type_mgr().Get<ast::type::SampledTextureType>(
                     param.texture_dimension, datatype));
       break;
 
     case ast::intrinsic::test::TextureKind::kDepth:
       tex = Var("texture", ast::StorageClass::kNone,
-                ctx->type_mgr().Get<ast::type::DepthTextureType>(
+                mod->type_mgr().Get<ast::type::DepthTextureType>(
                     param.texture_dimension));
       break;
   }