[msl-writer] Emit storage buffers.
This Cl adds the emission of storage buffers to the MSL backend.
Bug: tint:8
Change-Id: I6923926b36e73f2e351443cf1d2bf6d70873bc9a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/25264
Reviewed-by: David Neto <dneto@google.com>
diff --git a/simple.spv b/simple.spv
deleted file mode 100644
index d88a341..0000000
--- a/simple.spv
+++ /dev/null
Binary files differ
diff --git a/src/ast/function.cc b/src/ast/function.cc
index 44d5603..380d9a6 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -99,6 +99,34 @@
return ret;
}
+const std::vector<std::pair<Variable*, Function::BindingInfo>>
+Function::referenced_storagebuffer_variables() const {
+ std::vector<std::pair<Variable*, Function::BindingInfo>> ret;
+
+ for (auto* var : referenced_module_variables()) {
+ if (!var->IsDecorated() ||
+ var->storage_class() != ast::StorageClass::kStorageBuffer) {
+ continue;
+ }
+
+ BindingDecoration* binding = nullptr;
+ SetDecoration* set = nullptr;
+ for (const auto& deco : var->AsDecorated()->decorations()) {
+ if (deco->IsBinding()) {
+ binding = deco->AsBinding();
+ } else if (deco->IsSet()) {
+ set = deco->AsSet();
+ }
+ }
+ if (binding == nullptr || set == nullptr) {
+ continue;
+ }
+
+ ret.push_back({var, BindingInfo{binding, set}});
+ }
+ return ret;
+}
+
const std::vector<std::pair<Variable*, BuiltinDecoration*>>
Function::referenced_builtin_variables() const {
std::vector<std::pair<Variable*, BuiltinDecoration*>> ret;
diff --git a/src/ast/function.h b/src/ast/function.h
index 4a8c33f..b130ddc 100644
--- a/src/ast/function.h
+++ b/src/ast/function.h
@@ -101,6 +101,11 @@
/// @returns the referenced uniforms
const std::vector<std::pair<Variable*, Function::BindingInfo>>
referenced_uniform_variables() const;
+ /// Retrieves any referenced storagebuffer variables. Note, the storagebuffer
+ /// must be decorated with both binding and set decorations.
+ /// @returns the referenced storagebuffers
+ const std::vector<std::pair<Variable*, Function::BindingInfo>>
+ referenced_storagebuffer_variables() const;
/// Adds an ancestor entry point
/// @param ep the entry point ancestor
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index e9f85d5..b2c5af0 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -464,6 +464,15 @@
out_ << var->name();
}
+ for (const auto& data : func->referenced_storagebuffer_variables()) {
+ auto* var = data.first;
+ if (!first) {
+ out_ << ", ";
+ }
+ first = false;
+ out_ << var->name();
+ }
+
const auto& params = expr->params();
for (const auto& param : params) {
if (!first) {
@@ -1061,7 +1070,20 @@
out_ << "& " << var->name();
}
- // TODO(dsinclair): Binding/Set inputs
+ for (const auto& data : func->referenced_storagebuffer_variables()) {
+ auto* var = data.first;
+ if (!first) {
+ out_ << ", ";
+ }
+ first = false;
+
+ out_ << "device ";
+ // TODO(dsinclair): Can arrays be in storage buffers? If so, fix this ...
+ if (!EmitType(var->type(), "")) {
+ return false;
+ }
+ out_ << "& " << var->name();
+ }
for (const auto& v : func->params()) {
if (!first) {
@@ -1209,7 +1231,27 @@
out_ << "& " << var->name() << " [[buffer(" << binding->value() << ")]]";
}
- // TODO(dsinclair): Binding/Set inputs
+ for (auto data : func->referenced_storagebuffer_variables()) {
+ if (!first) {
+ out_ << ", ";
+ }
+ first = false;
+
+ auto* var = data.first;
+ // TODO(dsinclair): We're using the binding to make up the buffer number but
+ // we should instead be using a provided mapping that uses both buffer and
+ // set. https://bugs.chromium.org/p/tint/issues/detail?id=104
+ auto* binding = data.second.binding;
+ // auto* set = data.second.set;
+
+ out_ << "device ";
+ // TODO(dsinclair): Can you have a storagebuffer have an array? If so, this
+ // needs to be updated to handle arrays property.
+ if (!EmitType(var->type(), "")) {
+ return false;
+ }
+ out_ << "& " << var->name() << " [[buffer(" << binding->value() << ")]]";
+ }
out_ << ") {" << std::endl;
diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc
index 5a3e06e..297f6cd 100644
--- a/src/writer/msl/generator_impl_function_test.cc
+++ b/src/writer/msl/generator_impl_function_test.cc
@@ -344,6 +344,62 @@
)");
}
+TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_With_StorageBuffer) {
+ ast::type::VoidType void_type;
+ ast::type::F32Type f32;
+ ast::type::VectorType vec4(&f32, 4);
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "coord", ast::StorageClass::kStorageBuffer, &vec4));
+
+ ast::VariableDecorationList decos;
+ decos.push_back(std::make_unique<ast::BindingDecoration>(0));
+ decos.push_back(std::make_unique<ast::SetDecoration>(1));
+ coord_var->set_decorations(std::move(decos));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(coord_var.get());
+
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ast::VariableList params;
+ auto func = std::make_unique<ast::Function>("frag_main", std::move(params),
+ &void_type);
+
+ auto var =
+ std::make_unique<ast::Variable>("v", ast::StorageClass::kFunction, &f32);
+ var->set_constructor(std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("coord"),
+ std::make_unique<ast::IdentifierExpression>("x")));
+
+ ast::StatementList body;
+ body.push_back(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
+ body.push_back(std::make_unique<ast::ReturnStatement>());
+ func->set_body(std::move(body));
+
+ mod.AddFunction(std::move(func));
+
+ auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment, "",
+ "frag_main");
+ mod.AddEntryPoint(std::move(ep));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+
+ GeneratorImpl g(&mod);
+ ASSERT_TRUE(g.Generate()) << g.error();
+ EXPECT_EQ(g.result(), R"(#include <metal_stdlib>
+
+fragment void frag_main(device float4& coord [[buffer(0)]]) {
+ float v = coord.x;
+ return;
+}
+
+)");
+}
+
TEST_F(MslGeneratorImplTest,
Emit_Function_Called_By_EntryPoints_WithLocationGlobals_And_Params) {
ast::type::VoidType void_type;
@@ -616,6 +672,84 @@
)");
}
+TEST_F(MslGeneratorImplTest,
+ Emit_Function_Called_By_EntryPoint_With_StorageBuffer) {
+ ast::type::VoidType void_type;
+ ast::type::F32Type f32;
+ ast::type::VectorType vec4(&f32, 4);
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "coord", ast::StorageClass::kStorageBuffer, &vec4));
+
+ ast::VariableDecorationList decos;
+ decos.push_back(std::make_unique<ast::BindingDecoration>(0));
+ decos.push_back(std::make_unique<ast::SetDecoration>(1));
+ coord_var->set_decorations(std::move(decos));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(coord_var.get());
+
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ast::VariableList params;
+ params.push_back(std::make_unique<ast::Variable>(
+ "param", ast::StorageClass::kFunction, &f32));
+ auto sub_func =
+ std::make_unique<ast::Function>("sub_func", std::move(params), &f32);
+
+ ast::StatementList body;
+ body.push_back(std::make_unique<ast::ReturnStatement>(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("coord"),
+ std::make_unique<ast::IdentifierExpression>("x"))));
+ sub_func->set_body(std::move(body));
+
+ mod.AddFunction(std::move(sub_func));
+
+ auto func = std::make_unique<ast::Function>("frag_main", std::move(params),
+ &void_type);
+
+ ast::ExpressionList expr;
+ expr.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+
+ auto var =
+ std::make_unique<ast::Variable>("v", ast::StorageClass::kFunction, &f32);
+ var->set_constructor(std::make_unique<ast::CallExpression>(
+ std::make_unique<ast::IdentifierExpression>("sub_func"),
+ std::move(expr)));
+
+ body.push_back(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
+ body.push_back(std::make_unique<ast::ReturnStatement>());
+ func->set_body(std::move(body));
+
+ mod.AddFunction(std::move(func));
+
+ auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment, "",
+ "frag_main");
+ mod.AddEntryPoint(std::move(ep));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+
+ GeneratorImpl g(&mod);
+ ASSERT_TRUE(g.Generate()) << g.error();
+ EXPECT_EQ(g.result(), R"(#include <metal_stdlib>
+
+float sub_func(device float4& coord, float param) {
+ return coord.x;
+}
+
+fragment void frag_main(device float4& coord [[buffer(0)]]) {
+ float v = sub_func(coord, 1.00000000f);
+ return;
+}
+
+)");
+}
+
TEST_F(MslGeneratorImplTest, Emit_Function_Called_Two_EntryPoints_WithGlobals) {
ast::type::VoidType void_type;
ast::type::F32Type f32;