Import Tint changes from Dawn
Changes:
- d7b8efa431ad32b22fa48177ba85def3435b11ad [ir][spirv-writer] VarForDynamicIndex transform by James Price <jrprice@google.com>
- 7967019cfd7f8b5095e5466798cc90bcd08d69c5 [tint][utils] Add Offset and Truncate to Slice by James Price <jrprice@google.com>
- 117d4f387ec570b52887e0a447e7d8d404b02a7a [ir][spirv-writer] BlockDecoratedStructs transform by James Price <jrprice@google.com>
- 2f7c540ec52185a55e4e4b39b69f9e24373824e9 [tint][utils] Allow assignment of Slice to Vector by James Price <jrprice@google.com>
- b984643f1c0ae6e3758a10d7aa543616fbfa1281 [ir] Emit disassembly from the validator by dan sinclair <dsinclair@chromium.org>
- 7b6fd80c3576b9f040fd4b9f06d19f8cfe9e486f [tint][ir] Add ControlInstruction by Ben Clayton <bclayton@google.com>
- aa78dcd4c94ec2aa474ff1a2d26416ad313d3716 [ir] Emit struct declarations in the disassembly by James Price <jrprice@google.com>
- 879e948d69a7fa598cff4c31737a69f3ae617074 [tint][ir] Add Initializer block to ir::Loop by Ben Clayton <bclayton@google.com>
- 79899ef43883484062ceeef0209f50def7b3386c [tint][ir] Add Front() and Back() accessors to Block by Ben Clayton <bclayton@google.com>
- b165466d8e957ea33bb57127f2ae9237e7082b0e [tint][utils] Change return type of 'Vector::Slice() const' by Ben Clayton <bclayton@google.com>
- 7207e7108305b258e40dbdf6385b07eeef067fa6 [ir] Add SetOperand() to ir::Instruction by James Price <jrprice@google.com>
- e040ca685af2925b7bb0673b16c88ff5ed4071bd [ir] Add Value::RemoveUsage() to remove a usage by James Price <jrprice@google.com>
- 8e2ca44835f5d8e7f1e223033094f16d816e5a81 [tint][utils] Add Hasher specialization for VectorRef by Ben Clayton <bclayton@google.com>
- 2114131d0f5ac6bb64bf2ef241851982937f68f4 [ir] Track the operand index for value usages by James Price <jrprice@google.com>
- ae8a9f9f55c471f5136d11fb96277ff854e7a42f [ir] Add OperandInstruction class by James Price <jrprice@google.com>
- 03b708ad0ae42f365fd3d03f66c06e38b8623747 [ir][to_program]: Implement unary & binary ops by Ben Clayton <bclayton@google.com>
- f2c55d792dfc635a4570d370a71278dbf31da1d2 [ir][to_program]: Implement fn parameters by Ben Clayton <bclayton@google.com>
- efdcbdfc899944adddc3a72ea3e8ae37728e1774 [ir][spirv-writer] Emit access instructions by James Price <jrprice@google.com>
- 0a99719ad014347d96d3e76b5571370ef7dc1af8 [tint][utils] Add casting constructor to Slice by Ben Clayton <bclayton@google.com>
- cff18f3ec89141714c1e93ab325a290f21b7bc3e [ir] Add validator to spirv tests. by dan sinclair <dsinclair@chromium.org>
- fb0466e046fea4f27dcc233b40cf07c3eeb0b39b [ir] Stub out validator checks. by dan sinclair <dsinclair@chromium.org>
- 5baa4b41c352cdff713a68a922ce08a5bf9640f8 [ir][to_program]: Simplify error handling by Ben Clayton <bclayton@google.com>
- eefa7fe22c1e63d6a04f79074b0307e432f03e20 [ir] Add guardrails to IR classes. by dan sinclair <dsinclair@chromium.org>
GitOrigin-RevId: d7b8efa431ad32b22fa48177ba85def3435b11ad
Change-Id: I4afa1d1c244e73e860162cc133e5e17189ae8193
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/136120
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: Ben Clayton <bclayton@chromium.org>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index afd91d6..bd06461 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -521,6 +521,10 @@
sources = [
"ir/transform/add_empty_entry_point.cc",
"ir/transform/add_empty_entry_point.h",
+ "ir/transform/block_decorated_structs.cc",
+ "ir/transform/block_decorated_structs.h",
+ "ir/transform/var_for_dynamic_index.cc",
+ "ir/transform/var_for_dynamic_index.h",
]
deps = [
":libtint_builtins_src",
@@ -1236,6 +1240,8 @@
"ir/construct.h",
"ir/continue.cc",
"ir/continue.h",
+ "ir/control_instruction.cc",
+ "ir/control_instruction.h",
"ir/convert.cc",
"ir/convert.h",
"ir/disassembler.cc",
@@ -1265,6 +1271,8 @@
"ir/module.h",
"ir/next_iteration.cc",
"ir/next_iteration.h",
+ "ir/operand_instruction.cc",
+ "ir/operand_instruction.h",
"ir/return.cc",
"ir/return.h",
"ir/store.cc",
@@ -1810,7 +1818,9 @@
tint_unittests_source_set("tint_unittests_ir_transform_src") {
sources = [
"ir/transform/add_empty_entry_point_test.cc",
+ "ir/transform/block_decorated_structs_test.cc",
"ir/transform/test_helper.h",
+ "ir/transform/var_for_dynamic_index_test.cc",
]
deps = [
@@ -1978,6 +1988,7 @@
if (tint_build_ir) {
sources += [
+ "writer/spirv/ir/generator_impl_ir_access_test.cc",
"writer/spirv/ir/generator_impl_ir_binary_test.cc",
"writer/spirv/ir/generator_impl_ir_builtin_test.cc",
"writer/spirv/ir/generator_impl_ir_constant_test.cc",
@@ -2285,11 +2296,21 @@
if (tint_build_ir) {
tint_unittests_source_set("tint_unittests_ir_src") {
sources = [
+ "ir/access_test.cc",
"ir/binary_test.cc",
"ir/bitcast_test.cc",
+ "ir/block_param_test.cc",
"ir/block_test.cc",
+ "ir/break_if_test.cc",
+ "ir/builtin_test.cc",
"ir/constant_test.cc",
+ "ir/construct_test.cc",
+ "ir/continue_test.cc",
+ "ir/convert_test.cc",
"ir/discard_test.cc",
+ "ir/exit_if_test.cc",
+ "ir/exit_loop_test.cc",
+ "ir/exit_switch_test.cc",
"ir/from_program_accessor_test.cc",
"ir/from_program_binary_test.cc",
"ir/from_program_builtin_test.cc",
@@ -2301,16 +2322,25 @@
"ir/from_program_test.cc",
"ir/from_program_unary_test.cc",
"ir/from_program_var_test.cc",
+ "ir/function_param_test.cc",
+ "ir/function_test.cc",
+ "ir/if_test.cc",
"ir/instruction_test.cc",
"ir/ir_test_helper.h",
"ir/load_test.cc",
+ "ir/loop_test.cc",
"ir/module_test.cc",
+ "ir/next_iteration_test.cc",
"ir/program_test_helper.h",
+ "ir/return_test.cc",
"ir/store_test.cc",
+ "ir/switch_test.cc",
+ "ir/swizzle_test.cc",
"ir/to_program_roundtrip_test.cc",
- "ir/transform/add_empty_entry_point_test.cc",
"ir/unary_test.cc",
+ "ir/user_call_test.cc",
"ir/validate_test.cc",
+ "ir/var_test.cc",
]
deps = [
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index cdcd22a..764be39 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -743,6 +743,8 @@
ir/construct.h
ir/continue.cc
ir/continue.h
+ ir/control_instruction.cc
+ ir/control_instruction.h
ir/convert.cc
ir/convert.h
ir/disassembler.cc
@@ -774,6 +776,8 @@
ir/module.h
ir/next_iteration.cc
ir/next_iteration.h
+ ir/operand_instruction.cc
+ ir/operand_instruction.h
ir/return.cc
ir/return.h
ir/store.cc
@@ -796,8 +800,12 @@
ir/var.h
ir/transform/add_empty_entry_point.cc
ir/transform/add_empty_entry_point.h
+ ir/transform/block_decorated_structs.cc
+ ir/transform/block_decorated_structs.h
ir/transform/transform.cc
ir/transform/transform.h
+ ir/transform/var_for_dynamic_index.cc
+ ir/transform/var_for_dynamic_index.h
)
endif()
@@ -1267,6 +1275,7 @@
if(${TINT_BUILD_IR})
list(APPEND TINT_TEST_SRCS
+ writer/spirv/ir/generator_impl_ir_access_test.cc
writer/spirv/ir/generator_impl_ir_binary_test.cc
writer/spirv/ir/generator_impl_ir_builtin_test.cc
writer/spirv/ir/generator_impl_ir_constant_test.cc
@@ -1488,11 +1497,21 @@
if (${TINT_BUILD_IR})
list(APPEND TINT_TEST_SRCS
+ ir/access_test.cc
ir/binary_test.cc
ir/bitcast_test.cc
+ ir/block_param_test.cc
ir/block_test.cc
+ ir/break_if_test.cc
+ ir/builtin_test.cc
ir/constant_test.cc
+ ir/construct_test.cc
+ ir/continue_test.cc
+ ir/convert_test.cc
ir/discard_test.cc
+ ir/exit_if_test.cc
+ ir/exit_loop_test.cc
+ ir/exit_switch_test.cc
ir/from_program_accessor_test.cc
ir/from_program_binary_test.cc
ir/from_program_builtin_test.cc
@@ -1504,15 +1523,27 @@
ir/from_program_test.cc
ir/from_program_unary_test.cc
ir/from_program_var_test.cc
+ ir/function_param_test.cc
+ ir/function_test.cc
+ ir/if_test.cc
ir/instruction_test.cc
ir/ir_test_helper.h
ir/load_test.cc
+ ir/loop_test.cc
ir/module_test.cc
+ ir/next_iteration_test.cc
ir/program_test_helper.h
+ ir/return_test.cc
ir/store_test.cc
+ ir/switch_test.cc
+ ir/swizzle_test.cc
ir/transform/add_empty_entry_point_test.cc
+ ir/transform/block_decorated_structs_test.cc
+ ir/transform/var_for_dynamic_index_test.cc
ir/unary_test.cc
+ ir/user_call_test.cc
ir/validate_test.cc
+ ir/var_test.cc
)
endif()
diff --git a/src/tint/ir/access.cc b/src/tint/ir/access.cc
index d88ded8..8f2a793 100644
--- a/src/tint/ir/access.cc
+++ b/src/tint/ir/access.cc
@@ -24,8 +24,13 @@
//! @cond Doxygen_Suppress
Access::Access(const type::Type* ty, Value* object, utils::VectorRef<Value*> indices)
- : result_type_(ty), object_(object), indices_(std::move(indices)) {
- object_->AddUsage(this);
+ : result_type_(ty) {
+ TINT_ASSERT(IR, object);
+ TINT_ASSERT(IR, result_type_);
+ TINT_ASSERT(IR, !indices.IsEmpty());
+
+ AddOperand(object);
+ AddOperands(std::move(indices));
}
Access::~Access() = default;
diff --git a/src/tint/ir/access.h b/src/tint/ir/access.h
index ea03d65..55f7306 100644
--- a/src/tint/ir/access.h
+++ b/src/tint/ir/access.h
@@ -15,13 +15,13 @@
#ifndef SRC_TINT_IR_ACCESS_H_
#define SRC_TINT_IR_ACCESS_H_
-#include "src/tint/ir/instruction.h"
+#include "src/tint/ir/operand_instruction.h"
#include "src/tint/utils/castable.h"
namespace tint::ir {
/// An access instruction in the IR.
-class Access : public utils::Castable<Access, Instruction> {
+class Access : public utils::Castable<Access, OperandInstruction<3>> {
public:
/// Constructor
/// @param result_type the result type
@@ -34,15 +34,18 @@
const type::Type* Type() const override { return result_type_; }
/// @returns the object used for the access
- Value* Object() const { return object_; }
+ Value* Object() const { return operands_[0]; }
/// @returns the accessor indices
- utils::VectorRef<Value*> Indices() const { return indices_; }
+ utils::Slice<Value const* const> Indices() const {
+ return operands_.Slice().Offset(1).Reinterpret<Value const* const>();
+ }
+
+ /// @returns the accessor indices
+ utils::Slice<Value*> Indices() { return operands_.Slice().Offset(1); }
private:
const type::Type* result_type_ = nullptr;
- Value* object_ = nullptr;
- utils::Vector<Value*, 1> indices_;
};
} // namespace tint::ir
diff --git a/src/tint/ir/access_test.cc b/src/tint/ir/access_test.cc
new file mode 100644
index 0000000..9e09930
--- /dev/null
+++ b/src/tint/ir/access_test.cc
@@ -0,0 +1,87 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/access.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using IR_AccessTest = IRTestHelper;
+
+TEST_F(IR_AccessTest, SetsUsage) {
+ auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite);
+ auto* var = b.Declare(ty);
+ auto* idx = b.Constant(u32(1));
+ auto* a = b.Access(mod.Types().i32(), var, utils::Vector{idx});
+
+ EXPECT_THAT(var->Usages(), testing::UnorderedElementsAre(Usage{a, 0u}));
+ EXPECT_THAT(idx->Usages(), testing::UnorderedElementsAre(Usage{a, 1u}));
+}
+
+TEST_F(IR_AccessTest, Fail_NullType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite);
+ auto* var = b.Declare(ty);
+ b.Access(nullptr, var, utils::Vector{b.Constant(u32(1))});
+ },
+ "");
+}
+
+TEST_F(IR_AccessTest, Fail_NullObject) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Access(mod.Types().i32(), nullptr, utils::Vector{b.Constant(u32(1))});
+ },
+ "");
+}
+
+TEST_F(IR_AccessTest, Fail_EmptyIndices) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite);
+ auto* var = b.Declare(ty);
+ b.Access(mod.Types().i32(), var, utils::Empty);
+ },
+ "");
+}
+
+TEST_F(IR_AccessTest, Fail_NullIndex) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite);
+ auto* var = b.Declare(ty);
+ b.Access(mod.Types().i32(), var, utils::Vector<Value*, 1>{nullptr});
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/binary.cc b/src/tint/ir/binary.cc
index 2b179ac..6287715 100644
--- a/src/tint/ir/binary.cc
+++ b/src/tint/ir/binary.cc
@@ -20,11 +20,13 @@
namespace tint::ir {
Binary::Binary(enum Kind kind, const type::Type* res_ty, Value* lhs, Value* rhs)
- : kind_(kind), result_type_(res_ty), lhs_(lhs), rhs_(rhs) {
+ : kind_(kind), result_type_(res_ty) {
+ TINT_ASSERT(IR, result_type_);
TINT_ASSERT(IR, lhs);
TINT_ASSERT(IR, rhs);
- lhs_->AddUsage(this);
- rhs_->AddUsage(this);
+
+ AddOperand(lhs);
+ AddOperand(rhs);
}
Binary::~Binary() = default;
diff --git a/src/tint/ir/binary.h b/src/tint/ir/binary.h
index 3de078a..0998c85 100644
--- a/src/tint/ir/binary.h
+++ b/src/tint/ir/binary.h
@@ -15,13 +15,13 @@
#ifndef SRC_TINT_IR_BINARY_H_
#define SRC_TINT_IR_BINARY_H_
-#include "src/tint/ir/instruction.h"
+#include "src/tint/ir/operand_instruction.h"
#include "src/tint/utils/castable.h"
namespace tint::ir {
-/// An instruction in the IR.
-class Binary : public utils::Castable<Binary, Instruction> {
+/// A binary instruction in the IR.
+class Binary : public utils::Castable<Binary, OperandInstruction<2>> {
public:
/// The kind of instruction.
enum class Kind {
@@ -61,16 +61,14 @@
const type::Type* Type() const override { return result_type_; }
/// @returns the left-hand-side value for the instruction
- const Value* LHS() const { return lhs_; }
+ const Value* LHS() const { return operands_[0]; }
/// @returns the right-hand-side value for the instruction
- const Value* RHS() const { return rhs_; }
+ const Value* RHS() const { return operands_[1]; }
private:
enum Kind kind_;
- const type::Type* result_type_;
- Value* lhs_;
- Value* rhs_;
+ const type::Type* result_type_ = nullptr;
};
} // namespace tint::ir
diff --git a/src/tint/ir/binary_test.cc b/src/tint/ir/binary_test.cc
index be78a9d..9bdaba0 100644
--- a/src/tint/ir/binary_test.cc
+++ b/src/tint/ir/binary_test.cc
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
#include "src/tint/ir/builder.h"
#include "src/tint/ir/instruction.h"
#include "src/tint/ir/ir_test_helper.h"
@@ -23,6 +25,36 @@
using IR_BinaryTest = IRTestHelper;
+TEST_F(IR_BinaryTest, Fail_NullType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Add(nullptr, b.Constant(u32(1)), b.Constant(u32(2)));
+ },
+ "");
+}
+
+TEST_F(IR_BinaryTest, Fail_NullLHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Add(mod.Types().u32(), nullptr, b.Constant(u32(2)));
+ },
+ "");
+}
+
+TEST_F(IR_BinaryTest, Fail_NullRHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Add(mod.Types().u32(), b.Constant(u32(1)), nullptr);
+ },
+ "");
+}
+
TEST_F(IR_BinaryTest, CreateAnd) {
const auto* inst = b.And(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
@@ -314,29 +346,41 @@
}
TEST_F(IR_BinaryTest, Binary_Usage) {
- const auto* inst = b.And(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.And(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
EXPECT_EQ(inst->Kind(), Binary::Kind::kAnd);
ASSERT_NE(inst->LHS(), nullptr);
- ASSERT_EQ(inst->LHS()->Usage().Length(), 1u);
- EXPECT_EQ(inst->LHS()->Usage()[0], inst);
+ EXPECT_THAT(inst->LHS()->Usages(), testing::UnorderedElementsAre(Usage{inst, 0u}));
ASSERT_NE(inst->RHS(), nullptr);
- ASSERT_EQ(inst->RHS()->Usage().Length(), 1u);
- EXPECT_EQ(inst->RHS()->Usage()[0], inst);
+ EXPECT_THAT(inst->RHS()->Usages(), testing::UnorderedElementsAre(Usage{inst, 1u}));
}
TEST_F(IR_BinaryTest, Binary_Usage_DuplicateValue) {
auto val = b.Constant(4_i);
- const auto* inst = b.And(mod.Types().i32(), val, val);
+ auto* inst = b.And(mod.Types().i32(), val, val);
EXPECT_EQ(inst->Kind(), Binary::Kind::kAnd);
ASSERT_EQ(inst->LHS(), inst->RHS());
ASSERT_NE(inst->LHS(), nullptr);
- ASSERT_EQ(inst->LHS()->Usage().Length(), 1u);
- EXPECT_EQ(inst->LHS()->Usage()[0], inst);
+ EXPECT_THAT(inst->LHS()->Usages(),
+ testing::UnorderedElementsAre(Usage{inst, 0u}, Usage{inst, 1u}));
+}
+
+TEST_F(IR_BinaryTest, Binary_Usage_SetOperand) {
+ auto* rhs_a = b.Constant(2_i);
+ auto* rhs_b = b.Constant(3_i);
+ auto* inst = b.And(mod.Types().i32(), b.Constant(4_i), rhs_a);
+
+ EXPECT_EQ(inst->Kind(), Binary::Kind::kAnd);
+
+ EXPECT_THAT(rhs_a->Usages(), testing::UnorderedElementsAre(Usage{inst, 1u}));
+ EXPECT_THAT(rhs_b->Usages(), testing::UnorderedElementsAre());
+ inst->SetOperand(1, rhs_b);
+ EXPECT_THAT(rhs_a->Usages(), testing::UnorderedElementsAre());
+ EXPECT_THAT(rhs_b->Usages(), testing::UnorderedElementsAre(Usage{inst, 1u}));
}
} // namespace
diff --git a/src/tint/ir/bitcast.cc b/src/tint/ir/bitcast.cc
index 455bc69..a219749 100644
--- a/src/tint/ir/bitcast.cc
+++ b/src/tint/ir/bitcast.cc
@@ -19,7 +19,11 @@
namespace tint::ir {
-Bitcast::Bitcast(const type::Type* ty, Value* val) : Base(ty, utils::Vector{val}) {}
+Bitcast::Bitcast(const type::Type* ty, Value* val) : Base(ty) {
+ TINT_ASSERT(IR, val);
+
+ AddOperand(val);
+}
Bitcast::~Bitcast() = default;
diff --git a/src/tint/ir/bitcast_test.cc b/src/tint/ir/bitcast_test.cc
index dd9209f..e8ed473 100644
--- a/src/tint/ir/bitcast_test.cc
+++ b/src/tint/ir/bitcast_test.cc
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
#include "src/tint/ir/builder.h"
#include "src/tint/ir/constant.h"
#include "src/tint/ir/instruction.h"
@@ -39,13 +41,32 @@
}
TEST_F(IR_BitcastTest, Bitcast_Usage) {
- const auto* inst = b.Bitcast(mod.Types().i32(), b.Constant(4_i));
+ auto* inst = b.Bitcast(mod.Types().i32(), b.Constant(4_i));
const auto args = inst->Args();
ASSERT_EQ(args.Length(), 1u);
ASSERT_NE(args[0], nullptr);
- ASSERT_EQ(args[0]->Usage().Length(), 1u);
- EXPECT_EQ(args[0]->Usage()[0], inst);
+ EXPECT_THAT(args[0]->Usages(), testing::UnorderedElementsAre(Usage{inst, 0u}));
+}
+
+TEST_F(IR_BitcastTest, Fail_NullValue) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Bitcast(mod.Types().i32(), nullptr);
+ },
+ "");
+}
+
+TEST_F(IR_BitcastTest, Fail_NullType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Bitcast(nullptr, b.Constant(u32(1)));
+ },
+ "");
}
} // namespace
diff --git a/src/tint/ir/block.cc b/src/tint/ir/block.cc
index 8942ea9..b545f34 100644
--- a/src/tint/ir/block.cc
+++ b/src/tint/ir/block.cc
@@ -22,6 +22,22 @@
Block::~Block() = default;
+void Block::SetParams(utils::VectorRef<const BlockParam*> params) {
+ params_ = std::move(params);
+
+ for (auto* param : params_) {
+ TINT_ASSERT(IR, param != nullptr);
+ }
+}
+
+void Block::AddInboundBranch(ir::Branch* node) {
+ TINT_ASSERT(IR, node != nullptr);
+
+ if (node) {
+ inbound_branches_.Push(node);
+ }
+}
+
void Block::Prepend(Instruction* inst) {
TINT_ASSERT_OR_RETURN(IR, inst);
TINT_ASSERT_OR_RETURN(IR, inst->Block() == nullptr);
diff --git a/src/tint/ir/block.h b/src/tint/ir/block.h
index 6b1697a..1da6fb9 100644
--- a/src/tint/ir/block.h
+++ b/src/tint/ir/block.h
@@ -22,6 +22,11 @@
#include "src/tint/ir/instruction.h"
#include "src/tint/utils/vector.h"
+// Forward declarations
+namespace tint::ir {
+class ControlInstruction;
+} // namespace tint::ir
+
namespace tint::ir {
/// A block of statements. The instructions in the block are a linear list of instructions to
@@ -92,6 +97,12 @@
/// @returns the ending iterator
Iterator end() const { return Iterator{nullptr}; }
+ /// @returns the first instruction in the instruction list
+ Instruction* Front() const { return instructions_.first; }
+
+ /// @returns the last instruction in the instruction list
+ Instruction* Back() const { return instructions_.last; }
+
/// Adds the instruction to the beginning of the block
/// @param inst the instruction to add
void Prepend(Instruction* inst);
@@ -122,7 +133,7 @@
/// Sets the params to the block
/// @param params the params for the block
- void SetParams(utils::VectorRef<const BlockParam*> params) { params_ = std::move(params); }
+ void SetParams(utils::VectorRef<const BlockParam*> params);
/// @return the parameters passed into the block
utils::VectorRef<const BlockParam*> Params() const { return params_; }
/// @returns the params to the block
@@ -133,7 +144,13 @@
/// Adds the given node to the inbound branches
/// @param node the node to add
- void AddInboundBranch(ir::Branch* node) { inbound_branches_.Push(node); }
+ void AddInboundBranch(ir::Branch* node);
+
+ /// @return the parent instruction that owns this block
+ ControlInstruction* Parent() const { return parent_; }
+
+ /// @param parent the parent instruction that owns this block
+ void SetParent(ControlInstruction* parent) { parent_ = parent; }
private:
struct {
@@ -150,6 +167,8 @@
/// - Node is a merge target outside control flow (e.g. an if that returns in both branches)
/// - Node is a continue target outside control flow (e.g. a loop that returns)
utils::Vector<ir::Branch*, 2> inbound_branches_;
+
+ ControlInstruction* parent_ = nullptr;
};
} // namespace tint::ir
diff --git a/src/tint/ir/block_param.cc b/src/tint/ir/block_param.cc
index a3a0be5..99aa827 100644
--- a/src/tint/ir/block_param.cc
+++ b/src/tint/ir/block_param.cc
@@ -18,7 +18,9 @@
namespace tint::ir {
-BlockParam::BlockParam(const type::Type* ty) : type_(ty) {}
+BlockParam::BlockParam(const type::Type* ty) : type_(ty) {
+ TINT_ASSERT(IR, type_ != nullptr);
+}
BlockParam::~BlockParam() = default;
diff --git a/src/tint/ir/block_param.h b/src/tint/ir/block_param.h
index 386ea4d..0a90ba8 100644
--- a/src/tint/ir/block_param.h
+++ b/src/tint/ir/block_param.h
@@ -33,7 +33,7 @@
private:
/// the result type of the instruction
- const type::Type* type_;
+ const type::Type* type_ = nullptr;
};
} // namespace tint::ir
diff --git a/src/tint/ir/block_param_test.cc b/src/tint/ir/block_param_test.cc
new file mode 100644
index 0000000..9f4141a
--- /dev/null
+++ b/src/tint/ir/block_param_test.cc
@@ -0,0 +1,36 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/block_param.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_BlockParamTest = IRTestHelper;
+
+TEST_F(IR_BlockParamTest, Fail_NullType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.BlockParam(nullptr);
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/block_test.cc b/src/tint/ir/block_test.cc
index ce48026..8a5308f 100644
--- a/src/tint/ir/block_test.cc
+++ b/src/tint/ir/block_test.cc
@@ -14,13 +14,94 @@
#include "src/tint/ir/block.h"
#include "gtest/gtest-spi.h"
+#include "src/tint/ir/block_param.h"
#include "src/tint/ir/ir_test_helper.h"
namespace tint::ir {
namespace {
+using namespace tint::number_suffixes; // NOLINT
using IR_BlockTest = IRTestHelper;
+TEST_F(IR_BlockTest, HasBranchTarget_Empty) {
+ auto* blk = b.CreateBlock();
+ EXPECT_FALSE(blk->HasBranchTarget());
+}
+
+TEST_F(IR_BlockTest, HasBranchTarget_NoBranch) {
+ auto* blk = b.CreateBlock();
+ blk->Append(b.Add(mod.Types().i32(), b.Constant(1_u), b.Constant(2_u)));
+ EXPECT_FALSE(blk->HasBranchTarget());
+}
+
+TEST_F(IR_BlockTest, HasBranchTarget_BreakIf) {
+ auto* blk = b.CreateBlock();
+ auto* loop = b.CreateLoop();
+ blk->Append(b.BreakIf(b.Constant(true), loop));
+ EXPECT_TRUE(blk->HasBranchTarget());
+}
+
+TEST_F(IR_BlockTest, HasBranchTarget_Continue) {
+ auto* blk = b.CreateBlock();
+ auto* loop = b.CreateLoop();
+ blk->Append(b.Continue(loop));
+ EXPECT_TRUE(blk->HasBranchTarget());
+}
+
+TEST_F(IR_BlockTest, HasBranchTarget_ExitIf) {
+ auto* blk = b.CreateBlock();
+ auto* if_ = b.CreateIf(b.Constant(true));
+ blk->Append(b.ExitIf(if_));
+ EXPECT_TRUE(blk->HasBranchTarget());
+}
+
+TEST_F(IR_BlockTest, HasBranchTarget_ExitLoop) {
+ auto* blk = b.CreateBlock();
+ auto* loop = b.CreateLoop();
+ blk->Append(b.ExitLoop(loop));
+ EXPECT_TRUE(blk->HasBranchTarget());
+}
+
+TEST_F(IR_BlockTest, HasBranchTarget_ExitSwitch) {
+ auto* blk = b.CreateBlock();
+ auto* s = b.CreateSwitch(b.Constant(1_u));
+ blk->Append(b.ExitSwitch(s));
+ EXPECT_TRUE(blk->HasBranchTarget());
+}
+
+TEST_F(IR_BlockTest, HasBranchTarget_If) {
+ auto* blk = b.CreateBlock();
+ blk->Append(b.CreateIf(b.Constant(true)));
+ EXPECT_TRUE(blk->HasBranchTarget());
+}
+
+TEST_F(IR_BlockTest, HasBranchTarget_Loop) {
+ auto* blk = b.CreateBlock();
+ blk->Append(b.CreateLoop());
+ EXPECT_TRUE(blk->HasBranchTarget());
+}
+
+TEST_F(IR_BlockTest, HasBranchTarget_NextIteration) {
+ auto* blk = b.CreateBlock();
+ auto* loop = b.CreateLoop();
+ blk->Append(b.NextIteration(loop));
+ EXPECT_TRUE(blk->HasBranchTarget());
+}
+
+TEST_F(IR_BlockTest, HasBranchTarget_Return) {
+ auto* f = b.CreateFunction("myFunc", mod.Types().void_());
+
+ auto* blk = b.CreateBlock();
+ blk->Append(b.Return(f));
+ EXPECT_TRUE(blk->HasBranchTarget());
+}
+
+TEST_F(IR_BlockTest, HasBranchTarget_Switch) {
+ auto* blk = b.CreateBlock();
+ blk->Append(b.CreateSwitch(b.Constant(true)));
+ EXPECT_TRUE(blk->HasBranchTarget());
+}
+
TEST_F(IR_BlockTest, SetInstructions) {
auto* inst1 = b.CreateLoop();
auto* inst2 = b.CreateLoop();
@@ -658,5 +739,29 @@
"internal compiler error");
}
+TEST_F(IR_BlockTest, Fail_NullBlockParam) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+
+ auto* blk = b.CreateBlock();
+ blk->SetParams(utils::Vector<const BlockParam*, 1>{nullptr});
+ },
+ "");
+}
+
+TEST_F(IR_BlockTest, Fail_NullInboundBranch) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+
+ auto* blk = b.CreateBlock();
+ blk->AddInboundBranch(nullptr);
+ },
+ "");
+}
+
} // namespace
} // namespace tint::ir
diff --git a/src/tint/ir/branch.cc b/src/tint/ir/branch.cc
index 2c59c2c..a7ea495 100644
--- a/src/tint/ir/branch.cc
+++ b/src/tint/ir/branch.cc
@@ -16,19 +16,10 @@
#include <utility>
-#include "src/tint/ir/block.h"
-
TINT_INSTANTIATE_TYPEINFO(tint::ir::Branch);
namespace tint::ir {
-Branch::Branch(utils::VectorRef<Value*> args) : args_(std::move(args)) {
- for (auto* arg : args) {
- TINT_ASSERT(IR, arg);
- arg->AddUsage(this);
- }
-}
-
Branch::~Branch() = default;
} // namespace tint::ir
diff --git a/src/tint/ir/branch.h b/src/tint/ir/branch.h
index 5c926e9..f7dfd12 100644
--- a/src/tint/ir/branch.h
+++ b/src/tint/ir/branch.h
@@ -15,7 +15,7 @@
#ifndef SRC_TINT_IR_BRANCH_H_
#define SRC_TINT_IR_BRANCH_H_
-#include "src/tint/ir/instruction.h"
+#include "src/tint/ir/operand_instruction.h"
#include "src/tint/ir/value.h"
#include "src/tint/utils/castable.h"
@@ -27,20 +27,12 @@
namespace tint::ir {
/// A branch instruction.
-class Branch : public utils::Castable<Branch, Instruction> {
+class Branch : public utils::Castable<Branch, OperandInstruction<1>> {
public:
~Branch() override;
/// @returns the branch arguments
- utils::VectorRef<Value*> Args() const { return args_; }
-
- protected:
- /// Constructor
- /// @param args the branch arguments
- explicit Branch(utils::VectorRef<Value*> args);
-
- private:
- utils::Vector<Value*, 2> args_;
+ virtual utils::Slice<Value const* const> Args() const { return operands_.Slice(); }
};
} // namespace tint::ir
diff --git a/src/tint/ir/break_if.cc b/src/tint/ir/break_if.cc
index c17ef54..644502b 100644
--- a/src/tint/ir/break_if.cc
+++ b/src/tint/ir/break_if.cc
@@ -16,6 +16,7 @@
#include <utility>
+#include "src/tint/ir/block.h"
#include "src/tint/ir/loop.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::BreakIf);
@@ -25,13 +26,16 @@
BreakIf::BreakIf(Value* condition,
ir::Loop* loop,
utils::VectorRef<Value*> args /* = utils::Empty */)
- : Base(std::move(args)), condition_(condition), loop_(loop) {
- TINT_ASSERT(IR, condition_);
+ : loop_(loop) {
+ TINT_ASSERT(IR, condition);
TINT_ASSERT(IR, loop_);
- condition_->AddUsage(this);
- loop_->AddUsage(this);
- loop_->Body()->AddInboundBranch(this);
- loop_->Merge()->AddInboundBranch(this);
+
+ AddOperand(condition);
+ if (loop_) {
+ loop_->Body()->AddInboundBranch(this);
+ loop_->Merge()->AddInboundBranch(this);
+ }
+ AddOperands(std::move(args));
}
BreakIf::~BreakIf() = default;
diff --git a/src/tint/ir/break_if.h b/src/tint/ir/break_if.h
index 9769378..b673cd8 100644
--- a/src/tint/ir/break_if.h
+++ b/src/tint/ir/break_if.h
@@ -36,14 +36,18 @@
BreakIf(Value* condition, ir::Loop* loop, utils::VectorRef<Value*> args = utils::Empty);
~BreakIf() override;
+ /// @returns the branch arguments
+ utils::Slice<Value const* const> Args() const override {
+ return operands_.Slice().Offset(1).Reinterpret<Value const* const>();
+ }
+
/// @returns the break condition
- const Value* Condition() const { return condition_; }
+ const Value* Condition() const { return operands_[0]; }
/// @returns the loop containing the break-if
const ir::Loop* Loop() const { return loop_; }
private:
- Value* condition_ = nullptr;
ir::Loop* loop_ = nullptr;
};
diff --git a/src/tint/ir/break_if_test.cc b/src/tint/ir/break_if_test.cc
new file mode 100644
index 0000000..96363f1
--- /dev/null
+++ b/src/tint/ir/break_if_test.cc
@@ -0,0 +1,71 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/break_if.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_BreakIfTest = IRTestHelper;
+
+TEST_F(IR_BreakIfTest, Usage) {
+ auto* loop = b.CreateLoop();
+ auto* cond = b.Constant(true);
+ auto* arg1 = b.Constant(1_u);
+ auto* arg2 = b.Constant(2_u);
+
+ auto* brk = b.BreakIf(cond, loop, utils::Vector{arg1, arg2});
+
+ EXPECT_THAT(cond->Usages(), testing::UnorderedElementsAre(Usage{brk, 0u}));
+ EXPECT_THAT(arg1->Usages(), testing::UnorderedElementsAre(Usage{brk, 1u}));
+ EXPECT_THAT(arg2->Usages(), testing::UnorderedElementsAre(Usage{brk, 2u}));
+}
+
+TEST_F(IR_BreakIfTest, Fail_NullCondition) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.BreakIf(nullptr, b.CreateLoop());
+ },
+ "");
+}
+
+TEST_F(IR_BreakIfTest, Fail_NullLoop) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.BreakIf(b.Constant(true), nullptr);
+ },
+ "");
+}
+
+TEST_F(IR_BreakIfTest, Fail_NullArg) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.BreakIf(b.Constant(true), b.CreateLoop(), utils::Vector<Value*, 1>{nullptr});
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/builder.cc b/src/tint/ir/builder.cc
index b93cb94..fdf11b9 100644
--- a/src/tint/ir/builder.cc
+++ b/src/tint/ir/builder.cc
@@ -41,8 +41,6 @@
const type::Type* return_type,
Function::PipelineStage stage,
std::optional<std::array<uint32_t, 3>> wg_size) {
- TINT_ASSERT(IR, return_type);
-
auto* ir_func = ir.values.Create<Function>(return_type, stage, wg_size);
ir_func->SetStartTarget(CreateBlock());
ir.SetName(ir_func, name);
@@ -50,12 +48,11 @@
}
If* Builder::CreateIf(Value* condition) {
- TINT_ASSERT(IR, condition);
return ir.values.Create<If>(condition, CreateBlock(), CreateBlock(), CreateBlock());
}
-Loop* Builder::CreateLoop(utils::VectorRef<Value*> args /* = utils::Empty */) {
- return ir.values.Create<Loop>(CreateBlock(), CreateBlock(), CreateBlock(), std::move(args));
+Loop* Builder::CreateLoop() {
+ return ir.values.Create<Loop>(CreateBlock(), CreateBlock(), CreateBlock(), CreateBlock());
}
Switch* Builder::CreateSwitch(Value* condition) {
@@ -67,6 +64,7 @@
Block* b = s->Cases().Back().Start();
b->AddInboundBranch(s);
+ b->SetParent(s);
return b;
}
@@ -162,7 +160,7 @@
}
ir::Discard* Builder::Discard() {
- return ir.values.Create<ir::Discard>();
+ return ir.values.Create<ir::Discard>(ir.Types().void_());
}
ir::UserCall* Builder::UserCall(const type::Type* type,
@@ -188,8 +186,17 @@
}
ir::Load* Builder::Load(Value* from) {
+ TINT_ASSERT(IR, from != nullptr);
+ if (from == nullptr) {
+ return nullptr;
+ }
+
auto* ptr = from->Type()->As<type::Pointer>();
- TINT_ASSERT(IR, ptr);
+ TINT_ASSERT(IR, ptr != nullptr);
+ if (ptr == nullptr) {
+ return nullptr;
+ }
+
return ir.values.Create<ir::Load>(ptr->StoreType(), from);
}
diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h
index 9a554cd..4849743 100644
--- a/src/tint/ir/builder.h
+++ b/src/tint/ir/builder.h
@@ -87,9 +87,8 @@
If* CreateIf(Value* condition);
/// Creates a loop flow node
- /// @param args the branch arguments
/// @returns the flow node
- Loop* CreateLoop(utils::VectorRef<Value*> args = utils::Empty);
+ Loop* CreateLoop();
/// Creates a switch flow node
/// @param condition the switch condition
@@ -294,7 +293,9 @@
/// @param func the function being called
/// @param args the call arguments
/// @returns the instruction
- ir::UserCall* UserCall(const type::Type* type, Function* func, utils::VectorRef<Value*> args);
+ ir::UserCall* UserCall(const type::Type* type,
+ Function* func,
+ utils::VectorRef<Value*> args = utils::Empty);
/// Creates a value conversion instruction
/// @param to the type converted to
@@ -309,7 +310,7 @@
/// @param to the type being converted
/// @param args the arguments to be converted
/// @returns the instruction
- ir::Construct* Construct(const type::Type* to, utils::VectorRef<Value*> args);
+ ir::Construct* Construct(const type::Type* to, utils::VectorRef<Value*> args = utils::Empty);
/// Creates a builtin call instruction
/// @param type the return type
@@ -318,7 +319,7 @@
/// @returns the instruction
ir::Builtin* Builtin(const type::Type* type,
builtin::Function func,
- utils::VectorRef<Value*> args);
+ utils::VectorRef<Value*> args = utils::Empty);
/// Creates a load instruction
/// @param from the expression being loaded from
diff --git a/src/tint/ir/builtin.cc b/src/tint/ir/builtin.cc
index fac0793..c571168 100644
--- a/src/tint/ir/builtin.cc
+++ b/src/tint/ir/builtin.cc
@@ -24,7 +24,11 @@
namespace tint::ir {
Builtin::Builtin(const type::Type* ty, builtin::Function func, utils::VectorRef<Value*> arguments)
- : Base(ty, std::move(arguments)), func_(func) {}
+ : Base(ty), func_(func) {
+ TINT_ASSERT(IR, func != builtin::Function::kNone);
+ TINT_ASSERT(IR, func != builtin::Function::kTintMaterialize);
+ AddOperands(std::move(arguments));
+}
Builtin::~Builtin() = default;
diff --git a/src/tint/ir/builtin.h b/src/tint/ir/builtin.h
index ee096bb..628a04f 100644
--- a/src/tint/ir/builtin.h
+++ b/src/tint/ir/builtin.h
@@ -21,14 +21,16 @@
namespace tint::ir {
-/// A value conversion instruction in the IR.
+/// A builtin call instruction in the IR.
class Builtin : public utils::Castable<Builtin, Call> {
public:
/// Constructor
/// @param res_type the result type
/// @param func the builtin function
/// @param args the conversion arguments
- Builtin(const type::Type* res_type, builtin::Function func, utils::VectorRef<Value*> args);
+ Builtin(const type::Type* res_type,
+ builtin::Function func,
+ utils::VectorRef<Value*> args = utils::Empty);
~Builtin() override;
/// @returns the builtin function
diff --git a/src/tint/ir/builtin_test.cc b/src/tint/ir/builtin_test.cc
new file mode 100644
index 0000000..d5cf04a
--- /dev/null
+++ b/src/tint/ir/builtin_test.cc
@@ -0,0 +1,78 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/block_param.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_BuiltinTest = IRTestHelper;
+
+TEST_F(IR_BuiltinTest, Usage) {
+ auto* arg1 = b.Constant(1_u);
+ auto* arg2 = b.Constant(2_u);
+ auto* builtin =
+ b.Builtin(mod.Types().f32(), builtin::Function::kAbs, utils::Vector{arg1, arg2});
+
+ EXPECT_THAT(arg1->Usages(), testing::UnorderedElementsAre(Usage{builtin, 0u}));
+ EXPECT_THAT(arg2->Usages(), testing::UnorderedElementsAre(Usage{builtin, 1u}));
+}
+
+TEST_F(IR_BuiltinTest, Fail_NullType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Builtin(nullptr, builtin::Function::kAbs);
+ },
+ "");
+}
+
+TEST_F(IR_BuiltinTest, Fail_NoneFunction) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Builtin(mod.Types().f32(), builtin::Function::kNone);
+ },
+ "");
+}
+
+TEST_F(IR_BuiltinTest, Fail_TintMaterializeFunction) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Builtin(mod.Types().f32(), builtin::Function::kTintMaterialize);
+ },
+ "");
+}
+
+TEST_F(IR_BuiltinTest, Fail_NullArg) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Builtin(mod.Types().f32(), builtin::Function::kAbs,
+ utils::Vector<Value*, 1>{nullptr});
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/call.cc b/src/tint/ir/call.cc
index 909079f..30174d4 100644
--- a/src/tint/ir/call.cc
+++ b/src/tint/ir/call.cc
@@ -20,11 +20,8 @@
namespace tint::ir {
-Call::Call(const type::Type* res_ty, utils::VectorRef<Value*> arguments)
- : result_type_(res_ty), args_(std::move(arguments)) {
- for (auto* arg : args_) {
- arg->AddUsage(this);
- }
+Call::Call(const type::Type* res_ty) : result_type_(res_ty) {
+ TINT_ASSERT(IR, result_type_);
}
Call::~Call() = default;
diff --git a/src/tint/ir/call.h b/src/tint/ir/call.h
index c933331..669aaf5 100644
--- a/src/tint/ir/call.h
+++ b/src/tint/ir/call.h
@@ -15,13 +15,13 @@
#ifndef SRC_TINT_IR_CALL_H_
#define SRC_TINT_IR_CALL_H_
-#include "src/tint/ir/instruction.h"
+#include "src/tint/ir/operand_instruction.h"
#include "src/tint/utils/castable.h"
namespace tint::ir {
/// A Call instruction in the IR.
-class Call : public utils::Castable<Call, Instruction> {
+class Call : public utils::Castable<Call, OperandInstruction<4>> {
public:
~Call() override;
@@ -29,19 +29,17 @@
const type::Type* Type() const override { return result_type_; }
/// @returns the call arguments
- utils::VectorRef<Value*> Args() const { return args_; }
+ virtual utils::Slice<Value const* const> Args() const { return operands_.Slice(); }
protected:
/// Constructor
Call() = delete;
/// Constructor
/// @param result_type the result type
- /// @param args the constructor arguments
- Call(const type::Type* result_type, utils::VectorRef<Value*> args);
+ explicit Call(const type::Type* result_type);
private:
- const type::Type* result_type_;
- utils::Vector<Value*, 1> args_;
+ const type::Type* result_type_ = nullptr;
};
} // namespace tint::ir
diff --git a/src/tint/ir/constant.cc b/src/tint/ir/constant.cc
index a49d140..6ae678a 100644
--- a/src/tint/ir/constant.cc
+++ b/src/tint/ir/constant.cc
@@ -18,7 +18,9 @@
namespace tint::ir {
-Constant::Constant(const constant::Value* val) : value_(val) {}
+Constant::Constant(const constant::Value* val) : value_(val) {
+ TINT_ASSERT(IR, value_);
+}
Constant::~Constant() = default;
diff --git a/src/tint/ir/constant.h b/src/tint/ir/constant.h
index 2df7250..cf9ed40 100644
--- a/src/tint/ir/constant.h
+++ b/src/tint/ir/constant.h
@@ -35,7 +35,7 @@
const type::Type* Type() const override { return value_->Type(); }
private:
- const constant::Value* const value_;
+ const constant::Value* const value_ = nullptr;
};
} // namespace tint::ir
diff --git a/src/tint/ir/constant_test.cc b/src/tint/ir/constant_test.cc
index 67fd5e7..5339e41 100644
--- a/src/tint/ir/constant_test.cc
+++ b/src/tint/ir/constant_test.cc
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "gtest/gtest-spi.h"
#include "src/tint/ir/builder.h"
#include "src/tint/ir/ir_test_helper.h"
#include "src/tint/ir/value.h"
@@ -96,5 +97,19 @@
}
}
+TEST_F(IR_ConstantTest, Fail_NullValue) {
+ EXPECT_FATAL_FAILURE({ Constant c(nullptr); }, "");
+}
+
+TEST_F(IR_ConstantTest, Fail_Builder_NullValue) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Constant(nullptr);
+ },
+ "");
+}
+
} // namespace
} // namespace tint::ir
diff --git a/src/tint/ir/construct.cc b/src/tint/ir/construct.cc
index 18e830d..c4a330e 100644
--- a/src/tint/ir/construct.cc
+++ b/src/tint/ir/construct.cc
@@ -22,8 +22,9 @@
namespace tint::ir {
-Construct::Construct(const type::Type* ty, utils::VectorRef<Value*> arguments)
- : Base(ty, std::move(arguments)) {}
+Construct::Construct(const type::Type* ty, utils::VectorRef<Value*> arguments) : Base(ty) {
+ AddOperands(std::move(arguments));
+}
Construct::~Construct() = default;
diff --git a/src/tint/ir/construct.h b/src/tint/ir/construct.h
index f4da78d..4a15849 100644
--- a/src/tint/ir/construct.h
+++ b/src/tint/ir/construct.h
@@ -26,7 +26,7 @@
/// Constructor
/// @param type the result type
/// @param args the constructor arguments
- Construct(const type::Type* type, utils::VectorRef<Value*> args);
+ explicit Construct(const type::Type* type, utils::VectorRef<Value*> args = utils::Empty);
~Construct() override;
};
diff --git a/src/tint/ir/construct_test.cc b/src/tint/ir/construct_test.cc
new file mode 100644
index 0000000..cdab810
--- /dev/null
+++ b/src/tint/ir/construct_test.cc
@@ -0,0 +1,57 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/construct.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_ConstructTest = IRTestHelper;
+
+TEST_F(IR_ConstructTest, Usage) {
+ auto* arg1 = b.Constant(true);
+ auto* arg2 = b.Constant(false);
+ auto* c = b.Construct(mod.Types().f32(), utils::Vector{arg1, arg2});
+
+ EXPECT_THAT(arg1->Usages(), testing::UnorderedElementsAre(Usage{c, 0u}));
+ EXPECT_THAT(arg2->Usages(), testing::UnorderedElementsAre(Usage{c, 1u}));
+}
+
+TEST_F(IR_ConstructTest, Fail_NullType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Construct(nullptr);
+ },
+ "");
+}
+
+TEST_F(IR_ConstructTest, Fail_NullArg) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Construct(mod.Types().f32(), utils::Vector<Value*, 1>{nullptr});
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/continue.cc b/src/tint/ir/continue.cc
index 7edebf8..a7282ac 100644
--- a/src/tint/ir/continue.cc
+++ b/src/tint/ir/continue.cc
@@ -16,6 +16,7 @@
#include <utility>
+#include "src/tint/ir/block.h"
#include "src/tint/ir/loop.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::Continue);
@@ -23,10 +24,13 @@
namespace tint::ir {
Continue::Continue(ir::Loop* loop, utils::VectorRef<Value*> args /* = utils::Empty */)
- : Base(std::move(args)), loop_(loop) {
+ : loop_(loop) {
TINT_ASSERT(IR, loop_);
- loop_->AddUsage(this);
- loop_->Continuing()->AddInboundBranch(this);
+
+ if (loop_) {
+ loop_->Continuing()->AddInboundBranch(this);
+ }
+ AddOperands(std::move(args));
}
Continue::~Continue() = default;
diff --git a/src/tint/ir/continue_test.cc b/src/tint/ir/continue_test.cc
new file mode 100644
index 0000000..631d696
--- /dev/null
+++ b/src/tint/ir/continue_test.cc
@@ -0,0 +1,59 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/continue.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_ContinueTest = IRTestHelper;
+
+TEST_F(IR_ContinueTest, Usage) {
+ auto* loop = b.CreateLoop();
+ auto* arg1 = b.Constant(1_u);
+ auto* arg2 = b.Constant(2_u);
+
+ auto* brk = b.Continue(loop, utils::Vector{arg1, arg2});
+
+ EXPECT_THAT(arg1->Usages(), testing::UnorderedElementsAre(Usage{brk, 0u}));
+ EXPECT_THAT(arg2->Usages(), testing::UnorderedElementsAre(Usage{brk, 1u}));
+}
+
+TEST_F(IR_ContinueTest, Fail_NullLoop) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Continue(nullptr);
+ },
+ "");
+}
+
+TEST_F(IR_ContinueTest, Fail_NullArg) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Continue(b.CreateLoop(), utils::Vector<Value*, 1>{nullptr});
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/control_instruction.cc b/src/tint/ir/control_instruction.cc
new file mode 100644
index 0000000..efcac64
--- /dev/null
+++ b/src/tint/ir/control_instruction.cc
@@ -0,0 +1,23 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/control_instruction.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::ControlInstruction);
+
+namespace tint::ir {
+
+ControlInstruction::~ControlInstruction() = default;
+
+} // namespace tint::ir
diff --git a/src/tint/ir/control_instruction.h b/src/tint/ir/control_instruction.h
new file mode 100644
index 0000000..cd0e7d3
--- /dev/null
+++ b/src/tint/ir/control_instruction.h
@@ -0,0 +1,32 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_CONTROL_INSTRUCTION_H_
+#define SRC_TINT_IR_CONTROL_INSTRUCTION_H_
+
+#include "src/tint/ir/branch.h"
+
+namespace tint::ir {
+
+/// Base class of instructions that perform branches to two or more blocks, owned by the
+/// ControlInstruction.
+class ControlInstruction : public utils::Castable<ControlInstruction, Branch> {
+ public:
+ /// Destructor
+ ~ControlInstruction() override;
+};
+
+} // namespace tint::ir
+
+#endif // SRC_TINT_IR_CONTROL_INSTRUCTION_H_
diff --git a/src/tint/ir/convert.cc b/src/tint/ir/convert.cc
index edc2fe7..b55243b 100644
--- a/src/tint/ir/convert.cc
+++ b/src/tint/ir/convert.cc
@@ -13,6 +13,9 @@
// limitations under the License.
#include "src/tint/ir/convert.h"
+
+#include <utility>
+
#include "src/tint/debug.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::Convert);
@@ -22,7 +25,11 @@
Convert::Convert(const type::Type* to_type,
const type::Type* from_type,
utils::VectorRef<Value*> arguments)
- : Base(to_type, arguments), from_type_(from_type) {}
+ : Base(to_type), from_type_(from_type) {
+ TINT_ASSERT(IR, from_type_);
+ TINT_ASSERT(IR, !arguments.IsEmpty());
+ AddOperands(std::move(arguments));
+}
Convert::~Convert() = default;
diff --git a/src/tint/ir/convert_test.cc b/src/tint/ir/convert_test.cc
new file mode 100644
index 0000000..d0ef71d
--- /dev/null
+++ b/src/tint/ir/convert_test.cc
@@ -0,0 +1,66 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/convert.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_ConvertTest = IRTestHelper;
+
+TEST_F(IR_ConvertTest, Fail_NullToType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Convert(nullptr, mod.Types().f32(), utils::Vector{b.Constant(1_u)});
+ },
+ "");
+}
+
+TEST_F(IR_ConvertTest, Fail_NullFromType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Convert(mod.Types().f32(), nullptr, utils::Vector{b.Constant(1_u)});
+ },
+ "");
+}
+
+TEST_F(IR_ConvertTest, Fail_NoArg) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Convert(mod.Types().f32(), mod.Types().i32(), utils::Empty);
+ },
+ "");
+}
+
+TEST_F(IR_ConvertTest, Fail_NullArg) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Convert(mod.Types().f32(), mod.Types().i32(), utils::Vector<Value*, 1>{nullptr});
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc
index 65e1524..9f7aa8d 100644
--- a/src/tint/ir/disassembler.cc
+++ b/src/tint/ir/disassembler.cc
@@ -39,11 +39,14 @@
#include "src/tint/ir/store.h"
#include "src/tint/ir/switch.h"
#include "src/tint/ir/swizzle.h"
+#include "src/tint/ir/transform/block_decorated_structs.h"
#include "src/tint/ir/user_call.h"
#include "src/tint/ir/var.h"
#include "src/tint/switch.h"
+#include "src/tint/type/struct.h"
#include "src/tint/type/type.h"
#include "src/tint/utils/scoped_assignment.h"
+#include "src/tint/utils/string.h"
namespace tint::ir {
namespace {
@@ -71,6 +74,12 @@
return out_;
}
+void Disassembler::EmitLine() {
+ out_ << std::endl;
+ current_output_line_ += 1;
+ current_output_start_pos_ = out_.tellp();
+}
+
void Disassembler::EmitBlockInstructions(const Block* b) {
for (const auto* inst : *b) {
Indent();
@@ -93,11 +102,22 @@
});
}
+Source::Location Disassembler::MakeCurrentLocation() {
+ return Source::Location{current_output_line_, out_.tellp() - current_output_start_pos_ + 1};
+}
+
std::string Disassembler::Disassemble() {
+ for (auto* ty : mod_.Types()) {
+ if (auto* str = ty->As<type::Struct>()) {
+ EmitStructDecl(str);
+ }
+ }
+
if (mod_.root_block) {
- Indent() << "# Root block" << std::endl;
+ Indent() << "# Root block";
+ EmitLine();
WalkInternal(mod_.root_block);
- out_ << std::endl;
+ EmitLine();
}
for (auto* func : mod_.functions) {
@@ -111,29 +131,28 @@
return;
}
visited_.Add(blk);
-
- // If this block is dead, nothing to do
- if (!blk->HasBranchTarget()) {
- return;
- }
-
WalkInternal(blk);
}
void Disassembler::WalkInternal(const Block* blk) {
+ SourceMarker sm(this);
Indent() << "%b" << IdOf(blk) << " = block";
if (!blk->Params().IsEmpty()) {
out_ << " (";
- EmitValueList(blk->Params());
+ EmitValueList(blk->Params().Slice());
out_ << ")";
}
- out_ << " {" << std::endl;
+ out_ << " {";
+ EmitLine();
{
ScopedIndent si(indent_size_);
EmitBlockInstructions(blk);
}
- Indent() << "}" << std::endl;
+ Indent() << "}";
+ sm.Store(blk);
+
+ EmitLine();
}
void Disassembler::EmitBindingPoint(BindingPoint p) {
@@ -249,13 +268,15 @@
EmitReturnAttributes(func);
- out_ << " -> %b" << IdOf(func->StartTarget()) << " {" << std::endl;
+ out_ << " -> %b" << IdOf(func->StartTarget()) << " {";
+ EmitLine();
{
ScopedIndent si(indent_size_);
Walk(func->StartTarget());
}
- Indent() << "}" << std::endl;
+ Indent() << "}";
+ EmitLine();
}
void Disassembler::EmitValueWithType(const Value* val) {
@@ -319,6 +340,12 @@
[&](Default) { out_ << "Unknown value: " << val->TypeInfo().name; });
}
+void Disassembler::EmitInstructionName(std::string_view name, const Instruction* inst) {
+ SourceMarker sm(this);
+ out_ << name;
+ sm.Store(inst);
+}
+
void Disassembler::EmitInstruction(const Instruction* inst) {
tint::Switch(
inst, //
@@ -329,54 +356,71 @@
[&](const ir::Unary* u) { EmitUnary(u); },
[&](const ir::Bitcast* b) {
EmitValueWithType(b);
- out_ << " = bitcast ";
+ out_ << " = ";
+ EmitInstructionName("bitcast", b);
+ out_ << " ";
EmitArgs(b);
- out_ << std::endl;
+ EmitLine();
},
- [&](const ir::Discard*) { out_ << "discard" << std::endl; },
+ [&](const ir::Discard* d) {
+ EmitInstructionName("discard", d);
+ EmitLine();
+ },
[&](const ir::Builtin* b) {
EmitValueWithType(b);
- out_ << " = " << builtin::str(b->Func()) << " ";
+ out_ << " = ";
+ EmitInstructionName(builtin::str(b->Func()), b);
+ out_ << " ";
EmitArgs(b);
- out_ << std::endl;
+ EmitLine();
},
[&](const ir::Construct* c) {
EmitValueWithType(c);
- out_ << " = construct ";
+ out_ << " = ";
+ EmitInstructionName("construct", c);
+ out_ << " ";
EmitArgs(c);
- out_ << std::endl;
+ EmitLine();
},
[&](const ir::Convert* c) {
EmitValueWithType(c);
- out_ << " = convert " << c->FromType()->FriendlyName() << ", ";
+ out_ << " = ";
+ EmitInstructionName("convert", c);
+ out_ << " " << c->FromType()->FriendlyName() << ", ";
EmitArgs(c);
- out_ << std::endl;
+ EmitLine();
},
[&](const ir::Load* l) {
EmitValueWithType(l);
- out_ << " = load ";
+ out_ << " = ";
+ EmitInstructionName("load", l);
+ out_ << " ";
EmitValue(l->From());
- out_ << std::endl;
+ EmitLine();
},
[&](const ir::Store* s) {
- out_ << "store ";
+ EmitInstructionName("store", s);
+ out_ << " ";
EmitValue(s->To());
out_ << ", ";
EmitValue(s->From());
- out_ << std::endl;
+ EmitLine();
},
[&](const ir::UserCall* uc) {
EmitValueWithType(uc);
- out_ << " = call %" << IdOf(uc->Func());
+ out_ << " = ";
+ EmitInstructionName("call", uc);
+ out_ << " %" << IdOf(uc->Func());
if (!uc->Args().IsEmpty()) {
out_ << ", ";
}
EmitArgs(uc);
- out_ << std::endl;
+ EmitLine();
},
[&](const ir::Var* v) {
EmitValueWithType(v);
- out_ << " = var";
+ out_ << " = ";
+ EmitInstructionName("var", v);
if (v->Initializer()) {
out_ << ", ";
EmitValue(v->Initializer());
@@ -385,12 +429,13 @@
out_ << " ";
EmitBindingPoint(v->BindingPoint().value());
}
-
- out_ << std::endl;
+ EmitLine();
},
[&](const ir::Access* a) {
EmitValueWithType(a);
- out_ << " = access ";
+ out_ << " = ";
+ EmitInstructionName("access", a);
+ out_ << " ";
EmitValue(a->Object());
out_ << ", ";
for (size_t i = 0; i < a->Indices().Length(); ++i) {
@@ -399,11 +444,13 @@
}
EmitValue(a->Indices()[i]);
}
- out_ << std::endl;
+ EmitLine();
},
[&](const ir::Swizzle* s) {
EmitValueWithType(s);
- out_ << " = swizzle ";
+ out_ << " = ";
+ EmitInstructionName("swizzle", s);
+ out_ << " ";
EmitValue(s->Object());
out_ << ", ";
for (auto idx : s->Indices()) {
@@ -422,15 +469,22 @@
break;
}
}
- out_ << std::endl;
+ EmitLine();
},
[&](const ir::Branch* b) { EmitBranch(b); },
[&](Default) { out_ << "Unknown instruction: " << inst->TypeInfo().name; });
}
+void Disassembler::EmitOperand(const Value* val, const Instruction* inst, uint32_t index) {
+ SourceMarker condMarker(this);
+ EmitValue(val);
+ condMarker.Store(Operand{inst, index});
+}
+
void Disassembler::EmitIf(const If* i) {
+ SourceMarker sm(this);
out_ << "if ";
- EmitValue(i->Condition());
+ EmitOperand(i->Condition(), i, If::kConditionOperandIndex);
bool has_true = i->True()->HasBranchTarget();
bool has_false = i->False()->HasBranchTarget();
@@ -448,61 +502,84 @@
if (i->Merge()->HasBranchTarget()) {
out_ << ", m: %b" << IdOf(i->Merge());
}
- out_ << "]" << std::endl;
+ out_ << "]";
+ sm.Store(i);
+
+ EmitLine();
if (has_true) {
ScopedIndent si(indent_size_);
- Indent() << "# True block" << std::endl;
+ Indent() << "# True block";
+ EmitLine();
+
Walk(i->True());
- out_ << std::endl;
+ EmitLine();
}
if (has_false) {
ScopedIndent si(indent_size_);
- Indent() << "# False block" << std::endl;
+ Indent() << "# False block";
+ EmitLine();
+
Walk(i->False());
- out_ << std::endl;
+ EmitLine();
}
if (i->Merge()->HasBranchTarget()) {
- Indent() << "# Merge block" << std::endl;
+ Indent() << "# Merge block";
+ EmitLine();
Walk(i->Merge());
- out_ << std::endl;
+ EmitLine();
}
}
void Disassembler::EmitLoop(const Loop* l) {
- out_ << "loop [s: %b" << IdOf(l->Body());
+ utils::Vector<std::string, 4> parts;
+ if (l->Initializer()->HasBranchTarget()) {
+ parts.Push("i: %b" + std::to_string(IdOf(l->Initializer())));
+ }
+ if (l->Body()->HasBranchTarget()) {
+ parts.Push("b: %b" + std::to_string(IdOf(l->Body())));
+ }
if (l->Continuing()->HasBranchTarget()) {
- out_ << ", c: %b" << IdOf(l->Continuing());
+ parts.Push("c: %b" + std::to_string(IdOf(l->Continuing())));
}
if (l->Merge()->HasBranchTarget()) {
- out_ << ", m: %b" << IdOf(l->Merge());
+ parts.Push("m: %b" + std::to_string(IdOf(l->Merge())));
}
- out_ << "]";
+ SourceMarker sm(this);
+ out_ << "loop [" << utils::Join(parts, ", ") << "]";
+ sm.Store(l);
+ EmitLine();
- if (!l->Args().IsEmpty()) {
- out_ << " ";
- EmitValueList(l->Args());
- }
-
- out_ << std::endl;
-
- {
+ if (l->Initializer()->HasBranchTarget()) {
ScopedIndent si(indent_size_);
+ Indent() << "# Initializer block";
+ EmitLine();
+ Walk(l->Initializer());
+ EmitLine();
+ }
+
+ if (l->Body()->HasBranchTarget()) {
+ ScopedIndent si(indent_size_);
+ Indent() << "# Body block";
+ EmitLine();
Walk(l->Body());
- out_ << std::endl;
+ EmitLine();
}
if (l->Continuing()->HasBranchTarget()) {
ScopedIndent si(indent_size_);
- Indent() << "# Continuing block" << std::endl;
+ Indent() << "# Continuing block";
+ EmitLine();
Walk(l->Continuing());
- out_ << std::endl;
+ EmitLine();
}
if (l->Merge()->HasBranchTarget()) {
- Indent() << "# Merge block" << std::endl;
+ Indent() << "# Merge block";
+ EmitLine();
+
Walk(l->Merge());
- out_ << std::endl;
+ EmitLine();
}
}
@@ -531,22 +608,28 @@
if (s->Merge()->HasBranchTarget()) {
out_ << ", m: %b" << IdOf(s->Merge());
}
- out_ << "]" << std::endl;
+ out_ << "]";
+ EmitLine();
for (auto& c : s->Cases()) {
ScopedIndent si(indent_size_);
- Indent() << "# Case block" << std::endl;
+ Indent() << "# Case block";
+ EmitLine();
+
Walk(c.Start());
- out_ << std::endl;
+ EmitLine();
}
if (s->Merge()->HasBranchTarget()) {
- Indent() << "# Merge block" << std::endl;
+ Indent() << "# Merge block";
+ EmitLine();
+
Walk(s->Merge());
- out_ << std::endl;
+ EmitLine();
}
}
void Disassembler::EmitBranch(const Branch* b) {
+ SourceMarker sm(this);
tint::Switch(
b, //
[&](const ir::Return*) { out_ << "ret"; },
@@ -570,10 +653,12 @@
out_ << " ";
EmitValueList(b->Args());
}
- out_ << std::endl;
+ sm.Store(b);
+
+ EmitLine();
}
-void Disassembler::EmitValueList(tint::utils::VectorRef<const tint::ir::Value*> values) {
+void Disassembler::EmitValueList(utils::Slice<Value const* const> values) {
for (auto* v : values) {
if (v != values.Front()) {
out_ << ", ";
@@ -643,7 +728,7 @@
EmitValue(b->LHS());
out_ << ", ";
EmitValue(b->RHS());
- out_ << std::endl;
+ EmitLine();
}
void Disassembler::EmitUnary(const Unary* u) {
@@ -659,7 +744,41 @@
}
out_ << " ";
EmitValue(u->Val());
- out_ << std::endl;
+ EmitLine();
+}
+
+void Disassembler::EmitStructDecl(const type::Struct* str) {
+ out_ << str->Name().Name() << " = struct @align(" << str->Align() << ")";
+ if (str->StructFlags().Contains(type::StructFlag::kBlock)) {
+ out_ << ", @block";
+ }
+ out_ << " {";
+ EmitLine();
+ for (auto* member : str->Members()) {
+ out_ << " " << member->Name().Name() << ":" << member->Type()->FriendlyName();
+ out_ << " @offset(" << member->Offset() << ")";
+ if (member->Attributes().invariant) {
+ out_ << ", @invariant";
+ }
+ if (member->Attributes().location.has_value()) {
+ out_ << ", @location(" << member->Attributes().location.value() << ")";
+ }
+ if (member->Attributes().interpolation.has_value()) {
+ auto& interp = member->Attributes().interpolation.value();
+ out_ << ", @interpolate(" << interp.type;
+ if (interp.sampling != builtin::InterpolationSampling::kUndefined) {
+ out_ << ", " << interp.sampling;
+ }
+ out_ << ")";
+ }
+ if (member->Attributes().builtin.has_value()) {
+ out_ << ", @builtin(" << member->Attributes().builtin.value() << ")";
+ }
+ EmitLine();
+ }
+ out_ << "}";
+ EmitLine();
+ EmitLine();
}
} // namespace tint::ir
diff --git a/src/tint/ir/disassembler.h b/src/tint/ir/disassembler.h
index 4d5d928..6c58639 100644
--- a/src/tint/ir/disassembler.h
+++ b/src/tint/ir/disassembler.h
@@ -29,11 +29,40 @@
#include "src/tint/utils/hashset.h"
#include "src/tint/utils/string_stream.h"
+// Forward declarations.
+namespace tint::type {
+class Struct;
+}
+
namespace tint::ir {
/// Helper class to disassemble the IR
class Disassembler {
public:
+ /// An operand used in an instruction
+ struct Operand {
+ /// The instruction
+ const Instruction* instruction = nullptr;
+ /// The operand index
+ uint32_t operand_index = 0u;
+
+ /// A specialization of utils::Hasher for Operand.
+ struct Hasher {
+ /// @param u the operand to hash
+ /// @returns a hash of the operand
+ inline std::size_t operator()(const Operand& u) const {
+ return utils::Hash(u.instruction, u.operand_index);
+ }
+ };
+
+ /// An equality helper for Operand.
+ /// @param other the operand to compare against
+ /// @returns true if the two operands are equal
+ bool operator==(const Operand& other) const {
+ return instruction == other.instruction && operand_index == other.operand_index;
+ }
+ };
+
/// Constructor
/// @param mod the module
explicit Disassembler(const Module& mod);
@@ -50,7 +79,61 @@
/// @returns the string representation
std::string AsString() const { return out_.str(); }
+ /// @param inst the instruction to retrieve
+ /// @returns the source for the instruction
+ Source InstructionSource(const Instruction* inst) {
+ return instruction_to_src_.Get(inst).value_or(Source{});
+ }
+
+ /// @param operand the operand to retrieve
+ /// @returns the source for the operand
+ Source OperandSource(Operand operand) {
+ return operand_to_src_.Get(operand).value_or(Source{});
+ }
+
+ /// @param blk teh block to retrieve
+ /// @returns the source for the block
+ Source BlockSource(const Block* blk) { return block_to_src_.Get(blk).value_or(Source{}); }
+
+ /// Stores the given @p src location for @p inst instruction
+ /// @param inst the instruction to store
+ /// @param src the source location
+ void SetSource(const Instruction* inst, Source src) { instruction_to_src_.Add(inst, src); }
+
+ /// Stores the given @p src location for @p blk block
+ /// @param blk the block to store
+ /// @param src the source location
+ void SetSource(const Block* blk, Source src) { block_to_src_.Add(blk, src); }
+
+ /// Stores the given @p src location for @p op operand
+ /// @param op the operand to store
+ /// @param src the source location
+ void SetSource(Operand op, Source src) { operand_to_src_.Add(op, src); }
+
+ /// @returns the source location for the current emission location
+ Source::Location MakeCurrentLocation();
+
private:
+ class SourceMarker {
+ public:
+ explicit SourceMarker(Disassembler* d) : dis_(d), begin_(dis_->MakeCurrentLocation()) {}
+ ~SourceMarker() = default;
+
+ void Store(const Instruction* inst) { dis_->SetSource(inst, MakeSource()); }
+
+ void Store(const Block* blk) { dis_->SetSource(blk, MakeSource()); }
+
+ void Store(Operand operand) { dis_->SetSource(operand, MakeSource()); }
+
+ Source MakeSource() const {
+ return Source(Source::Range(begin_, dis_->MakeCurrentLocation()));
+ }
+
+ private:
+ Disassembler* dis_ = nullptr;
+ Source::Location begin_;
+ };
+
utils::StringStream& Indent();
size_t IdOf(const Block* blk);
@@ -66,7 +149,7 @@
void EmitInstruction(const Instruction* inst);
void EmitValueWithType(const Value* val);
void EmitValue(const Value* val);
- void EmitValueList(tint::utils::VectorRef<const tint::ir::Value*> values);
+ void EmitValueList(utils::Slice<ir::Value const* const> values);
void EmitArgs(const Call* call);
void EmitBinary(const Binary* b);
void EmitUnary(const Unary* b);
@@ -74,6 +157,10 @@
void EmitSwitch(const Switch* s);
void EmitLoop(const Loop* l);
void EmitIf(const If* i);
+ void EmitStructDecl(const type::Struct* str);
+ void EmitLine();
+ void EmitOperand(const Value* val, const Instruction* inst, uint32_t index);
+ void EmitInstructionName(std::string_view name, const Instruction* inst);
const Module& mod_;
utils::StringStream out_;
@@ -82,6 +169,13 @@
utils::Hashmap<const Value*, std::string, 32> value_ids_;
uint32_t indent_size_ = 0;
bool in_function_ = false;
+
+ uint32_t current_output_line_ = 1;
+ uint32_t current_output_start_pos_ = 0;
+
+ utils::Hashmap<const Block*, Source, 8> block_to_src_;
+ utils::Hashmap<const Instruction*, Source, 8> instruction_to_src_;
+ utils::Hashmap<Operand, Source, 8, Operand::Hasher> operand_to_src_;
};
} // namespace tint::ir
diff --git a/src/tint/ir/discard.cc b/src/tint/ir/discard.cc
index feebb07..ab9a511 100644
--- a/src/tint/ir/discard.cc
+++ b/src/tint/ir/discard.cc
@@ -14,12 +14,17 @@
#include "src/tint/ir/discard.h"
#include "src/tint/debug.h"
+#include "src/tint/type/void.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::Discard);
namespace tint::ir {
-Discard::Discard() : Base(nullptr, utils::Empty) {}
+Discard::Discard(const type::Type* ty) : Base(ty) {
+ if (ty) {
+ TINT_ASSERT(IR, ty->Is<type::Void>());
+ }
+}
Discard::~Discard() = default;
diff --git a/src/tint/ir/discard.h b/src/tint/ir/discard.h
index e87474c..24b9c59 100644
--- a/src/tint/ir/discard.h
+++ b/src/tint/ir/discard.h
@@ -25,7 +25,8 @@
class Discard : public utils::Castable<Discard, Call> {
public:
/// Constructor
- Discard();
+ /// @param ty the type of the discard, must be Void type.
+ explicit Discard(const type::Type* ty);
~Discard() override;
};
diff --git a/src/tint/ir/discard_test.cc b/src/tint/ir/discard_test.cc
index ea264f1..92eb6da 100644
--- a/src/tint/ir/discard_test.cc
+++ b/src/tint/ir/discard_test.cc
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "gtest/gtest-spi.h"
#include "src/tint/ir/builder.h"
#include "src/tint/ir/instruction.h"
#include "src/tint/ir/ir_test_helper.h"
@@ -26,5 +27,18 @@
ASSERT_TRUE(inst->Is<ir::Discard>());
}
+TEST_F(IR_DiscardTest, Fail_NullType) {
+ EXPECT_FATAL_FAILURE({ Discard d(nullptr); }, "");
+}
+
+TEST_F(IR_DiscardTest, Fail_NonVoidType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Discard d(mod.Types().i32());
+ },
+ "");
+}
+
} // namespace
} // namespace tint::ir
diff --git a/src/tint/ir/exit_if.cc b/src/tint/ir/exit_if.cc
index 16fac7f..44a7002 100644
--- a/src/tint/ir/exit_if.cc
+++ b/src/tint/ir/exit_if.cc
@@ -16,17 +16,20 @@
#include <utility>
+#include "src/tint/ir/block.h"
#include "src/tint/ir/if.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::ExitIf);
namespace tint::ir {
-ExitIf::ExitIf(ir::If* i, utils::VectorRef<Value*> args /* = utils::Empty */)
- : Base(std::move(args)), if_(i) {
+ExitIf::ExitIf(ir::If* i, utils::VectorRef<Value*> args /* = utils::Empty */) : if_(i) {
TINT_ASSERT(IR, if_);
- if_->AddUsage(this);
- if_->Merge()->AddInboundBranch(this);
+
+ if (if_) {
+ if_->Merge()->AddInboundBranch(this);
+ }
+ AddOperands(std::move(args));
}
ExitIf::~ExitIf() = default;
diff --git a/src/tint/ir/exit_if_test.cc b/src/tint/ir/exit_if_test.cc
new file mode 100644
index 0000000..13e0cc1
--- /dev/null
+++ b/src/tint/ir/exit_if_test.cc
@@ -0,0 +1,58 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/exit_if.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_ExitIfTest = IRTestHelper;
+
+TEST_F(IR_ExitIfTest, Usage) {
+ auto* arg1 = b.Constant(1_u);
+ auto* arg2 = b.Constant(2_u);
+ auto* if_ = b.CreateIf(b.Constant(true));
+ auto* e = b.ExitIf(if_, utils::Vector{arg1, arg2});
+
+ EXPECT_THAT(arg1->Usages(), testing::UnorderedElementsAre(Usage{e, 0u}));
+ EXPECT_THAT(arg2->Usages(), testing::UnorderedElementsAre(Usage{e, 1u}));
+}
+
+TEST_F(IR_ExitIfTest, Fail_NullIf) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.ExitIf(nullptr);
+ },
+ "");
+}
+
+TEST_F(IR_ExitIfTest, Fail_NullArg) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.ExitIf(b.CreateIf(b.Constant(false)), utils::Vector<Value*, 1>{nullptr});
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/exit_loop.cc b/src/tint/ir/exit_loop.cc
index 0effcfc..729e466 100644
--- a/src/tint/ir/exit_loop.cc
+++ b/src/tint/ir/exit_loop.cc
@@ -16,6 +16,7 @@
#include <utility>
+#include "src/tint/ir/block.h"
#include "src/tint/ir/loop.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::ExitLoop);
@@ -23,10 +24,13 @@
namespace tint::ir {
ExitLoop::ExitLoop(ir::Loop* loop, utils::VectorRef<Value*> args /* = utils::Empty */)
- : Base(std::move(args)), loop_(loop) {
+ : loop_(loop) {
TINT_ASSERT(IR, loop_);
- loop_->AddUsage(this);
- loop_->Merge()->AddInboundBranch(this);
+
+ if (loop_) {
+ loop_->Merge()->AddInboundBranch(this);
+ }
+ AddOperands(std::move(args));
}
ExitLoop::~ExitLoop() = default;
diff --git a/src/tint/ir/exit_loop_test.cc b/src/tint/ir/exit_loop_test.cc
new file mode 100644
index 0000000..11a6b28
--- /dev/null
+++ b/src/tint/ir/exit_loop_test.cc
@@ -0,0 +1,58 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/exit_loop.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_ExitLoopTest = IRTestHelper;
+
+TEST_F(IR_ExitLoopTest, Usage) {
+ auto* arg1 = b.Constant(1_u);
+ auto* arg2 = b.Constant(2_u);
+ auto* loop = b.CreateLoop();
+ auto* e = b.ExitLoop(loop, utils::Vector{arg1, arg2});
+
+ EXPECT_THAT(arg1->Usages(), testing::UnorderedElementsAre(Usage{e, 0u}));
+ EXPECT_THAT(arg2->Usages(), testing::UnorderedElementsAre(Usage{e, 1u}));
+}
+
+TEST_F(IR_ExitLoopTest, Fail_NullLoop) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.ExitLoop(nullptr);
+ },
+ "");
+}
+
+TEST_F(IR_ExitLoopTest, Fail_NullArg) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.ExitLoop(b.CreateLoop(), utils::Vector<Value*, 1>{nullptr});
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/exit_switch.cc b/src/tint/ir/exit_switch.cc
index e9679e5..29c6a24 100644
--- a/src/tint/ir/exit_switch.cc
+++ b/src/tint/ir/exit_switch.cc
@@ -16,6 +16,7 @@
#include <utility>
+#include "src/tint/ir/block.h"
#include "src/tint/ir/switch.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::ExitSwitch);
@@ -23,10 +24,13 @@
namespace tint::ir {
ExitSwitch::ExitSwitch(ir::Switch* sw, utils::VectorRef<Value*> args /* = utils::Empty */)
- : Base(std::move(args)), switch_(sw) {
+ : switch_(sw) {
TINT_ASSERT(IR, switch_);
- switch_->AddUsage(this);
- switch_->Merge()->AddInboundBranch(this);
+
+ if (switch_) {
+ switch_->Merge()->AddInboundBranch(this);
+ }
+ AddOperands(std::move(args));
}
ExitSwitch::~ExitSwitch() = default;
diff --git a/src/tint/ir/exit_switch_test.cc b/src/tint/ir/exit_switch_test.cc
new file mode 100644
index 0000000..945bb5c
--- /dev/null
+++ b/src/tint/ir/exit_switch_test.cc
@@ -0,0 +1,58 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/exit_switch.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_ExitSwitchTest = IRTestHelper;
+
+TEST_F(IR_ExitSwitchTest, Usage) {
+ auto* arg1 = b.Constant(1_u);
+ auto* arg2 = b.Constant(2_u);
+ auto* switch_ = b.CreateSwitch(b.Constant(true));
+ auto* e = b.ExitSwitch(switch_, utils::Vector{arg1, arg2});
+
+ EXPECT_THAT(arg1->Usages(), testing::UnorderedElementsAre(Usage{e, 0u}));
+ EXPECT_THAT(arg2->Usages(), testing::UnorderedElementsAre(Usage{e, 1u}));
+}
+
+TEST_F(IR_ExitSwitchTest, Fail_NullSwitch) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.ExitSwitch(nullptr);
+ },
+ "");
+}
+
+TEST_F(IR_ExitSwitchTest, Fail_NullArg) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.ExitSwitch(b.CreateSwitch(b.Constant(false)), utils::Vector<Value*, 1>{nullptr});
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/from_program.cc b/src/tint/ir/from_program.cc
index 49fbfe2..ad4d29f 100644
--- a/src/tint/ir/from_program.cc
+++ b/src/tint/ir/from_program.cc
@@ -145,8 +145,8 @@
/* dst */ {builder_.ir.constant_values},
};
- /// The stack of control blocks.
- utils::Vector<Branch*, 8> control_stack_;
+ /// The stack of flow control instructions.
+ utils::Vector<ControlInstruction*, 8> control_stack_;
/// The current block for expressions.
Block* current_block_ = nullptr;
@@ -162,7 +162,9 @@
class ControlStackScope {
public:
- ControlStackScope(Impl* impl, Branch* b) : impl_(impl) { impl_->control_stack_.Push(b); }
+ ControlStackScope(Impl* impl, ControlInstruction* b) : impl_(impl) {
+ impl_->control_stack_.Push(b);
+ }
~ControlStackScope() { impl_->control_stack_.Pop(); }
@@ -646,6 +648,9 @@
auto* loop_inst = builder_.CreateLoop();
current_block_->Append(loop_inst);
+ // Loop branches directly to the body (no initializer)
+ loop_inst->Body()->AddInboundBranch(loop_inst);
+
{
ControlStackScope scope(this, loop_inst);
current_block_ = loop_inst->Body();
@@ -690,6 +695,9 @@
auto* loop_inst = builder_.CreateLoop();
current_block_->Append(loop_inst);
+ // Loop branches directly to the body (no initializer)
+ loop_inst->Body()->AddInboundBranch(loop_inst);
+
// Continue is always empty, just go back to the start
current_block_ = loop_inst->Continuing();
SetBranch(builder_.NextIteration(loop_inst));
@@ -735,16 +743,23 @@
scopes_.Push();
TINT_DEFER(scopes_.Pop());
- if (stmt->initializer) {
- // Emit the for initializer before branching to the loop
- EmitStatement(stmt->initializer);
- }
-
{
ControlStackScope scope(this, loop_inst);
- current_block_ = loop_inst->Body();
+ if (stmt->initializer) {
+ // Loop branches to the initializer
+ loop_inst->Initializer()->AddInboundBranch(loop_inst);
+ // Emit the for initializer before branching to the body
+ current_block_ = loop_inst->Initializer();
+ EmitStatement(stmt->initializer);
+ SetBranch(builder_.NextIteration(loop_inst));
+ } else {
+ // If there's no initializer, then the loop branches directly to the body block
+ loop_inst->Body()->AddInboundBranch(loop_inst);
+ }
+
+ current_block_ = loop_inst->Body();
if (stmt->condition) {
// Emit the condition into the target target of the loop
auto reg = EmitExpression(stmt->condition);
diff --git a/src/tint/ir/from_program_accessor_test.cc b/src/tint/ir/from_program_accessor_test.cc
index 9da49a1..d8e6601 100644
--- a/src/tint/ir/from_program_accessor_test.cc
+++ b/src/tint/ir/from_program_accessor_test.cc
@@ -90,7 +90,11 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()),
- R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ R"(MyStruct = struct @align(4) {
+ foo:i32 @offset(0)
+}
+
+%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
%a:ptr<function, MyStruct, read_write> = var
%3:ptr<function, i32, read_write> = access %a, 0u
@@ -122,7 +126,16 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()),
- R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ R"(Inner = struct @align(4) {
+ bar:f32 @offset(0)
+}
+
+Outer = struct @align(4) {
+ a:i32 @offset(0)
+ foo:Inner @offset(4)
+}
+
+%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
%a:ptr<function, Outer, read_write> = var
%3:ptr<function, f32, read_write> = access %a, 1u, 0u
@@ -158,7 +171,18 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()),
- R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ R"(Inner = struct @align(16) {
+ b:i32 @offset(0)
+ c:f32 @offset(4)
+ bar:vec4<f32> @offset(16)
+}
+
+Outer = struct @align(16) {
+ a:i32 @offset(0)
+ foo:array<Inner, 4> @offset(16)
+}
+
+%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
%a:ptr<function, array<Outer, 4>, read_write> = var
%3:ptr<function, vec4<f32>, read_write> = access %a, 0u, 1u, 1u, 2u
@@ -281,7 +305,12 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()),
- R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ R"(MyStruct = struct @align(16) {
+ a:i32 @offset(0)
+ foo:vec4<f32> @offset(16)
+}
+
+%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
%a:ptr<function, MyStruct, read_write> = var
%3:ptr<function, vec4<f32>, read_write> = access %a, 1u
@@ -352,7 +381,11 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()),
- R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ R"(MyStruct = struct @align(4) {
+ foo:i32 @offset(0)
+}
+
+%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
%b:i32 = access MyStruct(0i), 0u
ret
@@ -382,7 +415,16 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()),
- R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ R"(Inner = struct @align(4) {
+ bar:f32 @offset(0)
+}
+
+Outer = struct @align(4) {
+ a:i32 @offset(0)
+ foo:Inner @offset(4)
+}
+
+%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
%b:f32 = access Outer(0i, Inner(0.0f)), 1u, 0u
ret
@@ -416,7 +458,18 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()),
- R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ R"(Inner = struct @align(16) {
+ b:i32 @offset(0)
+ c:f32 @offset(4)
+ bar:vec4<f32> @offset(16)
+}
+
+Outer = struct @align(16) {
+ a:i32 @offset(0)
+ foo:array<Inner, 4> @offset(16)
+}
+
+%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
%b:vec4<f32> = access array<Outer, 4>(Outer(0i, array<Inner, 4>(Inner(0i, 0.0f, vec4<f32>(0.0f))))), 0u, 1u, 1u, 2u
ret
@@ -508,7 +561,12 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()),
- R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ R"(MyStruct = struct @align(16) {
+ a:i32 @offset(0)
+ foo:vec4<f32> @offset(16)
+}
+
+%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
%2:vec4<f32> = access MyStruct(0i, vec4<f32>(0.0f)), 1u
%3:vec3<f32> = swizzle %2, zyx
diff --git a/src/tint/ir/from_program_test.cc b/src/tint/ir/from_program_test.cc
index 6b6abd9..e016481 100644
--- a/src/tint/ir/from_program_test.cc
+++ b/src/tint/ir/from_program_test.cc
@@ -302,7 +302,8 @@
if true [t: %b2, f: %b3, m: %b4]
# True block
%b2 = block {
- loop [s: %b5, m: %b6]
+ loop [b: %b5, m: %b6]
+ # Body block
%b5 = block {
exit_loop %b6
}
@@ -348,7 +349,8 @@
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
- loop [s: %b2, m: %b3]
+ loop [b: %b2, m: %b3]
+ # Body block
%b2 = block {
exit_loop %b3
}
@@ -388,7 +390,8 @@
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
- loop [s: %b2, c: %b3, m: %b4]
+ loop [b: %b2, c: %b3, m: %b4]
+ # Body block
%b2 = block {
if true [t: %b5, f: %b6, m: %b7]
# True block
@@ -443,7 +446,8 @@
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
- loop [s: %b2, c: %b3, m: %b4]
+ loop [b: %b2, c: %b3, m: %b4]
+ # Body block
%b2 = block {
continue %b3
}
@@ -476,7 +480,8 @@
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
- loop [s: %b2, c: %b3, m: %b4]
+ loop [b: %b2, c: %b3, m: %b4]
+ # Body block
%b2 = block {
continue %b3
}
@@ -520,7 +525,8 @@
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
- loop [s: %b2, c: %b3]
+ loop [b: %b2, c: %b3]
+ # Body block
%b2 = block {
if true [t: %b4, f: %b5, m: %b6]
# True block
@@ -569,7 +575,8 @@
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
- loop [s: %b2]
+ loop [b: %b2]
+ # Body block
%b2 = block {
ret
}
@@ -607,7 +614,8 @@
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
- loop [s: %b2]
+ loop [b: %b2]
+ # Body block
%b2 = block {
ret
}
@@ -641,7 +649,8 @@
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
- loop [s: %b2, m: %b3]
+ loop [b: %b2, m: %b3]
+ # Body block
%b2 = block {
if true [t: %b4, f: %b5]
# True block
@@ -686,9 +695,11 @@
EXPECT_EQ(Disassemble(m.Get()),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
- loop [s: %b2, c: %b3, m: %b4]
+ loop [b: %b2, c: %b3, m: %b4]
+ # Body block
%b2 = block {
- loop [s: %b5, c: %b6, m: %b7]
+ loop [b: %b5, c: %b6, m: %b7]
+ # Body block
%b5 = block {
if true [t: %b8, f: %b9, m: %b10]
# True block
@@ -725,14 +736,16 @@
# Continuing block
%b6 = block {
- loop [s: %b14, m: %b15]
+ loop [b: %b14, m: %b15]
+ # Body block
%b14 = block {
exit_loop %b15
}
# Merge block
%b15 = block {
- loop [s: %b16, c: %b17, m: %b18]
+ loop [b: %b16, c: %b17, m: %b18]
+ # Body block
%b16 = block {
continue %b17
}
@@ -814,7 +827,8 @@
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
- loop [s: %b2, c: %b3, m: %b4]
+ loop [b: %b2, c: %b3, m: %b4]
+ # Body block
%b2 = block {
if false [t: %b5, f: %b6, m: %b7]
# True block
@@ -875,7 +889,8 @@
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
- loop [s: %b2, c: %b3, m: %b4]
+ loop [b: %b2, c: %b3, m: %b4]
+ # Body block
%b2 = block {
if true [t: %b5, f: %b6, m: %b7]
# True block
@@ -949,6 +964,48 @@
EXPECT_EQ(Disassemble(m), R"()");
}
+TEST_F(IR_FromProgramTest, For_Init_NoCondOrContinuing) {
+ auto* ast_for = For(Decl(Var("i", ty.i32())), nullptr, nullptr, Block(Break()));
+ WrapInFunction(ast_for);
+
+ auto res = Build();
+ ASSERT_TRUE(res) << (!res ? res.Failure() : "");
+
+ auto m = res.Move();
+ auto* flow = FindSingleValue<ir::Loop>(m);
+
+ ASSERT_EQ(1u, m.functions.Length());
+
+ EXPECT_EQ(1u, flow->Initializer()->InboundBranches().Length());
+ EXPECT_EQ(1u, flow->Body()->InboundBranches().Length());
+ EXPECT_EQ(0u, flow->Continuing()->InboundBranches().Length());
+ EXPECT_EQ(1u, flow->Merge()->InboundBranches().Length());
+
+ EXPECT_EQ(Disassemble(m),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ loop [i: %b2, b: %b3, m: %b4]
+ # Initializer block
+ %b2 = block {
+ %i:ptr<function, i32, read_write> = var
+ next_iteration %b3
+ }
+
+ # Body block
+ %b3 = block {
+ exit_loop %b4
+ }
+
+ # Merge block
+ %b4 = block {
+ ret
+ }
+
+ }
+}
+)");
+}
+
TEST_F(IR_FromProgramTest, For_NoInitCondOrContinuing) {
auto* ast_for = For(nullptr, nullptr, nullptr, Block(Break()));
WrapInFunction(ast_for);
@@ -968,7 +1025,8 @@
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
- loop [s: %b2, m: %b3]
+ loop [b: %b2, m: %b3]
+ # Body block
%b2 = block {
exit_loop %b3
}
diff --git a/src/tint/ir/function.cc b/src/tint/ir/function.cc
index 3769d4f..fa13fb8 100644
--- a/src/tint/ir/function.cc
+++ b/src/tint/ir/function.cc
@@ -21,12 +21,21 @@
Function::Function(const type::Type* rt,
PipelineStage stage,
std::optional<std::array<uint32_t, 3>> wg_size)
- : Base(), pipeline_stage_(stage), workgroup_size_(wg_size) {
+ : pipeline_stage_(stage), workgroup_size_(wg_size) {
+ TINT_ASSERT(IR, rt != nullptr);
+
return_.type = rt;
}
Function::~Function() = default;
+void Function::SetParams(utils::VectorRef<FunctionParam*> params) {
+ params_ = std::move(params);
+ for (auto* param : params_) {
+ TINT_ASSERT(IR, param != nullptr);
+ }
+}
+
utils::StringStream& operator<<(utils::StringStream& out, Function::PipelineStage value) {
switch (value) {
case Function::PipelineStage::kVertex:
diff --git a/src/tint/ir/function.h b/src/tint/ir/function.h
index fa5c547..775033d 100644
--- a/src/tint/ir/function.h
+++ b/src/tint/ir/function.h
@@ -111,13 +111,16 @@
/// Sets the function parameters
/// @param params the function paramters
- void SetParams(utils::VectorRef<FunctionParam*> params) { params_ = std::move(params); }
+ void SetParams(utils::VectorRef<FunctionParam*> params);
/// @returns the function parameters
utils::VectorRef<FunctionParam*> Params() const { return params_; }
/// Sets the start target for the function
/// @param target the start target
- void SetStartTarget(Block* target) { start_target_ = target; }
+ void SetStartTarget(Block* target) {
+ TINT_ASSERT(IR, target != nullptr);
+ start_target_ = target;
+ }
/// @returns the function start target
Block* StartTarget() const { return start_target_; }
diff --git a/src/tint/ir/function_param.cc b/src/tint/ir/function_param.cc
index ae94f24..710a380 100644
--- a/src/tint/ir/function_param.cc
+++ b/src/tint/ir/function_param.cc
@@ -18,7 +18,9 @@
namespace tint::ir {
-FunctionParam::FunctionParam(const type::Type* ty) : type_(ty) {}
+FunctionParam::FunctionParam(const type::Type* ty) : type_(ty) {
+ TINT_ASSERT(IR, ty != nullptr);
+}
FunctionParam::~FunctionParam() = default;
diff --git a/src/tint/ir/function_param_test.cc b/src/tint/ir/function_param_test.cc
new file mode 100644
index 0000000..43c00b9
--- /dev/null
+++ b/src/tint/ir/function_param_test.cc
@@ -0,0 +1,48 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/function_param.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_FunctionParamTest = IRTestHelper;
+
+TEST_F(IR_FunctionParamTest, Fail_NullType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.FunctionParam(nullptr);
+ },
+ "");
+}
+
+TEST_F(IR_FunctionParamTest, Fail_SetDuplicateBuiltin) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ auto* fp = b.FunctionParam(mod.Types().f32());
+ fp->SetBuiltin(FunctionParam::Builtin::kVertexIndex);
+ fp->SetBuiltin(FunctionParam::Builtin::kSampleMask);
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/function_test.cc b/src/tint/ir/function_test.cc
new file mode 100644
index 0000000..e8ee091
--- /dev/null
+++ b/src/tint/ir/function_test.cc
@@ -0,0 +1,70 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/function.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_FunctionTest = IRTestHelper;
+
+TEST_F(IR_FunctionTest, Fail_NullReturnType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.CreateFunction("my_func", nullptr);
+ },
+ "");
+}
+
+TEST_F(IR_FunctionTest, Fail_DoubleReturnBuiltin) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ auto* f = b.CreateFunction("my_func", mod.Types().void_());
+ f->SetReturnBuiltin(Function::ReturnBuiltin::kFragDepth);
+ f->SetReturnBuiltin(Function::ReturnBuiltin::kPosition);
+ },
+ "");
+}
+
+TEST_F(IR_FunctionTest, Fail_NullParam) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ auto* f = b.CreateFunction("my_func", mod.Types().void_());
+ f->SetParams(utils::Vector<FunctionParam*, 1>{nullptr});
+ },
+ "");
+}
+
+TEST_F(IR_FunctionTest, Fail_NullStartTarget) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ auto* f = b.CreateFunction("my_func", mod.Types().void_());
+ f->SetStartTarget(nullptr);
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/if.cc b/src/tint/ir/if.cc
index aa1e884..29bcbb2 100644
--- a/src/tint/ir/if.cc
+++ b/src/tint/ir/if.cc
@@ -16,17 +16,28 @@
TINT_INSTANTIATE_TYPEINFO(tint::ir::If);
+#include "src/tint/ir/block.h"
+
namespace tint::ir {
-If::If(Value* cond, ir::Block* t, ir::Block* f, ir::Block* m)
- : Base(utils::Empty), condition_(cond), true_(t), false_(f), merge_(m) {
+If::If(Value* cond, ir::Block* t, ir::Block* f, ir::Block* m) : true_(t), false_(f), merge_(m) {
+ TINT_ASSERT(IR, cond);
TINT_ASSERT(IR, true_);
TINT_ASSERT(IR, false_);
TINT_ASSERT(IR, merge_);
- condition_->AddUsage(this);
- true_->AddInboundBranch(this);
- false_->AddInboundBranch(this);
+ AddOperand(cond);
+ if (true_) {
+ true_->AddInboundBranch(this);
+ true_->SetParent(this);
+ }
+ if (false_) {
+ false_->AddInboundBranch(this);
+ false_->SetParent(this);
+ }
+ if (merge_) {
+ merge_->SetParent(this);
+ }
}
If::~If() = default;
diff --git a/src/tint/ir/if.h b/src/tint/ir/if.h
index 6dbe5fc..4566ef8 100644
--- a/src/tint/ir/if.h
+++ b/src/tint/ir/if.h
@@ -15,20 +15,34 @@
#ifndef SRC_TINT_IR_IF_H_
#define SRC_TINT_IR_IF_H_
-#include "src/tint/ir/block.h"
-#include "src/tint/ir/branch.h"
-#include "src/tint/ir/value.h"
-
-// Forward declarations
-namespace tint::ir {
-class Block;
-} // namespace tint::ir
+#include "src/tint/ir/control_instruction.h"
namespace tint::ir {
-/// An if instruction
-class If : public utils::Castable<If, Branch> {
+/// If instruction.
+///
+/// ```
+/// in
+/// ┃
+/// ┏━━━━━━━━━━┻━━━━━━━━━━┓
+/// ▼ ▼
+/// ┌────────────┐ ┌────────────┐
+/// │ True │ │ False │
+/// | (optional) | | (optional) |
+/// └────────────┘ └────────────┘
+/// ExitIf ┃ ┌──────────┐ ┃ ExitIf
+/// ┗━━━━▶│ Merge │◀━━━━┛
+/// │(optional)│
+/// └──────────┘
+/// ┃
+/// ▼
+/// out
+/// ```
+class If : public utils::Castable<If, ControlInstruction> {
public:
+ /// The index of the condition operand
+ static constexpr size_t kConditionOperandIndex = 0;
+
/// Constructor
/// @param cond the if condition
/// @param t the true block
@@ -37,10 +51,13 @@
explicit If(Value* cond, ir::Block* t, ir::Block* f, ir::Block* m);
~If() override;
+ /// @returns the branch arguments
+ utils::Slice<Value const* const> Args() const override { return utils::Slice<Value*>{}; }
+
/// @returns the if condition
- const Value* Condition() const { return condition_; }
+ const Value* Condition() const { return operands_[kConditionOperandIndex]; }
/// @returns the if condition
- Value* Condition() { return condition_; }
+ Value* Condition() { return operands_[kConditionOperandIndex]; }
/// @returns the true branch block
const ir::Block* True() const { return true_; }
@@ -58,7 +75,6 @@
ir::Block* Merge() { return merge_; }
private:
- Value* condition_ = nullptr;
ir::Block* true_ = nullptr;
ir::Block* false_ = nullptr;
ir::Block* merge_ = nullptr;
diff --git a/src/tint/ir/if_test.cc b/src/tint/ir/if_test.cc
new file mode 100644
index 0000000..4d5bcdf
--- /dev/null
+++ b/src/tint/ir/if_test.cc
@@ -0,0 +1,81 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/if.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_IfTest = IRTestHelper;
+
+TEST_F(IR_IfTest, Usage) {
+ auto* cond = b.Constant(true);
+ auto* if_ = b.CreateIf(cond);
+ EXPECT_THAT(cond->Usages(), testing::UnorderedElementsAre(Usage{if_, 0u}));
+}
+
+TEST_F(IR_IfTest, Parent) {
+ auto* cond = b.Constant(true);
+ auto* if_ = b.CreateIf(cond);
+ EXPECT_EQ(if_->True()->Parent(), if_);
+ EXPECT_EQ(if_->False()->Parent(), if_);
+ EXPECT_EQ(if_->Merge()->Parent(), if_);
+}
+
+TEST_F(IR_IfTest, Fail_NullCondition) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.CreateIf(nullptr);
+ },
+ "");
+}
+
+TEST_F(IR_IfTest, Fail_NullTrueBlock) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ If if_(b.Constant(false), nullptr, b.CreateBlock(), b.CreateBlock());
+ },
+ "");
+}
+
+TEST_F(IR_IfTest, Fail_NullFalseBlock) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ If if_(b.Constant(false), b.CreateBlock(), nullptr, b.CreateBlock());
+ },
+ "");
+}
+
+TEST_F(IR_IfTest, Fail_NullMergeBlock) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ If if_(b.Constant(false), b.CreateBlock(), b.CreateBlock(), nullptr);
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/instruction.h b/src/tint/ir/instruction.h
index 9a6087e..c19171b 100644
--- a/src/tint/ir/instruction.h
+++ b/src/tint/ir/instruction.h
@@ -52,6 +52,11 @@
/// Removes this instruction from the owning block
void Remove();
+ /// Set an operand at a given index.
+ /// @param index the operand index
+ /// @param value the value to use
+ virtual void SetOperand(uint32_t index, ir::Value* value) = 0;
+
/// Pointer to the next instruction in the list
Instruction* next = nullptr;
/// Pointer to the previous instruction in the list
diff --git a/src/tint/ir/load.cc b/src/tint/ir/load.cc
index 3b64cdd..7761192 100644
--- a/src/tint/ir/load.cc
+++ b/src/tint/ir/load.cc
@@ -19,10 +19,11 @@
namespace tint::ir {
-Load::Load(const type::Type* type, Value* f) : Base(), result_type_(type), from_(f) {
+Load::Load(const type::Type* type, Value* f) : Base(), result_type_(type) {
TINT_ASSERT(IR, result_type_);
- TINT_ASSERT(IR, from_);
- from_->AddUsage(this);
+ TINT_ASSERT(IR, f);
+
+ AddOperand(f);
}
Load::~Load() = default;
diff --git a/src/tint/ir/load.h b/src/tint/ir/load.h
index bee65be..3a33320 100644
--- a/src/tint/ir/load.h
+++ b/src/tint/ir/load.h
@@ -15,13 +15,13 @@
#ifndef SRC_TINT_IR_LOAD_H_
#define SRC_TINT_IR_LOAD_H_
-#include "src/tint/ir/instruction.h"
+#include "src/tint/ir/operand_instruction.h"
#include "src/tint/utils/castable.h"
namespace tint::ir {
/// A load instruction in the IR.
-class Load : public utils::Castable<Load, Instruction> {
+class Load : public utils::Castable<Load, OperandInstruction<1>> {
public:
/// Constructor
/// @param type the result type
@@ -32,12 +32,11 @@
/// @returns the type of the value
const type::Type* Type() const override { return result_type_; }
- /// @returns the avlue being loaded from
- Value* From() const { return from_; }
+ /// @returns the value being loaded from
+ Value* From() const { return operands_[0]; }
private:
- const type::Type* result_type_;
- Value* from_;
+ const type::Type* result_type_ = nullptr;
};
} // namespace tint::ir
diff --git a/src/tint/ir/load_test.cc b/src/tint/ir/load_test.cc
index d552fac..ad21d13 100644
--- a/src/tint/ir/load_test.cc
+++ b/src/tint/ir/load_test.cc
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
#include "src/tint/ir/builder.h"
#include "src/tint/ir/instruction.h"
#include "src/tint/ir/ir_test_helper.h"
@@ -23,7 +25,7 @@
using IR_LoadTest = IRTestHelper;
-TEST_F(IR_LoadTest, CreateLoad) {
+TEST_F(IR_LoadTest, Create) {
auto* store_type = mod.Types().i32();
auto* var = b.Declare(mod.Types().pointer(store_type, builtin::AddressSpace::kFunction,
builtin::Access::kReadWrite));
@@ -38,15 +40,58 @@
EXPECT_EQ(inst->From(), var);
}
-TEST_F(IR_LoadTest, Load_Usage) {
+TEST_F(IR_LoadTest, Usage) {
auto* store_type = mod.Types().i32();
auto* var = b.Declare(mod.Types().pointer(store_type, builtin::AddressSpace::kFunction,
builtin::Access::kReadWrite));
- const auto* inst = b.Load(var);
+ auto* inst = b.Load(var);
ASSERT_NE(inst->From(), nullptr);
- ASSERT_EQ(inst->From()->Usage().Length(), 1u);
- EXPECT_EQ(inst->From()->Usage()[0], inst);
+ EXPECT_THAT(inst->From()->Usages(), testing::UnorderedElementsAre(Usage{inst, 0u}));
+}
+
+TEST_F(IR_LoadTest, Fail_NullType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+
+ auto* store_type = mod.Types().i32();
+ auto* var = b.Declare(mod.Types().pointer(store_type, builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite));
+ Load l(nullptr, var);
+ },
+ "");
+}
+
+TEST_F(IR_LoadTest, Fail_NonPtr_Builder) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Load(b.Declare(mod.Types().f32()));
+ },
+ "");
+}
+
+TEST_F(IR_LoadTest, Fail_NullValue_Builder) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Load(nullptr);
+ },
+ "");
+}
+
+TEST_F(IR_LoadTest, Fail_NullValue) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ Load l(mod.Types().f32(), nullptr);
+ },
+ "");
}
} // namespace
diff --git a/src/tint/ir/loop.cc b/src/tint/ir/loop.cc
index 06764c8..37c4f00 100644
--- a/src/tint/ir/loop.cc
+++ b/src/tint/ir/loop.cc
@@ -16,22 +16,37 @@
#include <utility>
+#include "src/tint/ir/block.h"
+
TINT_INSTANTIATE_TYPEINFO(tint::ir::Loop);
namespace tint::ir {
-Loop::Loop(ir::Block* b,
- ir::Block* c,
- ir::Block* m,
- utils::VectorRef<Value*> args /* = utils::Empty */)
- : Base(std::move(args)), body_(b), continuing_(c), merge_(m) {
+Loop::Loop(ir::Block* i, ir::Block* b, ir::Block* c, ir::Block* m)
+ : initializer_(i), body_(b), continuing_(c), merge_(m) {
+ TINT_ASSERT(IR, initializer_);
TINT_ASSERT(IR, body_);
TINT_ASSERT(IR, continuing_);
TINT_ASSERT(IR, merge_);
- body_->AddInboundBranch(this);
+ if (initializer_) {
+ initializer_->SetParent(this);
+ }
+ if (body_) {
+ body_->SetParent(this);
+ }
+ if (continuing_) {
+ continuing_->SetParent(this);
+ }
+ if (merge_) {
+ merge_->SetParent(this);
+ }
}
Loop::~Loop() = default;
+bool Loop::HasInitializer() const {
+ return initializer_->HasBranchTarget();
+}
+
} // namespace tint::ir
diff --git a/src/tint/ir/loop.h b/src/tint/ir/loop.h
index 0ac9463..50c69c7 100644
--- a/src/tint/ir/loop.h
+++ b/src/tint/ir/loop.h
@@ -15,30 +15,73 @@
#ifndef SRC_TINT_IR_LOOP_H_
#define SRC_TINT_IR_LOOP_H_
-#include "src/tint/ir/block.h"
-#include "src/tint/ir/branch.h"
+#include "src/tint/ir/control_instruction.h"
namespace tint::ir {
-/// Flow node describing a loop.
-class Loop : public utils::Castable<Loop, Branch> {
+/// Loop instruction.
+///
+/// ```
+/// in
+/// ┃
+/// ┣━━━━━━━━━━━┓
+/// ▼ ┃
+/// ┌─────────────────┐ ┃
+/// │ Initializer │ ┃
+/// │ (optional) │ ┃
+/// └─────────────────┘ ┃
+/// NextIteration ┃ ┃
+/// ┃◀━━━━━━━━━━┫
+/// ▼ ┃
+/// ┌─────────────────┐ ┃
+/// ┏━━│ Body │ ┃
+/// ┃ └─────────────────┘ ┃
+/// ┃ Continue ┃ ┃ NextIteration
+/// ┃ ▼ ┃
+/// ┃ ┌─────────────────┐ ┃ BreakIf(false)
+/// ExitLoop ┃ │ Continuing │━━┛
+/// │ (optional) │
+/// ┃ └─────────────────┘
+/// ┃ ┃
+/// ┃ ┃ BreakIf(true)
+/// ┗━━━━━━━━━━▶┃
+/// ▼
+/// ┌────────────────┐
+/// │ Merge │
+/// │ (optional) │
+/// └────────────────┘
+/// ┃
+/// ▼
+/// out
+///
+/// ```
+class Loop : public utils::Castable<Loop, ControlInstruction> {
public:
/// Constructor
+ /// @param i the initializer block
/// @param b the body block
/// @param c the continuing block
/// @param m the merge block
- /// @param args the branch arguments
- Loop(ir::Block* b, ir::Block* c, ir::Block* m, utils::VectorRef<Value*> args = utils::Empty);
+ Loop(ir::Block* i, ir::Block* b, ir::Block* c, ir::Block* m);
~Loop() override;
- /// @returns the switch start branch
+ /// @returns the switch initializer block
+ const ir::Block* Initializer() const { return initializer_; }
+ /// @returns the switch initializer block
+ ir::Block* Initializer() { return initializer_; }
+
+ /// @returns true if the loop uses an initializer block. If true, then the Loop first branches
+ /// to the initializer block, otherwise it first branches to the body block.
+ bool HasInitializer() const;
+
+ /// @returns the switch start block
const ir::Block* Body() const { return body_; }
- /// @returns the switch start branch
+ /// @returns the switch start block
ir::Block* Body() { return body_; }
- /// @returns the switch continuing branch
+ /// @returns the switch continuing block
const ir::Block* Continuing() const { return continuing_; }
- /// @returns the switch continuing branch
+ /// @returns the switch continuing block
ir::Block* Continuing() { return continuing_; }
/// @returns the switch merge branch
@@ -47,6 +90,7 @@
ir::Block* Merge() { return merge_; }
private:
+ ir::Block* initializer_ = nullptr;
ir::Block* body_ = nullptr;
ir::Block* continuing_ = nullptr;
ir::Block* merge_ = nullptr;
diff --git a/src/tint/ir/loop_test.cc b/src/tint/ir/loop_test.cc
new file mode 100644
index 0000000..e51a005
--- /dev/null
+++ b/src/tint/ir/loop_test.cc
@@ -0,0 +1,74 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/loop.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_LoopTest = IRTestHelper;
+
+TEST_F(IR_LoopTest, Parent) {
+ auto* loop = b.CreateLoop();
+ EXPECT_EQ(loop->Initializer()->Parent(), loop);
+ EXPECT_EQ(loop->Body()->Parent(), loop);
+ EXPECT_EQ(loop->Continuing()->Parent(), loop);
+ EXPECT_EQ(loop->Merge()->Parent(), loop);
+}
+
+TEST_F(IR_LoopTest, Fail_NullInitializerBlock) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ Loop loop(nullptr, b.CreateBlock(), b.CreateBlock(), b.CreateBlock());
+ },
+ "");
+}
+
+TEST_F(IR_LoopTest, Fail_NullBodyBlock) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ Loop loop(b.CreateBlock(), nullptr, b.CreateBlock(), b.CreateBlock());
+ },
+ "");
+}
+
+TEST_F(IR_LoopTest, Fail_NullContinuingBlock) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ Loop loop(b.CreateBlock(), b.CreateBlock(), nullptr, b.CreateBlock());
+ },
+ "");
+}
+
+TEST_F(IR_LoopTest, Fail_NullMergeBlock) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ Loop loop(b.CreateBlock(), b.CreateBlock(), b.CreateBlock(), nullptr);
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/module.h b/src/tint/ir/module.h
index 1520888..d3393e6 100644
--- a/src/tint/ir/module.h
+++ b/src/tint/ir/module.h
@@ -15,6 +15,7 @@
#ifndef SRC_TINT_IR_MODULE_H_
#define SRC_TINT_IR_MODULE_H_
+#include <memory>
#include <string>
#include "src/tint/constant/manager.h"
@@ -67,6 +68,9 @@
Symbol SetName(const Value* value, std::string_view name);
/// @return the type manager for the module
+ const type::Manager& Types() const { return constant_values.types; }
+
+ /// @return the type manager for the module
type::Manager& Types() { return constant_values.types; }
/// The block allocator
@@ -89,6 +93,9 @@
/// The map of constant::Value to their ir::Constant.
utils::Hashmap<const constant::Value*, ir::Constant*, 16> constants;
+
+ /// If the module generated a validation error, will store the file for the disassembly text.
+ std::unique_ptr<Source::File> disassembly_file;
};
} // namespace tint::ir
diff --git a/src/tint/ir/next_iteration.cc b/src/tint/ir/next_iteration.cc
index 1e057d0..b1bf620 100644
--- a/src/tint/ir/next_iteration.cc
+++ b/src/tint/ir/next_iteration.cc
@@ -16,6 +16,7 @@
#include <utility>
+#include "src/tint/ir/block.h"
#include "src/tint/ir/loop.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::NextIteration);
@@ -23,10 +24,13 @@
namespace tint::ir {
NextIteration::NextIteration(ir::Loop* loop, utils::VectorRef<Value*> args /* = utils::Empty */)
- : Base(std::move(args)), loop_(loop) {
+ : loop_(loop) {
TINT_ASSERT(IR, loop_);
- loop_->AddUsage(this);
- loop_->Body()->AddInboundBranch(this);
+
+ if (loop_) {
+ loop_->Body()->AddInboundBranch(this);
+ }
+ AddOperands(std::move(args));
}
NextIteration::~NextIteration() = default;
diff --git a/src/tint/ir/next_iteration_test.cc b/src/tint/ir/next_iteration_test.cc
new file mode 100644
index 0000000..33f99da
--- /dev/null
+++ b/src/tint/ir/next_iteration_test.cc
@@ -0,0 +1,46 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/next_iteration.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_NextIterationTest = IRTestHelper;
+
+TEST_F(IR_NextIterationTest, Fail_NullLoop) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.NextIteration(nullptr);
+ },
+ "");
+}
+
+TEST_F(IR_NextIterationTest, Fail_NullValue) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.NextIteration(b.CreateLoop(), utils::Vector<Value*, 1>{nullptr});
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/operand_instruction.cc b/src/tint/ir/operand_instruction.cc
new file mode 100644
index 0000000..7d56ed2
--- /dev/null
+++ b/src/tint/ir/operand_instruction.cc
@@ -0,0 +1,23 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/operand_instruction.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::OperandInstruction<1>);
+TINT_INSTANTIATE_TYPEINFO(tint::ir::OperandInstruction<2>);
+TINT_INSTANTIATE_TYPEINFO(tint::ir::OperandInstruction<3>);
+TINT_INSTANTIATE_TYPEINFO(tint::ir::OperandInstruction<4>);
+TINT_INSTANTIATE_TYPEINFO(tint::ir::OperandInstruction<8>);
+
+namespace tint::ir {} // namespace tint::ir
diff --git a/src/tint/ir/operand_instruction.h b/src/tint/ir/operand_instruction.h
new file mode 100644
index 0000000..8a121f7
--- /dev/null
+++ b/src/tint/ir/operand_instruction.h
@@ -0,0 +1,69 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_OPERAND_INSTRUCTION_H_
+#define SRC_TINT_IR_OPERAND_INSTRUCTION_H_
+
+#include "src/tint/ir/instruction.h"
+
+namespace tint::ir {
+
+/// An instruction in the IR that expects one or more operands.
+template <unsigned N>
+class OperandInstruction : public utils::Castable<OperandInstruction<N>, Instruction> {
+ public:
+ /// Destructor
+ ~OperandInstruction() override = default;
+
+ /// Set an operand at a given index.
+ /// @param index the operand index
+ /// @param value the value to use
+ void SetOperand(uint32_t index, ir::Value* value) override {
+ TINT_ASSERT(IR, index < operands_.Length());
+ if (operands_[index]) {
+ operands_[index]->RemoveUsage({this, index});
+ }
+ operands_[index] = value;
+ if (value) {
+ value->AddUsage({this, index});
+ }
+ return;
+ }
+
+ protected:
+ /// Append a new operand to the operand list for this instruction.
+ /// @param value the operand value to append
+ void AddOperand(ir::Value* value) {
+ if (value) {
+ value->AddUsage({this, static_cast<uint32_t>(operands_.Length())});
+ }
+ operands_.Push(value);
+ }
+
+ /// Append a list of non-null operands to the operand list for this instruction.
+ /// @param values the operand values to append
+ void AddOperands(utils::VectorRef<ir::Value*> values) {
+ for (auto* val : values) {
+ TINT_ASSERT(IR, val != nullptr);
+ AddOperand(val);
+ }
+ }
+
+ /// The operands to this instruction.
+ utils::Vector<ir::Value*, N> operands_;
+};
+
+} // namespace tint::ir
+
+#endif // SRC_TINT_IR_OPERAND_INSTRUCTION_H_
diff --git a/src/tint/ir/return.cc b/src/tint/ir/return.cc
index bf7fca7..73bcde0 100644
--- a/src/tint/ir/return.cc
+++ b/src/tint/ir/return.cc
@@ -14,15 +14,21 @@
#include "src/tint/ir/return.h"
+#include <utility>
+
#include "src/tint/ir/function.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::Return);
namespace tint::ir {
-Return::Return(Function* func, utils::VectorRef<Value*> args) : Base(args), func_(func) {
+Return::Return(Function* func, utils::VectorRef<Value*> args) : func_(func) {
TINT_ASSERT(IR, func_);
- func_->AddUsage(this);
+
+ if (func_) {
+ func_->AddUsage({this, 0u});
+ }
+ AddOperands(std::move(args));
}
Return::~Return() = default;
diff --git a/src/tint/ir/return_test.cc b/src/tint/ir/return_test.cc
new file mode 100644
index 0000000..e7a032e
--- /dev/null
+++ b/src/tint/ir/return_test.cc
@@ -0,0 +1,47 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/return.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_ReturnTest = IRTestHelper;
+
+TEST_F(IR_ReturnTest, Fail_NullFunction) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Return(nullptr);
+ },
+ "");
+}
+
+TEST_F(IR_ReturnTest, Fail_NullValue) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Return(b.CreateFunction("myfunc", mod.Types().void_()),
+ utils::Vector<Value*, 1>{nullptr});
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/store.cc b/src/tint/ir/store.cc
index 8b8de46..958f1e4 100644
--- a/src/tint/ir/store.cc
+++ b/src/tint/ir/store.cc
@@ -19,11 +19,12 @@
namespace tint::ir {
-Store::Store(Value* to, Value* from) : Base(), to_(to), from_(from) {
- TINT_ASSERT(IR, to_);
- TINT_ASSERT(IR, from_);
- to_->AddUsage(this);
- from_->AddUsage(this);
+Store::Store(Value* to, Value* from) {
+ TINT_ASSERT(IR, to);
+ TINT_ASSERT(IR, from);
+
+ AddOperand(to);
+ AddOperand(from);
}
Store::~Store() = default;
diff --git a/src/tint/ir/store.h b/src/tint/ir/store.h
index 374fe54..0dee73c 100644
--- a/src/tint/ir/store.h
+++ b/src/tint/ir/store.h
@@ -15,13 +15,13 @@
#ifndef SRC_TINT_IR_STORE_H_
#define SRC_TINT_IR_STORE_H_
-#include "src/tint/ir/instruction.h"
+#include "src/tint/ir/operand_instruction.h"
#include "src/tint/utils/castable.h"
namespace tint::ir {
-/// An instruction in the IR.
-class Store : public utils::Castable<Store, Instruction> {
+/// A store instruction in the IR.
+class Store : public utils::Castable<Store, OperandInstruction<2>> {
public:
/// Constructor
/// @param to the value to store too
@@ -30,14 +30,10 @@
~Store() override;
/// @returns the value being stored too
- Value* To() const { return to_; }
+ Value* To() const { return operands_[0]; }
/// @returns the value being stored
- Value* From() const { return from_; }
-
- private:
- Value* to_;
- Value* from_;
+ Value* From() const { return operands_[1]; }
};
} // namespace tint::ir
diff --git a/src/tint/ir/store_test.cc b/src/tint/ir/store_test.cc
index a1b0d38..a28ecad 100644
--- a/src/tint/ir/store_test.cc
+++ b/src/tint/ir/store_test.cc
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
#include "src/tint/ir/builder.h"
#include "src/tint/ir/instruction.h"
#include "src/tint/ir/ir_test_helper.h"
@@ -24,9 +26,8 @@
using IR_StoreTest = IRTestHelper;
TEST_F(IR_StoreTest, CreateStore) {
- // TODO(dsinclair): This is wrong, but we don't have anything correct to store too at the
- // moment.
- auto* to = b.Discard();
+ auto* to = b.Declare(mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kPrivate,
+ builtin::Access::kReadWrite));
const auto* inst = b.Store(to, b.Constant(4_i));
ASSERT_TRUE(inst->Is<Store>());
@@ -40,15 +41,35 @@
TEST_F(IR_StoreTest, Store_Usage) {
auto* to = b.Discard();
- const auto* inst = b.Store(to, b.Constant(4_i));
+ auto* inst = b.Store(to, b.Constant(4_i));
ASSERT_NE(inst->To(), nullptr);
- ASSERT_EQ(inst->To()->Usage().Length(), 1u);
- EXPECT_EQ(inst->To()->Usage()[0], inst);
+ EXPECT_THAT(inst->To()->Usages(), testing::UnorderedElementsAre(Usage{inst, 0u}));
ASSERT_NE(inst->From(), nullptr);
- ASSERT_EQ(inst->From()->Usage().Length(), 1u);
- EXPECT_EQ(inst->From()->Usage()[0], inst);
+ EXPECT_THAT(inst->From()->Usages(), testing::UnorderedElementsAre(Usage{inst, 1u}));
+}
+
+TEST_F(IR_StoreTest, Fail_NullTo) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Store(nullptr, b.Constant(1_u));
+ },
+ "");
+}
+
+TEST_F(IR_StoreTest, Fail_NullFrom) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ auto* to = b.Declare(mod.Types().pointer(
+ mod.Types().i32(), builtin::AddressSpace::kPrivate, builtin::Access::kReadWrite));
+ b.Store(to, nullptr);
+ },
+ "");
}
} // namespace
diff --git a/src/tint/ir/switch.cc b/src/tint/ir/switch.cc
index 3756f45..68137ac 100644
--- a/src/tint/ir/switch.cc
+++ b/src/tint/ir/switch.cc
@@ -16,12 +16,19 @@
TINT_INSTANTIATE_TYPEINFO(tint::ir::Switch);
+#include "src/tint/ir/block.h"
+
namespace tint::ir {
-Switch::Switch(Value* cond, ir::Block* m) : Base(utils::Empty), condition_(cond), merge_(m) {
- TINT_ASSERT(IR, condition_);
+Switch::Switch(Value* cond, ir::Block* m) : merge_(m) {
+ TINT_ASSERT(IR, cond);
TINT_ASSERT(IR, merge_);
- condition_->AddUsage(this);
+
+ AddOperand(cond);
+
+ if (merge_) {
+ merge_->SetParent(this);
+ }
}
Switch::~Switch() = default;
diff --git a/src/tint/ir/switch.h b/src/tint/ir/switch.h
index d9355f1..b065074 100644
--- a/src/tint/ir/switch.h
+++ b/src/tint/ir/switch.h
@@ -15,15 +15,35 @@
#ifndef SRC_TINT_IR_SWITCH_H_
#define SRC_TINT_IR_SWITCH_H_
-#include "src/tint/ir/block.h"
-#include "src/tint/ir/branch.h"
-#include "src/tint/ir/constant.h"
-#include "src/tint/ir/value.h"
+#include "src/tint/ir/control_instruction.h"
+
+// Forward declarations
+namespace tint::ir {
+class Constant;
+} // namespace tint::ir
namespace tint::ir {
-
-/// Flow node representing a switch statement
-class Switch : public utils::Castable<Switch, Branch> {
+/// Switch instruction.
+///
+/// ```
+/// in
+/// ┃
+/// ╌╌╌╌╌╌╌╌┲━━━━━━━━━━━━━━╋━━━━━━━━━━━━━━┱╌╌╌╌╌╌╌╌
+/// ▼ ▼ ▼
+/// ┌────────┐ ┌────────┐ ┌────────┐
+/// │ Case A │ │ Case B │ │ Case C │
+/// └────────┘ └────────┘ └────────┘
+/// ExitSwitch ┃ ExitSwitch ┃ ExitSwitch ┃
+/// ┃ ▼ ┃
+/// ┃ ┌────────────┐ ┃
+/// ╌╌╌╌╌╌╌╌┺━━━━━━▶│ Merge │◀━━━━━━━┹╌╌╌╌╌╌╌╌
+/// │ (optional) │
+/// └────────────┘
+/// ┃
+/// ▼
+/// out
+/// ```
+class Switch : public utils::Castable<Switch, ControlInstruction> {
public:
/// A case selector
struct CaseSelector {
@@ -63,13 +83,15 @@
/// @returns the switch cases
utils::Vector<Case, 4>& Cases() { return cases_; }
+ /// @returns the branch arguments
+ utils::Slice<Value const* const> Args() const override { return {}; }
+
/// @returns the condition
- const Value* Condition() const { return condition_; }
+ const Value* Condition() const { return operands_[0]; }
/// @returns the condition
- Value* Condition() { return condition_; }
+ Value* Condition() { return operands_[0]; }
private:
- Value* condition_ = nullptr;
ir::Block* merge_ = nullptr;
utils::Vector<Case, 4> cases_;
};
diff --git a/src/tint/ir/switch_test.cc b/src/tint/ir/switch_test.cc
new file mode 100644
index 0000000..cbcf11a
--- /dev/null
+++ b/src/tint/ir/switch_test.cc
@@ -0,0 +1,61 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/switch.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_SwitchTest = IRTestHelper;
+
+TEST_F(IR_SwitchTest, Usage) {
+ auto* cond = b.Constant(true);
+ auto* switch_ = b.CreateSwitch(cond);
+ EXPECT_THAT(cond->Usages(), testing::UnorderedElementsAre(Usage{switch_, 0u}));
+}
+
+TEST_F(IR_SwitchTest, Parent) {
+ auto* switch_ = b.CreateSwitch(b.Constant(1_i));
+ b.CreateCase(switch_, utils::Vector{Switch::CaseSelector{nullptr}});
+ EXPECT_THAT(switch_->Merge()->Parent(), switch_);
+ EXPECT_THAT(switch_->Cases().Front().Start()->Parent(), switch_);
+}
+
+TEST_F(IR_SwitchTest, Fail_NullCondition) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.CreateSwitch(nullptr);
+ },
+ "");
+}
+
+TEST_F(IR_SwitchTest, Fail_NullMergeBlock) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ Switch switch_(b.Constant(false), nullptr);
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/swizzle.cc b/src/tint/ir/swizzle.cc
index 1887d4d..58cecb2 100644
--- a/src/tint/ir/swizzle.cc
+++ b/src/tint/ir/swizzle.cc
@@ -23,8 +23,17 @@
namespace tint::ir {
Swizzle::Swizzle(const type::Type* ty, Value* object, utils::VectorRef<uint32_t> indices)
- : result_type_(ty), object_(object), indices_(std::move(indices)) {
- object_->AddUsage(this);
+ : result_type_(ty), indices_(std::move(indices)) {
+ TINT_ASSERT(IR, object != nullptr);
+ TINT_ASSERT(IR, result_type_ != nullptr);
+ TINT_ASSERT(IR, !indices.IsEmpty());
+ TINT_ASSERT(IR, indices.Length() <= 4);
+
+ AddOperand(object);
+
+ for (auto idx : indices_) {
+ TINT_ASSERT(IR, idx < 4);
+ }
}
Swizzle::~Swizzle() = default;
diff --git a/src/tint/ir/swizzle.h b/src/tint/ir/swizzle.h
index 31c3da1..0dde62d 100644
--- a/src/tint/ir/swizzle.h
+++ b/src/tint/ir/swizzle.h
@@ -15,13 +15,13 @@
#ifndef SRC_TINT_IR_SWIZZLE_H_
#define SRC_TINT_IR_SWIZZLE_H_
-#include "src/tint/ir/instruction.h"
+#include "src/tint/ir/operand_instruction.h"
#include "src/tint/utils/castable.h"
namespace tint::ir {
/// A swizzle instruction in the IR.
-class Swizzle : public utils::Castable<Swizzle, Instruction> {
+class Swizzle : public utils::Castable<Swizzle, OperandInstruction<1>> {
public:
/// Constructor
/// @param result_type the result type
@@ -34,14 +34,13 @@
const type::Type* Type() const override { return result_type_; }
/// @returns the object used for the access
- Value* Object() const { return object_; }
+ Value* Object() const { return operands_[0]; }
/// @returns the swizzle indices
utils::VectorRef<uint32_t> Indices() const { return indices_; }
private:
const type::Type* result_type_ = nullptr;
- Value* object_ = nullptr;
utils::Vector<uint32_t, 4> indices_;
};
diff --git a/src/tint/ir/swizzle_test.cc b/src/tint/ir/swizzle_test.cc
new file mode 100644
index 0000000..64df5c6
--- /dev/null
+++ b/src/tint/ir/swizzle_test.cc
@@ -0,0 +1,98 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/swizzle.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using IR_SwizzleTest = IRTestHelper;
+
+TEST_F(IR_SwizzleTest, SetsUsage) {
+ auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite);
+ auto* var = b.Declare(ty);
+ auto* a = b.Swizzle(mod.Types().i32(), var, utils::Vector{1u});
+
+ EXPECT_THAT(var->Usages(), testing::UnorderedElementsAre(Usage{a, 0u}));
+}
+
+TEST_F(IR_SwizzleTest, Fail_NullType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite);
+ auto* var = b.Declare(ty);
+ b.Swizzle(nullptr, var, utils::Vector{1u});
+ },
+ "");
+}
+
+TEST_F(IR_SwizzleTest, Fail_NullObject) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Swizzle(mod.Types().i32(), nullptr, utils::Vector{1u});
+ },
+ "");
+}
+
+TEST_F(IR_SwizzleTest, Fail_EmptyIndices) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite);
+ auto* var = b.Declare(ty);
+ b.Swizzle(mod.Types().i32(), var, utils::Empty);
+ },
+ "");
+}
+
+TEST_F(IR_SwizzleTest, Fail_TooManyIndices) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite);
+ auto* var = b.Declare(ty);
+ b.Swizzle(mod.Types().i32(), var, utils::Vector{1u, 1u, 1u, 1u, 1u});
+ },
+ "");
+}
+
+TEST_F(IR_SwizzleTest, Fail_IndexOutOfRange) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite);
+ auto* var = b.Declare(ty);
+ b.Swizzle(mod.Types().i32(), var, utils::Vector{4u});
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/to_program.cc b/src/tint/ir/to_program.cc
index 33c9674..b226236 100644
--- a/src/tint/ir/to_program.cc
+++ b/src/tint/ir/to_program.cc
@@ -17,6 +17,7 @@
#include <string>
#include <utility>
+#include "src/tint/ir/binary.h"
#include "src/tint/ir/block.h"
#include "src/tint/ir/call.h"
#include "src/tint/ir/constant.h"
@@ -28,6 +29,7 @@
#include "src/tint/ir/return.h"
#include "src/tint/ir/store.h"
#include "src/tint/ir/switch.h"
+#include "src/tint/ir/unary.h"
#include "src/tint/ir/user_call.h"
#include "src/tint/ir/var.h"
#include "src/tint/program_builder.h"
@@ -92,20 +94,20 @@
const ast::Function* Fn(const Function* fn) {
SCOPED_NESTING();
- auto name = NameOf(fn);
// TODO(crbug.com/tint/1915): Properly implement this when we've fleshed out Function
- utils::Vector<const ast::Parameter*, 1> params{};
+ static constexpr size_t N = decltype(ast::Function::params)::static_length;
+ auto params = utils::Transform<N>(fn->Params(), [&](const ir::FunctionParam* param) {
+ auto name = AssignNameTo(param);
+ auto ty = Type(param->Type());
+ return b.Param(name, ty);
+ });
+
+ auto name = AssignNameTo(fn);
auto ret_ty = Type(fn->ReturnType());
- if (!ret_ty) {
- return nullptr;
- }
auto* body = BlockGraph(fn->StartTarget());
- if (!body) {
- return nullptr;
- }
utils::Vector<const ast::Attribute*, 1> attrs{};
utils::Vector<const ast::Attribute*, 1> ret_attrs{};
- return b.Func(name, std::move(params), ret_ty.Get(), body, std::move(attrs),
+ return b.Func(name, std::move(params), ret_ty, body, std::move(attrs),
std::move(ret_attrs));
}
@@ -123,12 +125,8 @@
TINT_ASSERT(IR, block->HasBranchTarget());
for (auto* inst : *block) {
- auto stmt = Stmt(inst);
- if (TINT_UNLIKELY(!stmt)) {
- return nullptr;
- }
- if (auto* s = stmt.Get()) {
- stmts.Push(s);
+ if (auto* stmt = Stmt(inst)) {
+ stmts.Push(stmt);
}
}
if (auto* if_ = block->Branch()->As<ir::If>()) {
@@ -148,6 +146,35 @@
return b.Block(std::move(stmts));
}
+ ////////////////////////////////////////////////////////////////////////////////////////////////
+ // Statements
+ //
+ // Statement methods may return nullptr, in the case of instructions that do not map to an AST
+ // statement, or in the case of an error. These should simply be ignored.
+ ////////////////////////////////////////////////////////////////////////////////////////////////
+
+ /// @param inst the ir::Instruction
+ /// @return an ast::Statement from @p inst, or nullptr if there was an error
+ const ast::Statement* Stmt(const ir::Instruction* inst) {
+ return tint::Switch(
+ inst, //
+ [&](const ir::Store* i) { return Store(i); }, //
+ [&](const ir::Call* i) { return CallStmt(i); }, //
+ [&](const ir::Var* i) { return Var(i); }, //
+ [&](const ir::If* if_) { return If(if_); }, //
+ [&](const ir::Switch* switch_) { return Switch(switch_); }, //
+ [&](const ir::Return* ret) { return Return(ret); }, //
+ [&](const ir::Value*) { return ValueStmt(inst); },
+ // TODO(dsinclair): Remove when branch is only a parent ...
+ [&](const ir::Branch*) { return nullptr; },
+ [&](Default) {
+ UNHANDLED_CASE(inst);
+ return nullptr;
+ });
+ }
+
+ /// @param i the ir::If
+ /// @return an ast::IfStatement from @p i, or nullptr if there was an error
const ast::IfStatement* If(const ir::If* i) {
SCOPED_NESTING();
auto* cond = Expr(i->Condition());
@@ -181,6 +208,8 @@
return b.If(cond, t);
}
+ /// @param s the ir::Switch
+ /// @return an ast::SwitchStatement from @p s, or nullptr if there was an error
const ast::SwitchStatement* Switch(const ir::Switch* s) {
SCOPED_NESTING();
@@ -223,7 +252,9 @@
return b.Switch(cond, std::move(cases));
}
- utils::Result<const ast::ReturnStatement*> Return(const ir::Return* ret) {
+ /// @param ret the ir::Return
+ /// @return an ast::ReturnStatement from @p ret, or nullptr if there was an error
+ const ast::ReturnStatement* Return(const ir::Return* ret) {
if (ret->Args().IsEmpty()) {
// Return has no arguments.
// If this block is nested withing some control flow, then we must
@@ -239,100 +270,115 @@
if (ret->Args().Length() != 1) {
TINT_ICE(IR, b.Diagnostics())
<< "expected 1 value for return, got " << ret->Args().Length();
- return utils::Failure;
+ return b.Return();
}
auto* val = Expr(ret->Args().Front());
if (TINT_UNLIKELY(!val)) {
- return utils::Failure;
+ return b.Return();
}
return b.Return(val);
}
- utils::Result<const ast::Statement*> Stmt(const ir::Instruction* inst) {
- return tint::Switch<utils::Result<const ast::Statement*>>(
- inst, //
- [&](const ir::Call* i) { return CallStmt(i); }, //
- [&](const ir::Var* i) { return Var(i); }, //
- [&](const ir::Load*) { return nullptr; },
- [&](const ir::Store* i) { return Store(i); }, //
- [&](const ir::If* if_) { return If(if_); },
- [&](const ir::Switch* switch_) { return Switch(switch_); },
- [&](const ir::Return* ret) { return Return(ret); },
- // TODO(dsinclair): Remove when branch is only a parent ...
- [&](const ir::Branch*) { return utils::Result<const ast::Statement*>{nullptr}; },
- [&](Default) {
- UNHANDLED_CASE(inst);
- return utils::Failure;
- });
- }
+ /// @param call the ir::Call
+ /// @return an ast::CallStatement from @p call, or nullptr if there was an error
+ const ast::CallStatement* CallStmt(const ir::Call* call) { return b.CallStmt(Call(call)); }
- const ast::CallStatement* CallStmt(const ir::Call* call) {
- auto* expr = Call(call);
- if (!expr) {
- return nullptr;
- }
- return b.CallStmt(expr);
- }
-
+ /// @param var the ir::Var
+ /// @return an ast::VariableDeclStatement from @p var
const ast::VariableDeclStatement* Var(const ir::Var* var) {
- Symbol name = NameOf(var);
+ Symbol name = AssignNameTo(var);
auto* ptr = var->Type()->As<type::Pointer>();
- if (!ptr) {
- Err("Incorrect type for var");
- return nullptr;
- }
auto ty = Type(ptr->StoreType());
const ast::Expression* init = nullptr;
if (var->Initializer()) {
init = Expr(var->Initializer());
- if (!init) {
- return nullptr;
- }
}
switch (ptr->AddressSpace()) {
case builtin::AddressSpace::kFunction:
- return b.Decl(b.Var(name, ty.Get(), init));
+ return b.Decl(b.Var(name, ty, init));
case builtin::AddressSpace::kStorage:
- return b.Decl(b.Var(name, ty.Get(), init, ptr->Access(), ptr->AddressSpace()));
+ return b.Decl(b.Var(name, ty, init, ptr->Access(), ptr->AddressSpace()));
default:
- return b.Decl(b.Var(name, ty.Get(), init, ptr->AddressSpace()));
+ return b.Decl(b.Var(name, ty, init, ptr->AddressSpace()));
}
}
+ /// @param store the ir::Store
+ /// @return an ast::AssignmentStatement from @p call
const ast::AssignmentStatement* Store(const ir::Store* store) {
auto* expr = Expr(store->From());
- return b.Assign(NameOf(store->To()), expr);
+ return b.Assign(AssignNameTo(store->To()), expr);
}
- const ast::CallExpression* Call(const ir::Call* call) {
- auto args =
- utils::Transform<2>(call->Args(), [&](const ir::Value* arg) { return Expr(arg); });
- if (args.Any(utils::IsNull)) {
- return nullptr;
+ /// @param val the ir::Value
+ /// @return an ast::Statement from @p val, or nullptr if the value does not produce a statement.
+ const ast::Statement* ValueStmt(const ir::Value* val) {
+ // As we're visiting this value's declaration it shouldn't already have a name reserved.
+ TINT_ASSERT(IR, !value_names_.Contains(val));
+
+ // Determine whether the value should be placed into a let, or inlined in its single place
+ // of usage. Currently a value is inlined if it has a single usage and is unnamed.
+ // TODO(crbug.com/tint/1902): This logic needs to check that the sequence of side-effecting
+ // expressions is not changed by inlining the expression. This needs fixing.
+ bool create_let = val->Usages().Count() > 1 || mod.NameOf(val).IsValid();
+ if (create_let) {
+ auto* init = Expr(val); // Must come before giving the value a name
+ auto name = AssignNameTo(val);
+ return b.Decl(b.Let(name, init));
}
- return tint::Switch(
- call, //
- [&](const ir::UserCall* c) { return b.Call(NameOf(c->Func()), std::move(args)); },
- [&](Default) {
- UNHANDLED_CASE(call);
- return nullptr;
- });
+ return nullptr; // Value will be inlined at its place of usage.
}
+ ////////////////////////////////////////////////////////////////////////////////////////////////
+ // Expressions
+ //
+ // The the case of an error:
+ // * The expression generating methods must return a non-null ast expression pointer, which may
+ // not be semantically legal, but is enough to populate the AST.
+ // * A diagnostic error must be added to the ast::ProgramBuilder.
+ // This prevents littering the ToProgram logic with expensive error checking code.
+ ////////////////////////////////////////////////////////////////////////////////////////////////
+
+ /// @param val the ir::Expression
+ /// @return an ast::Expression from @p val.
+ /// @note May be a semantically-invalid placeholder expression on error.
const ast::Expression* Expr(const ir::Value* val) {
+ if (auto name = value_names_.Get(val)) {
+ return b.Expr(name.value());
+ }
+
return tint::Switch(
val, //
[&](const ir::Constant* c) { return ConstExpr(c); },
[&](const ir::Load* l) { return LoadExpr(l); },
- [&](const ir::Var* v) { return VarExpr(v); },
+ [&](const ir::Unary* u) { return UnaryExpr(u); },
+ [&](const ir::Binary* u) { return BinaryExpr(u); },
[&](Default) {
UNHANDLED_CASE(val);
- return nullptr;
+ return b.Expr("<error>");
});
}
+ /// @param call the ir::Call
+ /// @return an ast::CallExpression from @p call.
+ /// @note May be a semantically-invalid placeholder expression on error.
+ const ast::CallExpression* Call(const ir::Call* call) {
+ auto args =
+ utils::Transform<2>(call->Args(), [&](const ir::Value* arg) { return Expr(arg); });
+ return tint::Switch(
+ call, //
+ [&](const ir::UserCall* c) { return b.Call(AssignNameTo(c->Func()), std::move(args)); },
+ [&](Default) {
+ UNHANDLED_CASE(call);
+ return b.Call("<error>");
+ });
+ }
+
+ /// @param c the ir::Constant
+ /// @return an ast::Expression from @p c.
+ /// @note May be a semantically-invalid placeholder expression on error.
const ast::Expression* ConstExpr(const ir::Constant* c) {
return tint::Switch(
c->Type(), //
@@ -343,16 +389,93 @@
[&](const type::Bool*) { return b.Expr(c->Value()->ValueAs<bool>()); },
[&](Default) {
UNHANDLED_CASE(c);
- return nullptr;
+ return b.Expr("<error>");
});
}
+ /// @param l the ir::Load
+ /// @return an ast::Expression from @p l.
+ /// @note May be a semantically-invalid placeholder expression on error.
const ast::Expression* LoadExpr(const ir::Load* l) { return Expr(l->From()); }
- const ast::Expression* VarExpr(const ir::Var* v) { return b.Expr(NameOf(v)); }
+ /// @param u the ir::Unary
+ /// @return an ast::UnaryOpExpression from @p u.
+ /// @note May be a semantically-invalid placeholder expression on error.
+ const ast::Expression* UnaryExpr(const ir::Unary* u) {
+ switch (u->Kind()) {
+ case ir::Unary::Kind::kComplement:
+ return b.Complement(Expr(u->Val()));
+ case ir::Unary::Kind::kNegation:
+ return b.Negation(Expr(u->Val()));
+ }
+ return b.Expr("<error>");
+ }
- utils::Result<ast::Type> Type(const type::Type* ty) {
- return tint::Switch<utils::Result<ast::Type>>(
+ /// @param e the ir::Binary
+ /// @return an ast::BinaryOpExpression from @p e.
+ /// @note May be a semantically-invalid placeholder expression on error.
+ const ast::Expression* BinaryExpr(const ir::Binary* e) {
+ if (e->Kind() == ir::Binary::Kind::kEqual) {
+ auto* rhs = e->RHS()->As<ir::Constant>();
+ if (rhs && rhs->Type()->Is<type::Bool>() && rhs->Value()->ValueAs<bool>() == false) {
+ // expr == false
+ return b.Not(Expr(e->LHS()));
+ }
+ }
+ auto* lhs = Expr(e->LHS());
+ auto* rhs = Expr(e->RHS());
+ switch (e->Kind()) {
+ case ir::Binary::Kind::kAdd:
+ return b.Add(lhs, rhs);
+ case ir::Binary::Kind::kSubtract:
+ return b.Sub(lhs, rhs);
+ case ir::Binary::Kind::kMultiply:
+ return b.Mul(lhs, rhs);
+ case ir::Binary::Kind::kDivide:
+ return b.Div(lhs, rhs);
+ case ir::Binary::Kind::kModulo:
+ return b.Mod(lhs, rhs);
+ case ir::Binary::Kind::kAnd:
+ return b.And(lhs, rhs);
+ case ir::Binary::Kind::kOr:
+ return b.Or(lhs, rhs);
+ case ir::Binary::Kind::kXor:
+ return b.Xor(lhs, rhs);
+ case ir::Binary::Kind::kEqual:
+ return b.Equal(lhs, rhs);
+ case ir::Binary::Kind::kNotEqual:
+ return b.NotEqual(lhs, rhs);
+ case ir::Binary::Kind::kLessThan:
+ return b.LessThan(lhs, rhs);
+ case ir::Binary::Kind::kGreaterThan:
+ return b.GreaterThan(lhs, rhs);
+ case ir::Binary::Kind::kLessThanEqual:
+ return b.LessThanEqual(lhs, rhs);
+ case ir::Binary::Kind::kGreaterThanEqual:
+ return b.GreaterThanEqual(lhs, rhs);
+ case ir::Binary::Kind::kShiftLeft:
+ return b.Shl(lhs, rhs);
+ case ir::Binary::Kind::kShiftRight:
+ return b.Shr(lhs, rhs);
+ }
+ return b.Expr("<error>");
+ }
+
+ ////////////////////////////////////////////////////////////////////////////////////////////////
+ // Types
+ //
+ // The the case of an error:
+ // * The types generating methods must return a non-null ast type, which may not be semantically
+ // legal, but is enough to populate the AST.
+ // * A diagnostic error must be added to the ast::ProgramBuilder.
+ // This prevents littering the ToProgram logic with expensive error checking code.
+ ////////////////////////////////////////////////////////////////////////////////////////////////
+
+ /// @param ty the type::Type
+ /// @return an ast::Type from @p ty.
+ /// @note May be a semantically-invalid placeholder type on error.
+ ast::Type Type(const type::Type* ty) {
+ return tint::Switch(
ty, //
[&](const type::Void*) { return ast::Type{}; }, //
[&](const type::I32*) { return b.ty.i32(); }, //
@@ -360,99 +483,81 @@
[&](const type::F16*) { return b.ty.f16(); }, //
[&](const type::F32*) { return b.ty.f32(); }, //
[&](const type::Bool*) { return b.ty.bool_(); },
- [&](const type::Matrix* m) -> utils::Result<ast::Type> {
- auto el = Type(m->type());
- if (!el) {
- return utils::Failure;
- }
- return b.ty.mat(el.Get(), m->columns(), m->rows());
+ [&](const type::Matrix* m) {
+ return b.ty.mat(Type(m->type()), m->columns(), m->rows());
},
- [&](const type::Vector* v) -> utils::Result<ast::Type> {
+ [&](const type::Vector* v) {
auto el = Type(v->type());
- if (!el) {
- return utils::Failure;
- }
if (v->Packed()) {
TINT_ASSERT(IR, v->Width() == 3u);
- return b.ty(builtin::Builtin::kPackedVec3, el.Get());
+ return b.ty(builtin::Builtin::kPackedVec3, el);
} else {
- return b.ty.vec(el.Get(), v->Width());
+ return b.ty.vec(el, v->Width());
}
},
- [&](const type::Array* a) -> utils::Result<ast::Type> {
+ [&](const type::Array* a) {
auto el = Type(a->ElemType());
- if (!el) {
- return utils::Failure;
- }
utils::Vector<const ast::Attribute*, 1> attrs;
if (!a->IsStrideImplicit()) {
attrs.Push(b.Stride(a->Stride()));
}
if (a->Count()->Is<type::RuntimeArrayCount>()) {
- return b.ty.array(el.Get(), std::move(attrs));
+ return b.ty.array(el, std::move(attrs));
}
auto count = a->ConstantCount();
if (TINT_UNLIKELY(!count)) {
TINT_ICE(IR, b.Diagnostics()) << type::Array::kErrExpectedConstantCount;
- return b.ty.array(el.Get(), u32(1), std::move(attrs));
+ return b.ty.array(el, u32(1), std::move(attrs));
}
- return b.ty.array(el.Get(), u32(count.value()), std::move(attrs));
+ return b.ty.array(el, u32(count.value()), std::move(attrs));
},
[&](const type::Struct* s) { return b.ty(s->Name().NameView()); },
- [&](const type::Atomic* a) -> utils::Result<ast::Type> {
- auto el = Type(a->Type());
- if (!el) {
- return utils::Failure;
- }
- return b.ty.atomic(el.Get());
- },
+ [&](const type::Atomic* a) { return b.ty.atomic(Type(a->Type())); },
[&](const type::DepthTexture* t) { return b.ty.depth_texture(t->dim()); },
[&](const type::DepthMultisampledTexture* t) {
return b.ty.depth_multisampled_texture(t->dim());
},
[&](const type::ExternalTexture*) { return b.ty.external_texture(); },
- [&](const type::MultisampledTexture* t) -> utils::Result<ast::Type> {
+ [&](const type::MultisampledTexture* t) {
auto el = Type(t->type());
- if (!el) {
- return utils::Failure;
- }
- return b.ty.multisampled_texture(t->dim(), el.Get());
+ return b.ty.multisampled_texture(t->dim(), el);
},
- [&](const type::SampledTexture* t) -> utils::Result<ast::Type> {
+ [&](const type::SampledTexture* t) {
auto el = Type(t->type());
- if (!el) {
- return utils::Failure;
- }
- return b.ty.sampled_texture(t->dim(), el.Get());
+ return b.ty.sampled_texture(t->dim(), el);
},
[&](const type::StorageTexture* t) {
return b.ty.storage_texture(t->dim(), t->texel_format(), t->access());
},
[&](const type::Sampler* s) { return b.ty.sampler(s->kind()); },
- [&](const type::Pointer* p) -> utils::Result<ast::Type> {
+ [&](const type::Pointer* p) {
// Note: type::Pointer always has an inferred access, but WGSL only allows an
// explicit access in the 'storage' address space.
auto el = Type(p->StoreType());
- if (!el) {
- return utils::Failure;
- }
auto address_space = p->AddressSpace();
auto access = address_space == builtin::AddressSpace::kStorage
? p->Access()
: builtin::Access::kUndefined;
- return b.ty.pointer(el.Get(), address_space, access);
+ return b.ty.pointer(el, address_space, access);
},
- [&](const type::Reference*) -> utils::Result<ast::Type> {
+ [&](const type::Reference*) {
TINT_ICE(IR, b.Diagnostics()) << "reference types should never appear in the IR";
- return ast::Type{};
+ return b.ty.i32();
},
[&](Default) {
UNHANDLED_CASE(ty);
- return ast::Type{};
+ return b.ty.i32();
});
}
- Symbol NameOf(const Value* value) {
+ ////////////////////////////////////////////////////////////////////////////////////////////////
+ // Helpers
+ ////////////////////////////////////////////////////////////////////////////////////////////////
+
+ /// Creates and returns a new, unique name for the given value, or returns the previously
+ /// created name.
+ /// @return the value's name
+ Symbol AssignNameTo(const Value* value) {
TINT_ASSERT(IR, value);
return value_names_.GetOrCreate(value, [&] {
if (auto sym = mod.NameOf(value)) {
@@ -461,8 +566,6 @@
return b.Symbols().New("v" + std::to_string(value_names_.Count()));
});
}
-
- void Err(std::string str) { b.Diagnostics().add_error(diag::System::IR, std::move(str)); }
};
} // namespace
diff --git a/src/tint/ir/to_program_roundtrip_test.cc b/src/tint/ir/to_program_roundtrip_test.cc
index 48e9c5d..587d999 100644
--- a/src/tint/ir/to_program_roundtrip_test.cc
+++ b/src/tint/ir/to_program_roundtrip_test.cc
@@ -42,9 +42,11 @@
auto output_program = ToProgram(ir_module.Get());
if (!output_program.IsValid()) {
tint::ir::Disassembler d{ir_module.Get()};
- FAIL() << output_program.Diagnostics().str() << std::endl
- << "IR:" << std::endl
- << d.Disassemble();
+ FAIL() << output_program.Diagnostics().str() << std::endl //
+ << "IR:" << std::endl //
+ << d.Disassemble() << std::endl //
+ << "AST:" << std::endl //
+ << Program::printer(&output_program) << std::endl;
}
ASSERT_TRUE(output_program.IsValid()) << output_program.Diagnostics().str();
@@ -94,6 +96,215 @@
)");
}
+TEST_F(IRToProgramRoundtripTest, SingleFunction_Parameters) {
+ Test(R"(
+fn f(i : i32, u : u32) -> i32 {
+ return i;
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Unary ops
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, UnaryOp_Negate) {
+ Test(R"(
+fn f(i : i32) -> i32 {
+ return -(i);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, UnaryOp_Complement) {
+ Test(R"(
+fn f(i : u32) -> u32 {
+ return ~(i);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, UnaryOp_Not) {
+ Test(R"(
+fn f(b : bool) -> bool {
+ return !(b);
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Binary ops
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, BinaryOp_Add) {
+ Test(R"(
+fn f(a : i32, b : i32) -> i32 {
+ return (a + b);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, BinaryOp_Subtract) {
+ Test(R"(
+fn f(a : i32, b : i32) -> i32 {
+ return (a - b);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, BinaryOp_Multiply) {
+ Test(R"(
+fn f(a : i32, b : i32) -> i32 {
+ return (a * b);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, BinaryOp_Divide) {
+ Test(R"(
+fn f(a : i32, b : i32) -> i32 {
+ return (a / b);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, BinaryOp_Modulo) {
+ Test(R"(
+fn f(a : i32, b : i32) -> i32 {
+ return (a % b);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, BinaryOp_And) {
+ Test(R"(
+fn f(a : i32, b : i32) -> i32 {
+ return (a & b);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, BinaryOp_Or) {
+ Test(R"(
+fn f(a : i32, b : i32) -> i32 {
+ return (a | b);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, BinaryOp_Xor) {
+ Test(R"(
+fn f(a : i32, b : i32) -> i32 {
+ return (a ^ b);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, BinaryOp_Equal) {
+ Test(R"(
+fn f(a : i32, b : i32) -> bool {
+ return (a == b);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, BinaryOp_NotEqual) {
+ Test(R"(
+fn f(a : i32, b : i32) -> bool {
+ return (a != b);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, BinaryOp_LessThan) {
+ Test(R"(
+fn f(a : i32, b : i32) -> bool {
+ return (a < b);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, BinaryOp_GreaterThan) {
+ Test(R"(
+fn f(a : i32, b : i32) -> bool {
+ return (a > b);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, BinaryOp_LessThanEqual) {
+ Test(R"(
+fn f(a : i32, b : i32) -> bool {
+ return (a <= b);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, BinaryOp_GreaterThanEqual) {
+ Test(R"(
+fn f(a : i32, b : i32) -> bool {
+ return (a >= b);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, BinaryOp_ShiftLeft) {
+ Test(R"(
+fn f(a : i32, b : u32) -> i32 {
+ return (a << b);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, BinaryOp_ShiftRight) {
+ Test(R"(
+fn f(a : i32, b : u32) -> i32 {
+ return (a >> b);
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Short-circuiting binary ops
+////////////////////////////////////////////////////////////////////////////////
+
+// TODO(crbug.com/tint/1902): Pattern detect this
+TEST_F(IRToProgramRoundtripTest, DISABLED_BinaryOp_LogicalAnd) {
+ Test(R"(
+fn f(a : bool, b : bool) -> bool {
+ return (a && b);
+}
+)");
+}
+
+// TODO(crbug.com/tint/1902): Pattern detect this
+TEST_F(IRToProgramRoundtripTest, DISABLED_BinaryOp_LogicalOr) {
+ Test(R"(
+fn f(a : bool, b : bool) -> bool {
+ return (a && b);
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// let
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, LetUsedOnce) {
+ Test(R"(
+fn f(i : u32) -> u32 {
+ let v = ~(i);
+ return v;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, LetUsedTwice) {
+ Test(R"(
+fn f(i : i32) -> i32 {
+ let v = (i * 2i);
+ return (v + v);
+}
+)");
+}
+
////////////////////////////////////////////////////////////////////////////////
// Function-scope var
////////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/ir/transform/block_decorated_structs.cc b/src/tint/ir/transform/block_decorated_structs.cc
new file mode 100644
index 0000000..f9f191c
--- /dev/null
+++ b/src/tint/ir/transform/block_decorated_structs.cc
@@ -0,0 +1,117 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/block_decorated_structs.h"
+
+#include <utility>
+
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/module.h"
+#include "src/tint/type/pointer.h"
+#include "src/tint/type/struct.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::BlockDecoratedStructs);
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ir::transform {
+
+BlockDecoratedStructs::BlockDecoratedStructs() = default;
+
+BlockDecoratedStructs::~BlockDecoratedStructs() = default;
+
+void BlockDecoratedStructs::Run(Module* ir, const DataMap&, DataMap&) const {
+ Builder builder(*ir);
+
+ if (!ir->root_block) {
+ return;
+ }
+
+ // Loop over module-scope declarations, looking for storage or uniform buffers.
+ utils::Vector<Var*, 8> buffer_variables;
+ for (auto inst : *ir->root_block) {
+ auto* var = inst->As<Var>();
+ if (!var) {
+ continue;
+ }
+ auto* ptr = var->Type()->As<type::Pointer>();
+ if (!ptr || !(ptr->AddressSpace() == builtin::AddressSpace::kStorage ||
+ ptr->AddressSpace() == builtin::AddressSpace::kUniform)) {
+ continue;
+ }
+ buffer_variables.Push(var);
+ }
+
+ // Now process the buffer variables.
+ for (auto* var : buffer_variables) {
+ auto* ptr = var->Type()->As<type::Pointer>();
+ auto* store_ty = ptr->StoreType();
+
+ bool wrapped = false;
+ utils::Vector<const type::StructMember*, 4> members;
+
+ // Build the member list for the block-decorated structure.
+ if (auto* str = store_ty->As<type::Struct>(); str && !str->HasFixedFootprint()) {
+ // We know the original struct will only ever be used as the store type of a buffer, so
+ // just redeclare it as a block-decorated struct.
+ for (auto* member : str->Members()) {
+ members.Push(member);
+ }
+ } else {
+ // The original struct might be used in other places, so create a new block-decorated
+ // struct that wraps the original struct.
+ members.Push(ir->Types().Get<type::StructMember>(
+ /* name */ ir->symbols.New(),
+ /* type */ store_ty,
+ /* index */ 0u,
+ /* offset */ 0u,
+ /* align */ store_ty->Align(),
+ /* size */ store_ty->Size(),
+ /* attributes */ type::StructMemberAttributes{}));
+ wrapped = true;
+ }
+
+ // Create the block-decorated struct.
+ auto* block_struct = ir->Types().Get<type::Struct>(
+ /* name */ ir->symbols.New(),
+ /* members */ members,
+ /* align */ store_ty->Align(),
+ /* size */ utils::RoundUp(store_ty->Align(), store_ty->Size()),
+ /* size_no_padding */ store_ty->Size());
+ block_struct->SetStructFlag(type::StructFlag::kBlock);
+
+ // Replace the old variable declaration with one that uses the block-decorated struct type.
+ auto* new_var =
+ builder.Declare(ir->Types().pointer(block_struct, ptr->AddressSpace(), ptr->Access()));
+ new_var->SetBindingPoint(var->BindingPoint()->group, var->BindingPoint()->binding);
+ var->ReplaceWith(new_var);
+
+ // Replace uses of the old variable.
+ while (!var->Usages().IsEmpty()) {
+ auto& use = *var->Usages().begin();
+ if (wrapped) {
+ // The structure has been wrapped, so replace all uses of the old variable with a
+ // member accessor on the new variable.
+ auto* access =
+ builder.Access(var->Type(), new_var, utils::Vector{builder.Constant(0_u)});
+ access->InsertBefore(use.instruction);
+ use.instruction->SetOperand(use.operand_index, access);
+ } else {
+ use.instruction->SetOperand(use.operand_index, new_var);
+ }
+ }
+ }
+}
+
+} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/block_decorated_structs.h b/src/tint/ir/transform/block_decorated_structs.h
new file mode 100644
index 0000000..654f818
--- /dev/null
+++ b/src/tint/ir/transform/block_decorated_structs.h
@@ -0,0 +1,40 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_TRANSFORM_BLOCK_DECORATED_STRUCTS_H_
+#define SRC_TINT_IR_TRANSFORM_BLOCK_DECORATED_STRUCTS_H_
+
+#include "src/tint/ir/transform/transform.h"
+
+#include "src/tint/type/struct.h"
+
+namespace tint::ir::transform {
+
+/// BlockDecoratedStructs is a transform that changes the store type of a buffer to be a special
+/// structure that is recognized as needing a block decoration in SPIR-V, potentially wrapping the
+/// existing store type in a new structure if necessary.
+class BlockDecoratedStructs final : public utils::Castable<BlockDecoratedStructs, Transform> {
+ public:
+ /// Constructor
+ BlockDecoratedStructs();
+ /// Destructor
+ ~BlockDecoratedStructs() override;
+
+ /// @copydoc Transform::Run
+ void Run(ir::Module* module, const DataMap& inputs, DataMap& outputs) const override;
+};
+
+} // namespace tint::ir::transform
+
+#endif // SRC_TINT_IR_TRANSFORM_BLOCK_DECORATED_STRUCTS_H_
diff --git a/src/tint/ir/transform/block_decorated_structs_test.cc b/src/tint/ir/transform/block_decorated_structs_test.cc
new file mode 100644
index 0000000..a308f68
--- /dev/null
+++ b/src/tint/ir/transform/block_decorated_structs_test.cc
@@ -0,0 +1,334 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/block_decorated_structs.h"
+
+#include <utility>
+
+#include "src/tint/ir/transform/test_helper.h"
+#include "src/tint/type/array.h"
+#include "src/tint/type/pointer.h"
+#include "src/tint/type/struct.h"
+
+namespace tint::ir::transform {
+namespace {
+
+using IR_BlockDecoratedStructsTest = TransformTest;
+
+using namespace tint::number_suffixes; // NOLINT
+
+TEST_F(IR_BlockDecoratedStructsTest, NoRootBlock) {
+ auto* func = b.CreateFunction("foo", ty.void_());
+ func->StartTarget()->Append(b.Return(func));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func():void -> %b1 {
+ %b1 = block {
+ ret
+ }
+}
+)";
+
+ Run<BlockDecoratedStructs>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BlockDecoratedStructsTest, Scalar_Uniform) {
+ auto* buffer = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kUniform, builtin::Access::kReadWrite));
+ buffer->SetBindingPoint(0, 0);
+ b.CreateRootBlockIfNeeded()->Append(buffer);
+
+ auto* func = b.CreateFunction("foo", ty.i32());
+ auto* load = b.Load(buffer);
+ func->StartTarget()->Append(load);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{load}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+tint_symbol_1 = struct @align(4), @block {
+ tint_symbol:i32 @offset(0)
+}
+
+# Root block
+%b1 = block {
+ %1:ptr<uniform, tint_symbol_1, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func():i32 -> %b2 {
+ %b2 = block {
+ %3:ptr<uniform, i32, read_write> = access %1, 0u
+ %4:i32 = load %3
+ ret %4
+ }
+}
+)";
+
+ Run<BlockDecoratedStructs>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BlockDecoratedStructsTest, Scalar_Storage) {
+ auto* buffer = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
+ buffer->SetBindingPoint(0, 0);
+ b.CreateRootBlockIfNeeded()->Append(buffer);
+
+ auto* func = b.CreateFunction("foo", ty.void_());
+ func->StartTarget()->Append(b.Store(buffer, b.Constant(42_i)));
+ func->StartTarget()->Append(b.Return(func));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+tint_symbol_1 = struct @align(4), @block {
+ tint_symbol:i32 @offset(0)
+}
+
+# Root block
+%b1 = block {
+ %1:ptr<storage, tint_symbol_1, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func():void -> %b2 {
+ %b2 = block {
+ %3:ptr<storage, i32, read_write> = access %1, 0u
+ store %3, 42i
+ ret
+ }
+}
+)";
+
+ Run<BlockDecoratedStructs>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BlockDecoratedStructsTest, RuntimeArray) {
+ auto* buffer = b.Declare(ty.pointer(ty.runtime_array(ty.i32()), builtin::AddressSpace::kStorage,
+ builtin::Access::kReadWrite));
+ buffer->SetBindingPoint(0, 0);
+ b.CreateRootBlockIfNeeded()->Append(buffer);
+
+ auto* func = b.CreateFunction("foo", ty.void_());
+ auto* access =
+ b.Access(ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite),
+ buffer, utils::Vector{b.Constant(1_u)});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Store(access, b.Constant(42_i)));
+ func->StartTarget()->Append(b.Return(func));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+tint_symbol_1 = struct @align(4), @block {
+ tint_symbol:array<i32> @offset(0)
+}
+
+# Root block
+%b1 = block {
+ %1:ptr<storage, tint_symbol_1, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func():void -> %b2 {
+ %b2 = block {
+ %3:ptr<storage, array<i32>, read_write> = access %1, 0u
+ %4:ptr<storage, i32, read_write> = access %3, 1u
+ store %4, 42i
+ ret
+ }
+}
+)";
+
+ Run<BlockDecoratedStructs>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BlockDecoratedStructsTest, RuntimeArray_InStruct) {
+ utils::Vector<const type::StructMember*, 4> members;
+ members.Push(ty.Get<type::StructMember>(mod.symbols.New(), ty.i32(), 0u, 0u, 4u, 4u,
+ type::StructMemberAttributes{}));
+ members.Push(ty.Get<type::StructMember>(mod.symbols.New(), ty.runtime_array(ty.i32()), 1u, 4u,
+ 4u, 4u, type::StructMemberAttributes{}));
+ auto* structure = ty.Get<type::Struct>(mod.symbols.New(), members, 4u, 8u, 8u);
+
+ auto* buffer = b.Declare(
+ ty.pointer(structure, builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
+ buffer->SetBindingPoint(0, 0);
+ b.CreateRootBlockIfNeeded()->Append(buffer);
+
+ auto* i32_ptr =
+ ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite);
+
+ auto* func = b.CreateFunction("foo", ty.void_());
+ auto* val_ptr = b.Access(i32_ptr, buffer, utils::Vector{b.Constant(0_u)});
+ auto* load = b.Load(val_ptr);
+ auto* elem_ptr = b.Access(i32_ptr, buffer, utils::Vector{b.Constant(1_u), b.Constant(3_u)});
+ func->StartTarget()->Append(val_ptr);
+ func->StartTarget()->Append(load);
+ func->StartTarget()->Append(elem_ptr);
+ func->StartTarget()->Append(b.Store(elem_ptr, load));
+ func->StartTarget()->Append(b.Return(func));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+tint_symbol_2 = struct @align(4) {
+ tint_symbol:i32 @offset(0)
+ tint_symbol_1:array<i32> @offset(4)
+}
+
+tint_symbol_3 = struct @align(4), @block {
+ tint_symbol:i32 @offset(0)
+ tint_symbol_1:array<i32> @offset(4)
+}
+
+# Root block
+%b1 = block {
+ %1:ptr<storage, tint_symbol_3, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func():void -> %b2 {
+ %b2 = block {
+ %3:ptr<storage, i32, read_write> = access %1, 0u
+ %4:i32 = load %3
+ %5:ptr<storage, i32, read_write> = access %1, 1u, 3u
+ store %5, %4
+ ret
+ }
+}
+)";
+
+ Run<BlockDecoratedStructs>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BlockDecoratedStructsTest, StructUsedElsewhere) {
+ utils::Vector<const type::StructMember*, 4> members;
+ members.Push(ty.Get<type::StructMember>(mod.symbols.New(), ty.i32(), 0u, 0u, 4u, 4u,
+ type::StructMemberAttributes{}));
+ members.Push(ty.Get<type::StructMember>(mod.symbols.New(), ty.i32(), 1u, 4u, 4u, 4u,
+ type::StructMemberAttributes{}));
+ auto* structure = ty.Get<type::Struct>(mod.symbols.New(), members, 4u, 8u, 8u);
+
+ auto* buffer = b.Declare(
+ ty.pointer(structure, builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
+ buffer->SetBindingPoint(0, 0);
+ b.CreateRootBlockIfNeeded()->Append(buffer);
+
+ auto* private_var = b.Declare(
+ ty.pointer(structure, builtin::AddressSpace::kPrivate, builtin::Access::kReadWrite));
+ b.CreateRootBlockIfNeeded()->Append(private_var);
+
+ auto* func = b.CreateFunction("foo", ty.void_());
+ func->StartTarget()->Append(b.Store(buffer, private_var));
+ func->StartTarget()->Append(b.Return(func));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+tint_symbol_2 = struct @align(4) {
+ tint_symbol:i32 @offset(0)
+ tint_symbol_1:i32 @offset(4)
+}
+
+tint_symbol_4 = struct @align(4), @block {
+ tint_symbol_3:tint_symbol_2 @offset(0)
+}
+
+# Root block
+%b1 = block {
+ %1:ptr<storage, tint_symbol_4, read_write> = var @binding_point(0, 0)
+ %2:ptr<private, tint_symbol_2, read_write> = var
+}
+
+%foo = func():void -> %b2 {
+ %b2 = block {
+ %4:ptr<storage, tint_symbol_2, read_write> = access %1, 0u
+ store %4, %2
+ ret
+ }
+}
+)";
+
+ Run<BlockDecoratedStructs>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BlockDecoratedStructsTest, MultipleBuffers) {
+ auto* buffer_a = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
+ auto* buffer_b = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
+ auto* buffer_c = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
+ buffer_a->SetBindingPoint(0, 0);
+ buffer_b->SetBindingPoint(0, 1);
+ buffer_c->SetBindingPoint(0, 2);
+ auto* root = b.CreateRootBlockIfNeeded();
+ root->Append(buffer_a);
+ root->Append(buffer_b);
+ root->Append(buffer_c);
+
+ auto* func = b.CreateFunction("foo", ty.void_());
+ auto* load_b = b.Load(buffer_b);
+ auto* load_c = b.Load(buffer_c);
+ func->StartTarget()->Append(load_b);
+ func->StartTarget()->Append(load_c);
+ func->StartTarget()->Append(b.Store(buffer_a, b.Add(ty.i32(), load_b, load_c)));
+ func->StartTarget()->Append(b.Return(func));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+tint_symbol_1 = struct @align(4), @block {
+ tint_symbol:i32 @offset(0)
+}
+
+tint_symbol_3 = struct @align(4), @block {
+ tint_symbol_2:i32 @offset(0)
+}
+
+tint_symbol_5 = struct @align(4), @block {
+ tint_symbol_4:i32 @offset(0)
+}
+
+# Root block
+%b1 = block {
+ %1:ptr<storage, tint_symbol_1, read_write> = var @binding_point(0, 0)
+ %2:ptr<storage, tint_symbol_3, read_write> = var @binding_point(0, 1)
+ %3:ptr<storage, tint_symbol_5, read_write> = var @binding_point(0, 2)
+}
+
+%foo = func():void -> %b2 {
+ %b2 = block {
+ %5:ptr<storage, i32, read_write> = access %2, 0u
+ %6:i32 = load %5
+ %7:ptr<storage, i32, read_write> = access %3, 0u
+ %8:i32 = load %7
+ %9:ptr<storage, i32, read_write> = access %1, 0u
+ store %9, %10
+ ret
+ }
+}
+)";
+
+ Run<BlockDecoratedStructs>();
+
+ EXPECT_EQ(expect, str());
+}
+
+} // namespace
+} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/test_helper.h b/src/tint/ir/transform/test_helper.h
index a122eb2..83ff8c2 100644
--- a/src/tint/ir/transform/test_helper.h
+++ b/src/tint/ir/transform/test_helper.h
@@ -57,6 +57,8 @@
ir::Module mod;
/// The test IR builder.
ir::Builder b{mod};
+ /// The type manager.
+ type::Manager& ty{mod.Types()};
private:
std::vector<std::unique_ptr<Source::File>> files_;
diff --git a/src/tint/ir/transform/var_for_dynamic_index.cc b/src/tint/ir/transform/var_for_dynamic_index.cc
new file mode 100644
index 0000000..fa063e7
--- /dev/null
+++ b/src/tint/ir/transform/var_for_dynamic_index.cc
@@ -0,0 +1,174 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/var_for_dynamic_index.h"
+
+#include <utility>
+
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/module.h"
+#include "src/tint/switch.h"
+#include "src/tint/type/array.h"
+#include "src/tint/type/matrix.h"
+#include "src/tint/type/pointer.h"
+#include "src/tint/type/struct.h"
+#include "src/tint/type/vector.h"
+#include "src/tint/utils/hashmap.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::VarForDynamicIndex);
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ir::transform {
+
+namespace {
+// An access that needs replacing.
+struct AccessToReplace {
+ // The access instruction.
+ Access* access = nullptr;
+ // The index of the first dynamic index.
+ uint32_t first_dynamic_index = 0;
+ // The object type that corresponds to the source of the first dynamic index.
+ const type::Type* dynamic_index_source_type = nullptr;
+};
+
+// A partial access chain that uses constant indices to get to an object that will be
+// dynamically indexed.
+struct PartialAccess {
+ // The base object.
+ Value* base = nullptr;
+ // The list of constant indices to get from the base to the source object.
+ utils::Vector<Value*, 4> indices;
+
+ // A specialization of utils::Hasher for PartialAccess.
+ struct Hasher {
+ inline std::size_t operator()(const PartialAccess& src) const {
+ return utils::Hash(src.base, src.indices);
+ }
+ };
+
+ // An equality helper for PartialAccess.
+ bool operator==(const PartialAccess& other) const {
+ return base == other.base && indices == other.indices;
+ }
+};
+} // namespace
+
+VarForDynamicIndex::VarForDynamicIndex() = default;
+
+VarForDynamicIndex::~VarForDynamicIndex() = default;
+
+static std::optional<AccessToReplace> ShouldReplace(Access* access) {
+ AccessToReplace to_replace{access, 0, access->Object()->Type()};
+
+ // Find the first dynamic index, if any.
+ bool has_dynamic_index = false;
+ for (auto* idx : access->Indices()) {
+ if (to_replace.dynamic_index_source_type->Is<type::Vector>()) {
+ // Stop if we hit a vector, as they can support dynamic accesses.
+ break;
+ }
+
+ // Check if the index is dynamic.
+ auto* const_idx = idx->As<Constant>();
+ if (!const_idx) {
+ has_dynamic_index = true;
+ break;
+ }
+ to_replace.first_dynamic_index++;
+
+ // Update the current object type.
+ to_replace.dynamic_index_source_type = tint::Switch(
+ to_replace.dynamic_index_source_type, //
+ [&](const type::Array* arr) { return arr->ElemType(); },
+ [&](const type::Matrix* mat) { return mat->ColumnType(); },
+ [&](const type::Struct* str) {
+ return str->Members()[const_idx->Value()->ValueAs<u32>()]->Type();
+ },
+ [&](const type::Vector* vec) { return vec->type(); }, //
+ [&](Default) { return nullptr; });
+ }
+ if (!has_dynamic_index) {
+ // No need to modify accesses that only use constant indices.
+ return {};
+ }
+
+ return to_replace;
+}
+
+void VarForDynamicIndex::Run(ir::Module* ir, const DataMap&, DataMap&) const {
+ ir::Builder builder(*ir);
+
+ // Find the access instructions that need replacing.
+ utils::Vector<AccessToReplace, 4> worklist;
+ for (auto* inst : ir->values.Objects()) {
+ auto* access = inst->As<Access>();
+ if (access && !access->Type()->Is<type::Pointer>()) {
+ if (auto to_replace = ShouldReplace(access)) {
+ worklist.Push(to_replace.value());
+ }
+ }
+ }
+
+ // Replace each access instruction that we recorded.
+ utils::Hashmap<Value*, Value*, 4> object_to_local;
+ utils::Hashmap<PartialAccess, Value*, 4, PartialAccess::Hasher> source_object_to_value;
+ for (const auto& to_replace : worklist) {
+ auto* access = to_replace.access;
+ Value* source_object = access->Object();
+
+ // If the access starts with at least one constant index, extract the source of the first
+ // dynamic access to avoid copying the whole object.
+ if (to_replace.first_dynamic_index > 0) {
+ PartialAccess partial_access = {
+ access->Object(), access->Indices().Truncate(to_replace.first_dynamic_index)};
+ source_object = source_object_to_value.GetOrCreate(partial_access, [&]() {
+ auto* intermediate_source = builder.Access(to_replace.dynamic_index_source_type,
+ source_object, partial_access.indices);
+ intermediate_source->InsertBefore(access);
+ return intermediate_source;
+ });
+ }
+
+ // Declare a local variable and copy the source object to it.
+ auto* local = object_to_local.GetOrCreate(source_object, [&]() {
+ auto* decl = builder.Declare(ir->Types().pointer(to_replace.dynamic_index_source_type,
+ builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite));
+ decl->SetInitializer(source_object);
+ decl->InsertBefore(access);
+ return decl;
+ });
+
+ // Create a new access instruction using the local variable as the source.
+ utils::Vector<Value*, 4> indices{access->Indices().Offset(to_replace.first_dynamic_index)};
+ auto* new_access =
+ builder.Access(ir->Types().pointer(access->Type(), builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite),
+ local, indices);
+ access->ReplaceWith(new_access);
+
+ // Load from the access to get the final result value.
+ auto* load = builder.Load(new_access);
+ load->InsertAfter(new_access);
+
+ // Replace all uses of the old access instruction with the loaded result.
+ while (!access->Usages().IsEmpty()) {
+ auto& use = *access->Usages().begin();
+ use.instruction->SetOperand(use.operand_index, load);
+ }
+ }
+}
+
+} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/var_for_dynamic_index.h b/src/tint/ir/transform/var_for_dynamic_index.h
new file mode 100644
index 0000000..1f86b7d
--- /dev/null
+++ b/src/tint/ir/transform/var_for_dynamic_index.h
@@ -0,0 +1,39 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_TRANSFORM_VAR_FOR_DYNAMIC_INDEX_H_
+#define SRC_TINT_IR_TRANSFORM_VAR_FOR_DYNAMIC_INDEX_H_
+
+#include "src/tint/ir/transform/transform.h"
+
+namespace tint::ir::transform {
+
+/// VarForDynamicIndex is a transform that copies array and matrix values that are dynamically
+/// indexed to a temporary local `var` before performing the index. This transform is used by the
+/// SPIR-V writer as there is no SPIR-V instruction that can dynamically index a non-pointer
+/// composite.
+class VarForDynamicIndex final : public utils::Castable<VarForDynamicIndex, Transform> {
+ public:
+ /// Constructor
+ VarForDynamicIndex();
+ /// Destructor
+ ~VarForDynamicIndex() override;
+
+ /// @copydoc Transform::Run
+ void Run(ir::Module* module, const DataMap& inputs, DataMap& outputs) const override;
+};
+
+} // namespace tint::ir::transform
+
+#endif // SRC_TINT_IR_TRANSFORM_VAR_FOR_DYNAMIC_INDEX_H_
diff --git a/src/tint/ir/transform/var_for_dynamic_index_test.cc b/src/tint/ir/transform/var_for_dynamic_index_test.cc
new file mode 100644
index 0000000..0f6af06
--- /dev/null
+++ b/src/tint/ir/transform/var_for_dynamic_index_test.cc
@@ -0,0 +1,428 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/var_for_dynamic_index.h"
+
+#include <utility>
+
+#include "src/tint/ir/transform/test_helper.h"
+#include "src/tint/type/array.h"
+#include "src/tint/type/matrix.h"
+#include "src/tint/type/struct.h"
+
+namespace tint::ir::transform {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+
+class IR_VarForDynamicIndexTest : public TransformTest {
+ protected:
+ const type::Type* ptr(const type::Type* elem) {
+ return ty.pointer(elem, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+ }
+};
+
+TEST_F(IR_VarForDynamicIndexTest, NoModify_ConstantIndex_ArrayValue) {
+ auto* arr = b.FunctionParam(ty.array(ty.i32(), 4u));
+ auto* func = b.CreateFunction("foo", ty.i32());
+ func->SetParams(utils::Vector{arr});
+
+ auto* access = b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_i)});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:array<i32, 4>):i32 -> %b1 {
+ %b1 = block {
+ %3:i32 = access %2, 1i
+ ret %3
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, NoModify_ConstantIndex_MatrixValue) {
+ auto* mat = b.FunctionParam(ty.mat2x2(ty.f32()));
+ auto* func = b.CreateFunction("foo", ty.f32());
+ func->SetParams(utils::Vector{mat});
+
+ auto* access = b.Access(ty.f32(), mat, utils::Vector{b.Constant(1_i), b.Constant(0_i)});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:mat2x2<f32>):f32 -> %b1 {
+ %b1 = block {
+ %3:f32 = access %2, 1i, 0i
+ ret %3
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, NoModify_DynamicIndex_ArrayPointer) {
+ auto* arr = b.FunctionParam(ptr(ty.array(ty.i32(), 4u)));
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.i32());
+ func->SetParams(utils::Vector{arr, idx});
+
+ auto* access = b.Access(ptr(ty.i32()), arr, utils::Vector{idx});
+ auto* load = b.Load(access);
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(load);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{load}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:ptr<function, array<i32, 4>, read_write>, %3:i32):i32 -> %b1 {
+ %b1 = block {
+ %4:ptr<function, i32, read_write> = access %2, %3
+ %5:i32 = load %4
+ ret %5
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, NoModify_DynamicIndex_MatrixPointer) {
+ auto* mat = b.FunctionParam(ptr(ty.mat2x2(ty.f32())));
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.f32());
+ func->SetParams(utils::Vector{mat, idx});
+
+ auto* access = b.Access(ptr(ty.f32()), mat, utils::Vector{idx, idx});
+ auto* load = b.Load(access);
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(load);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{load}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:ptr<function, mat2x2<f32>, read_write>, %3:i32):f32 -> %b1 {
+ %b1 = block {
+ %4:ptr<function, f32, read_write> = access %2, %3, %3
+ %5:f32 = load %4
+ ret %5
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, NoModify_DynamicIndex_VectorValue) {
+ auto* vec = b.FunctionParam(ty.vec4(ty.f32()));
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.f32());
+ func->SetParams(utils::Vector{vec, idx});
+
+ auto* access = b.Access(ty.f32(), vec, utils::Vector{idx});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:vec4<f32>, %3:i32):f32 -> %b1 {
+ %b1 = block {
+ %4:f32 = access %2, %3
+ ret %4
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, DynamicIndex_ArrayValue) {
+ auto* arr = b.FunctionParam(ty.array(ty.i32(), 4u));
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.i32());
+ func->SetParams(utils::Vector{arr, idx});
+
+ auto* access = b.Access(ty.i32(), arr, utils::Vector{idx});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:array<i32, 4>, %3:i32):i32 -> %b1 {
+ %b1 = block {
+ %4:ptr<function, array<i32, 4>, read_write> = var, %2
+ %5:ptr<function, i32, read_write> = access %4, %3
+ %6:i32 = load %5
+ ret %6
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, DynamicIndex_MatrixValue) {
+ auto* arr = b.FunctionParam(ty.mat2x2(ty.f32()));
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.f32());
+ func->SetParams(utils::Vector{arr, idx});
+
+ auto* access = b.Access(ty.f32(), arr, utils::Vector{idx});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:mat2x2<f32>, %3:i32):f32 -> %b1 {
+ %b1 = block {
+ %4:ptr<function, mat2x2<f32>, read_write> = var, %2
+ %5:ptr<function, f32, read_write> = access %4, %3
+ %6:f32 = load %5
+ ret %6
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, AccessChain) {
+ auto* arr = b.FunctionParam(ty.array(ty.array(ty.array(ty.i32(), 4u), 4u), 4u));
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.i32());
+ func->SetParams(utils::Vector{arr, idx});
+
+ auto* access = b.Access(ty.i32(), arr, utils::Vector{idx, b.Constant(1_u), idx});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:array<array<array<i32, 4>, 4>, 4>, %3:i32):i32 -> %b1 {
+ %b1 = block {
+ %4:ptr<function, array<array<array<i32, 4>, 4>, 4>, read_write> = var, %2
+ %5:ptr<function, i32, read_write> = access %4, %3, 1u, %3
+ %6:i32 = load %5
+ ret %6
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, AccessChain_SkipConstantIndices) {
+ auto* arr = b.FunctionParam(ty.array(ty.array(ty.array(ty.i32(), 4u), 4u), 4u));
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.i32());
+ func->SetParams(utils::Vector{arr, idx});
+
+ auto* access = b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:array<array<array<i32, 4>, 4>, 4>, %3:i32):i32 -> %b1 {
+ %b1 = block {
+ %4:array<i32, 4> = access %2, 1u, 2u
+ %5:ptr<function, array<i32, 4>, read_write> = var, %4
+ %6:ptr<function, i32, read_write> = access %5, %3
+ %7:i32 = load %6
+ ret %7
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, AccessChain_SkipConstantIndices_Interleaved) {
+ auto* arr = b.FunctionParam(ty.array(ty.array(ty.array(ty.array(ty.i32(), 4u), 4u), 4u), 4u));
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.i32());
+ func->SetParams(utils::Vector{arr, idx});
+
+ auto* access =
+ b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), idx, b.Constant(2_u), idx});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:array<array<array<array<i32, 4>, 4>, 4>, 4>, %3:i32):i32 -> %b1 {
+ %b1 = block {
+ %4:array<array<array<i32, 4>, 4>, 4> = access %2, 1u
+ %5:ptr<function, array<array<array<i32, 4>, 4>, 4>, read_write> = var, %4
+ %6:ptr<function, i32, read_write> = access %5, %3, 2u, %3
+ %7:i32 = load %6
+ ret %7
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, AccessChain_SkipConstantIndices_Struct) {
+ auto* str_ty = ty.Get<type::Struct>(
+ mod.symbols.Register("MyStruct"),
+ utils::Vector{
+ ty.Get<type::StructMember>(mod.symbols.Register("arr1"), ty.array(ty.f32(), 1024u), 0u,
+ 0u, 4u, 4096u, type::StructMemberAttributes{}),
+ ty.Get<type::StructMember>(mod.symbols.Register("mat"), ty.mat4x4(ty.f32()), 1u, 4096u,
+ 16u, 64u, type::StructMemberAttributes{}),
+ ty.Get<type::StructMember>(mod.symbols.Register("arr2"), ty.array(ty.f32(), 1024u), 2u,
+ 4160u, 4u, 4096u, type::StructMemberAttributes{}),
+ },
+ 16u, 32u, 32u);
+ auto* str_val = b.FunctionParam(str_ty);
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.f32());
+ func->SetParams(utils::Vector{str_val, idx});
+
+ auto* access =
+ b.Access(ty.f32(), str_val, utils::Vector{b.Constant(1_u), idx, b.Constant(0_u)});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+MyStruct = struct @align(16) {
+ arr1:array<f32, 1024> @offset(0)
+ mat:mat4x4<f32> @offset(4096)
+ arr2:array<f32, 1024> @offset(4160)
+}
+
+%foo = func(%2:MyStruct, %3:i32):f32 -> %b1 {
+ %b1 = block {
+ %4:mat4x4<f32> = access %2, 1u
+ %5:ptr<function, mat4x4<f32>, read_write> = var, %4
+ %6:ptr<function, f32, read_write> = access %5, %3, 0u
+ %7:f32 = load %6
+ ret %7
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, MultipleAccessesFromSameSource) {
+ auto* arr = b.FunctionParam(ty.array(ty.i32(), 4u));
+ auto* idx_a = b.FunctionParam(ty.i32());
+ auto* idx_b = b.FunctionParam(ty.i32());
+ auto* idx_c = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.i32());
+ func->SetParams(utils::Vector{arr, idx_a, idx_b, idx_c});
+
+ auto* access_a = b.Access(ty.i32(), arr, utils::Vector{idx_a});
+ auto* access_b = b.Access(ty.i32(), arr, utils::Vector{idx_b});
+ auto* access_c = b.Access(ty.i32(), arr, utils::Vector{idx_c});
+ func->StartTarget()->Append(access_a);
+ func->StartTarget()->Append(access_b);
+ func->StartTarget()->Append(access_c);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access_c}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:array<i32, 4>, %3:i32, %4:i32, %5:i32):i32 -> %b1 {
+ %b1 = block {
+ %6:ptr<function, array<i32, 4>, read_write> = var, %2
+ %7:ptr<function, i32, read_write> = access %6, %3
+ %8:i32 = load %7
+ %9:ptr<function, i32, read_write> = access %6, %4
+ %10:i32 = load %9
+ %11:ptr<function, i32, read_write> = access %6, %5
+ %12:i32 = load %11
+ ret %12
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, MultipleAccessesFromSameSource_SkipConstantIndices) {
+ auto* arr = b.FunctionParam(ty.array(ty.array(ty.array(ty.i32(), 4u), 4u), 4u));
+ auto* idx_a = b.FunctionParam(ty.i32());
+ auto* idx_b = b.FunctionParam(ty.i32());
+ auto* idx_c = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.i32());
+ func->SetParams(utils::Vector{arr, idx_a, idx_b, idx_c});
+
+ auto* access_a =
+ b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx_a});
+ auto* access_b =
+ b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx_b});
+ auto* access_c =
+ b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx_c});
+ func->StartTarget()->Append(access_a);
+ func->StartTarget()->Append(access_b);
+ func->StartTarget()->Append(access_c);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access_c}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:array<array<array<i32, 4>, 4>, 4>, %3:i32, %4:i32, %5:i32):i32 -> %b1 {
+ %b1 = block {
+ %6:array<i32, 4> = access %2, 1u, 2u
+ %7:ptr<function, array<i32, 4>, read_write> = var, %6
+ %8:ptr<function, i32, read_write> = access %7, %3
+ %9:i32 = load %8
+ %10:ptr<function, i32, read_write> = access %7, %4
+ %11:i32 = load %10
+ %12:ptr<function, i32, read_write> = access %7, %5
+ %13:i32 = load %12
+ ret %13
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+} // namespace
+} // namespace tint::ir::transform
diff --git a/src/tint/ir/unary.cc b/src/tint/ir/unary.cc
index 2a0ee94..90509e0 100644
--- a/src/tint/ir/unary.cc
+++ b/src/tint/ir/unary.cc
@@ -19,10 +19,11 @@
namespace tint::ir {
-Unary::Unary(enum Kind k, const type::Type* res_ty, Value* val)
- : kind_(k), result_type_(res_ty), val_(val) {
- TINT_ASSERT(IR, val_);
- val_->AddUsage(this);
+Unary::Unary(enum Kind k, const type::Type* res_ty, Value* val) : kind_(k), result_type_(res_ty) {
+ TINT_ASSERT(IR, val != nullptr);
+ TINT_ASSERT(IR, result_type_ != nullptr);
+
+ AddOperand(val);
}
Unary::~Unary() = default;
diff --git a/src/tint/ir/unary.h b/src/tint/ir/unary.h
index 698413d..0ac191a 100644
--- a/src/tint/ir/unary.h
+++ b/src/tint/ir/unary.h
@@ -15,13 +15,13 @@
#ifndef SRC_TINT_IR_UNARY_H_
#define SRC_TINT_IR_UNARY_H_
-#include "src/tint/ir/instruction.h"
+#include "src/tint/ir/operand_instruction.h"
#include "src/tint/utils/castable.h"
namespace tint::ir {
-/// An instruction in the IR.
-class Unary : public utils::Castable<Unary, Instruction> {
+/// A unary instruction in the IR.
+class Unary : public utils::Castable<Unary, OperandInstruction<1>> {
public:
/// The kind of instruction.
enum class Kind {
@@ -40,17 +40,16 @@
const type::Type* Type() const override { return result_type_; }
/// @returns the value for the instruction
- const Value* Val() const { return val_; }
+ const Value* Val() const { return operands_[0]; }
/// @returns the value for the instruction
- Value* Val() { return val_; }
+ Value* Val() { return operands_[0]; }
/// @returns the kind of unary instruction
enum Kind Kind() const { return kind_; }
private:
enum Kind kind_;
- const type::Type* result_type_;
- Value* val_;
+ const type::Type* result_type_ = nullptr;
};
} // namespace tint::ir
diff --git a/src/tint/ir/unary_test.cc b/src/tint/ir/unary_test.cc
index e04b808..2f2307e 100644
--- a/src/tint/ir/unary_test.cc
+++ b/src/tint/ir/unary_test.cc
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
#include "src/tint/ir/builder.h"
#include "src/tint/ir/instruction.h"
#include "src/tint/ir/ir_test_helper.h"
@@ -53,8 +55,27 @@
EXPECT_EQ(inst->Kind(), Unary::Kind::kNegation);
ASSERT_NE(inst->Val(), nullptr);
- ASSERT_EQ(inst->Val()->Usage().Length(), 1u);
- EXPECT_EQ(inst->Val()->Usage()[0], inst);
+ EXPECT_THAT(inst->Val()->Usages(), testing::UnorderedElementsAre(Usage{inst, 0u}));
+}
+
+TEST_F(IR_UnaryTest, Fail_NullType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Negation(nullptr, b.Constant(1_i));
+ },
+ "");
+}
+
+TEST_F(IR_UnaryTest, Fail_NullValue) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Negation(mod.Types().i32(), nullptr);
+ },
+ "");
}
} // namespace
diff --git a/src/tint/ir/user_call.cc b/src/tint/ir/user_call.cc
index 74f3365..e963e14 100644
--- a/src/tint/ir/user_call.cc
+++ b/src/tint/ir/user_call.cc
@@ -23,8 +23,11 @@
namespace tint::ir {
UserCall::UserCall(const type::Type* ty, Function* func, utils::VectorRef<Value*> arguments)
- : Base(ty, std::move(arguments)), func_(func) {
- func->AddUsage(this);
+ : Base(ty) {
+ TINT_ASSERT(IR, func);
+
+ AddOperand(func);
+ AddOperands(std::move(arguments));
}
UserCall::~UserCall() = default;
diff --git a/src/tint/ir/user_call.h b/src/tint/ir/user_call.h
index 7cc35ae..d165715 100644
--- a/src/tint/ir/user_call.h
+++ b/src/tint/ir/user_call.h
@@ -31,11 +31,15 @@
UserCall(const type::Type* type, Function* func, utils::VectorRef<Value*> args);
~UserCall() override;
+ /// @returns the call arguments
+ utils::Slice<Value const* const> Args() const override {
+ return operands_.Slice().Offset(1).Reinterpret<Value const* const>();
+ }
+
/// @returns the called function name
- const Function* Func() const { return func_; }
+ const Function* Func() const { return operands_.Front()->As<ir::Function>(); }
private:
- const Function* func_ = nullptr;
};
} // namespace tint::ir
diff --git a/src/tint/ir/user_call_test.cc b/src/tint/ir/user_call_test.cc
new file mode 100644
index 0000000..2fd9151
--- /dev/null
+++ b/src/tint/ir/user_call_test.cc
@@ -0,0 +1,69 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/user_call.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_UserCallTest = IRTestHelper;
+
+TEST_F(IR_UserCallTest, Usage) {
+ auto* func = b.CreateFunction("myfunc", mod.Types().void_());
+ auto* arg1 = b.Constant(1_u);
+ auto* arg2 = b.Constant(2_u);
+ auto* e = b.UserCall(mod.Types().void_(), func, utils::Vector{arg1, arg2});
+ EXPECT_THAT(func->Usages(), testing::UnorderedElementsAre(Usage{e, 0u}));
+ EXPECT_THAT(arg1->Usages(), testing::UnorderedElementsAre(Usage{e, 1u}));
+ EXPECT_THAT(arg2->Usages(), testing::UnorderedElementsAre(Usage{e, 2u}));
+}
+
+TEST_F(IR_UserCallTest, Fail_NullType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.UserCall(nullptr, b.CreateFunction("myfunc", mod.Types().void_()));
+ },
+ "");
+}
+
+TEST_F(IR_UserCallTest, Fail_NullFunction) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.UserCall(mod.Types().f32(), nullptr);
+ },
+ "");
+}
+
+TEST_F(IR_UserCallTest, Fail_NullArg) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.UserCall(mod.Types().void_(), b.CreateFunction("myfunc", mod.Types().void_()),
+ utils::Vector<Value*, 1>{nullptr});
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/validate.cc b/src/tint/ir/validate.cc
index 5e28173..029ed8f 100644
--- a/src/tint/ir/validate.cc
+++ b/src/tint/ir/validate.cc
@@ -14,23 +14,46 @@
#include "src/tint/ir/validate.h"
+#include <memory>
+#include <string>
#include <utility>
+#include "src/tint/ir/access.h"
+#include "src/tint/ir/binary.h"
+#include "src/tint/ir/bitcast.h"
+#include "src/tint/ir/break_if.h"
+#include "src/tint/ir/builtin.h"
+#include "src/tint/ir/construct.h"
+#include "src/tint/ir/continue.h"
+#include "src/tint/ir/convert.h"
+#include "src/tint/ir/disassembler.h"
+#include "src/tint/ir/discard.h"
+#include "src/tint/ir/exit_if.h"
+#include "src/tint/ir/exit_loop.h"
+#include "src/tint/ir/exit_switch.h"
#include "src/tint/ir/function.h"
#include "src/tint/ir/if.h"
+#include "src/tint/ir/load.h"
#include "src/tint/ir/loop.h"
+#include "src/tint/ir/next_iteration.h"
#include "src/tint/ir/return.h"
+#include "src/tint/ir/store.h"
#include "src/tint/ir/switch.h"
+#include "src/tint/ir/swizzle.h"
+#include "src/tint/ir/unary.h"
+#include "src/tint/ir/user_call.h"
#include "src/tint/ir/var.h"
#include "src/tint/switch.h"
+#include "src/tint/type/bool.h"
#include "src/tint/type/pointer.h"
+#include "src/tint/utils/scoped_assignment.h"
namespace tint::ir {
namespace {
class Validator {
public:
- explicit Validator(const Module& mod) : mod_(mod) {}
+ explicit Validator(Module& mod) : mod_(mod) {}
~Validator() {}
@@ -42,16 +65,74 @@
}
if (diagnostics_.contains_errors()) {
+ // If a diassembly file was generated then one of the diagnostics referenced the
+ // disasembly. Emit the entire disassembly file at the end of the messages.
+ if (mod_.disassembly_file) {
+ diagnostics_.add_note(tint::diag::System::IR,
+ "# Disassembly\n" + mod_.disassembly_file->content.data, {});
+ }
return std::move(diagnostics_);
}
return Success{};
}
private:
- const Module& mod_;
+ Module& mod_;
diag::List diagnostics_;
+ Disassembler dis_{mod_};
- void AddError(const std::string& err) { diagnostics_.add_error(tint::diag::System::IR, err); }
+ const Block* current_block_ = nullptr;
+
+ void DisassembleIfNeeded() {
+ if (mod_.disassembly_file) {
+ return;
+ }
+ mod_.disassembly_file = std::make_unique<Source::File>("", dis_.Disassemble());
+ }
+
+ void AddError(const Instruction* inst, const std::string& err) {
+ DisassembleIfNeeded();
+ auto src = dis_.InstructionSource(inst);
+ src.file = mod_.disassembly_file.get();
+ AddError(err, src);
+
+ if (current_block_) {
+ AddNote(current_block_, "In block");
+ }
+ }
+
+ void AddError(const Instruction* inst, uint32_t idx, const std::string& err) {
+ DisassembleIfNeeded();
+ auto src = dis_.OperandSource(Disassembler::Operand{inst, idx});
+ src.file = mod_.disassembly_file.get();
+ AddError(err, src);
+
+ if (current_block_) {
+ AddNote(current_block_, "In block");
+ }
+ }
+
+ void AddError(const Block* blk, const std::string& err) {
+ DisassembleIfNeeded();
+ auto src = dis_.BlockSource(blk);
+ src.file = mod_.disassembly_file.get();
+ AddError(err, src);
+ }
+
+ void AddNote(const Block* blk, const std::string& err) {
+ DisassembleIfNeeded();
+ auto src = dis_.BlockSource(blk);
+ src.file = mod_.disassembly_file.get();
+ AddNote(err, src);
+ }
+
+ void AddError(const std::string& err, Source src = {}) {
+ diagnostics_.add_error(tint::diag::System::IR, err, src);
+ }
+
+ void AddNote(const std::string& note, Source src = {}) {
+ diagnostics_.add_note(tint::diag::System::IR, note, src);
+ }
std::string Name(const Value* v) { return mod_.NameOf(v).Name(); }
@@ -60,42 +141,34 @@
return;
}
+ TINT_SCOPED_ASSIGNMENT(current_block_, blk);
+
for (const auto* inst : *blk) {
auto* var = inst->As<ir::Var>();
if (!var) {
- AddError(std::string("root block: invalid instruction: ") + inst->TypeInfo().name);
+ AddError(inst,
+ std::string("root block: invalid instruction: ") + inst->TypeInfo().name);
continue;
}
if (!var->Type()->Is<type::Pointer>()) {
- AddError(std::string("root block: 'var' ") + Name(var) +
- "type is not a pointer: " + var->Type()->TypeInfo().name);
+ AddError(inst, std::string("root block: 'var' ") + Name(var) +
+ "type is not a pointer: " + var->Type()->TypeInfo().name);
}
}
}
- void CheckFunction(const Function* func) {
- for (const auto* param : func->Params()) {
- if (param == nullptr) {
- AddError("function '" + Name(func) + "': null parameter");
- continue;
- }
- }
-
- if (func->StartTarget() == nullptr) {
- AddError("function '" + Name(func) + "': null start target");
- } else {
- CheckBlock(func->StartTarget());
- }
- }
+ void CheckFunction(const Function* func) { CheckBlock(func->StartTarget()); }
void CheckBlock(const Block* blk) {
+ TINT_SCOPED_ASSIGNMENT(current_block_, blk);
+
if (!blk->HasBranchTarget()) {
- AddError("block: does not end in a branch");
+ AddError(blk, "block: does not end in a branch");
}
for (const auto* inst : *blk) {
if (inst->Is<ir::Branch>() && inst != blk->Branch()) {
- AddError("block: branch which isn't the final instruction");
+ AddError(inst, "block: branch which isn't the final instruction");
continue;
}
@@ -105,27 +178,72 @@
void CheckInstruction(const Instruction* inst) {
tint::Switch(
- inst, //
+ inst, //
+ [&](const ir::Access*) {}, //
+ [&](const ir::Binary*) {}, //
+ [&](const ir::Branch* b) { CheckBranch(b); }, //
+ [&](const ir::Call* c) { CheckCall(c); }, //
+ [&](const ir::Load*) {}, //
+ [&](const ir::Store*) {}, //
+ [&](const ir::Swizzle*) {}, //
+ [&](const ir::Unary*) {}, //
+ [&](const ir::Var*) {}, //
+ [&](Default) {
+ AddError(std::string("missing validation of: ") + inst->TypeInfo().name);
+ });
+ }
+
+ void CheckCall(const ir::Call* call) {
+ tint::Switch(
+ call, //
+ [&](const ir::Bitcast*) {}, //
+ [&](const ir::Builtin*) {}, //
+ [&](const ir::Construct*) {}, //
+ [&](const ir::Convert*) {}, //
+ [&](const ir::Discard*) {}, //
+ [&](const ir::UserCall*) {}, //
+ [&](Default) {
+ AddError(std::string("missing validation of call: ") + call->TypeInfo().name);
+ });
+ }
+
+ void CheckBranch(const ir::Branch* b) {
+ tint::Switch(
+ b, //
+ [&](const ir::BreakIf*) {}, //
+ [&](const ir::Continue*) {}, //
+ [&](const ir::ExitIf*) {}, //
+ [&](const ir::ExitLoop*) {}, //
+ [&](const ir::ExitSwitch*) {}, //
+ [&](const ir::If* if_) { CheckIf(if_); }, //
+ [&](const ir::Loop*) {}, //
+ [&](const ir::NextIteration*) {}, //
[&](const ir::Return* ret) {
if (ret->Func() == nullptr) {
AddError("return: null function");
}
- },
+ }, //
+ [&](const ir::Switch*) {}, //
[&](Default) {
- AddError(std::string("missing validation of: ") + inst->TypeInfo().name);
+ AddError(std::string("missing validation of branch: ") + b->TypeInfo().name);
});
}
+
+ void CheckIf(const ir::If* if_) {
+ if (!if_->Condition()) {
+ AddError(if_, "if: condition is nullptr");
+ }
+ if (if_->Condition() && !if_->Condition()->Type()->Is<type::Bool>()) {
+ AddError(if_, If::kConditionOperandIndex, "if: condition must be a `bool` type");
+ }
+ }
};
} // namespace
-utils::Result<Success, std::string> Validate(const Module& mod) {
+utils::Result<Success, diag::List> Validate(Module& mod) {
Validator v(mod);
- auto r = v.IsValid();
- if (!r) {
- return r.Failure().str();
- }
- return Success{};
+ return v.IsValid();
}
} // namespace tint::ir
diff --git a/src/tint/ir/validate.h b/src/tint/ir/validate.h
index 233295b..350dd57 100644
--- a/src/tint/ir/validate.h
+++ b/src/tint/ir/validate.h
@@ -15,8 +15,7 @@
#ifndef SRC_TINT_IR_VALIDATE_H_
#define SRC_TINT_IR_VALIDATE_H_
-#include <string>
-
+#include "src/tint/diagnostic/diagnostic.h"
#include "src/tint/ir/module.h"
#include "src/tint/utils/result.h"
@@ -28,7 +27,7 @@
/// Validates that a given IR module is correctly formed
/// @param mod the module to validate
/// @returns true on success, an error result otherwise
-utils::Result<Success, std::string> Validate(const Module& mod);
+utils::Result<Success, diag::List> Validate(Module& mod);
} // namespace tint::ir
diff --git a/src/tint/ir/validate_test.cc b/src/tint/ir/validate_test.cc
index 9c3676c..0fab76f 100644
--- a/src/tint/ir/validate_test.cc
+++ b/src/tint/ir/validate_test.cc
@@ -30,16 +30,52 @@
mod.root_block->Append(b.Declare(mod.Types().pointer(
mod.Types().i32(), builtin::AddressSpace::kPrivate, builtin::Access::kReadWrite)));
auto res = ir::Validate(mod);
- EXPECT_TRUE(res) << res.Failure();
+ EXPECT_TRUE(res) << res.Failure().str();
}
TEST_F(IR_ValidateTest, RootBlock_NonVar) {
+ auto* l = b.CreateLoop();
+ l->Body()->Append(b.Continue(l));
+
mod.root_block = b.CreateRootBlockIfNeeded();
- mod.root_block->Append(b.CreateLoop());
+ mod.root_block->Append(l);
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure(), "error: root block: invalid instruction: tint::ir::Loop");
+ EXPECT_EQ(res.Failure().str(), R"(:3:3 error: root block: invalid instruction: tint::ir::Loop
+ loop [b: %b2]
+ ^^^^^^^^^^^^^
+
+:2:1 note: In block
+%b1 = block {
+^^^^^^^^^^^^^
+ loop [b: %b2]
+^^^^^^^^^^^^^^^
+ # Body block
+^^^^^^^^^^^^^^^^
+ %b2 = block {
+^^^^^^^^^^^^^^^^^
+ continue %b3
+^^^^^^^^^^^^^^^^^^
+ }
+^^^^^
+
+
+}
+^
+
+note: # Disassembly
+# Root block
+%b1 = block {
+ loop [b: %b2]
+ # Body block
+ %b2 = block {
+ continue %b3
+ }
+
+}
+
+)");
}
TEST_F(IR_ValidateTest, RootBlock_VarBadType) {
@@ -47,7 +83,26 @@
mod.root_block->Append(b.Declare(mod.Types().i32()));
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure(), "error: root block: 'var' type is not a pointer: tint::type::I32");
+ EXPECT_EQ(res.Failure().str(),
+ R"(:3:12 error: root block: 'var' type is not a pointer: tint::type::I32
+ %1:i32 = var
+ ^^^
+
+:2:1 note: In block
+%b1 = block {
+^^^^^^^^^^^^^
+ %1:i32 = var
+^^^^^^^^^^^^^^
+}
+^
+
+note: # Disassembly
+# Root block
+%b1 = block {
+ %1:i32 = var
+}
+
+)");
}
TEST_F(IR_ValidateTest, Function) {
@@ -58,28 +113,7 @@
utils::Vector{b.FunctionParam(mod.Types().i32()), b.FunctionParam(mod.Types().f32())});
f->StartTarget()->SetInstructions(utils::Vector{b.Return(f)});
auto res = ir::Validate(mod);
- EXPECT_TRUE(res) << res.Failure();
-}
-
-TEST_F(IR_ValidateTest, Function_NullStartTarget) {
- auto* f = b.CreateFunction("my_func", mod.Types().void_());
- mod.functions.Push(f);
-
- f->SetStartTarget(nullptr);
- auto res = ir::Validate(mod);
- ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure(), "error: function 'my_func': null start target");
-}
-
-TEST_F(IR_ValidateTest, Function_ParamNull) {
- auto* f = b.CreateFunction("my_func", mod.Types().void_());
- mod.functions.Push(f);
-
- f->SetParams(utils::Vector<FunctionParam*, 1>{nullptr});
- f->StartTarget()->SetInstructions(utils::Vector{b.Return(f)});
- auto res = ir::Validate(mod);
- ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure(), "error: function 'my_func': null parameter");
+ EXPECT_TRUE(res) << res.Failure().str();
}
TEST_F(IR_ValidateTest, Block_NoBranchAtEnd) {
@@ -88,7 +122,18 @@
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure(), "error: block: does not end in a branch");
+ EXPECT_EQ(res.Failure().str(), R"(:2:1 error: block: does not end in a branch
+ %b1 = block {
+^^^^^^^^^^^^^^^
+ }
+^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ }
+}
+)");
}
TEST_F(IR_ValidateTest, Block_BranchInMiddle) {
@@ -98,7 +143,91 @@
f->StartTarget()->SetInstructions(utils::Vector{b.Return(f), b.Return(f)});
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure(), "error: block: branch which isn't the final instruction");
+ EXPECT_EQ(res.Failure().str(), R"(:3:5 error: block: branch which isn't the final instruction
+ ret
+ ^^^
+
+:2:1 note: In block
+ %b1 = block {
+^^^^^^^^^^^^^^^
+ ret
+^^^^^^^
+ ret
+^^^^^^^
+ }
+^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ ret
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, If_ConditionIsBool) {
+ auto* f = b.CreateFunction("my_func", mod.Types().void_());
+ mod.functions.Push(f);
+
+ auto* if_ = b.CreateIf(b.Constant(1_i));
+ if_->True()->Append(b.Return(f));
+ if_->False()->Append(b.Return(f));
+
+ f->StartTarget()->Append(if_);
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:8 error: if: condition must be a `bool` type
+ if 1i [t: %b2, f: %b3]
+ ^^
+
+:2:1 note: In block
+ %b1 = block {
+^^^^^^^^^^^^^^^
+ if 1i [t: %b2, f: %b3]
+^^^^^^^^^^^^^^^^^^^^^^^^^^
+ # True block
+^^^^^^^^^^^^^^^^^^
+ %b2 = block {
+^^^^^^^^^^^^^^^^^^^
+ ret
+^^^^^^^^^^^
+ }
+^^^^^^^
+
+
+ # False block
+^^^^^^^^^^^^^^^^^^^
+ %b3 = block {
+^^^^^^^^^^^^^^^^^^^
+ ret
+^^^^^^^^^^^
+ }
+^^^^^^^
+
+
+ }
+^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ if 1i [t: %b2, f: %b3]
+ # True block
+ %b2 = block {
+ ret
+ }
+
+ # False block
+ %b3 = block {
+ ret
+ }
+
+ }
+}
+)");
}
} // namespace
diff --git a/src/tint/ir/value.h b/src/tint/ir/value.h
index b090ad8..3166168 100644
--- a/src/tint/ir/value.h
+++ b/src/tint/ir/value.h
@@ -17,7 +17,7 @@
#include "src/tint/type/type.h"
#include "src/tint/utils/castable.h"
-#include "src/tint/utils/unique_vector.h"
+#include "src/tint/utils/hashset.h"
// Forward declarations
namespace tint::ir {
@@ -26,19 +26,47 @@
namespace tint::ir {
+/// A specific usage of a Value in the IR.
+struct Usage {
+ /// The instruction that is using the value;
+ Instruction* instruction = nullptr;
+ /// The index of the operand that is the value being used.
+ uint32_t operand_index = 0u;
+
+ /// A specialization of utils::Hasher for Usage.
+ struct Hasher {
+ /// @param u the usage to hash
+ /// @returns a hash of the usage
+ inline std::size_t operator()(const Usage& u) const {
+ return utils::Hash(u.instruction, u.operand_index);
+ }
+ };
+
+ /// An equality helper for Usage.
+ /// @param other the usage to compare against
+ /// @returns true if the two usages are equal
+ bool operator==(const Usage& other) const {
+ return instruction == other.instruction && operand_index == other.operand_index;
+ }
+};
+
/// Value in the IR.
class Value : public utils::Castable<Value> {
public:
/// Destructor
~Value() override;
- /// Adds an instruction which uses this value.
- /// @param inst the instruction
- void AddUsage(const Instruction* inst) { uses_.Add(inst); }
+ /// Adds a usage of this value.
+ /// @param u the usage
+ void AddUsage(Usage u) { uses_.Add(u); }
- /// @returns the vector of instructions which use this value. An instruction will only be
- /// returned once even if that instruction uses the given value multiple times.
- utils::VectorRef<const Instruction*> Usage() const { return uses_; }
+ /// Remove a usage of this value.
+ /// @param u the usage
+ void RemoveUsage(Usage u) { uses_.Remove(u); }
+
+ /// @returns the set of usages of this value. An instruction may appear multiple times if it
+ /// uses the value for multiple different operands.
+ const utils::Hashset<Usage, 4, Usage::Hasher>& Usages() const { return uses_; }
/// @returns the type of the value
virtual const type::Type* Type() const { return nullptr; }
@@ -48,9 +76,8 @@
Value();
private:
- utils::UniqueVector<const Instruction*, 4> uses_;
+ utils::Hashset<Usage, 4, Usage::Hasher> uses_;
};
-
} // namespace tint::ir
#endif // SRC_TINT_IR_VALUE_H_
diff --git a/src/tint/ir/var.cc b/src/tint/ir/var.cc
index 9b43329..8b0225f 100644
--- a/src/tint/ir/var.cc
+++ b/src/tint/ir/var.cc
@@ -19,14 +19,17 @@
namespace tint::ir {
-Var::Var(const type::Type* ty) : type_(ty) {}
+Var::Var(const type::Type* ty) : type_(ty) {
+ TINT_ASSERT(IR, type_ != nullptr);
+
+ // Default to no initializer.
+ AddOperand(nullptr);
+}
Var::~Var() = default;
void Var::SetInitializer(Value* initializer) {
- initializer_ = initializer;
- initializer_->AddUsage(this);
- // TODO(dsinclair): Probably should do a RemoveUsage on an existing initializer if set
+ SetOperand(0, initializer);
}
} // namespace tint::ir
diff --git a/src/tint/ir/var.h b/src/tint/ir/var.h
index 8048463..3c633cc 100644
--- a/src/tint/ir/var.h
+++ b/src/tint/ir/var.h
@@ -18,14 +18,14 @@
#include "src/tint/builtin/access.h"
#include "src/tint/builtin/address_space.h"
#include "src/tint/ir/binding_point.h"
-#include "src/tint/ir/instruction.h"
+#include "src/tint/ir/operand_instruction.h"
#include "src/tint/utils/castable.h"
#include "src/tint/utils/vector.h"
namespace tint::ir {
-/// An instruction in the IR.
-class Var : public utils::Castable<Var, Instruction> {
+/// A var instruction in the IR.
+class Var : public utils::Castable<Var, OperandInstruction<1>> {
public:
/// Constructor
/// @param type the type of the var
@@ -39,7 +39,7 @@
/// @param initializer the initializer
void SetInitializer(Value* initializer);
/// @returns the initializer
- const Value* Initializer() const { return initializer_; }
+ const Value* Initializer() const { return operands_[0]; }
/// Sets the binding point
/// @param group the group
@@ -49,8 +49,7 @@
std::optional<struct BindingPoint> BindingPoint() const { return binding_point_; }
private:
- const type::Type* type_;
- Value* initializer_ = nullptr;
+ const type::Type* type_ = nullptr;
std::optional<struct BindingPoint> binding_point_;
};
diff --git a/src/tint/ir/var_test.cc b/src/tint/ir/var_test.cc
new file mode 100644
index 0000000..ebde1bc
--- /dev/null
+++ b/src/tint/ir/var_test.cc
@@ -0,0 +1,52 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/var.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/instruction.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_VarTest = IRTestHelper;
+
+TEST_F(IR_VarTest, Fail_NullType) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+ b.Declare(nullptr);
+ },
+ "");
+}
+
+TEST_F(IR_VarTest, Initializer_Usage) {
+ Module mod;
+ Builder b{mod};
+ auto* var = b.Declare(mod.Types().f32());
+ auto* init = b.Constant(1_f);
+ var->SetInitializer(init);
+
+ EXPECT_THAT(init->Usages(), testing::UnorderedElementsAre(Usage{var, 0u}));
+ var->SetInitializer(nullptr);
+ EXPECT_TRUE(init->Usages().IsEmpty());
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/type/struct.h b/src/tint/type/struct.h
index 08119d0..dd2f7da 100644
--- a/src/tint/type/struct.h
+++ b/src/tint/type/struct.h
@@ -45,6 +45,14 @@
kComputeOutput,
};
+enum StructFlag {
+ /// The structure is a block-decorated structure (for SPIR-V or GLSL).
+ kBlock,
+};
+
+/// An alias to utils::EnumSet<StructFlag>
+using StructFlags = utils::EnumSet<StructFlag>;
+
/// Struct holds the Type information for structures.
class Struct : public utils::Castable<Struct, Type> {
public:
@@ -94,6 +102,13 @@
/// alignment padding
uint32_t SizeNoPadding() const { return size_no_padding_; }
+ /// @returns the structure flags
+ type::StructFlags StructFlags() const { return struct_flags_; }
+
+ /// Set a structure flag.
+ /// @param flag the flag to set
+ void SetStructFlag(StructFlag flag) { struct_flags_.Add(flag); }
+
/// Adds the AddressSpace usage to the structure.
/// @param usage the storage usage
void AddUsage(builtin::AddressSpace usage) { address_space_usage_.emplace(usage); }
@@ -153,6 +168,7 @@
const uint32_t align_;
const uint32_t size_;
const uint32_t size_no_padding_;
+ type::StructFlags struct_flags_;
std::unordered_set<builtin::AddressSpace> address_space_usage_;
std::unordered_set<PipelineStageUsage> pipeline_stage_uses_;
utils::Vector<const Struct*, 2> concrete_types_;
diff --git a/src/tint/utils/hash.h b/src/tint/utils/hash.h
index af590f4..fe1614e 100644
--- a/src/tint/utils/hash.h
+++ b/src/tint/utils/hash.h
@@ -111,11 +111,11 @@
}
};
-/// Hasher specialization for utils::vector
+/// Hasher specialization for utils::Vector
template <typename T, size_t N>
struct Hasher<utils::Vector<T, N>> {
- /// @param vector the vector to hash
- /// @returns a hash of the vector
+ /// @param vector the Vector to hash
+ /// @returns a hash of the Vector
size_t operator()(const utils::Vector<T, N>& vector) const {
auto hash = Hash(vector.Length());
for (auto& el : vector) {
@@ -125,6 +125,20 @@
}
};
+/// Hasher specialization for utils::VectorRef
+template <typename T>
+struct Hasher<utils::VectorRef<T>> {
+ /// @param vector the VectorRef reference to hash
+ /// @returns a hash of the Vector
+ size_t operator()(const utils::VectorRef<T>& vector) const {
+ auto hash = Hash(vector.Length());
+ for (auto& el : vector) {
+ hash = HashCombine(hash, el);
+ }
+ return hash;
+ }
+};
+
/// Hasher specialization for std::tuple
template <typename... TYPES>
struct Hasher<std::tuple<TYPES...>> {
diff --git a/src/tint/utils/hash_test.cc b/src/tint/utils/hash_test.cc
index 2261eb8..80595d4 100644
--- a/src/tint/utils/hash_test.cc
+++ b/src/tint/utils/hash_test.cc
@@ -43,6 +43,21 @@
EXPECT_EQ(Hash(Vector<int, 3>({1, 2, 3})), Hash(Vector<int, 2>({1, 2, 3})));
}
+TEST(HashTests, TintVectorRef) {
+ EXPECT_EQ(Hash(VectorRef<int>(Vector<int, 0>({}))), Hash(VectorRef<int>(Vector<int, 0>({}))));
+ EXPECT_EQ(Hash(VectorRef<int>(Vector<int, 0>({1, 2, 3}))),
+ Hash(VectorRef<int>(Vector<int, 0>({1, 2, 3}))));
+ EXPECT_EQ(Hash(VectorRef<int>(Vector<int, 3>({1, 2, 3}))),
+ Hash(VectorRef<int>(Vector<int, 4>({1, 2, 3}))));
+ EXPECT_EQ(Hash(VectorRef<int>(Vector<int, 3>({1, 2, 3}))),
+ Hash(VectorRef<int>(Vector<int, 2>({1, 2, 3}))));
+
+ EXPECT_EQ(Hash(VectorRef<int>(Vector<int, 0>({}))), Hash(Vector<int, 0>({})));
+ EXPECT_EQ(Hash(VectorRef<int>(Vector<int, 0>({1, 2, 3}))), Hash(Vector<int, 0>({1, 2, 3})));
+ EXPECT_EQ(Hash(VectorRef<int>(Vector<int, 3>({1, 2, 3}))), Hash(Vector<int, 4>({1, 2, 3})));
+ EXPECT_EQ(Hash(VectorRef<int>(Vector<int, 3>({1, 2, 3}))), Hash(Vector<int, 2>({1, 2, 3})));
+}
+
TEST(HashTests, Tuple) {
EXPECT_EQ(Hash(std::make_tuple(1)), Hash(std::make_tuple(1)));
EXPECT_EQ(Hash(std::make_tuple(1, 2, 3)), Hash(std::make_tuple(1, 2, 3)));
diff --git a/src/tint/utils/slice.h b/src/tint/utils/slice.h
index ad4dbcf..ca25949 100644
--- a/src/tint/utils/slice.h
+++ b/src/tint/utils/slice.h
@@ -117,6 +117,15 @@
/// Constructor
Slice(EmptyType) {} // NOLINT
+ /// Copy constructor with covariance / const conversion
+ /// @param other the vector to copy
+ /// @see CanReinterpretSlice for rules about conversion
+ template <typename U,
+ typename = std::enable_if_t<CanReinterpretSlice<ReinterpretMode::kSafe, T, U>>>
+ Slice(const Slice<U>& other) { // NOLINT(runtime/explicit)
+ *this = other.template Reinterpret<T, ReinterpretMode::kSafe>();
+ }
+
/// Constructor
/// @param d pointer to the first element in the slice
/// @param l total number of elements in the slice
@@ -150,6 +159,29 @@
/// @return true if the slice length is zero
bool IsEmpty() const { return len == 0; }
+ /// @return the length of the slice
+ size_t Length() const { return len; }
+
+ /// Create a new slice that represents an offset into this slice
+ /// @param offset the number of elements to offset
+ /// @return the new slice
+ Slice<T> Offset(size_t offset) const {
+ if (offset > len) {
+ offset = len;
+ }
+ return Slice(data + offset, len - offset, cap - offset);
+ }
+
+ /// Create a new slice that represents a truncated version of this slice
+ /// @param length the new length
+ /// @return a new slice that is truncated to `length` elements
+ Slice<T> Truncate(size_t length) const {
+ if (length > len) {
+ length = len;
+ }
+ return Slice(data, length, length);
+ }
+
/// Index operator
/// @param i the element index. Must be less than `len`.
/// @returns a reference to the i'th element.
diff --git a/src/tint/utils/slice_test.cc b/src/tint/utils/slice_test.cc
index 6a4493c..5725870 100644
--- a/src/tint/utils/slice_test.cc
+++ b/src/tint/utils/slice_test.cc
@@ -72,6 +72,21 @@
EXPECT_TRUE(slice.IsEmpty());
}
+TEST(TintSliceTest, CtorCast) {
+ C1* elements[3];
+
+ Slice<C1*> slice_a;
+ slice_a.data = &elements[0];
+ slice_a.len = 3;
+ slice_a.cap = 3;
+
+ Slice<const C0*> slice_b(slice_a);
+ EXPECT_EQ(slice_b.data, Bitcast<const C0**>(&elements[0]));
+ EXPECT_EQ(slice_b.len, 3u);
+ EXPECT_EQ(slice_b.cap, 3u);
+ EXPECT_FALSE(slice_b.IsEmpty());
+}
+
TEST(TintSliceTest, CtorEmpty) {
Slice<int> slice{Empty};
EXPECT_EQ(slice.data, nullptr);
@@ -127,5 +142,44 @@
}
}
+TEST(TintSliceTest, Offset) {
+ int elements[] = {1, 2, 3};
+
+ auto slice = Slice{elements};
+ auto offset = slice.Offset(1);
+ EXPECT_EQ(offset.Length(), 2u);
+ EXPECT_EQ(offset[0], 2);
+ EXPECT_EQ(offset[1], 3);
+}
+
+TEST(TintSliceTest, Offset_PastEnd) {
+ int elements[] = {1, 2, 3};
+
+ auto slice = Slice{elements};
+ auto offset = slice.Offset(4);
+ EXPECT_EQ(offset.Length(), 0u);
+}
+
+TEST(TintSliceTest, Truncate) {
+ int elements[] = {1, 2, 3};
+
+ auto slice = Slice{elements};
+ auto truncated = slice.Truncate(2);
+ EXPECT_EQ(truncated.Length(), 2u);
+ EXPECT_EQ(truncated[0], 1);
+ EXPECT_EQ(truncated[1], 2);
+}
+
+TEST(TintSliceTest, Truncate_PastEnd) {
+ int elements[] = {1, 2, 3};
+
+ auto slice = Slice{elements};
+ auto truncated = slice.Truncate(4);
+ EXPECT_EQ(truncated.Length(), 3u);
+ EXPECT_EQ(truncated[0], 1);
+ EXPECT_EQ(truncated[1], 2);
+ EXPECT_EQ(truncated[2], 3);
+}
+
} // namespace
} // namespace tint::utils
diff --git a/src/tint/utils/string_stream.h b/src/tint/utils/string_stream.h
index ecb88f7..2ab62a2 100644
--- a/src/tint/utils/string_stream.h
+++ b/src/tint/utils/string_stream.h
@@ -177,6 +177,9 @@
return *this;
}
+ /// @returns the current location in the output stream
+ uint32_t tellp() { return static_cast<uint32_t>(sstream_.tellp()); }
+
/// @returns the string contents of the stream
std::string str() const { return sstream_.str(); }
diff --git a/src/tint/utils/vector.h b/src/tint/utils/vector.h
index acb0e39..5cb0ff3 100644
--- a/src/tint/utils/vector.h
+++ b/src/tint/utils/vector.h
@@ -133,6 +133,10 @@
/// @param other the vector reference to copy
Vector(const VectorRef<T>& other) { Copy(other.slice_); } // NOLINT(runtime/explicit)
+ /// Copy constructor from an immutable slice
+ /// @param other the slice to copy
+ Vector(const Slice<T>& other) { Copy(other); } // NOLINT(runtime/explicit)
+
/// Destructor
~Vector() { ClearAndFree(); }
@@ -194,6 +198,14 @@
return *this;
}
+ /// Assignment operator for Slice
+ /// @param other the slice to copy
+ /// @returns this vector so calls can be chained
+ Vector& operator=(const Slice<T>& other) {
+ Copy(other);
+ return *this;
+ }
+
/// Index operator
/// @param i the element index. Must be less than `len`.
/// @returns a reference to the i'th element.
@@ -402,6 +414,9 @@
/// @returns the internal slice of the vector
utils::Slice<T> Slice() { return impl_.slice; }
+ /// @returns the internal slice of the vector
+ utils::Slice<const T> Slice() const { return impl_.slice; }
+
private:
/// Friend class (differing specializations of this class)
template <typename, size_t>
diff --git a/src/tint/utils/vector_test.cc b/src/tint/utils/vector_test.cc
index bb70e08..31e1a86 100644
--- a/src/tint/utils/vector_test.cc
+++ b/src/tint/utils/vector_test.cc
@@ -1106,6 +1106,54 @@
EXPECT_TRUE(AllExternallyHeld(vec));
}
+TEST(TintVectorTest, CopyAssignSlice_N2_to_N2) {
+ std::string data[] = {"hello", "world"};
+ Slice<std::string> slice(data);
+ Vector<std::string, 2> vec_b;
+ vec_b = slice;
+ EXPECT_EQ(vec_b.Length(), 2u);
+ EXPECT_EQ(vec_b.Capacity(), 2u);
+ EXPECT_EQ(vec_b[0], "hello");
+ EXPECT_EQ(vec_b[1], "world");
+ EXPECT_TRUE(AllInternallyHeld(vec_b));
+}
+
+TEST(TintVectorTest, CopyAssignSlice_N2_to_N1) {
+ std::string data[] = {"hello", "world"};
+ Slice<std::string> slice(data);
+ Vector<std::string, 1> vec_b;
+ vec_b = slice;
+ EXPECT_EQ(vec_b.Length(), 2u);
+ EXPECT_EQ(vec_b.Capacity(), 2u);
+ EXPECT_EQ(vec_b[0], "hello");
+ EXPECT_EQ(vec_b[1], "world");
+ EXPECT_TRUE(AllExternallyHeld(vec_b));
+}
+
+TEST(TintVectorTest, CopyAssignSlice_N2_to_N3) {
+ std::string data[] = {"hello", "world"};
+ Slice<std::string> slice(data);
+ Vector<std::string, 3> vec_b;
+ vec_b = slice;
+ EXPECT_EQ(vec_b.Length(), 2u);
+ EXPECT_EQ(vec_b.Capacity(), 3u);
+ EXPECT_EQ(vec_b[0], "hello");
+ EXPECT_EQ(vec_b[1], "world");
+ EXPECT_TRUE(AllInternallyHeld(vec_b));
+}
+
+TEST(TintVectorTest, CopyAssignSlice_N2_to_N0) {
+ std::string data[] = {"hello", "world"};
+ Slice<std::string> slice(data);
+ Vector<std::string, 0> vec_b;
+ vec_b = slice;
+ EXPECT_EQ(vec_b.Length(), 2u);
+ EXPECT_EQ(vec_b.Capacity(), 2u);
+ EXPECT_EQ(vec_b[0], "hello");
+ EXPECT_EQ(vec_b[1], "world");
+ EXPECT_TRUE(AllExternallyHeld(vec_b));
+}
+
TEST(TintVectorTest, Index) {
Vector<std::string, 2> vec{"hello", "world"};
static_assert(!std::is_const_v<std::remove_reference_t<decltype(vec[0])>>);
@@ -1821,6 +1869,24 @@
EXPECT_FALSE(vec.All(Ne(9)));
}
+TEST(TintVectorTest, Slice) {
+ Vector<std::string, 3> vec{"hello", "world"};
+ auto slice = vec.Slice();
+ static_assert(std::is_same_v<decltype(slice), Slice<std::string>>);
+ EXPECT_EQ(slice.data, &vec[0]);
+ EXPECT_EQ(slice.len, 2u);
+ EXPECT_EQ(slice.cap, 3u);
+}
+
+TEST(TintVectorTest, SliceConst) {
+ const Vector<std::string, 3> vec{"hello", "world"};
+ auto slice = vec.Slice();
+ static_assert(std::is_same_v<decltype(slice), Slice<const std::string>>);
+ EXPECT_EQ(slice.data, &vec[0]);
+ EXPECT_EQ(slice.len, 2u);
+ EXPECT_EQ(slice.cap, 3u);
+}
+
TEST(TintVectorTest, ostream) {
utils::StringStream ss;
ss << Vector{1, 2, 3};
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index 6196254..a173b49 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -19,6 +19,7 @@
#include "spirv/unified1/GLSL.std.450.h"
#include "spirv/unified1/spirv.h"
#include "src/tint/constant/scalar.h"
+#include "src/tint/ir/access.h"
#include "src/tint/ir/binary.h"
#include "src/tint/ir/block.h"
#include "src/tint/ir/break_if.h"
@@ -36,7 +37,10 @@
#include "src/tint/ir/store.h"
#include "src/tint/ir/switch.h"
#include "src/tint/ir/transform/add_empty_entry_point.h"
+#include "src/tint/ir/transform/block_decorated_structs.h"
+#include "src/tint/ir/transform/var_for_dynamic_index.h"
#include "src/tint/ir/user_call.h"
+#include "src/tint/ir/validate.h"
#include "src/tint/ir/var.h"
#include "src/tint/switch.h"
#include "src/tint/transform/manager.h"
@@ -64,6 +68,8 @@
transform::DataMap data;
manager.Add<ir::transform::AddEmptyEntryPoint>();
+ manager.Add<ir::transform::BlockDecoratedStructs>();
+ manager.Add<ir::transform::VarForDynamicIndex>();
transform::DataMap outputs;
manager.Run(module, data, outputs);
@@ -92,6 +98,12 @@
: ir_(module), zero_init_workgroup_memory_(zero_init_workgroup_mem) {}
bool GeneratorImplIr::Generate() {
+ auto valid = ir::Validate(*ir_);
+ if (!valid) {
+ diagnostics_ = valid.Failure();
+ return false;
+ }
+
// Run the IR transformations to prepare for SPIR-V emission.
Sanitize(ir_);
@@ -291,6 +303,11 @@
}
module_.PushType(spv::Op::OpTypeStruct, std::move(operands));
+ // Add a Block decoration if necessary.
+ if (str->StructFlags().Contains(type::StructFlag::kBlock)) {
+ module_.PushAnnot(spv::Op::OpDecorate, {id, U32Operand(SpvDecorationBlock)});
+ }
+
if (str->Name().IsValid()) {
module_.PushDebug(spv::Op::OpName, {operands[0], Operand(str->Name().Name())});
}
@@ -409,6 +426,14 @@
return;
}
+ // Emit all OpPhi nodes for incoming branches to block.
+ EmitIncomingPhis(block);
+
+ // Emit the block's statements.
+ EmitBlockInstructions(block);
+}
+
+void GeneratorImplIr::EmitIncomingPhis(const ir::Block* block) {
// Emit Phi nodes for all the incoming block parameters
for (size_t param_idx = 0; param_idx < block->Params().Length(); param_idx++) {
auto* param = block->Params()[param_idx];
@@ -422,11 +447,13 @@
current_function_.push_inst(spv::Op::OpPhi, std::move(ops));
}
+}
- // Emit the instructions.
+void GeneratorImplIr::EmitBlockInstructions(const ir::Block* block) {
for (auto* inst : *block) {
Switch(
inst, //
+ [&](const ir::Access* a) { EmitAccess(a); }, //
[&](const ir::Binary* b) { EmitBinary(b); }, //
[&](const ir::Builtin* b) { EmitBuiltin(b); }, //
[&](const ir::Load* l) { EmitLoad(l); }, //
@@ -526,6 +553,58 @@
EmitBlock(merge_block);
}
+void GeneratorImplIr::EmitAccess(const ir::Access* access) {
+ auto id = Value(access);
+ OperandList operands = {Type(access->Type()), id, Value(access->Object())};
+
+ if (access->Type()->Is<type::Pointer>()) {
+ // Use OpAccessChain for accesses into pointer types.
+ for (auto* idx : access->Indices()) {
+ operands.push_back(Value(idx));
+ }
+ current_function_.push_inst(spv::Op::OpAccessChain, std::move(operands));
+ return;
+ }
+
+ // For non-pointer types, we assume that the indices are constants and use OpCompositeExtract.
+ // If we hit a non-constant index into a vector type, use OpVectorExtractDynamic for it.
+ auto* ty = access->Object()->Type();
+ for (auto* idx : access->Indices()) {
+ if (auto* constant = idx->As<ir::Constant>()) {
+ // Push the index to the chain and update the current type.
+ auto i = constant->Value()->ValueAs<u32>();
+ operands.push_back(i);
+ ty = Switch(
+ ty, //
+ [&](const type::Array* arr) { return arr->ElemType(); },
+ [&](const type::Matrix* mat) { return mat->ColumnType(); },
+ [&](const type::Struct* str) { return str->Members()[i]->Type(); },
+ [&](const type::Vector* vec) { return vec->type(); },
+ [&](Default) { return nullptr; });
+ } else {
+ // The VarForDynamicIndex transform ensures that only value types that are vectors
+ // will be dynamically indexed, as we can use OpVectorExtractDynamic for this case.
+ TINT_ASSERT(Writer, ty->Is<type::Vector>());
+
+ // If this wasn't the first access in the chain then emit the chain so far as an
+ // OpCompositeExtract, creating a new result ID for the resulting vector.
+ auto vec_id = Value(access->Object());
+ if (operands.size() > 3) {
+ vec_id = module_.NextId();
+ operands[0] = Type(ty);
+ operands[1] = vec_id;
+ current_function_.push_inst(spv::Op::OpCompositeExtract, std::move(operands));
+ }
+
+ // Now emit the OpVectorExtractDynamic instruction.
+ operands = {Type(access->Type()), id, vec_id, Value(idx)};
+ current_function_.push_inst(spv::Op::OpVectorExtractDynamic, std::move(operands));
+ return;
+ }
+ }
+ current_function_.push_inst(spv::Op::OpCompositeExtract, std::move(operands));
+}
+
void GeneratorImplIr::EmitBinary(const ir::Binary* binary) {
auto id = Value(binary);
auto* lhs_ty = binary->LHS()->Type();
@@ -702,24 +781,34 @@
}
void GeneratorImplIr::EmitLoop(const ir::Loop* loop) {
- auto header_label = module_.NextId();
- auto body_label = Label(loop->Body());
+ auto init_label = loop->HasInitializer() ? Label(loop->Initializer()) : 0;
+ auto header_label = Label(loop->Body()); // Back-edge needs to branch to the loop header
+ auto body_label = module_.NextId();
auto continuing_label = Label(loop->Continuing());
auto merge_label = Label(loop->Merge());
- // Branch to and emit the loop header, which contains OpLoopMerge and OpBranch instructions.
- current_function_.push_inst(spv::Op::OpBranch, {header_label});
+ if (init_label != 0) {
+ // Emit the loop initializer.
+ current_function_.push_inst(spv::Op::OpBranch, {init_label});
+ EmitBlock(loop->Initializer());
+ } else {
+ // No initializer. Branch to body.
+ current_function_.push_inst(spv::Op::OpBranch, {header_label});
+ }
+
+ // Emit the loop body header, which contains the OpLoopMerge and OpPhis.
+ // This then unconditionally branches to body_label
current_function_.push_inst(spv::Op::OpLabel, {header_label});
+ EmitIncomingPhis(loop->Body());
current_function_.push_inst(
spv::Op::OpLoopMerge, {merge_label, continuing_label, U32Operand(SpvLoopControlMaskNone)});
current_function_.push_inst(spv::Op::OpBranch, {body_label});
- // Emit the loop body.
- EmitBlock(loop->Body());
+ // Emit the loop body
+ current_function_.push_inst(spv::Op::OpLabel, {body_label});
+ EmitBlockInstructions(loop->Body());
// Emit the loop continuing block.
- // The back-edge needs to go to the loop header, so update the label for the start block.
- block_labels_.Replace(loop->Body(), header_label);
if (loop->Continuing()->HasBranchTarget()) {
EmitBlock(loop->Continuing());
} else {
@@ -809,6 +898,18 @@
module_.PushType(spv::Op::OpVariable, operands);
break;
}
+ case builtin::AddressSpace::kStorage:
+ case builtin::AddressSpace::kUniform: {
+ TINT_ASSERT(Writer, !current_function_);
+ module_.PushType(spv::Op::OpVariable,
+ {ty, id, U32Operand(StorageClass(ptr->AddressSpace()))});
+ auto bp = var->BindingPoint().value();
+ module_.PushAnnot(spv::Op::OpDecorate,
+ {id, U32Operand(SpvDecorationDescriptorSet), bp.group});
+ module_.PushAnnot(spv::Op::OpDecorate,
+ {id, U32Operand(SpvDecorationBinding), bp.binding});
+ break;
+ }
case builtin::AddressSpace::kWorkgroup: {
TINT_ASSERT(Writer, !current_function_);
OperandList operands = {ty, id, U32Operand(SpvStorageClassWorkgroup)};
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.h b/src/tint/writer/spirv/ir/generator_impl_ir.h
index 7c3105b..d3abeb6 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.h
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.h
@@ -29,6 +29,7 @@
// Forward declarations
namespace tint::ir {
+class Access;
class Binary;
class Block;
class BlockParam;
@@ -112,10 +113,18 @@
/// @param id the result ID of the function declaration
void EmitEntryPoint(const ir::Function* func, uint32_t id);
- /// Emit a block.
+ /// Emit a block, including the initial OpLabel, OpPhis and instructions.
/// @param block the block to emit
void EmitBlock(const ir::Block* block);
+ /// Emit all OpPhi nodes for incoming branches to @p block.
+ /// @param block the block to emit the OpPhis for
+ void EmitIncomingPhis(const ir::Block* block);
+
+ /// Emit all instructions of @p block.
+ /// @param block the block's instructions to emit
+ void EmitBlockInstructions(const ir::Block* block);
+
/// Emit the root block.
/// @param root_block the root block to emit
void EmitRootBlock(const ir::Block* root_block);
@@ -124,6 +133,10 @@
/// @param i the if node to emit
void EmitIf(const ir::If* i);
+ /// Emit an access instruction
+ /// @param access the access instruction to emit
+ void EmitAccess(const ir::Access* access);
+
/// Emit a binary instruction.
/// @param binary the binary instruction to emit
void EmitBinary(const ir::Binary* binary);
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc
new file mode 100644
index 0000000..62dd469
--- /dev/null
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc
@@ -0,0 +1,448 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/writer/spirv/ir/test_helper_ir.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::writer::spirv {
+namespace {
+
+class SpvGeneratorImplTest_Access : public SpvGeneratorImplTest {
+ protected:
+ const type::Type* ptr(const type::Type* elem) {
+ return ty.pointer(elem, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+ }
+};
+
+TEST_F(SpvGeneratorImplTest_Access, Array_Value_ConstantIndex) {
+ auto* arr_val = b.FunctionParam(ty.array(ty.i32(), 4));
+ auto* access = b.Access(ty.i32(), arr_val, utils::Vector{b.Constant(1_u)});
+ auto* func = b.CreateFunction("foo", ty.void_());
+ func->SetParams(utils::Vector{arr_val});
+ func->StartTarget()->SetInstructions(utils::Vector{access, b.Return(func)});
+
+ ASSERT_TRUE(IRIsValid()) << Error();
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+OpDecorate %3 ArrayStride 4
+%2 = OpTypeVoid
+%4 = OpTypeInt 32 1
+%6 = OpTypeInt 32 0
+%5 = OpConstant %6 4
+%3 = OpTypeArray %4 %5
+%8 = OpTypeFunction %2 %3
+%1 = OpFunction %2 None %8
+%7 = OpFunctionParameter %3
+%9 = OpLabel
+%10 = OpCompositeExtract %4 %7 1
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest_Access, Array_Pointer_ConstantIndex) {
+ auto* arr_var = b.Declare(ptr(ty.array(ty.i32(), 4)));
+ auto* access = b.Access(ptr(ty.i32()), arr_var, utils::Vector{b.Constant(1_u)});
+ auto* func = b.CreateFunction("foo", ty.void_());
+ func->StartTarget()->SetInstructions(utils::Vector{arr_var, access, b.Return(func)});
+
+ ASSERT_TRUE(IRIsValid()) << Error();
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+OpDecorate %7 ArrayStride 4
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%8 = OpTypeInt 32 1
+%10 = OpTypeInt 32 0
+%9 = OpConstant %10 4
+%7 = OpTypeArray %8 %9
+%6 = OpTypePointer Function %7
+%12 = OpTypePointer Function %8
+%13 = OpConstant %10 1
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+%5 = OpVariable %6 Function
+%11 = OpAccessChain %12 %5 %13
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest_Access, Array_Pointer_DynamicIndex) {
+ auto* arr_var = b.Declare(ptr(ty.array(ty.i32(), 4)));
+ auto* idx_var = b.Declare(ptr(ty.i32()));
+ auto* idx = b.Load(idx_var);
+ auto* access = b.Access(ptr(ty.i32()), arr_var, utils::Vector{idx});
+ auto* func = b.CreateFunction("foo", ty.void_());
+ func->StartTarget()->SetInstructions(
+ utils::Vector{idx_var, idx, arr_var, access, b.Return(func)});
+
+ ASSERT_TRUE(IRIsValid()) << Error();
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+OpDecorate %11 ArrayStride 4
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpTypePointer Function %7
+%13 = OpTypeInt 32 0
+%12 = OpConstant %13 4
+%11 = OpTypeArray %7 %12
+%10 = OpTypePointer Function %11
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+%5 = OpVariable %6 Function
+%9 = OpVariable %10 Function
+%8 = OpLoad %7 %5
+%14 = OpAccessChain %6 %9 %8
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest_Access, Matrix_Value_ConstantIndex) {
+ auto* mat_val = b.FunctionParam(ty.mat2x2(ty.f32()));
+ auto* access_vec = b.Access(ty.vec2(ty.f32()), mat_val, utils::Vector{b.Constant(1_u)});
+ auto* access_el = b.Access(ty.f32(), mat_val, utils::Vector{b.Constant(1_u), b.Constant(0_u)});
+ auto* func = b.CreateFunction("foo", ty.void_());
+ func->SetParams(utils::Vector{mat_val});
+ func->StartTarget()->SetInstructions(utils::Vector{access_vec, access_el, b.Return(func)});
+
+ ASSERT_TRUE(IRIsValid()) << Error();
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%5 = OpTypeFloat 32
+%4 = OpTypeVector %5 2
+%3 = OpTypeMatrix %4 2
+%7 = OpTypeFunction %2 %3
+%1 = OpFunction %2 None %7
+%6 = OpFunctionParameter %3
+%8 = OpLabel
+%9 = OpCompositeExtract %4 %6 1
+%10 = OpCompositeExtract %5 %6 1 0
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest_Access, Matrix_Pointer_ConstantIndex) {
+ auto* mat_var = b.Declare(ptr(ty.mat2x2(ty.f32())));
+ auto* access_vec = b.Access(ptr(ty.vec2(ty.f32())), mat_var, utils::Vector{b.Constant(1_u)});
+ auto* access_el =
+ b.Access(ptr(ty.f32()), mat_var, utils::Vector{b.Constant(1_u), b.Constant(0_u)});
+ auto* func = b.CreateFunction("foo", ty.void_());
+ func->StartTarget()->SetInstructions(utils::Vector{access_vec, access_el, b.Return(func)});
+
+ ASSERT_TRUE(IRIsValid()) << Error();
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%8 = OpTypeFloat 32
+%7 = OpTypeVector %8 2
+%6 = OpTypePointer Function %7
+%11 = OpTypeInt 32 0
+%10 = OpConstant %11 1
+%13 = OpTypePointer Function %8
+%14 = OpConstant %11 0
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+%5 = OpAccessChain %6 %9 %10
+%12 = OpAccessChain %13 %9 %10 %14
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest_Access, Matrix_Pointer_DynamicIndex) {
+ auto* mat_var = b.Declare(ptr(ty.mat2x2(ty.f32())));
+ auto* idx_var = b.Declare(ptr(ty.i32()));
+ auto* idx = b.Load(idx_var);
+ auto* access_vec = b.Access(ptr(ty.vec2(ty.f32())), mat_var, utils::Vector{idx});
+ auto* access_el = b.Access(ptr(ty.f32()), mat_var, utils::Vector{idx, idx});
+ auto* func = b.CreateFunction("foo", ty.void_());
+ func->StartTarget()->SetInstructions(
+ utils::Vector{idx_var, idx, mat_var, access_vec, access_el, b.Return(func)});
+
+ ASSERT_TRUE(IRIsValid()) << Error();
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpTypePointer Function %7
+%13 = OpTypeFloat 32
+%12 = OpTypeVector %13 2
+%11 = OpTypeMatrix %12 2
+%10 = OpTypePointer Function %11
+%15 = OpTypePointer Function %12
+%17 = OpTypePointer Function %13
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+%5 = OpVariable %6 Function
+%9 = OpVariable %10 Function
+%8 = OpLoad %7 %5
+%14 = OpAccessChain %15 %9 %8
+%16 = OpAccessChain %17 %9 %8 %8
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest_Access, Vector_Value_ConstantIndex) {
+ auto* vec_val = b.FunctionParam(ty.vec4(ty.i32()));
+ auto* access = b.Access(ty.i32(), vec_val, utils::Vector{b.Constant(1_u)});
+ auto* func = b.CreateFunction("foo", ty.void_());
+ func->SetParams(utils::Vector{vec_val});
+ func->StartTarget()->SetInstructions(utils::Vector{access, b.Return(func)});
+
+ ASSERT_TRUE(IRIsValid()) << Error();
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%4 = OpTypeInt 32 1
+%3 = OpTypeVector %4 4
+%6 = OpTypeFunction %2 %3
+%1 = OpFunction %2 None %6
+%5 = OpFunctionParameter %3
+%7 = OpLabel
+%8 = OpCompositeExtract %4 %5 1
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest_Access, Vector_Value_DynamicIndex) {
+ auto* vec_val = b.FunctionParam(ty.vec4(ty.i32()));
+ auto* idx_var = b.Declare(ptr(ty.i32()));
+ auto* idx = b.Load(idx_var);
+ auto* access = b.Access(ty.i32(), vec_val, utils::Vector{idx});
+ auto* func = b.CreateFunction("foo", ty.void_());
+ func->SetParams(utils::Vector{vec_val});
+ func->StartTarget()->SetInstructions(utils::Vector{idx_var, idx, access, b.Return(func)});
+
+ ASSERT_TRUE(IRIsValid()) << Error();
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%4 = OpTypeInt 32 1
+%3 = OpTypeVector %4 4
+%6 = OpTypeFunction %2 %3
+%9 = OpTypePointer Function %4
+%1 = OpFunction %2 None %6
+%5 = OpFunctionParameter %3
+%7 = OpLabel
+%8 = OpVariable %9 Function
+%10 = OpLoad %4 %8
+%11 = OpVectorExtractDynamic %4 %5 %10
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest_Access, Vector_Pointer_ConstantIndex) {
+ auto* vec_var = b.Declare(ptr(ty.vec4(ty.i32())));
+ auto* access = b.Access(ptr(ty.i32()), vec_var, utils::Vector{b.Constant(1_u)});
+ auto* func = b.CreateFunction("foo", ty.void_());
+ func->StartTarget()->SetInstructions(utils::Vector{vec_var, access, b.Return(func)});
+
+ ASSERT_TRUE(IRIsValid()) << Error();
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%8 = OpTypeInt 32 1
+%7 = OpTypeVector %8 4
+%6 = OpTypePointer Function %7
+%10 = OpTypePointer Function %8
+%12 = OpTypeInt 32 0
+%11 = OpConstant %12 1
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+%5 = OpVariable %6 Function
+%9 = OpAccessChain %10 %5 %11
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest_Access, Vector_Pointer_DynamicIndex) {
+ auto* vec_var = b.Declare(ptr(ty.vec4(ty.i32())));
+ auto* idx_var = b.Declare(ptr(ty.i32()));
+ auto* idx = b.Load(idx_var);
+ auto* access = b.Access(ptr(ty.i32()), vec_var, utils::Vector{idx});
+ auto* func = b.CreateFunction("foo", ty.void_());
+ func->StartTarget()->SetInstructions(
+ utils::Vector{idx_var, idx, vec_var, access, b.Return(func)});
+
+ ASSERT_TRUE(IRIsValid()) << Error();
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpTypePointer Function %7
+%11 = OpTypeVector %7 4
+%10 = OpTypePointer Function %11
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+%5 = OpVariable %6 Function
+%9 = OpVariable %10 Function
+%8 = OpLoad %7 %5
+%12 = OpAccessChain %6 %9 %8
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest_Access, NestedVector_Value_DynamicIndex) {
+ auto* val = b.FunctionParam(ty.array(ty.array(ty.vec4(ty.i32()), 4), 4));
+ auto* idx_var = b.Declare(ptr(ty.i32()));
+ auto* idx = b.Load(idx_var);
+ auto* access = b.Access(ty.i32(), val, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx});
+ auto* func = b.CreateFunction("foo", ty.void_());
+ func->SetParams(utils::Vector{val});
+ func->StartTarget()->SetInstructions(utils::Vector{idx_var, idx, access, b.Return(func)});
+
+ ASSERT_TRUE(IRIsValid()) << Error();
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+OpDecorate %4 ArrayStride 16
+OpDecorate %3 ArrayStride 64
+%2 = OpTypeVoid
+%6 = OpTypeInt 32 1
+%5 = OpTypeVector %6 4
+%8 = OpTypeInt 32 0
+%7 = OpConstant %8 4
+%4 = OpTypeArray %5 %7
+%3 = OpTypeArray %4 %7
+%10 = OpTypeFunction %2 %3
+%13 = OpTypePointer Function %6
+%1 = OpFunction %2 None %10
+%9 = OpFunctionParameter %3
+%11 = OpLabel
+%12 = OpVariable %13 Function
+%14 = OpLoad %6 %12
+%16 = OpCompositeExtract %5 %9 1 2
+%15 = OpVectorExtractDynamic %6 %16 %14
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest_Access, Struct_Value_ConstantIndex) {
+ auto* str = ty.Get<type::Struct>(
+ mod.symbols.Register("MyStruct"),
+ utils::Vector{
+ ty.Get<type::StructMember>(mod.symbols.Register("a"), ty.f32(), 0u, 0u, 4u, 4u,
+ type::StructMemberAttributes{}),
+ ty.Get<type::StructMember>(mod.symbols.Register("b"), ty.vec4(ty.i32()), 1u, 16u, 16u,
+ 16u, type::StructMemberAttributes{}),
+ },
+ 16u, 32u, 32u);
+ auto* str_val = b.FunctionParam(str);
+ auto* access_vec = b.Access(ty.i32(), str_val, utils::Vector{b.Constant(1_u)});
+ auto* access_el = b.Access(ty.i32(), str_val, utils::Vector{b.Constant(1_u), b.Constant(2_u)});
+ auto* func = b.CreateFunction("foo", ty.void_());
+ func->SetParams(utils::Vector{str_val});
+ func->StartTarget()->SetInstructions(utils::Vector{access_vec, access_el, b.Return(func)});
+
+ ASSERT_TRUE(IRIsValid()) << Error();
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+OpMemberName %3 0 "a"
+OpMemberName %3 1 "b"
+OpName %3 "MyStruct"
+OpMemberDecorate %3 0 Offset 0
+OpMemberDecorate %3 1 Offset 16
+%2 = OpTypeVoid
+%4 = OpTypeFloat 32
+%6 = OpTypeInt 32 1
+%5 = OpTypeVector %6 4
+%3 = OpTypeStruct %4 %5
+%8 = OpTypeFunction %2 %3
+%1 = OpFunction %2 None %8
+%7 = OpFunctionParameter %3
+%9 = OpLabel
+%10 = OpCompositeExtract %6 %7 1
+%11 = OpCompositeExtract %6 %7 1 2
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest_Access, Struct_Pointer_ConstantIndex) {
+ auto* str = ty.Get<type::Struct>(
+ mod.symbols.Register("MyStruct"),
+ utils::Vector{
+ ty.Get<type::StructMember>(mod.symbols.Register("a"), ty.f32(), 0u, 0u, 4u, 4u,
+ type::StructMemberAttributes{}),
+ ty.Get<type::StructMember>(mod.symbols.Register("b"), ty.vec4(ty.i32()), 1u, 16u, 16u,
+ 16u, type::StructMemberAttributes{}),
+ },
+ 16u, 32u, 32u);
+ auto* str_var = b.Declare(ptr(str));
+ auto* access_vec = b.Access(ptr(ty.i32()), str_var, utils::Vector{b.Constant(1_u)});
+ auto* access_el =
+ b.Access(ptr(ty.i32()), str_var, utils::Vector{b.Constant(1_u), b.Constant(2_u)});
+ auto* func = b.CreateFunction("foo", ty.void_());
+ func->StartTarget()->SetInstructions(
+ utils::Vector{str_var, access_vec, access_el, b.Return(func)});
+
+ ASSERT_TRUE(IRIsValid()) << Error();
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+OpMemberName %7 0 "a"
+OpMemberName %7 1 "b"
+OpName %7 "MyStruct"
+OpMemberDecorate %7 0 Offset 0
+OpMemberDecorate %7 1 Offset 16
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%8 = OpTypeFloat 32
+%10 = OpTypeInt 32 1
+%9 = OpTypeVector %10 4
+%7 = OpTypeStruct %8 %9
+%6 = OpTypePointer Function %7
+%12 = OpTypePointer Function %10
+%14 = OpTypeInt 32 0
+%13 = OpConstant %14 1
+%16 = OpConstant %14 2
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+%5 = OpVariable %6 Function
+%11 = OpAccessChain %12 %5 %13
+%15 = OpAccessChain %12 %5 %13 %16
+OpReturn
+OpFunctionEnd
+)");
+}
+
+} // namespace
+} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
index 5ae21b3..5481e03 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
@@ -42,6 +42,8 @@
MakeScalarValue(params.type), MakeScalarValue(params.type)),
b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
@@ -55,6 +57,8 @@
b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
@@ -89,6 +93,8 @@
MakeScalarValue(params.type), MakeScalarValue(params.type)),
b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
@@ -102,6 +108,8 @@
b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
@@ -128,6 +136,8 @@
MakeScalarValue(params.type)),
b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
@@ -141,6 +151,8 @@
b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
@@ -195,6 +207,8 @@
auto* a = b.Subtract(ty.i32(), b.Constant(1_i), b.Constant(2_i));
func->StartTarget()->SetInstructions(utils::Vector{a, b.Add(ty.i32(), a, a), b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
index 69d3f0f..04d89fd 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
@@ -43,6 +43,8 @@
utils::Vector{MakeScalarValue(params.type)}),
b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
@@ -56,6 +58,8 @@
b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
@@ -72,6 +76,8 @@
func->StartTarget()->SetInstructions(
utils::Vector{result, b.Return(func, utils::Vector{result})});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeInt 32 0
@@ -90,6 +96,8 @@
func->StartTarget()->SetInstructions(
utils::Vector{result, b.Return(func, utils::Vector{result})});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%3 = OpTypeInt 32 0
@@ -116,6 +124,8 @@
utils::Vector{MakeScalarValue(params.type), MakeScalarValue(params.type)}),
b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
@@ -129,6 +139,8 @@
b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
index 0eaa859..0879e58 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
@@ -21,6 +21,8 @@
auto* func = b.CreateFunction("foo", ty.void_());
func->StartTarget()->SetInstructions(utils::Vector{b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -37,6 +39,8 @@
auto* func = b.CreateFunction("foo", ty.void_());
func->StartTarget()->SetInstructions(utils::Vector{b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
generator_.EmitFunction(func);
generator_.EmitFunction(func);
@@ -50,6 +54,8 @@
b.CreateFunction("main", ty.void_(), ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
func->StartTarget()->SetInstructions(utils::Vector{b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpEntryPoint GLCompute %1 "main"
OpExecutionMode %1 LocalSize 32 4 1
@@ -67,6 +73,8 @@
auto* func = b.CreateFunction("main", ty.void_(), ir::Function::PipelineStage::kFragment);
func->StartTarget()->SetInstructions(utils::Vector{b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpEntryPoint Fragment %1 "main"
OpExecutionMode %1 OriginUpperLeft
@@ -84,6 +92,8 @@
auto* func = b.CreateFunction("main", ty.void_(), ir::Function::PipelineStage::kVertex);
func->StartTarget()->SetInstructions(utils::Vector{b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpEntryPoint Vertex %1 "main"
OpName %1 "main"
@@ -108,6 +118,8 @@
auto* f3 = b.CreateFunction("main3", ty.void_(), ir::Function::PipelineStage::kFragment);
f3->StartTarget()->SetInstructions(utils::Vector{b.Return(f3)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(f1);
generator_.EmitFunction(f2);
generator_.EmitFunction(f3);
@@ -142,6 +154,8 @@
func->StartTarget()->SetInstructions(
utils::Vector{b.Return(func, utils::Vector{b.Constant(i32(42))})});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeInt 32 1
@@ -166,6 +180,8 @@
mod.SetName(x, "x");
mod.SetName(y, "y");
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
OpName %3 "x"
@@ -197,6 +213,8 @@
b.UserCall(i32_ty, foo, utils::Vector{b.Constant(i32(2)), b.Constant(i32(3))}),
b.Return(bar)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(foo);
generator_.EmitFunction(bar);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
@@ -230,6 +248,8 @@
bar->StartTarget()->SetInstructions(
utils::Vector{b.UserCall(ty.void_(), foo, utils::Empty), b.Return(bar)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(foo);
generator_.EmitFunction(bar);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
index 792b735..4c4e5c3 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
@@ -29,6 +29,8 @@
func->StartTarget()->SetInstructions(utils::Vector{i});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -58,6 +60,8 @@
func->StartTarget()->SetInstructions(utils::Vector{i});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -92,6 +96,8 @@
func->StartTarget()->SetInstructions(utils::Vector{i});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -122,6 +128,8 @@
func->StartTarget()->SetInstructions(utils::Vector{i});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -155,6 +163,8 @@
func->StartTarget()->SetInstructions(utils::Vector{i});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -192,6 +202,8 @@
func->StartTarget()->SetInstructions(utils::Vector{i});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -229,6 +241,8 @@
func->StartTarget()->SetInstructions(utils::Vector{i});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -271,6 +285,8 @@
func->StartTarget()->SetInstructions(utils::Vector{i});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
index cb43f9d..807833f 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
@@ -30,6 +30,8 @@
func->StartTarget()->Append(loop);
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -63,6 +65,8 @@
func->StartTarget()->Append(loop);
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -99,6 +103,8 @@
func->StartTarget()->Append(loop);
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -142,6 +148,8 @@
func->StartTarget()->Append(loop);
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -180,6 +188,8 @@
func->StartTarget()->Append(loop);
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -213,6 +223,8 @@
func->StartTarget()->Append(loop);
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -252,6 +264,8 @@
func->StartTarget()->Append(outer_loop);
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -299,6 +313,8 @@
func->StartTarget()->Append(outer_loop);
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -333,9 +349,12 @@
TEST_F(SpvGeneratorImplTest, Loop_Phi_SingleValue) {
auto* func = b.CreateFunction("foo", ty.void_());
- auto* l = b.CreateLoop(utils::Vector{b.Constant(1_i)});
+ auto* l = b.CreateLoop();
func->StartTarget()->Append(l);
+ l->Initializer()->AddInboundBranch(l);
+ l->Initializer()->Append(b.NextIteration(l, utils::Vector{b.Constant(1_i)}));
+
auto* loop_param = b.BlockParam(b.ir.Types().i32());
l->Body()->SetParams(utils::Vector{loop_param});
auto* inc = b.Add(b.ir.Types().i32(), loop_param, b.Constant(1_i));
@@ -348,29 +367,33 @@
l->Continuing()->Append(cmp);
l->Continuing()->Append(b.BreakIf(cmp, l, utils::Vector{cont_param}));
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
-%9 = OpTypeInt 32 1
-%11 = OpConstant %9 1
+%10 = OpTypeInt 32 1
+%12 = OpConstant %10 1
%16 = OpTypeBool
-%17 = OpConstant %9 5
+%17 = OpConstant %10 5
%1 = OpFunction %2 None %3
%4 = OpLabel
OpBranch %5
%5 = OpLabel
-OpLoopMerge %8 %7 None
OpBranch %6
%6 = OpLabel
-%10 = OpPhi %9 %11 %12 %13 %7
-%14 = OpIAdd %9 %10 %11
+%11 = OpPhi %10 %12 %5 %13 %8
+OpLoopMerge %9 %8 None
OpBranch %7
%7 = OpLabel
-%13 = OpPhi %9 %14 %5
-%15 = OpSGreaterThan %16 %13 %17
-OpBranchConditional %15 %8 %5
+%14 = OpIAdd %10 %11 %12
+OpBranch %8
%8 = OpLabel
+%13 = OpPhi %10 %14 %6
+%15 = OpSGreaterThan %16 %13 %17
+OpBranchConditional %15 %9 %6
+%9 = OpLabel
OpUnreachable
OpFunctionEnd
)");
@@ -379,9 +402,12 @@
TEST_F(SpvGeneratorImplTest, Loop_Phi_MultipleValue) {
auto* func = b.CreateFunction("foo", ty.void_());
- auto* l = b.CreateLoop(utils::Vector{b.Constant(1_i), b.Constant(false)});
+ auto* l = b.CreateLoop();
func->StartTarget()->Append(l);
+ l->Initializer()->AddInboundBranch(l);
+ l->Initializer()->Append(b.NextIteration(l, utils::Vector{b.Constant(1_i), b.Constant(false)}));
+
auto* loop_param_a = b.BlockParam(b.ir.Types().i32());
auto* loop_param_b = b.BlockParam(b.ir.Types().bool_());
l->Body()->SetParams(utils::Vector{loop_param_a, loop_param_b});
@@ -398,33 +424,37 @@
l->Continuing()->Append(not_b);
l->Continuing()->Append(b.BreakIf(cmp, l, utils::Vector{cont_param_a, not_b}));
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
-%9 = OpTypeInt 32 1
-%11 = OpConstant %9 1
+%10 = OpTypeInt 32 1
+%12 = OpConstant %10 1
%14 = OpTypeBool
%16 = OpConstantFalse %14
-%21 = OpConstant %9 5
+%21 = OpConstant %10 5
%1 = OpFunction %2 None %3
%4 = OpLabel
OpBranch %5
%5 = OpLabel
-OpLoopMerge %8 %7 None
OpBranch %6
%6 = OpLabel
-%10 = OpPhi %9 %11 %12 %13 %7
-%15 = OpPhi %14 %16 %12 %17 %7
-%18 = OpIAdd %9 %10 %11
+%11 = OpPhi %10 %12 %5 %13 %8
+%15 = OpPhi %14 %16 %5 %17 %8
+OpLoopMerge %9 %8 None
OpBranch %7
%7 = OpLabel
-%13 = OpPhi %9 %18 %5
-%19 = OpPhi %14 %15 %5
+%18 = OpIAdd %10 %11 %12
+OpBranch %8
+%8 = OpLabel
+%13 = OpPhi %10 %18 %6
+%19 = OpPhi %14 %15 %6
%20 = OpSGreaterThan %14 %13 %21
%17 = OpLogicalEqual %14 %19 %16
-OpBranchConditional %20 %8 %5
-%8 = OpLabel
+OpBranchConditional %20 %9 %6
+%9 = OpLabel
OpUnreachable
OpFunctionEnd
)");
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
index dca1dac..64e9491 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
@@ -31,6 +31,8 @@
func->StartTarget()->Append(swtch);
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -67,6 +69,8 @@
func->StartTarget()->Append(swtch);
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -110,6 +114,8 @@
func->StartTarget()->Append(swtch);
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -148,6 +154,8 @@
func->StartTarget()->Append(swtch);
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -190,6 +198,8 @@
func->StartTarget()->Append(swtch);
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -235,6 +245,8 @@
func->StartTarget()->SetInstructions(utils::Vector{s});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -274,6 +286,8 @@
func->StartTarget()->SetInstructions(utils::Vector{s});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -316,6 +330,8 @@
func->StartTarget()->SetInstructions(utils::Vector{s});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
index 516e504..70c1712 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
@@ -28,6 +28,8 @@
builtin::Access::kReadWrite)),
b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -51,6 +53,8 @@
func->StartTarget()->SetInstructions(utils::Vector{v, b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -75,6 +79,8 @@
func->StartTarget()->SetInstructions(utils::Vector{v, b.Return(func)});
mod.SetName(v, "myvar");
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
OpName %5 "myvar"
@@ -104,6 +110,8 @@
func->StartTarget()->SetInstructions(utils::Vector{i});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -137,6 +145,8 @@
ty.pointer(store_ty, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
func->StartTarget()->SetInstructions(utils::Vector{v, b.Load(v), b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -160,6 +170,8 @@
func->StartTarget()->SetInstructions(
utils::Vector{v, b.Store(v, b.Constant(42_i)), b.Return(func)});
+ ASSERT_TRUE(IRIsValid()) << Error();
+
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
@@ -180,7 +192,7 @@
b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(
ty.pointer(ty.i32(), builtin::AddressSpace::kPrivate, builtin::Access::kReadWrite))});
- generator_.Generate();
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %4 "unused_entry_point"
@@ -204,7 +216,7 @@
b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
v->SetInitializer(b.Constant(42_i));
- generator_.Generate();
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %5 "unused_entry_point"
@@ -230,7 +242,7 @@
v->SetInitializer(b.Constant(42_i));
mod.SetName(v, "myvar");
- generator_.Generate();
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %5 "unused_entry_point"
@@ -265,7 +277,7 @@
auto* store = b.Store(v, add);
func->StartTarget()->SetInstructions(utils::Vector{load, add, store, b.Return(func)});
- generator_.Generate();
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %5 "foo"
@@ -292,7 +304,7 @@
b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(
ty.pointer(ty.i32(), builtin::AddressSpace::kWorkgroup, builtin::Access::kReadWrite))});
- generator_.Generate();
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %4 "unused_entry_point"
@@ -316,7 +328,7 @@
b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
mod.SetName(v, "myvar");
- generator_.Generate();
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %4 "unused_entry_point"
@@ -350,7 +362,7 @@
auto* store = b.Store(v, add);
func->StartTarget()->SetInstructions(utils::Vector{load, add, store, b.Return(func)});
- generator_.Generate();
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %4 "foo"
@@ -378,7 +390,7 @@
// Create a generator with the zero_init_workgroup_memory flag set to `true`.
spirv::GeneratorImplIr gen(&mod, true);
- gen.Generate();
+ ASSERT_TRUE(gen.Generate()) << gen.Diagnostics().str();
EXPECT_EQ(DumpModule(gen.Module()), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %5 "unused_entry_point"
@@ -397,5 +409,224 @@
)");
}
+TEST_F(SpvGeneratorImplTest, StorageVar) {
+ auto* v = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
+ v->SetBindingPoint(0, 0);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %5 "unused_entry_point"
+OpExecutionMode %5 LocalSize 1 1 1
+OpMemberName %3 0 "tint_symbol"
+OpName %3 "tint_symbol_1"
+OpName %5 "unused_entry_point"
+OpMemberDecorate %3 0 Offset 0
+OpDecorate %3 Block
+OpDecorate %1 DescriptorSet 0
+OpDecorate %1 Binding 0
+%4 = OpTypeInt 32 1
+%3 = OpTypeStruct %4
+%2 = OpTypePointer StorageBuffer %3
+%1 = OpVariable %2 StorageBuffer
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%5 = OpFunction %6 None %7
+%8 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, StorageVar_Name) {
+ auto* v = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
+ v->SetBindingPoint(0, 0);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ mod.SetName(v, "myvar");
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %5 "unused_entry_point"
+OpExecutionMode %5 LocalSize 1 1 1
+OpMemberName %3 0 "tint_symbol"
+OpName %3 "tint_symbol_1"
+OpName %5 "unused_entry_point"
+OpMemberDecorate %3 0 Offset 0
+OpDecorate %3 Block
+OpDecorate %1 DescriptorSet 0
+OpDecorate %1 Binding 0
+%4 = OpTypeInt 32 1
+%3 = OpTypeStruct %4
+%2 = OpTypePointer StorageBuffer %3
+%1 = OpVariable %2 StorageBuffer
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%5 = OpFunction %6 None %7
+%8 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, StorageVar_LoadAndStore) {
+ auto* v = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
+ v->SetBindingPoint(0, 0);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+
+ auto* func = b.CreateFunction("foo", ty.void_(), ir::Function::PipelineStage::kCompute,
+ std::array{1u, 1u, 1u});
+ mod.functions.Push(func);
+
+ auto* load = b.Load(v);
+ auto* add = b.Add(ty.i32(), v, b.Constant(1_i));
+ auto* store = b.Store(v, add);
+ func->StartTarget()->SetInstructions(utils::Vector{load, add, store, b.Return(func)});
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %5 "foo"
+OpExecutionMode %5 LocalSize 1 1 1
+OpMemberName %3 0 "tint_symbol"
+OpName %3 "tint_symbol_1"
+OpName %5 "foo"
+OpMemberDecorate %3 0 Offset 0
+OpDecorate %3 Block
+OpDecorate %1 DescriptorSet 0
+OpDecorate %1 Binding 0
+%4 = OpTypeInt 32 1
+%3 = OpTypeStruct %4
+%2 = OpTypePointer StorageBuffer %3
+%1 = OpVariable %2 StorageBuffer
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%10 = OpTypePointer StorageBuffer %4
+%12 = OpTypeInt 32 0
+%11 = OpConstant %12 0
+%16 = OpConstant %4 1
+%5 = OpFunction %6 None %7
+%8 = OpLabel
+%9 = OpAccessChain %10 %1 %11
+%13 = OpLoad %4 %9
+%14 = OpAccessChain %10 %1 %11
+%15 = OpIAdd %4 %14 %16
+%17 = OpAccessChain %10 %1 %11
+OpStore %17 %15
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, UniformVar) {
+ auto* v = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kUniform, builtin::Access::kReadWrite));
+ v->SetBindingPoint(0, 0);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %5 "unused_entry_point"
+OpExecutionMode %5 LocalSize 1 1 1
+OpMemberName %3 0 "tint_symbol"
+OpName %3 "tint_symbol_1"
+OpName %5 "unused_entry_point"
+OpMemberDecorate %3 0 Offset 0
+OpDecorate %3 Block
+OpDecorate %1 DescriptorSet 0
+OpDecorate %1 Binding 0
+%4 = OpTypeInt 32 1
+%3 = OpTypeStruct %4
+%2 = OpTypePointer Uniform %3
+%1 = OpVariable %2 Uniform
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%5 = OpFunction %6 None %7
+%8 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, UniformVar_Name) {
+ auto* v = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kUniform, builtin::Access::kReadWrite));
+ v->SetBindingPoint(0, 0);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ mod.SetName(v, "myvar");
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %5 "unused_entry_point"
+OpExecutionMode %5 LocalSize 1 1 1
+OpMemberName %3 0 "tint_symbol"
+OpName %3 "tint_symbol_1"
+OpName %5 "unused_entry_point"
+OpMemberDecorate %3 0 Offset 0
+OpDecorate %3 Block
+OpDecorate %1 DescriptorSet 0
+OpDecorate %1 Binding 0
+%4 = OpTypeInt 32 1
+%3 = OpTypeStruct %4
+%2 = OpTypePointer Uniform %3
+%1 = OpVariable %2 Uniform
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%5 = OpFunction %6 None %7
+%8 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, UniformVar_Load) {
+ auto* v = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kUniform, builtin::Access::kReadWrite));
+ v->SetBindingPoint(0, 0);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+
+ auto* func = b.CreateFunction("foo", ty.void_(), ir::Function::PipelineStage::kCompute,
+ std::array{1u, 1u, 1u});
+ mod.functions.Push(func);
+
+ auto* load = b.Load(v);
+ func->StartTarget()->SetInstructions(utils::Vector{load, b.Return(func)});
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %5 "foo"
+OpExecutionMode %5 LocalSize 1 1 1
+OpMemberName %3 0 "tint_symbol"
+OpName %3 "tint_symbol_1"
+OpName %5 "foo"
+OpMemberDecorate %3 0 Offset 0
+OpDecorate %3 Block
+OpDecorate %1 DescriptorSet 0
+OpDecorate %1 Binding 0
+%4 = OpTypeInt 32 1
+%3 = OpTypeStruct %4
+%2 = OpTypePointer Uniform %3
+%1 = OpVariable %2 Uniform
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%10 = OpTypePointer Uniform %4
+%12 = OpTypeInt 32 0
+%11 = OpConstant %12 0
+%5 = OpFunction %6 None %7
+%8 = OpLabel
+%9 = OpAccessChain %10 %1 %11
+%13 = OpLoad %4 %9
+OpReturn
+OpFunctionEnd
+)");
+}
+
} // namespace
} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/ir/test_helper_ir.h b/src/tint/writer/spirv/ir/test_helper_ir.h
index 791dfca..08f8246 100644
--- a/src/tint/writer/spirv/ir/test_helper_ir.h
+++ b/src/tint/writer/spirv/ir/test_helper_ir.h
@@ -19,6 +19,7 @@
#include "gtest/gtest.h"
#include "src/tint/ir/builder.h"
+#include "src/tint/ir/validate.h"
#include "src/tint/writer/spirv/ir/generator_impl_ir.h"
#include "src/tint/writer/spirv/spv_dump.h"
@@ -50,6 +51,22 @@
/// The SPIR-V generator.
GeneratorImplIr generator_;
+ /// Validation errors
+ std::string err_;
+
+ /// @returns the error string from the validation
+ std::string Error() const { return err_; }
+
+ /// @returns true if the IR module is valid
+ bool IRIsValid() {
+ auto res = ir::Validate(mod);
+ if (!res) {
+ err_ = res.Failure().str();
+ return false;
+ }
+ return true;
+ }
+
/// @returns the disassembled types from the generated module.
std::string DumpTypes() { return DumpInstructions(generator_.Module().Types()); }