Add create<T>() method to Module for types

Migrate all uses to use this and the new `unique_type<T>()` and `types()` methods.

Remove the `type_mgr()` accessor. `TypeManager` is now an implementation detail of the module, allowing us to unify the allocation of types and nodes (if we so wish).

Fixes: tint:337
Bug: tint:307
Change-Id: I233fa9dc73d60515dd721f02ea7ba089ef7d374f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/33667
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/ast/builder.cc b/src/ast/builder.cc
index 871f4ad..0b27fd9 100644
--- a/src/ast/builder.cc
+++ b/src/ast/builder.cc
@@ -17,16 +17,16 @@
 namespace tint {
 namespace ast {
 
-TypesBuilder::TypesBuilder(TypeManager* tm)
-    : bool_(tm->Get<ast::type::BoolType>()),
-      f32(tm->Get<ast::type::F32Type>()),
-      i32(tm->Get<ast::type::I32Type>()),
-      u32(tm->Get<ast::type::U32Type>()),
-      void_(tm->Get<ast::type::VoidType>()),
-      tm_(tm) {}
+TypesBuilder::TypesBuilder(Module* mod)
+    : bool_(mod->create<ast::type::BoolType>()),
+      f32(mod->create<ast::type::F32Type>()),
+      i32(mod->create<ast::type::I32Type>()),
+      u32(mod->create<ast::type::U32Type>()),
+      void_(mod->create<ast::type::VoidType>()),
+      mod_(mod) {}
 
 Builder::Builder(tint::Context* c, tint::ast::Module* m)
-    : ctx(c), mod(m), ty(&m->type_mgr()) {}
+    : ctx(c), mod(m), ty(m) {}
 Builder::~Builder() = default;
 
 ast::Variable* Builder::Var(const std::string& name,
diff --git a/src/ast/builder.h b/src/ast/builder.h
index dcee4a2..1080ead 100644
--- a/src/ast/builder.h
+++ b/src/ast/builder.h
@@ -47,8 +47,8 @@
 class TypesBuilder {
  public:
   /// Constructor
-  /// @param tm the type manager
-  explicit TypesBuilder(TypeManager* tm);
+  /// @param mod the module
+  explicit TypesBuilder(Module* mod);
 
   /// A boolean type
   ast::type::BoolType* const bool_;
@@ -70,80 +70,80 @@
   /// @return the tint AST type for a 2-element vector of the C type `T`.
   template <typename T>
   ast::type::VectorType* vec2() const {
-    return tm_->Get<ast::type::VectorType>(Of<T>(), 2);
+    return mod_->create<ast::type::VectorType>(Of<T>(), 2);
   }
 
   /// @return the tint AST type for a 3-element vector of the C type `T`.
   template <typename T>
   ast::type::VectorType* vec3() const {
-    return tm_->Get<ast::type::VectorType>(Of<T>(), 3);
+    return mod_->create<ast::type::VectorType>(Of<T>(), 3);
   }
 
   /// @return the tint AST type for a 4-element vector of the C type `T`.
   template <typename T>
   ast::type::Type* vec4() const {
-    return tm_->Get<ast::type::VectorType>(Of<T>(), 4);
+    return mod_->create<ast::type::VectorType>(Of<T>(), 4);
   }
 
   /// @return the tint AST type for a 2x3 matrix of the C type `T`.
   template <typename T>
   ast::type::MatrixType* mat2x2() const {
-    return tm_->Get<ast::type::MatrixType>(Of<T>(), 2, 2);
+    return mod_->create<ast::type::MatrixType>(Of<T>(), 2, 2);
   }
 
   /// @return the tint AST type for a 2x3 matrix of the C type `T`.
   template <typename T>
   ast::type::MatrixType* mat2x3() const {
-    return tm_->Get<ast::type::MatrixType>(Of<T>(), 3, 2);
+    return mod_->create<ast::type::MatrixType>(Of<T>(), 3, 2);
   }
 
   /// @return the tint AST type for a 2x4 matrix of the C type `T`.
   template <typename T>
   ast::type::MatrixType* mat2x4() const {
-    return tm_->Get<ast::type::MatrixType>(Of<T>(), 4, 2);
+    return mod_->create<ast::type::MatrixType>(Of<T>(), 4, 2);
   }
 
   /// @return the tint AST type for a 3x2 matrix of the C type `T`.
   template <typename T>
   ast::type::MatrixType* mat3x2() const {
-    return tm_->Get<ast::type::MatrixType>(Of<T>(), 2, 3);
+    return mod_->create<ast::type::MatrixType>(Of<T>(), 2, 3);
   }
 
   /// @return the tint AST type for a 3x3 matrix of the C type `T`.
   template <typename T>
   ast::type::MatrixType* mat3x3() const {
-    return tm_->Get<ast::type::MatrixType>(Of<T>(), 3, 3);
+    return mod_->create<ast::type::MatrixType>(Of<T>(), 3, 3);
   }
 
   /// @return the tint AST type for a 3x4 matrix of the C type `T`.
   template <typename T>
   ast::type::MatrixType* mat3x4() const {
-    return tm_->Get<ast::type::MatrixType>(Of<T>(), 4, 3);
+    return mod_->create<ast::type::MatrixType>(Of<T>(), 4, 3);
   }
 
   /// @return the tint AST type for a 4x2 matrix of the C type `T`.
   template <typename T>
   ast::type::MatrixType* mat4x2() const {
-    return tm_->Get<ast::type::MatrixType>(Of<T>(), 2, 4);
+    return mod_->create<ast::type::MatrixType>(Of<T>(), 2, 4);
   }
 
   /// @return the tint AST type for a 4x3 matrix of the C type `T`.
   template <typename T>
   ast::type::MatrixType* mat4x3() const {
-    return tm_->Get<ast::type::MatrixType>(Of<T>(), 3, 4);
+    return mod_->create<ast::type::MatrixType>(Of<T>(), 3, 4);
   }
 
   /// @return the tint AST type for a 4x4 matrix of the C type `T`.
   template <typename T>
   ast::type::MatrixType* mat4x4() const {
-    return tm_->Get<ast::type::MatrixType>(Of<T>(), 4, 4);
+    return mod_->create<ast::type::MatrixType>(Of<T>(), 4, 4);
   }
 
   /// @param subtype the array element type
   /// @param n the array size. 0 represents unbounded
   /// @return the tint AST type for a array of size `n` of type `T`
   ast::type::ArrayType* array(ast::type::Type* subtype, uint32_t n) const {
-    return tm_->Get<ast::type::ArrayType>(subtype, n);
+    return mod_->create<ast::type::ArrayType>(subtype, n);
   }
 
   /// @return the tint AST type for an array of size `N` of type `T`
@@ -161,7 +161,7 @@
   template <typename T>
   struct CToAST {};
 
-  TypeManager* const tm_;
+  Module* const mod_;
 };
 
 /// Helper for building common AST constructs.
diff --git a/src/ast/module.h b/src/ast/module.h
index db8f353..0b2aa60 100644
--- a/src/ast/module.h
+++ b/src/ast/module.h
@@ -17,6 +17,8 @@
 
 #include <memory>
 #include <string>
+#include <type_traits>
+#include <unordered_map>
 #include <utility>
 #include <vector>
 
@@ -30,6 +32,10 @@
 
 /// Represents all the source in a given program.
 class Module {
+  template <typename T, typename BASE>
+  using EnableIfIsType =
+      typename std::enable_if<std::is_base_of<BASE, T>::value, T>::type;
+
  public:
   Module();
   /// Move constructor
@@ -78,15 +84,12 @@
   /// @returns a string representation of the module
   std::string to_str() const;
 
-  /// @returns the Type Manager
-  ast::TypeManager& type_mgr() { return type_mgr_; }
-
   /// Creates a new `ast::Node` owned by the Module. When the Module is
   /// destructed, the `ast::Node` will also be destructed.
   /// @param args the arguments to pass to the type constructor
   /// @returns the node pointer
   template <typename T, typename... ARGS>
-  T* create(ARGS&&... args) {
+  EnableIfIsType<T, ast::Node>* create(ARGS&&... args) {
     static_assert(std::is_base_of<ast::Node, T>::value,
                   "T does not derive from ast::Node");
     auto uptr = std::make_unique<T>(std::forward<ARGS>(args)...);
@@ -95,6 +98,38 @@
     return ptr;
   }
 
+  /// Creates a new `ast::Type` owned by the Module.
+  /// When the Module is destructed, owned Module and the returned
+  /// `ast::Type` will also be destructed.
+  /// Types are unique (de-aliased), and so `create()` for the same `T` and
+  /// arguments will return the same pointer.
+  /// @param args the arguments to pass to the type constructor
+  /// @returns the de-aliased type pointer
+  template <typename T, typename... ARGS>
+  EnableIfIsType<T, ast::type::Type>* create(ARGS&&... args) {
+    static_assert(std::is_base_of<ast::type::Type, T>::value,
+                  "T does not derive from ast::type::Type");
+    return type_mgr_.Get<T>(std::forward<ARGS>(args)...);
+  }
+
+  /// Moves the type `ty` to the Module, returning a pointer to the unique
+  /// (de-aliased) type.
+  /// When the Module is destructed, the returned `ast::Type` will also be
+  /// destructed.
+  /// @param ty the type to add to the module
+  /// @returns the de-aliased type pointer
+  template <typename T>
+  EnableIfIsType<T, ast::type::Type>* unique_type(std::unique_ptr<T> ty) {
+    return static_cast<T*>(type_mgr_.Get(std::move(ty)));
+  }
+
+  /// Returns all the declared types in the module
+  /// @returns the mapping from name string to type.
+  const std::unordered_map<std::string, std::unique_ptr<ast::type::Type>>&
+  types() {
+    return type_mgr_.types();
+  }
+
  private:
   Module(const Module&) = delete;
 
diff --git a/src/ast/type/storage_texture_type_test.cc b/src/ast/type/storage_texture_type_test.cc
index 4447f30..1600298 100644
--- a/src/ast/type/storage_texture_type_test.cc
+++ b/src/ast/type/storage_texture_type_test.cc
@@ -80,9 +80,9 @@
 TEST_F(StorageTextureTypeTest, F32Type) {
   Context ctx;
   ast::Module mod;
-  ast::type::Type* s = mod.type_mgr().Get(std::make_unique<StorageTextureType>(
+  ast::type::Type* s = mod.create<StorageTextureType>(
       TextureDimension::k2dArray, AccessControl::kReadOnly,
-      ImageFormat::kRgba32Float));
+      ImageFormat::kRgba32Float);
   TypeDeterminer td(&ctx, &mod);
 
   ASSERT_TRUE(td.Determine()) << td.error();
@@ -94,9 +94,9 @@
 TEST_F(StorageTextureTypeTest, U32Type) {
   Context ctx;
   ast::Module mod;
-  ast::type::Type* s = mod.type_mgr().Get(std::make_unique<StorageTextureType>(
+  ast::type::Type* s = mod.create<StorageTextureType>(
       TextureDimension::k2dArray, AccessControl::kReadOnly,
-      ImageFormat::kRgba8Unorm));
+      ImageFormat::kRgba8Unorm);
   TypeDeterminer td(&ctx, &mod);
 
   ASSERT_TRUE(td.Determine()) << td.error();
@@ -108,9 +108,9 @@
 TEST_F(StorageTextureTypeTest, I32Type) {
   Context ctx;
   ast::Module mod;
-  ast::type::Type* s = mod.type_mgr().Get(std::make_unique<StorageTextureType>(
+  ast::type::Type* s = mod.create<StorageTextureType>(
       TextureDimension::k2dArray, AccessControl::kReadOnly,
-      ImageFormat::kRgba32Sint));
+      ImageFormat::kRgba32Sint);
   TypeDeterminer td(&ctx, &mod);
 
   ASSERT_TRUE(td.Determine()) << td.error();
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index a30df6f..57192d5 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -3253,8 +3253,8 @@
     const auto* ast_ptr_type = type->AsPointer();
     const auto sc = GetStorageClassForPointerValue(result_id);
     if (ast_ptr_type->storage_class() != sc) {
-      return parser_impl_.get_module().type_mgr().Get(
-          std::make_unique<ast::type::PointerType>(ast_ptr_type->type(), sc));
+      return parser_impl_.get_module().create<ast::type::PointerType>(
+          ast_ptr_type->type(), sc);
     }
   }
   return type;
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index c58a7b3..4c7e867 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -196,8 +196,7 @@
     : Reader(ctx),
       spv_binary_(spv_binary),
       fail_stream_(&success_, &errors_),
-      bool_type_(
-          ast_module_.type_mgr().Get(std::make_unique<ast::type::BoolType>())),
+      bool_type_(ast_module_.create<ast::type::BoolType>()),
       namer_(fail_stream_),
       enum_converter_(fail_stream_),
       tools_context_(kInputEnv) {
@@ -286,8 +285,7 @@
 
   switch (spirv_type->kind()) {
     case spvtools::opt::analysis::Type::kVoid:
-      return save(
-          ast_module_.type_mgr().Get(std::make_unique<ast::type::VoidType>()));
+      return save(ast_module_.create<ast::type::VoidType>());
     case spvtools::opt::analysis::Type::kBool:
       return save(bool_type_);
     case spvtools::opt::analysis::Type::kInteger:
@@ -317,8 +315,7 @@
     case spvtools::opt::analysis::Type::kImage:
       // Fake it for sampler and texture types.  These are handled in an
       // entirely different way.
-      return save(
-          ast_module_.type_mgr().Get(std::make_unique<ast::type::VoidType>()));
+      return save(ast_module_.create<ast::type::VoidType>());
     default:
       break;
   }
@@ -651,10 +648,8 @@
 ast::type::Type* ParserImpl::ConvertType(
     const spvtools::opt::analysis::Integer* int_ty) {
   if (int_ty->width() == 32) {
-    auto* signed_ty =
-        ast_module_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
-    auto* unsigned_ty =
-        ast_module_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
+    ast::type::Type* signed_ty = ast_module_.create<ast::type::I32Type>();
+    ast::type::Type* unsigned_ty = ast_module_.create<ast::type::U32Type>();
     signed_type_for_[unsigned_ty] = signed_ty;
     unsigned_type_for_[signed_ty] = unsigned_ty;
     return int_ty->IsSigned() ? signed_ty : unsigned_ty;
@@ -666,7 +661,7 @@
 ast::type::Type* ParserImpl::ConvertType(
     const spvtools::opt::analysis::Float* float_ty) {
   if (float_ty->width() == 32) {
-    return ast_module_.type_mgr().Get(std::make_unique<ast::type::F32Type>());
+    return ast_module_.create<ast::type::F32Type>();
   }
   Fail() << "unhandled float width: " << float_ty->width();
   return nullptr;
@@ -679,19 +674,17 @@
   if (ast_elem_ty == nullptr) {
     return nullptr;
   }
-  auto* this_ty = ast_module_.type_mgr().Get(
-      std::make_unique<ast::type::VectorType>(ast_elem_ty, num_elem));
+  auto* this_ty =
+      ast_module_.create<ast::type::VectorType>(ast_elem_ty, num_elem);
   // Generate the opposite-signedness vector type, if this type is integral.
   if (unsigned_type_for_.count(ast_elem_ty)) {
-    auto* other_ty =
-        ast_module_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
-            unsigned_type_for_[ast_elem_ty], num_elem));
+    auto* other_ty = ast_module_.create<ast::type::VectorType>(
+        unsigned_type_for_[ast_elem_ty], num_elem);
     signed_type_for_[other_ty] = this_ty;
     unsigned_type_for_[this_ty] = other_ty;
   } else if (signed_type_for_.count(ast_elem_ty)) {
-    auto* other_ty =
-        ast_module_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
-            signed_type_for_[ast_elem_ty], num_elem));
+    auto* other_ty = ast_module_.create<ast::type::VectorType>(
+        signed_type_for_[ast_elem_ty], num_elem);
     unsigned_type_for_[other_ty] = this_ty;
     signed_type_for_[this_ty] = other_ty;
   }
@@ -708,8 +701,8 @@
   if (ast_scalar_ty == nullptr) {
     return nullptr;
   }
-  return ast_module_.type_mgr().Get(std::make_unique<ast::type::MatrixType>(
-      ast_scalar_ty, num_rows, num_columns));
+  return ast_module_.create<ast::type::MatrixType>(ast_scalar_ty, num_rows,
+                                                   num_columns);
 }
 
 ast::type::Type* ParserImpl::ConvertType(
@@ -722,7 +715,7 @@
   if (!ApplyArrayDecorations(rtarr_ty, ast_type.get())) {
     return nullptr;
   }
-  return ast_module_.type_mgr().Get(std::move(ast_type));
+  return ast_module_.unique_type(std::move(ast_type));
 }
 
 ast::type::Type* ParserImpl::ConvertType(
@@ -767,7 +760,7 @@
   if (remap_buffer_block_type_.count(elem_type_id)) {
     remap_buffer_block_type_.insert(type_mgr_->GetId(arr_ty));
   }
-  return ast_module_.type_mgr().Get(std::move(ast_type));
+  return ast_module_.unique_type(std::move(ast_type));
 }
 
 bool ParserImpl::ApplyArrayDecorations(
@@ -892,10 +885,9 @@
                                          std::move(ast_members));
 
   namer_.SuggestSanitizedName(type_id, "S");
-  auto ast_struct_type = std::make_unique<ast::type::StructType>(
-      namer_.GetName(type_id), ast_struct);
 
-  auto* result = ast_module_.type_mgr().Get(std::move(ast_struct_type));
+  auto* result = ast_module_.create<ast::type::StructType>(
+      namer_.GetName(type_id), ast_struct);
   id_to_type_[type_id] = result;
   if (num_non_writable_members == members.size()) {
     read_only_struct_types_.insert(result);
@@ -935,8 +927,8 @@
     ast_storage_class = ast::StorageClass::kStorageBuffer;
     remap_buffer_block_type_.insert(type_id);
   }
-  return ast_module_.type_mgr().Get(
-      std::make_unique<ast::type::PointerType>(ast_elem_ty, ast_storage_class));
+  return ast_module_.create<ast::type::PointerType>(ast_elem_ty,
+                                                    ast_storage_class);
 }
 
 bool ParserImpl::RegisterTypes() {
@@ -1065,10 +1057,8 @@
     return;
   }
   const auto name = namer_.GetName(type_id);
-  auto* ast_alias_type = ast_module_.type_mgr()
-                             .Get(std::make_unique<ast::type::AliasType>(
-                                 name, ast_underlying_type))
-                             ->AsAlias();
+  auto* ast_alias_type =
+      ast_module_.create<ast::type::AliasType>(name, ast_underlying_type);
   // Record this new alias as the AST type for this SPIR-V ID.
   id_to_type_[type_id] = ast_alias_type;
   ast_module_.AddConstructedType(ast_alias_type);
@@ -1169,8 +1159,7 @@
     auto access = read_only_struct_types_.count(type)
                       ? ast::AccessControl::kReadOnly
                       : ast::AccessControl::kReadWrite;
-    type = ast_module_.type_mgr().Get(
-        std::make_unique<ast::type::AccessControlType>(access, type));
+    type = ast_module_.create<ast::type::AccessControlType>(access, type);
   }
 
   auto* ast_var = create<ast::Variable>(namer_.Name(id), sc, type);
@@ -1363,9 +1352,8 @@
   if (type->IsMatrix()) {
     const auto* mat_ty = type->AsMatrix();
     // Matrix components are columns
-    auto* column_ty =
-        ast_module_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
-            mat_ty->type(), mat_ty->rows()));
+    auto* column_ty = ast_module_.create<ast::type::VectorType>(mat_ty->type(),
+                                                                mat_ty->rows());
     ast::ExpressionList ast_components;
     for (size_t i = 0; i < mat_ty->columns(); ++i) {
       ast_components.emplace_back(MakeNullValue(column_ty));
@@ -1446,15 +1434,13 @@
   if (other == nullptr) {
     Fail() << "no type provided";
   }
-  auto* i32 =
-      ast_module_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+  auto* i32 = ast_module_.create<ast::type::I32Type>();
   if (other->IsF32() || other->IsU32() || other->IsI32()) {
     return i32;
   }
   auto* vec_ty = other->AsVector();
   if (vec_ty) {
-    return ast_module_.type_mgr().Get(
-        std::make_unique<ast::type::VectorType>(i32, vec_ty->size()));
+    return ast_module_.create<ast::type::VectorType>(i32, vec_ty->size());
   }
   Fail() << "required numeric scalar or vector, but got " << other->type_name();
   return nullptr;
@@ -1466,15 +1452,13 @@
     Fail() << "no type provided";
     return nullptr;
   }
-  auto* u32 =
-      ast_module_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
+  auto* u32 = ast_module_.create<ast::type::U32Type>();
   if (other->IsF32() || other->IsU32() || other->IsI32()) {
     return u32;
   }
   auto* vec_ty = other->AsVector();
   if (vec_ty) {
-    return ast_module_.type_mgr().Get(
-        std::make_unique<ast::type::VectorType>(u32, vec_ty->size()));
+    return ast_module_.create<ast::type::VectorType>(u32, vec_ty->size());
   }
   Fail() << "required numeric scalar or vector, but got " << other->type_name();
   return nullptr;
@@ -1632,11 +1616,9 @@
   }
   ast::type::Type* ast_store_type = nullptr;
   if (usage.IsSampler()) {
-    ast_store_type =
-        ast_module_.type_mgr().Get(std::make_unique<ast::type::SamplerType>(
-            usage.IsComparisonSampler()
-                ? ast::type::SamplerKind::kComparisonSampler
-                : ast::type::SamplerKind::kSampler));
+    ast_store_type = ast_module_.create<ast::type::SamplerType>(
+        usage.IsComparisonSampler() ? ast::type::SamplerKind::kComparisonSampler
+                                    : ast::type::SamplerKind::kSampler);
   } else if (usage.IsTexture()) {
     const auto* ptr_type = def_use_mgr_->GetDef(var.type_id());
     if (!ptr_type) {
@@ -1689,17 +1671,14 @@
       // OpImage variable with an OpImage*Dref* instruction.  In WGSL we must
       // treat that as a depth texture.
       if (image_type->depth() || usage.IsDepthTexture()) {
-        ast_store_type = ast_module_.type_mgr().Get(
-            std::make_unique<ast::type::DepthTextureType>(dim));
+        ast_store_type = ast_module_.create<ast::type::DepthTextureType>(dim);
       } else if (image_type->is_multisampled()) {
         // Multisampled textures are never depth textures.
-        ast_store_type = ast_module_.type_mgr().Get(
-            std::make_unique<ast::type::MultisampledTextureType>(
-                dim, ast_sampled_component_type));
+        ast_store_type = ast_module_.create<ast::type::MultisampledTextureType>(
+            dim, ast_sampled_component_type);
       } else {
-        ast_store_type = ast_module_.type_mgr().Get(
-            std::make_unique<ast::type::SampledTextureType>(
-                dim, ast_sampled_component_type));
+        ast_store_type = ast_module_.create<ast::type::SampledTextureType>(
+            dim, ast_sampled_component_type);
       }
     } else {
       // Make a storage texture.
@@ -1731,8 +1710,8 @@
       if (format == ast::type::ImageFormat::kNone) {
         return nullptr;
       }
-      ast_store_type = ast_module_.type_mgr().Get(
-          std::make_unique<ast::type::StorageTextureType>(dim, access, format));
+      ast_store_type = ast_module_.create<ast::type::StorageTextureType>(
+          dim, access, format);
     }
   } else {
     Fail() << "unsupported: UniformConstant variable is not a recognized "
@@ -1741,8 +1720,8 @@
     return nullptr;
   }
   // Form the pointer type.
-  return ast_module_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
-      ast_store_type, ast::StorageClass::kUniformConstant));
+  return ast_module_.create<ast::type::PointerType>(
+      ast_store_type, ast::StorageClass::kUniformConstant);
 }
 
 bool ParserImpl::RegisterHandleUsage() {
diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc
index 9ab7068..f914bd4 100644
--- a/src/reader/wgsl/parser_impl.cc
+++ b/src/reader/wgsl/parser_impl.cc
@@ -308,7 +308,7 @@
       if (!expect("struct declaration", Token::Type::kSemicolon))
         return Failure::kErrored;
 
-      auto* type = module_.type_mgr().Get(std::move(str.value));
+      auto* type = module_.unique_type(std::move(str.value));
       register_constructed(type->AsStruct()->name(), type);
       module_.AddConstructedType(type);
       return true;
@@ -462,9 +462,8 @@
     if (subtype.errored)
       return Failure::kErrored;
 
-    return module_.type_mgr().Get(
-        std::make_unique<ast::type::SampledTextureType>(dim.value,
-                                                        subtype.value));
+    return module_.create<ast::type::SampledTextureType>(dim.value,
+                                                         subtype.value);
   }
 
   auto ms_dim = multisampled_texture_type();
