resolver: Fixes for bitcasts

Fix dependency graph traversal for bitcasts. These were not being traversed, leading to an ICE if the bitcast type was an alias, as the symbol was not resolved for later use by the resolver.

Add missing validation for bitcasts. We were permitting any bitcast that wasn't a being cast to a pointer type, when the spec only allows:
 * numeric_scalar to numeric_scalar
 * vecN<numeric_scalar> to vecN<numeric_scalar>

Add lots of tests.

Fixed: chromium:1276320
Change-Id: I9e5487ec7649ac543f73fc878e7e282bf932d8cb
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/71681
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 9db881d..de4d2d4 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -673,6 +673,7 @@
     resolver/assignment_validation_test.cc
     resolver/atomics_test.cc
     resolver/atomics_validation_test.cc
+    resolver/bitcast_validation_test.cc
     resolver/builtins_validation_test.cc
     resolver/call_test.cc
     resolver/call_validation_test.cc
diff --git a/src/resolver/bitcast_validation_test.cc b/src/resolver/bitcast_validation_test.cc
new file mode 100644
index 0000000..d4ce082
--- /dev/null
+++ b/src/resolver/bitcast_validation_test.cc
@@ -0,0 +1,228 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/ast/bitcast_expression.h"
+#include "src/resolver/resolver.h"
+#include "src/resolver/resolver_test_helper.h"
+
+#include "gmock/gmock.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+struct Type {
+  template <typename T>
+  static constexpr Type Create() {
+    return Type{builder::DataType<T>::AST, builder::DataType<T>::Sem,
+                builder::DataType<T>::Expr};
+  }
+
+  builder::ast_type_func_ptr ast;
+  builder::sem_type_func_ptr sem;
+  builder::ast_expr_func_ptr expr;
+};
+
+static constexpr Type kNumericScalars[] = {
+    Type::Create<builder::f32>(),
+    Type::Create<builder::i32>(),
+    Type::Create<builder::u32>(),
+};
+static constexpr Type kVec2NumericScalars[] = {
+    Type::Create<builder::vec2<builder::f32>>(),
+    Type::Create<builder::vec2<builder::i32>>(),
+    Type::Create<builder::vec2<builder::u32>>(),
+};
+static constexpr Type kVec3NumericScalars[] = {
+    Type::Create<builder::vec3<builder::f32>>(),
+    Type::Create<builder::vec3<builder::i32>>(),
+    Type::Create<builder::vec3<builder::u32>>(),
+};
+static constexpr Type kVec4NumericScalars[] = {
+    Type::Create<builder::vec4<builder::f32>>(),
+    Type::Create<builder::vec4<builder::i32>>(),
+    Type::Create<builder::vec4<builder::u32>>(),
+};
+static constexpr Type kInvalid[] = {
+    // A non-exhaustive selection of uncastable types
+    Type::Create<bool>(),
+    Type::Create<builder::vec2<bool>>(),
+    Type::Create<builder::vec3<bool>>(),
+    Type::Create<builder::vec4<bool>>(),
+    Type::Create<builder::array<2, builder::i32>>(),
+    Type::Create<builder::array<3, builder::u32>>(),
+    Type::Create<builder::array<4, builder::f32>>(),
+    Type::Create<builder::array<5, bool>>(),
+    Type::Create<builder::mat2x2<builder::f32>>(),
+    Type::Create<builder::mat3x3<builder::f32>>(),
+    Type::Create<builder::mat4x4<builder::f32>>(),
+    Type::Create<builder::ptr<builder::i32>>(),
+    Type::Create<builder::ptr<builder::array<2, builder::i32>>>(),
+    Type::Create<builder::ptr<builder::mat2x2<builder::f32>>>(),
+};
+
+using ResolverBitcastValidationTest =
+    ResolverTestWithParam<std::tuple<Type, Type>>;
+
+////////////////////////////////////////////////////////////////////////////////
+// Valid bitcasts
+////////////////////////////////////////////////////////////////////////////////
+using ResolverBitcastValidationTestPass = ResolverBitcastValidationTest;
+TEST_P(ResolverBitcastValidationTestPass, Test) {
+  auto src = std::get<0>(GetParam());
+  auto dst = std::get<1>(GetParam());
+
+  auto* cast = Bitcast(dst.ast(*this), src.expr(*this, 0));
+  WrapInFunction(cast);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+  EXPECT_EQ(TypeOf(cast), dst.sem(*this));
+}
+INSTANTIATE_TEST_SUITE_P(Scalars,
+                         ResolverBitcastValidationTestPass,
+                         testing::Combine(testing::ValuesIn(kNumericScalars),
+                                          testing::ValuesIn(kNumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec2,
+    ResolverBitcastValidationTestPass,
+    testing::Combine(testing::ValuesIn(kVec2NumericScalars),
+                     testing::ValuesIn(kVec2NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec3,
+    ResolverBitcastValidationTestPass,
+    testing::Combine(testing::ValuesIn(kVec3NumericScalars),
+                     testing::ValuesIn(kVec3NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec4,
+    ResolverBitcastValidationTestPass,
+    testing::Combine(testing::ValuesIn(kVec4NumericScalars),
+                     testing::ValuesIn(kVec4NumericScalars)));
+
+////////////////////////////////////////////////////////////////////////////////
+// Invalid source type for bitcasts
+////////////////////////////////////////////////////////////////////////////////
+using ResolverBitcastValidationTestInvalidSrcTy = ResolverBitcastValidationTest;
+TEST_P(ResolverBitcastValidationTestInvalidSrcTy, Test) {
+  auto src = std::get<0>(GetParam());
+  auto dst = std::get<1>(GetParam());
+
+  auto* cast = Bitcast(dst.ast(*this), Expr(Source{{12, 34}}, "src"));
+  WrapInFunction(Const("src", nullptr, src.expr(*this, 0)), cast);
+
+  auto expected = "12:34 error: '" + src.sem(*this)->FriendlyName(Symbols()) +
+                  "' cannot be bitcast";
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(), expected);
+}
+INSTANTIATE_TEST_SUITE_P(Scalars,
+                         ResolverBitcastValidationTestInvalidSrcTy,
+                         testing::Combine(testing::ValuesIn(kInvalid),
+                                          testing::ValuesIn(kNumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec2,
+    ResolverBitcastValidationTestInvalidSrcTy,
+    testing::Combine(testing::ValuesIn(kInvalid),
+                     testing::ValuesIn(kVec2NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec3,
+    ResolverBitcastValidationTestInvalidSrcTy,
+    testing::Combine(testing::ValuesIn(kInvalid),
+                     testing::ValuesIn(kVec3NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec4,
+    ResolverBitcastValidationTestInvalidSrcTy,
+    testing::Combine(testing::ValuesIn(kInvalid),
+                     testing::ValuesIn(kVec4NumericScalars)));
+
+////////////////////////////////////////////////////////////////////////////////
+// Invalid target type for bitcasts
+////////////////////////////////////////////////////////////////////////////////
+using ResolverBitcastValidationTestInvalidDstTy = ResolverBitcastValidationTest;
+TEST_P(ResolverBitcastValidationTestInvalidDstTy, Test) {
+  auto src = std::get<0>(GetParam());
+  auto dst = std::get<1>(GetParam());
+
+  // Use an alias so we can put a Source on the bitcast type
+  Alias("T", dst.ast(*this));
+  WrapInFunction(
+      Bitcast(ty.type_name(Source{{12, 34}}, "T"), src.expr(*this, 0)));
+
+  auto expected = "12:34 error: cannot bitcast to '" +
+                  dst.sem(*this)->FriendlyName(Symbols()) + "'";
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(), expected);
+}
+INSTANTIATE_TEST_SUITE_P(Scalars,
+                         ResolverBitcastValidationTestInvalidDstTy,
+                         testing::Combine(testing::ValuesIn(kNumericScalars),
+                                          testing::ValuesIn(kInvalid)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec2,
+    ResolverBitcastValidationTestInvalidDstTy,
+    testing::Combine(testing::ValuesIn(kVec2NumericScalars),
+                     testing::ValuesIn(kInvalid)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec3,
+    ResolverBitcastValidationTestInvalidDstTy,
+    testing::Combine(testing::ValuesIn(kVec3NumericScalars),
+                     testing::ValuesIn(kInvalid)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec4,
+    ResolverBitcastValidationTestInvalidDstTy,
+    testing::Combine(testing::ValuesIn(kVec4NumericScalars),
+                     testing::ValuesIn(kInvalid)));
+
+////////////////////////////////////////////////////////////////////////////////
+// Incompatible bitcast, but both src and dst types are valid
+////////////////////////////////////////////////////////////////////////////////
+using ResolverBitcastValidationTestIncompatible = ResolverBitcastValidationTest;
+TEST_P(ResolverBitcastValidationTestIncompatible, Test) {
+  auto src = std::get<0>(GetParam());
+  auto dst = std::get<1>(GetParam());
+
+  WrapInFunction(Bitcast(Source{{12, 34}}, dst.ast(*this), src.expr(*this, 0)));
+
+  auto expected = "12:34 error: cannot bitcast from '" +
+                  src.sem(*this)->FriendlyName(Symbols()) + "' to '" +
+                  dst.sem(*this)->FriendlyName(Symbols()) + "'";
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(), expected);
+}
+INSTANTIATE_TEST_SUITE_P(
+    ScalarToVec2,
+    ResolverBitcastValidationTestIncompatible,
+    testing::Combine(testing::ValuesIn(kNumericScalars),
+                     testing::ValuesIn(kVec2NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec2ToVec3,
+    ResolverBitcastValidationTestIncompatible,
+    testing::Combine(testing::ValuesIn(kVec2NumericScalars),
+                     testing::ValuesIn(kVec3NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec3ToVec4,
+    ResolverBitcastValidationTestIncompatible,
+    testing::Combine(testing::ValuesIn(kVec3NumericScalars),
+                     testing::ValuesIn(kVec4NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec4ToScalar,
+    ResolverBitcastValidationTestIncompatible,
+    testing::Combine(testing::ValuesIn(kVec4NumericScalars),
+                     testing::ValuesIn(kNumericScalars)));
+
+}  // namespace
+}  // namespace resolver
+}  // namespace tint
diff --git a/src/resolver/dependency_graph.cc b/src/resolver/dependency_graph.cc
index 947be99..8627dcb 100644
--- a/src/resolver/dependency_graph.cc
+++ b/src/resolver/dependency_graph.cc
@@ -329,6 +329,9 @@
                   utils::Lookup(graph_.resolved_symbols, call->target.type));
             }
           }
+          if (auto* cast = expr->As<ast::BitcastExpression>()) {
+            TraverseType(cast->type);
+          }
           return ast::TraverseAction::Descend;
         });
   }
diff --git a/src/resolver/dependency_graph_test.cc b/src/resolver/dependency_graph_test.cc
index 5fd048d..dcbcd72 100644
--- a/src/resolver/dependency_graph_test.cc
+++ b/src/resolver/dependency_graph_test.cc
@@ -1282,6 +1282,7 @@
               Block(Assign(V, V)),                    //
               Else(V,                                 //
                    Block(Assign(V, V)))),             //
+           Ignore(Bitcast(T, V)),                     //
            For(Decl(Var(Sym(), T, V)),                //
                Equal(V, V),                           //
                Assign(V, V),                          //
diff --git a/src/resolver/ptr_ref_validation_test.cc b/src/resolver/ptr_ref_validation_test.cc
index 57aea32..367b209 100644
--- a/src/resolver/ptr_ref_validation_test.cc
+++ b/src/resolver/ptr_ref_validation_test.cc
@@ -171,19 +171,6 @@
             "'ptr<storage, i32, read_write>'");
 }
 
-TEST_F(ResolverTest, Expr_Bitcast_ptr) {
-  auto* vf = Var("vf", ty.f32());
-  auto* bitcast = create<ast::BitcastExpression>(
-      Source{{12, 34}}, ty.pointer<i32>(ast::StorageClass::kFunction),
-      Expr("vf"));
-  auto* ip =
-      Const("ip", ty.pointer<i32>(ast::StorageClass::kFunction), bitcast);
-  WrapInFunction(Decl(vf), Decl(ip));
-
-  EXPECT_FALSE(r()->Resolve());
-  EXPECT_EQ(r()->error(), "12:34 error: cannot cast to a pointer");
-}
-
 }  // namespace
 }  // namespace resolver
 }  // namespace tint
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index c53de55..abe47c4 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -1221,15 +1221,17 @@
   if (!ty) {
     return nullptr;
   }
-  if (ty->Is<sem::Pointer>()) {
-    AddError("cannot cast to a pointer", expr->source);
-    return nullptr;
-  }
 
   auto val = EvaluateConstantValue(expr, ty);
   auto* sem =
       builder_->create<sem::Expression>(expr, ty, current_statement_, val);
+
   sem->Behaviors() = inner->Behaviors();
+
+  if (!ValidateBitcast(expr, ty)) {
+    return nullptr;
+  }
+
   return sem;
 }
 
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 2de937e..e7f6deb 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -238,6 +238,7 @@
   bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s);
   bool ValidateAtomicVariable(const sem::Variable* var);
   bool ValidateAssignment(const ast::AssignmentStatement* a);
+  bool ValidateBitcast(const ast::BitcastExpression* cast, const sem::Type* to);
   bool ValidateBreakStatement(const sem::Statement* stmt);
   bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
                                  const sem::Type* storage_type,
diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h
index 9f5ff52..bc50d93 100644
--- a/src/resolver/resolver_test_helper.h
+++ b/src/resolver/resolver_test_helper.h
@@ -171,6 +171,9 @@
 template <typename TO>
 using alias3 = alias<TO, 3>;
 
+template <typename TO>
+struct ptr {};
+
 using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b);
 using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b,
                                                      int elem_value);
