Manual cherry-pick of https://dawn-review.googlesource.com/c/tint/+/71681

Original commit message (modified):

resolver: Fixes for bitcasts

Add missing validation for bitcasts. We were permitting any bitcast that wasn't a being cast to a pointer type, when the spec only allows:
 * numeric_scalar to numeric_scalar
 * vecN<numeric_scalar> to vecN<numeric_scalar>

Add lots of tests.

Fixed: chromium:1276320
Change-Id: Iaaed4759234be1ed739e3c016d27679bde081ddc
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/72800
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 97e92f6..c92e231 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -666,6 +666,7 @@
     resolver/assignment_validation_test.cc
     resolver/atomics_test.cc
     resolver/atomics_validation_test.cc
+    resolver/bitcast_validation_test.cc
     resolver/builtins_validation_test.cc
     resolver/call_test.cc
     resolver/call_validation_test.cc
diff --git a/src/resolver/bitcast_validation_test.cc b/src/resolver/bitcast_validation_test.cc
new file mode 100644
index 0000000..d4ce082
--- /dev/null
+++ b/src/resolver/bitcast_validation_test.cc
@@ -0,0 +1,228 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/ast/bitcast_expression.h"
+#include "src/resolver/resolver.h"
+#include "src/resolver/resolver_test_helper.h"
+
+#include "gmock/gmock.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+struct Type {
+  template <typename T>
+  static constexpr Type Create() {
+    return Type{builder::DataType<T>::AST, builder::DataType<T>::Sem,
+                builder::DataType<T>::Expr};
+  }
+
+  builder::ast_type_func_ptr ast;
+  builder::sem_type_func_ptr sem;
+  builder::ast_expr_func_ptr expr;
+};
+
+static constexpr Type kNumericScalars[] = {
+    Type::Create<builder::f32>(),
+    Type::Create<builder::i32>(),
+    Type::Create<builder::u32>(),
+};
+static constexpr Type kVec2NumericScalars[] = {
+    Type::Create<builder::vec2<builder::f32>>(),
+    Type::Create<builder::vec2<builder::i32>>(),
+    Type::Create<builder::vec2<builder::u32>>(),
+};
+static constexpr Type kVec3NumericScalars[] = {
+    Type::Create<builder::vec3<builder::f32>>(),
+    Type::Create<builder::vec3<builder::i32>>(),
+    Type::Create<builder::vec3<builder::u32>>(),
+};
+static constexpr Type kVec4NumericScalars[] = {
+    Type::Create<builder::vec4<builder::f32>>(),
+    Type::Create<builder::vec4<builder::i32>>(),
+    Type::Create<builder::vec4<builder::u32>>(),
+};
+static constexpr Type kInvalid[] = {
+    // A non-exhaustive selection of uncastable types
+    Type::Create<bool>(),
+    Type::Create<builder::vec2<bool>>(),
+    Type::Create<builder::vec3<bool>>(),
+    Type::Create<builder::vec4<bool>>(),
+    Type::Create<builder::array<2, builder::i32>>(),
+    Type::Create<builder::array<3, builder::u32>>(),
+    Type::Create<builder::array<4, builder::f32>>(),
+    Type::Create<builder::array<5, bool>>(),
+    Type::Create<builder::mat2x2<builder::f32>>(),
+    Type::Create<builder::mat3x3<builder::f32>>(),
+    Type::Create<builder::mat4x4<builder::f32>>(),
+    Type::Create<builder::ptr<builder::i32>>(),
+    Type::Create<builder::ptr<builder::array<2, builder::i32>>>(),
+    Type::Create<builder::ptr<builder::mat2x2<builder::f32>>>(),
+};
+
+using ResolverBitcastValidationTest =
+    ResolverTestWithParam<std::tuple<Type, Type>>;
+
+////////////////////////////////////////////////////////////////////////////////
+// Valid bitcasts
+////////////////////////////////////////////////////////////////////////////////
+using ResolverBitcastValidationTestPass = ResolverBitcastValidationTest;
+TEST_P(ResolverBitcastValidationTestPass, Test) {
+  auto src = std::get<0>(GetParam());
+  auto dst = std::get<1>(GetParam());
+
+  auto* cast = Bitcast(dst.ast(*this), src.expr(*this, 0));
+  WrapInFunction(cast);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+  EXPECT_EQ(TypeOf(cast), dst.sem(*this));
+}
+INSTANTIATE_TEST_SUITE_P(Scalars,
+                         ResolverBitcastValidationTestPass,
+                         testing::Combine(testing::ValuesIn(kNumericScalars),
+                                          testing::ValuesIn(kNumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec2,
+    ResolverBitcastValidationTestPass,
+    testing::Combine(testing::ValuesIn(kVec2NumericScalars),
+                     testing::ValuesIn(kVec2NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec3,
+    ResolverBitcastValidationTestPass,
+    testing::Combine(testing::ValuesIn(kVec3NumericScalars),
+                     testing::ValuesIn(kVec3NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec4,
+    ResolverBitcastValidationTestPass,
+    testing::Combine(testing::ValuesIn(kVec4NumericScalars),
+                     testing::ValuesIn(kVec4NumericScalars)));
+
+////////////////////////////////////////////////////////////////////////////////
+// Invalid source type for bitcasts
+////////////////////////////////////////////////////////////////////////////////
+using ResolverBitcastValidationTestInvalidSrcTy = ResolverBitcastValidationTest;
+TEST_P(ResolverBitcastValidationTestInvalidSrcTy, Test) {
+  auto src = std::get<0>(GetParam());
+  auto dst = std::get<1>(GetParam());
+
+  auto* cast = Bitcast(dst.ast(*this), Expr(Source{{12, 34}}, "src"));
+  WrapInFunction(Const("src", nullptr, src.expr(*this, 0)), cast);
+
+  auto expected = "12:34 error: '" + src.sem(*this)->FriendlyName(Symbols()) +
+                  "' cannot be bitcast";
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(), expected);
+}
+INSTANTIATE_TEST_SUITE_P(Scalars,
+                         ResolverBitcastValidationTestInvalidSrcTy,
+                         testing::Combine(testing::ValuesIn(kInvalid),
+                                          testing::ValuesIn(kNumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec2,
+    ResolverBitcastValidationTestInvalidSrcTy,
+    testing::Combine(testing::ValuesIn(kInvalid),
+                     testing::ValuesIn(kVec2NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec3,
+    ResolverBitcastValidationTestInvalidSrcTy,
+    testing::Combine(testing::ValuesIn(kInvalid),
+                     testing::ValuesIn(kVec3NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec4,
+    ResolverBitcastValidationTestInvalidSrcTy,
+    testing::Combine(testing::ValuesIn(kInvalid),
+                     testing::ValuesIn(kVec4NumericScalars)));
+
+////////////////////////////////////////////////////////////////////////////////
+// Invalid target type for bitcasts
+////////////////////////////////////////////////////////////////////////////////
+using ResolverBitcastValidationTestInvalidDstTy = ResolverBitcastValidationTest;
+TEST_P(ResolverBitcastValidationTestInvalidDstTy, Test) {
+  auto src = std::get<0>(GetParam());
+  auto dst = std::get<1>(GetParam());
+
+  // Use an alias so we can put a Source on the bitcast type
+  Alias("T", dst.ast(*this));
+  WrapInFunction(
+      Bitcast(ty.type_name(Source{{12, 34}}, "T"), src.expr(*this, 0)));
+
+  auto expected = "12:34 error: cannot bitcast to '" +
+                  dst.sem(*this)->FriendlyName(Symbols()) + "'";
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(), expected);
+}
+INSTANTIATE_TEST_SUITE_P(Scalars,
+                         ResolverBitcastValidationTestInvalidDstTy,
+                         testing::Combine(testing::ValuesIn(kNumericScalars),
+                                          testing::ValuesIn(kInvalid)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec2,
+    ResolverBitcastValidationTestInvalidDstTy,
+    testing::Combine(testing::ValuesIn(kVec2NumericScalars),
+                     testing::ValuesIn(kInvalid)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec3,
+    ResolverBitcastValidationTestInvalidDstTy,
+    testing::Combine(testing::ValuesIn(kVec3NumericScalars),
+                     testing::ValuesIn(kInvalid)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec4,
+    ResolverBitcastValidationTestInvalidDstTy,
+    testing::Combine(testing::ValuesIn(kVec4NumericScalars),
+                     testing::ValuesIn(kInvalid)));
+
+////////////////////////////////////////////////////////////////////////////////
+// Incompatible bitcast, but both src and dst types are valid
+////////////////////////////////////////////////////////////////////////////////
+using ResolverBitcastValidationTestIncompatible = ResolverBitcastValidationTest;
+TEST_P(ResolverBitcastValidationTestIncompatible, Test) {
+  auto src = std::get<0>(GetParam());
+  auto dst = std::get<1>(GetParam());
+
+  WrapInFunction(Bitcast(Source{{12, 34}}, dst.ast(*this), src.expr(*this, 0)));
+
+  auto expected = "12:34 error: cannot bitcast from '" +
+                  src.sem(*this)->FriendlyName(Symbols()) + "' to '" +
+                  dst.sem(*this)->FriendlyName(Symbols()) + "'";
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(), expected);
+}
+INSTANTIATE_TEST_SUITE_P(
+    ScalarToVec2,
+    ResolverBitcastValidationTestIncompatible,
+    testing::Combine(testing::ValuesIn(kNumericScalars),
+                     testing::ValuesIn(kVec2NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec2ToVec3,
+    ResolverBitcastValidationTestIncompatible,
+    testing::Combine(testing::ValuesIn(kVec2NumericScalars),
+                     testing::ValuesIn(kVec3NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec3ToVec4,
+    ResolverBitcastValidationTestIncompatible,
+    testing::Combine(testing::ValuesIn(kVec3NumericScalars),
+                     testing::ValuesIn(kVec4NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+    Vec4ToScalar,
+    ResolverBitcastValidationTestIncompatible,
+    testing::Combine(testing::ValuesIn(kVec4NumericScalars),
+                     testing::ValuesIn(kNumericScalars)));
+
+}  // namespace
+}  // namespace resolver
+}  // namespace tint
diff --git a/src/resolver/ptr_ref_validation_test.cc b/src/resolver/ptr_ref_validation_test.cc
index 59e7b44..a1c6c10 100644
--- a/src/resolver/ptr_ref_validation_test.cc
+++ b/src/resolver/ptr_ref_validation_test.cc
@@ -143,19 +143,6 @@
             "'ptr<storage, i32, read_write>'");
 }
 
-TEST_F(ResolverTest, Expr_Bitcast_ptr) {
-  auto* vf = Var("vf", ty.f32());
-  auto* bitcast = create<ast::BitcastExpression>(
-      Source{{12, 34}}, ty.pointer<i32>(ast::StorageClass::kFunction),
-      Expr("vf"));
-  auto* ip =
-      Const("ip", ty.pointer<i32>(ast::StorageClass::kFunction), bitcast);
-  WrapInFunction(Decl(vf), Decl(ip));
-
-  EXPECT_FALSE(r()->Resolve());
-  EXPECT_EQ(r()->error(), "12:34 error: cannot cast to a pointer");
-}
-
 }  // namespace
 }  // namespace resolver
 }  // namespace tint
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 3c52f45..44854b5 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -2408,11 +2408,47 @@
   if (!ty) {
     return false;
   }
-  if (ty->Is<sem::Pointer>()) {
-    AddError("cannot cast to a pointer", expr->source);
+
+  SetExprInfo(expr, ty, expr->type->FriendlyName(builder_->Symbols()));
+
+  if (!ValidateBitcast(expr, ty)) {
     return false;
   }
-  SetExprInfo(expr, ty, expr->type->FriendlyName(builder_->Symbols()));
+
+  return true;
+}
+
+bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast,
+                               const sem::Type* to) {
+  auto TypeNameOf = [this](const sem::Type* ty) {
+    return ty->UnwrapRef()->FriendlyName(builder_->Symbols());
+  };
+
+  auto* from = TypeOf(cast->expr)->UnwrapRef();
+  if (!from->is_numeric_scalar_or_vector()) {
+    AddError("'" + TypeNameOf(from) + "' cannot be bitcast",
+             cast->expr->source);
+    return false;
+  }
+  if (!to->is_numeric_scalar_or_vector()) {
+    AddError("cannot bitcast to '" + TypeNameOf(to) + "'", cast->type->source);
+    return false;
+  }
+
+  auto width = [&](const sem::Type* ty) {
+    if (auto* vec = ty->As<sem::Vector>()) {
+      return vec->Width();
+    }
+    return 1u;
+  };
+
+  if (width(from) != width(to)) {
+    AddError("cannot bitcast from '" + TypeNameOf(from) + "' to '" +
+                 TypeNameOf(to) + "'",
+             cast->source);
+    return false;
+  }
+
   return true;
 }
 
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 212360b..0827da8 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -275,6 +275,7 @@
   bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
                                  const sem::Type* storage_type,
                                  const bool is_input);
+  bool ValidateBitcast(const ast::BitcastExpression* cast, const sem::Type* to);
   bool ValidateCall(const ast::CallExpression* call);
   bool ValidateCallStatement(const ast::CallStatement* stmt);
   bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h
index 9f5ff52..bc50d93 100644
--- a/src/resolver/resolver_test_helper.h
+++ b/src/resolver/resolver_test_helper.h
@@ -171,6 +171,9 @@
 template <typename TO>
 using alias3 = alias<TO, 3>;
 
+template <typename TO>
+struct ptr {};
+
 using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b);
 using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b,
                                                      int elem_value);
@@ -387,6 +390,36 @@
   }
 };
 
+/// Helper for building pointer types and expressions
+template <typename T>
+struct DataType<ptr<T>> {
+  /// true if the pointer type is a composite type
+  static constexpr bool is_composite = false;
+
+  /// @param b the ProgramBuilder
+  /// @return a new AST alias type
+  static inline const ast::Type* AST(ProgramBuilder& b) {
+    return b.create<ast::Pointer>(DataType<T>::AST(b),
+                                  ast::StorageClass::kPrivate,
+                                  ast::Access::kReadWrite);
+  }
+  /// @param b the ProgramBuilder
+  /// @return the semantic aliased type
+  static inline const sem::Type* Sem(ProgramBuilder& b) {
+    return b.create<sem::Pointer>(DataType<T>::Sem(b),
+                                  ast::StorageClass::kPrivate,
+                                  ast::Access::kReadWrite);
+  }
+
+  /// @param b the ProgramBuilder
+  /// @return a new AST expression of the alias type
+  static inline const ast::Expression* Expr(ProgramBuilder& b, int /*unused*/) {
+    auto sym = b.Symbols().New("global_for_ptr");
+    b.Global(sym, DataType<T>::AST(b), ast::StorageClass::kPrivate);
+    return b.AddressOf(sym);
+  }
+};
+
 /// Helper for building array types and expressions
 template <int N, typename T>
 struct DataType<array<N, T>> {
@@ -401,7 +434,14 @@
   /// @param b the ProgramBuilder
   /// @return the semantic array type
   static inline const sem::Type* Sem(ProgramBuilder& b) {
-    return b.create<sem::Array>(DataType<T>::Sem(b), N);
+    auto* el = DataType<T>::Sem(b);
+    return b.create<sem::Array>(
+        /* element */ el,
+        /* count */ N,
+        /* align */ el->Align(),
+        /* size */ el->Size(),
+        /* stride */ el->Align(),
+        /* implicit_stride */ el->Align());
   }
   /// @param b the ProgramBuilder
   /// @param elem_value the value each element in the array will be initialized
diff --git a/src/utils/unique_vector.h b/src/utils/unique_vector.h
index 1ae751a..f2c880b 100644
--- a/src/utils/unique_vector.h
+++ b/src/utils/unique_vector.h
@@ -15,6 +15,7 @@
 #ifndef SRC_UTILS_UNIQUE_VECTOR_H_
 #define SRC_UTILS_UNIQUE_VECTOR_H_
 
+#include <cstddef>
 #include <unordered_set>
 #include <vector>
 
diff --git a/src/writer/glsl/generator_impl_binary_test.cc b/src/writer/glsl/generator_impl_binary_test.cc
index c1b6724..6f9554d 100644
--- a/src/writer/glsl/generator_impl_binary_test.cc
+++ b/src/writer/glsl/generator_impl_binary_test.cc
@@ -464,36 +464,6 @@
 )");
 }
 
-TEST_F(GlslGeneratorImplTest_Binary, Bitcast_WithLogical) {
-  // as<i32>(a && (b || c))
-
-  Global("a", ty.bool_(), ast::StorageClass::kPrivate);
-  Global("b", ty.bool_(), ast::StorageClass::kPrivate);
-  Global("c", ty.bool_(), ast::StorageClass::kPrivate);
-
-  auto* expr = create<ast::BitcastExpression>(
-      ty.i32(), create<ast::BinaryExpression>(
-                    ast::BinaryOp::kLogicalAnd, Expr("a"),
-                    create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
-                                                  Expr("b"), Expr("c"))));
-  WrapInFunction(expr);
-
-  GeneratorImpl& gen = Build();
-
-  std::stringstream out;
-  ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
-  EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
-if (tint_tmp) {
-  bool tint_tmp_1 = b;
-  if (!tint_tmp_1) {
-    tint_tmp_1 = c;
-  }
-  tint_tmp = (tint_tmp_1);
-}
-)");
-  EXPECT_EQ(out.str(), R"(asint((tint_tmp)))");
-}
-
 TEST_F(GlslGeneratorImplTest_Binary, Call_WithLogical) {
   // foo(a && b, c || d, (a || c) && (b || d))
 
diff --git a/src/writer/hlsl/generator_impl_binary_test.cc b/src/writer/hlsl/generator_impl_binary_test.cc
index f7184ac..3e9b7f1 100644
--- a/src/writer/hlsl/generator_impl_binary_test.cc
+++ b/src/writer/hlsl/generator_impl_binary_test.cc
@@ -464,36 +464,6 @@
 )");
 }
 