@@ -475,9 +474,8 @@
     if (subtype.errored)
       return Failure::kErrored;
 
-    return module_.type_mgr().Get(
-        std::make_unique<ast::type::MultisampledTextureType>(ms_dim.value,
-                                                             subtype.value));
+    return module_.create<ast::type::MultisampledTextureType>(ms_dim.value,
+                                                              subtype.value);
   }
 
   auto storage = storage_texture_type();
@@ -490,9 +488,8 @@
     if (format.errored)
       return Failure::kErrored;
 
-    return module_.type_mgr().Get(
-        std::make_unique<ast::type::StorageTextureType>(
-            storage->first, storage->second, format.value));
+    return module_.create<ast::type::StorageTextureType>(
+        storage->first, storage->second, format.value);
   }
 
   return Failure::kNoMatch;
@@ -503,12 +500,12 @@
 //  | SAMPLER_COMPARISON
 Maybe<ast::type::Type*> ParserImpl::sampler_type() {
   if (match(Token::Type::kSampler))
-    return module_.type_mgr().Get(std::make_unique<ast::type::SamplerType>(
-        ast::type::SamplerKind::kSampler));
+    return module_.create<ast::type::SamplerType>(
+        ast::type::SamplerKind::kSampler);
 
   if (match(Token::Type::kComparisonSampler))
-    return module_.type_mgr().Get(std::make_unique<ast::type::SamplerType>(
-        ast::type::SamplerKind::kComparisonSampler));
+    return module_.create<ast::type::SamplerType>(
+        ast::type::SamplerKind::kComparisonSampler);
 
   return Failure::kNoMatch;
 }
