Import Tint changes from Dawn
Changes:
- 4b35f52f9b33eb358a01170ad2fd334f2b5b4347 [ir] Remove IsConnected by dan sinclair <dsinclair@chromium.org>
- a00fe39f3e823ad38aeb2c575daf94848977a961 [ir] Remove FlowNode. by dan sinclair <dsinclair@chromium.org>
- 0eb4d04d83e8f1565e9dc527f6d84fdbc0268aa8 [ir][spirv-writer] Expand binary arithmetic tests by James Price <jrprice@google.com>
- dd7b3141052053ce7e83057b6f82e4da9ec846af [ir][spirv-writer] Emit binary bitwise operators by James Price <jrprice@google.com>
- 59339216a1f625bce1eded49022c286e1cd53aa3 [ir][spirv-writer] Emit comparison instructions by James Price <jrprice@google.com>
- 1e67e5368d30d5c0ea8669b9c254f03ad86617f0 [tint][constant] Use the new constant::Manager by Ben Clayton <bclayton@google.com>
- a71bd22de1263cbf5e7774e0abad9653b77cec42 [tint][constant] Make Value::Clone() return a const pointer by Ben Clayton <bclayton@google.com>
- 9aa8012d91e5603650532ad1bd90f8d38b3583e2 [tint][constant] Add constant::Manager by Ben Clayton <bclayton@google.com>
- 56e5fb57ef241ce5c9937faacf64d792bccd5856 [ir] Convert function to Value. by dan sinclair <dsinclair@chromium.org>
- a518707db4a6d785ce82ea359b9f0a8d426937e3 [ir] Update disassembly output. by dan sinclair <dsinclair@chromium.org>
- f55ef5e48b14121017d315aa4c18c4b3e58c61a1 [ir] Convert FlowNode to Block where possible. by dan sinclair <dsinclair@chromium.org>
- 0089d5e6e2f9915b8867f68b758c8b467e2116b4 [ir] Change base class for terminators. by dan sinclair <dsinclair@chromium.org>
- e9a4adeff959518ee876ee61734c93967ed906e8 [ir] Remove Jump. by dan sinclair <dsinclair@chromium.org>
- 24c5ed6b0a5ab8ad7b5f708fa6f89dd1ac777e07 [ir][spirv-writer] Emit matrix constants by James Price <jrprice@google.com>
- a8c528052d67f37b9a1e7c7cb9a4e10d927e065b [ir][spirv-writer] Emit matrix types by James Price <jrprice@google.com>
- b54b58d57df2e26e0b6dd8aa87e730b1f7f818f8 [ir][spirv-writer] Implement user function calls by James Price <jrprice@google.com>
- c1fd6316de4817f1d1a5caef97a5623e384e472c [fuzzers] Substitute all override variables by James Price <jrprice@google.com>
- 1ea7f0f351856153e7f51b7a2695da658df38600 [ir][spirv-writer] Add support for parameters by James Price <jrprice@google.com>
- f5a62539f4ea16999fbdac588793aa3a47f0d8b0 [ir][spirv-writer] Add support for return values by James Price <jrprice@google.com>
- 23c0451377ec6c31c5d5b898595d860b605992de [ir] Add scalar type helpers to type::Manager by James Price <jrprice@google.com>
GitOrigin-RevId: 4b35f52f9b33eb358a01170ad2fd334f2b5b4347
Change-Id: I7b21e9697ef4855fefaac7fa557cc341d6b495f4
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/134380
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index b14851f..7e2444c 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -894,6 +894,8 @@
"constant/clone_context.h",
"constant/composite.cc",
"constant/composite.h",
+ "constant/manager.cc",
+ "constant/manager.h",
"constant/node.cc",
"constant/node.h",
"constant/scalar.cc",
@@ -1230,8 +1232,6 @@
"ir/disassembler.h",
"ir/discard.cc",
"ir/discard.h",
- "ir/flow_node.cc",
- "ir/flow_node.h",
"ir/function.cc",
"ir/function.h",
"ir/function_param.cc",
@@ -1242,8 +1242,6 @@
"ir/if.h",
"ir/instruction.cc",
"ir/instruction.h",
- "ir/jump.cc",
- "ir/jump.h",
"ir/load.cc",
"ir/load.h",
"ir/loop.cc",
@@ -1683,6 +1681,7 @@
tint_unittests_source_set("tint_unittests_constant_src") {
sources = [
"constant/composite_test.cc",
+ "constant/manager_test.cc",
"constant/scalar_test.cc",
"constant/splat_test.cc",
]
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 979c434..ffb4d9d 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -237,6 +237,8 @@
constant/clone_context.h
constant/composite.cc
constant/composite.h
+ constant/manager.cc
+ constant/manager.h
constant/scalar.cc
constant/scalar.h
constant/splat.cc
@@ -740,8 +742,6 @@
ir/discard.h
ir/from_program.cc
ir/from_program.h
- ir/flow_node.cc
- ir/flow_node.h
ir/function.cc
ir/function.h
ir/function_param.cc
@@ -752,8 +752,6 @@
ir/if.h
ir/instruction.cc
ir/instruction.h
- ir/jump.cc
- ir/jump.h
ir/load.cc
ir/load.h
ir/loop.cc
@@ -925,6 +923,7 @@
ast/workgroup_attribute_test.cc
clone_context_test.cc
constant/composite_test.cc
+ constant/manager_test.cc
constant/scalar_test.cc
constant/splat_test.cc
debug_test.cc
diff --git a/src/tint/constant/clone_context.h b/src/tint/constant/clone_context.h
index 3f597ed..5709baf 100644
--- a/src/tint/constant/clone_context.h
+++ b/src/tint/constant/clone_context.h
@@ -16,11 +16,10 @@
#define SRC_TINT_CONSTANT_CLONE_CONTEXT_H_
#include "src/tint/type/clone_context.h"
-#include "src/tint/utils/block_allocator.h"
-// Forward Declarations
+// Forward declarations
namespace tint::constant {
-class Value;
+class Manager;
} // namespace tint::constant
namespace tint::constant {
@@ -31,10 +30,7 @@
type::CloneContext type_ctx;
/// Destination information
- struct {
- /// The constant allocator
- utils::BlockAllocator<constant::Value>* constants;
- } dst;
+ constant::Manager& dst;
};
} // namespace tint::constant
diff --git a/src/tint/constant/composite.cc b/src/tint/constant/composite.cc
index 5b004e0..488f6f0 100644
--- a/src/tint/constant/composite.cc
+++ b/src/tint/constant/composite.cc
@@ -16,6 +16,8 @@
#include <utility>
+#include "src/tint/constant/manager.h"
+
TINT_INSTANTIATE_TYPEINFO(tint::constant::Composite);
namespace tint::constant {
@@ -24,17 +26,19 @@
utils::VectorRef<const constant::Value*> els,
bool all_0,
bool any_0)
- : type(t), elements(std::move(els)), all_zero(all_0), any_zero(any_0), hash(CalcHash()) {}
+ : type(t), elements(std::move(els)), all_zero(all_0), any_zero(any_0), hash(CalcHash()) {
+ TINT_ASSERT(Constant, !elements.IsEmpty());
+}
Composite::~Composite() = default;
-Composite* Composite::Clone(CloneContext& ctx) const {
+const Composite* Composite::Clone(CloneContext& ctx) const {
auto* ty = type->Clone(ctx.type_ctx);
utils::Vector<const constant::Value*, 4> els;
for (const auto* el : elements) {
els.Push(el->Clone(ctx));
}
- return ctx.dst.constants->Create<Composite>(ty, els, all_zero, any_zero);
+ return ctx.dst.Get<Composite>(ty, std::move(els), all_zero, any_zero);
}
} // namespace tint::constant
diff --git a/src/tint/constant/composite.h b/src/tint/constant/composite.h
index 7a50640..23367a5 100644
--- a/src/tint/constant/composite.h
+++ b/src/tint/constant/composite.h
@@ -61,7 +61,7 @@
/// Clones the constant into the provided context
/// @param ctx the clone context
/// @returns the cloned node
- Composite* Clone(CloneContext& ctx) const override;
+ const Composite* Clone(CloneContext& ctx) const override;
/// The composite type
type::Type const* const type;
diff --git a/src/tint/constant/composite_test.cc b/src/tint/constant/composite_test.cc
index fd083db..6a4c12d 100644
--- a/src/tint/constant/composite_test.cc
+++ b/src/tint/constant/composite_test.cc
@@ -25,15 +25,15 @@
using ConstantTest_Composite = TestHelper;
TEST_F(ConstantTest_Composite, AllZero) {
- auto* f32 = create<type::F32>();
+ auto* vec3f = create<type::Vector>(create<type::F32>(), 3u);
- auto* fPos0 = create<Scalar<tint::f32>>(f32, 0_f);
- auto* fNeg0 = create<Scalar<tint::f32>>(f32, -0_f);
- auto* fPos1 = create<Scalar<tint::f32>>(f32, 1_f);
+ auto* fPos0 = constants.Get(0_f);
+ auto* fNeg0 = constants.Get(-0_f);
+ auto* fPos1 = constants.Get(1_f);
- auto* compositeAll = create<Composite>(f32, utils::Vector{fPos0, fPos0});
- auto* compositeAny = create<Composite>(f32, utils::Vector{fNeg0, fPos1, fPos0});
- auto* compositeNone = create<Composite>(f32, utils::Vector{fNeg0, fNeg0});
+ auto* compositeAll = constants.Composite(vec3f, utils::Vector{fPos0, fPos0});
+ auto* compositeAny = constants.Composite(vec3f, utils::Vector{fNeg0, fPos1, fPos0});
+ auto* compositeNone = constants.Composite(vec3f, utils::Vector{fNeg0, fNeg0});
EXPECT_TRUE(compositeAll->AllZero());
EXPECT_FALSE(compositeAny->AllZero());
@@ -41,15 +41,15 @@
}
TEST_F(ConstantTest_Composite, AnyZero) {
- auto* f32 = create<type::F32>();
+ auto* vec3f = create<type::Vector>(create<type::F32>(), 3u);
- auto* fPos0 = create<Scalar<tint::f32>>(f32, 0_f);
- auto* fNeg0 = create<Scalar<tint::f32>>(f32, -0_f);
- auto* fPos1 = create<Scalar<tint::f32>>(f32, 1_f);
+ auto* fPos0 = constants.Get(0_f);
+ auto* fNeg0 = constants.Get(-0_f);
+ auto* fPos1 = constants.Get(1_f);
- auto* compositeAll = create<Composite>(f32, utils::Vector{fPos0, fPos0});
- auto* compositeAny = create<Composite>(f32, utils::Vector{fNeg0, fPos1, fPos0});
- auto* compositeNone = create<Composite>(f32, utils::Vector{fNeg0, fNeg0});
+ auto* compositeAll = constants.Composite(vec3f, utils::Vector{fPos0, fPos0});
+ auto* compositeAny = constants.Composite(vec3f, utils::Vector{fNeg0, fPos1, fPos0});
+ auto* compositeNone = constants.Composite(vec3f, utils::Vector{fNeg0, fNeg0});
EXPECT_TRUE(compositeAll->AnyZero());
EXPECT_TRUE(compositeAny->AnyZero());
@@ -57,12 +57,12 @@
}
TEST_F(ConstantTest_Composite, Index) {
- auto* f32 = create<type::F32>();
+ auto* vec3f = create<type::Vector>(create<type::F32>(), 3u);
- auto* fPos0 = create<Scalar<tint::f32>>(f32, 0_f);
- auto* fPos1 = create<Scalar<tint::f32>>(f32, 1_f);
+ auto* fPos0 = constants.Get(0_f);
+ auto* fPos1 = constants.Get(1_f);
- auto* composite = create<Composite>(f32, utils::Vector{fPos1, fPos0});
+ auto* composite = constants.Composite(vec3f, utils::Vector{fPos1, fPos0});
ASSERT_NE(composite->Index(0), nullptr);
ASSERT_NE(composite->Index(1), nullptr);
@@ -75,20 +75,19 @@
}
TEST_F(ConstantTest_Composite, Clone) {
- auto* f32 = create<type::F32>();
+ auto* vec3f = create<type::Vector>(create<type::F32>(), 3u);
- auto* fPos0 = create<Scalar<tint::f32>>(f32, 0_f);
- auto* fPos1 = create<Scalar<tint::f32>>(f32, 1_f);
+ auto* fPos0 = constants.Get(0_f);
+ auto* fPos1 = constants.Get(1_f);
- auto* composite = create<Composite>(f32, utils::Vector{fPos1, fPos0});
+ auto* composite = constants.Composite(vec3f, utils::Vector{fPos1, fPos0});
- type::Manager mgr;
- utils::BlockAllocator<constant::Value> consts;
- constant::CloneContext ctx{type::CloneContext{{nullptr}, {nullptr, &mgr}}, {&consts}};
+ constant::Manager mgr;
+ constant::CloneContext ctx{type::CloneContext{{nullptr}, {nullptr, &mgr.types}}, mgr};
auto* r = composite->As<Composite>()->Clone(ctx);
ASSERT_NE(r, nullptr);
- EXPECT_TRUE(r->type->Is<type::F32>());
+ EXPECT_TRUE(r->type->Is<type::Vector>());
EXPECT_FALSE(r->all_zero);
EXPECT_TRUE(r->any_zero);
ASSERT_EQ(r->elements.Length(), 2u);
diff --git a/src/tint/constant/manager.cc b/src/tint/constant/manager.cc
new file mode 100644
index 0000000..829edb2
--- /dev/null
+++ b/src/tint/constant/manager.cc
@@ -0,0 +1,105 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/constant/manager.h"
+
+#include "src/tint/constant/composite.h"
+#include "src/tint/constant/scalar.h"
+#include "src/tint/constant/splat.h"
+#include "src/tint/type/abstract_float.h"
+#include "src/tint/type/abstract_int.h"
+#include "src/tint/type/bool.h"
+#include "src/tint/type/f16.h"
+#include "src/tint/type/f32.h"
+#include "src/tint/type/i32.h"
+#include "src/tint/type/manager.h"
+#include "src/tint/type/u32.h"
+#include "src/tint/utils/predicates.h"
+
+namespace tint::constant {
+
+Manager::Manager() = default;
+
+Manager::Manager(Manager&&) = default;
+
+Manager& Manager::operator=(Manager&& rhs) = default;
+
+Manager::~Manager() = default;
+
+const constant::Value* Manager::Composite(const type::Type* type,
+ utils::VectorRef<const constant::Value*> elements) {
+ if (elements.IsEmpty()) {
+ return nullptr;
+ }
+
+ bool any_zero = false;
+ bool all_zero = true;
+ bool all_equal = true;
+ auto* first = elements.Front();
+ for (auto* el : elements) {
+ if (!el) {
+ return nullptr;
+ }
+ if (!any_zero && el->AnyZero()) {
+ any_zero = true;
+ }
+ if (all_zero && !el->AllZero()) {
+ all_zero = false;
+ }
+ if (all_equal && el != first) {
+ all_equal = false;
+ }
+ }
+ if (all_equal) {
+ return Splat(type, elements.Front(), elements.Length());
+ }
+
+ return Get<constant::Composite>(type, std::move(elements), all_zero, any_zero);
+}
+
+const constant::Splat* Manager::Splat(const type::Type* type,
+ const constant::Value* element,
+ size_t n) {
+ return Get<constant::Splat>(type, element, n);
+}
+
+const Scalar<i32>* Manager::Get(i32 value) {
+ return Get<Scalar<i32>>(types.i32(), value);
+}
+
+const Scalar<u32>* Manager::Get(u32 value) {
+ return Get<Scalar<u32>>(types.u32(), value);
+}
+
+const Scalar<f32>* Manager::Get(f32 value) {
+ return Get<Scalar<f32>>(types.f32(), value);
+}
+
+const Scalar<f16>* Manager::Get(f16 value) {
+ return Get<Scalar<f16>>(types.f16(), value);
+}
+
+const Scalar<bool>* Manager::Get(bool value) {
+ return Get<Scalar<bool>>(types.bool_(), value);
+}
+
+const Scalar<AFloat>* Manager::Get(AFloat value) {
+ return Get<Scalar<AFloat>>(types.AFloat(), value);
+}
+
+const Scalar<AInt>* Manager::Get(AInt value) {
+ return Get<Scalar<AInt>>(types.AInt(), value);
+}
+
+} // namespace tint::constant
diff --git a/src/tint/constant/manager.h b/src/tint/constant/manager.h
new file mode 100644
index 0000000..d356988
--- /dev/null
+++ b/src/tint/constant/manager.h
@@ -0,0 +1,157 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_CONSTANT_MANAGER_H_
+#define SRC_TINT_CONSTANT_MANAGER_H_
+
+#include <utility>
+
+#include "src/tint/constant/value.h"
+#include "src/tint/number.h"
+#include "src/tint/type/manager.h"
+#include "src/tint/utils/hash.h"
+#include "src/tint/utils/unique_allocator.h"
+
+namespace tint::constant {
+class Splat;
+
+template <typename T>
+class Scalar;
+} // namespace tint::constant
+
+namespace tint::constant {
+
+/// The constant manager holds a type manager and all the pointers to the known constant values.
+class Manager final {
+ public:
+ /// Iterator is the type returned by begin() and end()
+ using TypeIterator = utils::BlockAllocator<Value>::ConstIterator;
+
+ /// Constructor
+ Manager();
+
+ /// Move constructor
+ Manager(Manager&&);
+
+ /// Move assignment operator
+ /// @param rhs the Manager to move
+ /// @return this Manager
+ Manager& operator=(Manager&& rhs);
+
+ /// Destructor
+ ~Manager();
+
+ /// Wrap returns a new Manager created with the constants and types of `inner`.
+ /// The Manager returned by Wrap is intended to temporarily extend the constants and types of an
+ /// existing immutable Manager. As the copied constants and types are owned by `inner`, `inner`
+ /// must not be destructed or assigned while using the returned Manager.
+ /// TODO(bclayton) - Evaluate whether there are safer alternatives to this
+ /// function. See crbug.com/tint/460.
+ /// @param inner the immutable Manager to extend
+ /// @return the Manager that wraps `inner`
+ static Manager Wrap(const Manager& inner) {
+ Manager out;
+ out.values_.Wrap(inner.values_);
+ out.types = type::Manager::Wrap(inner.types);
+ return out;
+ }
+
+ /// @param args the arguments used to construct the type, unique node or node.
+ /// @return a pointer to an instance of `T` with the provided arguments.
+ /// If NODE derives from UniqueNode and an existing instance of `T` has been
+ /// constructed, then the same pointer is returned.
+ template <typename NODE, typename... ARGS>
+ NODE* Get(ARGS&&... args) {
+ return values_.Get<NODE>(std::forward<ARGS>(args)...);
+ }
+
+ /// @returns an iterator to the beginning of the types
+ TypeIterator begin() const { return values_.begin(); }
+ /// @returns an iterator to the end of the types
+ TypeIterator end() const { return values_.end(); }
+
+ /// Constructs a constant of a vector, matrix or array type.
+ ///
+ /// Examines the element values and will return either a constant::Composite or a
+ /// constant::Splat, depending on the element types and values.
+ ///
+ /// @param type the composite type
+ /// @param elements the composite elements
+ /// @returns the value pointer
+ const constant::Value* Composite(const type::Type* type,
+ utils::VectorRef<const constant::Value*> elements);
+
+ /// Constructs a splat constant.
+ /// @param type the splat type
+ /// @param element the splat element
+ /// @param n the number of elements
+ /// @returns the value pointer
+ const constant::Splat* Splat(const type::Type* type, const constant::Value* element, size_t n);
+
+ /// @param value the constant value
+ /// @return a Scalar holding the i32 value @p value
+ const Scalar<i32>* Get(i32 value);
+
+ /// @param value the constant value
+ /// @return a Scalar holding the u32 value @p value
+ const Scalar<u32>* Get(u32 value);
+
+ /// @param value the constant value
+ /// @return a Scalar holding the f32 value @p value
+ const Scalar<f32>* Get(f32 value);
+
+ /// @param value the constant value
+ /// @return a Scalar holding the f16 value @p value
+ const Scalar<f16>* Get(f16 value);
+
+ /// @param value the constant value
+ /// @return a Scalar holding the bool value @p value
+ const Scalar<bool>* Get(bool value);
+
+ /// @param value the constant value
+ /// @return a Scalar holding the AFloat value @p value
+ const Scalar<AFloat>* Get(AFloat value);
+
+ /// @param value the constant value
+ /// @return a Scalar holding the AInt value @p value
+ const Scalar<AInt>* Get(AInt value);
+
+ /// The type manager
+ type::Manager types;
+
+ private:
+ /// A specialization of utils::Hasher for constant::Value
+ struct Hasher {
+ /// @param value the value to hash
+ /// @returns a hash of the value
+ size_t operator()(const constant::Value& value) const { return value.Hash(); }
+ };
+
+ /// An equality helper for constant::Value
+ struct Equal {
+ /// @param a the LHS value
+ /// @param b the RHS value
+ /// @returns true if the two constants are equal
+ bool operator()(const constant::Value& a, const constant::Value& b) const {
+ return a.Equal(&b);
+ }
+ };
+
+ /// Unique types owned by the manager
+ utils::UniqueAllocator<Value, Hasher, Equal> values_;
+};
+
+} // namespace tint::constant
+
+#endif // SRC_TINT_CONSTANT_MANAGER_H_
diff --git a/src/tint/constant/manager_test.cc b/src/tint/constant/manager_test.cc
new file mode 100644
index 0000000..ab6135c
--- /dev/null
+++ b/src/tint/constant/manager_test.cc
@@ -0,0 +1,187 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/constant/manager.h"
+
+#include "gtest/gtest.h"
+#include "src/tint/constant/scalar.h"
+#include "src/tint/type/abstract_float.h"
+#include "src/tint/type/abstract_int.h"
+#include "src/tint/type/bool.h"
+#include "src/tint/type/f16.h"
+#include "src/tint/type/f32.h"
+#include "src/tint/type/i32.h"
+#include "src/tint/type/manager.h"
+#include "src/tint/type/u32.h"
+
+namespace tint::constant {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+
+template <typename T>
+size_t count(const T& range_loopable) {
+ size_t n = 0;
+ for (auto it : range_loopable) {
+ (void)it;
+ n++;
+ }
+ return n;
+}
+
+using ManagerTest = testing::Test;
+
+TEST_F(ManagerTest, GetUnregistered) {
+ constant::Manager cm;
+
+ auto* c = cm.Get(1_i);
+ static_assert(std::is_same_v<const Scalar<i32>*, decltype(c)>);
+ ASSERT_NE(c, nullptr);
+}
+
+TEST_F(ManagerTest, GetSameConstantReturnsSamePtr) {
+ constant::Manager cm;
+
+ auto* c1 = cm.Get(1_i);
+ static_assert(std::is_same_v<const Scalar<i32>*, decltype(c1)>);
+ ASSERT_NE(c1, nullptr);
+
+ auto* c2 = cm.Get(1_i);
+ static_assert(std::is_same_v<const Scalar<i32>*, decltype(c2)>);
+ EXPECT_EQ(c1, c2);
+ EXPECT_EQ(c1->Type(), c2->Type());
+}
+
+TEST_F(ManagerTest, GetDifferentTypeReturnsDifferentPtr) {
+ constant::Manager cm;
+
+ auto* c1 = cm.Get(1_i);
+ static_assert(std::is_same_v<const Scalar<i32>*, decltype(c1)>);
+ ASSERT_NE(c1, nullptr);
+
+ auto* c2 = cm.Get(1_u);
+ static_assert(std::is_same_v<const Scalar<u32>*, decltype(c2)>);
+ EXPECT_NE(static_cast<const Value*>(c1), static_cast<const Value*>(c2));
+ EXPECT_NE(c1->Type(), c2->Type());
+}
+
+TEST_F(ManagerTest, GetDifferentValueReturnsDifferentPtr) {
+ constant::Manager cm;
+
+ auto* c1 = cm.Get(1_i);
+ static_assert(std::is_same_v<const Scalar<i32>*, decltype(c1)>);
+ ASSERT_NE(c1, nullptr);
+
+ auto* c2 = cm.Get(2_i);
+ static_assert(std::is_same_v<const Scalar<i32>*, decltype(c2)>);
+ ASSERT_NE(c2, nullptr);
+ EXPECT_NE(c1, c2);
+ EXPECT_EQ(c1->Type(), c2->Type());
+}
+
+TEST_F(ManagerTest, Get_i32) {
+ constant::Manager cm;
+
+ auto* c = cm.Get(1_i);
+ static_assert(std::is_same_v<const Scalar<i32>*, decltype(c)>);
+ ASSERT_TRUE(Is<type::I32>(c->Type()));
+ EXPECT_EQ(c->value, 1_i);
+}
+
+TEST_F(ManagerTest, Get_u32) {
+ constant::Manager cm;
+
+ auto* c = cm.Get(1_u);
+ static_assert(std::is_same_v<const Scalar<u32>*, decltype(c)>);
+ ASSERT_TRUE(Is<type::U32>(c->Type()));
+ EXPECT_EQ(c->value, 1_u);
+}
+
+TEST_F(ManagerTest, Get_f32) {
+ constant::Manager cm;
+
+ auto* c = cm.Get(1_f);
+ static_assert(std::is_same_v<const Scalar<f32>*, decltype(c)>);
+ ASSERT_TRUE(Is<type::F32>(c->Type()));
+ EXPECT_EQ(c->value, 1_f);
+}
+
+TEST_F(ManagerTest, Get_f16) {
+ constant::Manager cm;
+
+ auto* c = cm.Get(1_h);
+ static_assert(std::is_same_v<const Scalar<f16>*, decltype(c)>);
+ ASSERT_TRUE(Is<type::F16>(c->Type()));
+ EXPECT_EQ(c->value, 1_h);
+}
+
+TEST_F(ManagerTest, Get_bool) {
+ constant::Manager cm;
+
+ auto* c = cm.Get(true);
+ static_assert(std::is_same_v<const Scalar<bool>*, decltype(c)>);
+ ASSERT_TRUE(Is<type::Bool>(c->Type()));
+ EXPECT_EQ(c->value, true);
+}
+
+TEST_F(ManagerTest, Get_AFloat) {
+ constant::Manager cm;
+
+ auto* c = cm.Get(1._a);
+ static_assert(std::is_same_v<const Scalar<AFloat>*, decltype(c)>);
+ ASSERT_TRUE(Is<type::AbstractFloat>(c->Type()));
+ EXPECT_EQ(c->value, 1._a);
+}
+
+TEST_F(ManagerTest, Get_AInt) {
+ constant::Manager cm;
+
+ auto* c = cm.Get(1_a);
+ static_assert(std::is_same_v<const Scalar<AInt>*, decltype(c)>);
+ ASSERT_TRUE(Is<type::AbstractInt>(c->Type()));
+ EXPECT_EQ(c->value, 1_a);
+}
+
+TEST_F(ManagerTest, WrapDoesntAffectInner_Constant) {
+ Manager inner;
+ Manager outer = Manager::Wrap(inner);
+
+ inner.Get(1_i);
+
+ EXPECT_EQ(count(inner), 1u);
+ EXPECT_EQ(count(outer), 0u);
+
+ outer.Get(1_i);
+
+ EXPECT_EQ(count(inner), 1u);
+ EXPECT_EQ(count(outer), 1u);
+}
+
+TEST_F(ManagerTest, WrapDoesntAffectInner_Types) {
+ Manager inner;
+ Manager outer = Manager::Wrap(inner);
+
+ inner.types.Get<type::I32>();
+
+ EXPECT_EQ(count(inner.types), 1u);
+ EXPECT_EQ(count(outer.types), 0u);
+
+ outer.types.Get<type::U32>();
+
+ EXPECT_EQ(count(inner.types), 1u);
+ EXPECT_EQ(count(outer.types), 1u);
+}
+
+} // namespace
+} // namespace tint::constant
diff --git a/src/tint/constant/scalar.h b/src/tint/constant/scalar.h
index ab5f852..474c9ad 100644
--- a/src/tint/constant/scalar.h
+++ b/src/tint/constant/scalar.h
@@ -15,6 +15,7 @@
#ifndef SRC_TINT_CONSTANT_SCALAR_H_
#define SRC_TINT_CONSTANT_SCALAR_H_
+#include "src/tint/constant/manager.h"
#include "src/tint/constant/value.h"
#include "src/tint/number.h"
#include "src/tint/type/type.h"
@@ -61,9 +62,9 @@
/// Clones the constant into the provided context
/// @param ctx the clone context
/// @returns the cloned node
- Scalar* Clone(CloneContext& ctx) const override {
+ const Scalar* Clone(CloneContext& ctx) const override {
auto* ty = type->Clone(ctx.type_ctx);
- return ctx.dst.constants->Create<Scalar<T>>(ty, value);
+ return ctx.dst.Get<Scalar<T>>(ty, value);
}
/// @returns `value` if `T` is not a Number, otherwise ValueOf returns the inner value of the
diff --git a/src/tint/constant/scalar_test.cc b/src/tint/constant/scalar_test.cc
index 6dedf75..49881e5 100644
--- a/src/tint/constant/scalar_test.cc
+++ b/src/tint/constant/scalar_test.cc
@@ -24,40 +24,34 @@
using ConstantTest_Scalar = TestHelper;
TEST_F(ConstantTest_Scalar, AllZero) {
- auto* i32 = create<type::I32>();
- auto* u32 = create<type::U32>();
- auto* f16 = create<type::F16>();
- auto* f32 = create<type::F32>();
- auto* bool_ = create<type::Bool>();
+ auto* i0 = constants.Get(0_i);
+ auto* iPos1 = constants.Get(1_i);
+ auto* iNeg1 = constants.Get(-1_i);
- auto* i0 = create<Scalar<tint::i32>>(i32, 0_i);
- auto* iPos1 = create<Scalar<tint::i32>>(i32, 1_i);
- auto* iNeg1 = create<Scalar<tint::i32>>(i32, -1_i);
+ auto* u0 = constants.Get(0_u);
+ auto* u1 = constants.Get(1_u);
- auto* u0 = create<Scalar<tint::u32>>(u32, 0_u);
- auto* u1 = create<Scalar<tint::u32>>(u32, 1_u);
+ auto* fPos0 = constants.Get(0_f);
+ auto* fNeg0 = constants.Get(-0_f);
+ auto* fPos1 = constants.Get(1_f);
+ auto* fNeg1 = constants.Get(-1_f);
- auto* fPos0 = create<Scalar<tint::f32>>(f32, 0_f);
- auto* fNeg0 = create<Scalar<tint::f32>>(f32, -0_f);
- auto* fPos1 = create<Scalar<tint::f32>>(f32, 1_f);
- auto* fNeg1 = create<Scalar<tint::f32>>(f32, -1_f);
+ auto* f16Pos0 = constants.Get(0_h);
+ auto* f16Neg0 = constants.Get(-0_h);
+ auto* f16Pos1 = constants.Get(1_h);
+ auto* f16Neg1 = constants.Get(-1_h);
- auto* f16Pos0 = create<Scalar<tint::f16>>(f16, 0_h);
- auto* f16Neg0 = create<Scalar<tint::f16>>(f16, -0_h);
- auto* f16Pos1 = create<Scalar<tint::f16>>(f16, 1_h);
- auto* f16Neg1 = create<Scalar<tint::f16>>(f16, -1_h);
+ auto* bf = constants.Get(false);
+ auto* bt = constants.Get(true);
- auto* bf = create<Scalar<bool>>(bool_, false);
- auto* bt = create<Scalar<bool>>(bool_, true);
+ auto* afPos0 = constants.Get(0.0_a);
+ auto* afNeg0 = constants.Get(-0.0_a);
+ auto* afPos1 = constants.Get(1.0_a);
+ auto* afNeg1 = constants.Get(-1.0_a);
- auto* afPos0 = create<Scalar<tint::AFloat>>(f32, 0.0_a);
- auto* afNeg0 = create<Scalar<tint::AFloat>>(f32, -0.0_a);
- auto* afPos1 = create<Scalar<tint::AFloat>>(f32, 1.0_a);
- auto* afNeg1 = create<Scalar<tint::AFloat>>(f32, -1.0_a);
-
- auto* ai0 = create<Scalar<tint::AInt>>(i32, 0_a);
- auto* aiPos1 = create<Scalar<tint::AInt>>(i32, 1_a);
- auto* aiNeg1 = create<Scalar<tint::AInt>>(i32, -1_a);
+ auto* ai0 = constants.Get(0_a);
+ auto* aiPos1 = constants.Get(1_a);
+ auto* aiNeg1 = constants.Get(-1_a);
EXPECT_TRUE(i0->AllZero());
EXPECT_FALSE(iPos1->AllZero());
@@ -90,40 +84,34 @@
}
TEST_F(ConstantTest_Scalar, AnyZero) {
- auto* i32 = create<type::I32>();
- auto* u32 = create<type::U32>();
- auto* f16 = create<type::F16>();
- auto* f32 = create<type::F32>();
- auto* bool_ = create<type::Bool>();
+ auto* i0 = constants.Get(0_i);
+ auto* iPos1 = constants.Get(1_i);
+ auto* iNeg1 = constants.Get(-1_i);
- auto* i0 = create<Scalar<tint::i32>>(i32, 0_i);
- auto* iPos1 = create<Scalar<tint::i32>>(i32, 1_i);
- auto* iNeg1 = create<Scalar<tint::i32>>(i32, -1_i);
+ auto* u0 = constants.Get(0_u);
+ auto* u1 = constants.Get(1_u);
- auto* u0 = create<Scalar<tint::u32>>(u32, 0_u);
- auto* u1 = create<Scalar<tint::u32>>(u32, 1_u);
+ auto* fPos0 = constants.Get(0_f);
+ auto* fNeg0 = constants.Get(-0_f);
+ auto* fPos1 = constants.Get(1_f);
+ auto* fNeg1 = constants.Get(-1_f);
- auto* fPos0 = create<Scalar<tint::f32>>(f32, 0_f);
- auto* fNeg0 = create<Scalar<tint::f32>>(f32, -0_f);
- auto* fPos1 = create<Scalar<tint::f32>>(f32, 1_f);
- auto* fNeg1 = create<Scalar<tint::f32>>(f32, -1_f);
+ auto* f16Pos0 = constants.Get(0_h);
+ auto* f16Neg0 = constants.Get(-0_h);
+ auto* f16Pos1 = constants.Get(1_h);
+ auto* f16Neg1 = constants.Get(-1_h);
- auto* f16Pos0 = create<Scalar<tint::f16>>(f16, 0_h);
- auto* f16Neg0 = create<Scalar<tint::f16>>(f16, -0_h);
- auto* f16Pos1 = create<Scalar<tint::f16>>(f16, 1_h);
- auto* f16Neg1 = create<Scalar<tint::f16>>(f16, -1_h);
+ auto* bf = constants.Get(false);
+ auto* bt = constants.Get(true);
- auto* bf = create<Scalar<bool>>(bool_, false);
- auto* bt = create<Scalar<bool>>(bool_, true);
+ auto* afPos0 = constants.Get(0.0_a);
+ auto* afNeg0 = constants.Get(-0.0_a);
+ auto* afPos1 = constants.Get(1.0_a);
+ auto* afNeg1 = constants.Get(-1.0_a);
- auto* afPos0 = create<Scalar<tint::AFloat>>(f32, 0.0_a);
- auto* afNeg0 = create<Scalar<tint::AFloat>>(f32, -0.0_a);
- auto* afPos1 = create<Scalar<tint::AFloat>>(f32, 1.0_a);
- auto* afNeg1 = create<Scalar<tint::AFloat>>(f32, -1.0_a);
-
- auto* ai0 = create<Scalar<tint::AInt>>(i32, 0_a);
- auto* aiPos1 = create<Scalar<tint::AInt>>(i32, 1_a);
- auto* aiNeg1 = create<Scalar<tint::AInt>>(i32, -1_a);
+ auto* ai0 = constants.Get(0_a);
+ auto* aiPos1 = constants.Get(1_a);
+ auto* aiNeg1 = constants.Get(-1_a);
EXPECT_TRUE(i0->AnyZero());
EXPECT_FALSE(iPos1->AnyZero());
@@ -156,20 +144,14 @@
}
TEST_F(ConstantTest_Scalar, ValueOf) {
- auto* i32 = create<type::I32>();
- auto* u32 = create<type::U32>();
- auto* f16 = create<type::F16>();
- auto* f32 = create<type::F32>();
- auto* bool_ = create<type::Bool>();
-
- auto* i1 = create<Scalar<tint::i32>>(i32, 1_i);
- auto* u1 = create<Scalar<tint::u32>>(u32, 1_u);
- auto* f1 = create<Scalar<tint::f32>>(f32, 1_f);
- auto* f16Pos1 = create<Scalar<tint::f16>>(f16, 1_h);
- auto* bf = create<Scalar<bool>>(bool_, false);
- auto* bt = create<Scalar<bool>>(bool_, true);
- auto* af1 = create<Scalar<tint::AFloat>>(f32, 1.0_a);
- auto* ai1 = create<Scalar<tint::AInt>>(i32, 1_a);
+ auto* i1 = constants.Get(1_i);
+ auto* u1 = constants.Get(1_u);
+ auto* f1 = constants.Get(1_f);
+ auto* f16Pos1 = constants.Get(1_h);
+ auto* bf = constants.Get(false);
+ auto* bt = constants.Get(true);
+ auto* af1 = constants.Get(1.0_a);
+ auto* ai1 = constants.Get(1_a);
EXPECT_EQ(i1->ValueOf(), 1);
EXPECT_EQ(u1->ValueOf(), 1u);
@@ -182,12 +164,10 @@
}
TEST_F(ConstantTest_Scalar, Clone) {
- auto* i32 = create<type::I32>();
- auto* val = create<Scalar<tint::i32>>(i32, 12_i);
+ auto* val = constants.Get(12_i);
- type::Manager mgr;
- utils::BlockAllocator<constant::Value> consts;
- constant::CloneContext ctx{type::CloneContext{{nullptr}, {nullptr, &mgr}}, {&consts}};
+ constant::Manager mgr;
+ constant::CloneContext ctx{type::CloneContext{{nullptr}, {nullptr, &mgr.types}}, mgr};
auto* r = val->Clone(ctx);
ASSERT_NE(r, nullptr);
diff --git a/src/tint/constant/splat.cc b/src/tint/constant/splat.cc
index 4ebfe5f..8adbed7 100644
--- a/src/tint/constant/splat.cc
+++ b/src/tint/constant/splat.cc
@@ -14,6 +14,8 @@
#include "src/tint/constant/splat.h"
+#include "src/tint/constant/manager.h"
+
TINT_INSTANTIATE_TYPEINFO(tint::constant::Splat);
namespace tint::constant {
@@ -22,10 +24,10 @@
Splat::~Splat() = default;
-Splat* Splat::Clone(CloneContext& ctx) const {
+const Splat* Splat::Clone(CloneContext& ctx) const {
auto* ty = type->Clone(ctx.type_ctx);
auto* element = el->Clone(ctx);
- return ctx.dst.constants->Create<Splat>(ty, element, count);
+ return ctx.dst.Splat(ty, element, count);
}
} // namespace tint::constant
diff --git a/src/tint/constant/splat.h b/src/tint/constant/splat.h
index d8e55a6..2bf82dd 100644
--- a/src/tint/constant/splat.h
+++ b/src/tint/constant/splat.h
@@ -57,7 +57,7 @@
/// Clones the constant into the provided context
/// @param ctx the clone context
/// @returns the cloned node
- Splat* Clone(CloneContext& ctx) const override;
+ const Splat* Clone(CloneContext& ctx) const override;
/// The type of the splat element
type::Type const* const type;
diff --git a/src/tint/constant/splat_test.cc b/src/tint/constant/splat_test.cc
index fe8aeda..0e2f21e 100644
--- a/src/tint/constant/splat_test.cc
+++ b/src/tint/constant/splat_test.cc
@@ -25,15 +25,15 @@
using ConstantTest_Splat = TestHelper;
TEST_F(ConstantTest_Splat, AllZero) {
- auto* f32 = create<type::F32>();
+ auto* vec3f = create<type::Vector>(create<type::F32>(), 3u);
- auto* fPos0 = create<Scalar<tint::f32>>(f32, 0_f);
- auto* fNeg0 = create<Scalar<tint::f32>>(f32, -0_f);
- auto* fPos1 = create<Scalar<tint::f32>>(f32, 1_f);
+ auto* fPos0 = constants.Get(0_f);
+ auto* fNeg0 = constants.Get(-0_f);
+ auto* fPos1 = constants.Get(1_f);
- auto* SpfPos0 = create<Splat>(f32, fPos0, 2);
- auto* SpfNeg0 = create<Splat>(f32, fNeg0, 2);
- auto* SpfPos1 = create<Splat>(f32, fPos1, 2);
+ auto* SpfPos0 = constants.Splat(vec3f, fPos0, 2);
+ auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 2);
+ auto* SpfPos1 = constants.Splat(vec3f, fPos1, 2);
EXPECT_TRUE(SpfPos0->AllZero());
EXPECT_FALSE(SpfNeg0->AllZero());
@@ -41,15 +41,15 @@
}
TEST_F(ConstantTest_Splat, AnyZero) {
- auto* f32 = create<type::F32>();
+ auto* vec3f = create<type::Vector>(create<type::F32>(), 3u);
- auto* fPos0 = create<Scalar<tint::f32>>(f32, 0_f);
- auto* fNeg0 = create<Scalar<tint::f32>>(f32, -0_f);
- auto* fPos1 = create<Scalar<tint::f32>>(f32, 1_f);
+ auto* fPos0 = constants.Get(0_f);
+ auto* fNeg0 = constants.Get(-0_f);
+ auto* fPos1 = constants.Get(1_f);
- auto* SpfPos0 = create<Splat>(f32, fPos0, 2);
- auto* SpfNeg0 = create<Splat>(f32, fNeg0, 2);
- auto* SpfPos1 = create<Splat>(f32, fPos1, 2);
+ auto* SpfPos0 = constants.Splat(vec3f, fPos0, 2);
+ auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 2);
+ auto* SpfPos1 = constants.Splat(vec3f, fPos1, 2);
EXPECT_TRUE(SpfPos0->AnyZero());
EXPECT_FALSE(SpfNeg0->AnyZero());
@@ -57,10 +57,10 @@
}
TEST_F(ConstantTest_Splat, Index) {
- auto* f32 = create<type::F32>();
+ auto* vec3f = create<type::Vector>(create<type::F32>(), 3u);
- auto* f1 = create<Scalar<tint::f32>>(f32, 1_f);
- auto* sp = create<Splat>(f32, f1, 2);
+ auto* f1 = constants.Get(1_f);
+ auto* sp = constants.Splat(vec3f, f1, 2);
ASSERT_NE(sp->Index(0), nullptr);
ASSERT_NE(sp->Index(1), nullptr);
@@ -71,17 +71,16 @@
}
TEST_F(ConstantTest_Splat, Clone) {
- auto* i32 = create<type::I32>();
- auto* val = create<Scalar<tint::i32>>(i32, 12_i);
- auto* sp = create<Splat>(i32, val, 2);
+ auto* vec3i = create<type::Vector>(create<type::I32>(), 3u);
+ auto* val = constants.Get(12_i);
+ auto* sp = constants.Splat(vec3i, val, 2);
- type::Manager mgr;
- utils::BlockAllocator<constant::Value> consts;
- constant::CloneContext ctx{type::CloneContext{{nullptr}, {nullptr, &mgr}}, {&consts}};
+ constant::Manager mgr;
+ constant::CloneContext ctx{type::CloneContext{{nullptr}, {nullptr, &mgr.types}}, mgr};
auto* r = sp->Clone(ctx);
ASSERT_NE(r, nullptr);
- EXPECT_TRUE(r->type->Is<type::I32>());
+ EXPECT_TRUE(r->type->Is<type::Vector>());
EXPECT_TRUE(r->el->Is<Scalar<tint::i32>>());
EXPECT_EQ(r->count, 2u);
}
diff --git a/src/tint/constant/value.h b/src/tint/constant/value.h
index fe1f78a..baf1171 100644
--- a/src/tint/constant/value.h
+++ b/src/tint/constant/value.h
@@ -80,7 +80,7 @@
/// Clones the constant into the provided context
/// @param ctx the clone context
/// @returns the cloned node
- virtual Value* Clone(CloneContext& ctx) const = 0;
+ virtual const Value* Clone(CloneContext& ctx) const = 0;
protected:
/// @returns the value, if this is of a scalar value or abstract numeric, otherwise
diff --git a/src/tint/fuzzers/tint_common_fuzzer.cc b/src/tint/fuzzers/tint_common_fuzzer.cc
index 5595d10..6fa3b6e 100644
--- a/src/tint/fuzzers/tint_common_fuzzer.cc
+++ b/src/tint/fuzzers/tint_common_fuzzer.cc
@@ -248,7 +248,7 @@
cfg.map.insert({override_id, 0.0});
}
- if (!cfg.map.empty()) {
+ if (!default_values.empty()) {
transform::DataMap override_data;
override_data.Add<ast::transform::SubstituteOverride::Config>(cfg);
diff --git a/src/tint/ir/binary_test.cc b/src/tint/ir/binary_test.cc
index dc35978..281d04b 100644
--- a/src/tint/ir/binary_test.cc
+++ b/src/tint/ir/binary_test.cc
@@ -27,7 +27,7 @@
Module mod;
Builder b{mod};
- const auto* inst = b.And(b.ir.types.Get<type::I32>(), b.Constant(4_i), b.Constant(2_i));
+ const auto* inst = b.And(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kAnd);
@@ -48,7 +48,7 @@
Module mod;
Builder b{mod};
- const auto* inst = b.Or(b.ir.types.Get<type::I32>(), b.Constant(4_i), b.Constant(2_i));
+ const auto* inst = b.Or(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kOr);
@@ -68,7 +68,7 @@
Module mod;
Builder b{mod};
- const auto* inst = b.Xor(b.ir.types.Get<type::I32>(), b.Constant(4_i), b.Constant(2_i));
+ const auto* inst = b.Xor(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kXor);
@@ -88,7 +88,7 @@
Module mod;
Builder b{mod};
- const auto* inst = b.Equal(b.ir.types.Get<type::Bool>(), b.Constant(4_i), b.Constant(2_i));
+ const auto* inst = b.Equal(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kEqual);
@@ -108,7 +108,7 @@
Module mod;
Builder b{mod};
- const auto* inst = b.NotEqual(b.ir.types.Get<type::Bool>(), b.Constant(4_i), b.Constant(2_i));
+ const auto* inst = b.NotEqual(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kNotEqual);
@@ -128,7 +128,7 @@
Module mod;
Builder b{mod};
- const auto* inst = b.LessThan(b.ir.types.Get<type::Bool>(), b.Constant(4_i), b.Constant(2_i));
+ const auto* inst = b.LessThan(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kLessThan);
@@ -148,8 +148,7 @@
Module mod;
Builder b{mod};
- const auto* inst =
- b.GreaterThan(b.ir.types.Get<type::Bool>(), b.Constant(4_i), b.Constant(2_i));
+ const auto* inst = b.GreaterThan(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kGreaterThan);
@@ -169,8 +168,7 @@
Module mod;
Builder b{mod};
- const auto* inst =
- b.LessThanEqual(b.ir.types.Get<type::Bool>(), b.Constant(4_i), b.Constant(2_i));
+ const auto* inst = b.LessThanEqual(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kLessThanEqual);
@@ -190,8 +188,7 @@
Module mod;
Builder b{mod};
- const auto* inst =
- b.GreaterThanEqual(b.ir.types.Get<type::Bool>(), b.Constant(4_i), b.Constant(2_i));
+ const auto* inst = b.GreaterThanEqual(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kGreaterThanEqual);
@@ -210,7 +207,7 @@
TEST_F(IR_InstructionTest, CreateNot) {
Module mod;
Builder b{mod};
- const auto* inst = b.Not(b.ir.types.Get<type::Bool>(), b.Constant(true));
+ const auto* inst = b.Not(mod.Types().bool_(), b.Constant(true));
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kEqual);
@@ -230,7 +227,7 @@
Module mod;
Builder b{mod};
- const auto* inst = b.ShiftLeft(b.ir.types.Get<type::I32>(), b.Constant(4_i), b.Constant(2_i));
+ const auto* inst = b.ShiftLeft(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kShiftLeft);
@@ -250,7 +247,7 @@
Module mod;
Builder b{mod};
- const auto* inst = b.ShiftRight(b.ir.types.Get<type::I32>(), b.Constant(4_i), b.Constant(2_i));
+ const auto* inst = b.ShiftRight(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kShiftRight);
@@ -270,7 +267,7 @@
Module mod;
Builder b{mod};
- const auto* inst = b.Add(b.ir.types.Get<type::I32>(), b.Constant(4_i), b.Constant(2_i));
+ const auto* inst = b.Add(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kAdd);
@@ -290,7 +287,7 @@
Module mod;
Builder b{mod};
- const auto* inst = b.Subtract(b.ir.types.Get<type::I32>(), b.Constant(4_i), b.Constant(2_i));
+ const auto* inst = b.Subtract(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kSubtract);
@@ -310,7 +307,7 @@
Module mod;
Builder b{mod};
- const auto* inst = b.Multiply(b.ir.types.Get<type::I32>(), b.Constant(4_i), b.Constant(2_i));
+ const auto* inst = b.Multiply(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kMultiply);
@@ -330,7 +327,7 @@
Module mod;
Builder b{mod};
- const auto* inst = b.Divide(b.ir.types.Get<type::I32>(), b.Constant(4_i), b.Constant(2_i));
+ const auto* inst = b.Divide(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kDivide);
@@ -350,7 +347,7 @@
Module mod;
Builder b{mod};
- const auto* inst = b.Modulo(b.ir.types.Get<type::I32>(), b.Constant(4_i), b.Constant(2_i));
+ const auto* inst = b.Modulo(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kModulo);
@@ -369,7 +366,7 @@
TEST_F(IR_InstructionTest, Binary_Usage) {
Module mod;
Builder b{mod};
- const auto* inst = b.And(b.ir.types.Get<type::I32>(), b.Constant(4_i), b.Constant(2_i));
+ const auto* inst = b.And(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
EXPECT_EQ(inst->Kind(), Binary::Kind::kAnd);
@@ -386,7 +383,7 @@
Module mod;
Builder b{mod};
auto val = b.Constant(4_i);
- const auto* inst = b.And(b.ir.types.Get<type::I32>(), val, val);
+ const auto* inst = b.And(mod.Types().i32(), val, val);
EXPECT_EQ(inst->Kind(), Binary::Kind::kAnd);
ASSERT_EQ(inst->LHS(), inst->RHS());
diff --git a/src/tint/ir/bitcast_test.cc b/src/tint/ir/bitcast_test.cc
index 6eda562..bf66e39 100644
--- a/src/tint/ir/bitcast_test.cc
+++ b/src/tint/ir/bitcast_test.cc
@@ -27,7 +27,7 @@
TEST_F(IR_InstructionTest, Bitcast) {
Module mod;
Builder b{mod};
- const auto* inst = b.Bitcast(b.ir.types.Get<type::I32>(), b.Constant(4_i));
+ const auto* inst = b.Bitcast(mod.Types().i32(), b.Constant(4_i));
ASSERT_TRUE(inst->Is<ir::Bitcast>());
ASSERT_NE(inst->Type(), nullptr);
@@ -43,7 +43,7 @@
TEST_F(IR_InstructionTest, Bitcast_Usage) {
Module mod;
Builder b{mod};
- const auto* inst = b.Bitcast(b.ir.types.Get<type::I32>(), b.Constant(4_i));
+ const auto* inst = b.Bitcast(mod.Types().i32(), b.Constant(4_i));
const auto args = inst->Args();
ASSERT_EQ(args.Length(), 1u);
diff --git a/src/tint/ir/block.h b/src/tint/ir/block.h
index abdf2a3..597ff2c 100644
--- a/src/tint/ir/block.h
+++ b/src/tint/ir/block.h
@@ -19,23 +19,22 @@
#include "src/tint/ir/block_param.h"
#include "src/tint/ir/branch.h"
-#include "src/tint/ir/flow_node.h"
#include "src/tint/ir/instruction.h"
#include "src/tint/utils/vector.h"
namespace tint::ir {
-/// A flow node comprising a block of statements. The instructions in the block are a linear list of
-/// instructions to execute. The block will branch at the end. The only blocks which do not branch
-/// are the end blocks of functions.
-class Block : public utils::Castable<Block, FlowNode> {
+/// A block of statements. The instructions in the block are a linear list of instructions to
+/// execute. The block will branch at the end. The only blocks which do not branch are the end
+/// blocks of functions.
+class Block : public utils::Castable<Block> {
public:
/// Constructor
Block();
~Block() override;
/// @returns true if this is block has a branch target set
- bool HasBranchTarget() const override {
+ bool HasBranchTarget() const {
return !instructions_.IsEmpty() && instructions_.Back()->Is<ir::Branch>();
}
@@ -49,7 +48,7 @@
/// @param target the block to see if we trampoline too
/// @returns if this block just branches to the provided target.
- bool IsTrampoline(const FlowNode* target) const {
+ bool IsTrampoline(const Block* target) const {
if (instructions_.Length() != 1) {
return false;
}
@@ -78,9 +77,23 @@
/// @returns the params to the block
utils::Vector<const BlockParam*, 0>& Params() { return params_; }
+ /// @returns the inbound branch list for the block
+ utils::VectorRef<ir::Branch*> InboundBranches() const { return inbound_branches_; }
+
+ /// Adds the given node to the inbound branches
+ /// @param node the node to add
+ void AddInboundBranch(ir::Branch* node) { inbound_branches_.Push(node); }
+
private:
utils::Vector<const Instruction*, 16> instructions_;
utils::Vector<const BlockParam*, 0> params_;
+
+ /// The list of branches into this node. This list maybe empty for several
+ /// reasons:
+ /// - Node is a start node
+ /// - Node is a merge target outside control flow (e.g. an if that returns in both branches)
+ /// - Node is a continue target outside control flow (e.g. a loop that returns)
+ utils::Vector<ir::Branch*, 2> inbound_branches_;
};
} // namespace tint::ir
diff --git a/src/tint/ir/branch.cc b/src/tint/ir/branch.cc
index a16b7ae..3648b61 100644
--- a/src/tint/ir/branch.cc
+++ b/src/tint/ir/branch.cc
@@ -16,13 +16,13 @@
#include <utility>
-#include "src/tint/ir/flow_node.h"
+#include "src/tint/ir/block.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::Branch);
namespace tint::ir {
-Branch::Branch(FlowNode* to, utils::VectorRef<Value*> args) : to_(to), args_(std::move(args)) {
+Branch::Branch(Block* to, utils::VectorRef<Value*> args) : to_(to), args_(std::move(args)) {
TINT_ASSERT(IR, to_);
to_->AddInboundBranch(this);
for (auto* arg : args) {
diff --git a/src/tint/ir/branch.h b/src/tint/ir/branch.h
index fcb1256..fe08b97 100644
--- a/src/tint/ir/branch.h
+++ b/src/tint/ir/branch.h
@@ -21,7 +21,7 @@
// Forward declarations
namespace tint::ir {
-class FlowNode;
+class Block;
} // namespace tint::ir
namespace tint::ir {
@@ -32,17 +32,17 @@
/// Constructor
/// @param to the block to branch too
/// @param args the branch arguments
- explicit Branch(FlowNode* to, utils::VectorRef<Value*> args = {});
+ explicit Branch(Block* to, utils::VectorRef<Value*> args = {});
~Branch() override;
/// @returns the block being branched too.
- const FlowNode* To() const { return to_; }
+ const Block* To() const { return to_; }
/// @returns the branch arguments
utils::VectorRef<Value*> Args() const { return args_; }
private:
- FlowNode* to_;
+ Block* to_;
utils::Vector<Value*, 2> args_;
};
diff --git a/src/tint/ir/builder.cc b/src/tint/ir/builder.cc
index 205754b..77c56c8 100644
--- a/src/tint/ir/builder.cc
+++ b/src/tint/ir/builder.cc
@@ -34,31 +34,31 @@
}
Block* Builder::CreateBlock() {
- return ir.flow_nodes.Create<Block>();
+ return ir.blocks.Create<Block>();
}
RootTerminator* Builder::CreateRootTerminator() {
- return ir.flow_nodes.Create<RootTerminator>();
+ return ir.blocks.Create<RootTerminator>();
}
FunctionTerminator* Builder::CreateFunctionTerminator() {
- return ir.flow_nodes.Create<FunctionTerminator>();
+ return ir.blocks.Create<FunctionTerminator>();
}
Function* Builder::CreateFunction(std::string_view name,
- type::Type* return_type,
+ const type::Type* return_type,
Function::PipelineStage stage,
std::optional<std::array<uint32_t, 3>> wg_size) {
return CreateFunction(ir.symbols.Register(name), return_type, stage, wg_size);
}
Function* Builder::CreateFunction(Symbol name,
- type::Type* return_type,
+ const type::Type* return_type,
Function::PipelineStage stage,
std::optional<std::array<uint32_t, 3>> wg_size) {
TINT_ASSERT(IR, return_type);
- auto* ir_func = ir.flow_nodes.Create<Function>(name, return_type, stage, wg_size);
+ auto* ir_func = ir.values.Create<Function>(name, return_type, stage, wg_size);
ir_func->SetStartTarget(CreateBlock());
ir_func->SetEndTarget(CreateFunctionTerminator());
@@ -170,7 +170,7 @@
}
Binary* Builder::Not(const type::Type* type, Value* val) {
- return Equal(type, val, Constant(create<constant::Scalar<bool>>(type, false)));
+ return Equal(type, val, Constant(false));
}
ir::Bitcast* Builder::Bitcast(const type::Type* type, Value* val) {
@@ -217,14 +217,10 @@
return ir.values.Create<ir::Var>(type);
}
-ir::Branch* Builder::Branch(FlowNode* to, utils::VectorRef<Value*> args) {
+ir::Branch* Builder::Branch(Block* to, utils::VectorRef<Value*> args) {
return ir.values.Create<ir::Branch>(to, args);
}
-ir::Jump* Builder::Jump(FlowNode* to, utils::VectorRef<Value*> args) {
- return ir.values.Create<ir::Jump>(to, args);
-}
-
ir::BlockParam* Builder::BlockParam(const type::Type* type) {
return ir.values.Create<ir::BlockParam>(type);
}
diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h
index 78bcb9b..02d2b36 100644
--- a/src/tint/ir/builder.h
+++ b/src/tint/ir/builder.h
@@ -30,7 +30,6 @@
#include "src/tint/ir/function_param.h"
#include "src/tint/ir/function_terminator.h"
#include "src/tint/ir/if.h"
-#include "src/tint/ir/jump.h"
#include "src/tint/ir/load.h"
#include "src/tint/ir/loop.h"
#include "src/tint/ir/module.h"
@@ -76,7 +75,7 @@
/// @param wg_size the workgroup_size
/// @returns the flow node
Function* CreateFunction(std::string_view name,
- type::Type* return_type,
+ const type::Type* return_type,
Function::PipelineStage stage = Function::PipelineStage::kUndefined,
std::optional<std::array<uint32_t, 3>> wg_size = {});
@@ -87,7 +86,7 @@
/// @param wg_size the workgroup_size
/// @returns the flow node
Function* CreateFunction(Symbol name,
- type::Type* return_type,
+ const type::Type* return_type,
Function::PipelineStage stage = Function::PipelineStage::kUndefined,
std::optional<std::array<uint32_t, 3>> wg_size = {});
@@ -111,50 +110,6 @@
/// @returns the start block for the case flow node
Block* CreateCase(Switch* s, utils::VectorRef<Switch::CaseSelector> selectors);
- /// Creates a constant::Value
- /// @param args the arguments
- /// @returns the new constant value
- template <typename T, typename... ARGS>
- utils::traits::EnableIf<utils::traits::IsTypeOrDerived<T, constant::Value>, const T>* create(
- ARGS&&... args) {
- return ir.constants_arena.Create<T>(std::forward<ARGS>(args)...);
- }
-
- /// @param v the value
- /// @returns the constant value
- const constant::Value* Bool(bool v) {
- // TODO(dsinclair): Replace when constant::Value is uniqed by the arena.
- return Constant(create<constant::Scalar<bool>>(ir.types.Get<type::Bool>(), v))->Value();
- }
-
- /// @param v the value
- /// @returns the constant value
- const constant::Value* U32(uint32_t v) {
- // TODO(dsinclair): Replace when constant::Value is uniqed by the arena.
- return Constant(create<constant::Scalar<u32>>(ir.types.Get<type::U32>(), u32(v)))->Value();
- }
-
- /// @param v the value
- /// @returns the constant value
- const constant::Value* I32(int32_t v) {
- // TODO(dsinclair): Replace when constant::Value is uniqed by the arena.
- return Constant(create<constant::Scalar<i32>>(ir.types.Get<type::I32>(), i32(v)))->Value();
- }
-
- /// @param v the value
- /// @returns the constant value
- const constant::Value* F16(float v) {
- // TODO(dsinclair): Replace when constant::Value is uniqed by the arena.
- return Constant(create<constant::Scalar<f16>>(ir.types.Get<type::F16>(), f16(v)))->Value();
- }
-
- /// @param v the value
- /// @returns the constant value
- const constant::Value* F32(float v) {
- // TODO(dsinclair): Replace when constant::Value is uniqed by the arena.
- return Constant(create<constant::Scalar<f32>>(ir.types.Get<type::F32>(), f32(v)))->Value();
- }
-
/// Creates a new ir::Constant
/// @param val the constant value
/// @returns the new constant
@@ -165,37 +120,27 @@
/// Creates a ir::Constant for an i32 Scalar
/// @param v the value
/// @returns the new constant
- ir::Constant* Constant(i32 v) {
- return Constant(create<constant::Scalar<i32>>(ir.types.Get<type::I32>(), v));
- }
+ ir::Constant* Constant(i32 v) { return Constant(ir.constant_values.Get(v)); }
/// Creates a ir::Constant for a u32 Scalar
/// @param v the value
/// @returns the new constant
- ir::Constant* Constant(u32 v) {
- return Constant(create<constant::Scalar<u32>>(ir.types.Get<type::U32>(), v));
- }
+ ir::Constant* Constant(u32 v) { return Constant(ir.constant_values.Get(v)); }
/// Creates a ir::Constant for a f32 Scalar
/// @param v the value
/// @returns the new constant
- ir::Constant* Constant(f32 v) {
- return Constant(create<constant::Scalar<f32>>(ir.types.Get<type::F32>(), v));
- }
+ ir::Constant* Constant(f32 v) { return Constant(ir.constant_values.Get(v)); }
/// Creates a ir::Constant for a f16 Scalar
/// @param v the value
/// @returns the new constant
- ir::Constant* Constant(f16 v) {
- return Constant(create<constant::Scalar<f16>>(ir.types.Get<type::F16>(), v));
- }
+ ir::Constant* Constant(f16 v) { return Constant(ir.constant_values.Get(v)); }
/// Creates a ir::Constant for a bool Scalar
/// @param v the value
/// @returns the new constant
- ir::Constant* Constant(bool v) {
- return Constant(create<constant::Scalar<bool>>(ir.types.Get<type::Bool>(), v));
- }
+ ir::Constant* Constant(bool v) { return Constant(ir.constant_values.Get(v)); }
/// Creates an op for `lhs kind rhs`
/// @param kind the kind of operation
@@ -403,13 +348,7 @@
/// @param to the node being branched too
/// @param args the branch arguments
/// @returns the instruction
- ir::Branch* Branch(FlowNode* to, utils::VectorRef<Value*> args = {});
-
- /// Creates a jump declaration
- /// @param to the node being branched too
- /// @param args the branch arguments
- /// @returns the instruction
- ir::Jump* Jump(FlowNode* to, utils::VectorRef<Value*> args = {});
+ ir::Branch* Branch(Block* to, utils::VectorRef<Value*> args = {});
/// Creates a new `BlockParam`
/// @param type the parameter type
diff --git a/src/tint/ir/debug.cc b/src/tint/ir/debug.cc
index 7e8155d..ea6363e 100644
--- a/src/tint/ir/debug.cc
+++ b/src/tint/ir/debug.cc
@@ -29,56 +29,56 @@
// static
std::string Debug::AsDotGraph(const Module* mod) {
- size_t node_count = 0;
+ size_t block_count = 0;
- std::unordered_set<const FlowNode*> visited;
- std::unordered_set<const FlowNode*> merge_nodes;
- std::unordered_map<const FlowNode*, std::string> node_to_name;
+ std::unordered_set<const Block*> visited;
+ std::unordered_set<const Block*> merge_blocks;
+ std::unordered_map<const Block*, std::string> block_to_name;
utils::StringStream out;
- auto name_for = [&](const FlowNode* node) -> std::string {
- if (node_to_name.count(node) > 0) {
- return node_to_name[node];
+ auto name_for = [&](const Block* blk) -> std::string {
+ if (block_to_name.count(blk) > 0) {
+ return block_to_name[blk];
}
- std::string name = "node_" + std::to_string(node_count);
- node_count += 1;
+ std::string name = "blk_" + std::to_string(block_count);
+ block_count += 1;
- node_to_name[node] = name;
+ block_to_name[blk] = name;
return name;
};
- std::function<void(const FlowNode*)> Graph = [&](const FlowNode* node) {
- if (visited.count(node) > 0) {
+ std::function<void(const Block*)> Graph = [&](const Block* blk) {
+ if (visited.count(blk) > 0) {
return;
}
- visited.insert(node);
+ visited.insert(blk);
tint::Switch(
- node,
+ blk,
+ [&](const ir::FunctionTerminator*) {
+ // Already done
+ },
[&](const ir::Block* b) {
- if (node_to_name.count(b) == 0) {
+ if (block_to_name.count(b) == 0) {
out << name_for(b) << R"( [label="block"])" << std::endl;
}
out << name_for(b) << " -> " << name_for(b->Branch()->To());
// Dashed lines to merge blocks
- if (merge_nodes.count(b->Branch()->To()) != 0) {
+ if (merge_blocks.count(b->Branch()->To()) != 0) {
out << " [style=dashed]";
}
out << std::endl;
Graph(b->Branch()->To());
- },
- [&](const ir::FunctionTerminator*) {
- // Already done
});
};
out << "digraph G {" << std::endl;
for (const auto* func : mod->functions) {
// Cluster each function to label and draw a box around it.
- out << "subgraph cluster_" << name_for(func) << " {" << std::endl;
+ out << "subgraph cluster_" << func->Name().Name() << " {" << std::endl;
out << R"(label=")" << func->Name().Name() << R"(")" << std::endl;
out << name_for(func->StartTarget()) << R"( [label="start"])" << std::endl;
out << name_for(func->EndTarget()) << R"( [label="end"])" << std::endl;
diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc
index 8eb0974..a0edd21 100644
--- a/src/tint/ir/disassembler.cc
+++ b/src/tint/ir/disassembler.cc
@@ -27,7 +27,6 @@
#include "src/tint/ir/discard.h"
#include "src/tint/ir/function_terminator.h"
#include "src/tint/ir/if.h"
-#include "src/tint/ir/jump.h"
#include "src/tint/ir/load.h"
#include "src/tint/ir/loop.h"
#include "src/tint/ir/root_terminator.h"
@@ -73,9 +72,9 @@
}
}
-size_t Disassembler::IdOf(const FlowNode* node) {
+size_t Disassembler::IdOf(const Block* node) {
TINT_ASSERT(IR, node);
- return flow_node_ids_.GetOrCreate(node, [&] { return flow_node_ids_.Count(); });
+ return block_ids_.GetOrCreate(node, [&] { return block_ids_.Count(); });
}
std::string_view Disassembler::IdOf(const Value* value) {
@@ -90,111 +89,106 @@
std::string Disassembler::Disassemble() {
if (mod_.root_block) {
- walk_list_.push_back(mod_.root_block);
- Walk();
- TINT_ASSERT(IR, walk_list_.empty());
+ Indent() << "# Root block" << std::endl;
+ Walk(mod_.root_block);
+ Walk(mod_.root_block->Branch()->To());
}
for (auto* func : mod_.functions) {
- walk_list_.push_back(func);
- Walk();
- TINT_ASSERT(IR, walk_list_.empty());
+ EmitFunction(func);
}
return out_.str();
}
-void Disassembler::Walk() {
- utils::Hashset<const FlowNode*, 32> visited_;
+void Disassembler::Walk(const Block* blk) {
+ if (visited_.Contains(blk)) {
+ return;
+ }
+ visited_.Add(blk);
- while (!walk_list_.empty()) {
- const FlowNode* node = walk_list_.front();
- walk_list_.pop_front();
+ tint::Switch(
+ blk,
+ [&](const ir::FunctionTerminator* t) {
+ TINT_ASSERT(IR, in_function_);
+ Indent() << "%fn" << IdOf(t) << " = func_terminator" << std::endl;
+ in_function_ = false;
+ },
+ [&](const ir::RootTerminator* t) {
+ TINT_ASSERT(IR, !in_function_);
+ Indent() << "%fn" << IdOf(t) << " = root_terminator" << std::endl << std::endl;
+ },
+ [&](const ir::Block* b) {
+ // If this block is dead, nothing to do
+ if (!b->HasBranchTarget()) {
+ return;
+ }
- if (visited_.Contains(node)) {
- continue;
- }
- visited_.Add(node);
-
- tint::Switch(
- node,
- [&](const ir::Function* f) {
- in_function_ = true;
-
- Indent() << "%fn" << IdOf(f) << " = func " << f->Name().Name() << "(";
- for (auto* p : f->Params()) {
- if (p != f->Params().Front()) {
+ Indent() << "%fn" << IdOf(b) << " = block";
+ if (!b->Params().IsEmpty()) {
+ out_ << " (";
+ for (auto* p : b->Params()) {
+ if (p != b->Params().Front()) {
out_ << ", ";
}
- out_ << "%" << IdOf(p) << ":" << p->Type()->FriendlyName();
+ EmitValue(p);
}
- out_ << "):" << f->ReturnType()->FriendlyName();
+ out_ << ")";
+ }
- if (f->Stage() != Function::PipelineStage::kUndefined) {
- out_ << " [@" << f->Stage();
+ out_ << " {" << std::endl;
+ {
+ ScopedIndent si(indent_size_);
+ EmitBlockInstructions(b);
+ }
+ Indent() << "}" << std::endl;
- if (f->WorkgroupSize()) {
- auto arr = f->WorkgroupSize().value();
- out_ << " @workgroup_size(" << arr[0] << ", " << arr[1] << ", " << arr[2]
- << ")";
- }
+ if (!b->Branch()->To()->Is<FunctionTerminator>()) {
+ out_ << std::endl;
+ }
+ });
+}
- if (!f->ReturnAttributes().IsEmpty()) {
- out_ << " ra:";
+void Disassembler::EmitFunction(const Function* func) {
+ in_function_ = true;
- for (auto attr : f->ReturnAttributes()) {
- out_ << " @" << attr;
- if (attr == Function::ReturnAttribute::kLocation) {
- out_ << "(" << f->ReturnLocation().value() << ")";
- }
- }
- }
-
- out_ << "]";
- }
- out_ << " -> %fn" << IdOf(f->StartTarget()) << std::endl;
- walk_list_.push_back(f->StartTarget());
- },
- [&](const ir::Block* b) {
- // If this block is dead, nothing to do
- if (!b->HasBranchTarget()) {
- return;
- }
-
- Indent() << "%fn" << IdOf(b) << " = block";
- if (!b->Params().IsEmpty()) {
- out_ << " (";
- for (auto* p : b->Params()) {
- if (p != b->Params().Front()) {
- out_ << ", ";
- }
- EmitValue(p);
- }
- out_ << ")";
- }
-
- out_ << " {" << std::endl;
- {
- ScopedIndent si(indent_size_);
- EmitBlockInstructions(b);
- }
- Indent() << "}" << std::endl;
-
- if (!b->Branch()->To()->Is<FunctionTerminator>()) {
- out_ << std::endl;
- }
-
- walk_list_.push_back(b->Branch()->To());
- },
- [&](const ir::FunctionTerminator* t) {
- TINT_ASSERT(IR, in_function_);
- Indent() << "%fn" << IdOf(t) << " = func_terminator" << std::endl << std::endl;
- in_function_ = false;
- },
- [&](const ir::RootTerminator* t) {
- TINT_ASSERT(IR, !in_function_);
- Indent() << "%fn" << IdOf(t) << " = root_terminator" << std::endl << std::endl;
- });
+ Indent() << "%" << IdOf(func) << " = func " << func->Name().Name() << "(";
+ for (auto* p : func->Params()) {
+ if (p != func->Params().Front()) {
+ out_ << ", ";
+ }
+ out_ << "%" << IdOf(p) << ":" << p->Type()->FriendlyName();
}
+ out_ << "):" << func->ReturnType()->FriendlyName();
+
+ if (func->Stage() != Function::PipelineStage::kUndefined) {
+ out_ << " [@" << func->Stage();
+
+ if (func->WorkgroupSize()) {
+ auto arr = func->WorkgroupSize().value();
+ out_ << " @workgroup_size(" << arr[0] << ", " << arr[1] << ", " << arr[2] << ")";
+ }
+
+ if (!func->ReturnAttributes().IsEmpty()) {
+ out_ << " ra:";
+
+ for (auto attr : func->ReturnAttributes()) {
+ out_ << " @" << attr;
+ if (attr == Function::ReturnAttribute::kLocation) {
+ out_ << "(" << func->ReturnLocation().value() << ")";
+ }
+ }
+ }
+
+ out_ << "]";
+ }
+ out_ << " -> %fn" << IdOf(func->StartTarget()) << " {" << std::endl;
+
+ {
+ ScopedIndent si(indent_size_);
+ Walk(func->StartTarget());
+ Walk(func->EndTarget());
+ }
+ Indent() << "}" << std::endl;
}
void Disassembler::EmitValueWithType(const Value* val) {
@@ -333,40 +327,51 @@
}
out_ << "f: %fn" << IdOf(i->False());
}
- if (i->Merge()->IsConnected()) {
+ if (i->Merge()->HasBranchTarget()) {
out_ << ", m: %fn" << IdOf(i->Merge());
}
- out_ << "]";
+ out_ << "]" << std::endl;
if (has_true) {
- walk_list_.push_back(i->True());
+ ScopedIndent si(indent_size_);
+ Indent() << "# True block" << std::endl;
+ Walk(i->True());
}
if (has_false) {
- walk_list_.push_back(i->False());
+ ScopedIndent si(indent_size_);
+ Indent() << "# False block" << std::endl;
+ Walk(i->False());
}
- if (i->Merge()->IsConnected()) {
- walk_list_.push_back(i->Merge());
+ if (i->Merge()->HasBranchTarget()) {
+ Indent() << "# Merge block" << std::endl;
+ Walk(i->Merge());
}
}
void Disassembler::EmitLoop(const Loop* l) {
out_ << "loop [s: %fn" << IdOf(l->Start());
- if (l->Continuing()->IsConnected()) {
+ if (l->Continuing()->HasBranchTarget()) {
out_ << ", c: %fn" << IdOf(l->Continuing());
}
- if (l->Merge()->IsConnected()) {
+ if (l->Merge()->HasBranchTarget()) {
out_ << ", m: %fn" << IdOf(l->Merge());
}
- out_ << "]";
+ out_ << "]" << std::endl;
- { walk_list_.push_back(l->Start()); }
-
- if (l->Continuing()->IsConnected()) {
- walk_list_.push_back(l->Continuing());
+ {
+ ScopedIndent si(indent_size_);
+ Walk(l->Start());
}
- if (l->Merge()->IsConnected()) {
- walk_list_.push_back(l->Merge());
+
+ if (l->Continuing()->HasBranchTarget()) {
+ ScopedIndent si(indent_size_);
+ Indent() << "# Continuing block" << std::endl;
+ Walk(l->Continuing());
+ }
+ if (l->Merge()->HasBranchTarget()) {
+ Indent() << "# Merge block" << std::endl;
+ Walk(l->Merge());
}
}
@@ -392,32 +397,25 @@
}
out_ << ", %fn" << IdOf(c.Start()) << ")";
}
- if (s->Merge()->IsConnected()) {
+ if (s->Merge()->HasBranchTarget()) {
out_ << ", m: %fn" << IdOf(s->Merge());
}
- out_ << "]";
+ out_ << "]" << std::endl;
for (auto& c : s->Cases()) {
- walk_list_.push_back(c.Start());
+ ScopedIndent si(indent_size_);
+ Indent() << "# Case block" << std::endl;
+ Walk(c.Start());
}
- if (s->Merge()->IsConnected()) {
- walk_list_.push_back(s->Merge());
+ if (s->Merge()->HasBranchTarget()) {
+ Indent() << "# Merge block" << std::endl;
+ Walk(s->Merge());
}
}
void Disassembler::EmitBranch(const Branch* b) {
- if (b->Is<Jump>()) {
- out_ << "jmp ";
-
- // Stuff the thing we're jumping too into the front of the walk list so it will be emitted
- // next.
- walk_list_.push_front(b->To());
- } else {
- out_ << "br ";
- }
-
std::string suffix = "";
- out_ << "%fn" << IdOf(b->To());
+ out_ << "br %fn" << IdOf(b->To());
if (b->To()->Is<FunctionTerminator>()) {
suffix = "return";
} else if (b->To()->Is<RootTerminator>()) {
diff --git a/src/tint/ir/disassembler.h b/src/tint/ir/disassembler.h
index 7b9e4c5..f171031 100644
--- a/src/tint/ir/disassembler.h
+++ b/src/tint/ir/disassembler.h
@@ -15,18 +15,18 @@
#ifndef SRC_TINT_IR_DISASSEMBLER_H_
#define SRC_TINT_IR_DISASSEMBLER_H_
-#include <deque>
#include <string>
#include "src/tint/ir/binary.h"
+#include "src/tint/ir/block.h"
#include "src/tint/ir/call.h"
-#include "src/tint/ir/flow_node.h"
#include "src/tint/ir/if.h"
#include "src/tint/ir/loop.h"
#include "src/tint/ir/module.h"
#include "src/tint/ir/switch.h"
#include "src/tint/ir/unary.h"
#include "src/tint/utils/hashmap.h"
+#include "src/tint/utils/hashset.h"
#include "src/tint/utils/string_stream.h"
namespace tint::ir {
@@ -53,10 +53,11 @@
private:
utils::StringStream& Indent();
- size_t IdOf(const FlowNode* node);
+ size_t IdOf(const Block* blk);
std::string_view IdOf(const Value* node);
- void Walk();
+ void Walk(const Block* blk);
+ void EmitFunction(const Function* func);
void EmitInstruction(const Instruction* inst);
void EmitValueWithType(const Value* val);
void EmitValue(const Value* val);
@@ -70,8 +71,8 @@
const Module& mod_;
utils::StringStream out_;
- std::deque<const FlowNode*> walk_list_;
- utils::Hashmap<const FlowNode*, size_t, 32> flow_node_ids_;
+ utils::Hashset<const Block*, 32> visited_;
+ utils::Hashmap<const Block*, size_t, 32> block_ids_;
utils::Hashmap<const Value*, std::string, 32> value_ids_;
uint32_t indent_size_ = 0;
bool in_function_ = false;
diff --git a/src/tint/ir/flow_node.cc b/src/tint/ir/flow_node.cc
deleted file mode 100644
index bbbd78b..0000000
--- a/src/tint/ir/flow_node.cc
+++ /dev/null
@@ -1,25 +0,0 @@
-// Copyright 2022 The Tint Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "src/tint/ir/flow_node.h"
-
-TINT_INSTANTIATE_TYPEINFO(tint::ir::FlowNode);
-
-namespace tint::ir {
-
-FlowNode::FlowNode() = default;
-
-FlowNode::~FlowNode() = default;
-
-} // namespace tint::ir
diff --git a/src/tint/ir/flow_node.h b/src/tint/ir/flow_node.h
deleted file mode 100644
index b072964..0000000
--- a/src/tint/ir/flow_node.h
+++ /dev/null
@@ -1,61 +0,0 @@
-// Copyright 2022 The Tint Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef SRC_TINT_IR_FLOW_NODE_H_
-#define SRC_TINT_IR_FLOW_NODE_H_
-
-#include "src/tint/utils/castable.h"
-#include "src/tint/utils/vector.h"
-
-// Forward Declarations
-namespace tint::ir {
-class Branch;
-} // namespace tint::ir
-
-namespace tint::ir {
-
-/// Base class for flow nodes
-class FlowNode : public utils::Castable<FlowNode> {
- public:
- ~FlowNode() override;
-
- /// @returns true if this node has inbound branches and branches out
- bool IsConnected() const { return HasBranchTarget(); }
-
- /// @returns true if the node has a branch target
- virtual bool HasBranchTarget() const { return false; }
-
- /// @returns the inbound branch list for the flow node
- utils::VectorRef<Branch*> InboundBranches() const { return inbound_branches_; }
-
- /// Adds the given node to the inbound branches
- /// @param node the node to add
- void AddInboundBranch(Branch* node) { inbound_branches_.Push(node); }
-
- protected:
- /// Constructor
- FlowNode();
-
- private:
- /// The list of flow nodes which branch into this node. This list maybe empty for several
- /// reasons:
- /// - Node is a start node
- /// - Node is a merge target outside control flow (e.g. an if that returns in both branches)
- /// - Node is a continue target outside control flow (e.g. a loop that returns)
- utils::Vector<Branch*, 2> inbound_branches_;
-};
-
-} // namespace tint::ir
-
-#endif // SRC_TINT_IR_FLOW_NODE_H_
diff --git a/src/tint/ir/from_program.cc b/src/tint/ir/from_program.cc
index 212efea..f64037a 100644
--- a/src/tint/ir/from_program.cc
+++ b/src/tint/ir/from_program.cc
@@ -101,11 +101,7 @@
// For an `if` and `switch` block, the merge has a registered incoming branch instruction of the
// `if` and `switch. So, to determine if the merge is connected to any of the branches that happend
// in the `if` or `switch` we need a `count` value that is larger then 1.
-bool IsConnected(const FlowNode* b, uint32_t count) {
- // Function is always connected as it's the start.
- if (b->Is<ir::Function>()) {
- return true;
- }
+bool IsConnected(const Block* b, uint32_t count) {
return b->InboundBranches().Length() > count;
}
@@ -136,9 +132,9 @@
constant::CloneContext clone_ctx_{
/* type_ctx */ type::CloneContext{
/* src */ {&program_->Symbols()},
- /* dst */ {&builder_.ir.symbols, &builder_.ir.types},
+ /* dst */ {&builder_.ir.symbols, &builder_.ir.Types()},
},
- /* dst */ {&builder_.ir.constants_arena},
+ /* dst */ {builder_.ir.constant_values},
};
/// The stack of control blocks.
@@ -170,21 +166,7 @@
diagnostics_.add_error(tint::diag::System::IR, err, s);
}
- void JumpTo(FlowNode* node, utils::VectorRef<Value*> args = {}) {
- TINT_ASSERT(IR, current_flow_block_);
- TINT_ASSERT(IR, !current_flow_block_->HasBranchTarget());
-
- current_flow_block_->Instructions().Push(builder_.Jump(node, args));
- current_flow_block_ = nullptr;
- }
- void JumpToIfNeeded(FlowNode* node) {
- if (!current_flow_block_ || current_flow_block_->HasBranchTarget()) {
- return;
- }
- JumpTo(node);
- }
-
- void BranchTo(FlowNode* node, utils::VectorRef<Value*> args = {}) {
+ void BranchTo(Block* node, utils::VectorRef<Value*> args = {}) {
TINT_ASSERT(IR, current_flow_block_);
TINT_ASSERT(IR, !current_flow_block_->HasBranchTarget());
@@ -192,7 +174,7 @@
current_flow_block_ = nullptr;
}
- void BranchToIfNeeded(FlowNode* node) {
+ void BranchToIfNeeded(Block* node) {
if (!current_flow_block_ || current_flow_block_->HasBranchTarget()) {
return;
}
@@ -357,7 +339,7 @@
// If the branch target has already been set then a `return` was called. Only set in
// the case where `return` wasn't called.
- JumpToIfNeeded(current_function_->EndTarget());
+ BranchToIfNeeded(current_function_->EndTarget());
}
TINT_ASSERT(IR, control_stack_.IsEmpty());
@@ -587,7 +569,7 @@
// The current block didn't `break`, `return` or `continue`, go to the continuing
// block.
- JumpToIfNeeded(loop_inst->Continuing());
+ BranchToIfNeeded(loop_inst->Continuing());
current_flow_block_ = loop_inst->Continuing();
if (stmt->continuing) {
@@ -634,7 +616,7 @@
current_flow_block_ = if_inst->Merge();
EmitBlock(stmt->body);
- JumpToIfNeeded(loop_inst->Continuing());
+ BranchToIfNeeded(loop_inst->Continuing());
}
// The while loop always has a path to the Merge().target as the break statement comes
// before anything inside the loop.
@@ -678,7 +660,7 @@
}
EmitBlock(stmt->body);
- JumpToIfNeeded(loop_inst->Continuing());
+ BranchToIfNeeded(loop_inst->Continuing());
if (stmt->continuing) {
current_flow_block_ = loop_inst->Continuing();
@@ -766,7 +748,7 @@
// Discard is being treated as an instruction. The semantics in WGSL is demote_to_helper, so
// the code has to continue as before it just predicates writes. If WGSL grows some kind of
- // terminating discard that would probably make sense as a FlowNode but would then require
+ // terminating discard that would probably make sense as a Block but would then require
// figuring out the multi-level exit that is triggered.
void EmitDiscard(const ast::DiscardStatement*) {
auto* inst = builder_.Discard();
@@ -859,7 +841,7 @@
var,
[&](const ast::Var* v) {
auto* ref = sem->Type()->As<type::Reference>();
- auto* ty = builder_.ir.types.Get<type::Pointer>(
+ auto* ty = builder_.ir.Types().Get<type::Pointer>(
ref->StoreType()->Clone(clone_ctx_.type_ctx), ref->AddressSpace(),
ref->Access());
@@ -964,7 +946,7 @@
auto* if_inst = builder_.CreateIf(lhs.Get());
current_flow_block_->Instructions().Push(if_inst);
- auto* result = builder_.BlockParam(builder_.ir.types.Get<type::Bool>());
+ auto* result = builder_.BlockParam(builder_.ir.Types().bool_());
if_inst->Merge()->SetParams(utils::Vector{result});
utils::Result<Value*> rhs;
diff --git a/src/tint/ir/from_program_binary_test.cc b/src/tint/ir/from_program_binary_test.cc
index 6e63e40..7909bda 100644
--- a/src/tint/ir/from_program_binary_test.cc
+++ b/src/tint/ir/from_program_binary_test.cc
@@ -34,20 +34,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():u32 -> %fn2
-%fn2 = block {
- br %fn3 0u # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0u # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:u32 = call my_func
- %tint_symbol:u32 = add %1, 4u
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = call my_func
+ %tint_symbol:u32 = add %3, 4u
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -59,22 +59,23 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%v1:ptr<private, u32, read_write> = var
br %fn2 # root_end
}
%fn2 = root_terminator
-%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn4
-%fn4 = block {
- %2:u32 = load %v1
- %3:u32 = add %2, 1u
- store %v1, %3
- jmp %fn5 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = load %v1
+ %4:u32 = add %3, 1u
+ store %v1, %4
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn5 = func_terminator
-
)");
}
@@ -86,22 +87,23 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%v1:ptr<private, u32, read_write> = var
br %fn2 # root_end
}
%fn2 = root_terminator
-%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn4
-%fn4 = block {
- %2:u32 = load %v1
- %3:u32 = add %2, 1u
- store %v1, %3
- jmp %fn5 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = load %v1
+ %4:u32 = add %3, 1u
+ store %v1, %4
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn5 = func_terminator
-
)");
}
@@ -113,20 +115,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():u32 -> %fn2
-%fn2 = block {
- br %fn3 0u # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0u # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:u32 = call my_func
- %tint_symbol:u32 = sub %1, 4u
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = call my_func
+ %tint_symbol:u32 = sub %3, 4u
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -138,22 +140,23 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%v1:ptr<private, i32, read_write> = var
br %fn2 # root_end
}
%fn2 = root_terminator
-%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn4
-%fn4 = block {
- %2:i32 = load %v1
- %3:i32 = sub %2, 1i
- store %v1, %3
- jmp %fn5 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:i32 = load %v1
+ %4:i32 = sub %3, 1i
+ store %v1, %4
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn5 = func_terminator
-
)");
}
@@ -165,22 +168,23 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%v1:ptr<private, u32, read_write> = var
br %fn2 # root_end
}
%fn2 = root_terminator
-%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn4
-%fn4 = block {
- %2:u32 = load %v1
- %3:u32 = sub %2, 1u
- store %v1, %3
- jmp %fn5 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = load %v1
+ %4:u32 = sub %3, 1u
+ store %v1, %4
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn5 = func_terminator
-
)");
}
@@ -192,20 +196,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():u32 -> %fn2
-%fn2 = block {
- br %fn3 0u # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0u # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:u32 = call my_func
- %tint_symbol:u32 = mul %1, 4u
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = call my_func
+ %tint_symbol:u32 = mul %3, 4u
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -217,22 +221,23 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%v1:ptr<private, u32, read_write> = var
br %fn2 # root_end
}
%fn2 = root_terminator
-%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn4
-%fn4 = block {
- %2:u32 = load %v1
- %3:u32 = mul %2, 1u
- store %v1, %3
- jmp %fn5 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = load %v1
+ %4:u32 = mul %3, 1u
+ store %v1, %4
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn5 = func_terminator
-
)");
}
@@ -244,20 +249,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():u32 -> %fn2
-%fn2 = block {
- br %fn3 0u # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0u # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:u32 = call my_func
- %tint_symbol:u32 = div %1, 4u
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = call my_func
+ %tint_symbol:u32 = div %3, 4u
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -269,22 +274,23 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%v1:ptr<private, u32, read_write> = var
br %fn2 # root_end
}
%fn2 = root_terminator
-%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn4
-%fn4 = block {
- %2:u32 = load %v1
- %3:u32 = div %2, 1u
- store %v1, %3
- jmp %fn5 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = load %v1
+ %4:u32 = div %3, 1u
+ store %v1, %4
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn5 = func_terminator
-
)");
}
@@ -296,20 +302,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():u32 -> %fn2
-%fn2 = block {
- br %fn3 0u # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0u # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:u32 = call my_func
- %tint_symbol:u32 = mod %1, 4u
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = call my_func
+ %tint_symbol:u32 = mod %3, 4u
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -321,22 +327,23 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%v1:ptr<private, u32, read_write> = var
br %fn2 # root_end
}
%fn2 = root_terminator
-%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn4
-%fn4 = block {
- %2:u32 = load %v1
- %3:u32 = mod %2, 1u
- store %v1, %3
- jmp %fn5 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = load %v1
+ %4:u32 = mod %3, 1u
+ store %v1, %4
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn5 = func_terminator
-
)");
}
@@ -348,20 +355,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():u32 -> %fn2
-%fn2 = block {
- br %fn3 0u # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0u # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:u32 = call my_func
- %tint_symbol:u32 = and %1, 4u
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = call my_func
+ %tint_symbol:u32 = and %3, 4u
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -373,22 +380,23 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%v1:ptr<private, bool, read_write> = var
br %fn2 # root_end
}
%fn2 = root_terminator
-%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn4
-%fn4 = block {
- %2:bool = load %v1
- %3:bool = and %2, false
- store %v1, %3
- jmp %fn5 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:bool = load %v1
+ %4:bool = and %3, false
+ store %v1, %4
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn5 = func_terminator
-
)");
}
@@ -400,20 +408,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():u32 -> %fn2
-%fn2 = block {
- br %fn3 0u # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0u # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:u32 = call my_func
- %tint_symbol:u32 = or %1, 4u
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = call my_func
+ %tint_symbol:u32 = or %3, 4u
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -425,22 +433,23 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%v1:ptr<private, bool, read_write> = var
br %fn2 # root_end
}
%fn2 = root_terminator
-%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn4
-%fn4 = block {
- %2:bool = load %v1
- %3:bool = or %2, false
- store %v1, %3
- jmp %fn5 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:bool = load %v1
+ %4:bool = or %3, false
+ store %v1, %4
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn5 = func_terminator
-
)");
}
@@ -452,20 +461,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():u32 -> %fn2
-%fn2 = block {
- br %fn3 0u # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0u # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:u32 = call my_func
- %tint_symbol:u32 = xor %1, 4u
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = call my_func
+ %tint_symbol:u32 = xor %3, 4u
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -477,22 +486,23 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%v1:ptr<private, u32, read_write> = var
br %fn2 # root_end
}
%fn2 = root_terminator
-%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn4
-%fn4 = block {
- %2:u32 = load %v1
- %3:u32 = xor %2, 1u
- store %v1, %3
- jmp %fn5 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = load %v1
+ %4:u32 = xor %3, 1u
+ store %v1, %4
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn5 = func_terminator
-
)");
}
@@ -504,43 +514,51 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():bool -> %fn2
-%fn2 = block {
- br %fn3 true # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():bool -> %fn1 {
+ %fn1 = block {
+ br %fn2 true # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:bool = call my_func
+ if %3 [t: %fn4, f: %fn5, m: %fn6]
+ # True block
+ %fn4 = block {
+ br %fn6 false
+ }
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:bool = call my_func
- if %1 [t: %fn6, f: %fn7, m: %fn8]
+ # False block
+ %fn5 = block {
+ br %fn6 %3
+ }
+
+ # Merge block
+ %fn6 = block (%4:bool) {
+ if %4:bool [t: %fn7, f: %fn8, m: %fn9]
+ # True block
+ %fn7 = block {
+ br %fn9
+ }
+
+ # False block
+ %fn8 = block {
+ br %fn9
+ }
+
+ # Merge block
+ %fn9 = block {
+ br %fn10 # return
+ }
+
+ }
+
+
+ }
+
+ %fn10 = func_terminator
}
-
-%fn6 = block {
- br %fn8 false
-}
-
-%fn7 = block {
- br %fn8 %1
-}
-
-%fn8 = block (%2:bool) {
- if %2:bool [t: %fn9, f: %fn10, m: %fn11]
-}
-
-%fn9 = block {
- br %fn11
-}
-
-%fn10 = block {
- br %fn11
-}
-
-%fn11 = block {
- jmp %fn12 # return
-}
-%fn12 = func_terminator
-
)");
}
@@ -552,43 +570,51 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():bool -> %fn2
-%fn2 = block {
- br %fn3 true # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():bool -> %fn1 {
+ %fn1 = block {
+ br %fn2 true # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:bool = call my_func
+ if %3 [t: %fn4, f: %fn5, m: %fn6]
+ # True block
+ %fn4 = block {
+ br %fn6 %3
+ }
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:bool = call my_func
- if %1 [t: %fn6, f: %fn7, m: %fn8]
+ # False block
+ %fn5 = block {
+ br %fn6 true
+ }
+
+ # Merge block
+ %fn6 = block (%4:bool) {
+ if %4:bool [t: %fn7, f: %fn8, m: %fn9]
+ # True block
+ %fn7 = block {
+ br %fn9
+ }
+
+ # False block
+ %fn8 = block {
+ br %fn9
+ }
+
+ # Merge block
+ %fn9 = block {
+ br %fn10 # return
+ }
+
+ }
+
+
+ }
+
+ %fn10 = func_terminator
}
-
-%fn6 = block {
- br %fn8 %1
-}
-
-%fn7 = block {
- br %fn8 true
-}
-
-%fn8 = block (%2:bool) {
- if %2:bool [t: %fn9, f: %fn10, m: %fn11]
-}
-
-%fn9 = block {
- br %fn11
-}
-
-%fn10 = block {
- br %fn11
-}
-
-%fn11 = block {
- jmp %fn12 # return
-}
-%fn12 = func_terminator
-
)");
}
@@ -600,20 +626,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():u32 -> %fn2
-%fn2 = block {
- br %fn3 0u # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0u # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:u32 = call my_func
- %tint_symbol:bool = eq %1, 4u
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = call my_func
+ %tint_symbol:bool = eq %3, 4u
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -625,20 +651,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():u32 -> %fn2
-%fn2 = block {
- br %fn3 0u # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0u # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:u32 = call my_func
- %tint_symbol:bool = neq %1, 4u
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = call my_func
+ %tint_symbol:bool = neq %3, 4u
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -650,20 +676,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():u32 -> %fn2
-%fn2 = block {
- br %fn3 0u # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0u # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:u32 = call my_func
- %tint_symbol:bool = lt %1, 4u
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = call my_func
+ %tint_symbol:bool = lt %3, 4u
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -675,20 +701,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():u32 -> %fn2
-%fn2 = block {
- br %fn3 0u # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0u # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:u32 = call my_func
- %tint_symbol:bool = gt %1, 4u
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = call my_func
+ %tint_symbol:bool = gt %3, 4u
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -700,20 +726,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():u32 -> %fn2
-%fn2 = block {
- br %fn3 0u # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0u # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:u32 = call my_func
- %tint_symbol:bool = lte %1, 4u
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = call my_func
+ %tint_symbol:bool = lte %3, 4u
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -725,20 +751,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():u32 -> %fn2
-%fn2 = block {
- br %fn3 0u # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0u # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:u32 = call my_func
- %tint_symbol:bool = gte %1, 4u
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = call my_func
+ %tint_symbol:bool = gte %3, 4u
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -750,20 +776,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():u32 -> %fn2
-%fn2 = block {
- br %fn3 0u # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0u # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:u32 = call my_func
- %tint_symbol:u32 = shiftl %1, 4u
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = call my_func
+ %tint_symbol:u32 = shiftl %3, 4u
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -775,22 +801,23 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%v1:ptr<private, u32, read_write> = var
br %fn2 # root_end
}
%fn2 = root_terminator
-%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn4
-%fn4 = block {
- %2:u32 = load %v1
- %3:u32 = shiftl %2, 1u
- store %v1, %3
- jmp %fn5 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = load %v1
+ %4:u32 = shiftl %3, 1u
+ store %v1, %4
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn5 = func_terminator
-
)");
}
@@ -802,20 +829,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():u32 -> %fn2
-%fn2 = block {
- br %fn3 0u # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0u # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:u32 = call my_func
- %tint_symbol:u32 = shiftr %1, 4u
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = call my_func
+ %tint_symbol:u32 = shiftr %3, 4u
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -827,22 +854,23 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%v1:ptr<private, u32, read_write> = var
br %fn2 # root_end
}
%fn2 = root_terminator
-%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn4
-%fn4 = block {
- %2:u32 = load %v1
- %3:u32 = shiftr %2, 1u
- store %v1, %3
- jmp %fn5 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = load %v1
+ %4:u32 = shiftr %3, 1u
+ store %v1, %4
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn5 = func_terminator
-
)");
}
@@ -856,37 +884,41 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():f32 -> %fn2
-%fn2 = block {
- br %fn3 0.0f # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():f32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0.0f # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:f32 = call my_func
+ %4:bool = lt %3, 2.0f
+ if %4 [t: %fn4, f: %fn5, m: %fn6]
+ # True block
+ %fn4 = block {
+ %5:f32 = call my_func
+ %6:f32 = call my_func
+ %7:f32 = mul 2.29999995231628417969f, %6
+ %8:f32 = div %5, %7
+ %9:bool = gt 2.5f, %8
+ br %fn6 %9
+ }
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:f32 = call my_func
- %2:bool = lt %1, 2.0f
- if %2 [t: %fn6, f: %fn7, m: %fn8]
+ # False block
+ %fn5 = block {
+ br %fn6 %4
+ }
+
+ # Merge block
+ %fn6 = block (%tint_symbol:bool) {
+ br %fn7 # return
+ }
+
+ }
+
+ %fn7 = func_terminator
}
-
-%fn6 = block {
- %3:f32 = call my_func
- %4:f32 = call my_func
- %5:f32 = mul 2.29999995231628417969f, %4
- %6:f32 = div %3, %5
- %7:bool = gt 2.5f, %6
- br %fn8 %7
-}
-
-%fn7 = block {
- br %fn8 %2
-}
-
-%fn8 = block (%tint_symbol:bool) {
- jmp %fn9 # return
-}
-%fn9 = func_terminator
-
)");
}
@@ -899,19 +931,19 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func(%p:bool):bool -> %fn2
-%fn2 = block {
- br %fn3 true # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func(%p:bool):bool -> %fn1 {
+ %fn1 = block {
+ br %fn2 true # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %tint_symbol:bool = call my_func, false
- jmp %fn6 # return
+%3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %tint_symbol:bool = call my_func, false
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
diff --git a/src/tint/ir/from_program_call_test.cc b/src/tint/ir/from_program_call_test.cc
index 2f74508..2d0f1ab 100644
--- a/src/tint/ir/from_program_call_test.cc
+++ b/src/tint/ir/from_program_call_test.cc
@@ -35,20 +35,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():f32 -> %fn2
-%fn2 = block {
- br %fn3 0.0f # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():f32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 0.0f # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:f32 = call my_func
- %tint_symbol:f32 = bitcast %1
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:f32 = call my_func
+ %tint_symbol:f32 = bitcast %3
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -62,13 +62,13 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func test_function():void [@fragment] -> %fn2
-%fn2 = block {
- discard
- jmp %fn3 # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func test_function():void [@fragment] -> %fn1 {
+ %fn1 = block {
+ discard
+ br %fn2 # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
)");
}
@@ -80,19 +80,19 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func(%p:f32):void -> %fn2
-%fn2 = block {
- jmp %fn3 # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func(%p:f32):void -> %fn1 {
+ %fn1 = block {
+ br %fn2 # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %2:void = call my_func, 6.0f
- jmp %fn6 # return
+%3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %4:void = call my_func, 6.0f
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -104,21 +104,22 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%i:ptr<private, i32, read_write> = var, 1i
br %fn2 # root_end
}
%fn2 = root_terminator
-%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn4
-%fn4 = block {
- %2:i32 = load %i
- %tint_symbol:f32 = convert i32, %2
- jmp %fn5 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:i32 = load %i
+ %tint_symbol:f32 = convert i32, %3
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn5 = func_terminator
-
)");
}
@@ -129,7 +130,8 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%i:ptr<private, vec3<f32>, read_write> = var, vec3<f32> 0.0f
br %fn2 # root_end
}
@@ -147,21 +149,22 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%i:ptr<private, f32, read_write> = var, 1.0f
br %fn2 # root_end
}
%fn2 = root_terminator
-%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn4
-%fn4 = block {
- %2:f32 = load %i
- %tint_symbol:vec3<f32> = construct 2.0f, 3.0f, %2
- jmp %fn5 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:f32 = load %i
+ %tint_symbol:vec3<f32> = construct 2.0f, 3.0f, %3
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn5 = func_terminator
-
)");
}
diff --git a/src/tint/ir/from_program_materialize_test.cc b/src/tint/ir/from_program_materialize_test.cc
index 009d417..68bb81a 100644
--- a/src/tint/ir/from_program_materialize_test.cc
+++ b/src/tint/ir/from_program_materialize_test.cc
@@ -34,12 +34,12 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func test_function():f32 -> %fn2
-%fn2 = block {
- br %fn3 2.0f # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func test_function():f32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 2.0f # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
)");
}
diff --git a/src/tint/ir/from_program_store_test.cc b/src/tint/ir/from_program_store_test.cc
index ab0945b..1559f95 100644
--- a/src/tint/ir/from_program_store_test.cc
+++ b/src/tint/ir/from_program_store_test.cc
@@ -35,20 +35,21 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%a:ptr<private, u32, read_write> = var
br %fn2 # root_end
}
%fn2 = root_terminator
-%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn4
-%fn4 = block {
- store %a, 4u
- jmp %fn5 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ store %a, 4u
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn5 = func_terminator
-
)");
}
diff --git a/src/tint/ir/from_program_test.cc b/src/tint/ir/from_program_test.cc
index 197e7e3..bc3b2a1 100644
--- a/src/tint/ir/from_program_test.cc
+++ b/src/tint/ir/from_program_test.cc
@@ -70,12 +70,12 @@
EXPECT_EQ(m->functions[0]->Stage(), Function::PipelineStage::kUndefined);
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func f():void -> %fn2
-%fn2 = block {
- jmp %fn3 # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func f():void -> %fn1 {
+ %fn1 = block {
+ br %fn2 # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
)");
}
@@ -95,12 +95,12 @@
EXPECT_EQ(m->functions[0]->Stage(), Function::PipelineStage::kUndefined);
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func f(%a:u32):u32 -> %fn2
-%fn2 = block {
- br %fn3 %a # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func f(%a:u32):u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 %a # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
)");
}
@@ -121,12 +121,12 @@
EXPECT_EQ(m->functions[0]->Stage(), Function::PipelineStage::kUndefined);
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func f(%a:u32, %b:i32, %c:bool):void -> %fn2
-%fn2 = block {
- jmp %fn3 # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func f(%a:u32, %b:i32, %c:bool):void -> %fn1 {
+ %fn1 = block {
+ br %fn2 # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
)");
}
@@ -159,24 +159,28 @@
EXPECT_EQ(1u, func->EndTarget()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- if true [t: %fn3, f: %fn4, m: %fn5]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ if true [t: %fn2, f: %fn3, m: %fn4]
+ # True block
+ %fn2 = block {
+ br %fn4
+ }
-%fn3 = block {
- br %fn5
-}
+ # False block
+ %fn3 = block {
+ br %fn4
+ }
-%fn4 = block {
- br %fn5
-}
+ # Merge block
+ %fn4 = block {
+ br %fn5 # return
+ }
-%fn5 = block {
- jmp %fn6 # return
-}
-%fn6 = func_terminator
+ }
+ %fn5 = func_terminator
+}
)");
}
@@ -199,23 +203,27 @@
EXPECT_EQ(2u, func->EndTarget()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- if true [t: %fn3, f: %fn4, m: %fn5]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ if true [t: %fn2, f: %fn3, m: %fn4]
+ # True block
+ %fn2 = block {
+ br %fn5 # return
+ }
+ # False block
+ %fn3 = block {
+ br %fn4
+ }
-%fn3 = block {
- br %fn6 # return
-}
-%fn4 = block {
- br %fn5
-}
+ # Merge block
+ %fn4 = block {
+ br %fn5 # return
+ }
-%fn5 = block {
- jmp %fn6 # return
-}
-%fn6 = func_terminator
+ }
+ %fn5 = func_terminator
+}
)");
}
@@ -238,23 +246,27 @@
EXPECT_EQ(2u, func->EndTarget()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- if true [t: %fn3, f: %fn4, m: %fn5]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ if true [t: %fn2, f: %fn3, m: %fn4]
+ # True block
+ %fn2 = block {
+ br %fn4
+ }
-%fn3 = block {
- br %fn5
-}
+ # False block
+ %fn3 = block {
+ br %fn5 # return
+ }
+ # Merge block
+ %fn4 = block {
+ br %fn5 # return
+ }
-%fn4 = block {
- br %fn6 # return
-}
-%fn5 = block {
- jmp %fn6 # return
-}
-%fn6 = func_terminator
+ }
+ %fn5 = func_terminator
+}
)");
}
@@ -277,19 +289,22 @@
EXPECT_EQ(2u, func->EndTarget()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- if true [t: %fn3, f: %fn4]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ if true [t: %fn2, f: %fn3]
+ # True block
+ %fn2 = block {
+ br %fn4 # return
+ }
+ # False block
+ %fn3 = block {
+ br %fn4 # return
+ }
-%fn3 = block {
- br %fn5 # return
-}
-%fn4 = block {
- br %fn5 # return
-}
-%fn5 = func_terminator
+ }
+ %fn4 = func_terminator
+}
)");
}
@@ -309,36 +324,43 @@
ASSERT_NE(loop_flow, nullptr);
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- if true [t: %fn3, f: %fn4, m: %fn5]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ if true [t: %fn2, f: %fn3, m: %fn4]
+ # True block
+ %fn2 = block {
+ loop [s: %fn5, c: %fn6, m: %fn7]
+ %fn5 = block {
+ br %fn7
+ }
-%fn3 = block {
- loop [s: %fn6, c: %fn7, m: %fn8]
-}
+ # Continuing block
+ %fn6 = block {
+ br %fn5
+ }
-%fn4 = block {
- br %fn5
-}
+ # Merge block
+ %fn7 = block {
+ br %fn4
+ }
-%fn5 = block {
- jmp %fn9 # return
-}
-%fn9 = func_terminator
-%fn6 = block {
- br %fn8
-}
+ }
-%fn7 = block {
- br %fn6
-}
+ # False block
+ %fn3 = block {
+ br %fn4
+ }
-%fn8 = block {
- br %fn5
-}
+ # Merge block
+ %fn4 = block {
+ br %fn8 # return
+ }
+ }
+
+ %fn8 = func_terminator
+}
)");
}
@@ -361,24 +383,27 @@
EXPECT_EQ(1u, func->EndTarget()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- loop [s: %fn3, c: %fn4, m: %fn5]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ loop [s: %fn2, c: %fn3, m: %fn4]
+ %fn2 = block {
+ br %fn4
+ }
-%fn3 = block {
- br %fn5
-}
+ # Continuing block
+ %fn3 = block {
+ br %fn2
+ }
-%fn4 = block {
- br %fn3
-}
+ # Merge block
+ %fn4 = block {
+ br %fn5 # return
+ }
-%fn5 = block {
- jmp %fn6 # return
-}
-%fn6 = func_terminator
+ }
+ %fn5 = func_terminator
+}
)");
}
@@ -407,36 +432,43 @@
EXPECT_EQ(1u, func->EndTarget()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- loop [s: %fn3, c: %fn4, m: %fn5]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ loop [s: %fn2, c: %fn3, m: %fn4]
+ %fn2 = block {
+ if true [t: %fn5, f: %fn6, m: %fn7]
+ # True block
+ %fn5 = block {
+ br %fn4
+ }
-%fn3 = block {
- if true [t: %fn6, f: %fn7, m: %fn8]
-}
+ # False block
+ %fn6 = block {
+ br %fn7
+ }
-%fn4 = block {
- br %fn3
-}
+ # Merge block
+ %fn7 = block {
+ br %fn3
+ }
-%fn5 = block {
- jmp %fn9 # return
-}
-%fn9 = func_terminator
-%fn6 = block {
- br %fn5
-}
+ }
-%fn7 = block {
- br %fn8
-}
+ # Continuing block
+ %fn3 = block {
+ br %fn2
+ }
-%fn8 = block {
- br %fn4
-}
+ # Merge block
+ %fn4 = block {
+ br %fn8 # return
+ }
+ }
+
+ %fn8 = func_terminator
+}
)");
}
@@ -464,36 +496,43 @@
EXPECT_EQ(1u, func->EndTarget()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- loop [s: %fn3, c: %fn4, m: %fn5]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ loop [s: %fn2, c: %fn3, m: %fn4]
+ %fn2 = block {
+ br %fn3
+ }
-%fn3 = block {
- jmp %fn4
-}
+ # Continuing block
+ %fn3 = block {
+ if true [t: %fn5, f: %fn6, m: %fn7]
+ # True block
+ %fn5 = block {
+ br %fn4
+ }
-%fn4 = block {
- if true [t: %fn6, f: %fn7, m: %fn8]
-}
+ # False block
+ %fn6 = block {
+ br %fn7
+ }
-%fn5 = block {
- jmp %fn9 # return
-}
-%fn9 = func_terminator
+ # Merge block
+ %fn7 = block {
+ br %fn2
+ }
-%fn6 = block {
- br %fn5
-}
-%fn7 = block {
- br %fn8
-}
+ }
-%fn8 = block {
- br %fn3
-}
+ # Merge block
+ %fn4 = block {
+ br %fn8 # return
+ }
+ }
+
+ %fn8 = func_terminator
+}
)");
}
@@ -508,36 +547,43 @@
auto m = res.Move();
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- loop [s: %fn3, c: %fn4, m: %fn5]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ loop [s: %fn2, c: %fn3, m: %fn4]
+ %fn2 = block {
+ br %fn3
+ }
-%fn3 = block {
- jmp %fn4
-}
+ # Continuing block
+ %fn3 = block {
+ if true [t: %fn5, f: %fn6, m: %fn7]
+ # True block
+ %fn5 = block {
+ br %fn4
+ }
-%fn4 = block {
- if true [t: %fn6, f: %fn7, m: %fn8]
-}
+ # False block
+ %fn6 = block {
+ br %fn7
+ }
-%fn5 = block {
- jmp %fn9 # return
-}
-%fn9 = func_terminator
+ # Merge block
+ %fn7 = block {
+ br %fn2
+ }
-%fn6 = block {
- br %fn5
-}
-%fn7 = block {
- br %fn8
-}
+ }
-%fn8 = block {
- br %fn3
-}
+ # Merge block
+ %fn4 = block {
+ br %fn8 # return
+ }
+ }
+
+ %fn8 = func_terminator
+}
)");
}
@@ -565,32 +611,38 @@
EXPECT_EQ(1u, func->EndTarget()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- loop [s: %fn3, c: %fn4]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ loop [s: %fn2, c: %fn3]
+ %fn2 = block {
+ if true [t: %fn4, f: %fn5, m: %fn6]
+ # True block
+ %fn4 = block {
+ br %fn7 # return
+ }
+ # False block
+ %fn5 = block {
+ br %fn6
+ }
-%fn3 = block {
- if true [t: %fn5, f: %fn6, m: %fn7]
-}
+ # Merge block
+ %fn6 = block {
+ br %fn3
+ }
-%fn4 = block {
- br %fn3
-}
-%fn5 = block {
- br %fn8 # return
-}
-%fn6 = block {
- br %fn7
-}
+ }
-%fn7 = block {
- br %fn4
+ # Continuing block
+ %fn3 = block {
+ br %fn2
+ }
+
+
+ }
+
+ %fn7 = func_terminator
}
-
-%fn8 = func_terminator
-
)");
}
@@ -613,20 +665,22 @@
EXPECT_EQ(1u, func->EndTarget()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- loop [s: %fn3, c: %fn4]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ loop [s: %fn2, c: %fn3]
+ %fn2 = block {
+ br %fn4 # return
+ }
+ # Continuing block
+ %fn3 = block {
+ br %fn2
+ }
-%fn3 = block {
- br %fn5 # return
-}
-%fn4 = block {
- br %fn3
-}
-%fn5 = func_terminator
+ }
+ %fn4 = func_terminator
+}
)");
}
@@ -658,45 +712,56 @@
EXPECT_EQ(3u, func->EndTarget()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- loop [s: %fn3, c: %fn4, m: %fn5]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ loop [s: %fn2, c: %fn3, m: %fn4]
+ %fn2 = block {
+ br %fn5 # return
+ }
+ # Continuing block
+ %fn3 = block {
+ if true [t: %fn6, f: %fn7, m: %fn8]
+ # True block
+ %fn6 = block {
+ br %fn4
+ }
-%fn3 = block {
- br %fn6 # return
-}
-%fn4 = block {
- if true [t: %fn7, f: %fn8, m: %fn9]
-}
+ # False block
+ %fn7 = block {
+ br %fn8
+ }
-%fn5 = block {
- if true [t: %fn10, f: %fn11, m: %fn12]
-}
+ # Merge block
+ %fn8 = block {
+ br %fn2
+ }
-%fn6 = func_terminator
-%fn7 = block {
- br %fn5
-}
+ }
-%fn8 = block {
- br %fn9
-}
+ # Merge block
+ %fn4 = block {
+ if true [t: %fn9, f: %fn10, m: %fn11]
+ # True block
+ %fn9 = block {
+ br %fn5 # return
+ }
+ # False block
+ %fn10 = block {
+ br %fn11
+ }
-%fn9 = block {
- br %fn3
-}
+ # Merge block
+ %fn11 = block {
+ br %fn5 # return
+ }
-%fn10 = block {
- br %fn6 # return
-}
-%fn11 = block {
- br %fn12
-}
+ }
-%fn12 = block {
- jmp %fn6 # return
+
+ }
+
+ %fn5 = func_terminator
}
)");
}
@@ -725,32 +790,38 @@
EXPECT_EQ(1u, func->EndTarget()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- loop [s: %fn3, c: %fn4, m: %fn5]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ loop [s: %fn2, c: %fn3, m: %fn4]
+ %fn2 = block {
+ if true [t: %fn5, f: %fn6]
+ # True block
+ %fn5 = block {
+ br %fn4
+ }
-%fn3 = block {
- if true [t: %fn6, f: %fn7]
-}
+ # False block
+ %fn6 = block {
+ br %fn4
+ }
-%fn4 = block {
- br %fn3
-}
-%fn5 = block {
- jmp %fn8 # return
-}
-%fn8 = func_terminator
+ }
-%fn6 = block {
- br %fn5
-}
+ # Continuing block
+ %fn3 = block {
+ br %fn2
+ }
-%fn7 = block {
- br %fn5
-}
+ # Merge block
+ %fn4 = block {
+ br %fn7 # return
+ }
+ }
+
+ %fn7 = func_terminator
+}
)");
}
@@ -772,108 +843,136 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- loop [s: %fn3, c: %fn4, m: %fn5]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ loop [s: %fn2, c: %fn3, m: %fn4]
+ %fn2 = block {
+ loop [s: %fn5, c: %fn6, m: %fn7]
+ %fn5 = block {
+ if true [t: %fn8, f: %fn9, m: %fn10]
+ # True block
+ %fn8 = block {
+ br %fn7
+ }
-%fn3 = block {
- loop [s: %fn6, c: %fn7, m: %fn8]
-}
+ # False block
+ %fn9 = block {
+ br %fn10
+ }
-%fn4 = block {
- br %fn3
-}
+ # Merge block
+ %fn10 = block {
+ if true [t: %fn11, f: %fn12, m: %fn13]
+ # True block
+ %fn11 = block {
+ br %fn6
+ }
-%fn5 = block {
- jmp %fn9 # return
-}
-%fn9 = func_terminator
+ # False block
+ %fn12 = block {
+ br %fn13
+ }
-%fn6 = block {
- if true [t: %fn10, f: %fn11, m: %fn12]
-}
+ # Merge block
+ %fn13 = block {
+ br %fn6
+ }
-%fn7 = block {
- loop [s: %fn13, c: %fn14, m: %fn15]
-}
-%fn8 = block {
- if true [t: %fn16, f: %fn17, m: %fn18]
-}
+ }
-%fn10 = block {
- br %fn8
-}
-%fn11 = block {
- br %fn12
-}
+ }
-%fn12 = block {
- if true [t: %fn19, f: %fn20, m: %fn21]
-}
+ # Continuing block
+ %fn6 = block {
+ loop [s: %fn14, c: %fn15, m: %fn16]
+ %fn14 = block {
+ br %fn16
+ }
-%fn13 = block {
- br %fn15
-}
+ # Continuing block
+ %fn15 = block {
+ br %fn14
+ }
-%fn14 = block {
- br %fn13
-}
+ # Merge block
+ %fn16 = block {
+ loop [s: %fn17, c: %fn18, m: %fn19]
+ %fn17 = block {
+ br %fn18
+ }
-%fn15 = block {
- loop [s: %fn22, c: %fn23, m: %fn24]
-}
+ # Continuing block
+ %fn18 = block {
+ if true [t: %fn20, f: %fn21, m: %fn22]
+ # True block
+ %fn20 = block {
+ br %fn19
+ }
-%fn16 = block {
- br %fn5
-}
+ # False block
+ %fn21 = block {
+ br %fn22
+ }
-%fn17 = block {
- br %fn18
-}
+ # Merge block
+ %fn22 = block {
+ br %fn17
+ }
-%fn18 = block {
- jmp %fn4
-}
-%fn19 = block {
- br %fn7
-}
+ }
-%fn20 = block {
- br %fn21
-}
+ # Merge block
+ %fn19 = block {
+ br %fn5
+ }
-%fn21 = block {
- jmp %fn7
-}
-%fn22 = block {
- jmp %fn23
-}
+ }
-%fn23 = block {
- if true [t: %fn25, f: %fn26, m: %fn27]
-}
-%fn24 = block {
- br %fn6
-}
+ }
-%fn25 = block {
- br %fn24
-}
+ # Merge block
+ %fn7 = block {
+ if true [t: %fn23, f: %fn24, m: %fn25]
+ # True block
+ %fn23 = block {
+ br %fn4
+ }
-%fn26 = block {
- br %fn27
-}
+ # False block
+ %fn24 = block {
+ br %fn25
+ }
-%fn27 = block {
- br %fn22
-}
+ # Merge block
+ %fn25 = block {
+ br %fn3
+ }
+
+ }
+
+
+ }
+
+ # Continuing block
+ %fn3 = block {
+ br %fn2
+ }
+
+ # Merge block
+ %fn4 = block {
+ br %fn26 # return
+ }
+
+ }
+
+ %fn26 = func_terminator
+}
)");
}
@@ -903,36 +1002,43 @@
EXPECT_EQ(2u, if_flow->Merge()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- loop [s: %fn3, c: %fn4, m: %fn5]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ loop [s: %fn2, c: %fn3, m: %fn4]
+ %fn2 = block {
+ if false [t: %fn5, f: %fn6, m: %fn7]
+ # True block
+ %fn5 = block {
+ br %fn7
+ }
-%fn3 = block {
- if false [t: %fn6, f: %fn7, m: %fn8]
-}
+ # False block
+ %fn6 = block {
+ br %fn4
+ }
-%fn4 = block {
- br %fn3
-}
+ # Merge block
+ %fn7 = block {
+ br %fn3
+ }
-%fn5 = block {
- jmp %fn9 # return
-}
-%fn9 = func_terminator
-%fn6 = block {
- br %fn8
-}
+ }
-%fn7 = block {
- br %fn5
-}
+ # Continuing block
+ %fn3 = block {
+ br %fn2
+ }
-%fn8 = block {
- jmp %fn4
-}
+ # Merge block
+ %fn4 = block {
+ br %fn8 # return
+ }
+ }
+
+ %fn8 = func_terminator
+}
)");
}
@@ -962,34 +1068,41 @@
EXPECT_EQ(2u, if_flow->Merge()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- loop [s: %fn3, c: %fn4, m: %fn5]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ loop [s: %fn2, c: %fn3, m: %fn4]
+ %fn2 = block {
+ if true [t: %fn5, f: %fn6, m: %fn7]
+ # True block
+ %fn5 = block {
+ br %fn7
+ }
-%fn3 = block {
- if true [t: %fn6, f: %fn7, m: %fn8]
-}
+ # False block
+ %fn6 = block {
+ br %fn4
+ }
-%fn4 = block {
- br %fn3
-}
+ # Merge block
+ %fn7 = block {
+ br %fn8 # return
+ }
-%fn5 = block {
- jmp %fn9 # return
-}
-%fn9 = func_terminator
+ }
-%fn6 = block {
- br %fn8
-}
+ # Continuing block
+ %fn3 = block {
+ br %fn2
+ }
-%fn7 = block {
- br %fn5
-}
+ # Merge block
+ %fn4 = block {
+ br %fn8 # return
+ }
-%fn8 = block {
- br %fn9 # return
+ }
+
+ %fn8 = func_terminator
}
)");
}
@@ -1054,24 +1167,27 @@
EXPECT_EQ(1u, func->EndTarget()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- loop [s: %fn3, c: %fn4, m: %fn5]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ loop [s: %fn2, c: %fn3, m: %fn4]
+ %fn2 = block {
+ br %fn4
+ }
-%fn3 = block {
- br %fn5
-}
+ # Continuing block
+ %fn3 = block {
+ br %fn2
+ }
-%fn4 = block {
- br %fn3
-}
+ # Merge block
+ %fn4 = block {
+ br %fn5 # return
+ }
-%fn5 = block {
- jmp %fn6 # return
-}
-%fn6 = func_terminator
+ }
+ %fn5 = func_terminator
+}
)");
}
@@ -1114,28 +1230,33 @@
EXPECT_EQ(1u, func->EndTarget()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- switch 1i [c: (0i, %fn3), c: (1i, %fn4), c: (default, %fn5), m: %fn6]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ switch 1i [c: (0i, %fn2), c: (1i, %fn3), c: (default, %fn4), m: %fn5]
+ # Case block
+ %fn2 = block {
+ br %fn5
+ }
-%fn3 = block {
- br %fn6
-}
+ # Case block
+ %fn3 = block {
+ br %fn5
+ }
-%fn4 = block {
- br %fn6
-}
+ # Case block
+ %fn4 = block {
+ br %fn5
+ }
-%fn5 = block {
- br %fn6
-}
+ # Merge block
+ %fn5 = block {
+ br %fn6 # return
+ }
-%fn6 = block {
- jmp %fn7 # return
-}
-%fn7 = func_terminator
+ }
+ %fn6 = func_terminator
+}
)");
}
@@ -1174,20 +1295,23 @@
EXPECT_EQ(1u, func->EndTarget()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- switch 1i [c: (0i 1i default, %fn3), m: %fn4]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ switch 1i [c: (0i 1i default, %fn2), m: %fn3]
+ # Case block
+ %fn2 = block {
+ br %fn3
+ }
-%fn3 = block {
- br %fn4
-}
+ # Merge block
+ %fn3 = block {
+ br %fn4 # return
+ }
-%fn4 = block {
- jmp %fn5 # return
-}
-%fn5 = func_terminator
+ }
+ %fn4 = func_terminator
+}
)");
}
@@ -1214,20 +1338,23 @@
EXPECT_EQ(1u, func->EndTarget()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- switch 1i [c: (default, %fn3), m: %fn4]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ switch 1i [c: (default, %fn2), m: %fn3]
+ # Case block
+ %fn2 = block {
+ br %fn3
+ }
-%fn3 = block {
- br %fn4
-}
+ # Merge block
+ %fn3 = block {
+ br %fn4 # return
+ }
-%fn4 = block {
- jmp %fn5 # return
-}
-%fn5 = func_terminator
+ }
+ %fn4 = func_terminator
+}
)");
}
@@ -1263,24 +1390,28 @@
EXPECT_EQ(1u, func->EndTarget()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- switch 1i [c: (0i, %fn3), c: (default, %fn4), m: %fn5]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ switch 1i [c: (0i, %fn2), c: (default, %fn3), m: %fn4]
+ # Case block
+ %fn2 = block {
+ br %fn4
+ }
-%fn3 = block {
- br %fn5
-}
+ # Case block
+ %fn3 = block {
+ br %fn4
+ }
-%fn4 = block {
- br %fn5
-}
+ # Merge block
+ %fn4 = block {
+ br %fn5 # return
+ }
-%fn5 = block {
- jmp %fn6 # return
-}
-%fn6 = func_terminator
+ }
+ %fn5 = func_terminator
+}
)");
}
@@ -1318,19 +1449,22 @@
EXPECT_EQ(2u, func->EndTarget()->InboundBranches().Length());
EXPECT_EQ(Disassemble(m),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- switch 1i [c: (0i, %fn3), c: (default, %fn4)]
-}
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ switch 1i [c: (0i, %fn2), c: (default, %fn3)]
+ # Case block
+ %fn2 = block {
+ br %fn4 # return
+ }
+ # Case block
+ %fn3 = block {
+ br %fn4 # return
+ }
-%fn3 = block {
- br %fn5 # return
-}
-%fn4 = block {
- br %fn5 # return
-}
-%fn5 = func_terminator
+ }
+ %fn4 = func_terminator
+}
)");
}
@@ -1342,19 +1476,19 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()),
- R"(%fn1 = func b():i32 -> %fn2
-%fn2 = block {
- br %fn3 1i # return
+ R"(%1 = func b():i32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 1i # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:i32 = call b
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:i32 = call b
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
diff --git a/src/tint/ir/from_program_unary_test.cc b/src/tint/ir/from_program_unary_test.cc
index a47acf6..dc13f25 100644
--- a/src/tint/ir/from_program_unary_test.cc
+++ b/src/tint/ir/from_program_unary_test.cc
@@ -34,20 +34,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():bool -> %fn2
-%fn2 = block {
- br %fn3 false # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():bool -> %fn1 {
+ %fn1 = block {
+ br %fn2 false # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:bool = call my_func
- %tint_symbol:bool = eq %1, false
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:bool = call my_func
+ %tint_symbol:bool = eq %3, false
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -59,20 +59,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():u32 -> %fn2
-%fn2 = block {
- br %fn3 1u # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():u32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 1u # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:u32 = call my_func
- %tint_symbol:u32 = complement %1
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:u32 = call my_func
+ %tint_symbol:u32 = complement %3
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -84,20 +84,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = func my_func():i32 -> %fn2
-%fn2 = block {
- br %fn3 1i # return
+ EXPECT_EQ(Disassemble(m.Get()), R"(%1 = func my_func():i32 -> %fn1 {
+ %fn1 = block {
+ br %fn2 1i # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
-%fn4 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn5
-%fn5 = block {
- %1:i32 = call my_func
- %tint_symbol:i32 = negation %1
- jmp %fn6 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ %3:i32 = call my_func
+ %tint_symbol:i32 = negation %3
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn6 = func_terminator
-
)");
}
@@ -110,19 +110,20 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%v2:ptr<private, i32, read_write> = var
br %fn2 # root_end
}
%fn2 = root_terminator
-%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn4
-%fn4 = block {
- jmp %fn5 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn5 = func_terminator
-
)");
}
@@ -137,20 +138,21 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%v3:ptr<private, i32, read_write> = var
br %fn2 # root_end
}
%fn2 = root_terminator
-%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn4
-%fn4 = block {
- store %v3, 42i
- jmp %fn5 # return
+%2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn3 {
+ %fn3 = block {
+ store %v3, 42i
+ br %fn4 # return
+ }
+ %fn4 = func_terminator
}
-%fn5 = func_terminator
-
)");
}
diff --git a/src/tint/ir/from_program_var_test.cc b/src/tint/ir/from_program_var_test.cc
index 2e49265..85a0319 100644
--- a/src/tint/ir/from_program_var_test.cc
+++ b/src/tint/ir/from_program_var_test.cc
@@ -32,7 +32,8 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%a:ptr<private, u32, read_write> = var
br %fn2 # root_end
}
@@ -49,7 +50,8 @@
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
- EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
+ EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
+%fn1 = block {
%a:ptr<private, u32, read_write> = var, 2u
br %fn2 # root_end
}
@@ -67,13 +69,13 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- %a:ptr<function, u32, read_write> = var
- jmp %fn3 # return
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ %a:ptr<function, u32, read_write> = var
+ br %fn2 # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
)");
}
@@ -86,13 +88,13 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()),
- R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- %a:ptr<function, u32, read_write> = var, 2u
- jmp %fn3 # return
+ R"(%1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ %a:ptr<function, u32, read_write> = var, 2u
+ br %fn2 # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
)");
}
} // namespace
diff --git a/src/tint/ir/function.cc b/src/tint/ir/function.cc
index d6cd6fc..d60376f 100644
--- a/src/tint/ir/function.cc
+++ b/src/tint/ir/function.cc
@@ -19,7 +19,7 @@
namespace tint::ir {
Function::Function(Symbol name,
- type::Type* rt,
+ const type::Type* rt,
PipelineStage stage,
std::optional<std::array<uint32_t, 3>> wg_size)
: Base(), name_(name), return_type_(rt), pipeline_stage_(stage), workgroup_size_(wg_size) {}
diff --git a/src/tint/ir/function.h b/src/tint/ir/function.h
index e966e14..a9f0c1a 100644
--- a/src/tint/ir/function.h
+++ b/src/tint/ir/function.h
@@ -19,8 +19,8 @@
#include <optional>
#include <utility>
-#include "src/tint/ir/flow_node.h"
#include "src/tint/ir/function_param.h"
+#include "src/tint/ir/value.h"
#include "src/tint/symbol.h"
#include "src/tint/type/type.h"
@@ -33,7 +33,7 @@
namespace tint::ir {
/// An IR representation of a function
-class Function : public utils::Castable<Function, FlowNode> {
+class Function : public utils::Castable<Function, Value> {
public:
/// The pipeline stage for an entry point
enum class PipelineStage {
@@ -69,7 +69,7 @@
/// @param stage the function stage
/// @param wg_size the workgroup_size
Function(Symbol n,
- type::Type* rt,
+ const type::Type* rt,
PipelineStage stage = PipelineStage::kUndefined,
std::optional<std::array<uint32_t, 3>> wg_size = {});
~Function() override;
diff --git a/src/tint/ir/function_terminator.h b/src/tint/ir/function_terminator.h
index 668eab6..42aa01e 100644
--- a/src/tint/ir/function_terminator.h
+++ b/src/tint/ir/function_terminator.h
@@ -15,13 +15,13 @@
#ifndef SRC_TINT_IR_FUNCTION_TERMINATOR_H_
#define SRC_TINT_IR_FUNCTION_TERMINATOR_H_
-#include "src/tint/ir/flow_node.h"
+#include "src/tint/ir/block.h"
namespace tint::ir {
-/// Flow node used as the end of a function. Must only be used as the `end_target` in a function
-/// flow node. There are no instructions and no branches from this node.
-class FunctionTerminator : public utils::Castable<FunctionTerminator, FlowNode> {
+/// Block used as the end of a function. Must only be used as the `end_target` in a function. There
+/// are no instructions in this block.
+class FunctionTerminator : public utils::Castable<FunctionTerminator, Block> {
public:
/// Constructor
FunctionTerminator();
diff --git a/src/tint/ir/jump.cc b/src/tint/ir/jump.cc
deleted file mode 100644
index cda2f06..0000000
--- a/src/tint/ir/jump.cc
+++ /dev/null
@@ -1,25 +0,0 @@
-// Copyright 2023 The Tint Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "src/tint/ir/jump.h"
-
-TINT_INSTANTIATE_TYPEINFO(tint::ir::Jump);
-
-namespace tint::ir {
-
-Jump::Jump(FlowNode* to, utils::VectorRef<Value*> args) : Base(to, args) {}
-
-Jump::~Jump() = default;
-
-} // namespace tint::ir
diff --git a/src/tint/ir/jump.h b/src/tint/ir/jump.h
deleted file mode 100644
index 3159755..0000000
--- a/src/tint/ir/jump.h
+++ /dev/null
@@ -1,37 +0,0 @@
-// Copyright 2023 The Tint Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef SRC_TINT_IR_JUMP_H_
-#define SRC_TINT_IR_JUMP_H_
-
-#include "src/tint/ir/block.h"
-#include "src/tint/ir/branch.h"
-#include "src/tint/ir/value.h"
-#include "src/tint/utils/castable.h"
-
-namespace tint::ir {
-
-/// A jump instruction. A jump is walk continuing.
-class Jump : public utils::Castable<Jump, Branch> {
- public:
- /// Constructor
- /// @param to the block to branch too
- /// @param args the branch arguments
- explicit Jump(FlowNode* to, utils::VectorRef<Value*> args = {});
- ~Jump() override;
-};
-
-} // namespace tint::ir
-
-#endif // SRC_TINT_IR_JUMP_H_
diff --git a/src/tint/ir/load_test.cc b/src/tint/ir/load_test.cc
index 9881e98..72359e7 100644
--- a/src/tint/ir/load_test.cc
+++ b/src/tint/ir/load_test.cc
@@ -27,8 +27,8 @@
Module mod;
Builder b{mod};
- auto* store_type = b.ir.types.Get<type::I32>();
- auto* var = b.Declare(b.ir.types.Get<type::Pointer>(
+ auto* store_type = mod.Types().i32();
+ auto* var = b.Declare(mod.Types().Get<type::Pointer>(
store_type, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
const auto* inst = b.Load(var);
@@ -45,8 +45,8 @@
Module mod;
Builder b{mod};
- auto* store_type = b.ir.types.Get<type::I32>();
- auto* var = b.Declare(b.ir.types.Get<type::Pointer>(
+ auto* store_type = mod.Types().i32();
+ auto* var = b.Declare(mod.Types().Get<type::Pointer>(
store_type, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
const auto* inst = b.Load(var);
diff --git a/src/tint/ir/module.h b/src/tint/ir/module.h
index c6cf74c..1520888 100644
--- a/src/tint/ir/module.h
+++ b/src/tint/ir/module.h
@@ -17,7 +17,8 @@
#include <string>
-#include "src/tint/constant/value.h"
+#include "src/tint/constant/manager.h"
+#include "src/tint/ir/block.h"
#include "src/tint/ir/constant.h"
#include "src/tint/ir/function.h"
#include "src/tint/ir/instruction.h"
@@ -65,10 +66,15 @@
/// @return the unique symbol of the given value.
Symbol SetName(const Value* value, std::string_view name);
- /// The flow node allocator
- utils::BlockAllocator<FlowNode> flow_nodes;
- /// The constant allocator
- utils::BlockAllocator<constant::Value> constants_arena;
+ /// @return the type manager for the module
+ type::Manager& Types() { return constant_values.types; }
+
+ /// The block allocator
+ utils::BlockAllocator<Block> blocks;
+
+ /// The constant value manager
+ constant::Manager constant_values;
+
/// The value allocator
utils::BlockAllocator<Value> values;
@@ -78,34 +84,11 @@
/// The block containing module level declarations, if any exist.
Block* root_block = nullptr;
- /// The type manager for the module
- type::Manager types;
-
/// The symbol table for the module
SymbolTable symbols{prog_id_};
- /// ConstantHasher provides a hash function for a constant::Value pointer, hashing the value
- /// instead of the pointer itself.
- struct ConstantHasher {
- /// @param c the constant pointer to create a hash for
- /// @return the hash value
- inline std::size_t operator()(const constant::Value* c) const { return c->Hash(); }
- };
-
- /// ConstantEquals provides an equality function for two constant::Value pointers, comparing
- /// their values instead of the pointers.
- struct ConstantEquals {
- /// @param a the first constant pointer to compare
- /// @param b the second constant pointer to compare
- /// @return the hash value
- inline bool operator()(const constant::Value* a, const constant::Value* b) const {
- return a->Equal(b);
- }
- };
-
/// The map of constant::Value to their ir::Constant.
- utils::Hashmap<const constant::Value*, ir::Constant*, 16, ConstantHasher, ConstantEquals>
- constants;
+ utils::Hashmap<const constant::Value*, ir::Constant*, 16> constants;
};
} // namespace tint::ir
diff --git a/src/tint/ir/module_test.cc b/src/tint/ir/module_test.cc
index 9cd36ae..f15aa35 100644
--- a/src/tint/ir/module_test.cc
+++ b/src/tint/ir/module_test.cc
@@ -25,20 +25,20 @@
TEST_F(IR_ModuleTest, NameOfUnnamed) {
Module mod;
- auto* v = mod.values.Create<ir::Var>(mod.types.Get<type::I32>());
+ auto* v = mod.values.Create<ir::Var>(mod.Types().i32());
EXPECT_FALSE(mod.NameOf(v).IsValid());
}
TEST_F(IR_ModuleTest, SetName) {
Module mod;
- auto* v = mod.values.Create<ir::Var>(mod.types.Get<type::I32>());
+ auto* v = mod.values.Create<ir::Var>(mod.Types().i32());
EXPECT_EQ(mod.SetName(v, "a").Name(), "a");
EXPECT_EQ(mod.NameOf(v).Name(), "a");
}
TEST_F(IR_ModuleTest, SetNameRename) {
Module mod;
- auto* v = mod.values.Create<ir::Var>(mod.types.Get<type::I32>());
+ auto* v = mod.values.Create<ir::Var>(mod.Types().i32());
EXPECT_EQ(mod.SetName(v, "a").Name(), "a");
EXPECT_EQ(mod.SetName(v, "b").Name(), "b");
EXPECT_EQ(mod.NameOf(v).Name(), "b");
@@ -46,9 +46,9 @@
TEST_F(IR_ModuleTest, SetNameCollision) {
Module mod;
- auto* a = mod.values.Create<ir::Var>(mod.types.Get<type::I32>());
- auto* b = mod.values.Create<ir::Var>(mod.types.Get<type::I32>());
- auto* c = mod.values.Create<ir::Var>(mod.types.Get<type::I32>());
+ auto* a = mod.values.Create<ir::Var>(mod.Types().i32());
+ auto* b = mod.values.Create<ir::Var>(mod.Types().i32());
+ auto* c = mod.values.Create<ir::Var>(mod.Types().i32());
EXPECT_EQ(mod.SetName(a, "x").Name(), "x");
EXPECT_EQ(mod.SetName(b, "x_1").Name(), "x_1");
EXPECT_EQ(mod.SetName(c, "x").Name(), "x_2");
diff --git a/src/tint/ir/root_terminator.h b/src/tint/ir/root_terminator.h
index 361aa6d..4a52b32 100644
--- a/src/tint/ir/root_terminator.h
+++ b/src/tint/ir/root_terminator.h
@@ -15,13 +15,12 @@
#ifndef SRC_TINT_IR_ROOT_TERMINATOR_H_
#define SRC_TINT_IR_ROOT_TERMINATOR_H_
-#include "src/tint/ir/flow_node.h"
+#include "src/tint/ir/block.h"
namespace tint::ir {
-/// Flow node used as the end of a function. Must only be used as the `end_target` in a function
-/// flow node. There are no instructions and no branches from this node.
-class RootTerminator : public utils::Castable<RootTerminator, FlowNode> {
+/// Block used as the end of a root block. There are no instructions in this block.
+class RootTerminator : public utils::Castable<RootTerminator, Block> {
public:
/// Constructor
RootTerminator();
diff --git a/src/tint/ir/to_program.cc b/src/tint/ir/to_program.cc
index c1cb9f9..842f41d 100644
--- a/src/tint/ir/to_program.cc
+++ b/src/tint/ir/to_program.cc
@@ -23,7 +23,6 @@
#include "src/tint/ir/function_terminator.h"
#include "src/tint/ir/if.h"
#include "src/tint/ir/instruction.h"
-#include "src/tint/ir/jump.h"
#include "src/tint/ir/load.h"
#include "src/tint/ir/module.h"
#include "src/tint/ir/store.h"
@@ -99,7 +98,7 @@
if (!ret_ty) {
return nullptr;
}
- auto* body = FlowNodeGraph(fn->StartTarget());
+ auto* body = BlockGraph(fn->StartTarget());
if (!body) {
return nullptr;
}
@@ -109,13 +108,13 @@
std::move(ret_attrs));
}
- const ast::BlockStatement* FlowNodeGraph(const ir::Block* start_node) {
+ const ast::BlockStatement* BlockGraph(const ir::Block* start_node) {
// TODO(crbug.com/tint/1902): Check if the block is dead
utils::Vector<const ast::Statement*,
decltype(ast::BlockStatement::statements)::static_length>
stmts;
- const ir::FlowNode* block = start_node;
+ const ir::Block* block = start_node;
// TODO(crbug.com/tint/1902): Handle block arguments.
@@ -127,6 +126,8 @@
Status status = tint::Switch(
block,
+ [&](const ir::FunctionTerminator*) { return kStop; },
+
[&](const ir::Block* blk) {
for (auto* inst : blk->Instructions()) {
auto stmt = Stmt(inst);
@@ -137,10 +138,7 @@
stmts.Push(s);
}
}
- if (blk->Branch()->Is<Jump>() && blk->Branch()->To()->Is<Block>()) {
- block = blk->Branch()->To()->As<Block>();
- return kContinue;
- } else if (auto* if_ = blk->Branch()->As<ir::If>()) {
+ if (auto* if_ = blk->Branch()->As<ir::If>()) {
if (if_->Merge()->HasBranchTarget()) {
block = if_->Merge();
return kContinue;
@@ -154,8 +152,6 @@
return kStop;
},
- [&](const ir::FunctionTerminator*) { return kStop; },
-
[&](Default) {
UNHANDLED_CASE(block);
return kError;
@@ -175,7 +171,7 @@
const ast::IfStatement* If(const ir::If* i) {
SCOPED_NESTING();
auto* cond = Expr(i->Condition());
- auto* t = FlowNodeGraph(i->True());
+ auto* t = BlockGraph(i->True());
if (TINT_UNLIKELY(!t)) {
return nullptr;
}
@@ -191,7 +187,7 @@
}
return b.If(cond, t, b.Else(f));
} else {
- auto* f = FlowNodeGraph(i->False());
+ auto* f = BlockGraph(i->False());
if (!f) {
return nullptr;
}
@@ -214,7 +210,7 @@
s->Cases(), //
[&](const ir::Switch::Case c) -> const tint::ast::CaseStatement* {
SCOPED_NESTING();
- auto* body = FlowNodeGraph(c.start);
+ auto* body = BlockGraph(c.start);
if (!body) {
return nullptr;
}
@@ -274,7 +270,7 @@
}
/// @return true if there are no instructions between @p node and and @p stop_at
- bool IsEmpty(const ir::Block* node, const ir::FlowNode* stop_at) {
+ bool IsEmpty(const ir::Block* node, const ir::Block* stop_at) {
if (node->Instructions().IsEmpty()) {
return true;
}
diff --git a/src/tint/ir/transform/add_empty_entry_point.cc b/src/tint/ir/transform/add_empty_entry_point.cc
index f40c60e..2372ef54 100644
--- a/src/tint/ir/transform/add_empty_entry_point.cc
+++ b/src/tint/ir/transform/add_empty_entry_point.cc
@@ -35,9 +35,8 @@
}
ir::Builder builder(*ir);
- auto* ep =
- builder.CreateFunction(ir->symbols.New("unused_entry_point"), ir->types.Get<type::Void>(),
- Function::PipelineStage::kCompute, std::array{1u, 1u, 1u});
+ auto* ep = builder.CreateFunction(ir->symbols.New("unused_entry_point"), ir->Types().void_(),
+ Function::PipelineStage::kCompute, std::array{1u, 1u, 1u});
ep->StartTarget()->SetInstructions(utils::Vector{builder.Branch(ep->EndTarget())});
ir->functions.Push(ep);
}
diff --git a/src/tint/ir/transform/add_empty_entry_point_test.cc b/src/tint/ir/transform/add_empty_entry_point_test.cc
index c86c09c..fec4192 100644
--- a/src/tint/ir/transform/add_empty_entry_point_test.cc
+++ b/src/tint/ir/transform/add_empty_entry_point_test.cc
@@ -25,12 +25,12 @@
TEST_F(IR_AddEmptyEntryPointTest, EmptyModule) {
auto* expect = R"(
-%fn1 = func unused_entry_point():void [@compute @workgroup_size(1, 1, 1)] -> %fn2
-%fn2 = block {
- br %fn3 # return
+%1 = func unused_entry_point():void [@compute @workgroup_size(1, 1, 1)] -> %fn1 {
+ %fn1 = block {
+ br %fn2 # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
)";
Run<AddEmptyEntryPoint>();
@@ -39,18 +39,17 @@
}
TEST_F(IR_AddEmptyEntryPointTest, ExistingEntryPoint) {
- auto* ep =
- b.CreateFunction("main", mod.types.Get<type::Void>(), Function::PipelineStage::kFragment);
+ auto* ep = b.CreateFunction("main", mod.Types().void_(), Function::PipelineStage::kFragment);
ep->StartTarget()->SetInstructions(utils::Vector{b.Branch(ep->EndTarget())});
mod.functions.Push(ep);
auto* expect = R"(
-%fn1 = func main():void [@fragment] -> %fn2
-%fn2 = block {
- br %fn3 # return
+%1 = func main():void [@fragment] -> %fn1 {
+ %fn1 = block {
+ br %fn2 # return
+ }
+ %fn2 = func_terminator
}
-%fn3 = func_terminator
-
)";
Run<AddEmptyEntryPoint>();
diff --git a/src/tint/ir/unary_test.cc b/src/tint/ir/unary_test.cc
index bb0b4f2..92ad24c 100644
--- a/src/tint/ir/unary_test.cc
+++ b/src/tint/ir/unary_test.cc
@@ -26,7 +26,7 @@
TEST_F(IR_InstructionTest, CreateComplement) {
Module mod;
Builder b{mod};
- auto* inst = b.Complement(b.ir.types.Get<type::I32>(), b.Constant(4_i));
+ auto* inst = b.Complement(mod.Types().i32(), b.Constant(4_i));
ASSERT_TRUE(inst->Is<Unary>());
EXPECT_EQ(inst->Kind(), Unary::Kind::kComplement);
@@ -40,7 +40,7 @@
TEST_F(IR_InstructionTest, CreateNegation) {
Module mod;
Builder b{mod};
- auto* inst = b.Negation(b.ir.types.Get<type::I32>(), b.Constant(4_i));
+ auto* inst = b.Negation(mod.Types().i32(), b.Constant(4_i));
ASSERT_TRUE(inst->Is<Unary>());
EXPECT_EQ(inst->Kind(), Unary::Kind::kNegation);
@@ -54,7 +54,7 @@
TEST_F(IR_InstructionTest, Unary_Usage) {
Module mod;
Builder b{mod};
- auto* inst = b.Negation(b.ir.types.Get<type::I32>(), b.Constant(4_i));
+ auto* inst = b.Negation(mod.Types().i32(), b.Constant(4_i));
EXPECT_EQ(inst->Kind(), Unary::Kind::kNegation);
diff --git a/src/tint/program.cc b/src/tint/program.cc
index a1cf5c8..92a9f02 100644
--- a/src/tint/program.cc
+++ b/src/tint/program.cc
@@ -37,10 +37,9 @@
Program::Program(Program&& program)
: id_(std::move(program.id_)),
highest_node_id_(std::move(program.highest_node_id_)),
- types_(std::move(program.types_)),
+ constants_(std::move(program.constants_)),
ast_nodes_(std::move(program.ast_nodes_)),
sem_nodes_(std::move(program.sem_nodes_)),
- constant_nodes_(std::move(program.constant_nodes_)),
ast_(std::move(program.ast_)),
sem_(std::move(program.sem_)),
symbols_(std::move(program.symbols_)),
@@ -63,10 +62,9 @@
}
// The above must be called *before* the calls to std::move() below
- types_ = std::move(builder.Types());
+ constants_ = std::move(builder.constants);
ast_nodes_ = std::move(builder.ASTNodes());
sem_nodes_ = std::move(builder.SemNodes());
- constant_nodes_ = std::move(builder.ConstantNodes());
ast_ = &builder.AST(); // ast::Module is actually a heap allocation.
sem_ = std::move(builder.Sem());
symbols_ = std::move(builder.Symbols());
@@ -89,10 +87,9 @@
moved_ = false;
id_ = std::move(program.id_);
highest_node_id_ = std::move(program.highest_node_id_);
- types_ = std::move(program.types_);
+ constants_ = std::move(program.constants_);
ast_nodes_ = std::move(program.ast_nodes_);
sem_nodes_ = std::move(program.sem_nodes_);
- constant_nodes_ = std::move(program.constant_nodes_);
ast_ = std::move(program.ast_);
sem_ = std::move(program.sem_);
symbols_ = std::move(program.symbols_);
diff --git a/src/tint/program.h b/src/tint/program.h
index ae46acd..940a6ec 100644
--- a/src/tint/program.h
+++ b/src/tint/program.h
@@ -19,7 +19,7 @@
#include <unordered_set>
#include "src/tint/ast/function.h"
-#include "src/tint/constant/value.h"
+#include "src/tint/constant/manager.h"
#include "src/tint/program_id.h"
#include "src/tint/sem/info.h"
#include "src/tint/symbol_table.h"
@@ -44,9 +44,6 @@
/// SemNodeAllocator is an alias to BlockAllocator<sem::Node>
using SemNodeAllocator = utils::BlockAllocator<sem::Node>;
- /// ConstantAllocator is an alias to BlockAllocator<constant::Value>
- using ConstantAllocator = utils::BlockAllocator<constant::Value>;
-
/// Constructor
Program();
@@ -72,10 +69,16 @@
/// @returns the last allocated (numerically highest) AST node identifier.
ast::NodeID HighestASTNodeID() const { return highest_node_id_; }
+ /// @returns a reference to the program's constants
+ const constant::Manager& Constants() const {
+ AssertNotMoved();
+ return constants_;
+ }
+
/// @returns a reference to the program's types
const type::Manager& Types() const {
AssertNotMoved();
- return types_;
+ return constants_.types;
}
/// @returns a reference to the program's AST nodes storage
@@ -165,10 +168,9 @@
ProgramID id_;
ast::NodeID highest_node_id_;
- type::Manager types_;
+ constant::Manager constants_;
ASTNodeAllocator ast_nodes_;
SemNodeAllocator sem_nodes_;
- ConstantAllocator constant_nodes_;
ast::Module* ast_ = nullptr;
sem::Info sem_;
SymbolTable symbols_{id_};
diff --git a/src/tint/program_builder.cc b/src/tint/program_builder.cc
index fc2e795..20a3fc4 100644
--- a/src/tint/program_builder.cc
+++ b/src/tint/program_builder.cc
@@ -38,9 +38,9 @@
ast_(ast_nodes_.Create<ast::Module>(id_, AllocateNodeID(), Source{})) {}
ProgramBuilder::ProgramBuilder(ProgramBuilder&& rhs)
- : id_(std::move(rhs.id_)),
+ : constants(std::move(rhs.constants)),
+ id_(std::move(rhs.id_)),
last_ast_node_id_(std::move(rhs.last_ast_node_id_)),
- types_(std::move(rhs.types_)),
ast_nodes_(std::move(rhs.ast_nodes_)),
sem_nodes_(std::move(rhs.sem_nodes_)),
ast_(std::move(rhs.ast_)),
@@ -57,7 +57,7 @@
AssertNotMoved();
id_ = std::move(rhs.id_);
last_ast_node_id_ = std::move(rhs.last_ast_node_id_);
- types_ = std::move(rhs.types_);
+ constants = std::move(rhs.constants);
ast_nodes_ = std::move(rhs.ast_nodes_);
sem_nodes_ = std::move(rhs.sem_nodes_);
ast_ = std::move(rhs.ast_);
@@ -72,7 +72,7 @@
ProgramBuilder builder;
builder.id_ = program->ID();
builder.last_ast_node_id_ = program->HighestASTNodeID();
- builder.types_ = type::Manager::Wrap(program->Types());
+ builder.constants = constant::Manager::Wrap(program->Constants());
builder.ast_ =
builder.create<ast::Module>(program->AST().source, program->AST().GlobalDeclarations());
builder.sem_ = sem::Info::Wrap(program->Sem());
@@ -136,39 +136,4 @@
});
}
-const constant::Value* ProgramBuilder::createSplatOrComposite(
- const type::Type* type,
- utils::VectorRef<const constant::Value*> elements) {
- if (elements.IsEmpty()) {
- return nullptr;
- }
-
- bool any_zero = false;
- bool all_zero = true;
- bool all_equal = true;
- auto* first = elements.Front();
- for (auto* el : elements) {
- if (!el) {
- return nullptr;
- }
- if (!any_zero && el->AnyZero()) {
- any_zero = true;
- }
- if (all_zero && !el->AllZero()) {
- all_zero = false;
- }
- if (all_equal && el != first) {
- if (!el->Equal(first)) {
- all_equal = false;
- }
- }
- }
- if (all_equal) {
- return create<constant::Splat>(type, elements[0], elements.Length());
- }
-
- return constant_nodes_.Create<constant::Composite>(type, std::move(elements), all_zero,
- any_zero);
-}
-
} // namespace tint
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index b944e21..459edaa 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -78,9 +78,7 @@
#include "src/tint/builtin/extension.h"
#include "src/tint/builtin/interpolation_sampling.h"
#include "src/tint/builtin/interpolation_type.h"
-#include "src/tint/constant/composite.h"
-#include "src/tint/constant/splat.h"
-#include "src/tint/constant/value.h"
+#include "src/tint/constant/manager.h"
#include "src/tint/number.h"
#include "src/tint/program.h"
#include "src/tint/program_id.h"
@@ -328,9 +326,6 @@
/// SemNodeAllocator is an alias to BlockAllocator<sem::Node>
using SemNodeAllocator = utils::BlockAllocator<sem::Node>;
- /// ConstantAllocator is an alias to BlockAllocator<constant::Value>
- using ConstantAllocator = utils::BlockAllocator<constant::Value>;
-
/// Constructor
ProgramBuilder();
@@ -364,13 +359,13 @@
/// @returns a reference to the program's types
type::Manager& Types() {
AssertNotMoved();
- return types_;
+ return constants.types;
}
/// @returns a reference to the program's types
const type::Manager& Types() const {
AssertNotMoved();
- return types_;
+ return constants.types;
}
/// @returns a reference to the program's AST nodes storage
@@ -397,12 +392,6 @@
return sem_nodes_;
}
- /// @returns a reference to the program's semantic constant storage
- ConstantAllocator& ConstantNodes() {
- AssertNotMoved();
- return constant_nodes_;
- }
-
/// @returns a reference to the program's AST root Module
ast::Module& AST() {
AssertNotMoved();
@@ -528,53 +517,6 @@
return sem_nodes_.Create<T>(std::forward<ARGS>(args)...);
}
- /// Creates a new constant::Value owned by the ProgramBuilder.
- /// When the ProgramBuilder is destructed, the sem::Node will also be destructed.
- /// @param args the arguments to pass to the constructor
- /// @returns the node pointer
- template <typename T, typename... ARGS>
- utils::traits::EnableIf<utils::traits::IsTypeOrDerived<T, constant::Value> &&
- !utils::traits::IsTypeOrDerived<T, constant::Composite> &&
- !utils::traits::IsTypeOrDerived<T, constant::Splat>,
- T>*
- create(ARGS&&... args) {
- AssertNotMoved();
- return constant_nodes_.Create<T>(std::forward<ARGS>(args)...);
- }
-
- /// Constructs a constant of a vector, matrix or array type.
- ///
- /// Examines the element values and will return either a constant::Composite or a
- /// constant::Splat, depending on the element types and values.
- ///
- /// @param type the composite type
- /// @param elements the composite elements
- /// @returns the node pointer
- template <
- typename T,
- typename = utils::traits::EnableIf<utils::traits::IsTypeOrDerived<T, constant::Composite> ||
- utils::traits::IsTypeOrDerived<T, constant::Splat>>>
- const constant::Value* create(const type::Type* type,
- utils::VectorRef<const constant::Value*> elements) {
- AssertNotMoved();
- return createSplatOrComposite(type, elements);
- }
-
- /// Constructs a splat constant.
- /// @param type the splat type
- /// @param element the splat element
- /// @param n the number of elements
- /// @returns the node pointer
- template <
- typename T,
- typename = utils::traits::EnableIf<utils::traits::IsTypeOrDerived<T, constant::Splat>>>
- const constant::Splat* create(const type::Type* type,
- const constant::Value* element,
- size_t n) {
- AssertNotMoved();
- return constant_nodes_.Create<constant::Splat>(type, element, n);
- }
-
/// Creates a new type::Node owned by the ProgramBuilder.
/// When the ProgramBuilder is destructed, owned ProgramBuilder and the returned node will also
/// be destructed. If T derives from type::UniqueNode, then the calling create() for the same
@@ -584,7 +526,7 @@
template <typename T, typename... ARGS>
utils::traits::EnableIfIsType<T, type::Node>* create(ARGS&&... args) {
AssertNotMoved();
- return types_.Get<T>(std::forward<ARGS>(args)...);
+ return constants.types.Get<T>(std::forward<ARGS>(args)...);
}
/// Marks this builder as moved, preventing any further use of the builder.
@@ -3953,6 +3895,9 @@
/// @returns the function
const ast::Function* WrapInFunction(utils::VectorRef<const ast::Statement*> stmts);
+ /// The constants manager
+ constant::Manager constants;
+
/// The builder types
TypesBuilder const ty{this};
@@ -3961,16 +3906,10 @@
void AssertNotMoved() const;
private:
- const constant::Value* createSplatOrComposite(
- const type::Type* type,
- utils::VectorRef<const constant::Value*> elements);
-
ProgramID id_;
ast::NodeID last_ast_node_id_ = ast::NodeID{static_cast<decltype(ast::NodeID::value)>(0) - 1};
- type::Manager types_;
ASTNodeAllocator ast_nodes_;
SemNodeAllocator sem_nodes_;
- ConstantAllocator constant_nodes_;
ast::Module* ast_;
sem::Info sem_;
SymbolTable symbols_{id_};
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 35ea8bb..c181f02 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -261,13 +261,15 @@
using FROM = T;
if constexpr (std::is_same_v<TO, bool>) {
// [x -> bool]
- return builder.create<constant::Scalar<TO>>(target_ty, !scalar->IsPositiveZero());
+ return builder.constants.Get<constant::Scalar<TO>>(target_ty,
+ !scalar->IsPositiveZero());
} else if constexpr (std::is_same_v<FROM, bool>) {
// [bool -> x]
- return builder.create<constant::Scalar<TO>>(target_ty, TO(scalar->value ? 1 : 0));
+ return builder.constants.Get<constant::Scalar<TO>>(target_ty,
+ TO(scalar->value ? 1 : 0));
} else if (auto conv = CheckedConvert<TO>(scalar->value)) {
// Conversion success
- return builder.create<constant::Scalar<TO>>(target_ty, conv.Get());
+ return builder.constants.Get<constant::Scalar<TO>>(target_ty, conv.Get());
// --- Below this point are the failure cases ---
} else if constexpr (IsAbstract<FROM>) {
// [abstract-numeric -> x] - materialization failure
@@ -276,9 +278,10 @@
builder.Diagnostics().add_warning(tint::diag::System::Resolver, msg, source);
switch (conv.Failure()) {
case ConversionFailure::kExceedsNegativeLimit:
- return builder.create<constant::Scalar<TO>>(target_ty, TO::Lowest());
+ return builder.constants.Get<constant::Scalar<TO>>(target_ty, TO::Lowest());
case ConversionFailure::kExceedsPositiveLimit:
- return builder.create<constant::Scalar<TO>>(target_ty, TO::Highest());
+ return builder.constants.Get<constant::Scalar<TO>>(target_ty,
+ TO::Highest());
}
} else {
builder.Diagnostics().add_error(tint::diag::System::Resolver, msg, source);
@@ -292,9 +295,10 @@
builder.Diagnostics().add_warning(tint::diag::System::Resolver, msg, source);
switch (conv.Failure()) {
case ConversionFailure::kExceedsNegativeLimit:
- return builder.create<constant::Scalar<TO>>(target_ty, TO::Lowest());
+ return builder.constants.Get<constant::Scalar<TO>>(target_ty, TO::Lowest());
case ConversionFailure::kExceedsPositiveLimit:
- return builder.create<constant::Scalar<TO>>(target_ty, TO::Highest());
+ return builder.constants.Get<constant::Scalar<TO>>(target_ty,
+ TO::Highest());
}
} else {
builder.Diagnostics().add_error(tint::diag::System::Resolver, msg, source);
@@ -305,14 +309,15 @@
// https://www.w3.org/TR/WGSL/#floating-point-conversion
switch (conv.Failure()) {
case ConversionFailure::kExceedsNegativeLimit:
- return builder.create<constant::Scalar<TO>>(target_ty, TO::Lowest());
+ return builder.constants.Get<constant::Scalar<TO>>(target_ty, TO::Lowest());
case ConversionFailure::kExceedsPositiveLimit:
- return builder.create<constant::Scalar<TO>>(target_ty, TO::Highest());
+ return builder.constants.Get<constant::Scalar<TO>>(target_ty, TO::Highest());
}
} else if constexpr (IsIntegral<FROM>) {
// [integer -> integer] - number not exactly representable
// Static cast
- return builder.create<constant::Scalar<TO>>(target_ty, static_cast<TO>(scalar->value));
+ return builder.constants.Get<constant::Scalar<TO>>(target_ty,
+ static_cast<TO>(scalar->value));
}
return nullptr; // Expression is not constant.
});
@@ -362,7 +367,7 @@
}
conv_els.Push(conv_el.Get());
}
- return builder.create<constant::Composite>(target_ty, std::move(conv_els));
+ return builder.constants.Composite(target_ty, std::move(conv_els));
}
ConstEval::Result SplatConvert(const constant::Splat* splat,
@@ -396,7 +401,7 @@
if (!conv_el.Get()) {
return nullptr;
}
- return builder.create<constant::Splat>(target_ty, conv_el.Get(), splat->count);
+ return builder.constants.Splat(target_ty, conv_el.Get(), splat->count);
}
ConstEval::Result ConvertInternal(const constant::Value* c,
@@ -466,7 +471,7 @@
return el.Failure();
}
}
- return builder.create<constant::Composite>(composite_ty, std::move(els));
+ return builder.constants.Composite(composite_ty, std::move(els));
}
} // namespace detail
@@ -520,7 +525,7 @@
return el.Failure();
}
}
- return builder.create<constant::Composite>(composite_ty, std::move(els));
+ return builder.constants.Composite(composite_ty, std::move(els));
}
} // namespace
@@ -542,7 +547,7 @@
}
}
}
- return builder.create<constant::Scalar<T>>(t, v);
+ return builder.constants.Get<constant::Scalar<T>>(t, v);
}
const constant::Value* ConstEval::ZeroValue(const type::Type* type) {
@@ -550,16 +555,16 @@
type, //
[&](const type::Vector* v) -> const constant::Value* {
auto* zero_el = ZeroValue(v->type());
- return builder.create<constant::Splat>(type, zero_el, v->Width());
+ return builder.constants.Splat(type, zero_el, v->Width());
},
[&](const type::Matrix* m) -> const constant::Value* {
auto* zero_el = ZeroValue(m->ColumnType());
- return builder.create<constant::Splat>(type, zero_el, m->columns());
+ return builder.constants.Splat(type, zero_el, m->columns());
},
[&](const type::Array* a) -> const constant::Value* {
if (auto n = a->ConstantCount()) {
if (auto* zero_el = ZeroValue(a->ElemType())) {
- return builder.create<constant::Splat>(type, zero_el, n.value());
+ return builder.constants.Splat(type, zero_el, n.value());
}
}
return nullptr;
@@ -578,9 +583,9 @@
}
if (zero_by_type.Count() == 1) {
// All members were of the same type, so the zero value is the same for all members.
- return builder.create<constant::Splat>(type, zeros[0], s->Members().Length());
+ return builder.constants.Splat(type, zeros[0], s->Members().Length());
}
- return builder.create<constant::Composite>(s, std::move(zeros));
+ return builder.constants.Composite(s, std::move(zeros));
},
[&](Default) -> const constant::Value* {
return ZeroTypeDispatch(type, [&](auto zero) -> const constant::Value* {
@@ -1260,7 +1265,7 @@
}
// Multiple arguments. Must be a value constructor.
- return builder.create<constant::Composite>(ty, std::move(args));
+ return builder.constants.Composite(ty, std::move(args));
}
ConstEval::Result ConstEval::Conv(const type::Type* ty,
@@ -1295,8 +1300,7 @@
utils::VectorRef<const constant::Value*> args,
const Source&) {
if (auto* arg = args[0]) {
- return builder.create<constant::Splat>(ty, arg,
- static_cast<const type::Vector*>(ty)->Width());
+ return builder.constants.Splat(ty, arg, static_cast<const type::Vector*>(ty)->Width());
}
return nullptr;
}
@@ -1304,7 +1308,7 @@
ConstEval::Result ConstEval::VecInitS(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source&) {
- return builder.create<constant::Composite>(ty, args);
+ return builder.constants.Composite(ty, args);
}
ConstEval::Result ConstEval::VecInitM(const type::Type* ty,
@@ -1330,7 +1334,7 @@
els.Push(val);
}
}
- return builder.create<constant::Composite>(ty, std::move(els));
+ return builder.constants.Composite(ty, std::move(els));
}
ConstEval::Result ConstEval::MatInitS(const type::Type* ty,
@@ -1345,15 +1349,15 @@
auto i = r + c * m->rows();
column.Push(args[i]);
}
- els.Push(builder.create<constant::Composite>(m->ColumnType(), std::move(column)));
+ els.Push(builder.constants.Composite(m->ColumnType(), std::move(column)));
}
- return builder.create<constant::Composite>(ty, std::move(els));
+ return builder.constants.Composite(ty, std::move(els));
}
ConstEval::Result ConstEval::MatInitV(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source&) {
- return builder.create<constant::Composite>(ty, args);
+ return builder.constants.Composite(ty, args);
}
ConstEval::Result ConstEval::Index(const type::Type* ty,
@@ -1411,7 +1415,7 @@
}
auto values = utils::Transform<4>(
indices, [&](uint32_t i) { return vec_val->Index(static_cast<size_t>(i)); });
- return builder.create<constant::Composite>(ty, std::move(values));
+ return builder.constants.Composite(ty, std::move(values));
}
ConstEval::Result ConstEval::Bitcast(const type::Type* ty,
@@ -1557,7 +1561,7 @@
}
result.Push(r.Get());
}
- return builder.create<constant::Composite>(ty, result);
+ return builder.constants.Composite(ty, result);
}
ConstEval::Result ConstEval::OpMultiplyVecMat(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
@@ -1607,7 +1611,7 @@
}
result.Push(r.Get());
}
- return builder.create<constant::Composite>(ty, result);
+ return builder.constants.Composite(ty, result);
}
ConstEval::Result ConstEval::OpMultiplyMatMat(const type::Type* ty,
@@ -1669,9 +1673,9 @@
// Add column vector to matrix
auto* col_vec_ty = ty->As<type::Matrix>()->ColumnType();
- result_mat.Push(builder.create<constant::Composite>(col_vec_ty, col_vec));
+ result_mat.Push(builder.constants.Composite(col_vec_ty, col_vec));
}
- return builder.create<constant::Composite>(ty, result_mat);
+ return builder.constants.Composite(ty, result_mat);
}
ConstEval::Result ConstEval::OpDivide(const type::Type* ty,
@@ -2311,7 +2315,7 @@
return utils::Failure;
}
- return builder.create<constant::Composite>(
+ return builder.constants.Composite(
ty, utils::Vector<const constant::Value*, 3>{x.Get(), y.Get(), z.Get()});
}
@@ -2707,20 +2711,20 @@
}
auto fract_ty = builder.create<type::Vector>(fract_els[0]->Type(), vec->Width());
auto exp_ty = builder.create<type::Vector>(exp_els[0]->Type(), vec->Width());
- return builder.create<constant::Composite>(
+ return builder.constants.Composite(
ty, utils::Vector<const constant::Value*, 2>{
- builder.create<constant::Composite>(fract_ty, std::move(fract_els)),
- builder.create<constant::Composite>(exp_ty, std::move(exp_els)),
+ builder.constants.Composite(fract_ty, std::move(fract_els)),
+ builder.constants.Composite(exp_ty, std::move(exp_els)),
});
} else {
auto fe = scalar(arg);
if (!fe.fract || !fe.exp) {
return utils::Failure;
}
- return builder.create<constant::Composite>(ty, utils::Vector<const constant::Value*, 2>{
- fe.fract.Get(),
- fe.exp.Get(),
- });
+ return builder.constants.Composite(ty, utils::Vector<const constant::Value*, 2>{
+ fe.fract.Get(),
+ fe.exp.Get(),
+ });
}
}
@@ -3014,7 +3018,7 @@
return utils::Failure;
}
- return builder.create<constant::Composite>(ty, std::move(fields));
+ return builder.constants.Composite(ty, std::move(fields));
}
ConstEval::Result ConstEval::normalize(const type::Type* ty,
@@ -3600,10 +3604,9 @@
for (size_t c = 0; c < mat_ty->columns(); ++c) {
new_col_vec.Push(me(r, c));
}
- result_mat.Push(
- builder.create<constant::Composite>(result_mat_ty->ColumnType(), new_col_vec));
+ result_mat.Push(builder.constants.Composite(result_mat_ty->ColumnType(), new_col_vec));
}
- return builder.create<constant::Composite>(ty, result_mat);
+ return builder.constants.Composite(ty, result_mat);
}
ConstEval::Result ConstEval::trunc(const type::Type* ty,
@@ -3643,7 +3646,7 @@
}
els.Push(el.Get());
}
- return builder.create<constant::Composite>(ty, std::move(els));
+ return builder.constants.Composite(ty, std::move(els));
}
ConstEval::Result ConstEval::unpack2x16snorm(const type::Type* ty,
@@ -3663,7 +3666,7 @@
}
els.Push(el.Get());
}
- return builder.create<constant::Composite>(ty, std::move(els));
+ return builder.constants.Composite(ty, std::move(els));
}
ConstEval::Result ConstEval::unpack2x16unorm(const type::Type* ty,
@@ -3682,7 +3685,7 @@
}
els.Push(el.Get());
}
- return builder.create<constant::Composite>(ty, std::move(els));
+ return builder.constants.Composite(ty, std::move(els));
}
ConstEval::Result ConstEval::unpack4x8snorm(const type::Type* ty,
@@ -3702,7 +3705,7 @@
}
els.Push(el.Get());
}
- return builder.create<constant::Composite>(ty, std::move(els));
+ return builder.constants.Composite(ty, std::move(els));
}
ConstEval::Result ConstEval::unpack4x8unorm(const type::Type* ty,
@@ -3721,7 +3724,7 @@
}
els.Push(el.Get());
}
- return builder.create<constant::Composite>(ty, std::move(els));
+ return builder.constants.Composite(ty, std::move(els));
}
ConstEval::Result ConstEval::quantizeToF16(const type::Type* ty,
diff --git a/src/tint/resolver/const_eval_conversion_test.cc b/src/tint/resolver/const_eval_conversion_test.cc
index 350afb2..468d2b4 100644
--- a/src/tint/resolver/const_eval_conversion_test.cc
+++ b/src/tint/resolver/const_eval_conversion_test.cc
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "src/tint/constant/splat.h"
#include "src/tint/resolver/const_eval_test.h"
#include "src/tint/sem/materialize.h"
diff --git a/src/tint/resolver/const_eval_runtime_semantics_test.cc b/src/tint/resolver/const_eval_runtime_semantics_test.cc
index eb0cf52..2e41482 100644
--- a/src/tint/resolver/const_eval_runtime_semantics_test.cc
+++ b/src/tint/resolver/const_eval_runtime_semantics_test.cc
@@ -37,29 +37,11 @@
diag::Formatter formatter{style};
return formatter.format(Diagnostics());
}
-
- /// Helper to make a scalar constant::Value from a value.
- template <typename T>
- const constant::Value* Scalar(T value) {
- if constexpr (IsAbstract<T>) {
- if constexpr (IsFloatingPoint<T>) {
- return create<constant::Scalar<AFloat>>(create<type::AbstractFloat>(), value);
- } else if constexpr (IsIntegral<T>) {
- return create<constant::Scalar<AInt>>(create<type::AbstractInt>(), value);
- }
- } else if constexpr (IsFloatingPoint<T>) {
- return create<constant::Scalar<f32>>(create<type::F32>(), value);
- } else if constexpr (IsSignedIntegral<T>) {
- return create<constant::Scalar<i32>>(create<type::I32>(), value);
- } else if constexpr (IsUnsignedIntegral<T>) {
- return create<constant::Scalar<u32>>(create<type::U32>(), value);
- }
- }
};
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Add_AInt_Overflow) {
- auto* a = Scalar(AInt::Highest());
- auto* b = Scalar(AInt(1));
+ auto* a = constants.Get(AInt::Highest());
+ auto* b = constants.Get(AInt(1));
auto result = const_eval.OpPlus(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<AInt>(), 0);
@@ -68,8 +50,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Add_AFloat_Overflow) {
- auto* a = Scalar(AFloat::Highest());
- auto* b = Scalar(AFloat::Highest());
+ auto* a = constants.Get(AFloat::Highest());
+ auto* b = constants.Get(AFloat::Highest());
auto result = const_eval.OpPlus(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<AFloat>(), 0.f);
@@ -79,8 +61,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Add_F32_Overflow) {
- auto* a = Scalar(f32::Highest());
- auto* b = Scalar(f32::Highest());
+ auto* a = constants.Get(f32::Highest());
+ auto* b = constants.Get(f32::Highest());
auto result = const_eval.OpPlus(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -90,8 +72,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Sub_AInt_Overflow) {
- auto* a = Scalar(AInt::Lowest());
- auto* b = Scalar(AInt(1));
+ auto* a = constants.Get(AInt::Lowest());
+ auto* b = constants.Get(AInt(1));
auto result = const_eval.OpMinus(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<AInt>(), 0);
@@ -100,8 +82,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Sub_AFloat_Overflow) {
- auto* a = Scalar(AFloat::Lowest());
- auto* b = Scalar(AFloat::Highest());
+ auto* a = constants.Get(AFloat::Lowest());
+ auto* b = constants.Get(AFloat::Highest());
auto result = const_eval.OpMinus(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<AFloat>(), 0.f);
@@ -111,8 +93,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Sub_F32_Overflow) {
- auto* a = Scalar(f32::Lowest());
- auto* b = Scalar(f32::Highest());
+ auto* a = constants.Get(f32::Lowest());
+ auto* b = constants.Get(f32::Highest());
auto result = const_eval.OpMinus(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -122,8 +104,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mul_AInt_Overflow) {
- auto* a = Scalar(AInt::Highest());
- auto* b = Scalar(AInt(2));
+ auto* a = constants.Get(AInt::Highest());
+ auto* b = constants.Get(AInt(2));
auto result = const_eval.OpMultiply(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<AInt>(), 0);
@@ -132,8 +114,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mul_AFloat_Overflow) {
- auto* a = Scalar(AFloat::Highest());
- auto* b = Scalar(AFloat::Highest());
+ auto* a = constants.Get(AFloat::Highest());
+ auto* b = constants.Get(AFloat::Highest());
auto result = const_eval.OpMultiply(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<AFloat>(), 0.f);
@@ -143,8 +125,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mul_F32_Overflow) {
- auto* a = Scalar(f32::Highest());
- auto* b = Scalar(f32::Highest());
+ auto* a = constants.Get(f32::Highest());
+ auto* b = constants.Get(f32::Highest());
auto result = const_eval.OpMultiply(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -154,8 +136,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_AInt_ZeroDenominator) {
- auto* a = Scalar(AInt(42));
- auto* b = Scalar(AInt(0));
+ auto* a = constants.Get(AInt(42));
+ auto* b = constants.Get(AInt(0));
auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<AInt>(), 42);
@@ -163,8 +145,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_I32_ZeroDenominator) {
- auto* a = Scalar(i32(42));
- auto* b = Scalar(i32(0));
+ auto* a = constants.Get(i32(42));
+ auto* b = constants.Get(i32(0));
auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<i32>(), 42);
@@ -172,8 +154,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_U32_ZeroDenominator) {
- auto* a = Scalar(u32(42));
- auto* b = Scalar(u32(0));
+ auto* a = constants.Get(u32(42));
+ auto* b = constants.Get(u32(0));
auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<u32>(), 42);
@@ -181,8 +163,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_AFloat_ZeroDenominator) {
- auto* a = Scalar(AFloat(42));
- auto* b = Scalar(AFloat(0));
+ auto* a = constants.Get(AFloat(42));
+ auto* b = constants.Get(AFloat(0));
auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<AFloat>(), 42.f);
@@ -190,8 +172,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_F32_ZeroDenominator) {
- auto* a = Scalar(f32(42));
- auto* b = Scalar(f32(0));
+ auto* a = constants.Get(f32(42));
+ auto* b = constants.Get(f32(0));
auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 42.f);
@@ -199,8 +181,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_I32_MostNegativeByMinInt) {
- auto* a = Scalar(i32::Lowest());
- auto* b = Scalar(i32(-1));
+ auto* a = constants.Get(i32::Lowest());
+ auto* b = constants.Get(i32(-1));
auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<i32>(), i32::Lowest());
@@ -208,8 +190,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_AInt_ZeroDenominator) {
- auto* a = Scalar(AInt(42));
- auto* b = Scalar(AInt(0));
+ auto* a = constants.Get(AInt(42));
+ auto* b = constants.Get(AInt(0));
auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<AInt>(), 0);
@@ -217,8 +199,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_I32_ZeroDenominator) {
- auto* a = Scalar(i32(42));
- auto* b = Scalar(i32(0));
+ auto* a = constants.Get(i32(42));
+ auto* b = constants.Get(i32(0));
auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<i32>(), 0);
@@ -226,8 +208,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_U32_ZeroDenominator) {
- auto* a = Scalar(u32(42));
- auto* b = Scalar(u32(0));
+ auto* a = constants.Get(u32(42));
+ auto* b = constants.Get(u32(0));
auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<u32>(), 0);
@@ -235,8 +217,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_AFloat_ZeroDenominator) {
- auto* a = Scalar(AFloat(42));
- auto* b = Scalar(AFloat(0));
+ auto* a = constants.Get(AFloat(42));
+ auto* b = constants.Get(AFloat(0));
auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<AFloat>(), 0.f);
@@ -244,8 +226,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_F32_ZeroDenominator) {
- auto* a = Scalar(f32(42));
- auto* b = Scalar(f32(0));
+ auto* a = constants.Get(f32(42));
+ auto* b = constants.Get(f32(0));
auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -253,8 +235,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_I32_MostNegativeByMinInt) {
- auto* a = Scalar(i32::Lowest());
- auto* b = Scalar(i32(-1));
+ auto* a = constants.Get(i32::Lowest());
+ auto* b = constants.Get(i32(-1));
auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<i32>(), 0);
@@ -262,8 +244,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftLeft_AInt_SignChange) {
- auto* a = Scalar(AInt(0x0FFFFFFFFFFFFFFFll));
- auto* b = Scalar(u32(9));
+ auto* a = constants.Get(AInt(0x0FFFFFFFFFFFFFFFll));
+ auto* b = constants.Get(u32(9));
auto result = const_eval.OpShiftLeft(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<AInt>(), static_cast<AInt>(0x0FFFFFFFFFFFFFFFull << 9));
@@ -271,8 +253,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftLeft_I32_SignChange) {
- auto* a = Scalar(i32(0x0FFFFFFF));
- auto* b = Scalar(u32(9));
+ auto* a = constants.Get(i32(0x0FFFFFFF));
+ auto* b = constants.Get(u32(9));
auto result = const_eval.OpShiftLeft(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<i32>(), static_cast<i32>(0x0FFFFFFFu << 9));
@@ -280,8 +262,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftLeft_I32_MoreThanBitWidth) {
- auto* a = Scalar(i32(0x1));
- auto* b = Scalar(u32(33));
+ auto* a = constants.Get(i32(0x1));
+ auto* b = constants.Get(u32(33));
auto result = const_eval.OpShiftLeft(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<i32>(), 2);
@@ -291,8 +273,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftLeft_U32_MoreThanBitWidth) {
- auto* a = Scalar(u32(0x1));
- auto* b = Scalar(u32(33));
+ auto* a = constants.Get(u32(0x1));
+ auto* b = constants.Get(u32(33));
auto result = const_eval.OpShiftLeft(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<u32>(), 2);
@@ -302,8 +284,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftRight_I32_MoreThanBitWidth) {
- auto* a = Scalar(i32(0x2));
- auto* b = Scalar(u32(33));
+ auto* a = constants.Get(i32(0x2));
+ auto* b = constants.Get(u32(33));
auto result = const_eval.OpShiftRight(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<i32>(), 1);
@@ -313,8 +295,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftRight_U32_MoreThanBitWidth) {
- auto* a = Scalar(u32(0x2));
- auto* b = Scalar(u32(33));
+ auto* a = constants.Get(u32(0x2));
+ auto* b = constants.Get(u32(33));
auto result = const_eval.OpShiftRight(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<u32>(), 1);
@@ -324,7 +306,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Acos_F32_OutOfRange) {
- auto* a = Scalar(f32(2));
+ auto* a = constants.Get(f32(2));
auto result = const_eval.acos(a->Type(), utils::Vector{a}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -333,7 +315,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Acosh_F32_OutOfRange) {
- auto* a = Scalar(f32(-1));
+ auto* a = constants.Get(f32(-1));
auto result = const_eval.acosh(a->Type(), utils::Vector{a}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -341,7 +323,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Asin_F32_OutOfRange) {
- auto* a = Scalar(f32(2));
+ auto* a = constants.Get(f32(2));
auto result = const_eval.asin(a->Type(), utils::Vector{a}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -350,7 +332,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Atanh_F32_OutOfRange) {
- auto* a = Scalar(f32(2));
+ auto* a = constants.Get(f32(2));
auto result = const_eval.atanh(a->Type(), utils::Vector{a}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -359,7 +341,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Exp_F32_Overflow) {
- auto* a = Scalar(f32(1000));
+ auto* a = constants.Get(f32(1000));
auto result = const_eval.exp(a->Type(), utils::Vector{a}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -367,7 +349,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Exp2_F32_Overflow) {
- auto* a = Scalar(f32(1000));
+ auto* a = constants.Get(f32(1000));
auto result = const_eval.exp2(a->Type(), utils::Vector{a}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -375,9 +357,9 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, ExtractBits_I32_TooManyBits) {
- auto* a = Scalar(i32(0x12345678));
- auto* offset = Scalar(u32(24));
- auto* count = Scalar(u32(16));
+ auto* a = constants.Get(i32(0x12345678));
+ auto* offset = constants.Get(u32(24));
+ auto* count = constants.Get(u32(16));
auto result = const_eval.extractBits(a->Type(), utils::Vector{a, offset, count}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<i32>(), 0x12);
@@ -386,9 +368,9 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, ExtractBits_U32_TooManyBits) {
- auto* a = Scalar(u32(0x12345678));
- auto* offset = Scalar(u32(24));
- auto* count = Scalar(u32(16));
+ auto* a = constants.Get(u32(0x12345678));
+ auto* offset = constants.Get(u32(24));
+ auto* count = constants.Get(u32(16));
auto result = const_eval.extractBits(a->Type(), utils::Vector{a, offset, count}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<u32>(), 0x12);
@@ -397,10 +379,10 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, InsertBits_I32_TooManyBits) {
- auto* a = Scalar(i32(0x99345678));
- auto* b = Scalar(i32(0x12));
- auto* offset = Scalar(u32(24));
- auto* count = Scalar(u32(16));
+ auto* a = constants.Get(i32(0x99345678));
+ auto* b = constants.Get(i32(0x12));
+ auto* offset = constants.Get(u32(24));
+ auto* count = constants.Get(u32(16));
auto result = const_eval.insertBits(a->Type(), utils::Vector{a, b, offset, count}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<i32>(), 0x12345678);
@@ -409,10 +391,10 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, InsertBits_U32_TooManyBits) {
- auto* a = Scalar(u32(0x99345678));
- auto* b = Scalar(u32(0x12));
- auto* offset = Scalar(u32(24));
- auto* count = Scalar(u32(16));
+ auto* a = constants.Get(u32(0x99345678));
+ auto* b = constants.Get(u32(0x12));
+ auto* offset = constants.Get(u32(24));
+ auto* count = constants.Get(u32(16));
auto result = const_eval.insertBits(a->Type(), utils::Vector{a, b, offset, count}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<u32>(), 0x12345678);
@@ -421,7 +403,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, InverseSqrt_F32_OutOfRange) {
- auto* a = Scalar(f32(-1));
+ auto* a = constants.Get(f32(-1));
auto result = const_eval.inverseSqrt(a->Type(), utils::Vector{a}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -429,8 +411,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, LDExpr_F32_OutOfRange) {
- auto* a = Scalar(f32(42.f));
- auto* b = Scalar(f32(200));
+ auto* a = constants.Get(f32(42.f));
+ auto* b = constants.Get(f32(200));
auto result = const_eval.ldexp(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -438,7 +420,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Log_F32_OutOfRange) {
- auto* a = Scalar(f32(-1));
+ auto* a = constants.Get(f32(-1));
auto result = const_eval.log(a->Type(), utils::Vector{a}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -446,7 +428,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Log2_F32_OutOfRange) {
- auto* a = Scalar(f32(-1));
+ auto* a = constants.Get(f32(-1));
auto result = const_eval.log2(a->Type(), utils::Vector{a}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -454,7 +436,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Normalize_ZeroLength) {
- auto* zero = Scalar(f32(0));
+ auto* zero = constants.Get(f32(0));
auto* vec =
const_eval.VecSplat(create<type::Vector>(create<type::F32>(), 4u), utils::Vector{zero}, {})
.Get();
@@ -468,8 +450,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Pack2x16Float_OutOfRange) {
- auto* a = Scalar(f32(75250.f));
- auto* b = Scalar(f32(42.1f));
+ auto* a = constants.Get(f32(75250.f));
+ auto* b = constants.Get(f32(42.1f));
auto* vec =
const_eval.VecInitS(create<type::Vector>(create<type::F32>(), 2u), utils::Vector{a, b}, {})
.Get();
@@ -480,8 +462,8 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Pow_F32_Overflow) {
- auto* a = Scalar(f32(2));
- auto* b = Scalar(f32(1000));
+ auto* a = constants.Get(f32(2));
+ auto* b = constants.Get(f32(1000));
auto result = const_eval.pow(a->Type(), utils::Vector{a, b}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -489,7 +471,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Unpack2x16Float_OutOfRange) {
- auto* a = Scalar(u32(0x51437C00));
+ auto* a = constants.Get(u32(0x51437C00));
auto result = const_eval.unpack2x16float(create<type::U32>(), utils::Vector{a}, {});
ASSERT_TRUE(result);
EXPECT_FLOAT_EQ(result.Get()->Index(0)->ValueAs<f32>(), 0.f);
@@ -498,7 +480,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, QuantizeToF16_OutOfRange) {
- auto* a = Scalar(f32(75250.f));
+ auto* a = constants.Get(f32(75250.f));
auto result = const_eval.quantizeToF16(create<type::U32>(), utils::Vector{a}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<u32>(), 0);
@@ -506,7 +488,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Sqrt_F32_OutOfRange) {
- auto* a = Scalar(f32(-1));
+ auto* a = constants.Get(f32(-1));
auto result = const_eval.sqrt(a->Type(), utils::Vector{a}, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -514,7 +496,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Bitcast_Infinity) {
- auto* a = Scalar(u32(0x7F800000));
+ auto* a = constants.Get(u32(0x7F800000));
auto result = const_eval.Bitcast(create<type::F32>(), a, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -522,7 +504,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Bitcast_NaN) {
- auto* a = Scalar(u32(0x7FC00000));
+ auto* a = constants.Get(u32(0x7FC00000));
auto result = const_eval.Bitcast(create<type::F32>(), a, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
@@ -530,7 +512,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Convert_F32_TooHigh) {
- auto* a = Scalar(AFloat::Highest());
+ auto* a = constants.Get(AFloat::Highest());
auto result = const_eval.Convert(create<type::F32>(), a, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), f32::kHighestValue);
@@ -540,7 +522,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Convert_F32_TooLow) {
- auto* a = Scalar(AFloat::Lowest());
+ auto* a = constants.Get(AFloat::Lowest());
auto result = const_eval.Convert(create<type::F32>(), a, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), f32::kLowestValue);
@@ -550,7 +532,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Convert_F16_TooHigh) {
- auto* a = Scalar(f32(1000000.0));
+ auto* a = constants.Get(f32(1000000.0));
auto result = const_eval.Convert(create<type::F16>(), a, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), f16::kHighestValue);
@@ -558,7 +540,7 @@
}
TEST_F(ResolverConstEvalRuntimeSemanticsTest, Convert_F16_TooLow) {
- auto* a = Scalar(f32(-1000000.0));
+ auto* a = constants.Get(f32(-1000000.0));
auto result = const_eval.Convert(create<type::F16>(), a, {});
ASSERT_TRUE(result);
EXPECT_EQ(result.Get()->ValueAs<f32>(), f16::kLowestValue);
@@ -571,10 +553,10 @@
auto* a = const_eval
.VecInitS(vec4f,
utils::Vector{
- Scalar(f32(1)),
- Scalar(f32(4)),
- Scalar(f32(-1)),
- Scalar(f32(65536)),
+ constants.Get(f32(1)),
+ constants.Get(f32(4)),
+ constants.Get(f32(-1)),
+ constants.Get(f32(65536)),
},
{})
.Get();
diff --git a/src/tint/transform/manager_test.cc b/src/tint/transform/manager_test.cc
index d1f6333..75afa93 100644
--- a/src/tint/transform/manager_test.cc
+++ b/src/tint/transform/manager_test.cc
@@ -51,7 +51,7 @@
void Run(ir::Module* mod, const DataMap&, DataMap&) const override {
ir::Builder builder(*mod);
auto* func =
- builder.CreateFunction(mod->symbols.New("ir_func"), mod->types.Get<type::Void>());
+ builder.CreateFunction(mod->symbols.New("ir_func"), mod->Types().Get<type::Void>());
func->StartTarget()->SetInstructions(utils::Vector{builder.Branch(func->EndTarget())});
mod->functions.Push(func);
}
@@ -69,7 +69,7 @@
ir::Module mod;
ir::Builder builder(mod);
auto* func =
- builder.CreateFunction(builder.ir.symbols.New("main"), builder.ir.types.Get<type::Void>());
+ builder.CreateFunction(builder.ir.symbols.New("main"), mod.Types().Get<type::Void>());
func->StartTarget()->SetInstructions(utils::Vector{builder.Branch(func->EndTarget())});
builder.ir.functions.Push(func);
return mod;
diff --git a/src/tint/type/manager.cc b/src/tint/type/manager.cc
index 666a1d3..0782633 100644
--- a/src/tint/type/manager.cc
+++ b/src/tint/type/manager.cc
@@ -14,6 +14,18 @@
#include "src/tint/type/manager.h"
+#include "src/tint/type/abstract_float.h"
+#include "src/tint/type/abstract_int.h"
+#include "src/tint/type/bool.h"
+#include "src/tint/type/f16.h"
+#include "src/tint/type/f32.h"
+#include "src/tint/type/i32.h"
+#include "src/tint/type/matrix.h"
+#include "src/tint/type/type.h"
+#include "src/tint/type/u32.h"
+#include "src/tint/type/vector.h"
+#include "src/tint/type/void.h"
+
namespace tint::type {
Manager::Manager() = default;
@@ -24,4 +36,91 @@
Manager::~Manager() = default;
+const type::Void* Manager::void_() {
+ return Get<type::Void>();
+}
+
+const type::Bool* Manager::bool_() {
+ return Get<type::Bool>();
+}
+
+const type::I32* Manager::i32() {
+ return Get<type::I32>();
+}
+
+const type::U32* Manager::u32() {
+ return Get<type::U32>();
+}
+
+const type::F32* Manager::f32() {
+ return Get<type::F32>();
+}
+
+const type::F16* Manager::f16() {
+ return Get<type::F16>();
+}
+
+const type::AbstractFloat* Manager::AFloat() {
+ return Get<type::AbstractFloat>();
+}
+
+const type::AbstractInt* Manager::AInt() {
+ return Get<type::AbstractInt>();
+}
+
+const type::Vector* Manager::vec(const type::Type* inner, uint32_t size) {
+ return Get<type::Vector>(inner, size);
+}
+
+const type::Vector* Manager::vec2(const type::Type* inner) {
+ return vec(inner, 2);
+}
+
+const type::Vector* Manager::vec3(const type::Type* inner) {
+ return vec(inner, 3);
+}
+
+const type::Vector* Manager::vec4(const type::Type* inner) {
+ return vec(inner, 4);
+}
+
+const type::Matrix* Manager::mat(const type::Type* inner, uint32_t cols, uint32_t rows) {
+ return Get<type::Matrix>(vec(inner, rows), cols);
+}
+
+const type::Matrix* Manager::mat2x2(const type::Type* inner) {
+ return mat(inner, 2, 2);
+}
+
+const type::Matrix* Manager::mat2x3(const type::Type* inner) {
+ return mat(inner, 2, 3);
+}
+
+const type::Matrix* Manager::mat2x4(const type::Type* inner) {
+ return mat(inner, 2, 4);
+}
+
+const type::Matrix* Manager::mat3x2(const type::Type* inner) {
+ return mat(inner, 3, 2);
+}
+
+const type::Matrix* Manager::mat3x3(const type::Type* inner) {
+ return mat(inner, 3, 3);
+}
+
+const type::Matrix* Manager::mat3x4(const type::Type* inner) {
+ return mat(inner, 3, 4);
+}
+
+const type::Matrix* Manager::mat4x2(const type::Type* inner) {
+ return mat(inner, 4, 2);
+}
+
+const type::Matrix* Manager::mat4x3(const type::Type* inner) {
+ return mat(inner, 4, 3);
+}
+
+const type::Matrix* Manager::mat4x4(const type::Type* inner) {
+ return mat(inner, 4, 4);
+}
} // namespace tint::type
diff --git a/src/tint/type/manager.h b/src/tint/type/manager.h
index 4eb48bb..b4fa560 100644
--- a/src/tint/type/manager.h
+++ b/src/tint/type/manager.h
@@ -18,10 +18,24 @@
#include <utility>
#include "src/tint/type/type.h"
-#include "src/tint/type/vector.h"
+#include "src/tint/type/unique_node.h"
#include "src/tint/utils/hash.h"
#include "src/tint/utils/unique_allocator.h"
+// Forward declarations
+namespace tint::type {
+class AbstractFloat;
+class AbstractInt;
+class Bool;
+class F16;
+class F32;
+class I32;
+class Matrix;
+class U32;
+class Vector;
+class Void;
+} // namespace tint::type
+
namespace tint::type {
/// The type manager holds all the pointers to the known types.
@@ -85,22 +99,88 @@
return types_.Find<TYPE>(std::forward<ARGS>(args)...);
}
+ /// @returns a void type
+ const type::Void* void_();
+
+ /// @returns a bool type
+ const type::Bool* bool_();
+
+ /// @returns an i32 type
+ const type::I32* i32();
+
+ /// @returns a u32 type
+ const type::U32* u32();
+
+ /// @returns an f32 type
+ const type::F32* f32();
+
+ /// @returns an f16 type
+ const type::F16* f16();
+
+ /// @returns a abstract-float type
+ const type::AbstractFloat* AFloat();
+
+ /// @returns a abstract-int type
+ const type::AbstractInt* AInt();
+
/// @param inner the inner type
/// @param size the vector size
/// @returns the vector type
- type::Type* vec(type::Type* inner, uint32_t size) { return Get<type::Vector>(inner, size); }
+ const type::Vector* vec(const type::Type* inner, uint32_t size);
/// @param inner the inner type
/// @returns the vector type
- type::Type* vec2(type::Type* inner) { return vec(inner, 2); }
+ const type::Vector* vec2(const type::Type* inner);
/// @param inner the inner type
/// @returns the vector type
- type::Type* vec3(type::Type* inner) { return vec(inner, 3); }
+ const type::Vector* vec3(const type::Type* inner);
/// @param inner the inner type
/// @returns the vector type
- type::Type* vec4(type::Type* inner) { return vec(inner, 4); }
+ const type::Vector* vec4(const type::Type* inner);
+
+ /// @param inner the inner type
+ /// @param cols the number of columns
+ /// @param rows the number of rows
+ /// @returns the matrix type
+ const type::Matrix* mat(const type::Type* inner, uint32_t cols, uint32_t rows);
+
+ /// @param inner the inner type
+ /// @returns the matrix type
+ const type::Matrix* mat2x2(const type::Type* inner);
+
+ /// @param inner the inner type
+ /// @returns the matrix type
+ const type::Matrix* mat2x3(const type::Type* inner);
+
+ /// @param inner the inner type
+ /// @returns the matrix type
+ const type::Matrix* mat2x4(const type::Type* inner);
+
+ /// @param inner the inner type
+ /// @returns the matrix type
+ const type::Matrix* mat3x2(const type::Type* inner);
+
+ /// @param inner the inner type
+ /// @returns the matrix type
+ const type::Matrix* mat3x3(const type::Type* inner);
+
+ /// @param inner the inner type
+ /// @returns the matrix type
+ const type::Matrix* mat3x4(const type::Type* inner);
+
+ /// @param inner the inner type
+ /// @returns the matrix type
+ const type::Matrix* mat4x2(const type::Type* inner);
+
+ /// @param inner the inner type
+ /// @returns the matrix type
+ const type::Matrix* mat4x3(const type::Type* inner);
+
+ /// @param inner the inner type
+ /// @returns the matrix type
+ const type::Matrix* mat4x4(const type::Type* inner);
/// @returns an iterator to the beginning of the types
TypeIterator begin() const { return types_.begin(); }
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index a15f6a9..549f366 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -52,6 +52,7 @@
#include "src/tint/ast/transform/unshadow.h"
#include "src/tint/ast/transform/zero_init_workgroup_memory.h"
#include "src/tint/ast/variable_decl_statement.h"
+#include "src/tint/constant/splat.h"
#include "src/tint/constant/value.h"
#include "src/tint/debug.h"
#include "src/tint/sem/block_statement.h"
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index 95d2022..eb07e8e 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -51,6 +51,7 @@
#include "src/tint/ast/transform/vectorize_scalar_matrix_initializers.h"
#include "src/tint/ast/transform/zero_init_workgroup_memory.h"
#include "src/tint/ast/variable_decl_statement.h"
+#include "src/tint/constant/splat.h"
#include "src/tint/constant/value.h"
#include "src/tint/debug.h"
#include "src/tint/sem/block_statement.h"
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index 55b9f8d..4c51b17 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -49,6 +49,7 @@
#include "src/tint/ast/transform/vectorize_scalar_matrix_initializers.h"
#include "src/tint/ast/transform/zero_init_workgroup_memory.h"
#include "src/tint/ast/variable_decl_statement.h"
+#include "src/tint/constant/splat.h"
#include "src/tint/constant/value.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/function.h"
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index 44bead9..05802f8 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -14,6 +14,8 @@
#include "src/tint/writer/spirv/ir/generator_impl_ir.h"
+#include <utility>
+
#include "spirv/unified1/spirv.h"
#include "src/tint/ir/binary.h"
#include "src/tint/ir/block.h"
@@ -23,6 +25,7 @@
#include "src/tint/ir/module.h"
#include "src/tint/ir/store.h"
#include "src/tint/ir/transform/add_empty_entry_point.h"
+#include "src/tint/ir/user_call.h"
#include "src/tint/ir/var.h"
#include "src/tint/switch.h"
#include "src/tint/transform/manager.h"
@@ -30,6 +33,7 @@
#include "src/tint/type/f16.h"
#include "src/tint/type/f32.h"
#include "src/tint/type/i32.h"
+#include "src/tint/type/matrix.h"
#include "src/tint/type/pointer.h"
#include "src/tint/type/type.h"
#include "src/tint/type/u32.h"
@@ -146,6 +150,13 @@
}
module_.PushType(spv::Op::OpConstantComposite, operands);
},
+ [&](const type::Matrix* mat) {
+ OperandList operands = {Type(ty), id};
+ for (uint32_t i = 0; i < mat->columns(); i++) {
+ operands.push_back(Constant(constant->Index(i)));
+ }
+ module_.PushType(spv::Op::OpConstantComposite, operands);
+ },
[&](Default) {
TINT_ICE(Writer, diagnostics_) << "unhandled constant type: " << ty->FriendlyName();
});
@@ -175,6 +186,10 @@
[&](const type::Vector* vec) {
module_.PushType(spv::Op::OpTypeVector, {id, Type(vec->type()), vec->Width()});
},
+ [&](const type::Matrix* mat) {
+ module_.PushType(spv::Op::OpTypeMatrix,
+ {id, Type(mat->ColumnType()), mat->columns()});
+ },
[&](const type::Pointer* ptr) {
module_.PushType(
spv::Op::OpTypePointer,
@@ -191,17 +206,13 @@
return Switch(
value, //
[&](const ir::Constant* constant) { return Constant(constant); },
- [&](const ir::Instruction* inst) {
- auto id = instructions_.Find(inst);
+ [&](const ir::Value*) {
+ auto id = values_.Find(value);
if (TINT_UNLIKELY(!id)) {
- TINT_ICE(Writer, diagnostics_) << "missing instruction result";
+ TINT_ICE(Writer, diagnostics_) << "missing result ID for value";
return 0u;
}
return *id;
- },
- [&](Default) {
- TINT_ICE(Writer, diagnostics_) << "unhandled value node: " << value->TypeInfo().name;
- return 0u;
});
}
@@ -212,6 +223,7 @@
void GeneratorImplIr::EmitFunction(const ir::Function* func) {
// Make an ID for the function.
auto id = module_.NextId();
+ functions_.Add(func->Name(), id);
// Emit the function name.
module_.PushDebug(spv::Op::OpName, {id, Operand(func->Name().Name())});
@@ -224,9 +236,22 @@
// Get the ID for the return type.
auto return_type_id = Type(func->ReturnType());
- // Get the ID for the function type (creating it if needed).
- // TODO(jrprice): Add the parameter types when they are supported in the IR.
FunctionType function_type{return_type_id, {}};
+ InstructionList params;
+
+ // Generate function parameter declarations and add their type IDs to the function signature.
+ for (auto* param : func->Params()) {
+ auto param_type_id = Type(param->Type());
+ auto param_id = module_.NextId();
+ params.push_back(Instruction(spv::Op::OpFunctionParameter, {param_type_id, param_id}));
+ values_.Add(param, param_id);
+ function_type.param_type_ids.Push(param_type_id);
+ if (auto name = ir_->NameOf(param)) {
+ module_.PushDebug(spv::Op::OpName, {param_id, Operand(name.Name())});
+ }
+ }
+
+ // Get the ID for the function type (creating it if needed).
auto function_type_id = function_types_.GetOrCreate(function_type, [&]() {
auto func_ty_id = module_.NextId();
OperandList operands = {func_ty_id, return_type_id};
@@ -242,9 +267,8 @@
{return_type_id, id, U32Operand(SpvFunctionControlMaskNone), function_type_id}};
// Create a function that we will add instructions to.
- // TODO(jrprice): Add the parameter declarations when they are supported in the IR.
auto entry_block = module_.NextId();
- current_function_ = Function(decl, entry_block, {});
+ current_function_ = Function(decl, entry_block, std::move(params));
TINT_DEFER(current_function_ = Function());
// Emit the body of the function.
@@ -309,6 +333,7 @@
EmitStore(s);
return 0u;
},
+ [&](const ir::UserCall* c) { return EmitUserCall(c); },
[&](const ir::Var* v) { return EmitVar(v); },
[&](const ir::If* i) {
EmitIf(i);
@@ -323,21 +348,24 @@
<< "unimplemented instruction: " << inst->TypeInfo().name;
return 0u;
});
- instructions_.Add(inst, result);
+ values_.Add(inst, result);
}
}
void GeneratorImplIr::EmitBranch(const ir::Branch* b) {
Switch(
b->To(),
- [&](const ir::Block* blk) { current_function_.push_inst(spv::Op::OpBranch, {Label(blk)}); },
[&](const ir::FunctionTerminator*) {
- // TODO(jrprice): Handle the return value, which will be a branch argument.
if (!b->Args().IsEmpty()) {
- TINT_ICE(Writer, diagnostics_) << "unimplemented return value";
+ TINT_ASSERT(Writer, b->Args().Length() == 1u);
+ OperandList operands;
+ operands.push_back(Value(b->Args()[0]));
+ current_function_.push_inst(spv::Op::OpReturnValue, operands);
+ } else {
+ current_function_.push_inst(spv::Op::OpReturn, {});
}
- current_function_.push_inst(spv::Op::OpReturn, {});
},
+ [&](const ir::Block* blk) { current_function_.push_inst(spv::Op::OpBranch, {Label(blk)}); },
[&](Default) {
// A block may not have an outward branch (e.g. an unreachable merge
// block).
@@ -384,6 +412,7 @@
uint32_t GeneratorImplIr::EmitBinary(const ir::Binary* binary) {
auto id = module_.NextId();
+ auto* lhs_ty = binary->LHS()->Type();
// Determine the opcode.
spv::Op op = spv::Op::Max;
@@ -396,6 +425,81 @@
op = binary->Type()->is_integer_scalar_or_vector() ? spv::Op::OpISub : spv::Op::OpFSub;
break;
}
+
+ case ir::Binary::Kind::kAnd: {
+ op = spv::Op::OpBitwiseAnd;
+ break;
+ }
+ case ir::Binary::Kind::kOr: {
+ op = spv::Op::OpBitwiseOr;
+ break;
+ }
+ case ir::Binary::Kind::kXor: {
+ op = spv::Op::OpBitwiseXor;
+ break;
+ }
+
+ case ir::Binary::Kind::kEqual: {
+ if (lhs_ty->is_bool_scalar_or_vector()) {
+ op = spv::Op::OpLogicalEqual;
+ } else if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdEqual;
+ } else if (lhs_ty->is_integer_scalar_or_vector()) {
+ op = spv::Op::OpIEqual;
+ }
+ break;
+ }
+ case ir::Binary::Kind::kNotEqual: {
+ if (lhs_ty->is_bool_scalar_or_vector()) {
+ op = spv::Op::OpLogicalNotEqual;
+ } else if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdNotEqual;
+ } else if (lhs_ty->is_integer_scalar_or_vector()) {
+ op = spv::Op::OpINotEqual;
+ }
+ break;
+ }
+ case ir::Binary::Kind::kGreaterThan: {
+ if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdGreaterThan;
+ } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSGreaterThan;
+ } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpUGreaterThan;
+ }
+ break;
+ }
+ case ir::Binary::Kind::kGreaterThanEqual: {
+ if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdGreaterThanEqual;
+ } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSGreaterThanEqual;
+ } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpUGreaterThanEqual;
+ }
+ break;
+ }
+ case ir::Binary::Kind::kLessThan: {
+ if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdLessThan;
+ } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSLessThan;
+ } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpULessThan;
+ }
+ break;
+ }
+ case ir::Binary::Kind::kLessThanEqual: {
+ if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdLessThanEqual;
+ } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSLessThanEqual;
+ } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpULessThanEqual;
+ }
+ break;
+ }
+
default: {
TINT_ICE(Writer, diagnostics_)
<< "unimplemented binary instruction: " << static_cast<uint32_t>(binary->Kind());
@@ -419,6 +523,16 @@
current_function_.push_inst(spv::Op::OpStore, {Value(store->To()), Value(store->From())});
}
+uint32_t GeneratorImplIr::EmitUserCall(const ir::UserCall* call) {
+ auto id = module_.NextId();
+ OperandList operands = {Type(call->Type()), id, functions_.Get(call->Name()).value()};
+ for (auto* arg : call->Args()) {
+ operands.push_back(Value(arg));
+ }
+ current_function_.push_inst(spv::Op::OpFunctionCall, operands);
+ return id;
+}
+
uint32_t GeneratorImplIr::EmitVar(const ir::Var* var) {
auto id = module_.NextId();
auto* ptr = var->Type()->As<type::Pointer>();
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.h b/src/tint/writer/spirv/ir/generator_impl_ir.h
index ccd09ad..ee790a0 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.h
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.h
@@ -20,6 +20,7 @@
#include "src/tint/constant/value.h"
#include "src/tint/diagnostic/diagnostic.h"
#include "src/tint/ir/constant.h"
+#include "src/tint/symbol.h"
#include "src/tint/utils/hashmap.h"
#include "src/tint/utils/vector.h"
#include "src/tint/writer/spirv/binary_writer.h"
@@ -36,6 +37,7 @@
class Load;
class Module;
class Store;
+class UserCall;
class Value;
class Var;
} // namespace tint::ir
@@ -117,6 +119,11 @@
/// @param store the store instruction to emit
void EmitStore(const ir::Store* store);
+ /// Emit a user call instruction.
+ /// @param call the user call instruction to emit
+ /// @returns the result ID of the instruction
+ uint32_t EmitUserCall(const ir::UserCall* call);
+
/// Emit a var instruction.
/// @param var the var instruction to emit
/// @returns the result ID of the instruction
@@ -171,8 +178,12 @@
/// The map of constants to their result IDs.
utils::Hashmap<const constant::Value*, uint32_t, 16> constants_;
- /// The map of instructions to their result IDs.
- utils::Hashmap<const ir::Instruction*, uint32_t, 8> instructions_;
+ /// The map of functions to their result IDs.
+ /// TODO(jrprice): Merge into `values_` map when `ir::Function` becomes an `ir::Value`.
+ utils::Hashmap<Symbol, uint32_t, 8> functions_;
+
+ /// The map of non-constant values to their result IDs.
+ utils::Hashmap<const ir::Value*, uint32_t, 8> values_;
/// The map of blocks to the IDs of their label instructions.
utils::Hashmap<const ir::Block*, uint32_t, 8> block_labels_;
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
index 8792239..71a2a36 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
@@ -14,208 +14,272 @@
#include "src/tint/writer/spirv/ir/test_helper_ir.h"
+#include "gmock/gmock.h"
+#include "src/tint/ir/binary.h"
+
using namespace tint::number_suffixes; // NOLINT
namespace tint::writer::spirv {
namespace {
-TEST_F(SpvGeneratorImplTest, Binary_Add_I32) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+/// The element type of a test.
+enum Type {
+ kBool,
+ kI32,
+ kU32,
+ kF32,
+ kF16,
+};
+
+/// A parameterized test case.
+struct BinaryTestCase {
+ /// The element type to test.
+ Type type;
+ /// The binary operation.
+ enum ir::Binary::Kind kind;
+ /// The expected SPIR-V instruction.
+ std::string spirv_inst;
+};
+
+/// A helper class for parameterized binary instruction tests.
+class BinaryInstructionTest : public SpvGeneratorImplTestWithParam<BinaryTestCase> {
+ protected:
+ /// Helper to make a scalar type corresponding to the element type `ty`.
+ /// @param ty the element type
+ /// @returns the scalar type
+ const type::Type* MakeScalarType(Type ty) {
+ switch (ty) {
+ case kBool:
+ return mod.Types().bool_();
+ case kI32:
+ return mod.Types().i32();
+ case kU32:
+ return mod.Types().u32();
+ case kF32:
+ return mod.Types().f32();
+ case kF16:
+ return mod.Types().f16();
+ }
+ return nullptr;
+ }
+
+ /// Helper to make a vector type corresponding to the element type `ty`.
+ /// @param ty the element type
+ /// @returns the vector type
+ const type::Type* MakeVectorType(Type ty) { return mod.Types().vec2(MakeScalarType(ty)); }
+
+ /// Helper to make a scalar value with the scalar type `ty`.
+ /// @param ty the element type
+ /// @returns the scalar value
+ ir::Value* MakeScalarValue(Type ty) {
+ switch (ty) {
+ case kBool:
+ return b.Constant(true);
+ case kI32:
+ return b.Constant(1_i);
+ case kU32:
+ return b.Constant(1_u);
+ case kF32:
+ return b.Constant(1_f);
+ case kF16:
+ return b.Constant(1_h);
+ }
+ return nullptr;
+ }
+
+ /// Helper to make a vector value with an element type of `ty`.
+ /// @param ty the element type
+ /// @returns the vector value
+ ir::Value* MakeVectorValue(Type ty) {
+ switch (ty) {
+ case kBool:
+ return b.Constant(b.ir.constant_values.Composite(
+ MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(true),
+ b.ir.constant_values.Get(false)}));
+ case kI32:
+ return b.Constant(b.ir.constant_values.Composite(
+ MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_i),
+ b.ir.constant_values.Get(-10_i)}));
+ case kU32:
+ return b.Constant(b.ir.constant_values.Composite(
+ MakeVectorType(ty),
+ utils::Vector{b.ir.constant_values.Get(42_u), b.ir.constant_values.Get(10_u)}));
+ case kF32:
+ return b.Constant(b.ir.constant_values.Composite(
+ MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_f),
+ b.ir.constant_values.Get(-0.5_f)}));
+ case kF16:
+ return b.Constant(b.ir.constant_values.Composite(
+ MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_h),
+ b.ir.constant_values.Get(-0.5_h)}));
+ }
+ return nullptr;
+ }
+};
+
+using Arithmetic = BinaryInstructionTest;
+TEST_P(Arithmetic, Scalar) {
+ auto params = GetParam();
+
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(
- utils::Vector{b.Add(mod.types.Get<type::I32>(), b.Constant(1_i), b.Constant(2_i)),
+ utils::Vector{b.CreateBinary(params.kind, MakeScalarType(params.type),
+ MakeScalarValue(params.type), MakeScalarValue(params.type)),
b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%6 = OpTypeInt 32 1
-%7 = OpConstant %6 1
-%8 = OpConstant %6 2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpIAdd %6 %7 %8
-OpReturn
-OpFunctionEnd
-)");
+ EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
+TEST_P(Arithmetic, Vector) {
+ auto params = GetParam();
-TEST_F(SpvGeneratorImplTest, Binary_Add_U32) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(
- utils::Vector{b.Add(mod.types.Get<type::U32>(), b.Constant(1_u), b.Constant(2_u)),
+ utils::Vector{b.CreateBinary(params.kind, MakeVectorType(params.type),
+ MakeVectorValue(params.type), MakeVectorValue(params.type)),
+
b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%6 = OpTypeInt 32 0
-%7 = OpConstant %6 1
-%8 = OpConstant %6 2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpIAdd %6 %7 %8
-OpReturn
-OpFunctionEnd
-)");
+ EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
+INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest_Binary_I32,
+ Arithmetic,
+ testing::Values(BinaryTestCase{kI32, ir::Binary::Kind::kAdd, "OpIAdd"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kSubtract,
+ "OpISub"}));
+INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest_Binary_U32,
+ Arithmetic,
+ testing::Values(BinaryTestCase{kU32, ir::Binary::Kind::kAdd, "OpIAdd"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kSubtract,
+ "OpISub"}));
+INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest_Binary_F32,
+ Arithmetic,
+ testing::Values(BinaryTestCase{kF32, ir::Binary::Kind::kAdd, "OpFAdd"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kSubtract,
+ "OpFSub"}));
+INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest_Binary_F16,
+ Arithmetic,
+ testing::Values(BinaryTestCase{kF16, ir::Binary::Kind::kAdd, "OpFAdd"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kSubtract,
+ "OpFSub"}));
-TEST_F(SpvGeneratorImplTest, Binary_Add_F32) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+using Bitwise = BinaryInstructionTest;
+TEST_P(Bitwise, Scalar) {
+ auto params = GetParam();
+
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(
- utils::Vector{b.Add(mod.types.Get<type::F32>(), b.Constant(1_f), b.Constant(2_f)),
+ utils::Vector{b.CreateBinary(params.kind, MakeScalarType(params.type),
+ MakeScalarValue(params.type), MakeScalarValue(params.type)),
b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%6 = OpTypeFloat 32
-%7 = OpConstant %6 1
-%8 = OpConstant %6 2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpFAdd %6 %7 %8
-OpReturn
-OpFunctionEnd
-)");
+ EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
+TEST_P(Bitwise, Vector) {
+ auto params = GetParam();
-TEST_F(SpvGeneratorImplTest, Binary_Sub_I32) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(
- utils::Vector{b.Subtract(mod.types.Get<type::I32>(), b.Constant(1_i), b.Constant(2_i)),
+ utils::Vector{b.CreateBinary(params.kind, MakeVectorType(params.type),
+ MakeVectorValue(params.type), MakeVectorValue(params.type)),
+
b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%6 = OpTypeInt 32 1
-%7 = OpConstant %6 1
-%8 = OpConstant %6 2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpISub %6 %7 %8
-OpReturn
-OpFunctionEnd
-)");
+ EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest_Binary_I32,
+ Bitwise,
+ testing::Values(BinaryTestCase{kI32, ir::Binary::Kind::kAnd, "OpBitwiseAnd"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kOr, "OpBitwiseOr"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kXor, "OpBitwiseXor"}));
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest_Binary_U32,
+ Bitwise,
+ testing::Values(BinaryTestCase{kU32, ir::Binary::Kind::kAnd, "OpBitwiseAnd"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kOr, "OpBitwiseOr"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kXor, "OpBitwiseXor"}));
-TEST_F(SpvGeneratorImplTest, Binary_Sub_U32) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+using Comparison = BinaryInstructionTest;
+TEST_P(Comparison, Scalar) {
+ auto params = GetParam();
+
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(
- utils::Vector{b.Subtract(mod.types.Get<type::U32>(), b.Constant(1_u), b.Constant(2_u)),
+ utils::Vector{b.CreateBinary(params.kind, mod.Types().bool_(), MakeScalarValue(params.type),
+ MakeScalarValue(params.type)),
b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%6 = OpTypeInt 32 0
-%7 = OpConstant %6 1
-%8 = OpConstant %6 2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpISub %6 %7 %8
-OpReturn
-OpFunctionEnd
-)");
+ EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
+TEST_P(Comparison, Vector) {
+ auto params = GetParam();
-TEST_F(SpvGeneratorImplTest, Binary_Sub_F32) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(
- utils::Vector{b.Subtract(mod.types.Get<type::F32>(), b.Constant(1_f), b.Constant(2_f)),
+ utils::Vector{b.CreateBinary(params.kind, mod.Types().vec2(mod.Types().bool_()),
+ MakeVectorValue(params.type), MakeVectorValue(params.type)),
+
b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%6 = OpTypeFloat 32
-%7 = OpConstant %6 1
-%8 = OpConstant %6 2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpFSub %6 %7 %8
-OpReturn
-OpFunctionEnd
-)");
+ EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
-
-TEST_F(SpvGeneratorImplTest, Binary_Sub_Vec2i) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
- auto* lhs = b.create<constant::Composite>(mod.types.vec2(mod.types.Get<type::I32>()),
- utils::Vector{b.I32(42), b.I32(-1)}, false, false);
- auto* rhs = b.create<constant::Composite>(mod.types.vec2(mod.types.Get<type::I32>()),
- utils::Vector{b.I32(0), b.I32(-43)}, false, false);
- func->StartTarget()->SetInstructions(
- utils::Vector{b.Subtract(mod.types.Get<type::Vector>(mod.types.Get<type::I32>(), 2u),
- b.Constant(lhs), b.Constant(rhs)),
- b.Branch(func->EndTarget())});
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%7 = OpTypeInt 32 1
-%6 = OpTypeVector %7 2
-%9 = OpConstant %7 42
-%10 = OpConstant %7 -1
-%8 = OpConstantComposite %6 %9 %10
-%12 = OpConstant %7 0
-%13 = OpConstant %7 -43
-%11 = OpConstantComposite %6 %12 %13
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpISub %6 %8 %11
-OpReturn
-OpFunctionEnd
-)");
-}
-
-TEST_F(SpvGeneratorImplTest, Binary_Sub_Vec4f) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
- auto* lhs = b.create<constant::Composite>(
- mod.types.vec4(mod.types.Get<type::F32>()),
- utils::Vector{b.F32(42), b.F32(-1), b.F32(0), b.F32(1.25)}, false, false);
- auto* rhs = b.create<constant::Composite>(
- mod.types.vec4(mod.types.Get<type::F32>()),
- utils::Vector{b.F32(0), b.F32(1.25), b.F32(-42), b.F32(1)}, false, false);
- func->StartTarget()->SetInstructions(
- utils::Vector{b.Subtract(mod.types.Get<type::Vector>(mod.types.Get<type::F32>(), 4u),
- b.Constant(lhs), b.Constant(rhs)),
- b.Branch(func->EndTarget())});
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%7 = OpTypeFloat 32
-%6 = OpTypeVector %7 4
-%9 = OpConstant %7 42
-%10 = OpConstant %7 -1
-%11 = OpConstant %7 0
-%12 = OpConstant %7 1.25
-%8 = OpConstantComposite %6 %9 %10 %11 %12
-%14 = OpConstant %7 -42
-%15 = OpConstant %7 1
-%13 = OpConstantComposite %6 %11 %12 %14 %15
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpFSub %6 %8 %13
-OpReturn
-OpFunctionEnd
-)");
-}
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest_Binary_I32,
+ Comparison,
+ testing::Values(BinaryTestCase{kI32, ir::Binary::Kind::kEqual, "OpIEqual"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kNotEqual, "OpINotEqual"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kGreaterThan, "OpSGreaterThan"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kGreaterThanEqual,
+ "OpSGreaterThanEqual"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kLessThan, "OpSLessThan"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kLessThanEqual, "OpSLessThanEqual"}));
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest_Binary_U32,
+ Comparison,
+ testing::Values(BinaryTestCase{kU32, ir::Binary::Kind::kEqual, "OpIEqual"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kNotEqual, "OpINotEqual"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kGreaterThan, "OpUGreaterThan"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kGreaterThanEqual,
+ "OpUGreaterThanEqual"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kLessThan, "OpULessThan"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kLessThanEqual, "OpULessThanEqual"}));
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest_Binary_F32,
+ Comparison,
+ testing::Values(BinaryTestCase{kF32, ir::Binary::Kind::kEqual, "OpFOrdEqual"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kNotEqual, "OpFOrdNotEqual"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kGreaterThan, "OpFOrdGreaterThan"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kGreaterThanEqual,
+ "OpFOrdGreaterThanEqual"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kLessThan, "OpFOrdLessThan"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kLessThanEqual, "OpFOrdLessThanEqual"}));
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest_Binary_F16,
+ Comparison,
+ testing::Values(BinaryTestCase{kF16, ir::Binary::Kind::kEqual, "OpFOrdEqual"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kNotEqual, "OpFOrdNotEqual"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kGreaterThan, "OpFOrdGreaterThan"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kGreaterThanEqual,
+ "OpFOrdGreaterThanEqual"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kLessThan, "OpFOrdLessThan"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kLessThanEqual, "OpFOrdLessThanEqual"}));
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest_Binary_Bool,
+ Comparison,
+ testing::Values(BinaryTestCase{kBool, ir::Binary::Kind::kEqual, "OpLogicalEqual"},
+ BinaryTestCase{kBool, ir::Binary::Kind::kNotEqual, "OpLogicalNotEqual"}));
TEST_F(SpvGeneratorImplTest, Binary_Chain) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
- auto* a = b.Subtract(mod.types.Get<type::I32>(), b.Constant(1_i), b.Constant(2_i));
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* a = b.Subtract(mod.Types().i32(), b.Constant(1_i), b.Constant(2_i));
func->StartTarget()->SetInstructions(
- utils::Vector{a, b.Add(mod.types.Get<type::I32>(), a, a), b.Branch(func->EndTarget())});
+ utils::Vector{a, b.Add(mod.Types().i32(), a, a), b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc
index 75775ff..7cb240a 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc
@@ -63,9 +63,10 @@
}
TEST_F(SpvGeneratorImplTest, Constant_Vec4Bool) {
- auto* v = b.create<constant::Composite>(
- mod.types.vec4(mod.types.Get<type::Bool>()),
- utils::Vector{b.Bool(true), b.Bool(false), b.Bool(false), b.Bool(true)}, false, true);
+ auto const_bool = [&](bool val) { return mod.constant_values.Get(val); };
+ auto* v = mod.constant_values.Composite(
+ mod.Types().vec4(mod.Types().bool_()),
+ utils::Vector{const_bool(true), const_bool(false), const_bool(false), const_bool(true)});
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeBool
@@ -77,8 +78,9 @@
}
TEST_F(SpvGeneratorImplTest, Constant_Vec2i) {
- auto* v = b.create<constant::Composite>(mod.types.vec2(mod.types.Get<type::I32>()),
- utils::Vector{b.I32(42), b.I32(-1)}, false, false);
+ auto const_i32 = [&](float val) { return mod.constant_values.Get(i32(val)); };
+ auto* v = mod.constant_values.Composite(mod.Types().vec2(mod.Types().i32()),
+ utils::Vector{const_i32(42), const_i32(-1)});
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 1
%2 = OpTypeVector %3 2
@@ -89,9 +91,10 @@
}
TEST_F(SpvGeneratorImplTest, Constant_Vec3u) {
- auto* v = b.create<constant::Composite>(mod.types.vec3(mod.types.Get<type::U32>()),
- utils::Vector{b.U32(42), b.U32(0), b.U32(4000000000)},
- false, true);
+ auto const_u32 = [&](float val) { return mod.constant_values.Get(u32(val)); };
+ auto* v = mod.constant_values.Composite(
+ mod.Types().vec3(mod.Types().u32()),
+ utils::Vector{const_u32(42), const_u32(0), const_u32(4000000000)});
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 0
%2 = OpTypeVector %3 3
@@ -103,9 +106,10 @@
}
TEST_F(SpvGeneratorImplTest, Constant_Vec4f) {
- auto* v = b.create<constant::Composite>(
- mod.types.vec4(mod.types.Get<type::F32>()),
- utils::Vector{b.F32(42), b.F32(0), b.F32(0.25), b.F32(-1)}, false, true);
+ auto const_f32 = [&](float val) { return mod.constant_values.Get(f32(val)); };
+ auto* v = mod.constant_values.Composite(
+ mod.Types().vec4(mod.Types().f32()),
+ utils::Vector{const_f32(42), const_f32(0), const_f32(0.25), const_f32(-1)});
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeFloat 32
%2 = OpTypeVector %3 4
@@ -118,8 +122,9 @@
}
TEST_F(SpvGeneratorImplTest, Constant_Vec2h) {
- auto* v = b.create<constant::Composite>(mod.types.vec2(mod.types.Get<type::F16>()),
- utils::Vector{b.F16(42), b.F16(0.25)}, false, false);
+ auto const_f16 = [&](float val) { return mod.constant_values.Get(f16(val)); };
+ auto* v = mod.constant_values.Composite(mod.Types().vec2(mod.Types().f16()),
+ utils::Vector{const_f16(42), const_f16(0.25)});
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeFloat 16
%2 = OpTypeVector %3 2
@@ -129,6 +134,69 @@
)");
}
+TEST_F(SpvGeneratorImplTest, Constant_Mat2x3f) {
+ auto const_f32 = [&](float val) { return mod.constant_values.Get(f32(val)); };
+ auto* f32 = mod.Types().f32();
+ auto* v = mod.constant_values.Composite(
+ mod.Types().mat2x3(f32),
+ utils::Vector{
+ mod.constant_values.Composite(
+ mod.Types().vec3(f32),
+ utils::Vector{const_f32(42), const_f32(-1), const_f32(0.25)}),
+ mod.constant_values.Composite(
+ mod.Types().vec3(f32),
+ utils::Vector{const_f32(-42), const_f32(0), const_f32(-0.25)}),
+ });
+ generator_.Constant(b.Constant(v));
+ EXPECT_EQ(DumpTypes(), R"(%4 = OpTypeFloat 32
+%3 = OpTypeVector %4 3
+%2 = OpTypeMatrix %3 2
+%6 = OpConstant %4 42
+%7 = OpConstant %4 -1
+%8 = OpConstant %4 0.25
+%5 = OpConstantComposite %3 %6 %7 %8
+%10 = OpConstant %4 -42
+%11 = OpConstant %4 0
+%12 = OpConstant %4 -0.25
+%9 = OpConstantComposite %3 %10 %11 %12
+%1 = OpConstantComposite %2 %5 %9
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, Constant_Mat4x2h) {
+ auto const_f16 = [&](float val) { return mod.constant_values.Get(f16(val)); };
+ auto* f16 = mod.Types().f16();
+ auto* v = mod.constant_values.Composite(
+ mod.Types().mat4x2(f16),
+ utils::Vector{
+ mod.constant_values.Composite(mod.Types().vec2(f16),
+ utils::Vector{const_f16(42), const_f16(-1)}),
+ mod.constant_values.Composite(mod.Types().vec2(f16),
+ utils::Vector{const_f16(0), const_f16(0.25)}),
+ mod.constant_values.Composite(mod.Types().vec2(f16),
+ utils::Vector{const_f16(-42), const_f16(1)}),
+ mod.constant_values.Composite(mod.Types().vec2(f16),
+ utils::Vector{const_f16(0.5), const_f16(-0)}),
+ });
+ generator_.Constant(b.Constant(v));
+ EXPECT_EQ(DumpTypes(), R"(%4 = OpTypeFloat 16
+%3 = OpTypeVector %4 2
+%2 = OpTypeMatrix %3 4
+%6 = OpConstant %4 0x1.5p+5
+%7 = OpConstant %4 -0x1p+0
+%5 = OpConstantComposite %3 %6 %7
+%9 = OpConstant %4 0x0p+0
+%10 = OpConstant %4 0x1p-2
+%8 = OpConstantComposite %3 %9 %10
+%12 = OpConstant %4 -0x1.5p+5
+%13 = OpConstant %4 0x1p+0
+%11 = OpConstantComposite %3 %12 %13
+%15 = OpConstant %4 0x1p-1
+%14 = OpConstantComposite %3 %15 %9
+%1 = OpConstantComposite %2 %5 %8 %11 %14
+)");
+}
+
// Test that we do not emit the same constant more than once.
TEST_F(SpvGeneratorImplTest, Constant_Deduplicate) {
generator_.Constant(b.Constant(i32(42)));
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
index 7221963..22b0594 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
@@ -18,7 +18,7 @@
namespace {
TEST_F(SpvGeneratorImplTest, Function_Empty) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(utils::Vector{b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
@@ -34,7 +34,7 @@
// Test that we do not emit the same function type more than once.
TEST_F(SpvGeneratorImplTest, Function_DeduplicateType) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(utils::Vector{b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
@@ -46,7 +46,7 @@
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Compute) {
- auto* func = b.CreateFunction("main", mod.types.Get<type::Void>(),
+ auto* func = b.CreateFunction("main", mod.Types().void_(),
ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
func->StartTarget()->SetInstructions(utils::Vector{b.Branch(func->EndTarget())});
@@ -64,8 +64,8 @@
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Fragment) {
- auto* func = b.CreateFunction("main", mod.types.Get<type::Void>(),
- ir::Function::PipelineStage::kFragment);
+ auto* func =
+ b.CreateFunction("main", mod.Types().void_(), ir::Function::PipelineStage::kFragment);
func->StartTarget()->SetInstructions(utils::Vector{b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
@@ -83,7 +83,7 @@
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Vertex) {
auto* func =
- b.CreateFunction("main", mod.types.Get<type::Void>(), ir::Function::PipelineStage::kVertex);
+ b.CreateFunction("main", mod.Types().void_(), ir::Function::PipelineStage::kVertex);
func->StartTarget()->SetInstructions(utils::Vector{b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
@@ -99,16 +99,16 @@
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Multiple) {
- auto* f1 = b.CreateFunction("main1", mod.types.Get<type::Void>(),
- ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
+ auto* f1 = b.CreateFunction("main1", mod.Types().void_(), ir::Function::PipelineStage::kCompute,
+ {{32, 4, 1}});
f1->StartTarget()->SetInstructions(utils::Vector{b.Branch(f1->EndTarget())});
- auto* f2 = b.CreateFunction("main2", mod.types.Get<type::Void>(),
- ir::Function::PipelineStage::kCompute, {{8, 2, 16}});
+ auto* f2 = b.CreateFunction("main2", mod.Types().void_(), ir::Function::PipelineStage::kCompute,
+ {{8, 2, 16}});
f2->StartTarget()->SetInstructions(utils::Vector{b.Branch(f2->EndTarget())});
- auto* f3 = b.CreateFunction("main3", mod.types.Get<type::Void>(),
- ir::Function::PipelineStage::kFragment);
+ auto* f3 =
+ b.CreateFunction("main3", mod.Types().void_(), ir::Function::PipelineStage::kFragment);
f3->StartTarget()->SetInstructions(utils::Vector{b.Branch(f3->EndTarget())});
generator_.EmitFunction(f1);
@@ -140,5 +140,118 @@
)");
}
+TEST_F(SpvGeneratorImplTest, Function_ReturnValue) {
+ auto* func = b.CreateFunction("foo", mod.Types().i32());
+ func->StartTarget()->SetInstructions(
+ utils::Vector{b.Branch(func->EndTarget(), utils::Vector{b.Constant(i32(42))})});
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeInt 32 1
+%3 = OpTypeFunction %2
+%5 = OpConstant %2 42
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpReturnValue %5
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, Function_Parameters) {
+ auto* i32 = mod.Types().i32();
+ auto* x = b.FunctionParam(i32);
+ auto* y = b.FunctionParam(i32);
+ auto* result = b.Add(i32, x, y);
+ auto* func = b.CreateFunction("foo", i32);
+ func->SetParams(utils::Vector{x, y});
+ func->StartTarget()->SetInstructions(
+ utils::Vector{result, b.Branch(func->EndTarget(), utils::Vector{result})});
+ mod.SetName(x, "x");
+ mod.SetName(y, "y");
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+OpName %3 "x"
+OpName %4 "y"
+%2 = OpTypeInt 32 1
+%5 = OpTypeFunction %2 %2 %2
+%1 = OpFunction %2 None %5
+%3 = OpFunctionParameter %2
+%4 = OpFunctionParameter %2
+%6 = OpLabel
+%7 = OpIAdd %2 %3 %4
+OpReturnValue %7
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, Function_Call) {
+ auto* i32_ty = mod.Types().i32();
+ auto* x = b.FunctionParam(i32_ty);
+ auto* y = b.FunctionParam(i32_ty);
+ auto* result = b.Add(i32_ty, x, y);
+ auto* foo = b.CreateFunction("foo", i32_ty);
+ foo->SetParams(utils::Vector{x, y});
+ foo->StartTarget()->SetInstructions(
+ utils::Vector{result, b.Branch(foo->EndTarget(), utils::Vector{result})});
+
+ auto* bar = b.CreateFunction("bar", mod.Types().void_());
+ bar->StartTarget()->SetInstructions(
+ utils::Vector{b.UserCall(i32_ty, mod.symbols.Get("foo"),
+ utils::Vector{b.Constant(i32(2)), b.Constant(i32(3))}),
+ b.Branch(bar->EndTarget())});
+
+ generator_.EmitFunction(foo);
+ generator_.EmitFunction(bar);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+OpName %8 "bar"
+%2 = OpTypeInt 32 1
+%5 = OpTypeFunction %2 %2 %2
+%9 = OpTypeVoid
+%10 = OpTypeFunction %9
+%13 = OpConstant %2 2
+%14 = OpConstant %2 3
+%1 = OpFunction %2 None %5
+%3 = OpFunctionParameter %2
+%4 = OpFunctionParameter %2
+%6 = OpLabel
+%7 = OpIAdd %2 %3 %4
+OpReturnValue %7
+OpFunctionEnd
+%8 = OpFunction %9 None %10
+%11 = OpLabel
+%12 = OpFunctionCall %2 %1 %13 %14
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, Function_Call_Void) {
+ auto* foo = b.CreateFunction("foo", mod.Types().void_());
+ foo->StartTarget()->SetInstructions(utils::Vector{b.Branch(foo->EndTarget())});
+
+ auto* bar = b.CreateFunction("bar", mod.Types().void_());
+ bar->StartTarget()->SetInstructions(
+ utils::Vector{b.UserCall(mod.Types().void_(), mod.symbols.Get("foo"), utils::Empty),
+ b.Branch(bar->EndTarget())});
+
+ generator_.EmitFunction(foo);
+ generator_.EmitFunction(bar);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+OpName %5 "bar"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpReturn
+OpFunctionEnd
+%5 = OpFunction %2 None %3
+%6 = OpLabel
+%7 = OpFunctionCall %2 %1
+OpReturn
+OpFunctionEnd
+)");
+}
+
} // namespace
} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
index dc312d1..ff7defc 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
@@ -20,7 +20,7 @@
namespace {
TEST_F(SpvGeneratorImplTest, If_TrueEmpty_FalseEmpty) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
auto* i = b.CreateIf(b.Constant(true));
i->True()->SetInstructions(utils::Vector{b.Branch(i->Merge())});
@@ -46,7 +46,7 @@
}
TEST_F(SpvGeneratorImplTest, If_FalseEmpty) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
auto* i = b.CreateIf(b.Constant(true));
i->False()->SetInstructions(utils::Vector{b.Branch(i->Merge())});
@@ -54,7 +54,7 @@
auto* true_block = i->True();
true_block->SetInstructions(utils::Vector{
- b.Add(mod.types.Get<type::I32>(), b.Constant(1_i), b.Constant(1_i)), b.Branch(i->Merge())});
+ b.Add(mod.Types().i32(), b.Constant(1_i), b.Constant(1_i)), b.Branch(i->Merge())});
func->StartTarget()->SetInstructions(utils::Vector{i});
@@ -80,7 +80,7 @@
}
TEST_F(SpvGeneratorImplTest, If_TrueEmpty) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
auto* i = b.CreateIf(b.Constant(true));
i->True()->SetInstructions(utils::Vector{b.Branch(i->Merge())});
@@ -88,7 +88,7 @@
auto* false_block = i->False();
false_block->SetInstructions(utils::Vector{
- b.Add(mod.types.Get<type::I32>(), b.Constant(1_i), b.Constant(1_i)), b.Branch(i->Merge())});
+ b.Add(mod.Types().i32(), b.Constant(1_i), b.Constant(1_i)), b.Branch(i->Merge())});
func->StartTarget()->SetInstructions(utils::Vector{i});
@@ -114,7 +114,7 @@
}
TEST_F(SpvGeneratorImplTest, If_BothBranchesReturn) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
auto* i = b.CreateIf(b.Constant(true));
i->True()->SetInstructions(utils::Vector{b.Branch(func->EndTarget())});
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
index 38bf1f0..80e69c3 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
@@ -25,43 +25,43 @@
namespace {
TEST_F(SpvGeneratorImplTest, Type_Void) {
- auto id = generator_.Type(mod.types.Get<type::Void>());
+ auto id = generator_.Type(mod.Types().void_());
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(), "%1 = OpTypeVoid\n");
}
TEST_F(SpvGeneratorImplTest, Type_Bool) {
- auto id = generator_.Type(mod.types.Get<type::Bool>());
+ auto id = generator_.Type(mod.Types().bool_());
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(), "%1 = OpTypeBool\n");
}
TEST_F(SpvGeneratorImplTest, Type_I32) {
- auto id = generator_.Type(mod.types.Get<type::I32>());
+ auto id = generator_.Type(mod.Types().i32());
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(), "%1 = OpTypeInt 32 1\n");
}
TEST_F(SpvGeneratorImplTest, Type_U32) {
- auto id = generator_.Type(mod.types.Get<type::U32>());
+ auto id = generator_.Type(mod.Types().u32());
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(), "%1 = OpTypeInt 32 0\n");
}
TEST_F(SpvGeneratorImplTest, Type_F32) {
- auto id = generator_.Type(mod.types.Get<type::F32>());
+ auto id = generator_.Type(mod.Types().f32());
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(), "%1 = OpTypeFloat 32\n");
}
TEST_F(SpvGeneratorImplTest, Type_F16) {
- auto id = generator_.Type(mod.types.Get<type::F16>());
+ auto id = generator_.Type(mod.Types().f16());
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(), "%1 = OpTypeFloat 16\n");
}
TEST_F(SpvGeneratorImplTest, Type_Vec2i) {
- auto* vec = b.ir.types.Get<type::Vector>(b.ir.types.Get<type::I32>(), 2u);
+ auto* vec = mod.Types().Get<type::Vector>(mod.Types().i32(), 2u);
auto id = generator_.Type(vec);
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
@@ -70,7 +70,7 @@
}
TEST_F(SpvGeneratorImplTest, Type_Vec3u) {
- auto* vec = b.ir.types.Get<type::Vector>(b.ir.types.Get<type::U32>(), 3u);
+ auto* vec = mod.Types().Get<type::Vector>(mod.Types().u32(), 3u);
auto id = generator_.Type(vec);
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
@@ -79,7 +79,7 @@
}
TEST_F(SpvGeneratorImplTest, Type_Vec4f) {
- auto* vec = b.ir.types.Get<type::Vector>(b.ir.types.Get<type::F32>(), 4u);
+ auto* vec = mod.Types().Get<type::Vector>(mod.Types().f32(), 4u);
auto id = generator_.Type(vec);
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
@@ -88,7 +88,7 @@
}
TEST_F(SpvGeneratorImplTest, Type_Vec4h) {
- auto* vec = b.ir.types.Get<type::Vector>(b.ir.types.Get<type::F16>(), 2u);
+ auto* vec = mod.Types().Get<type::Vector>(mod.Types().f16(), 2u);
auto id = generator_.Type(vec);
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
@@ -97,7 +97,7 @@
}
TEST_F(SpvGeneratorImplTest, Type_Vec4Bool) {
- auto* vec = b.ir.types.Get<type::Vector>(b.ir.types.Get<type::Bool>(), 4u);
+ auto* vec = mod.Types().Get<type::Vector>(mod.Types().bool_(), 4u);
auto id = generator_.Type(vec);
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
@@ -105,13 +105,33 @@
"%1 = OpTypeVector %2 4\n");
}
+TEST_F(SpvGeneratorImplTest, Type_Mat2x3f) {
+ auto* vec = mod.Types().mat2x3(mod.Types().f32());
+ auto id = generator_.Type(vec);
+ EXPECT_EQ(id, 1u);
+ EXPECT_EQ(DumpTypes(),
+ "%3 = OpTypeFloat 32\n"
+ "%2 = OpTypeVector %3 3\n"
+ "%1 = OpTypeMatrix %2 2\n");
+}
+
+TEST_F(SpvGeneratorImplTest, Type_Mat4x2h) {
+ auto* vec = mod.Types().mat4x2(mod.Types().f16());
+ auto id = generator_.Type(vec);
+ EXPECT_EQ(id, 1u);
+ EXPECT_EQ(DumpTypes(),
+ "%3 = OpTypeFloat 16\n"
+ "%2 = OpTypeVector %3 2\n"
+ "%1 = OpTypeMatrix %2 4\n");
+}
+
// Test that we can emit multiple types.
// Includes types with the same opcode but different parameters.
TEST_F(SpvGeneratorImplTest, Type_Multiple) {
- EXPECT_EQ(generator_.Type(mod.types.Get<type::I32>()), 1u);
- EXPECT_EQ(generator_.Type(mod.types.Get<type::U32>()), 2u);
- EXPECT_EQ(generator_.Type(mod.types.Get<type::F32>()), 3u);
- EXPECT_EQ(generator_.Type(mod.types.Get<type::F16>()), 4u);
+ EXPECT_EQ(generator_.Type(mod.Types().i32()), 1u);
+ EXPECT_EQ(generator_.Type(mod.Types().u32()), 2u);
+ EXPECT_EQ(generator_.Type(mod.Types().f32()), 3u);
+ EXPECT_EQ(generator_.Type(mod.Types().f16()), 4u);
EXPECT_EQ(DumpTypes(), R"(%1 = OpTypeInt 32 1
%2 = OpTypeInt 32 0
%3 = OpTypeFloat 32
@@ -121,7 +141,7 @@
// Test that we do not emit the same type more than once.
TEST_F(SpvGeneratorImplTest, Type_Deduplicate) {
- auto* i32 = mod.types.Get<type::I32>();
+ auto* i32 = mod.Types().i32();
EXPECT_EQ(generator_.Type(i32), 1u);
EXPECT_EQ(generator_.Type(i32), 1u);
EXPECT_EQ(generator_.Type(i32), 1u);
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
index 6c9fbc3..4a42bdc 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
@@ -21,10 +21,10 @@
namespace {
TEST_F(SpvGeneratorImplTest, FunctionVar_NoInit) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
- auto* ty = mod.types.Get<type::Pointer>(
- mod.types.Get<type::I32>(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+ auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite);
func->StartTarget()->SetInstructions(utils::Vector{b.Declare(ty), b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
@@ -42,10 +42,10 @@
}
TEST_F(SpvGeneratorImplTest, FunctionVar_WithInit) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
- auto* ty = mod.types.Get<type::Pointer>(
- mod.types.Get<type::I32>(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+ auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite);
auto* v = b.Declare(ty);
v->SetInitializer(b.Constant(42_i));
@@ -68,10 +68,10 @@
}
TEST_F(SpvGeneratorImplTest, FunctionVar_Name) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
- auto* ty = mod.types.Get<type::Pointer>(
- mod.types.Get<type::I32>(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+ auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite);
auto* v = b.Declare(ty);
func->StartTarget()->SetInstructions(utils::Vector{v, b.Branch(func->EndTarget())});
mod.SetName(v, "myvar");
@@ -92,10 +92,10 @@
}
TEST_F(SpvGeneratorImplTest, FunctionVar_DeclInsideBlock) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
- auto* ty = mod.types.Get<type::Pointer>(
- mod.types.Get<type::I32>(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+ auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite);
auto* v = b.Declare(ty);
v->SetInitializer(b.Constant(42_i));
@@ -132,11 +132,11 @@
}
TEST_F(SpvGeneratorImplTest, FunctionVar_Load) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
- auto* store_ty = mod.types.Get<type::I32>();
- auto* ty = mod.types.Get<type::Pointer>(store_ty, builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite);
+ auto* store_ty = mod.Types().i32();
+ auto* ty = mod.Types().Get<type::Pointer>(store_ty, builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite);
auto* v = b.Declare(ty);
func->StartTarget()->SetInstructions(utils::Vector{v, b.Load(v), b.Branch(func->EndTarget())});
@@ -156,10 +156,10 @@
}
TEST_F(SpvGeneratorImplTest, FunctionVar_Store) {
- auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
- auto* ty = mod.types.Get<type::Pointer>(
- mod.types.Get<type::I32>(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+ auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite);
auto* v = b.Declare(ty);
func->StartTarget()->SetInstructions(
utils::Vector{v, b.Store(v, b.Constant(42_i)), b.Branch(func->EndTarget())});