tint/sem: Return vector for Type::ElementOf(matrix)

Return the column vector type, instead of the column vector element
type. This matches what you'd get if you were to index the matrix.

DeepestElementOf() can be used to easily obtain the matrix column
element type.

Change-Id: I5293f4cca205c9e378253ac67880bf9d998814aa
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/94327
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc
index 76dd62e..5f4a41f 100644
--- a/src/tint/resolver/materialize_test.cc
+++ b/src/tint/resolver/materialize_test.cc
@@ -325,7 +325,7 @@
         EXPECT_TYPE(expr->ConstantValue().Type(), target_sem_ty);
 
         uint32_t num_elems = 0;
-        const sem::Type* target_sem_el_ty = sem::Type::ElementOf(target_sem_ty, &num_elems);
+        const sem::Type* target_sem_el_ty = sem::Type::DeepestElementOf(target_sem_ty, &num_elems);
         EXPECT_TYPE(expr->ConstantValue().ElementType(), target_sem_el_ty);
         expr->ConstantValue().WithElements([&](auto&& vec) {
             using VEC_TY = std::decay_t<decltype(vec)>;
@@ -738,7 +738,8 @@
         EXPECT_TYPE(expr->ConstantValue().Type(), expected_sem_ty);
 
         uint32_t num_elems = 0;
-        const sem::Type* expected_sem_el_ty = sem::Type::ElementOf(expected_sem_ty, &num_elems);
+        const sem::Type* expected_sem_el_ty =
+            sem::Type::DeepestElementOf(expected_sem_ty, &num_elems);
         EXPECT_TYPE(expr->ConstantValue().ElementType(), expected_sem_el_ty);
         expr->ConstantValue().WithElements([&](auto&& vec) {
             using VEC_TY = std::decay_t<decltype(vec)>;
diff --git a/src/tint/resolver/resolver_constants.cc b/src/tint/resolver/resolver_constants.cc
index 95b494e..ad89f3e 100644
--- a/src/tint/resolver/resolver_constants.cc
+++ b/src/tint/resolver/resolver_constants.cc
@@ -152,13 +152,10 @@
 }  // namespace
 
 sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type) {
-    if (auto* e = expr->As<ast::LiteralExpression>()) {
-        return EvaluateConstantValue(e, type);
-    }
-    if (auto* e = expr->As<ast::CallExpression>()) {
-        return EvaluateConstantValue(e, type);
-    }
-    return {};
+    return Switch(
+        expr,  //
+        [&](const ast::LiteralExpression* e) { return EvaluateConstantValue(e, type); },
+        [&](const ast::CallExpression* e) { return EvaluateConstantValue(e, type); });
 }
 
 sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* literal,
@@ -178,10 +175,10 @@
 
 sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
                                               const sem::Type* ty) {
-    uint32_t result_size = 0;
-    auto* el_ty = sem::Type::ElementOf(ty, &result_size);
+    uint32_t num_elems = 0;
+    auto* el_ty = sem::Type::DeepestElementOf(ty, &num_elems);
     if (!el_ty) {
-        return sem::Constant{};
+        return {};
     }
 
     // ElementOf() will also return the element type of array, which we do not support.
@@ -194,16 +191,16 @@
         return Switch(
             el_ty,
             [&](const sem::AbstractInt*) {
-                return sem::Constant(ty, std::vector(result_size, AInt(0)));
+                return sem::Constant(ty, std::vector(num_elems, AInt(0)));
             },
             [&](const sem::AbstractFloat*) {
-                return sem::Constant(ty, std::vector(result_size, AFloat(0)));
+                return sem::Constant(ty, std::vector(num_elems, AFloat(0)));
             },
-            [&](const sem::I32*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); },
-            [&](const sem::U32*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); },
-            [&](const sem::F32*) { return sem::Constant(ty, std::vector(result_size, AFloat(0))); },
-            [&](const sem::F16*) { return sem::Constant(ty, std::vector(result_size, AFloat(0))); },
-            [&](const sem::Bool*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); });
+            [&](const sem::I32*) { return sem::Constant(ty, std::vector(num_elems, AInt(0))); },
+            [&](const sem::U32*) { return sem::Constant(ty, std::vector(num_elems, AInt(0))); },
+            [&](const sem::F32*) { return sem::Constant(ty, std::vector(num_elems, AFloat(0))); },
+            [&](const sem::F16*) { return sem::Constant(ty, std::vector(num_elems, AFloat(0))); },
+            [&](const sem::Bool*) { return sem::Constant(ty, std::vector(num_elems, AInt(0))); });
     }
 
     // Build value for type_ctor from each child value by converting to type_ctor's type.