@@ -636,20 +633,20 @@
 //  | TEXTURE_DEPTH_CUBE_ARRAY
 Maybe<ast::type::Type*> ParserImpl::depth_texture_type() {
   if (match(Token::Type::kTextureDepth2d))
-    return module_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>(
-        ast::type::TextureDimension::k2d));
+    return module_.create<ast::type::DepthTextureType>(
+        ast::type::TextureDimension::k2d);
 
   if (match(Token::Type::kTextureDepth2dArray))
-    return module_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>(
-        ast::type::TextureDimension::k2dArray));
+    return module_.create<ast::type::DepthTextureType>(
+        ast::type::TextureDimension::k2dArray);
 
   if (match(Token::Type::kTextureDepthCube))
-    return module_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>(
-        ast::type::TextureDimension::kCube));
+    return module_.create<ast::type::DepthTextureType>(
+        ast::type::TextureDimension::kCube);
 
   if (match(Token::Type::kTextureDepthCubeArray))
-    return module_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>(
-        ast::type::TextureDimension::kCubeArray));
+    return module_.create<ast::type::DepthTextureType>(
+        ast::type::TextureDimension::kCubeArray);
 
   return Failure::kNoMatch;
 }
@@ -834,8 +831,8 @@
   for (auto* deco : access_decos) {
     // If we have an access control decoration then we take it and wrap our
     // type up with that decoration
-    ty = module_.type_mgr().Get(std::make_unique<ast::type::AccessControlType>(
-        deco->AsAccess()->value(), ty));
+    ty = module_.create<ast::type::AccessControlType>(deco->AsAccess()->value(),
+                                                      ty);
   }
 
   return TypedIdentifier{ty, ident.value, ident.source};
