Import Tint changes from Dawn
Changes:
- 5ccafa4dece04658608a39c5eaeddd83d00fc75d [ir] Add a validator by dan sinclair <dsinclair@chromium.org>
- 68bfd6db9d27dd28104026c4130d6e7ab832bd4c [tint] Add assertions for iterator invalidation by James Price <jrprice@google.com>
- 5682a629756e1efe36a1179505cc9877e6abcc6a [ir][spirv-writer] Implement constant arrays by James Price <jrprice@google.com>
- 41d7513507a99e17ff82acdb11ab85014eea96d9 [ir][spirv-writer] Add type manager alias in tests by James Price <jrprice@google.com>
- 319dd5c4a66b93af0845eed119a8437778b79b98 [ir] Rename Instruction::Replace to ReplaceWith by James Price <jrprice@google.com>
- 0b454a372c9acc37c36414fea2b21577a4ec78c1 [ir] Fix type for accessor on LHS of assignment by James Price <jrprice@google.com>
- ef390b9d12369093d0b8cf8d75d381b824804dbc [ir] Cleanup ir test helper. by dan sinclair <dsinclair@chromium.org>
- f4a3d2220d9422c8cb6bfbd1e1edc8021de6f4d5 [ir] Fix access disassembly. by dan sinclair <dsinclair@chromium.org>
- 9c2369afe0b478746070fb835cc7bd50467257e0 [ir] Add multi-element swizzle. by dan sinclair <dsinclair@chromium.org>
- b8f7ad179867715d25421f6629a4a435552af141 [tint][reader][wgsl]: Improve diagnostics for template er... by Ben Clayton <bclayton@google.com>
- 8037f42c0756cf137702a90a31d685de0a522716 [ir] Fix let accessor tests by dan sinclair <dsinclair@chromium.org>
- d39a400f98d782c176381d1bdf505092ddd4fc16 [tint][reader][wgsl]: Remove source from Expect by Ben Clayton <bclayton@google.com>
- fb4d0ae372378bdde98288be43edd450a9fcdd8c [ir] Add accessors by dan sinclair <dsinclair@chromium.org>
- 0f8b7043dda6d7b3d7258c39e0de1a8114155355 [tint][reader][wgsl]: Unify expression list parsing. by Ben Clayton <bclayton@google.com>
GitOrigin-RevId: 5ccafa4dece04658608a39c5eaeddd83d00fc75d
Change-Id: Id20cebb88c89d159f09bca11255937f9ee4ff59a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/135860
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 11f6a78..afd91d6 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -1209,6 +1209,8 @@
libtint_source_set("libtint_ir_src") {
sources = [
+ "ir/access.cc",
+ "ir/access.h",
"ir/binary.cc",
"ir/binary.h",
"ir/binding_point.h",
@@ -1269,12 +1271,16 @@
"ir/store.h",
"ir/switch.cc",
"ir/switch.h",
+ "ir/swizzle.cc",
+ "ir/swizzle.h",
"ir/transform/transform.cc",
"ir/transform/transform.h",
"ir/unary.cc",
"ir/unary.h",
"ir/user_call.cc",
"ir/user_call.h",
+ "ir/validate.cc",
+ "ir/validate.h",
"ir/value.cc",
"ir/value.h",
"ir/var.cc",
@@ -2284,6 +2290,7 @@
"ir/block_test.cc",
"ir/constant_test.cc",
"ir/discard_test.cc",
+ "ir/from_program_accessor_test.cc",
"ir/from_program_binary_test.cc",
"ir/from_program_builtin_test.cc",
"ir/from_program_call_test.cc",
@@ -2295,13 +2302,15 @@
"ir/from_program_unary_test.cc",
"ir/from_program_var_test.cc",
"ir/instruction_test.cc",
+ "ir/ir_test_helper.h",
"ir/load_test.cc",
"ir/module_test.cc",
+ "ir/program_test_helper.h",
"ir/store_test.cc",
- "ir/test_helper.h",
"ir/to_program_roundtrip_test.cc",
"ir/transform/add_empty_entry_point_test.cc",
"ir/unary_test.cc",
+ "ir/validate_test.cc",
]
deps = [
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 40dd3ee..cdcd22a 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -716,6 +716,8 @@
if(${TINT_BUILD_IR})
list(APPEND TINT_LIB_SRCS
+ ir/access.cc
+ ir/access.h
ir/binary.cc
ir/binary.h
ir/binding_point.h
@@ -778,12 +780,16 @@
ir/store.h
ir/switch.cc
ir/switch.h
+ ir/swizzle.cc
+ ir/swizzle.h
ir/to_program.cc
ir/to_program.h
ir/unary.cc
ir/unary.h
ir/user_call.cc
ir/user_call.h
+ ir/validate.cc
+ ir/validate.h
ir/value.cc
ir/value.h
ir/var.cc
@@ -1487,6 +1493,7 @@
ir/block_test.cc
ir/constant_test.cc
ir/discard_test.cc
+ ir/from_program_accessor_test.cc
ir/from_program_binary_test.cc
ir/from_program_builtin_test.cc
ir/from_program_call_test.cc
@@ -1498,12 +1505,14 @@
ir/from_program_unary_test.cc
ir/from_program_var_test.cc
ir/instruction_test.cc
+ ir/ir_test_helper.h
ir/load_test.cc
ir/module_test.cc
+ ir/program_test_helper.h
ir/store_test.cc
- ir/test_helper.h
ir/transform/add_empty_entry_point_test.cc
ir/unary_test.cc
+ ir/validate_test.cc
)
endif()
diff --git a/src/tint/ir/access.cc b/src/tint/ir/access.cc
new file mode 100644
index 0000000..d88ded8
--- /dev/null
+++ b/src/tint/ir/access.cc
@@ -0,0 +1,34 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/access.h"
+
+#include <utility>
+
+#include "src/tint/debug.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::Access);
+
+namespace tint::ir {
+
+//! @cond Doxygen_Suppress
+Access::Access(const type::Type* ty, Value* object, utils::VectorRef<Value*> indices)
+ : result_type_(ty), object_(object), indices_(std::move(indices)) {
+ object_->AddUsage(this);
+}
+
+Access::~Access() = default;
+//! @endcond
+
+} // namespace tint::ir
diff --git a/src/tint/ir/access.h b/src/tint/ir/access.h
new file mode 100644
index 0000000..ea03d65
--- /dev/null
+++ b/src/tint/ir/access.h
@@ -0,0 +1,50 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_ACCESS_H_
+#define SRC_TINT_IR_ACCESS_H_
+
+#include "src/tint/ir/instruction.h"
+#include "src/tint/utils/castable.h"
+
+namespace tint::ir {
+
+/// An access instruction in the IR.
+class Access : public utils::Castable<Access, Instruction> {
+ public:
+ /// Constructor
+ /// @param result_type the result type
+ /// @param object the accessor object
+ /// @param indices the indices to access
+ Access(const type::Type* result_type, Value* object, utils::VectorRef<Value*> indices);
+ ~Access() override;
+
+ /// @returns the type of the value
+ const type::Type* Type() const override { return result_type_; }
+
+ /// @returns the object used for the access
+ Value* Object() const { return object_; }
+
+ /// @returns the accessor indices
+ utils::VectorRef<Value*> Indices() const { return indices_; }
+
+ private:
+ const type::Type* result_type_ = nullptr;
+ Value* object_ = nullptr;
+ utils::Vector<Value*, 1> indices_;
+};
+
+} // namespace tint::ir
+
+#endif // SRC_TINT_IR_ACCESS_H_
diff --git a/src/tint/ir/binary_test.cc b/src/tint/ir/binary_test.cc
index 4c4fab8..be78a9d 100644
--- a/src/tint/ir/binary_test.cc
+++ b/src/tint/ir/binary_test.cc
@@ -14,19 +14,16 @@
#include "src/tint/ir/builder.h"
#include "src/tint/ir/instruction.h"
-#include "src/tint/ir/test_helper.h"
+#include "src/tint/ir/ir_test_helper.h"
namespace tint::ir {
namespace {
using namespace tint::number_suffixes; // NOLINT
-using IR_BinaryTest = TestHelper;
+using IR_BinaryTest = IRTestHelper;
TEST_F(IR_BinaryTest, CreateAnd) {
- Module mod;
- Builder b{mod};
-
const auto* inst = b.And(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
@@ -45,9 +42,6 @@
}
TEST_F(IR_BinaryTest, CreateOr) {
- Module mod;
- Builder b{mod};
-
const auto* inst = b.Or(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
@@ -65,9 +59,6 @@
}
TEST_F(IR_BinaryTest, CreateXor) {
- Module mod;
- Builder b{mod};
-
const auto* inst = b.Xor(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
@@ -85,9 +76,6 @@
}
TEST_F(IR_BinaryTest, CreateEqual) {
- Module mod;
- Builder b{mod};
-
const auto* inst = b.Equal(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
@@ -105,9 +93,6 @@
}
TEST_F(IR_BinaryTest, CreateNotEqual) {
- Module mod;
- Builder b{mod};
-
const auto* inst = b.NotEqual(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
@@ -125,9 +110,6 @@
}
TEST_F(IR_BinaryTest, CreateLessThan) {
- Module mod;
- Builder b{mod};
-
const auto* inst = b.LessThan(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
@@ -145,9 +127,6 @@
}
TEST_F(IR_BinaryTest, CreateGreaterThan) {
- Module mod;
- Builder b{mod};
-
const auto* inst = b.GreaterThan(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
@@ -165,9 +144,6 @@
}
TEST_F(IR_BinaryTest, CreateLessThanEqual) {
- Module mod;
- Builder b{mod};
-
const auto* inst = b.LessThanEqual(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
@@ -185,9 +161,6 @@
}
TEST_F(IR_BinaryTest, CreateGreaterThanEqual) {
- Module mod;
- Builder b{mod};
-
const auto* inst = b.GreaterThanEqual(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
@@ -205,8 +178,6 @@
}
TEST_F(IR_BinaryTest, CreateNot) {
- Module mod;
- Builder b{mod};
const auto* inst = b.Not(mod.Types().bool_(), b.Constant(true));
ASSERT_TRUE(inst->Is<Binary>());
@@ -224,9 +195,6 @@
}
TEST_F(IR_BinaryTest, CreateShiftLeft) {
- Module mod;
- Builder b{mod};
-
const auto* inst = b.ShiftLeft(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
@@ -244,9 +212,6 @@
}
TEST_F(IR_BinaryTest, CreateShiftRight) {
- Module mod;
- Builder b{mod};
-
const auto* inst = b.ShiftRight(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
@@ -264,9 +229,6 @@
}
TEST_F(IR_BinaryTest, CreateAdd) {
- Module mod;
- Builder b{mod};
-
const auto* inst = b.Add(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
@@ -284,9 +246,6 @@
}
TEST_F(IR_BinaryTest, CreateSubtract) {
- Module mod;
- Builder b{mod};
-
const auto* inst = b.Subtract(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
@@ -304,9 +263,6 @@
}
TEST_F(IR_BinaryTest, CreateMultiply) {
- Module mod;
- Builder b{mod};
-
const auto* inst = b.Multiply(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
@@ -324,9 +280,6 @@
}
TEST_F(IR_BinaryTest, CreateDivide) {
- Module mod;
- Builder b{mod};
-
const auto* inst = b.Divide(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
@@ -344,9 +297,6 @@
}
TEST_F(IR_BinaryTest, CreateModulo) {
- Module mod;
- Builder b{mod};
-
const auto* inst = b.Modulo(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
ASSERT_TRUE(inst->Is<Binary>());
@@ -364,8 +314,6 @@
}
TEST_F(IR_BinaryTest, Binary_Usage) {
- Module mod;
- Builder b{mod};
const auto* inst = b.And(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
EXPECT_EQ(inst->Kind(), Binary::Kind::kAnd);
@@ -380,8 +328,6 @@
}
TEST_F(IR_BinaryTest, Binary_Usage_DuplicateValue) {
- Module mod;
- Builder b{mod};
auto val = b.Constant(4_i);
const auto* inst = b.And(mod.Types().i32(), val, val);
diff --git a/src/tint/ir/bitcast_test.cc b/src/tint/ir/bitcast_test.cc
index d63853f..dd9209f 100644
--- a/src/tint/ir/bitcast_test.cc
+++ b/src/tint/ir/bitcast_test.cc
@@ -15,18 +15,16 @@
#include "src/tint/ir/builder.h"
#include "src/tint/ir/constant.h"
#include "src/tint/ir/instruction.h"
-#include "src/tint/ir/test_helper.h"
+#include "src/tint/ir/ir_test_helper.h"
namespace tint::ir {
namespace {
using namespace tint::number_suffixes; // NOLINT
-using IR_BitcastTest = TestHelper;
+using IR_BitcastTest = IRTestHelper;
TEST_F(IR_BitcastTest, Bitcast) {
- Module mod;
- Builder b{mod};
const auto* inst = b.Bitcast(mod.Types().i32(), b.Constant(4_i));
ASSERT_TRUE(inst->Is<ir::Bitcast>());
@@ -41,8 +39,6 @@
}
TEST_F(IR_BitcastTest, Bitcast_Usage) {
- Module mod;
- Builder b{mod};
const auto* inst = b.Bitcast(mod.Types().i32(), b.Constant(4_i));
const auto args = inst->Args();
diff --git a/src/tint/ir/block_test.cc b/src/tint/ir/block_test.cc
index adce60f..ce48026 100644
--- a/src/tint/ir/block_test.cc
+++ b/src/tint/ir/block_test.cc
@@ -14,18 +14,12 @@
#include "src/tint/ir/block.h"
#include "gtest/gtest-spi.h"
-#include "gtest/gtest.h"
-#include "src/tint/ir/builder.h"
-#include "src/tint/ir/module.h"
+#include "src/tint/ir/ir_test_helper.h"
namespace tint::ir {
namespace {
-class IR_BlockTest : public ::testing::Test {
- public:
- Module mod;
- Builder b{mod};
-};
+using IR_BlockTest = IRTestHelper;
TEST_F(IR_BlockTest, SetInstructions) {
auto* inst1 = b.CreateLoop();
diff --git a/src/tint/ir/branch.cc b/src/tint/ir/branch.cc
index 191831f..2c59c2c 100644
--- a/src/tint/ir/branch.cc
+++ b/src/tint/ir/branch.cc
@@ -24,6 +24,7 @@
Branch::Branch(utils::VectorRef<Value*> args) : args_(std::move(args)) {
for (auto* arg : args) {
+ TINT_ASSERT(IR, arg);
arg->AddUsage(this);
}
}
diff --git a/src/tint/ir/builder.cc b/src/tint/ir/builder.cc
index 4706d59..b93cb94 100644
--- a/src/tint/ir/builder.cc
+++ b/src/tint/ir/builder.cc
@@ -241,4 +241,16 @@
return ir.values.Create<ir::FunctionParam>(type);
}
+ir::Access* Builder::Access(const type::Type* type,
+ Value* source,
+ utils::VectorRef<Value*> indices) {
+ return ir.values.Create<ir::Access>(type, source, indices);
+}
+
+ir::Swizzle* Builder::Swizzle(const type::Type* type,
+ Value* source,
+ utils::VectorRef<uint32_t> indices) {
+ return ir.values.Create<ir::Swizzle>(type, source, indices);
+}
+
} // namespace tint::ir
diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h
index 092162c..9a554cd 100644
--- a/src/tint/ir/builder.h
+++ b/src/tint/ir/builder.h
@@ -18,6 +18,7 @@
#include <utility>
#include "src/tint/constant/scalar.h"
+#include "src/tint/ir/access.h"
#include "src/tint/ir/binary.h"
#include "src/tint/ir/bitcast.h"
#include "src/tint/ir/block_param.h"
@@ -41,6 +42,7 @@
#include "src/tint/ir/return.h"
#include "src/tint/ir/store.h"
#include "src/tint/ir/switch.h"
+#include "src/tint/ir/swizzle.h"
#include "src/tint/ir/unary.h"
#include "src/tint/ir/user_call.h"
#include "src/tint/ir/value.h"
@@ -49,6 +51,7 @@
#include "src/tint/type/f16.h"
#include "src/tint/type/f32.h"
#include "src/tint/type/i32.h"
+#include "src/tint/type/pointer.h"
#include "src/tint/type/u32.h"
#include "src/tint/type/vector.h"
#include "src/tint/type/void.h"
@@ -388,6 +391,20 @@
/// @returns the value
ir::FunctionParam* FunctionParam(const type::Type* type);
+ /// Creates a new `Access`
+ /// @param type the return type
+ /// @param source the source value
+ /// @param indices the access indices
+ /// @returns the instruction
+ ir::Access* Access(const type::Type* type, Value* source, utils::VectorRef<Value*> indices);
+
+ /// Creates a new `Swizzle`
+ /// @param type the return type
+ /// @param source the source value
+ /// @param indices the access indices
+ /// @returns the instruction
+ ir::Swizzle* Swizzle(const type::Type* type, Value* source, utils::VectorRef<uint32_t> indices);
+
/// Retrieves the root block for the module, creating if necessary
/// @returns the root block
ir::Block* CreateRootBlockIfNeeded();
diff --git a/src/tint/ir/constant_test.cc b/src/tint/ir/constant_test.cc
index 7c3024a..67fd5e7 100644
--- a/src/tint/ir/constant_test.cc
+++ b/src/tint/ir/constant_test.cc
@@ -13,7 +13,7 @@
// limitations under the License.
#include "src/tint/ir/builder.h"
-#include "src/tint/ir/test_helper.h"
+#include "src/tint/ir/ir_test_helper.h"
#include "src/tint/ir/value.h"
namespace tint::ir {
@@ -21,12 +21,9 @@
using namespace tint::number_suffixes; // NOLINT
-using IR_ConstantTest = TestHelper;
+using IR_ConstantTest = IRTestHelper;
TEST_F(IR_ConstantTest, f32) {
- Module mod;
- Builder b{mod};
-
utils::StringStream str;
auto* c = b.Constant(1.2_f);
@@ -40,9 +37,6 @@
}
TEST_F(IR_ConstantTest, f16) {
- Module mod;
- Builder b{mod};
-
utils::StringStream str;
auto* c = b.Constant(1.1_h);
@@ -56,9 +50,6 @@
}
TEST_F(IR_ConstantTest, i32) {
- Module mod;
- Builder b{mod};
-
utils::StringStream str;
auto* c = b.Constant(1_i);
@@ -72,9 +63,6 @@
}
TEST_F(IR_ConstantTest, u32) {
- Module mod;
- Builder b{mod};
-
utils::StringStream str;
auto* c = b.Constant(2_u);
@@ -88,9 +76,6 @@
}
TEST_F(IR_ConstantTest, bool) {
- Module mod;
- Builder b{mod};
-
{
utils::StringStream str;
diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc
index 6ca9356..65e1524 100644
--- a/src/tint/ir/disassembler.cc
+++ b/src/tint/ir/disassembler.cc
@@ -18,6 +18,7 @@
#include "src/tint/constant/composite.h"
#include "src/tint/constant/scalar.h"
#include "src/tint/constant/splat.h"
+#include "src/tint/ir/access.h"
#include "src/tint/ir/binary.h"
#include "src/tint/ir/bitcast.h"
#include "src/tint/ir/block.h"
@@ -37,6 +38,7 @@
#include "src/tint/ir/return.h"
#include "src/tint/ir/store.h"
#include "src/tint/ir/switch.h"
+#include "src/tint/ir/swizzle.h"
#include "src/tint/ir/user_call.h"
#include "src/tint/ir/var.h"
#include "src/tint/switch.h"
@@ -292,17 +294,19 @@
out_ << (scalar->ValueAs<bool>() ? "true" : "false");
},
[&](const constant::Splat* splat) {
- out_ << splat->Type()->FriendlyName() << " ";
+ out_ << splat->Type()->FriendlyName() << "(";
emit(splat->Index(0));
+ out_ << ")";
},
[&](const constant::Composite* composite) {
- out_ << composite->Type()->FriendlyName() << " ";
+ out_ << composite->Type()->FriendlyName() << "(";
for (const auto* elem : composite->elements) {
if (elem != composite->elements[0]) {
out_ << ", ";
}
emit(elem);
}
+ out_ << ")";
});
};
emit(constant->Value());
@@ -384,6 +388,42 @@
out_ << std::endl;
},
+ [&](const ir::Access* a) {
+ EmitValueWithType(a);
+ out_ << " = access ";
+ EmitValue(a->Object());
+ out_ << ", ";
+ for (size_t i = 0; i < a->Indices().Length(); ++i) {
+ if (i > 0) {
+ out_ << ", ";
+ }
+ EmitValue(a->Indices()[i]);
+ }
+ out_ << std::endl;
+ },
+ [&](const ir::Swizzle* s) {
+ EmitValueWithType(s);
+ out_ << " = swizzle ";
+ EmitValue(s->Object());
+ out_ << ", ";
+ for (auto idx : s->Indices()) {
+ switch (idx) {
+ case 0:
+ out_ << "x";
+ break;
+ case 1:
+ out_ << "y";
+ break;
+ case 2:
+ out_ << "z";
+ break;
+ case 3:
+ out_ << "w";
+ break;
+ }
+ }
+ out_ << std::endl;
+ },
[&](const ir::Branch* b) { EmitBranch(b); },
[&](Default) { out_ << "Unknown instruction: " << inst->TypeInfo().name; });
}
diff --git a/src/tint/ir/discard_test.cc b/src/tint/ir/discard_test.cc
index a09c2e9..ea264f1 100644
--- a/src/tint/ir/discard_test.cc
+++ b/src/tint/ir/discard_test.cc
@@ -14,17 +14,14 @@
#include "src/tint/ir/builder.h"
#include "src/tint/ir/instruction.h"
-#include "src/tint/ir/test_helper.h"
+#include "src/tint/ir/ir_test_helper.h"
namespace tint::ir {
namespace {
-using IR_DiscardTest = TestHelper;
+using IR_DiscardTest = IRTestHelper;
TEST_F(IR_DiscardTest, Discard) {
- Module mod;
- Builder b{mod};
-
const auto* inst = b.Discard();
ASSERT_TRUE(inst->Is<ir::Discard>());
}
diff --git a/src/tint/ir/from_program.cc b/src/tint/ir/from_program.cc
index a3c36e3..49fbfe2 100644
--- a/src/tint/ir/from_program.cc
+++ b/src/tint/ir/from_program.cc
@@ -17,7 +17,9 @@
#include <iostream>
#include <unordered_map>
#include <utility>
+#include <vector>
+#include "src/tint/ast/accessor_expression.h"
#include "src/tint/ast/alias.h"
#include "src/tint/ast/assignment_statement.h"
#include "src/tint/ast/binary_expression.h"
@@ -42,12 +44,14 @@
#include "src/tint/ast/identifier_expression.h"
#include "src/tint/ast/if_statement.h"
#include "src/tint/ast/increment_decrement_statement.h"
+#include "src/tint/ast/index_accessor_expression.h"
#include "src/tint/ast/int_literal_expression.h"
#include "src/tint/ast/interpolate_attribute.h"
#include "src/tint/ast/invariant_attribute.h"
#include "src/tint/ast/let.h"
#include "src/tint/ast/literal_expression.h"
#include "src/tint/ast/loop_statement.h"
+#include "src/tint/ast/member_accessor_expression.h"
#include "src/tint/ast/override.h"
#include "src/tint/ast/phony_expression.h"
#include "src/tint/ast/return_statement.h"
@@ -80,6 +84,7 @@
#include "src/tint/sem/function.h"
#include "src/tint/sem/load.h"
#include "src/tint/sem/materialize.h"
+#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/module.h"
#include "src/tint/sem/switch_statement.h"
#include "src/tint/sem/value_constructor.h"
@@ -89,9 +94,11 @@
#include "src/tint/switch.h"
#include "src/tint/type/pointer.h"
#include "src/tint/type/reference.h"
+#include "src/tint/type/struct.h"
#include "src/tint/type/void.h"
#include "src/tint/utils/defer.h"
#include "src/tint/utils/result.h"
+#include "src/tint/utils/reverse.h"
#include "src/tint/utils/scoped_assignment.h"
using namespace tint::number_suffixes; // NOLINT
@@ -476,7 +483,6 @@
(void)EmitExpression(stmt->rhs);
return;
}
-
auto lhs = EmitExpression(stmt->lhs);
if (!lhs) {
return;
@@ -870,6 +876,148 @@
SetBranch(builder_.BreakIf(cond.Get(), current_control->As<ir::Loop>()));
}
+ struct AccessorInfo {
+ Value* object = nullptr;
+ Instruction* result = nullptr;
+ const type::Type* result_type = nullptr;
+ utils::Vector<Value*, 1> indices;
+ };
+
+ utils::Result<Value*> EmitAccess(const ast::AccessorExpression* expr) {
+ std::vector<const ast::Expression*> accessors;
+ const ast::Expression* object = expr;
+ while (true) {
+ if (auto* array = object->As<ast::IndexAccessorExpression>()) {
+ accessors.push_back(object);
+ object = array->object;
+ } else if (auto* member = object->As<ast::MemberAccessorExpression>()) {
+ accessors.push_back(object);
+ object = member->object;
+ } else {
+ break;
+ }
+ }
+
+ AccessorInfo info;
+ {
+ auto res = EmitExpression(object);
+ if (!res) {
+ return utils::Failure;
+ }
+ info.object = res.Get();
+ }
+ info.result_type =
+ program_->Sem().Get(expr)->Type()->UnwrapRef()->Clone(clone_ctx_.type_ctx);
+
+ // The AST chain is `inside-out` compared to what we need, which means the list it generates
+ // is backwards. We need to operate on the list in reverse order to have the correct access
+ // chain.
+ for (auto* accessor : utils::Reverse(accessors)) {
+ bool ok = tint::Switch(
+ accessor,
+ [&](const ast::IndexAccessorExpression* idx) {
+ return GenerateIndexAccessor(idx, info);
+ },
+ [&](const ast::MemberAccessorExpression* member) {
+ return GenerateMemberAccessor(member, info);
+ },
+ [&](Default) {
+ TINT_ICE(Writer, diagnostics_)
+ << "invalid accessor in list: " + std::string(accessor->TypeInfo().name);
+ return false;
+ });
+ if (!ok) {
+ return utils::Failure;
+ }
+ }
+
+ if (!info.indices.IsEmpty()) {
+ info.result = GenerateAccess(info);
+ }
+ return info.result;
+ }
+
+ Instruction* GenerateAccess(const AccessorInfo& info) {
+ // The access result type should match the source result type. If the source is a pointer,
+ // we generate a pointer.
+ const type::Type* ty = nullptr;
+ if (info.object->Type()->Is<type::Pointer>() && !info.result_type->Is<type::Pointer>()) {
+ auto* ptr = info.object->Type()->As<type::Pointer>();
+ ty = builder_.ir.Types().pointer(info.result_type, ptr->AddressSpace(), ptr->Access());
+ } else {
+ ty = info.result_type;
+ }
+
+ auto* a = builder_.Access(ty, info.object, info.indices);
+ current_block_->Append(a);
+ return a;
+ }
+
+ bool GenerateIndexAccessor(const ast::IndexAccessorExpression* expr, AccessorInfo& info) {
+ auto res = EmitExpression(expr->index);
+ if (!res) {
+ return false;
+ }
+
+ info.indices.Push(res.Get());
+ return true;
+ }
+
+ bool GenerateMemberAccessor(const ast::MemberAccessorExpression* expr, AccessorInfo& info) {
+ auto* expr_sem = program_->Sem().Get(expr)->UnwrapLoad();
+
+ return tint::Switch(
+ expr_sem, //
+ [&](const sem::StructMemberAccess* access) {
+ uint32_t idx = access->Member()->Index();
+ info.indices.Push(builder_.Constant(u32(idx)));
+ return true;
+ },
+ [&](const sem::Swizzle* swizzle) {
+ auto& indices = swizzle->Indices();
+
+ // A single element swizzle is just treated as an accessor.
+ if (indices.Length() == 1) {
+ info.indices.Push(builder_.Constant(u32(indices[0])));
+ return true;
+ }
+
+ // Store the result type away, this will be the result of the swizzle, but the
+ // intermediate steps need different result types.
+ auto* result_type = info.result_type;
+
+ // Emit any preceeding member/index accessors
+ if (!info.indices.IsEmpty()) {
+ // The access chain is being split, the initial part of than will have a
+ // resulting type that matches the object being swizzled.
+ info.result_type = swizzle->Object()->Type()->Clone(clone_ctx_.type_ctx);
+ info.object = GenerateAccess(info);
+ info.indices.Clear();
+
+ // If the sub-accessor generated a pointer result, make sure a load is emitted
+ if (auto* ptr = info.object->Type()->As<type::Pointer>()) {
+ auto* load = builder_.Load(info.object);
+ info.result_type = ptr->StoreType();
+ info.object = load;
+ current_block_->Append(load);
+ }
+ }
+
+ info.result = builder_.Swizzle(swizzle->Type()->Clone(clone_ctx_.type_ctx),
+ info.object, std::move(indices));
+ current_block_->Append(info.result);
+
+ info.object = info.result;
+ info.result_type = result_type;
+ return true;
+ },
+ [&](Default) {
+ TINT_ICE(IR, diagnostics_)
+ << "unhandled member index type: " << expr_sem->TypeInfo().name;
+ return false;
+ });
+ }
+
utils::Result<Value*> EmitExpression(const ast::Expression* expr) {
// If this is a value that has been const-eval'd return the result.
auto* sem = program_->Sem().GetVal(expr);
@@ -882,10 +1030,8 @@
}
auto result = tint::Switch(
- expr,
- // [&](const ast::IndexAccessorExpression* a) {
- // TODO(dsinclair): Implement
- // },
+ expr, //
+ [&](const ast::AccessorExpression* a) { return EmitAccess(a); },
[&](const ast::BinaryExpression* b) { return EmitBinary(b); },
[&](const ast::BitcastExpression* b) { return EmitBitcast(b); },
[&](const ast::CallExpression* c) { return EmitCall(c); },
@@ -899,9 +1045,6 @@
return {v};
},
[&](const ast::LiteralExpression* l) { return EmitLiteral(l); },
- // [&](const ast::MemberAccessorExpression* m) {
- // TODO(dsinclair): Implement
- // },
[&](const ast::UnaryOpExpression* u) { return EmitUnary(u); },
// Note, ast::PhonyExpression is explicitly not handled here as it should never get
// into this method. The assignment statement should have filtered it out already.
@@ -917,7 +1060,6 @@
current_block_->Append(load);
return load;
}
-
return result;
}
diff --git a/src/tint/ir/from_program_accessor_test.cc b/src/tint/ir/from_program_accessor_test.cc
new file mode 100644
index 0000000..9da49a1
--- /dev/null
+++ b/src/tint/ir/from_program_accessor_test.cc
@@ -0,0 +1,524 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gmock/gmock.h"
+#include "src/tint/ast/case_selector.h"
+#include "src/tint/ast/int_literal_expression.h"
+#include "src/tint/constant/scalar.h"
+#include "src/tint/ir/block.h"
+#include "src/tint/ir/constant.h"
+#include "src/tint/ir/program_test_helper.h"
+#include "src/tint/ir/var.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+
+using IR_FromProgramAccessorTest = ProgramTestHelper;
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Var_SingleIndex) {
+ // var a: vec3<u32>
+ // let b = a[2]
+
+ auto* a = Var("a", ty.vec3<u32>(), builtin::AddressSpace::kFunction);
+ auto* expr = Decl(Let("b", IndexAccessor(a, 2_u)));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %a:ptr<function, vec3<u32>, read_write> = var
+ %3:ptr<function, u32, read_write> = access %a, 2u
+ %b:u32 = load %3
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Var_MultiIndex) {
+ // var a: mat3x4<f32>
+ // let b = a[2][3]
+
+ auto* a = Var("a", ty.mat3x4<f32>(), builtin::AddressSpace::kFunction);
+ auto* expr = Decl(Let("b", IndexAccessor(IndexAccessor(a, 2_u), 3_u)));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %a:ptr<function, mat3x4<f32>, read_write> = var
+ %3:ptr<function, f32, read_write> = access %a, 2u, 3u
+ %b:f32 = load %3
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Var_SingleMember) {
+ // struct MyStruct { foo: i32 }
+ // var a: MyStruct;
+ // let b = a.foo
+
+ auto* s = Structure("MyStruct", utils::Vector{
+ Member("foo", ty.i32()),
+ });
+ auto* a = Var("a", ty.Of(s), builtin::AddressSpace::kFunction);
+ auto* expr = Decl(Let("b", MemberAccessor(a, "foo")));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %a:ptr<function, MyStruct, read_write> = var
+ %3:ptr<function, i32, read_write> = access %a, 0u
+ %b:i32 = load %3
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Var_MultiMember) {
+ // struct Inner { bar: f32 }
+ // struct Outer { a: i32, foo: Inner }
+ // var a: Outer;
+ // let b = a.foo.bar
+
+ auto* inner = Structure("Inner", utils::Vector{
+ Member("bar", ty.f32()),
+ });
+ auto* outer = Structure("Outer", utils::Vector{
+ Member("a", ty.i32()),
+ Member("foo", ty.Of(inner)),
+ });
+ auto* a = Var("a", ty.Of(outer), builtin::AddressSpace::kFunction);
+ auto* expr = Decl(Let("b", MemberAccessor(MemberAccessor(a, "foo"), "bar")));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %a:ptr<function, Outer, read_write> = var
+ %3:ptr<function, f32, read_write> = access %a, 1u, 0u
+ %b:f32 = load %3
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Var_Mixed) {
+ // struct Inner { b: i32, c: f32, bar: vec4<f32> }
+ // struct Outer { a: i32, foo: array<Inner, 4> }
+ // var a: array<Outer, 4>
+ // let b = a[0].foo[1].bar
+
+ auto* inner = Structure("Inner", utils::Vector{
+ Member("b", ty.i32()),
+ Member("c", ty.f32()),
+ Member("bar", ty.vec4<f32>()),
+ });
+ auto* outer = Structure("Outer", utils::Vector{
+ Member("a", ty.i32()),
+ Member("foo", ty.array(ty.Of(inner), 4_u)),
+ });
+ auto* a = Var("a", ty.array(ty.Of(outer), 4_u), builtin::AddressSpace::kFunction);
+ auto* expr = Decl(Let(
+ "b",
+ MemberAccessor(IndexAccessor(MemberAccessor(IndexAccessor(a, 0_u), "foo"), 1_u), "bar")));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %a:ptr<function, array<Outer, 4>, read_write> = var
+ %3:ptr<function, vec4<f32>, read_write> = access %a, 0u, 1u, 1u, 2u
+ %b:vec4<f32> = load %3
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Var_AssignmentLHS) {
+ // var a: array<u32, 4>();
+ // a[2] = 0;
+
+ auto* a = Var("a", ty.array<u32, 4>(), builtin::AddressSpace::kFunction);
+ auto* assign = Assign(IndexAccessor(a, 2_u), 0_u);
+ WrapInFunction(Block(utils::Vector{Decl(a), assign}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %a:ptr<function, array<u32, 4>, read_write> = var
+ %3:ptr<function, u32, read_write> = access %a, 2u
+ store %3, 0u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Var_SingleElementSwizzle) {
+ // var a: vec2<f32>
+ // let b = a.y
+
+ auto* a = Var("a", ty.vec2<f32>(), builtin::AddressSpace::kFunction);
+ auto* expr = Decl(Let("b", MemberAccessor(a, "y")));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %a:ptr<function, vec2<f32>, read_write> = var
+ %3:ptr<function, f32, read_write> = access %a, 1u
+ %b:f32 = load %3
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Var_MultiElementSwizzle) {
+ // var a: vec3<f32>
+ // let b = a.zyxz
+
+ auto* a = Var("a", ty.vec3<f32>(), builtin::AddressSpace::kFunction);
+ auto* expr = Decl(Let("b", MemberAccessor(a, "zyxz")));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %a:ptr<function, vec3<f32>, read_write> = var
+ %3:vec3<f32> = load %a
+ %b:vec4<f32> = swizzle %3, zyxz
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Var_MultiElementSwizzleOfSwizzle) {
+ // var a: vec3<f32>
+ // let b = a.zyx.yy
+
+ auto* a = Var("a", ty.vec3<f32>(), builtin::AddressSpace::kFunction);
+ auto* expr = Decl(Let("b", MemberAccessor(MemberAccessor(a, "zyx"), "yy")));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %a:ptr<function, vec3<f32>, read_write> = var
+ %3:vec3<f32> = load %a
+ %4:vec3<f32> = swizzle %3, zyx
+ %b:vec2<f32> = swizzle %4, yy
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Var_MultiElementSwizzle_MiddleOfChain) {
+ // struct MyStruct { a: i32; foo: vec4<f32> }
+ // var a: MyStruct;
+ // let b = a.foo.zyx.yx[0]
+
+ auto* s = Structure("MyStruct", utils::Vector{
+ Member("a", ty.i32()),
+ Member("foo", ty.vec4<f32>()),
+ });
+ auto* a = Var("a", ty.Of(s), builtin::AddressSpace::kFunction);
+ auto* expr = Decl(Let(
+ "b",
+ IndexAccessor(MemberAccessor(MemberAccessor(MemberAccessor(a, "foo"), "zyx"), "yx"), 0_u)));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %a:ptr<function, MyStruct, read_write> = var
+ %3:ptr<function, vec4<f32>, read_write> = access %a, 1u
+ %4:vec4<f32> = load %3
+ %5:vec3<f32> = swizzle %4, zyx
+ %6:vec2<f32> = swizzle %5, yx
+ %b:f32 = access %6, 0u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Let_SingleIndex) {
+ // let a: vec3<u32> = vec3()
+ // let b = a[2]
+ auto* a = Let("a", ty.vec3<u32>(), vec(ty.u32(), 3));
+ auto* expr = Decl(Let("b", IndexAccessor(a, 2_u)));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %b:u32 = access vec3<u32>(0u), 2u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Let_MultiIndex) {
+ // let a: mat3x4<f32> = mat3x4<u32>()
+ // let b = a[2][3]
+
+ auto* a = Let("a", ty.mat3x4<f32>(), mat3x4<f32>());
+ auto* expr = Decl(Let("b", IndexAccessor(IndexAccessor(a, 2_u), 3_u)));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %b:f32 = access mat3x4<f32>(vec4<f32>(0.0f)), 2u, 3u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Let_SingleMember) {
+ // struct MyStruct { foo: i32 }
+ // let a: MyStruct = MyStruct();
+ // let b = a.foo
+
+ auto* s = Structure("MyStruct", utils::Vector{
+ Member("foo", ty.i32()),
+ });
+ auto* a = Let("a", ty.Of(s), Call("MyStruct"));
+ auto* expr = Decl(Let("b", MemberAccessor(a, "foo")));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %b:i32 = access MyStruct(0i), 0u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Let_MultiMember) {
+ // struct Inner { bar: f32 }
+ // struct Outer { a: i32, foo: Inner }
+ // let a: Outer = Outer();
+ // let b = a.foo.bar
+
+ auto* inner = Structure("Inner", utils::Vector{
+ Member("bar", ty.f32()),
+ });
+ auto* outer = Structure("Outer", utils::Vector{
+ Member("a", ty.i32()),
+ Member("foo", ty.Of(inner)),
+ });
+ auto* a = Let("a", ty.Of(outer), Call("Outer"));
+ auto* expr = Decl(Let("b", MemberAccessor(MemberAccessor(a, "foo"), "bar")));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %b:f32 = access Outer(0i, Inner(0.0f)), 1u, 0u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Let_Mixed) {
+ // struct Outer { a: i32, foo: array<Inner, 4> }
+ // struct Inner { b: i32, c: f32, bar: vec4<f32> }
+ // let a: array<Outer, 4> = array();
+ // let b = a[0].foo[1].bar
+
+ auto* inner = Structure("Inner", utils::Vector{
+ Member("b", ty.i32()),
+ Member("c", ty.f32()),
+ Member("bar", ty.vec4<f32>()),
+ });
+ auto* outer = Structure("Outer", utils::Vector{
+ Member("a", ty.i32()),
+ Member("foo", ty.array(ty.Of(inner), 4_u)),
+ });
+ auto* a = Let("a", ty.array(ty.Of(outer), 4_u), array(ty.Of(outer), 4_u));
+ auto* expr = Decl(Let(
+ "b",
+ MemberAccessor(IndexAccessor(MemberAccessor(IndexAccessor(a, 0_u), "foo"), 1_u), "bar")));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %b:vec4<f32> = access array<Outer, 4>(Outer(0i, array<Inner, 4>(Inner(0i, 0.0f, vec4<f32>(0.0f))))), 0u, 1u, 1u, 2u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Let_SingleElement) {
+ // let a: vec2<f32> = vec2()
+ // let b = a.y
+
+ auto* a = Let("a", ty.vec2<f32>(), vec(ty.f32(), 2));
+ auto* expr = Decl(Let("b", MemberAccessor(a, "y")));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %b:f32 = access vec2<f32>(0.0f), 1u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Let_MultiElementSwizzle) {
+ // let a: vec3<f32 = vec3()>
+ // let b = a.zyxz
+
+ auto* a = Let("a", ty.vec3<f32>(), vec(ty.f32(), 3));
+ auto* expr = Decl(Let("b", MemberAccessor(a, "zyxz")));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %b:vec4<f32> = swizzle vec3<f32>(0.0f), zyxz
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Let_MultiElementSwizzleOfSwizzle) {
+ // let a: vec3<f32> = vec3();
+ // let b = a.zyx.yy
+
+ auto* a = Let("a", ty.vec3<f32>(), vec(ty.f32(), 3));
+ auto* expr = Decl(Let("b", MemberAccessor(MemberAccessor(a, "zyx"), "yy")));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %2:vec3<f32> = swizzle vec3<f32>(0.0f), zyx
+ %b:vec2<f32> = swizzle %2, yy
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Let_MultiElementSwizzle_MiddleOfChain) {
+ // struct MyStruct { a: i32; foo: vec4<f32> }
+ // let a: MyStruct = MyStruct();
+ // let b = a.foo.zyx.yx[0]
+
+ auto* s = Structure("MyStruct", utils::Vector{
+ Member("a", ty.i32()),
+ Member("foo", ty.vec4<f32>()),
+ });
+ auto* a = Let("a", ty.Of(s), Call("MyStruct"));
+ auto* expr = Decl(Let(
+ "b",
+ IndexAccessor(MemberAccessor(MemberAccessor(MemberAccessor(a, "foo"), "zyx"), "yx"), 0_u)));
+ WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %2:vec4<f32> = access MyStruct(0i, vec4<f32>(0.0f)), 1u
+ %3:vec3<f32> = swizzle %2, zyx
+ %4:vec2<f32> = swizzle %3, yx
+ %b:f32 = access %4, 0u
+ ret
+ }
+}
+)");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/from_program_binary_test.cc b/src/tint/ir/from_program_binary_test.cc
index 2be72e1..a3bcb1e 100644
--- a/src/tint/ir/from_program_binary_test.cc
+++ b/src/tint/ir/from_program_binary_test.cc
@@ -12,19 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ir/test_helper.h"
-
#include "gmock/gmock.h"
#include "src/tint/ast/case_selector.h"
#include "src/tint/ast/int_literal_expression.h"
#include "src/tint/constant/scalar.h"
+#include "src/tint/ir/program_test_helper.h"
namespace tint::ir {
namespace {
using namespace tint::number_suffixes; // NOLINT
-using IR_FromProgramBinaryTest = TestHelper;
+using IR_FromProgramBinaryTest = ProgramTestHelper;
TEST_F(IR_FromProgramBinaryTest, EmitExpression_Binary_Add) {
Func("my_func", utils::Empty, ty.u32(), utils::Vector{Return(0_u)});
diff --git a/src/tint/ir/from_program_builtin_test.cc b/src/tint/ir/from_program_builtin_test.cc
index 7eafa17..262c8bd 100644
--- a/src/tint/ir/from_program_builtin_test.cc
+++ b/src/tint/ir/from_program_builtin_test.cc
@@ -12,19 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ir/test_helper.h"
-
#include "gmock/gmock.h"
#include "src/tint/ast/case_selector.h"
#include "src/tint/ast/int_literal_expression.h"
#include "src/tint/constant/scalar.h"
+#include "src/tint/ir/program_test_helper.h"
namespace tint::ir {
namespace {
using namespace tint::number_suffixes; // NOLINT
-using IR_FromProgramBuiltinTest = TestHelper;
+using IR_FromProgramBuiltinTest = ProgramTestHelper;
TEST_F(IR_FromProgramBuiltinTest, EmitExpression_Builtin) {
auto i = GlobalVar("i", builtin::AddressSpace::kPrivate, Expr(1_f));
diff --git a/src/tint/ir/from_program_call_test.cc b/src/tint/ir/from_program_call_test.cc
index 1aa224c..b2194ec 100644
--- a/src/tint/ir/from_program_call_test.cc
+++ b/src/tint/ir/from_program_call_test.cc
@@ -12,19 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ir/test_helper.h"
-
#include "gmock/gmock.h"
#include "src/tint/ast/case_selector.h"
#include "src/tint/ast/int_literal_expression.h"
#include "src/tint/constant/scalar.h"
+#include "src/tint/ir/program_test_helper.h"
namespace tint::ir {
namespace {
using namespace tint::number_suffixes; // NOLINT
-using IR_FromProgramCallTest = TestHelper;
+using IR_FromProgramCallTest = ProgramTestHelper;
TEST_F(IR_FromProgramCallTest, EmitExpression_Bitcast) {
Func("my_func", utils::Empty, ty.f32(), utils::Vector{Return(0_f)});
@@ -123,7 +122,7 @@
EXPECT_EQ(Disassemble(m.Get()), R"(# Root block
%b1 = block {
- %i:ptr<private, vec3<f32>, read_write> = var, vec3<f32> 0.0f
+ %i:ptr<private, vec3<f32>, read_write> = var, vec3<f32>(0.0f)
}
)");
diff --git a/src/tint/ir/from_program_function_test.cc b/src/tint/ir/from_program_function_test.cc
index a130281..c467f28 100644
--- a/src/tint/ir/from_program_function_test.cc
+++ b/src/tint/ir/from_program_function_test.cc
@@ -12,19 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ir/test_helper.h"
-
#include "gmock/gmock.h"
#include "src/tint/ast/case_selector.h"
#include "src/tint/ast/int_literal_expression.h"
#include "src/tint/constant/scalar.h"
+#include "src/tint/ir/program_test_helper.h"
namespace tint::ir {
namespace {
using namespace tint::number_suffixes; // NOLINT
-using IR_FromProgramFunctionTest = TestHelper;
+using IR_FromProgramFunctionTest = ProgramTestHelper;
TEST_F(IR_FromProgramFunctionTest, EmitFunction_Vertex) {
Func("test", utils::Empty, ty.vec4<f32>(), utils::Vector{Return(vec4<f32>(0_f, 0_f, 0_f, 0_f))},
@@ -36,7 +35,7 @@
EXPECT_EQ(Disassemble(m.Get()), R"(%test = @vertex func():vec4<f32> [@position] -> %b1 {
%b1 = block {
- ret vec4<f32> 0.0f
+ ret vec4<f32>(0.0f)
}
}
)");
@@ -82,7 +81,7 @@
EXPECT_EQ(Disassemble(m.Get()), R"(%test = func():vec3<f32> -> %b1 {
%b1 = block {
- ret vec3<f32> 0.0f
+ ret vec3<f32>(0.0f)
}
}
)");
@@ -98,7 +97,7 @@
EXPECT_EQ(Disassemble(m.Get()), R"(%test = @vertex func():vec4<f32> [@position] -> %b1 {
%b1 = block {
- ret vec4<f32> 1.0f, 2.0f, 3.0f, 4.0f
+ ret vec4<f32>(1.0f, 2.0f, 3.0f, 4.0f)
}
}
)");
@@ -115,7 +114,7 @@
EXPECT_EQ(Disassemble(m.Get()),
R"(%test = @vertex func():vec4<f32> [@invariant, @position] -> %b1 {
%b1 = block {
- ret vec4<f32> 1.0f, 2.0f, 3.0f, 4.0f
+ ret vec4<f32>(1.0f, 2.0f, 3.0f, 4.0f)
}
}
)");
@@ -131,7 +130,7 @@
EXPECT_EQ(Disassemble(m.Get()),
R"(%test = @fragment func():vec4<f32> [@location(1)] -> %b1 {
%b1 = block {
- ret vec4<f32> 1.0f, 2.0f, 3.0f, 4.0f
+ ret vec4<f32>(1.0f, 2.0f, 3.0f, 4.0f)
}
}
)");
@@ -150,7 +149,7 @@
Disassemble(m.Get()),
R"(%test = @fragment func():vec4<f32> [@location(1), @interpolate(linear, centroid)] -> %b1 {
%b1 = block {
- ret vec4<f32> 1.0f, 2.0f, 3.0f, 4.0f
+ ret vec4<f32>(1.0f, 2.0f, 3.0f, 4.0f)
}
}
)");
diff --git a/src/tint/ir/from_program_literal_test.cc b/src/tint/ir/from_program_literal_test.cc
index 1ea9706..d542a63 100644
--- a/src/tint/ir/from_program_literal_test.cc
+++ b/src/tint/ir/from_program_literal_test.cc
@@ -12,14 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ir/test_helper.h"
-
#include "gmock/gmock.h"
#include "src/tint/ast/case_selector.h"
#include "src/tint/ast/int_literal_expression.h"
#include "src/tint/constant/scalar.h"
#include "src/tint/ir/block.h"
#include "src/tint/ir/constant.h"
+#include "src/tint/ir/program_test_helper.h"
#include "src/tint/ir/var.h"
namespace tint::ir {
@@ -42,7 +41,7 @@
using namespace tint::number_suffixes; // NOLINT
-using IR_FromProgramLiteralTest = TestHelper;
+using IR_FromProgramLiteralTest = ProgramTestHelper;
TEST_F(IR_FromProgramLiteralTest, EmitLiteral_Bool_True) {
auto* expr = Expr(true);
diff --git a/src/tint/ir/from_program_materialize_test.cc b/src/tint/ir/from_program_materialize_test.cc
index 07b49c1..21ccebd 100644
--- a/src/tint/ir/from_program_materialize_test.cc
+++ b/src/tint/ir/from_program_materialize_test.cc
@@ -12,19 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ir/test_helper.h"
-
#include "gmock/gmock.h"
#include "src/tint/ast/case_selector.h"
#include "src/tint/ast/int_literal_expression.h"
#include "src/tint/constant/scalar.h"
+#include "src/tint/ir/program_test_helper.h"
namespace tint::ir {
namespace {
using namespace tint::number_suffixes; // NOLINT
-using IR_FromProgramMaterializeTest = TestHelper;
+using IR_FromProgramMaterializeTest = ProgramTestHelper;
TEST_F(IR_FromProgramMaterializeTest, EmitExpression_MaterializedCall) {
auto* expr = Return(Call("trunc", 2.5_f));
diff --git a/src/tint/ir/from_program_store_test.cc b/src/tint/ir/from_program_store_test.cc
index daeed37..c58cd47 100644
--- a/src/tint/ir/from_program_store_test.cc
+++ b/src/tint/ir/from_program_store_test.cc
@@ -12,19 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ir/test_helper.h"
-
#include "gmock/gmock.h"
#include "src/tint/ast/case_selector.h"
#include "src/tint/ast/int_literal_expression.h"
#include "src/tint/constant/scalar.h"
+#include "src/tint/ir/program_test_helper.h"
namespace tint::ir {
namespace {
using namespace tint::number_suffixes; // NOLINT
-using IR_FromProgramStoreTest = TestHelper;
+using IR_FromProgramStoreTest = ProgramTestHelper;
TEST_F(IR_FromProgramStoreTest, EmitStatement_Assign) {
GlobalVar("a", ty.u32(), builtin::AddressSpace::kPrivate);
diff --git a/src/tint/ir/from_program_test.cc b/src/tint/ir/from_program_test.cc
index d7825c5..6b6abd9 100644
--- a/src/tint/ir/from_program_test.cc
+++ b/src/tint/ir/from_program_test.cc
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ir/test_helper.h"
-
#include "gmock/gmock.h"
#include "src/tint/ast/case_selector.h"
#include "src/tint/ast/int_literal_expression.h"
@@ -21,6 +19,7 @@
#include "src/tint/ir/block.h"
#include "src/tint/ir/if.h"
#include "src/tint/ir/loop.h"
+#include "src/tint/ir/program_test_helper.h"
#include "src/tint/ir/switch.h"
namespace tint::ir {
@@ -51,7 +50,7 @@
using namespace tint::number_suffixes; // NOLINT
-using IR_FromProgramTest = TestHelper;
+using IR_FromProgramTest = ProgramTestHelper;
TEST_F(IR_FromProgramTest, Func) {
Func("f", utils::Empty, ty.void_(), utils::Empty);
diff --git a/src/tint/ir/from_program_unary_test.cc b/src/tint/ir/from_program_unary_test.cc
index a7af3dd..2a66d1c 100644
--- a/src/tint/ir/from_program_unary_test.cc
+++ b/src/tint/ir/from_program_unary_test.cc
@@ -12,9 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ir/test_helper.h"
+#include "src/tint/ir/program_test_helper.h"
-#include "gmock/gmock.h"
#include "src/tint/ast/case_selector.h"
#include "src/tint/ast/int_literal_expression.h"
#include "src/tint/constant/scalar.h"
@@ -24,7 +23,7 @@
using namespace tint::number_suffixes; // NOLINT
-using IR_FromProgramUnaryTest = TestHelper;
+using IR_FromProgramUnaryTest = ProgramTestHelper;
TEST_F(IR_FromProgramUnaryTest, EmitExpression_Unary_Not) {
Func("my_func", utils::Empty, ty.bool_(), utils::Vector{Return(false)});
diff --git a/src/tint/ir/from_program_var_test.cc b/src/tint/ir/from_program_var_test.cc
index 5f6e78c..aedb0cb 100644
--- a/src/tint/ir/from_program_var_test.cc
+++ b/src/tint/ir/from_program_var_test.cc
@@ -12,19 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ir/test_helper.h"
-
#include "gmock/gmock.h"
#include "src/tint/ast/case_selector.h"
#include "src/tint/ast/int_literal_expression.h"
#include "src/tint/constant/scalar.h"
+#include "src/tint/ir/program_test_helper.h"
namespace tint::ir {
namespace {
using namespace tint::number_suffixes; // NOLINT
-using IR_FromProgramVarTest = TestHelper;
+using IR_FromProgramVarTest = ProgramTestHelper;
TEST_F(IR_FromProgramVarTest, Emit_GlobalVar_NoInit) {
GlobalVar("a", ty.u32(), builtin::AddressSpace::kPrivate);
diff --git a/src/tint/ir/instruction.cc b/src/tint/ir/instruction.cc
index 22e3904..ea9cf31 100644
--- a/src/tint/ir/instruction.cc
+++ b/src/tint/ir/instruction.cc
@@ -37,7 +37,7 @@
after->Block()->InsertAfter(after, this);
}
-void Instruction::Replace(Instruction* replacement) {
+void Instruction::ReplaceWith(Instruction* replacement) {
TINT_ASSERT_OR_RETURN(IR, replacement);
TINT_ASSERT_OR_RETURN(IR, Block() != nullptr);
Block()->Replace(this, replacement);
diff --git a/src/tint/ir/instruction.h b/src/tint/ir/instruction.h
index a79b3da..9a6087e 100644
--- a/src/tint/ir/instruction.h
+++ b/src/tint/ir/instruction.h
@@ -48,7 +48,7 @@
void InsertAfter(Instruction* after);
/// Replaces this instruction with @p replacement in the owning block owning this instruction
/// @param replacement the instruction to replace with
- void Replace(Instruction* replacement);
+ void ReplaceWith(Instruction* replacement);
/// Removes this instruction from the owning block
void Remove();
diff --git a/src/tint/ir/instruction_test.cc b/src/tint/ir/instruction_test.cc
index 3fb8ab0..3a19cd1 100644
--- a/src/tint/ir/instruction_test.cc
+++ b/src/tint/ir/instruction_test.cc
@@ -13,19 +13,15 @@
// limitations under the License.
#include "gtest/gtest-spi.h"
-#include "gtest/gtest.h"
#include "src/tint/ir/block.h"
#include "src/tint/ir/builder.h"
+#include "src/tint/ir/ir_test_helper.h"
#include "src/tint/ir/module.h"
namespace tint::ir {
namespace {
-class IR_InstructionTest : public ::testing::Test {
- public:
- Module mod;
- Builder b{mod};
-};
+using IR_InstructionTest = IRTestHelper;
TEST_F(IR_InstructionTest, InsertBefore) {
auto* inst1 = b.CreateLoop();
@@ -97,18 +93,18 @@
"");
}
-TEST_F(IR_InstructionTest, Replace) {
+TEST_F(IR_InstructionTest, ReplaceWith) {
auto* inst1 = b.CreateLoop();
auto* inst2 = b.CreateLoop();
auto* blk = b.CreateBlock();
blk->Append(inst2);
- inst2->Replace(inst1);
+ inst2->ReplaceWith(inst1);
EXPECT_EQ(1u, blk->Length());
EXPECT_EQ(inst1->Block(), blk);
EXPECT_EQ(inst2->Block(), nullptr);
}
-TEST_F(IR_InstructionTest, Fail_ReplaceNullptr) {
+TEST_F(IR_InstructionTest, Fail_ReplaceWithNullptr) {
EXPECT_FATAL_FAILURE(
{
Module mod;
@@ -117,12 +113,12 @@
auto* inst1 = b.CreateLoop();
auto* blk = b.CreateBlock();
blk->Append(inst1);
- inst1->Replace(nullptr);
+ inst1->ReplaceWith(nullptr);
},
"");
}
-TEST_F(IR_InstructionTest, Fail_ReplaceNotInserted) {
+TEST_F(IR_InstructionTest, Fail_ReplaceWithNotInserted) {
EXPECT_FATAL_FAILURE(
{
Module mod;
@@ -130,7 +126,7 @@
auto* inst1 = b.CreateLoop();
auto* inst2 = b.CreateLoop();
- inst1->Replace(inst2);
+ inst1->ReplaceWith(inst2);
},
"");
}
diff --git a/src/tint/ir/ir_test_helper.h b/src/tint/ir/ir_test_helper.h
new file mode 100644
index 0000000..a7826ec
--- /dev/null
+++ b/src/tint/ir/ir_test_helper.h
@@ -0,0 +1,44 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_IR_TEST_HELPER_H_
+#define SRC_TINT_IR_IR_TEST_HELPER_H_
+
+#include "gtest/gtest.h"
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/module.h"
+
+namespace tint::ir {
+
+/// Helper class for testing
+template <typename BASE>
+class IRTestHelperBase : public BASE {
+ public:
+ IRTestHelperBase() = default;
+ ~IRTestHelperBase() override = default;
+
+ /// The IR module
+ Module mod;
+ /// The IR builder
+ Builder b{mod};
+};
+
+using IRTestHelper = IRTestHelperBase<testing::Test>;
+
+template <typename T>
+using IRTestParamHelper = IRTestHelperBase<testing::TestWithParam<T>>;
+
+} // namespace tint::ir
+
+#endif // SRC_TINT_IR_IR_TEST_HELPER_H_
diff --git a/src/tint/ir/load_test.cc b/src/tint/ir/load_test.cc
index 02bc65b..d552fac 100644
--- a/src/tint/ir/load_test.cc
+++ b/src/tint/ir/load_test.cc
@@ -14,22 +14,19 @@
#include "src/tint/ir/builder.h"
#include "src/tint/ir/instruction.h"
-#include "src/tint/ir/test_helper.h"
+#include "src/tint/ir/ir_test_helper.h"
namespace tint::ir {
namespace {
using namespace tint::number_suffixes; // NOLINT
-using IR_LoadTest = TestHelper;
+using IR_LoadTest = IRTestHelper;
TEST_F(IR_LoadTest, CreateLoad) {
- Module mod;
- Builder b{mod};
-
auto* store_type = mod.Types().i32();
- auto* var = b.Declare(mod.Types().Get<type::Pointer>(
- store_type, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
+ auto* var = b.Declare(mod.Types().pointer(store_type, builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite));
const auto* inst = b.Load(var);
ASSERT_TRUE(inst->Is<Load>());
@@ -42,12 +39,9 @@
}
TEST_F(IR_LoadTest, Load_Usage) {
- Module mod;
- Builder b{mod};
-
auto* store_type = mod.Types().i32();
- auto* var = b.Declare(mod.Types().Get<type::Pointer>(
- store_type, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
+ auto* var = b.Declare(mod.Types().pointer(store_type, builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite));
const auto* inst = b.Load(var);
ASSERT_NE(inst->From(), nullptr);
diff --git a/src/tint/ir/module_test.cc b/src/tint/ir/module_test.cc
index f15aa35..945e499 100644
--- a/src/tint/ir/module_test.cc
+++ b/src/tint/ir/module_test.cc
@@ -13,7 +13,7 @@
// limitations under the License.
#include "src/tint/ir/module.h"
-#include "src/tint/ir/test_helper.h"
+#include "src/tint/ir/ir_test_helper.h"
#include "src/tint/ir/var.h"
namespace tint::ir {
@@ -21,23 +21,20 @@
using namespace tint::number_suffixes; // NOLINT
-using IR_ModuleTest = TestHelper;
+using IR_ModuleTest = IRTestHelper;
TEST_F(IR_ModuleTest, NameOfUnnamed) {
- Module mod;
auto* v = mod.values.Create<ir::Var>(mod.Types().i32());
EXPECT_FALSE(mod.NameOf(v).IsValid());
}
TEST_F(IR_ModuleTest, SetName) {
- Module mod;
auto* v = mod.values.Create<ir::Var>(mod.Types().i32());
EXPECT_EQ(mod.SetName(v, "a").Name(), "a");
EXPECT_EQ(mod.NameOf(v).Name(), "a");
}
TEST_F(IR_ModuleTest, SetNameRename) {
- Module mod;
auto* v = mod.values.Create<ir::Var>(mod.Types().i32());
EXPECT_EQ(mod.SetName(v, "a").Name(), "a");
EXPECT_EQ(mod.SetName(v, "b").Name(), "b");
@@ -45,7 +42,6 @@
}
TEST_F(IR_ModuleTest, SetNameCollision) {
- Module mod;
auto* a = mod.values.Create<ir::Var>(mod.Types().i32());
auto* b = mod.values.Create<ir::Var>(mod.Types().i32());
auto* c = mod.values.Create<ir::Var>(mod.Types().i32());
diff --git a/src/tint/ir/test_helper.h b/src/tint/ir/program_test_helper.h
similarity index 79%
rename from src/tint/ir/test_helper.h
rename to src/tint/ir/program_test_helper.h
index b7279f3..c0ecdb4 100644
--- a/src/tint/ir/test_helper.h
+++ b/src/tint/ir/program_test_helper.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef SRC_TINT_IR_TEST_HELPER_H_
-#define SRC_TINT_IR_TEST_HELPER_H_
+#ifndef SRC_TINT_IR_PROGRAM_TEST_HELPER_H_
+#define SRC_TINT_IR_PROGRAM_TEST_HELPER_H_
#include <memory>
#include <string>
@@ -30,10 +30,10 @@
/// Helper class for testing
template <typename BASE>
-class TestHelperBase : public BASE, public ProgramBuilder {
+class ProgramTestHelperBase : public BASE, public ProgramBuilder {
public:
- TestHelperBase() = default;
- ~TestHelperBase() override = default;
+ ProgramTestHelperBase() = default;
+ ~ProgramTestHelperBase() override = default;
/// Build the module, cleaning up the program before returning.
/// @returns the generated module
@@ -57,11 +57,11 @@
}
};
-using TestHelper = TestHelperBase<testing::Test>;
+using ProgramTestHelper = ProgramTestHelperBase<testing::Test>;
template <typename T>
-using TestParamHelper = TestHelperBase<testing::TestWithParam<T>>;
+using ProgramTestParamHelper = ProgramTestHelperBase<testing::TestWithParam<T>>;
} // namespace tint::ir
-#endif // SRC_TINT_IR_TEST_HELPER_H_
+#endif // SRC_TINT_IR_PROGRAM_TEST_HELPER_H_
diff --git a/src/tint/ir/store_test.cc b/src/tint/ir/store_test.cc
index bc68a98..a1b0d38 100644
--- a/src/tint/ir/store_test.cc
+++ b/src/tint/ir/store_test.cc
@@ -14,19 +14,16 @@
#include "src/tint/ir/builder.h"
#include "src/tint/ir/instruction.h"
-#include "src/tint/ir/test_helper.h"
+#include "src/tint/ir/ir_test_helper.h"
namespace tint::ir {
namespace {
using namespace tint::number_suffixes; // NOLINT
-using IR_StoreTest = TestHelper;
+using IR_StoreTest = IRTestHelper;
TEST_F(IR_StoreTest, CreateStore) {
- Module mod;
- Builder b{mod};
-
// TODO(dsinclair): This is wrong, but we don't have anything correct to store too at the
// moment.
auto* to = b.Discard();
@@ -42,9 +39,6 @@
}
TEST_F(IR_StoreTest, Store_Usage) {
- Module mod;
- Builder b{mod};
-
auto* to = b.Discard();
const auto* inst = b.Store(to, b.Constant(4_i));
diff --git a/src/tint/ir/swizzle.cc b/src/tint/ir/swizzle.cc
new file mode 100644
index 0000000..1887d4d
--- /dev/null
+++ b/src/tint/ir/swizzle.cc
@@ -0,0 +1,32 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/swizzle.h"
+
+#include <utility>
+
+#include "src/tint/debug.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::Swizzle);
+
+namespace tint::ir {
+
+Swizzle::Swizzle(const type::Type* ty, Value* object, utils::VectorRef<uint32_t> indices)
+ : result_type_(ty), object_(object), indices_(std::move(indices)) {
+ object_->AddUsage(this);
+}
+
+Swizzle::~Swizzle() = default;
+
+} // namespace tint::ir
diff --git a/src/tint/ir/swizzle.h b/src/tint/ir/swizzle.h
new file mode 100644
index 0000000..31c3da1
--- /dev/null
+++ b/src/tint/ir/swizzle.h
@@ -0,0 +1,50 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_SWIZZLE_H_
+#define SRC_TINT_IR_SWIZZLE_H_
+
+#include "src/tint/ir/instruction.h"
+#include "src/tint/utils/castable.h"
+
+namespace tint::ir {
+
+/// A swizzle instruction in the IR.
+class Swizzle : public utils::Castable<Swizzle, Instruction> {
+ public:
+ /// Constructor
+ /// @param result_type the result type
+ /// @param object the object being swizzled
+ /// @param indices the indices to swizzle
+ Swizzle(const type::Type* result_type, Value* object, utils::VectorRef<uint32_t> indices);
+ ~Swizzle() override;
+
+ /// @returns the type of the value
+ const type::Type* Type() const override { return result_type_; }
+
+ /// @returns the object used for the access
+ Value* Object() const { return object_; }
+
+ /// @returns the swizzle indices
+ utils::VectorRef<uint32_t> Indices() const { return indices_; }
+
+ private:
+ const type::Type* result_type_ = nullptr;
+ Value* object_ = nullptr;
+ utils::Vector<uint32_t, 4> indices_;
+};
+
+} // namespace tint::ir
+
+#endif // SRC_TINT_IR_SWIZZLE_H_
diff --git a/src/tint/ir/to_program_roundtrip_test.cc b/src/tint/ir/to_program_roundtrip_test.cc
index c8af503..48e9c5d 100644
--- a/src/tint/ir/to_program_roundtrip_test.cc
+++ b/src/tint/ir/to_program_roundtrip_test.cc
@@ -13,7 +13,7 @@
// limitations under the License.
#include "src/tint/ir/from_program.h"
-#include "src/tint/ir/test_helper.h"
+#include "src/tint/ir/program_test_helper.h"
#include "src/tint/ir/to_program.h"
#include "src/tint/reader/wgsl/parser.h"
#include "src/tint/utils/string.h"
@@ -28,7 +28,7 @@
using namespace tint::number_suffixes; // NOLINT
-class IRToProgramRoundtripTest : public TestHelper {
+class IRToProgramRoundtripTest : public ProgramTestHelper {
public:
void Test(std::string_view input_wgsl, std::string_view expected_wgsl) {
auto input = utils::TrimSpace(input_wgsl);
diff --git a/src/tint/ir/unary_test.cc b/src/tint/ir/unary_test.cc
index d6e5d50..e04b808 100644
--- a/src/tint/ir/unary_test.cc
+++ b/src/tint/ir/unary_test.cc
@@ -14,18 +14,16 @@
#include "src/tint/ir/builder.h"
#include "src/tint/ir/instruction.h"
-#include "src/tint/ir/test_helper.h"
+#include "src/tint/ir/ir_test_helper.h"
namespace tint::ir {
namespace {
using namespace tint::number_suffixes; // NOLINT
-using IR_UnaryTest = TestHelper;
+using IR_UnaryTest = IRTestHelper;
TEST_F(IR_UnaryTest, CreateComplement) {
- Module mod;
- Builder b{mod};
auto* inst = b.Complement(mod.Types().i32(), b.Constant(4_i));
ASSERT_TRUE(inst->Is<Unary>());
@@ -38,8 +36,6 @@
}
TEST_F(IR_UnaryTest, CreateNegation) {
- Module mod;
- Builder b{mod};
auto* inst = b.Negation(mod.Types().i32(), b.Constant(4_i));
ASSERT_TRUE(inst->Is<Unary>());
@@ -52,8 +48,6 @@
}
TEST_F(IR_UnaryTest, Unary_Usage) {
- Module mod;
- Builder b{mod};
auto* inst = b.Negation(mod.Types().i32(), b.Constant(4_i));
EXPECT_EQ(inst->Kind(), Unary::Kind::kNegation);
diff --git a/src/tint/ir/validate.cc b/src/tint/ir/validate.cc
new file mode 100644
index 0000000..5e28173
--- /dev/null
+++ b/src/tint/ir/validate.cc
@@ -0,0 +1,131 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/validate.h"
+
+#include <utility>
+
+#include "src/tint/ir/function.h"
+#include "src/tint/ir/if.h"
+#include "src/tint/ir/loop.h"
+#include "src/tint/ir/return.h"
+#include "src/tint/ir/switch.h"
+#include "src/tint/ir/var.h"
+#include "src/tint/switch.h"
+#include "src/tint/type/pointer.h"
+
+namespace tint::ir {
+namespace {
+
+class Validator {
+ public:
+ explicit Validator(const Module& mod) : mod_(mod) {}
+
+ ~Validator() {}
+
+ utils::Result<Success, diag::List> IsValid() {
+ CheckRootBlock(mod_.root_block);
+
+ for (const auto* func : mod_.functions) {
+ CheckFunction(func);
+ }
+
+ if (diagnostics_.contains_errors()) {
+ return std::move(diagnostics_);
+ }
+ return Success{};
+ }
+
+ private:
+ const Module& mod_;
+ diag::List diagnostics_;
+
+ void AddError(const std::string& err) { diagnostics_.add_error(tint::diag::System::IR, err); }
+
+ std::string Name(const Value* v) { return mod_.NameOf(v).Name(); }
+
+ void CheckRootBlock(const Block* blk) {
+ if (!blk) {
+ return;
+ }
+
+ for (const auto* inst : *blk) {
+ auto* var = inst->As<ir::Var>();
+ if (!var) {
+ AddError(std::string("root block: invalid instruction: ") + inst->TypeInfo().name);
+ continue;
+ }
+ if (!var->Type()->Is<type::Pointer>()) {
+ AddError(std::string("root block: 'var' ") + Name(var) +
+ "type is not a pointer: " + var->Type()->TypeInfo().name);
+ }
+ }
+ }
+
+ void CheckFunction(const Function* func) {
+ for (const auto* param : func->Params()) {
+ if (param == nullptr) {
+ AddError("function '" + Name(func) + "': null parameter");
+ continue;
+ }
+ }
+
+ if (func->StartTarget() == nullptr) {
+ AddError("function '" + Name(func) + "': null start target");
+ } else {
+ CheckBlock(func->StartTarget());
+ }
+ }
+
+ void CheckBlock(const Block* blk) {
+ if (!blk->HasBranchTarget()) {
+ AddError("block: does not end in a branch");
+ }
+
+ for (const auto* inst : *blk) {
+ if (inst->Is<ir::Branch>() && inst != blk->Branch()) {
+ AddError("block: branch which isn't the final instruction");
+ continue;
+ }
+
+ CheckInstruction(inst);
+ }
+ }
+
+ void CheckInstruction(const Instruction* inst) {
+ tint::Switch(
+ inst, //
+ [&](const ir::Return* ret) {
+ if (ret->Func() == nullptr) {
+ AddError("return: null function");
+ }
+ },
+ [&](Default) {
+ AddError(std::string("missing validation of: ") + inst->TypeInfo().name);
+ });
+ }
+};
+
+} // namespace
+
+utils::Result<Success, std::string> Validate(const Module& mod) {
+ Validator v(mod);
+ auto r = v.IsValid();
+ if (!r) {
+ return r.Failure().str();
+ }
+ return Success{};
+}
+
+} // namespace tint::ir
diff --git a/src/tint/ir/validate.h b/src/tint/ir/validate.h
new file mode 100644
index 0000000..233295b
--- /dev/null
+++ b/src/tint/ir/validate.h
@@ -0,0 +1,35 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_VALIDATE_H_
+#define SRC_TINT_IR_VALIDATE_H_
+
+#include <string>
+
+#include "src/tint/ir/module.h"
+#include "src/tint/utils/result.h"
+
+namespace tint::ir {
+
+/// Signifies the validation completed successfully
+struct Success {};
+
+/// Validates that a given IR module is correctly formed
+/// @param mod the module to validate
+/// @returns true on success, an error result otherwise
+utils::Result<Success, std::string> Validate(const Module& mod);
+
+} // namespace tint::ir
+
+#endif // SRC_TINT_IR_VALIDATE_H_
diff --git a/src/tint/ir/validate_test.cc b/src/tint/ir/validate_test.cc
new file mode 100644
index 0000000..9c3676c
--- /dev/null
+++ b/src/tint/ir/validate_test.cc
@@ -0,0 +1,105 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/validate.h"
+#include "gmock/gmock.h"
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/ir_test_helper.h"
+#include "src/tint/type/pointer.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+
+using IR_ValidateTest = IRTestHelper;
+
+TEST_F(IR_ValidateTest, RootBlock_Var) {
+ mod.root_block = b.CreateRootBlockIfNeeded();
+ mod.root_block->Append(b.Declare(mod.Types().pointer(
+ mod.Types().i32(), builtin::AddressSpace::kPrivate, builtin::Access::kReadWrite)));
+ auto res = ir::Validate(mod);
+ EXPECT_TRUE(res) << res.Failure();
+}
+
+TEST_F(IR_ValidateTest, RootBlock_NonVar) {
+ mod.root_block = b.CreateRootBlockIfNeeded();
+ mod.root_block->Append(b.CreateLoop());
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure(), "error: root block: invalid instruction: tint::ir::Loop");
+}
+
+TEST_F(IR_ValidateTest, RootBlock_VarBadType) {
+ mod.root_block = b.CreateRootBlockIfNeeded();
+ mod.root_block->Append(b.Declare(mod.Types().i32()));
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure(), "error: root block: 'var' type is not a pointer: tint::type::I32");
+}
+
+TEST_F(IR_ValidateTest, Function) {
+ auto* f = b.CreateFunction("my_func", mod.Types().void_());
+ mod.functions.Push(f);
+
+ f->SetParams(
+ utils::Vector{b.FunctionParam(mod.Types().i32()), b.FunctionParam(mod.Types().f32())});
+ f->StartTarget()->SetInstructions(utils::Vector{b.Return(f)});
+ auto res = ir::Validate(mod);
+ EXPECT_TRUE(res) << res.Failure();
+}
+
+TEST_F(IR_ValidateTest, Function_NullStartTarget) {
+ auto* f = b.CreateFunction("my_func", mod.Types().void_());
+ mod.functions.Push(f);
+
+ f->SetStartTarget(nullptr);
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure(), "error: function 'my_func': null start target");
+}
+
+TEST_F(IR_ValidateTest, Function_ParamNull) {
+ auto* f = b.CreateFunction("my_func", mod.Types().void_());
+ mod.functions.Push(f);
+
+ f->SetParams(utils::Vector<FunctionParam*, 1>{nullptr});
+ f->StartTarget()->SetInstructions(utils::Vector{b.Return(f)});
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure(), "error: function 'my_func': null parameter");
+}
+
+TEST_F(IR_ValidateTest, Block_NoBranchAtEnd) {
+ auto* f = b.CreateFunction("my_func", mod.Types().void_());
+ mod.functions.Push(f);
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure(), "error: block: does not end in a branch");
+}
+
+TEST_F(IR_ValidateTest, Block_BranchInMiddle) {
+ auto* f = b.CreateFunction("my_func", mod.Types().void_());
+ mod.functions.Push(f);
+
+ f->StartTarget()->SetInstructions(utils::Vector{b.Return(f), b.Return(f)});
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure(), "error: block: branch which isn't the final instruction");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc
index 514e7f8..835c305 100644
--- a/src/tint/reader/wgsl/parser_impl.cc
+++ b/src/tint/reader/wgsl/parser_impl.cc
@@ -579,7 +579,7 @@
errored = true;
}
if (decl.matched) {
- if (!expect_attributes_consumed(attrs.value)) {
+ if (expect_attributes_consumed(attrs.value).errored) {
return Failure::kErrored;
}
return kSuccess;
@@ -590,7 +590,7 @@
errored = true;
}
if (str.matched) {
- if (!expect_attributes_consumed(attrs.value)) {
+ if (expect_attributes_consumed(attrs.value).errored) {
return Failure::kErrored;
}
return kSuccess;
@@ -902,7 +902,7 @@
if (val != ENUM::kUndefined) {
synchronized_ = true;
next();
- return {val, t.source()};
+ return val;
}
}
@@ -2115,10 +2115,10 @@
// | PERIOD swizzle_name component_or_swizzle_specifier?
Maybe<const ast::Expression*> ParserImpl::component_or_swizzle_specifier(
const ast::Expression* prefix) {
- Source source;
+ MultiTokenSource source(this, prefix->source);
while (continue_parsing()) {
- if (match(Token::Type::kBracketLeft, &source)) {
+ if (match(Token::Type::kBracketLeft)) {
auto res = sync(Token::Type::kBracketRight, [&]() -> Maybe<const ast::Expression*> {
auto param = expression();
if (param.errored) {
@@ -2132,7 +2132,7 @@
return Failure::kErrored;
}
- return create<ast::IndexAccessorExpression>(source, prefix, param.value);
+ return create<ast::IndexAccessorExpression>(source.Source(), prefix, param.value);
});
if (res.errored) {
@@ -2148,7 +2148,7 @@
return Failure::kErrored;
}
- prefix = builder_.MemberAccessor(ident.source, prefix, ident.value);
+ prefix = builder_.MemberAccessor(source.Source(), prefix, ident.value);
continue;
}
@@ -2162,22 +2162,12 @@
// : PAREN_LEFT ((expression COMMA)* expression COMMA?)? PAREN_RIGHT
Expect<ParserImpl::ExpressionList> ParserImpl::expect_argument_expression_list(
std::string_view use) {
- return expect_paren_block(use, [&]() -> Expect<ExpressionList> {
- ExpressionList ret;
- while (continue_parsing()) {
- auto arg = expression();
- if (arg.errored) {
- return Failure::kErrored;
- } else if (!arg.matched) {
- break;
- }
- ret.Push(arg.value);
-
- if (!match(Token::Type::kComma)) {
- break;
- }
+ return expect_paren_block(use, [&]() -> Expect<ParserImpl::ExpressionList> { //
+ auto list = expression_list(use, Token::Type::kParenRight);
+ if (list.errored) {
+ return Failure::kErrored;
}
- return ret;
+ return list.value;
});
}
@@ -2482,10 +2472,21 @@
return add_error(t, "expected expression for " + std::string(use));
}
-Expect<utils::Vector<const ast::Expression*, 3>> ParserImpl::expect_expression_list(
- std::string_view use,
- Token::Type terminator) {
- utils::Vector<const ast::Expression*, 3> exprs;
+Maybe<ParserImpl::ExpressionList> ParserImpl::expression_list(std::string_view use,
+ Token::Type terminator) {
+ if (peek_is(terminator)) {
+ return Failure::kNoMatch;
+ }
+ auto list = expect_expression_list(use, terminator);
+ if (list.errored) {
+ return Failure::kErrored;
+ }
+ return list.value;
+}
+
+Expect<ParserImpl::ExpressionList> ParserImpl::expect_expression_list(std::string_view use,
+ Token::Type terminator) {
+ ParserImpl::ExpressionList exprs;
while (continue_parsing()) {
auto expr = expect_expression(use);
if (expr.errored) {
@@ -2495,7 +2496,22 @@
if (peek_is(terminator)) {
break;
}
- if (!expect(use, Token::Type::kComma)) {
+
+ // Check if the next token is a template start, which was likely intended as a less-than.
+ if (expect_next_not_template_list(expr->source).errored) {
+ return Failure::kErrored; // expect_next_not_template_list() raised an error.
+ }
+ if (!match(Token::Type::kComma)) {
+ // Next expression is not a terminator or comma, so this is a parse error.
+
+ // Check if last parsed expression was a templated identifier, which was likely indented
+ // as a less-than / greater-than.
+ if (expect_not_templated_ident_expr(expr.value).errored) {
+ return Failure::kErrored; // expect_not_templated_ident_expr() raised an error.
+ }
+
+ // Emit the expected ',' error
+ expect(use, Token::Type::kComma);
return Failure::kErrored;
}
if (peek_is(terminator)) {
@@ -3073,12 +3089,58 @@
}
}
-bool ParserImpl::expect_attributes_consumed(utils::VectorRef<const ast::Attribute*> in) {
+Expect<Void> ParserImpl::expect_attributes_consumed(utils::VectorRef<const ast::Attribute*> in) {
if (in.IsEmpty()) {
- return true;
+ return kSuccess;
}
add_error(in[0]->source, "unexpected attributes");
- return false;
+ return Failure::kErrored;
+}
+
+Expect<Void> ParserImpl::expect_next_not_template_list(const Source& lhs_source) {
+ Source end;
+ if (!match(Token::Type::kTemplateArgsLeft, &end)) {
+ return kSuccess;
+ }
+
+ // Try to find end of template
+ for (size_t i = 0; i < 32; i++) {
+ if (auto& t = peek(i); t.type() == Token::Type::kTemplateArgsRight) {
+ end = t.source();
+ }
+ }
+ Source template_source = lhs_source;
+ template_source.range.end = end.range.end;
+ add_error(template_source, "parsed as template list");
+
+ if (auto rhs = expression(); rhs.matched) {
+ Source lt_source = lhs_source;
+ lt_source.range.end = rhs->source.range.end;
+ add_note(lt_source,
+ "if this is intended to be a less-than expression then wrap in parentheses");
+ }
+ return Failure::kErrored;
+}
+
+Expect<Void> ParserImpl::expect_not_templated_ident_expr(const ast::Expression* expr) {
+ auto* ident_expr = expr->As<ast::IdentifierExpression>();
+ if (!ident_expr) {
+ return kSuccess;
+ }
+ auto* ident = ident_expr->identifier->As<ast::TemplatedIdentifier>();
+ if (!ident) {
+ return kSuccess;
+ }
+
+ add_error(ident->source, "parsed as template list");
+
+ if (auto rhs = expression(); rhs.matched) {
+ Source gt_source = ident->arguments.Back()->source;
+ gt_source.range.end = rhs->source.range.end;
+ add_note(gt_source,
+ "if this is intended to be a greater-than expression then wrap in parentheses");
+ }
+ return Failure::kErrored;
}
// severity_control_name
@@ -3196,8 +3258,11 @@
return false;
}
-Expect<int32_t> ParserImpl::expect_sint(std::string_view use) {
+Expect<int32_t> ParserImpl::expect_sint(std::string_view use, Source* source /* = nullptr */) {
auto& t = peek();
+ if (source) {
+ *source = t.source();
+ }
if (!t.Is(Token::Type::kIntLiteral) && !t.Is(Token::Type::kIntLiteral_I)) {
return add_error(t.source(), "expected signed integer literal", use);
}
@@ -3210,33 +3275,35 @@
}
next();
- return {static_cast<int32_t>(t.to_i64()), t.source()};
+ return static_cast<int32_t>(t.to_i64());
}
Expect<uint32_t> ParserImpl::expect_positive_sint(std::string_view use) {
- auto sint = expect_sint(use);
+ Source source;
+ auto sint = expect_sint(use, &source);
if (sint.errored) {
return Failure::kErrored;
}
if (sint.value < 0) {
- return add_error(sint.source, std::string(use) + " must be positive");
+ return add_error(source, std::string(use) + " must be positive");
}
- return {static_cast<uint32_t>(sint.value), sint.source};
+ return static_cast<uint32_t>(sint.value);
}
Expect<uint32_t> ParserImpl::expect_nonzero_positive_sint(std::string_view use) {
- auto sint = expect_sint(use);
+ Source source;
+ auto sint = expect_sint(use, &source);
if (sint.errored) {
return Failure::kErrored;
}
if (sint.value <= 0) {
- return add_error(sint.source, std::string(use) + " must be greater than 0");
+ return add_error(source, std::string(use) + " must be greater than 0");
}
- return {static_cast<uint32_t>(sint.value), sint.source};
+ return static_cast<uint32_t>(sint.value);
}
Expect<const ast::Identifier*> ParserImpl::expect_ident(
diff --git a/src/tint/reader/wgsl/parser_impl.h b/src/tint/reader/wgsl/parser_impl.h
index c7ac96d..b167f56 100644
--- a/src/tint/reader/wgsl/parser_impl.h
+++ b/src/tint/reader/wgsl/parser_impl.h
@@ -99,10 +99,9 @@
/// Constructor for a successful parse.
/// @param val the result value of the parse
- /// @param s the optional source of the value
template <typename U>
- inline Expect(U&& val, const Source& s = {}) // NOLINT
- : value(std::forward<U>(val)), source(s) {}
+ inline Expect(U&& val) // NOLINT
+ : value(std::forward<U>(val)) {}
/// Constructor for parse error.
inline Expect(Failure::Errored) : errored(true) {} // NOLINT
@@ -130,8 +129,6 @@
/// The expected value of a successful parse.
/// Zero-initialized when there was a parse error.
T value{};
- /// Optional source of the value.
- Source source;
/// True if there was a error parsing.
bool errored = false;
};
@@ -565,8 +562,13 @@
/// @param use the use of the expression list
/// @param terminator the terminating token for the list
/// @returns the parsed expression list or error
- Expect<utils::Vector<const ast::Expression*, 3>> expect_expression_list(std::string_view use,
- Token::Type terminator);
+ Maybe<ParserImpl::ExpressionList> expression_list(std::string_view use, Token::Type terminator);
+ /// Parses a comma separated expression list, with at least one expression
+ /// @param use the use of the expression list
+ /// @param terminator the terminating token for the list
+ /// @returns the parsed expression list or error
+ Expect<ParserImpl::ExpressionList> expect_expression_list(std::string_view use,
+ Token::Type terminator);
/// Parses the `bitwise_expression.post.unary_expression` grammar element
/// @param lhs the left side of the expression
/// @returns the parsed expression or nullptr
@@ -679,7 +681,9 @@
/// Consumes the next token on match.
/// @param use a description of what was being parsed if an error was raised
/// @returns the parsed integer.
- Expect<int32_t> expect_sint(std::string_view use);
+ /// @param source if not nullptr, the next token's source is written to this
+ /// pointer, regardless of success or error
+ Expect<int32_t> expect_sint(std::string_view use, Source* source = nullptr);
/// Parses a signed integer from the next token in the stream, erroring if
/// the next token is not a signed integer or is negative.
/// Consumes the next token if it is a signed integer (not necessarily
@@ -820,7 +824,28 @@
/// Reports an error if the attribute list `list` is not empty.
/// Used to ensure that all attributes are consumed.
- bool expect_attributes_consumed(utils::VectorRef<const ast::Attribute*> list);
+ Expect<Void> expect_attributes_consumed(utils::VectorRef<const ast::Attribute*> list);
+
+ /// Raises an error if the next token is the start of a template list.
+ /// Used to hint to the user that the parser interpreted the following as a templated identifier
+ /// expression:
+ ///
+ /// ```
+ /// a < b, c >
+ /// ^~~~~~~~
+ /// ```
+ Expect<Void> expect_next_not_template_list(const Source& lhs_source);
+
+ /// Raises an error if the parsed expression is a templated identifier expression
+ /// Used to hint to the user that the parser intepreted the following as a templated identifier
+ /// expression:
+ ///
+ /// ```
+ /// a < b, c > d
+ /// ^^^^^^^^^^
+ /// expr
+ /// ```
+ Expect<Void> expect_not_templated_ident_expr(const ast::Expression* expr);
/// Parses the given enum, providing sensible error messages if the next token does not match
/// any of the enum values.
diff --git a/src/tint/reader/wgsl/parser_impl_argument_expression_list_test.cc b/src/tint/reader/wgsl/parser_impl_argument_expression_list_test.cc
index e22a706..fbab695 100644
--- a/src/tint/reader/wgsl/parser_impl_argument_expression_list_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_argument_expression_list_test.cc
@@ -72,7 +72,7 @@
auto e = p->expect_argument_expression_list("argument list");
ASSERT_TRUE(p->has_error());
ASSERT_TRUE(e.errored);
- EXPECT_EQ(p->error(), "1:3: expected ')' for argument list");
+ EXPECT_EQ(p->error(), "1:3: expected ',' for argument list");
}
TEST_F(ParserImplTest, ArgumentExpressionList_HandlesMissingExpression_0) {
@@ -80,7 +80,7 @@
auto e = p->expect_argument_expression_list("argument list");
ASSERT_TRUE(p->has_error());
ASSERT_TRUE(e.errored);
- EXPECT_EQ(p->error(), "1:2: expected ')' for argument list");
+ EXPECT_EQ(p->error(), "1:2: expected expression for argument list");
}
TEST_F(ParserImplTest, ArgumentExpressionList_HandlesMissingExpression_1) {
@@ -88,7 +88,7 @@
auto e = p->expect_argument_expression_list("argument list");
ASSERT_TRUE(p->has_error());
ASSERT_TRUE(e.errored);
- EXPECT_EQ(p->error(), "1:5: expected ')' for argument list");
+ EXPECT_EQ(p->error(), "1:5: expected expression for argument list");
}
TEST_F(ParserImplTest, ArgumentExpressionList_HandlesInvalidExpression) {
@@ -96,7 +96,7 @@
auto e = p->expect_argument_expression_list("argument list");
ASSERT_TRUE(p->has_error());
ASSERT_TRUE(e.errored);
- EXPECT_EQ(p->error(), "1:2: expected ')' for argument list");
+ EXPECT_EQ(p->error(), "1:2: expected expression for argument list");
}
} // namespace
diff --git a/src/tint/reader/wgsl/parser_impl_call_stmt_test.cc b/src/tint/reader/wgsl/parser_impl_call_stmt_test.cc
index 5fd1b59..d9a2491 100644
--- a/src/tint/reader/wgsl/parser_impl_call_stmt_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_call_stmt_test.cc
@@ -83,7 +83,7 @@
EXPECT_TRUE(p->has_error());
EXPECT_TRUE(e.errored);
EXPECT_FALSE(e.matched);
- EXPECT_EQ(p->error(), "1:3: expected ')' for function call");
+ EXPECT_EQ(p->error(), "1:3: expected expression for function call");
}
TEST_F(ParserImplTest, Statement_Call_Missing_Semi) {
@@ -101,7 +101,7 @@
EXPECT_TRUE(p->has_error());
EXPECT_TRUE(e.errored);
EXPECT_FALSE(e.matched);
- EXPECT_EQ(p->error(), "1:5: expected ')' for function call");
+ EXPECT_EQ(p->error(), "1:5: expected ',' for function call");
}
} // namespace
diff --git a/src/tint/reader/wgsl/parser_impl_error_msg_test.cc b/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
index af50894..9156931 100644
--- a/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
@@ -158,15 +158,39 @@
TEST_F(ParserImplErrorTest, CallExprMissingRParen) {
EXPECT("fn f() { x = f(1.; }",
- R"(test.wgsl:1:18 error: expected ')' for function call
+ R"(test.wgsl:1:18 error: expected ',' for function call
fn f() { x = f(1.; }
^
)");
}
+TEST_F(ParserImplErrorTest, CallStmtArgsTreatedAsTemplateLHS) {
+ EXPECT("fn f() { f( a, b.x < c, d > e ); }",
+ R"(test.wgsl:1:16 error: parsed as template list
+fn f() { f( a, b.x < c, d > e ); }
+ ^^^^^^^^^^^^
+
+test.wgsl:1:16 note: if this is intended to be a less-than expression then wrap in parentheses
+fn f() { f( a, b.x < c, d > e ); }
+ ^^^^^^^
+)");
+}
+
+TEST_F(ParserImplErrorTest, CallStmtArgsTreatedAsTemplateRHS) {
+ EXPECT("fn f() { f( a, b < c, d > e ); }",
+ R"(test.wgsl:1:16 error: parsed as template list
+fn f() { f( a, b < c, d > e ); }
+ ^^^^^^^^^^
+
+test.wgsl:1:23 note: if this is intended to be a greater-than expression then wrap in parentheses
+fn f() { f( a, b < c, d > e ); }
+ ^^^^^
+)");
+}
+
TEST_F(ParserImplErrorTest, CallStmtMissingRParen) {
EXPECT("fn f() { f(1.; }",
- R"(test.wgsl:1:14 error: expected ')' for function call
+ R"(test.wgsl:1:14 error: expected ',' for function call
fn f() { f(1.; }
^
)");
@@ -174,7 +198,7 @@
TEST_F(ParserImplErrorTest, CallStmtInvalidArgument0) {
EXPECT("fn f() { f(<); }",
- R"(test.wgsl:1:12 error: expected ')' for function call
+ R"(test.wgsl:1:12 error: expected expression for function call
fn f() { f(<); }
^
)");
@@ -182,7 +206,7 @@
TEST_F(ParserImplErrorTest, CallStmtInvalidArgument1) {
EXPECT("fn f() { f(1.0, <); }",
- R"(test.wgsl:1:17 error: expected ')' for function call
+ R"(test.wgsl:1:17 error: expected expression for function call
fn f() { f(1.0, <); }
^
)");
@@ -206,7 +230,7 @@
TEST_F(ParserImplErrorTest, InitializerExprMissingRParen) {
EXPECT("fn f() { x = vec2<u32>(1,2; }",
- R"(test.wgsl:1:27 error: expected ')' for function call
+ R"(test.wgsl:1:27 error: expected ',' for function call
fn f() { x = vec2<u32>(1,2; }
^
)");
@@ -498,7 +522,7 @@
TEST_F(ParserImplErrorTest, GlobalDeclConstMissingRParen) {
EXPECT("const i : vec2<i32> = vec2<i32>(1., 2.;",
- R"(test.wgsl:1:39 error: expected ')' for function call
+ R"(test.wgsl:1:39 error: expected ',' for function call
const i : vec2<i32> = vec2<i32>(1., 2.;
^
)");
@@ -549,7 +573,7 @@
TEST_F(ParserImplErrorTest, GlobalDeclConstExprMissingRParen) {
EXPECT("const i : vec2<i32> = vec2<i32>(1, 2;",
- R"(test.wgsl:1:37 error: expected ')' for function call
+ R"(test.wgsl:1:37 error: expected ',' for function call
const i : vec2<i32> = vec2<i32>(1, 2;
^
)");
diff --git a/src/tint/reader/wgsl/parser_impl_for_stmt_test.cc b/src/tint/reader/wgsl/parser_impl_for_stmt_test.cc
index 528be9f..e97ead9 100644
--- a/src/tint/reader/wgsl/parser_impl_for_stmt_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_for_stmt_test.cc
@@ -314,7 +314,7 @@
// Test a for loop with an invalid continuing statement.
TEST_F(ForStmtErrorTest, InvalidContinuingAsFuncCall) {
std::string for_str = "for (;; a(,) ) { }";
- std::string error_str = "1:11: expected ')' for function call";
+ std::string error_str = "1:11: expected expression for function call";
TestForWithError(for_str, error_str);
}
diff --git a/src/tint/reader/wgsl/parser_impl_primary_expression_test.cc b/src/tint/reader/wgsl/parser_impl_primary_expression_test.cc
index 904e881..51392da 100644
--- a/src/tint/reader/wgsl/parser_impl_primary_expression_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_primary_expression_test.cc
@@ -84,7 +84,7 @@
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
ASSERT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:25: expected ')' for function call");
+ EXPECT_EQ(p->error(), "1:25: expected ',' for function call");
}
TEST_F(ParserImplTest, PrimaryExpression_TypeDecl_InvalidValue) {
@@ -94,7 +94,7 @@
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
ASSERT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:5: expected ')' for function call");
+ EXPECT_EQ(p->error(), "1:5: expected expression for function call");
}
TEST_F(ParserImplTest, PrimaryExpression_TypeDecl_StructInitializer_Empty) {
diff --git a/src/tint/reader/wgsl/parser_impl_singular_expression_test.cc b/src/tint/reader/wgsl/parser_impl_singular_expression_test.cc
index 8809ba5..3f5dc16 100644
--- a/src/tint/reader/wgsl/parser_impl_singular_expression_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_singular_expression_test.cc
@@ -55,6 +55,11 @@
EXPECT_EQ(ident_expr->identifier->symbol, p->builder().Symbols().Get("a"));
ASSERT_TRUE(idx->index->Is<ast::BinaryExpression>());
+
+ EXPECT_EQ(e->source.range.begin.line, 1u);
+ EXPECT_EQ(e->source.range.begin.column, 1u);
+ EXPECT_EQ(e->source.range.end.line, 1u);
+ EXPECT_EQ(e->source.range.end.column, 13u);
}
TEST_F(ParserImplTest, SingularExpression_Array_MissingIndex) {
@@ -141,7 +146,7 @@
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:3: expected ')' for function call");
+ EXPECT_EQ(p->error(), "1:3: expected expression for function call");
}
TEST_F(ParserImplTest, SingularExpression_Call_MissingRightParen) {
@@ -151,7 +156,7 @@
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:3: expected ')' for function call");
+ EXPECT_EQ(p->error(), "1:3: expected expression for function call");
}
TEST_F(ParserImplTest, SingularExpression_MemberAccessor) {
@@ -169,6 +174,11 @@
p->builder().Symbols().Get("a"));
EXPECT_EQ(m->member->symbol, p->builder().Symbols().Get("b"));
+
+ EXPECT_EQ(e->source.range.begin.line, 1u);
+ EXPECT_EQ(e->source.range.begin.column, 1u);
+ EXPECT_EQ(e->source.range.end.line, 1u);
+ EXPECT_EQ(e->source.range.end.column, 4u);
}
TEST_F(ParserImplTest, SingularExpression_MemberAccesssor_InvalidIdent) {
@@ -212,20 +222,45 @@
const auto* outer_accessor = e->As<ast::IndexAccessorExpression>();
ASSERT_TRUE(outer_accessor);
+ EXPECT_EQ(outer_accessor->source.range.begin.line, 1u);
+ EXPECT_EQ(outer_accessor->source.range.begin.column, 1u);
+ EXPECT_EQ(outer_accessor->source.range.end.line, 1u);
+ EXPECT_EQ(outer_accessor->source.range.end.column, 8u);
+
const auto* outer_object = outer_accessor->object->As<ast::IdentifierExpression>();
ASSERT_TRUE(outer_object);
EXPECT_EQ(outer_object->identifier->symbol, p->builder().Symbols().Get("a"));
+ EXPECT_EQ(outer_object->source.range.begin.line, 1u);
+ EXPECT_EQ(outer_object->source.range.begin.column, 1u);
+ EXPECT_EQ(outer_object->source.range.end.line, 1u);
+ EXPECT_EQ(outer_object->source.range.end.column, 2u);
+
const auto* inner_accessor = outer_accessor->index->As<ast::IndexAccessorExpression>();
ASSERT_TRUE(inner_accessor);
+ EXPECT_EQ(inner_accessor->source.range.begin.line, 1u);
+ EXPECT_EQ(inner_accessor->source.range.begin.column, 3u);
+ EXPECT_EQ(inner_accessor->source.range.end.line, 1u);
+ EXPECT_EQ(inner_accessor->source.range.end.column, 7u);
+
const auto* inner_object = inner_accessor->object->As<ast::IdentifierExpression>();
ASSERT_TRUE(inner_object);
EXPECT_EQ(inner_object->identifier->symbol, p->builder().Symbols().Get("b"));
+ EXPECT_EQ(inner_object->source.range.begin.line, 1u);
+ EXPECT_EQ(inner_object->source.range.begin.column, 3u);
+ EXPECT_EQ(inner_object->source.range.end.line, 1u);
+ EXPECT_EQ(inner_object->source.range.end.column, 4u);
+
const auto* index_expr = inner_accessor->index->As<ast::IdentifierExpression>();
ASSERT_TRUE(index_expr);
EXPECT_EQ(index_expr->identifier->symbol, p->builder().Symbols().Get("c"));
+
+ EXPECT_EQ(index_expr->source.range.begin.line, 1u);
+ EXPECT_EQ(index_expr->source.range.begin.column, 5u);
+ EXPECT_EQ(index_expr->source.range.end.line, 1u);
+ EXPECT_EQ(index_expr->source.range.end.column, 6u);
}
} // namespace
diff --git a/src/tint/type/manager.cc b/src/tint/type/manager.cc
index 136f910..d9c8667 100644
--- a/src/tint/type/manager.cc
+++ b/src/tint/type/manager.cc
@@ -22,6 +22,7 @@
#include "src/tint/type/f32.h"
#include "src/tint/type/i32.h"
#include "src/tint/type/matrix.h"
+#include "src/tint/type/pointer.h"
#include "src/tint/type/type.h"
#include "src/tint/type/u32.h"
#include "src/tint/type/vector.h"
@@ -155,4 +156,10 @@
/* implicit stride */ elem_ty->Align());
}
+const type::Pointer* Manager::pointer(const type::Type* subtype,
+ builtin::AddressSpace address_space,
+ builtin::Access access) {
+ return Get<type::Pointer>(subtype, address_space, access);
+}
+
} // namespace tint::type
diff --git a/src/tint/type/manager.h b/src/tint/type/manager.h
index 6492e53..fa29ae0 100644
--- a/src/tint/type/manager.h
+++ b/src/tint/type/manager.h
@@ -17,6 +17,8 @@
#include <utility>
+#include "src/tint/builtin/access.h"
+#include "src/tint/builtin/address_space.h"
#include "src/tint/type/type.h"
#include "src/tint/type/unique_node.h"
#include "src/tint/utils/hash.h"
@@ -32,6 +34,7 @@
class F32;
class I32;
class Matrix;
+class Pointer;
class U32;
class Vector;
class Void;
@@ -194,6 +197,14 @@
/// @returns the runtime array type
const type::Array* runtime_array(const type::Type* elem_ty, uint32_t stride = 0);
+ /// @param subtype the pointer subtype
+ /// @param address_space the address space
+ /// @param access the access settings
+ /// @returns the pointer type
+ const type::Pointer* pointer(const type::Type* subtype,
+ builtin::AddressSpace address_space,
+ builtin::Access access);
+
/// @returns an iterator to the beginning of the types
TypeIterator begin() const { return types_.begin(); }
/// @returns an iterator to the end of the types
diff --git a/src/tint/utils/hashmap_base.h b/src/tint/utils/hashmap_base.h
index f31dcf2..73cb0d7 100644
--- a/src/tint/utils/hashmap_base.h
+++ b/src/tint/utils/hashmap_base.h
@@ -25,6 +25,10 @@
#include "src/tint/utils/hash.h"
#include "src/tint/utils/vector.h"
+#ifndef NDEBUG
+#define TINT_ASSERT_ITERATORS_NOT_INVALIDATED
+#endif
+
namespace tint::utils {
/// Action taken by a map mutation
@@ -190,10 +194,20 @@
class IteratorT {
public:
/// @returns the value pointed to by this iterator
- EntryRef<IS_CONST> operator->() const { return *this; }
+ EntryRef<IS_CONST> operator->() const {
+#ifdef TINT_ASSERT_ITERATORS_NOT_INVALIDATED
+ TINT_ASSERT(Utils, map.Generation() == initial_generation &&
+ "iterator invalidated by container modification");
+#endif
+ return *this;
+ }
/// @returns a reference to the value at the iterator
EntryRef<IS_CONST> operator*() const {
+#ifdef TINT_ASSERT_ITERATORS_NOT_INVALIDATED
+ TINT_ASSERT(Utils, map.Generation() == initial_generation &&
+ "iterator invalidated by container modification");
+#endif
auto& ref = current->entry.value();
if constexpr (ValueIsVoid) {
return ref;
@@ -205,6 +219,10 @@
/// Increments the iterator
/// @returns this iterator
IteratorT& operator++() {
+#ifdef TINT_ASSERT_ITERATORS_NOT_INVALIDATED
+ TINT_ASSERT(Utils, map.Generation() == initial_generation &&
+ "iterator invalidated by container modification");
+#endif
if (current == end) {
return *this;
}
@@ -216,12 +234,24 @@
/// Equality operator
/// @param other the other iterator to compare this iterator to
/// @returns true if this iterator is equal to other
- bool operator==(const IteratorT& other) const { return current == other.current; }
+ bool operator==(const IteratorT& other) const {
+#ifdef TINT_ASSERT_ITERATORS_NOT_INVALIDATED
+ TINT_ASSERT(Utils, map.Generation() == initial_generation &&
+ "iterator invalidated by container modification");
+#endif
+ return current == other.current;
+ }
/// Inequality operator
/// @param other the other iterator to compare this iterator to
/// @returns true if this iterator is not equal to other
- bool operator!=(const IteratorT& other) const { return current != other.current; }
+ bool operator!=(const IteratorT& other) const {
+#ifdef TINT_ASSERT_ITERATORS_NOT_INVALIDATED
+ TINT_ASSERT(Utils, map.Generation() == initial_generation &&
+ "iterator invalidated by container modification");
+#endif
+ return current != other.current;
+ }
private:
/// Friend class
@@ -229,7 +259,17 @@
using SLOT = std::conditional_t<IS_CONST, const Slot, Slot>;
- IteratorT(SLOT* c, SLOT* e) : current(c), end(e) { SkipToNextValue(); }
+ IteratorT(SLOT* c, SLOT* e, [[maybe_unused]] const HashmapBase& m)
+ : current(c),
+ end(e)
+#ifdef TINT_ASSERT_ITERATORS_NOT_INVALIDATED
+ ,
+ map(m),
+ initial_generation(m.Generation())
+#endif
+ {
+ SkipToNextValue();
+ }
/// Moves the iterator forward, stopping at the next slot that is not empty.
void SkipToNextValue() {
@@ -240,6 +280,11 @@
SLOT* current; /// The slot the iterator is pointing to
SLOT* end; /// One past the last slot in the map
+
+#ifdef TINT_ASSERT_ITERATORS_NOT_INVALIDATED
+ const HashmapBase& map; /// The hashmap that is being iterated over.
+ size_t initial_generation; /// The generation ID when the iterator was created.
+#endif
};
/// An immutable key and mutable value iterator
@@ -373,16 +418,16 @@
size_t Generation() const { return generation_; }
/// @returns an immutable iterator to the start of the map.
- ConstIterator begin() const { return ConstIterator{slots_.begin(), slots_.end()}; }
+ ConstIterator begin() const { return ConstIterator{slots_.begin(), slots_.end(), *this}; }
/// @returns an immutable iterator to the end of the map.
- ConstIterator end() const { return ConstIterator{slots_.end(), slots_.end()}; }
+ ConstIterator end() const { return ConstIterator{slots_.end(), slots_.end(), *this}; }
/// @returns an iterator to the start of the map.
- Iterator begin() { return Iterator{slots_.begin(), slots_.end()}; }
+ Iterator begin() { return Iterator{slots_.begin(), slots_.end(), *this}; }
/// @returns an iterator to the end of the map.
- Iterator end() { return Iterator{slots_.end(), slots_.end()}; }
+ Iterator end() { return Iterator{slots_.end(), slots_.end(), *this}; }
/// A debug function for checking that the map is in good health.
/// Asserts if the map is corrupted.
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index 89e6a1f..6196254 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -168,6 +168,14 @@
}
module_.PushType(spv::Op::OpConstantComposite, operands);
},
+ [&](const type::Array* arr) {
+ TINT_ASSERT(Writer, arr->ConstantCount());
+ OperandList operands = {Type(ty), id};
+ for (uint32_t i = 0; i < arr->ConstantCount(); i++) {
+ operands.push_back(Constant(constant->Index(i)));
+ }
+ module_.PushType(spv::Op::OpConstantComposite, operands);
+ },
[&](Default) {
TINT_ICE(Writer, diagnostics_) << "unhandled constant type: " << ty->FriendlyName();
});
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
index f6e5015..5ae21b3 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
@@ -36,7 +36,7 @@
TEST_P(Arithmetic, Scalar) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
func->StartTarget()->SetInstructions(
utils::Vector{b.CreateBinary(params.kind, MakeScalarType(params.type),
MakeScalarValue(params.type), MakeScalarValue(params.type)),
@@ -48,7 +48,7 @@
TEST_P(Arithmetic, Vector) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
func->StartTarget()->SetInstructions(
utils::Vector{b.CreateBinary(params.kind, MakeVectorType(params.type),
MakeVectorValue(params.type), MakeVectorValue(params.type)),
@@ -83,7 +83,7 @@
TEST_P(Bitwise, Scalar) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
func->StartTarget()->SetInstructions(
utils::Vector{b.CreateBinary(params.kind, MakeScalarType(params.type),
MakeScalarValue(params.type), MakeScalarValue(params.type)),
@@ -95,7 +95,7 @@
TEST_P(Bitwise, Vector) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
func->StartTarget()->SetInstructions(
utils::Vector{b.CreateBinary(params.kind, MakeVectorType(params.type),
MakeVectorValue(params.type), MakeVectorValue(params.type)),
@@ -122,9 +122,9 @@
TEST_P(Comparison, Scalar) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
func->StartTarget()->SetInstructions(
- utils::Vector{b.CreateBinary(params.kind, mod.Types().bool_(), MakeScalarValue(params.type),
+ utils::Vector{b.CreateBinary(params.kind, ty.bool_(), MakeScalarValue(params.type),
MakeScalarValue(params.type)),
b.Return(func)});
@@ -134,10 +134,10 @@
TEST_P(Comparison, Vector) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
func->StartTarget()->SetInstructions(
- utils::Vector{b.CreateBinary(params.kind, mod.Types().vec2(mod.Types().bool_()),
- MakeVectorValue(params.type), MakeVectorValue(params.type)),
+ utils::Vector{b.CreateBinary(params.kind, ty.vec2(ty.bool_()), MakeVectorValue(params.type),
+ MakeVectorValue(params.type)),
b.Return(func)});
@@ -191,10 +191,9 @@
BinaryTestCase{kBool, ir::Binary::Kind::kNotEqual, "OpLogicalNotEqual"}));
TEST_F(SpvGeneratorImplTest, Binary_Chain) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
- auto* a = b.Subtract(mod.Types().i32(), b.Constant(1_i), b.Constant(2_i));
- func->StartTarget()->SetInstructions(
- utils::Vector{a, b.Add(mod.Types().i32(), a, a), b.Return(func)});
+ auto* func = b.CreateFunction("foo", ty.void_());
+ auto* a = b.Subtract(ty.i32(), b.Constant(1_i), b.Constant(2_i));
+ func->StartTarget()->SetInstructions(utils::Vector{a, b.Add(ty.i32(), a, a), b.Return(func)});
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
index 3bc06c1..69d3f0f 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
@@ -37,7 +37,7 @@
TEST_P(Builtin_1arg, Scalar) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
func->StartTarget()->SetInstructions(
utils::Vector{b.Builtin(MakeScalarType(params.type), params.function,
utils::Vector{MakeScalarValue(params.type)}),
@@ -49,7 +49,7 @@
TEST_P(Builtin_1arg, Vector) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
func->StartTarget()->SetInstructions(
utils::Vector{b.Builtin(MakeVectorType(params.type), params.function,
utils::Vector{MakeVectorValue(params.type)}),
@@ -110,7 +110,7 @@
TEST_P(Builtin_2arg, Scalar) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
func->StartTarget()->SetInstructions(utils::Vector{
b.Builtin(MakeScalarType(params.type), params.function,
utils::Vector{MakeScalarValue(params.type), MakeScalarValue(params.type)}),
@@ -122,7 +122,7 @@
TEST_P(Builtin_2arg, Vector) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
func->StartTarget()->SetInstructions(utils::Vector{
b.Builtin(MakeVectorType(params.type), params.function,
utils::Vector{MakeVectorValue(params.type), MakeVectorValue(params.type)}),
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc
index 7cb240a..b0b7665 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc
@@ -17,6 +17,8 @@
namespace tint::writer::spirv {
namespace {
+using namespace tint::number_suffixes; // NOLINT
+
TEST_F(SpvGeneratorImplTest, Constant_Bool) {
generator_.Constant(b.Constant(true));
generator_.Constant(b.Constant(false));
@@ -65,7 +67,7 @@
TEST_F(SpvGeneratorImplTest, Constant_Vec4Bool) {
auto const_bool = [&](bool val) { return mod.constant_values.Get(val); };
auto* v = mod.constant_values.Composite(
- mod.Types().vec4(mod.Types().bool_()),
+ ty.vec4(ty.bool_()),
utils::Vector{const_bool(true), const_bool(false), const_bool(false), const_bool(true)});
generator_.Constant(b.Constant(v));
@@ -79,7 +81,7 @@
TEST_F(SpvGeneratorImplTest, Constant_Vec2i) {
auto const_i32 = [&](float val) { return mod.constant_values.Get(i32(val)); };
- auto* v = mod.constant_values.Composite(mod.Types().vec2(mod.Types().i32()),
+ auto* v = mod.constant_values.Composite(ty.vec2(ty.i32()),
utils::Vector{const_i32(42), const_i32(-1)});
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 1
@@ -93,8 +95,7 @@
TEST_F(SpvGeneratorImplTest, Constant_Vec3u) {
auto const_u32 = [&](float val) { return mod.constant_values.Get(u32(val)); };
auto* v = mod.constant_values.Composite(
- mod.Types().vec3(mod.Types().u32()),
- utils::Vector{const_u32(42), const_u32(0), const_u32(4000000000)});
+ ty.vec3(ty.u32()), utils::Vector{const_u32(42), const_u32(0), const_u32(4000000000)});
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 0
%2 = OpTypeVector %3 3
@@ -108,7 +109,7 @@
TEST_F(SpvGeneratorImplTest, Constant_Vec4f) {
auto const_f32 = [&](float val) { return mod.constant_values.Get(f32(val)); };
auto* v = mod.constant_values.Composite(
- mod.Types().vec4(mod.Types().f32()),
+ ty.vec4(ty.f32()),
utils::Vector{const_f32(42), const_f32(0), const_f32(0.25), const_f32(-1)});
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeFloat 32
@@ -123,7 +124,7 @@
TEST_F(SpvGeneratorImplTest, Constant_Vec2h) {
auto const_f16 = [&](float val) { return mod.constant_values.Get(f16(val)); };
- auto* v = mod.constant_values.Composite(mod.Types().vec2(mod.Types().f16()),
+ auto* v = mod.constant_values.Composite(ty.vec2(ty.f16()),
utils::Vector{const_f16(42), const_f16(0.25)});
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeFloat 16
@@ -136,16 +137,14 @@
TEST_F(SpvGeneratorImplTest, Constant_Mat2x3f) {
auto const_f32 = [&](float val) { return mod.constant_values.Get(f32(val)); };
- auto* f32 = mod.Types().f32();
+ auto* f32 = ty.f32();
auto* v = mod.constant_values.Composite(
- mod.Types().mat2x3(f32),
+ ty.mat2x3(f32),
utils::Vector{
mod.constant_values.Composite(
- mod.Types().vec3(f32),
- utils::Vector{const_f32(42), const_f32(-1), const_f32(0.25)}),
+ ty.vec3(f32), utils::Vector{const_f32(42), const_f32(-1), const_f32(0.25)}),
mod.constant_values.Composite(
- mod.Types().vec3(f32),
- utils::Vector{const_f32(-42), const_f32(0), const_f32(-0.25)}),
+ ty.vec3(f32), utils::Vector{const_f32(-42), const_f32(0), const_f32(-0.25)}),
});
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%4 = OpTypeFloat 32
@@ -165,19 +164,18 @@
TEST_F(SpvGeneratorImplTest, Constant_Mat4x2h) {
auto const_f16 = [&](float val) { return mod.constant_values.Get(f16(val)); };
- auto* f16 = mod.Types().f16();
+ auto* f16 = ty.f16();
auto* v = mod.constant_values.Composite(
- mod.Types().mat4x2(f16),
- utils::Vector{
- mod.constant_values.Composite(mod.Types().vec2(f16),
- utils::Vector{const_f16(42), const_f16(-1)}),
- mod.constant_values.Composite(mod.Types().vec2(f16),
- utils::Vector{const_f16(0), const_f16(0.25)}),
- mod.constant_values.Composite(mod.Types().vec2(f16),
- utils::Vector{const_f16(-42), const_f16(1)}),
- mod.constant_values.Composite(mod.Types().vec2(f16),
- utils::Vector{const_f16(0.5), const_f16(-0)}),
- });
+ ty.mat4x2(f16), utils::Vector{
+ mod.constant_values.Composite(
+ ty.vec2(f16), utils::Vector{const_f16(42), const_f16(-1)}),
+ mod.constant_values.Composite(
+ ty.vec2(f16), utils::Vector{const_f16(0), const_f16(0.25)}),
+ mod.constant_values.Composite(
+ ty.vec2(f16), utils::Vector{const_f16(-42), const_f16(1)}),
+ mod.constant_values.Composite(
+ ty.vec2(f16), utils::Vector{const_f16(0.5), const_f16(-0)}),
+ });
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%4 = OpTypeFloat 16
%3 = OpTypeVector %4 2
@@ -197,6 +195,56 @@
)");
}
+TEST_F(SpvGeneratorImplTest, Constant_Array_I32) {
+ auto* arr =
+ mod.constant_values.Composite(ty.array(ty.i32(), 4), utils::Vector{
+ mod.constant_values.Get(1_i),
+ mod.constant_values.Get(2_i),
+ mod.constant_values.Get(3_i),
+ mod.constant_values.Get(4_i),
+ });
+ generator_.Constant(b.Constant(arr));
+ EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 1
+%5 = OpTypeInt 32 0
+%4 = OpConstant %5 4
+%2 = OpTypeArray %3 %4
+%6 = OpConstant %3 1
+%7 = OpConstant %3 2
+%8 = OpConstant %3 3
+%9 = OpConstant %3 4
+%1 = OpConstantComposite %2 %6 %7 %8 %9
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, Constant_Array_Array_I32) {
+ auto* inner =
+ mod.constant_values.Composite(ty.array(ty.i32(), 4), utils::Vector{
+ mod.constant_values.Get(1_i),
+ mod.constant_values.Get(2_i),
+ mod.constant_values.Get(3_i),
+ mod.constant_values.Get(4_i),
+ });
+ auto* arr = mod.constant_values.Composite(ty.array(ty.array(ty.i32(), 4), 4), utils::Vector{
+ inner,
+ inner,
+ inner,
+ inner,
+ });
+ generator_.Constant(b.Constant(arr));
+ EXPECT_EQ(DumpTypes(), R"(%4 = OpTypeInt 32 1
+%6 = OpTypeInt 32 0
+%5 = OpConstant %6 4
+%3 = OpTypeArray %4 %5
+%2 = OpTypeArray %3 %5
+%8 = OpConstant %4 1
+%9 = OpConstant %4 2
+%10 = OpConstant %4 3
+%11 = OpConstant %4 4
+%7 = OpConstantComposite %3 %8 %9 %10 %11
+%1 = OpConstantComposite %2 %7 %7 %7 %7
+)");
+}
+
// Test that we do not emit the same constant more than once.
TEST_F(SpvGeneratorImplTest, Constant_Deduplicate) {
generator_.Constant(b.Constant(i32(42)));
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
index 6412bc6..0eaa859 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
@@ -18,7 +18,7 @@
namespace {
TEST_F(SpvGeneratorImplTest, Function_Empty) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
func->StartTarget()->SetInstructions(utils::Vector{b.Return(func)});
generator_.EmitFunction(func);
@@ -34,7 +34,7 @@
// Test that we do not emit the same function type more than once.
TEST_F(SpvGeneratorImplTest, Function_DeduplicateType) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
func->StartTarget()->SetInstructions(utils::Vector{b.Return(func)});
generator_.EmitFunction(func);
@@ -46,8 +46,8 @@
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Compute) {
- auto* func = b.CreateFunction("main", mod.Types().void_(),
- ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
+ auto* func =
+ b.CreateFunction("main", ty.void_(), ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
func->StartTarget()->SetInstructions(utils::Vector{b.Return(func)});
generator_.EmitFunction(func);
@@ -64,8 +64,7 @@
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Fragment) {
- auto* func =
- b.CreateFunction("main", mod.Types().void_(), ir::Function::PipelineStage::kFragment);
+ auto* func = b.CreateFunction("main", ty.void_(), ir::Function::PipelineStage::kFragment);
func->StartTarget()->SetInstructions(utils::Vector{b.Return(func)});
generator_.EmitFunction(func);
@@ -82,8 +81,7 @@
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Vertex) {
- auto* func =
- b.CreateFunction("main", mod.Types().void_(), ir::Function::PipelineStage::kVertex);
+ auto* func = b.CreateFunction("main", ty.void_(), ir::Function::PipelineStage::kVertex);
func->StartTarget()->SetInstructions(utils::Vector{b.Return(func)});
generator_.EmitFunction(func);
@@ -99,16 +97,15 @@
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Multiple) {
- auto* f1 = b.CreateFunction("main1", mod.Types().void_(), ir::Function::PipelineStage::kCompute,
- {{32, 4, 1}});
+ auto* f1 =
+ b.CreateFunction("main1", ty.void_(), ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
f1->StartTarget()->SetInstructions(utils::Vector{b.Return(f1)});
- auto* f2 = b.CreateFunction("main2", mod.Types().void_(), ir::Function::PipelineStage::kCompute,
- {{8, 2, 16}});
+ auto* f2 =
+ b.CreateFunction("main2", ty.void_(), ir::Function::PipelineStage::kCompute, {{8, 2, 16}});
f2->StartTarget()->SetInstructions(utils::Vector{b.Return(f2)});
- auto* f3 =
- b.CreateFunction("main3", mod.Types().void_(), ir::Function::PipelineStage::kFragment);
+ auto* f3 = b.CreateFunction("main3", ty.void_(), ir::Function::PipelineStage::kFragment);
f3->StartTarget()->SetInstructions(utils::Vector{b.Return(f3)});
generator_.EmitFunction(f1);
@@ -141,7 +138,7 @@
}
TEST_F(SpvGeneratorImplTest, Function_ReturnValue) {
- auto* func = b.CreateFunction("foo", mod.Types().i32());
+ auto* func = b.CreateFunction("foo", ty.i32());
func->StartTarget()->SetInstructions(
utils::Vector{b.Return(func, utils::Vector{b.Constant(i32(42))})});
@@ -158,7 +155,7 @@
}
TEST_F(SpvGeneratorImplTest, Function_Parameters) {
- auto* i32 = mod.Types().i32();
+ auto* i32 = ty.i32();
auto* x = b.FunctionParam(i32);
auto* y = b.FunctionParam(i32);
auto* result = b.Add(i32, x, y);
@@ -186,7 +183,7 @@
}
TEST_F(SpvGeneratorImplTest, Function_Call) {
- auto* i32_ty = mod.Types().i32();
+ auto* i32_ty = ty.i32();
auto* x = b.FunctionParam(i32_ty);
auto* y = b.FunctionParam(i32_ty);
auto* result = b.Add(i32_ty, x, y);
@@ -195,7 +192,7 @@
foo->StartTarget()->SetInstructions(
utils::Vector{result, b.Return(foo, utils::Vector{result})});
- auto* bar = b.CreateFunction("bar", mod.Types().void_());
+ auto* bar = b.CreateFunction("bar", ty.void_());
bar->StartTarget()->SetInstructions(utils::Vector{
b.UserCall(i32_ty, foo, utils::Vector{b.Constant(i32(2)), b.Constant(i32(3))}),
b.Return(bar)});
@@ -226,12 +223,12 @@
}
TEST_F(SpvGeneratorImplTest, Function_Call_Void) {
- auto* foo = b.CreateFunction("foo", mod.Types().void_());
+ auto* foo = b.CreateFunction("foo", ty.void_());
foo->StartTarget()->SetInstructions(utils::Vector{b.Return(foo)});
- auto* bar = b.CreateFunction("bar", mod.Types().void_());
+ auto* bar = b.CreateFunction("bar", ty.void_());
bar->StartTarget()->SetInstructions(
- utils::Vector{b.UserCall(mod.Types().void_(), foo, utils::Empty), b.Return(bar)});
+ utils::Vector{b.UserCall(ty.void_(), foo, utils::Empty), b.Return(bar)});
generator_.EmitFunction(foo);
generator_.EmitFunction(bar);
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
index 8d88852..792b735 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
@@ -20,7 +20,7 @@
namespace {
TEST_F(SpvGeneratorImplTest, If_TrueEmpty_FalseEmpty) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* i = b.CreateIf(b.Constant(true));
i->True()->SetInstructions(utils::Vector{b.ExitIf(i)});
@@ -46,7 +46,7 @@
}
TEST_F(SpvGeneratorImplTest, If_FalseEmpty) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* i = b.CreateIf(b.Constant(true));
i->False()->SetInstructions(utils::Vector{b.ExitIf(i)});
@@ -54,7 +54,7 @@
auto* true_block = i->True();
true_block->SetInstructions(
- utils::Vector{b.Add(mod.Types().i32(), b.Constant(1_i), b.Constant(1_i)), b.ExitIf(i)});
+ utils::Vector{b.Add(ty.i32(), b.Constant(1_i), b.Constant(1_i)), b.ExitIf(i)});
func->StartTarget()->SetInstructions(utils::Vector{i});
@@ -80,7 +80,7 @@
}
TEST_F(SpvGeneratorImplTest, If_TrueEmpty) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* i = b.CreateIf(b.Constant(true));
i->True()->SetInstructions(utils::Vector{b.ExitIf(i)});
@@ -88,7 +88,7 @@
auto* false_block = i->False();
false_block->SetInstructions(
- utils::Vector{b.Add(mod.Types().i32(), b.Constant(1_i), b.Constant(1_i)), b.ExitIf(i)});
+ utils::Vector{b.Add(ty.i32(), b.Constant(1_i), b.Constant(1_i)), b.ExitIf(i)});
func->StartTarget()->SetInstructions(utils::Vector{i});
@@ -114,7 +114,7 @@
}
TEST_F(SpvGeneratorImplTest, If_BothBranchesReturn) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* i = b.CreateIf(b.Constant(true));
i->True()->SetInstructions(utils::Vector{b.Return(func)});
@@ -143,7 +143,7 @@
}
TEST_F(SpvGeneratorImplTest, If_Phi_SingleValue) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* merge_param = b.BlockParam(b.ir.Types().i32());
@@ -180,7 +180,7 @@
}
TEST_F(SpvGeneratorImplTest, If_Phi_SingleValue_TrueReturn) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* merge_param = b.BlockParam(b.ir.Types().i32());
@@ -217,7 +217,7 @@
}
TEST_F(SpvGeneratorImplTest, If_Phi_SingleValue_FalseReturn) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* merge_param = b.BlockParam(b.ir.Types().i32());
@@ -254,7 +254,7 @@
}
TEST_F(SpvGeneratorImplTest, If_Phi_MultipleValue) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* merge_param_0 = b.BlockParam(b.ir.Types().i32());
auto* merge_param_1 = b.BlockParam(b.ir.Types().bool_());
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
index 577d9d2..cb43f9d 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
@@ -20,7 +20,7 @@
namespace {
TEST_F(SpvGeneratorImplTest, Loop_BreakIf) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* loop = b.CreateLoop();
@@ -54,7 +54,7 @@
// Test that we still emit the continuing block with a back-edge, even when it is unreachable.
TEST_F(SpvGeneratorImplTest, Loop_UnconditionalBreakInBody) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* loop = b.CreateLoop();
@@ -84,7 +84,7 @@
}
TEST_F(SpvGeneratorImplTest, Loop_ConditionalBreakInBody) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* loop = b.CreateLoop();
@@ -127,7 +127,7 @@
}
TEST_F(SpvGeneratorImplTest, Loop_ConditionalContinueInBody) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* loop = b.CreateLoop();
@@ -172,7 +172,7 @@
// Test that we still emit the continuing block with a back-edge, and the merge block, even when
// they are unreachable.
TEST_F(SpvGeneratorImplTest, Loop_UnconditionalReturnInBody) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* loop = b.CreateLoop();
@@ -201,11 +201,11 @@
}
TEST_F(SpvGeneratorImplTest, Loop_UseResultFromBodyInContinuing) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* loop = b.CreateLoop();
- auto* result = b.Equal(mod.Types().i32(), b.Constant(1_i), b.Constant(2_i));
+ auto* result = b.Equal(ty.i32(), b.Constant(1_i), b.Constant(2_i));
loop->Body()->Append(result);
loop->Continuing()->Append(b.BreakIf(result, loop));
@@ -237,7 +237,7 @@
}
TEST_F(SpvGeneratorImplTest, Loop_NestedLoopInBody) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* outer_loop = b.CreateLoop();
auto* inner_loop = b.CreateLoop();
@@ -284,7 +284,7 @@
}
TEST_F(SpvGeneratorImplTest, Loop_NestedLoopInContinuing) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* outer_loop = b.CreateLoop();
auto* inner_loop = b.CreateLoop();
@@ -331,7 +331,7 @@
}
TEST_F(SpvGeneratorImplTest, Loop_Phi_SingleValue) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* l = b.CreateLoop(utils::Vector{b.Constant(1_i)});
func->StartTarget()->Append(l);
@@ -377,7 +377,7 @@
}
TEST_F(SpvGeneratorImplTest, Loop_Phi_MultipleValue) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* l = b.CreateLoop(utils::Vector{b.Constant(1_i), b.Constant(false)});
func->StartTarget()->Append(l);
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
index 2773ae0..dca1dac 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
@@ -20,7 +20,7 @@
namespace {
TEST_F(SpvGeneratorImplTest, Switch_Basic) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* swtch = b.CreateSwitch(b.Constant(42_i));
@@ -50,7 +50,7 @@
}
TEST_F(SpvGeneratorImplTest, Switch_MultipleCases) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* swtch = b.CreateSwitch(b.Constant(42_i));
@@ -90,7 +90,7 @@
}
TEST_F(SpvGeneratorImplTest, Switch_MultipleSelectorsPerCase) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* swtch = b.CreateSwitch(b.Constant(42_i));
@@ -133,7 +133,7 @@
}
TEST_F(SpvGeneratorImplTest, Switch_AllCasesReturn) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* swtch = b.CreateSwitch(b.Constant(42_i));
@@ -171,7 +171,7 @@
}
TEST_F(SpvGeneratorImplTest, Switch_ConditionalBreak) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* swtch = b.CreateSwitch(b.Constant(42_i));
@@ -218,7 +218,7 @@
}
TEST_F(SpvGeneratorImplTest, Switch_Phi_SingleValue) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* merge_param = b.BlockParam(b.ir.Types().i32());
@@ -259,7 +259,7 @@
}
TEST_F(SpvGeneratorImplTest, Switch_Phi_SingleValue_CaseReturn) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* s = b.CreateSwitch(b.Constant(42_i));
auto* case_a = b.CreateCase(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
@@ -298,7 +298,7 @@
}
TEST_F(SpvGeneratorImplTest, Switch_Phi_MultipleValue) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
auto* merge_param_0 = b.BlockParam(b.ir.Types().i32());
auto* merge_param_1 = b.BlockParam(b.ir.Types().bool_());
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
index d6c654f..5deefe0 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
@@ -25,44 +25,43 @@
namespace {
TEST_F(SpvGeneratorImplTest, Type_Void) {
- auto id = generator_.Type(mod.Types().void_());
+ auto id = generator_.Type(ty.void_());
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(), "%1 = OpTypeVoid\n");
}
TEST_F(SpvGeneratorImplTest, Type_Bool) {
- auto id = generator_.Type(mod.Types().bool_());
+ auto id = generator_.Type(ty.bool_());
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(), "%1 = OpTypeBool\n");
}
TEST_F(SpvGeneratorImplTest, Type_I32) {
- auto id = generator_.Type(mod.Types().i32());
+ auto id = generator_.Type(ty.i32());
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(), "%1 = OpTypeInt 32 1\n");
}
TEST_F(SpvGeneratorImplTest, Type_U32) {
- auto id = generator_.Type(mod.Types().u32());
+ auto id = generator_.Type(ty.u32());
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(), "%1 = OpTypeInt 32 0\n");
}
TEST_F(SpvGeneratorImplTest, Type_F32) {
- auto id = generator_.Type(mod.Types().f32());
+ auto id = generator_.Type(ty.f32());
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(), "%1 = OpTypeFloat 32\n");
}
TEST_F(SpvGeneratorImplTest, Type_F16) {
- auto id = generator_.Type(mod.Types().f16());
+ auto id = generator_.Type(ty.f16());
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(), "%1 = OpTypeFloat 16\n");
}
TEST_F(SpvGeneratorImplTest, Type_Vec2i) {
- auto* vec = mod.Types().Get<type::Vector>(mod.Types().i32(), 2u);
- auto id = generator_.Type(vec);
+ auto id = generator_.Type(ty.vec2(ty.i32()));
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
"%2 = OpTypeInt 32 1\n"
@@ -70,8 +69,7 @@
}
TEST_F(SpvGeneratorImplTest, Type_Vec3u) {
- auto* vec = mod.Types().Get<type::Vector>(mod.Types().u32(), 3u);
- auto id = generator_.Type(vec);
+ auto id = generator_.Type(ty.vec3(ty.u32()));
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
"%2 = OpTypeInt 32 0\n"
@@ -79,17 +77,15 @@
}
TEST_F(SpvGeneratorImplTest, Type_Vec4f) {
- auto* vec = mod.Types().Get<type::Vector>(mod.Types().f32(), 4u);
- auto id = generator_.Type(vec);
+ auto id = generator_.Type(ty.vec4(ty.f32()));
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
"%2 = OpTypeFloat 32\n"
"%1 = OpTypeVector %2 4\n");
}
-TEST_F(SpvGeneratorImplTest, Type_Vec4h) {
- auto* vec = mod.Types().Get<type::Vector>(mod.Types().f16(), 2u);
- auto id = generator_.Type(vec);
+TEST_F(SpvGeneratorImplTest, Type_Vec2h) {
+ auto id = generator_.Type(ty.vec2(ty.f16()));
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
"%2 = OpTypeFloat 16\n"
@@ -97,8 +93,7 @@
}
TEST_F(SpvGeneratorImplTest, Type_Vec4Bool) {
- auto* vec = mod.Types().Get<type::Vector>(mod.Types().bool_(), 4u);
- auto id = generator_.Type(vec);
+ auto id = generator_.Type(ty.vec4(ty.bool_()));
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
"%2 = OpTypeBool\n"
@@ -106,7 +101,7 @@
}
TEST_F(SpvGeneratorImplTest, Type_Mat2x3f) {
- auto* vec = mod.Types().mat2x3(mod.Types().f32());
+ auto* vec = ty.mat2x3(ty.f32());
auto id = generator_.Type(vec);
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
@@ -116,7 +111,7 @@
}
TEST_F(SpvGeneratorImplTest, Type_Mat4x2h) {
- auto* vec = mod.Types().mat4x2(mod.Types().f16());
+ auto* vec = ty.mat4x2(ty.f16());
auto id = generator_.Type(vec);
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
@@ -126,7 +121,7 @@
}
TEST_F(SpvGeneratorImplTest, Type_Array_DefaultStride) {
- auto* arr = mod.Types().array(mod.Types().f32(), 4u);
+ auto* arr = ty.array(ty.f32(), 4u);
auto id = generator_.Type(arr);
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
@@ -138,7 +133,7 @@
}
TEST_F(SpvGeneratorImplTest, Type_Array_ExplicitStride) {
- auto* arr = mod.Types().array(mod.Types().f32(), 4u, 16);
+ auto* arr = ty.array(ty.f32(), 4u, 16);
auto id = generator_.Type(arr);
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
@@ -150,7 +145,7 @@
}
TEST_F(SpvGeneratorImplTest, Type_Array_NestedArray) {
- auto* arr = mod.Types().array(mod.Types().array(mod.Types().f32(), 64u), 4u);
+ auto* arr = ty.array(ty.array(ty.f32(), 64u), 4u);
auto id = generator_.Type(arr);
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
@@ -166,7 +161,7 @@
}
TEST_F(SpvGeneratorImplTest, Type_RuntimeArray_DefaultStride) {
- auto* arr = mod.Types().runtime_array(mod.Types().f32());
+ auto* arr = ty.runtime_array(ty.f32());
auto id = generator_.Type(arr);
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
@@ -176,7 +171,7 @@
}
TEST_F(SpvGeneratorImplTest, Type_RuntimeArray_ExplicitStride) {
- auto* arr = mod.Types().runtime_array(mod.Types().f32(), 16);
+ auto* arr = ty.runtime_array(ty.f32(), 16);
auto id = generator_.Type(arr);
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
@@ -186,14 +181,13 @@
}
TEST_F(SpvGeneratorImplTest, Type_Struct) {
- auto* str = mod.Types().Get<type::Struct>(
+ auto* str = ty.Get<type::Struct>(
mod.symbols.Register("MyStruct"),
utils::Vector{
- mod.Types().Get<type::StructMember>(mod.symbols.Register("a"), mod.Types().f32(), 0u,
- 0u, 4u, 4u, type::StructMemberAttributes{}),
- mod.Types().Get<type::StructMember>(mod.symbols.Register("b"),
- mod.Types().vec4(mod.Types().i32()), 1u, 16u, 16u,
- 16u, type::StructMemberAttributes{}),
+ ty.Get<type::StructMember>(mod.symbols.Register("a"), ty.f32(), 0u, 0u, 4u, 4u,
+ type::StructMemberAttributes{}),
+ ty.Get<type::StructMember>(mod.symbols.Register("b"), ty.vec4(ty.i32()), 1u, 16u, 16u,
+ 16u, type::StructMemberAttributes{}),
},
16u, 32u, 32u);
auto id = generator_.Type(str);
@@ -213,17 +207,15 @@
}
TEST_F(SpvGeneratorImplTest, Type_Struct_MatrixLayout) {
- auto* str = mod.Types().Get<type::Struct>(
+ auto* str = ty.Get<type::Struct>(
mod.symbols.Register("MyStruct"),
utils::Vector{
- mod.Types().Get<type::StructMember>(mod.symbols.Register("m"),
- mod.Types().mat3x3(mod.Types().f32()), 0u, 0u, 16u,
- 48u, type::StructMemberAttributes{}),
+ ty.Get<type::StructMember>(mod.symbols.Register("m"), ty.mat3x3(ty.f32()), 0u, 0u, 16u,
+ 48u, type::StructMemberAttributes{}),
// Matrices nested inside arrays need layout decorations on the struct member too.
- mod.Types().Get<type::StructMember>(
- mod.symbols.Register("arr"),
- mod.Types().array(mod.Types().array(mod.Types().mat2x4(mod.Types().f16()), 4), 4),
- 1u, 64u, 8u, 64u, type::StructMemberAttributes{}),
+ ty.Get<type::StructMember>(mod.symbols.Register("arr"),
+ ty.array(ty.array(ty.mat2x4(ty.f16()), 4), 4), 1u, 64u, 8u,
+ 64u, type::StructMemberAttributes{}),
},
16u, 128u, 128u);
auto id = generator_.Type(str);
@@ -258,10 +250,10 @@
// Test that we can emit multiple types.
// Includes types with the same opcode but different parameters.
TEST_F(SpvGeneratorImplTest, Type_Multiple) {
- EXPECT_EQ(generator_.Type(mod.Types().i32()), 1u);
- EXPECT_EQ(generator_.Type(mod.Types().u32()), 2u);
- EXPECT_EQ(generator_.Type(mod.Types().f32()), 3u);
- EXPECT_EQ(generator_.Type(mod.Types().f16()), 4u);
+ EXPECT_EQ(generator_.Type(ty.i32()), 1u);
+ EXPECT_EQ(generator_.Type(ty.u32()), 2u);
+ EXPECT_EQ(generator_.Type(ty.f32()), 3u);
+ EXPECT_EQ(generator_.Type(ty.f16()), 4u);
EXPECT_EQ(DumpTypes(), R"(%1 = OpTypeInt 32 1
%2 = OpTypeInt 32 0
%3 = OpTypeFloat 32
@@ -271,7 +263,7 @@
// Test that we do not emit the same type more than once.
TEST_F(SpvGeneratorImplTest, Type_Deduplicate) {
- auto* i32 = mod.Types().i32();
+ auto* i32 = ty.i32();
EXPECT_EQ(generator_.Type(i32), 1u);
EXPECT_EQ(generator_.Type(i32), 1u);
EXPECT_EQ(generator_.Type(i32), 1u);
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
index 8ad2c10..516e504 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
@@ -21,11 +21,12 @@
namespace {
TEST_F(SpvGeneratorImplTest, FunctionVar_NoInit) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
- auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite);
- func->StartTarget()->SetInstructions(utils::Vector{b.Declare(ty), b.Return(func)});
+ func->StartTarget()->SetInstructions(
+ utils::Vector{b.Declare(ty.pointer(ty.i32(), builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite)),
+ b.Return(func)});
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
@@ -42,11 +43,10 @@
}
TEST_F(SpvGeneratorImplTest, FunctionVar_WithInit) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
- auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite);
- auto* v = b.Declare(ty);
+ auto* v = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
v->SetInitializer(b.Constant(42_i));
func->StartTarget()->SetInstructions(utils::Vector{v, b.Return(func)});
@@ -68,11 +68,10 @@
}
TEST_F(SpvGeneratorImplTest, FunctionVar_Name) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
- auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite);
- auto* v = b.Declare(ty);
+ auto* v = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
func->StartTarget()->SetInstructions(utils::Vector{v, b.Return(func)});
mod.SetName(v, "myvar");
@@ -92,11 +91,10 @@
}
TEST_F(SpvGeneratorImplTest, FunctionVar_DeclInsideBlock) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
- auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite);
- auto* v = b.Declare(ty);
+ auto* v = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
v->SetInitializer(b.Constant(42_i));
auto* i = b.CreateIf(b.Constant(true));
@@ -132,12 +130,11 @@
}
TEST_F(SpvGeneratorImplTest, FunctionVar_Load) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
- auto* store_ty = mod.Types().i32();
- auto* ty = mod.Types().Get<type::Pointer>(store_ty, builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite);
- auto* v = b.Declare(ty);
+ auto* store_ty = ty.i32();
+ auto* v = b.Declare(
+ ty.pointer(store_ty, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
func->StartTarget()->SetInstructions(utils::Vector{v, b.Load(v), b.Return(func)});
generator_.EmitFunction(func);
@@ -156,11 +153,10 @@
}
TEST_F(SpvGeneratorImplTest, FunctionVar_Store) {
- auto* func = b.CreateFunction("foo", mod.Types().void_());
+ auto* func = b.CreateFunction("foo", ty.void_());
- auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite);
- auto* v = b.Declare(ty);
+ auto* v = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
func->StartTarget()->SetInstructions(
utils::Vector{v, b.Store(v, b.Constant(42_i)), b.Return(func)});
@@ -181,9 +177,8 @@
}
TEST_F(SpvGeneratorImplTest, PrivateVar_NoInit) {
- auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kPrivate,
- builtin::Access::kReadWrite);
- b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(ty)});
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kPrivate, builtin::Access::kReadWrite))});
generator_.Generate();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
@@ -204,9 +199,8 @@
}
TEST_F(SpvGeneratorImplTest, PrivateVar_WithInit) {
- auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kPrivate,
- builtin::Access::kReadWrite);
- auto* v = b.Declare(ty);
+ auto* v = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kPrivate, builtin::Access::kReadWrite));
b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
v->SetInitializer(b.Constant(42_i));
@@ -230,9 +224,8 @@
}
TEST_F(SpvGeneratorImplTest, PrivateVar_Name) {
- auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kPrivate,
- builtin::Access::kReadWrite);
- auto* v = b.Declare(ty);
+ auto* v = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kPrivate, builtin::Access::kReadWrite));
b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
v->SetInitializer(b.Constant(42_i));
mod.SetName(v, "myvar");
@@ -258,14 +251,12 @@
}
TEST_F(SpvGeneratorImplTest, PrivateVar_LoadAndStore) {
- auto* func =
- b.CreateFunction("foo", mod.Types().void_(), ir::Function::PipelineStage::kFragment);
+ auto* func = b.CreateFunction("foo", ty.void_(), ir::Function::PipelineStage::kFragment);
mod.functions.Push(func);
- auto* store_ty = mod.Types().i32();
- auto* ty = mod.Types().Get<type::Pointer>(store_ty, builtin::AddressSpace::kPrivate,
- builtin::Access::kReadWrite);
- auto* v = b.Declare(ty);
+ auto* store_ty = ty.i32();
+ auto* v = b.Declare(
+ ty.pointer(store_ty, builtin::AddressSpace::kPrivate, builtin::Access::kReadWrite));
b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
v->SetInitializer(b.Constant(42_i));
@@ -298,9 +289,8 @@
}
TEST_F(SpvGeneratorImplTest, WorkgroupVar) {
- auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kWorkgroup,
- builtin::Access::kReadWrite);
- b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(ty)});
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kWorkgroup, builtin::Access::kReadWrite))});
generator_.Generate();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
@@ -321,9 +311,8 @@
}
TEST_F(SpvGeneratorImplTest, WorkgroupVar_Name) {
- auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kWorkgroup,
- builtin::Access::kReadWrite);
- auto* v = b.Declare(ty);
+ auto* v = b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kWorkgroup, builtin::Access::kReadWrite));
b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
mod.SetName(v, "myvar");
@@ -347,14 +336,13 @@
}
TEST_F(SpvGeneratorImplTest, WorkgroupVar_LoadAndStore) {
- auto* func = b.CreateFunction("foo", mod.Types().void_(), ir::Function::PipelineStage::kCompute,
+ auto* func = b.CreateFunction("foo", ty.void_(), ir::Function::PipelineStage::kCompute,
std::array{1u, 1u, 1u});
mod.functions.Push(func);
- auto* store_ty = mod.Types().i32();
- auto* ty = mod.Types().Get<type::Pointer>(store_ty, builtin::AddressSpace::kWorkgroup,
- builtin::Access::kReadWrite);
- auto* v = b.Declare(ty);
+ auto* store_ty = ty.i32();
+ auto* v = b.Declare(
+ ty.pointer(store_ty, builtin::AddressSpace::kWorkgroup, builtin::Access::kReadWrite));
b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
auto* load = b.Load(v);
@@ -385,9 +373,8 @@
}
TEST_F(SpvGeneratorImplTest, WorkgroupVar_ZeroInitializeWithExtension) {
- auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kWorkgroup,
- builtin::Access::kReadWrite);
- b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(ty)});
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(
+ ty.pointer(ty.i32(), builtin::AddressSpace::kWorkgroup, builtin::Access::kReadWrite))});
// Create a generator with the zero_init_workgroup_memory flag set to `true`.
spirv::GeneratorImplIr gen(&mod, true);
diff --git a/src/tint/writer/spirv/ir/test_helper_ir.h b/src/tint/writer/spirv/ir/test_helper_ir.h
index add9eaf..791dfca 100644
--- a/src/tint/writer/spirv/ir/test_helper_ir.h
+++ b/src/tint/writer/spirv/ir/test_helper_ir.h
@@ -39,10 +39,12 @@
public:
SpvGeneratorTestHelperBase() : generator_(&mod, false) {}
- /// The test module
+ /// The test module.
ir::Module mod;
- /// The test builder
+ /// The test builder.
ir::Builder b{mod};
+ /// The type manager.
+ type::Manager& ty{mod.Types()};
protected:
/// The SPIR-V generator.
@@ -51,37 +53,35 @@
/// @returns the disassembled types from the generated module.
std::string DumpTypes() { return DumpInstructions(generator_.Module().Types()); }
- /// Helper to make a scalar type corresponding to the element type `ty`.
- /// @param ty the element type
+ /// Helper to make a scalar type corresponding to the element type `type`.
+ /// @param type the element type
/// @returns the scalar type
- const type::Type* MakeScalarType(TestElementType ty) {
- switch (ty) {
+ const type::Type* MakeScalarType(TestElementType type) {
+ switch (type) {
case kBool:
- return mod.Types().bool_();
+ return ty.bool_();
case kI32:
- return mod.Types().i32();
+ return ty.i32();
case kU32:
- return mod.Types().u32();
+ return ty.u32();
case kF32:
- return mod.Types().f32();
+ return ty.f32();
case kF16:
- return mod.Types().f16();
+ return ty.f16();
}
return nullptr;
}
- /// Helper to make a vector type corresponding to the element type `ty`.
- /// @param ty the element type
+ /// Helper to make a vector type corresponding to the element type `type`.
+ /// @param type the element type
/// @returns the vector type
- const type::Type* MakeVectorType(TestElementType ty) {
- return mod.Types().vec2(MakeScalarType(ty));
- }
+ const type::Type* MakeVectorType(TestElementType type) { return ty.vec2(MakeScalarType(type)); }
- /// Helper to make a scalar value with the scalar type `ty`.
- /// @param ty the element type
+ /// Helper to make a scalar value with the scalar type `type`.
+ /// @param type the element type
/// @returns the scalar value
- ir::Value* MakeScalarValue(TestElementType ty) {
- switch (ty) {
+ ir::Value* MakeScalarValue(TestElementType type) {
+ switch (type) {
case kBool:
return b.Constant(true);
case kI32:
@@ -96,34 +96,34 @@
return nullptr;
}
- /// Helper to make a vector value with an element type of `ty`.
- /// @param ty the element type
+ /// Helper to make a vector value with an element type of `type`.
+ /// @param type the element type
/// @returns the vector value
- ir::Value* MakeVectorValue(TestElementType ty) {
- switch (ty) {
+ ir::Value* MakeVectorValue(TestElementType type) {
+ switch (type) {
case kBool:
return b.Constant(mod.constant_values.Composite(
- MakeVectorType(ty),
+ MakeVectorType(type),
utils::Vector<const constant::Value*, 2>{mod.constant_values.Get(true),
mod.constant_values.Get(false)}));
case kI32:
return b.Constant(mod.constant_values.Composite(
- MakeVectorType(ty),
+ MakeVectorType(type),
utils::Vector<const constant::Value*, 2>{mod.constant_values.Get(i32(42)),
mod.constant_values.Get(i32(-10))}));
case kU32:
return b.Constant(mod.constant_values.Composite(
- MakeVectorType(ty),
+ MakeVectorType(type),
utils::Vector<const constant::Value*, 2>{mod.constant_values.Get(u32(42)),
mod.constant_values.Get(u32(10))}));
case kF32:
return b.Constant(mod.constant_values.Composite(
- MakeVectorType(ty),
+ MakeVectorType(type),
utils::Vector<const constant::Value*, 2>{mod.constant_values.Get(f32(42)),
mod.constant_values.Get(f32(-0.5))}));
case kF16:
return b.Constant(mod.constant_values.Composite(
- MakeVectorType(ty),
+ MakeVectorType(type),
utils::Vector<const constant::Value*, 2>{mod.constant_values.Get(f16(42)),
mod.constant_values.Get(f16(-0.5))}));
}