@@ -387,6 +390,36 @@
   }
 };
 
+/// Helper for building pointer types and expressions
+template <typename T>
+struct DataType<ptr<T>> {
+  /// true if the pointer type is a composite type
+  static constexpr bool is_composite = false;
+
+  /// @param b the ProgramBuilder
+  /// @return a new AST alias type
+  static inline const ast::Type* AST(ProgramBuilder& b) {
+    return b.create<ast::Pointer>(DataType<T>::AST(b),
+                                  ast::StorageClass::kPrivate,
+                                  ast::Access::kReadWrite);
+  }
+  /// @param b the ProgramBuilder
+  /// @return the semantic aliased type
+  static inline const sem::Type* Sem(ProgramBuilder& b) {
+    return b.create<sem::Pointer>(DataType<T>::Sem(b),
+                                  ast::StorageClass::kPrivate,
+                                  ast::Access::kReadWrite);
+  }
+
+  /// @param b the ProgramBuilder
+  /// @return a new AST expression of the alias type
+  static inline const ast::Expression* Expr(ProgramBuilder& b, int /*unused*/) {
+    auto sym = b.Symbols().New("global_for_ptr");
+    b.Global(sym, DataType<T>::AST(b), ast::StorageClass::kPrivate);
+    return b.AddressOf(sym);
+  }
+};
+
 /// Helper for building array types and expressions
 template <int N, typename T>
 struct DataType<array<N, T>> {
@@ -401,7 +434,14 @@
   /// @param b the ProgramBuilder
   /// @return the semantic array type
   static inline const sem::Type* Sem(ProgramBuilder& b) {
-    return b.create<sem::Array>(DataType<T>::Sem(b), N);
+    auto* el = DataType<T>::Sem(b);
+    return b.create<sem::Array>(
+        /* element */ el,
+        /* count */ N,
+        /* align */ el->Align(),
+        /* size */ el->Size(),
+        /* stride */ el->Align(),
+        /* implicit_stride */ el->Align());
   }
   /// @param b the ProgramBuilder
   /// @param elem_value the value each element in the array will be initialized
diff --git a/src/resolver/resolver_validation.cc b/src/resolver/resolver_validation.cc
index f376335..7c6afef 100644
--- a/src/resolver/resolver_validation.cc
+++ b/src/resolver/resolver_validation.cc
@@ -1347,6 +1347,36 @@
   return true;
 }
 
+bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast,
+                               const sem::Type* to) {
+  auto* from = TypeOf(cast->expr)->UnwrapRef();
+  if (!from->is_numeric_scalar_or_vector()) {
+    AddError("'" + TypeNameOf(from) + "' cannot be bitcast",
+             cast->expr->source);
+    return false;
+  }
+  if (!to->is_numeric_scalar_or_vector()) {
+    AddError("cannot bitcast to '" + TypeNameOf(to) + "'", cast->type->source);
+    return false;
+  }
+
+  auto width = [&](const sem::Type* ty) {
+    if (auto* vec = ty->As<sem::Vector>()) {
+      return vec->Width();
+    }
+    return 1u;
+  };
+
+  if (width(from) != width(to)) {
+    AddError("cannot bitcast from '" + TypeNameOf(from) + "' to '" +
+                 TypeNameOf(to) + "'",
+             cast->source);
+    return false;
+  }
+
+  return true;
+}
+
 bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) {
   if (!stmt->FindFirstParent<sem::LoopBlockStatement, sem::CaseStatement>()) {
     AddError("break statement must be in a loop or switch case",
diff --git a/src/writer/glsl/generator_impl_binary_test.cc b/src/writer/glsl/generator_impl_binary_test.cc
index 5c397e9..5ef5339 100644
--- a/src/writer/glsl/generator_impl_binary_test.cc
+++ b/src/writer/glsl/generator_impl_binary_test.cc
@@ -452,36 +452,6 @@
 )");
 }
 
-TEST_F(GlslGeneratorImplTest_Binary, Bitcast_WithLogical) {
-  // as<i32>(a && (b || c))
-
-  Global("a", ty.bool_(), ast::StorageClass::kPrivate);
-  Global("b", ty.bool_(), ast::StorageClass::kPrivate);
-  Global("c", ty.bool_(), ast::StorageClass::kPrivate);
-
-  auto* expr = create<ast::BitcastExpression>(
-      ty.i32(), create<ast::BinaryExpression>(
-                    ast::BinaryOp::kLogicalAnd, Expr("a"),
-                    create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
-                                                  Expr("b"), Expr("c"))));
-  WrapInFunction(expr);
-
-  GeneratorImpl& gen = Build();
-
-  std::stringstream out;
-  ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
-  EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
-if (tint_tmp) {
-  bool tint_tmp_1 = b;
-  if (!tint_tmp_1) {
-    tint_tmp_1 = c;
-  }
-  tint_tmp = (tint_tmp_1);
-}
-)");
-  EXPECT_EQ(out.str(), R"(int((tint_tmp)))");
-}
-
 TEST_F(GlslGeneratorImplTest_Binary, Call_WithLogical) {
   // foo(a && b, c || d, (a || c) && (b || d))
 
diff --git a/src/writer/hlsl/generator_impl_binary_test.cc b/src/writer/hlsl/generator_impl_binary_test.cc
index 44d27b7..c14a89a 100644
--- a/src/writer/hlsl/generator_impl_binary_test.cc
+++ b/src/writer/hlsl/generator_impl_binary_test.cc
@@ -452,36 +452,6 @@
 )");
 }
 