@@ -894,8 +891,7 @@
   if (!type.matched)
     return add_error(peek(), "invalid type alias");
 
-  auto* alias = module_.type_mgr().Get(
-      std::make_unique<ast::type::AliasType>(name.value, type.value));
+  auto* alias = module_.create<ast::type::AliasType>(name.value, type.value);
   register_constructed(name.value, alias);
 
   return alias->AsAlias();
@@ -953,16 +949,16 @@
   }
 
   if (match(Token::Type::kBool))
-    return module_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
+    return module_.create<ast::type::BoolType>();
 
   if (match(Token::Type::kF32))
-    return module_.type_mgr().Get(std::make_unique<ast::type::F32Type>());
+    return module_.create<ast::type::F32Type>();
 
   if (match(Token::Type::kI32))
-    return module_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+    return module_.create<ast::type::I32Type>();
 
   if (match(Token::Type::kU32))
-    return module_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
+    return module_.create<ast::type::U32Type>();
 
   if (t.IsVec2() || t.IsVec3() || t.IsVec4()) {
     next();  // Consume the peek
@@ -1020,8 +1016,7 @@
     if (subtype.errored)
       return Failure::kErrored;
 
-    return module_.type_mgr().Get(
-        std::make_unique<ast::type::PointerType>(subtype.value, sc.value));
+    return module_.create<ast::type::PointerType>(subtype.value, sc.value);
   });
 }
 
@@ -1038,8 +1033,7 @@
   if (subtype.errored)
     return Failure::kErrored;
 
