[ir][spirv-writer] Implement convert instructions
These map to one of several instructions, depending on the input and
output types. Conversions to bool are implemented using OpSelect
between a zero and one value, splatted to vectors if needed.
Bug: tint:1906
Change-Id: If208873506c314e3dd1bc4026ef9ebeba7c8933a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/139900
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 5cbdf81..4af3f92 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -2045,6 +2045,7 @@
"writer/spirv/ir/generator_impl_ir_builtin_test.cc",
"writer/spirv/ir/generator_impl_ir_constant_test.cc",
"writer/spirv/ir/generator_impl_ir_construct_test.cc",
+ "writer/spirv/ir/generator_impl_ir_convert_test.cc",
"writer/spirv/ir/generator_impl_ir_function_test.cc",
"writer/spirv/ir/generator_impl_ir_if_test.cc",
"writer/spirv/ir/generator_impl_ir_loop_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 7a91c92..d52ea2b 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -1316,6 +1316,7 @@
writer/spirv/ir/generator_impl_ir_builtin_test.cc
writer/spirv/ir/generator_impl_ir_constant_test.cc
writer/spirv/ir/generator_impl_ir_construct_test.cc
+ writer/spirv/ir/generator_impl_ir_convert_test.cc
writer/spirv/ir/generator_impl_ir_function_test.cc
writer/spirv/ir/generator_impl_ir_if_test.cc
writer/spirv/ir/generator_impl_ir_loop_test.cc
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index 00b021b..edc3675 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -78,6 +78,8 @@
namespace {
+using namespace tint::number_suffixes; // NOLINT
+
constexpr uint32_t kGeneratorVersion = 1;
void Sanitize(ir::Module* module) {
@@ -738,6 +740,7 @@
[&](ir::Binary* b) { EmitBinary(b); }, //
[&](ir::BuiltinCall* b) { EmitBuiltinCall(b); }, //
[&](ir::Construct* c) { EmitConstruct(c); }, //
+ [&](ir::Convert* c) { EmitConvert(c); }, //
[&](ir::Load* l) { EmitLoad(l); }, //
[&](ir::Loop* l) { EmitLoop(l); }, //
[&](ir::Switch* sw) { EmitSwitch(sw); }, //
@@ -1136,6 +1139,88 @@
current_function_.push_inst(spv::Op::OpCompositeConstruct, std::move(operands));
}
+void GeneratorImplIr::EmitConvert(ir::Convert* convert) {
+ auto* res_ty = convert->Result()->Type();
+ auto* arg_ty = convert->Args()[0]->Type();
+
+ OperandList operands = {Type(convert->Result()->Type()), Value(convert)};
+ for (auto* arg : convert->Args()) {
+ operands.push_back(Value(arg));
+ }
+
+ spv::Op op = spv::Op::Max;
+ if (res_ty->is_signed_integer_scalar_or_vector() && arg_ty->is_float_scalar_or_vector()) {
+ // float to signed int.
+ op = spv::Op::OpConvertFToS;
+ } else if (res_ty->is_unsigned_integer_scalar_or_vector() &&
+ arg_ty->is_float_scalar_or_vector()) {
+ // float to unsigned int.
+ op = spv::Op::OpConvertFToU;
+ } else if (res_ty->is_float_scalar_or_vector() &&
+ arg_ty->is_signed_integer_scalar_or_vector()) {
+ // signed int to float.
+ op = spv::Op::OpConvertSToF;
+ } else if (res_ty->is_float_scalar_or_vector() &&
+ arg_ty->is_unsigned_integer_scalar_or_vector()) {
+ // unsigned int to float.
+ op = spv::Op::OpConvertUToF;
+ } else if (res_ty->is_float_scalar_or_vector() && arg_ty->is_float_scalar_or_vector() &&
+ res_ty->Size() != arg_ty->Size()) {
+ // float to float (different bitwidth).
+ op = spv::Op::OpFConvert;
+ } else if (res_ty->is_integer_scalar_or_vector() && arg_ty->is_integer_scalar_or_vector() &&
+ res_ty->Size() == arg_ty->Size()) {
+ // int to int (same bitwidth, different signedness).
+ op = spv::Op::OpBitcast;
+ } else if (res_ty->is_bool_scalar_or_vector()) {
+ if (arg_ty->is_integer_scalar_or_vector()) {
+ // int to bool.
+ op = spv::Op::OpINotEqual;
+ } else {
+ // float to bool.
+ op = spv::Op::OpFUnordNotEqual;
+ }
+ operands.push_back(ConstantNull(arg_ty));
+ } else if (arg_ty->is_bool_scalar_or_vector()) {
+ // Select between constant one and zero, splatting them to vectors if necessary.
+ const constant::Value* one = nullptr;
+ const constant::Value* zero = nullptr;
+ Switch(
+ res_ty->DeepestElement(), //
+ [&](const type::F32*) {
+ one = ir_->constant_values.Get(1_f);
+ zero = ir_->constant_values.Get(0_f);
+ },
+ [&](const type::F16*) {
+ one = ir_->constant_values.Get(1_h);
+ zero = ir_->constant_values.Get(0_h);
+ },
+ [&](const type::I32*) {
+ one = ir_->constant_values.Get(1_i);
+ zero = ir_->constant_values.Get(0_i);
+ },
+ [&](const type::U32*) {
+ one = ir_->constant_values.Get(1_u);
+ zero = ir_->constant_values.Get(0_u);
+ });
+ TINT_ASSERT_OR_RETURN(Writer, one && zero);
+
+ if (auto* vec = res_ty->As<type::Vector>()) {
+ // Splat the scalars into vectors.
+ one = ir_->constant_values.Splat(vec, one, vec->Width());
+ zero = ir_->constant_values.Splat(vec, zero, vec->Width());
+ }
+
+ op = spv::Op::OpSelect;
+ operands.push_back(Constant(one));
+ operands.push_back(Constant(zero));
+ } else {
+ TINT_ICE(Writer, diagnostics_) << "unhandled convert instruction";
+ }
+
+ current_function_.push_inst(op, std::move(operands));
+}
+
void GeneratorImplIr::EmitLoad(ir::Load* load) {
current_function_.push_inst(spv::Op::OpLoad,
{Type(load->Result()->Type()), Value(load), Value(load->From())});
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.h b/src/tint/writer/spirv/ir/generator_impl_ir.h
index 34e5899..3a463c2 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.h
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.h
@@ -36,6 +36,7 @@
class BuiltinCall;
class Construct;
class ControlInstruction;
+class Convert;
class ExitIf;
class ExitLoop;
class ExitSwitch;
@@ -195,6 +196,10 @@
/// @param construct the construct instruction to emit
void EmitConstruct(ir::Construct* construct);
+ /// Emit a convert instruction.
+ /// @param convert the convert instruction to emit
+ void EmitConvert(ir::Convert* convert);
+
/// Emit a load instruction.
/// @param load the load instruction to emit
void EmitLoad(ir::Load* load);
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_convert_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_convert_test.cc
new file mode 100644
index 0000000..22481bb
--- /dev/null
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_convert_test.cc
@@ -0,0 +1,102 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/writer/spirv/ir/test_helper_ir.h"
+
+namespace tint::writer::spirv {
+namespace {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+/// A parameterized test case.
+struct ConvertCase {
+ /// The input type.
+ TestElementType in;
+ /// The output type.
+ TestElementType out;
+ /// The expected SPIR-V instruction.
+ std::string spirv_inst;
+ /// The expected SPIR-V result type name.
+ std::string spirv_type_name;
+};
+std::string PrintCase(testing::TestParamInfo<ConvertCase> cc) {
+ utils::StringStream ss;
+ ss << cc.param.in << "_to_" << cc.param.out;
+ return ss.str();
+}
+
+using Convert = SpvGeneratorImplTestWithParam<ConvertCase>;
+TEST_P(Convert, Scalar) {
+ auto& params = GetParam();
+ auto* func = b.Function("foo", MakeScalarType(params.out));
+ func->SetParams({b.FunctionParam("arg", MakeScalarType(params.in))});
+ b.With(func->Block(), [&] {
+ auto* result = b.Convert(MakeScalarType(params.out), func->Params()[0]);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = " + params.spirv_inst + " %" + params.spirv_type_name + " %arg");
+}
+TEST_P(Convert, Vector) {
+ auto& params = GetParam();
+ auto* func = b.Function("foo", MakeVectorType(params.out));
+ func->SetParams({b.FunctionParam("arg", MakeVectorType(params.in))});
+ b.With(func->Block(), [&] {
+ auto* result = b.Convert(MakeVectorType(params.out), func->Params()[0]);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = " + params.spirv_inst + " %v2" + params.spirv_type_name + " %arg");
+}
+INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest,
+ Convert,
+ testing::Values(
+ // To f32.
+ ConvertCase{kF16, kF32, "OpFConvert", "float"},
+ ConvertCase{kI32, kF32, "OpConvertSToF", "float"},
+ ConvertCase{kU32, kF32, "OpConvertUToF", "float"},
+ ConvertCase{kBool, kF32, "OpSelect", "float"},
+
+ // To f16.
+ ConvertCase{kF32, kF16, "OpFConvert", "half"},
+ ConvertCase{kI32, kF16, "OpConvertSToF", "half"},
+ ConvertCase{kU32, kF16, "OpConvertUToF", "half"},
+ ConvertCase{kBool, kF16, "OpSelect", "half"},
+
+ // To i32.
+ ConvertCase{kF32, kI32, "OpConvertFToS", "int"},
+ ConvertCase{kF16, kI32, "OpConvertFToS", "int"},
+ ConvertCase{kU32, kI32, "OpBitcast", "int"},
+ ConvertCase{kBool, kI32, "OpSelect", "int"},
+
+ // To u32.
+ ConvertCase{kF32, kU32, "OpConvertFToU", "uint"},
+ ConvertCase{kF16, kU32, "OpConvertFToU", "uint"},
+ ConvertCase{kI32, kU32, "OpBitcast", "uint"},
+ ConvertCase{kBool, kU32, "OpSelect", "uint"},
+
+ // To bool.
+ ConvertCase{kF32, kBool, "OpFUnordNotEqual", "bool"},
+ ConvertCase{kF16, kBool, "OpFUnordNotEqual", "bool"},
+ ConvertCase{kI32, kBool, "OpINotEqual", "bool"},
+ ConvertCase{kU32, kBool, "OpINotEqual", "bool"}),
+ PrintCase);
+
+} // namespace
+} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/ir/test_helper_ir.h b/src/tint/writer/spirv/ir/test_helper_ir.h
index 5fbf6e4..dfc05b8 100644
--- a/src/tint/writer/spirv/ir/test_helper_ir.h
+++ b/src/tint/writer/spirv/ir/test_helper_ir.h
@@ -41,6 +41,26 @@
kF32,
kF16,
};
+inline utils::StringStream& operator<<(utils::StringStream& out, TestElementType type) {
+ switch (type) {
+ case kBool:
+ out << "bool";
+ break;
+ case kI32:
+ out << "i32";
+ break;
+ case kU32:
+ out << "u32";
+ break;
+ case kF32:
+ out << "f32";
+ break;
+ case kF16:
+ out << "f16";
+ break;
+ }
+ return out;
+}
/// Base helper class for testing the SPIR-V generator implementation.
template <typename BASE>