@@ -235,18 +232,27 @@
         }
     }
 
-    // Splat single-value initializers
-    std::visit(
+    if (!elements) {
+        return {};
+    }
+
+    return std::visit(
         [&](auto&& v) {
-            if (v.size() == 1) {
-                for (uint32_t i = 0; i < result_size - 1; ++i) {
-                    v.emplace_back(v[0]);
+            if (num_elems != v.size()) {
+                if (v.size() == 1) {
+                    // Splat single-value initializers
+                    for (uint32_t i = 0; i < num_elems - 1; ++i) {
+                        v.emplace_back(v[0]);
+                    }
+                } else {
+                    // Provided number of arguments does not match the required number of elements.
+                    // Validation should error here.
+                    return sem::Constant{};
                 }
             }
+            return sem::Constant(ty, std::move(elements.value()));
         },
         elements.value());
-
-    return sem::Constant(ty, std::move(elements.value()));
 }
 
 utils::Result<sem::Constant> Resolver::ConvertValue(const sem::Constant& value,
@@ -256,7 +262,7 @@
         return value;
     }
 
-    auto* el_ty = sem::Type::ElementOf(ty);
+    auto* el_ty = sem::Type::DeepestElementOf(ty);
     if (el_ty == nullptr) {
         return sem::Constant{};
     }
diff --git a/src/tint/resolver/resolver_constants_test.cc b/src/tint/resolver/resolver_constants_test.cc
index bbdbfea..c937fa1 100644
--- a/src/tint/resolver/resolver_constants_test.cc
+++ b/src/tint/resolver/resolver_constants_test.cc
@@ -93,9 +93,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::I32>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -112,9 +113,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::U32>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::U32>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -131,9 +133,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::F32>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -150,9 +153,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::Bool>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::Bool>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -169,9 +173,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::I32>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -188,9 +193,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::U32>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::U32>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -207,9 +213,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::F32>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -226,9 +233,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::Bool>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::Bool>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -245,9 +253,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::I32>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -264,9 +273,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::U32>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::U32>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -283,9 +293,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::F32>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -302,9 +313,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::Bool>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::Bool>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -321,9 +333,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::I32>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -340,9 +353,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::U32>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::U32>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -359,9 +373,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::F32>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -378,9 +393,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::Bool>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::Bool>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -397,9 +413,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::I32>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -416,9 +433,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::F32>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -435,9 +453,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::I32>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -454,9 +473,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::U32>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::U32>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -478,9 +498,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F16>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::F16>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F16>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -500,9 +521,10 @@
 
     auto* sem = Sem().Get(expr);
     EXPECT_NE(sem, nullptr);
-    ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
-    EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F16>());
-    EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+    auto* vec = sem->Type()->As<sem::Vector>();
+    ASSERT_NE(vec, nullptr);
+    EXPECT_TRUE(vec->type()->Is<sem::F16>());
+    EXPECT_EQ(vec->Width(), 3u);
     EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
     EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F16>());
     ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@@ -511,5 +533,80 @@
     EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 0.0);
 }
 