-  return module_.type_mgr().Get(
-      std::make_unique<ast::type::VectorType>(subtype.value, count));
+  return module_.create<ast::type::VectorType>(subtype.value, count);
 }
 
 Expect<ast::type::Type*> ParserImpl::expect_type_decl_array(
@@ -1061,7 +1055,7 @@
 
     auto ty = std::make_unique<ast::type::ArrayType>(subtype.value, size);
     ty->set_decorations(std::move(decos));
-    return module_.type_mgr().Get(std::move(ty));
+    return module_.unique_type(std::move(ty));
   });
 }
 
@@ -1085,8 +1079,7 @@
   if (subtype.errored)
     return Failure::kErrored;
 
-  return module_.type_mgr().Get(
-      std::make_unique<ast::type::MatrixType>(subtype.value, rows, columns));
+  return module_.create<ast::type::MatrixType>(subtype.value, rows, columns);
 }
 
 // storage_class
@@ -1254,7 +1247,7 @@
 //   | VOID
 Maybe<ast::type::Type*> ParserImpl::function_type_decl() {
   if (match(Token::Type::kVoid))
-    return module_.type_mgr().Get(std::make_unique<ast::type::VoidType>());
+    return module_.create<ast::type::VoidType>();
 
   return type_decl();
 }
@@ -2613,25 +2606,23 @@
 Maybe<ast::Literal*> ParserImpl::const_literal() {
   auto t = peek();
   if (match(Token::Type::kTrue)) {
-    auto* type =
-        module_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
+    auto* type = module_.create<ast::type::BoolType>();
     return create<ast::BoolLiteral>(type, true);
   }
   if (match(Token::Type::kFalse)) {
-    auto* type =
-        module_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
+    auto* type = module_.create<ast::type::BoolType>();
     return create<ast::BoolLiteral>(type, false);
   }
   if (match(Token::Type::kSintLiteral)) {
-    auto* type = module_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+    auto* type = module_.create<ast::type::I32Type>();
     return create<ast::SintLiteral>(type, t.to_i32());
   }
   if (match(Token::Type::kUintLiteral)) {
-    auto* type = module_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
+    auto* type = module_.create<ast::type::U32Type>();
     return create<ast::UintLiteral>(type, t.to_u32());
   }
   if (match(Token::Type::kFloatLiteral)) {
-    auto* type = module_.type_mgr().Get(std::make_unique<ast::type::F32Type>());
+    auto* type = module_.create<ast::type::F32Type>();
     return create<ast::FloatLiteral>(type, t.to_f32());
   }
   return Failure::kNoMatch;
diff --git a/src/reader/wgsl/parser_impl_function_type_decl_test.cc b/src/reader/wgsl/parser_impl_function_type_decl_test.cc
index ad4d362..64b1624 100644
--- a/src/reader/wgsl/parser_impl_function_type_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_function_type_decl_test.cc
@@ -30,7 +30,7 @@
   auto p = parser("void");
 
   auto& mod = p->get_module();
-  auto* v = mod.type_mgr().Get(std::make_unique<ast::type::VoidType>());
+  auto* v = mod.create<ast::type::VoidType>();
 
   auto e = p->function_type_decl();
   EXPECT_TRUE(e.matched);
@@ -43,9 +43,8 @@
   auto p = parser("vec2<f32>");
 
   auto& mod = p->get_module();
-  auto* f32 = mod.type_mgr().Get(std::make_unique<ast::type::F32Type>());
-  auto* vec2 =
-      mod.type_mgr().Get(std::make_unique<ast::type::VectorType>(f32, 2));
+  auto* f32 = mod.create<ast::type::F32Type>();
+  auto* vec2 = mod.create<ast::type::VectorType>(f32, 2);
 
   auto e = p->function_type_decl();
   EXPECT_TRUE(e.matched);
diff --git a/src/reader/wgsl/parser_impl_param_list_test.cc b/src/reader/wgsl/parser_impl_param_list_test.cc
index 2aa8737..de27bda 100644
--- a/src/reader/wgsl/parser_impl_param_list_test.cc
+++ b/src/reader/wgsl/parser_impl_param_list_test.cc
@@ -31,7 +31,7 @@
   auto p = parser("a : i32");
 
   auto& mod = p->get_module();
-  auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+  auto* i32 = mod.create<ast::type::I32Type>();
 
   auto e = p->expect_param_list();
   ASSERT_FALSE(p->has_error()) << p->error();
@@ -52,10 +52,9 @@
   auto p = parser("a : i32, b: f32, c: vec2<f32>");
 
   auto& mod = p->get_module();
-  auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
-  auto* f32 = mod.type_mgr().Get(std::make_unique<ast::type::F32Type>());
-  auto* vec2 =
-      mod.type_mgr().Get(std::make_unique<ast::type::VectorType>(f32, 2));
+  auto* i32 = mod.create<ast::type::I32Type>();
+  auto* f32 = mod.create<ast::type::F32Type>();
+  auto* vec2 = mod.create<ast::type::VectorType>(f32, 2);
 
   auto e = p->expect_param_list();
   ASSERT_FALSE(p->has_error()) << p->error();
diff --git a/src/reader/wgsl/parser_impl_primary_expression_test.cc b/src/reader/wgsl/parser_impl_primary_expression_test.cc
index de6d4d2..24a9c59 100644
--- a/src/reader/wgsl/parser_impl_primary_expression_test.cc
+++ b/src/reader/wgsl/parser_impl_primary_expression_test.cc
@@ -193,7 +193,7 @@
   auto p = parser("f32(1)");
 
   auto& mod = p->get_module();
-  auto* f32 = mod.type_mgr().Get(std::make_unique<ast::type::F32Type>());
+  auto* f32 = mod.create<ast::type::F32Type>();
 
   auto e = p->primary_expression();
   EXPECT_TRUE(e.matched);
@@ -215,7 +215,7 @@
   auto p = parser("bitcast<f32>(1)");
 
   auto& mod = p->get_module();
-  auto* f32 = mod.type_mgr().Get(std::make_unique<ast::type::F32Type>());
+  auto* f32 = mod.create<ast::type::F32Type>();
 
   auto e = p->primary_expression();
   EXPECT_TRUE(e.matched);
diff --git a/src/reader/wgsl/parser_impl_struct_body_decl_test.cc b/src/reader/wgsl/parser_impl_struct_body_decl_test.cc
index 0d44e4c..74c7fdf 100644
--- a/src/reader/wgsl/parser_impl_struct_body_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_struct_body_decl_test.cc
@@ -26,7 +26,7 @@
   auto p = parser("{a : i32;}");
 
   auto& mod = p->get_module();
-  auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+  auto* i32 = mod.create<ast::type::I32Type>();
 
   auto m = p->expect_struct_body_decl();
   ASSERT_FALSE(p->has_error());
diff --git a/src/reader/wgsl/parser_impl_struct_member_test.cc b/src/reader/wgsl/parser_impl_struct_member_test.cc
index 45d733f..4e98980 100644
--- a/src/reader/wgsl/parser_impl_struct_member_test.cc
+++ b/src/reader/wgsl/parser_impl_struct_member_test.cc
@@ -27,7 +27,7 @@
   auto p = parser("a : i32;");
 
   auto& mod = p->get_module();
-  auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+  auto* i32 = mod.create<ast::type::I32Type>();
 
   auto decos = p->decoration_list();
   EXPECT_FALSE(decos.errored);
@@ -53,7 +53,7 @@
   auto p = parser("[[offset(2)]] a : i32;");
 
   auto& mod = p->get_module();
-  auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+  auto* i32 = mod.create<ast::type::I32Type>();
 
   auto decos = p->decoration_list();
   EXPECT_FALSE(decos.errored);
@@ -82,7 +82,7 @@
 [[offset(4)]] a : i32;)");
 
   auto& mod = p->get_module();
