Import Tint changes from Dawn
Changes:
- 97ab6a3a706121a6a3c046c13361e9964d53feaa [ir][spirv-writer] Emit store instructions by James Price <jrprice@google.com>
- 75aaa49d772bf99f9f6c3cc662422fe74626ff93 [ir][spirv-writer] Emit function variables by James Price <jrprice@google.com>
- 6e40b1a9df6215573aab4dfe0a814aebb53a554e [tint][ir][ToProgram] Begin emitting Switch statements by Ben Clayton <bclayton@google.com>
- 43b110ce5311aaeaef9b648fe53adfa5a850e7d4 [ir] Only show type on lhs by dan sinclair <dsinclair@chromium.org>
- c42014805a3d4829109930f47c8d8621e39d4760 [tint][ir] Guard transform source sets in GN build by James Price <jrprice@google.com>
- 54d1d714ce7cac3762bb997393698d0cd059c8f1 [ir][spirv-writer] Move code to `ir` subdirectory by James Price <jrprice@google.com>
- 90b8cc1e93677bc76db880fc284fcb5da27714d6 [ir] Add load instruction by James Price <jrprice@google.com>
- 0bb1bb3067dafb048fea59579e5885a7299e5dc4 [ir] Remove references, indirection and address-of by James Price <jrprice@google.com>
- 82db91ac96fcebcb72f2edc4a49672d995c5a430 [ir][spirv-writer] Emit `If` flow nodes by James Price <jrprice@google.com>
- 7ac28d3c6e5159c70525ebe4609459a141630f72 [ir] Add AddEmptyEntryPoint transform by James Price <jrprice@google.com>
- 95b06129f0f9b0141fa681ae7d0a29da3a814d5c [ir] Add base `ir::transform::Transform` class by James Price <jrprice@google.com>
- db5ad9f357b5be9e6a89e57b10a82dc9e0e60d2b [tint] Materialize compound assignment RHS by James Price <jrprice@google.com>
- 11ee6b6cc6775ff54bbef5bac2d20c2e75b83fe1 [ir] Handle phony assignment. by dan sinclair <dsinclair@chromium.org>
- bbaa456b18e052ddcc6da7b061cbbad046673196 [ir] Remove instruction allocator. by dan sinclair <dsinclair@chromium.org>
- 0531610e99a2f193e3f425704d665d5512a90881 [ir] Add basic block arguments. by dan sinclair <dsinclair@chromium.org>
- 9fc46dc3c1e2e7f9fbbc23f3d76f8a5da9c38aef [ir] Drop address space and access from ir::Var. by dan sinclair <dsinclair@chromium.org>
- 25ae3114b37811aa185b3d1cc097e2808a0f6fb1 [tint][ir][ToProgram] Emit returns with values by Ben Clayton <bclayton@google.com>
- a6e7cfc1d08c5bf41caa9116b84fe92c65de2ed1 [tint][ir][ToProgram] Emit returns without values by Ben Clayton <bclayton@google.com>
- 6ab77f16f0560b84b7e67e3e110c3e218dcd38b7 [tint][ir][ToProgram] Emit 'else if' instead of 'else { if' by Ben Clayton <bclayton@google.com>
- 1ea1e1a37578539aeafe13177346dadd3f171af8 [tint][ir][ToProgram] Begin flow node traversal by Ben Clayton <bclayton@google.com>
- 0b9cb101bf878ed3e672123f93cffae573a2f329 [tint][ir][ToProgram] Implement var expressions by Ben Clayton <bclayton@google.com>
GitOrigin-RevId: 97ab6a3a706121a6a3c046c13361e9964d53feaa
Change-Id: I09d69b9ce06712f1e95b64d53f882ec8e032be26
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/133540
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 7e85a38..47efd52 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -359,10 +359,7 @@
"transform/transform.cc",
"transform/transform.h",
]
- deps = [
- ":libtint_program_src",
- ":libtint_utils_src",
- ]
+ deps = [ ":libtint_utils_src" ]
}
libtint_source_set("libtint_ast_transform_base_src") {
@@ -370,11 +367,11 @@
"ast/transform/transform.cc",
"ast/transform/transform.h",
]
- public_deps = [ ":libtint_transform_src" ]
deps = [
":libtint_builtins_src",
":libtint_program_src",
":libtint_sem_src",
+ ":libtint_transform_src",
":libtint_type_src",
":libtint_utils_src",
]
@@ -387,7 +384,10 @@
]
deps = [
":libtint_ast_transform_base_src",
+ ":libtint_ir_builder_src",
+ ":libtint_ir_src",
":libtint_program_src",
+ ":libtint_transform_src",
]
}
@@ -492,9 +492,9 @@
"ast/transform/zero_init_workgroup_memory.cc",
"ast/transform/zero_init_workgroup_memory.h",
]
- public_deps = [ ":libtint_ast_transform_base_src" ]
deps = [
":libtint_ast_src",
+ ":libtint_ast_transform_base_src",
":libtint_builtins_src",
":libtint_program_src",
":libtint_sem_src",
@@ -505,6 +505,22 @@
]
}
+if (tint_build_ir) {
+ libtint_source_set("libtint_ir_transform_src") {
+ sources = [
+ "ir/transform/add_empty_entry_point.cc",
+ "ir/transform/add_empty_entry_point.h",
+ ]
+ deps = [
+ ":libtint_builtins_src",
+ ":libtint_ir_src",
+ ":libtint_symbols_src",
+ ":libtint_type_src",
+ ":libtint_utils_src",
+ ]
+ }
+}
+
libtint_source_set("libtint_ast_hdrs") {
sources = [
"ast/accessor_expression.h",
@@ -1014,12 +1030,13 @@
if (tint_build_ir) {
sources += [
- "writer/spirv/generator_impl_ir.cc",
- "writer/spirv/generator_impl_ir.h",
+ "writer/spirv/ir/generator_impl_ir.cc",
+ "writer/spirv/ir/generator_impl_ir.h",
]
deps += [
":libtint_ir_builder_src",
":libtint_ir_src",
+ ":libtint_ir_transform_src",
]
}
}
@@ -1185,6 +1202,8 @@
"ir/bitcast.h",
"ir/block.cc",
"ir/block.h",
+ "ir/block_param.cc",
+ "ir/block_param.h",
"ir/builder.cc",
"ir/builder.h",
"ir/builtin.cc",
@@ -1213,6 +1232,8 @@
"ir/if.h",
"ir/instruction.cc",
"ir/instruction.h",
+ "ir/load.cc",
+ "ir/load.h",
"ir/loop.cc",
"ir/loop.h",
"ir/module.cc",
@@ -1223,6 +1244,8 @@
"ir/store.h",
"ir/switch.cc",
"ir/switch.h",
+ "ir/transform/transform.cc",
+ "ir/transform/transform.h",
"ir/unary.cc",
"ir/unary.h",
"ir/user_call.cc",
@@ -1237,6 +1260,7 @@
":libtint_builtins_src",
":libtint_constant_src",
":libtint_symbols_src",
+ ":libtint_transform_src",
":libtint_type_src",
":libtint_utils_src",
]
@@ -1245,6 +1269,7 @@
source_set("libtint") {
public_deps = [
":libtint_ast_src",
+ ":libtint_ast_transform_base_src",
":libtint_ast_transform_src",
":libtint_constant_src",
":libtint_initializer_src",
@@ -1736,6 +1761,7 @@
]
deps = [
+ ":libtint_ast_transform_base_src",
":libtint_ast_transform_src",
":libtint_builtins_src",
":libtint_transform_manager_src",
@@ -1746,6 +1772,21 @@
]
}
+ if (tint_build_ir) {
+ tint_unittests_source_set("tint_unittests_ir_transform_src") {
+ sources = [
+ "ir/transform/add_empty_entry_point_test.cc",
+ "ir/transform/test_helper.h",
+ ]
+
+ deps = [
+ ":libtint_ir_src",
+ ":libtint_ir_transform_src",
+ ":libtint_transform_manager_src",
+ ]
+ }
+ }
+
tint_unittests_source_set("tint_unittests_utils_src") {
sources = [
"debug_test.cc",
@@ -1902,11 +1943,14 @@
if (tint_build_ir) {
sources += [
- "writer/spirv/generator_impl_binary_test.cc",
- "writer/spirv/generator_impl_constant_test.cc",
- "writer/spirv/generator_impl_function_test.cc",
- "writer/spirv/generator_impl_ir_test.cc",
- "writer/spirv/generator_impl_type_test.cc",
+ "writer/spirv/ir/generator_impl_ir_binary_test.cc",
+ "writer/spirv/ir/generator_impl_ir_constant_test.cc",
+ "writer/spirv/ir/generator_impl_ir_function_test.cc",
+ "writer/spirv/ir/generator_impl_ir_if_test.cc",
+ "writer/spirv/ir/generator_impl_ir_test.cc",
+ "writer/spirv/ir/generator_impl_ir_type_test.cc",
+ "writer/spirv/ir/generator_impl_ir_var_test.cc",
+ "writer/spirv/ir/test_helper_ir.h",
]
deps += [ ":libtint_ir_src" ]
}
@@ -2181,12 +2225,23 @@
"clone_context_test.cc",
"program_builder_test.cc",
"program_test.cc",
+ "transform/manager_test.cc",
]
deps = [
+ ":libtint_ast_transform_base_src",
+ ":libtint_program_src",
+ ":libtint_transform_manager_src",
":libtint_unittests_ast_helper",
":tint_unittests_ast_src",
]
+
+ if (tint_build_ir) {
+ deps += [
+ ":libtint_ir_builder_src",
+ ":libtint_ir_src",
+ ]
+ }
}
tint_unittests_source_set("tint_unittests_ir_src") {
@@ -2203,6 +2258,7 @@
"ir/from_program_test.cc",
"ir/from_program_unary_test.cc",
"ir/from_program_var_test.cc",
+ "ir/load_test.cc",
"ir/module_test.cc",
"ir/store_test.cc",
"ir/test_helper.h",
@@ -2278,7 +2334,10 @@
}
if (tint_build_ir) {
- deps += [ ":tint_unittests_ir_src" ]
+ deps += [
+ ":tint_unittests_ir_src",
+ ":tint_unittests_ir_transform_src",
+ ]
}
if (build_with_chromium) {
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 7d12ae2..d0c7487 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -656,8 +656,8 @@
if(${TINT_BUILD_IR})
list(APPEND TINT_LIB_SRCS
- writer/spirv/generator_impl_ir.cc
- writer/spirv/generator_impl_ir.h
+ writer/spirv/ir/generator_impl_ir.cc
+ writer/spirv/ir/generator_impl_ir.h
)
endif()
endif()
@@ -716,6 +716,8 @@
ir/bitcast.h
ir/block.cc
ir/block.h
+ ir/block_param.cc
+ ir/block_param.h
ir/builder.cc
ir/builder.h
ir/builtin.cc
@@ -746,6 +748,8 @@
ir/if.h
ir/instruction.cc
ir/instruction.h
+ ir/load.cc
+ ir/load.h
ir/loop.cc
ir/loop.h
ir/module.cc
@@ -766,6 +770,10 @@
ir/value.h
ir/var.cc
ir/var.h
+ ir/transform/add_empty_entry_point.cc
+ ir/transform/add_empty_entry_point.h
+ ir/transform/transform.cc
+ ir/transform/transform.h
)
endif()
@@ -996,6 +1004,7 @@
symbol_test.cc
test_main.cc
ast/transform/transform_test.cc
+ transform/manager_test.cc
type/array_test.cc
type/atomic_test.cc
type/bool_test.cc
@@ -1231,12 +1240,14 @@
if(${TINT_BUILD_IR})
list(APPEND TINT_TEST_SRCS
- writer/spirv/generator_impl_binary_test.cc
- writer/spirv/generator_impl_constant_test.cc
- writer/spirv/generator_impl_function_test.cc
- writer/spirv/generator_impl_ir_test.cc
- writer/spirv/generator_impl_type_test.cc
- writer/spirv/test_helper_ir.h
+ writer/spirv/ir/generator_impl_ir_binary_test.cc
+ writer/spirv/ir/generator_impl_ir_constant_test.cc
+ writer/spirv/ir/generator_impl_ir_function_test.cc
+ writer/spirv/ir/generator_impl_ir_if_test.cc
+ writer/spirv/ir/generator_impl_ir_test.cc
+ writer/spirv/ir/generator_impl_ir_type_test.cc
+ writer/spirv/ir/generator_impl_ir_var_test.cc
+ writer/spirv/ir/test_helper_ir.h
)
endif()
endif()
@@ -1459,6 +1470,7 @@
ir/from_program_test.cc
ir/from_program_unary_test.cc
ir/from_program_var_test.cc
+ ir/load_test.cc
ir/module_test.cc
ir/store_test.cc
ir/test_helper.h
diff --git a/src/tint/ir/block.h b/src/tint/ir/block.h
index 3981355..5af32fe 100644
--- a/src/tint/ir/block.h
+++ b/src/tint/ir/block.h
@@ -15,6 +15,7 @@
#ifndef SRC_TINT_IR_BLOCK_H_
#define SRC_TINT_IR_BLOCK_H_
+#include "src/tint/ir/block_param.h"
#include "src/tint/ir/branch.h"
#include "src/tint/ir/flow_node.h"
#include "src/tint/ir/instruction.h"
@@ -40,6 +41,9 @@
/// The instructions in the block
utils::Vector<const Instruction*, 16> instructions;
+
+ /// The parameters passed into the block
+ utils::Vector<const BlockParam*, 0> params;
};
} // namespace tint::ir
diff --git a/src/tint/ir/block_param.cc b/src/tint/ir/block_param.cc
new file mode 100644
index 0000000..f014d19
--- /dev/null
+++ b/src/tint/ir/block_param.cc
@@ -0,0 +1,25 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/block_param.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::BlockParam);
+
+namespace tint::ir {
+
+BlockParam::BlockParam(const type::Type* ty) : type(ty) {}
+
+BlockParam::~BlockParam() = default;
+
+} // namespace tint::ir
diff --git a/src/tint/ir/block_param.h b/src/tint/ir/block_param.h
new file mode 100644
index 0000000..8ba68a7
--- /dev/null
+++ b/src/tint/ir/block_param.h
@@ -0,0 +1,45 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_BLOCK_PARAM_H_
+#define SRC_TINT_IR_BLOCK_PARAM_H_
+
+#include "src/tint/ir/value.h"
+#include "src/tint/utils/castable.h"
+
+namespace tint::ir {
+
+/// An instruction in the IR.
+class BlockParam : public utils::Castable<BlockParam, Value> {
+ public:
+ /// Constructor
+ /// @param type the type of the var
+ explicit BlockParam(const type::Type* type);
+ BlockParam(const BlockParam& inst) = delete;
+ BlockParam(BlockParam&& inst) = delete;
+ ~BlockParam() override;
+
+ BlockParam& operator=(const BlockParam& inst) = delete;
+ BlockParam& operator=(BlockParam&& inst) = delete;
+
+ /// @returns the type of the var
+ const type::Type* Type() const override { return type; }
+
+ /// the result type of the instruction
+ const type::Type* type = nullptr;
+};
+
+} // namespace tint::ir
+
+#endif // SRC_TINT_IR_BLOCK_PARAM_H_
diff --git a/src/tint/ir/builder.cc b/src/tint/ir/builder.cc
index d62bb43..983815e 100644
--- a/src/tint/ir/builder.cc
+++ b/src/tint/ir/builder.cc
@@ -17,6 +17,8 @@
#include <utility>
#include "src/tint/constant/scalar.h"
+#include "src/tint/type/pointer.h"
+#include "src/tint/type/reference.h"
namespace tint::ir {
@@ -114,7 +116,7 @@
}
Binary* Builder::CreateBinary(Binary::Kind kind, const type::Type* type, Value* lhs, Value* rhs) {
- return ir.instructions.Create<ir::Binary>(kind, type, lhs, rhs);
+ return ir.values.Create<ir::Binary>(kind, type, lhs, rhs);
}
Binary* Builder::And(const type::Type* type, Value* lhs, Value* rhs) {
@@ -182,21 +184,13 @@
}
Unary* Builder::CreateUnary(Unary::Kind kind, const type::Type* type, Value* val) {
- return ir.instructions.Create<ir::Unary>(kind, type, val);
-}
-
-Unary* Builder::AddressOf(const type::Type* type, Value* val) {
- return CreateUnary(Unary::Kind::kAddressOf, type, val);
+ return ir.values.Create<ir::Unary>(kind, type, val);
}
Unary* Builder::Complement(const type::Type* type, Value* val) {
return CreateUnary(Unary::Kind::kComplement, type, val);
}
-Unary* Builder::Indirection(const type::Type* type, Value* val) {
- return CreateUnary(Unary::Kind::kIndirection, type, val);
-}
-
Unary* Builder::Negation(const type::Type* type, Value* val) {
return CreateUnary(Unary::Kind::kNegation, type, val);
}
@@ -206,43 +200,51 @@
}
ir::Bitcast* Builder::Bitcast(const type::Type* type, Value* val) {
- return ir.instructions.Create<ir::Bitcast>(type, val);
+ return ir.values.Create<ir::Bitcast>(type, val);
}
ir::Discard* Builder::Discard() {
- return ir.instructions.Create<ir::Discard>();
+ return ir.values.Create<ir::Discard>();
}
ir::UserCall* Builder::UserCall(const type::Type* type,
Symbol name,
utils::VectorRef<Value*> args) {
- return ir.instructions.Create<ir::UserCall>(type, name, std::move(args));
+ return ir.values.Create<ir::UserCall>(type, name, std::move(args));
}
ir::Convert* Builder::Convert(const type::Type* to,
const type::Type* from,
utils::VectorRef<Value*> args) {
- return ir.instructions.Create<ir::Convert>(to, from, std::move(args));
+ return ir.values.Create<ir::Convert>(to, from, std::move(args));
}
ir::Construct* Builder::Construct(const type::Type* to, utils::VectorRef<Value*> args) {
- return ir.instructions.Create<ir::Construct>(to, std::move(args));
+ return ir.values.Create<ir::Construct>(to, std::move(args));
}
ir::Builtin* Builder::Builtin(const type::Type* type,
builtin::Function func,
utils::VectorRef<Value*> args) {
- return ir.instructions.Create<ir::Builtin>(type, func, args);
+ return ir.values.Create<ir::Builtin>(type, func, args);
+}
+
+ir::Load* Builder::Load(Value* from) {
+ auto* ptr = from->Type()->As<type::Pointer>();
+ TINT_ASSERT(IR, ptr);
+ return ir.values.Create<ir::Load>(ptr->StoreType(), from);
}
ir::Store* Builder::Store(Value* to, Value* from) {
- return ir.instructions.Create<ir::Store>(to, from);
+ return ir.values.Create<ir::Store>(to, from);
}
-ir::Var* Builder::Declare(const type::Type* type,
- builtin::AddressSpace address_space,
- builtin::Access access) {
- return ir.instructions.Create<ir::Var>(type, address_space, access);
+ir::Var* Builder::Declare(const type::Type* type) {
+ return ir.values.Create<ir::Var>(type);
+}
+
+ir::BlockParam* Builder::BlockParam(const type::Type* type) {
+ return ir.values.Create<ir::BlockParam>(type);
}
} // namespace tint::ir
diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h
index b4a2a81..7c0cab6 100644
--- a/src/tint/ir/builder.h
+++ b/src/tint/ir/builder.h
@@ -28,6 +28,7 @@
#include "src/tint/ir/function.h"
#include "src/tint/ir/function_terminator.h"
#include "src/tint/ir/if.h"
+#include "src/tint/ir/load.h"
#include "src/tint/ir/loop.h"
#include "src/tint/ir/module.h"
#include "src/tint/ir/root_terminator.h"
@@ -279,24 +280,12 @@
/// @returns the operation
Unary* CreateUnary(Unary::Kind kind, const type::Type* type, Value* val);
- /// Creates an AddressOf operation
- /// @param type the result type of the expression
- /// @param val the value
- /// @returns the operation
- Unary* AddressOf(const type::Type* type, Value* val);
-
/// Creates a Complement operation
/// @param type the result type of the expression
/// @param val the value
/// @returns the operation
Unary* Complement(const type::Type* type, Value* val);
- /// Creates an Indirection operation
- /// @param type the result type of the expression
- /// @param val the value
- /// @returns the operation
- Unary* Indirection(const type::Type* type, Value* val);
-
/// Creates a Negation operation
/// @param type the result type of the expression
/// @param val the value
@@ -350,7 +339,12 @@
builtin::Function func,
utils::VectorRef<Value*> args);
- /// Creates an store instruction
+ /// Creates a load instruction
+ /// @param from the expression being loaded from
+ /// @returns the instruction
+ ir::Load* Load(Value* from);
+
+ /// Creates a store instruction
/// @param to the expression being stored too
/// @param from the expression being stored
/// @returns the instruction
@@ -358,12 +352,13 @@
/// Creates a new `var` declaration
/// @param type the var type
- /// @param address_space the address space
- /// @param access the access mode
/// @returns the instruction
- ir::Var* Declare(const type::Type* type,
- builtin::AddressSpace address_space,
- builtin::Access access);
+ ir::Var* Declare(const type::Type* type);
+
+ /// Creates a new `BlockParam`
+ /// @param type the parameter type
+ /// @returns the value
+ ir::BlockParam* BlockParam(const type::Type* type);
/// Retrieves the root block for the module, creating if necessary
/// @returns the root block
diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc
index d4ee3a7..f68b21a 100644
--- a/src/tint/ir/disassembler.cc
+++ b/src/tint/ir/disassembler.cc
@@ -27,6 +27,7 @@
#include "src/tint/ir/discard.h"
#include "src/tint/ir/function_terminator.h"
#include "src/tint/ir/if.h"
+#include "src/tint/ir/load.h"
#include "src/tint/ir/loop.h"
#include "src/tint/ir/root_terminator.h"
#include "src/tint/ir/store.h"
@@ -154,7 +155,19 @@
return;
}
- Indent() << "%fn" << IdOf(b) << " = block {" << std::endl;
+ Indent() << "%fn" << IdOf(b) << " = block";
+ if (!b->params.IsEmpty()) {
+ out_ << " (";
+ for (auto* p : b->params) {
+ if (p != b->params.Front()) {
+ out_ << ", ";
+ }
+ EmitValue(p);
+ }
+ out_ << ")";
+ }
+
+ out_ << " {" << std::endl;
{
ScopedIndent si(indent_size_);
EmitBlockInstructions(b);
@@ -248,7 +261,20 @@
[&](const ir::If* i) {
Indent() << "%fn" << IdOf(i) << " = if ";
EmitValue(i->condition);
- out_ << " [t: %fn" << IdOf(i->true_.target) << ", f: %fn" << IdOf(i->false_.target);
+
+ bool has_true = !i->true_.target->IsDead();
+ bool has_false = !i->false_.target->IsDead();
+
+ out_ << " [";
+ if (has_true) {
+ out_ << "t: %fn" << IdOf(i->true_.target);
+ }
+ if (has_false) {
+ if (has_true) {
+ out_ << ", ";
+ }
+ out_ << "f: %fn" << IdOf(i->false_.target);
+ }
if (i->merge.target->IsConnected()) {
out_ << ", m: %fn" << IdOf(i->merge.target);
}
@@ -258,10 +284,12 @@
ScopedIndent if_indent(indent_size_);
ScopedStopNode scope(stop_nodes_, i->merge.target);
- Indent() << "# true branch" << std::endl;
- Walk(i->true_.target);
+ if (has_true) {
+ Indent() << "# true branch" << std::endl;
+ Walk(i->true_.target);
+ }
- if (!i->false_.target->IsDead()) {
+ if (has_false) {
Indent() << "# false branch" << std::endl;
Walk(i->false_.target);
}
@@ -324,6 +352,13 @@
return out_.str();
}
+void Disassembler::EmitValueWithType(const Value* val) {
+ EmitValue(val);
+ if (auto* i = val->As<ir::Instruction>(); i->Type() != nullptr) {
+ out_ << ":" << i->Type()->FriendlyName();
+ }
+}
+
void Disassembler::EmitValue(const Value* val) {
tint::Switch(
val,
@@ -368,12 +403,11 @@
};
emit(constant->value);
},
- [&](const ir::Instruction* i) {
- out_ << "%" << IdOf(i);
- if (i->Type() != nullptr) {
- out_ << ":" << i->Type()->FriendlyName();
- }
- });
+ [&](const ir::Instruction* i) { out_ << "%" << IdOf(i); },
+ [&](const ir::BlockParam* p) {
+ out_ << "%" << IdOf(p) << ":" << p->Type()->FriendlyName();
+ },
+ [&](Default) { out_ << "Unknown value: " << val->TypeInfo().name; });
}
void Disassembler::EmitInstruction(const Instruction* inst) {
@@ -381,26 +415,31 @@
inst, //
[&](const ir::Binary* b) { EmitBinary(b); }, [&](const ir::Unary* u) { EmitUnary(u); },
[&](const ir::Bitcast* b) {
- EmitValue(b);
+ EmitValueWithType(b);
out_ << " = bitcast ";
EmitArgs(b);
},
[&](const ir::Discard*) { out_ << "discard"; },
[&](const ir::Builtin* b) {
- EmitValue(b);
+ EmitValueWithType(b);
out_ << " = " << builtin::str(b->Func()) << " ";
EmitArgs(b);
},
[&](const ir::Construct* c) {
- EmitValue(c);
+ EmitValueWithType(c);
out_ << " = construct ";
EmitArgs(c);
},
[&](const ir::Convert* c) {
- EmitValue(c);
+ EmitValueWithType(c);
out_ << " = convert " << c->FromType()->FriendlyName() << ", ";
EmitArgs(c);
},
+ [&](const ir::Load* l) {
+ EmitValueWithType(l);
+ out_ << " = load ";
+ EmitValue(l->from);
+ },
[&](const ir::Store* s) {
out_ << "store ";
EmitValue(s->to);
@@ -408,7 +447,7 @@
EmitValue(s->from);
},
[&](const ir::UserCall* uc) {
- EmitValue(uc);
+ EmitValueWithType(uc);
out_ << " = call " << uc->name.Name();
if (uc->args.Length() > 0) {
out_ << ", ";
@@ -416,8 +455,8 @@
EmitArgs(uc);
},
[&](const ir::Var* v) {
- EmitValue(v);
- out_ << " = var " << v->address_space << ", " << v->access;
+ EmitValueWithType(v);
+ out_ << " = var";
if (v->initializer) {
out_ << ", ";
EmitValue(v->initializer);
@@ -437,7 +476,7 @@
}
void Disassembler::EmitBinary(const Binary* b) {
- EmitValue(b);
+ EmitValueWithType(b);
out_ << " = ";
switch (b->kind) {
case Binary::Kind::kAdd:
@@ -496,18 +535,12 @@
}
void Disassembler::EmitUnary(const Unary* u) {
- EmitValue(u);
+ EmitValueWithType(u);
out_ << " = ";
switch (u->kind) {
- case Unary::Kind::kAddressOf:
- out_ << "addr_of";
- break;
case Unary::Kind::kComplement:
out_ << "complement";
break;
- case Unary::Kind::kIndirection:
- out_ << "indirection";
- break;
case Unary::Kind::kNegation:
out_ << "negation";
break;
diff --git a/src/tint/ir/disassembler.h b/src/tint/ir/disassembler.h
index 2438d80..c8953db 100644
--- a/src/tint/ir/disassembler.h
+++ b/src/tint/ir/disassembler.h
@@ -55,6 +55,7 @@
void Walk(const FlowNode* node);
void EmitInstruction(const Instruction* inst);
+ void EmitValueWithType(const Value* val);
void EmitValue(const Value* val);
void EmitArgs(const Call* call);
void EmitBinary(const Binary* b);
diff --git a/src/tint/ir/from_program.cc b/src/tint/ir/from_program.cc
index f7729f3..bb9c3f9 100644
--- a/src/tint/ir/from_program.cc
+++ b/src/tint/ir/from_program.cc
@@ -48,6 +48,7 @@
#include "src/tint/ast/literal_expression.h"
#include "src/tint/ast/loop_statement.h"
#include "src/tint/ast/override.h"
+#include "src/tint/ast/phony_expression.h"
#include "src/tint/ast/return_statement.h"
#include "src/tint/ast/statement.h"
#include "src/tint/ast/struct.h"
@@ -59,6 +60,7 @@
#include "src/tint/ast/var.h"
#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/ast/while_statement.h"
+#include "src/tint/ir/block_param.h"
#include "src/tint/ir/builder.h"
#include "src/tint/ir/function.h"
#include "src/tint/ir/if.h"
@@ -72,6 +74,7 @@
#include "src/tint/sem/builtin.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/function.h"
+#include "src/tint/sem/load.h"
#include "src/tint/sem/materialize.h"
#include "src/tint/sem/module.h"
#include "src/tint/sem/switch_statement.h"
@@ -80,6 +83,8 @@
#include "src/tint/sem/value_expression.h"
#include "src/tint/sem/variable.h"
#include "src/tint/switch.h"
+#include "src/tint/type/pointer.h"
+#include "src/tint/type/reference.h"
#include "src/tint/type/void.h"
#include "src/tint/utils/defer.h"
#include "src/tint/utils/result.h"
@@ -336,7 +341,6 @@
current_flow_block_ = ir_func->start_target;
EmitBlock(ast_func->body);
- // TODO(dsinclair): Store return type and attributes
// TODO(dsinclair): Store parameters
// If the branch target has already been set then a `return` was called. Only set in the
@@ -390,6 +394,16 @@
}
void EmitAssignment(const ast::AssignmentStatement* stmt) {
+ // If assigning to a phony, just generate the RHS and we're done. Note that, because this
+ // isn't used, a subsequent transform could remove it due to it being dead code. This could
+ // then change the interface for the program (i.e. a global var no longer used). If that
+ // happens we have to either fix this to store to a phony value, or make sure we pull the
+ // interface before doing the dead code elimination.
+ if (stmt->lhs->Is<ast::PhonyExpression>()) {
+ (void)EmitExpression(stmt->rhs);
+ return;
+ }
+
auto lhs = EmitExpression(stmt->lhs);
if (!lhs) {
return;
@@ -409,15 +423,20 @@
return;
}
- auto* ty = lhs.Get()->Type();
- auto* rhs = ty->UnwrapRef()->is_signed_integer_scalar() ? builder_.Constant(1_i)
- : builder_.Constant(1_u);
+ // Load from the LHS.
+ auto* lhs_value = builder_.Load(lhs.Get());
+ current_flow_block_->instructions.Push(lhs_value);
+
+ auto* ty = lhs_value->Type();
+
+ auto* rhs =
+ ty->is_signed_integer_scalar() ? builder_.Constant(1_i) : builder_.Constant(1_u);
Binary* inst = nullptr;
if (stmt->increment) {
- inst = builder_.Add(ty, lhs.Get(), rhs);
+ inst = builder_.Add(ty, lhs_value, rhs);
} else {
- inst = builder_.Subtract(ty, lhs.Get(), rhs);
+ inst = builder_.Subtract(ty, lhs_value, rhs);
}
current_flow_block_->instructions.Push(inst);
@@ -435,38 +454,44 @@
if (!rhs) {
return;
}
- auto* ty = lhs.Get()->Type();
+
+ // Load from the LHS.
+ auto* lhs_value = builder_.Load(lhs.Get());
+ current_flow_block_->instructions.Push(lhs_value);
+
+ auto* ty = lhs_value->Type();
+
Binary* inst = nullptr;
switch (stmt->op) {
case ast::BinaryOp::kAnd:
- inst = builder_.And(ty, lhs.Get(), rhs.Get());
+ inst = builder_.And(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kOr:
- inst = builder_.Or(ty, lhs.Get(), rhs.Get());
+ inst = builder_.Or(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kXor:
- inst = builder_.Xor(ty, lhs.Get(), rhs.Get());
+ inst = builder_.Xor(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kShiftLeft:
- inst = builder_.ShiftLeft(ty, lhs.Get(), rhs.Get());
+ inst = builder_.ShiftLeft(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kShiftRight:
- inst = builder_.ShiftRight(ty, lhs.Get(), rhs.Get());
+ inst = builder_.ShiftRight(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kAdd:
- inst = builder_.Add(ty, lhs.Get(), rhs.Get());
+ inst = builder_.Add(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kSubtract:
- inst = builder_.Subtract(ty, lhs.Get(), rhs.Get());
+ inst = builder_.Subtract(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kMultiply:
- inst = builder_.Multiply(ty, lhs.Get(), rhs.Get());
+ inst = builder_.Multiply(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kDivide:
- inst = builder_.Divide(ty, lhs.Get(), rhs.Get());
+ inst = builder_.Divide(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kModulo:
- inst = builder_.Modulo(ty, lhs.Get(), rhs.Get());
+ inst = builder_.Modulo(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kLessThanEqual:
case ast::BinaryOp::kGreaterThanEqual:
@@ -796,7 +821,8 @@
utils::Result<Value*> EmitExpression(const ast::Expression* expr) {
// If this is a value that has been const-eval'd return the result.
- if (auto* sem = program_->Sem().Get(expr)->As<sem::ValueExpression>()) {
+ auto* sem = program_->Sem().GetVal(expr);
+ if (sem) {
if (auto* v = sem->ConstantValue()) {
if (auto* cv = v->Clone(clone_ctx_)) {
return builder_.Constant(cv);
@@ -804,7 +830,7 @@
}
}
- return tint::Switch(
+ auto result = tint::Switch(
expr,
// [&](const ast::IndexAccessorExpression* a) {
// TODO(dsinclair): Implement
@@ -825,15 +851,23 @@
// [&](const ast::MemberAccessorExpression* m) {
// TODO(dsinclair): Implement
// },
- // [&](const ast::PhonyExpression*) {
- // TODO(dsinclair): Implement. The call may have side effects so has to be made.
- // },
[&](const ast::UnaryOpExpression* u) { return EmitUnary(u); },
+ // Note, ast::PhonyExpression is explicitly not handled here as it should never get into
+ // this method. The assignment statement should have filtered it out already.
[&](Default) {
add_error(expr->source,
"unknown expression type: " + std::string(expr->TypeInfo().name));
return utils::Failure;
});
+
+ // If this expression maps to sem::Load, insert a load instruction to get the result.
+ if (result && sem->Is<sem::Load>()) {
+ auto* load = builder_.Load(result.Get());
+ current_flow_block_->instructions.Push(load);
+ return load;
+ }
+
+ return result;
}
void EmitVariable(const ast::Variable* var) {
@@ -842,8 +876,12 @@
return tint::Switch( //
var,
[&](const ast::Var* v) {
- auto* ty = sem->Type()->Clone(clone_ctx_.type_ctx);
- auto* val = builder_.Declare(ty, sem->AddressSpace(), sem->Access());
+ auto* ref = sem->Type()->As<type::Reference>();
+ auto* ty = builder_.ir.types.Get<type::Pointer>(
+ ref->StoreType()->Clone(clone_ctx_.type_ctx), ref->AddressSpace(),
+ ref->Access());
+
+ auto* val = builder_.Declare(ty);
current_flow_block_->instructions.Push(val);
if (v->initializer) {
@@ -904,14 +942,12 @@
Instruction* inst = nullptr;
switch (expr->op) {
case ast::UnaryOp::kAddressOf:
- inst = builder_.AddressOf(ty, val.Get());
- break;
+ case ast::UnaryOp::kIndirection:
+ // 'address-of' and 'indirection' just fold away and we propagate the pointer.
+ return val;
case ast::UnaryOp::kComplement:
inst = builder_.Complement(ty, val.Get());
break;
- case ast::UnaryOp::kIndirection:
- inst = builder_.Indirection(ty, val.Get());
- break;
case ast::UnaryOp::kNegation:
inst = builder_.Negation(ty, val.Get());
break;
@@ -943,27 +979,34 @@
return utils::Failure;
}
- // Generate a variable to store the short-circut into
- auto* ty = builder_.ir.types.Get<type::Bool>();
- auto* result_var =
- builder_.Declare(ty, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
- current_flow_block_->instructions.Push(result_var);
-
- auto* lhs_store = builder_.Store(result_var, lhs.Get());
- current_flow_block_->instructions.Push(lhs_store);
-
auto* if_node = builder_.CreateIf(lhs.Get());
BranchTo(if_node);
+ auto* result = builder_.BlockParam(builder_.ir.types.Get<type::Bool>());
+ if_node->merge.target->As<Block>()->params.Push(result);
+
utils::Result<Value*> rhs;
{
FlowStackScope scope(this, if_node);
+ utils::Vector<Value*, 1> alt_args;
+ alt_args.Push(lhs.Get());
+
// If this is an `&&` then we only evaluate the RHS expression in the true block.
// If this is an `||` then we only evaluate the RHS expression in the false block.
if (expr->op == ast::BinaryOp::kLogicalAnd) {
+ // If the lhs is false, then that is the result we want to pass to the merge block
+ // as our argument
+ current_flow_block_ = if_node->false_.target->As<Block>();
+ BranchTo(if_node->merge.target, std::move(alt_args));
+
current_flow_block_ = if_node->true_.target->As<Block>();
} else {
+ // If the lhs is true, then that is the result we want to pass to the merge block
+ // as our argument
+ current_flow_block_ = if_node->true_.target->As<Block>();
+ BranchTo(if_node->merge.target, std::move(alt_args));
+
current_flow_block_ = if_node->false_.target->As<Block>();
}
@@ -971,14 +1014,14 @@
if (!rhs) {
return utils::Failure;
}
- auto* rhs_store = builder_.Store(result_var, rhs.Get());
- current_flow_block_->instructions.Push(rhs_store);
+ utils::Vector<Value*, 1> args;
+ args.Push(rhs.Get());
- BranchTo(if_node->merge.target);
+ BranchTo(if_node->merge.target, std::move(args));
}
current_flow_block_ = if_node->merge.target->As<Block>();
- return result_var;
+ return result;
}
utils::Result<Value*> EmitBinary(const ast::BinaryExpression* expr) {
diff --git a/src/tint/ir/from_program_binary_test.cc b/src/tint/ir/from_program_binary_test.cc
index 33266cb..85377ff 100644
--- a/src/tint/ir/from_program_binary_test.cc
+++ b/src/tint/ir/from_program_binary_test.cc
@@ -42,7 +42,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:u32 = call my_func
- %tint_symbol:u32 = add %1:u32, 4u
+ %tint_symbol:u32 = add %1, 4u
} -> %func_end # return
} %func_end
@@ -58,14 +58,15 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %v1:ref<private, u32, read_write> = var private, read_write
+ %v1:ptr<private, u32, read_write> = var
}
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
- %2:ref<private, u32, read_write> = add %v1:ref<private, u32, read_write>, 1u
- store %v1:ref<private, u32, read_write>, %2:ref<private, u32, read_write>
+ %2:u32 = load %v1
+ %3:u32 = add %2, 1u
+ store %v1, %3
} -> %func_end # return
} %func_end
@@ -81,14 +82,15 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %v1:ref<private, u32, read_write> = var private, read_write
+ %v1:ptr<private, u32, read_write> = var
}
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
- %2:ref<private, u32, read_write> = add %v1:ref<private, u32, read_write>, 1u
- store %v1:ref<private, u32, read_write>, %2:ref<private, u32, read_write>
+ %2:u32 = load %v1
+ %3:u32 = add %2, 1u
+ store %v1, %3
} -> %func_end # return
} %func_end
@@ -111,7 +113,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:u32 = call my_func
- %tint_symbol:u32 = sub %1:u32, 4u
+ %tint_symbol:u32 = sub %1, 4u
} -> %func_end # return
} %func_end
@@ -127,14 +129,15 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %v1:ref<private, i32, read_write> = var private, read_write
+ %v1:ptr<private, i32, read_write> = var
}
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
- %2:ref<private, i32, read_write> = sub %v1:ref<private, i32, read_write>, 1i
- store %v1:ref<private, i32, read_write>, %2:ref<private, i32, read_write>
+ %2:i32 = load %v1
+ %3:i32 = sub %2, 1i
+ store %v1, %3
} -> %func_end # return
} %func_end
@@ -150,14 +153,15 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %v1:ref<private, u32, read_write> = var private, read_write
+ %v1:ptr<private, u32, read_write> = var
}
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
- %2:ref<private, u32, read_write> = sub %v1:ref<private, u32, read_write>, 1u
- store %v1:ref<private, u32, read_write>, %2:ref<private, u32, read_write>
+ %2:u32 = load %v1
+ %3:u32 = sub %2, 1u
+ store %v1, %3
} -> %func_end # return
} %func_end
@@ -180,7 +184,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:u32 = call my_func
- %tint_symbol:u32 = mul %1:u32, 4u
+ %tint_symbol:u32 = mul %1, 4u
} -> %func_end # return
} %func_end
@@ -196,14 +200,15 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %v1:ref<private, u32, read_write> = var private, read_write
+ %v1:ptr<private, u32, read_write> = var
}
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
- %2:ref<private, u32, read_write> = mul %v1:ref<private, u32, read_write>, 1u
- store %v1:ref<private, u32, read_write>, %2:ref<private, u32, read_write>
+ %2:u32 = load %v1
+ %3:u32 = mul %2, 1u
+ store %v1, %3
} -> %func_end # return
} %func_end
@@ -226,7 +231,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:u32 = call my_func
- %tint_symbol:u32 = div %1:u32, 4u
+ %tint_symbol:u32 = div %1, 4u
} -> %func_end # return
} %func_end
@@ -242,14 +247,15 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %v1:ref<private, u32, read_write> = var private, read_write
+ %v1:ptr<private, u32, read_write> = var
}
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
- %2:ref<private, u32, read_write> = div %v1:ref<private, u32, read_write>, 1u
- store %v1:ref<private, u32, read_write>, %2:ref<private, u32, read_write>
+ %2:u32 = load %v1
+ %3:u32 = div %2, 1u
+ store %v1, %3
} -> %func_end # return
} %func_end
@@ -272,7 +278,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:u32 = call my_func
- %tint_symbol:u32 = mod %1:u32, 4u
+ %tint_symbol:u32 = mod %1, 4u
} -> %func_end # return
} %func_end
@@ -288,14 +294,15 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %v1:ref<private, u32, read_write> = var private, read_write
+ %v1:ptr<private, u32, read_write> = var
}
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
- %2:ref<private, u32, read_write> = mod %v1:ref<private, u32, read_write>, 1u
- store %v1:ref<private, u32, read_write>, %2:ref<private, u32, read_write>
+ %2:u32 = load %v1
+ %3:u32 = mod %2, 1u
+ store %v1, %3
} -> %func_end # return
} %func_end
@@ -318,7 +325,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:u32 = call my_func
- %tint_symbol:u32 = and %1:u32, 4u
+ %tint_symbol:u32 = and %1, 4u
} -> %func_end # return
} %func_end
@@ -334,14 +341,15 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %v1:ref<private, bool, read_write> = var private, read_write
+ %v1:ptr<private, bool, read_write> = var
}
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
- %2:ref<private, bool, read_write> = and %v1:ref<private, bool, read_write>, false
- store %v1:ref<private, bool, read_write>, %2:ref<private, bool, read_write>
+ %2:bool = load %v1
+ %3:bool = and %2, false
+ store %v1, %3
} -> %func_end # return
} %func_end
@@ -364,7 +372,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:u32 = call my_func
- %tint_symbol:u32 = or %1:u32, 4u
+ %tint_symbol:u32 = or %1, 4u
} -> %func_end # return
} %func_end
@@ -380,14 +388,15 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %v1:ref<private, bool, read_write> = var private, read_write
+ %v1:ptr<private, bool, read_write> = var
}
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
- %2:ref<private, bool, read_write> = or %v1:ref<private, bool, read_write>, false
- store %v1:ref<private, bool, read_write>, %2:ref<private, bool, read_write>
+ %2:bool = load %v1
+ %3:bool = or %2, false
+ store %v1, %3
} -> %func_end # return
} %func_end
@@ -410,7 +419,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:u32 = call my_func
- %tint_symbol:u32 = xor %1:u32, 4u
+ %tint_symbol:u32 = xor %1, 4u
} -> %func_end # return
} %func_end
@@ -426,14 +435,15 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %v1:ref<private, u32, read_write> = var private, read_write
+ %v1:ptr<private, u32, read_write> = var
}
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
- %2:ref<private, u32, read_write> = xor %v1:ref<private, u32, read_write>, 1u
- store %v1:ref<private, u32, read_write>, %2:ref<private, u32, read_write>
+ %2:u32 = load %v1
+ %3:u32 = xor %2, 1u
+ store %v1, %3
} -> %func_end # return
} %func_end
@@ -442,7 +452,7 @@
TEST_F(IR_BuilderImplTest, EmitExpression_Binary_LogicalAnd) {
Func("my_func", utils::Empty, ty.bool_(), utils::Vector{Return(true)});
- auto* expr = LogicalAnd(Call("my_func"), false);
+ auto* expr = If(LogicalAnd(Call("my_func"), false), Block());
WrapInFunction(expr);
auto m = Build();
@@ -456,18 +466,32 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:bool = call my_func
- %tint_symbol:bool = var function, read_write
- store %tint_symbol:bool, %1:bool
} -> %fn5 # branch
- %fn5 = if %1:bool [t: %fn6, f: %fn7, m: %fn8]
+ %fn5 = if %1 [t: %fn6, f: %fn7, m: %fn8]
# true branch
%fn6 = block {
- store %tint_symbol:bool, false
- } -> %fn8 # branch
+ } -> %fn8 false # branch
+
+ # false branch
+ %fn7 = block {
+ } -> %fn8 %1 # branch
# if merge
- %fn8 = block {
+ %fn8 = block (%2:bool) {
+ } -> %fn9 # branch
+
+ %fn9 = if %2:bool [t: %fn10, f: %fn11, m: %fn12]
+ # true branch
+ %fn10 = block {
+ } -> %fn12 # branch
+
+ # false branch
+ %fn11 = block {
+ } -> %fn12 # branch
+
+ # if merge
+ %fn12 = block {
} -> %func_end # return
} %func_end
@@ -476,7 +500,7 @@
TEST_F(IR_BuilderImplTest, EmitExpression_Binary_LogicalOr) {
Func("my_func", utils::Empty, ty.bool_(), utils::Vector{Return(true)});
- auto* expr = LogicalOr(Call("my_func"), true);
+ auto* expr = If(LogicalOr(Call("my_func"), true), Block());
WrapInFunction(expr);
auto m = Build();
@@ -490,19 +514,32 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:bool = call my_func
- %tint_symbol:bool = var function, read_write
- store %tint_symbol:bool, %1:bool
} -> %fn5 # branch
- %fn5 = if %1:bool [t: %fn6, f: %fn7, m: %fn8]
+ %fn5 = if %1 [t: %fn6, f: %fn7, m: %fn8]
# true branch
+ %fn6 = block {
+ } -> %fn8 %1 # branch
+
# false branch
%fn7 = block {
- store %tint_symbol:bool, true
- } -> %fn8 # branch
+ } -> %fn8 true # branch
# if merge
- %fn8 = block {
+ %fn8 = block (%2:bool) {
+ } -> %fn9 # branch
+
+ %fn9 = if %2:bool [t: %fn10, f: %fn11, m: %fn12]
+ # true branch
+ %fn10 = block {
+ } -> %fn12 # branch
+
+ # false branch
+ %fn11 = block {
+ } -> %fn12 # branch
+
+ # if merge
+ %fn12 = block {
} -> %func_end # return
} %func_end
@@ -525,7 +562,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:u32 = call my_func
- %tint_symbol:bool = eq %1:u32, 4u
+ %tint_symbol:bool = eq %1, 4u
} -> %func_end # return
} %func_end
@@ -548,7 +585,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:u32 = call my_func
- %tint_symbol:bool = neq %1:u32, 4u
+ %tint_symbol:bool = neq %1, 4u
} -> %func_end # return
} %func_end
@@ -571,7 +608,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:u32 = call my_func
- %tint_symbol:bool = lt %1:u32, 4u
+ %tint_symbol:bool = lt %1, 4u
} -> %func_end # return
} %func_end
@@ -594,7 +631,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:u32 = call my_func
- %tint_symbol:bool = gt %1:u32, 4u
+ %tint_symbol:bool = gt %1, 4u
} -> %func_end # return
} %func_end
@@ -617,7 +654,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:u32 = call my_func
- %tint_symbol:bool = lte %1:u32, 4u
+ %tint_symbol:bool = lte %1, 4u
} -> %func_end # return
} %func_end
@@ -640,7 +677,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:u32 = call my_func
- %tint_symbol:bool = gte %1:u32, 4u
+ %tint_symbol:bool = gte %1, 4u
} -> %func_end # return
} %func_end
@@ -663,7 +700,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:u32 = call my_func
- %tint_symbol:u32 = shiftl %1:u32, 4u
+ %tint_symbol:u32 = shiftl %1, 4u
} -> %func_end # return
} %func_end
@@ -679,14 +716,15 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %v1:ref<private, u32, read_write> = var private, read_write
+ %v1:ptr<private, u32, read_write> = var
}
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
- %2:ref<private, u32, read_write> = shiftl %v1:ref<private, u32, read_write>, 1u
- store %v1:ref<private, u32, read_write>, %2:ref<private, u32, read_write>
+ %2:u32 = load %v1
+ %3:u32 = shiftl %2, 1u
+ store %v1, %3
} -> %func_end # return
} %func_end
@@ -709,7 +747,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:u32 = call my_func
- %tint_symbol:u32 = shiftr %1:u32, 4u
+ %tint_symbol:u32 = shiftr %1, 4u
} -> %func_end # return
} %func_end
@@ -725,14 +763,15 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %v1:ref<private, u32, read_write> = var private, read_write
+ %v1:ptr<private, u32, read_write> = var
}
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
- %2:ref<private, u32, read_write> = shiftr %v1:ref<private, u32, read_write>, 1u
- store %v1:ref<private, u32, read_write>, %2:ref<private, u32, read_write>
+ %2:u32 = load %v1
+ %3:u32 = shiftr %2, 1u
+ store %v1, %3
} -> %func_end # return
} %func_end
@@ -757,24 +796,25 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:f32 = call my_func
- %2:bool = lt %1:f32, 2.0f
- %tint_symbol:bool = var function, read_write
- store %tint_symbol:bool, %2:bool
+ %2:bool = lt %1, 2.0f
} -> %fn5 # branch
- %fn5 = if %2:bool [t: %fn6, f: %fn7, m: %fn8]
+ %fn5 = if %2 [t: %fn6, f: %fn7, m: %fn8]
# true branch
%fn6 = block {
+ %3:f32 = call my_func
%4:f32 = call my_func
- %5:f32 = call my_func
- %6:f32 = mul 2.29999995231628417969f, %5:f32
- %7:f32 = div %4:f32, %6:f32
- %8:bool = gt 2.5f, %7:f32
- store %tint_symbol:bool, %8:bool
- } -> %fn8 # branch
+ %5:f32 = mul 2.29999995231628417969f, %4
+ %6:f32 = div %3, %5
+ %7:bool = gt 2.5f, %6
+ } -> %fn8 %7 # branch
+
+ # false branch
+ %fn7 = block {
+ } -> %fn8 %2 # branch
# if merge
- %fn8 = block {
+ %fn8 = block (%tint_symbol:bool) {
} -> %func_end # return
} %func_end
diff --git a/src/tint/ir/from_program_call_test.cc b/src/tint/ir/from_program_call_test.cc
index ae217d1..038f4dd 100644
--- a/src/tint/ir/from_program_call_test.cc
+++ b/src/tint/ir/from_program_call_test.cc
@@ -43,7 +43,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:f32 = call my_func
- %tint_symbol:f32 = bitcast %1:f32
+ %tint_symbol:f32 = bitcast %1
} -> %func_end # return
} %func_end
@@ -100,13 +100,14 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %i:ref<private, i32, read_write> = var private, read_write, 1i
+ %i:ptr<private, i32, read_write> = var, 1i
}
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
- %tint_symbol:f32 = convert i32, %i:ref<private, i32, read_write>
+ %2:i32 = load %i
+ %tint_symbol:f32 = convert i32, %2
} -> %func_end # return
} %func_end
@@ -121,7 +122,7 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %i:ref<private, vec3<f32>, read_write> = var private, read_write, vec3<f32> 0.0f
+ %i:ptr<private, vec3<f32>, read_write> = var, vec3<f32> 0.0f
}
@@ -137,13 +138,14 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %i:ref<private, f32, read_write> = var private, read_write, 1.0f
+ %i:ptr<private, f32, read_write> = var, 1.0f
}
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
- %tint_symbol:vec3<f32> = construct 2.0f, 3.0f, %i:ref<private, f32, read_write>
+ %2:f32 = load %i
+ %tint_symbol:vec3<f32> = construct 2.0f, 3.0f, %2
} -> %func_end # return
} %func_end
diff --git a/src/tint/ir/from_program_store_test.cc b/src/tint/ir/from_program_store_test.cc
index 4e89dbe..5bb6398 100644
--- a/src/tint/ir/from_program_store_test.cc
+++ b/src/tint/ir/from_program_store_test.cc
@@ -36,13 +36,13 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %a:ref<private, u32, read_write> = var private, read_write
+ %a:ptr<private, u32, read_write> = var
}
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
- store %a:ref<private, u32, read_write>, 4u
+ store %a, 4u
} -> %func_end # return
} %func_end
diff --git a/src/tint/ir/from_program_test.cc b/src/tint/ir/from_program_test.cc
index 234b6cc..47dc911 100644
--- a/src/tint/ir/from_program_test.cc
+++ b/src/tint/ir/from_program_test.cc
@@ -1427,5 +1427,27 @@
)");
}
+TEST_F(IR_BuilderImplTest, Emit_Phony) {
+ Func("b", utils::Empty, ty.i32(), Return(1_i));
+ WrapInFunction(Ignore(Call("b")));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%fn1 = func b():i32 {
+ %fn2 = block {
+ } -> %func_end 1i # return
+} %func_end
+
+%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
+ %fn4 = block {
+ %1:i32 = call b
+ } -> %func_end # return
+} %func_end
+
+)");
+}
+
} // namespace
} // namespace tint::ir
diff --git a/src/tint/ir/from_program_unary_test.cc b/src/tint/ir/from_program_unary_test.cc
index b2ba2d7..2afd2a2 100644
--- a/src/tint/ir/from_program_unary_test.cc
+++ b/src/tint/ir/from_program_unary_test.cc
@@ -42,7 +42,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:bool = call my_func
- %tint_symbol:bool = eq %1:bool, false
+ %tint_symbol:bool = eq %1, false
} -> %func_end # return
} %func_end
@@ -65,7 +65,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:u32 = call my_func
- %tint_symbol:u32 = complement %1:u32
+ %tint_symbol:u32 = complement %1
} -> %func_end # return
} %func_end
@@ -88,7 +88,7 @@
%fn3 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn4 = block {
%1:i32 = call my_func
- %tint_symbol:i32 = negation %1:i32
+ %tint_symbol:i32 = negation %1
} -> %func_end # return
} %func_end
@@ -105,13 +105,12 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %v1:ref<private, i32, read_write> = var private, read_write
+ %v2:ptr<private, i32, read_write> = var
}
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
- %v2:ptr<private, i32, read_write> = addr_of %v1:ref<private, i32, read_write>
} -> %func_end # return
} %func_end
@@ -122,7 +121,7 @@
GlobalVar("v1", builtin::AddressSpace::kPrivate, ty.i32());
utils::Vector stmts = {
Decl(Let("v3", AddressOf("v1"))),
- Decl(Let("v2", Deref("v3"))),
+ Assign(Deref("v3"), 42_i),
};
WrapInFunction(stmts);
@@ -130,14 +129,13 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %v1:ref<private, i32, read_write> = var private, read_write
+ %v3:ptr<private, i32, read_write> = var
}
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
- %v3:ptr<private, i32, read_write> = addr_of %v1:ref<private, i32, read_write>
- %v2:i32 = indirection %v3:ptr<private, i32, read_write>
+ store %v3, 42i
} -> %func_end # return
} %func_end
diff --git a/src/tint/ir/from_program_var_test.cc b/src/tint/ir/from_program_var_test.cc
index b5379ac..e235f88 100644
--- a/src/tint/ir/from_program_var_test.cc
+++ b/src/tint/ir/from_program_var_test.cc
@@ -33,7 +33,7 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %a:ref<private, u32, read_write> = var private, read_write
+ %a:ptr<private, u32, read_write> = var
}
@@ -48,7 +48,7 @@
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
EXPECT_EQ(Disassemble(m.Get()), R"(%fn1 = block {
- %a:ref<private, u32, read_write> = var private, read_write, 2u
+ %a:ptr<private, u32, read_write> = var, 2u
}
@@ -65,7 +65,7 @@
EXPECT_EQ(Disassemble(m.Get()),
R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn2 = block {
- %a:ref<function, u32, read_write> = var function, read_write
+ %a:ptr<function, u32, read_write> = var
} -> %func_end # return
} %func_end
@@ -83,7 +83,7 @@
EXPECT_EQ(Disassemble(m.Get()),
R"(%fn1 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn2 = block {
- %a:ref<function, u32, read_write> = var function, read_write, 2u
+ %a:ptr<function, u32, read_write> = var, 2u
} -> %func_end # return
} %func_end
diff --git a/src/tint/writer/spirv/generator_impl_ir_test.cc b/src/tint/ir/load.cc
similarity index 60%
copy from src/tint/writer/spirv/generator_impl_ir_test.cc
copy to src/tint/ir/load.cc
index a202eea..1fe55c0 100644
--- a/src/tint/writer/spirv/generator_impl_ir_test.cc
+++ b/src/tint/ir/load.cc
@@ -12,18 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/writer/spirv/test_helper_ir.h"
+#include "src/tint/ir/load.h"
+#include "src/tint/debug.h"
-namespace tint::writer::spirv {
-namespace {
+TINT_INSTANTIATE_TYPEINFO(tint::ir::Load);
-TEST_F(SpvGeneratorImplTest, ModuleHeader) {
- ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
- auto got = Disassemble(generator_.Result());
- EXPECT_EQ(got, R"(OpCapability Shader
-OpMemoryModel Logical GLSL450
-)");
+namespace tint::ir {
+
+Load::Load(const type::Type* type, Value* f) : Base(), result_type(type), from(f) {
+ TINT_ASSERT(IR, result_type);
+ TINT_ASSERT(IR, from);
+ from->AddUsage(this);
}
-} // namespace
-} // namespace tint::writer::spirv
+Load::~Load() = default;
+
+} // namespace tint::ir
diff --git a/src/tint/ir/load.h b/src/tint/ir/load.h
new file mode 100644
index 0000000..b15eced
--- /dev/null
+++ b/src/tint/ir/load.h
@@ -0,0 +1,49 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_LOAD_H_
+#define SRC_TINT_IR_LOAD_H_
+
+#include "src/tint/ir/instruction.h"
+#include "src/tint/utils/castable.h"
+
+namespace tint::ir {
+
+/// A load instruction in the IR.
+class Load : public utils::Castable<Load, Instruction> {
+ public:
+ /// Constructor
+ /// @param type the result type
+ /// @param from the value being loaded from
+ Load(const type::Type* type, Value* from);
+ Load(const Load& inst) = delete;
+ Load(Load&& inst) = delete;
+ ~Load() override;
+
+ Load& operator=(const Load& inst) = delete;
+ Load& operator=(Load&& inst) = delete;
+
+ /// @returns the type of the value
+ const type::Type* Type() const override { return result_type; }
+
+ /// the result type of the instruction
+ const type::Type* result_type = nullptr;
+
+ /// the value being loaded
+ Value* from = nullptr;
+};
+
+} // namespace tint::ir
+
+#endif // SRC_TINT_IR_LOAD_H_
diff --git a/src/tint/ir/load_test.cc b/src/tint/ir/load_test.cc
new file mode 100644
index 0000000..2c6e5c1
--- /dev/null
+++ b/src/tint/ir/load_test.cc
@@ -0,0 +1,59 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/instruction.h"
+#include "src/tint/ir/test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+
+using IR_InstructionTest = TestHelper;
+
+TEST_F(IR_InstructionTest, CreateLoad) {
+ Module mod;
+ Builder b{mod};
+
+ auto* store_type = b.ir.types.Get<type::I32>();
+ auto* var = b.Declare(b.ir.types.Get<type::Pointer>(
+ store_type, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
+ const auto* inst = b.Load(var);
+
+ ASSERT_TRUE(inst->Is<Load>());
+ ASSERT_EQ(inst->from, var);
+
+ EXPECT_EQ(inst->Type(), store_type);
+
+ ASSERT_TRUE(inst->from->Is<ir::Var>());
+ EXPECT_EQ(inst->from, var);
+}
+
+TEST_F(IR_InstructionTest, Load_Usage) {
+ Module mod;
+ Builder b{mod};
+
+ auto* store_type = b.ir.types.Get<type::I32>();
+ auto* var = b.Declare(b.ir.types.Get<type::Pointer>(
+ store_type, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
+ const auto* inst = b.Load(var);
+
+ ASSERT_NE(inst->from, nullptr);
+ ASSERT_EQ(inst->from->Usage().Length(), 1u);
+ EXPECT_EQ(inst->from->Usage()[0], inst);
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/module.h b/src/tint/ir/module.h
index a7929d2..244bc0a 100644
--- a/src/tint/ir/module.h
+++ b/src/tint/ir/module.h
@@ -70,8 +70,6 @@
utils::BlockAllocator<constant::Value> constants;
/// The value allocator
utils::BlockAllocator<Value> values;
- /// The instruction allocator
- utils::BlockAllocator<Instruction> instructions;
/// List of functions in the program
utils::Vector<Function*, 8> functions;
diff --git a/src/tint/ir/module_test.cc b/src/tint/ir/module_test.cc
index c9b1150..9cd36ae 100644
--- a/src/tint/ir/module_test.cc
+++ b/src/tint/ir/module_test.cc
@@ -25,23 +25,20 @@
TEST_F(IR_ModuleTest, NameOfUnnamed) {
Module mod;
- auto* v = mod.values.Create<ir::Var>(
- mod.types.Get<type::I32>(), builtin::AddressSpace::kUndefined, builtin::Access::kUndefined);
+ auto* v = mod.values.Create<ir::Var>(mod.types.Get<type::I32>());
EXPECT_FALSE(mod.NameOf(v).IsValid());
}
TEST_F(IR_ModuleTest, SetName) {
Module mod;
- auto* v = mod.values.Create<ir::Var>(
- mod.types.Get<type::I32>(), builtin::AddressSpace::kUndefined, builtin::Access::kUndefined);
+ auto* v = mod.values.Create<ir::Var>(mod.types.Get<type::I32>());
EXPECT_EQ(mod.SetName(v, "a").Name(), "a");
EXPECT_EQ(mod.NameOf(v).Name(), "a");
}
TEST_F(IR_ModuleTest, SetNameRename) {
Module mod;
- auto* v = mod.values.Create<ir::Var>(
- mod.types.Get<type::I32>(), builtin::AddressSpace::kUndefined, builtin::Access::kUndefined);
+ auto* v = mod.values.Create<ir::Var>(mod.types.Get<type::I32>());
EXPECT_EQ(mod.SetName(v, "a").Name(), "a");
EXPECT_EQ(mod.SetName(v, "b").Name(), "b");
EXPECT_EQ(mod.NameOf(v).Name(), "b");
@@ -49,12 +46,9 @@
TEST_F(IR_ModuleTest, SetNameCollision) {
Module mod;
- auto* a = mod.values.Create<ir::Var>(
- mod.types.Get<type::I32>(), builtin::AddressSpace::kUndefined, builtin::Access::kUndefined);
- auto* b = mod.values.Create<ir::Var>(
- mod.types.Get<type::I32>(), builtin::AddressSpace::kUndefined, builtin::Access::kUndefined);
- auto* c = mod.values.Create<ir::Var>(
- mod.types.Get<type::I32>(), builtin::AddressSpace::kUndefined, builtin::Access::kUndefined);
+ auto* a = mod.values.Create<ir::Var>(mod.types.Get<type::I32>());
+ auto* b = mod.values.Create<ir::Var>(mod.types.Get<type::I32>());
+ auto* c = mod.values.Create<ir::Var>(mod.types.Get<type::I32>());
EXPECT_EQ(mod.SetName(a, "x").Name(), "x");
EXPECT_EQ(mod.SetName(b, "x_1").Name(), "x_1");
EXPECT_EQ(mod.SetName(c, "x").Name(), "x_2");
diff --git a/src/tint/ir/to_program.cc b/src/tint/ir/to_program.cc
index 8a2dca9..b1aa67f 100644
--- a/src/tint/ir/to_program.cc
+++ b/src/tint/ir/to_program.cc
@@ -14,15 +14,19 @@
#include "src/tint/ir/to_program.h"
+#include <string>
#include <utility>
#include "src/tint/ir/block.h"
#include "src/tint/ir/call.h"
#include "src/tint/ir/constant.h"
+#include "src/tint/ir/function_terminator.h"
#include "src/tint/ir/if.h"
#include "src/tint/ir/instruction.h"
+#include "src/tint/ir/load.h"
#include "src/tint/ir/module.h"
#include "src/tint/ir/store.h"
+#include "src/tint/ir/switch.h"
#include "src/tint/ir/user_call.h"
#include "src/tint/ir/var.h"
#include "src/tint/program_builder.h"
@@ -40,6 +44,17 @@
#include "src/tint/utils/transform.h"
#include "src/tint/utils/vector.h"
+// Helper for calling TINT_UNIMPLEMENTED() from a Switch(object_ptr) default case.
+#define UNHANDLED_CASE(object_ptr) \
+ TINT_UNIMPLEMENTED(IR, b.Diagnostics()) \
+ << "unhandled case in Switch(): " << (object_ptr ? object_ptr->TypeInfo().name : "<null>")
+
+// Helper for incrementing nesting_depth_ and then decrementing nesting_depth_ at the end
+// of the scope that holds the call.
+#define SCOPED_NESTING() \
+ nesting_depth_++; \
+ TINT_DEFER(nesting_depth_--)
+
namespace tint::ir {
namespace {
@@ -58,76 +73,264 @@
}
private:
+ /// The source IR module
const Module& mod;
+
+ /// The target ProgramBuilder
ProgramBuilder b;
+
+ /// A hashmap of value to symbol used in the emitted AST
utils::Hashmap<const Value*, Symbol, 32> value_names_;
- void Fn(const Function* fn) {
+ // The nesting depth of the currently generated AST
+ // 0 is module scope
+ // 1 is root-level function scope
+ // 2+ is within control flow
+ uint32_t nesting_depth_ = 0;
+
+ const ast::Function* Fn(const Function* fn) {
+ SCOPED_NESTING();
+
auto name = Sym(fn->name);
// TODO(crbug.com/tint/1915): Properly implement this when we've fleshed out Function
utils::Vector<const ast::Parameter*, 1> params{};
- ast::Type ret_ty;
- auto* body = Block(fn->start_target);
+ auto ret_ty = Type(fn->return_type);
+ if (!ret_ty) {
+ return nullptr;
+ }
+ auto* body = FlowNodeGraph(fn->start_target);
+ if (!body) {
+ return nullptr;
+ }
utils::Vector<const ast::Attribute*, 1> attrs{};
utils::Vector<const ast::Attribute*, 1> ret_attrs{};
- b.Func(name, std::move(params), ret_ty, body, std::move(attrs), std::move(ret_attrs));
+ return b.Func(name, std::move(params), ret_ty.Get(), body, std::move(attrs),
+ std::move(ret_attrs));
}
- const ast::BlockStatement* Block(const ir::Block* block) {
+ const ast::BlockStatement* FlowNodeGraph(ir::FlowNode* start_node,
+ ir::FlowNode* stop_at = nullptr) {
// TODO(crbug.com/tint/1902): Check if the block is dead
- utils::Vector<const ast::Statement*, decltype(ir::Block::instructions)::static_length>
+ utils::Vector<const ast::Statement*,
+ decltype(ast::BlockStatement::statements)::static_length>
stmts;
- for (auto* inst : block->instructions) {
- auto* stmt = Stmt(inst);
- if (!stmt) {
+
+ ir::Branch root_branch{start_node, {}};
+ const ir::Branch* branch = &root_branch;
+
+ // TODO(crbug.com/tint/1902): Handle block arguments.
+
+ while (branch->target != stop_at) {
+ enum Status { kContinue, kStop, kError };
+ Status status = tint::Switch(
+ branch->target,
+
+ [&](const ir::Block* block) {
+ for (auto* inst : block->instructions) {
+ auto stmt = Stmt(inst);
+ if (TINT_UNLIKELY(!stmt)) {
+ return kError;
+ }
+ if (auto* s = stmt.Get()) {
+ stmts.Push(s);
+ }
+ }
+ branch = &block->branch;
+ return kContinue;
+ },
+
+ [&](const ir::If* if_) {
+ auto* stmt = If(if_);
+ if (TINT_UNLIKELY(!stmt)) {
+ return kError;
+ }
+ stmts.Push(stmt);
+ branch = &if_->merge;
+ return branch->target->inbound_branches.IsEmpty() ? kStop : kContinue;
+ },
+
+ [&](const ir::Switch* switch_) {
+ auto* stmt = Switch(switch_);
+ if (TINT_UNLIKELY(!stmt)) {
+ return kError;
+ }
+ stmts.Push(stmt);
+ branch = &switch_->merge;
+ return branch->target->inbound_branches.IsEmpty() ? kStop : kContinue;
+ },
+
+ [&](const ir::FunctionTerminator*) {
+ auto res = FunctionTerminator(branch);
+ if (TINT_UNLIKELY(!res)) {
+ return kError;
+ }
+ if (auto* stmt = res.Get()) {
+ stmts.Push(stmt);
+ }
+ return kStop;
+ },
+
+ [&](Default) {
+ UNHANDLED_CASE(branch->target);
+ return kError;
+ });
+
+ if (TINT_UNLIKELY(status == kError)) {
return nullptr;
}
- stmts.Push(stmt);
+ if (status == kStop) {
+ break;
+ }
}
+
return b.Block(std::move(stmts));
}
- const ast::Statement* FlowNode(const ir::FlowNode* node) {
- // TODO(crbug.com/tint/1902): Check the node is connected
- return Switch(
- node, //
- [&](const ir::If* i) {
- auto* cond = Expr(i->condition);
- auto* t = Branch(i->true_);
- if (auto* f = Branch(i->false_)) {
- return b.If(cond, t, b.Else(f));
- }
- // TODO(crbug.com/tint/1902): Emit merge block
- return b.If(cond, t);
- },
- [&](Default) {
- TINT_UNIMPLEMENTED(IR, b.Diagnostics())
- << "unhandled case in Switch(): " << node->TypeInfo().name;
- return nullptr;
- });
- }
+ const ast::IfStatement* If(const ir::If* i) {
+ SCOPED_NESTING();
- const ast::BlockStatement* Branch(const ir::Branch& branch) {
- auto* stmt = FlowNode(branch.target);
- if (!stmt) {
+ auto* cond = Expr(i->condition);
+ auto* t = FlowNodeGraph(i->true_.target, i->merge.target);
+ if (TINT_UNLIKELY(!t)) {
return nullptr;
}
- if (auto* block = stmt->As<ast::BlockStatement>()) {
- return block;
+
+ if (!IsEmpty(i->false_.target, i->merge.target)) {
+ // If the else target is an if flow node with the same merge target as this if, then
+ // emit an 'else if' instead of a block statement for the else.
+ if (auto* else_if = As<ir::If>(NextNonEmptyNode(i->false_.target));
+ else_if &&
+ NextNonEmptyNode(i->merge.target) == NextNonEmptyNode(else_if->merge.target)) {
+ auto* f = If(else_if);
+ if (!f) {
+ return nullptr;
+ }
+ return b.If(cond, t, b.Else(f));
+ } else {
+ auto* f = FlowNodeGraph(i->false_.target, i->merge.target);
+ if (!f) {
+ return nullptr;
+ }
+ return b.If(cond, t, b.Else(f));
+ }
}
- return b.Block(stmt);
+
+ return b.If(cond, t);
}
- const ast::Statement* Stmt(const ir::Instruction* inst) {
- return Switch(
+ const ast::SwitchStatement* Switch(const ir::Switch* s) {
+ SCOPED_NESTING();
+
+ auto* cond = Expr(s->condition);
+ if (!cond) {
+ return nullptr;
+ }
+
+ auto cases = utils::Transform(
+ s->cases, //
+ [&](const ir::Switch::Case& c) -> const tint::ast::CaseStatement* {
+ SCOPED_NESTING();
+ auto* body = FlowNodeGraph(c.start.target, s->merge.target);
+ if (!body) {
+ return nullptr;
+ }
+
+ auto selectors = utils::Transform(
+ c.selectors, //
+ [&](const ir::Switch::CaseSelector& cs) -> const ast::CaseSelector* {
+ if (cs.IsDefault()) {
+ return b.DefaultCaseSelector();
+ }
+ auto* expr = Expr(cs.val);
+ if (!expr) {
+ return nullptr;
+ }
+ return b.CaseSelector(expr);
+ });
+ if (selectors.Any(utils::IsNull)) {
+ return nullptr;
+ }
+
+ return b.Case(std::move(selectors), body);
+ });
+ if (cases.Any(utils::IsNull)) {
+ return nullptr;
+ }
+
+ return b.Switch(cond, std::move(cases));
+ }
+
+ utils::Result<const ast::ReturnStatement*> FunctionTerminator(const ir::Branch* branch) {
+ if (branch->args.IsEmpty()) {
+ // Branch to function terminator has no arguments.
+ // If this block is nested withing some control flow, then we must emit a
+ // 'return' statement, otherwise we've just naturally reached the end of the
+ // function where the 'return' is redundant.
+ if (nesting_depth_ > 1) {
+ return b.Return();
+ }
+ return nullptr;
+ }
+
+ // Branch to function terminator has arguments - this is the return value.
+ if (branch->args.Length() != 1) {
+ TINT_ICE(IR, b.Diagnostics())
+ << "expected 1 value for function terminator (return value), got "
+ << branch->args.Length();
+ return utils::Failure;
+ }
+
+ auto* val = Expr(branch->args.Front());
+ if (TINT_UNLIKELY(!val)) {
+ return utils::Failure;
+ }
+
+ return b.Return(val);
+ }
+
+ /// @return true if there are no instructions between @p node and and @p stop_at
+ bool IsEmpty(const ir::FlowNode* node, const ir::FlowNode* stop_at) {
+ while (node != stop_at) {
+ if (auto* block = node->As<ir::Block>()) {
+ if (block->instructions.Length() > 0) {
+ return false;
+ }
+ node = block->branch.target;
+ } else {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /// @return the next flow node that isn't an empty block
+ const ir::FlowNode* NextNonEmptyNode(const ir::FlowNode* node) {
+ while (node) {
+ if (auto* block = node->As<ir::Block>()) {
+ for (auto* inst : block->instructions) {
+ // Load instructions will be inlined, so ignore them.
+ if (!inst->Is<ir::Load>()) {
+ return node;
+ }
+ }
+ node = block->branch.target;
+ } else {
+ return node;
+ }
+ }
+ return nullptr;
+ }
+
+ utils::Result<const ast::Statement*> Stmt(const ir::Instruction* inst) {
+ return tint::Switch<utils::Result<const ast::Statement*>>(
inst, //
[&](const ir::Call* i) { return CallStmt(i); }, //
[&](const ir::Var* i) { return Var(i); }, //
- [&](const ir::Store* i) { return Store(i); },
+ [&](const ir::Load*) { return nullptr; },
+ [&](const ir::Store* i) { return Store(i); }, //
[&](Default) {
- TINT_UNIMPLEMENTED(IR, b.Diagnostics())
- << "unhandled case in Switch(): " << inst->TypeInfo().name;
- return nullptr;
+ UNHANDLED_CASE(inst);
+ return utils::Failure;
});
}
@@ -141,7 +344,12 @@
const ast::VariableDeclStatement* Var(const ir::Var* var) {
Symbol name = NameOf(var);
- auto ty = Type(var->Type());
+ auto* ptr = var->Type()->As<type::Pointer>();
+ if (!ptr) {
+ Err("Incorrect type for var");
+ return nullptr;
+ }
+ auto ty = Type(ptr->StoreType());
const ast::Expression* init = nullptr;
if (var->initializer) {
init = Expr(var->initializer);
@@ -149,13 +357,13 @@
return nullptr;
}
}
- switch (var->address_space) {
+ switch (ptr->AddressSpace()) {
case builtin::AddressSpace::kFunction:
- return b.Decl(b.Var(name, ty, init));
+ return b.Decl(b.Var(name, ty.Get(), init));
case builtin::AddressSpace::kStorage:
- return b.Decl(b.Var(name, ty, init, var->access, var->address_space));
+ return b.Decl(b.Var(name, ty.Get(), init, ptr->Access(), ptr->AddressSpace()));
default:
- return b.Decl(b.Var(name, ty, init, var->address_space));
+ return b.Decl(b.Var(name, ty.Get(), init, ptr->AddressSpace()));
}
}
@@ -169,29 +377,29 @@
if (args.Any(utils::IsNull)) {
return nullptr;
}
- return Switch(
+ return tint::Switch(
call, //
[&](const ir::UserCall* c) { return b.Call(Sym(c->name), std::move(args)); },
[&](Default) {
- TINT_UNIMPLEMENTED(IR, b.Diagnostics())
- << "unhandled case in Switch(): " << call->TypeInfo().name;
+ UNHANDLED_CASE(call);
return nullptr;
});
}
const ast::Expression* Expr(const ir::Value* val) {
- return Switch(
+ return tint::Switch(
val, //
[&](const ir::Constant* c) { return ConstExpr(c); },
+ [&](const ir::Load* l) { return LoadExpr(l); },
+ [&](const ir::Var* v) { return VarExpr(v); },
[&](Default) {
- TINT_UNIMPLEMENTED(IR, b.Diagnostics())
- << "unhandled case in Switch(): " << val->TypeInfo().name;
+ UNHANDLED_CASE(val);
return nullptr;
});
}
const ast::Expression* ConstExpr(const ir::Constant* c) {
- return Switch(
+ return tint::Switch(
c->Type(), //
[&](const type::I32*) { return b.Expr(c->value->ValueAs<i32>()); },
[&](const type::U32*) { return b.Expr(c->value->ValueAs<u32>()); },
@@ -199,14 +407,17 @@
[&](const type::F16*) { return b.Expr(c->value->ValueAs<f16>()); },
[&](const type::Bool*) { return b.Expr(c->value->ValueAs<bool>()); },
[&](Default) {
- TINT_UNIMPLEMENTED(IR, b.Diagnostics())
- << "unhandled case in Switch(): " << c->TypeInfo().name;
+ UNHANDLED_CASE(c);
return nullptr;
});
}
- const ast::Type Type(const type::Type* ty) {
- return Switch(
+ const ast::Expression* LoadExpr(const ir::Load* l) { return Expr(l->from); }
+
+ const ast::Expression* VarExpr(const ir::Var* v) { return b.Expr(NameOf(v)); }
+
+ utils::Result<ast::Type> Type(const type::Type* ty) {
+ return tint::Switch<utils::Result<ast::Type>>(
ty, //
[&](const type::Void*) { return ast::Type{}; }, //
[&](const type::I32*) { return b.ty.i32(); }, //
@@ -214,64 +425,94 @@
[&](const type::F16*) { return b.ty.f16(); }, //
[&](const type::F32*) { return b.ty.f32(); }, //
[&](const type::Bool*) { return b.ty.bool_(); },
- [&](const type::Matrix* m) {
+ [&](const type::Matrix* m) -> utils::Result<ast::Type> {
auto el = Type(m->type());
- return b.ty.mat(el, m->columns(), m->rows());
+ if (!el) {
+ return utils::Failure;
+ }
+ return b.ty.mat(el.Get(), m->columns(), m->rows());
},
- [&](const type::Vector* v) {
+ [&](const type::Vector* v) -> utils::Result<ast::Type> {
auto el = Type(v->type());
+ if (!el) {
+ return utils::Failure;
+ }
if (v->Packed()) {
TINT_ASSERT(IR, v->Width() == 3u);
- return b.ty(builtin::Builtin::kPackedVec3, el);
+ return b.ty(builtin::Builtin::kPackedVec3, el.Get());
} else {
- return b.ty.vec(el, v->Width());
+ return b.ty.vec(el.Get(), v->Width());
}
},
- [&](const type::Array* a) {
+ [&](const type::Array* a) -> utils::Result<ast::Type> {
auto el = Type(a->ElemType());
+ if (!el) {
+ return utils::Failure;
+ }
utils::Vector<const ast::Attribute*, 1> attrs;
if (!a->IsStrideImplicit()) {
attrs.Push(b.Stride(a->Stride()));
}
if (a->Count()->Is<type::RuntimeArrayCount>()) {
- return b.ty.array(el, std::move(attrs));
+ return b.ty.array(el.Get(), std::move(attrs));
}
auto count = a->ConstantCount();
if (TINT_UNLIKELY(!count)) {
TINT_ICE(IR, b.Diagnostics()) << type::Array::kErrExpectedConstantCount;
- return b.ty.array(el, u32(1), std::move(attrs));
+ return b.ty.array(el.Get(), u32(1), std::move(attrs));
}
- return b.ty.array(el, u32(count.value()), std::move(attrs));
+ return b.ty.array(el.Get(), u32(count.value()), std::move(attrs));
},
[&](const type::Struct* s) { return b.ty(s->Name().NameView()); },
- [&](const type::Atomic* a) { return b.ty.atomic(Type(a->Type())); },
+ [&](const type::Atomic* a) -> utils::Result<ast::Type> {
+ auto el = Type(a->Type());
+ if (!el) {
+ return utils::Failure;
+ }
+ return b.ty.atomic(el.Get());
+ },
[&](const type::DepthTexture* t) { return b.ty.depth_texture(t->dim()); },
[&](const type::DepthMultisampledTexture* t) {
return b.ty.depth_multisampled_texture(t->dim());
},
[&](const type::ExternalTexture*) { return b.ty.external_texture(); },
- [&](const type::MultisampledTexture* t) {
- return b.ty.multisampled_texture(t->dim(), Type(t->type()));
+ [&](const type::MultisampledTexture* t) -> utils::Result<ast::Type> {
+ auto el = Type(t->type());
+ if (!el) {
+ return utils::Failure;
+ }
+ return b.ty.multisampled_texture(t->dim(), el.Get());
},
- [&](const type::SampledTexture* t) {
- return b.ty.sampled_texture(t->dim(), Type(t->type()));
+ [&](const type::SampledTexture* t) -> utils::Result<ast::Type> {
+ auto el = Type(t->type());
+ if (!el) {
+ return utils::Failure;
+ }
+ return b.ty.sampled_texture(t->dim(), el.Get());
},
[&](const type::StorageTexture* t) {
return b.ty.storage_texture(t->dim(), t->texel_format(), t->access());
},
[&](const type::Sampler* s) { return b.ty.sampler(s->kind()); },
- [&](const type::Pointer* p) {
+ [&](const type::Pointer* p) -> utils::Result<ast::Type> {
// Note: type::Pointer always has an inferred access, but WGSL only allows an
// explicit access in the 'storage' address space.
+ auto el = Type(p->StoreType());
+ if (!el) {
+ return utils::Failure;
+ }
auto address_space = p->AddressSpace();
auto access = address_space == builtin::AddressSpace::kStorage
? p->Access()
: builtin::Access::kUndefined;
- return b.ty.pointer(Type(p->StoreType()), address_space, access);
+ return b.ty.pointer(el.Get(), address_space, access);
},
- [&](const type::Reference* r) { return Type(r->StoreType()); },
+ [&](const type::Reference*) -> utils::Result<ast::Type> {
+ TINT_ICE(IR, b.Diagnostics()) << "reference types should never appear in the IR";
+ return ast::Type{};
+ },
[&](Default) {
- TINT_UNREACHABLE(IR, b.Diagnostics()) << "unhandled type: " << ty->TypeInfo().name;
+ UNHANDLED_CASE(ty);
return ast::Type{};
});
}
@@ -288,7 +529,7 @@
Symbol Sym(const Symbol& s) { return b.Symbols().Register(s.NameView()); }
- // void Err(std::string str) { b.Diagnostics().add_error(diag::System::IR, std::move(str)); }
+ void Err(std::string str) { b.Diagnostics().add_error(diag::System::IR, std::move(str)); }
};
} // namespace
diff --git a/src/tint/ir/to_program_roundtrip_test.cc b/src/tint/ir/to_program_roundtrip_test.cc
index 11aff46..8f26792 100644
--- a/src/tint/ir/to_program_roundtrip_test.cc
+++ b/src/tint/ir/to_program_roundtrip_test.cc
@@ -40,6 +40,13 @@
ASSERT_TRUE(ir_module);
auto output_program = ToProgram(ir_module.Get());
+ if (!output_program.IsValid()) {
+ tint::ir::Disassembler d{ir_module.Get()};
+ FAIL() << output_program.Diagnostics().str() << std::endl
+ << "IR:" << std::endl
+ << d.Disassemble();
+ }
+
ASSERT_TRUE(output_program.IsValid()) << output_program.Diagnostics().str();
auto output = writer::wgsl::Generate(&output_program, {});
@@ -60,13 +67,44 @@
Test("");
}
-TEST_F(IRToProgramRoundtripTest, EmptySingleFunction) {
+TEST_F(IRToProgramRoundtripTest, SingleFunction_Empty) {
Test(R"(
fn f() {
}
)");
}
+TEST_F(IRToProgramRoundtripTest, SingleFunction_Return) {
+ Test(R"(
+fn f() {
+ return;
+}
+)",
+ R"(
+fn f() {
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, SingleFunction_Return_i32) {
+ Test(R"(
+fn f() -> i32 {
+ return 42i;
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Function-scope var
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, FunctionScopeVar_i32) {
+ Test(R"(
+fn f() {
+ var i : i32;
+}
+)");
+}
+
TEST_F(IRToProgramRoundtripTest, FunctionScopeVar_i32_InitLiteral) {
Test(R"(
fn f() {
@@ -75,5 +113,256 @@
)");
}
+TEST_F(IRToProgramRoundtripTest, FunctionScopeVar_Chained) {
+ Test(R"(
+fn f() {
+ var a : i32 = 42i;
+ var b : i32 = a;
+ var c : i32 = b;
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// If
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, If_CallFn) {
+ Test(R"(
+fn a() {
+}
+
+fn f() {
+ var cond : bool = true;
+ if (cond) {
+ a();
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, If_Return) {
+ Test(R"(
+fn f() {
+ var cond : bool = true;
+ if (cond) {
+ return;
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, If_Return_i32) {
+ Test(R"(
+fn f() -> i32 {
+ var cond : bool = true;
+ if (cond) {
+ return 42i;
+ }
+ return 10i;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, If_CallFn_Else_CallFn) {
+ Test(R"(
+fn a() {
+}
+
+fn b() {
+}
+
+fn f() {
+ var cond : bool = true;
+ if (cond) {
+ a();
+ } else {
+ b();
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, If_Return_f32_Else_Return_f32) {
+ Test(R"(
+fn f() -> f32 {
+ var cond : bool = true;
+ if (cond) {
+ return 1.0f;
+ } else {
+ return 2.0f;
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, If_Return_u32_Else_CallFn) {
+ Test(R"(
+fn a() {
+}
+
+fn b() {
+}
+
+fn f() -> u32 {
+ var cond : bool = true;
+ if (cond) {
+ return 1u;
+ } else {
+ a();
+ }
+ b();
+ return 2u;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, If_CallFn_ElseIf_CallFn) {
+ Test(R"(
+fn a() {
+}
+
+fn b() {
+}
+
+fn c() {
+}
+
+fn f() {
+ var cond_a : bool = true;
+ var cond_b : bool = true;
+ if (cond_a) {
+ a();
+ } else if (cond_b) {
+ b();
+ }
+ c();
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Switch
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, Switch_Default) {
+ Test(R"(
+fn a() {
+}
+
+fn f() {
+ var v : i32 = 42i;
+ switch(v) {
+ default: {
+ a();
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Switch_3_Cases) {
+ Test(R"(
+fn a() {
+}
+
+fn b() {
+}
+
+fn c() {
+}
+
+fn f() {
+ var v : i32 = 42i;
+ switch(v) {
+ case 0i: {
+ a();
+ }
+ case 1i, default: {
+ b();
+ }
+ case 2i: {
+ c();
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Switch_3_Cases_AllReturn) {
+ Test(R"(
+fn a() {
+}
+
+fn f() {
+ var v : i32 = 42i;
+ switch(v) {
+ case 0i: {
+ return;
+ }
+ case 1i, default: {
+ return;
+ }
+ case 2i: {
+ return;
+ }
+ }
+ a();
+}
+)",
+ R"(
+fn a() {
+}
+
+fn f() {
+ var v : i32 = 42i;
+ switch(v) {
+ case 0i: {
+ return;
+ }
+ case 1i, default: {
+ return;
+ }
+ case 2i: {
+ return;
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Switch_Nested) {
+ Test(R"(
+fn a() {
+}
+
+fn b() {
+}
+
+fn c() {
+}
+
+fn f() {
+ var v1 : i32 = 42i;
+ var v2 : i32 = 24i;
+ switch(v1) {
+ case 0i: {
+ a();
+ }
+ case 1i, default: {
+ switch(v2) {
+ case 0i: {
+ }
+ case 1i, default: {
+ return;
+ }
+ }
+ }
+ case 2i: {
+ c();
+ }
+ }
+}
+)");
+}
+
} // namespace
} // namespace tint::ir
diff --git a/src/tint/ir/transform/add_empty_entry_point.cc b/src/tint/ir/transform/add_empty_entry_point.cc
new file mode 100644
index 0000000..809a6ad
--- /dev/null
+++ b/src/tint/ir/transform/add_empty_entry_point.cc
@@ -0,0 +1,45 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/add_empty_entry_point.h"
+
+#include <utility>
+
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/module.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::AddEmptyEntryPoint);
+
+namespace tint::ir::transform {
+
+AddEmptyEntryPoint::AddEmptyEntryPoint() = default;
+
+AddEmptyEntryPoint::~AddEmptyEntryPoint() = default;
+
+void AddEmptyEntryPoint::Run(ir::Module* ir, const DataMap&, DataMap&) const {
+ for (auto* func : ir->functions) {
+ if (func->pipeline_stage != Function::PipelineStage::kUndefined) {
+ return;
+ }
+ }
+
+ ir::Builder builder(*ir);
+ auto* ep =
+ builder.CreateFunction(ir->symbols.New("unused_entry_point"), ir->types.Get<type::Void>(),
+ Function::PipelineStage::kCompute, std::array{1u, 1u, 1u});
+ builder.Branch(ep->start_target, ep->end_target);
+ ir->functions.Push(ep);
+}
+
+} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/add_empty_entry_point.h b/src/tint/ir/transform/add_empty_entry_point.h
new file mode 100644
index 0000000..39f3413
--- /dev/null
+++ b/src/tint/ir/transform/add_empty_entry_point.h
@@ -0,0 +1,36 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_TRANSFORM_ADD_EMPTY_ENTRY_POINT_H_
+#define SRC_TINT_IR_TRANSFORM_ADD_EMPTY_ENTRY_POINT_H_
+
+#include "src/tint/ir/transform/transform.h"
+
+namespace tint::ir::transform {
+
+/// Add an empty entry point to the module, if no other entry points exist.
+class AddEmptyEntryPoint final : public utils::Castable<AddEmptyEntryPoint, Transform> {
+ public:
+ /// Constructor
+ AddEmptyEntryPoint();
+ /// Destructor
+ ~AddEmptyEntryPoint() override;
+
+ /// @copydoc Transform::Run
+ void Run(ir::Module* module, const DataMap& inputs, DataMap& outputs) const override;
+};
+
+} // namespace tint::ir::transform
+
+#endif // SRC_TINT_IR_TRANSFORM_ADD_EMPTY_ENTRY_POINT_H_
diff --git a/src/tint/ir/transform/add_empty_entry_point_test.cc b/src/tint/ir/transform/add_empty_entry_point_test.cc
new file mode 100644
index 0000000..baba8e0
--- /dev/null
+++ b/src/tint/ir/transform/add_empty_entry_point_test.cc
@@ -0,0 +1,60 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/add_empty_entry_point.h"
+
+#include <utility>
+
+#include "src/tint/ir/transform/test_helper.h"
+
+namespace tint::ir::transform {
+namespace {
+
+using IR_AddEmptyEntryPointTest = TransformTest;
+
+TEST_F(IR_AddEmptyEntryPointTest, EmptyModule) {
+ auto* expect = R"(
+%fn1 = func unused_entry_point():void [@compute @workgroup_size(1, 1, 1)] {
+ %fn2 = block {
+ } -> %func_end # return
+} %func_end
+
+)";
+
+ Run<AddEmptyEntryPoint>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_AddEmptyEntryPointTest, ExistingEntryPoint) {
+ auto* ep = b.CreateFunction(mod.symbols.New("main"), mod.types.Get<type::Void>(),
+ Function::PipelineStage::kFragment);
+ b.Branch(ep->start_target, ep->end_target);
+ mod.functions.Push(ep);
+
+ auto* expect = R"(
+%fn1 = func main():void [@fragment] {
+ %fn2 = block {
+ } -> %func_end # return
+} %func_end
+
+)";
+
+ Run<AddEmptyEntryPoint>();
+
+ EXPECT_EQ(expect, str());
+}
+
+} // namespace
+} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/test_helper.h b/src/tint/ir/transform/test_helper.h
new file mode 100644
index 0000000..a122eb2
--- /dev/null
+++ b/src/tint/ir/transform/test_helper.h
@@ -0,0 +1,72 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_TRANSFORM_TEST_HELPER_H_
+#define SRC_TINT_IR_TRANSFORM_TEST_HELPER_H_
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/disassembler.h"
+#include "src/tint/ir/transform/transform.h"
+#include "src/tint/transform/manager.h"
+
+namespace tint::ir::transform {
+
+/// Helper class for testing IR transforms.
+template <typename BASE>
+class TransformTestBase : public BASE {
+ public:
+ /// Transforms the module, using transforms in `TRANSFORMS`.
+ /// @param data the optional Transform::DataMap to pass to Transform::Run()
+ /// @returns the transform outputs, if any
+ template <typename... TRANSFORMS>
+ Transform::DataMap Run(const Transform::DataMap& data = {}) {
+ tint::transform::Manager manager;
+ tint::transform::DataMap outputs;
+ for (auto* transform_ptr : std::initializer_list<Transform*>{new TRANSFORMS()...}) {
+ manager.append(std::unique_ptr<Transform>(transform_ptr));
+ }
+ manager.Run(&mod, data, outputs);
+ return outputs;
+ }
+
+ /// @returns the transformed module as a disassembled string
+ std::string str() {
+ ir::Disassembler dis(mod);
+ return "\n" + dis.Disassemble();
+ }
+
+ protected:
+ /// The test IR module.
+ ir::Module mod;
+ /// The test IR builder.
+ ir::Builder b{mod};
+
+ private:
+ std::vector<std::unique_ptr<Source::File>> files_;
+};
+
+using TransformTest = TransformTestBase<testing::Test>;
+
+template <typename T>
+using TransformTestWithParam = TransformTestBase<testing::TestWithParam<T>>;
+
+} // namespace tint::ir::transform
+
+#endif // SRC_TINT_IR_TRANSFORM_TEST_HELPER_H_
diff --git a/src/tint/ir/transform/transform.cc b/src/tint/ir/transform/transform.cc
new file mode 100644
index 0000000..eb5e68e
--- /dev/null
+++ b/src/tint/ir/transform/transform.cc
@@ -0,0 +1,24 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/transform.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::Transform);
+
+namespace tint::ir::transform {
+
+Transform::Transform() = default;
+Transform::~Transform() = default;
+
+} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/transform.h b/src/tint/ir/transform/transform.h
new file mode 100644
index 0000000..a09fe77
--- /dev/null
+++ b/src/tint/ir/transform/transform.h
@@ -0,0 +1,48 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_TRANSFORM_TRANSFORM_H_
+#define SRC_TINT_IR_TRANSFORM_TRANSFORM_H_
+
+#include "src/tint/transform/transform.h"
+
+#include <utility>
+
+#include "src/tint/utils/castable.h"
+
+// Forward declarations
+namespace tint::ir {
+class Module;
+} // namespace tint::ir
+
+namespace tint::ir::transform {
+
+/// Interface for IR Module transforms.
+class Transform : public utils::Castable<Transform, tint::transform::Transform> {
+ public:
+ /// Constructor
+ Transform();
+ /// Destructor
+ ~Transform() override;
+
+ /// Run the transform on @p module
+ /// @param module the source module to transform
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ virtual void Run(ir::Module* module, const DataMap& inputs, DataMap& outputs) const = 0;
+};
+
+} // namespace tint::ir::transform
+
+#endif // SRC_TINT_IR_TRANSFORM_TRANSFORM_H_
diff --git a/src/tint/ir/unary.h b/src/tint/ir/unary.h
index e665386..46edd45 100644
--- a/src/tint/ir/unary.h
+++ b/src/tint/ir/unary.h
@@ -25,9 +25,7 @@
public:
/// The kind of instruction.
enum class Kind {
- kAddressOf,
kComplement,
- kIndirection,
kNegation,
};
@@ -50,7 +48,7 @@
const Value* Val() const { return val_; }
/// the kind of unary instruction
- Kind kind = Kind::kAddressOf;
+ Kind kind = Kind::kNegation;
/// the result type of the instruction
const type::Type* result_type = nullptr;
diff --git a/src/tint/ir/unary_test.cc b/src/tint/ir/unary_test.cc
index 392a75e..280c7de 100644
--- a/src/tint/ir/unary_test.cc
+++ b/src/tint/ir/unary_test.cc
@@ -23,27 +23,6 @@
using IR_InstructionTest = TestHelper;
-TEST_F(IR_InstructionTest, CreateAddressOf) {
- Module mod;
- Builder b{mod};
-
- // TODO(dsinclair): This would be better as an identifier, but works for now.
- const auto* inst = b.AddressOf(
- b.ir.types.Get<type::Pointer>(b.ir.types.Get<type::I32>(), builtin::AddressSpace::kPrivate,
- builtin::Access::kReadWrite),
- b.Constant(4_i));
-
- ASSERT_TRUE(inst->Is<Unary>());
- EXPECT_EQ(inst->kind, Unary::Kind::kAddressOf);
-
- ASSERT_NE(inst->Type(), nullptr);
-
- ASSERT_TRUE(inst->Val()->Is<Constant>());
- auto lhs = inst->Val()->As<Constant>()->value;
- ASSERT_TRUE(lhs->Is<constant::Scalar<i32>>());
- EXPECT_EQ(4_i, lhs->As<constant::Scalar<i32>>()->ValueAs<i32>());
-}
-
TEST_F(IR_InstructionTest, CreateComplement) {
Module mod;
Builder b{mod};
@@ -58,22 +37,6 @@
EXPECT_EQ(4_i, lhs->As<constant::Scalar<i32>>()->ValueAs<i32>());
}
-TEST_F(IR_InstructionTest, CreateIndirection) {
- Module mod;
- Builder b{mod};
-
- // TODO(dsinclair): This would be better as an identifier, but works for now.
- const auto* inst = b.Indirection(b.ir.types.Get<type::I32>(), b.Constant(4_i));
-
- ASSERT_TRUE(inst->Is<Unary>());
- EXPECT_EQ(inst->kind, Unary::Kind::kIndirection);
-
- ASSERT_TRUE(inst->Val()->Is<Constant>());
- auto lhs = inst->Val()->As<Constant>()->value;
- ASSERT_TRUE(lhs->Is<constant::Scalar<i32>>());
- EXPECT_EQ(4_i, lhs->As<constant::Scalar<i32>>()->ValueAs<i32>());
-}
-
TEST_F(IR_InstructionTest, CreateNegation) {
Module mod;
Builder b{mod};
diff --git a/src/tint/ir/var.cc b/src/tint/ir/var.cc
index e33da54..9bf4b8d 100644
--- a/src/tint/ir/var.cc
+++ b/src/tint/ir/var.cc
@@ -19,8 +19,7 @@
namespace tint::ir {
-Var::Var(const type::Type* ty, builtin::AddressSpace addr_space, builtin::Access acc)
- : type(ty), address_space(addr_space), access(acc) {}
+Var::Var(const type::Type* ty) : type(ty) {}
Var::~Var() = default;
diff --git a/src/tint/ir/var.h b/src/tint/ir/var.h
index 5a61104..c874a62 100644
--- a/src/tint/ir/var.h
+++ b/src/tint/ir/var.h
@@ -27,9 +27,7 @@
public:
/// Constructor
/// @param type the type of the var
- /// @param address_space the address space of the var
- /// @param access the access mode of the var
- Var(const type::Type* type, builtin::AddressSpace address_space, builtin::Access access);
+ explicit Var(const type::Type* type);
Var(const Var& inst) = delete;
Var(Var&& inst) = delete;
~Var() override;
@@ -43,12 +41,6 @@
/// the result type of the instruction
const type::Type* type = nullptr;
- /// The variable address space
- builtin::AddressSpace address_space = builtin::AddressSpace::kUndefined;
-
- /// The variable access mode
- builtin::Access access = builtin::Access::kUndefined;
-
/// The optional initializer
Value* initializer = nullptr;
};
diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc
index 696f279..cf67c93 100644
--- a/src/tint/resolver/materialize_test.cc
+++ b/src/tint/resolver/materialize_test.cc
@@ -211,6 +211,10 @@
// abstract_expr[runtime-index]
kRuntimeIndex,
+
+ // var a : target_type;
+ // a += abstract_expr;
+ kCompoundAssign,
};
static std::ostream& operator<<(std::ostream& o, Method m) {
@@ -247,6 +251,8 @@
return o << "workgroup-size";
case Method::kRuntimeIndex:
return o << "runtime-index";
+ case Method::kCompoundAssign:
+ return o << "compound-assign";
}
return o << "<unknown>";
}
@@ -387,10 +393,15 @@
utils::Vector{WorkgroupSize(target_expr(), abstract_expr, Expr(123_a)),
Stage(ast::PipelineStage::kCompute)});
break;
- case Method::kRuntimeIndex:
+ case Method::kRuntimeIndex: {
auto* runtime_index = Var("runtime_index", Expr(1_i));
WrapInFunction(runtime_index, IndexAccessor(abstract_expr, runtime_index));
break;
+ }
+ case Method::kCompoundAssign:
+ WrapInFunction(Decl(Var("a", target_ty())),
+ CompoundAssign("a", abstract_expr, ast::BinaryOp::kAdd));
+ break;
}
switch (expectation) {
@@ -421,6 +432,10 @@
expect = "error: no matching overload for operator + (" +
data.target_type_name + ", " + data.abstract_type_name + ")";
break;
+ case Method::kCompoundAssign:
+ expect = "error: no matching overload for operator += (" +
+ data.target_type_name + ", " + data.abstract_type_name + ")";
+ break;
default:
expect = "error: cannot convert value of type '" + data.abstract_type_name +
"' to type '" + data.target_type_name + "'";
@@ -440,13 +455,13 @@
/// Methods that support scalar materialization
constexpr Method kScalarMethods[] = {
Method::kLet, Method::kVar, Method::kAssign, Method::kFnArg, Method::kBuiltinArg,
- Method::kReturn, Method::kArray, Method::kStruct, Method::kBinaryOp,
+ Method::kReturn, Method::kArray, Method::kStruct, Method::kBinaryOp, Method::kCompoundAssign,
};
/// Methods that support vector materialization
constexpr Method kVectorMethods[] = {
Method::kLet, Method::kVar, Method::kAssign, Method::kFnArg, Method::kBuiltinArg,
- Method::kReturn, Method::kArray, Method::kStruct, Method::kBinaryOp,
+ Method::kReturn, Method::kArray, Method::kStruct, Method::kBinaryOp, Method::kCompoundAssign,
};
/// Methods that support matrix materialization
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index a65d825..09d8464 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -4479,7 +4479,7 @@
return false;
}
- auto* rhs = Load(ValueExpression(stmt->rhs));
+ const auto* rhs = ValueExpression(stmt->rhs);
if (!rhs) {
return false;
}
@@ -4491,12 +4491,19 @@
auto* lhs_ty = lhs->Type()->UnwrapRef();
auto* rhs_ty = rhs->Type()->UnwrapRef();
auto stage = sem::EarliestStage(lhs->Stage(), rhs->Stage());
- auto* ty =
- intrinsic_table_->Lookup(stmt->op, lhs_ty, rhs_ty, stage, stmt->source, true).result;
- if (!ty) {
+
+ auto op = intrinsic_table_->Lookup(stmt->op, lhs_ty, rhs_ty, stage, stmt->source, true);
+ if (!op.result) {
return false;
}
- return validator_.Assignment(stmt, ty);
+
+ // Load or materialize the RHS if necessary.
+ rhs = Load(Materialize(rhs, op.rhs));
+ if (!rhs) {
+ return false;
+ }
+
+ return validator_.Assignment(stmt, op.result);
});
}
diff --git a/src/tint/transform/manager.cc b/src/tint/transform/manager.cc
index eff3a0c..92d0427 100644
--- a/src/tint/transform/manager.cc
+++ b/src/tint/transform/manager.cc
@@ -17,12 +17,24 @@
#include "src/tint/ast/transform/transform.h"
#include "src/tint/program_builder.h"
+#if TINT_BUILD_IR
+#include "src/tint/ir/from_program.h"
+#include "src/tint/ir/to_program.h"
+#include "src/tint/ir/transform/transform.h"
+#else
+// Declare an ir::Module class so that the transform target variant compiles.
+namespace ir {
+class Module;
+}
+#endif // TINT_BUILD_IR
+
/// If set to 1 then the transform::Manager will dump the WGSL of the program
/// before and after each transform. Helpful for debugging bad output.
#define TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM 0
#if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
#include <iostream>
+#include "src/tint/ir/disassembler.h"
#define TINT_IF_PRINT_PROGRAM(x) x
#else // TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
#define TINT_IF_PRINT_PROGRAM(x)
@@ -33,62 +45,140 @@
Manager::Manager() = default;
Manager::~Manager() = default;
-Program Manager::Run(const Program* program,
- const transform::DataMap& inputs,
- transform::DataMap& outputs) const {
+template <typename OUTPUT, typename INPUT>
+OUTPUT Manager::RunTransforms(INPUT in,
+ const transform::DataMap& inputs,
+ transform::DataMap& outputs) const {
+ static_assert(std::is_same<INPUT, const Program*>() || std::is_same<INPUT, ir::Module*>());
+ static_assert(std::is_same<OUTPUT, Program>() || std::is_same<OUTPUT, ir::Module*>());
+
+ // The current transform target, which could be either AST or IR.
+ std::variant<const Program*, ir::Module*> target = in;
+ // A local AST program to hold the result of AST transforms.
+ Program ast_result;
+#if TINT_BUILD_IR
+ // A local IR module to hold the result of AST->IR conversions.
+ ir::Module ir_result;
+#endif
+
#if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
- auto print_program = [&](const char* msg, const Transform* transform) {
- auto wgsl = Program::printer(program);
+ auto print_program = [&](const char* msg, const char* name) {
std::cout << "=========================================================" << std::endl;
- std::cout << "== " << msg << " " << transform->TypeInfo().name << ":" << std::endl;
+ std::cout << "== " << msg << " " << name << ":" << std::endl;
std::cout << "=========================================================" << std::endl;
- std::cout << wgsl << std::endl;
- if (!program->IsValid()) {
- std::cout << "-- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --" << std::endl;
- std::cout << program->Diagnostics().str() << std::endl;
+ if (std::holds_alternative<const Program*>(target)) {
+ auto* program = std::get<const Program*>(target);
+ auto wgsl = Program::printer(program);
+ std::cout << wgsl << std::endl;
+ if (!program->IsValid()) {
+ std::cout << "-- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --"
+ << std::endl;
+ std::cout << program->Diagnostics().str() << std::endl;
+ }
+ } else if (std::holds_alternative<ir::Module*>(target)) {
+#if TINT_BUILD_IR
+ ir::Disassembler dis(*std::get<ir::Module*>(target));
+ std::cout << dis.Disassemble();
+#endif // TINT_BUILD_IR
}
std::cout << "=========================================================" << std::endl
<< std::endl;
};
#endif
- std::optional<Program> output;
+ // Helper functions to get the current program state as either an AST program or IR module,
+ // performing a conversion if necessary.
+ auto get_ast = [&]() {
+#if TINT_BUILD_IR
+ if (std::holds_alternative<ir::Module*>(target)) {
+ // Convert the IR module to an AST program.
+ ast_result = ir::ToProgram(*std::get<ir::Module*>(target));
+ target = &ast_result;
+ }
+#endif // TINT_BUILD_IR
+ TINT_ASSERT(Transform, std::holds_alternative<const Program*>(target));
+ return std::get<const Program*>(target);
+ };
+#if TINT_BUILD_IR
+ auto get_ir = [&]() {
+ if (std::holds_alternative<const Program*>(target)) {
+ // Convert the AST program to an IR module.
+ auto converted = ir::FromProgram(std::get<const Program*>(target));
+ TINT_ASSERT(Transform, converted);
+ ir_result = converted.Move();
+ target = &ir_result;
+ }
+ TINT_ASSERT(Transform, std::holds_alternative<ir::Module*>(target));
+ return std::get<ir::Module*>(target);
+ };
+#endif // TINT_BUILD_IR
- TINT_IF_PRINT_PROGRAM(print_program("Input of", this));
+ TINT_IF_PRINT_PROGRAM(print_program("Input of", "transform manager"));
for (const auto& transform : transforms_) {
if (auto* ast_transform = transform->As<ast::transform::Transform>()) {
- if (auto result = ast_transform->Apply(program, inputs, outputs)) {
- output.emplace(std::move(result.value()));
- program = &output.value();
+ if (auto result = ast_transform->Apply(get_ast(), inputs, outputs)) {
+ ast_result = std::move(result.value());
+ target = &ast_result;
- if (!program->IsValid()) {
- TINT_IF_PRINT_PROGRAM(print_program("Invalid output of", transform.get()));
+ if (!ast_result.IsValid()) {
+ TINT_IF_PRINT_PROGRAM(
+ print_program("Invalid output of", transform->TypeInfo().name));
break;
}
- TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get()));
+ TINT_IF_PRINT_PROGRAM(print_program("Output of", transform->TypeInfo().name));
} else {
TINT_IF_PRINT_PROGRAM(std::cout << "Skipped " << transform->TypeInfo().name
<< std::endl);
}
+#if TINT_BUILD_IR
+ } else if (auto* ir_transform = transform->As<ir::transform::Transform>()) {
+ ir_transform->Run(get_ir(), inputs, outputs);
+ TINT_IF_PRINT_PROGRAM(print_program("Output of", transform->TypeInfo().name));
+#endif // TINT_BUILD_IR
} else {
- ProgramBuilder b;
- TINT_ICE(Transform, b.Diagnostics()) << "unhandled transform type";
- return Program(std::move(b));
+ TINT_ASSERT(Transform, false && "unhandled transform type");
}
}
- TINT_IF_PRINT_PROGRAM(print_program("Final output of", this));
+ TINT_IF_PRINT_PROGRAM(print_program("Final output of", "transform manager"));
- if (!output) {
- ProgramBuilder b;
- CloneContext ctx{&b, program, /* auto_clone_symbols */ true};
- ctx.Clone();
- output = Program(std::move(b));
+ if constexpr (std::is_same<OUTPUT, Program>()) {
+ auto* result = get_ast();
+ if (result == in) {
+ // AST transform pipelines are expected to return a clone of the program, so make sure
+ // the input is cloned at least once even if nothing changed.
+ ProgramBuilder b;
+ CloneContext ctx{&b, result, /* auto_clone_symbols */ true};
+ ctx.Clone();
+ ast_result = Program(std::move(b));
+ }
+ return ast_result;
+#if TINT_BUILD_IR
+ } else if constexpr (std::is_same<OUTPUT, ir::Module*>()) {
+ auto* result = get_ir();
+ if (result == &ir_result) {
+ // IR transform pipelines are expected to mutate the module in place, so move the local
+ // temporary result to the original input.
+ *in = std::move(ir_result);
+ }
+ return in;
+#endif // TINT_BUILD_IR
}
-
- return std::move(output.value());
}
+Program Manager::Run(const Program* program,
+ const transform::DataMap& inputs,
+ transform::DataMap& outputs) const {
+ return RunTransforms<Program>(program, inputs, outputs);
+}
+
+#if TINT_BUILD_IR
+void Manager::Run(ir::Module* module, const DataMap& inputs, DataMap& outputs) const {
+ auto* output = RunTransforms<ir::Module*>(module, inputs, outputs);
+ TINT_ASSERT(Transform, output == module);
+}
+#endif // TINT_BUILD_IR
+
} // namespace tint::transform
diff --git a/src/tint/transform/manager.h b/src/tint/transform/manager.h
index 2df5785..289a16d 100644
--- a/src/tint/transform/manager.h
+++ b/src/tint/transform/manager.h
@@ -19,8 +19,16 @@
#include <utility>
#include <vector>
+#include "src/tint/program.h"
#include "src/tint/transform/transform.h"
+#if TINT_BUILD_IR
+// Forward declarations
+namespace tint::ir {
+class Module;
+} // namespace tint::ir
+#endif // TINT_BUILD_IR
+
namespace tint::transform {
/// A collection of Transforms that act as a single Transform.
@@ -54,8 +62,19 @@
/// @returns the transformed program
Program Run(const Program* program, const DataMap& inputs, DataMap& outputs) const;
+#if TINT_BUILD_IR
+ /// Runs the transforms on @p module
+ /// @param module the module to transform
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(ir::Module* module, const DataMap& inputs, DataMap& outputs) const;
+#endif // TINT_BUILD_IR
+
private:
std::vector<std::unique_ptr<Transform>> transforms_;
+
+ template <typename OUTPUT, typename INPUT>
+ OUTPUT RunTransforms(INPUT in, const DataMap& inputs, DataMap& outputs) const;
};
} // namespace tint::transform
diff --git a/src/tint/transform/manager_test.cc b/src/tint/transform/manager_test.cc
new file mode 100644
index 0000000..3b0ee25
--- /dev/null
+++ b/src/tint/transform/manager_test.cc
@@ -0,0 +1,174 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/transform/manager.h"
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "src/tint/ast/transform/transform.h"
+#include "src/tint/program_builder.h"
+
+#if TINT_BUILD_IR
+#include "src/tint/ir/builder.h" // nogncheck
+#include "src/tint/ir/transform/transform.h" // nogncheck
+#endif // TINT_BUILD_IR
+
+namespace tint::transform {
+namespace {
+
+using TransformManagerTest = testing::Test;
+
+class AST_NoOp final : public ast::transform::Transform {
+ ApplyResult Apply(const Program*, const DataMap&, DataMap&) const override {
+ return SkipTransform;
+ }
+};
+
+class AST_AddFunction final : public ast::transform::Transform {
+ ApplyResult Apply(const Program* src, const DataMap&, DataMap&) const override {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src};
+ b.Func(b.Sym("ast_func"), {}, b.ty.void_(), {});
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+};
+
+#if TINT_BUILD_IR
+class IR_AddFunction final : public ir::transform::Transform {
+ void Run(ir::Module* mod, const DataMap&, DataMap&) const override {
+ ir::Builder builder(*mod);
+ auto* func =
+ builder.CreateFunction(mod->symbols.New("ir_func"), mod->types.Get<type::Void>());
+ builder.Branch(func->start_target, func->end_target);
+ mod->functions.Push(func);
+ }
+};
+#endif // TINT_BUILD_IR
+
+Program MakeAST() {
+ ProgramBuilder b;
+ b.Func(b.Sym("main"), {}, b.ty.void_(), {});
+ return Program(std::move(b));
+}
+
+#if TINT_BUILD_IR
+ir::Module MakeIR() {
+ ir::Module mod;
+ ir::Builder builder(mod);
+ auto* func =
+ builder.CreateFunction(builder.ir.symbols.New("main"), builder.ir.types.Get<type::Void>());
+ builder.Branch(func->start_target, func->end_target);
+ builder.ir.functions.Push(func);
+ return mod;
+}
+#endif // TINT_BUILD_IR
+
+// Test that an AST program is always cloned, even if all transforms are skipped.
+TEST_F(TransformManagerTest, AST_AlwaysClone) {
+ Program ast = MakeAST();
+
+ transform::Manager manager;
+ transform::DataMap outputs;
+ manager.Add<AST_NoOp>();
+
+ auto result = manager.Run(&ast, {}, outputs);
+ EXPECT_TRUE(result.IsValid()) << result.Diagnostics();
+ EXPECT_NE(result.ID(), ast.ID());
+ ASSERT_EQ(result.AST().Functions().Length(), 1u);
+ EXPECT_EQ(result.AST().Functions()[0]->name->symbol.Name(), "main");
+}
+
+#if TINT_BUILD_IR
+
+// Test that an IR module is mutated in place.
+TEST_F(TransformManagerTest, IR_MutateInPlace) {
+ ir::Module ir = MakeIR();
+
+ transform::Manager manager;
+ transform::DataMap outputs;
+ manager.Add<IR_AddFunction>();
+
+ manager.Run(&ir, {}, outputs);
+ ASSERT_EQ(ir.functions.Length(), 2u);
+ EXPECT_EQ(ir.functions[0]->name.Name(), "main");
+ EXPECT_EQ(ir.functions[1]->name.Name(), "ir_func");
+}
+
+TEST_F(TransformManagerTest, AST_MixedTransforms_AST_Before_IR) {
+ Program ast = MakeAST();
+
+ transform::Manager manager;
+ transform::DataMap outputs;
+ manager.Add<AST_AddFunction>();
+ manager.Add<IR_AddFunction>();
+
+ auto result = manager.Run(&ast, {}, outputs);
+ ASSERT_TRUE(result.IsValid()) << result.Diagnostics();
+ ASSERT_EQ(result.AST().Functions().Length(), 3u);
+ EXPECT_EQ(result.AST().Functions()[0]->name->symbol.Name(), "ast_func");
+ EXPECT_EQ(result.AST().Functions()[1]->name->symbol.Name(), "main");
+ EXPECT_EQ(result.AST().Functions()[2]->name->symbol.Name(), "ir_func");
+}
+
+TEST_F(TransformManagerTest, AST_MixedTransforms_IR_Before_AST) {
+ Program ast = MakeAST();
+
+ transform::Manager manager;
+ transform::DataMap outputs;
+ manager.Add<IR_AddFunction>();
+ manager.Add<AST_AddFunction>();
+
+ auto result = manager.Run(&ast, {}, outputs);
+ ASSERT_TRUE(result.IsValid()) << result.Diagnostics();
+ ASSERT_EQ(result.AST().Functions().Length(), 3u);
+ EXPECT_EQ(result.AST().Functions()[0]->name->symbol.Name(), "ast_func");
+ EXPECT_EQ(result.AST().Functions()[1]->name->symbol.Name(), "main");
+ EXPECT_EQ(result.AST().Functions()[2]->name->symbol.Name(), "ir_func");
+}
+
+TEST_F(TransformManagerTest, IR_MixedTransforms_AST_Before_IR) {
+ ir::Module ir = MakeIR();
+
+ transform::Manager manager;
+ transform::DataMap outputs;
+ manager.Add<AST_AddFunction>();
+ manager.Add<IR_AddFunction>();
+
+ manager.Run(&ir, {}, outputs);
+ ASSERT_EQ(ir.functions.Length(), 3u);
+ EXPECT_EQ(ir.functions[0]->name.Name(), "ast_func");
+ EXPECT_EQ(ir.functions[1]->name.Name(), "main");
+ EXPECT_EQ(ir.functions[2]->name.Name(), "ir_func");
+}
+
+TEST_F(TransformManagerTest, IR_MixedTransforms_IR_Before_AST) {
+ ir::Module ir = MakeIR();
+
+ transform::Manager manager;
+ transform::DataMap outputs;
+ manager.Add<IR_AddFunction>();
+ manager.Add<AST_AddFunction>();
+
+ manager.Run(&ir, {}, outputs);
+ ASSERT_EQ(ir.functions.Length(), 3u);
+ EXPECT_EQ(ir.functions[0]->name.Name(), "ast_func");
+ EXPECT_EQ(ir.functions[1]->name.Name(), "main");
+ EXPECT_EQ(ir.functions[2]->name.Name(), "ir_func");
+}
+#endif // TINT_BUILD_IR
+
+} // namespace
+} // namespace tint::transform
diff --git a/src/tint/transform/transform.cc b/src/tint/transform/transform.cc
index 657f9be..5d2117c 100644
--- a/src/tint/transform/transform.cc
+++ b/src/tint/transform/transform.cc
@@ -14,8 +14,6 @@
#include "src/tint/transform/transform.h"
-#include "src/tint/program_builder.h"
-
TINT_INSTANTIATE_TYPEINFO(tint::transform::Transform);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Data);
diff --git a/src/tint/transform/transform.h b/src/tint/transform/transform.h
index 1f8a3f5..d142a96 100644
--- a/src/tint/transform/transform.h
+++ b/src/tint/transform/transform.h
@@ -19,7 +19,6 @@
#include <unordered_map>
#include <utility>
-#include "src/tint/program.h"
#include "src/tint/utils/castable.h"
namespace tint::transform {
diff --git a/src/tint/writer/spirv/generator.cc b/src/tint/writer/spirv/generator.cc
index de6178b..6531ec5 100644
--- a/src/tint/writer/spirv/generator.cc
+++ b/src/tint/writer/spirv/generator.cc
@@ -18,9 +18,9 @@
#include "src/tint/writer/spirv/generator_impl.h"
#if TINT_BUILD_IR
-#include "src/tint/ir/from_program.h" // nogncheck
-#include "src/tint/writer/spirv/generator_impl_ir.h" // nogncheck
-#endif // TINT_BUILD_IR
+#include "src/tint/ir/from_program.h" // nogncheck
+#include "src/tint/writer/spirv/ir/generator_impl_ir.h" // nogncheck
+#endif // TINT_BUILD_IR
namespace tint::writer::spirv {
@@ -41,14 +41,15 @@
#if TINT_BUILD_IR
if (options.use_tint_ir) {
// Convert the AST program to an IR module.
- auto ir = ir::FromProgram(program);
- if (!ir) {
- result.error = "IR converter: " + ir.Failure();
+ auto converted = ir::FromProgram(program);
+ if (!converted) {
+ result.error = "IR converter: " + converted.Failure();
return result;
}
// Generate the SPIR-V code.
- auto impl = std::make_unique<GeneratorImplIr>(&ir.Get(), zero_initialize_workgroup_memory);
+ auto ir = converted.Move();
+ auto impl = std::make_unique<GeneratorImplIr>(&ir, zero_initialize_workgroup_memory);
result.success = impl->Generate();
result.error = impl->Diagnostics().str();
result.spirv = std::move(impl->Result());
diff --git a/src/tint/writer/spirv/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
similarity index 67%
rename from src/tint/writer/spirv/generator_impl_ir.cc
rename to src/tint/writer/spirv/ir/generator_impl_ir.cc
index 5d8c2f0..4ec4d9d 100644
--- a/src/tint/writer/spirv/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -12,30 +12,71 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/writer/spirv/generator_impl_ir.h"
+#include "src/tint/writer/spirv/ir/generator_impl_ir.h"
#include "spirv/unified1/spirv.h"
#include "src/tint/ir/binary.h"
#include "src/tint/ir/block.h"
#include "src/tint/ir/function_terminator.h"
+#include "src/tint/ir/if.h"
#include "src/tint/ir/module.h"
+#include "src/tint/ir/store.h"
+#include "src/tint/ir/transform/add_empty_entry_point.h"
+#include "src/tint/ir/var.h"
#include "src/tint/switch.h"
+#include "src/tint/transform/manager.h"
#include "src/tint/type/bool.h"
#include "src/tint/type/f16.h"
#include "src/tint/type/f32.h"
#include "src/tint/type/i32.h"
+#include "src/tint/type/pointer.h"
#include "src/tint/type/type.h"
#include "src/tint/type/u32.h"
#include "src/tint/type/vector.h"
#include "src/tint/type/void.h"
+#include "src/tint/writer/spirv/generator.h"
#include "src/tint/writer/spirv/module.h"
namespace tint::writer::spirv {
-GeneratorImplIr::GeneratorImplIr(const ir::Module* module, bool zero_init_workgroup_mem)
+namespace {
+
+void Sanitize(ir::Module* module) {
+ transform::Manager manager;
+ transform::DataMap data;
+
+ manager.Add<ir::transform::AddEmptyEntryPoint>();
+
+ transform::DataMap outputs;
+ manager.Run(module, data, outputs);
+}
+
+SpvStorageClass StorageClass(builtin::AddressSpace addrspace) {
+ switch (addrspace) {
+ case builtin::AddressSpace::kFunction:
+ return SpvStorageClassFunction;
+ case builtin::AddressSpace::kPrivate:
+ return SpvStorageClassPrivate;
+ case builtin::AddressSpace::kStorage:
+ return SpvStorageClassStorageBuffer;
+ case builtin::AddressSpace::kUniform:
+ return SpvStorageClassUniform;
+ case builtin::AddressSpace::kWorkgroup:
+ return SpvStorageClassWorkgroup;
+ default:
+ return SpvStorageClassMax;
+ }
+}
+
+} // namespace
+
+GeneratorImplIr::GeneratorImplIr(ir::Module* module, bool zero_init_workgroup_mem)
: ir_(module), zero_init_workgroup_memory_(zero_init_workgroup_mem) {}
bool GeneratorImplIr::Generate() {
+ // Run the IR transformations to prepare for SPIR-V emission.
+ Sanitize(ir_);
+
// TODO(crbug.com/tint/1906): Check supported extensions.
module_.PushCapability(SpvCapabilityShader);
@@ -133,6 +174,11 @@
[&](const type::Vector* vec) {
module_.PushType(spv::Op::OpTypeVector, {id, Type(vec->type()), vec->Width()});
},
+ [&](const type::Pointer* ptr) {
+ module_.PushType(
+ spv::Op::OpTypePointer,
+ {id, U32Operand(StorageClass(ptr->AddressSpace())), Type(ptr->StoreType())});
+ },
[&](Default) {
TINT_ICE(Writer, diagnostics_) << "unhandled type: " << ty->FriendlyName();
});
@@ -158,6 +204,10 @@
});
}
+uint32_t GeneratorImplIr::Label(const ir::Block* block) {
+ return block_labels_.GetOrCreate(block, [&]() { return module_.NextId(); });
+}
+
void GeneratorImplIr::EmitFunction(const ir::Function* func) {
// Make an ID for the function.
auto id = module_.NextId();
@@ -235,11 +285,22 @@
}
void GeneratorImplIr::EmitBlock(const ir::Block* block) {
+ // Emit the label.
+ // Skip if this is the function's entry block, as it will be emitted by the function object.
+ if (!current_function_.instructions().empty()) {
+ current_function_.push_inst(spv::Op::OpLabel, {Label(block)});
+ }
+
// Emit the instructions.
for (auto* inst : block->instructions) {
auto result = Switch(
inst, //
[&](const ir::Binary* b) { return EmitBinary(b); },
+ [&](const ir::Store* s) {
+ EmitStore(s);
+ return 0u;
+ },
+ [&](const ir::Var* v) { return EmitVar(v); },
[&](Default) {
TINT_ICE(Writer, diagnostics_)
<< "unimplemented instruction: " << inst->TypeInfo().name;
@@ -251,6 +312,8 @@
// Handle the branch at the end of the block.
Switch(
block->branch.target,
+ [&](const ir::Block* b) { current_function_.push_inst(spv::Op::OpBranch, {Label(b)}); },
+ [&](const ir::If* i) { EmitIf(i); },
[&](const ir::FunctionTerminator*) {
// TODO(jrprice): Handle the return value, which will be a branch argument.
if (!block->branch.args.IsEmpty()) {
@@ -258,7 +321,52 @@
}
current_function_.push_inst(spv::Op::OpReturn, {});
},
- [&](Default) { TINT_ICE(Writer, diagnostics_) << "unimplemented branch target"; });
+ [&](Default) {
+ if (!block->branch.target) {
+ // A block may not have an outward branch (e.g. an unreachable merge block).
+ current_function_.push_inst(spv::Op::OpUnreachable, {});
+ } else {
+ TINT_ICE(Writer, diagnostics_)
+ << "unimplemented branch target: " << block->branch.target->TypeInfo().name;
+ }
+ });
+}
+
+void GeneratorImplIr::EmitIf(const ir::If* i) {
+ auto* merge_block = i->merge.target->As<ir::Block>();
+ auto* true_block = i->true_.target->As<ir::Block>();
+ auto* false_block = i->false_.target->As<ir::Block>();
+
+ // Generate labels for the blocks. We emit the true or false block if it:
+ // 1. contains instructions, or
+ // 2. branches somewhere other then the merge target.
+ // Otherwise we skip them and branch straight to the merge block.
+ uint32_t merge_label = Label(merge_block);
+ uint32_t true_label = merge_label;
+ uint32_t false_label = merge_label;
+ if (!true_block->instructions.IsEmpty() || true_block->branch.target != merge_block) {
+ true_label = Label(true_block);
+ }
+ if (!false_block->instructions.IsEmpty() || false_block->branch.target != merge_block) {
+ false_label = Label(false_block);
+ }
+
+ // Emit the OpSelectionMerge and OpBranchConditional instructions.
+ current_function_.push_inst(spv::Op::OpSelectionMerge,
+ {merge_label, U32Operand(SpvSelectionControlMaskNone)});
+ current_function_.push_inst(spv::Op::OpBranchConditional,
+ {Value(i->condition), true_label, false_label});
+
+ // Emit the `true` and `false` blocks, if they're not being skipped.
+ if (true_label != merge_label) {
+ EmitBlock(true_block);
+ }
+ if (false_label != merge_label) {
+ EmitBlock(false_block);
+ }
+
+ // Emit the merge block.
+ EmitBlock(merge_block);
}
uint32_t GeneratorImplIr::EmitBinary(const ir::Binary* binary) {
@@ -288,4 +396,34 @@
return id;
}
+void GeneratorImplIr::EmitStore(const ir::Store* store) {
+ current_function_.push_inst(spv::Op::OpStore, {Value(store->to), Value(store->from)});
+}
+
+uint32_t GeneratorImplIr::EmitVar(const ir::Var* var) {
+ auto id = module_.NextId();
+ auto* ptr = var->Type()->As<type::Pointer>();
+ TINT_ASSERT(Writer, ptr);
+ auto ty = Type(ptr);
+
+ if (ptr->AddressSpace() == builtin::AddressSpace::kFunction) {
+ TINT_ASSERT(Writer, current_function_);
+ current_function_.push_var({ty, id, U32Operand(SpvStorageClassFunction)});
+ if (var->initializer) {
+ current_function_.push_inst(spv::Op::OpStore, {id, Value(var->initializer)});
+ }
+ } else {
+ TINT_ICE(Writer, diagnostics_)
+ << "unimplemented variable address space " << ptr->AddressSpace();
+ return 0u;
+ }
+
+ // Set the name if present.
+ if (auto name = ir_->NameOf(var)) {
+ module_.PushDebug(spv::Op::OpName, {id, Operand(name.Name())});
+ }
+
+ return id;
+}
+
} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/generator_impl_ir.h b/src/tint/writer/spirv/ir/generator_impl_ir.h
similarity index 85%
rename from src/tint/writer/spirv/generator_impl_ir.h
rename to src/tint/writer/spirv/ir/generator_impl_ir.h
index e885edb..66ffe48 100644
--- a/src/tint/writer/spirv/generator_impl_ir.h
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef SRC_TINT_WRITER_SPIRV_GENERATOR_IMPL_IR_H_
-#define SRC_TINT_WRITER_SPIRV_GENERATOR_IMPL_IR_H_
+#ifndef SRC_TINT_WRITER_SPIRV_IR_GENERATOR_IMPL_IR_H_
+#define SRC_TINT_WRITER_SPIRV_IR_GENERATOR_IMPL_IR_H_
#include <vector>
@@ -30,9 +30,12 @@
namespace tint::ir {
class Binary;
class Block;
+class If;
class Function;
class Module;
+class Store;
class Value;
+class Var;
} // namespace tint::ir
namespace tint::type {
class Type;
@@ -47,7 +50,7 @@
/// @param module the Tint IR module to generate
/// @param zero_init_workgroup_memory `true` to initialize all the variables in the Workgroup
/// storage class with OpConstantNull
- GeneratorImplIr(const ir::Module* module, bool zero_init_workgroup_memory);
+ GeneratorImplIr(ir::Module* module, bool zero_init_workgroup_memory);
/// @returns true on successful generation; false otherwise
bool Generate();
@@ -76,6 +79,11 @@
/// @returns the result ID of the value
uint32_t Value(const ir::Value* value);
+ /// Get the ID of the label for `block`.
+ /// @param block the block to get the label ID for
+ /// @returns the ID of the block's label
+ uint32_t Label(const ir::Block* block);
+
/// Emit a function.
/// @param func the function to emit
void EmitFunction(const ir::Function* func);
@@ -89,18 +97,31 @@
/// @param block the block to emit
void EmitBlock(const ir::Block* block);
+ /// Emit an `if` flow node.
+ /// @param i the if node to emit
+ void EmitIf(const ir::If* i);
+
/// Emit a binary instruction.
/// @param binary the binary instruction to emit
/// @returns the result ID of the instruction
uint32_t EmitBinary(const ir::Binary* binary);
+ /// Emit a store instruction.
+ /// @param store the store instruction to emit
+ void EmitStore(const ir::Store* store);
+
+ /// Emit a var instruction.
+ /// @param var the var instruction to emit
+ /// @returns the result ID of the instruction
+ uint32_t EmitVar(const ir::Var* var);
+
private:
/// Get the result ID of the constant `constant`, emitting its instruction if necessary.
/// @param constant the constant to get the ID for
/// @returns the result ID of the constant
uint32_t Constant(const constant::Value* constant);
- const ir::Module* ir_;
+ ir::Module* ir_;
spirv::Module module_;
BinaryWriter writer_;
diag::List diagnostics_;
@@ -161,6 +182,9 @@
/// The map of instructions to their result IDs.
utils::Hashmap<const ir::Instruction*, uint32_t, 8> instructions_;
+ /// The map of blocks to the IDs of their label instructions.
+ utils::Hashmap<const ir::Block*, uint32_t, 8> block_labels_;
+
/// The current function that is being emitted.
Function current_function_;
@@ -169,4 +193,4 @@
} // namespace tint::writer::spirv
-#endif // SRC_TINT_WRITER_SPIRV_GENERATOR_IMPL_IR_H_
+#endif // SRC_TINT_WRITER_SPIRV_IR_GENERATOR_IMPL_IR_H_
diff --git a/src/tint/writer/spirv/generator_impl_binary_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
similarity index 98%
rename from src/tint/writer/spirv/generator_impl_binary_test.cc
rename to src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
index 4bf6b81..30dc059 100644
--- a/src/tint/writer/spirv/generator_impl_binary_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/writer/spirv/test_helper_ir.h"
+#include "src/tint/writer/spirv/ir/test_helper_ir.h"
using namespace tint::number_suffixes; // NOLINT
diff --git a/src/tint/writer/spirv/generator_impl_constant_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc
similarity index 98%
rename from src/tint/writer/spirv/generator_impl_constant_test.cc
rename to src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc
index 1e27e89..95fce03 100644
--- a/src/tint/writer/spirv/generator_impl_constant_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/writer/spirv/test_helper_ir.h"
+#include "src/tint/writer/spirv/ir/test_helper_ir.h"
namespace tint::writer::spirv {
namespace {
diff --git a/src/tint/writer/spirv/generator_impl_function_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
similarity index 98%
rename from src/tint/writer/spirv/generator_impl_function_test.cc
rename to src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
index b8faebb..77b4a62 100644
--- a/src/tint/writer/spirv/generator_impl_function_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/writer/spirv/test_helper_ir.h"
+#include "src/tint/writer/spirv/ir/test_helper_ir.h"
namespace tint::writer::spirv {
namespace {
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
new file mode 100644
index 0000000..7b41184
--- /dev/null
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
@@ -0,0 +1,149 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/writer/spirv/ir/test_helper_ir.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::writer::spirv {
+namespace {
+
+TEST_F(SpvGeneratorImplTest, If_TrueEmpty_FalseEmpty) {
+ auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
+
+ auto* i = b.CreateIf(b.Constant(true));
+ b.Branch(i->true_.target->As<ir::Block>(), i->merge.target);
+ b.Branch(i->false_.target->As<ir::Block>(), i->merge.target);
+ b.Branch(i->merge.target->As<ir::Block>(), func->end_target);
+
+ b.Branch(func->start_target, i);
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeBool
+%6 = OpConstantTrue %7
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %5 None
+OpBranchConditional %6 %5 %5
+%5 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, If_FalseEmpty) {
+ auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
+
+ auto* i = b.CreateIf(b.Constant(true));
+ b.Branch(i->false_.target->As<ir::Block>(), i->merge.target);
+ b.Branch(i->merge.target->As<ir::Block>(), func->end_target);
+
+ auto* true_block = i->true_.target->As<ir::Block>();
+ true_block->instructions.Push(
+ b.Add(mod.types.Get<type::I32>(), b.Constant(1_i), b.Constant(1_i)));
+ b.Branch(true_block, i->merge.target);
+
+ b.Branch(func->start_target, i);
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%8 = OpTypeBool
+%7 = OpConstantTrue %8
+%10 = OpTypeInt 32 1
+%11 = OpConstant %10 1
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %5 None
+OpBranchConditional %7 %6 %5
+%6 = OpLabel
+%9 = OpIAdd %10 %11 %11
+OpBranch %5
+%5 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, If_TrueEmpty) {
+ auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
+
+ auto* i = b.CreateIf(b.Constant(true));
+ b.Branch(i->true_.target->As<ir::Block>(), i->merge.target);
+ b.Branch(i->merge.target->As<ir::Block>(), func->end_target);
+
+ auto* false_block = i->false_.target->As<ir::Block>();
+ false_block->instructions.Push(
+ b.Add(mod.types.Get<type::I32>(), b.Constant(1_i), b.Constant(1_i)));
+ b.Branch(false_block, i->merge.target);
+
+ b.Branch(func->start_target, i);
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%8 = OpTypeBool
+%7 = OpConstantTrue %8
+%10 = OpTypeInt 32 1
+%11 = OpConstant %10 1
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %5 None
+OpBranchConditional %7 %5 %6
+%6 = OpLabel
+%9 = OpIAdd %10 %11 %11
+OpBranch %5
+%5 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, If_BothBranchesReturn) {
+ auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
+
+ auto* i = b.CreateIf(b.Constant(true));
+ b.Branch(i->true_.target->As<ir::Block>(), func->end_target);
+ b.Branch(i->false_.target->As<ir::Block>(), func->end_target);
+ i->merge.target->As<ir::Block>()->branch.target = nullptr;
+
+ b.Branch(func->start_target, i);
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%9 = OpTypeBool
+%8 = OpConstantTrue %9
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %5 None
+OpBranchConditional %8 %6 %7
+%6 = OpLabel
+OpReturn
+%7 = OpLabel
+OpReturn
+%5 = OpLabel
+OpUnreachable
+OpFunctionEnd
+)");
+}
+
+} // namespace
+} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/generator_impl_ir_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_test.cc
similarity index 85%
rename from src/tint/writer/spirv/generator_impl_ir_test.cc
rename to src/tint/writer/spirv/ir/generator_impl_ir_test.cc
index a202eea..ffafa60 100644
--- a/src/tint/writer/spirv/generator_impl_ir_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_test.cc
@@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/writer/spirv/test_helper_ir.h"
+#include "src/tint/writer/spirv/ir/test_helper_ir.h"
+
+#include "gmock/gmock.h"
namespace tint::writer::spirv {
namespace {
@@ -20,9 +22,9 @@
TEST_F(SpvGeneratorImplTest, ModuleHeader) {
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
auto got = Disassemble(generator_.Result());
- EXPECT_EQ(got, R"(OpCapability Shader
+ EXPECT_THAT(got, testing::StartsWith(R"(OpCapability Shader
OpMemoryModel Logical GLSL450
-)");
+)"));
}
} // namespace
diff --git a/src/tint/writer/spirv/generator_impl_type_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
similarity index 98%
rename from src/tint/writer/spirv/generator_impl_type_test.cc
rename to src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
index c86d363..38bf1f0 100644
--- a/src/tint/writer/spirv/generator_impl_type_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
@@ -19,7 +19,7 @@
#include "src/tint/type/type.h"
#include "src/tint/type/u32.h"
#include "src/tint/type/void.h"
-#include "src/tint/writer/spirv/test_helper_ir.h"
+#include "src/tint/writer/spirv/ir/test_helper_ir.h"
namespace tint::writer::spirv {
namespace {
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
new file mode 100644
index 0000000..ff94862
--- /dev/null
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
@@ -0,0 +1,168 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/type/pointer.h"
+#include "src/tint/writer/spirv/ir/test_helper_ir.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::writer::spirv {
+namespace {
+
+TEST_F(SpvGeneratorImplTest, FunctionVar_NoInit) {
+ auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
+ b.Branch(func->start_target, func->end_target);
+
+ auto* ty = mod.types.Get<type::Pointer>(
+ mod.types.Get<type::I32>(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+ auto* v = b.Declare(ty);
+ func->start_target->instructions.Push(v);
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpTypePointer Function %7
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+%5 = OpVariable %6 Function
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, FunctionVar_WithInit) {
+ auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
+ b.Branch(func->start_target, func->end_target);
+
+ auto* ty = mod.types.Get<type::Pointer>(
+ mod.types.Get<type::I32>(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+ auto* v = b.Declare(ty);
+ func->start_target->instructions.Push(v);
+ v->initializer = b.Constant(42_i);
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpTypePointer Function %7
+%8 = OpConstant %7 42
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+%5 = OpVariable %6 Function
+OpStore %5 %8
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, FunctionVar_Name) {
+ auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
+ b.Branch(func->start_target, func->end_target);
+
+ auto* ty = mod.types.Get<type::Pointer>(
+ mod.types.Get<type::I32>(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+ auto* v = b.Declare(ty);
+ func->start_target->instructions.Push(v);
+ mod.SetName(v, "myvar");
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+OpName %5 "myvar"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpTypePointer Function %7
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+%5 = OpVariable %6 Function
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, FunctionVar_DeclInsideBlock) {
+ auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
+ b.Branch(func->start_target, func->end_target);
+
+ auto* ty = mod.types.Get<type::Pointer>(
+ mod.types.Get<type::I32>(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+ auto* v = b.Declare(ty);
+ v->initializer = b.Constant(42_i);
+
+ auto* i = b.CreateIf(b.Constant(true));
+ b.Branch(i->false_.target->As<ir::Block>(), func->end_target);
+ b.Branch(i->merge.target->As<ir::Block>(), func->end_target);
+
+ auto* true_block = i->true_.target->As<ir::Block>();
+ true_block->instructions.Push(v);
+ b.Branch(true_block, i->merge.target);
+
+ b.Branch(func->start_target, i);
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%9 = OpTypeBool
+%8 = OpConstantTrue %9
+%12 = OpTypeInt 32 1
+%11 = OpTypePointer Function %12
+%13 = OpConstant %12 42
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+%10 = OpVariable %11 Function
+OpSelectionMerge %5 None
+OpBranchConditional %8 %6 %7
+%6 = OpLabel
+OpStore %10 %13
+OpBranch %5
+%7 = OpLabel
+OpReturn
+%5 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, FunctionVar_Store) {
+ auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
+ b.Branch(func->start_target, func->end_target);
+
+ auto* ty = mod.types.Get<type::Pointer>(
+ mod.types.Get<type::I32>(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+ auto* v = b.Declare(ty);
+ func->start_target->instructions.Push(v);
+ func->start_target->instructions.Push(b.Store(v, b.Constant(42_i)));
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpTypePointer Function %7
+%8 = OpConstant %7 42
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+%5 = OpVariable %6 Function
+OpStore %5 %8
+OpReturn
+OpFunctionEnd
+)");
+}
+
+} // namespace
+} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/test_helper_ir.h b/src/tint/writer/spirv/ir/test_helper_ir.h
similarity index 87%
rename from src/tint/writer/spirv/test_helper_ir.h
rename to src/tint/writer/spirv/ir/test_helper_ir.h
index 3574645..9509b42 100644
--- a/src/tint/writer/spirv/test_helper_ir.h
+++ b/src/tint/writer/spirv/ir/test_helper_ir.h
@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef SRC_TINT_WRITER_SPIRV_TEST_HELPER_IR_H_
-#define SRC_TINT_WRITER_SPIRV_TEST_HELPER_IR_H_
+#ifndef SRC_TINT_WRITER_SPIRV_IR_TEST_HELPER_IR_H_
+#define SRC_TINT_WRITER_SPIRV_IR_TEST_HELPER_IR_H_
#include <string>
#include "gtest/gtest.h"
#include "src/tint/ir/builder.h"
-#include "src/tint/writer/spirv/generator_impl_ir.h"
+#include "src/tint/writer/spirv/ir/generator_impl_ir.h"
#include "src/tint/writer/spirv/spv_dump.h"
namespace tint::writer::spirv {
@@ -50,4 +50,4 @@
} // namespace tint::writer::spirv
-#endif // SRC_TINT_WRITER_SPIRV_TEST_HELPER_IR_H_
+#endif // SRC_TINT_WRITER_SPIRV_IR_TEST_HELPER_IR_H_