-TEST_F(HlslGeneratorImplTest_Binary, Bitcast_WithLogical) {
-  // as<i32>(a && (b || c))
-
-  Global("a", ty.bool_(), ast::StorageClass::kPrivate);
-  Global("b", ty.bool_(), ast::StorageClass::kPrivate);
-  Global("c", ty.bool_(), ast::StorageClass::kPrivate);
-
-  auto* expr = create<ast::BitcastExpression>(
-      ty.i32(), create<ast::BinaryExpression>(
-                    ast::BinaryOp::kLogicalAnd, Expr("a"),
-                    create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
-                                                  Expr("b"), Expr("c"))));
-  WrapInFunction(expr);
-
-  GeneratorImpl& gen = Build();
-
-  std::stringstream out;
-  ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
-  EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
-if (tint_tmp) {
-  bool tint_tmp_1 = b;
-  if (!tint_tmp_1) {
-    tint_tmp_1 = c;
-  }
-  tint_tmp = (tint_tmp_1);
-}
-)");
-  EXPECT_EQ(out.str(), R"(asint((tint_tmp)))");
-}
-
 TEST_F(HlslGeneratorImplTest_Binary, Call_WithLogical) {
   // foo(a && b, c || d, (a || c) && (b || d))
 
diff --git a/test/BUILD.gn b/test/BUILD.gn
index a4467af..cad09fb 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -236,6 +236,7 @@
     "../src/resolver/assignment_validation_test.cc",
     "../src/resolver/atomics_test.cc",
     "../src/resolver/atomics_validation_test.cc",
+    "../src/resolver/bitcast_validation_test.cc",
     "../src/resolver/builtins_validation_test.cc",
     "../src/resolver/call_test.cc",
     "../src/resolver/call_validation_test.cc",