-  auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+  auto* i32 = mod.create<ast::type::I32Type>();
 
   auto decos = p->decoration_list();
   EXPECT_FALSE(decos.errored);
diff --git a/src/reader/wgsl/parser_impl_type_alias_test.cc b/src/reader/wgsl/parser_impl_type_alias_test.cc
index 0431ccf..e224fef 100644
--- a/src/reader/wgsl/parser_impl_type_alias_test.cc
+++ b/src/reader/wgsl/parser_impl_type_alias_test.cc
@@ -29,7 +29,7 @@
   auto p = parser("type a = i32");
 
   auto& mod = p->get_module();
-  auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+  auto* i32 = mod.create<ast::type::I32Type>();
 
   auto t = p->type_alias();
   EXPECT_FALSE(p->has_error());
diff --git a/src/reader/wgsl/parser_impl_type_decl_test.cc b/src/reader/wgsl/parser_impl_type_decl_test.cc
index b141031..50d9bd4 100644
--- a/src/reader/wgsl/parser_impl_type_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_type_decl_test.cc
@@ -48,9 +48,8 @@
 
   auto& mod = p->get_module();
 
-  auto* int_type = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
-  auto* alias_type =
-      mod.type_mgr().Get(std::make_unique<ast::type::AliasType>("A", int_type));
+  auto* int_type = mod.create<ast::type::I32Type>();
+  auto* alias_type = mod.create<ast::type::AliasType>("A", int_type);
 
   p->register_constructed("A", alias_type);
 
@@ -81,7 +80,7 @@
   auto p = parser("bool");
 
   auto& mod = p->get_module();
-  auto* bool_type = mod.type_mgr().Get(std::make_unique<ast::type::BoolType>());
+  auto* bool_type = mod.create<ast::type::BoolType>();
 
   auto t = p->type_decl();
   EXPECT_TRUE(t.matched);
@@ -95,7 +94,7 @@
   auto p = parser("f32");
 
   auto& mod = p->get_module();
-  auto* float_type = mod.type_mgr().Get(std::make_unique<ast::type::F32Type>());
+  auto* float_type = mod.create<ast::type::F32Type>();
 
   auto t = p->type_decl();
   EXPECT_TRUE(t.matched);
@@ -109,7 +108,7 @@
   auto p = parser("i32");
 
   auto& mod = p->get_module();
-  auto* int_type = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+  auto* int_type = mod.create<ast::type::I32Type>();
 
   auto t = p->type_decl();
   EXPECT_TRUE(t.matched);
@@ -123,7 +122,7 @@
   auto p = parser("u32");
 
   auto& mod = p->get_module();
-  auto* uint_type = mod.type_mgr().Get(std::make_unique<ast::type::U32Type>());
+  auto* uint_type = mod.create<ast::type::U32Type>();
 
   auto t = p->type_decl();
   EXPECT_TRUE(t.matched);
@@ -740,8 +739,8 @@
   auto p = parser("sampler");
 
   auto& mod = p->get_module();
-  auto* type = mod.type_mgr().Get(std::make_unique<ast::type::SamplerType>(
-      ast::type::SamplerKind::kSampler));
+  auto* type =
+      mod.create<ast::type::SamplerType>(ast::type::SamplerKind::kSampler);
 
   auto t = p->type_decl();
   EXPECT_TRUE(t.matched);
@@ -757,9 +756,8 @@
 
   auto& mod = p->get_module();
   ast::type::F32Type f32;
-  auto* type =
-      mod.type_mgr().Get(std::make_unique<ast::type::SampledTextureType>(
-          ast::type::TextureDimension::kCube, &f32));
+  auto* type = mod.create<ast::type::SampledTextureType>(
+      ast::type::TextureDimension::kCube, &f32);
 
   auto t = p->type_decl();
   EXPECT_TRUE(t.matched);
@@ -776,9 +774,8 @@
 
   ast::type::F32Type f32;
   auto& mod = p->get_module();
-  auto* type =
-      mod.type_mgr().Get(std::make_unique<ast::type::SampledTextureType>(
-          ast::type::TextureDimension::kCube, &f32));
+  auto* type = mod.create<ast::type::SampledTextureType>(
+      ast::type::TextureDimension::kCube, &f32);
 
   auto t = p->type_decl();
   EXPECT_TRUE(t.matched);
diff --git a/src/transform/bound_array_accessors_transform.cc b/src/transform/bound_array_accessors_transform.cc
index b4b404e..b41d02b 100644
--- a/src/transform/bound_array_accessors_transform.cc
+++ b/src/transform/bound_array_accessors_transform.cc
@@ -237,7 +237,7 @@
       return false;
     }
   } else {
-    auto* u32 = mod_->type_mgr().Get(std::make_unique<ast::type::U32Type>());
+    auto* u32 = mod_->create<ast::type::U32Type>();
 
     ast::ExpressionList cast_expr;
     cast_expr.push_back(expr->idx_expr());
diff --git a/src/transform/vertex_pulling_transform.cc b/src/transform/vertex_pulling_transform.cc
index c941df9..4d7a6e3 100644
--- a/src/transform/vertex_pulling_transform.cc
+++ b/src/transform/vertex_pulling_transform.cc
@@ -222,7 +222,7 @@
   ary_decos.push_back(create<ast::StrideDecoration>(4u, Source{}));
   internal_array->set_decorations(std::move(ary_decos));
 
-  auto* internal_array_type = mod_->type_mgr().Get(std::move(internal_array));
+  auto* internal_array_type = mod_->unique_type(std::move(internal_array));
 
   // Creating the struct type
   ast::StructMemberList members;
@@ -235,10 +235,8 @@
   ast::StructDecorationList decos;
   decos.push_back(create<ast::StructBlockDecoration>(Source{}));
 
-  auto* struct_type =
-      mod_->type_mgr().Get(std::make_unique<ast::type::StructType>(
-          kStructName,
-          create<ast::Struct>(std::move(decos), std::move(members))));
+  auto* struct_type = mod_->create<ast::type::StructType>(
+      kStructName, create<ast::Struct>(std::move(decos), std::move(members)));
 
   for (uint32_t i = 0; i < vertex_state_->vertex_buffers.size(); ++i) {
     // The decorated variable with struct type
@@ -411,21 +409,20 @@
   }
 
   return create<ast::TypeConstructorExpression>(
-      mod_->type_mgr().Get(
-          std::make_unique<ast::type::VectorType>(base_type, count)),
+      mod_->create<ast::type::VectorType>(base_type, count),
       std::move(expr_list));
 }
 
 ast::type::Type* VertexPullingTransform::GetU32Type() {
-  return mod_->type_mgr().Get(std::make_unique<ast::type::U32Type>());
+  return mod_->create<ast::type::U32Type>();
 }
 
 ast::type::Type* VertexPullingTransform::GetI32Type() {
-  return mod_->type_mgr().Get(std::make_unique<ast::type::I32Type>());
+  return mod_->create<ast::type::I32Type>();
 }
 
 ast::type::Type* VertexPullingTransform::GetF32Type() {
-  return mod_->type_mgr().Get(std::make_unique<ast::type::F32Type>());
+  return mod_->create<ast::type::F32Type>();
 }
 
 VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;