+TEST_F(ResolverConstantsTest, Mat2x3_ZeroInit_f32) {
+    auto* expr = mat2x3<f32>();
+    WrapInFunction(expr);
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+    auto* sem = Sem().Get(expr);
+    EXPECT_NE(sem, nullptr);
+    auto* mat = sem->Type()->As<sem::Matrix>();
+    ASSERT_NE(mat, nullptr);
+    EXPECT_TRUE(mat->type()->Is<sem::F32>());
+    EXPECT_EQ(mat->columns(), 2u);
+    EXPECT_EQ(mat->rows(), 3u);
+    EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+    EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
+    ASSERT_EQ(sem->ConstantValue().ElementCount(), 6u);
+    EXPECT_EQ(sem->ConstantValue().Element<f32>(0).value, 0._f);
+    EXPECT_EQ(sem->ConstantValue().Element<f32>(1).value, 0._f);
+    EXPECT_EQ(sem->ConstantValue().Element<f32>(2).value, 0._f);
+    EXPECT_EQ(sem->ConstantValue().Element<f32>(3).value, 0._f);
+    EXPECT_EQ(sem->ConstantValue().Element<f32>(4).value, 0._f);
+    EXPECT_EQ(sem->ConstantValue().Element<f32>(5).value, 0._f);
+}
+
+TEST_F(ResolverConstantsTest, Mat3x2_Construct_Scalars_af) {
+    auto* expr = Construct(ty.mat(nullptr, 3, 2), 1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a);
+    WrapInFunction(expr);
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+    auto* sem = Sem().Get(expr);
+    EXPECT_NE(sem, nullptr);
+    auto* mat = sem->Type()->As<sem::Matrix>();
+    ASSERT_NE(mat, nullptr);
+    EXPECT_TRUE(mat->type()->Is<sem::F32>());
+    EXPECT_EQ(mat->columns(), 3u);
+    EXPECT_EQ(mat->rows(), 2u);
+    EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+    EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
+    ASSERT_EQ(sem->ConstantValue().ElementCount(), 6u);
+    EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 1._a);
+    EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 2._a);
+    EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 3._a);
+    EXPECT_EQ(sem->ConstantValue().Element<AFloat>(3).value, 4._a);
+    EXPECT_EQ(sem->ConstantValue().Element<AFloat>(4).value, 5._a);
+    EXPECT_EQ(sem->ConstantValue().Element<AFloat>(5).value, 6._a);
+}
+
+TEST_F(ResolverConstantsTest, Mat3x2_Construct_Columns_af) {
+    auto* expr = Construct(ty.mat(nullptr, 3, 2),           //
+                           vec(nullptr, 2u, 1.0_a, 2.0_a),  //
+                           vec(nullptr, 2u, 3.0_a, 4.0_a),  //
+                           vec(nullptr, 2u, 5.0_a, 6.0_a));
+    WrapInFunction(expr);
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+    auto* sem = Sem().Get(expr);
+    EXPECT_NE(sem, nullptr);
+    auto* mat = sem->Type()->As<sem::Matrix>();
+    ASSERT_NE(mat, nullptr);
+    EXPECT_TRUE(mat->type()->Is<sem::F32>());
+    EXPECT_EQ(mat->columns(), 3u);
+    EXPECT_EQ(mat->rows(), 2u);
+    EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+    EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
+    ASSERT_EQ(sem->ConstantValue().ElementCount(), 6u);
+    EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 1._a);
+    EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 2._a);
+    EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 3._a);
+    EXPECT_EQ(sem->ConstantValue().Element<AFloat>(3).value, 4._a);
+    EXPECT_EQ(sem->ConstantValue().Element<AFloat>(4).value, 5._a);
+    EXPECT_EQ(sem->ConstantValue().Element<AFloat>(5).value, 6._a);
+}
+
 }  // namespace
 }  // namespace tint::resolver
