[ir][msl] Emit `var` instructions.
Add emission of `var` instructions to the MSL IR Generator.
Bug: tint:1967
Change-Id: I32e0a22262e1959872a1fc196152528ee7099b95
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/144402
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index d60dc59..372bfae 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -2328,6 +2328,7 @@
"lang/msl/writer/printer/let_test.cc",
"lang/msl/writer/printer/return_test.cc",
"lang/msl/writer/printer/type_test.cc",
+ "lang/msl/writer/printer/var_test.cc",
]
deps += [
":libtint_ir_src",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 890134e..886cac2 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -1561,6 +1561,7 @@
lang/msl/writer/printer/let_test.cc
lang/msl/writer/printer/return_test.cc
lang/msl/writer/printer/type_test.cc
+ lang/msl/writer/printer/var_test.cc
)
endif()
endif()
diff --git a/src/tint/lang/msl/writer/printer/let_test.cc b/src/tint/lang/msl/writer/printer/let_test.cc
index 888ac16..519d905 100644
--- a/src/tint/lang/msl/writer/printer/let_test.cc
+++ b/src/tint/lang/msl/writer/printer/let_test.cc
@@ -179,7 +179,7 @@
)");
}
-TEST_F(MslPrinterTest, LetArrVec2BoolEmit_VariableDeclStatement_Const_arr_vec2_bool) {
+TEST_F(MslPrinterTest, LetArrVec2Bool) {
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
b.Let("l", b.Composite(ty.array<vec2<bool>, 3>(), //
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index efa20c1..831efff 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -22,10 +22,12 @@
#include "src/tint/lang/core/ir/exit_if.h"
#include "src/tint/lang/core/ir/if.h"
#include "src/tint/lang/core/ir/let.h"
+#include "src/tint/lang/core/ir/load.h"
#include "src/tint/lang/core/ir/multi_in_block.h"
#include "src/tint/lang/core/ir/return.h"
#include "src/tint/lang/core/ir/unreachable.h"
#include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/lang/core/ir/var.h"
#include "src/tint/lang/core/type/array.h"
#include "src/tint/lang/core/type/atomic.h"
#include "src/tint/lang/core/type/bool.h"
@@ -84,7 +86,7 @@
// Emit module-scope declarations.
if (ir_->root_block) {
- // EmitRootBlock(ir_->root_block);
+ EmitBlockInstructions(ir_->root_block);
}
// Emit functions.
@@ -170,12 +172,59 @@
[&](ir::ExitIf* e) { EmitExitIf(e); }, //
[&](ir::If* if_) { EmitIf(if_); }, //
[&](ir::Let* l) { EmitLet(l); }, //
+ [&](ir::Load* l) { EmitLoad(l); }, //
[&](ir::Return* r) { EmitReturn(r); }, //
[&](ir::Unreachable*) { EmitUnreachable(); }, //
+ [&](ir::Var* v) { EmitVar(v); }, //
[&](Default) { TINT_ICE() << "unimplemented instruction: " << inst->TypeInfo().name; });
}
}
+void Printer::EmitLoad(ir::Load* l) {
+ // Force loads to be bound as inlines
+ bindings_.Add(l->Result(), InlinedValue{Expr(l->From()), PtrKind::kRef});
+}
+
+void Printer::EmitVar(ir::Var* v) {
+ auto out = Line();
+
+ auto* ptr = v->Result()->Type()->As<type::Pointer>();
+ TINT_ASSERT_OR_RETURN(ptr);
+
+ auto space = ptr->AddressSpace();
+ switch (space) {
+ case builtin::AddressSpace::kFunction:
+ case builtin::AddressSpace::kHandle:
+ break;
+ case builtin::AddressSpace::kPrivate:
+ out << "thread ";
+ break;
+ case builtin::AddressSpace::kWorkgroup:
+ out << "threadgroup ";
+ break;
+ default:
+ TINT_ICE() << "unhandled variable address space";
+ return;
+ }
+
+ auto name = ir_->NameOf(v);
+
+ EmitType(out, ptr->UnwrapPtr());
+ out << " " << name.Name();
+
+ if (v->Initializer()) {
+ out << " = " << Expr(v->Initializer());
+ } else if (space == builtin::AddressSpace::kPrivate ||
+ space == builtin::AddressSpace::kFunction ||
+ space == builtin::AddressSpace::kUndefined) {
+ out << " = ";
+ EmitZeroValue(out, ptr->UnwrapPtr());
+ }
+ out << ";";
+
+ Bind(v->Result(), name, PtrKind::kRef);
+}
+
void Printer::EmitLet(ir::Let* l) {
Bind(l->Result(), Expr(l->Value(), PtrKind::kPtr), PtrKind::kPtr);
}
@@ -225,8 +274,8 @@
}
void Printer::EmitReturn(ir::Return* r) {
- // If this return has no arguments and the current block is for the function which is being
- // returned, skip the return.
+ // If this return has no arguments and the current block is for the function which is
+ // being returned, skip the return.
if (current_block_ == current_function_->Block() && r->Args().IsEmpty()) {
return;
}
@@ -419,10 +468,10 @@
return;
}
- // This does not append directly to the preamble because a struct may require other structs, or
- // the array template, to get emitted before it. So, the struct emits into a temporary text
- // buffer, then anything it depends on will emit to the preamble first, and then it copies the
- // text buffer into the preamble.
+ // This does not append directly to the preamble because a struct may require other
+ // structs, or the array template, to get emitted before it. So, the struct emits into a
+ // temporary text buffer, then anything it depends on will emit to the preamble first,
+ // and then it copies the text buffer into the preamble.
TextBuffer str_buf;
Line(&str_buf) << "struct " << StructName(str) << " {";
@@ -623,6 +672,25 @@
[&](Default) { UNHANDLED_CASE(c->Type()); });
}
+void Printer::EmitZeroValue(StringStream& out, const type::Type* ty) {
+ Switch(
+ ty, [&](const type::Bool*) { out << "false"; }, //
+ [&](const type::F16*) { out << "0.0h"; }, //
+ [&](const type::F32*) { out << "0.0f"; }, //
+ [&](const type::I32*) { out << "0"; }, //
+ [&](const type::U32*) { out << "0u"; }, //
+ [&](const type::Vector* vec) { EmitZeroValue(out, vec->type()); }, //
+ [&](const type::Matrix* mat) {
+ EmitType(out, mat);
+
+ ScopedParen sp(out);
+ EmitZeroValue(out, mat->type());
+ },
+ [&](const type::Array*) { out << "{}"; }, //
+ [&](const type::Struct*) { out << "{}"; }, //
+ [&](Default) { TINT_ICE() << "Invalid type for zero emission: " << ty->FriendlyName(); });
+}
+
std::string Printer::StructName(const type::Struct* s) {
auto name = s->Name().Name();
if (HasPrefix(name, "__")) {
@@ -665,10 +733,12 @@
}
if constexpr (std::is_same_v<T, InlinedValue>) {
+ auto result = ExprAndPtrKind{got.expr, got.ptr_kind};
+
// Single use (inlined) expression.
// Mark the bindings_ map entry as consumed.
*lookup = ConsumedValue{};
- return {got.expr, got.ptr_kind};
+ return result;
}
if constexpr (std::is_same_v<T, ConsumedValue>) {
diff --git a/src/tint/lang/msl/writer/printer/printer.h b/src/tint/lang/msl/writer/printer/printer.h
index 2bf3b76..d396965 100644
--- a/src/tint/lang/msl/writer/printer/printer.h
+++ b/src/tint/lang/msl/writer/printer/printer.h
@@ -30,8 +30,10 @@
class ExitIf;
class If;
class Let;
+class Load;
class Return;
class Unreachable;
+class Var;
} // namespace tint::ir
namespace tint::msl::writer {
@@ -71,6 +73,12 @@
/// Emit a let instruction
/// @param l the let instruction
void EmitLet(ir::Let* l);
+ /// Emit a var instruction
+ /// @param v the var instruction
+ void EmitVar(ir::Var* v);
+ /// Emit a load instruction
+ /// @param l the load instruction
+ void EmitLoad(ir::Load* l);
/// Emit a return instruction
/// @param r the return instruction
@@ -126,6 +134,11 @@
/// @param c the constant to emit
void EmitConstant(StringStream& out, const constant::Value* c);
+ /// Emits the zero value for the given type
+ /// @param out the stream to emit too
+ /// @param ty the type
+ void EmitZeroValue(StringStream& out, const type::Type* ty);
+
/// @returns the name of the templated `tint_array` helper type, generating it if needed
const std::string& ArrayTemplateName();
diff --git a/src/tint/lang/msl/writer/printer/var_test.cc b/src/tint/lang/msl/writer/printer/var_test.cc
new file mode 100644
index 0000000..b13529a
--- /dev/null
+++ b/src/tint/lang/msl/writer/printer/var_test.cc
@@ -0,0 +1,288 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/msl/writer/printer/helper_test.h"
+
+namespace tint::msl::writer {
+namespace {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+TEST_F(MslPrinterTest, VarF32) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, f32>());
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ float a = 0.0f;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarI32) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* v = b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, i32>());
+ v->SetInitializer(b.Constant(1_i));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ int a = 1;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarU32) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* v = b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, u32>());
+ v->SetInitializer(b.Constant(1_u));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ uint a = 1u;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarArrayF32) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, array<f32, 5>>());
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + MetalArray() + R"(
+void foo() {
+ tint_array<float, 5> a = {};
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarStruct) {
+ auto* s = ty.Struct(mod.symbols.New("MyStruct"), {{mod.symbols.Register("a"), ty.f32()}, //
+ {mod.symbols.Register("b"), ty.vec4<i32>()}});
+
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Var("a", ty.ptr(builtin::AddressSpace::kFunction, s));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(struct MyStruct {
+ float a;
+ int4 b;
+};
+
+void foo() {
+ MyStruct a = {};
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarVecF32) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, vec2<f32>>());
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ float2 a = 0.0f;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarVecF16) {
+ // Enable f16?
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, vec2<f16>>());
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ half2 a = 0.0h;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarMatF32) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, mat3x2<f32>>());
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ float3x2 a = float3x2(0.0f);
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarMatF16) {
+ // Enable f16?
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, mat3x2<f16>>());
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ half3x2 a = half3x2(0.0h);
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarVecF32SplatZero) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* v = b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, vec3<f32>>());
+ v->SetInitializer(b.Splat(ty.vec3<f32>(), 0_f, 3));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ float3 a = float3(0.0f);
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarVecF16SplatZero) {
+ // Enable f16
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* v = b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, vec3<f16>>());
+ v->SetInitializer(b.Splat(ty.vec3<f16>(), 0_h, 3));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ half3 a = half3(0.0h);
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarMatF32SplatZero) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* v = b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, mat2x3<f32>>());
+ v->SetInitializer(b.Composite(ty.mat2x3<f32>(), b.Splat(ty.vec3<f32>(), 0_f, 3),
+ b.Splat(ty.vec3<f32>(), 0_f, 3)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ float2x3 a = float2x3(float3(0.0f), float3(0.0f));
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarMatF16SplatZero) {
+ // Enable f16?
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* v = b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, mat2x3<f16>>());
+ v->SetInitializer(b.Composite(ty.mat2x3<f16>(), b.Splat(ty.vec3<f16>(), 0_h, 3),
+ b.Splat(ty.vec3<f16>(), 0_h, 3)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ half2x3 a = half2x3(half3(0.0h), half3(0.0h));
+}
+)");
+}
+
+// TODO(dsinclair): Requires ModuleScopeVarToEntryPointParam transform
+TEST_F(MslPrinterTest, DISABLED_VarGlobalPrivate) {
+ ir::Var* v = nullptr;
+ b.Append(b.RootBlock(),
+ [&] { v = b.Var("v", ty.ptr<builtin::AddressSpace::kPrivate, f32>()); });
+
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* ld = b.Load(v->Result());
+ auto* a = b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, f32>());
+ a->SetInitializer(ld->Result());
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+struct tint_private_vars_struct {
+ float a;
+};
+
+void foo() {
+ thread tint_private_vars_struct tint_private_vars = {};
+ float const a = tint_private_vars.a;
+ return;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarGlobalWorkgroup) {
+ ir::Var* v = nullptr;
+ b.Append(b.RootBlock(),
+ [&] { v = b.Var("v", ty.ptr<builtin::AddressSpace::kWorkgroup, f32>()); });
+
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* ld = b.Load(v->Result());
+ auto* a = b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, f32>());
+ a->SetInitializer(ld->Result());
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+threadgroup float v;
+void foo() {
+ float a = v;
+}
+)");
+}
+
+} // namespace
+} // namespace tint::msl::writer