-TEST_F(HlslGeneratorImplTest_Binary, Bitcast_WithLogical) {
-  // as<i32>(a && (b || c))
-
-  Global("a", ty.bool_(), ast::StorageClass::kPrivate);
-  Global("b", ty.bool_(), ast::StorageClass::kPrivate);
-  Global("c", ty.bool_(), ast::StorageClass::kPrivate);
-
-  auto* expr = create<ast::BitcastExpression>(
-      ty.i32(), create<ast::BinaryExpression>(
-                    ast::BinaryOp::kLogicalAnd, Expr("a"),
-                    create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
-                                                  Expr("b"), Expr("c"))));
-  WrapInFunction(expr);
-
-  GeneratorImpl& gen = Build();
-
-  std::stringstream out;
-  ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
-  EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
-if (tint_tmp) {
-  bool tint_tmp_1 = b;
-  if (!tint_tmp_1) {
-    tint_tmp_1 = c;
-  }
-  tint_tmp = (tint_tmp_1);
-}
-)");
-  EXPECT_EQ(out.str(), R"(asint((tint_tmp)))");
-}
-
 TEST_F(HlslGeneratorImplTest_Binary, Call_WithLogical) {
   // foo(a && b, c || d, (a || c) && (b || d))
 
diff --git a/test/BUILD.gn b/test/BUILD.gn
index edd66b9..f5a6ad4 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -235,6 +235,7 @@
     "../src/resolver/assignment_validation_test.cc",
     "../src/resolver/atomics_test.cc",
     "../src/resolver/atomics_validation_test.cc",
+    "../src/resolver/bitcast_validation_test.cc",
     "../src/resolver/builtins_validation_test.cc",
     "../src/resolver/call_test.cc",
     "../src/resolver/call_validation_test.cc",