[ir][ms] Emit Lets
Add emission of `let` instructions to the MSL IR generator.
Bug: tint:1967
Change-Id: I5b6af981bce7ee4e497b101c95a1fc662f70cc3e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/144042
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 16920a1..0c554e2 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -2323,6 +2323,7 @@
"lang/msl/writer/printer/function_test.cc",
"lang/msl/writer/printer/helper_test.h",
"lang/msl/writer/printer/if_test.cc",
+ "lang/msl/writer/printer/let_test.cc",
"lang/msl/writer/printer/return_test.cc",
"lang/msl/writer/printer/type_test.cc",
]
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 7c52c4b..89e7d6c 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -1556,6 +1556,7 @@
lang/msl/writer/printer/function_test.cc
lang/msl/writer/printer/helper_test.h
lang/msl/writer/printer/if_test.cc
+ lang/msl/writer/printer/let_test.cc
lang/msl/writer/printer/return_test.cc
lang/msl/writer/printer/type_test.cc
)
diff --git a/src/tint/lang/core/ir/builder.h b/src/tint/lang/core/ir/builder.h
index 2048955..a71257c 100644
--- a/src/tint/lang/core/ir/builder.h
+++ b/src/tint/lang/core/ir/builder.h
@@ -57,10 +57,12 @@
#include "src/tint/lang/core/ir/user_call.h"
#include "src/tint/lang/core/ir/value.h"
#include "src/tint/lang/core/ir/var.h"
+#include "src/tint/lang/core/type/array.h"
#include "src/tint/lang/core/type/bool.h"
#include "src/tint/lang/core/type/f16.h"
#include "src/tint/lang/core/type/f32.h"
#include "src/tint/lang/core/type/i32.h"
+#include "src/tint/lang/core/type/matrix.h"
#include "src/tint/lang/core/type/pointer.h"
#include "src/tint/lang/core/type/u32.h"
#include "src/tint/lang/core/type/vector.h"
@@ -219,29 +221,72 @@
/// Creates a ir::Constant for an i32 Scalar
/// @param v the value
/// @returns the new constant
- ir::Constant* Constant(i32 v) { return Constant(ir.constant_values.Get(v)); }
+ ir::Constant* Constant(i32 v) { return Constant(ConstantValue(v)); }
/// Creates a ir::Constant for a u32 Scalar
/// @param v the value
/// @returns the new constant
- ir::Constant* Constant(u32 v) { return Constant(ir.constant_values.Get(v)); }
+ ir::Constant* Constant(u32 v) { return Constant(ConstantValue(v)); }
/// Creates a ir::Constant for a f32 Scalar
/// @param v the value
/// @returns the new constant
- ir::Constant* Constant(f32 v) { return Constant(ir.constant_values.Get(v)); }
+ ir::Constant* Constant(f32 v) { return Constant(ConstantValue(v)); }
/// Creates a ir::Constant for a f16 Scalar
/// @param v the value
/// @returns the new constant
- ir::Constant* Constant(f16 v) { return Constant(ir.constant_values.Get(v)); }
+ ir::Constant* Constant(f16 v) { return Constant(ConstantValue(v)); }
/// Creates a ir::Constant for a bool Scalar
/// @param v the value
/// @returns the new constant
template <typename BOOL, typename = std::enable_if_t<std::is_same_v<BOOL, bool>>>
ir::Constant* Constant(BOOL v) {
- return Constant(ir.constant_values.Get(v));
+ return Constant(ConstantValue(v));
+ }
+
+ /// Retrieves the inner constant from an ir::Constant
+ /// @param constant the ir constant
+ /// @returns the constant::Value inside the constant
+ const constant::Value* ConstantValue(ir::Constant* constant) { return constant->Value(); }
+
+ /// Creates a constant::Value for an i32 Scalar
+ /// @param v the value
+ /// @returns the new constant
+ const constant::Value* ConstantValue(i32 v) { return ir.constant_values.Get(v); }
+
+ /// Creates a constant::Value for a u32 Scalar
+ /// @param v the value
+ /// @returns the new constant
+ const constant::Value* ConstantValue(u32 v) { return ir.constant_values.Get(v); }
+
+ /// Creates a constant::Value for a f32 Scalar
+ /// @param v the value
+ /// @returns the new constant
+ const constant::Value* ConstantValue(f32 v) { return ir.constant_values.Get(v); }
+
+ /// Creates a constant::Value for a f16 Scalar
+ /// @param v the value
+ /// @returns the new constant
+ const constant::Value* ConstantValue(f16 v) { return ir.constant_values.Get(v); }
+
+ /// Creates a constant::Value for a bool Scalar
+ /// @param v the value
+ /// @returns the new constant
+ template <typename BOOL, typename = std::enable_if_t<std::is_same_v<BOOL, bool>>>
+ const constant::Value* ConstantValue(BOOL v) {
+ return ir.constant_values.Get(v);
+ }
+
+ /// Creates a new ir::Constant
+ /// @param ty the constant type
+ /// @param values the composite values
+ /// @returns the new constant
+ template <typename... ARGS, typename = DisableIfVectorLike<ARGS...>>
+ ir::Constant* Composite(const type::Type* ty, ARGS&&... values) {
+ return Constant(
+ ir.constant_values.Composite(ty, Vector{ConstantValue(std::forward<ARGS>(values))...}));
}
/// @param in the input value. One of: nullptr, ir::Value*, ir::Instruction* or a numeric value.
diff --git a/src/tint/lang/msl/writer/printer/helper_test.h b/src/tint/lang/msl/writer/printer/helper_test.h
index 152cf3b..9569f75 100644
--- a/src/tint/lang/msl/writer/printer/helper_test.h
+++ b/src/tint/lang/msl/writer/printer/helper_test.h
@@ -29,6 +29,20 @@
using namespace metal;
)";
+constexpr auto kMetalArray = R"(template<typename T, size_t N>
+struct tint_array {
+ const constant T& operator[](size_t i) const constant { return elements[i]; }
+ device T& operator[](size_t i) device { return elements[i]; }
+ const device T& operator[](size_t i) const device { return elements[i]; }
+ thread T& operator[](size_t i) thread { return elements[i]; }
+ const thread T& operator[](size_t i) const thread { return elements[i]; }
+ threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
+ const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
+ T elements[N];
+};
+
+)";
+
/// Base helper class for testing the MSL generator implementation.
template <typename BASE>
class MslPrinterTestHelperBase : public BASE {
@@ -44,6 +58,8 @@
/// @returns the metal header string
std::string MetalHeader() const { return kMetalHeader; }
+ /// @return the metal array string
+ std::string MetalArray() const { return kMetalArray; }
protected:
/// The MSL generator.
diff --git a/src/tint/lang/msl/writer/printer/let_test.cc b/src/tint/lang/msl/writer/printer/let_test.cc
new file mode 100644
index 0000000..888ac16
--- /dev/null
+++ b/src/tint/lang/msl/writer/printer/let_test.cc
@@ -0,0 +1,201 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/msl/writer/printer/helper_test.h"
+
+namespace tint::msl::writer {
+namespace {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+TEST_F(MslPrinterTest, LetU32) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Let("l", 42_u);
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ uint const l = 42u;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetDuplicate) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Let("l1", 42_u);
+ b.Let("l2", 42_u);
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ uint const l1 = 42u;
+ uint const l2 = 42u;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetF32) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Let("l", 42.0_f);
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ float const l = 42.0f;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetI32) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Let("l", 42_i);
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ int const l = 42;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetF16) {
+ // Enable F16?
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Let("l", 42_h);
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ half const l = 42.0h;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetVec3F32) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Let("l", b.Composite(ty.vec3<f32>(), 1_f, 2_f, 3_f));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ float3 const l = float3(1.0f, 2.0f, 3.0f);
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetVec3F16) {
+ // Enable f16?
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Let("l", b.Composite(ty.vec3<f16>(), 1_h, 2_h, 3_h));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ half3 const l = half3(1.0h, 2.0h, 3.0h);
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetMat2x3F32) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Let("l", b.Composite(ty.mat2x3<f32>(), //
+ b.Composite(ty.vec3<f32>(), 1_f, 2_f, 3_f),
+ b.Composite(ty.vec3<f32>(), 4_f, 5_f, 6_f)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ float2x3 const l = float2x3(float3(1.0f, 2.0f, 3.0f), float3(4.0f, 5.0f, 6.0f));
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetMat2x3F16) {
+ // Enable f16?
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Let("l", b.Composite(ty.mat2x3<f16>(), //
+ b.Composite(ty.vec3<f16>(), 1_h, 2_h, 3_h),
+ b.Composite(ty.vec3<f16>(), 4_h, 5_h, 6_h)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ half2x3 const l = half2x3(half3(1.0h, 2.0h, 3.0h), half3(4.0h, 5.0h, 6.0h));
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetArrF32) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Let("l", b.Composite(ty.array<f32, 3>(), 1_f, 2_f, 3_f));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + MetalArray() + R"(
+void foo() {
+ tint_array<float, 3> const l = tint_array<float, 3>{1.0f, 2.0f, 3.0f};
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetArrVec2BoolEmit_VariableDeclStatement_Const_arr_vec2_bool) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Let("l", b.Composite(ty.array<vec2<bool>, 3>(), //
+ b.Composite(ty.vec2<bool>(), true, false),
+ b.Composite(ty.vec2<bool>(), false, true),
+ b.Composite(ty.vec2<bool>(), true, false)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + MetalArray() + R"(
+void foo() {
+ tint_array<bool2, 3> const l = tint_array<bool2, 3>{bool2(true, false), bool2(false, true), bool2(true, false)};
+}
+)");
+}
+
+} // namespace
+} // namespace tint::msl::writer
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index 1f0a5cc..efa20c1 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -21,6 +21,7 @@
#include "src/tint/lang/core/ir/constant.h"
#include "src/tint/lang/core/ir/exit_if.h"
#include "src/tint/lang/core/ir/if.h"
+#include "src/tint/lang/core/ir/let.h"
#include "src/tint/lang/core/ir/multi_in_block.h"
#include "src/tint/lang/core/ir/return.h"
#include "src/tint/lang/core/ir/unreachable.h"
@@ -168,12 +169,17 @@
inst, //
[&](ir::ExitIf* e) { EmitExitIf(e); }, //
[&](ir::If* if_) { EmitIf(if_); }, //
+ [&](ir::Let* l) { EmitLet(l); }, //
[&](ir::Return* r) { EmitReturn(r); }, //
[&](ir::Unreachable*) { EmitUnreachable(); }, //
[&](Default) { TINT_ICE() << "unimplemented instruction: " << inst->TypeInfo().name; });
}
}
+void Printer::EmitLet(ir::Let* l) {
+ Bind(l->Result(), Expr(l->Value(), PtrKind::kPtr), PtrKind::kPtr);
+}
+
void Printer::EmitIf(ir::If* if_) {
// Emit any nodes that need to be used as PHI nodes
for (auto* phi : if_->Results()) {
@@ -723,6 +729,7 @@
} else {
out << expr;
}
+ out << ";";
Bind(value, mod_name, PtrKind::kPtr);
}
diff --git a/src/tint/lang/msl/writer/printer/printer.h b/src/tint/lang/msl/writer/printer/printer.h
index 3144147..2bf3b76 100644
--- a/src/tint/lang/msl/writer/printer/printer.h
+++ b/src/tint/lang/msl/writer/printer/printer.h
@@ -29,6 +29,7 @@
namespace tint::ir {
class ExitIf;
class If;
+class Let;
class Return;
class Unreachable;
} // namespace tint::ir
@@ -67,6 +68,10 @@
/// @param e the exit-if instruction
void EmitExitIf(ir::ExitIf* e);
+ /// Emit a let instruction
+ /// @param l the let instruction
+ void EmitLet(ir::Let* l);
+
/// Emit a return instruction
/// @param r the return instruction
void EmitReturn(ir::Return* r);