[spirv-writer] Allow casting of vectors along with scalars.

The current `cast` conversion code only handles scalar types and fails
if provided with vectors. This CL updates the logic to accept scalars
along with the provided scalar cases.

Bug: tint:96
Change-Id: I60772e75286fc3ee7a9dfba6634db069062b22d0
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23820
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 93f3133..59a36c0 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -1527,20 +1527,36 @@
   auto* from_type = cast->expr()->result_type()->UnwrapPtrIfNeeded();
 
   spv::Op op = spv::Op::OpNop;
-  if (from_type->IsI32() && to_type->IsF32()) {
+  if ((from_type->IsI32() && to_type->IsF32()) ||
+      (from_type->is_signed_integer_vector() && to_type->is_float_vector())) {
     op = spv::Op::OpConvertSToF;
-  } else if (from_type->IsU32() && to_type->IsF32()) {
+  } else if ((from_type->IsU32() && to_type->IsF32()) ||
+             (from_type->is_unsigned_integer_vector() &&
+              to_type->is_float_vector())) {
     op = spv::Op::OpConvertUToF;
-  } else if (from_type->IsF32() && to_type->IsI32()) {
+  } else if ((from_type->IsF32() && to_type->IsI32()) ||
+             (from_type->is_float_vector() &&
+              to_type->is_signed_integer_vector())) {
     op = spv::Op::OpConvertFToS;
-  } else if (from_type->IsF32() && to_type->IsU32()) {
+  } else if ((from_type->IsF32() && to_type->IsU32()) ||
+             (from_type->is_float_vector() &&
+              to_type->is_unsigned_integer_vector())) {
     op = spv::Op::OpConvertFToU;
   } else if ((from_type->IsU32() && to_type->IsU32()) ||
              (from_type->IsI32() && to_type->IsI32()) ||
-             (from_type->IsF32() && to_type->IsF32())) {
+             (from_type->IsF32() && to_type->IsF32()) ||
+             (from_type->is_unsigned_integer_vector() &&
+              to_type->is_unsigned_integer_vector()) ||
+             (from_type->is_signed_integer_vector() &&
+              to_type->is_signed_integer_vector()) ||
+             (from_type->is_float_vector() && to_type->is_float_vector())) {
     op = spv::Op::OpCopyObject;
   } else if ((from_type->IsI32() && to_type->IsU32()) ||
-             (from_type->IsU32() && to_type->IsI32())) {
+             (from_type->IsU32() && to_type->IsI32()) ||
+             (from_type->is_signed_integer_vector() &&
+              to_type->is_unsigned_integer_vector()) ||
+             (from_type->is_unsigned_integer_vector() &&
+              to_type->is_integer_scalar_or_vector())) {
     op = spv::Op::OpBitcast;
   }
 
diff --git a/src/writer/spirv/builder_cast_expression_test.cc b/src/writer/spirv/builder_cast_expression_test.cc
index c559ec9..675bb20 100644
--- a/src/writer/spirv/builder_cast_expression_test.cc
+++ b/src/writer/spirv/builder_cast_expression_test.cc
@@ -22,6 +22,7 @@
 #include "src/ast/type/f32_type.h"
 #include "src/ast/type/i32_type.h"
 #include "src/ast/type/u32_type.h"
+#include "src/ast/type/vector_type.h"
 #include "src/ast/uint_literal.h"
 #include "src/context.h"
 #include "src/type_determiner.h"
@@ -329,6 +330,224 @@
 )");
 }
 
+TEST_F(BuilderTest, Cast_Vectors_I32_to_F32) {
+  ast::type::I32Type i32;
+  ast::type::VectorType ivec3(&i32, 3);
+  ast::type::F32Type f32;
+  ast::type::VectorType fvec3(&f32, 3);
+
+  auto var =
+      std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &ivec3);
+
+  ast::CastExpression cast(&fvec3,
+                           std::make_unique<ast::IdentifierExpression>("i"));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(var.get());
+  ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+  ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+  EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
+
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 1
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+%8 = OpTypeFloat 32
+%7 = OpTypeVector %8 3
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(%9 = OpLoad %3 %1
+%6 = OpConvertSToF %7 %9
+)");
+}
+
+TEST_F(BuilderTest, Cast_Vectors_U32_to_F32) {
+  ast::type::U32Type u32;
+  ast::type::VectorType uvec3(&u32, 3);
+  ast::type::F32Type f32;
+  ast::type::VectorType fvec3(&f32, 3);
+
+  auto var =
+      std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &uvec3);
+
+  ast::CastExpression cast(&fvec3,
+                           std::make_unique<ast::IdentifierExpression>("i"));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(var.get());
+  ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+  ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+  EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
+
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 0
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+%8 = OpTypeFloat 32
+%7 = OpTypeVector %8 3
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(%9 = OpLoad %3 %1
+%6 = OpConvertUToF %7 %9
+)");
+}
+
+TEST_F(BuilderTest, Cast_Vectors_F32_to_I32) {
+  ast::type::I32Type i32;
+  ast::type::VectorType ivec3(&i32, 3);
+  ast::type::F32Type f32;
+  ast::type::VectorType fvec3(&f32, 3);
+
+  auto var =
+      std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &fvec3);
+
+  ast::CastExpression cast(&ivec3,
+                           std::make_unique<ast::IdentifierExpression>("i"));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(var.get());
+  ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+  ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+  EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
+
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+%8 = OpTypeInt 32 1
+%7 = OpTypeVector %8 3
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(%9 = OpLoad %3 %1
+%6 = OpConvertFToS %7 %9
+)");
+}
+
+TEST_F(BuilderTest, Cast_Vectors_F32_to_U32) {
+  ast::type::U32Type u32;
+  ast::type::VectorType uvec3(&u32, 3);
+  ast::type::F32Type f32;
+  ast::type::VectorType fvec3(&f32, 3);
+
+  auto var =
+      std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &fvec3);
+
+  ast::CastExpression cast(&uvec3,
+                           std::make_unique<ast::IdentifierExpression>("i"));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(var.get());
+  ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+  ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+  EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
+
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+%8 = OpTypeInt 32 0
+%7 = OpTypeVector %8 3
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(%9 = OpLoad %3 %1
+%6 = OpConvertFToU %7 %9
+)");
+}
+
+TEST_F(BuilderTest, Cast_Vectors_U32_to_U32) {
+  ast::type::U32Type u32;
+  ast::type::VectorType uvec3(&u32, 3);
+
+  auto var =
+      std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &uvec3);
+
+  ast::CastExpression cast(&uvec3,
+                           std::make_unique<ast::IdentifierExpression>("i"));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(var.get());
+  ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+  ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+  EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
+
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 0
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(%7 = OpLoad %3 %1
+%6 = OpCopyObject %3 %7
+)");
+}
+
+TEST_F(BuilderTest, Cast_Vectors_I32_to_U32) {
+  ast::type::U32Type u32;
+  ast::type::VectorType uvec3(&u32, 3);
+  ast::type::I32Type i32;
+  ast::type::VectorType ivec3(&i32, 3);
+
+  auto var =
+      std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &ivec3);
+
+  ast::CastExpression cast(&uvec3,
+                           std::make_unique<ast::IdentifierExpression>("i"));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(var.get());
+  ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+  ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+  EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
+
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 1
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+%8 = OpTypeInt 32 0
+%7 = OpTypeVector %8 3
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(%9 = OpLoad %3 %1
+%6 = OpBitcast %7 %9
+)");
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace writer