Resolver: compute canonical types and store them as semantic::Variable::Type

We define the canonical type as a type stripped of all aliases. For
example, Canonical(alias<alias<vec3<alias<f32>>>>) is vec3<f32>. This
change adds Resolver::Canonical(Type*) which caches and returns the
resulting canonical type. We use this throughout the Resolver instead of
UnwrapAliasIfNeeded(), and we store the result in semantic::Variable,
returned from it's Type() member function.

Also:

* Wrote unit tests for Resolver::Canonical()

* Added semantic::Variable::DeclaredType() as a convenience to
retrieve the AST variable's type.

* Updated post-resolve code (transforms) to make use of Type and
DeclaredType appropriately, removing unnecessary calls to
UnwrapAliasIfNeeded.

* Added IntrinsicTableTest.MatchWithNestedAliasUnwrapping to ensure we
don't need to pass canonical parameter types for instrinsic table
lookups.

* ProgramBuilder: added vecN and matMxN overloads that take a Type* arg
to create them with alias types.

Bug: tint:705
Change-Id: I58a3b62538356b8dad2b1161a19b38bcefdd5d62
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47360
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/intrinsic_table_test.cc b/src/intrinsic_table_test.cc
index 1cf1868..350eec7 100644
--- a/src/intrinsic_table_test.cc
+++ b/src/intrinsic_table_test.cc
@@ -357,6 +357,23 @@
   EXPECT_THAT(result.intrinsic->Parameters(), ElementsAre(Parameter{ty.f32()}));
 }
 