diff --git a/src/transform/vertex_pulling_transform_test.cc b/src/transform/vertex_pulling_transform_test.cc
index 2b31f1f..093d7fc 100644
--- a/src/transform/vertex_pulling_transform_test.cc
+++ b/src/transform/vertex_pulling_transform_test.cc
@@ -46,10 +46,9 @@
 
   // Create basic module with an entry point and vertex function
   void InitBasicModule() {
-    auto* func = create<ast::Function>(
-        "main", ast::VariableList{},
-        mod_->type_mgr().Get(std::make_unique<ast::type::VoidType>()),
-        create<ast::BlockStatement>());
+    auto* func = create<ast::Function>("main", ast::VariableList{},
+                                       mod_->create<ast::type::VoidType>(),
+                                       create<ast::BlockStatement>());
     func->add_decoration(
         create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}));
     mod()->AddFunction(func);
@@ -125,10 +124,9 @@
 }
 
 TEST_F(VertexPullingTransformTest, Error_EntryPointWrongStage) {
-  auto* func = create<ast::Function>(
-      "main", ast::VariableList{},
-      mod()->type_mgr().Get(std::make_unique<ast::type::VoidType>()),
-      create<ast::BlockStatement>());
+  auto* func = create<ast::Function>("main", ast::VariableList{},
+                                     mod()->create<ast::type::VoidType>(),
+                                     create<ast::BlockStatement>());
   func->add_decoration(
       create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}));
   mod()->AddFunction(func);
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 6f95f2f..d980086 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -83,7 +83,7 @@
 }
 
 bool TypeDeterminer::Determine() {
-  for (auto& iter : mod_->type_mgr().types()) {
+  for (auto& iter : mod_->types()) {
     auto& type = iter.second;
     if (!type->IsTexture() || !type->AsTexture()->IsStorage()) {
       continue;
@@ -339,8 +339,7 @@
     ret = parent_type->AsVector()->type();
   } else if (parent_type->IsMatrix()) {
     auto* m = parent_type->AsMatrix();
-    ret = mod_->type_mgr().Get(
-        std::make_unique<ast::type::VectorType>(m->type(), m->rows()));
+    ret = mod_->create<ast::type::VectorType>(m->type(), m->rows());
   } else {
     set_error(expr->source(), "invalid parent type (" +
                                   parent_type->type_name() +
@@ -350,15 +349,15 @@
 
   // If we're extracting from a pointer, we return a pointer.
   if (res->IsPointer()) {
-    ret = mod_->type_mgr().Get(std::make_unique<ast::type::PointerType>(
-        ret, res->AsPointer()->storage_class()));
+    ret = mod_->create<ast::type::PointerType>(
+        ret, res->AsPointer()->storage_class());
   } else if (parent_type->IsArray() &&
              !parent_type->AsArray()->type()->is_scalar()) {
     // If we extract a non-scalar from an array then we also get a pointer. We
     // will generate a Function storage class variable to store this
     // into.
-    ret = mod_->type_mgr().Get(std::make_unique<ast::type::PointerType>(
-        ret, ast::StorageClass::kFunction));
+    ret =
+        mod_->create<ast::type::PointerType>(ret, ast::StorageClass::kFunction);
   }
   expr->set_result_type(ret);
 
@@ -522,13 +521,11 @@
   }
   if (ident->intrinsic() == ast::Intrinsic::kAny ||
       ident->intrinsic() == ast::Intrinsic::kAll) {
-    expr->func()->set_result_type(
-        mod_->type_mgr().Get(std::make_unique<ast::type::BoolType>()));
+    expr->func()->set_result_type(mod_->create<ast::type::BoolType>());
     return true;
   }
   if (ident->intrinsic() == ast::Intrinsic::kArrayLength) {
-    expr->func()->set_result_type(
-        mod_->type_mgr().Get(std::make_unique<ast::type::U32Type>()));
+    expr->func()->set_result_type(mod_->create<ast::type::U32Type>());
     return true;
   }
   if (ast::intrinsic::IsFloatClassificationIntrinsic(ident->intrinsic())) {
@@ -538,14 +535,12 @@
       return false;
     }
 
-    auto* bool_type =
-        mod_->type_mgr().Get(std::make_unique<ast::type::BoolType>());
+    auto* bool_type = mod_->create<ast::type::BoolType>();
 
     auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded();
     if (param_type->IsVector()) {
-      expr->func()->set_result_type(
-          mod_->type_mgr().Get(std::make_unique<ast::type::VectorType>(
-              bool_type, param_type->AsVector()->size())));
+      expr->func()->set_result_type(mod_->create<ast::type::VectorType>(
+          bool_type, param_type->AsVector()->size()));
     } else {
       expr->func()->set_result_type(bool_type);
     }
@@ -666,8 +661,7 @@
         std::make_unique<ast::intrinsic::TextureSignature>(param));
 
     if (texture->IsDepth()) {
-      expr->func()->set_result_type(
-          mod_->type_mgr().Get(std::make_unique<ast::type::F32Type>()));
+      expr->func()->set_result_type(mod_->create<ast::type::F32Type>());
       return true;
     }
 
@@ -688,13 +682,11 @@
       set_error(expr->source(), "unknown texture type for texture sampling");
       return false;
     }
-    expr->func()->set_result_type(
-        mod_->type_mgr().Get(std::make_unique<ast::type::VectorType>(type, 4)));
+    expr->func()->set_result_type(mod_->create<ast::type::VectorType>(type, 4));
     return true;
   }
   if (ident->intrinsic() == ast::Intrinsic::kDot) {
-    expr->func()->set_result_type(
-        mod_->type_mgr().Get(std::make_unique<ast::type::F32Type>()));
+    expr->func()->set_result_type(mod_->create<ast::type::F32Type>());
     return true;
   }
   if (ident->intrinsic() == ast::Intrinsic::kOuterProduct) {
@@ -711,10 +703,9 @@
       return false;
     }
 
-    expr->func()->set_result_type(
-        mod_->type_mgr().Get(std::make_unique<ast::type::MatrixType>(
-            mod_->type_mgr().Get(std::make_unique<ast::type::F32Type>()),
-            param0_type->AsVector()->size(), param1_type->AsVector()->size())));
+    expr->func()->set_result_type(mod_->create<ast::type::MatrixType>(
+        mod_->create<ast::type::F32Type>(), param0_type->AsVector()->size(),
+        param1_type->AsVector()->size()));
     return true;
   }
   if (ident->intrinsic() == ast::Intrinsic::kSelect) {
@@ -861,9 +852,8 @@
     } else if (var->type()->IsPointer()) {
       expr->set_result_type(var->type());
     } else {
-      expr->set_result_type(
-          mod_->type_mgr().Get(std::make_unique<ast::type::PointerType>(
-              var->type(), var->storage_class())));
+      expr->set_result_type(mod_->create<ast::type::PointerType>(
+          var->type(), var->storage_class()));
     }
 
     set_referenced_from_function_if_needed(var);
