[msl] Support handle types in ModuleScopeVars
These are passed around by value instead of as pointers, so add some
extra logic to drop the pointers and remove load instructions.
Bug: 42251016
Change-Id: Ie395f12fce678f62942610d2a6b50b5e6986cbb1
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/188842
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/msl/writer/printer/function_test.cc b/src/tint/lang/msl/writer/printer/function_test.cc
index 0513a48..8ba60c2 100644
--- a/src/tint/lang/msl/writer/printer/function_test.cc
+++ b/src/tint/lang/msl/writer/printer/function_test.cc
@@ -25,6 +25,7 @@
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#include "src/tint/lang/core/type/sampled_texture.h"
#include "src/tint/lang/msl/writer/printer/helper_test.h"
namespace tint::msl::writer {
@@ -41,7 +42,7 @@
)");
}
-TEST_F(MslPrinterTest, EntryPointParameterBindingPoint) {
+TEST_F(MslPrinterTest, EntryPointParameterBufferBindingPoint) {
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
auto* storage = b.FunctionParam("storage", ty.ptr(core::AddressSpace::kStorage, ty.i32()));
auto* uniform = b.FunctionParam("uniform", ty.ptr(core::AddressSpace::kUniform, ty.i32()));
@@ -57,5 +58,22 @@
)");
}
+TEST_F(MslPrinterTest, EntryPointParameterHandleBindingPoint) {
+ auto* t = ty.Get<core::type::SampledTexture>(core::type::TextureDimension::k2d, ty.f32());
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ auto* texture = b.FunctionParam("texture", t);
+ auto* sampler = b.FunctionParam("sampler", ty.sampler());
+ texture->SetBindingPoint(0, 1);
+ sampler->SetBindingPoint(0, 2);
+ func->SetParams({texture, sampler});
+ func->Block()->Append(b.Return(func));
+
+ ASSERT_TRUE(Generate()) << err_ << output_;
+ EXPECT_EQ(output_, MetalHeader() + R"(
+fragment void foo(texture2d<float, access::sample> texture [[texture(1)]], sampler sampler [[sampler(2)]]) {
+}
+)");
+}
+
} // namespace
} // namespace tint::msl::writer
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index 7c58cd4..0ae4046 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -329,16 +329,28 @@
}
if (auto binding_point = param->BindingPoint()) {
- auto ptr = param->Type()->As<core::type::Pointer>();
TINT_ASSERT(binding_point->group == 0);
- switch (ptr->AddressSpace()) {
- case core::AddressSpace::kStorage:
- case core::AddressSpace::kUniform:
- out << " [[buffer(" << binding_point->binding << ")]]";
- break;
- default:
- TINT_UNREACHABLE() << "invalid address space with binding point: "
- << ptr->AddressSpace();
+ if (auto ptr = param->Type()->As<core::type::Pointer>()) {
+ switch (ptr->AddressSpace()) {
+ case core::AddressSpace::kStorage:
+ case core::AddressSpace::kUniform:
+ out << " [[buffer(" << binding_point->binding << ")]]";
+ break;
+ default:
+ TINT_UNREACHABLE() << "invalid address space with binding point: "
+ << ptr->AddressSpace();
+ }
+ } else {
+ // Handle types are declared by value instead of by pointer.
+ Switch(
+ param->Type(),
+ [&](const core::type::Texture*) {
+ out << " [[texture(" << binding_point->binding << ")]]";
+ },
+ [&](const core::type::Sampler*) {
+ out << " [[sampler(" << binding_point->binding << ")]]";
+ },
+ TINT_ICE_ON_NO_MATCH);
}
}
}
@@ -1040,7 +1052,6 @@
switch (sc) {
case core::AddressSpace::kFunction:
case core::AddressSpace::kPrivate:
- case core::AddressSpace::kHandle:
out << "thread";
break;
case core::AddressSpace::kWorkgroup:
diff --git a/src/tint/lang/msl/writer/raise/module_scope_vars.cc b/src/tint/lang/msl/writer/raise/module_scope_vars.cc
index 0cb07c1..00a569f 100644
--- a/src/tint/lang/msl/writer/raise/module_scope_vars.cc
+++ b/src/tint/lang/msl/writer/raise/module_scope_vars.cc
@@ -88,14 +88,32 @@
ProcessFunction(*func);
}
- // Replace uses of each module-scope variable with pointers extracted from the structure.
+ // Replace uses of each module-scope variable with values extracted from the structure.
uint32_t index = 0;
for (auto& var : module_vars) {
- var->Result(0)->ReplaceAllUsesWith([&](core::ir::Usage use) { //
- return GetPointerFromStruct(var, use.instruction, index);
+ Vector<core::ir::Instruction*, 16> to_destroy;
+ auto* ptr = var->Result(0)->Type()->As<core::type::Pointer>();
+ var->Result(0)->ForEachUse([&](core::ir::Usage use) { //
+ auto* extracted_variable = GetVariableFromStruct(var, use.instruction, index);
+
+ // We drop the pointer from handle variables and store them in the struct by value
+ // instead, so remove any load instructions for the handle address space.
+ if (use.instruction->Is<core::ir::Load>() &&
+ ptr->AddressSpace() == core::AddressSpace::kHandle) {
+ use.instruction->Result(0)->ReplaceAllUsesWith(extracted_variable);
+ to_destroy.Push(use.instruction);
+ return;
+ }
+
+ use.instruction->SetOperand(use.operand_index, extracted_variable);
});
var->Destroy();
index++;
+
+ // Clean up instructions that need to be removed.
+ for (auto* inst : to_destroy) {
+ inst->Destroy();
+ }
}
}
@@ -106,6 +124,13 @@
for (auto* global : *ir.root_block) {
if (auto* var = global->As<core::ir::Var>()) {
auto* type = var->Result(0)->Type();
+
+ // Handle types drop the pointer and are passed around by value.
+ auto* ptr = type->As<core::type::Pointer>();
+ if (ptr->AddressSpace() == core::AddressSpace::kHandle) {
+ type = ptr->StoreType();
+ }
+
auto name = ir.NameOf(var);
if (!name) {
name = ir.symbols.New();
@@ -181,6 +206,14 @@
decl = param;
break;
}
+ case core::AddressSpace::kHandle: {
+ // Handle types become function parameters and drop the pointer.
+ auto* param = b.FunctionParam(ptr->UnwrapPtr());
+ param->SetBindingPoint(var->BindingPoint());
+ func->AppendParam(param);
+ decl = param;
+ break;
+ }
default:
TINT_UNREACHABLE() << "unhandled address space: " << ptr->AddressSpace();
}
@@ -218,18 +251,26 @@
return param;
}
- /// Get a pointer from the module-scope variable replacement structure, inserting new access
+ /// Get a variable from the module-scope variable replacement structure, inserting new access
/// instructions before @p inst.
/// @param var the variable to get the replacement for
/// @param inst the instruction that uses the variable
/// @param index the index of the variable in the structure member list
- /// @returns the pointer extracted from the structure
- core::ir::Value* GetPointerFromStruct(core::ir::Var* var,
- core::ir::Instruction* inst,
- uint32_t index) {
+ /// @returns the variable extracted from the structure
+ core::ir::Value* GetVariableFromStruct(core::ir::Var* var,
+ core::ir::Instruction* inst,
+ uint32_t index) {
auto* func = ContainingFunction(inst);
auto* struct_value = function_to_struct_value.GetOr(func, nullptr);
- auto* access = b.Access(var->Result(0)->Type(), struct_value, u32(index));
+ auto* type = var->Result(0)->Type();
+
+ // Handle types drop the pointer and are passed around by value.
+ auto* ptr = type->As<core::type::Pointer>();
+ if (ptr->AddressSpace() == core::AddressSpace::kHandle) {
+ type = ptr->StoreType();
+ }
+
+ auto* access = b.Access(type, struct_value, u32(index));
access->InsertBefore(inst);
return access->Result(0);
}
diff --git a/src/tint/lang/msl/writer/raise/module_scope_vars_test.cc b/src/tint/lang/msl/writer/raise/module_scope_vars_test.cc
index 17ffdc1..3bef228 100644
--- a/src/tint/lang/msl/writer/raise/module_scope_vars_test.cc
+++ b/src/tint/lang/msl/writer/raise/module_scope_vars_test.cc
@@ -30,6 +30,7 @@
#include <utility>
#include "src/tint/lang/core/ir/transform/helper_test.h"
+#include "src/tint/lang/core/type/sampled_texture.h"
using namespace tint::core::fluent_types; // NOLINT
using namespace tint::core::number_suffixes; // NOLINT
@@ -306,6 +307,63 @@
EXPECT_EQ(expect, str());
}
+TEST_F(MslWriter_ModuleScopeVarsTest, HandleTypes) {
+ auto* t = ty.Get<core::type::SampledTexture>(core::type::TextureDimension::k2d, ty.f32());
+ auto* var_t = b.Var("t", ty.ptr<handle>(t));
+ auto* var_s = b.Var("s", ty.ptr<handle>(ty.sampler()));
+ var_t->SetBindingPoint(1, 2);
+ var_s->SetBindingPoint(3, 4);
+ mod.root_block->Append(var_t);
+ mod.root_block->Append(var_s);
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* load_t = b.Load(var_t);
+ auto* load_s = b.Load(var_s);
+ b.Call<vec4<f32>>(core::BuiltinFn::kTextureSample, load_t, load_s, b.Splat<vec2<f32>>(0_f));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+$B1: { # root
+ %t:ptr<handle, texture_2d<f32>, read> = var @binding_point(1, 2)
+ %s:ptr<handle, sampler, read> = var @binding_point(3, 4)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %4:texture_2d<f32> = load %t
+ %5:sampler = load %s
+ %6:vec4<f32> = textureSample %4, %5, vec2<f32>(0.0f)
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+tint_module_vars_struct = struct @align(1) {
+ t:texture_2d<f32> @offset(0)
+ s:sampler @offset(0)
+}
+
+%foo = @fragment func(%t:texture_2d<f32> [@binding_point(1, 2)], %s:sampler [@binding_point(3, 4)]):void {
+ $B1: {
+ %4:tint_module_vars_struct = construct %t, %s
+ %tint_module_vars:tint_module_vars_struct = let %4
+ %6:texture_2d<f32> = access %tint_module_vars, 0u
+ %7:sampler = access %tint_module_vars, 1u
+ %8:vec4<f32> = textureSample %6, %7, vec2<f32>(0.0f)
+ ret
+ }
+}
+)";
+
+ Run(ModuleScopeVars);
+
+ EXPECT_EQ(expect, str());
+}
+
TEST_F(MslWriter_ModuleScopeVarsTest, MultipleAddressSpaces) {
auto* var_a = b.Var("a", ty.ptr<uniform, i32, core::Access::kRead>());
auto* var_b = b.Var("b", ty.ptr<storage, i32, core::Access::kReadWrite>());
@@ -607,6 +665,84 @@
EXPECT_EQ(expect, str());
}
+TEST_F(MslWriter_ModuleScopeVarsTest, CallFunctionThatUsesVars_HandleTypes) {
+ auto* t = ty.Get<core::type::SampledTexture>(core::type::TextureDimension::k2d, ty.f32());
+ auto* var_t = b.Var("t", ty.ptr<handle>(t));
+ auto* var_s = b.Var("s", ty.ptr<handle>(ty.sampler()));
+ var_t->SetBindingPoint(1, 2);
+ var_s->SetBindingPoint(3, 4);
+ mod.root_block->Append(var_t);
+ mod.root_block->Append(var_s);
+
+ auto* foo = b.Function("foo", ty.vec4<f32>());
+ auto* param = b.FunctionParam<i32>("param");
+ foo->SetParams({param});
+ b.Append(foo->Block(), [&] {
+ auto* load_t = b.Load(var_t);
+ auto* load_s = b.Load(var_s);
+ auto* result = b.Call<vec4<f32>>(core::BuiltinFn::kTextureSample, load_t, load_s,
+ b.Splat<vec2<f32>>(0_f));
+ b.Return(foo, result);
+ });
+
+ auto* func = b.Function("main", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Call(foo, 42_i);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+$B1: { # root
+ %t:ptr<handle, texture_2d<f32>, read> = var @binding_point(1, 2)
+ %s:ptr<handle, sampler, read> = var @binding_point(3, 4)
+}
+
+%foo = func(%param:i32):vec4<f32> {
+ $B2: {
+ %5:texture_2d<f32> = load %t
+ %6:sampler = load %s
+ %7:vec4<f32> = textureSample %5, %6, vec2<f32>(0.0f)
+ ret %7
+ }
+}
+%main = @fragment func():void {
+ $B3: {
+ %9:vec4<f32> = call %foo, 42i
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+tint_module_vars_struct = struct @align(1) {
+ t:texture_2d<f32> @offset(0)
+ s:sampler @offset(0)
+}
+
+%foo = func(%param:i32, %tint_module_vars:tint_module_vars_struct):vec4<f32> {
+ $B1: {
+ %4:texture_2d<f32> = access %tint_module_vars, 0u
+ %5:sampler = access %tint_module_vars, 1u
+ %6:vec4<f32> = textureSample %4, %5, vec2<f32>(0.0f)
+ ret %6
+ }
+}
+%main = @fragment func(%t:texture_2d<f32> [@binding_point(1, 2)], %s:sampler [@binding_point(3, 4)]):void {
+ $B2: {
+ %10:tint_module_vars_struct = construct %t, %s
+ %tint_module_vars_1:tint_module_vars_struct = let %10 # %tint_module_vars_1: 'tint_module_vars'
+ %12:vec4<f32> = call %foo, 42i, %tint_module_vars_1
+ ret
+ }
+}
+)";
+
+ Run(ModuleScopeVars);
+
+ EXPECT_EQ(expect, str());
+}
+
TEST_F(MslWriter_ModuleScopeVarsTest, CallFunctionThatUsesVars_OutOfOrder) {
auto* var_a = b.Var("a", ty.ptr<storage, i32, core::Access::kRead>());
auto* var_b = b.Var("b", ty.ptr<storage, i32, core::Access::kReadWrite>());