+TEST_F(IntrinsicTableTest, MatchWithNestedAliasUnwrapping) {
+  auto* alias_a = ty.alias("alias_a", ty.bool_());
+  auto* alias_b = ty.alias("alias_b", alias_a);
+  auto* alias_c = ty.alias("alias_c", alias_b);
+  auto* vec4_of_c = ty.vec4(alias_c);
+  auto* alias_d = ty.alias("alias_d", vec4_of_c);
+  auto* alias_e = ty.alias("alias_e", alias_d);
+
+  auto result = table->Lookup(*this, IntrinsicType::kAll, {alias_e}, Source{});
+  ASSERT_NE(result.intrinsic, nullptr);
+  ASSERT_EQ(result.diagnostics.str(), "");
+  EXPECT_THAT(result.intrinsic->Type(), IntrinsicType::kAll);
+  EXPECT_THAT(result.intrinsic->ReturnType(), ty.bool_());
+  EXPECT_THAT(result.intrinsic->Parameters(),
+              ElementsAre(Parameter{ty.vec4<bool>()}));
+}
+
 TEST_F(IntrinsicTableTest, MatchOpenType) {
   auto result = table->Lookup(*this, IntrinsicType::kClamp,
                               {ty.f32(), ty.f32(), ty.f32()}, Source{});
diff --git a/src/program_builder.h b/src/program_builder.h
index 5fddb56..fb261cc 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -338,76 +338,148 @@
     /// @returns a void type
     type::Void* void_() const { return builder->create<type::Void>(); }
 
+    /// @param type vector subtype
+    /// @return the tint AST type for a 2-element vector of `type`.
+    type::Vector* vec2(type::Type* type) const {
+      return builder->create<type::Vector>(type, 2u);
+    }
+
+    /// @param type vector subtype
+    /// @return the tint AST type for a 3-element vector of `type`.
+    type::Vector* vec3(type::Type* type) const {
+      return builder->create<type::Vector>(type, 3u);
+    }
+
+    /// @param type vector subtype
+    /// @return the tint AST type for a 4-element vector of `type`.
+    type::Type* vec4(type::Type* type) const {
+      return builder->create<type::Vector>(type, 4u);
+    }
+
     /// @return the tint AST type for a 2-element vector of the C type `T`.
     template <typename T>
     type::Vector* vec2() const {
-      return builder->create<type::Vector>(Of<T>(), 2);
+      return vec2(Of<T>());
     }
 
     /// @return the tint AST type for a 3-element vector of the C type `T`.
     template <typename T>
     type::Vector* vec3() const {
-      return builder->create<type::Vector>(Of<T>(), 3);
+      return vec3(Of<T>());
     }
 
     /// @return the tint AST type for a 4-element vector of the C type `T`.
     template <typename T>
     type::Type* vec4() const {
-      return builder->create<type::Vector>(Of<T>(), 4);
+      return vec4(Of<T>());
+    }
+
+    /// @param type matrix subtype
+    /// @return the tint AST type for a 2x3 matrix of `type`.
+    type::Matrix* mat2x2(type::Type* type) const {
+      return builder->create<type::Matrix>(type, 2u, 2u);
+    }
+
+    /// @param type matrix subtype
+    /// @return the tint AST type for a 2x3 matrix of `type`.
+    type::Matrix* mat2x3(type::Type* type) const {
+      return builder->create<type::Matrix>(type, 3u, 2u);
+    }
+
+    /// @param type matrix subtype
+    /// @return the tint AST type for a 2x4 matrix of `type`.
+    type::Matrix* mat2x4(type::Type* type) const {
+      return builder->create<type::Matrix>(type, 4u, 2u);
+    }
+
+    /// @param type matrix subtype
+    /// @return the tint AST type for a 3x2 matrix of `type`.
+    type::Matrix* mat3x2(type::Type* type) const {
+      return builder->create<type::Matrix>(type, 2u, 3u);
+    }
+
+    /// @param type matrix subtype
+    /// @return the tint AST type for a 3x3 matrix of `type`.
+    type::Matrix* mat3x3(type::Type* type) const {
+      return builder->create<type::Matrix>(type, 3u, 3u);
+    }
+
+    /// @param type matrix subtype
+    /// @return the tint AST type for a 3x4 matrix of `type`.
+    type::Matrix* mat3x4(type::Type* type) const {
+      return builder->create<type::Matrix>(type, 4u, 3u);
+    }
+
+    /// @param type matrix subtype
+    /// @return the tint AST type for a 4x2 matrix of `type`.
+    type::Matrix* mat4x2(type::Type* type) const {
+      return builder->create<type::Matrix>(type, 2u, 4u);
+    }
+
+    /// @param type matrix subtype
+    /// @return the tint AST type for a 4x3 matrix of `type`.
+    type::Matrix* mat4x3(type::Type* type) const {
+      return builder->create<type::Matrix>(type, 3u, 4u);
+    }
+
+    /// @param type matrix subtype
+    /// @return the tint AST type for a 4x4 matrix of `type`.
+    type::Matrix* mat4x4(type::Type* type) const {
+      return builder->create<type::Matrix>(type, 4u, 4u);
     }
 
     /// @return the tint AST type for a 2x3 matrix of the C type `T`.
     template <typename T>
     type::Matrix* mat2x2() const {
-      return builder->create<type::Matrix>(Of<T>(), 2, 2);
+      return mat2x2(Of<T>());
     }
 
     /// @return the tint AST type for a 2x3 matrix of the C type `T`.
     template <typename T>
     type::Matrix* mat2x3() const {
-      return builder->create<type::Matrix>(Of<T>(), 3, 2);
+      return mat2x3(Of<T>());
     }
 
     /// @return the tint AST type for a 2x4 matrix of the C type `T`.
     template <typename T>
     type::Matrix* mat2x4() const {
-      return builder->create<type::Matrix>(Of<T>(), 4, 2);
+      return mat2x4(Of<T>());
     }
 
     /// @return the tint AST type for a 3x2 matrix of the C type `T`.
     template <typename T>
     type::Matrix* mat3x2() const {
-      return builder->create<type::Matrix>(Of<T>(), 2, 3);
+      return mat3x2(Of<T>());
     }
 
     /// @return the tint AST type for a 3x3 matrix of the C type `T`.
     template <typename T>
     type::Matrix* mat3x3() const {
-      return builder->create<type::Matrix>(Of<T>(), 3, 3);
+      return mat3x3(Of<T>());
     }
 
     /// @return the tint AST type for a 3x4 matrix of the C type `T`.
     template <typename T>
     type::Matrix* mat3x4() const {
-      return builder->create<type::Matrix>(Of<T>(), 4, 3);
+      return mat3x4(Of<T>());
     }
 
     /// @return the tint AST type for a 4x2 matrix of the C type `T`.
     template <typename T>
     type::Matrix* mat4x2() const {
-      return builder->create<type::Matrix>(Of<T>(), 2, 4);
+      return mat4x2(Of<T>());
     }
 
     /// @return the tint AST type for a 4x3 matrix of the C type `T`.
     template <typename T>
     type::Matrix* mat4x3() const {
-      return builder->create<type::Matrix>(Of<T>(), 3, 4);
+      return mat4x3(Of<T>());
     }
 
     /// @return the tint AST type for a 4x4 matrix of the C type `T`.
     template <typename T>
     type::Matrix* mat4x4() const {
-      return builder->create<type::Matrix>(Of<T>(), 4, 4);
+      return mat4x4(Of<T>());
     }
 
     /// @param subtype the array element type
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 06461e5..a18b481 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -42,6 +42,7 @@
 #include "src/semantic/struct.h"
 #include "src/semantic/variable.h"
 #include "src/type/access_control_type.h"
+#include "src/utils/get_or_create.h"
 #include "src/utils/math.h"
 
 namespace tint {
@@ -430,7 +431,7 @@
         }
 
         // Check that we saw a pipeline IO attribute iff we need one.
-        if (ty->UnwrapAliasIfNeeded()->Is<type::Struct>()) {
+        if (Canonical(ty)->Is<type::Struct>()) {
           if (pipeline_io_attribute) {
             diagnostics_.add_error(
                 "entry point IO attributes must not be used on structure " +
@@ -466,11 +467,11 @@
       return false;
     }
 
-    if (auto* struct_ty = ty->UnwrapAliasIfNeeded()->As<type::Struct>()) {
+    if (auto* struct_ty = Canonical(ty)->As<type::Struct>()) {
       // Validate the decorations for each struct members, and also check for
       // invalid member types.
       for (auto* member : struct_ty->impl()->members()) {
-        auto* member_ty = member->type()->UnwrapAliasIfNeeded();
+        auto* member_ty = Canonical(member->type());
         if (member_ty->Is<type::Struct>()) {
           diagnostics_.add_error(
               "entry point IO types cannot contain nested structures",
@@ -547,8 +548,7 @@
       return false;
     }
 
-    if (auto* str =
-            param->declared_type()->UnwrapAliasIfNeeded()->As<type::Struct>()) {
+    if (auto* str = param_info->type->As<type::Struct>()) {
       auto* info = Structure(str);
       if (!info) {
         return false;
@@ -572,8 +572,7 @@
     }
   }
 
-  if (auto* str =
-          func->return_type()->UnwrapAliasIfNeeded()->As<type::Struct>()) {
+  if (auto* str = Canonical(func->return_type())->As<type::Struct>()) {
     if (!ApplyStorageClassUsageToType(ast::StorageClass::kNone, str,
                                       func->source())) {
       diagnostics_.add_note("while instantiating return type for " +
@@ -1288,15 +1287,16 @@
   using Matrix = type::Matrix;
   using Vector = type::Vector;
 
-  auto* lhs_type = TypeOf(expr->lhs())->UnwrapAll();
-  auto* rhs_type = TypeOf(expr->rhs())->UnwrapAll();
+  auto* lhs_declared_type = TypeOf(expr->lhs())->UnwrapAll();
+  auto* rhs_declared_type = TypeOf(expr->rhs())->UnwrapAll();
+
+  auto* lhs_type = Canonical(lhs_declared_type);
+  auto* rhs_type = Canonical(rhs_declared_type);
 
   auto* lhs_vec = lhs_type->As<Vector>();
-  auto* lhs_vec_elem_type =
-      lhs_vec ? lhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
+  auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
   auto* rhs_vec = rhs_type->As<Vector>();
-  auto* rhs_vec_elem_type =
-      rhs_vec ? rhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
+  auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr;
 
   const bool matching_vec_elem_types =
       lhs_vec_elem_type && rhs_vec_elem_type &&
@@ -1348,11 +1348,9 @@
     }
 
     auto* lhs_mat = lhs_type->As<Matrix>();
-    auto* lhs_mat_elem_type =
-        lhs_mat ? lhs_mat->type()->UnwrapAliasIfNeeded() : nullptr;
+    auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
     auto* rhs_mat = rhs_type->As<Matrix>();
-    auto* rhs_mat_elem_type =
-        rhs_mat ? rhs_mat->type()->UnwrapAliasIfNeeded() : nullptr;
+    auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
 
     // Multiplication of a matrix and a scalar
     if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
@@ -1438,9 +1436,9 @@
 
   diagnostics_.add_error(
       "Binary expression operand types are invalid for this operation: " +
-          lhs_type->FriendlyName(builder_->Symbols()) + " " +
+          lhs_declared_type->FriendlyName(builder_->Symbols()) + " " +
           FriendlyName(expr->op()) + " " +
-          rhs_type->FriendlyName(builder_->Symbols()),
+          rhs_declared_type->FriendlyName(builder_->Symbols()),
       expr->source());
   return false;
 }
@@ -1600,7 +1598,7 @@
 }
 
 Resolver::VariableInfo* Resolver::CreateVariableInfo(ast::Variable* var) {
-  auto* info = variable_infos_.Create(var);
+  auto* info = variable_infos_.Create(var, Canonical(var->declared_type()));
   variable_to_info_.emplace(var, info);
   return info;
 }
@@ -1748,13 +1746,13 @@
       /*vec4*/ 16,
   };
 
-  ty = ty->UnwrapAliasIfNeeded();
-  if (ty->is_scalar()) {
+  auto* cty = Canonical(ty);
+  if (cty->is_scalar()) {
     // Note: Also captures booleans, but these are not host-shareable.
     align = 4;
     size = 4;
     return true;
-  } else if (auto* vec = ty->As<type::Vector>()) {
+  } else if (auto* vec = cty->As<type::Vector>()) {
     if (vec->size() < 2 || vec->size() > 4) {
       TINT_UNREACHABLE(diagnostics_)
           << "Invalid vector size: vec" << vec->size();
@@ -1763,7 +1761,7 @@
     align = vector_align[vec->size()];
     size = vector_size[vec->size()];
     return true;
-  } else if (auto* mat = ty->As<type::Matrix>()) {
+  } else if (auto* mat = cty->As<type::Matrix>()) {
     if (mat->columns() < 2 || mat->columns() > 4 || mat->rows() < 2 ||
         mat->rows() > 4) {
       TINT_UNREACHABLE(diagnostics_)
@@ -1773,15 +1771,15 @@
     align = vector_align[mat->rows()];
     size = vector_align[mat->rows()] * mat->columns();
     return true;
-  } else if (auto* s = ty->As<type::Struct>()) {
+  } else if (auto* s = cty->As<type::Struct>()) {
     if (auto* si = Structure(s)) {
       align = si->align;
       size = si->size;
       return true;
     }
     return false;
-  } else if (auto* arr = ty->As<type::Array>()) {
-    if (auto* sem = Array(arr)) {
+  } else if (cty->Is<type::Array>()) {
+    if (auto* sem = Array(ty->UnwrapAliasIfNeeded()->As<type::Array>())) {
       align = sem->Align();
       size = sem->Size();
       return true;
@@ -2249,9 +2247,37 @@
   return vec_type.FriendlyName(builder_->Symbols());
 }
 
-Resolver::VariableInfo::VariableInfo(ast::Variable* decl)
+type::Type* Resolver::Canonical(type::Type* type) {
+  using Type = type::Type;
+  using Alias = type::Alias;
+  using Matrix = type::Matrix;
+  using Vector = type::Vector;
+
+  std::function<Type*(Type*)> make_canonical;
+  make_canonical = [&](Type* t) -> type::Type* {
+    // Unwrap alias sequence
+    Type* ct = t;
+    while (auto* p = ct->As<Alias>()) {
+      ct = p->type();
+    }
+
+    if (auto* v = ct->As<Vector>()) {
+      return builder_->create<Vector>(make_canonical(v->type()), v->size());
+    }
+    if (auto* m = ct->As<Matrix>()) {
+      return builder_->create<Matrix>(make_canonical(m->type()), m->rows(),
+                                      m->columns());
+    }
+    return ct;
+  };
+
+  return utils::GetOrCreate(type_to_canonical_, type,
+                            [&] { return make_canonical(type); });
+}
+
+Resolver::VariableInfo::VariableInfo(ast::Variable* decl, type::Type* ctype)
     : declaration(decl),
-      type(decl->declared_type()),
+      type(ctype),
       storage_class(decl->declared_storage_class()) {}
 
 Resolver::VariableInfo::~VariableInfo() = default;
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index e7f1f76..d4f73fc 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -86,11 +86,17 @@
   /// structure member or array element of type `lhs`
   static bool IsValidAssignment(type::Type* lhs, type::Type* rhs);
 
+  /// @param type the input type
+  /// @returns the canonical type for `type`; that is, a type with all aliases
+  /// removed. For example, `Canonical(alias<alias<vec3<alias<f32>>>>)` is
+  /// `vec3<f32>`.
+  type::Type* Canonical(type::Type* type);
+
  private:
   /// Structure holding semantic information about a variable.
   /// Used to build the semantic::Variable nodes at the end of resolving.
   struct VariableInfo {
-    explicit VariableInfo(ast::Variable* decl);
+    VariableInfo(ast::Variable* decl, type::Type* type);
     ~VariableInfo();
 
     ast::Variable* const declaration;
@@ -306,6 +312,7 @@
   std::unordered_map<ast::CallExpression*, FunctionCallInfo> function_calls_;
   std::unordered_map<ast::Expression*, ExpressionInfo> expr_info_;
   std::unordered_map<type::Struct*, StructInfo*> struct_info_;
+  std::unordered_map<type::Type*, type::Type*> type_to_canonical_;
   FunctionInfo* current_function_ = nullptr;
   semantic::Statement* current_statement_ = nullptr;
   BlockAllocator<VariableInfo> variable_infos_;
diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h
index cbcb2a6..cd9dd7a 100644
--- a/src/resolver/resolver_test_helper.h
+++ b/src/resolver/resolver_test_helper.h
@@ -119,18 +119,30 @@
   return ty.f32();
 }
 
+using create_type_func_ptr =
+    type::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
+
 template <typename T>
 type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
   return ty.vec3<T>();
 }
 
+template <create_type_func_ptr create_type>
+type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
+  auto* type = create_type(ty);
+  return ty.vec3(type);
+}
+
 template <typename T>
 type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
   return ty.mat3x3<T>();
 }
 
-using create_type_func_ptr =
-    type::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
+template <create_type_func_ptr create_type>
+type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
+  auto* type = create_type(ty);
+  return ty.mat3x3(type);
+}
 
 template <create_type_func_ptr create_type>
 type::Type* ty_alias(const ProgramBuilder::TypesBuilder& ty) {
diff --git a/src/resolver/type_validation_test.cc b/src/resolver/type_validation_test.cc
index ec2113c..973c718 100644
--- a/src/resolver/type_validation_test.cc
+++ b/src/resolver/type_validation_test.cc
@@ -21,6 +21,7 @@
 #include "gmock/gmock.h"
 
 namespace tint {
+namespace resolver {
 namespace {
 
 class ResolverTypeValidationTest : public resolver::TestHelper,
@@ -463,5 +464,51 @@
   EXPECT_TRUE(r()->Resolve()) << r()->error();
 }
 
+namespace GetCanonicalTests {
+struct Params {
+  create_type_func_ptr create_type;
+  create_type_func_ptr create_canonical_type;
+};
+
+static constexpr Params cases[] = {
+    Params{ty_bool_, ty_bool_},
+    Params{ty_alias<ty_bool_>, ty_bool_},
+    Params{ty_alias<ty_alias<ty_bool_>>, ty_bool_},
+
+    Params{ty_vec3<ty_f32>, ty_vec3<ty_f32>},
+    Params{ty_alias<ty_vec3<ty_f32>>, ty_vec3<ty_f32>},
+    Params{ty_alias<ty_alias<ty_vec3<ty_f32>>>, ty_vec3<ty_f32>},
+
+    Params{ty_vec3<ty_alias<ty_f32>>, ty_vec3<ty_f32>},
+    Params{ty_alias<ty_vec3<ty_alias<ty_f32>>>, ty_vec3<ty_f32>},
+    Params{ty_alias<ty_alias<ty_vec3<ty_alias<ty_f32>>>>, ty_vec3<ty_f32>},
+    Params{ty_alias<ty_alias<ty_vec3<ty_alias<ty_alias<ty_f32>>>>>,
+           ty_vec3<ty_f32>},
+
+    Params{ty_mat3x3<ty_alias<ty_f32>>, ty_mat3x3<ty_f32>},
+    Params{ty_alias<ty_mat3x3<ty_alias<ty_f32>>>, ty_mat3x3<ty_f32>},
+    Params{ty_alias<ty_alias<ty_mat3x3<ty_alias<ty_f32>>>>, ty_mat3x3<ty_f32>},
+    Params{ty_alias<ty_alias<ty_mat3x3<ty_alias<ty_alias<ty_f32>>>>>,
+           ty_mat3x3<ty_f32>},
+};
+
+using CanonicalTest = ResolverTestWithParam<Params>;
+TEST_P(CanonicalTest, All) {
+  auto& params = GetParam();
+
+  auto* type = params.create_type(ty);
+  auto* expected_canonical_type = params.create_canonical_type(ty);
+
+  auto* canonical_type = r()->Canonical(type);
+
+  EXPECT_EQ(canonical_type, expected_canonical_type);
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+                         CanonicalTest,
+                         testing::ValuesIn(cases));
+
+}  // namespace GetCanonicalTests
+
 }  // namespace
+}  // namespace resolver
 }  // namespace tint
diff --git a/src/semantic/sem_variable.cc b/src/semantic/sem_variable.cc
index 13b3820..c345154 100644
--- a/src/semantic/sem_variable.cc
+++ b/src/semantic/sem_variable.cc
@@ -15,6 +15,7 @@
 #include "src/semantic/variable.h"
 
 #include "src/ast/identifier_expression.h"
+#include "src/ast/variable.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::semantic::Variable);
 TINT_INSTANTIATE_TYPEINFO(tint::semantic::VariableUser);
@@ -29,6 +30,10 @@
 
 Variable::~Variable() = default;
 
+type::Type* Variable::DeclaredType() const {
+  return declaration_->declared_type();
+}
+
 VariableUser::VariableUser(ast::IdentifierExpression* declaration,
                            type::Type* type,
                            Statement* statement,
diff --git a/src/semantic/variable.h b/src/semantic/variable.h
index 7cae490..9873176 100644
--- a/src/semantic/variable.h
+++ b/src/semantic/variable.h
@@ -52,9 +52,12 @@
   /// @returns the AST declaration node
   const ast::Variable* Declaration() const { return declaration_; }
 
-  /// @returns the type for the variable
+  /// @returns the canonical type for the variable
   type::Type* Type() const { return type_; }
 
+  /// @returns the AST node's type. May be nullptr.
+  type::Type* DeclaredType() const;
+
   /// @returns the storage class for the variable
   ast::StorageClass StorageClass() const { return storage_class_; }
 
diff --git a/src/transform/binding_remapper.cc b/src/transform/binding_remapper.cc
index 4c8ed7b..aed0b0a 100644
--- a/src/transform/binding_remapper.cc
+++ b/src/transform/binding_remapper.cc
@@ -65,8 +65,7 @@
       auto ac_it = remappings->access_controls.find(from);
       if (ac_it != remappings->access_controls.end()) {
         ast::AccessControl ac = ac_it->second;
-        auto* var_ty = in->Sem().Get(var)->Type();
-        auto* ty = var_ty->UnwrapAliasIfNeeded();
+        auto* ty = in->Sem().Get(var)->Type();
         type::Type* inner_ty = nullptr;
         if (auto* old_ac = ty->As<type::AccessControl>()) {
           inner_ty = ctx.Clone(old_ac->type());
diff --git a/src/transform/canonicalize_entry_point_io.cc b/src/transform/canonicalize_entry_point_io.cc
index db1b5a7..72c88b4 100644
--- a/src/transform/canonicalize_entry_point_io.cc
+++ b/src/transform/canonicalize_entry_point_io.cc
@@ -72,11 +72,11 @@
       for (auto* param : func->params()) {
         auto param_name = ctx.Clone(param->symbol());
         auto* param_ty = ctx.src->Sem().Get(param)->Type();
+        auto* param_declared_ty = ctx.src->Sem().Get(param)->DeclaredType();
 
         ast::Expression* func_const_initializer = nullptr;
 
-        if (auto* struct_ty =
-                param_ty->UnwrapAliasIfNeeded()->As<type::Struct>()) {
+        if (auto* struct_ty = param_ty->As<type::Struct>()) {
           // Pull out all struct members and build initializer list.
           ast::ExpressionList init_values;
           for (auto* member : struct_ty->impl()->members()) {
@@ -97,7 +97,7 @@
           }
 
           func_const_initializer =
-              ctx.dst->Construct(ctx.Clone(param_ty), init_values);
+              ctx.dst->Construct(ctx.Clone(param_declared_ty), init_values);
         } else {
           ast::DecorationList new_decorations = RemoveDecorations(
               &ctx, param->decorations(), [](const ast::Decoration* deco) {
@@ -105,7 +105,7 @@
                                       ast::LocationDecoration>();
               });
           new_struct_members.push_back(ctx.dst->Member(
-              param_name, ctx.Clone(param_ty), new_decorations));
+              param_name, ctx.Clone(param_declared_ty), new_decorations));
           func_const_initializer =
               ctx.dst->MemberAccessor(new_struct_param_symbol, param_name);
         }
@@ -117,8 +117,8 @@
 
         // Create a function-scope const to replace the parameter.
         // Initialize it with the value extracted from the new struct parameter.
-        auto* func_const = ctx.dst->Const(param_name, ctx.Clone(param_ty),
-                                          func_const_initializer);
+        auto* func_const = ctx.dst->Const(
+            param_name, ctx.Clone(param_declared_ty), func_const_initializer);
         ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
                          ctx.dst->WrapInStatement(func_const));
 
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc
index 6945cbf..6adc931 100644
--- a/src/transform/spirv.cc
+++ b/src/transform/spirv.cc
@@ -141,7 +141,8 @@
 
     for (auto* param : func->params()) {
       Symbol new_var = HoistToInputVariables(
-          ctx, func, ctx.src->Sem().Get(param)->Type(), param->decorations());
+          ctx, func, ctx.src->Sem().Get(param)->Type(),
+          ctx.src->Sem().Get(param)->DeclaredType(), param->decorations());
 
       // Replace all uses of the function parameter with the new variable.
       for (auto* user : ctx.src->Sem().Get(param)->Users()) {
@@ -153,9 +154,9 @@
     if (!func->return_type()->Is<type::Void>()) {
       ast::StatementList stores;
       auto store_value_symbol = ctx.dst->Symbols().New();
-      HoistToOutputVariables(ctx, func, func->return_type(),
-                             func->return_type_decorations(), {},
-                             store_value_symbol, stores);
+      HoistToOutputVariables(
+          ctx, func, func->return_type(), func->return_type(),
+          func->return_type_decorations(), {}, store_value_symbol, stores);
 
       // Create a function that writes a return value to all output variables.
       auto* store_value =
@@ -251,8 +252,9 @@
     CloneContext& ctx,
     const ast::Function* func,
     type::Type* ty,
+    type::Type* declared_ty,
     const ast::DecorationList& decorations) const {
-  if (!ty->UnwrapAliasIfNeeded()->Is<type::Struct>()) {
+  if (!ty->Is<type::Struct>()) {
     // Base case: create a global variable and return.
     ast::DecorationList new_decorations =
         RemoveDecorations(&ctx, decorations, [](const ast::Decoration* deco) {
@@ -261,7 +263,7 @@
         });
     auto global_var_symbol = ctx.dst->Symbols().New();
     auto* global_var =
-        ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
+        ctx.dst->Var(global_var_symbol, ctx.Clone(declared_ty),
                      ast::StorageClass::kInput, nullptr, new_decorations);
     ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
     return global_var_symbol;
@@ -269,10 +271,10 @@
 
   // Recurse into struct members and build the initializer list.
   ast::ExpressionList init_values;
-  auto* struct_ty = ty->UnwrapAliasIfNeeded()->As<type::Struct>();
+  auto* struct_ty = ty->As<type::Struct>();
   for (auto* member : struct_ty->impl()->members()) {
-    auto member_var =
-        HoistToInputVariables(ctx, func, member->type(), member->decorations());
+    auto member_var = HoistToInputVariables(
+        ctx, func, member->type(), member->type(), member->decorations());
     init_values.push_back(ctx.dst->Expr(member_var));
   }
 
@@ -283,8 +285,9 @@
   }
 
   // Create a function-scope variable for the struct.
-  auto* initializer = ctx.dst->Construct(ctx.Clone(ty), init_values);
-  auto* func_var = ctx.dst->Const(func_var_symbol, ctx.Clone(ty), initializer);
+  auto* initializer = ctx.dst->Construct(ctx.Clone(declared_ty), init_values);
+  auto* func_var =
+      ctx.dst->Const(func_var_symbol, ctx.Clone(declared_ty), initializer);
   ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
                    ctx.dst->WrapInStatement(func_var));
   return func_var_symbol;
@@ -293,12 +296,13 @@
 void Spirv::HoistToOutputVariables(CloneContext& ctx,
                                    const ast::Function* func,
                                    type::Type* ty,
+                                   type::Type* declared_ty,
                                    const ast::DecorationList& decorations,
                                    std::vector<Symbol> member_accesses,
                                    Symbol store_value,
                                    ast::StatementList& stores) const {
   // Base case.
-  if (!ty->UnwrapAliasIfNeeded()->Is<type::Struct>()) {
+  if (!ty->Is<type::Struct>()) {
     // Create a global variable.
     ast::DecorationList new_decorations =
         RemoveDecorations(&ctx, decorations, [](const ast::Decoration* deco) {
@@ -307,7 +311,7 @@
         });
     auto global_var_symbol = ctx.dst->Symbols().New();
     auto* global_var =
-        ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
+        ctx.dst->Var(global_var_symbol, ctx.Clone(declared_ty),
                      ast::StorageClass::kOutput, nullptr, new_decorations);
     ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
 
@@ -322,11 +326,12 @@
   }
 
   // Recurse into struct members.
-  auto* struct_ty = ty->UnwrapAliasIfNeeded()->As<type::Struct>();
+  auto* struct_ty = ty->As<type::Struct>();
   for (auto* member : struct_ty->impl()->members()) {
     member_accesses.push_back(ctx.Clone(member->symbol()));
-    HoistToOutputVariables(ctx, func, member->type(), member->decorations(),
-                           member_accesses, store_value, stores);
+    HoistToOutputVariables(ctx, func, member->type(), member->type(),
+                           member->decorations(), member_accesses, store_value,
+                           stores);
     member_accesses.pop_back();
   }
 }
diff --git a/src/transform/spirv.h b/src/transform/spirv.h
index 7ac53e1..a80a1dc 100644
--- a/src/transform/spirv.h
+++ b/src/transform/spirv.h
@@ -60,6 +60,7 @@
   Symbol HoistToInputVariables(CloneContext& ctx,
                                const ast::Function* func,
                                type::Type* ty,
+                               type::Type* declared_ty,
                                const ast::DecorationList& decorations) const;
 
   /// Recursively create module-scope output variables for `ty` and build a list
@@ -73,6 +74,7 @@
   void HoistToOutputVariables(CloneContext& ctx,
                               const ast::Function* func,
                               type::Type* ty,
+                              type::Type* declared_ty,
                               const ast::DecorationList& decorations,
                               std::vector<Symbol> member_accesses,
                               Symbol store_value,
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 2c35689..81ae49a 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -1894,7 +1894,7 @@
     for (auto* var : func_sem->ReferencedModuleVariables()) {
       auto* decl = var->Declaration();
 
-      auto* unwrapped_type = var->Type()->UnwrapAll();
+      auto* unwrapped_type = var->DeclaredType()->UnwrapAll();
       if (!emitted_globals.emplace(decl->symbol()).second) {
         continue;  // Global already emitted
       }
@@ -1905,7 +1905,7 @@
         continue;  // Not interested in this type
       }
 
-      if (!EmitType(out, var->Type(), var->StorageClass(), "")) {
+      if (!EmitType(out, var->DeclaredType(), var->StorageClass(), "")) {
         return false;
       }
       out << " " << builder_.Symbols().NameFor(decl->symbol());
@@ -1915,9 +1915,7 @@
       if (unwrapped_type->Is<type::Texture>()) {
         register_space = "t";
         if (unwrapped_type->Is<type::StorageTexture>()) {
-          if (auto* ac = var->Type()
-                             ->UnwrapAliasIfNeeded()
-                             ->As<type::AccessControl>()) {
+          if (auto* ac = var->Type()->As<type::AccessControl>()) {
             if (!ac->IsReadOnly()) {
               register_space = "u";
             }
diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc
index 22285c0..78242ec 100644
--- a/src/writer/wgsl/generator_impl.cc
+++ b/src/writer/wgsl/generator_impl.cc
@@ -325,7 +325,7 @@
 
     out_ << program_->Symbols().NameFor(v->symbol()) << " : ";
 
-    if (!EmitType(program_->Sem().Get(v)->Type())) {
+    if (!EmitType(program_->Sem().Get(v)->DeclaredType())) {
       return false;
     }
   }
@@ -599,7 +599,7 @@
   }
 
   out_ << " " << program_->Symbols().NameFor(var->symbol()) << " : ";
-  if (!EmitType(sem->Type())) {
+  if (!EmitType(sem->DeclaredType())) {
     return false;
   }