[hlsl-writer] StorageBuffer support.
This Cl adds support for storage buffers to the HLSL backend.
Bug: tint:7
Change-Id: I7adb655de8ccfcb6771fa661ff205c543b4efe66
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/27001
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/ast/struct.cc b/src/ast/struct.cc
index f0b2a9c..01344b6 100644
--- a/src/ast/struct.cc
+++ b/src/ast/struct.cc
@@ -31,6 +31,15 @@
Struct::~Struct() = default;
+StructMember* Struct::get_member(const std::string& name) const {
+ for (auto& mem : members_) {
+ if (mem->name() == name) {
+ return mem.get();
+ }
+ }
+ return nullptr;
+}
+
bool Struct::IsValid() const {
for (const auto& mem : members_) {
if (mem == nullptr || !mem->IsValid()) {
diff --git a/src/ast/struct.h b/src/ast/struct.h
index 9705d0f..e12cfd5 100644
--- a/src/ast/struct.h
+++ b/src/ast/struct.h
@@ -59,6 +59,11 @@
/// @returns the members
const StructMemberList& members() const { return members_; }
+ /// Returns the struct member with the given name or nullptr if non exists.
+ /// @param name the name of the member
+ /// @returns the struct member or nullptr if not found
+ StructMember* get_member(const std::string& name) const;
+
/// @returns true if the node is valid
bool IsValid() const override;
diff --git a/src/ast/struct_member.cc b/src/ast/struct_member.cc
index f0ce746..447fed9 100644
--- a/src/ast/struct_member.cc
+++ b/src/ast/struct_member.cc
@@ -14,6 +14,8 @@
#include "src/ast/struct_member.h"
+#include "src/ast/struct_member_offset_decoration.h"
+
namespace tint {
namespace ast {
@@ -37,6 +39,24 @@
StructMember::~StructMember() = default;
+bool StructMember::has_offset_decoration() const {
+ for (const auto& deco : decorations_) {
+ if (deco->IsOffset()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+uint32_t StructMember::offset() const {
+ for (const auto& deco : decorations_) {
+ if (deco->IsOffset()) {
+ return deco->AsOffset()->offset();
+ }
+ }
+ return 0;
+}
+
bool StructMember::IsValid() const {
if (name_.empty() || type_ == nullptr) {
return false;
diff --git a/src/ast/struct_member.h b/src/ast/struct_member.h
index fb21ea8..91a0621 100644
--- a/src/ast/struct_member.h
+++ b/src/ast/struct_member.h
@@ -72,6 +72,11 @@
/// @returns the decorations
const StructMemberDecorationList& decorations() const { return decorations_; }
+ /// @returns true if the struct member has an offset decoration
+ bool has_offset_decoration() const;
+ /// @returns the offset decoration value.
+ uint32_t offset() const;
+
/// @returns true if the node is valid
bool IsValid() const override;
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 6cf1791..6af10e0 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -424,7 +424,9 @@
ret = ctx_.type_mgr().Get(
std::make_unique<ast::type::VectorType>(m->type(), m->rows()));
} else {
- set_error(expr->source(), "invalid parent type in array accessor");
+ set_error(expr->source(), "invalid parent type (" +
+ parent_type->type_name() +
+ ") in array accessor");
return false;
}
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 0ae0e94..cb9b824 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -65,6 +65,37 @@
return stmts->last()->IsBreak() || stmts->last()->IsFallthrough();
}
+std::string get_buffer_name(ast::Expression* expr) {
+ for (;;) {
+ if (expr->IsIdentifier()) {
+ return expr->AsIdentifier()->name();
+ } else if (expr->IsMemberAccessor()) {
+ expr = expr->AsMemberAccessor()->structure();
+ } else if (expr->IsArrayAccessor()) {
+ expr = expr->AsArrayAccessor()->array();
+ } else {
+ break;
+ }
+ }
+ return "";
+}
+
+uint32_t convert_swizzle_to_index(const std::string& swizzle) {
+ if (swizzle == "r" || swizzle == "x") {
+ return 0;
+ }
+ if (swizzle == "g" || swizzle == "y") {
+ return 1;
+ }
+ if (swizzle == "b" || swizzle == "z") {
+ return 2;
+ }
+ if (swizzle == "a" || swizzle == "w") {
+ return 3;
+ }
+ return 0;
+}
+
} // namespace
GeneratorImpl::GeneratorImpl(ast::Module* module) : module_(module) {}
@@ -73,7 +104,7 @@
bool GeneratorImpl::Generate() {
for (const auto& global : module_->global_variables()) {
- global_variables_.set(global->name(), global.get());
+ register_global(global.get());
}
for (auto* const alias : module_->alias_types()) {
@@ -114,6 +145,10 @@
return true;
}
+void GeneratorImpl::register_global(ast::Variable* global) {
+ global_variables_.set(global->name(), global);
+}
+
std::string GeneratorImpl::generate_name(const std::string& prefix) {
std::string name = prefix;
uint32_t i = 0;
@@ -166,6 +201,11 @@
}
bool GeneratorImpl::EmitArrayAccessor(ast::ArrayAccessorExpression* expr) {
+ // Handle writing into a storage buffer array
+ if (is_storage_buffer_access(expr)) {
+ return EmitStorageBufferAccessor(expr, nullptr);
+ }
+
if (!EmitExpression(expr->array())) {
return false;
}
@@ -201,6 +241,28 @@
bool GeneratorImpl::EmitAssign(ast::AssignmentStatement* stmt) {
make_indent();
+ // If the LHS is an accessor into a storage buffer then we have to
+ // emit a Store operation instead of an ='s.
+ if (stmt->lhs()->IsMemberAccessor()) {
+ auto* mem = stmt->lhs()->AsMemberAccessor();
+ if (is_storage_buffer_access(mem)) {
+ if (!EmitStorageBufferAccessor(mem, stmt->rhs())) {
+ return false;
+ }
+ out_ << ";" << std::endl;
+ return true;
+ }
+ } else if (stmt->lhs()->IsArrayAccessor()) {
+ auto* ary = stmt->lhs()->AsArrayAccessor();
+ if (is_storage_buffer_access(ary)) {
+ if (!EmitStorageBufferAccessor(ary, stmt->rhs())) {
+ return false;
+ }
+ out_ << ";" << std::endl;
+ return true;
+ }
+ }
+
if (!EmitExpression(stmt->lhs())) {
return false;
}
@@ -1108,6 +1170,19 @@
out_ << std::endl;
}
+ bool emitted_storagebuffer = false;
+ for (auto data : func->referenced_storagebuffer_variables()) {
+ auto* var = data.first;
+ auto* binding = data.second.binding;
+
+ out_ << "RWByteAddressBuffer " << var->name() << " : register(u"
+ << binding->value() << ");" << std::endl;
+ emitted_storagebuffer = true;
+ }
+ if (emitted_storagebuffer) {
+ out_ << std::endl;
+ }
+
auto ep_name = ep->name();
if (ep_name.empty()) {
ep_name = ep->function_name();
@@ -1396,7 +1471,188 @@
return true;
}
+// TODO(dsinclair): This currently only handles loading of 4, 8, 12 or 16 byte
+// members. If we need to support larger we'll need to do the loading into
+// chunks.
+//
+// TODO(dsinclair): Need to support loading through a pointer. The pointer is
+// just a memory address in the storage buffer, so need to do the correct
+// calculation.
+bool GeneratorImpl::EmitStorageBufferAccessor(ast::Expression* expr,
+ ast::Expression* rhs) {
+ auto* result_type = expr->result_type()->UnwrapAliasPtrAlias();
+ std::string access_method = rhs != nullptr ? "Store" : "Load";
+ if (result_type->IsVector()) {
+ access_method += std::to_string(result_type->AsVector()->size());
+ }
+
+ // If we aren't storing then we need to put in the outer cast.
+ if (rhs == nullptr) {
+ if (result_type->is_float_scalar_or_vector()) {
+ out_ << "asfloat(";
+ } else if (result_type->is_signed_scalar_or_vector()) {
+ out_ << "asint(";
+ } else if (result_type->is_unsigned_scalar_or_vector()) {
+ out_ << "asuint(";
+ }
+ }
+
+ auto buffer_name = get_buffer_name(expr);
+ if (buffer_name.empty()) {
+ error_ = "error emitting storage buffer access";
+ return false;
+ }
+ out_ << buffer_name << "." << access_method << "(";
+
+ auto* ptr = expr;
+ bool first = true;
+ for (;;) {
+ if (ptr->IsIdentifier()) {
+ break;
+ }
+
+ if (!first) {
+ out_ << " + ";
+ }
+ first = false;
+ if (ptr->IsMemberAccessor()) {
+ auto* mem = ptr->AsMemberAccessor();
+ auto* res_type = mem->structure()->result_type()->UnwrapAliasPtrAlias();
+
+ if (res_type->IsStruct()) {
+ auto* str_type = res_type->AsStruct()->impl();
+ auto* str_member = str_type->get_member(mem->member()->name());
+
+ if (!str_member->has_offset_decoration()) {
+ error_ = "missing offset decoration for struct member";
+ return false;
+ }
+ out_ << str_member->offset();
+ } else if (res_type->IsVector()) {
+ // This must be a single element swizzle if we've got a vector at this
+ // point.
+ if (mem->member()->name().size() != 1) {
+ error_ =
+ "Encountered multi-element swizzle when should have only one "
+ "level";
+ return false;
+ }
+
+ // TODO(dsinclair): All our types are currently 4 bytes (f32, i32, u32)
+ // so this is assuming 4. This will need to be fixed when we get f16 or
+ // f64 types.
+ out_ << "(4 * " << convert_swizzle_to_index(mem->member()->name())
+ << ")";
+ } else {
+ error_ =
+ "Invalid result type for member accessor: " + res_type->type_name();
+ return false;
+ }
+
+ ptr = mem->structure();
+ } else if (ptr->IsArrayAccessor()) {
+ auto* ary = ptr->AsArrayAccessor();
+ auto* ary_type = ary->array()->result_type()->UnwrapAliasPtrAlias();
+
+ out_ << "(";
+ // TODO(dsinclair): Handle matrix case and struct case.
+ if (ary_type->IsArray()) {
+ out_ << ary_type->AsArray()->array_stride();
+ } else if (ary_type->IsVector()) {
+ // TODO(dsinclair): This is a hack. Our vectors can only be f32, i32
+ // or u32 which are all 4 bytes. When we get f16 or other types we'll
+ // have to ask the type for the byte size.
+ out_ << "4";
+ } else {
+ error_ = "Invalid array type in storage buffer access";
+ return false;
+ }
+ out_ << " * ";
+ if (!EmitExpression(ary->idx_expr())) {
+ return false;
+ }
+ out_ << ")";
+
+ ptr = ary->array();
+ } else {
+ error_ = "error emitting storage buffer access";
+ return false;
+ }
+ }
+
+ if (rhs != nullptr) {
+ out_ << ", asuint(";
+ if (!EmitExpression(rhs)) {
+ return false;
+ }
+ out_ << ")";
+ }
+
+ out_ << ")";
+
+ // Close the outer cast.
+ if (rhs == nullptr) {
+ out_ << ")";
+ }
+
+ return true;
+}
+
+bool GeneratorImpl::is_storage_buffer_access(
+ ast::ArrayAccessorExpression* expr) {
+ // We only care about array so we can get to the next part of the expression.
+ // If it isn't an array or a member accessor we can stop looking as it won't
+ // be a storage buffer.
+ auto* ary = expr->array();
+ if (ary->IsMemberAccessor()) {
+ return is_storage_buffer_access(ary->AsMemberAccessor());
+ } else if (ary->IsArrayAccessor()) {
+ return is_storage_buffer_access(ary->AsArrayAccessor());
+ }
+ return false;
+}
+
+bool GeneratorImpl::is_storage_buffer_access(
+ ast::MemberAccessorExpression* expr) {
+ auto* structure = expr->structure();
+ auto* data_type = structure->result_type()->UnwrapAliasPtrAlias();
+ // If the data is a multi-element swizzle then we will not load the swizzle
+ // portion through the Load command.
+ if (data_type->IsVector() && expr->member()->name().size() > 1) {
+ return false;
+ }
+
+ // Check if this is a storage buffer variable
+ if (structure->IsIdentifier()) {
+ auto* ident = expr->structure()->AsIdentifier();
+ if (ident->has_path()) {
+ return false;
+ }
+
+ ast::Variable* var = nullptr;
+ if (!global_variables_.get(ident->name(), &var)) {
+ return false;
+ }
+ return var->storage_class() == ast::StorageClass::kStorageBuffer;
+ } else if (structure->IsMemberAccessor()) {
+ return is_storage_buffer_access(structure->AsMemberAccessor());
+ } else if (structure->IsArrayAccessor()) {
+ return is_storage_buffer_access(structure->AsArrayAccessor());
+ }
+
+ // Technically I don't think this is possible, but if we don't have a struct
+ // or array accessor then we can't have a storage buffer I believe.
+ return false;
+}
+
bool GeneratorImpl::EmitMemberAccessor(ast::MemberAccessorExpression* expr) {
+ // Look for storage buffer accesses as we have to convert them into Load
+ // expressions. Stores will be identified in the assignment emission and a
+ // member accessor store of a storage buffer will not get here.
+ if (is_storage_buffer_access(expr)) {
+ return EmitStorageBufferAccessor(expr, nullptr);
+ }
+
if (!EmitExpression(expr->structure())) {
return false;
}
diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h
index 4b5159a..c9eae28 100644
--- a/src/writer/hlsl/generator_impl.h
+++ b/src/writer/hlsl/generator_impl.h
@@ -159,6 +159,11 @@
/// @param expr the member accessor expression
/// @returns true if the member accessor was emitted
bool EmitMemberAccessor(ast::MemberAccessorExpression* expr);
+ /// Handles a storage buffer accessor expression
+ /// @param expr the storage buffer accessor expression
+ /// @param rhs the right side of a store expression. Set to nullptr for a load
+ /// @returns true if the storage buffer accessor was emitted
+ bool EmitStorageBufferAccessor(ast::Expression* expr, ast::Expression* rhs);
/// Handles return statements
/// @param stmt the statement to emit
/// @returns true if the statement was successfully emitted
@@ -193,6 +198,18 @@
/// @returns true if the variable was emitted
bool EmitProgramConstVariable(const ast::Variable* var);
+ /// Returns true if the accessor is accessing a storage buffer.
+ /// @param expr the expression to check
+ /// @returns true if the accessor is accessing a storage buffer for which
+ /// we need to execute a Load instruction.
+ bool is_storage_buffer_access(ast::MemberAccessorExpression* expr);
+ /// Returns true if the accessor is accessing a storage buffer.
+ /// @param expr the expression to check
+ /// @returns true if the accessor is accessing a storage buffer
+ bool is_storage_buffer_access(ast::ArrayAccessorExpression* expr);
+ /// Registers the given global with the generator
+ /// @param global the global to register
+ void register_global(ast::Variable* global);
/// Checks if the global variable is in an input or output struct
/// @param var the variable to check
/// @returns true if the global is in an input or output struct
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index 47a79d9..9170fc7 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -30,6 +30,7 @@
#include "src/ast/set_decoration.h"
#include "src/ast/sint_literal.h"
#include "src/ast/struct.h"
+#include "src/ast/struct_member_offset_decoration.h"
#include "src/ast/type/alias_type.h"
#include "src/ast/type/array_type.h"
#include "src/ast/type/f32_type.h"
@@ -417,14 +418,104 @@
}
TEST_F(HlslGeneratorImplTest,
- DISABLED_Emit_Function_EntryPoint_With_StorageBuffer) {
+ Emit_Function_EntryPoint_With_StorageBuffer_Read) {
ast::type::VoidType void_type;
ast::type::F32Type f32;
- ast::type::VectorType vec4(&f32, 4);
+ ast::type::I32Type i32;
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList a_deco;
+ a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &i32, std::move(a_deco)));
+
+ ast::StructMemberDecorationList b_deco;
+ b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
+ members.push_back(
+ std::make_unique<ast::StructMember>("b", &f32, std::move(b_deco)));
+
+ auto str = std::make_unique<ast::Struct>();
+ str->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(str));
+ s.set_name("Data");
auto coord_var =
std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
- "coord", ast::StorageClass::kStorageBuffer, &vec4));
+ "coord", ast::StorageClass::kStorageBuffer, &s));
+
+ ast::VariableDecorationList decos;
+ decos.push_back(std::make_unique<ast::BindingDecoration>(0));
+ decos.push_back(std::make_unique<ast::SetDecoration>(1));
+ coord_var->set_decorations(std::move(decos));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(coord_var.get());
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ast::VariableList params;
+ auto func = std::make_unique<ast::Function>("frag_main", std::move(params),
+ &void_type);
+
+ auto var =
+ std::make_unique<ast::Variable>("v", ast::StorageClass::kFunction, &f32);
+ var->set_constructor(std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("coord"),
+ std::make_unique<ast::IdentifierExpression>("b")));
+
+ auto body = std::make_unique<ast::BlockStatement>();
+ body->append(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
+ body->append(std::make_unique<ast::ReturnStatement>());
+ func->set_body(std::move(body));
+
+ mod.AddFunction(std::move(func));
+
+ auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment, "",
+ "frag_main");
+ mod.AddEntryPoint(std::move(ep));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+
+ GeneratorImpl g(&mod);
+ ASSERT_TRUE(g.Generate()) << g.error();
+ EXPECT_EQ(g.result(), R"(RWByteAddressBuffer coord : register(u0);
+
+void frag_main() {
+ float v = asfloat(coord.Load(4));
+ return;
+}
+
+)");
+}
+
+TEST_F(HlslGeneratorImplTest,
+ Emit_Function_EntryPoint_With_StorageBuffer_Store) {
+ ast::type::VoidType void_type;
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList a_deco;
+ a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &i32, std::move(a_deco)));
+
+ ast::StructMemberDecorationList b_deco;
+ b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
+ members.push_back(
+ std::make_unique<ast::StructMember>("b", &f32, std::move(b_deco)));
+
+ auto str = std::make_unique<ast::Struct>();
+ str->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(str));
+ s.set_name("Data");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "coord", ast::StorageClass::kStorageBuffer, &s));
ast::VariableDecorationList decos;
decos.push_back(std::make_unique<ast::BindingDecoration>(0));
@@ -442,14 +533,15 @@
auto func = std::make_unique<ast::Function>("frag_main", std::move(params),
&void_type);
- auto var =
- std::make_unique<ast::Variable>("v", ast::StorageClass::kFunction, &f32);
- var->set_constructor(std::make_unique<ast::MemberAccessorExpression>(
- std::make_unique<ast::IdentifierExpression>("coord"),
- std::make_unique<ast::IdentifierExpression>("x")));
+ auto assign = std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("coord"),
+ std::make_unique<ast::IdentifierExpression>("b")),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 2.0f)));
auto body = std::make_unique<ast::BlockStatement>();
- body->append(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
+ body->append(std::move(assign));
body->append(std::make_unique<ast::ReturnStatement>());
func->set_body(std::move(body));
@@ -463,7 +555,14 @@
GeneratorImpl g(&mod);
ASSERT_TRUE(g.Generate()) << g.error();
- EXPECT_EQ(g.result(), R"( ... )");
+ EXPECT_EQ(g.result(), R"(RWByteAddressBuffer coord : register(u0);
+
+void frag_main() {
+ coord.Store(4, asuint(2.00000000f));
+ return;
+}
+
+)");
}
TEST_F(HlslGeneratorImplTest,
diff --git a/src/writer/hlsl/generator_impl_member_accessor_test.cc b/src/writer/hlsl/generator_impl_member_accessor_test.cc
index a4df1d8..88e3660 100644
--- a/src/writer/hlsl/generator_impl_member_accessor_test.cc
+++ b/src/writer/hlsl/generator_impl_member_accessor_test.cc
@@ -15,9 +15,27 @@
#include <memory>
#include "gtest/gtest.h"
+#include "src/ast/array_accessor_expression.h"
+#include "src/ast/assignment_statement.h"
+#include "src/ast/binary_expression.h"
+#include "src/ast/decorated_variable.h"
+#include "src/ast/float_literal.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/module.h"
+#include "src/ast/scalar_constructor_expression.h"
+#include "src/ast/sint_literal.h"
+#include "src/ast/struct.h"
+#include "src/ast/struct_member.h"
+#include "src/ast/struct_member_offset_decoration.h"
+#include "src/ast/type/array_type.h"
+#include "src/ast/type/f32_type.h"
+#include "src/ast/type/i32_type.h"
+#include "src/ast/type/struct_type.h"
+#include "src/ast/type/vector_type.h"
+#include "src/ast/type_constructor_expression.h"
+#include "src/context.h"
+#include "src/type_determiner.h"
#include "src/writer/hlsl/generator_impl.h"
namespace tint {
@@ -28,17 +46,1061 @@
using HlslGeneratorImplTest = testing::Test;
TEST_F(HlslGeneratorImplTest, EmitExpression_MemberAccessor) {
+ ast::type::F32Type f32;
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList deco;
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("mem", &f32, std::move(deco)));
+
+ auto strct = std::make_unique<ast::Struct>();
+ strct->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(strct));
+ s.set_name("Str");
+
+ auto str_var = std::make_unique<ast::DecoratedVariable>(
+ std::make_unique<ast::Variable>("str", ast::StorageClass::kPrivate, &s));
+
auto str = std::make_unique<ast::IdentifierExpression>("str");
auto mem = std::make_unique<ast::IdentifierExpression>("mem");
ast::MemberAccessorExpression expr(std::move(str), std::move(mem));
- ast::Module m;
- GeneratorImpl g(&m);
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ GeneratorImpl g(&mod);
+ td.RegisterVariableForTesting(str_var.get());
+ g.register_global(str_var.get());
+ mod.AddGlobalVariable(std::move(str_var));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_TRUE(g.EmitExpression(&expr)) << g.error();
EXPECT_EQ(g.result(), "str.mem");
}
+TEST_F(HlslGeneratorImplTest,
+ EmitExpression_MemberAccessor_StorageBuffer_Load) {
+ // struct Data {
+ // [[offset 0]] a : i32;
+ // [[offset 4]] b : f32;
+ // };
+ // var<storage_buffer> data : Data;
+ // data.b;
+ //
+ // -> asfloat(data.Load(4));
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList a_deco;
+ a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &i32, std::move(a_deco)));
+
+ ast::StructMemberDecorationList b_deco;
+ b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
+ members.push_back(
+ std::make_unique<ast::StructMember>("b", &f32, std::move(b_deco)));
+
+ auto str = std::make_unique<ast::Struct>();
+ str->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(str));
+ s.set_name("Data");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &s));
+
+ ast::MemberAccessorExpression expr(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("b"));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ GeneratorImpl g(&mod);
+
+ td.RegisterVariableForTesting(coord_var.get());
+ g.register_global(coord_var.get());
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+ ASSERT_TRUE(td.DetermineResultType(&expr));
+
+ ASSERT_TRUE(g.EmitExpression(&expr)) << g.error();
+ EXPECT_EQ(g.result(), "asfloat(data.Load(4))");
+}
+
+TEST_F(HlslGeneratorImplTest,
+ EmitExpression_MemberAccessor_StorageBuffer_Load_Int) {
+ // struct Data {
+ // [[offset 0]] a : i32;
+ // [[offset 4]] b : f32;
+ // };
+ // var<storage_buffer> data : Data;
+ // data.a;
+ //
+ // -> asint(data.Load(0));
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList a_deco;
+ a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &i32, std::move(a_deco)));
+
+ ast::StructMemberDecorationList b_deco;
+ b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
+ members.push_back(
+ std::make_unique<ast::StructMember>("b", &f32, std::move(b_deco)));
+
+ auto str = std::make_unique<ast::Struct>();
+ str->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(str));
+ s.set_name("Data");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &s));
+
+ ast::MemberAccessorExpression expr(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("a"));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ GeneratorImpl g(&mod);
+ td.RegisterVariableForTesting(coord_var.get());
+ g.register_global(coord_var.get());
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+ ASSERT_TRUE(td.DetermineResultType(&expr));
+
+ ASSERT_TRUE(g.EmitExpression(&expr)) << g.error();
+ EXPECT_EQ(g.result(), "asint(data.Load(0))");
+}
+
+TEST_F(HlslGeneratorImplTest,
+ EmitExpression_ArrayAccessor_StorageBuffer_Load_Int_FromArray) {
+ // struct Data {
+ // [[offset 0]] a : [[stride 4]] array<i32, 5>;
+ // };
+ // var<storage_buffer> data : Data;
+ // data.a[2];
+ //
+ // -> asint(data.Load((2 * 4));
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+ ast::type::ArrayType ary(&i32, 5);
+ ary.set_array_stride(4);
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList a_deco;
+ a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &ary, std::move(a_deco)));
+
+ auto str = std::make_unique<ast::Struct>();
+ str->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(str));
+ s.set_name("Data");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &s));
+
+ ast::ArrayAccessorExpression expr(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("a")),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 2)));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ GeneratorImpl g(&mod);
+ td.RegisterVariableForTesting(coord_var.get());
+ g.register_global(coord_var.get());
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ ASSERT_TRUE(g.EmitExpression(&expr)) << g.error();
+ EXPECT_EQ(g.result(), "asint(data.Load((4 * 2) + 0))");
+}
+
+TEST_F(HlslGeneratorImplTest,
+ EmitExpression_ArrayAccessor_StorageBuffer_Load_Int_FromArray_ExprIdx) {
+ // struct Data {
+ // [[offset 0]] a : [[stride 4]] array<i32, 5>;
+ // };
+ // var<storage_buffer> data : Data;
+ // data.a[(2 + 4) - 3];
+ //
+ // -> asint(data.Load((4 * ((2 + 4) - 3)));
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+ ast::type::ArrayType ary(&i32, 5);
+ ary.set_array_stride(4);
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList a_deco;
+ a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &ary, std::move(a_deco)));
+
+ auto str = std::make_unique<ast::Struct>();
+ str->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(str));
+ s.set_name("Data");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &s));
+
+ ast::ArrayAccessorExpression expr(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("a")),
+ std::make_unique<ast::BinaryExpression>(
+ ast::BinaryOp::kSubtract,
+ std::make_unique<ast::BinaryExpression>(
+ ast::BinaryOp::kAdd,
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 2)),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 4))),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 3))));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ GeneratorImpl g(&mod);
+ td.RegisterVariableForTesting(coord_var.get());
+ g.register_global(coord_var.get());
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ ASSERT_TRUE(g.EmitExpression(&expr)) << g.error();
+ EXPECT_EQ(g.result(), "asint(data.Load((4 * ((2 + 4) - 3)) + 0))");
+}
+
+TEST_F(HlslGeneratorImplTest,
+ EmitExpression_MemberAccessor_StorageBuffer_Store) {
+ // struct Data {
+ // [[offset 0]] a : i32;
+ // [[offset 4]] b : f32;
+ // };
+ // var<storage_buffer> data : Data;
+ // data.b = 2.3f;
+ //
+ // -> data.Store(0, asuint(2.0f));
+
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList a_deco;
+ a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &i32, std::move(a_deco)));
+
+ ast::StructMemberDecorationList b_deco;
+ b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
+ members.push_back(
+ std::make_unique<ast::StructMember>("b", &f32, std::move(b_deco)));
+
+ auto str = std::make_unique<ast::Struct>();
+ str->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(str));
+ s.set_name("Data");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &s));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ GeneratorImpl g(&mod);
+ td.RegisterVariableForTesting(coord_var.get());
+ g.register_global(coord_var.get());
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+
+ auto lhs = std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("b"));
+ auto rhs = std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 2.0f));
+ ast::AssignmentStatement assign(std::move(lhs), std::move(rhs));
+
+ ASSERT_TRUE(td.DetermineResultType(&assign));
+ ASSERT_TRUE(g.EmitStatement(&assign)) << g.error();
+ EXPECT_EQ(g.result(), R"(data.Store(4, asuint(2.00000000f));
+)");
+}
+
+TEST_F(HlslGeneratorImplTest,
+ EmitExpression_MemberAccessor_StorageBuffer_Store_ToArray) {
+ // struct Data {
+ // [[offset 0]] a : [[stride 4]] array<i32, 5>;
+ // };
+ // var<storage_buffer> data : Data;
+ // data.a[2] = 2;
+ //
+ // -> data.Store((2 * 4), asuint(2.3f));
+
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+ ast::type::ArrayType ary(&i32, 5);
+ ary.set_array_stride(4);
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList a_deco;
+ a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &ary, std::move(a_deco)));
+
+ auto str = std::make_unique<ast::Struct>();
+ str->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(str));
+ s.set_name("Data");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &s));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ GeneratorImpl g(&mod);
+ td.RegisterVariableForTesting(coord_var.get());
+ g.register_global(coord_var.get());
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+
+ auto lhs = std::make_unique<ast::ArrayAccessorExpression>(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("a")),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 2)));
+ auto rhs = std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 2));
+ ast::AssignmentStatement assign(std::move(lhs), std::move(rhs));
+
+ ASSERT_TRUE(td.DetermineResultType(&assign)) << td.error();
+ ASSERT_TRUE(g.EmitStatement(&assign)) << g.error();
+ EXPECT_EQ(g.result(), R"(data.Store((4 * 2) + 0, asuint(2));
+)");
+}
+
+TEST_F(HlslGeneratorImplTest,
+ EmitExpression_MemberAccessor_StorageBuffer_Store_Int) {
+ // struct Data {
+ // [[offset 0]] a : i32;
+ // [[offset 4]] b : f32;
+ // };
+ // var<storage_buffer> data : Data;
+ // data.a = 2;
+ //
+ // -> data.Store(0, asuint(2));
+
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList a_deco;
+ a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &i32, std::move(a_deco)));
+
+ ast::StructMemberDecorationList b_deco;
+ b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
+ members.push_back(
+ std::make_unique<ast::StructMember>("b", &f32, std::move(b_deco)));
+
+ auto str = std::make_unique<ast::Struct>();
+ str->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(str));
+ s.set_name("Data");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &s));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ GeneratorImpl g(&mod);
+ td.RegisterVariableForTesting(coord_var.get());
+ g.register_global(coord_var.get());
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+
+ auto lhs = std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("a"));
+ auto rhs = std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 2));
+ ast::AssignmentStatement assign(std::move(lhs), std::move(rhs));
+
+ ASSERT_TRUE(td.DetermineResultType(&assign));
+ ASSERT_TRUE(g.EmitStatement(&assign)) << g.error();
+ EXPECT_EQ(g.result(), R"(data.Store(0, asuint(2));
+)");
+}
+
+TEST_F(HlslGeneratorImplTest,
+ EmitExpression_MemberAccessor_StorageBuffer_Load_Vec3) {
+ // struct Data {
+ // [[offset 0]] a : vec3<i32>;
+ // [[offset 16]] b : vec3<f32>;
+ // };
+ // var<storage_buffer> data : Data;
+ // data.b;
+ //
+ // -> asfloat(data.Load(16));
+
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+ ast::type::VectorType ivec3(&i32, 3);
+ ast::type::VectorType fvec3(&f32, 3);
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList a_deco;
+ a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &ivec3, std::move(a_deco)));
+
+ ast::StructMemberDecorationList b_deco;
+ b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(16));
+ members.push_back(
+ std::make_unique<ast::StructMember>("b", &fvec3, std::move(b_deco)));
+
+ auto str = std::make_unique<ast::Struct>();
+ str->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(str));
+ s.set_name("Data");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &s));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ GeneratorImpl g(&mod);
+ td.RegisterVariableForTesting(coord_var.get());
+ g.register_global(coord_var.get());
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+
+ ast::MemberAccessorExpression expr(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("b"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr));
+ ASSERT_TRUE(g.EmitExpression(&expr)) << g.error();
+ EXPECT_EQ(g.result(), "asfloat(data.Load3(16))");
+}
+
+TEST_F(HlslGeneratorImplTest,
+ EmitExpression_MemberAccessor_StorageBuffer_Store_Vec3) {
+ // struct Data {
+ // [[offset 0]] a : vec3<i32>;
+ // [[offset 16]] b : vec3<f32>;
+ // };
+ // var<storage_buffer> data : Data;
+ // data.b = vec3<f32>(2.3f, 1.2f, 0.2f);
+ //
+ // -> data.Store(16, asuint(vector<float, 3>(2.3f, 1.2f, 0.2f)));
+
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+ ast::type::VectorType ivec3(&i32, 3);
+ ast::type::VectorType fvec3(&f32, 3);
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList a_deco;
+ a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &ivec3, std::move(a_deco)));
+
+ ast::StructMemberDecorationList b_deco;
+ b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(16));
+ members.push_back(
+ std::make_unique<ast::StructMember>("b", &fvec3, std::move(b_deco)));
+
+ auto str = std::make_unique<ast::Struct>();
+ str->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(str));
+ s.set_name("Data");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &s));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ GeneratorImpl g(&mod);
+ td.RegisterVariableForTesting(coord_var.get());
+ g.register_global(coord_var.get());
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+
+ auto lit1 = std::make_unique<ast::FloatLiteral>(&f32, 1.f);
+ auto lit2 = std::make_unique<ast::FloatLiteral>(&f32, 2.f);
+ auto lit3 = std::make_unique<ast::FloatLiteral>(&f32, 3.f);
+ ast::ExpressionList values;
+ values.push_back(
+ std::make_unique<ast::ScalarConstructorExpression>(std::move(lit1)));
+ values.push_back(
+ std::make_unique<ast::ScalarConstructorExpression>(std::move(lit2)));
+ values.push_back(
+ std::make_unique<ast::ScalarConstructorExpression>(std::move(lit3)));
+
+ auto lhs = std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("b"));
+ auto rhs = std::make_unique<ast::TypeConstructorExpression>(
+ &fvec3, std::move(values));
+
+ ast::AssignmentStatement assign(std::move(lhs), std::move(rhs));
+
+ ASSERT_TRUE(td.DetermineResultType(&assign));
+ ASSERT_TRUE(g.EmitStatement(&assign)) << g.error();
+ EXPECT_EQ(
+ g.result(),
+ R"(data.Store3(16, asuint(vector<float, 3>(1.00000000f, 2.00000000f, 3.00000000f)));
+)");
+}
+
+TEST_F(HlslGeneratorImplTest,
+ EmitExpression_MemberAccessor_StorageBuffer_Load_MultiLevel) {
+ // struct Data {
+ // [[offset 0]] a : vec3<i32>;
+ // [[offset 16]] b : vec3<f32>;
+ // };
+ // struct Pre {
+ // var c : [[stride 32]] array<Data, 4>;
+ // };
+ //
+ // var<storage_buffer> data : Pre;
+ // data.c[2].b
+ //
+ // -> asfloat(data.Load3(16 + (2 * 32)))
+
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+ ast::type::VectorType ivec3(&i32, 3);
+ ast::type::VectorType fvec3(&f32, 3);
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList deco;
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &ivec3, std::move(deco)));
+
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(16));
+ members.push_back(
+ std::make_unique<ast::StructMember>("b", &fvec3, std::move(deco)));
+
+ auto data_str = std::make_unique<ast::Struct>();
+ data_str->set_members(std::move(members));
+
+ ast::type::StructType data(std::move(data_str));
+ data.set_name("Data");
+
+ ast::type::ArrayType ary(&data, 4);
+ ary.set_array_stride(32);
+
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("c", &ary, std::move(deco)));
+
+ auto pre_str = std::make_unique<ast::Struct>();
+ pre_str->set_members(std::move(members));
+
+ ast::type::StructType pre(std::move(pre_str));
+ pre.set_name("Pre");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &pre));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ GeneratorImpl g(&mod);
+ td.RegisterVariableForTesting(coord_var.get());
+ g.register_global(coord_var.get());
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+
+ ast::MemberAccessorExpression expr(
+ std::make_unique<ast::ArrayAccessorExpression>(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("c")),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 2))),
+ std::make_unique<ast::IdentifierExpression>("b"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr));
+ ASSERT_TRUE(g.EmitExpression(&expr)) << g.error();
+ EXPECT_EQ(g.result(), "asfloat(data.Load3(16 + (32 * 2) + 0))");
+}
+
+TEST_F(HlslGeneratorImplTest,
+ EmitExpression_MemberAccessor_StorageBuffer_Load_MultiLevel_Swizzle) {
+ // struct Data {
+ // [[offset 0]] a : vec3<i32>;
+ // [[offset 16]] b : vec3<f32>;
+ // };
+ // struct Pre {
+ // var c : [[stride 32]] array<Data, 4>;
+ // };
+ //
+ // var<storage_buffer> data : Pre;
+ // data.c[2].b.xy
+ //
+ // -> asfloat(data.Load3(16 + (2 * 32))).xy
+
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+ ast::type::VectorType ivec3(&i32, 3);
+ ast::type::VectorType fvec3(&f32, 3);
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList deco;
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &ivec3, std::move(deco)));
+
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(16));
+ members.push_back(
+ std::make_unique<ast::StructMember>("b", &fvec3, std::move(deco)));
+
+ auto data_str = std::make_unique<ast::Struct>();
+ data_str->set_members(std::move(members));
+
+ ast::type::StructType data(std::move(data_str));
+ data.set_name("Data");
+
+ ast::type::ArrayType ary(&data, 4);
+ ary.set_array_stride(32);
+
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("c", &ary, std::move(deco)));
+
+ auto pre_str = std::make_unique<ast::Struct>();
+ pre_str->set_members(std::move(members));
+
+ ast::type::StructType pre(std::move(pre_str));
+ pre.set_name("Pre");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &pre));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ GeneratorImpl g(&mod);
+ td.RegisterVariableForTesting(coord_var.get());
+ g.register_global(coord_var.get());
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+
+ ast::MemberAccessorExpression expr(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::ArrayAccessorExpression>(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("c")),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 2))),
+ std::make_unique<ast::IdentifierExpression>("b")),
+ std::make_unique<ast::IdentifierExpression>("xy"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr));
+ ASSERT_TRUE(g.EmitExpression(&expr)) << g.error();
+ EXPECT_EQ(g.result(), "asfloat(data.Load3(16 + (32 * 2) + 0)).xy");
+}
+
+TEST_F(
+ HlslGeneratorImplTest,
+ EmitExpression_MemberAccessor_StorageBuffer_Load_MultiLevel_Swizzle_SingleLetter) {
+ // struct Data {
+ // [[offset 0]] a : vec3<i32>;
+ // [[offset 16]] b : vec3<f32>;
+ // };
+ // struct Pre {
+ // var c : [[stride 32]] array<Data, 4>;
+ // };
+ //
+ // var<storage_buffer> data : Pre;
+ // data.c[2].b.g
+ //
+ // -> asfloat(data.Load((4 * 1) + 16 + (2 * 32) + 0))
+
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+ ast::type::VectorType ivec3(&i32, 3);
+ ast::type::VectorType fvec3(&f32, 3);
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList deco;
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &ivec3, std::move(deco)));
+
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(16));
+ members.push_back(
+ std::make_unique<ast::StructMember>("b", &fvec3, std::move(deco)));
+
+ auto data_str = std::make_unique<ast::Struct>();
+ data_str->set_members(std::move(members));
+
+ ast::type::StructType data(std::move(data_str));
+ data.set_name("Data");
+
+ ast::type::ArrayType ary(&data, 4);
+ ary.set_array_stride(32);
+
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("c", &ary, std::move(deco)));
+
+ auto pre_str = std::make_unique<ast::Struct>();
+ pre_str->set_members(std::move(members));
+
+ ast::type::StructType pre(std::move(pre_str));
+ pre.set_name("Pre");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &pre));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ GeneratorImpl g(&mod);
+ td.RegisterVariableForTesting(coord_var.get());
+ g.register_global(coord_var.get());
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+
+ ast::MemberAccessorExpression expr(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::ArrayAccessorExpression>(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("c")),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 2))),
+ std::make_unique<ast::IdentifierExpression>("b")),
+ std::make_unique<ast::IdentifierExpression>("g"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr));
+ ASSERT_TRUE(g.EmitExpression(&expr)) << g.error();
+ EXPECT_EQ(g.result(), "asfloat(data.Load((4 * 1) + 16 + (32 * 2) + 0))");
+}
+
+TEST_F(HlslGeneratorImplTest,
+ EmitExpression_MemberAccessor_StorageBuffer_Load_MultiLevel_Index) {
+ // struct Data {
+ // [[offset 0]] a : vec3<i32>;
+ // [[offset 16]] b : vec3<f32>;
+ // };
+ // struct Pre {
+ // var c : [[stride 32]] array<Data, 4>;
+ // };
+ //
+ // var<storage_buffer> data : Pre;
+ // data.c[2].b[1]
+ //
+ // -> asfloat(data.Load(4 + 16 + (2 * 32)))
+
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+ ast::type::VectorType ivec3(&i32, 3);
+ ast::type::VectorType fvec3(&f32, 3);
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList deco;
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &ivec3, std::move(deco)));
+
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(16));
+ members.push_back(
+ std::make_unique<ast::StructMember>("b", &fvec3, std::move(deco)));
+
+ auto data_str = std::make_unique<ast::Struct>();
+ data_str->set_members(std::move(members));
+
+ ast::type::StructType data(std::move(data_str));
+ data.set_name("Data");
+
+ ast::type::ArrayType ary(&data, 4);
+ ary.set_array_stride(32);
+
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("c", &ary, std::move(deco)));
+
+ auto pre_str = std::make_unique<ast::Struct>();
+ pre_str->set_members(std::move(members));
+
+ ast::type::StructType pre(std::move(pre_str));
+ pre.set_name("Pre");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &pre));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ GeneratorImpl g(&mod);
+ td.RegisterVariableForTesting(coord_var.get());
+ g.register_global(coord_var.get());
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+
+ ast::ArrayAccessorExpression expr(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::ArrayAccessorExpression>(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("c")),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 2))),
+ std::make_unique<ast::IdentifierExpression>("b")),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 1)));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr));
+ ASSERT_TRUE(g.EmitExpression(&expr)) << g.error();
+ EXPECT_EQ(g.result(), "asfloat(data.Load((4 * 1) + 16 + (32 * 2) + 0))");
+}
+
+TEST_F(HlslGeneratorImplTest,
+ EmitExpression_MemberAccessor_StorageBuffer_Store_MultiLevel) {
+ // struct Data {
+ // [[offset 0]] a : vec3<i32>;
+ // [[offset 16]] b : vec3<f32>;
+ // };
+ // struct Pre {
+ // var c : [[stride 32]] array<Data, 4>;
+ // };
+ //
+ // var<storage_buffer> data : Pre;
+ // data.c[2].b = vec3<f32>(1.f, 2.f, 3.f);
+ //
+ // -> data.Store3(16 + (2 * 32), asuint(vector<float, 3>(1.0f, 2.0f, 3.0f)));
+
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+ ast::type::VectorType ivec3(&i32, 3);
+ ast::type::VectorType fvec3(&f32, 3);
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList deco;
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &ivec3, std::move(deco)));
+
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(16));
+ members.push_back(
+ std::make_unique<ast::StructMember>("b", &fvec3, std::move(deco)));
+
+ auto data_str = std::make_unique<ast::Struct>();
+ data_str->set_members(std::move(members));
+
+ ast::type::StructType data(std::move(data_str));
+ data.set_name("Data");
+
+ ast::type::ArrayType ary(&data, 4);
+ ary.set_array_stride(32);
+
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("c", &ary, std::move(deco)));
+
+ auto pre_str = std::make_unique<ast::Struct>();
+ pre_str->set_members(std::move(members));
+
+ ast::type::StructType pre(std::move(pre_str));
+ pre.set_name("Pre");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &pre));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ GeneratorImpl g(&mod);
+ td.RegisterVariableForTesting(coord_var.get());
+ g.register_global(coord_var.get());
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+
+ auto lhs = std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::ArrayAccessorExpression>(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("c")),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 2))),
+ std::make_unique<ast::IdentifierExpression>("b"));
+
+ auto lit1 = std::make_unique<ast::FloatLiteral>(&f32, 1.f);
+ auto lit2 = std::make_unique<ast::FloatLiteral>(&f32, 2.f);
+ auto lit3 = std::make_unique<ast::FloatLiteral>(&f32, 3.f);
+ ast::ExpressionList values;
+ values.push_back(
+ std::make_unique<ast::ScalarConstructorExpression>(std::move(lit1)));
+ values.push_back(
+ std::make_unique<ast::ScalarConstructorExpression>(std::move(lit2)));
+ values.push_back(
+ std::make_unique<ast::ScalarConstructorExpression>(std::move(lit3)));
+
+ auto rhs = std::make_unique<ast::TypeConstructorExpression>(
+ &fvec3, std::move(values));
+
+ ast::AssignmentStatement assign(std::move(lhs), std::move(rhs));
+
+ ASSERT_TRUE(td.DetermineResultType(&assign));
+ ASSERT_TRUE(g.EmitStatement(&assign)) << g.error();
+ EXPECT_EQ(
+ g.result(),
+ R"(data.Store3(16 + (32 * 2) + 0, asuint(vector<float, 3>(1.00000000f, 2.00000000f, 3.00000000f)));
+)");
+}
+
+TEST_F(HlslGeneratorImplTest,
+ EmitExpression_MemberAccessor_StorageBuffer_Store_Swizzle_SingleLetter) {
+ // struct Data {
+ // [[offset 0]] a : vec3<i32>;
+ // [[offset 16]] b : vec3<f32>;
+ // };
+ // struct Pre {
+ // var c : [[stride 32]] array<Data, 4>;
+ // };
+ //
+ // var<storage_buffer> data : Pre;
+ // data.c[2].b.y = 1.f;
+ //
+ // -> data.Store((4 * 1) + 16 + (2 * 32) + 0, asuint(1.0f));
+
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+ ast::type::VectorType ivec3(&i32, 3);
+ ast::type::VectorType fvec3(&f32, 3);
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList deco;
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &ivec3, std::move(deco)));
+
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(16));
+ members.push_back(
+ std::make_unique<ast::StructMember>("b", &fvec3, std::move(deco)));
+
+ auto data_str = std::make_unique<ast::Struct>();
+ data_str->set_members(std::move(members));
+
+ ast::type::StructType data(std::move(data_str));
+ data.set_name("Data");
+
+ ast::type::ArrayType ary(&data, 4);
+ ary.set_array_stride(32);
+
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("c", &ary, std::move(deco)));
+
+ auto pre_str = std::make_unique<ast::Struct>();
+ pre_str->set_members(std::move(members));
+
+ ast::type::StructType pre(std::move(pre_str));
+ pre.set_name("Pre");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &pre));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ GeneratorImpl g(&mod);
+ td.RegisterVariableForTesting(coord_var.get());
+ g.register_global(coord_var.get());
+ mod.AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+
+ auto lhs = std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::ArrayAccessorExpression>(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("c")),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 2))),
+ std::make_unique<ast::IdentifierExpression>("b")),
+ std::make_unique<ast::IdentifierExpression>("y"));
+
+ auto rhs = std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&i32, 1.f));
+
+ ast::AssignmentStatement assign(std::move(lhs), std::move(rhs));
+
+ ASSERT_TRUE(td.DetermineResultType(&assign));
+ ASSERT_TRUE(g.EmitStatement(&assign)) << g.error();
+ EXPECT_EQ(g.result(),
+ R"(data.Store((4 * 1) + 16 + (32 * 2) + 0, asuint(1.00000000f));
+)");
+}
+
} // namespace
} // namespace hlsl
} // namespace writer