[hlsl] Decompose load/store vector element
This CL updates the decompose memory access transform to convert vector
load and store calls into the equivalent HLSL `Load` and `Store` member
functions.
Bug: 349867642
Change-Id: I9e7a8808d664478126e98d63a735eed2011d5ff9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/196314
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/ir/BUILD.bazel b/src/tint/lang/hlsl/ir/BUILD.bazel
index b8a34ae..c52853bb 100644
--- a/src/tint/lang/hlsl/ir/BUILD.bazel
+++ b/src/tint/lang/hlsl/ir/BUILD.bazel
@@ -40,10 +40,12 @@
name = "ir",
srcs = [
"builtin_call.cc",
+ "member_builtin_call.cc",
"ternary.cc",
],
hdrs = [
"builtin_call.h",
+ "member_builtin_call.h",
"ternary.h",
],
deps = [
@@ -77,6 +79,7 @@
alwayslink = True,
srcs = [
"builtin_call_test.cc",
+ "member_builtin_call_test.cc",
"ternary_test.cc",
],
deps = [
@@ -90,6 +93,7 @@
"//src/tint/lang/hlsl",
"//src/tint/lang/hlsl/intrinsic",
"//src/tint/lang/hlsl/ir",
+ "//src/tint/lang/hlsl/type",
"//src/tint/utils/containers",
"//src/tint/utils/diagnostic",
"//src/tint/utils/ice",
diff --git a/src/tint/lang/hlsl/ir/BUILD.cmake b/src/tint/lang/hlsl/ir/BUILD.cmake
index f89052f..08baa1e 100644
--- a/src/tint/lang/hlsl/ir/BUILD.cmake
+++ b/src/tint/lang/hlsl/ir/BUILD.cmake
@@ -41,6 +41,8 @@
tint_add_target(tint_lang_hlsl_ir lib
lang/hlsl/ir/builtin_call.cc
lang/hlsl/ir/builtin_call.h
+ lang/hlsl/ir/member_builtin_call.cc
+ lang/hlsl/ir/member_builtin_call.h
lang/hlsl/ir/ternary.cc
lang/hlsl/ir/ternary.h
)
@@ -75,6 +77,7 @@
################################################################################
tint_add_target(tint_lang_hlsl_ir_test test
lang/hlsl/ir/builtin_call_test.cc
+ lang/hlsl/ir/member_builtin_call_test.cc
lang/hlsl/ir/ternary_test.cc
)
@@ -89,6 +92,7 @@
tint_lang_hlsl
tint_lang_hlsl_intrinsic
tint_lang_hlsl_ir
+ tint_lang_hlsl_type
tint_utils_containers
tint_utils_diagnostic
tint_utils_ice
diff --git a/src/tint/lang/hlsl/ir/BUILD.gn b/src/tint/lang/hlsl/ir/BUILD.gn
index 21087ee..2fa8d6c 100644
--- a/src/tint/lang/hlsl/ir/BUILD.gn
+++ b/src/tint/lang/hlsl/ir/BUILD.gn
@@ -46,6 +46,8 @@
sources = [
"builtin_call.cc",
"builtin_call.h",
+ "member_builtin_call.cc",
+ "member_builtin_call.h",
"ternary.cc",
"ternary.h",
]
@@ -77,6 +79,7 @@
tint_unittests_source_set("unittests") {
sources = [
"builtin_call_test.cc",
+ "member_builtin_call_test.cc",
"ternary_test.cc",
]
deps = [
@@ -91,6 +94,7 @@
"${tint_src_dir}/lang/hlsl",
"${tint_src_dir}/lang/hlsl/intrinsic",
"${tint_src_dir}/lang/hlsl/ir",
+ "${tint_src_dir}/lang/hlsl/type",
"${tint_src_dir}/utils/containers",
"${tint_src_dir}/utils/diagnostic",
"${tint_src_dir}/utils/ice",
diff --git a/src/tint/lang/hlsl/ir/member_builtin_call.cc b/src/tint/lang/hlsl/ir/member_builtin_call.cc
new file mode 100644
index 0000000..81707ca
--- /dev/null
+++ b/src/tint/lang/hlsl/ir/member_builtin_call.cc
@@ -0,0 +1,63 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/hlsl/ir/member_builtin_call.h"
+
+#include <utility>
+
+#include "src/tint/lang/core/ir/clone_context.h"
+#include "src/tint/lang/core/ir/instruction_result.h"
+#include "src/tint/lang/core/ir/module.h"
+#include "src/tint/lang/core/ir/value.h"
+#include "src/tint/lang/hlsl/builtin_fn.h"
+#include "src/tint/utils/containers/vector.h"
+#include "src/tint/utils/ice/ice.h"
+#include "src/tint/utils/rtti/castable.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::hlsl::ir::MemberBuiltinCall);
+
+namespace tint::hlsl::ir {
+
+MemberBuiltinCall::MemberBuiltinCall(core::ir::InstructionResult* result,
+ BuiltinFn func,
+ core::ir::Value* object,
+ VectorRef<core::ir::Value*> arguments)
+ : Base(result, object, arguments), func_(func) {
+ TINT_ASSERT(func != BuiltinFn::kNone);
+}
+
+MemberBuiltinCall::~MemberBuiltinCall() = default;
+
+MemberBuiltinCall* MemberBuiltinCall::Clone(core::ir::CloneContext& ctx) {
+ auto* new_result = ctx.Clone(Result(0));
+ auto* new_object = ctx.Clone(Object());
+ auto new_args = ctx.Clone<MemberBuiltinCall::kDefaultNumOperands>(Args());
+ return ctx.ir.allocators.instructions.Create<MemberBuiltinCall>(new_result, func_, new_object,
+ std::move(new_args));
+}
+
+} // namespace tint::hlsl::ir
diff --git a/src/tint/lang/hlsl/ir/member_builtin_call.h b/src/tint/lang/hlsl/ir/member_builtin_call.h
new file mode 100644
index 0000000..068ff97
--- /dev/null
+++ b/src/tint/lang/hlsl/ir/member_builtin_call.h
@@ -0,0 +1,78 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef SRC_TINT_LANG_HLSL_IR_MEMBER_BUILTIN_CALL_H_
+#define SRC_TINT_LANG_HLSL_IR_MEMBER_BUILTIN_CALL_H_
+
+#include <string>
+
+#include "src/tint/lang/core/intrinsic/table_data.h"
+#include "src/tint/lang/core/ir/member_builtin_call.h"
+#include "src/tint/lang/hlsl/builtin_fn.h"
+#include "src/tint/lang/hlsl/intrinsic/dialect.h"
+#include "src/tint/utils/rtti/castable.h"
+
+namespace tint::hlsl::ir {
+
+/// An HLSL member builtin call instruction in the IR.
+class MemberBuiltinCall final : public Castable<MemberBuiltinCall, core::ir::MemberBuiltinCall> {
+ public:
+ /// Constructor
+ /// @param result the result value
+ /// @param func the builtin function
+ /// @param object the object
+ /// @param args the call arguments
+ MemberBuiltinCall(core::ir::InstructionResult* result,
+ BuiltinFn func,
+ core::ir::Value* object,
+ VectorRef<core::ir::Value*> args = tint::Empty);
+ ~MemberBuiltinCall() override;
+
+ /// @copydoc core::ir::Instruction::Clone()
+ MemberBuiltinCall* Clone(core::ir::CloneContext& ctx) override;
+
+ /// @returns the builtin function
+ BuiltinFn Func() const { return func_; }
+
+ /// @returns the identifier for the function
+ size_t FuncId() const override { return static_cast<size_t>(func_); }
+
+ /// @returns the friendly name for the instruction
+ std::string FriendlyName() const override { return str(func_); }
+
+ /// @returns the table data to validate this builtin
+ const core::intrinsic::TableData& TableData() const override {
+ return hlsl::intrinsic::Dialect::kData;
+ }
+
+ private:
+ BuiltinFn func_;
+};
+
+} // namespace tint::hlsl::ir
+
+#endif // SRC_TINT_LANG_HLSL_IR_MEMBER_BUILTIN_CALL_H_
diff --git a/src/tint/lang/hlsl/ir/member_builtin_call_test.cc b/src/tint/lang/hlsl/ir/member_builtin_call_test.cc
new file mode 100644
index 0000000..86ff37d
--- /dev/null
+++ b/src/tint/lang/hlsl/ir/member_builtin_call_test.cc
@@ -0,0 +1,149 @@
+// Copyright 2023 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/hlsl/ir/member_builtin_call.h"
+
+#include "gtest/gtest.h"
+
+#include "src/tint/lang/core/fluent_types.h"
+#include "src/tint/lang/core/ir/ir_helper_test.h"
+#include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/lang/core/number.h"
+#include "src/tint/lang/core/type/sampled_texture.h"
+#include "src/tint/lang/core/type/sampler.h"
+#include "src/tint/lang/core/type/texture_dimension.h"
+#include "src/tint/lang/core/type/vector.h"
+#include "src/tint/lang/hlsl/builtin_fn.h"
+#include "src/tint/lang/hlsl/type/byte_address_buffer.h"
+#include "src/tint/utils/result/result.h"
+
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
+
+namespace tint::hlsl::ir {
+namespace {
+
+using IR_HlslMemberBuiltinCallTest = core::ir::IRTestHelper;
+
+TEST_F(IR_HlslMemberBuiltinCallTest, Clone) {
+ auto* buf = ty.Get<hlsl::type::ByteAddressBuffer>(ty.vec3<i32>(), core::Access::kReadWrite);
+
+ auto* t = b.FunctionParam("t", buf);
+ auto* builtin = b.MemberCall<MemberBuiltinCall>(mod.Types().u32(), BuiltinFn::kLoad, t, 2_u);
+
+ auto* new_b = clone_ctx.Clone(builtin);
+
+ EXPECT_NE(builtin, new_b);
+ EXPECT_NE(builtin->Result(0), new_b->Result(0));
+ EXPECT_EQ(mod.Types().u32(), new_b->Result(0)->Type());
+
+ EXPECT_EQ(BuiltinFn::kLoad, new_b->Func());
+
+ EXPECT_TRUE(new_b->Object()->Type()->Is<hlsl::type::ByteAddressBuffer>());
+
+ auto args = new_b->Args();
+ ASSERT_EQ(1u, args.Length());
+ EXPECT_TRUE(args[0]->Type()->Is<core::type::U32>());
+}
+
+TEST_F(IR_HlslMemberBuiltinCallTest, DoesNotMatchNonMemberFunction) {
+ auto* buf = ty.Get<hlsl::type::ByteAddressBuffer>(ty.vec3<i32>(), core::Access::kRead);
+
+ auto* t = b.FunctionParam("t", buf);
+
+ auto* func = b.Function("foo", ty.u32());
+ func->SetParams({t});
+ b.Append(func->Block(), [&] {
+ auto* builtin =
+ b.MemberCall<MemberBuiltinCall>(mod.Types().u32(), BuiltinFn::kAsint, t, 2_u);
+ b.Return(func, builtin);
+ });
+
+ auto res = core::ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(
+ res.Failure().reason.Str(),
+ R"(:3:17 error: asint: no matching call to 'asint(hlsl.byte_address_buffer<vec3<i32>, read>, u32)'
+
+ %3:u32 = %t.asint 2u
+ ^^^^^
+
+:2:3 note: in block
+ $B1: {
+ ^^^
+
+note: # Disassembly
+%foo = func(%t:hlsl.byte_address_buffer<vec3<i32>, read>):u32 {
+ $B1: {
+ %3:u32 = %t.asint 2u
+ ret %3
+ }
+}
+)");
+}
+
+TEST_F(IR_HlslMemberBuiltinCallTest, DoesNotMatchIncorrectType) {
+ auto* buf = ty.Get<hlsl::type::ByteAddressBuffer>(ty.vec3<i32>(), core::Access::kRead);
+
+ auto* t = b.FunctionParam("t", buf);
+
+ auto* func = b.Function("foo", ty.u32());
+ func->SetParams({t});
+ b.Append(func->Block(), [&] {
+ auto* builtin =
+ b.MemberCall<MemberBuiltinCall>(mod.Types().u32(), BuiltinFn::kStore, t, 2_u, 2_u);
+ b.Return(func, builtin);
+ });
+
+ auto res = core::ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(
+ res.Failure().reason.Str(),
+ R"(:3:17 error: Store: no matching call to 'Store(hlsl.byte_address_buffer<vec3<i32>, read>, u32, u32)'
+
+1 candidate function:
+ • 'Store(byte_address_buffer<T, write' or 'read_write> ✗ , offset: u32 ✓ , value: u32 ✓ )'
+
+ %3:u32 = %t.Store 2u, 2u
+ ^^^^^
+
+:2:3 note: in block
+ $B1: {
+ ^^^
+
+note: # Disassembly
+%foo = func(%t:hlsl.byte_address_buffer<vec3<i32>, read>):u32 {
+ $B1: {
+ %3:u32 = %t.Store 2u, 2u
+ ret %3
+ }
+}
+)");
+}
+
+} // namespace
+} // namespace tint::hlsl::ir
diff --git a/src/tint/lang/hlsl/writer/access_test.cc b/src/tint/lang/hlsl/writer/access_test.cc
index 02d78f1..63a5822 100644
--- a/src/tint/lang/hlsl/writer/access_test.cc
+++ b/src/tint/lang/hlsl/writer/access_test.cc
@@ -269,5 +269,59 @@
)");
}
+TEST_F(HlslWriterTest, AccessVectorLoad) {
+ auto* var = b.Var<storage, vec4<f32>, core::Access::kRead>("v");
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.LoadVectorElement(var, 0_u));
+ b.Let("b", b.LoadVectorElement(var, 1_u));
+ b.Let("c", b.LoadVectorElement(var, 2_u));
+ b.Let("d", b.LoadVectorElement(var, 3_u));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+ByteAddressBuffer v : register(t0);
+void foo() {
+ float a = asfloat(v.Load(0u));
+ float b = asfloat(v.Load(4u));
+ float c = asfloat(v.Load(8u));
+ float d = asfloat(v.Load(12u));
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, AccessVectorStore) {
+ auto* var = b.Var<storage, vec4<f32>, core::Access::kReadWrite>("v");
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.StoreVectorElement(var, 0_u, 2_f);
+ b.StoreVectorElement(var, 1_u, 4_f);
+ b.StoreVectorElement(var, 2_u, 8_f);
+ b.StoreVectorElement(var, 3_u, 16_f);
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+RWByteAddressBuffer v : register(u0);
+void foo() {
+ v.Store(0u, asuint(2.0f));
+ v.Store(4u, asuint(4.0f));
+ v.Store(8u, asuint(8.0f));
+ v.Store(12u, asuint(16.0f));
+}
+
+)");
+}
+
} // namespace
} // namespace tint::hlsl::writer
diff --git a/src/tint/lang/hlsl/writer/printer/BUILD.bazel b/src/tint/lang/hlsl/writer/printer/BUILD.bazel
index 69bea38..3b779b1 100644
--- a/src/tint/lang/hlsl/writer/printer/BUILD.bazel
+++ b/src/tint/lang/hlsl/writer/printer/BUILD.bazel
@@ -54,6 +54,7 @@
"//src/tint/lang/hlsl",
"//src/tint/lang/hlsl/intrinsic",
"//src/tint/lang/hlsl/ir",
+ "//src/tint/lang/hlsl/type",
"//src/tint/utils/containers",
"//src/tint/utils/diagnostic",
"//src/tint/utils/generator",
diff --git a/src/tint/lang/hlsl/writer/printer/BUILD.cmake b/src/tint/lang/hlsl/writer/printer/BUILD.cmake
index 4fcaeb4..349113e 100644
--- a/src/tint/lang/hlsl/writer/printer/BUILD.cmake
+++ b/src/tint/lang/hlsl/writer/printer/BUILD.cmake
@@ -53,6 +53,7 @@
tint_lang_hlsl
tint_lang_hlsl_intrinsic
tint_lang_hlsl_ir
+ tint_lang_hlsl_type
tint_utils_containers
tint_utils_diagnostic
tint_utils_generator
diff --git a/src/tint/lang/hlsl/writer/printer/BUILD.gn b/src/tint/lang/hlsl/writer/printer/BUILD.gn
index eb14752..9c7a49f 100644
--- a/src/tint/lang/hlsl/writer/printer/BUILD.gn
+++ b/src/tint/lang/hlsl/writer/printer/BUILD.gn
@@ -53,6 +53,7 @@
"${tint_src_dir}/lang/hlsl",
"${tint_src_dir}/lang/hlsl/intrinsic",
"${tint_src_dir}/lang/hlsl/ir",
+ "${tint_src_dir}/lang/hlsl/type",
"${tint_src_dir}/utils/containers",
"${tint_src_dir}/utils/diagnostic",
"${tint_src_dir}/utils/generator",
diff --git a/src/tint/lang/hlsl/writer/printer/printer.cc b/src/tint/lang/hlsl/writer/printer/printer.cc
index 8aa8284..f4adafc 100644
--- a/src/tint/lang/hlsl/writer/printer/printer.cc
+++ b/src/tint/lang/hlsl/writer/printer/printer.cc
@@ -31,7 +31,6 @@
#include <cstddef>
#include <cstdint>
#include <string>
-#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
@@ -68,7 +67,7 @@
#include "src/tint/lang/core/ir/load_vector_element.h"
#include "src/tint/lang/core/ir/loop.h"
#include "src/tint/lang/core/ir/module.h"
-#include "src/tint/lang/core/ir/multi_in_block.h"
+#include "src/tint/lang/core/ir/multi_in_block.h" // IWYU pragma: export
#include "src/tint/lang/core/ir/next_iteration.h"
#include "src/tint/lang/core/ir/return.h"
#include "src/tint/lang/core/ir/store.h"
@@ -105,7 +104,9 @@
#include "src/tint/lang/core/type/vector.h"
#include "src/tint/lang/core/type/void.h"
#include "src/tint/lang/hlsl/ir/builtin_call.h"
+#include "src/tint/lang/hlsl/ir/member_builtin_call.h"
#include "src/tint/lang/hlsl/ir/ternary.h"
+#include "src/tint/lang/hlsl/type/byte_address_buffer.h"
#include "src/tint/utils/containers/hashmap.h"
#include "src/tint/utils/containers/map.h"
#include "src/tint/utils/generator/text_generator.h"
@@ -247,9 +248,6 @@
++i;
auto ptr = param->Type()->As<core::type::Pointer>();
- auto address_space = core::AddressSpace::kUndefined;
- auto access = core::Access::kUndefined;
-
if (is_ep && !param->Type()->Is<core::type::Struct>()) {
// ICE likely indicates that the ShaderIO transform was not run, or a builtin
// parameter was added after it was run.
@@ -265,7 +263,8 @@
out << "inout ";
}
}
- EmitTypeAndName(out, param->Type(), address_space, access, NameOf(param));
+ EmitTypeAndName(out, param->Type(), core::AddressSpace::kUndefined,
+ core::Access::kUndefined, NameOf(param));
}
out << ") {";
@@ -469,38 +468,39 @@
}
void EmitGlobalVar(const core::ir::Var* var) {
- auto* ptr = var->Result(0)->Type()->As<core::type::Pointer>();
- TINT_ASSERT(ptr);
+ Switch(
+ var->Result(0)->Type(), //
+ [&](const hlsl::type::ByteAddressBuffer* buf) { EmitStorageVariable(var, buf); },
+ [&](const core::type::Pointer* ptr) {
+ auto space = ptr->AddressSpace();
- auto space = ptr->AddressSpace();
+ switch (space) {
+ case core::AddressSpace::kUniform:
+ EmitUniformVariable(var);
+ break;
+ case core::AddressSpace::kHandle:
+ EmitHandleVariable(var);
+ break;
+ case core::AddressSpace::kPrivate: {
+ auto out = Line();
+ out << "static ";
+ EmitVar(out, var);
+ break;
+ }
+ case core::AddressSpace::kWorkgroup: {
+ auto out = Line();
- switch (space) {
- case core::AddressSpace::kUniform:
- EmitUniformVariable(var);
- break;
- case core::AddressSpace::kStorage:
- EmitStorageVariable(var);
- break;
- case core::AddressSpace::kHandle:
- EmitHandleVariable(var);
- break;
- case core::AddressSpace::kPrivate: {
- auto out = Line();
- out << "static ";
- EmitVar(out, var);
- break;
- }
- case core::AddressSpace::kWorkgroup: {
- auto out = Line();
- out << "groupshared ";
- EmitVar(out, var);
- break;
- }
- case core::AddressSpace::kPushConstant:
- default: {
- TINT_ICE() << "unhandled address space " << space;
- }
- }
+ out << "groupshared ";
+ EmitVar(out, var);
+ break;
+ }
+ case core::AddressSpace::kPushConstant:
+ default: {
+ TINT_ICE() << "unhandled address space " << space;
+ }
+ }
+ },
+ TINT_ICE_ON_NO_MATCH);
}
void EmitUniformVariable(const core::ir::Var* var) {
@@ -524,18 +524,15 @@
Line() << "};";
}
- void EmitStorageVariable(const core::ir::Var* var) {
- auto* ptr = var->Result(0)->Type()->As<core::type::Pointer>();
- TINT_ASSERT(ptr);
-
+ void EmitStorageVariable(const core::ir::Var* var, const hlsl::type::ByteAddressBuffer* buf) {
auto out = Line();
- EmitTypeAndName(out, var->Result(0)->Type(), core::AddressSpace::kStorage, ptr->Access(),
+ EmitTypeAndName(out, var->Result(0)->Type(), core::AddressSpace::kStorage, buf->Access(),
NameOf(var->Result(0)));
auto bp = var->BindingPoint();
TINT_ASSERT(bp.has_value());
- out << RegisterAndSpace(ptr->Access() == core::Access::kRead ? 't' : 'u', bp.value())
+ out << RegisterAndSpace(buf->Access() == core::Access::kRead ? 't' : 'u', bp.value())
<< ";";
}
@@ -648,13 +645,30 @@
[&](const core::ir::Var* var) { out << NameOf(var->Result(0)); }, //
[&](const hlsl::ir::BuiltinCall* c) { EmitHlslBuiltinCall(out, c); }, //
- [&](const hlsl::ir::Ternary* t) { EmitTernary(out, t); }, //
+ [&](const hlsl::ir::Ternary* t) { EmitTernary(out, t); },
+ [&](const hlsl::ir::MemberBuiltinCall* mbc) {
+ EmitHlslMemberBuiltinCall(out, mbc);
+ },
TINT_ICE_ON_NO_MATCH);
},
[&](const core::ir::FunctionParam* p) { out << NameOf(p); }, //
TINT_ICE_ON_NO_MATCH);
}
+ void EmitHlslMemberBuiltinCall(StringStream& out, const hlsl::ir::MemberBuiltinCall* c) {
+ EmitValue(out, c->Object());
+ out << "." << c->Func() << "(";
+ bool needs_comma = false;
+ for (const auto* arg : c->Args()) {
+ if (needs_comma) {
+ out << ", ";
+ }
+ EmitValue(out, arg);
+ needs_comma = true;
+ }
+ out << ")";
+ }
+
void EmitTernary(StringStream& out, const hlsl::ir::Ternary* t) {
out << "((";
EmitValue(out, t->Cmp());
diff --git a/src/tint/lang/hlsl/writer/raise/BUILD.bazel b/src/tint/lang/hlsl/writer/raise/BUILD.bazel
index 3f7094e..428e2fc 100644
--- a/src/tint/lang/hlsl/writer/raise/BUILD.bazel
+++ b/src/tint/lang/hlsl/writer/raise/BUILD.bazel
@@ -66,6 +66,7 @@
"//src/tint/lang/hlsl",
"//src/tint/lang/hlsl/intrinsic",
"//src/tint/lang/hlsl/ir",
+ "//src/tint/lang/hlsl/type",
"//src/tint/lang/hlsl/writer/common",
"//src/tint/utils/containers",
"//src/tint/utils/diagnostic",
diff --git a/src/tint/lang/hlsl/writer/raise/BUILD.cmake b/src/tint/lang/hlsl/writer/raise/BUILD.cmake
index 5d38ac3..f211265 100644
--- a/src/tint/lang/hlsl/writer/raise/BUILD.cmake
+++ b/src/tint/lang/hlsl/writer/raise/BUILD.cmake
@@ -65,6 +65,7 @@
tint_lang_hlsl
tint_lang_hlsl_intrinsic
tint_lang_hlsl_ir
+ tint_lang_hlsl_type
tint_lang_hlsl_writer_common
tint_utils_containers
tint_utils_diagnostic
diff --git a/src/tint/lang/hlsl/writer/raise/BUILD.gn b/src/tint/lang/hlsl/writer/raise/BUILD.gn
index 4a4d467..9a106f2 100644
--- a/src/tint/lang/hlsl/writer/raise/BUILD.gn
+++ b/src/tint/lang/hlsl/writer/raise/BUILD.gn
@@ -69,6 +69,7 @@
"${tint_src_dir}/lang/hlsl",
"${tint_src_dir}/lang/hlsl/intrinsic",
"${tint_src_dir}/lang/hlsl/ir",
+ "${tint_src_dir}/lang/hlsl/type",
"${tint_src_dir}/lang/hlsl/writer/common",
"${tint_src_dir}/utils/containers",
"${tint_src_dir}/utils/diagnostic",
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc b/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
index 883c4ed..0a2609b 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
@@ -27,15 +27,18 @@
#include "src/tint/lang/hlsl/writer/raise/decompose_memory_access.h"
-#include "src/tint/lang/core/ir/block.h"
#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/lang/hlsl/builtin_fn.h"
+#include "src/tint/lang/hlsl/ir/member_builtin_call.h"
+#include "src/tint/lang/hlsl/type/byte_address_buffer.h"
#include "src/tint/utils/result/result.h"
namespace tint::hlsl::writer::raise {
namespace {
-using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
/// PIMPL state for the transform.
struct State {
@@ -49,7 +52,104 @@
core::type::Manager& ty{ir.Types()};
/// Process the module.
- void Process() {}
+ void Process() {
+ Vector<core::ir::Var*, 4> var_worklist;
+ for (auto* inst : *ir.root_block) {
+ // Allow this to run before or after PromoteInitializers by handling non-var root_block
+ // entries
+ auto* var = inst->As<core::ir::Var>();
+ if (!var) {
+ continue;
+ }
+
+ // Var must be a pointer
+ auto* var_ty = var->Result(0)->Type()->As<core::type::Pointer>();
+ TINT_ASSERT(var_ty);
+
+ // Only care about storage address space variables.
+ if (var_ty->AddressSpace() != core::AddressSpace::kStorage) {
+ continue;
+ }
+
+ var_worklist.Push(var);
+ }
+
+ for (auto* var : var_worklist) {
+ auto* var_ty = var->Result(0)->Type()->As<core::type::Pointer>();
+
+ core::type::Type* buf_type =
+ ty.Get<hlsl::type::ByteAddressBuffer>(var_ty->StoreType(), var_ty->Access());
+
+ // Swap the result type of the `var` to the new HLSL result type
+ auto* result = var->Result(0);
+ result->SetType(buf_type);
+
+ // Find all the usages of the `var` which is loading or storing.
+ Vector<core::ir::Instruction*, 4> usage_worklist;
+ for (auto& usage : result->Usages()) {
+ Switch(
+ usage->instruction, //
+ [&](core::ir::LoadVectorElement* lve) { usage_worklist.Push(lve); },
+ [&](core::ir::StoreVectorElement* sve) { usage_worklist.Push(sve); }, //
+
+ [&](core::ir::Store* st) { usage_worklist.Push(st); }, //
+ [&](core::ir::Load* ld) { usage_worklist.Push(ld); } //
+ );
+ }
+
+ for (auto* inst : usage_worklist) {
+ Switch(
+ inst, //
+ [&](core::ir::LoadVectorElement* lve) {
+ // Converts to:
+ //
+ // %1:u32 = v.Load 0u
+ // %b:f32 = bitcast %1
+
+ auto* idx_value = lve->Index()->As<core::ir::Constant>();
+ TINT_ASSERT(idx_value);
+
+ uint32_t pos = idx_value->Value()->ValueAs<uint32_t>() *
+ var_ty->StoreType()->DeepestElement()->Size();
+
+ auto* builtin = b.MemberCall<hlsl::ir::MemberBuiltinCall>(
+ ty.u32(), BuiltinFn::kLoad, var, u32(pos));
+
+ auto* cast = b.Bitcast(lve->Result(0)->Type(), builtin->Result(0));
+ lve->Result(0)->ReplaceAllUsesWith(cast->Result(0));
+
+ builtin->InsertBefore(lve);
+ cast->InsertBefore(lve);
+ lve->Destroy();
+ },
+ [&](core::ir::StoreVectorElement* sve) {
+ // Converts to:
+ //
+ // %1 = <sve->Value()>
+ // %2:u32 = bitcast %1
+ // %3:void = v.Store 0u, %2
+
+ auto* idx_value = sve->Index()->As<core::ir::Constant>();
+ TINT_ASSERT(idx_value);
+
+ uint32_t pos = idx_value->Value()->ValueAs<uint32_t>() *
+ var_ty->StoreType()->DeepestElement()->Size();
+
+ auto* cast = b.Bitcast(ty.u32(), sve->Value());
+ auto* builtin = b.MemberCall<hlsl::ir::MemberBuiltinCall>(
+ ty.void_(), BuiltinFn::kStore, var, u32(pos), cast);
+
+ cast->InsertBefore(sve);
+ builtin->InsertBefore(sve);
+ sve->Destroy();
+ },
+
+ [&](core::ir::Store*) {}, //
+ [&](core::ir::Load*) {} //
+ );
+ }
+ }
+ }
};
} // namespace
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc b/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
index 08df8a2..940976d 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
@@ -61,5 +61,121 @@
EXPECT_EQ(expect, str());
}
+TEST_F(HlslWriterDecomposeMemoryAccessTest, VectorLoad) {
+ auto* var = b.Var<storage, vec4<f32>, core::Access::kRead>("v");
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.LoadVectorElement(var, 0_u));
+ b.Let("b", b.LoadVectorElement(var, 1_u));
+ b.Let("c", b.LoadVectorElement(var, 2_u));
+ b.Let("d", b.LoadVectorElement(var, 3_u));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+$B1: { # root
+ %v:ptr<storage, vec4<f32>, read> = var
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:f32 = load_vector_element %v, 0u
+ %a:f32 = let %3
+ %5:f32 = load_vector_element %v, 1u
+ %b:f32 = let %5
+ %7:f32 = load_vector_element %v, 2u
+ %c:f32 = let %7
+ %9:f32 = load_vector_element %v, 3u
+ %d:f32 = let %9
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+$B1: { # root
+ %v:hlsl.byte_address_buffer<vec4<f32>, read> = var
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:u32 = %v.Load 0u
+ %4:f32 = bitcast %3
+ %a:f32 = let %4
+ %6:u32 = %v.Load 4u
+ %7:f32 = bitcast %6
+ %b:f32 = let %7
+ %9:u32 = %v.Load 8u
+ %10:f32 = bitcast %9
+ %c:f32 = let %10
+ %12:u32 = %v.Load 12u
+ %13:f32 = bitcast %12
+ %d:f32 = let %13
+ ret
+ }
+}
+)";
+
+ Run(DecomposeMemoryAccess);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterDecomposeMemoryAccessTest, VectorStore) {
+ auto* var = b.Var<storage, vec4<f32>, core::Access::kReadWrite>("v");
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.StoreVectorElement(var, 0_u, 2_f);
+ b.StoreVectorElement(var, 1_u, 4_f);
+ b.StoreVectorElement(var, 2_u, 8_f);
+ b.StoreVectorElement(var, 3_u, 16_f);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+$B1: { # root
+ %v:ptr<storage, vec4<f32>, read_write> = var
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ store_vector_element %v, 0u, 2.0f
+ store_vector_element %v, 1u, 4.0f
+ store_vector_element %v, 2u, 8.0f
+ store_vector_element %v, 3u, 16.0f
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+$B1: { # root
+ %v:hlsl.byte_address_buffer<vec4<f32>, read_write> = var
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:u32 = bitcast 2.0f
+ %4:void = %v.Store 0u, %3
+ %5:u32 = bitcast 4.0f
+ %6:void = %v.Store 4u, %5
+ %7:u32 = bitcast 8.0f
+ %8:void = %v.Store 8u, %7
+ %9:u32 = bitcast 16.0f
+ %10:void = %v.Store 12u, %9
+ ret
+ }
+}
+)";
+
+ Run(DecomposeMemoryAccess);
+ EXPECT_EQ(expect, str());
+}
+
} // namespace
} // namespace tint::hlsl::writer::raise