[spirv-writer] Allow casting of vectors along with scalars.
The current `cast` conversion code only handles scalar types and fails
if provided with vectors. This CL updates the logic to accept scalars
along with the provided scalar cases.
Bug: tint:96
Change-Id: I60772e75286fc3ee7a9dfba6634db069062b22d0
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23820
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 93f3133..59a36c0 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -1527,20 +1527,36 @@
auto* from_type = cast->expr()->result_type()->UnwrapPtrIfNeeded();
spv::Op op = spv::Op::OpNop;
- if (from_type->IsI32() && to_type->IsF32()) {
+ if ((from_type->IsI32() && to_type->IsF32()) ||
+ (from_type->is_signed_integer_vector() && to_type->is_float_vector())) {
op = spv::Op::OpConvertSToF;
- } else if (from_type->IsU32() && to_type->IsF32()) {
+ } else if ((from_type->IsU32() && to_type->IsF32()) ||
+ (from_type->is_unsigned_integer_vector() &&
+ to_type->is_float_vector())) {
op = spv::Op::OpConvertUToF;
- } else if (from_type->IsF32() && to_type->IsI32()) {
+ } else if ((from_type->IsF32() && to_type->IsI32()) ||
+ (from_type->is_float_vector() &&
+ to_type->is_signed_integer_vector())) {
op = spv::Op::OpConvertFToS;
- } else if (from_type->IsF32() && to_type->IsU32()) {
+ } else if ((from_type->IsF32() && to_type->IsU32()) ||
+ (from_type->is_float_vector() &&
+ to_type->is_unsigned_integer_vector())) {
op = spv::Op::OpConvertFToU;
} else if ((from_type->IsU32() && to_type->IsU32()) ||
(from_type->IsI32() && to_type->IsI32()) ||
- (from_type->IsF32() && to_type->IsF32())) {
+ (from_type->IsF32() && to_type->IsF32()) ||
+ (from_type->is_unsigned_integer_vector() &&
+ to_type->is_unsigned_integer_vector()) ||
+ (from_type->is_signed_integer_vector() &&
+ to_type->is_signed_integer_vector()) ||
+ (from_type->is_float_vector() && to_type->is_float_vector())) {
op = spv::Op::OpCopyObject;
} else if ((from_type->IsI32() && to_type->IsU32()) ||
- (from_type->IsU32() && to_type->IsI32())) {
+ (from_type->IsU32() && to_type->IsI32()) ||
+ (from_type->is_signed_integer_vector() &&
+ to_type->is_unsigned_integer_vector()) ||
+ (from_type->is_unsigned_integer_vector() &&
+ to_type->is_integer_scalar_or_vector())) {
op = spv::Op::OpBitcast;
}
diff --git a/src/writer/spirv/builder_cast_expression_test.cc b/src/writer/spirv/builder_cast_expression_test.cc
index c559ec9..675bb20 100644
--- a/src/writer/spirv/builder_cast_expression_test.cc
+++ b/src/writer/spirv/builder_cast_expression_test.cc
@@ -22,6 +22,7 @@
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
#include "src/ast/type/u32_type.h"
+#include "src/ast/type/vector_type.h"
#include "src/ast/uint_literal.h"
#include "src/context.h"
#include "src/type_determiner.h"
@@ -329,6 +330,224 @@
)");
}
+TEST_F(BuilderTest, Cast_Vectors_I32_to_F32) {
+ ast::type::I32Type i32;
+ ast::type::VectorType ivec3(&i32, 3);
+ ast::type::F32Type f32;
+ ast::type::VectorType fvec3(&f32, 3);
+
+ auto var =
+ std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &ivec3);
+
+ ast::CastExpression cast(&fvec3,
+ std::make_unique<ast::IdentifierExpression>("i"));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(var.get());
+ ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
+
+ Builder b(&mod);
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+ EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 1
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+%8 = OpTypeFloat 32
+%7 = OpTypeVector %8 3
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%9 = OpLoad %3 %1
+%6 = OpConvertSToF %7 %9
+)");
+}
+
+TEST_F(BuilderTest, Cast_Vectors_U32_to_F32) {
+ ast::type::U32Type u32;
+ ast::type::VectorType uvec3(&u32, 3);
+ ast::type::F32Type f32;
+ ast::type::VectorType fvec3(&f32, 3);
+
+ auto var =
+ std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &uvec3);
+
+ ast::CastExpression cast(&fvec3,
+ std::make_unique<ast::IdentifierExpression>("i"));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(var.get());
+ ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
+
+ Builder b(&mod);
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+ EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 0
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+%8 = OpTypeFloat 32
+%7 = OpTypeVector %8 3
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%9 = OpLoad %3 %1
+%6 = OpConvertUToF %7 %9
+)");
+}
+
+TEST_F(BuilderTest, Cast_Vectors_F32_to_I32) {
+ ast::type::I32Type i32;
+ ast::type::VectorType ivec3(&i32, 3);
+ ast::type::F32Type f32;
+ ast::type::VectorType fvec3(&f32, 3);
+
+ auto var =
+ std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &fvec3);
+
+ ast::CastExpression cast(&ivec3,
+ std::make_unique<ast::IdentifierExpression>("i"));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(var.get());
+ ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
+
+ Builder b(&mod);
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+ EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+%8 = OpTypeInt 32 1
+%7 = OpTypeVector %8 3
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%9 = OpLoad %3 %1
+%6 = OpConvertFToS %7 %9
+)");
+}
+
+TEST_F(BuilderTest, Cast_Vectors_F32_to_U32) {
+ ast::type::U32Type u32;
+ ast::type::VectorType uvec3(&u32, 3);
+ ast::type::F32Type f32;
+ ast::type::VectorType fvec3(&f32, 3);
+
+ auto var =
+ std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &fvec3);
+
+ ast::CastExpression cast(&uvec3,
+ std::make_unique<ast::IdentifierExpression>("i"));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(var.get());
+ ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
+
+ Builder b(&mod);
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+ EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+%8 = OpTypeInt 32 0
+%7 = OpTypeVector %8 3
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%9 = OpLoad %3 %1
+%6 = OpConvertFToU %7 %9
+)");
+}
+
+TEST_F(BuilderTest, Cast_Vectors_U32_to_U32) {
+ ast::type::U32Type u32;
+ ast::type::VectorType uvec3(&u32, 3);
+
+ auto var =
+ std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &uvec3);
+
+ ast::CastExpression cast(&uvec3,
+ std::make_unique<ast::IdentifierExpression>("i"));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(var.get());
+ ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
+
+ Builder b(&mod);
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+ EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 0
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%7 = OpLoad %3 %1
+%6 = OpCopyObject %3 %7
+)");
+}
+
+TEST_F(BuilderTest, Cast_Vectors_I32_to_U32) {
+ ast::type::U32Type u32;
+ ast::type::VectorType uvec3(&u32, 3);
+ ast::type::I32Type i32;
+ ast::type::VectorType ivec3(&i32, 3);
+
+ auto var =
+ std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &ivec3);
+
+ ast::CastExpression cast(&uvec3,
+ std::make_unique<ast::IdentifierExpression>("i"));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(var.get());
+ ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
+
+ Builder b(&mod);
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+ EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 1
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+%8 = OpTypeInt 32 0
+%7 = OpTypeVector %8 3
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%9 = OpLoad %3 %1
+%6 = OpBitcast %7 %9
+)");
+}
+
} // namespace
} // namespace spirv
} // namespace writer