diff --git a/src/tint/sem/constant.cc b/src/tint/sem/constant.cc
index 8086992..921e9a7 100644
--- a/src/tint/sem/constant.cc
+++ b/src/tint/sem/constant.cc
@@ -98,7 +98,7 @@
     diag::List diag;
     if (ty->is_abstract_or_scalar() || ty->IsAnyOf<Vector, Matrix>()) {
         uint32_t count = 0;
-        auto* el_ty = Type::ElementOf(ty, &count);
+        auto* el_ty = Type::DeepestElementOf(ty, &count);
         if (num_elements != count) {
             TINT_ICE(Semantic, diag) << "sem::Constant() type <-> element mismatch. type: '"
                                      << ty->TypeInfo().name << "' element: " << num_elements;
diff --git a/src/tint/sem/constant.h b/src/tint/sem/constant.h
index 4f7132a..42683cf 100644
--- a/src/tint/sem/constant.h
+++ b/src/tint/sem/constant.h
@@ -104,10 +104,10 @@
         return std::visit([](auto&& v) { return v.size(); }, elems_);
     }
 
-    /// @returns the element type of the Constant
+    /// @returns the flattened element type of the Constant
     const sem::Type* ElementType() const { return elem_type_; }
 
-    /// @returns the constant's elements
+    /// @returns the constant's flattened elements
     const Elements& GetElements() const { return elems_; }
 
     /// WithElements calls the function `f` with the vector of elements as either AFloats or AInts
diff --git a/src/tint/sem/type.cc b/src/tint/sem/type.cc
index d92b236..9d4c469 100644
--- a/src/tint/sem/type.cc
+++ b/src/tint/sem/type.cc
@@ -220,9 +220,9 @@
         },
         [&](const Matrix* m) {
             if (count) {
-                *count = m->columns() * m->rows();
+                *count = m->columns();
             }
-            return m->type();
+            return m->ColumnType();
         },
         [&](const Array* a) {
             if (count) {
diff --git a/src/tint/sem/type.h b/src/tint/sem/type.h
index ac863207..25f3a43 100644
--- a/src/tint/sem/type.h
+++ b/src/tint/sem/type.h
@@ -132,15 +132,22 @@
     /// @param ty the type to obtain the element type from
     /// @param count if not null, then this is assigned the number of child elements in the type.
     /// For example, the count of an `array<vec3<f32>, 5>` type would be 5.
-    /// @returns `ty` if `ty` is an abstract or scalar, or the element type if ty is a vector,
-    /// matrix or array, otherwise nullptr.
+    /// @returns
+    ///   * `ty` if `ty` is an abstract or scalar
+    ///   * the element type if `ty` is a vector or array
+    ///   * the column type if `ty` is a matrix
+    ///   * `nullptr` if `ty` is none of the above
     static const Type* ElementOf(const Type* ty, uint32_t* count = nullptr);
 
     /// @param ty the type to obtain the deepest element type from
     /// @param count if not null, then this is assigned the full number of most deeply nested
     /// elements in the type. For example, the count of an `array<vec3<f32>, 5>` type would be 15.
-    /// @returns `ty` if `ty` is an abstract or scalar, or the element type if ty is a vector,
-    /// matrix, or the deepest element type if ty is an array, otherwise nullptr.
+    /// @returns
+    ///   * `ty` if `ty` is an abstract or scalar
+    ///   * the element type if `ty` is a vector
+    ///   * the matrix element type if `ty` is a matrix
+    ///   * the deepest element type if `ty` is an array
+    ///   * `nullptr` if `ty` is none of the above
     static const Type* DeepestElementOf(const Type* ty, uint32_t* count = nullptr);
 
     /// @param types a pointer to a list of `const Type*`.
diff --git a/src/tint/sem/type_test.cc b/src/tint/sem/type_test.cc
index e604453..8458880 100644
--- a/src/tint/sem/type_test.cc
+++ b/src/tint/sem/type_test.cc
@@ -157,9 +157,9 @@
     EXPECT_TYPE(Type::ElementOf(vec4_f32), f32);
     EXPECT_TYPE(Type::ElementOf(vec3_u32), u32);
     EXPECT_TYPE(Type::ElementOf(vec3_i32), i32);
-    EXPECT_TYPE(Type::ElementOf(mat2x4_f32), f32);
-    EXPECT_TYPE(Type::ElementOf(mat4x2_f32), f32);
-    EXPECT_TYPE(Type::ElementOf(mat4x3_f16), f16);
+    EXPECT_TYPE(Type::ElementOf(mat2x4_f32), vec4_f32);
+    EXPECT_TYPE(Type::ElementOf(mat4x2_f32), vec2_f32);
+    EXPECT_TYPE(Type::ElementOf(mat4x3_f16), vec3_f16);
     EXPECT_TYPE(Type::ElementOf(str), nullptr);
     EXPECT_TYPE(Type::ElementOf(arr_i32), i32);
     EXPECT_TYPE(Type::ElementOf(arr_vec3_i32), vec3_i32);
@@ -195,14 +195,14 @@
     EXPECT_TYPE(Type::ElementOf(vec3_i32, &count), i32);
     EXPECT_EQ(count, 3u);
     count = 42;
-    EXPECT_TYPE(Type::ElementOf(mat2x4_f32, &count), f32);
-    EXPECT_EQ(count, 8u);
+    EXPECT_TYPE(Type::ElementOf(mat2x4_f32, &count), vec4_f32);
+    EXPECT_EQ(count, 2u);
     count = 42;
-    EXPECT_TYPE(Type::ElementOf(mat4x2_f32, &count), f32);
-    EXPECT_EQ(count, 8u);
+    EXPECT_TYPE(Type::ElementOf(mat4x2_f32, &count), vec2_f32);
+    EXPECT_EQ(count, 4u);
     count = 42;
-    EXPECT_TYPE(Type::ElementOf(mat4x3_f16, &count), f16);
-    EXPECT_EQ(count, 12u);
+    EXPECT_TYPE(Type::ElementOf(mat4x3_f16, &count), vec3_f16);
+    EXPECT_EQ(count, 4u);
     count = 42;
     EXPECT_TYPE(Type::ElementOf(str, &count), nullptr);
     EXPECT_EQ(count, 0u);