@@ -1055,8 +1045,8 @@
 
     // If we're extracting from a pointer, we return a pointer.
     if (res->IsPointer()) {
-      ret = mod_->type_mgr().Get(std::make_unique<ast::type::PointerType>(
-          ret, res->AsPointer()->storage_class()));
+      ret = mod_->create<ast::type::PointerType>(
+          ret, res->AsPointer()->storage_class());
     }
   } else if (data_type->IsVector()) {
     auto* vec = data_type->AsVector();
@@ -1067,15 +1057,14 @@
       ret = vec->type();
       // If we're extracting from a pointer, we return a pointer.
       if (res->IsPointer()) {
-        ret = mod_->type_mgr().Get(std::make_unique<ast::type::PointerType>(
-            ret, res->AsPointer()->storage_class()));
+        ret = mod_->create<ast::type::PointerType>(
+            ret, res->AsPointer()->storage_class());
       }
     } else {
       // The vector will have a number of components equal to the length of the
       // swizzle. This assumes the validator will check that the swizzle
       // is correct.
-      ret = mod_->type_mgr().Get(
-          std::make_unique<ast::type::VectorType>(vec->type(), size));
+      ret = mod_->create<ast::type::VectorType>(vec->type(), size);
     }
   } else {
     set_error(
@@ -1106,13 +1095,11 @@
   if (expr->IsLogicalAnd() || expr->IsLogicalOr() || expr->IsEqual() ||
       expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() ||
       expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
-    auto* bool_type =
-        mod_->type_mgr().Get(std::make_unique<ast::type::BoolType>());
+    auto* bool_type = mod_->create<ast::type::BoolType>();
     auto* param_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded();
     if (param_type->IsVector()) {
-      expr->set_result_type(
-          mod_->type_mgr().Get(std::make_unique<ast::type::VectorType>(
-              bool_type, param_type->AsVector()->size())));
+      expr->set_result_type(mod_->create<ast::type::VectorType>(
+          bool_type, param_type->AsVector()->size()));
     } else {
       expr->set_result_type(bool_type);
     }
@@ -1125,20 +1112,18 @@
     // Note, the ordering here matters. The later checks depend on the prior
     // checks having been done.
     if (lhs_type->IsMatrix() && rhs_type->IsMatrix()) {
-      expr->set_result_type(
-          mod_->type_mgr().Get(std::make_unique<ast::type::MatrixType>(
-              lhs_type->AsMatrix()->type(), lhs_type->AsMatrix()->rows(),
-              rhs_type->AsMatrix()->columns())));
+      expr->set_result_type(mod_->create<ast::type::MatrixType>(
+          lhs_type->AsMatrix()->type(), lhs_type->AsMatrix()->rows(),
+          rhs_type->AsMatrix()->columns()));
 
     } else if (lhs_type->IsMatrix() && rhs_type->IsVector()) {
       auto* mat = lhs_type->AsMatrix();
-      expr->set_result_type(mod_->type_mgr().Get(
-          std::make_unique<ast::type::VectorType>(mat->type(), mat->rows())));
+      expr->set_result_type(
+          mod_->create<ast::type::VectorType>(mat->type(), mat->rows()));
     } else if (lhs_type->IsVector() && rhs_type->IsMatrix()) {
       auto* mat = rhs_type->AsMatrix();
       expr->set_result_type(
-          mod_->type_mgr().Get(std::make_unique<ast::type::VectorType>(
-              mat->type(), mat->columns())));
+          mod_->create<ast::type::VectorType>(mat->type(), mat->columns()));
     } else if (lhs_type->IsMatrix()) {
       // matrix * scalar
       expr->set_result_type(lhs_type);
@@ -1197,8 +1182,7 @@
     case ast::type::ImageFormat::kRg32Uint:
     case ast::type::ImageFormat::kRgba16Uint:
     case ast::type::ImageFormat::kRgba32Uint: {
-      tex->set_type(
-          mod_->type_mgr().Get(std::make_unique<ast::type::U32Type>()));
+      tex->set_type(mod_->create<ast::type::U32Type>());
       return true;
     }
 
@@ -1214,8 +1198,7 @@
     case ast::type::ImageFormat::kRg32Sint:
     case ast::type::ImageFormat::kRgba16Sint:
     case ast::type::ImageFormat::kRgba32Sint: {
-      tex->set_type(
-          mod_->type_mgr().Get(std::make_unique<ast::type::I32Type>()));
+      tex->set_type(mod_->create<ast::type::I32Type>());
       return true;
     }
 
@@ -1226,8 +1209,7 @@
     case ast::type::ImageFormat::kRg32Float:
     case ast::type::ImageFormat::kRgba16Float:
     case ast::type::ImageFormat::kRgba32Float: {
-      tex->set_type(
-          mod_->type_mgr().Get(std::make_unique<ast::type::F32Type>()));
+      tex->set_type(mod_->create<ast::type::F32Type>());
       return true;
     }
 
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 81ec3ef..c007cac 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -1786,9 +1786,8 @@
   ast::type::I32Type i32;
   auto coords_type = get_coords_type(dim, &i32);
 
-  ast::type::Type* texture_type =
-      mod->type_mgr().Get(std::make_unique<ast::type::StorageTextureType>(
-          dim, ast::AccessControl::kReadOnly, format));
+  ast::type::Type* texture_type = mod->create<ast::type::StorageTextureType>(
+      dim, ast::AccessControl::kReadOnly, format);
 
   ast::ExpressionList call_params;
 
@@ -4549,14 +4548,13 @@
   switch (param.texture_kind) {
     case ast::intrinsic::test::TextureKind::kRegular:
       Var("texture", ast::StorageClass::kNone,
-          mod->type_mgr().Get<ast::type::SampledTextureType>(
-              param.texture_dimension, datatype));
+          mod->create<ast::type::SampledTextureType>(param.texture_dimension,
+                                                     datatype));
       break;
 
     case ast::intrinsic::test::TextureKind::kDepth:
       Var("texture", ast::StorageClass::kNone,
-          mod->type_mgr().Get<ast::type::DepthTextureType>(
-              param.texture_dimension));
+          mod->create<ast::type::DepthTextureType>(param.texture_dimension));
       break;
   }
 
diff --git a/src/writer/hlsl/generator_impl_intrinsic_texture_test.cc b/src/writer/hlsl/generator_impl_intrinsic_texture_test.cc
index 94dc341..19c8612 100644
--- a/src/writer/hlsl/generator_impl_intrinsic_texture_test.cc
+++ b/src/writer/hlsl/generator_impl_intrinsic_texture_test.cc
@@ -183,14 +183,13 @@
   switch (param.texture_kind) {
     case ast::intrinsic::test::TextureKind::kRegular:
       Var("texture", ast::StorageClass::kNone,
-          mod->type_mgr().Get<ast::type::SampledTextureType>(
-              param.texture_dimension, datatype));
+          mod->create<ast::type::SampledTextureType>(param.texture_dimension,
+                                                     datatype));
       break;
 
     case ast::intrinsic::test::TextureKind::kDepth:
       Var("texture", ast::StorageClass::kNone,
-          mod->type_mgr().Get<ast::type::DepthTextureType>(
-              param.texture_dimension));
+          mod->create<ast::type::DepthTextureType>(param.texture_dimension));
       break;
   }
 
diff --git a/src/writer/spirv/builder_intrinsic_texture_test.cc b/src/writer/spirv/builder_intrinsic_texture_test.cc
index d11c80e..5d0a255 100644
--- a/src/writer/spirv/builder_intrinsic_texture_test.cc
+++ b/src/writer/spirv/builder_intrinsic_texture_test.cc
@@ -1618,14 +1618,14 @@
   switch (param.texture_kind) {
     case ast::intrinsic::test::TextureKind::kRegular:
       tex = Var("texture", ast::StorageClass::kNone,
-                mod->type_mgr().Get<ast::type::SampledTextureType>(
+                mod->create<ast::type::SampledTextureType>(
                     param.texture_dimension, datatype));
       break;
 
     case ast::intrinsic::test::TextureKind::kDepth:
-      tex = Var("texture", ast::StorageClass::kNone,
-                mod->type_mgr().Get<ast::type::DepthTextureType>(
-                    param.texture_dimension));
+      tex = Var(
+          "texture", ast::StorageClass::kNone,
+          mod->create<ast::type::DepthTextureType>(param.texture_dimension));
       break;
   }