Import Tint changes from Dawn
Changes:
- 20d20f6f3f846793ddcec16fbed60a2bb7e97875 Allow @index Attribute On Variables by Brandon Jones <brandon1.jones@intel.com>
- 54fc2a84e91ad8b89dce909922beb771b343c25d Recognize @index attribute when parsing by Brandon Jones <brandon1.jones@intel.com>
- 4a38f0a803b9d2f653a39970b7fc187ca1becc1d Add @index Attribute to WGSL Writer by Brandon Jones <brandon1.jones@intel.com>
- 3d9c5563be46145f153dbf8eec6afdd273e41580 [tint][ir] Swap order of first two BreakIf() params by Ben Clayton <bclayton@google.com>
- 6130722a8173ec3d2d697faf4b52cb3055557521 [ir][spirv-writer] Add helper macro for unit tests by James Price <jrprice@google.com>
- 407137aa46ed7b69e760a246a6d480e58fc4135c [tint][utils] Add bounds assertions to vector by Ben Clayton <bclayton@google.com>
- 5735a675eb5fb259c0361cc06638d0a3ba994611 [tint][utils] Update doxygen for EnumSet by Ben Clayton <bclayton@google.com>
- e04687a3e660e119ff2074511f81ab5f8ee03297 [tint][ir][ToProgram] Inline values respecting ordering by Ben Clayton <bclayton@google.com>
- 3e65d908e9f82ea5e7dece9704af7ec00d2b28c8 [ir][spirv-writer] Rework unit testing by James Price <jrprice@google.com>
- 084e2fdb43a5793f3de800c638b6970fae8947d1 [ir][validation] Add Unary validation by dan sinclair <dsinclair@chromium.org>
- 027636e49941a54e8f0dd0a00f4104b441eb88ab [ir][validate] Extract operand nullptr checks by dan sinclair <dsinclair@chromium.org>
- 9a51768e568ac1c0efe2d5c29f6a4b0142f43844 [ir][validate] Check functions only added to module once by dan sinclair <dsinclair@chromium.org>
- 7686852fd55ed56f41074655a3aebdc408cf9334 [ir][validate] Improve result error messages by dan sinclair <dsinclair@chromium.org>
- 54482355a5f020766c87967d518d64b21b768189 [ir] Add Builder::Var overload with name by James Price <jrprice@google.com>
- ce882c00b218849f5bcadd2ba904362c123ec0e4 [tint][ir] Add EnumSet flags to Value and Instruction by Ben Clayton <bclayton@google.com>
- 610e4e627af6a894cc148cef61b04b2136252953 [ir] Make Builder::Function add the function by James Price <jrprice@google.com>
- 25b514646cd39d56c17bd25a361973aa4ab98d12 [ir] Add Builder::FunctionParam overload with name by James Price <jrprice@google.com>
- 9e819bf9ea71646ec75e0d82d65c22e9a7d7cbdf [ir][spirv-writer] Handle shader IO by James Price <jrprice@google.com>
- 4765e38cdc27112e4b1f06cea35acac4f66ab652 [ir][validation] Update binary tests to mark undef operands by dan sinclair <dsinclair@chromium.org>
- 7257c563dd3855bf493b9d426d8e4231a4248be1 [ir][msl] Split long emit type switch. by dan sinclair <dsinclair@chromium.org>
- 4f13d3116b40935cb79ed3dc717cc71bcb21fd6e [tint][ir] Refactor IRToProgramTest class by Ben Clayton <bclayton@google.com>
- 0306088ebf79036a5b86669795e634c5a9c6956c [tint][utils] Add EnumSet::Set() by Ben Clayton <bclayton@google.com>
- 5ef873a519434764d1fce835b9bdc0ce64332bb7 [tint][utils] Add UniqueVector::Erase by Ben Clayton <bclayton@google.com>
- 607e241ba075f0de26793d52fc4c9b7de20f5632 [tint][utils] Add Vector::Erase by Ben Clayton <bclayton@google.com>
- 45b59a8ba039f3df861bc63d5e02201884eb7932 [ir][msl] Emit struct constants by dan sinclair <dsinclair@chromium.org>
- a356526632a54a28e93bc5d21f4130e230bd1fe3 [ir][msl] Cleanup duplicate emission code. by dan sinclair <dsinclair@chromium.org>
- 676745988943cc9894b7e53dc5fea99ff90e1fa6 [ir][msl] Emit array constants by dan sinclair <dsinclair@chromium.org>
- ae33f9722957ab97abde31a734b59b5c96dad494 [ir][msl] Emit matrix constants by dan sinclair <dsinclair@chromium.org>
- c8a5cf8713eb32eb6a78a7d74f7075456deb0d64 [ir][msl] Emit vector constants types by dan sinclair <dsinclair@chromium.org>
GitOrigin-RevId: 20d20f6f3f846793ddcec16fbed60a2bb7e97875
Change-Id: I18caa4521c6cdb96ee03707ffe202e9d81c71e58
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/139380
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 83d7421..a315fd4 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -523,6 +523,10 @@
"ir/transform/block_decorated_structs.h",
"ir/transform/merge_return.cc",
"ir/transform/merge_return.h",
+ "ir/transform/shader_io.cc",
+ "ir/transform/shader_io.h",
+ "ir/transform/shader_io_spirv.cc",
+ "ir/transform/shader_io_spirv.h",
"ir/transform/var_for_dynamic_index.cc",
"ir/transform/var_for_dynamic_index.h",
]
@@ -1866,6 +1870,7 @@
"ir/transform/add_empty_entry_point_test.cc",
"ir/transform/block_decorated_structs_test.cc",
"ir/transform/merge_return_test.cc",
+ "ir/transform/shader_io_test.cc",
"ir/transform/test_helper.h",
"ir/transform/var_for_dynamic_index_test.cc",
]
@@ -2398,6 +2403,7 @@
"ir/store_test.cc",
"ir/switch_test.cc",
"ir/swizzle_test.cc",
+ "ir/to_program_inlining_test.cc",
"ir/to_program_roundtrip_test.cc",
"ir/to_program_test.cc",
"ir/unary_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 0ddae38..967cef2 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -832,6 +832,10 @@
ir/transform/block_decorated_structs.h
ir/transform/merge_return.cc
ir/transform/merge_return.h
+ ir/transform/shader_io.cc
+ ir/transform/shader_io.h
+ ir/transform/shader_io_spirv.cc
+ ir/transform/shader_io_spirv.h
ir/transform/transform.cc
ir/transform/transform.h
ir/transform/var_for_dynamic_index.cc
@@ -1586,6 +1590,7 @@
ir/transform/add_empty_entry_point_test.cc
ir/transform/block_decorated_structs_test.cc
ir/transform/merge_return_test.cc
+ ir/transform/shader_io_test.cc
ir/transform/var_for_dynamic_index_test.cc
ir/unary_test.cc
ir/user_call_test.cc
@@ -1603,6 +1608,7 @@
if (${TINT_BUILD_IR} AND ${TINT_BUILD_WGSL_WRITER})
list(APPEND TINT_TEST_SRCS
+ ir/to_program_inlining_test.cc
ir/to_program_test.cc
)
endif()
diff --git a/src/tint/ir/block_test.cc b/src/tint/ir/block_test.cc
index 9f5fbfa..79ae399 100644
--- a/src/tint/ir/block_test.cc
+++ b/src/tint/ir/block_test.cc
@@ -36,7 +36,7 @@
TEST_F(IR_BlockTest, HasTerminator_BreakIf) {
auto* blk = b.Block();
auto* loop = b.Loop();
- blk->Append(b.BreakIf(true, loop));
+ blk->Append(b.BreakIf(loop, true));
EXPECT_TRUE(blk->HasTerminator());
}
diff --git a/src/tint/ir/break_if_test.cc b/src/tint/ir/break_if_test.cc
index 2ce17ae..506d93f 100644
--- a/src/tint/ir/break_if_test.cc
+++ b/src/tint/ir/break_if_test.cc
@@ -30,7 +30,7 @@
auto* arg1 = b.Constant(1_u);
auto* arg2 = b.Constant(2_u);
- auto* brk = b.BreakIf(cond, loop, arg1, arg2);
+ auto* brk = b.BreakIf(loop, cond, arg1, arg2);
EXPECT_THAT(cond->Usages(), testing::UnorderedElementsAre(Usage{brk, 0u}));
EXPECT_THAT(arg1->Usages(), testing::UnorderedElementsAre(Usage{brk, 1u}));
@@ -43,7 +43,7 @@
auto* arg1 = b.Constant(1_u);
auto* arg2 = b.Constant(2_u);
- auto* brk = b.BreakIf(cond, loop, arg1, arg2);
+ auto* brk = b.BreakIf(loop, cond, arg1, arg2);
EXPECT_FALSE(brk->HasResults());
EXPECT_FALSE(brk->HasMultiResults());
}
@@ -53,7 +53,7 @@
{
Module mod;
Builder b{mod};
- b.BreakIf(true, nullptr);
+ b.BreakIf(nullptr, true);
},
"");
}
diff --git a/src/tint/ir/builder.cc b/src/tint/ir/builder.cc
index 2b3c9a4..6049a85 100644
--- a/src/tint/ir/builder.cc
+++ b/src/tint/ir/builder.cc
@@ -50,6 +50,7 @@
auto* ir_func = ir.values.Create<ir::Function>(return_type, stage, wg_size);
ir_func->SetBlock(Block());
ir.SetName(ir_func, name);
+ ir.functions.Push(ir_func);
return ir_func;
}
@@ -76,6 +77,12 @@
return Append(ir.instructions.Create<ir::Var>(InstructionResult(type)));
}
+ir::Var* Builder::Var(std::string_view name, const type::Pointer* type) {
+ auto* var = Var(type);
+ ir.SetName(var, name);
+ return var;
+}
+
ir::BlockParam* Builder::BlockParam(const type::Type* type) {
return ir.values.Create<ir::BlockParam>(type);
}
@@ -84,6 +91,12 @@
return ir.values.Create<ir::FunctionParam>(type);
}
+ir::FunctionParam* Builder::FunctionParam(std::string_view name, const type::Type* type) {
+ auto* param = ir.values.Create<ir::FunctionParam>(type);
+ ir.SetName(param, name);
+ return param;
+}
+
ir::Unreachable* Builder::Unreachable() {
return Append(ir.instructions.Create<ir::Unreachable>());
}
diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h
index 3ad87a6..95c6df0 100644
--- a/src/tint/ir/builder.h
+++ b/src/tint/ir/builder.h
@@ -17,7 +17,9 @@
#include <utility>
+#include "src/tint/constant/composite.h"
#include "src/tint/constant/scalar.h"
+#include "src/tint/constant/splat.h"
#include "src/tint/ir/access.h"
#include "src/tint/ir/binary.h"
#include "src/tint/ir/bitcast.h"
@@ -559,6 +561,12 @@
/// @returns the instruction
ir::Var* Var(const type::Pointer* type);
+ /// Creates a new `var` declaration with a name
+ /// @param name the var name
+ /// @param type the var type
+ /// @returns the instruction
+ ir::Var* Var(std::string_view name, const type::Pointer* type);
+
/// Creates a return instruction
/// @param func the function being returned
/// @returns the instruction
@@ -572,6 +580,11 @@
/// @returns the instruction
template <typename ARG>
ir::Return* Return(ir::Function* func, ARG&& value) {
+ if constexpr (std::is_same_v<std::decay_t<ARG>, ir::Value*>) {
+ if (value == nullptr) {
+ return Append(ir.instructions.Create<ir::Return>(func));
+ }
+ }
return Append(ir.instructions.Create<ir::Return>(func, Value(std::forward<ARG>(value))));
}
@@ -591,7 +604,7 @@
/// @param args the arguments for the target MultiInBlock
/// @returns the instruction
template <typename CONDITION, typename... ARGS>
- ir::BreakIf* BreakIf(CONDITION&& condition, ir::Loop* loop, ARGS&&... args) {
+ ir::BreakIf* BreakIf(ir::Loop* loop, CONDITION&& condition, ARGS&&... args) {
return Append(ir.instructions.Create<ir::BreakIf>(
Value(std::forward<CONDITION>(condition)), loop, Values(std::forward<ARGS>(args)...)));
}
@@ -658,6 +671,12 @@
/// @returns the value
ir::FunctionParam* FunctionParam(const type::Type* type);
+ /// Creates a new `FunctionParam` with a name.
+ /// @param name the parameter name
+ /// @param type the parameter type
+ /// @returns the value
+ ir::FunctionParam* FunctionParam(std::string_view name, const type::Type* type);
+
/// Creates a new `Access`
/// @param type the return type
/// @param object the object being accessed
diff --git a/src/tint/ir/call.cc b/src/tint/ir/call.cc
index ead6124..cacdaa4 100644
--- a/src/tint/ir/call.cc
+++ b/src/tint/ir/call.cc
@@ -20,7 +20,9 @@
namespace tint::ir {
-Call::Call() = default;
+Call::Call() {
+ flags_.Add(Flag::kSequenced);
+}
Call::~Call() = default;
diff --git a/src/tint/ir/control_instruction.cc b/src/tint/ir/control_instruction.cc
index 6451622..a5c3f0f 100644
--- a/src/tint/ir/control_instruction.cc
+++ b/src/tint/ir/control_instruction.cc
@@ -18,6 +18,10 @@
namespace tint::ir {
+ControlInstruction::ControlInstruction() {
+ flags_.Add(Flag::kSequenced);
+}
+
ControlInstruction::~ControlInstruction() = default;
void ControlInstruction::AddExit(Exit* exit) {
diff --git a/src/tint/ir/control_instruction.h b/src/tint/ir/control_instruction.h
index 701e032..30fbf4d 100644
--- a/src/tint/ir/control_instruction.h
+++ b/src/tint/ir/control_instruction.h
@@ -31,6 +31,9 @@
/// ControlInstruction.
class ControlInstruction : public utils::Castable<ControlInstruction, OperandInstruction<1, 1>> {
public:
+ /// Constructor
+ ControlInstruction();
+
/// Destructor
~ControlInstruction() override;
diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc
index f8e301f..38e9be4 100644
--- a/src/tint/ir/disassembler.cc
+++ b/src/tint/ir/disassembler.cc
@@ -294,11 +294,13 @@
}
void Disassembler::EmitValueWithType(Instruction* val) {
+ SourceMarker sm(this);
if (val->Result()) {
EmitValueWithType(val->Result());
} else {
out_ << "undef";
}
+ sm.StoreResult(Usage{val, 0});
}
void Disassembler::EmitValueWithType(Value* val) {
@@ -371,6 +373,13 @@
}
void Disassembler::EmitInstruction(Instruction* inst) {
+ if (!inst->Alive()) {
+ SourceMarker sm(this);
+ out_ << "<destroyed " << inst->TypeInfo().name << " " << utils::ToString(inst) << ">";
+ sm.Store(inst);
+ EmitLine();
+ return;
+ }
tint::Switch(
inst, //
[&](Switch* s) { EmitSwitch(s); }, //
@@ -512,46 +521,54 @@
}
}
-void Disassembler::EmitIf(If* i) {
+void Disassembler::EmitIf(If* if_) {
SourceMarker sm(this);
- if (i->Result()) {
- EmitValueWithType(i->Result());
+ if (if_->HasResults()) {
+ auto res = if_->Results();
+ for (size_t i = 0; i < res.Length(); ++i) {
+ if (i > 0) {
+ out_ << ", ";
+ }
+ SourceMarker rs(this);
+ EmitValueWithType(res[i]);
+ rs.StoreResult(Usage{if_, i});
+ }
out_ << " = ";
}
out_ << "if ";
- EmitOperand(i, i->Condition(), If::kConditionOperandOffset);
+ EmitOperand(if_, if_->Condition(), If::kConditionOperandOffset);
- bool has_true = !i->True()->IsEmpty();
- bool has_false = !i->False()->IsEmpty();
+ bool has_true = !if_->True()->IsEmpty();
+ bool has_false = !if_->False()->IsEmpty();
out_ << " [";
if (has_true) {
- out_ << "t: %b" << IdOf(i->True());
+ out_ << "t: %b" << IdOf(if_->True());
}
if (has_false) {
if (has_true) {
out_ << ", ";
}
- out_ << "f: %b" << IdOf(i->False());
+ out_ << "f: %b" << IdOf(if_->False());
}
out_ << "]";
- sm.Store(i);
+ sm.Store(if_);
- out_ << " { # " << NameOf(i);
+ out_ << " { # " << NameOf(if_);
EmitLine();
if (has_true) {
ScopedIndent si(indent_size_);
- EmitBlock(i->True(), "true");
+ EmitBlock(if_->True(), "true");
}
if (has_false) {
ScopedIndent si(indent_size_);
- EmitBlock(i->False(), "false");
- } else if (i->HasResults()) {
+ EmitBlock(if_->False(), "false");
+ } else if (if_->HasResults()) {
ScopedIndent si(indent_size_);
Indent();
out_ << "# implicit false block: exit_if undef";
- for (size_t v = 1; v < i->Results().Length(); v++) {
+ for (size_t v = 1; v < if_->Results().Length(); v++) {
out_ << ", undef";
}
EmitLine();
@@ -738,9 +755,9 @@
break;
}
out_ << " ";
- EmitValue(b->LHS());
+ EmitOperand(b, b->LHS(), Binary::kLhsOperandOffset);
out_ << ", ";
- EmitValue(b->RHS());
+ EmitOperand(b, b->RHS(), Binary::kRhsOperandOffset);
sm.Store(b);
EmitLine();
@@ -759,7 +776,7 @@
break;
}
out_ << " ";
- EmitValue(u->Val());
+ EmitOperand(u, u->Val(), Unary::kValueOperandOffset);
sm.Store(u);
EmitLine();
diff --git a/src/tint/ir/disassembler.h b/src/tint/ir/disassembler.h
index 940f7f6..4a0e495 100644
--- a/src/tint/ir/disassembler.h
+++ b/src/tint/ir/disassembler.h
@@ -61,6 +61,10 @@
/// @returns the source for the operand
Source OperandSource(Usage operand) { return operand_to_src_.Get(operand).value_or(Source{}); }
+ /// @param result the result to retrieve
+ /// @returns the source for the result
+ Source ResultSource(Usage result) { return result_to_src_.Get(result).value_or(Source{}); }
+
/// @param blk teh block to retrieve
/// @returns the source for the block
Source BlockSource(Block* blk) { return block_to_src_.Get(blk).value_or(Source{}); }
@@ -80,6 +84,11 @@
/// @param src the source location
void SetSource(Usage op, Source src) { operand_to_src_.Add(op, src); }
+ /// Stores the given @p src location for @p result
+ /// @param result the result to store
+ /// @param src the source location
+ void SetResultSource(Usage result, Source src) { result_to_src_.Add(result, src); }
+
/// @returns the source location for the current emission location
Source::Location MakeCurrentLocation();
@@ -95,6 +104,8 @@
void Store(Usage operand) { dis_->SetSource(operand, MakeSource()); }
+ void StoreResult(Usage result) { dis_->SetResultSource(result, MakeSource()); }
+
Source MakeSource() const {
return Source(Source::Range(begin_, dis_->MakeCurrentLocation()));
}
@@ -151,6 +162,7 @@
utils::Hashmap<Block*, Source, 8> block_to_src_;
utils::Hashmap<Instruction*, Source, 8> instruction_to_src_;
utils::Hashmap<Usage, Source, 8, Usage::Hasher> operand_to_src_;
+ utils::Hashmap<Usage, Source, 8, Usage::Hasher> result_to_src_;
utils::Hashmap<If*, std::string, 8> if_names_;
utils::Hashmap<Loop*, std::string, 8> loop_names_;
utils::Hashmap<Switch*, std::string, 8> switch_names_;
diff --git a/src/tint/ir/from_program.cc b/src/tint/ir/from_program.cc
index b6c6586..69b80ea 100644
--- a/src/tint/ir/from_program.cc
+++ b/src/tint/ir/from_program.cc
@@ -263,7 +263,6 @@
auto* ir_func = builder_.Function(ast_func->name->symbol.NameView(),
sem->ReturnType()->Clone(clone_ctx_.type_ctx));
current_function_ = ir_func;
- builder_.ir.functions.Push(ir_func);
scopes_.Set(ast_func->name->symbol, ir_func);
@@ -338,7 +337,7 @@
for (auto* p : ast_func->params) {
const auto* param_sem = program_->Sem().Get(p)->As<sem::Parameter>();
auto* ty = param_sem->Type()->Clone(clone_ctx_.type_ctx);
- auto* param = builder_.FunctionParam(ty);
+ auto* param = builder_.FunctionParam(p->name->symbol.NameView(), ty);
// Note, interpolated is only valid when paired with Location, so it will only be set
// when the location is set.
@@ -412,7 +411,6 @@
}
scopes_.Set(p->name->symbol, param);
- builder_.ir.SetName(param, p->name->symbol.NameView());
params.Push(param);
}
ir_func->SetParams(params);
@@ -855,7 +853,7 @@
if (!cond) {
return;
}
- SetTerminator(builder_.BreakIf(cond.Get(), current_control->As<ir::Loop>()));
+ SetTerminator(builder_.BreakIf(current_control->As<ir::Loop>(), cond.Get()));
}
struct AccessorInfo {
diff --git a/src/tint/ir/function.h b/src/tint/ir/function.h
index 0aadc64..4310bc6 100644
--- a/src/tint/ir/function.h
+++ b/src/tint/ir/function.h
@@ -79,6 +79,9 @@
/// @param z the z size
void SetWorkgroupSize(uint32_t x, uint32_t y, uint32_t z) { workgroup_size_ = {x, y, z}; }
+ /// Clears the workgroup size.
+ void ClearWorkgroupSize() { workgroup_size_ = {}; }
+
/// @returns the workgroup size information
std::optional<std::array<uint32_t, 3>> WorkgroupSize() { return workgroup_size_; }
@@ -93,6 +96,8 @@
}
/// @returns the return builtin attribute
std::optional<enum ReturnBuiltin> ReturnBuiltin() { return return_.builtin; }
+ /// Clears the return builtin attribute.
+ void ClearReturnBuiltin() { return_.builtin = {}; }
/// Sets the return location
/// @param loc the location to set
@@ -102,6 +107,8 @@
}
/// @returns the return location
std::optional<Location> ReturnLocation() { return return_.location; }
+ /// Clears the return location attribute.
+ void ClearReturnLocation() { return_.location = {}; }
/// Sets the return as invariant
/// @param val the invariant value to set
diff --git a/src/tint/ir/function_param.h b/src/tint/ir/function_param.h
index 6a1e8ca..0c138bc 100644
--- a/src/tint/ir/function_param.h
+++ b/src/tint/ir/function_param.h
@@ -70,6 +70,8 @@
}
/// @returns the builtin set for the parameter
std::optional<FunctionParam::Builtin> Builtin() { return builtin_; }
+ /// Clears the builtin attribute.
+ void ClearBuiltin() { builtin_ = {}; }
/// Sets the parameter as invariant
/// @param val the value to set for invariant
@@ -85,13 +87,15 @@
}
/// @returns the location if `Attributes` contains `kLocation`
std::optional<struct Location> Location() { return location_; }
+ /// Clears the location attribute.
+ void ClearLocation() { location_ = {}; }
/// Sets the binding point
/// @param group the group
/// @param binding the binding
void SetBindingPoint(uint32_t group, uint32_t binding) { binding_point_ = {group, binding}; }
/// @returns the binding points if `Attributes` contains `kBindingPoint`
- std::optional<struct BindingPoint> BindingPoint() { return binding_point_; }
+ std::optional<struct BindingPoint>& BindingPoint() { return binding_point_; }
private:
const type::Type* type_ = nullptr;
diff --git a/src/tint/ir/instruction.cc b/src/tint/ir/instruction.cc
index e8bc12c..635302f 100644
--- a/src/tint/ir/instruction.cc
+++ b/src/tint/ir/instruction.cc
@@ -34,7 +34,7 @@
result->SetSource(nullptr);
result->Destroy();
}
- alive_ = false;
+ flags_.Add(Flag::kDead);
}
void Instruction::InsertBefore(Instruction* before) {
diff --git a/src/tint/ir/instruction.h b/src/tint/ir/instruction.h
index 8145258..5b6a940 100644
--- a/src/tint/ir/instruction.h
+++ b/src/tint/ir/instruction.h
@@ -18,6 +18,7 @@
#include "src/tint/ir/instruction_result.h"
#include "src/tint/ir/value.h"
#include "src/tint/utils/castable.h"
+#include "src/tint/utils/enum_set.h"
// Forward declarations
namespace tint::ir {
@@ -57,7 +58,11 @@
virtual void Destroy();
/// @returns true if the Instruction has not been destroyed with Destroy()
- bool Alive() const { return alive_; }
+ bool Alive() const { return !flags_.Contains(Flag::kDead); }
+
+ /// @returns true if the Instruction is sequenced. Sequenced instructions cannot be implicitly
+ /// reordered with other sequenced instructions.
+ bool Sequenced() const { return flags_.Contains(Flag::kSequenced); }
/// Sets the block that owns this instruction
/// @param block the new owner block
@@ -92,14 +97,22 @@
Instruction* prev = nullptr;
protected:
+ /// Flags applied to an Instruction
+ enum class Flag {
+ /// The instruction has been destroyed
+ kDead,
+ /// The instruction must not be reordered with another sequenced instruction
+ kSequenced,
+ };
+
/// Constructor
Instruction();
/// The block that owns this instruction
ir::Block* block_ = nullptr;
- private:
- bool alive_ = true;
+ /// Bitset of instruction flags
+ utils::EnumSet<Flag> flags_;
};
} // namespace tint::ir
diff --git a/src/tint/ir/load.cc b/src/tint/ir/load.cc
index 0918398..ac62f46 100644
--- a/src/tint/ir/load.cc
+++ b/src/tint/ir/load.cc
@@ -22,6 +22,8 @@
namespace tint::ir {
Load::Load(InstructionResult* result, Value* from) {
+ flags_.Add(Flag::kSequenced);
+
TINT_ASSERT(IR, from->Type()->Is<type::Pointer>());
TINT_ASSERT(IR, from && from->Type()->UnwrapPtr() == result->Type());
diff --git a/src/tint/ir/store.cc b/src/tint/ir/store.cc
index d1ea611..23b5f83 100644
--- a/src/tint/ir/store.cc
+++ b/src/tint/ir/store.cc
@@ -20,6 +20,8 @@
namespace tint::ir {
Store::Store(Value* to, Value* from) {
+ flags_.Add(Flag::kSequenced);
+
AddOperand(Store::kToOperandOffset, to);
AddOperand(Store::kFromOperandOffset, from);
}
diff --git a/src/tint/ir/to_program.cc b/src/tint/ir/to_program.cc
index e3a9604..3ca56a6 100644
--- a/src/tint/ir/to_program.cc
+++ b/src/tint/ir/to_program.cc
@@ -21,6 +21,7 @@
#include "src/tint/ir/access.h"
#include "src/tint/ir/binary.h"
#include "src/tint/ir/block.h"
+#include "src/tint/ir/break_if.h"
#include "src/tint/ir/call.h"
#include "src/tint/ir/constant.h"
#include "src/tint/ir/continue.h"
@@ -118,6 +119,9 @@
/// The current switch case block
ir::Block* current_switch_case_ = nullptr;
+ // Values that can be inlined.
+ utils::Hashset<ir::Value*, 64> can_inline_;
+
const ast::Function* Fn(ir::Function* fn) {
SCOPED_NESTING();
@@ -146,6 +150,7 @@
StatementList Statements(ir::Block* block) {
StatementList stmts;
if (block) {
+ MarkInlinable(block);
TINT_SCOPED_ASSIGNMENT(statements_, &stmts);
for (auto* inst : *block) {
Instruction(inst);
@@ -154,12 +159,80 @@
return stmts;
}
+ void MarkInlinable(ir::Block* block) {
+ // An ordered list of possibly-inlinable values returned by sequenced instructions that have
+ // not yet been marked-for or ruled-out-for inlining.
+ utils::UniqueVector<ir::Value*, 32> pending_resolution;
+
+ // Walk the instructions of the block starting with the first.
+ for (auto* inst : *block) {
+ // Is the instruction sequenced?
+ bool sequenced = inst->Sequenced();
+
+ // Walk the instruction's operands starting with the right-most.
+ auto operands = inst->Operands();
+ for (auto* operand : utils::Reverse(operands)) {
+ if (!pending_resolution.Contains(operand)) {
+ continue;
+ }
+ // Operand is in 'pending_resolution'
+
+ if (pending_resolution.TryPop(operand)) {
+ // Operand was the last sequenced value to be added to 'pending_resolution'
+ // This operand can be inlined as it does not change the sequencing order.
+ can_inline_.Add(operand);
+ sequenced = true; // Inherit the 'sequenced' flag from the inlined value
+ } else {
+ // Operand was in 'pending_resolution', but was not the last sequenced value to
+ // be added. Inlining this operand would break the sequencing order, so must be
+ // emitted as a let. All preceding pending values must also be emitted as a
+ // let to prevent them being inlined and breaking the sequencing order.
+ // Remove all the values in pending upto and including 'operand'.
+ for (size_t i = 0; i < pending_resolution.Length(); i++) {
+ if (pending_resolution[i] == operand) {
+ pending_resolution.Erase(0, i + 1);
+ break;
+ }
+ }
+ }
+ }
+
+ if (inst->Results().Length() == 1) {
+ // Instruction has a single result value.
+ // Check to see if the result of this instruction is a candidate for inlining.
+ auto* result = inst->Result();
+ // Only values with a single usage can be inlined.
+ // Named values are not inlined, as we want to emit the name for a let.
+ if (result->Usages().Count() == 1 && !mod.NameOf(result).IsValid()) {
+ if (sequenced) {
+ // The value comes from a sequenced instruction. We need to ensure
+ // instruction ordering so add it to 'pending_resolution'.
+ pending_resolution.Add(result);
+ } else {
+ // The value comes from an unsequenced instruction. Just inline.
+ can_inline_.Add(result);
+ }
+ continue;
+ }
+ }
+
+ // At this point the value has been ruled out for inlining.
+
+ if (sequenced) {
+ // A sequenced instruction with zero or multiple return values cannot be inlined.
+ // All preceding sequenced instructions cannot be inlined past this point.
+ pending_resolution.Clear();
+ }
+ }
+ }
+
void Append(const ast::Statement* inst) { statements_->Push(inst); }
void Instruction(ir::Instruction* inst) {
tint::Switch(
inst, //
- [&](ir::Binary* u) { Binary(u); }, //
+ [&](ir::Binary* i) { Binary(i); }, //
+ [&](ir::BreakIf* i) { BreakIf(i); }, //
[&](ir::Call* i) { Call(i); }, //
[&](ir::ExitIf*) {}, //
[&](ir::ExitSwitch* i) { ExitSwitch(i); }, //
@@ -214,6 +287,7 @@
StatementList body_stmts;
{
+ MarkInlinable(l->Body());
TINT_SCOPED_ASSIGNMENT(statements_, &body_stmts);
for (auto* inst : *l->Body()) {
if (body_stmts.IsEmpty()) {
@@ -293,6 +367,8 @@
void ExitLoop(const ir::ExitLoop*) { Append(b.Break()); }
+ void BreakIf(ir::BreakIf* i) { Append(b.BreakIf(Expr(i->Condition()))); }
+
void Return(ir::Return* ret) {
if (ret->Args().IsEmpty()) {
// Return has no arguments.
@@ -615,7 +691,7 @@
template <typename T>
void Bind(ir::Value* value, const T* expr) {
TINT_ASSERT(IR, value);
- if (CanInline(value)) {
+ if (can_inline_.Remove(value)) {
// Value will be inlined at its place of usage.
bool added = bindings_.Add(value, expr);
if (TINT_UNLIKELY(!added)) {
@@ -627,12 +703,6 @@
}
}
- /// @returns true if the if the value can be inlined into its single place
- /// of usage. Currently a value is inlined if it has a single usage and is unnamed.
- /// TODO(crbug.com/tint/1902): This logic needs to check that the sequence of side-effecting
- /// expressions is not changed by inlining the expression. This needs fixing.
- bool CanInline(Value* val) { return val->Usages().Count() == 1 && !mod.NameOf(val).IsValid(); }
-
////////////////////////////////////////////////////////////////////////////////////////////////
// Helpers
////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/ir/to_program_inlining_test.cc b/src/tint/ir/to_program_inlining_test.cc
new file mode 100644
index 0000000..72b9d39
--- /dev/null
+++ b/src/tint/ir/to_program_inlining_test.cc
@@ -0,0 +1,1155 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+
+#include "src/tint/ir/disassembler.h"
+#include "src/tint/ir/to_program.h"
+#include "src/tint/ir/to_program_test.h"
+#include "src/tint/utils/string.h"
+#include "src/tint/writer/wgsl/generator.h"
+
+namespace tint::ir::test {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using namespace tint::builtin::fluent_types; // NOLINT
+
+using IRToProgramInliningTest = IRToProgramTest;
+
+////////////////////////////////////////////////////////////////////////////////
+// Load / Store
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramInliningTest, LoadVar_ThenStoreVar_ThenUseLoad) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ b.Store(var, 1_i);
+ auto* load = b.Load(var);
+ b.Store(var, 2_i);
+ b.Return(fn, load);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ var v : i32;
+ v = 1i;
+ let v_1 = v;
+ v = 2i;
+ return v_1;
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Binary op
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramInliningTest, BinaryOpUnsequencedLHSThenUnsequencedRHS) {
+ auto* fn_a = b.Function("a", ty.i32());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, 0_i); });
+ fn_a->SetParams({b.FunctionParam(ty.i32())});
+
+ auto* fn_b = b.Function("b", ty.i32());
+ b.With(fn_b->Block(), [&] {
+ auto* lhs = b.Add(ty.i32(), 1_i, 2_i);
+ auto* rhs = b.Add(ty.i32(), 3_i, 4_i);
+ auto* bin = b.Add(ty.i32(), lhs, rhs);
+ b.Return(fn_b, bin);
+ });
+
+ EXPECT_WGSL(R"(
+fn a(v : i32) -> i32 {
+ return 0i;
+}
+
+fn b() -> i32 {
+ return ((1i + 2i) + (3i + 4i));
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, BinaryOpSequencedLHSThenUnsequencedRHS) {
+ auto* fn_a = b.Function("a", ty.i32());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, 0_i); });
+ fn_a->SetParams({b.FunctionParam(ty.i32())});
+
+ auto* fn_b = b.Function("b", ty.i32());
+ b.With(fn_b->Block(), [&] {
+ auto* lhs = b.Call(ty.i32(), fn_a, 1_i);
+ auto* rhs = b.Add(ty.i32(), 2_i, 3_i);
+ auto* bin = b.Add(ty.i32(), lhs, rhs);
+ b.Return(fn_b, bin);
+ });
+
+ EXPECT_WGSL(R"(
+fn a(v : i32) -> i32 {
+ return 0i;
+}
+
+fn b() -> i32 {
+ return (a(1i) + (2i + 3i));
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, BinaryOpUnsequencedLHSThenSequencedRHS) {
+ auto* fn_a = b.Function("a", ty.i32());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, 0_i); });
+ fn_a->SetParams({b.FunctionParam(ty.i32())});
+
+ auto* fn_b = b.Function("b", ty.i32());
+ b.With(fn_b->Block(), [&] {
+ auto* lhs = b.Add(ty.i32(), 1_i, 2_i);
+ auto* rhs = b.Call(ty.i32(), fn_a, 3_i);
+ auto* bin = b.Add(ty.i32(), lhs, rhs);
+ b.Return(fn_b, bin);
+ });
+
+ EXPECT_WGSL(R"(
+fn a(v : i32) -> i32 {
+ return 0i;
+}
+
+fn b() -> i32 {
+ return ((1i + 2i) + a(3i));
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, BinaryOpSequencedLHSThenSequencedRHS) {
+ auto* fn_a = b.Function("a", ty.i32());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, 0_i); });
+ fn_a->SetParams({b.FunctionParam(ty.i32())});
+
+ auto* fn_b = b.Function("b", ty.i32());
+ b.With(fn_b->Block(), [&] {
+ auto* lhs = b.Call(ty.i32(), fn_a, 1_i);
+ auto* rhs = b.Call(ty.i32(), fn_a, 2_i);
+ auto* bin = b.Add(ty.i32(), lhs, rhs);
+ b.Return(fn_b, bin);
+ });
+
+ EXPECT_WGSL(R"(
+fn a(v : i32) -> i32 {
+ return 0i;
+}
+
+fn b() -> i32 {
+ return (a(1i) + a(2i));
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, BinaryOpUnsequencedRHSThenUnsequencedLHS) {
+ auto* fn_a = b.Function("a", ty.i32());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, 0_i); });
+ fn_a->SetParams({b.FunctionParam(ty.i32())});
+
+ auto* fn_b = b.Function("b", ty.i32());
+ b.With(fn_b->Block(), [&] {
+ auto* rhs = b.Add(ty.i32(), 3_i, 4_i);
+ auto* lhs = b.Add(ty.i32(), 1_i, 2_i);
+ auto* bin = b.Add(ty.i32(), lhs, rhs);
+ b.Return(fn_b, bin);
+ });
+
+ EXPECT_WGSL(R"(
+fn a(v : i32) -> i32 {
+ return 0i;
+}
+
+fn b() -> i32 {
+ return ((1i + 2i) + (3i + 4i));
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, BinaryOpUnsequencedRHSThenSequencedLHS) {
+ auto* fn_a = b.Function("a", ty.i32());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, 0_i); });
+ fn_a->SetParams({b.FunctionParam(ty.i32())});
+
+ auto* fn_b = b.Function("b", ty.i32());
+ b.With(fn_b->Block(), [&] {
+ auto* rhs = b.Add(ty.i32(), 2_i, 3_i);
+ auto* lhs = b.Call(ty.i32(), fn_a, 1_i);
+ auto* bin = b.Add(ty.i32(), lhs, rhs);
+ b.Return(fn_b, bin);
+ });
+
+ EXPECT_WGSL(R"(
+fn a(v : i32) -> i32 {
+ return 0i;
+}
+
+fn b() -> i32 {
+ return (a(1i) + (2i + 3i));
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, BinaryOpSequencedRHSThenUnsequencedLHS) {
+ auto* fn_a = b.Function("a", ty.i32());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, 0_i); });
+ fn_a->SetParams({b.FunctionParam(ty.i32())});
+
+ auto* fn_b = b.Function("b", ty.i32());
+ b.With(fn_b->Block(), [&] {
+ auto* rhs = b.Call(ty.i32(), fn_a, 3_i);
+ auto* lhs = b.Add(ty.i32(), 1_i, 2_i);
+ auto* bin = b.Add(ty.i32(), lhs, rhs);
+ b.Return(fn_b, bin);
+ });
+
+ EXPECT_WGSL(R"(
+fn a(v : i32) -> i32 {
+ return 0i;
+}
+
+fn b() -> i32 {
+ return ((1i + 2i) + a(3i));
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, BinaryOpSequencedRHSThenSequencedLHS) {
+ auto* fn_a = b.Function("a", ty.i32());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, 0_i); });
+ fn_a->SetParams({b.FunctionParam(ty.i32())});
+
+ auto* fn_b = b.Function("b", ty.i32());
+ b.With(fn_b->Block(), [&] {
+ auto* rhs = b.Call(ty.i32(), fn_a, 2_i);
+ auto* lhs = b.Call(ty.i32(), fn_a, 1_i);
+ auto* bin = b.Add(ty.i32(), lhs, rhs);
+ b.Return(fn_b, bin);
+ });
+
+ EXPECT_WGSL(R"(
+fn a(v : i32) -> i32 {
+ return 0i;
+}
+
+fn b() -> i32 {
+ let v_1 = a(2i);
+ return (a(1i) + v_1);
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Call
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramInliningTest, CallSequencedXYZ) {
+ auto* fn_a = b.Function("a", ty.i32());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, 0_i); });
+ fn_a->SetParams({b.FunctionParam(ty.i32())});
+
+ auto* fn_b = b.Function("b", ty.i32());
+ b.With(fn_b->Block(), [&] { b.Return(fn_b, 0_i); });
+ fn_b->SetParams(
+ {b.FunctionParam(ty.i32()), b.FunctionParam(ty.i32()), b.FunctionParam(ty.i32())});
+
+ auto* fn_c = b.Function("c", ty.i32());
+ b.With(fn_c->Block(), [&] {
+ auto* x = b.Call(ty.i32(), fn_a, 1_i);
+ auto* y = b.Call(ty.i32(), fn_a, 2_i);
+ auto* z = b.Call(ty.i32(), fn_a, 3_i);
+ auto* call = b.Call(ty.i32(), fn_b, x, y, z);
+ b.Return(fn_c, call);
+ });
+
+ EXPECT_WGSL(R"(
+fn a(v : i32) -> i32 {
+ return 0i;
+}
+
+fn b(v_1 : i32, v_2 : i32, v_3 : i32) -> i32 {
+ return 0i;
+}
+
+fn c() -> i32 {
+ return b(a(1i), a(2i), a(3i));
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, CallSequencedYXZ) {
+ auto* fn_a = b.Function("a", ty.i32());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, 0_i); });
+ fn_a->SetParams({b.FunctionParam(ty.i32())});
+
+ auto* fn_b = b.Function("b", ty.i32());
+ b.With(fn_b->Block(), [&] { b.Return(fn_b, 0_i); });
+ fn_b->SetParams(
+ {b.FunctionParam(ty.i32()), b.FunctionParam(ty.i32()), b.FunctionParam(ty.i32())});
+
+ auto* fn_c = b.Function("c", ty.i32());
+ b.With(fn_c->Block(), [&] {
+ auto* y = b.Call(ty.i32(), fn_a, 2_i);
+ auto* x = b.Call(ty.i32(), fn_a, 1_i);
+ auto* z = b.Call(ty.i32(), fn_a, 3_i);
+ auto* call = b.Call(ty.i32(), fn_b, x, y, z);
+ b.Return(fn_c, call);
+ });
+
+ EXPECT_WGSL(R"(
+fn a(v : i32) -> i32 {
+ return 0i;
+}
+
+fn b(v_1 : i32, v_2 : i32, v_3 : i32) -> i32 {
+ return 0i;
+}
+
+fn c() -> i32 {
+ let v_4 = a(2i);
+ return b(a(1i), v_4, a(3i));
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, CallSequencedXZY) {
+ auto* fn_a = b.Function("a", ty.i32());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, 0_i); });
+ fn_a->SetParams({b.FunctionParam(ty.i32())});
+
+ auto* fn_b = b.Function("b", ty.i32());
+ b.With(fn_b->Block(), [&] { b.Return(fn_b, 0_i); });
+ fn_b->SetParams(
+ {b.FunctionParam(ty.i32()), b.FunctionParam(ty.i32()), b.FunctionParam(ty.i32())});
+
+ auto* fn_c = b.Function("c", ty.i32());
+ b.With(fn_c->Block(), [&] {
+ auto* x = b.Call(ty.i32(), fn_a, 1_i);
+ auto* z = b.Call(ty.i32(), fn_a, 3_i);
+ auto* y = b.Call(ty.i32(), fn_a, 2_i);
+ auto* call = b.Call(ty.i32(), fn_b, x, y, z);
+ b.Return(fn_c, call);
+ });
+
+ EXPECT_WGSL(R"(
+fn a(v : i32) -> i32 {
+ return 0i;
+}
+
+fn b(v_1 : i32, v_2 : i32, v_3 : i32) -> i32 {
+ return 0i;
+}
+
+fn c() -> i32 {
+ let v_4 = a(1i);
+ let v_5 = a(3i);
+ return b(v_4, a(2i), v_5);
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, CallSequencedZXY) {
+ auto* fn_a = b.Function("a", ty.i32());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, 0_i); });
+ fn_a->SetParams({b.FunctionParam(ty.i32())});
+
+ auto* fn_b = b.Function("b", ty.i32());
+ b.With(fn_b->Block(), [&] { b.Return(fn_b, 0_i); });
+ fn_b->SetParams(
+ {b.FunctionParam(ty.i32()), b.FunctionParam(ty.i32()), b.FunctionParam(ty.i32())});
+
+ auto* fn_c = b.Function("c", ty.i32());
+ b.With(fn_c->Block(), [&] {
+ auto* z = b.Call(ty.i32(), fn_a, 3_i);
+ auto* x = b.Call(ty.i32(), fn_a, 1_i);
+ auto* y = b.Call(ty.i32(), fn_a, 2_i);
+ auto* call = b.Call(ty.i32(), fn_b, x, y, z);
+ b.Return(fn_c, call);
+ });
+
+ EXPECT_WGSL(R"(
+fn a(v : i32) -> i32 {
+ return 0i;
+}
+
+fn b(v_1 : i32, v_2 : i32, v_3 : i32) -> i32 {
+ return 0i;
+}
+
+fn c() -> i32 {
+ let v_4 = a(3i);
+ return b(a(1i), a(2i), v_4);
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, CallSequencedYZX) {
+ auto* fn_a = b.Function("a", ty.i32());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, 0_i); });
+ fn_a->SetParams({b.FunctionParam(ty.i32())});
+
+ auto* fn_b = b.Function("b", ty.i32());
+ b.With(fn_b->Block(), [&] { b.Return(fn_b, 0_i); });
+ fn_b->SetParams(
+ {b.FunctionParam(ty.i32()), b.FunctionParam(ty.i32()), b.FunctionParam(ty.i32())});
+
+ auto* fn_c = b.Function("c", ty.i32());
+ b.With(fn_c->Block(), [&] {
+ auto* y = b.Call(ty.i32(), fn_a, 2_i);
+ auto* z = b.Call(ty.i32(), fn_a, 3_i);
+ auto* x = b.Call(ty.i32(), fn_a, 1_i);
+ auto* call = b.Call(ty.i32(), fn_b, x, y, z);
+ b.Return(fn_c, call);
+ });
+
+ EXPECT_WGSL(R"(
+fn a(v : i32) -> i32 {
+ return 0i;
+}
+
+fn b(v_1 : i32, v_2 : i32, v_3 : i32) -> i32 {
+ return 0i;
+}
+
+fn c() -> i32 {
+ let v_4 = a(2i);
+ let v_5 = a(3i);
+ return b(a(1i), v_4, v_5);
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, CallSequencedZYX) {
+ auto* fn_a = b.Function("a", ty.i32());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, 0_i); });
+ fn_a->SetParams({b.FunctionParam(ty.i32())});
+
+ auto* fn_b = b.Function("b", ty.i32());
+ b.With(fn_b->Block(), [&] { b.Return(fn_b, 0_i); });
+ fn_b->SetParams(
+ {b.FunctionParam(ty.i32()), b.FunctionParam(ty.i32()), b.FunctionParam(ty.i32())});
+
+ auto* fn_c = b.Function("c", ty.i32());
+ b.With(fn_c->Block(), [&] {
+ auto* z = b.Call(ty.i32(), fn_a, 3_i);
+ auto* y = b.Call(ty.i32(), fn_a, 2_i);
+ auto* x = b.Call(ty.i32(), fn_a, 1_i);
+ auto* call = b.Call(ty.i32(), fn_b, x, y, z);
+ b.Return(fn_c, call);
+ });
+
+ EXPECT_WGSL(R"(
+fn a(v : i32) -> i32 {
+ return 0i;
+}
+
+fn b(v_1 : i32, v_2 : i32, v_3 : i32) -> i32 {
+ return 0i;
+}
+
+fn c() -> i32 {
+ let v_4 = a(3i);
+ let v_5 = a(2i);
+ return b(a(1i), v_5, v_4);
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, LoadVar_ThenCallVoidFn_ThenUseLoad) {
+ auto* fn_a = b.Function("a", ty.void_());
+
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ b.Store(var, 1_i);
+ auto* load = b.Load(var);
+ b.Call(ty.void_(), fn_a);
+ b.Return(fn, load);
+ });
+
+ EXPECT_WGSL(R"(
+fn a() {
+}
+
+fn f() -> i32 {
+ var v : i32;
+ v = 1i;
+ let v_1 = v;
+ a();
+ return v_1;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, LoadVar_ThenCallUnusedi32Fn_ThenUseLoad) {
+ auto* fn_a = b.Function("a", ty.i32());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, 1_i); });
+
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ b.Store(var, 1_i);
+ auto* load = b.Load(var);
+ b.Call(ty.i32(), fn_a);
+ b.Return(fn, load);
+ });
+
+ EXPECT_WGSL(R"(
+fn a() -> i32 {
+ return 1i;
+}
+
+fn f() -> i32 {
+ var v : i32;
+ v = 1i;
+ let v_1 = v;
+ a();
+ return v_1;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, LoadVar_ThenCalli32Fn_ThenUseLoadBeforeCall) {
+ auto* fn_a = b.Function("a", ty.i32());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, 1_i); });
+
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ b.Store(var, 1_i);
+ auto* load = b.Load(var);
+ auto* call = b.Call(ty.i32(), fn_a);
+ b.Return(fn, b.Add(ty.i32(), load, call));
+ });
+
+ EXPECT_WGSL(R"(
+fn a() -> i32 {
+ return 1i;
+}
+
+fn f() -> i32 {
+ var v : i32;
+ v = 1i;
+ return (v + a());
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, LoadVar_ThenCalli32Fn_ThenUseCallBeforeLoad) {
+ auto* fn_a = b.Function("a", ty.i32());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, 1_i); });
+
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ b.Store(var, 1_i);
+ auto* load = b.Load(var);
+ auto* call = b.Call(ty.i32(), fn_a);
+ b.Return(fn, b.Add(ty.i32(), call, load));
+ });
+
+ EXPECT_WGSL(R"(
+fn a() -> i32 {
+ return 1i;
+}
+
+fn f() -> i32 {
+ var v : i32;
+ v = 1i;
+ let v_1 = v;
+ return (a() + v_1);
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// If
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramInliningTest, UnsequencedOutsideIf) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* v = b.Add(ty.i32(), 1_i, 2_i);
+ auto* if_ = b.If(true);
+ b.With(if_->True(), [&] { b.Return(fn, v); });
+ b.Return(fn, 0_i);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ if (true) {
+ return (1i + 2i);
+ }
+ return 0i;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, SequencedOutsideIf) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ var->SetInitializer(b.Constant(1_i));
+ auto* v_1 = b.Load(var);
+ auto* v_2 = b.Add(ty.i32(), v_1, 2_i);
+ auto* if_ = b.If(true);
+ b.With(if_->True(), [&] { b.Return(fn, v_2); });
+ b.Return(fn, 0_i);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ var v : i32 = 1i;
+ let v_1 = (v + 2i);
+ if (true) {
+ return v_1;
+ }
+ return 0i;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, UnsequencedUsedByIfCondition) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* v = b.Equal(ty.bool_(), 1_i, 2_i);
+ auto* if_ = b.If(v);
+ b.With(if_->True(), [&] { b.Return(fn, 3_i); });
+ b.Return(fn, 0_i);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ if ((1i == 2i)) {
+ return 3i;
+ }
+ return 0i;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, SequencedUsedByIfCondition) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ var->SetInitializer(b.Constant(1_i));
+ auto* v_1 = b.Load(var);
+ auto* v_2 = b.Equal(ty.bool_(), v_1, 2_i);
+ auto* if_ = b.If(v_2);
+ b.With(if_->True(), [&] { b.Return(fn, 3_i); });
+ b.Return(fn, 0_i);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ var v : i32 = 1i;
+ if ((v == 2i)) {
+ return 3i;
+ }
+ return 0i;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, LoadVar_ThenWriteToVarInIf_ThenUseLoad) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ b.Store(var, 1_i);
+ auto* load = b.Load(var);
+ auto* if_ = b.If(true);
+ b.With(if_->True(), [&] { b.Store(var, 2_i); });
+ b.Return(fn, load);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ var v : i32;
+ v = 1i;
+ let v_1 = v;
+ if (true) {
+ v = 2i;
+ }
+ return v_1;
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Switch
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramInliningTest, UnsequencedOutsideSwitch) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* v = b.Add(ty.i32(), 1_i, 2_i);
+ auto* switch_ = b.Switch(3_i);
+ auto* case_ = b.Case(switch_, {Switch::CaseSelector{}});
+ b.With(case_, [&] { b.Return(fn, v); });
+ b.Return(fn, 0_i);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ switch(3i) {
+ default: {
+ return (1i + 2i);
+ }
+ }
+ return 0i;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, SequencedOutsideSwitch) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ var->SetInitializer(b.Constant(1_i));
+ auto* v_1 = b.Load(var);
+ auto* v_2 = b.Add(ty.i32(), v_1, 2_i);
+ auto* switch_ = b.Switch(3_i);
+ auto* case_ = b.Case(switch_, {Switch::CaseSelector{}});
+ b.With(case_, [&] { b.Return(fn, v_2); });
+ b.Return(fn, 0_i);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ var v : i32 = 1i;
+ let v_1 = (v + 2i);
+ switch(3i) {
+ default: {
+ return v_1;
+ }
+ }
+ return 0i;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, UnsequencedUsedBySwitchCondition) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* v = b.Add(ty.i32(), 1_i, 2_i);
+ auto* switch_ = b.Switch(v);
+ auto* case_ = b.Case(switch_, {Switch::CaseSelector{}});
+ b.With(case_, [&] { b.Return(fn, 3_i); });
+ b.Return(fn, 0_i);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ switch((1i + 2i)) {
+ default: {
+ return 3i;
+ }
+ }
+ return 0i;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, SequencedUsedBySwitchCondition) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ var->SetInitializer(b.Constant(1_i));
+ auto* v_1 = b.Load(var);
+ auto* switch_ = b.Switch(v_1);
+ auto* case_ = b.Case(switch_, {Switch::CaseSelector{}});
+ b.With(case_, [&] { b.Return(fn, 3_i); });
+ b.Return(fn, 0_i);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ var v : i32 = 1i;
+ switch(v) {
+ default: {
+ return 3i;
+ }
+ }
+ return 0i;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, LoadVar_ThenWriteToVarInSwitch_ThenUseLoad) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ b.Store(var, 1_i);
+ auto* load = b.Load(var);
+ auto* switch_ = b.Switch(1_i);
+ auto* case_ = b.Case(switch_, {Switch::CaseSelector{}});
+ b.With(case_, [&] { b.Store(var, 2_i); });
+ b.Return(fn, load);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ var v : i32;
+ v = 1i;
+ let v_1 = v;
+ switch(1i) {
+ default: {
+ v = 2i;
+ }
+ }
+ return v_1;
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Loop
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramInliningTest, UnsequencedOutsideLoopInitializer) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ auto* v = b.Add(ty.i32(), 1_i, 2_i);
+ auto* loop = b.Loop();
+ b.With(loop->Initializer(), [&] { b.Store(var, v); });
+ b.With(loop->Body(), [&] { b.ExitLoop(loop); });
+ b.Return(fn, 0_i);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ var v : i32;
+ {
+ v = (1i + 2i);
+ loop {
+ break;
+ }
+ }
+ return 0i;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, SequencedOutsideLoopInitializer) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ auto* v_1 = b.Load(var);
+ auto* v_2 = b.Add(ty.i32(), v_1, 2_i);
+ auto* loop = b.Loop();
+ b.With(loop->Initializer(), [&] { b.Store(var, v_2); });
+ b.With(loop->Body(), [&] { b.ExitLoop(loop); });
+ b.Return(fn, 0_i);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ var v : i32;
+ let v_1 = (v + 2i);
+ {
+ v = v_1;
+ loop {
+ break;
+ }
+ }
+ return 0i;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, LoadVar_ThenWriteToVarInLoopInitializer_ThenUseLoad) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ b.Store(var, 1_i);
+ auto* load = b.Load(var);
+ auto* loop = b.Loop();
+ b.With(loop->Initializer(), [&] { b.Store(var, 2_i); });
+ b.With(loop->Body(), [&] { b.ExitLoop(loop); });
+ b.Return(fn, load);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ var v : i32;
+ v = 1i;
+ let v_1 = v;
+ {
+ v = 2i;
+ loop {
+ break;
+ }
+ }
+ return v_1;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, UnsequencedOutsideLoopBody) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* v = b.Add(ty.i32(), 1_i, 2_i);
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] { b.Return(fn, v); });
+ b.Return(fn, 0_i);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ loop {
+ return (1i + 2i);
+ }
+ return 0i;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, SequencedOutsideLoopBody) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ auto* v_1 = b.Load(var);
+ auto* v_2 = b.Add(ty.i32(), v_1, 2_i);
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] { b.Return(fn, v_2); });
+ b.Return(fn, 0_i);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ var v : i32;
+ let v_1 = (v + 2i);
+ loop {
+ return v_1;
+ }
+ return 0i;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, LoadVar_ThenWriteToVarInLoopBody_ThenUseLoad) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ b.Store(var, 1_i);
+ auto* load = b.Load(var);
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] {
+ b.Store(var, 2_i);
+ b.ExitLoop(loop);
+ });
+ b.Return(fn, load);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ var v : i32;
+ v = 1i;
+ let v_1 = v;
+ loop {
+ v = 2i;
+ break;
+ }
+ return v_1;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, UnsequencedOutsideLoopContinuing) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* v = b.Add(ty.i32(), 1_i, 2_i);
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] { b.Continue(loop); });
+ b.With(loop->Continuing(), [&] { b.BreakIf(loop, b.Equal(ty.bool_(), v, 3_i)); });
+ b.Return(fn, 0_i);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ loop {
+
+ continuing {
+ break if ((1i + 2i) == 3i);
+ }
+ }
+ return 0i;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, SequencedOutsideLoopContinuing) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ auto* v_1 = b.Load(var);
+ auto* v_2 = b.Add(ty.i32(), v_1, 2_i);
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] { b.Continue(loop); });
+ b.With(loop->Continuing(), [&] { b.BreakIf(loop, b.Equal(ty.bool_(), v_2, 3_i)); });
+ b.Return(fn, 0_i);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ var v : i32;
+ let v_1 = (v + 2i);
+ loop {
+
+ continuing {
+ break if (v_1 == 3i);
+ }
+ }
+ return 0i;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, LoadVar_ThenWriteToVarInLoopContinuing_ThenUseLoad) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ b.Store(var, 1_i);
+ auto* load = b.Load(var);
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] { b.Continue(loop); });
+ b.With(loop->Continuing(), [&] {
+ b.Store(var, 2_i);
+ b.NextIteration(loop);
+ });
+ b.With(loop->Body(), [&] {
+ b.Store(var, 2_i);
+ b.ExitLoop(loop);
+ });
+ b.Return(fn, load);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ var v : i32;
+ v = 1i;
+ let v_1 = v;
+ loop {
+ v = 2i;
+ break;
+
+ continuing {
+ v = 2i;
+ }
+ }
+ return v_1;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, LoadVarInLoopInitializer_ThenReadAndWriteToVarInLoopBody) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ b.Store(var, 1_i);
+ auto* loop = b.Loop();
+ b.With(loop->Initializer(), [&] {
+ auto* load = b.Load(var);
+ b.With(loop->Body(), [&] {
+ b.Store(var, b.Add(ty.i32(), load, 1_i));
+ b.ExitLoop(loop);
+ });
+ });
+ b.Return(fn, 3_i);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ var v : i32;
+ v = 1i;
+ {
+ let v_1 = v;
+ loop {
+ v = (v_1 + 1i);
+ break;
+ }
+ }
+ return 3i;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, LoadVarInLoopInitializer_ThenReadAndWriteToVarInLoopContinuing) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ b.Store(var, 1_i);
+ auto* loop = b.Loop();
+ b.With(loop->Initializer(), [&] {
+ auto* load = b.Load(var);
+ b.With(loop->Body(), [&] { b.Continue(loop); });
+ b.With(loop->Continuing(), [&] {
+ b.Store(var, b.Add(ty.i32(), load, 1_i));
+ b.BreakIf(loop, true);
+ });
+ });
+ b.Return(fn, 3_i);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ var v : i32;
+ v = 1i;
+ {
+ let v_1 = v;
+ loop {
+
+ continuing {
+ v = (v_1 + 1i);
+ break if true;
+ }
+ }
+ }
+ return 3i;
+}
+)");
+}
+
+TEST_F(IRToProgramInliningTest, LoadVarInLoopBody_ThenReadAndWriteToVarInLoopContinuing) {
+ auto* fn = b.Function("f", ty.i32());
+ b.With(fn->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, i32>());
+ b.Store(var, 1_i);
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] {
+ auto* load = b.Load(var);
+ b.Continue(loop);
+
+ b.With(loop->Continuing(), [&] {
+ b.Store(var, b.Add(ty.i32(), load, 1_i));
+ b.BreakIf(loop, true);
+ });
+ });
+ b.Return(fn, 3_i);
+ });
+
+ EXPECT_WGSL(R"(
+fn f() -> i32 {
+ var v : i32;
+ v = 1i;
+ loop {
+ let v_1 = v;
+
+ continuing {
+ v = (v_1 + 1i);
+ break if true;
+ }
+ }
+ return 3i;
+}
+)");
+}
+
+} // namespace
+} // namespace tint::ir::test
diff --git a/src/tint/ir/to_program_roundtrip_test.cc b/src/tint/ir/to_program_roundtrip_test.cc
index c8552bc..c5b482a 100644
--- a/src/tint/ir/to_program_roundtrip_test.cc
+++ b/src/tint/ir/to_program_roundtrip_test.cc
@@ -264,7 +264,7 @@
////////////////////////////////////////////////////////////////////////////////
// Short-circuiting binary ops
////////////////////////////////////////////////////////////////////////////////
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalAnd_Param_2) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_And_Param_2) {
Test(R"(
fn f(a : bool, b : bool) -> bool {
return (a && b);
@@ -272,7 +272,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalAnd_Param_3_ab_c) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_And_Param_3_ab_c) {
Test(R"(
fn f(a : bool, b : bool, c : bool) -> bool {
return ((a && b) && c);
@@ -280,7 +280,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalAnd_Param_3_a_bc) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_And_Param_3_a_bc) {
Test(R"(
fn f(a : bool, b : bool, c : bool) -> bool {
return ((a && b) && c);
@@ -288,7 +288,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalAnd_Let_2) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_And_Let_2) {
Test(R"(
fn f(a : bool, b : bool) -> bool {
let l = (a && b);
@@ -297,7 +297,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalAnd_Let_3_ab_c) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_And_Let_3_ab_c) {
Test(R"(
fn f(a : bool, b : bool, c : bool) -> bool {
let l = ((a && b) && c);
@@ -306,7 +306,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalAnd_Let_3_a_bc) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_And_Let_3_a_bc) {
Test(R"(
fn f(a : bool, b : bool, c : bool) -> bool {
let l = (a && (b && c));
@@ -315,7 +315,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalAnd_Call_2) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_And_Call_2) {
Test(R"(
fn a() -> bool {
return true;
@@ -331,7 +331,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalAnd_Call_3_ab_c) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_And_Call_3_ab_c) {
Test(R"(
fn a() -> bool {
return true;
@@ -351,7 +351,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalAnd_Call_3_a_bc) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_And_Call_3_a_bc) {
Test(R"(
fn a() -> bool {
return true;
@@ -371,7 +371,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalOr_Param_2) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_Or_Param_2) {
Test(R"(
fn f(a : bool, b : bool) -> bool {
return (a || b);
@@ -379,7 +379,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalOr_Param_3_ab_c) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_Or_Param_3_ab_c) {
Test(R"(
fn f(a : bool, b : bool, c : bool) -> bool {
return ((a || b) || c);
@@ -387,7 +387,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalOr_Param_3_a_bc) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_Or_Param_3_a_bc) {
Test(R"(
fn f(a : bool, b : bool, c : bool) -> bool {
return (a || (b || c));
@@ -395,7 +395,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalOr_Let_2) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_Or_Let_2) {
Test(R"(
fn f(a : bool, b : bool) -> bool {
let l = (a || b);
@@ -404,7 +404,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalOr_Let_3_ab_c) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_Or_Let_3_ab_c) {
Test(R"(
fn f(a : bool, b : bool, c : bool) -> bool {
let l = ((a || b) || c);
@@ -413,7 +413,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalOr_Let_3_a_bc) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_Or_Let_3_a_bc) {
Test(R"(
fn f(a : bool, b : bool, c : bool) -> bool {
let l = (a || (b || c));
@@ -422,7 +422,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalOr_Call_2) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_Or_Call_2) {
Test(R"(
fn a() -> bool {
return true;
@@ -438,7 +438,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalOr_Call_3_ab_c) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_Or_Call_3_ab_c) {
Test(R"(
fn a() -> bool {
return true;
@@ -458,7 +458,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalOr_Call_3_a_bc) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_Or_Call_3_a_bc) {
Test(R"(
fn a() -> bool {
return true;
@@ -478,7 +478,7 @@
)");
}
-TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalMixed) {
+TEST_F(IRToProgramRoundtripTest, ShortCircuit_Mixed) {
Test(R"(
fn b() -> bool {
return true;
diff --git a/src/tint/ir/to_program_test.cc b/src/tint/ir/to_program_test.cc
index 8cfae5c..00b5681 100644
--- a/src/tint/ir/to_program_test.cc
+++ b/src/tint/ir/to_program_test.cc
@@ -12,65 +12,59 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <sstream>
#include <string>
#include "src/tint/ir/disassembler.h"
-#include "src/tint/ir/ir_test_helper.h"
#include "src/tint/ir/to_program.h"
+#include "src/tint/ir/to_program_test.h"
#include "src/tint/utils/string.h"
#include "src/tint/writer/wgsl/generator.h"
-#if !TINT_BUILD_WGSL_WRITER
-#error "to_program_test.cc requires both the WGSL writer to be enabled"
-#endif
-
-namespace tint::ir {
-namespace {
+namespace tint::ir::test {
using namespace tint::number_suffixes; // NOLINT
using namespace tint::builtin::fluent_types; // NOLINT
-class IRToProgramTest : public IRTestHelper {
- public:
- void Test(std::string_view expected_wgsl) {
- tint::ir::Disassembler d{mod};
- auto disassembly = d.Disassemble();
+IRToProgramTest::Result IRToProgramTest::Run() {
+ Result result;
- auto output_program = ToProgram(mod);
- if (!output_program.IsValid()) {
- FAIL() << output_program.Diagnostics().str() << std::endl //
- << "IR:" << std::endl //
- << disassembly << std::endl //
- << "AST:" << std::endl //
- << Program::printer(&output_program) << std::endl;
- }
+ tint::ir::Disassembler d{mod};
+ result.ir = d.Disassemble();
- ASSERT_TRUE(output_program.IsValid()) << output_program.Diagnostics().str();
-
- auto output = writer::wgsl::Generate(&output_program, {});
- ASSERT_TRUE(output.success) << output.error;
-
- auto expected = std::string(utils::TrimSpace(expected_wgsl));
- if (!expected.empty()) {
- expected = "\n" + expected + "\n";
- }
- auto got = std::string(utils::TrimSpace(output.wgsl));
- if (!got.empty()) {
- got = "\n" + got + "\n";
- }
- EXPECT_EQ(expected, got) << "IR:" << std::endl << disassembly;
+ auto output_program = ToProgram(mod);
+ if (!output_program.IsValid()) {
+ result.err = output_program.Diagnostics().str();
+ result.ast = Program::printer(&output_program);
+ return result;
}
-};
+
+ auto output = writer::wgsl::Generate(&output_program, {});
+ if (!output.success) {
+ std::stringstream ss;
+ ss << "wgsl::Generate() errored: " << output.error;
+ result.err = ss.str();
+ return result;
+ }
+
+ result.wgsl = std::string(utils::TrimSpace(output.wgsl));
+ if (!result.wgsl.empty()) {
+ result.wgsl = "\n" + result.wgsl + "\n";
+ }
+
+ return result;
+}
+
+namespace {
TEST_F(IRToProgramTest, EmptyModule) {
- Test("");
+ EXPECT_WGSL("");
}
TEST_F(IRToProgramTest, SingleFunction_Empty) {
- auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
+ b.Function("f", ty.void_());
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
}
)");
@@ -78,11 +72,10 @@
TEST_F(IRToProgramTest, SingleFunction_Return) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
fn->Block()->Append(b.Return(fn));
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
}
)");
@@ -90,11 +83,10 @@
TEST_F(IRToProgramTest, SingleFunction_Return_i32) {
auto* fn = b.Function("f", ty.i32());
- mod.functions.Push(fn);
fn->Block()->Append(b.Return(fn, 42_i));
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() -> i32 {
return 42i;
}
@@ -103,16 +95,13 @@
TEST_F(IRToProgramTest, SingleFunction_Parameters) {
auto* fn = b.Function("f", ty.i32());
- auto* i = b.FunctionParam(ty.i32());
- auto* u = b.FunctionParam(ty.u32());
- mod.SetName(i, "i");
- mod.SetName(u, "u");
+ auto* i = b.FunctionParam("i", ty.i32());
+ auto* u = b.FunctionParam("u", ty.u32());
fn->SetParams({i, u});
- mod.functions.Push(fn);
fn->Block()->Append(b.Return(fn, i));
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(i : i32, u : u32) -> i32 {
return i;
}
@@ -124,14 +113,12 @@
////////////////////////////////////////////////////////////////////////////////
TEST_F(IRToProgramTest, UnaryOp_Negate) {
auto* fn = b.Function("f", ty.i32());
- auto* i = b.FunctionParam(ty.i32());
- mod.SetName(i, "i");
+ auto* i = b.FunctionParam("i", ty.i32());
fn->SetParams({i});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.Negation(ty.i32(), i)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(i : i32) -> i32 {
return -(i);
}
@@ -140,14 +127,12 @@
TEST_F(IRToProgramTest, UnaryOp_Complement) {
auto* fn = b.Function("f", ty.u32());
- auto* i = b.FunctionParam(ty.u32());
- mod.SetName(i, "i");
+ auto* i = b.FunctionParam("i", ty.u32());
fn->SetParams({i});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.Complement(ty.u32(), i)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(i : u32) -> u32 {
return ~(i);
}
@@ -156,14 +141,12 @@
TEST_F(IRToProgramTest, UnaryOp_Not) {
auto* fn = b.Function("f", ty.bool_());
- auto* i = b.FunctionParam(ty.bool_());
- mod.SetName(i, "b");
+ auto* i = b.FunctionParam("b", ty.bool_());
fn->SetParams({i});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.Not(ty.bool_(), i)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(b : bool) -> bool {
return !(b);
}
@@ -175,16 +158,13 @@
////////////////////////////////////////////////////////////////////////////////
TEST_F(IRToProgramTest, BinaryOp_Add) {
auto* fn = b.Function("f", ty.i32());
- auto* pa = b.FunctionParam(ty.i32());
- auto* pb = b.FunctionParam(ty.i32());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.i32());
+ auto* pb = b.FunctionParam("b", ty.i32());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.Add(ty.i32(), pa, pb)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : i32, b : i32) -> i32 {
return (a + b);
}
@@ -193,16 +173,13 @@
TEST_F(IRToProgramTest, BinaryOp_Subtract) {
auto* fn = b.Function("f", ty.i32());
- auto* pa = b.FunctionParam(ty.i32());
- auto* pb = b.FunctionParam(ty.i32());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.i32());
+ auto* pb = b.FunctionParam("b", ty.i32());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.Subtract(ty.i32(), pa, pb)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : i32, b : i32) -> i32 {
return (a - b);
}
@@ -211,16 +188,13 @@
TEST_F(IRToProgramTest, BinaryOp_Multiply) {
auto* fn = b.Function("f", ty.i32());
- auto* pa = b.FunctionParam(ty.i32());
- auto* pb = b.FunctionParam(ty.i32());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.i32());
+ auto* pb = b.FunctionParam("b", ty.i32());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.Multiply(ty.i32(), pa, pb)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : i32, b : i32) -> i32 {
return (a * b);
}
@@ -229,16 +203,13 @@
TEST_F(IRToProgramTest, BinaryOp_Divide) {
auto* fn = b.Function("f", ty.i32());
- auto* pa = b.FunctionParam(ty.i32());
- auto* pb = b.FunctionParam(ty.i32());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.i32());
+ auto* pb = b.FunctionParam("b", ty.i32());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.Divide(ty.i32(), pa, pb)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : i32, b : i32) -> i32 {
return (a / b);
}
@@ -247,16 +218,13 @@
TEST_F(IRToProgramTest, BinaryOp_Modulo) {
auto* fn = b.Function("f", ty.i32());
- auto* pa = b.FunctionParam(ty.i32());
- auto* pb = b.FunctionParam(ty.i32());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.i32());
+ auto* pb = b.FunctionParam("b", ty.i32());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.Modulo(ty.i32(), pa, pb)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : i32, b : i32) -> i32 {
return (a % b);
}
@@ -265,16 +233,13 @@
TEST_F(IRToProgramTest, BinaryOp_And) {
auto* fn = b.Function("f", ty.i32());
- auto* pa = b.FunctionParam(ty.i32());
- auto* pb = b.FunctionParam(ty.i32());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.i32());
+ auto* pb = b.FunctionParam("b", ty.i32());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.And(ty.i32(), pa, pb)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : i32, b : i32) -> i32 {
return (a & b);
}
@@ -283,16 +248,13 @@
TEST_F(IRToProgramTest, BinaryOp_Or) {
auto* fn = b.Function("f", ty.i32());
- auto* pa = b.FunctionParam(ty.i32());
- auto* pb = b.FunctionParam(ty.i32());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.i32());
+ auto* pb = b.FunctionParam("b", ty.i32());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.Or(ty.i32(), pa, pb)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : i32, b : i32) -> i32 {
return (a | b);
}
@@ -301,16 +263,13 @@
TEST_F(IRToProgramTest, BinaryOp_Xor) {
auto* fn = b.Function("f", ty.i32());
- auto* pa = b.FunctionParam(ty.i32());
- auto* pb = b.FunctionParam(ty.i32());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.i32());
+ auto* pb = b.FunctionParam("b", ty.i32());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.Xor(ty.i32(), pa, pb)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : i32, b : i32) -> i32 {
return (a ^ b);
}
@@ -319,16 +278,13 @@
TEST_F(IRToProgramTest, BinaryOp_Equal) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.i32());
- auto* pb = b.FunctionParam(ty.i32());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.i32());
+ auto* pb = b.FunctionParam("b", ty.i32());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.Equal(ty.i32(), pa, pb)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : i32, b : i32) -> bool {
return (a == b);
}
@@ -337,16 +293,13 @@
TEST_F(IRToProgramTest, BinaryOp_NotEqual) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.i32());
- auto* pb = b.FunctionParam(ty.i32());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.i32());
+ auto* pb = b.FunctionParam("b", ty.i32());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.NotEqual(ty.i32(), pa, pb)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : i32, b : i32) -> bool {
return (a != b);
}
@@ -355,16 +308,13 @@
TEST_F(IRToProgramTest, BinaryOp_LessThan) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.i32());
- auto* pb = b.FunctionParam(ty.i32());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.i32());
+ auto* pb = b.FunctionParam("b", ty.i32());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.LessThan(ty.i32(), pa, pb)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : i32, b : i32) -> bool {
return (a < b);
}
@@ -373,16 +323,13 @@
TEST_F(IRToProgramTest, BinaryOp_GreaterThan) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.i32());
- auto* pb = b.FunctionParam(ty.i32());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.i32());
+ auto* pb = b.FunctionParam("b", ty.i32());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.GreaterThan(ty.i32(), pa, pb)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : i32, b : i32) -> bool {
return (a > b);
}
@@ -391,16 +338,13 @@
TEST_F(IRToProgramTest, BinaryOp_LessThanEqual) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.i32());
- auto* pb = b.FunctionParam(ty.i32());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.i32());
+ auto* pb = b.FunctionParam("b", ty.i32());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.LessThanEqual(ty.i32(), pa, pb)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : i32, b : i32) -> bool {
return (a <= b);
}
@@ -409,16 +353,13 @@
TEST_F(IRToProgramTest, BinaryOp_GreaterThanEqual) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.i32());
- auto* pb = b.FunctionParam(ty.i32());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.i32());
+ auto* pb = b.FunctionParam("b", ty.i32());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.GreaterThanEqual(ty.i32(), pa, pb)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : i32, b : i32) -> bool {
return (a >= b);
}
@@ -427,16 +368,13 @@
TEST_F(IRToProgramTest, BinaryOp_ShiftLeft) {
auto* fn = b.Function("f", ty.i32());
- auto* pa = b.FunctionParam(ty.i32());
- auto* pb = b.FunctionParam(ty.u32());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.i32());
+ auto* pb = b.FunctionParam("b", ty.u32());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.ShiftLeft(ty.i32(), pa, pb)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : i32, b : u32) -> i32 {
return (a << b);
}
@@ -445,16 +383,13 @@
TEST_F(IRToProgramTest, BinaryOp_ShiftRight) {
auto* fn = b.Function("f", ty.i32());
- auto* pa = b.FunctionParam(ty.i32());
- auto* pb = b.FunctionParam(ty.u32());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.i32());
+ auto* pb = b.FunctionParam("b", ty.u32());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] { b.Return(fn, b.ShiftRight(ty.i32(), pa, pb)); });
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : i32, b : u32) -> i32 {
return (a >> b);
}
@@ -464,14 +399,11 @@
////////////////////////////////////////////////////////////////////////////////
// Short-circuiting binary ops
////////////////////////////////////////////////////////////////////////////////
-TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Param_2) {
+TEST_F(IRToProgramTest, ShortCircuit_And_Param_2) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.bool_());
- auto* pb = b.FunctionParam(ty.bool_());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.bool_());
+ auto* pb = b.FunctionParam("b", ty.bool_());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* if_ = b.If(pa);
@@ -482,23 +414,19 @@
b.Return(fn, if_->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : bool, b : bool) -> bool {
return (a && b);
}
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Param_3_ab_c) {
+TEST_F(IRToProgramTest, ShortCircuit_And_Param_3_ab_c) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.bool_());
- auto* pb = b.FunctionParam(ty.bool_());
- auto* pc = b.FunctionParam(ty.bool_());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
- mod.SetName(pc, "c");
+ auto* pa = b.FunctionParam("a", ty.bool_());
+ auto* pb = b.FunctionParam("b", ty.bool_());
+ auto* pc = b.FunctionParam("c", ty.bool_());
fn->SetParams({pa, pb, pc});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* if1 = b.If(pa);
@@ -514,52 +442,47 @@
b.Return(fn, if2->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : bool, b : bool, c : bool) -> bool {
return ((a && b) && c);
}
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Param_3_a_bc) {
+TEST_F(IRToProgramTest, ShortCircuit_And_Param_3_a_bc) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.bool_());
- auto* pb = b.FunctionParam(ty.bool_());
- auto* pc = b.FunctionParam(ty.bool_());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
- mod.SetName(pc, "c");
+ auto* pa = b.FunctionParam("a", ty.bool_());
+ auto* pb = b.FunctionParam("b", ty.bool_());
+ auto* pc = b.FunctionParam("c", ty.bool_());
fn->SetParams({pa, pb, pc});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* if1 = b.If(pb);
+ auto* if1 = b.If(pa);
if1->SetResults(b.InstructionResult(ty.bool_()));
- b.With(if1->True(), [&] { b.ExitIf(if1, pc); });
- b.With(if1->False(), [&] { b.ExitIf(if1, false); });
+ b.With(if1->True(), [&] {
+ auto* if2 = b.If(pb);
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, pc); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, false); });
- auto* if2 = b.If(pa);
- if2->SetResults(b.InstructionResult(ty.bool_()));
- b.With(if2->True(), [&] { b.ExitIf(if2, if1->Result(0)); });
- b.With(if2->False(), [&] { b.ExitIf(if2, false); });
- b.Return(fn, if2->Result(0));
+ b.ExitIf(if1, if2->Result(0));
+ });
+ b.With(if1->False(), [&] { b.ExitIf(if1, false); });
+ b.Return(fn, if1->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : bool, b : bool, c : bool) -> bool {
return (a && (b && c));
}
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Let_2) {
+TEST_F(IRToProgramTest, ShortCircuit_And_Let_2) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.bool_());
- auto* pb = b.FunctionParam(ty.bool_());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.bool_());
+ auto* pb = b.FunctionParam("b", ty.bool_());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* if_ = b.If(pa);
@@ -571,7 +494,7 @@
b.Return(fn, if_->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : bool, b : bool) -> bool {
let l = (a && b);
return l;
@@ -579,16 +502,12 @@
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Let_3_ab_c) {
+TEST_F(IRToProgramTest, ShortCircuit_And_Let_3_ab_c) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.bool_());
- auto* pb = b.FunctionParam(ty.bool_());
- auto* pc = b.FunctionParam(ty.bool_());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
- mod.SetName(pc, "c");
+ auto* pa = b.FunctionParam("a", ty.bool_());
+ auto* pb = b.FunctionParam("b", ty.bool_());
+ auto* pc = b.FunctionParam("c", ty.bool_());
fn->SetParams({pa, pb, pc});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* if1 = b.If(pa);
@@ -605,7 +524,7 @@
b.Return(fn, if2->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : bool, b : bool, c : bool) -> bool {
let l = ((a && b) && c);
return l;
@@ -613,33 +532,31 @@
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Let_3_a_bc) {
+TEST_F(IRToProgramTest, ShortCircuit_And_Let_3_a_bc) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.bool_());
- auto* pb = b.FunctionParam(ty.bool_());
- auto* pc = b.FunctionParam(ty.bool_());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
- mod.SetName(pc, "c");
+ auto* pa = b.FunctionParam("a", ty.bool_());
+ auto* pb = b.FunctionParam("b", ty.bool_());
+ auto* pc = b.FunctionParam("c", ty.bool_());
fn->SetParams({pa, pb, pc});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* if1 = b.If(pb);
+ auto* if1 = b.If(pa);
if1->SetResults(b.InstructionResult(ty.bool_()));
- b.With(if1->True(), [&] { b.ExitIf(if1, pc); });
+ b.With(if1->True(), [&] {
+ auto* if2 = b.If(pb);
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, pc); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, false); });
+
+ b.ExitIf(if1, if2->Result(0));
+ });
b.With(if1->False(), [&] { b.ExitIf(if1, false); });
- auto* if2 = b.If(pa);
- if2->SetResults(b.InstructionResult(ty.bool_()));
- b.With(if2->True(), [&] { b.ExitIf(if2, if1->Result(0)); });
- b.With(if2->False(), [&] { b.ExitIf(if2, false); });
-
- mod.SetName(if2->Result(0), "l");
- b.Return(fn, if2->Result(0));
+ mod.SetName(if1->Result(0), "l");
+ b.Return(fn, if1->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : bool, b : bool, c : bool) -> bool {
let l = (a && (b && c));
return l;
@@ -647,17 +564,14 @@
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Call_2) {
+TEST_F(IRToProgramTest, ShortCircuit_And_Call_2) {
auto* fn_a = b.Function("a", ty.bool_());
- mod.functions.Push(fn_a);
b.With(fn_a->Block(), [&] { b.Return(fn_a, true); });
auto* fn_b = b.Function("b", ty.bool_());
- mod.functions.Push(fn_b);
b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
auto* fn = b.Function("f", ty.bool_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* if_ = b.If(b.Call(ty.bool_(), fn_a));
@@ -668,7 +582,7 @@
b.Return(fn, if_->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn a() -> bool {
return true;
}
@@ -683,21 +597,17 @@
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Call_3_ab_c) {
+TEST_F(IRToProgramTest, ShortCircuit_And_Call_3_ab_c) {
auto* fn_a = b.Function("a", ty.bool_());
- mod.functions.Push(fn_a);
b.With(fn_a->Block(), [&] { b.Return(fn_a, true); });
auto* fn_b = b.Function("b", ty.bool_());
- mod.functions.Push(fn_b);
b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
auto* fn_c = b.Function("c", ty.bool_());
- mod.functions.Push(fn_c);
b.With(fn_c->Block(), [&] { b.Return(fn_c, true); });
auto* fn = b.Function("f", ty.bool_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* if1 = b.If(b.Call(ty.bool_(), fn_a));
@@ -713,7 +623,7 @@
b.Return(fn, if2->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn a() -> bool {
return true;
}
@@ -732,37 +642,35 @@
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Call_3_a_bc) {
+TEST_F(IRToProgramTest, ShortCircuit_And_Call_3_a_bc) {
auto* fn_a = b.Function("a", ty.bool_());
- mod.functions.Push(fn_a);
b.With(fn_a->Block(), [&] { b.Return(fn_a, true); });
auto* fn_b = b.Function("b", ty.bool_());
- mod.functions.Push(fn_b);
b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
auto* fn_c = b.Function("c", ty.bool_());
- mod.functions.Push(fn_c);
b.With(fn_c->Block(), [&] { b.Return(fn_c, true); });
auto* fn = b.Function("f", ty.bool_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* if1 = b.If(b.Call(ty.bool_(), fn_b));
+ auto* if1 = b.If(b.Call(ty.bool_(), fn_a));
if1->SetResults(b.InstructionResult(ty.bool_()));
- b.With(if1->True(), [&] { b.ExitIf(if1, b.Call(ty.bool_(), fn_c)); });
+ b.With(if1->True(), [&] {
+ auto* if2 = b.If(b.Call(ty.bool_(), fn_b));
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, b.Call(ty.bool_(), fn_c)); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, false); });
+
+ b.ExitIf(if1, if2->Result(0));
+ });
b.With(if1->False(), [&] { b.ExitIf(if1, false); });
- auto* if2 = b.If(b.Call(ty.bool_(), fn_a));
- if2->SetResults(b.InstructionResult(ty.bool_()));
- b.With(if2->True(), [&] { b.ExitIf(if2, if1->Result(0)); });
- b.With(if2->False(), [&] { b.ExitIf(if2, false); });
-
- b.Return(fn, if2->Result(0));
+ b.Return(fn, if1->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn a() -> bool {
return true;
}
@@ -781,14 +689,11 @@
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Param_2) {
+TEST_F(IRToProgramTest, ShortCircuit_Or_Param_2) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.bool_());
- auto* pb = b.FunctionParam(ty.bool_());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.bool_());
+ auto* pb = b.FunctionParam("b", ty.bool_());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* if_ = b.If(pa);
@@ -799,23 +704,19 @@
b.Return(fn, if_->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : bool, b : bool) -> bool {
return (a || b);
}
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Param_3_ab_c) {
+TEST_F(IRToProgramTest, ShortCircuit_Or_Param_3_ab_c) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.bool_());
- auto* pb = b.FunctionParam(ty.bool_());
- auto* pc = b.FunctionParam(ty.bool_());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
- mod.SetName(pc, "c");
+ auto* pa = b.FunctionParam("a", ty.bool_());
+ auto* pb = b.FunctionParam("b", ty.bool_());
+ auto* pc = b.FunctionParam("c", ty.bool_());
fn->SetParams({pa, pb, pc});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* if1 = b.If(pa);
@@ -831,53 +732,48 @@
b.Return(fn, if2->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : bool, b : bool, c : bool) -> bool {
return ((a || b) || c);
}
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Param_3_a_bc) {
+TEST_F(IRToProgramTest, ShortCircuit_Or_Param_3_a_bc) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.bool_());
- auto* pb = b.FunctionParam(ty.bool_());
- auto* pc = b.FunctionParam(ty.bool_());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
- mod.SetName(pc, "c");
+ auto* pa = b.FunctionParam("a", ty.bool_());
+ auto* pb = b.FunctionParam("b", ty.bool_());
+ auto* pc = b.FunctionParam("c", ty.bool_());
fn->SetParams({pa, pb, pc});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* if1 = b.If(pb);
+ auto* if1 = b.If(pa);
if1->SetResults(b.InstructionResult(ty.bool_()));
b.With(if1->True(), [&] { b.ExitIf(if1, true); });
- b.With(if1->False(), [&] { b.ExitIf(if1, pc); });
+ b.With(if1->False(), [&] {
+ auto* if2 = b.If(pb);
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, true); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, pc); });
- auto* if2 = b.If(pa);
- if2->SetResults(b.InstructionResult(ty.bool_()));
- b.With(if2->True(), [&] { b.ExitIf(if2, true); });
- b.With(if2->False(), [&] { b.ExitIf(if2, if1->Result(0)); });
+ b.ExitIf(if1, if2->Result(0));
+ });
- b.Return(fn, if2->Result(0));
+ b.Return(fn, if1->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : bool, b : bool, c : bool) -> bool {
return (a || (b || c));
}
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Let_2) {
+TEST_F(IRToProgramTest, ShortCircuit_Or_Let_2) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.bool_());
- auto* pb = b.FunctionParam(ty.bool_());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
+ auto* pa = b.FunctionParam("a", ty.bool_());
+ auto* pb = b.FunctionParam("b", ty.bool_());
fn->SetParams({pa, pb});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* if_ = b.If(pa);
@@ -889,7 +785,7 @@
b.Return(fn, if_->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : bool, b : bool) -> bool {
let l = (a || b);
return l;
@@ -897,16 +793,12 @@
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Let_3_ab_c) {
+TEST_F(IRToProgramTest, ShortCircuit_Or_Let_3_ab_c) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.bool_());
- auto* pb = b.FunctionParam(ty.bool_());
- auto* pc = b.FunctionParam(ty.bool_());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
- mod.SetName(pc, "c");
+ auto* pa = b.FunctionParam("a", ty.bool_());
+ auto* pb = b.FunctionParam("b", ty.bool_());
+ auto* pc = b.FunctionParam("c", ty.bool_());
fn->SetParams({pa, pb, pc});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* if1 = b.If(pa);
@@ -923,7 +815,7 @@
b.Return(fn, if2->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : bool, b : bool, c : bool) -> bool {
let l = ((a || b) || c);
return l;
@@ -931,33 +823,31 @@
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Let_3_a_bc) {
+TEST_F(IRToProgramTest, ShortCircuit_Or_Let_3_a_bc) {
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.bool_());
- auto* pb = b.FunctionParam(ty.bool_());
- auto* pc = b.FunctionParam(ty.bool_());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
- mod.SetName(pc, "c");
+ auto* pa = b.FunctionParam("a", ty.bool_());
+ auto* pb = b.FunctionParam("b", ty.bool_());
+ auto* pc = b.FunctionParam("c", ty.bool_());
fn->SetParams({pa, pb, pc});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* if1 = b.If(pb);
+ auto* if1 = b.If(pa);
if1->SetResults(b.InstructionResult(ty.bool_()));
b.With(if1->True(), [&] { b.ExitIf(if1, true); });
- b.With(if1->False(), [&] { b.ExitIf(if1, pc); });
+ b.With(if1->False(), [&] {
+ auto* if2 = b.If(pb);
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, true); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, pc); });
- auto* if2 = b.If(pa);
- if2->SetResults(b.InstructionResult(ty.bool_()));
- b.With(if2->True(), [&] { b.ExitIf(if2, true); });
- b.With(if2->False(), [&] { b.ExitIf(if2, if1->Result(0)); });
+ b.ExitIf(if1, if2->Result(0));
+ });
- mod.SetName(if2->Result(0), "l");
- b.Return(fn, if2->Result(0));
+ mod.SetName(if1->Result(0), "l");
+ b.Return(fn, if1->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(a : bool, b : bool, c : bool) -> bool {
let l = (a || (b || c));
return l;
@@ -965,17 +855,14 @@
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Call_2) {
+TEST_F(IRToProgramTest, ShortCircuit_Or_Call_2) {
auto* fn_a = b.Function("a", ty.bool_());
- mod.functions.Push(fn_a);
b.With(fn_a->Block(), [&] { b.Return(fn_a, true); });
auto* fn_b = b.Function("b", ty.bool_());
- mod.functions.Push(fn_b);
b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
auto* fn = b.Function("f", ty.bool_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* if_ = b.If(b.Call(ty.bool_(), fn_a));
@@ -986,7 +873,7 @@
b.Return(fn, if_->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn a() -> bool {
return true;
}
@@ -1001,21 +888,17 @@
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Call_3_ab_c) {
+TEST_F(IRToProgramTest, ShortCircuit_Or_Call_3_ab_c) {
auto* fn_a = b.Function("a", ty.bool_());
- mod.functions.Push(fn_a);
b.With(fn_a->Block(), [&] { b.Return(fn_a, true); });
auto* fn_b = b.Function("b", ty.bool_());
- mod.functions.Push(fn_b);
b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
auto* fn_c = b.Function("c", ty.bool_());
- mod.functions.Push(fn_c);
b.With(fn_c->Block(), [&] { b.Return(fn_c, true); });
auto* fn = b.Function("f", ty.bool_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* if1 = b.If(b.Call(ty.bool_(), fn_a));
@@ -1031,7 +914,7 @@
b.Return(fn, if2->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn a() -> bool {
return true;
}
@@ -1050,37 +933,35 @@
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Call_3_a_bc) {
+TEST_F(IRToProgramTest, ShortCircuit_Or_Call_3_a_bc) {
auto* fn_a = b.Function("a", ty.bool_());
- mod.functions.Push(fn_a);
b.With(fn_a->Block(), [&] { b.Return(fn_a, true); });
auto* fn_b = b.Function("b", ty.bool_());
- mod.functions.Push(fn_b);
b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
auto* fn_c = b.Function("c", ty.bool_());
- mod.functions.Push(fn_c);
b.With(fn_c->Block(), [&] { b.Return(fn_c, true); });
auto* fn = b.Function("f", ty.bool_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* if1 = b.If(b.Call(ty.bool_(), fn_b));
+ auto* if1 = b.If(b.Call(ty.bool_(), fn_a));
if1->SetResults(b.InstructionResult(ty.bool_()));
b.With(if1->True(), [&] { b.ExitIf(if1, true); });
- b.With(if1->False(), [&] { b.ExitIf(if1, b.Call(ty.bool_(), fn_c)); });
+ b.With(if1->False(), [&] {
+ auto* if2 = b.If(b.Call(ty.bool_(), fn_b));
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, true); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, b.Call(ty.bool_(), fn_c)); });
- auto* if2 = b.If(b.Call(ty.bool_(), fn_a));
- if2->SetResults(b.InstructionResult(ty.bool_()));
- b.With(if2->True(), [&] { b.ExitIf(if2, true); });
- b.With(if2->False(), [&] { b.ExitIf(if2, if1->Result(0)); });
+ b.ExitIf(if1, if2->Result(0));
+ });
- b.Return(fn, if2->Result(0));
+ b.Return(fn, if1->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn a() -> bool {
return true;
}
@@ -1099,22 +980,17 @@
)");
}
-TEST_F(IRToProgramTest, BinaryOp_LogicalMixed) {
+TEST_F(IRToProgramTest, ShortCircuit_Mixed) {
auto* fn_b = b.Function("b", ty.bool_());
- mod.functions.Push(fn_b);
b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
auto* fn_d = b.Function("d", ty.bool_());
- mod.functions.Push(fn_d);
b.With(fn_d->Block(), [&] { b.Return(fn_d, true); });
auto* fn = b.Function("f", ty.bool_());
- auto* pa = b.FunctionParam(ty.bool_());
- auto* pc = b.FunctionParam(ty.bool_());
- mod.SetName(pa, "a");
- mod.SetName(pc, "c");
+ auto* pa = b.FunctionParam("a", ty.bool_());
+ auto* pc = b.FunctionParam("c", ty.bool_());
fn->SetParams({pa, pc});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* if1 = b.If(pa);
@@ -1138,7 +1014,7 @@
b.Return(fn, if2->Result(0));
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn b() -> bool {
return true;
}
@@ -1155,19 +1031,266 @@
}
////////////////////////////////////////////////////////////////////////////////
+// Non short-circuiting binary ops
+// Similar to the above, but cannot be short-circuited as the RHS is evaluated
+// outside of the if block.
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramTest, NonShortCircuit_And_ParamCallParam_a_bc) {
+ auto* fn_b = b.Function("b", ty.bool_());
+ b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
+
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pc = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pc, "c");
+ fn->SetParams({pa, pc});
+
+ b.With(fn->Block(), [&] {
+ // 'b() && c' is evaluated before 'a'.
+ auto* if1 = b.If(b.Call(ty.bool_(), fn_b));
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, pc); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, false); });
+
+ auto* if2 = b.If(pa);
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, if1->Result(0)); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, false); });
+ b.Return(fn, if2->Result(0));
+ });
+
+ EXPECT_WGSL(R"(
+fn b() -> bool {
+ return true;
+}
+
+fn f(a : bool, c : bool) -> bool {
+ let v = (b() && c);
+ return (a && v);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, NonShortCircuit_And_Call_3_a_bc) {
+ auto* fn_a = b.Function("a", ty.bool_());
+
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, true); });
+
+ auto* fn_b = b.Function("b", ty.bool_());
+
+ b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
+
+ auto* fn_c = b.Function("c", ty.bool_());
+
+ b.With(fn_c->Block(), [&] { b.Return(fn_c, true); });
+
+ auto* fn = b.Function("f", ty.bool_());
+
+ b.With(fn->Block(), [&] {
+ // 'b() && c()' is evaluated before 'a()'.
+ auto* if1 = b.If(b.Call(ty.bool_(), fn_b));
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, b.Call(ty.bool_(), fn_c)); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, false); });
+
+ auto* if2 = b.If(b.Call(ty.bool_(), fn_a));
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, if1->Result(0)); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, false); });
+
+ b.Return(fn, if2->Result(0));
+ });
+
+ EXPECT_WGSL(R"(
+fn a() -> bool {
+ return true;
+}
+
+fn b() -> bool {
+ return true;
+}
+
+fn c() -> bool {
+ return true;
+}
+
+fn f() -> bool {
+ let v = (b() && c());
+ return (a() && v);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, NonShortCircuit_And_Param_3_a_bc_EarlyEval) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pb = b.FunctionParam(ty.bool_());
+ auto* pc = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ mod.SetName(pc, "c");
+ fn->SetParams({pa, pb, pc});
+
+ b.With(fn->Block(), [&] {
+ // 'b && c' is evaluated outside the true block of if2, but these can be moved to the RHS
+ // of the 'a &&' as the 'b && c' is not sequenced.
+ auto* if1 = b.If(pb);
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, pc); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, false); });
+
+ auto* if2 = b.If(pa);
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, if1->Result(0)); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, false); });
+
+ b.Return(fn, if2->Result(0));
+ });
+
+ EXPECT_WGSL(R"(
+fn f(a : bool, b : bool, c : bool) -> bool {
+ let v = (b && c);
+ return (a && v);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, NonShortCircuit_Or_ParamCallParam_a_bc) {
+ auto* fn_b = b.Function("b", ty.bool_());
+ b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
+
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pc = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pc, "c");
+ fn->SetParams({pa, pc});
+
+ b.With(fn->Block(), [&] {
+ // 'b() && c' is evaluated before 'a'.
+ auto* if1 = b.If(b.Call(ty.bool_(), fn_b));
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, true); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, pc); });
+
+ auto* if2 = b.If(pa);
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, true); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, if1->Result(0)); });
+
+ mod.SetName(if2->Result(0), "l");
+ b.Return(fn, if2->Result(0));
+ });
+
+ EXPECT_WGSL(R"(
+fn b() -> bool {
+ return true;
+}
+
+fn f(a : bool, c : bool) -> bool {
+ let v = (b() || c);
+ let l = (a || v);
+ return l;
+}
+)");
+}
+
+TEST_F(IRToProgramTest, NonShortCircuit_Or_Call_3_a_bc) {
+ auto* fn_a = b.Function("a", ty.bool_());
+
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, true); });
+
+ auto* fn_b = b.Function("b", ty.bool_());
+
+ b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
+
+ auto* fn_c = b.Function("c", ty.bool_());
+
+ b.With(fn_c->Block(), [&] { b.Return(fn_c, true); });
+
+ auto* fn = b.Function("f", ty.bool_());
+
+ b.With(fn->Block(), [&] {
+ auto* if1 = b.If(b.Call(ty.bool_(), fn_b));
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, true); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, b.Call(ty.bool_(), fn_c)); });
+
+ auto* if2 = b.If(b.Call(ty.bool_(), fn_a));
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, true); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, if1->Result(0)); });
+
+ b.Return(fn, if2->Result(0));
+ });
+
+ EXPECT_WGSL(R"(
+fn a() -> bool {
+ return true;
+}
+
+fn b() -> bool {
+ return true;
+}
+
+fn c() -> bool {
+ return true;
+}
+
+fn f() -> bool {
+ let v = (b() || c());
+ return (a() || v);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, NonShortCircuit_Or_Param_3_a_bc_EarlyEval) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pb = b.FunctionParam(ty.bool_());
+ auto* pc = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ mod.SetName(pc, "c");
+ fn->SetParams({pa, pb, pc});
+
+ b.With(fn->Block(), [&] {
+ // 'b || c' is evaluated outside the true block of if2, but these can be moved to the RHS
+ // of the 'a ||' as the 'b || c' is not sequenced.
+ auto* if1 = b.If(pb);
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, true); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, pc); });
+
+ auto* if2 = b.If(pa);
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, true); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, if1->Result(0)); });
+
+ b.Return(fn, if2->Result(0));
+ });
+
+ EXPECT_WGSL(R"(
+fn f(a : bool, b : bool, c : bool) -> bool {
+ let v = (b || c);
+ return (a || v);
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
// Compound assignment
////////////////////////////////////////////////////////////////////////////////
TEST_F(IRToProgramTest, CompoundAssign_Increment) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* v = b.Var(ty.ptr<function, i32>());
+ auto* v = b.Var("v", ty.ptr<function, i32>());
b.Store(v, b.Add(ty.i32(), b.Load(v), 1_i));
- mod.SetName(v, "v");
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
var v : i32;
v = (v + 1i);
@@ -1177,15 +1300,13 @@
TEST_F(IRToProgramTest, CompoundAssign_Decrement) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* v = b.Var(ty.ptr<function, i32>());
+ auto* v = b.Var("v", ty.ptr<function, i32>());
b.Store(v, b.Subtract(ty.i32(), b.Load(v), 1_i));
- mod.SetName(v, "v");
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
var v : i32;
v = (v - 1i);
@@ -1195,15 +1316,13 @@
TEST_F(IRToProgramTest, CompoundAssign_Add) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* v = b.Var(ty.ptr<function, i32>());
+ auto* v = b.Var("v", ty.ptr<function, i32>());
b.Store(v, b.Add(ty.i32(), b.Load(v), 8_i));
- mod.SetName(v, "v");
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
var v : i32;
v = (v + 8i);
@@ -1213,15 +1332,13 @@
TEST_F(IRToProgramTest, CompoundAssign_Subtract) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* v = b.Var(ty.ptr<function, i32>());
+ auto* v = b.Var("v", ty.ptr<function, i32>());
b.Store(v, b.Subtract(ty.i32(), b.Load(v), 8_i));
- mod.SetName(v, "v");
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
var v : i32;
v = (v - 8i);
@@ -1231,15 +1348,13 @@
TEST_F(IRToProgramTest, CompoundAssign_Multiply) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* v = b.Var(ty.ptr<function, i32>());
+ auto* v = b.Var("v", ty.ptr<function, i32>());
b.Store(v, b.Multiply(ty.i32(), b.Load(v), 8_i));
- mod.SetName(v, "v");
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
var v : i32;
v = (v * 8i);
@@ -1249,15 +1364,13 @@
TEST_F(IRToProgramTest, CompoundAssign_Divide) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* v = b.Var(ty.ptr<function, i32>());
+ auto* v = b.Var("v", ty.ptr<function, i32>());
b.Store(v, b.Divide(ty.i32(), b.Load(v), 8_i));
- mod.SetName(v, "v");
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
var v : i32;
v = (v / 8i);
@@ -1267,15 +1380,13 @@
TEST_F(IRToProgramTest, CompoundAssign_Xor) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* v = b.Var(ty.ptr<function, i32>());
+ auto* v = b.Var("v", ty.ptr<function, i32>());
b.Store(v, b.Xor(ty.i32(), b.Load(v), 8_i));
- mod.SetName(v, "v");
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
var v : i32;
v = (v ^ 8i);
@@ -1288,10 +1399,8 @@
////////////////////////////////////////////////////////////////////////////////
TEST_F(IRToProgramTest, LetUsedOnce) {
auto* fn = b.Function("f", ty.u32());
- auto* i = b.FunctionParam(ty.u32());
- mod.SetName(i, "i");
+ auto* i = b.FunctionParam("i", ty.u32());
fn->SetParams({i});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* v = b.Complement(ty.u32(), i);
@@ -1299,7 +1408,7 @@
mod.SetName(v, "v");
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(i : u32) -> u32 {
let v = ~(i);
return v;
@@ -1309,10 +1418,8 @@
TEST_F(IRToProgramTest, LetUsedTwice) {
auto* fn = b.Function("f", ty.i32());
- auto* i = b.FunctionParam(ty.i32());
- mod.SetName(i, "i");
+ auto* i = b.FunctionParam("i", ty.i32());
fn->SetParams({i});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* v = b.Multiply(ty.i32(), i, 2_i);
@@ -1320,7 +1427,7 @@
mod.SetName(v, "v");
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(i : i32) -> i32 {
let v = (i * 2i);
return (v + v);
@@ -1333,14 +1440,12 @@
////////////////////////////////////////////////////////////////////////////////
TEST_F(IRToProgramTest, FunctionScopeVar_i32) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
- b.With(fn->Block(), [&] {
- auto* i = b.Var(ty.ptr<function, i32>());
- mod.SetName(i, "i");
+ b.With(fn->Block(), [&] { //
+ b.Var("i", ty.ptr<function, i32>());
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
var i : i32;
}
@@ -1349,15 +1454,13 @@
TEST_F(IRToProgramTest, FunctionScopeVar_i32_InitLiteral) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* i = b.Var(ty.ptr<function, i32>());
+ auto* i = b.Var("i", ty.ptr<function, i32>());
i->SetInitializer(b.Constant(42_i));
- mod.SetName(i, "i");
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
var i : i32 = 42i;
}
@@ -1366,26 +1469,21 @@
TEST_F(IRToProgramTest, FunctionScopeVar_Chained) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* va = b.Var(ty.ptr<function, i32>());
+ auto* va = b.Var("a", ty.ptr<function, i32>());
va->SetInitializer(b.Constant(42_i));
auto* la = b.Load(va)->Result();
- auto* vb = b.Var(ty.ptr<function, i32>());
+ auto* vb = b.Var("b", ty.ptr<function, i32>());
vb->SetInitializer(la);
auto* lb = b.Load(vb)->Result();
- auto* vc = b.Var(ty.ptr<function, i32>());
+ auto* vc = b.Var("c", ty.ptr<function, i32>());
vc->SetInitializer(lb);
-
- mod.SetName(va, "a");
- mod.SetName(vb, "b");
- mod.SetName(vc, "c");
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
var a : i32 = 42i;
var b : i32 = a;
@@ -1399,13 +1497,10 @@
////////////////////////////////////////////////////////////////////////////////
TEST_F(IRToProgramTest, If_CallFn) {
auto* a = b.Function("a", ty.void_());
- mod.functions.Push(a);
auto* fn = b.Function("f", ty.void_());
- auto* cond = b.FunctionParam(ty.bool_());
- mod.SetName(cond, "cond");
+ auto* cond = b.FunctionParam("cond", ty.bool_());
fn->SetParams({cond});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* if_ = b.If(cond);
@@ -1415,7 +1510,7 @@
});
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn a() {
}
@@ -1429,17 +1524,15 @@
TEST_F(IRToProgramTest, If_Return) {
auto* fn = b.Function("f", ty.void_());
- auto* cond = b.FunctionParam(ty.bool_());
- mod.SetName(cond, "cond");
+ auto* cond = b.FunctionParam("cond", ty.bool_());
fn->SetParams({cond});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto if_ = b.If(cond);
b.With(if_->True(), [&] { b.Return(fn); });
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(cond : bool) {
if (cond) {
return;
@@ -1450,18 +1543,16 @@
TEST_F(IRToProgramTest, If_Return_i32) {
auto* fn = b.Function("f", ty.i32());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* cond = b.Var(ty.ptr<function, bool>());
- mod.SetName(cond, "cond");
+ auto* cond = b.Var("cond", ty.ptr<function, bool>());
cond->SetInitializer(b.Constant(true));
auto if_ = b.If(b.Load(cond));
b.With(if_->True(), [&] { b.Return(fn, 42_i); });
b.Return(fn, 10_i);
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() -> i32 {
var cond : bool = true;
if (cond) {
@@ -1474,16 +1565,12 @@
TEST_F(IRToProgramTest, If_CallFn_Else_CallFn) {
auto* fn_a = b.Function("a", ty.void_());
- mod.functions.Push(fn_a);
auto* fn_b = b.Function("b", ty.void_());
- mod.functions.Push(fn_b);
auto* fn = b.Function("f", ty.void_());
- auto* cond = b.FunctionParam(ty.bool_());
- mod.SetName(cond, "cond");
+ auto* cond = b.FunctionParam("cond", ty.bool_());
fn->SetParams({cond});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto if_ = b.If(cond);
@@ -1497,7 +1584,7 @@
});
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn a() {
}
@@ -1516,18 +1603,16 @@
TEST_F(IRToProgramTest, If_Return_f32_Else_Return_f32) {
auto* fn = b.Function("f", ty.f32());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* cond = b.Var(ty.ptr<function, bool>());
- mod.SetName(cond, "cond");
+ auto* cond = b.Var("cond", ty.ptr<function, bool>());
cond->SetInitializer(b.Constant(true));
auto if_ = b.If(b.Load(cond));
b.With(if_->True(), [&] { b.Return(fn, 1.0_f); });
b.With(if_->False(), [&] { b.Return(fn, 2.0_f); });
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() -> f32 {
var cond : bool = true;
if (cond) {
@@ -1541,17 +1626,13 @@
TEST_F(IRToProgramTest, If_Return_u32_Else_CallFn) {
auto* fn_a = b.Function("a", ty.void_());
- mod.functions.Push(fn_a);
auto* fn_b = b.Function("b", ty.void_());
- mod.functions.Push(fn_b);
auto* fn = b.Function("f", ty.u32());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* cond = b.Var(ty.ptr<function, bool>());
- mod.SetName(cond, "cond");
+ auto* cond = b.Var("cond", ty.ptr<function, bool>());
cond->SetInitializer(b.Constant(true));
auto if_ = b.If(b.Load(cond));
b.With(if_->True(), [&] { b.Return(fn, 1_u); });
@@ -1563,7 +1644,7 @@
b.Return(fn, 2_u);
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn a() {
}
@@ -1585,20 +1666,15 @@
TEST_F(IRToProgramTest, If_CallFn_ElseIf_CallFn) {
auto* fn_a = b.Function("a", ty.void_());
- mod.functions.Push(fn_a);
auto* fn_b = b.Function("b", ty.void_());
- mod.functions.Push(fn_b);
auto* fn_c = b.Function("c", ty.void_());
- mod.functions.Push(fn_c);
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* cond = b.Var(ty.ptr<function, bool>());
- mod.SetName(cond, "cond");
+ auto* cond = b.Var("cond", ty.ptr<function, bool>());
cond->SetInitializer(b.Constant(true));
auto if1 = b.If(b.Load(cond));
b.With(if1->True(), [&] {
@@ -1616,7 +1692,7 @@
b.Call(ty.void_(), fn_c);
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn a() {
}
@@ -1640,23 +1716,16 @@
TEST_F(IRToProgramTest, If_Else_Chain) {
auto* x = b.Function("x", ty.bool_());
- auto* i = b.FunctionParam(ty.i32());
- mod.SetName(i, "i");
+ auto* i = b.FunctionParam("i", ty.i32());
x->SetParams({i});
- mod.functions.Push(x);
b.With(x->Block(), [&] { b.Return(x, true); });
auto* fn = b.Function("f", ty.void_());
- auto* pa = b.FunctionParam(ty.bool_());
- auto* pb = b.FunctionParam(ty.bool_());
- auto* pc = b.FunctionParam(ty.bool_());
- auto* pd = b.FunctionParam(ty.bool_());
- mod.SetName(pa, "a");
- mod.SetName(pb, "b");
- mod.SetName(pc, "c");
- mod.SetName(pd, "d");
+ auto* pa = b.FunctionParam("a", ty.bool_());
+ auto* pb = b.FunctionParam("b", ty.bool_());
+ auto* pc = b.FunctionParam("c", ty.bool_());
+ auto* pd = b.FunctionParam("d", ty.bool_());
fn->SetParams({pa, pb, pc, pd});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto if1 = b.If(pa);
@@ -1683,7 +1752,7 @@
});
});
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn x(i : i32) -> bool {
return true;
}
@@ -1707,14 +1776,11 @@
////////////////////////////////////////////////////////////////////////////////
TEST_F(IRToProgramTest, Switch_Default) {
auto* fn_a = b.Function("a", ty.void_());
- mod.functions.Push(fn_a);
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* v = b.Var(ty.ptr<function, i32>());
- mod.SetName(v, "v");
+ auto* v = b.Var("v", ty.ptr<function, i32>());
v->SetInitializer(b.Constant(42_i));
auto s = b.Switch(b.Load(v));
@@ -1724,7 +1790,7 @@
});
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn a() {
}
@@ -1741,20 +1807,15 @@
TEST_F(IRToProgramTest, Switch_3_Cases) {
auto* fn_a = b.Function("a", ty.void_());
- mod.functions.Push(fn_a);
auto* fn_b = b.Function("b", ty.void_());
- mod.functions.Push(fn_b);
auto* fn_c = b.Function("c", ty.void_());
- mod.functions.Push(fn_c);
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* v = b.Var(ty.ptr<function, i32>());
- mod.SetName(v, "v");
+ auto* v = b.Var("v", ty.ptr<function, i32>());
v->SetInitializer(b.Constant(42_i));
auto s = b.Switch(b.Load(v));
@@ -1777,7 +1838,7 @@
});
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn a() {
}
@@ -1806,14 +1867,11 @@
TEST_F(IRToProgramTest, Switch_3_Cases_AllReturn) {
auto* fn_a = b.Function("a", ty.void_());
- mod.functions.Push(fn_a);
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* v = b.Var(ty.ptr<function, i32>());
- mod.SetName(v, "v");
+ auto* v = b.Var("v", ty.ptr<function, i32>());
v->SetInitializer(b.Constant(42_i));
auto s = b.Switch(b.Load(v));
@@ -1830,7 +1888,7 @@
b.Return(fn);
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn a() {
}
@@ -1854,24 +1912,18 @@
TEST_F(IRToProgramTest, Switch_Nested) {
auto* fn_a = b.Function("a", ty.void_());
- mod.functions.Push(fn_a);
- auto* fn_b = b.Function("b", ty.void_());
- mod.functions.Push(fn_b);
+ b.Function("b", ty.void_());
auto* fn_c = b.Function("c", ty.void_());
- mod.functions.Push(fn_c);
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* v1 = b.Var(ty.ptr<function, i32>());
- mod.SetName(v1, "v1");
+ auto* v1 = b.Var("v1", ty.ptr<function, i32>());
v1->SetInitializer(b.Constant(42_i));
- auto* v2 = b.Var(ty.ptr<function, i32>());
- mod.SetName(v2, "v2");
+ auto* v2 = b.Var("v2", ty.ptr<function, i32>());
v2->SetInitializer(b.Constant(24_i));
auto s1 = b.Switch(b.Load(v1));
@@ -1900,7 +1952,7 @@
b.ExitSwitch(s1);
});
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn a() {
}
@@ -1939,14 +1991,12 @@
////////////////////////////////////////////////////////////////////////////////
TEST_F(IRToProgramTest, For_Empty) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* loop = b.Loop();
b.With(loop->Initializer(), [&] {
- auto* i = b.Var(ty.ptr<function, i32>());
- mod.SetName(i, "i");
+ auto* i = b.Var("i", ty.ptr<function, i32>());
i->SetInitializer(b.Constant(0_i));
b.With(loop->Body(), [&] {
@@ -1959,7 +2009,7 @@
});
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
for(var i : i32 = 0i; (i < 5i); i = (i + 1i)) {
}
@@ -1969,11 +2019,9 @@
TEST_F(IRToProgramTest, For_Empty_NoInit) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* i = b.Var(ty.ptr<function, i32>());
- mod.SetName(i, "i");
+ auto* i = b.Var("i", ty.ptr<function, i32>());
i->SetInitializer(b.Constant(0_i));
auto* loop = b.Loop();
@@ -1987,7 +2035,7 @@
b.With(loop->Continuing(), [&] { b.Store(i, b.Add(ty.i32(), b.Load(i), 1_i)); });
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
var i : i32 = 0i;
for(; (i < 5i); i = (i + 1i)) {
@@ -1998,14 +2046,12 @@
TEST_F(IRToProgramTest, For_Empty_NoCont) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* loop = b.Loop();
b.With(loop->Initializer(), [&] {
- auto* i = b.Var(ty.ptr<function, i32>());
- mod.SetName(i, "i");
+ auto* i = b.Var("i", ty.ptr<function, i32>());
i->SetInitializer(b.Constant(0_i));
b.With(loop->Body(), [&] {
@@ -2016,7 +2062,7 @@
});
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
for(var i : i32 = 0i; (i < 5i); ) {
}
@@ -2026,21 +2072,17 @@
TEST_F(IRToProgramTest, For_ComplexBody) {
auto* a = b.Function("a", ty.bool_());
- auto* v = b.FunctionParam(ty.i32());
- mod.SetName(v, "v");
+ auto* v = b.FunctionParam("v", ty.i32());
a->SetParams({v});
b.With(a->Block(), [&] { b.Return(a, b.Equal(ty.bool_(), v, 1_i)); });
- mod.functions.Push(a);
auto* fn = b.Function("f", ty.i32());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* loop = b.Loop();
b.With(loop->Initializer(), [&] {
- auto* i = b.Var(ty.ptr<function, i32>());
- mod.SetName(i, "i");
+ auto* i = b.Var("i", ty.ptr<function, i32>());
i->SetInitializer(b.Constant(0_i));
b.With(loop->Body(), [&] {
@@ -2059,7 +2101,7 @@
b.Return(fn, 3_i);
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn a(v : i32) -> bool {
return (v == 1i);
}
@@ -2079,18 +2121,14 @@
TEST_F(IRToProgramTest, For_ComplexBody_NoInit) {
auto* a = b.Function("a", ty.bool_());
- auto* v = b.FunctionParam(ty.i32());
- mod.SetName(v, "v");
+ auto* v = b.FunctionParam("v", ty.i32());
a->SetParams({v});
b.With(a->Block(), [&] { b.Return(a, b.Equal(ty.bool_(), v, 1_i)); });
- mod.functions.Push(a);
auto* fn = b.Function("f", ty.i32());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* i = b.Var(ty.ptr<function, i32>());
- mod.SetName(i, "i");
+ auto* i = b.Var("i", ty.ptr<function, i32>());
i->SetInitializer(b.Constant(0_i));
auto* loop = b.Loop();
@@ -2110,7 +2148,7 @@
b.Return(fn, 3_i);
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn a(v : i32) -> bool {
return (v == 1i);
}
@@ -2131,21 +2169,17 @@
TEST_F(IRToProgramTest, For_ComplexBody_NoCont) {
auto* a = b.Function("a", ty.bool_());
- auto* v = b.FunctionParam(ty.i32());
- mod.SetName(v, "v");
+ auto* v = b.FunctionParam("v", ty.i32());
a->SetParams({v});
b.With(a->Block(), [&] { b.Return(a, b.Equal(ty.bool_(), v, 1_i)); });
- mod.functions.Push(a);
auto* fn = b.Function("f", ty.i32());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* loop = b.Loop();
b.With(loop->Initializer(), [&] {
- auto* i = b.Var(ty.ptr<function, i32>());
- mod.SetName(i, "i");
+ auto* i = b.Var("i", ty.ptr<function, i32>());
i->SetInitializer(b.Constant(0_i));
b.With(loop->Body(), [&] {
@@ -2161,7 +2195,7 @@
b.Return(fn, 3_i);
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn a(v : i32) -> bool {
return (v == 1i);
}
@@ -2181,26 +2215,24 @@
TEST_F(IRToProgramTest, For_CallInInitCondCont) {
auto* fn_n = b.Function("n", ty.i32());
- auto* v = b.FunctionParam(ty.i32());
- mod.SetName(v, "v");
+ auto* v = b.FunctionParam("v", ty.i32());
fn_n->SetParams({v});
b.With(fn_n->Block(), [&] { b.Return(fn_n, b.Add(ty.i32(), v, 1_i)); });
- mod.functions.Push(fn_n);
auto* fn_f = b.Function("f", ty.void_());
- mod.functions.Push(fn_f);
b.With(fn_f->Block(), [&] {
auto* loop = b.Loop();
b.With(loop->Initializer(), [&] {
auto* n_0 = b.Call(ty.i32(), fn_n, 0_i)->Result();
- auto* i = b.Var(ty.ptr<function, i32>());
- mod.SetName(i, "i");
+ auto* i = b.Var("i", ty.ptr<function, i32>());
i->SetInitializer(n_0);
b.With(loop->Body(), [&] {
- auto* if_ = b.If(b.LessThan(ty.bool_(), b.Load(i), b.Call(ty.i32(), fn_n, 1_i)));
+ auto* load = b.Load(i);
+ auto* call = b.Call(ty.i32(), fn_n, 1_i);
+ auto* if_ = b.If(b.LessThan(ty.bool_(), load, call));
b.With(if_->True(), [&] { b.ExitIf(if_); });
b.With(if_->False(), [&] { b.ExitLoop(loop); });
});
@@ -2209,7 +2241,7 @@
});
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn n(v : i32) -> i32 {
return (v + 1i);
}
@@ -2226,7 +2258,6 @@
////////////////////////////////////////////////////////////////////////////////
TEST_F(IRToProgramTest, While_Empty) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* loop = b.Loop();
@@ -2238,7 +2269,7 @@
});
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
while(true) {
}
@@ -2248,10 +2279,8 @@
TEST_F(IRToProgramTest, While_Cond) {
auto* fn = b.Function("f", ty.void_());
- auto* cond = b.FunctionParam(ty.bool_());
- mod.SetName(cond, "cond");
+ auto* cond = b.FunctionParam("cond", ty.bool_());
fn->SetParams({cond});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* loop = b.Loop();
@@ -2263,7 +2292,7 @@
});
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(cond : bool) {
while(cond) {
}
@@ -2273,7 +2302,6 @@
TEST_F(IRToProgramTest, While_Break) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* loop = b.Loop();
@@ -2286,7 +2314,7 @@
});
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
while(true) {
break;
@@ -2297,10 +2325,8 @@
TEST_F(IRToProgramTest, While_IfBreak) {
auto* fn = b.Function("f", ty.void_());
- auto* cond = b.FunctionParam(ty.bool_());
- mod.SetName(cond, "cond");
+ auto* cond = b.FunctionParam("cond", ty.bool_());
fn->SetParams({cond});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* loop = b.Loop();
@@ -2315,7 +2341,7 @@
});
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(cond : bool) {
while(true) {
if (cond) {
@@ -2328,10 +2354,8 @@
TEST_F(IRToProgramTest, While_IfReturn) {
auto* fn = b.Function("f", ty.void_());
- auto* cond = b.FunctionParam(ty.bool_());
- mod.SetName(cond, "cond");
+ auto* cond = b.FunctionParam("cond", ty.bool_());
fn->SetParams({cond});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* loop = b.Loop();
@@ -2346,7 +2370,7 @@
});
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(cond : bool) {
while(true) {
if (cond) {
@@ -2362,7 +2386,6 @@
////////////////////////////////////////////////////////////////////////////////
TEST_F(IRToProgramTest, Loop_Break) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* loop = b.Loop();
@@ -2370,7 +2393,7 @@
b.With(loop->Body(), [&] { b.ExitLoop(loop); });
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
loop {
break;
@@ -2381,10 +2404,8 @@
TEST_F(IRToProgramTest, Loop_IfBreak) {
auto* fn = b.Function("f", ty.void_());
- auto* cond = b.FunctionParam(ty.bool_());
- mod.SetName(cond, "cond");
+ auto* cond = b.FunctionParam("cond", ty.bool_());
fn->SetParams({cond});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* loop = b.Loop();
@@ -2394,7 +2415,7 @@
b.With(if_->True(), [&] { b.ExitLoop(loop); });
});
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(cond : bool) {
loop {
if (cond) {
@@ -2407,10 +2428,8 @@
TEST_F(IRToProgramTest, Loop_IfReturn) {
auto* fn = b.Function("f", ty.void_());
- auto* cond = b.FunctionParam(ty.bool_());
- mod.SetName(cond, "cond");
+ auto* cond = b.FunctionParam("cond", ty.bool_());
fn->SetParams({cond});
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
auto* loop = b.Loop();
@@ -2421,7 +2440,7 @@
});
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f(cond : bool) {
loop {
if (cond) {
@@ -2434,12 +2453,10 @@
TEST_F(IRToProgramTest, Loop_IfContinuing) {
auto* fn = b.Function("f", ty.void_());
- mod.functions.Push(fn);
b.With(fn->Block(), [&] {
- auto* cond = b.Var(ty.ptr<function, bool>());
+ auto* cond = b.Var("cond", ty.ptr<function, bool>());
cond->SetInitializer(b.Constant(false));
- mod.SetName(cond, "cond");
auto* loop = b.Loop();
@@ -2451,7 +2468,7 @@
b.With(loop->Continuing(), [&] { b.Store(cond, true); });
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
var cond : bool = false;
loop {
@@ -2469,30 +2486,32 @@
TEST_F(IRToProgramTest, Loop_VarsDeclaredOutsideAndInside) {
auto* f = b.Function("f", ty.void_());
- mod.functions.Push(f);
b.With(f->Block(), [&] {
- auto* var_b = b.Var(ty.ptr<function, i32>());
+ auto* var_b = b.Var("b", ty.ptr<function, i32>());
var_b->SetInitializer(b.Constant(1_i));
- mod.SetName(var_b, "b");
auto* loop = b.Loop();
b.With(loop->Body(), [&] {
- auto* var_a = b.Var(ty.ptr<function, i32>());
+ auto* var_a = b.Var("a", ty.ptr<function, i32>());
var_a->SetInitializer(b.Constant(2_i));
- mod.SetName(var_a, "a");
- auto* if_ = b.If(b.Equal(ty.bool_(), b.Load(var_a), b.Load(var_b)));
+ auto* body_load_a = b.Load(var_a);
+ auto* body_load_b = b.Load(var_b);
+ auto* if_ = b.If(b.Equal(ty.bool_(), body_load_a, body_load_b));
b.With(if_->True(), [&] { b.Return(f); });
b.With(if_->False(), [&] { b.ExitIf(if_); });
- b.With(loop->Continuing(),
- [&] { b.Store(var_b, b.Add(ty.i32(), b.Load(var_a), b.Load(var_b))); });
+ b.With(loop->Continuing(), [&] {
+ auto* cont_load_a = b.Load(var_a);
+ auto* cont_load_b = b.Load(var_b);
+ b.Store(var_b, b.Add(ty.i32(), cont_load_a, cont_load_b));
+ });
});
});
- Test(R"(
+ EXPECT_WGSL(R"(
fn f() {
var b : i32 = 1i;
loop {
@@ -2510,4 +2529,4 @@
}
} // namespace
-} // namespace tint::ir
+} // namespace tint::ir::test
diff --git a/src/tint/ir/to_program_test.h b/src/tint/ir/to_program_test.h
new file mode 100644
index 0000000..ccf13ce
--- /dev/null
+++ b/src/tint/ir/to_program_test.h
@@ -0,0 +1,61 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_TO_PROGRAM_TEST_H_
+#define SRC_TINT_IR_TO_PROGRAM_TEST_H_
+
+#include <string>
+
+#include "src/tint/ir/ir_test_helper.h"
+
+#if !TINT_BUILD_WGSL_WRITER
+#error "to_program_test.h requires the WGSL writer to be enabled"
+#endif
+
+namespace tint::ir::test {
+
+/// Class used for IR to Program tests
+class IRToProgramTest : public IRTestHelper {
+ public:
+ /// The result of Run()
+ struct Result {
+ /// The resulting WGSL
+ std::string wgsl;
+ /// The resulting AST
+ std::string ast;
+ /// The resulting IR
+ std::string ir;
+ /// The resulting error
+ std::string err;
+ };
+ /// @returns the WGSL generated from the IR
+ Result Run();
+};
+
+#define EXPECT_WGSL(expected_wgsl) \
+ do { \
+ if (auto got = Run(); got.err.empty()) { \
+ auto expected = std::string(utils::TrimSpace(expected_wgsl)); \
+ if (!expected.empty()) { \
+ expected = "\n" + expected + "\n"; \
+ } \
+ EXPECT_EQ(expected, got.wgsl) << "IR: " << got.ir; \
+ } else { \
+ FAIL() << got.err << std::endl << "IR: " << got.ir << "AST:" << std::endl; \
+ } \
+ } while (false)
+
+} // namespace tint::ir::test
+
+#endif // SRC_TINT_IR_TO_PROGRAM_TEST_H_
diff --git a/src/tint/ir/transform/add_empty_entry_point.cc b/src/tint/ir/transform/add_empty_entry_point.cc
index 5781e93..985c2ff 100644
--- a/src/tint/ir/transform/add_empty_entry_point.cc
+++ b/src/tint/ir/transform/add_empty_entry_point.cc
@@ -38,7 +38,6 @@
auto* ep = builder.Function("unused_entry_point", ir->Types().void_(),
Function::PipelineStage::kCompute, std::array{1u, 1u, 1u});
ep->Block()->Append(builder.Return(ep));
- ir->functions.Push(ep);
}
} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/add_empty_entry_point_test.cc b/src/tint/ir/transform/add_empty_entry_point_test.cc
index 567feae..d74ed0e 100644
--- a/src/tint/ir/transform/add_empty_entry_point_test.cc
+++ b/src/tint/ir/transform/add_empty_entry_point_test.cc
@@ -40,7 +40,6 @@
TEST_F(IR_AddEmptyEntryPointTest, ExistingEntryPoint) {
auto* ep = b.Function("main", mod.Types().void_(), Function::PipelineStage::kFragment);
ep->Block()->Append(b.Return(ep));
- mod.functions.Push(ep);
auto* expect = R"(
%main = @fragment func():void -> %b1 {
diff --git a/src/tint/ir/transform/block_decorated_structs_test.cc b/src/tint/ir/transform/block_decorated_structs_test.cc
index 2d2a856..2d93a2d 100644
--- a/src/tint/ir/transform/block_decorated_structs_test.cc
+++ b/src/tint/ir/transform/block_decorated_structs_test.cc
@@ -32,7 +32,6 @@
TEST_F(IR_BlockDecoratedStructsTest, NoRootBlock) {
auto* func = b.Function("foo", ty.void_());
func->Block()->Append(b.Return(func));
- mod.functions.Push(func);
auto* expect = R"(
%foo = func():void -> %b1 {
@@ -57,7 +56,6 @@
auto* block = func->Block();
auto* load = block->Append(b.Load(buffer));
block->Append(b.Return(func, load));
- mod.functions.Push(func);
auto* expect = R"(
tint_symbol_1 = struct @align(4), @block {
@@ -90,7 +88,6 @@
auto* func = b.Function("foo", ty.void_());
func->Block()->Append(b.Store(buffer, 42_i));
func->Block()->Append(b.Return(func));
- mod.functions.Push(func);
auto* expect = R"(
tint_symbol_1 = struct @align(4), @block {
@@ -128,8 +125,6 @@
b.Return(func);
});
- mod.functions.Push(func);
-
auto* expect = R"(
tint_symbol_1 = struct @align(4), @block {
tint_symbol:array<i32> @offset(0)
@@ -177,8 +172,6 @@
b.Return(func);
});
- mod.functions.Push(func);
-
auto* expect = R"(
MyStruct = struct @align(4) {
i:i32 @offset(0)
@@ -226,7 +219,6 @@
auto* func = b.Function("foo", ty.void_());
func->Block()->Append(b.Store(buffer, private_var));
func->Block()->Append(b.Return(func));
- mod.functions.Push(func);
auto* expect = R"(
MyStruct = struct @align(4) {
@@ -275,7 +267,6 @@
auto* load_c = block->Append(b.Load(buffer_c));
block->Append(b.Store(buffer_a, b.Add(ty.i32(), load_b, load_c)));
block->Append(b.Return(func));
- mod.functions.Push(func);
auto* expect = R"(
tint_symbol_1 = struct @align(4), @block {
diff --git a/src/tint/ir/transform/merge_return.cc b/src/tint/ir/transform/merge_return.cc
index d2d4259..af42ab9 100644
--- a/src/tint/ir/transform/merge_return.cc
+++ b/src/tint/ir/transform/merge_return.cc
@@ -78,16 +78,14 @@
}
// Create a boolean variable that can be used to check whether the function is returning.
- continue_execution = b.Var(ty.ptr<function, bool>());
+ continue_execution = b.Var("continue_execution", ty.ptr<function, bool>());
continue_execution->SetInitializer(b.Constant(true));
fn->Block()->Prepend(continue_execution);
- ir->SetName(continue_execution, "continue_execution");
// Create a variable to hold the return value if needed.
if (!fn->ReturnType()->Is<type::Void>()) {
- return_val = b.Var(ty.ptr(function, fn->ReturnType()));
+ return_val = b.Var("return_value", ty.ptr(function, fn->ReturnType()));
fn->Block()->Prepend(return_val);
- ir->SetName(return_val, "return_value");
}
// Look to see if the function ends with a return
diff --git a/src/tint/ir/transform/merge_return_test.cc b/src/tint/ir/transform/merge_return_test.cc
index c54f809..0991bb5 100644
--- a/src/tint/ir/transform/merge_return_test.cc
+++ b/src/tint/ir/transform/merge_return_test.cc
@@ -30,7 +30,6 @@
auto* in = b.FunctionParam(ty.i32());
auto* func = b.Function("foo", ty.i32());
func->SetParams({in});
- mod.functions.Push(func);
b.With(func->Block(), [&] { b.Return(func, b.Add(ty.i32(), in, 1_i)); });
@@ -56,7 +55,6 @@
auto* cond = b.FunctionParam(ty.bool_());
auto* func = b.Function("foo", ty.i32());
func->SetParams({in});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* ifelse = b.If(cond);
@@ -97,7 +95,6 @@
auto* cond = b.FunctionParam(ty.bool_());
auto* func = b.Function("foo", ty.i32());
func->SetParams({in});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* swtch = b.Switch(in);
@@ -150,7 +147,6 @@
auto* cond = b.FunctionParam(ty.bool_());
auto* func = b.Function("foo", ty.void_());
func->SetParams({cond});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* ifelse = b.If(cond);
@@ -204,7 +200,6 @@
auto* cond = b.FunctionParam(ty.bool_());
auto* func = b.Function("foo", ty.void_());
func->SetParams({cond});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* ifelse = b.If(cond);
@@ -256,7 +251,6 @@
auto* cond = b.FunctionParam(ty.bool_());
auto* func = b.Function("foo", ty.i32());
func->SetParams({cond});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* ifelse = b.If(cond);
@@ -320,7 +314,6 @@
auto* cond = b.FunctionParam(ty.bool_());
auto* func = b.Function("foo", ty.i32());
func->SetParams({cond});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* ifelse = b.If(cond);
@@ -385,7 +378,6 @@
auto* cond = b.FunctionParam(ty.bool_());
auto* func = b.Function("foo", ty.i32());
func->SetParams({cond});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* ifelse = b.If(cond);
@@ -450,7 +442,6 @@
auto* cond = b.FunctionParam(ty.bool_());
auto* func = b.Function("foo", ty.void_());
func->SetParams({cond});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* ifelse = b.If(cond);
@@ -506,7 +497,6 @@
auto* cond = b.FunctionParam(ty.bool_());
auto* func = b.Function("foo", ty.void_());
func->SetParams({cond});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* ifelse = b.If(cond);
@@ -583,7 +573,6 @@
auto* cond = b.FunctionParam(ty.bool_());
auto* func = b.Function("foo", ty.void_());
func->SetParams({cond});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* ifelse = b.If(cond);
@@ -656,14 +645,10 @@
b.RootBlock()->Append(global);
auto* func = b.Function("foo", ty.i32());
- auto* condA = b.FunctionParam(ty.bool_());
- auto* condB = b.FunctionParam(ty.bool_());
- auto* condC = b.FunctionParam(ty.bool_());
- mod.SetName(condA, "condA");
- mod.SetName(condB, "condB");
- mod.SetName(condC, "condC");
+ auto* condA = b.FunctionParam("condA", ty.bool_());
+ auto* condB = b.FunctionParam("condB", ty.bool_());
+ auto* condC = b.FunctionParam("condC", ty.bool_());
func->SetParams({condA, condB, condC});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* ifelse_outer = b.If(condA);
@@ -806,14 +791,10 @@
b.RootBlock()->Append(global);
auto* func = b.Function("foo", ty.i32());
- auto* condA = b.FunctionParam(ty.bool_());
- auto* condB = b.FunctionParam(ty.bool_());
- auto* condC = b.FunctionParam(ty.bool_());
- mod.SetName(condA, "condA");
- mod.SetName(condB, "condB");
- mod.SetName(condC, "condC");
+ auto* condA = b.FunctionParam("condA", ty.bool_());
+ auto* condB = b.FunctionParam("condB", ty.bool_());
+ auto* condC = b.FunctionParam("condC", ty.bool_());
func->SetParams({condA, condB, condC});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* ifelse_outer = b.If(condA);
@@ -931,14 +912,10 @@
b.RootBlock()->Append(global);
auto* func = b.Function("foo", ty.i32());
- auto* condA = b.FunctionParam(ty.bool_());
- auto* condB = b.FunctionParam(ty.bool_());
- auto* condC = b.FunctionParam(ty.bool_());
- mod.SetName(condA, "condA");
- mod.SetName(condB, "condB");
- mod.SetName(condC, "condC");
+ auto* condA = b.FunctionParam("condA", ty.bool_());
+ auto* condB = b.FunctionParam("condB", ty.bool_());
+ auto* condC = b.FunctionParam("condC", ty.bool_());
func->SetParams({condA, condB, condC});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* ifelse_outer = b.If(condA);
@@ -1080,7 +1057,6 @@
TEST_F(IR_MergeReturnTest, Loop_UnconditionalReturnInBody) {
auto* func = b.Function("foo", ty.i32());
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* loop = b.Loop();
@@ -1130,7 +1106,6 @@
auto* cond = b.FunctionParam(ty.bool_());
auto* func = b.Function("foo", ty.i32());
func->SetParams({cond});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* loop = b.Loop();
@@ -1145,7 +1120,7 @@
b.With(loop->Continuing(), [&] {
b.Store(global, 1_i);
- b.BreakIf(true, loop);
+ b.BreakIf(loop, true);
});
b.Store(global, 3_i);
@@ -1245,7 +1220,6 @@
auto* cond = b.FunctionParam(ty.bool_());
auto* func = b.Function("foo", ty.i32());
func->SetParams({cond});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* loop = b.Loop();
@@ -1351,7 +1325,6 @@
auto* cond = b.FunctionParam(ty.bool_());
auto* func = b.Function("foo", ty.i32());
func->SetParams({cond});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* loop = b.Loop();
@@ -1367,7 +1340,7 @@
b.With(loop->Continuing(), [&] {
b.Store(global, 1_i);
- b.BreakIf(true, loop, 4_i);
+ b.BreakIf(loop, true, 4_i);
});
b.Store(global, 3_i);
@@ -1463,7 +1436,6 @@
auto* cond = b.FunctionParam(ty.i32());
auto* func = b.Function("foo", ty.i32());
func->SetParams({cond});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* sw = b.Switch(cond);
@@ -1530,7 +1502,6 @@
auto* cond = b.FunctionParam(ty.i32());
auto* func = b.Function("foo", ty.i32());
func->SetParams({cond});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* sw = b.Switch(cond);
@@ -1634,7 +1605,6 @@
auto* cond = b.FunctionParam(ty.i32());
auto* func = b.Function("foo", ty.i32());
func->SetParams({cond});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* sw = b.Switch(cond);
@@ -1713,7 +1683,6 @@
TEST_F(IR_MergeReturnTest, LoopIfReturnThenContinue) {
auto* func = b.Function("foo", ty.void_());
- mod.functions.Push(func);
b.With(func->Block(), [&] {
auto* loop = b.Loop();
@@ -1776,7 +1745,6 @@
TEST_F(IR_MergeReturnTest, NestedIfsWithReturns) {
auto* func = b.Function("foo", ty.i32());
- mod.functions.Push(func);
b.With(func->Block(), [&] {
b.With(b.If(true)->True(), [&] {
diff --git a/src/tint/ir/transform/shader_io.cc b/src/tint/ir/transform/shader_io.cc
new file mode 100644
index 0000000..6456710
--- /dev/null
+++ b/src/tint/ir/transform/shader_io.cc
@@ -0,0 +1,279 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/shader_io.h"
+
+#include <memory>
+#include <utility>
+
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/module.h"
+#include "src/tint/type/struct.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::ShaderIO);
+TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::ShaderIO::Config);
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ir::transform {
+
+namespace {
+
+builtin::BuiltinValue FunctionParamBuiltin(enum FunctionParam::Builtin builtin) {
+ switch (builtin) {
+ case FunctionParam::Builtin::kVertexIndex:
+ return builtin::BuiltinValue::kVertexIndex;
+ case FunctionParam::Builtin::kInstanceIndex:
+ return builtin::BuiltinValue::kInstanceIndex;
+ case FunctionParam::Builtin::kPosition:
+ return builtin::BuiltinValue::kPosition;
+ case FunctionParam::Builtin::kFrontFacing:
+ return builtin::BuiltinValue::kFrontFacing;
+ case FunctionParam::Builtin::kLocalInvocationId:
+ return builtin::BuiltinValue::kLocalInvocationId;
+ case FunctionParam::Builtin::kLocalInvocationIndex:
+ return builtin::BuiltinValue::kLocalInvocationIndex;
+ case FunctionParam::Builtin::kGlobalInvocationId:
+ return builtin::BuiltinValue::kGlobalInvocationId;
+ case FunctionParam::Builtin::kWorkgroupId:
+ return builtin::BuiltinValue::kWorkgroupId;
+ case FunctionParam::Builtin::kNumWorkgroups:
+ return builtin::BuiltinValue::kNumWorkgroups;
+ case FunctionParam::Builtin::kSampleIndex:
+ return builtin::BuiltinValue::kSampleIndex;
+ case FunctionParam::Builtin::kSampleMask:
+ return builtin::BuiltinValue::kSampleMask;
+ }
+ return builtin::BuiltinValue::kUndefined;
+}
+
+builtin::BuiltinValue ReturnBuiltin(enum Function::ReturnBuiltin builtin) {
+ switch (builtin) {
+ case Function::ReturnBuiltin::kPosition:
+ return builtin::BuiltinValue::kPosition;
+ case Function::ReturnBuiltin::kFragDepth:
+ return builtin::BuiltinValue::kFragDepth;
+ case Function::ReturnBuiltin::kSampleMask:
+ return builtin::BuiltinValue::kSampleMask;
+ }
+ return builtin::BuiltinValue::kUndefined;
+}
+
+} // namespace
+
+ShaderIO::ShaderIO() = default;
+
+ShaderIO::~ShaderIO() = default;
+
+/// PIMPL state for the transform, for a single entry point function.
+struct ShaderIO::State {
+ /// The configuration data.
+ const ShaderIO::Config& config;
+ /// The IR module.
+ Module* ir = nullptr;
+ /// The IR builder.
+ Builder b{*ir};
+ /// The type manager.
+ type::Manager& ty{ir->Types()};
+ /// The set of struct members that need to have their IO attributes stripped.
+ utils::Hashset<const type::StructMember*, 8> members_to_strip;
+
+ /// The entry point currently being processed.
+ Function* func = nullptr;
+
+ /// The backend state object for the current entry point.
+ std::unique_ptr<ShaderIO::BackendState> backend;
+
+ /// Constructor
+ /// @param cfg the transform config
+ /// @param mod the module
+ State(const ShaderIO::Config& cfg, Module* mod) : config(cfg), ir(mod) {}
+
+ /// Process an entry point.
+ /// @param f the original entry point function
+ /// @param bs the backend state object
+ void Process(Function* f, std::unique_ptr<ShaderIO::BackendState> bs) {
+ TINT_SCOPED_ASSIGNMENT(func, f);
+ backend = std::move(bs);
+ TINT_DEFER(backend = nullptr);
+
+ // Process the parameters and return value to prepare for building a wrapper function.
+ GatherInputs();
+ GatherOutput();
+ auto new_params = backend->FinalizeInputs();
+ auto* new_ret_val = backend->FinalizeOutputs();
+
+ // Rename the old function and remove its pipeline stage and workgroup size, as we will be
+ // wrapping it with a new entry point.
+ auto name = ir->NameOf(func).Name();
+ auto stage = func->Stage();
+ auto wgsize = func->WorkgroupSize();
+ ir->SetName(func, name + "_inner");
+ func->SetStage(Function::PipelineStage::kUndefined);
+ func->ClearWorkgroupSize();
+
+ // Create the entry point wrapper function.
+ auto* ep = b.Function(name, new_ret_val ? new_ret_val->Type() : ty.void_());
+ ep->SetStage(stage);
+ if (wgsize) {
+ ep->SetWorkgroupSize((*wgsize)[0], (*wgsize)[1], (*wgsize)[2]);
+ }
+ auto wrapper = b.With(ep->Block());
+
+ // Call the original function, passing it the inputs and capturing its return value.
+ auto inner_call_args = BuildInnerCallArgs(wrapper);
+ auto* inner_result = wrapper.Call(func->ReturnType(), func, std::move(inner_call_args));
+ SetOutputs(wrapper, inner_result->Result());
+
+ // Return the new result.
+ wrapper.Return(ep, new_ret_val);
+ }
+
+ /// Gather the shader inputs.
+ void GatherInputs() {
+ for (auto* param : func->Params()) {
+ if (auto* str = param->Type()->As<type::Struct>()) {
+ for (auto* member : str->Members()) {
+ auto name = str->Name().Name() + "_" + member->Name().Name();
+ backend->AddInput(ir->symbols.Register(name), member->Type(),
+ member->Attributes());
+ members_to_strip.Add(member);
+ }
+ } else {
+ // Pull out the IO attributes and remove them from the parameter.
+ type::StructMemberAttributes attributes;
+ if (auto loc = param->Location()) {
+ attributes.location = loc->value;
+ if (loc->interpolation) {
+ attributes.interpolation = *loc->interpolation;
+ }
+ param->ClearLocation();
+ } else if (auto builtin = param->Builtin()) {
+ attributes.builtin = FunctionParamBuiltin(*builtin);
+ param->ClearBuiltin();
+ }
+ attributes.invariant = param->Invariant();
+ param->SetInvariant(false);
+
+ auto name = ir->NameOf(param);
+ backend->AddInput(name, param->Type(), std::move(attributes));
+ }
+ }
+ }
+
+ /// Gather the shader outputs.
+ void GatherOutput() {
+ if (func->ReturnType()->Is<type::Void>()) {
+ return;
+ }
+
+ if (auto* str = func->ReturnType()->As<type::Struct>()) {
+ for (auto* member : str->Members()) {
+ auto name = str->Name().Name() + "_" + member->Name().Name();
+ backend->AddOutput(ir->symbols.Register(name), member->Type(),
+ member->Attributes());
+ members_to_strip.Add(member);
+ }
+ } else {
+ // Pull out the IO attributes and remove them from the original function.
+ type::StructMemberAttributes attributes;
+ if (auto loc = func->ReturnLocation()) {
+ attributes.location = loc->value;
+ func->ClearReturnLocation();
+ } else if (auto builtin = func->ReturnBuiltin()) {
+ attributes.builtin = ReturnBuiltin(*builtin);
+ func->ClearReturnBuiltin();
+ }
+ attributes.invariant = func->ReturnInvariant();
+ func->SetReturnInvariant(false);
+
+ backend->AddOutput(ir->symbols.New(), func->ReturnType(), std::move(attributes));
+ }
+ }
+
+ /// Build the argument list to call the original entry point function.
+ /// @param builder the IR builder for new instructions
+ /// @returns the argument list
+ utils::Vector<Value*, 4> BuildInnerCallArgs(Builder& builder) {
+ uint32_t input_idx = 0;
+ utils::Vector<Value*, 4> args;
+ for (auto* param : func->Params()) {
+ if (auto* str = param->Type()->As<type::Struct>()) {
+ utils::Vector<Value*, 4> construct_args;
+ for (uint32_t i = 0; i < str->Members().Length(); i++) {
+ construct_args.Push(backend->GetInput(builder, input_idx++));
+ }
+ args.Push(builder.Construct(param->Type(), construct_args)->Result());
+ } else {
+ args.Push(backend->GetInput(builder, input_idx++));
+ }
+ }
+
+ return args;
+ }
+
+ /// Propagate outputs from the inner function call to their final destination.
+ /// @param builder the IR builder for new instructions
+ /// @param inner_result the return value from calling the original entry point function
+ void SetOutputs(Builder& builder, Value* inner_result) {
+ if (auto* str = inner_result->Type()->As<type::Struct>()) {
+ for (auto* member : str->Members()) {
+ Value* from =
+ builder.Access(member->Type(), inner_result, u32(member->Index()))->Result();
+ backend->SetOutput(builder, member->Index(), from);
+ }
+ } else if (!inner_result->Type()->Is<type::Void>()) {
+ backend->SetOutput(builder, 0u, inner_result);
+ }
+ }
+
+ /// Finalize any state needed to complete the transform.
+ void Finalize() {
+ // Remove IO attributes from all structure members that had them prior to this transform.
+ for (auto* member : members_to_strip) {
+ // TODO(crbug.com/tint/745): Remove the const_cast.
+ const_cast<type::StructMember*>(member)->SetAttributes({});
+ }
+ }
+};
+
+void ShaderIO::Run(Module* ir, const DataMap& inputs, DataMap&) const {
+ auto* cfg = inputs.Get<Config>();
+ TINT_ASSERT(Transform, cfg);
+
+ ShaderIO::State state(*cfg, ir);
+ for (auto* func : ir->functions) {
+ // Only process entry points.
+ if (func->Stage() == Function::PipelineStage::kUndefined) {
+ continue;
+ }
+
+ // Skip entry points with no inputs or outputs.
+ if (func->Params().IsEmpty() && func->ReturnType()->Is<type::Void>()) {
+ continue;
+ }
+
+ state.Process(func, MakeBackendState(ir, func));
+ }
+ state.Finalize();
+}
+
+ShaderIO::Config::Config() = default;
+
+ShaderIO::Config::~Config() = default;
+
+ShaderIO::BackendState::~BackendState() = default;
+
+} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/shader_io.h b/src/tint/ir/transform/shader_io.h
new file mode 100644
index 0000000..1f80811
--- /dev/null
+++ b/src/tint/ir/transform/shader_io.h
@@ -0,0 +1,134 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_TRANSFORM_SHADER_IO_H_
+#define SRC_TINT_IR_TRANSFORM_SHADER_IO_H_
+
+#include <memory>
+#include <utility>
+
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/transform/transform.h"
+#include "src/tint/type/manager.h"
+
+namespace tint::ir::transform {
+
+/// ShaderIO is a transform that modifies an entry point function's parameters and return value to
+/// prepare them for backend codegen.
+class ShaderIO : public utils::Castable<ShaderIO, Transform> {
+ public:
+ /// Configuration options for the transform.
+ struct Config final : public utils::Castable<Config, Data> {
+ /// Constructor
+ Config();
+
+ /// Copy constructor
+ Config(const Config&) = default;
+
+ /// Destructor
+ ~Config() override;
+ };
+
+ /// Constructor
+ ShaderIO();
+ /// Destructor
+ ~ShaderIO() override;
+
+ /// @copydoc Transform::Run
+ void Run(ir::Module* module, const DataMap& inputs, DataMap& outputs) const override;
+
+ /// Abstract base class for the state needed to handle IO for a particular backend target.
+ struct BackendState {
+ /// Constructor
+ /// @param mod the IR module
+ /// @param f the entry point function
+ BackendState(Module* mod, Function* f) : ir(mod), func(f) {}
+
+ /// Destructor
+ virtual ~BackendState();
+
+ /// Add an input.
+ /// @param name the name of the input
+ /// @param type the type of the input
+ /// @param attributes the IO attributes
+ virtual void AddInput(Symbol name,
+ const type::Type* type,
+ type::StructMemberAttributes attributes) {
+ inputs.Push({name, type, std::move(attributes)});
+ }
+
+ /// Add an output.
+ /// @param name the name of the output
+ /// @param type the type of the output
+ /// @param attributes the IO attributes
+ virtual void AddOutput(Symbol name,
+ const type::Type* type,
+ type::StructMemberAttributes attributes) {
+ outputs.Push({name, type, std::move(attributes)});
+ }
+
+ /// Finalize the shader inputs and create any state needed for the new entry point function.
+ /// @returns the list of function parameters for the new entry point
+ virtual utils::Vector<FunctionParam*, 4> FinalizeInputs() = 0;
+
+ /// Finalize the shader outputs and create state needed for the new entry point function.
+ /// @returns the return value for the new entry point
+ virtual Value* FinalizeOutputs() = 0;
+
+ /// Get the value of the input at index @p idx
+ /// @param builder the IR builder for new instructions
+ /// @param idx the index of the input
+ /// @returns the value of the input
+ virtual Value* GetInput(Builder& builder, uint32_t idx) = 0;
+
+ /// Set the value of the output at index @p idx
+ /// @param builder the IR builder for new instructions
+ /// @param idx the index of the output
+ /// @param value the value to set
+ virtual void SetOutput(Builder& builder, uint32_t idx, Value* value) = 0;
+
+ protected:
+ /// The IR module.
+ Module* ir = nullptr;
+
+ /// The IR builder.
+ Builder b{*ir};
+
+ /// The type manager.
+ type::Manager& ty{ir->Types()};
+
+ /// The original entry point function.
+ Function* func = nullptr;
+
+ /// The list of shader inputs.
+ utils::Vector<type::Manager::StructMemberDesc, 4> inputs;
+
+ /// The list of shader outputs.
+ utils::Vector<type::Manager::StructMemberDesc, 4> outputs;
+ };
+
+ protected:
+ struct State;
+
+ /// Create a backend state object.
+ /// @param mod the IR module
+ /// @param func the entry point function
+ /// @returns the backend state object
+ virtual std::unique_ptr<ShaderIO::BackendState> MakeBackendState(Module* mod,
+ Function* func) const = 0;
+};
+
+} // namespace tint::ir::transform
+
+#endif // SRC_TINT_IR_TRANSFORM_SHADER_IO_H_
diff --git a/src/tint/ir/transform/shader_io_spirv.cc b/src/tint/ir/transform/shader_io_spirv.cc
new file mode 100644
index 0000000..4a2400e
--- /dev/null
+++ b/src/tint/ir/transform/shader_io_spirv.cc
@@ -0,0 +1,172 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/shader_io_spirv.h"
+
+#include <memory>
+#include <utility>
+
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/module.h"
+#include "src/tint/type/array.h"
+#include "src/tint/type/struct.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::ShaderIOSpirv);
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ir::transform {
+
+ShaderIOSpirv::ShaderIOSpirv() = default;
+
+ShaderIOSpirv::~ShaderIOSpirv() = default;
+
+namespace {
+
+/// PIMPL state for the parts of the shader IO transform specific to SPIR-V.
+/// For SPIR-V, we split builtins and locations into two separate structures each for input and
+/// output, and declare global variables for them. The wrapper entry point then loads from and
+/// stores to these variables.
+/// We also modify the type of the SampleMask builtin to be an array, as required by Vulkan.
+struct StateImpl : ShaderIO::BackendState {
+ /// The global variable for input builtins.
+ Var* builtin_input_var = nullptr;
+ /// The global variable for input locations.
+ Var* location_input_var = nullptr;
+ /// The global variable for output builtins.
+ Var* builtin_output_var = nullptr;
+ /// The global variable for output locations.
+ Var* location_output_var = nullptr;
+ /// The member indices for inputs.
+ utils::Vector<uint32_t, 4> input_indices;
+ /// The member indices for outputs.
+ utils::Vector<uint32_t, 4> output_indices;
+
+ /// Constructor
+ /// @copydoc ShaderIO::BackendState::BackendState
+ using ShaderIO::BackendState::BackendState;
+ /// Destructor
+ ~StateImpl() override {}
+
+ /// Split the members listed in @p entries into two separate structures for builtins and
+ /// locations, and make global variables for them. Record the new member indices in @p indices.
+ /// @param builtin_var the generated global variable for builtins
+ /// @param location_var the generated global variable for locations
+ /// @param indices the new member indices
+ /// @param entries the entries to split
+ /// @param addrspace the address to use for the global variables
+ /// @param access the access mode to use for the global variables
+ /// @param name_suffix the suffix to add to struct and variable names
+ void MakeStructs(Var*& builtin_var,
+ Var*& location_var,
+ utils::Vector<uint32_t, 4>* indices,
+ utils::Vector<type::Manager::StructMemberDesc, 4>& entries,
+ builtin::AddressSpace addrspace,
+ builtin::Access access,
+ const char* name_suffix) {
+ // Build separate lists of builtin and location entries and record their new indices.
+ uint32_t next_builtin_idx = 0;
+ uint32_t next_location_idx = 0;
+ utils::Vector<type::Manager::StructMemberDesc, 4> builtin_members;
+ utils::Vector<type::Manager::StructMemberDesc, 4> location_members;
+ for (auto io : entries) {
+ if (io.attributes.builtin) {
+ // SampleMask must be an array for Vulkan.
+ if (io.attributes.builtin.value() == builtin::BuiltinValue::kSampleMask) {
+ io.type = ty.array<u32, 1>();
+ }
+ builtin_members.Push(io);
+ indices->Push(next_builtin_idx++);
+ } else {
+ location_members.Push(io);
+ indices->Push(next_location_idx++);
+ }
+ }
+
+ // Declare the structs and variables if needed.
+ auto make_struct = [&](auto& members, const char* iotype) {
+ auto name = ir->NameOf(func).Name() + iotype + name_suffix;
+ auto* str = ty.Struct(ir->symbols.New(name + "Struct"), std::move(members));
+ auto* var = b.Var(name, ty.ptr(addrspace, str, access));
+ str->SetStructFlag(type::kBlock);
+ b.RootBlock()->Append(var);
+ return var;
+ };
+ if (!builtin_members.IsEmpty()) {
+ builtin_var = make_struct(builtin_members, "_Builtin");
+ }
+ if (!location_members.IsEmpty()) {
+ location_var = make_struct(location_members, "_Location");
+ }
+ }
+
+ /// @copydoc ShaderIO::BackendState::FinalizeInputs
+ utils::Vector<FunctionParam*, 4> FinalizeInputs() override {
+ MakeStructs(builtin_input_var, location_input_var, &input_indices, inputs,
+ builtin::AddressSpace::kIn, builtin::Access::kRead, "Inputs");
+ return utils::Empty;
+ }
+
+ /// @copydoc ShaderIO::BackendState::FinalizeOutputs
+ Value* FinalizeOutputs() override {
+ MakeStructs(builtin_output_var, location_output_var, &output_indices, outputs,
+ builtin::AddressSpace::kOut, builtin::Access::kWrite, "Outputs");
+ return nullptr;
+ }
+
+ /// @copydoc ShaderIO::BackendState::GetInput
+ Value* GetInput(Builder& builder, uint32_t idx) override {
+ // Load the input from the global variable declared earlier.
+ auto* ptr = ty.ptr(builtin::AddressSpace::kIn, inputs[idx].type, builtin::Access::kRead);
+ Access* from = nullptr;
+ if (inputs[idx].attributes.builtin) {
+ if (inputs[idx].attributes.builtin.value() == builtin::BuiltinValue::kSampleMask) {
+ // SampleMask becomes an array for SPIR-V, so load from the first element.
+ from = builder.Access(ptr, builtin_input_var, u32(input_indices[idx]), 0_u);
+ } else {
+ from = builder.Access(ptr, builtin_input_var, u32(input_indices[idx]));
+ }
+ } else {
+ from = builder.Access(ptr, location_input_var, u32(input_indices[idx]));
+ }
+ return builder.Load(from)->Result();
+ }
+
+ /// @copydoc ShaderIO::BackendState::SetOutput
+ void SetOutput(Builder& builder, uint32_t idx, Value* value) override {
+ // Store the output to the global variable declared earlier.
+ auto* ptr = ty.ptr(builtin::AddressSpace::kOut, outputs[idx].type, builtin::Access::kWrite);
+ Access* to = nullptr;
+ if (outputs[idx].attributes.builtin) {
+ if (outputs[idx].attributes.builtin.value() == builtin::BuiltinValue::kSampleMask) {
+ // SampleMask becomes an array for SPIR-V, so store to the first element.
+ to = builder.Access(ptr, builtin_output_var, u32(output_indices[idx]), 0_u);
+ } else {
+ to = builder.Access(ptr, builtin_output_var, u32(output_indices[idx]));
+ }
+ } else {
+ to = builder.Access(ptr, location_output_var, u32(output_indices[idx]));
+ }
+ builder.Store(to, value);
+ }
+};
+} // namespace
+
+std::unique_ptr<ShaderIO::BackendState> ShaderIOSpirv::MakeBackendState(Module* mod,
+ Function* func) const {
+ return std::make_unique<StateImpl>(mod, func);
+}
+
+} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/shader_io_spirv.h b/src/tint/ir/transform/shader_io_spirv.h
new file mode 100644
index 0000000..f374e9c
--- /dev/null
+++ b/src/tint/ir/transform/shader_io_spirv.h
@@ -0,0 +1,39 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_TRANSFORM_SHADER_IO_SPIRV_H_
+#define SRC_TINT_IR_TRANSFORM_SHADER_IO_SPIRV_H_
+
+#include "src/tint/ir/transform/shader_io.h"
+
+#include <memory>
+
+namespace tint::ir::transform {
+
+/// ShaderIOSpirv is the subclass of the ShaderIO transform used for the SPIR-V backend.
+class ShaderIOSpirv final : public utils::Castable<ShaderIOSpirv, ShaderIO> {
+ public:
+ /// Constructor
+ ShaderIOSpirv();
+ /// Destructor
+ ~ShaderIOSpirv() override;
+
+ /// @copydoc ShaderIO::MakeBackendState
+ std::unique_ptr<ShaderIO::BackendState> MakeBackendState(Module* mod,
+ Function* func) const override;
+};
+
+} // namespace tint::ir::transform
+
+#endif // SRC_TINT_IR_TRANSFORM_SHADER_IO_SPIRV_H_
diff --git a/src/tint/ir/transform/shader_io_test.cc b/src/tint/ir/transform/shader_io_test.cc
new file mode 100644
index 0000000..7f18722
--- /dev/null
+++ b/src/tint/ir/transform/shader_io_test.cc
@@ -0,0 +1,948 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <utility>
+
+#include "src/tint/ir/transform/shader_io.h"
+#include "src/tint/ir/transform/shader_io_spirv.h"
+#include "src/tint/ir/transform/test_helper.h"
+#include "src/tint/type/struct.h"
+
+namespace tint::ir::transform {
+namespace {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+using IR_ShaderIOTest = TransformTest;
+
+TEST_F(IR_ShaderIOTest, NoInputsOrOutputs) {
+ auto* ep = b.Function("foo", ty.void_());
+ ep->SetStage(Function::PipelineStage::kCompute);
+
+ b.With(ep->Block(), [&] { //
+ b.Return(ep);
+ });
+
+ auto* src = R"(
+%foo = @compute func():void -> %b1 {
+ %b1 = block {
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+
+ Transform::DataMap data;
+ data.Add<ShaderIO::Config>();
+ Run<ShaderIOSpirv>(data);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ShaderIOTest, Parameters_NonStruct_Spirv) {
+ auto* ep = b.Function("foo", ty.void_());
+ auto* front_facing = b.FunctionParam("front_facing", ty.bool_());
+ front_facing->SetBuiltin(FunctionParam::Builtin::kFrontFacing);
+ auto* position = b.FunctionParam("position", ty.vec4<f32>());
+ position->SetBuiltin(FunctionParam::Builtin::kPosition);
+ position->SetInvariant(true);
+ auto* color1 = b.FunctionParam("color1", ty.f32());
+ color1->SetLocation(0, {});
+ auto* color2 = b.FunctionParam("color2", ty.f32());
+ color2->SetLocation(1, builtin::Interpolation{builtin::InterpolationType::kLinear,
+ builtin::InterpolationSampling::kSample});
+
+ ep->SetParams({front_facing, position, color1, color2});
+ ep->SetStage(Function::PipelineStage::kFragment);
+
+ b.With(ep->Block(), [&] {
+ auto* ifelse = b.If(front_facing);
+ b.With(ifelse->True(), [&] {
+ b.Multiply(ty.vec4<f32>(), position, b.Add(ty.f32(), color1, color2));
+ b.ExitIf(ifelse);
+ });
+ b.Return(ep);
+ });
+
+ auto* src = R"(
+%foo = @fragment func(%front_facing:bool [@front_facing], %position:vec4<f32> [@invariant, @position], %color1:f32 [@location(0)], %color2:f32 [@location(1), @interpolate(linear, sample)]):void -> %b1 {
+ %b1 = block {
+ if %front_facing [t: %b2] { # if_1
+ %b2 = block { # true
+ %6:f32 = add %color1, %color2
+ %7:vec4<f32> = mul %position, %6
+ exit_if # if_1
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+foo_BuiltinInputsStruct = struct @align(16), @block {
+ front_facing:bool @offset(0), @builtin(front_facing)
+ position:vec4<f32> @offset(16), @invariant, @builtin(position)
+}
+
+foo_LocationInputsStruct = struct @align(4), @block {
+ color1:f32 @offset(0), @location(0)
+ color2:f32 @offset(4), @location(1), @interpolate(linear, sample)
+}
+
+%b1 = block { # root
+ %foo_BuiltinInputs:ptr<__in, foo_BuiltinInputsStruct, read> = var
+ %foo_LocationInputs:ptr<__in, foo_LocationInputsStruct, read> = var
+}
+
+%foo_inner = func(%front_facing:bool, %position:vec4<f32>, %color1:f32, %color2:f32):void -> %b2 {
+ %b2 = block {
+ if %front_facing [t: %b3] { # if_1
+ %b3 = block { # true
+ %8:f32 = add %color1, %color2
+ %9:vec4<f32> = mul %position, %8
+ exit_if # if_1
+ }
+ }
+ ret
+ }
+}
+%foo = @fragment func():void -> %b4 {
+ %b4 = block {
+ %11:ptr<__in, bool, read> = access %foo_BuiltinInputs, 0u
+ %12:bool = load %11
+ %13:ptr<__in, vec4<f32>, read> = access %foo_BuiltinInputs, 1u
+ %14:vec4<f32> = load %13
+ %15:ptr<__in, f32, read> = access %foo_LocationInputs, 0u
+ %16:f32 = load %15
+ %17:ptr<__in, f32, read> = access %foo_LocationInputs, 1u
+ %18:f32 = load %17
+ %19:void = call %foo_inner, %12, %14, %16, %18
+ ret
+ }
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<ShaderIO::Config>();
+ Run<ShaderIOSpirv>(data);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ShaderIOTest, Parameters_Struct_Spirv) {
+ auto* str_ty =
+ ty.Struct(mod.symbols.New("Inputs"),
+ {
+ {
+ mod.symbols.New("front_facing"),
+ ty.bool_(),
+ {{}, {}, builtin::BuiltinValue::kFrontFacing, {}, false},
+ },
+ {
+ mod.symbols.New("position"),
+ ty.vec4<f32>(),
+ {{}, {}, builtin::BuiltinValue::kPosition, {}, true},
+ },
+ {
+ mod.symbols.New("color1"),
+ ty.f32(),
+ {0u, {}, {}, {}, false},
+ },
+ {
+ mod.symbols.New("color2"),
+ ty.f32(),
+ {1u,
+ {},
+ {},
+ builtin::Interpolation{builtin::InterpolationType::kLinear,
+ builtin::InterpolationSampling::kSample},
+ false},
+ },
+ });
+
+ auto* ep = b.Function("foo", ty.void_());
+ auto* str_param = b.FunctionParam("inputs", str_ty);
+ ep->SetParams({str_param});
+ ep->SetStage(Function::PipelineStage::kFragment);
+
+ b.With(ep->Block(), [&] {
+ auto* ifelse = b.If(b.Access(ty.bool_(), str_param, 0_i));
+ b.With(ifelse->True(), [&] {
+ auto* position = b.Access(ty.vec4<f32>(), str_param, 1_i);
+ auto* color1 = b.Access(ty.f32(), str_param, 2_i);
+ auto* color2 = b.Access(ty.f32(), str_param, 3_i);
+ b.Multiply(ty.vec4<f32>(), position, b.Add(ty.f32(), color1, color2));
+ b.ExitIf(ifelse);
+ });
+ b.Return(ep);
+ });
+
+ auto* src = R"(
+Inputs = struct @align(16) {
+ front_facing:bool @offset(0), @builtin(front_facing)
+ position:vec4<f32> @offset(16), @invariant, @builtin(position)
+ color1:f32 @offset(32), @location(0)
+ color2:f32 @offset(36), @location(1), @interpolate(linear, sample)
+}
+
+%foo = @fragment func(%inputs:Inputs):void -> %b1 {
+ %b1 = block {
+ %3:bool = access %inputs, 0i
+ if %3 [t: %b2] { # if_1
+ %b2 = block { # true
+ %4:vec4<f32> = access %inputs, 1i
+ %5:f32 = access %inputs, 2i
+ %6:f32 = access %inputs, 3i
+ %7:f32 = add %5, %6
+ %8:vec4<f32> = mul %4, %7
+ exit_if # if_1
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+Inputs = struct @align(16) {
+ front_facing:bool @offset(0)
+ position:vec4<f32> @offset(16)
+ color1:f32 @offset(32)
+ color2:f32 @offset(36)
+}
+
+foo_BuiltinInputsStruct = struct @align(16), @block {
+ Inputs_front_facing:bool @offset(0), @builtin(front_facing)
+ Inputs_position:vec4<f32> @offset(16), @invariant, @builtin(position)
+}
+
+foo_LocationInputsStruct = struct @align(4), @block {
+ Inputs_color1:f32 @offset(0), @location(0)
+ Inputs_color2:f32 @offset(4), @location(1), @interpolate(linear, sample)
+}
+
+%b1 = block { # root
+ %foo_BuiltinInputs:ptr<__in, foo_BuiltinInputsStruct, read> = var
+ %foo_LocationInputs:ptr<__in, foo_LocationInputsStruct, read> = var
+}
+
+%foo_inner = func(%inputs:Inputs):void -> %b2 {
+ %b2 = block {
+ %5:bool = access %inputs, 0i
+ if %5 [t: %b3] { # if_1
+ %b3 = block { # true
+ %6:vec4<f32> = access %inputs, 1i
+ %7:f32 = access %inputs, 2i
+ %8:f32 = access %inputs, 3i
+ %9:f32 = add %7, %8
+ %10:vec4<f32> = mul %6, %9
+ exit_if # if_1
+ }
+ }
+ ret
+ }
+}
+%foo = @fragment func():void -> %b4 {
+ %b4 = block {
+ %12:ptr<__in, bool, read> = access %foo_BuiltinInputs, 0u
+ %13:bool = load %12
+ %14:ptr<__in, vec4<f32>, read> = access %foo_BuiltinInputs, 1u
+ %15:vec4<f32> = load %14
+ %16:ptr<__in, f32, read> = access %foo_LocationInputs, 0u
+ %17:f32 = load %16
+ %18:ptr<__in, f32, read> = access %foo_LocationInputs, 1u
+ %19:f32 = load %18
+ %20:Inputs = construct %13, %15, %17, %19
+ %21:void = call %foo_inner, %20
+ ret
+ }
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<ShaderIO::Config>();
+ Run<ShaderIOSpirv>(data);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ShaderIOTest, Parameters_Mixed_Spirv) {
+ auto* str_ty = ty.Struct(mod.symbols.New("Inputs"),
+ {
+ {
+ mod.symbols.New("position"),
+ ty.vec4<f32>(),
+ {{}, {}, builtin::BuiltinValue::kPosition, {}, true},
+ },
+ {
+ mod.symbols.New("color1"),
+ ty.f32(),
+ {0u, {}, {}, {}, false},
+ },
+ });
+
+ auto* ep = b.Function("foo", ty.void_());
+ auto* front_facing = b.FunctionParam("front_facing", ty.bool_());
+ front_facing->SetBuiltin(FunctionParam::Builtin::kFrontFacing);
+ auto* str_param = b.FunctionParam("inputs", str_ty);
+ auto* color2 = b.FunctionParam("color2", ty.f32());
+ color2->SetLocation(1, builtin::Interpolation{builtin::InterpolationType::kLinear,
+ builtin::InterpolationSampling::kSample});
+
+ ep->SetParams({front_facing, str_param, color2});
+ ep->SetStage(Function::PipelineStage::kFragment);
+
+ b.With(ep->Block(), [&] {
+ auto* ifelse = b.If(front_facing);
+ b.With(ifelse->True(), [&] {
+ auto* position = b.Access(ty.vec4<f32>(), str_param, 0_i);
+ auto* color1 = b.Access(ty.f32(), str_param, 1_i);
+ b.Multiply(ty.vec4<f32>(), position, b.Add(ty.f32(), color1, color2));
+ b.ExitIf(ifelse);
+ });
+ b.Return(ep);
+ });
+
+ auto* src = R"(
+Inputs = struct @align(16) {
+ position:vec4<f32> @offset(0), @invariant, @builtin(position)
+ color1:f32 @offset(16), @location(0)
+}
+
+%foo = @fragment func(%front_facing:bool [@front_facing], %inputs:Inputs, %color2:f32 [@location(1), @interpolate(linear, sample)]):void -> %b1 {
+ %b1 = block {
+ if %front_facing [t: %b2] { # if_1
+ %b2 = block { # true
+ %5:vec4<f32> = access %inputs, 0i
+ %6:f32 = access %inputs, 1i
+ %7:f32 = add %6, %color2
+ %8:vec4<f32> = mul %5, %7
+ exit_if # if_1
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+Inputs = struct @align(16) {
+ position:vec4<f32> @offset(0)
+ color1:f32 @offset(16)
+}
+
+foo_BuiltinInputsStruct = struct @align(16), @block {
+ front_facing:bool @offset(0), @builtin(front_facing)
+ Inputs_position:vec4<f32> @offset(16), @invariant, @builtin(position)
+}
+
+foo_LocationInputsStruct = struct @align(4), @block {
+ Inputs_color1:f32 @offset(0), @location(0)
+ color2:f32 @offset(4), @location(1), @interpolate(linear, sample)
+}
+
+%b1 = block { # root
+ %foo_BuiltinInputs:ptr<__in, foo_BuiltinInputsStruct, read> = var
+ %foo_LocationInputs:ptr<__in, foo_LocationInputsStruct, read> = var
+}
+
+%foo_inner = func(%front_facing:bool, %inputs:Inputs, %color2:f32):void -> %b2 {
+ %b2 = block {
+ if %front_facing [t: %b3] { # if_1
+ %b3 = block { # true
+ %7:vec4<f32> = access %inputs, 0i
+ %8:f32 = access %inputs, 1i
+ %9:f32 = add %8, %color2
+ %10:vec4<f32> = mul %7, %9
+ exit_if # if_1
+ }
+ }
+ ret
+ }
+}
+%foo = @fragment func():void -> %b4 {
+ %b4 = block {
+ %12:ptr<__in, bool, read> = access %foo_BuiltinInputs, 0u
+ %13:bool = load %12
+ %14:ptr<__in, vec4<f32>, read> = access %foo_BuiltinInputs, 1u
+ %15:vec4<f32> = load %14
+ %16:ptr<__in, f32, read> = access %foo_LocationInputs, 0u
+ %17:f32 = load %16
+ %18:Inputs = construct %15, %17
+ %19:ptr<__in, f32, read> = access %foo_LocationInputs, 1u
+ %20:f32 = load %19
+ %21:void = call %foo_inner, %13, %18, %20
+ ret
+ }
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<ShaderIO::Config>();
+ Run<ShaderIOSpirv>(data);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ShaderIOTest, ReturnValue_NonStructBuiltin_Spirv) {
+ auto* ep = b.Function("foo", ty.vec4<f32>());
+ ep->SetReturnBuiltin(Function::ReturnBuiltin::kPosition);
+ ep->SetReturnInvariant(true);
+ ep->SetStage(Function::PipelineStage::kVertex);
+
+ b.With(ep->Block(), [&] { //
+ b.Return(ep, b.Construct(ty.vec4<f32>(), 0.5_f));
+ });
+
+ auto* src = R"(
+%foo = @vertex func():vec4<f32> [@invariant, @position] -> %b1 {
+ %b1 = block {
+ %2:vec4<f32> = construct 0.5f
+ ret %2
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+foo_BuiltinOutputsStruct = struct @align(16), @block {
+ tint_symbol:vec4<f32> @offset(0), @invariant, @builtin(position)
+}
+
+%b1 = block { # root
+ %foo_BuiltinOutputs:ptr<__out, foo_BuiltinOutputsStruct, write> = var
+}
+
+%foo_inner = func():vec4<f32> -> %b2 {
+ %b2 = block {
+ %3:vec4<f32> = construct 0.5f
+ ret %3
+ }
+}
+%foo = @vertex func():void -> %b3 {
+ %b3 = block {
+ %5:vec4<f32> = call %foo_inner
+ %6:ptr<__out, vec4<f32>, write> = access %foo_BuiltinOutputs, 0u
+ store %6, %5
+ ret
+ }
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<ShaderIO::Config>();
+ Run<ShaderIOSpirv>(data);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ShaderIOTest, ReturnValue_NonStructLocation_Spirv) {
+ auto* ep = b.Function("foo", ty.vec4<f32>());
+ ep->SetReturnLocation(1u, {});
+ ep->SetStage(Function::PipelineStage::kFragment);
+
+ b.With(ep->Block(), [&] { //
+ b.Return(ep, b.Construct(ty.vec4<f32>(), 0.5_f));
+ });
+
+ auto* src = R"(
+%foo = @fragment func():vec4<f32> [@location(1)] -> %b1 {
+ %b1 = block {
+ %2:vec4<f32> = construct 0.5f
+ ret %2
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+foo_LocationOutputsStruct = struct @align(16), @block {
+ tint_symbol:vec4<f32> @offset(0), @location(1)
+}
+
+%b1 = block { # root
+ %foo_LocationOutputs:ptr<__out, foo_LocationOutputsStruct, write> = var
+}
+
+%foo_inner = func():vec4<f32> -> %b2 {
+ %b2 = block {
+ %3:vec4<f32> = construct 0.5f
+ ret %3
+ }
+}
+%foo = @fragment func():void -> %b3 {
+ %b3 = block {
+ %5:vec4<f32> = call %foo_inner
+ %6:ptr<__out, vec4<f32>, write> = access %foo_LocationOutputs, 0u
+ store %6, %5
+ ret
+ }
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<ShaderIO::Config>();
+ Run<ShaderIOSpirv>(data);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ShaderIOTest, ReturnValue_Struct_Spirv) {
+ auto* str_ty =
+ ty.Struct(mod.symbols.New("Outputs"),
+ {
+ {
+ mod.symbols.New("position"),
+ ty.vec4<f32>(),
+ {{}, {}, builtin::BuiltinValue::kPosition, {}, true},
+ },
+ {
+ mod.symbols.New("color1"),
+ ty.f32(),
+ {0u, {}, {}, {}, false},
+ },
+ {
+ mod.symbols.New("color2"),
+ ty.f32(),
+ {1u,
+ {},
+ {},
+ builtin::Interpolation{builtin::InterpolationType::kLinear,
+ builtin::InterpolationSampling::kSample},
+ false},
+ },
+ });
+
+ auto* ep = b.Function("foo", str_ty);
+ ep->SetStage(Function::PipelineStage::kVertex);
+
+ b.With(ep->Block(), [&] { //
+ b.Return(ep, b.Construct(str_ty, b.Construct(ty.vec4<f32>(), 0_f), 0.25_f, 0.75_f));
+ });
+
+ auto* src = R"(
+Outputs = struct @align(16) {
+ position:vec4<f32> @offset(0), @invariant, @builtin(position)
+ color1:f32 @offset(16), @location(0)
+ color2:f32 @offset(20), @location(1), @interpolate(linear, sample)
+}
+
+%foo = @vertex func():Outputs -> %b1 {
+ %b1 = block {
+ %2:vec4<f32> = construct 0.0f
+ %3:Outputs = construct %2, 0.25f, 0.75f
+ ret %3
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+Outputs = struct @align(16) {
+ position:vec4<f32> @offset(0)
+ color1:f32 @offset(16)
+ color2:f32 @offset(20)
+}
+
+foo_BuiltinOutputsStruct = struct @align(16), @block {
+ Outputs_position:vec4<f32> @offset(0), @invariant, @builtin(position)
+}
+
+foo_LocationOutputsStruct = struct @align(4), @block {
+ Outputs_color1:f32 @offset(0), @location(0)
+ Outputs_color2:f32 @offset(4), @location(1), @interpolate(linear, sample)
+}
+
+%b1 = block { # root
+ %foo_BuiltinOutputs:ptr<__out, foo_BuiltinOutputsStruct, write> = var
+ %foo_LocationOutputs:ptr<__out, foo_LocationOutputsStruct, write> = var
+}
+
+%foo_inner = func():Outputs -> %b2 {
+ %b2 = block {
+ %4:vec4<f32> = construct 0.0f
+ %5:Outputs = construct %4, 0.25f, 0.75f
+ ret %5
+ }
+}
+%foo = @vertex func():void -> %b3 {
+ %b3 = block {
+ %7:Outputs = call %foo_inner
+ %8:vec4<f32> = access %7, 0u
+ %9:ptr<__out, vec4<f32>, write> = access %foo_BuiltinOutputs, 0u
+ store %9, %8
+ %10:f32 = access %7, 1u
+ %11:ptr<__out, f32, write> = access %foo_LocationOutputs, 0u
+ store %11, %10
+ %12:f32 = access %7, 2u
+ %13:ptr<__out, f32, write> = access %foo_LocationOutputs, 1u
+ store %13, %12
+ ret
+ }
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<ShaderIO::Config>();
+ Run<ShaderIOSpirv>(data);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ShaderIOTest, Struct_SharedByVertexAndFragment_Spirv) {
+ auto* vec4f = ty.vec4<f32>();
+ auto* str_ty = ty.Struct(mod.symbols.New("Interface"),
+ {
+ {
+ mod.symbols.New("position"),
+ vec4f,
+ {{}, {}, builtin::BuiltinValue::kPosition, {}, false},
+ },
+ {
+ mod.symbols.New("color"),
+ vec4f,
+ {0u, {}, {}, {}, false},
+ },
+ });
+
+ // Vertex shader.
+ {
+ auto* ep = b.Function("vert", str_ty);
+ ep->SetStage(Function::PipelineStage::kVertex);
+
+ b.With(ep->Block(), [&] { //
+ auto* position = b.Construct(vec4f, 0_f);
+ auto* color = b.Construct(vec4f, 1_f);
+ b.Return(ep, b.Construct(str_ty, position, color));
+ });
+ }
+
+ // Fragment shader.
+ {
+ auto* ep = b.Function("frag", vec4f);
+ auto* inputs = b.FunctionParam("inputs", str_ty);
+ ep->SetStage(Function::PipelineStage::kFragment);
+ ep->SetParams({inputs});
+
+ b.With(ep->Block(), [&] { //
+ auto* position = b.Access(vec4f, inputs, 0_u);
+ auto* color = b.Access(vec4f, inputs, 1_u);
+ b.Return(ep, b.Add(vec4f, position, color));
+ });
+ }
+
+ auto* src = R"(
+Interface = struct @align(16) {
+ position:vec4<f32> @offset(0), @builtin(position)
+ color:vec4<f32> @offset(16), @location(0)
+}
+
+%vert = @vertex func():Interface -> %b1 {
+ %b1 = block {
+ %2:vec4<f32> = construct 0.0f
+ %3:vec4<f32> = construct 1.0f
+ %4:Interface = construct %2, %3
+ ret %4
+ }
+}
+%frag = @fragment func(%inputs:Interface):vec4<f32> -> %b2 {
+ %b2 = block {
+ %7:vec4<f32> = access %inputs, 0u
+ %8:vec4<f32> = access %inputs, 1u
+ %9:vec4<f32> = add %7, %8
+ ret %9
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+Interface = struct @align(16) {
+ position:vec4<f32> @offset(0)
+ color:vec4<f32> @offset(16)
+}
+
+vert_BuiltinOutputsStruct = struct @align(16), @block {
+ Interface_position:vec4<f32> @offset(0), @builtin(position)
+}
+
+vert_LocationOutputsStruct = struct @align(16), @block {
+ Interface_color:vec4<f32> @offset(0), @location(0)
+}
+
+frag_BuiltinInputsStruct = struct @align(16), @block {
+ Interface_position:vec4<f32> @offset(0), @builtin(position)
+}
+
+frag_LocationInputsStruct = struct @align(16), @block {
+ Interface_color:vec4<f32> @offset(0), @location(0)
+}
+
+frag_LocationOutputsStruct = struct @align(16), @block {
+ tint_symbol:vec4<f32> @offset(0)
+}
+
+%b1 = block { # root
+ %vert_BuiltinOutputs:ptr<__out, vert_BuiltinOutputsStruct, write> = var
+ %vert_LocationOutputs:ptr<__out, vert_LocationOutputsStruct, write> = var
+ %frag_BuiltinInputs:ptr<__in, frag_BuiltinInputsStruct, read> = var
+ %frag_LocationInputs:ptr<__in, frag_LocationInputsStruct, read> = var
+ %frag_LocationOutputs:ptr<__out, frag_LocationOutputsStruct, write> = var
+}
+
+%vert_inner = func():Interface -> %b2 {
+ %b2 = block {
+ %7:vec4<f32> = construct 0.0f
+ %8:vec4<f32> = construct 1.0f
+ %9:Interface = construct %7, %8
+ ret %9
+ }
+}
+%frag_inner = func(%inputs:Interface):vec4<f32> -> %b3 {
+ %b3 = block {
+ %12:vec4<f32> = access %inputs, 0u
+ %13:vec4<f32> = access %inputs, 1u
+ %14:vec4<f32> = add %12, %13
+ ret %14
+ }
+}
+%vert = @vertex func():void -> %b4 {
+ %b4 = block {
+ %16:Interface = call %vert_inner
+ %17:vec4<f32> = access %16, 0u
+ %18:ptr<__out, vec4<f32>, write> = access %vert_BuiltinOutputs, 0u
+ store %18, %17
+ %19:vec4<f32> = access %16, 1u
+ %20:ptr<__out, vec4<f32>, write> = access %vert_LocationOutputs, 0u
+ store %20, %19
+ ret
+ }
+}
+%frag = @fragment func():void -> %b5 {
+ %b5 = block {
+ %22:ptr<__in, vec4<f32>, read> = access %frag_BuiltinInputs, 0u
+ %23:vec4<f32> = load %22
+ %24:ptr<__in, vec4<f32>, read> = access %frag_LocationInputs, 0u
+ %25:vec4<f32> = load %24
+ %26:Interface = construct %23, %25
+ %27:vec4<f32> = call %frag_inner, %26
+ %28:ptr<__out, vec4<f32>, write> = access %frag_LocationOutputs, 0u
+ store %28, %27
+ ret
+ }
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<ShaderIO::Config>();
+ Run<ShaderIOSpirv>(data);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ShaderIOTest, Struct_SharedWithBuffer_Spirv) {
+ auto* vec4f = ty.vec4<f32>();
+ auto* str_ty = ty.Struct(mod.symbols.New("Outputs"),
+ {
+ {
+ mod.symbols.New("position"),
+ vec4f,
+ {{}, {}, builtin::BuiltinValue::kPosition, {}, false},
+ },
+ {
+ mod.symbols.New("color"),
+ vec4f,
+ {0u, {}, {}, {}, false},
+ },
+ });
+
+ auto* buffer = b.RootBlock()->Append(b.Var(ty.ptr(storage, str_ty, read)));
+
+ auto* ep = b.Function("vert", str_ty);
+ ep->SetStage(Function::PipelineStage::kVertex);
+
+ b.With(ep->Block(), [&] { //
+ b.Return(ep, b.Load(buffer));
+ });
+
+ auto* src = R"(
+Outputs = struct @align(16) {
+ position:vec4<f32> @offset(0), @builtin(position)
+ color:vec4<f32> @offset(16), @location(0)
+}
+
+%b1 = block { # root
+ %1:ptr<storage, Outputs, read> = var
+}
+
+%vert = @vertex func():Outputs -> %b2 {
+ %b2 = block {
+ %3:Outputs = load %1
+ ret %3
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+Outputs = struct @align(16) {
+ position:vec4<f32> @offset(0)
+ color:vec4<f32> @offset(16)
+}
+
+vert_BuiltinOutputsStruct = struct @align(16), @block {
+ Outputs_position:vec4<f32> @offset(0), @builtin(position)
+}
+
+vert_LocationOutputsStruct = struct @align(16), @block {
+ Outputs_color:vec4<f32> @offset(0), @location(0)
+}
+
+%b1 = block { # root
+ %1:ptr<storage, Outputs, read> = var
+ %vert_BuiltinOutputs:ptr<__out, vert_BuiltinOutputsStruct, write> = var
+ %vert_LocationOutputs:ptr<__out, vert_LocationOutputsStruct, write> = var
+}
+
+%vert_inner = func():Outputs -> %b2 {
+ %b2 = block {
+ %5:Outputs = load %1
+ ret %5
+ }
+}
+%vert = @vertex func():void -> %b3 {
+ %b3 = block {
+ %7:Outputs = call %vert_inner
+ %8:vec4<f32> = access %7, 0u
+ %9:ptr<__out, vec4<f32>, write> = access %vert_BuiltinOutputs, 0u
+ store %9, %8
+ %10:vec4<f32> = access %7, 1u
+ %11:ptr<__out, vec4<f32>, write> = access %vert_LocationOutputs, 0u
+ store %11, %10
+ ret
+ }
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<ShaderIO::Config>();
+ Run<ShaderIOSpirv>(data);
+
+ EXPECT_EQ(expect, str());
+}
+
+// Test that we change the type of the sample mask builtin to an array for SPIR-V.
+TEST_F(IR_ShaderIOTest, SampleMask_Spirv) {
+ auto* str_ty = ty.Struct(mod.symbols.New("Outputs"),
+ {
+ {
+ mod.symbols.New("color"),
+ ty.f32(),
+ {0u, {}, {}, {}, false},
+ },
+ {
+ mod.symbols.New("mask"),
+ ty.u32(),
+ {{}, {}, builtin::BuiltinValue::kSampleMask, {}, false},
+ },
+ });
+
+ auto* mask_in = b.FunctionParam("mask_in", ty.u32());
+ mask_in->SetBuiltin(FunctionParam::Builtin::kSampleMask);
+
+ auto* ep = b.Function("foo", str_ty);
+ ep->SetStage(Function::PipelineStage::kFragment);
+ ep->SetParams({mask_in});
+
+ b.With(ep->Block(), [&] { //
+ b.Return(ep, b.Construct(str_ty, 0.5_f, mask_in));
+ });
+
+ auto* src = R"(
+Outputs = struct @align(4) {
+ color:f32 @offset(0), @location(0)
+ mask:u32 @offset(4), @builtin(sample_mask)
+}
+
+%foo = @fragment func(%mask_in:u32 [@sample_mask]):Outputs -> %b1 {
+ %b1 = block {
+ %3:Outputs = construct 0.5f, %mask_in
+ ret %3
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+Outputs = struct @align(4) {
+ color:f32 @offset(0)
+ mask:u32 @offset(4)
+}
+
+foo_BuiltinInputsStruct = struct @align(4), @block {
+ mask_in:array<u32, 1> @offset(0), @builtin(sample_mask)
+}
+
+foo_BuiltinOutputsStruct = struct @align(4), @block {
+ Outputs_mask:array<u32, 1> @offset(0), @builtin(sample_mask)
+}
+
+foo_LocationOutputsStruct = struct @align(4), @block {
+ Outputs_color:f32 @offset(0), @location(0)
+}
+
+%b1 = block { # root
+ %foo_BuiltinInputs:ptr<__in, foo_BuiltinInputsStruct, read> = var
+ %foo_BuiltinOutputs:ptr<__out, foo_BuiltinOutputsStruct, write> = var
+ %foo_LocationOutputs:ptr<__out, foo_LocationOutputsStruct, write> = var
+}
+
+%foo_inner = func(%mask_in:u32):Outputs -> %b2 {
+ %b2 = block {
+ %6:Outputs = construct 0.5f, %mask_in
+ ret %6
+ }
+}
+%foo = @fragment func():void -> %b3 {
+ %b3 = block {
+ %8:ptr<__in, u32, read> = access %foo_BuiltinInputs, 0u, 0u
+ %9:u32 = load %8
+ %10:Outputs = call %foo_inner, %9
+ %11:f32 = access %10, 0u
+ %12:ptr<__out, f32, write> = access %foo_LocationOutputs, 0u
+ store %12, %11
+ %13:u32 = access %10, 1u
+ %14:ptr<__out, u32, write> = access %foo_BuiltinOutputs, 0u, 0u
+ store %14, %13
+ ret
+ }
+}
+)";
+
+ Transform::DataMap data;
+ data.Add<ShaderIO::Config>();
+ Run<ShaderIOSpirv>(data);
+
+ EXPECT_EQ(expect, str());
+}
+
+} // namespace
+} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/var_for_dynamic_index_test.cc b/src/tint/ir/transform/var_for_dynamic_index_test.cc
index ce0d691..7848174 100644
--- a/src/tint/ir/transform/var_for_dynamic_index_test.cc
+++ b/src/tint/ir/transform/var_for_dynamic_index_test.cc
@@ -37,7 +37,6 @@
auto* block = func->Block();
auto* access = block->Append(b.Access(ty.i32(), arr, 1_i));
block->Append(b.Return(func, access));
- mod.functions.Push(func);
auto* expect = R"(
%foo = func(%2:array<i32, 4>):i32 -> %b1 {
@@ -61,7 +60,6 @@
auto* block = func->Block();
auto* access = block->Append(b.Access(ty.f32(), mat, 1_i, 0_i));
block->Append(b.Return(func, access));
- mod.functions.Push(func);
auto* expect = R"(
%foo = func(%2:mat2x2<f32>):f32 -> %b1 {
@@ -87,7 +85,6 @@
auto* access = block->Append(b.Access(ty.ptr<function, i32>(), arr, idx));
auto* load = block->Append(b.Load(access));
block->Append(b.Return(func, load));
- mod.functions.Push(func);
auto* expect = R"(
%foo = func(%2:ptr<function, array<i32, 4>, read_write>, %3:i32):i32 -> %b1 {
@@ -114,7 +111,6 @@
auto* access = block->Append(b.Access(ty.ptr<function, f32>(), mat, idx, idx));
auto* load = block->Append(b.Load(access));
block->Append(b.Return(func, load));
- mod.functions.Push(func);
auto* expect = R"(
%foo = func(%2:ptr<function, mat2x2<f32>, read_write>, %3:i32):f32 -> %b1 {
@@ -140,7 +136,6 @@
auto* block = func->Block();
auto* access = block->Append(b.Access(ty.f32(), vec, idx));
block->Append(b.Return(func, access));
- mod.functions.Push(func);
auto* expect = R"(
%foo = func(%2:vec4<f32>, %3:i32):f32 -> %b1 {
@@ -165,7 +160,6 @@
auto* block = func->Block();
auto* access = block->Append(b.Access(ty.i32(), arr, idx));
block->Append(b.Return(func, access));
- mod.functions.Push(func);
auto* expect = R"(
%foo = func(%2:array<i32, 4>, %3:i32):i32 -> %b1 {
@@ -192,7 +186,6 @@
auto* block = func->Block();
auto* access = block->Append(b.Access(ty.vec2<f32>(), mat, idx));
block->Append(b.Return(func, access));
- mod.functions.Push(func);
auto* expect = R"(
%foo = func(%2:mat2x2<f32>, %3:i32):vec2<f32> -> %b1 {
@@ -219,7 +212,6 @@
auto* block = func->Block();
auto* access = block->Append(b.Access(ty.i32(), arr, idx, 1_u, idx));
block->Append(b.Return(func, access));
- mod.functions.Push(func);
auto* expect = R"(
%foo = func(%2:array<array<array<i32, 4>, 4>, 4>, %3:i32):i32 -> %b1 {
@@ -246,7 +238,6 @@
auto* block = func->Block();
auto* access = block->Append(b.Access(ty.i32(), arr, 1_u, 2_u, idx));
block->Append(b.Return(func, access));
- mod.functions.Push(func);
auto* expect = R"(
%foo = func(%2:array<array<array<i32, 4>, 4>, 4>, %3:i32):i32 -> %b1 {
@@ -274,7 +265,6 @@
auto* block = func->Block();
auto* access = block->Append(b.Access(ty.i32(), arr, 1_u, idx, 2_u, idx));
block->Append(b.Return(func, access));
- mod.functions.Push(func);
auto* expect = R"(
%foo = func(%2:array<array<array<array<i32, 4>, 4>, 4>, 4>, %3:i32):i32 -> %b1 {
@@ -308,7 +298,6 @@
auto* block = func->Block();
auto* access = block->Append(b.Access(ty.f32(), str_val, 1_u, idx, 0_u));
block->Append(b.Return(func, access));
- mod.functions.Push(func);
auto* expect = R"(
MyStruct = struct @align(16) {
@@ -346,7 +335,6 @@
block->Append(b.Access(ty.i32(), arr, idx_b));
auto* access_c = block->Append(b.Access(ty.i32(), arr, idx_c));
block->Append(b.Return(func, access_c));
- mod.functions.Push(func);
auto* expect = R"(
%foo = func(%2:array<i32, 4>, %3:i32, %4:i32, %5:i32):i32 -> %b1 {
@@ -381,7 +369,6 @@
block->Append(b.Access(ty.i32(), arr, 1_u, 2_u, idx_b));
auto* access_c = block->Append(b.Access(ty.i32(), arr, 1_u, 2_u, idx_c));
block->Append(b.Return(func, access_c));
- mod.functions.Push(func);
auto* expect = R"(
%foo = func(%2:array<array<array<i32, 4>, 4>, 4>, %3:i32, %4:i32, %5:i32):i32 -> %b1 {
diff --git a/src/tint/ir/validate.cc b/src/tint/ir/validate.cc
index d8871b9..a743786 100644
--- a/src/tint/ir/validate.cc
+++ b/src/tint/ir/validate.cc
@@ -66,12 +66,9 @@
}
if (diagnostics_.contains_errors()) {
- // If a diassembly file was generated then one of the diagnostics referenced the
- // disasembly. Emit the entire disassembly file at the end of the messages.
- if (mod_.disassembly_file) {
- diagnostics_.add_note(tint::diag::System::IR,
- "# Disassembly\n" + mod_.disassembly_file->content.data, {});
- }
+ DisassembleIfNeeded();
+ diagnostics_.add_note(tint::diag::System::IR,
+ "# Disassembly\n" + mod_.disassembly_file->content.data, {});
return std::move(diagnostics_);
}
return Success{};
@@ -83,6 +80,7 @@
Disassembler dis_{mod_};
Block* current_block_ = nullptr;
+ utils::Hashset<Function*, 4> seen_functions_;
void DisassembleIfNeeded() {
if (mod_.disassembly_file) {
@@ -113,6 +111,17 @@
}
}
+ void AddResultError(Instruction* inst, size_t idx, std::string err) {
+ DisassembleIfNeeded();
+ auto src = dis_.ResultSource(Usage{inst, static_cast<uint32_t>(idx)});
+ src.file = mod_.disassembly_file.get();
+ AddError(std::move(err), src);
+
+ if (current_block_) {
+ AddNote(current_block_, "In block");
+ }
+ }
+
void AddError(Block* blk, std::string err) {
DisassembleIfNeeded();
auto src = dis_.BlockSource(blk);
@@ -142,7 +151,30 @@
diagnostics_.add_note(tint::diag::System::IR, std::move(note), src);
}
- // std::string Name(Value* v) { return mod_.NameOf(v).Name(); }
+ std::string Name(Value* v) { return mod_.NameOf(v).Name(); }
+
+ template <typename FUNC>
+ void CheckOperandNotNull(ir::Instruction* inst,
+ ir::Value* operand,
+ size_t idx,
+ std::string_view name,
+ FUNC&& cb) {
+ if (operand == nullptr) {
+ AddError(inst, idx, std::string(name) + ": " + cb(idx) + " operand is undefined");
+ }
+ }
+
+ template <typename FUNC>
+ void CheckOperandsNotNull(ir::Instruction* inst,
+ size_t start_operand,
+ size_t end_operand,
+ std::string_view name,
+ FUNC&& cb) {
+ auto operands = inst->Operands();
+ for (size_t i = start_operand; i <= end_operand; i++) {
+ CheckOperandNotNull(inst, operands[i], i, name, cb);
+ }
+ }
void CheckRootBlock(Block* blk) {
if (!blk) {
@@ -158,11 +190,17 @@
std::string("root block: invalid instruction: ") + inst->TypeInfo().name);
continue;
}
- CheckVar(var);
+ CheckInstruction(var);
}
}
- void CheckFunction(Function* func) { CheckBlock(func->Block()); }
+ void CheckFunction(Function* func) {
+ if (!seen_functions_.Add(func)) {
+ AddError("function '" + Name(func) + "' added to module multiple times");
+ }
+
+ CheckBlock(func->Block());
+ }
void CheckBlock(Block* blk) {
TINT_SCOPED_ASSIGNMENT(current_block_, blk);
@@ -184,12 +222,22 @@
void CheckInstruction(Instruction* inst) {
if (!inst->Alive()) {
AddError(inst, "destroyed instruction found in instruction list");
+ return;
}
- if (inst->Result()) {
- if (inst->Result()->Source() == nullptr) {
- AddError(inst, "instruction result source is undefined");
- } else if (inst->Result()->Source() != inst) {
- AddError(inst, "instruction result source has wrong instruction");
+ if (inst->HasResults()) {
+ auto results = inst->Results();
+ for (size_t i = 0; i < results.Length(); ++i) {
+ auto* res = results[i];
+ if (!res) {
+ AddResultError(inst, i, "instruction result is undefined");
+ continue;
+ }
+
+ if (res->Source() == nullptr) {
+ AddResultError(inst, i, "instruction result source is undefined");
+ } else if (res->Source() != inst) {
+ AddResultError(inst, i, "instruction result source has wrong instruction");
+ }
}
}
@@ -203,7 +251,7 @@
// Note, a `nullptr` is a valid operand in some cases, like `var` so we can't just check
// for `nullptr` here.
if (!op->Alive()) {
- AddError(inst, "instruction has undefined operand");
+ AddError(inst, i, "instruction has operand which is not alive");
}
if (!op->Usages().Contains({inst, i})) {
@@ -223,7 +271,7 @@
[&](Switch*) {}, //
[&](Swizzle*) {}, //
[&](Terminator* b) { CheckTerminator(b); }, //
- [&](Unary*) {}, //
+ [&](Unary* u) { CheckUnary(u); }, //
[&](Var* var) { CheckVar(var); }, //
[&](Default) {
AddError(std::string("missing validation of: ") + inst->TypeInfo().name);
@@ -313,14 +361,20 @@
}
void CheckBinary(ir::Binary* b) {
- if (b->LHS() == nullptr) {
- AddError(b, "binary: left operand is undefined");
- }
- if (b->RHS() == nullptr) {
- AddError(b, "binary: right operand is undefined");
- }
- if (b->Result() == nullptr) {
- AddError(b, "binary: result is undefined");
+ CheckOperandsNotNull(b, Binary::kLhsOperandOffset, Binary::kRhsOperandOffset, "binary",
+ [](size_t err_idx) {
+ return (err_idx == Binary::kLhsOperandOffset) ? "left" : "right";
+ });
+ }
+
+ void CheckUnary(ir::Unary* u) {
+ CheckOperandNotNull(u, u->Val(), Unary::kValueOperandOffset, "unary",
+ [](size_t) { return "value"; });
+
+ if (u->Result() && u->Val()) {
+ if (u->Result()->Type() != u->Val()->Type()) {
+ AddError(u, "unary: result type must match value type");
+ }
}
}
@@ -345,19 +399,15 @@
}
void CheckIf(If* if_) {
- if (!if_->Condition()) {
- AddError(if_, "if: condition is undefined");
- }
+ CheckOperandNotNull(if_, if_->Condition(), If::kConditionOperandOffset, "if",
+ [](size_t) { return "condition"; });
+
if (if_->Condition() && !if_->Condition()->Type()->Is<type::Bool>()) {
AddError(if_, If::kConditionOperandOffset, "if: condition must be a `bool` type");
}
}
void CheckVar(Var* var) {
- if (var->Result() == nullptr) {
- AddError(var, "var: result is undefined");
- }
-
if (var->Result() && var->Initializer()) {
if (var->Initializer()->Type() != var->Result()->Type()->UnwrapPtr()) {
AddError(var, "var initializer has incorrect type");
diff --git a/src/tint/ir/validate_test.cc b/src/tint/ir/validate_test.cc
index 8d696e1..723c3fc 100644
--- a/src/tint/ir/validate_test.cc
+++ b/src/tint/ir/validate_test.cc
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <string>
#include <utility>
#include "gmock/gmock.h"
@@ -21,6 +22,7 @@
#include "src/tint/type/matrix.h"
#include "src/tint/type/pointer.h"
#include "src/tint/type/struct.h"
+#include "src/tint/utils/string.h"
namespace tint::ir {
namespace {
@@ -68,7 +70,6 @@
TEST_F(IR_ValidateTest, Function) {
auto* f = b.Function("my_func", ty.void_());
- mod.functions.Push(f);
f->SetParams({b.FunctionParam(ty.i32()), b.FunctionParam(ty.f32())});
f->Block()->Append(b.Return(f));
@@ -77,10 +78,34 @@
EXPECT_TRUE(res) << res.Failure().str();
}
-TEST_F(IR_ValidateTest, Block_NoTerminator) {
+TEST_F(IR_ValidateTest, Function_Duplicate) {
auto* f = b.Function("my_func", ty.void_());
+ // Function would auto-push by the builder, so this adds a duplicate
mod.functions.Push(f);
+ f->SetParams({b.FunctionParam(ty.i32()), b.FunctionParam(ty.f32())});
+ f->Block()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(error: function 'my_func' added to module multiple times
+note: # Disassembly
+%my_func = func(%2:i32, %3:f32):void -> %b1 {
+ %b1 = block {
+ ret
+ }
+}
+%my_func = func(%2:i32, %3:f32):void -> %b1 {
+ %b1 = block {
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Block_NoTerminator) {
+ b.Function("my_func", ty.void_());
+
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
EXPECT_EQ(res.Failure().str(), R"(:2:3 error: block: does not end in a terminator instruction
@@ -99,7 +124,6 @@
auto* f = b.Function("my_func", ty.void_());
auto* obj = b.FunctionParam(ty.mat3x2<f32>());
f->SetParams({obj});
- mod.functions.Push(f);
b.With(f->Block(), [&] {
b.Access(ty.f32(), obj, 1_u, 0_u);
@@ -114,7 +138,6 @@
auto* f = b.Function("my_func", ty.void_());
auto* obj = b.FunctionParam(ty.ptr<private_, mat3x2<f32>>());
f->SetParams({obj});
- mod.functions.Push(f);
b.With(f->Block(), [&] {
b.Access(ty.ptr<private_, f32>(), obj, 1_u, 0_u);
@@ -129,7 +152,6 @@
auto* f = b.Function("my_func", ty.void_());
auto* obj = b.FunctionParam(ty.vec3<f32>());
f->SetParams({obj});
- mod.functions.Push(f);
b.With(f->Block(), [&] {
b.Access(ty.f32(), obj, -1_i);
@@ -160,7 +182,6 @@
auto* f = b.Function("my_func", ty.void_());
auto* obj = b.FunctionParam(ty.mat3x2<f32>());
f->SetParams({obj});
- mod.functions.Push(f);
b.With(f->Block(), [&] {
b.Access(ty.f32(), obj, 1_u, 3_u);
@@ -195,7 +216,6 @@
auto* f = b.Function("my_func", ty.void_());
auto* obj = b.FunctionParam(ty.ptr<private_, mat3x2<f32>>());
f->SetParams({obj});
- mod.functions.Push(f);
b.With(f->Block(), [&] {
b.Access(ty.ptr<private_, f32>(), obj, 1_u, 3_u);
@@ -231,7 +251,6 @@
auto* f = b.Function("my_func", ty.void_());
auto* obj = b.FunctionParam(ty.f32());
f->SetParams({obj});
- mod.functions.Push(f);
b.With(f->Block(), [&] {
b.Access(ty.f32(), obj, 1_u);
@@ -262,7 +281,6 @@
auto* f = b.Function("my_func", ty.void_());
auto* obj = b.FunctionParam(ty.ptr<private_, f32>());
f->SetParams({obj});
- mod.functions.Push(f);
b.With(f->Block(), [&] {
b.Access(ty.ptr<private_, f32>(), obj, 1_u);
@@ -299,7 +317,6 @@
auto* obj = b.FunctionParam(str_ty);
auto* idx = b.FunctionParam(ty.i32());
f->SetParams({obj, idx});
- mod.functions.Push(f);
b.With(f->Block(), [&] {
b.Access(ty.i32(), obj, idx);
@@ -342,7 +359,6 @@
auto* obj = b.FunctionParam(ty.ptr<private_, read_write>(str_ty));
auto* idx = b.FunctionParam(ty.i32());
f->SetParams({obj, idx});
- mod.functions.Push(f);
b.With(f->Block(), [&] {
b.Access(ty.i32(), obj, idx);
@@ -379,7 +395,6 @@
auto* f = b.Function("my_func", ty.void_());
auto* obj = b.FunctionParam(ty.mat3x2<f32>());
f->SetParams({obj});
- mod.functions.Push(f);
b.With(f->Block(), [&] {
b.Access(ty.i32(), obj, 1_u, 1_u);
@@ -411,7 +426,6 @@
auto* f = b.Function("my_func", ty.void_());
auto* obj = b.FunctionParam(ty.ptr<private_, mat3x2<f32>>());
f->SetParams({obj});
- mod.functions.Push(f);
b.With(f->Block(), [&] {
b.Access(ty.ptr<private_, i32>(), obj, 1_u, 1_u);
@@ -444,7 +458,6 @@
auto* f = b.Function("my_func", ty.void_());
auto* obj = b.FunctionParam(ty.ptr<private_, mat3x2<f32>>());
f->SetParams({obj});
- mod.functions.Push(f);
b.With(f->Block(), [&] {
b.Access(ty.f32(), obj, 1_u, 1_u);
@@ -475,7 +488,6 @@
TEST_F(IR_ValidateTest, Block_TerminatorInMiddle) {
auto* f = b.Function("my_func", ty.void_());
- mod.functions.Push(f);
b.With(f->Block(), [&] {
b.Return(f);
@@ -505,7 +517,6 @@
TEST_F(IR_ValidateTest, If_ConditionIsBool) {
auto* f = b.Function("my_func", ty.void_());
- mod.functions.Push(f);
auto* if_ = b.If(1_i);
if_->True()->Append(b.Return(f));
@@ -541,15 +552,52 @@
)");
}
+TEST_F(IR_ValidateTest, If_ConditionIsNullptr) {
+ auto* f = b.Function("my_func", ty.void_());
+
+ auto* if_ = b.If(nullptr);
+ if_->True()->Append(b.Return(f));
+ if_->False()->Append(b.Return(f));
+
+ f->Block()->Append(if_);
+ f->Block()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:8 error: if: condition operand is undefined
+ if undef [t: %b2, f: %b3] { # if_1
+ ^^^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ if undef [t: %b2, f: %b3] { # if_1
+ %b2 = block { # true
+ ret
+ }
+ %b3 = block { # false
+ ret
+ }
+ }
+ ret
+ }
+}
+)");
+}
+
TEST_F(IR_ValidateTest, Var_RootBlock_NullResult) {
auto* v = mod.instructions.Create<ir::Var>(nullptr);
b.RootBlock()->Append(v);
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure().str(), R"(:2:11 error: var: result is undefined
+ EXPECT_EQ(res.Failure().str(), R"(:2:3 error: instruction result is undefined
undef = var
- ^^^
+ ^^^^^
:1:1 note: In block
%b1 = block { # root
@@ -567,7 +615,6 @@
auto* v = mod.instructions.Create<ir::Var>(nullptr);
auto* f = b.Function("my_func", ty.void_());
- mod.functions.Push(f);
auto sb = b.With(f->Block());
sb.Append(v);
@@ -575,9 +622,9 @@
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure().str(), R"(:3:13 error: var: result is undefined
+ EXPECT_EQ(res.Failure().str(), R"(:3:5 error: instruction result is undefined
undef = var
- ^^^
+ ^^^^^
:2:3 note: In block
%b1 = block {
@@ -595,7 +642,6 @@
TEST_F(IR_ValidateTest, Var_Init_WrongType) {
auto* f = b.Function("my_func", ty.void_());
- mod.functions.Push(f);
auto sb = b.With(f->Block());
auto* v = sb.Var(ty.ptr<function, f32>());
@@ -626,7 +672,6 @@
TEST_F(IR_ValidateTest, Instruction_AppendedDead) {
auto* f = b.Function("my_func", ty.void_());
- mod.functions.Push(f);
auto sb = b.With(f->Block());
auto* v = sb.Var(ty.ptr<function, f32>());
@@ -635,19 +680,12 @@
v->Destroy();
v->InsertBefore(ret);
- auto res = ir::Validate(mod);
- ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure().str(), R"(:3:41 error: destroyed instruction found in instruction list
- %2:ptr<function, f32, read_write> = var
- ^^^
+ auto addr = utils::ToString(v);
+ auto arrows = std::string(addr.length(), '^');
-:2:3 note: In block
- %b1 = block {
- ^^^^^^^^^^^
-
-:3:41 error: instruction result source is undefined
- %2:ptr<function, f32, read_write> = var
- ^^^
+ std::string expected = R"(:3:5 error: destroyed instruction found in instruction list
+ <destroyed tint::ir::Var $ADDRESS>
+ ^^^^^^^^^^^^^^^^^^^^^^^^^$ARROWS^
:2:3 note: In block
%b1 = block {
@@ -656,16 +694,22 @@
note: # Disassembly
%my_func = func():void -> %b1 {
%b1 = block {
- %2:ptr<function, f32, read_write> = var
+ <destroyed tint::ir::Var $ADDRESS>
ret
}
}
-)");
+)";
+
+ expected = utils::ReplaceAll(expected, "$ADDRESS", addr);
+ expected = utils::ReplaceAll(expected, "$ARROWS", arrows);
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), expected);
}
TEST_F(IR_ValidateTest, Instruction_NullSource) {
auto* f = b.Function("my_func", ty.void_());
- mod.functions.Push(f);
auto sb = b.With(f->Block());
auto* v = sb.Var(ty.ptr<function, f32>());
@@ -675,9 +719,9 @@
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure().str(), R"(:3:41 error: instruction result source is undefined
+ EXPECT_EQ(res.Failure().str(), R"(:3:5 error: instruction result source is undefined
%2:ptr<function, f32, read_write> = var
- ^^^
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
:2:3 note: In block
%b1 = block {
@@ -695,7 +739,6 @@
TEST_F(IR_ValidateTest, Instruction_DeadOperand) {
auto* f = b.Function("my_func", ty.void_());
- mod.functions.Push(f);
auto sb = b.With(f->Block());
auto* v = sb.Var(ty.ptr<function, f32>());
@@ -707,9 +750,9 @@
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure().str(), R"(:3:41 error: instruction has undefined operand
+ EXPECT_EQ(res.Failure().str(), R"(:3:46 error: instruction has operand which is not alive
%2:ptr<function, f32, read_write> = var, %3
- ^^^
+ ^^
:2:3 note: In block
%b1 = block {
@@ -727,7 +770,6 @@
TEST_F(IR_ValidateTest, Instruction_OperandUsageRemoved) {
auto* f = b.Function("my_func", ty.void_());
- mod.functions.Push(f);
auto sb = b.With(f->Block());
auto* v = sb.Var(ty.ptr<function, f32>());
@@ -759,7 +801,6 @@
TEST_F(IR_ValidateTest, Binary_LHS_Nullptr) {
auto* f = b.Function("my_func", ty.void_());
- mod.functions.Push(f);
auto sb = b.With(f->Block());
sb.Add(ty.i32(), nullptr, sb.Constant(2_i));
@@ -767,9 +808,9 @@
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure().str(), R"(:3:5 error: binary: left operand is undefined
+ EXPECT_EQ(res.Failure().str(), R"(:3:18 error: binary: left operand is undefined
%2:i32 = add undef, 2i
- ^^^^^^^^^^^^^^^^^^^^^^
+ ^^^^^
:2:3 note: In block
%b1 = block {
@@ -787,7 +828,6 @@
TEST_F(IR_ValidateTest, Binary_RHS_Nullptr) {
auto* f = b.Function("my_func", ty.void_());
- mod.functions.Push(f);
auto sb = b.With(f->Block());
sb.Add(ty.i32(), sb.Constant(2_i), nullptr);
@@ -795,9 +835,9 @@
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure().str(), R"(:3:5 error: binary: right operand is undefined
+ EXPECT_EQ(res.Failure().str(), R"(:3:22 error: binary: right operand is undefined
%2:i32 = add 2i, undef
- ^^^^^^^^^^^^^^^^^^^^^^
+ ^^^^^
:2:3 note: In block
%b1 = block {
@@ -818,7 +858,6 @@
b.Constant(3_i), b.Constant(2_i));
auto* f = b.Function("my_func", ty.void_());
- mod.functions.Push(f);
auto sb = b.With(f->Block());
sb.Append(bin);
@@ -826,9 +865,9 @@
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure().str(), R"(:3:5 error: binary: result is undefined
+ EXPECT_EQ(res.Failure().str(), R"(:3:5 error: instruction result is undefined
undef = add 3i, 2i
- ^^^^^^^^^^^^^^^^^^
+ ^^^^^
:2:3 note: In block
%b1 = block {
@@ -844,5 +883,91 @@
)");
}
+TEST_F(IR_ValidateTest, Unary_Value_Nullptr) {
+ auto* f = b.Function("my_func", ty.void_());
+
+ auto sb = b.With(f->Block());
+ sb.Negation(ty.i32(), nullptr);
+ sb.Return(f);
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:23 error: unary: value operand is undefined
+ %2:i32 = negation undef
+ ^^^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ %2:i32 = negation undef
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Unary_Result_Nullptr) {
+ auto* bin =
+ mod.instructions.Create<ir::Unary>(nullptr, ir::Unary::Kind::kNegation, b.Constant(2_i));
+
+ auto* f = b.Function("my_func", ty.void_());
+
+ auto sb = b.With(f->Block());
+ sb.Append(bin);
+ sb.Return(f);
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:5 error: instruction result is undefined
+ undef = negation 2i
+ ^^^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ undef = negation 2i
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Unary_ResultTypeNotMatchValueType) {
+ auto* bin = b.Complement(ty.f32(), 2_i);
+
+ auto* f = b.Function("my_func", ty.void_());
+
+ auto sb = b.With(f->Block());
+ sb.Append(bin);
+ sb.Return(f);
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:5 error: unary: result type must match value type
+ %2:f32 = complement 2i
+ ^^^^^^^^^^^^^^^^^^^^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ %2:f32 = complement 2i
+ ret
+ }
+}
+)");
+}
+
} // namespace
} // namespace tint::ir
diff --git a/src/tint/ir/value.cc b/src/tint/ir/value.cc
index 1471681..dd93f4b 100644
--- a/src/tint/ir/value.cc
+++ b/src/tint/ir/value.cc
@@ -28,7 +28,7 @@
void Value::Destroy() {
TINT_ASSERT(IR, Alive());
TINT_ASSERT(IR, Usages().Count() == 0);
- alive_ = false;
+ flags_.Add(Flag::kDead);
}
void Value::ReplaceAllUsesWith(std::function<Value*(Usage use)> replacer) {
diff --git a/src/tint/ir/value.h b/src/tint/ir/value.h
index 0efa6d0..1f7f3b5 100644
--- a/src/tint/ir/value.h
+++ b/src/tint/ir/value.h
@@ -64,7 +64,7 @@
virtual void Destroy();
/// @returns true if the Value has not been destroyed with Destroy()
- bool Alive() const { return alive_; }
+ bool Alive() const { return !flags_.Contains(Flag::kDead); }
/// Adds a usage of this value.
/// @param u the usage
@@ -91,8 +91,16 @@
Value();
private:
+ /// Flags applied to an Value
+ enum class Flag {
+ /// The value has been destroyed
+ kDead,
+ };
+
utils::Hashset<Usage, 4, Usage::Hasher> uses_;
- bool alive_ = true;
+
+ /// Bitset of value flags
+ utils::EnumSet<Flag> flags_;
};
} // namespace tint::ir
diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc
index 5b5c684..352fd47 100644
--- a/src/tint/reader/wgsl/parser_impl.cc
+++ b/src/tint/reader/wgsl/parser_impl.cc
@@ -3067,6 +3067,8 @@
return create<ast::GroupAttribute>(t.source(), args[0]);
case builtin::Attribute::kId:
return create<ast::IdAttribute>(t.source(), args[0]);
+ case builtin::Attribute::kIndex:
+ return create<ast::IndexAttribute>(t.source(), args[0]);
case builtin::Attribute::kInterpolate:
return create<ast::InterpolateAttribute>(t.source(), args[0],
args.Length() == 2 ? args[1] : nullptr);
diff --git a/src/tint/reader/wgsl/parser_impl_struct_member_attribute_test.cc b/src/tint/reader/wgsl/parser_impl_struct_member_attribute_test.cc
index 30834ad..0cc1138 100644
--- a/src/tint/reader/wgsl/parser_impl_struct_member_attribute_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_struct_member_attribute_test.cc
@@ -223,5 +223,22 @@
EXPECT_EQ(p->error(), "1:9: mixing '+' and '<<' requires parenthesis");
}
+TEST_F(ParserImplTest, Attribute_Index) {
+ auto p = parser("index(1)");
+ auto attr = p->attribute();
+ EXPECT_TRUE(attr.matched);
+ EXPECT_FALSE(attr.errored);
+ ASSERT_NE(attr.value, nullptr);
+ ASSERT_FALSE(p->has_error()) << p->error();
+
+ auto* member_attr = attr.value->As<ast::Attribute>();
+ ASSERT_NE(member_attr, nullptr);
+ ASSERT_TRUE(member_attr->Is<ast::IndexAttribute>());
+
+ auto* o = member_attr->As<ast::IndexAttribute>();
+ ASSERT_TRUE(o->expr->Is<ast::IntLiteralExpression>());
+ EXPECT_EQ(o->expr->As<ast::IntLiteralExpression>()->value, 1);
+}
+
} // namespace
} // namespace tint::reader::wgsl
diff --git a/src/tint/resolver/dual_source_blending_extension_test.cc b/src/tint/resolver/dual_source_blending_extension_test.cc
index 4974940..fe8ed1d 100644
--- a/src/tint/resolver/dual_source_blending_extension_test.cc
+++ b/src/tint/resolver/dual_source_blending_extension_test.cc
@@ -120,5 +120,17 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
+// Using an index attribute on a global variable should pass. This is needed internally when using
+// @index with the canonicalize_entry_point transform. This test uses an internal attribute to
+// ignore address space, which is how it is used with the canonicalize_entry_point transform.
+TEST_F(DualSourceBlendingExtensionTests, GlobalVariableIndexAttribute) {
+ GlobalVar("var", ty.vec4<f32>(),
+ utils::Vector{Location(0_a), Index(0_a),
+ Disable(ast::DisabledValidation::kIgnoreAddressSpace)},
+ builtin::AddressSpace::kOut);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 254c969..6e960ec 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -572,7 +572,7 @@
bool has_io_address_space = address_space == builtin::AddressSpace::kIn ||
address_space == builtin::AddressSpace::kOut;
- std::optional<uint32_t> group, binding, location;
+ std::optional<uint32_t> group, binding, location, index;
for (auto* attribute : var->attributes) {
Mark(attribute);
enum Status { kSuccess, kErrored, kInvalid };
@@ -605,6 +605,17 @@
location = value.Get();
return kSuccess;
},
+ [&](const ast::IndexAttribute* attr) {
+ if (!has_io_address_space) {
+ return kInvalid;
+ }
+ auto value = IndexAttribute(attr);
+ if (!value) {
+ return kErrored;
+ }
+ index = value.Get();
+ return kSuccess;
+ },
[&](const ast::BuiltinAttribute* attr) {
if (!has_io_address_space) {
return kInvalid;
@@ -645,7 +656,7 @@
}
sem = builder_->create<sem::GlobalVariable>(
var, var_ty, sem::EvaluationStage::kRuntime, address_space, access,
- /* constant_value */ nullptr, binding_point, location);
+ /* constant_value */ nullptr, binding_point, location, index);
} else {
for (auto* attribute : var->attributes) {
diff --git a/src/tint/sem/variable.cc b/src/tint/sem/variable.cc
index a220a80..95c4236 100644
--- a/src/tint/sem/variable.cc
+++ b/src/tint/sem/variable.cc
@@ -62,10 +62,12 @@
builtin::Access access,
const constant::Value* constant_value,
std::optional<sem::BindingPoint> binding_point,
- std::optional<uint32_t> location)
+ std::optional<uint32_t> location,
+ std::optional<uint32_t> index)
: Base(declaration, type, stage, address_space, access, constant_value),
binding_point_(binding_point),
- location_(location) {}
+ location_(location),
+ index_(index) {}
GlobalVariable::~GlobalVariable() = default;
diff --git a/src/tint/sem/variable.h b/src/tint/sem/variable.h
index 7660df3..cc8b661 100644
--- a/src/tint/sem/variable.h
+++ b/src/tint/sem/variable.h
@@ -156,6 +156,7 @@
/// @param constant_value the constant value for the variable. May be null
/// @param binding_point the optional resource binding point of the variable
/// @param location the location value if provided
+ /// @param index the index value if provided
///
/// Note, a GlobalVariable generally doesn't have a `location` in WGSL, as it isn't allowed by
/// the spec. The location maybe attached by transforms such as CanonicalizeEntryPointIO.
@@ -166,7 +167,8 @@
builtin::Access access,
const constant::Value* constant_value,
std::optional<sem::BindingPoint> binding_point = std::nullopt,
- std::optional<uint32_t> location = std::nullopt);
+ std::optional<uint32_t> location = std::nullopt,
+ std::optional<uint32_t> index = std::nullopt);
/// Destructor
~GlobalVariable() override;
@@ -188,6 +190,7 @@
tint::OverrideId override_id_;
std::optional<uint32_t> location_;
+ std::optional<uint32_t> index_;
};
/// Parameter is a function parameter
diff --git a/src/tint/transform/manager_test.cc b/src/tint/transform/manager_test.cc
index e578766..f8f8e0b 100644
--- a/src/tint/transform/manager_test.cc
+++ b/src/tint/transform/manager_test.cc
@@ -52,7 +52,6 @@
ir::Builder builder(*mod);
auto* func = builder.Function("ir_func", mod->Types().Get<type::Void>());
func->Block()->Append(builder.Return(func));
- mod->functions.Push(func);
}
};
#endif // TINT_BUILD_IR
@@ -69,7 +68,6 @@
ir::Builder builder(mod);
auto* func = builder.Function("main", mod.Types().Get<type::Void>());
func->Block()->Append(builder.Return(func));
- builder.ir.functions.Push(func);
return mod;
}
#endif // TINT_BUILD_IR
diff --git a/src/tint/type/manager.cc b/src/tint/type/manager.cc
index 17a3011..33bed10 100644
--- a/src/tint/type/manager.cc
+++ b/src/tint/type/manager.cc
@@ -172,7 +172,7 @@
return Get<type::Pointer>(address_space, subtype, access);
}
-const type::Struct* Manager::Struct(Symbol name, utils::VectorRef<StructMemberDesc> md) {
+type::Struct* Manager::Struct(Symbol name, utils::VectorRef<StructMemberDesc> md) {
utils::Vector<const type::StructMember*, 4> members;
uint32_t current_size = 0u;
uint32_t max_align = 0u;
diff --git a/src/tint/type/manager.h b/src/tint/type/manager.h
index 63e02e3..76cf959 100644
--- a/src/tint/type/manager.h
+++ b/src/tint/type/manager.h
@@ -429,13 +429,13 @@
/// @param name the name of the structure
/// @param members the list of structure member descriptors
/// @returns the structure type
- const type::Struct* Struct(Symbol name, utils::VectorRef<StructMemberDesc> members);
+ type::Struct* Struct(Symbol name, utils::VectorRef<StructMemberDesc> members);
/// Create a new structure declaration.
/// @param name the name of the structure
/// @param members the list of structure member descriptors
/// @returns the structure type
- const type::Struct* Struct(Symbol name, std::initializer_list<StructMemberDesc> members) {
+ type::Struct* Struct(Symbol name, std::initializer_list<StructMemberDesc> members) {
return Struct(name, utils::Vector<StructMemberDesc, 4>(members));
}
diff --git a/src/tint/type/struct.h b/src/tint/type/struct.h
index 907485c..5a6c4dd 100644
--- a/src/tint/type/struct.h
+++ b/src/tint/type/struct.h
@@ -15,11 +15,11 @@
#ifndef SRC_TINT_TYPE_STRUCT_H_
#define SRC_TINT_TYPE_STRUCT_H_
-#include <stdint.h>
-
+#include <cstdint>
#include <optional>
#include <string>
#include <unordered_set>
+#include <utility>
#include "src/tint/builtin/address_space.h"
#include "src/tint/builtin/interpolation.h"
@@ -245,6 +245,10 @@
/// @returns the optional attributes
const StructMemberAttributes& Attributes() const { return attributes_; }
+ /// Set the attributes of the struct member.
+ /// @param attributes the new attributes
+ void SetAttributes(StructMemberAttributes&& attributes) { attributes_ = std::move(attributes); }
+
/// @param ctx the clone context
/// @returns a clone of this struct member
StructMember* Clone(CloneContext& ctx) const;
@@ -257,7 +261,7 @@
const uint32_t offset_;
const uint32_t align_;
const uint32_t size_;
- const StructMemberAttributes attributes_;
+ StructMemberAttributes attributes_;
};
} // namespace tint::type
diff --git a/src/tint/utils/enum_set.h b/src/tint/utils/enum_set.h
index 575868b..c156d50 100644
--- a/src/tint/utils/enum_set.h
+++ b/src/tint/utils/enum_set.h
@@ -37,27 +37,27 @@
constexpr EnumSet() = default;
/// Copy constructor.
- /// @param s the set to copy
+ /// @param s the EnumSet to copy
constexpr EnumSet(const EnumSet& s) = default;
/// Constructor. Initializes the EnumSet with the given values.
- /// @param values the enumerator values to construct the set with
+ /// @param values the enumerator values to construct the EnumSet with
template <typename... VALUES>
explicit constexpr EnumSet(VALUES... values) : set(Union(values...)) {}
/// Copy assignment operator.
- /// @param set the set to assign to this set
- /// @return this set so calls can be chained
+ /// @param set the EnumSet to assign to this set
+ /// @returns this EnumSet so calls can be chained
inline EnumSet& operator=(const EnumSet& set) = default;
/// Copy assignment operator.
/// @param e the enum value
- /// @return this set so calls can be chained
+ /// @returns this EnumSet so calls can be chained
inline EnumSet& operator=(Enum e) { return *this = EnumSet{e}; }
/// Adds all the given values to this set
/// @param values the values to add
- /// @return this set so calls can be chained
+ /// @returns this EnumSet so calls can be chained
template <typename... VALUES>
inline EnumSet& Add(VALUES... values) {
return Add(EnumSet(std::forward<VALUES>(values)...));
@@ -65,24 +65,30 @@
/// Removes all the given values from this set
/// @param values the values to remove
- /// @return this set so calls can be chained
+ /// @returns this EnumSet so calls can be chained
template <typename... VALUES>
inline EnumSet& Remove(VALUES... values) {
return Remove(EnumSet(std::forward<VALUES>(values)...));
}
- /// Adds all of s to this set
+ /// Adds all of @p s to this set
/// @param s the enum value
- /// @return this set so calls can be chained
+ /// @returns this EnumSet so calls can be chained
inline EnumSet& Add(EnumSet s) { return (*this = *this + s); }
- /// Removes all of s from this set
+ /// Removes all of @p s from this set
/// @param s the enum value
- /// @return this set so calls can be chained
+ /// @returns this EnumSet so calls can be chained
inline EnumSet& Remove(EnumSet s) { return (*this = *this - s); }
+ /// Adds or removes @p e to the set
/// @param e the enum value
- /// @returns a copy of this set with e added
+ /// @param add if true the enum value is added, otherwise removed
+ /// @returns this EnumSet so calls can be chained
+ inline EnumSet& Set(Enum e, bool add = true) { return add ? Add(e) : Remove(e); }
+
+ /// @param e the enum value
+ /// @returns a copy of this EnumSet with @p e added
inline EnumSet operator+(Enum e) const {
EnumSet out;
out.set = set | Bit(e);
@@ -90,7 +96,7 @@
}
/// @param e the enum value
- /// @returns a copy of this set with e removed
+ /// @returns a copy of this EnumSet with @p e removed
inline EnumSet operator-(Enum e) const {
EnumSet out;
out.set = set & ~Bit(e);
@@ -98,7 +104,7 @@
}
/// @param s the other set
- /// @returns the union of this set with s (this ∪ rhs)
+ /// @returns the union of this EnumSet with @p s (`this` ∪ @p s)
inline EnumSet operator+(EnumSet s) const {
EnumSet out;
out.set = set | s.set;
@@ -106,7 +112,7 @@
}
/// @param s the other set
- /// @returns the set of entries found in this but not in s (this \ s)
+ /// @returns the set of entries found in this but not in s (`this` \ @p s)
inline EnumSet operator-(EnumSet s) const {
EnumSet out;
out.set = set & ~s.set;
@@ -114,7 +120,7 @@
}
/// @param s the other set
- /// @returns the intersection of this set with s (this ∩ rhs)
+ /// @returns the intersection of this EnumSet with s (`this` ∩ @p s)
inline EnumSet operator&(EnumSet s) const {
EnumSet out;
out.set = set & s.set;
@@ -122,7 +128,7 @@
}
/// @param e the enum value
- /// @return true if the set contains `e`
+ /// @return true if the set contains @p e
inline bool Contains(Enum e) const { return (set & Bit(e)) != 0; }
/// @return true if the set is empty
@@ -130,22 +136,22 @@
/// Equality operator
/// @param rhs the other EnumSet to compare this to
- /// @return true if this EnumSet is equal to rhs
+ /// @return true if this EnumSet is equal to @p rhs
inline bool operator==(EnumSet rhs) const { return set == rhs.set; }
/// Inequality operator
/// @param rhs the other EnumSet to compare this to
- /// @return true if this EnumSet is not equal to rhs
+ /// @return true if this EnumSet is not equal to @p rhs
inline bool operator!=(EnumSet rhs) const { return set != rhs.set; }
/// Equality operator
/// @param rhs the enum to compare this to
- /// @return true if this EnumSet only contains `rhs`
+ /// @return true if this EnumSet only contains @p rhs
inline bool operator==(Enum rhs) const { return set == Bit(rhs); }
/// Inequality operator
/// @param rhs the enum to compare this to
- /// @return false if this EnumSet only contains `rhs`
+ /// @return false if this EnumSet only contains @p rhs
inline bool operator!=(Enum rhs) const { return set != Bit(rhs); }
/// @return the underlying value for the EnumSet
diff --git a/src/tint/utils/enum_set_test.cc b/src/tint/utils/enum_set_test.cc
index 6e9b59c..cf0496a 100644
--- a/src/tint/utils/enum_set_test.cc
+++ b/src/tint/utils/enum_set_test.cc
@@ -126,6 +126,19 @@
EXPECT_FALSE(set.Contains(E::C));
}
+TEST(EnumSetTest, Set) {
+ EnumSet<E> set;
+ set.Set(E::B);
+ EXPECT_FALSE(set.Contains(E::A));
+ EXPECT_TRUE(set.Contains(E::B));
+ EXPECT_FALSE(set.Contains(E::C));
+
+ set.Set(E::B, false);
+ EXPECT_FALSE(set.Contains(E::A));
+ EXPECT_FALSE(set.Contains(E::B));
+ EXPECT_FALSE(set.Contains(E::C));
+}
+
TEST(EnumSetTest, OperatorPlusEnum) {
EnumSet<E> set = EnumSet<E>{E::B} + E::C;
EXPECT_FALSE(set.Contains(E::A));
diff --git a/src/tint/utils/slice.h b/src/tint/utils/slice.h
index 49483ec..efe2ac4 100644
--- a/src/tint/utils/slice.h
+++ b/src/tint/utils/slice.h
@@ -18,6 +18,7 @@
#include <cstdint>
#include <iterator>
+#include "src/tint/debug.h"
#include "src/tint/utils/bitcast.h"
#include "src/tint/utils/castable.h"
#include "src/tint/utils/traits.h"
@@ -186,24 +187,42 @@
/// Index operator
/// @param i the element index. Must be less than `len`.
/// @returns a reference to the i'th element.
- T& operator[](size_t i) { return data[i]; }
+ T& operator[](size_t i) {
+ TINT_ASSERT(Utils, i < Length());
+ return data[i];
+ }
/// Index operator
/// @param i the element index. Must be less than `len`.
/// @returns a reference to the i'th element.
- const T& operator[](size_t i) const { return data[i]; }
+ const T& operator[](size_t i) const {
+ TINT_ASSERT(Utils, i < Length());
+ return data[i];
+ }
/// @returns a reference to the first element in the vector
- T& Front() { return data[0]; }
+ T& Front() {
+ TINT_ASSERT(Utils, !IsEmpty());
+ return data[0];
+ }
/// @returns a reference to the first element in the vector
- const T& Front() const { return data[0]; }
+ const T& Front() const {
+ TINT_ASSERT(Utils, !IsEmpty());
+ return data[0];
+ }
/// @returns a reference to the last element in the vector
- T& Back() { return data[len - 1]; }
+ T& Back() {
+ TINT_ASSERT(Utils, !IsEmpty());
+ return data[len - 1];
+ }
/// @returns a reference to the last element in the vector
- const T& Back() const { return data[len - 1]; }
+ const T& Back() const {
+ TINT_ASSERT(Utils, !IsEmpty());
+ return data[len - 1];
+ }
/// @returns a pointer to the first element in the vector
T* begin() { return data; }
diff --git a/src/tint/utils/string.cc b/src/tint/utils/string.cc
index 00b51a8..320335d 100644
--- a/src/tint/utils/string.cc
+++ b/src/tint/utils/string.cc
@@ -25,7 +25,7 @@
const auto len_b = str_b.size();
Vector<size_t, 64> mat;
- mat.Reserve((len_a + 1) * (len_b + 1));
+ mat.Resize((len_a + 1) * (len_b + 1));
auto at = [&](size_t a, size_t b) -> size_t& { return mat[a + b * (len_a + 1)]; };
diff --git a/src/tint/utils/string_stream.h b/src/tint/utils/string_stream.h
index 2ab62a2..7485a9f 100644
--- a/src/tint/utils/string_stream.h
+++ b/src/tint/utils/string_stream.h
@@ -24,7 +24,6 @@
#include <utility>
#include "src/tint/utils/unicode.h"
-#include "src/tint/utils/vector.h"
namespace tint::utils {
@@ -193,44 +192,6 @@
/// @returns out so calls can be chained
utils::StringStream& operator<<(utils::StringStream& out, CodePoint codepoint);
-/// Prints the vector @p vec to @p o
-/// @param o the stream to write to
-/// @param vec the vector
-/// @return the stream so calls can be chained
-template <typename T, size_t N>
-inline utils::StringStream& operator<<(utils::StringStream& o, const utils::Vector<T, N>& vec) {
- o << "[";
- bool first = true;
- for (auto& el : vec) {
- if (!first) {
- o << ", ";
- }
- first = false;
- o << el;
- }
- o << "]";
- return o;
-}
-
-/// Prints the vector @p vec to @p o
-/// @param o the stream to write to
-/// @param vec the vector reference
-/// @return the stream so calls can be chained
-template <typename T>
-inline utils::StringStream& operator<<(utils::StringStream& o, utils::VectorRef<T> vec) {
- o << "[";
- bool first = true;
- for (auto& el : vec) {
- if (!first) {
- o << ", ";
- }
- first = false;
- o << el;
- }
- o << "]";
- return o;
-}
-
} // namespace tint::utils
#endif // SRC_TINT_UTILS_STRING_STREAM_H_
diff --git a/src/tint/utils/unique_vector.h b/src/tint/utils/unique_vector.h
index eca0d4c..22df7ec 100644
--- a/src/tint/utils/unique_vector.h
+++ b/src/tint/utils/unique_vector.h
@@ -45,7 +45,7 @@
}
}
- /// add appends the item to the end of the vector, if the vector does not
+ /// Add appends the item to the end of the vector, if the vector does not
/// already contain the given item.
/// @param item the item to append to the end of the vector
/// @returns true if the item was added, otherwise false.
@@ -57,6 +57,16 @@
return false;
}
+ /// Removes @p count elements from the vector
+ /// @param start the index of the first element to remove
+ /// @param count the number of elements to remove
+ void Erase(size_t start, size_t count = 1) {
+ for (size_t i = 0; i < count; i++) {
+ set.Remove(vector[start + i]);
+ }
+ vector.Erase(start, count);
+ }
+
/// @returns true if the vector contains `item`
/// @param item the item
bool Contains(const T& item) const { return set.Contains(item); }
@@ -72,6 +82,12 @@
/// @returns true if the vector is empty
bool IsEmpty() const { return vector.IsEmpty(); }
+ /// Removes all elements from the vector
+ void Clear() {
+ vector.Clear();
+ set.Clear();
+ }
+
/// @returns the number of items in the vector
size_t Length() const { return vector.Length(); }
@@ -111,6 +127,18 @@
return vector.Pop();
}
+ /// Removes the last element from the vector if it is equal to @p value
+ /// @param value the value to pop if it is at the back of the vector
+ /// @returns true if the value was popped, otherwise false
+ bool TryPop(T value) {
+ if (!vector.IsEmpty() && vector.Back() == value) {
+ set.Remove(vector.Back());
+ vector.Pop();
+ return true;
+ }
+ return false;
+ }
+
private:
Vector<T, N> vector;
Hashset<T, N, HASH, EQUAL> set;
diff --git a/src/tint/utils/unique_vector_test.cc b/src/tint/utils/unique_vector_test.cc
index c198615..1d8b2f8 100644
--- a/src/tint/utils/unique_vector_test.cc
+++ b/src/tint/utils/unique_vector_test.cc
@@ -25,14 +25,14 @@
TEST(UniqueVectorTest, Empty) {
UniqueVector<int, 4> unique_vec;
- EXPECT_EQ(unique_vec.Length(), 0u);
+ ASSERT_EQ(unique_vec.Length(), 0u);
EXPECT_EQ(unique_vec.IsEmpty(), true);
EXPECT_EQ(unique_vec.begin(), unique_vec.end());
}
TEST(UniqueVectorTest, MoveConstructor) {
UniqueVector<int, 4> unique_vec(std::vector<int>{0, 3, 2, 1, 2});
- EXPECT_EQ(unique_vec.Length(), 4u);
+ ASSERT_EQ(unique_vec.Length(), 4u);
EXPECT_EQ(unique_vec.IsEmpty(), false);
EXPECT_EQ(unique_vec[0], 0);
EXPECT_EQ(unique_vec[1], 3);
@@ -45,7 +45,7 @@
unique_vec.Add(0);
unique_vec.Add(1);
unique_vec.Add(2);
- EXPECT_EQ(unique_vec.Length(), 3u);
+ ASSERT_EQ(unique_vec.Length(), 3u);
EXPECT_EQ(unique_vec.IsEmpty(), false);
int i = 0;
for (auto n : unique_vec) {
@@ -69,7 +69,7 @@
unique_vec.Add(1);
unique_vec.Add(1);
unique_vec.Add(2);
- EXPECT_EQ(unique_vec.Length(), 3u);
+ ASSERT_EQ(unique_vec.Length(), 3u);
EXPECT_EQ(unique_vec.IsEmpty(), false);
int i = 0;
for (auto n : unique_vec) {
@@ -85,6 +85,87 @@
EXPECT_EQ(unique_vec[2], 2);
}
+TEST(UniqueVectorTest, Erase) {
+ UniqueVector<int, 4> unique_vec;
+ unique_vec.Add(0);
+ unique_vec.Add(3);
+ unique_vec.Add(2);
+ unique_vec.Add(5);
+ unique_vec.Add(1);
+ unique_vec.Add(6);
+ EXPECT_EQ(unique_vec.Length(), 6u);
+ EXPECT_EQ(unique_vec.IsEmpty(), false);
+
+ unique_vec.Erase(2, 2);
+
+ ASSERT_EQ(unique_vec.Length(), 4u);
+ EXPECT_EQ(unique_vec[0], 0);
+ EXPECT_EQ(unique_vec[1], 3);
+ EXPECT_EQ(unique_vec[2], 1);
+ EXPECT_EQ(unique_vec[3], 6);
+ EXPECT_TRUE(unique_vec.Contains(0));
+ EXPECT_TRUE(unique_vec.Contains(3));
+ EXPECT_FALSE(unique_vec.Contains(2));
+ EXPECT_FALSE(unique_vec.Contains(5));
+ EXPECT_TRUE(unique_vec.Contains(1));
+ EXPECT_TRUE(unique_vec.Contains(6));
+ EXPECT_EQ(unique_vec.IsEmpty(), false);
+
+ unique_vec.Erase(1);
+
+ ASSERT_EQ(unique_vec.Length(), 3u);
+ EXPECT_EQ(unique_vec[0], 0);
+ EXPECT_EQ(unique_vec[1], 1);
+ EXPECT_EQ(unique_vec[2], 6);
+ EXPECT_TRUE(unique_vec.Contains(0));
+ EXPECT_FALSE(unique_vec.Contains(3));
+ EXPECT_FALSE(unique_vec.Contains(2));
+ EXPECT_FALSE(unique_vec.Contains(5));
+ EXPECT_TRUE(unique_vec.Contains(1));
+ EXPECT_TRUE(unique_vec.Contains(6));
+ EXPECT_EQ(unique_vec.IsEmpty(), false);
+
+ unique_vec.Erase(2);
+
+ ASSERT_EQ(unique_vec.Length(), 2u);
+ EXPECT_EQ(unique_vec[0], 0);
+ EXPECT_EQ(unique_vec[1], 1);
+ EXPECT_TRUE(unique_vec.Contains(0));
+ EXPECT_FALSE(unique_vec.Contains(3));
+ EXPECT_FALSE(unique_vec.Contains(2));
+ EXPECT_FALSE(unique_vec.Contains(5));
+ EXPECT_TRUE(unique_vec.Contains(1));
+ EXPECT_FALSE(unique_vec.Contains(6));
+ EXPECT_EQ(unique_vec.IsEmpty(), false);
+
+ unique_vec.Erase(0, 2);
+
+ ASSERT_EQ(unique_vec.Length(), 0u);
+ EXPECT_FALSE(unique_vec.Contains(0));
+ EXPECT_FALSE(unique_vec.Contains(3));
+ EXPECT_FALSE(unique_vec.Contains(2));
+ EXPECT_FALSE(unique_vec.Contains(5));
+ EXPECT_FALSE(unique_vec.Contains(1));
+ EXPECT_FALSE(unique_vec.Contains(6));
+ EXPECT_EQ(unique_vec.IsEmpty(), true);
+
+ unique_vec.Add(6);
+ unique_vec.Add(0);
+ unique_vec.Add(2);
+
+ ASSERT_EQ(unique_vec.Length(), 3u);
+ EXPECT_EQ(unique_vec[0], 6);
+ EXPECT_EQ(unique_vec[1], 0);
+ EXPECT_EQ(unique_vec[2], 2);
+ EXPECT_TRUE(unique_vec.Contains(0));
+ EXPECT_FALSE(unique_vec.Contains(3));
+ EXPECT_TRUE(unique_vec.Contains(2));
+ EXPECT_FALSE(unique_vec.Contains(5));
+ EXPECT_FALSE(unique_vec.Contains(1));
+ EXPECT_TRUE(unique_vec.Contains(6));
+ EXPECT_EQ(unique_vec.IsEmpty(), false);
+}
+
TEST(UniqueVectorTest, AsVector) {
UniqueVector<int, 4> unique_vec;
unique_vec.Add(0);
@@ -115,25 +196,25 @@
unique_vec.Add(1);
EXPECT_EQ(unique_vec.Pop(), 1);
- EXPECT_EQ(unique_vec.Length(), 2u);
+ ASSERT_EQ(unique_vec.Length(), 2u);
EXPECT_EQ(unique_vec.IsEmpty(), false);
EXPECT_EQ(unique_vec[0], 0);
EXPECT_EQ(unique_vec[1], 2);
EXPECT_EQ(unique_vec.Pop(), 2);
- EXPECT_EQ(unique_vec.Length(), 1u);
+ ASSERT_EQ(unique_vec.Length(), 1u);
EXPECT_EQ(unique_vec.IsEmpty(), false);
EXPECT_EQ(unique_vec[0], 0);
unique_vec.Add(1);
- EXPECT_EQ(unique_vec.Length(), 2u);
+ ASSERT_EQ(unique_vec.Length(), 2u);
EXPECT_EQ(unique_vec.IsEmpty(), false);
EXPECT_EQ(unique_vec[0], 0);
EXPECT_EQ(unique_vec[1], 1);
EXPECT_EQ(unique_vec.Pop(), 1);
- EXPECT_EQ(unique_vec.Length(), 1u);
+ ASSERT_EQ(unique_vec.Length(), 1u);
EXPECT_EQ(unique_vec.IsEmpty(), false);
EXPECT_EQ(unique_vec[0], 0);
diff --git a/src/tint/utils/vector.h b/src/tint/utils/vector.h
index 78cddbd..aa17a9f 100644
--- a/src/tint/utils/vector.h
+++ b/src/tint/utils/vector.h
@@ -23,9 +23,11 @@
#include <utility>
#include <vector>
+#include "src/tint/debug.h"
#include "src/tint/utils/bitcast.h"
#include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/slice.h"
+#include "src/tint/utils/string_stream.h"
namespace tint::utils {
@@ -209,12 +211,18 @@
/// Index operator
/// @param i the element index. Must be less than `len`.
/// @returns a reference to the i'th element.
- T& operator[](size_t i) { return impl_.slice[i]; }
+ T& operator[](size_t i) {
+ TINT_ASSERT(Utils, i < Length());
+ return impl_.slice[i];
+ }
/// Index operator
/// @param i the element index. Must be less than `len`.
/// @returns a reference to the i'th element.
- const T& operator[](size_t i) const { return impl_.slice[i]; }
+ const T& operator[](size_t i) const {
+ TINT_ASSERT(Utils, i < Length());
+ return impl_.slice[i];
+ }
/// @return the number of elements in the vector
size_t Length() const { return impl_.slice.len; }
@@ -313,12 +321,32 @@
/// Removes and returns the last element from the vector.
/// @returns the popped element
T Pop() {
+ TINT_ASSERT(Utils, !IsEmpty());
auto& el = impl_.slice.data[--impl_.slice.len];
auto val = std::move(el);
el.~T();
return val;
}
+ /// Removes @p count elements from the vector
+ /// @param start the index of the first element to remove
+ /// @param count the number of elements to remove
+ void Erase(size_t start, size_t count = 1) {
+ TINT_ASSERT(Utils, start < Length());
+ TINT_ASSERT(Utils, (start + count) <= Length());
+ // Shuffle
+ for (size_t i = start + count; i < impl_.slice.len; i++) {
+ auto& src = impl_.slice.data[i];
+ auto& dst = impl_.slice.data[i - count];
+ dst = std::move(src);
+ }
+ // Pop
+ for (size_t i = 0; i < count; i++) {
+ auto& el = impl_.slice.data[--impl_.slice.len];
+ el.~T();
+ }
+ }
+
/// Sort sorts the vector in-place using the predicate function @p pred
/// @param pred a function that has the signature `bool(const T& a, const T& b)` which returns
/// true if `a` is ordered before `b`.
@@ -766,6 +794,44 @@
return out;
}
+/// Prints the vector @p vec to @p o
+/// @param o the stream to write to
+/// @param vec the vector
+/// @return the stream so calls can be chained
+template <typename T, size_t N>
+inline StringStream& operator<<(StringStream& o, const Vector<T, N>& vec) {
+ o << "[";
+ bool first = true;
+ for (auto& el : vec) {
+ if (!first) {
+ o << ", ";
+ }
+ first = false;
+ o << el;
+ }
+ o << "]";
+ return o;
+}
+
+/// Prints the vector @p vec to @p o
+/// @param o the stream to write to
+/// @param vec the vector reference
+/// @return the stream so calls can be chained
+template <typename T>
+inline StringStream& operator<<(StringStream& o, VectorRef<T> vec) {
+ o << "[";
+ bool first = true;
+ for (auto& el : vec) {
+ if (!first) {
+ o << ", ";
+ }
+ first = false;
+ o << el;
+ }
+ o << "]";
+ return o;
+}
+
namespace detail {
/// IsVectorLike<T>::value is true if T is a utils::Vector or utils::VectorRef.
diff --git a/src/tint/utils/vector_test.cc b/src/tint/utils/vector_test.cc
index d65f570..91f7166 100644
--- a/src/tint/utils/vector_test.cc
+++ b/src/tint/utils/vector_test.cc
@@ -151,6 +151,70 @@
EXPECT_TRUE(AllExternallyHeld(vec));
}
+TEST(TintVectorTest, Erase_Front) {
+ Vector<std::string, 3> vec;
+ vec.Push("one");
+ vec.Push("two");
+ vec.Push("three");
+ vec.Push("four");
+ EXPECT_EQ(vec.Length(), 4u);
+
+ vec.Erase(0);
+ EXPECT_EQ(vec.Length(), 3u);
+ EXPECT_EQ(vec[0], "two");
+ EXPECT_EQ(vec[1], "three");
+ EXPECT_EQ(vec[2], "four");
+
+ vec.Erase(0, 1);
+ EXPECT_EQ(vec.Length(), 2u);
+ EXPECT_EQ(vec[0], "three");
+ EXPECT_EQ(vec[1], "four");
+
+ vec.Erase(0, 2);
+ EXPECT_EQ(vec.Length(), 0u);
+}
+
+TEST(TintVectorTest, Erase_Mid) {
+ Vector<std::string, 5> vec;
+ vec.Push("one");
+ vec.Push("two");
+ vec.Push("three");
+ vec.Push("four");
+ vec.Push("five");
+ EXPECT_EQ(vec.Length(), 5u);
+
+ vec.Erase(1);
+ EXPECT_EQ(vec.Length(), 4u);
+ EXPECT_EQ(vec[0], "one");
+ EXPECT_EQ(vec[1], "three");
+ EXPECT_EQ(vec[2], "four");
+ EXPECT_EQ(vec[3], "five");
+
+ vec.Erase(1, 2);
+ EXPECT_EQ(vec.Length(), 2u);
+ EXPECT_EQ(vec[0], "one");
+ EXPECT_EQ(vec[1], "five");
+}
+
+TEST(TintVectorTest, Erase_Back) {
+ Vector<std::string, 3> vec;
+ vec.Push("one");
+ vec.Push("two");
+ vec.Push("three");
+ vec.Push("four");
+ EXPECT_EQ(vec.Length(), 4u);
+
+ vec.Erase(3);
+ EXPECT_EQ(vec.Length(), 3u);
+ EXPECT_EQ(vec[0], "one");
+ EXPECT_EQ(vec[1], "two");
+ EXPECT_EQ(vec[2], "three");
+
+ vec.Erase(1, 2);
+ EXPECT_EQ(vec.Length(), 1u);
+ EXPECT_EQ(vec[0], "one");
+}
+
TEST(TintVectorTest, InferTN_1CString) {
auto vec = Vector{"one"};
static_assert(std::is_same_v<decltype(vec)::value_type, const char*>);
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index a00c31a..357dd3a 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -1444,7 +1444,7 @@
const ast::transform::DecomposeMemoryAccess::Intrinsic* intrinsic) {
auto const buffer = intrinsic->Buffer()->identifier->symbol.Name();
auto* const offset = expr->args[0];
- auto* const value = expr->args[1];
+ auto* const value = expr->args.Length() > 1 ? expr->args[1] : nullptr;
using Op = ast::transform::DecomposeMemoryAccess::Intrinsic::Op;
using DataType = ast::transform::DecomposeMemoryAccess::Intrinsic::DataType;
diff --git a/src/tint/writer/msl/ir/generator_impl_ir.cc b/src/tint/writer/msl/ir/generator_impl_ir.cc
index 3209562..5c47493 100644
--- a/src/tint/writer/msl/ir/generator_impl_ir.cc
+++ b/src/tint/writer/msl/ir/generator_impl_ir.cc
@@ -14,6 +14,8 @@
#include "src/tint/writer/msl/ir/generator_impl_ir.h"
+#include "src/tint/constant/composite.h"
+#include "src/tint/constant/splat.h"
#include "src/tint/ir/constant.h"
#include "src/tint/ir/validate.h"
#include "src/tint/switch.h"
@@ -162,129 +164,13 @@
[&](const type::F16*) { out << "half"; }, //
[&](const type::I32*) { out << "int"; }, //
[&](const type::U32*) { out << "uint"; }, //
- [&](const type::Array* arr) {
- out << ArrayTemplateName() << "<";
- EmitType(out, arr->ElemType());
- out << ", ";
- if (arr->Count()->Is<type::RuntimeArrayCount>()) {
- out << "1";
- } else {
- auto count = arr->ConstantCount();
- if (!count) {
- diagnostics_.add_error(diag::System::Writer,
- type::Array::kErrExpectedConstantCount);
- return;
- }
- out << count.value();
- }
- out << ">";
- },
- [&](const type::Vector* vec) {
- if (vec->Packed()) {
- out << "packed_";
- }
- EmitType(out, vec->type());
- out << vec->Width();
- },
- [&](const type::Matrix* mat) {
- EmitType(out, mat->type());
- out << mat->columns() << "x" << mat->rows();
- },
- [&](const type::Atomic* atomic) {
- if (atomic->Type()->Is<type::I32>()) {
- out << "atomic_int";
- return;
- }
- if (TINT_LIKELY(atomic->Type()->Is<type::U32>())) {
- out << "atomic_uint";
- return;
- }
- TINT_ICE(Writer, diagnostics_)
- << "unhandled atomic type " << atomic->Type()->FriendlyName();
- },
- [&](const type::Pointer* ptr) {
- if (ptr->Access() == builtin::Access::kRead) {
- out << "const ";
- }
- EmitAddressSpace(out, ptr->AddressSpace());
- out << " ";
- EmitType(out, ptr->StoreType());
- out << "*";
- },
+ [&](const type::Array* arr) { EmitArrayType(out, arr); },
+ [&](const type::Vector* vec) { EmitVectorType(out, vec); },
+ [&](const type::Matrix* mat) { EmitMatrixType(out, mat); },
+ [&](const type::Atomic* atomic) { EmitAtomicType(out, atomic); },
+ [&](const type::Pointer* ptr) { EmitPointerType(out, ptr); },
[&](const type::Sampler*) { out << "sampler"; }, //
- [&](const type::Texture* tex) {
- if (TINT_UNLIKELY(tex->Is<type::ExternalTexture>())) {
- TINT_ICE(Writer, diagnostics_)
- << "Multiplanar external texture transform was not run.";
- return;
- }
-
- if (tex->IsAnyOf<type::DepthTexture, type::DepthMultisampledTexture>()) {
- out << "depth";
- } else {
- out << "texture";
- }
-
- switch (tex->dim()) {
- case type::TextureDimension::k1d:
- out << "1d";
- break;
- case type::TextureDimension::k2d:
- out << "2d";
- break;
- case type::TextureDimension::k2dArray:
- out << "2d_array";
- break;
- case type::TextureDimension::k3d:
- out << "3d";
- break;
- case type::TextureDimension::kCube:
- out << "cube";
- break;
- case type::TextureDimension::kCubeArray:
- out << "cube_array";
- break;
- default:
- diagnostics_.add_error(diag::System::Writer, "Invalid texture dimensions");
- return;
- }
- if (tex->IsAnyOf<type::MultisampledTexture, type::DepthMultisampledTexture>()) {
- out << "_ms";
- }
- out << "<";
- TINT_DEFER(out << ">");
-
- tint::Switch(
- tex, //
- [&](const type::DepthTexture*) { out << "float, access::sample"; },
- [&](const type::DepthMultisampledTexture*) { out << "float, access::read"; },
- [&](const type::StorageTexture* storage) {
- EmitType(out, storage->type());
- out << ", ";
-
- std::string access_str;
- if (storage->access() == builtin::Access::kRead) {
- out << "access::read";
- } else if (storage->access() == builtin::Access::kWrite) {
- out << "access::write";
- } else {
- diagnostics_.add_error(diag::System::Writer,
- "Invalid access control for storage texture");
- return;
- }
- },
- [&](const type::MultisampledTexture* ms) {
- EmitType(out, ms->type());
- out << ", access::read";
- },
- [&](const type::SampledTexture* sampled) {
- EmitType(out, sampled->type());
- out << ", access::sample";
- },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer, "invalid texture type");
- });
- },
+ [&](const type::Texture* tex) { EmitTextureType(out, tex); },
[&](const type::Struct* str) {
out << StructName(str);
@@ -294,6 +180,129 @@
[&](Default) { UNHANDLED_CASE(ty); });
}
+void GeneratorImplIr::EmitPointerType(utils::StringStream& out, const type::Pointer* ptr) {
+ if (ptr->Access() == builtin::Access::kRead) {
+ out << "const ";
+ }
+ EmitAddressSpace(out, ptr->AddressSpace());
+ out << " ";
+ EmitType(out, ptr->StoreType());
+ out << "*";
+}
+
+void GeneratorImplIr::EmitAtomicType(utils::StringStream& out, const type::Atomic* atomic) {
+ if (atomic->Type()->Is<type::I32>()) {
+ out << "atomic_int";
+ return;
+ }
+ if (TINT_LIKELY(atomic->Type()->Is<type::U32>())) {
+ out << "atomic_uint";
+ return;
+ }
+ TINT_ICE(Writer, diagnostics_) << "unhandled atomic type " << atomic->Type()->FriendlyName();
+}
+
+void GeneratorImplIr::EmitArrayType(utils::StringStream& out, const type::Array* arr) {
+ out << ArrayTemplateName() << "<";
+ EmitType(out, arr->ElemType());
+ out << ", ";
+ if (arr->Count()->Is<type::RuntimeArrayCount>()) {
+ out << "1";
+ } else {
+ auto count = arr->ConstantCount();
+ if (!count) {
+ diagnostics_.add_error(diag::System::Writer, type::Array::kErrExpectedConstantCount);
+ return;
+ }
+ out << count.value();
+ }
+ out << ">";
+}
+
+void GeneratorImplIr::EmitVectorType(utils::StringStream& out, const type::Vector* vec) {
+ if (vec->Packed()) {
+ out << "packed_";
+ }
+ EmitType(out, vec->type());
+ out << vec->Width();
+}
+
+void GeneratorImplIr::EmitMatrixType(utils::StringStream& out, const type::Matrix* mat) {
+ EmitType(out, mat->type());
+ out << mat->columns() << "x" << mat->rows();
+}
+
+void GeneratorImplIr::EmitTextureType(utils::StringStream& out, const type::Texture* tex) {
+ if (TINT_UNLIKELY(tex->Is<type::ExternalTexture>())) {
+ TINT_ICE(Writer, diagnostics_) << "Multiplanar external texture transform was not run.";
+ return;
+ }
+
+ if (tex->IsAnyOf<type::DepthTexture, type::DepthMultisampledTexture>()) {
+ out << "depth";
+ } else {
+ out << "texture";
+ }
+
+ switch (tex->dim()) {
+ case type::TextureDimension::k1d:
+ out << "1d";
+ break;
+ case type::TextureDimension::k2d:
+ out << "2d";
+ break;
+ case type::TextureDimension::k2dArray:
+ out << "2d_array";
+ break;
+ case type::TextureDimension::k3d:
+ out << "3d";
+ break;
+ case type::TextureDimension::kCube:
+ out << "cube";
+ break;
+ case type::TextureDimension::kCubeArray:
+ out << "cube_array";
+ break;
+ default:
+ diagnostics_.add_error(diag::System::Writer, "Invalid texture dimensions");
+ return;
+ }
+ if (tex->IsAnyOf<type::MultisampledTexture, type::DepthMultisampledTexture>()) {
+ out << "_ms";
+ }
+ out << "<";
+ TINT_DEFER(out << ">");
+
+ tint::Switch(
+ tex, //
+ [&](const type::DepthTexture*) { out << "float, access::sample"; },
+ [&](const type::DepthMultisampledTexture*) { out << "float, access::read"; },
+ [&](const type::StorageTexture* storage) {
+ EmitType(out, storage->type());
+ out << ", ";
+
+ std::string access_str;
+ if (storage->access() == builtin::Access::kRead) {
+ out << "access::read";
+ } else if (storage->access() == builtin::Access::kWrite) {
+ out << "access::write";
+ } else {
+ diagnostics_.add_error(diag::System::Writer,
+ "Invalid access control for storage texture");
+ return;
+ }
+ },
+ [&](const type::MultisampledTexture* ms) {
+ EmitType(out, ms->type());
+ out << ", access::read";
+ },
+ [&](const type::SampledTexture* sampled) {
+ EmitType(out, sampled->type());
+ out << ", access::sample";
+ },
+ [&](Default) { diagnostics_.add_error(diag::System::Writer, "invalid texture type"); });
+}
+
void GeneratorImplIr::EmitStructType(const type::Struct* str) {
auto it = emitted_structs_.emplace(str);
if (!it.second) {
@@ -431,14 +440,77 @@
}
void GeneratorImplIr::EmitConstant(utils::StringStream& out, ir::Constant* c) {
- return tint::Switch(
+ EmitConstant(out, c->Value());
+}
+
+void GeneratorImplIr::EmitConstant(utils::StringStream& out, const constant::Value* c) {
+ auto emit_values = [&](uint32_t count) {
+ for (size_t i = 0; i < count; i++) {
+ if (i > 0) {
+ out << ", ";
+ }
+ EmitConstant(out, c->Index(i));
+ }
+ };
+
+ tint::Switch(
c->Type(), //
- [&](const type::Bool*) { out << (c->Value()->ValueAs<bool>() ? "true" : "false"); },
- [&](const type::I32*) { PrintI32(out, c->Value()->ValueAs<i32>()); },
- [&](const type::U32*) { out << c->Value()->ValueAs<u32>() << "u"; },
- [&](const type::F32*) { PrintF32(out, c->Value()->ValueAs<f32>()); },
- [&](const type::F16*) { PrintF16(out, c->Value()->ValueAs<f16>()); },
- [&](Default) { UNHANDLED_CASE(c); });
+ [&](const type::Bool*) { out << (c->ValueAs<bool>() ? "true" : "false"); },
+ [&](const type::I32*) { PrintI32(out, c->ValueAs<i32>()); },
+ [&](const type::U32*) { out << c->ValueAs<u32>() << "u"; },
+ [&](const type::F32*) { PrintF32(out, c->ValueAs<f32>()); },
+ [&](const type::F16*) { PrintF16(out, c->ValueAs<f16>()); },
+ [&](const type::Vector* v) {
+ EmitType(out, v);
+
+ ScopedParen sp(out);
+ if (auto* splat = c->As<constant::Splat>()) {
+ EmitConstant(out, splat->el);
+ return;
+ }
+ emit_values(v->Width());
+ },
+ [&](const type::Matrix* m) {
+ EmitType(out, m);
+ ScopedParen sp(out);
+ emit_values(m->columns());
+ },
+ [&](const type::Array* a) {
+ EmitType(out, a);
+ out << "{";
+ TINT_DEFER(out << "}");
+
+ if (c->AllZero()) {
+ return;
+ }
+
+ auto count = a->ConstantCount();
+ if (!count) {
+ diagnostics_.add_error(diag::System::Writer,
+ type::Array::kErrExpectedConstantCount);
+ return;
+ }
+ emit_values(*count);
+ },
+ [&](const type::Struct* s) {
+ EmitStructType(s);
+ out << StructName(s) << "{";
+ TINT_DEFER(out << "}");
+
+ if (c->AllZero()) {
+ return;
+ }
+
+ auto members = s->Members();
+ for (size_t i = 0; i < members.Length(); i++) {
+ if (i > 0) {
+ out << ", ";
+ }
+ out << "." << members[i]->Name().Name() << "=";
+ EmitConstant(out, c->Index(i));
+ }
+ },
+ [&](Default) { UNHANDLED_CASE(c->Type()); });
}
} // namespace tint::writer::msl
diff --git a/src/tint/writer/msl/ir/generator_impl_ir.h b/src/tint/writer/msl/ir/generator_impl_ir.h
index 89ec1a1..80f0cb2 100644
--- a/src/tint/writer/msl/ir/generator_impl_ir.h
+++ b/src/tint/writer/msl/ir/generator_impl_ir.h
@@ -20,6 +20,7 @@
#include "src/tint/diagnostic/diagnostic.h"
#include "src/tint/ir/module.h"
+#include "src/tint/type/texture.h"
#include "src/tint/utils/string_stream.h"
#include "src/tint/writer/ir_text_generator.h"
@@ -45,6 +46,30 @@
/// @param ty the type to emit
void EmitType(utils::StringStream& out, const type::Type* ty);
+ /// Handles generating an array declaration
+ /// @param out the output stream
+ /// @param arr the array to emit
+ void EmitArrayType(utils::StringStream& out, const type::Array* arr);
+ /// Handles generating an atomic declaration
+ /// @param out the output stream
+ /// @param atomic the atomic to emit
+ void EmitAtomicType(utils::StringStream& out, const type::Atomic* atomic);
+ /// Handles generating a pointer declaration
+ /// @param out the output stream
+ /// @param ptr the pointer to emit
+ void EmitPointerType(utils::StringStream& out, const type::Pointer* ptr);
+ /// Handles generating a vector declaration
+ /// @param out the output stream
+ /// @param vec the vector to emit
+ void EmitVectorType(utils::StringStream& out, const type::Vector* vec);
+ /// Handles generating a matrix declaration
+ /// @param out the output stream
+ /// @param mat the matrix to emit
+ void EmitMatrixType(utils::StringStream& out, const type::Matrix* mat);
+ /// Handles generating a texture declaration
+ /// @param out the output stream
+ /// @param tex the texture to emit
+ void EmitTextureType(utils::StringStream& out, const type::Texture* tex);
/// Handles generating a struct declaration. If the structure has already been emitted, then
/// this function will simply return without emitting anything.
/// @param str the struct to generate
@@ -59,6 +84,10 @@
/// @param out the stream to write the constant too
/// @param c the constant to emit
void EmitConstant(utils::StringStream& out, ir::Constant* c);
+ /// Handles constant::Value values
+ /// @param out the stream to write the constant too
+ /// @param c the constant to emit
+ void EmitConstant(utils::StringStream& out, const constant::Value* c);
/// @returns the name of the templated `tint_array` helper type, generating it if needed
const std::string& ArrayTemplateName();
diff --git a/src/tint/writer/msl/ir/generator_impl_ir_constant_test.cc b/src/tint/writer/msl/ir/generator_impl_ir_constant_test.cc
index 41a6139..9cdd570 100644
--- a/src/tint/writer/msl/ir/generator_impl_ir_constant_test.cc
+++ b/src/tint/writer/msl/ir/generator_impl_ir_constant_test.cc
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "src/tint/type/array.h"
+#include "src/tint/type/matrix.h"
#include "src/tint/utils/string.h"
#include "src/tint/writer/msl/ir/test_helper_ir.h"
@@ -62,5 +64,251 @@
EXPECT_EQ(utils::TrimSpace(generator_.Result()), std::string("32752.0h"));
}
+TEST_F(MslGeneratorImplIrTest, Constant_Vector_Splat) {
+ auto* c = b.Constant(mod.constant_values.Splat(ty.vec3<f32>(), b.Constant(1.5_f)->Value(), 3));
+ generator_.EmitConstant(generator_.Line(), c);
+ ASSERT_TRUE(generator_.Diagnostics().empty()) << generator_.Diagnostics().str();
+ EXPECT_EQ(utils::TrimSpace(generator_.Result()), std::string("float3(1.5f)"));
+}
+
+TEST_F(MslGeneratorImplIrTest, Constant_Vector_Composite) {
+ auto* c = b.Constant(mod.constant_values.Composite(
+ ty.vec3<f32>(), utils::Vector{b.Constant(1.5_f)->Value(), b.Constant(1.0_f)->Value(),
+ b.Constant(1.5_f)->Value()}));
+ generator_.EmitConstant(generator_.Line(), c);
+ ASSERT_TRUE(generator_.Diagnostics().empty()) << generator_.Diagnostics().str();
+ EXPECT_EQ(utils::TrimSpace(generator_.Result()), std::string("float3(1.5f, 1.0f, 1.5f)"));
+}
+
+TEST_F(MslGeneratorImplIrTest, Constant_Vector_Composite_AnyZero) {
+ auto* c = b.Constant(mod.constant_values.Composite(
+ ty.vec3<f32>(), utils::Vector{b.Constant(1.0_f)->Value(), b.Constant(0.0_f)->Value(),
+ b.Constant(1.5_f)->Value()}));
+ generator_.EmitConstant(generator_.Line(), c);
+ ASSERT_TRUE(generator_.Diagnostics().empty()) << generator_.Diagnostics().str();
+ EXPECT_EQ(utils::TrimSpace(generator_.Result()), std::string("float3(1.0f, 0.0f, 1.5f)"));
+}
+
+TEST_F(MslGeneratorImplIrTest, Constant_Vector_Composite_AllZero) {
+ auto* c = b.Constant(mod.constant_values.Composite(
+ ty.vec3<f32>(), utils::Vector{b.Constant(0.0_f)->Value(), b.Constant(0.0_f)->Value(),
+ b.Constant(0.0_f)->Value()}));
+ generator_.EmitConstant(generator_.Line(), c);
+ ASSERT_TRUE(generator_.Diagnostics().empty()) << generator_.Diagnostics().str();
+ EXPECT_EQ(utils::TrimSpace(generator_.Result()), std::string("float3(0.0f)"));
+}
+
+TEST_F(MslGeneratorImplIrTest, Constant_Matrix_Splat) {
+ auto* c =
+ b.Constant(mod.constant_values.Splat(ty.mat3x2<f32>(), b.Constant(1.5_f)->Value(), 3));
+ generator_.EmitConstant(generator_.Line(), c);
+ ASSERT_TRUE(generator_.Diagnostics().empty()) << generator_.Diagnostics().str();
+ EXPECT_EQ(utils::TrimSpace(generator_.Result()), std::string("float3x2(1.5f, 1.5f, 1.5f)"));
+}
+
+TEST_F(MslGeneratorImplIrTest, Constant_Matrix_Composite) {
+ auto* c = b.Constant(mod.constant_values.Composite(
+ ty.mat3x2<f32>(),
+ utils::Vector{mod.constant_values.Composite(
+ ty.vec2<f32>(),
+ utils::Vector{b.Constant(1.5_f)->Value(), b.Constant(1.0_f)->Value()}),
+ mod.constant_values.Composite(
+ ty.vec2<f32>(),
+ utils::Vector{b.Constant(1.5_f)->Value(), b.Constant(2.0_f)->Value()}),
+ mod.constant_values.Composite(
+ ty.vec2<f32>(),
+ utils::Vector{b.Constant(2.5_f)->Value(), b.Constant(3.5_f)->Value()})}));
+ generator_.EmitConstant(generator_.Line(), c);
+ ASSERT_TRUE(generator_.Diagnostics().empty()) << generator_.Diagnostics().str();
+ EXPECT_EQ(utils::TrimSpace(generator_.Result()),
+ std::string("float3x2(float2(1.5f, 1.0f), float2(1.5f, 2.0f), float2(2.5f, 3.5f))"));
+}
+
+TEST_F(MslGeneratorImplIrTest, Constant_Matrix_Composite_AnyZero) {
+ auto* c = b.Constant(mod.constant_values.Composite(
+ ty.mat2x2<f32>(),
+ utils::Vector{mod.constant_values.Composite(
+ ty.vec2<f32>(),
+ utils::Vector{b.Constant(1.0_f)->Value(), b.Constant(0.0_f)->Value()}),
+ mod.constant_values.Composite(
+ ty.vec2<f32>(),
+ utils::Vector{b.Constant(1.5_f)->Value(), b.Constant(2.5_f)->Value()})}));
+ generator_.EmitConstant(generator_.Line(), c);
+ ASSERT_TRUE(generator_.Diagnostics().empty()) << generator_.Diagnostics().str();
+ EXPECT_EQ(utils::TrimSpace(generator_.Result()),
+ std::string("float2x2(float2(1.0f, 0.0f), float2(1.5f, 2.5f))"));
+}
+
+TEST_F(MslGeneratorImplIrTest, Constant_Matrix_Composite_AllZero) {
+ auto* c = b.Constant(mod.constant_values.Composite(
+ ty.mat3x2<f32>(),
+ utils::Vector{mod.constant_values.Composite(
+ ty.vec2<f32>(),
+ utils::Vector{b.Constant(0.0_f)->Value(), b.Constant(0.0_f)->Value()}),
+ mod.constant_values.Composite(
+ ty.vec2<f32>(),
+ utils::Vector{b.Constant(0.0_f)->Value(), b.Constant(0.0_f)->Value()}),
+ mod.constant_values.Composite(
+ ty.vec2<f32>(),
+ utils::Vector{b.Constant(0.0_f)->Value(), b.Constant(0.0_f)->Value()})}));
+ generator_.EmitConstant(generator_.Line(), c);
+ ASSERT_TRUE(generator_.Diagnostics().empty()) << generator_.Diagnostics().str();
+ EXPECT_EQ(utils::TrimSpace(generator_.Result()),
+ std::string("float3x2(float2(0.0f), float2(0.0f), float2(0.0f))"));
+}
+
+TEST_F(MslGeneratorImplIrTest, Constant_Array_Splat) {
+ auto* c =
+ b.Constant(mod.constant_values.Splat(ty.array<f32, 3>(), b.Constant(1.5_f)->Value(), 3));
+ generator_.EmitConstant(generator_.Line(), c);
+ ASSERT_TRUE(generator_.Diagnostics().empty()) << generator_.Diagnostics().str();
+ EXPECT_EQ(utils::TrimSpace(generator_.Result()), R"(template<typename T, size_t N>
+struct tint_array {
+ const constant T& operator[](size_t i) const constant { return elements[i]; }
+ device T& operator[](size_t i) device { return elements[i]; }
+ const device T& operator[](size_t i) const device { return elements[i]; }
+ thread T& operator[](size_t i) thread { return elements[i]; }
+ const thread T& operator[](size_t i) const thread { return elements[i]; }
+ threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
+ const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
+ T elements[N];
+};
+
+
+tint_array<float, 3>{1.5f, 1.5f, 1.5f})");
+}
+
+TEST_F(MslGeneratorImplIrTest, Constant_Array_Composite) {
+ auto* c = b.Constant(mod.constant_values.Composite(
+ ty.array<f32, 3>(), utils::Vector{b.Constant(1.5_f)->Value(), b.Constant(1.0_f)->Value(),
+ b.Constant(2.0_f)->Value()}));
+ generator_.EmitConstant(generator_.Line(), c);
+ ASSERT_TRUE(generator_.Diagnostics().empty()) << generator_.Diagnostics().str();
+ EXPECT_EQ(utils::TrimSpace(generator_.Result()), std::string(R"(template<typename T, size_t N>
+struct tint_array {
+ const constant T& operator[](size_t i) const constant { return elements[i]; }
+ device T& operator[](size_t i) device { return elements[i]; }
+ const device T& operator[](size_t i) const device { return elements[i]; }
+ thread T& operator[](size_t i) thread { return elements[i]; }
+ const thread T& operator[](size_t i) const thread { return elements[i]; }
+ threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
+ const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
+ T elements[N];
+};
+
+
+tint_array<float, 3>{1.5f, 1.0f, 2.0f})"));
+}
+
+TEST_F(MslGeneratorImplIrTest, Constant_Array_Composite_AnyZero) {
+ auto* c = b.Constant(mod.constant_values.Composite(
+ ty.array<f32, 2>(), utils::Vector{b.Constant(1.0_f)->Value(), b.Constant(0.0_f)->Value()}));
+ generator_.EmitConstant(generator_.Line(), c);
+ ASSERT_TRUE(generator_.Diagnostics().empty()) << generator_.Diagnostics().str();
+ EXPECT_EQ(utils::TrimSpace(generator_.Result()), R"(template<typename T, size_t N>
+struct tint_array {
+ const constant T& operator[](size_t i) const constant { return elements[i]; }
+ device T& operator[](size_t i) device { return elements[i]; }
+ const device T& operator[](size_t i) const device { return elements[i]; }
+ thread T& operator[](size_t i) thread { return elements[i]; }
+ const thread T& operator[](size_t i) const thread { return elements[i]; }
+ threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
+ const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
+ T elements[N];
+};
+
+
+tint_array<float, 2>{1.0f, 0.0f})");
+}
+
+TEST_F(MslGeneratorImplIrTest, Constant_Array_Composite_AllZero) {
+ auto* c = b.Constant(mod.constant_values.Composite(
+ ty.array<f32, 3>(), utils::Vector{b.Constant(0.0_f)->Value(), b.Constant(0.0_f)->Value(),
+ b.Constant(0.0_f)->Value()}));
+ generator_.EmitConstant(generator_.Line(), c);
+ ASSERT_TRUE(generator_.Diagnostics().empty()) << generator_.Diagnostics().str();
+ EXPECT_EQ(utils::TrimSpace(generator_.Result()), R"(template<typename T, size_t N>
+struct tint_array {
+ const constant T& operator[](size_t i) const constant { return elements[i]; }
+ device T& operator[](size_t i) device { return elements[i]; }
+ const device T& operator[](size_t i) const device { return elements[i]; }
+ thread T& operator[](size_t i) thread { return elements[i]; }
+ const thread T& operator[](size_t i) const thread { return elements[i]; }
+ threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
+ const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
+ T elements[N];
+};
+
+
+tint_array<float, 3>{})");
+}
+
+TEST_F(MslGeneratorImplIrTest, Constant_Struct_Splat) {
+ auto* s = ty.Struct(mod.symbols.New("S"), {
+ {mod.symbols.Register("a"), ty.f32()},
+ {mod.symbols.Register("b"), ty.f32()},
+ });
+ auto* c = b.Constant(mod.constant_values.Splat(s, b.Constant(1.5_f)->Value(), 2));
+ generator_.EmitConstant(generator_.Line(), c);
+ ASSERT_TRUE(generator_.Diagnostics().empty()) << generator_.Diagnostics().str();
+ EXPECT_EQ(utils::TrimSpace(generator_.Result()), std::string(R"(struct S {
+ float a;
+ float b;
+};
+
+S{.a=1.5f, .b=1.5f})"));
+}
+
+TEST_F(MslGeneratorImplIrTest, Constant_Struct_Composite) {
+ auto* s = ty.Struct(mod.symbols.New("S"), {
+ {mod.symbols.Register("a"), ty.f32()},
+ {mod.symbols.Register("b"), ty.f32()},
+ });
+ auto* c = b.Constant(mod.constant_values.Composite(
+ s, utils::Vector{b.Constant(1.5_f)->Value(), b.Constant(1.0_f)->Value()}));
+ generator_.EmitConstant(generator_.Line(), c);
+ ASSERT_TRUE(generator_.Diagnostics().empty()) << generator_.Diagnostics().str();
+ EXPECT_EQ(utils::TrimSpace(generator_.Result()), std::string(R"(struct S {
+ float a;
+ float b;
+};
+
+S{.a=1.5f, .b=1.0f})"));
+}
+
+TEST_F(MslGeneratorImplIrTest, Constant_Struct_Composite_AnyZero) {
+ auto* s = ty.Struct(mod.symbols.New("S"), {
+ {mod.symbols.Register("a"), ty.f32()},
+ {mod.symbols.Register("b"), ty.f32()},
+ });
+ auto* c = b.Constant(mod.constant_values.Composite(
+ s, utils::Vector{b.Constant(1.0_f)->Value(), b.Constant(0.0_f)->Value()}));
+ generator_.EmitConstant(generator_.Line(), c);
+ ASSERT_TRUE(generator_.Diagnostics().empty()) << generator_.Diagnostics().str();
+ EXPECT_EQ(utils::TrimSpace(generator_.Result()), std::string(R"(struct S {
+ float a;
+ float b;
+};
+
+S{.a=1.0f, .b=0.0f})"));
+}
+
+TEST_F(MslGeneratorImplIrTest, Constant_Struct_Composite_AllZero) {
+ auto* s = ty.Struct(mod.symbols.New("S"), {
+ {mod.symbols.Register("a"), ty.f32()},
+ {mod.symbols.Register("b"), ty.f32()},
+ });
+ auto* c = b.Constant(mod.constant_values.Composite(
+ s, utils::Vector{b.Constant(0.0_f)->Value(), b.Constant(0.0_f)->Value()}));
+ generator_.EmitConstant(generator_.Line(), c);
+ ASSERT_TRUE(generator_.Diagnostics().empty()) << generator_.Diagnostics().str();
+ EXPECT_EQ(utils::TrimSpace(generator_.Result()), std::string(R"(struct S {
+ float a;
+ float b;
+};
+
+S{})"));
+}
+
} // namespace
} // namespace tint::writer::msl
diff --git a/src/tint/writer/msl/ir/test_helper_ir.h b/src/tint/writer/msl/ir/test_helper_ir.h
index 38bbf9b..cf74688 100644
--- a/src/tint/writer/msl/ir/test_helper_ir.h
+++ b/src/tint/writer/msl/ir/test_helper_ir.h
@@ -15,6 +15,7 @@
#ifndef SRC_TINT_WRITER_MSL_IR_TEST_HELPER_IR_H_
#define SRC_TINT_WRITER_MSL_IR_TEST_HELPER_IR_H_
+#include <iostream>
#include <string>
#include "gtest/gtest.h"
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index 4f07640..1fd5db9 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -43,6 +43,7 @@
#include "src/tint/ir/transform/add_empty_entry_point.h"
#include "src/tint/ir/transform/block_decorated_structs.h"
#include "src/tint/ir/transform/merge_return.h"
+#include "src/tint/ir/transform/shader_io_spirv.h"
#include "src/tint/ir/transform/var_for_dynamic_index.h"
#include "src/tint/ir/unreachable.h"
#include "src/tint/ir/user_call.h"
@@ -77,8 +78,11 @@
manager.Add<ir::transform::AddEmptyEntryPoint>();
manager.Add<ir::transform::BlockDecoratedStructs>();
manager.Add<ir::transform::MergeReturn>();
+ manager.Add<ir::transform::ShaderIOSpirv>();
manager.Add<ir::transform::VarForDynamicIndex>();
+ data.Add<ir::transform::ShaderIO::Config>(ir::transform::ShaderIO::Config());
+
transform::DataMap outputs;
manager.Run(module, data, outputs);
}
@@ -87,8 +91,12 @@
switch (addrspace) {
case builtin::AddressSpace::kFunction:
return SpvStorageClassFunction;
+ case builtin::AddressSpace::kIn:
+ return SpvStorageClassInput;
case builtin::AddressSpace::kPrivate:
return SpvStorageClassPrivate;
+ case builtin::AddressSpace::kOut:
+ return SpvStorageClassOutput;
case builtin::AddressSpace::kStorage:
return SpvStorageClassStorageBuffer;
case builtin::AddressSpace::kUniform:
@@ -144,6 +152,47 @@
return true;
}
+uint32_t GeneratorImplIr::Builtin(builtin::BuiltinValue builtin, builtin::AddressSpace addrspace) {
+ switch (builtin) {
+ case builtin::BuiltinValue::kPointSize:
+ return SpvBuiltInPointSize;
+ case builtin::BuiltinValue::kFragDepth:
+ return SpvBuiltInFragDepth;
+ case builtin::BuiltinValue::kFrontFacing:
+ return SpvBuiltInFrontFacing;
+ case builtin::BuiltinValue::kGlobalInvocationId:
+ return SpvBuiltInGlobalInvocationId;
+ case builtin::BuiltinValue::kInstanceIndex:
+ return SpvBuiltInInstanceIndex;
+ case builtin::BuiltinValue::kLocalInvocationId:
+ return SpvBuiltInLocalInvocationId;
+ case builtin::BuiltinValue::kLocalInvocationIndex:
+ return SpvBuiltInLocalInvocationIndex;
+ case builtin::BuiltinValue::kNumWorkgroups:
+ return SpvBuiltInNumWorkgroups;
+ case builtin::BuiltinValue::kPosition:
+ if (addrspace == builtin::AddressSpace::kOut) {
+ // Vertex output.
+ return SpvBuiltInPosition;
+ } else {
+ // Fragment input.
+ return SpvBuiltInFragCoord;
+ }
+ case builtin::BuiltinValue::kSampleIndex:
+ module_.PushCapability(SpvCapabilitySampleRateShading);
+ return SpvBuiltInSampleId;
+ case builtin::BuiltinValue::kSampleMask:
+ return SpvBuiltInSampleMask;
+ case builtin::BuiltinValue::kVertexIndex:
+ return SpvBuiltInVertexIndex;
+ case builtin::BuiltinValue::kWorkgroupId:
+ return SpvBuiltInWorkgroupId;
+ case builtin::BuiltinValue::kUndefined:
+ return SpvBuiltInMax;
+ }
+ return SpvBuiltInMax;
+}
+
uint32_t GeneratorImplIr::Constant(ir::Constant* constant) {
return Constant(constant->Value());
}
@@ -218,7 +267,8 @@
});
}
-uint32_t GeneratorImplIr::Type(const type::Type* ty) {
+uint32_t GeneratorImplIr::Type(const type::Type* ty,
+ builtin::AddressSpace addrspace /* = kUndefined */) {
return types_.GetOrCreate(ty, [&] {
auto id = module_.NextId();
Switch(
@@ -261,11 +311,11 @@
{id, U32Operand(SpvDecorationArrayStride), arr->Stride()});
},
[&](const type::Pointer* ptr) {
- module_.PushType(
- spv::Op::OpTypePointer,
- {id, U32Operand(StorageClass(ptr->AddressSpace())), Type(ptr->StoreType())});
+ module_.PushType(spv::Op::OpTypePointer,
+ {id, U32Operand(StorageClass(ptr->AddressSpace())),
+ Type(ptr->StoreType(), ptr->AddressSpace())});
},
- [&](const type::Struct* str) { EmitStructType(id, str); },
+ [&](const type::Struct* str) { EmitStructType(id, str, addrspace); },
[&](Default) {
TINT_ICE(Writer, diagnostics_) << "unhandled type: " << ty->FriendlyName();
});
@@ -288,7 +338,9 @@
return block_labels_.GetOrCreate(block, [&] { return module_.NextId(); });
}
-void GeneratorImplIr::EmitStructType(uint32_t id, const type::Struct* str) {
+void GeneratorImplIr::EmitStructType(uint32_t id,
+ const type::Struct* str,
+ builtin::AddressSpace addrspace /* = kUndefined */) {
// Helper to return `type` or a potentially nested array element type within `type` as a matrix
// type, or nullptr if no such matrix type is present.
auto get_nested_matrix_type = [&](const type::Type* type) {
@@ -307,6 +359,56 @@
spv::Op::OpMemberDecorate,
{operands[0], member->Index(), U32Operand(SpvDecorationOffset), member->Offset()});
+ // Generate shader IO decorations.
+ const auto& attrs = member->Attributes();
+ if (attrs.location) {
+ module_.PushAnnot(
+ spv::Op::OpMemberDecorate,
+ {operands[0], member->Index(), U32Operand(SpvDecorationLocation), *attrs.location});
+ if (attrs.interpolation) {
+ switch (attrs.interpolation->type) {
+ case builtin::InterpolationType::kLinear:
+ module_.PushAnnot(
+ spv::Op::OpMemberDecorate,
+ {operands[0], member->Index(), U32Operand(SpvDecorationNoPerspective)});
+ break;
+ case builtin::InterpolationType::kFlat:
+ module_.PushAnnot(
+ spv::Op::OpMemberDecorate,
+ {operands[0], member->Index(), U32Operand(SpvDecorationFlat)});
+ break;
+ case builtin::InterpolationType::kPerspective:
+ case builtin::InterpolationType::kUndefined:
+ break;
+ }
+ switch (attrs.interpolation->sampling) {
+ case builtin::InterpolationSampling::kCentroid:
+ module_.PushAnnot(
+ spv::Op::OpMemberDecorate,
+ {operands[0], member->Index(), U32Operand(SpvDecorationCentroid)});
+ break;
+ case builtin::InterpolationSampling::kSample:
+ module_.PushCapability(SpvCapabilitySampleRateShading);
+ module_.PushAnnot(
+ spv::Op::OpMemberDecorate,
+ {operands[0], member->Index(), U32Operand(SpvDecorationSample)});
+ break;
+ case builtin::InterpolationSampling::kCenter:
+ case builtin::InterpolationSampling::kUndefined:
+ break;
+ }
+ }
+ }
+ if (attrs.builtin) {
+ module_.PushAnnot(spv::Op::OpMemberDecorate,
+ {operands[0], member->Index(), U32Operand(SpvDecorationBuiltIn),
+ Builtin(*attrs.builtin, addrspace)});
+ }
+ if (attrs.invariant) {
+ module_.PushAnnot(spv::Op::OpMemberDecorate,
+ {operands[0], member->Index(), U32Operand(SpvDecorationInvariant)});
+ }
+
// Emit matrix layout decorations if necessary.
if (auto* matrix_type = get_nested_matrix_type(member->Type())) {
const uint32_t effective_row_count = (matrix_type->rows() == 2) ? 2 : 4;
@@ -404,7 +506,6 @@
stage = SpvExecutionModelFragment;
module_.PushExecutionMode(spv::Op::OpExecutionMode,
{id, U32Operand(SpvExecutionModeOriginUpperLeft)});
- // TODO(jrprice): Add DepthReplacing execution mode if FragDepth is used.
break;
}
case ir::Function::PipelineStage::kVertex: {
@@ -416,9 +517,52 @@
return;
}
- // TODO(jrprice): Add the interface list of all referenced global variables.
- module_.PushEntryPoint(spv::Op::OpEntryPoint,
- {U32Operand(stage), id, ir_->NameOf(func).Name()});
+ OperandList operands = {U32Operand(stage), id, ir_->NameOf(func).Name()};
+
+ // Add the list of all referenced shader IO variables.
+ if (ir_->root_block) {
+ for (auto* global : *ir_->root_block) {
+ auto* var = global->As<ir::Var>();
+ if (!var) {
+ continue;
+ }
+
+ auto* ptr = var->Result()->Type()->As<type::Pointer>();
+ if (!(ptr->AddressSpace() == builtin::AddressSpace::kIn ||
+ ptr->AddressSpace() == builtin::AddressSpace::kOut)) {
+ continue;
+ }
+
+ // Determine if this IO variable is used by the entry point.
+ bool used = false;
+ for (const auto& use : var->Result()->Usages()) {
+ auto* block = use.instruction->Block();
+ while (block->Parent()) {
+ block = block->Parent()->Block();
+ }
+ if (block == func->Block()) {
+ used = true;
+ break;
+ }
+ }
+ if (!used) {
+ continue;
+ }
+ operands.push_back(Value(var));
+
+ // Add the `DepthReplacing` execution mode if `frag_depth` is used.
+ if (auto* str = ptr->StoreType()->As<type::Struct>()) {
+ for (auto* member : str->Members()) {
+ if (member->Attributes().builtin == builtin::BuiltinValue::kFragDepth) {
+ module_.PushExecutionMode(spv::Op::OpExecutionMode,
+ {id, U32Operand(SpvExecutionModeDepthReplacing)});
+ }
+ }
+ }
+ }
+ }
+
+ module_.PushEntryPoint(spv::Op::OpEntryPoint, operands);
}
void GeneratorImplIr::EmitRootBlock(ir::Block* root_block) {
@@ -492,6 +636,13 @@
TINT_ICE(Writer, diagnostics_)
<< "unimplemented instruction: " << inst->TypeInfo().name;
});
+
+ // Set the name for the SPIR-V result ID if provided in the module.
+ if (inst->Result() && !inst->Is<ir::Var>()) {
+ if (auto name = ir_->NameOf(inst)) {
+ module_.PushDebug(spv::Op::OpName, {Value(inst), Operand(name.Name())});
+ }
+ }
}
if (block->IsEmpty()) {
@@ -932,6 +1083,11 @@
}
break;
}
+ case builtin::AddressSpace::kIn: {
+ TINT_ASSERT(Writer, !current_function_);
+ module_.PushType(spv::Op::OpVariable, {ty, id, U32Operand(SpvStorageClassInput)});
+ break;
+ }
case builtin::AddressSpace::kPrivate: {
TINT_ASSERT(Writer, !current_function_);
OperandList operands = {ty, id, U32Operand(SpvStorageClassPrivate)};
@@ -942,6 +1098,11 @@
module_.PushType(spv::Op::OpVariable, operands);
break;
}
+ case builtin::AddressSpace::kOut: {
+ TINT_ASSERT(Writer, !current_function_);
+ module_.PushType(spv::Op::OpVariable, {ty, id, U32Operand(SpvStorageClassOutput)});
+ break;
+ }
case builtin::AddressSpace::kStorage:
case builtin::AddressSpace::kUniform: {
TINT_ASSERT(Writer, !current_function_);
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.h b/src/tint/writer/spirv/ir/generator_impl_ir.h
index d0dc4bd..ae1bc95 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.h
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.h
@@ -92,8 +92,10 @@
/// Get the result ID of the type `ty`, emitting a type declaration instruction if necessary.
/// @param ty the type to get the ID for
+ /// @param addrspace the optional address space that this type is being used for
/// @returns the result ID of the type
- uint32_t Type(const type::Type* ty);
+ uint32_t Type(const type::Type* ty,
+ builtin::AddressSpace addrspace = builtin::AddressSpace::kUndefined);
/// Get the result ID of the value `value`, emitting its instruction if necessary.
/// @param value the value to get the ID for
@@ -112,8 +114,11 @@
/// Emit a struct type.
/// @param id the result ID to use
+ /// @param addrspace the optional address space that this type is being used for
/// @param str the struct type to emit
- void EmitStructType(uint32_t id, const type::Struct* str);
+ void EmitStructType(uint32_t id,
+ const type::Struct* str,
+ builtin::AddressSpace addrspace = builtin::AddressSpace::kUndefined);
/// Emit a function.
/// @param func the function to emit
@@ -193,6 +198,13 @@
void EmitExitPhis(ir::ControlInstruction* inst);
private:
+ /// Convert a builtin to the corresponding SPIR-V enum value, taking into account the target
+ /// address space. Adds any capabilities needed for the builtin.
+ /// @param builtin the builtin to convert
+ /// @param addrspace the address space the builtin is being used in
+ /// @returns the enum value of the corresponding SPIR-V builtin
+ uint32_t Builtin(builtin::BuiltinValue builtin, builtin::AddressSpace addrspace);
+
/// Get the result ID of the constant `constant`, emitting its instruction if necessary.
/// @param constant the constant to get the ID for
/// @returns the result ID of the constant
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc
index 40f4d4d..df494fb 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "gmock/gmock.h"
#include "src/tint/writer/spirv/ir/test_helper_ir.h"
namespace tint::writer::spirv {
@@ -20,452 +21,213 @@
using namespace tint::builtin::fluent_types; // NOLINT
using namespace tint::number_suffixes; // NOLINT
-using SpvGeneratorImplTest_Access = SpvGeneratorImplTest;
-
-TEST_F(SpvGeneratorImplTest_Access, Array_Value_ConstantIndex) {
- auto* arr_val = b.FunctionParam(ty.array(ty.i32(), 4));
+TEST_F(SpvGeneratorImplTest, Access_Array_Value_ConstantIndex) {
+ auto* arr_val = b.FunctionParam("arr", ty.array(ty.i32(), 4));
auto* func = b.Function("foo", ty.void_());
func->SetParams({arr_val});
-
b.With(func->Block(), [&] {
- b.Access(ty.i32(), arr_val, 1_u);
+ auto* result = b.Access(ty.i32(), arr_val, 1_u);
b.Return(func);
+ mod.SetName(result, "result");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-OpDecorate %3 ArrayStride 4
-%2 = OpTypeVoid
-%4 = OpTypeInt 32 1
-%6 = OpTypeInt 32 0
-%5 = OpConstant %6 4
-%3 = OpTypeArray %4 %5
-%8 = OpTypeFunction %2 %3
-%1 = OpFunction %2 None %8
-%7 = OpFunctionParameter %3
-%9 = OpLabel
-%10 = OpCompositeExtract %4 %7 1
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error();
+ EXPECT_INST("%result = OpCompositeExtract %int %arr 1");
}
-TEST_F(SpvGeneratorImplTest_Access, Array_Pointer_ConstantIndex) {
+TEST_F(SpvGeneratorImplTest, Access_Array_Pointer_ConstantIndex) {
auto* func = b.Function("foo", ty.void_());
-
b.With(func->Block(), [&] {
- auto* arr_var = b.Var(ty.ptr<function, array<i32, 4>>());
- b.Access(ty.ptr<function, i32>(), arr_var, 1_u);
+ auto* arr_var = b.Var("arr", ty.ptr<function, array<i32, 4>>());
+ auto* result = b.Access(ty.ptr<function, i32>(), arr_var, 1_u);
b.Return(func);
+ mod.SetName(result, "result");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-OpDecorate %7 ArrayStride 4
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%8 = OpTypeInt 32 1
-%10 = OpTypeInt 32 0
-%9 = OpConstant %10 4
-%7 = OpTypeArray %8 %9
-%6 = OpTypePointer Function %7
-%12 = OpTypePointer Function %8
-%13 = OpConstant %10 1
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpVariable %6 Function
-%11 = OpAccessChain %12 %5 %13
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error();
+ EXPECT_INST("%result = OpAccessChain %_ptr_Function_int %arr %uint_1");
}
-TEST_F(SpvGeneratorImplTest_Access, Array_Pointer_DynamicIndex) {
+TEST_F(SpvGeneratorImplTest, Access_Array_Pointer_DynamicIndex) {
+ auto* idx = b.FunctionParam("idx", ty.i32());
auto* func = b.Function("foo", ty.void_());
-
+ func->SetParams({idx});
b.With(func->Block(), [&] {
- auto* idx_var = b.Var(ty.ptr<function, i32>());
- auto* idx = b.Load(idx_var);
- auto* arr_var = b.Var(ty.ptr<function, array<i32, 4>>());
- b.Access(ty.ptr<function, i32>(), arr_var, idx);
+ auto* arr_var = b.Var("arr", ty.ptr<function, array<i32, 4>>());
+ auto* result = b.Access(ty.ptr<function, i32>(), arr_var, idx);
b.Return(func);
+ mod.SetName(result, "result");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-OpDecorate %11 ArrayStride 4
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%7 = OpTypeInt 32 1
-%6 = OpTypePointer Function %7
-%13 = OpTypeInt 32 0
-%12 = OpConstant %13 4
-%11 = OpTypeArray %7 %12
-%10 = OpTypePointer Function %11
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpVariable %6 Function
-%9 = OpVariable %10 Function
-%8 = OpLoad %7 %5
-%14 = OpAccessChain %6 %9 %8
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error();
+ EXPECT_INST("%result = OpAccessChain %_ptr_Function_int %arr %idx");
}
-TEST_F(SpvGeneratorImplTest_Access, Matrix_Value_ConstantIndex) {
- auto* mat_val = b.FunctionParam(ty.mat2x2(ty.f32()));
+TEST_F(SpvGeneratorImplTest, Access_Matrix_Value_ConstantIndex) {
+ auto* mat_val = b.FunctionParam("mat", ty.mat2x2(ty.f32()));
auto* func = b.Function("foo", ty.void_());
func->SetParams({mat_val});
-
b.With(func->Block(), [&] {
- b.Access(ty.vec2(ty.f32()), mat_val, 1_u);
- b.Access(ty.f32(), mat_val, 1_u, 0_u);
+ auto* result_vector = b.Access(ty.vec2(ty.f32()), mat_val, 1_u);
+ auto* result_scalar = b.Access(ty.f32(), mat_val, 1_u, 0_u);
b.Return(func);
+ mod.SetName(result_vector, "result_vector");
+ mod.SetName(result_scalar, "result_scalar");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%5 = OpTypeFloat 32
-%4 = OpTypeVector %5 2
-%3 = OpTypeMatrix %4 2
-%7 = OpTypeFunction %2 %3
-%1 = OpFunction %2 None %7
-%6 = OpFunctionParameter %3
-%8 = OpLabel
-%9 = OpCompositeExtract %4 %6 1
-%10 = OpCompositeExtract %5 %6 1 0
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error();
+ EXPECT_INST("%result_vector = OpCompositeExtract %v2float %mat 1");
+ EXPECT_INST("%result_scalar = OpCompositeExtract %float %mat 1 0");
}
-TEST_F(SpvGeneratorImplTest_Access, Matrix_Pointer_ConstantIndex) {
+TEST_F(SpvGeneratorImplTest, Access_Matrix_Pointer_ConstantIndex) {
auto* func = b.Function("foo", ty.void_());
-
b.With(func->Block(), [&] {
- auto* mat_var = b.Var(ty.ptr<function, mat2x2<f32>>());
- b.Access(ty.ptr<function, vec2<f32>>(), mat_var, 1_u);
- b.Access(ty.ptr<function, f32>(), mat_var, 1_u, 0_u);
+ auto* mat_var = b.Var("mat", ty.ptr<function, mat2x2<f32>>());
+ auto* result_vector = b.Access(ty.ptr<function, vec2<f32>>(), mat_var, 1_u);
+ auto* result_scalar = b.Access(ty.ptr<function, f32>(), mat_var, 1_u, 0_u);
b.Return(func);
+ mod.SetName(result_vector, "result_vector");
+ mod.SetName(result_scalar, "result_scalar");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%9 = OpTypeFloat 32
-%8 = OpTypeVector %9 2
-%7 = OpTypeMatrix %8 2
-%6 = OpTypePointer Function %7
-%11 = OpTypePointer Function %8
-%13 = OpTypeInt 32 0
-%12 = OpConstant %13 1
-%15 = OpTypePointer Function %9
-%16 = OpConstant %13 0
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpVariable %6 Function
-%10 = OpAccessChain %11 %5 %12
-%14 = OpAccessChain %15 %5 %12 %16
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error();
+ EXPECT_INST("%result_vector = OpAccessChain %_ptr_Function_v2float %mat %uint_1");
+ EXPECT_INST("%result_scalar = OpAccessChain %_ptr_Function_float %mat %uint_1 %uint_0");
}
-TEST_F(SpvGeneratorImplTest_Access, Matrix_Pointer_DynamicIndex) {
+TEST_F(SpvGeneratorImplTest, Access_Matrix_Pointer_DynamicIndex) {
+ auto* idx = b.FunctionParam("idx", ty.i32());
auto* func = b.Function("foo", ty.void_());
-
+ func->SetParams({idx});
b.With(func->Block(), [&] {
- auto* idx_var = b.Var(ty.ptr<function, i32>());
- auto* idx = b.Load(idx_var);
- auto* mat_var = b.Var(ty.ptr<function, mat2x2<f32>>());
- b.Access(ty.ptr<function, vec2<f32>>(), mat_var, idx);
- b.Access(ty.ptr<function, f32>(), mat_var, idx, idx);
+ auto* mat_var = b.Var("mat", ty.ptr<function, mat2x2<f32>>());
+ auto* result_vector = b.Access(ty.ptr<function, vec2<f32>>(), mat_var, idx);
+ auto* result_scalar = b.Access(ty.ptr<function, f32>(), mat_var, idx, idx);
b.Return(func);
+ mod.SetName(result_vector, "result_vector");
+ mod.SetName(result_scalar, "result_scalar");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%7 = OpTypeInt 32 1
-%6 = OpTypePointer Function %7
-%13 = OpTypeFloat 32
-%12 = OpTypeVector %13 2
-%11 = OpTypeMatrix %12 2
-%10 = OpTypePointer Function %11
-%15 = OpTypePointer Function %12
-%17 = OpTypePointer Function %13
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpVariable %6 Function
-%9 = OpVariable %10 Function
-%8 = OpLoad %7 %5
-%14 = OpAccessChain %15 %9 %8
-%16 = OpAccessChain %17 %9 %8 %8
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error();
+ EXPECT_INST("%result_vector = OpAccessChain %_ptr_Function_v2float %mat %idx");
+ EXPECT_INST("%result_scalar = OpAccessChain %_ptr_Function_float %mat %idx %idx");
}
-TEST_F(SpvGeneratorImplTest_Access, Vector_Value_ConstantIndex) {
+TEST_F(SpvGeneratorImplTest, Access_Vector_Value_ConstantIndex) {
+ auto* vec_val = b.FunctionParam("vec", ty.vec4(ty.i32()));
auto* func = b.Function("foo", ty.void_());
- auto* vec_val = b.FunctionParam(ty.vec4(ty.i32()));
func->SetParams({vec_val});
-
b.With(func->Block(), [&] {
- b.Access(ty.i32(), vec_val, 1_u);
+ auto* result = b.Access(ty.i32(), vec_val, 1_u);
b.Return(func);
+ mod.SetName(result, "result");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%4 = OpTypeInt 32 1
-%3 = OpTypeVector %4 4
-%6 = OpTypeFunction %2 %3
-%1 = OpFunction %2 None %6
-%5 = OpFunctionParameter %3
-%7 = OpLabel
-%8 = OpCompositeExtract %4 %5 1
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error();
+ EXPECT_INST("%result = OpCompositeExtract %int %vec 1");
}
-TEST_F(SpvGeneratorImplTest_Access, Vector_Value_DynamicIndex) {
+TEST_F(SpvGeneratorImplTest, Access_Vector_Value_DynamicIndex) {
+ auto* vec_val = b.FunctionParam("vec", ty.vec4(ty.i32()));
+ auto* idx = b.FunctionParam("idx", ty.i32());
auto* func = b.Function("foo", ty.void_());
- auto* vec_val = b.FunctionParam(ty.vec4(ty.i32()));
- func->SetParams({vec_val});
-
+ func->SetParams({vec_val, idx});
b.With(func->Block(), [&] {
- auto* idx_var = b.Var(ty.ptr<function, i32>());
- auto* idx = b.Load(idx_var);
- b.Access(ty.i32(), vec_val, idx);
+ auto* result = b.Access(ty.i32(), vec_val, idx);
b.Return(func);
+ mod.SetName(result, "result");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%4 = OpTypeInt 32 1
-%3 = OpTypeVector %4 4
-%6 = OpTypeFunction %2 %3
-%9 = OpTypePointer Function %4
-%1 = OpFunction %2 None %6
-%5 = OpFunctionParameter %3
-%7 = OpLabel
-%8 = OpVariable %9 Function
-%10 = OpLoad %4 %8
-%11 = OpVectorExtractDynamic %4 %5 %10
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error();
+ EXPECT_INST("%result = OpVectorExtractDynamic %int %vec %idx");
}
-TEST_F(SpvGeneratorImplTest_Access, Vector_Pointer_ConstantIndex) {
+TEST_F(SpvGeneratorImplTest, Access_Vector_Pointer_ConstantIndex) {
auto* func = b.Function("foo", ty.void_());
-
b.With(func->Block(), [&] {
- auto* vec_var = b.Var(ty.ptr<function, vec4<i32>>());
- b.Access(ty.ptr<function, i32>(), vec_var, 1_u);
+ auto* vec_var = b.Var("vec", ty.ptr<function, vec4<i32>>());
+ auto* result = b.Access(ty.ptr<function, i32>(), vec_var, 1_u);
b.Return(func);
+ mod.SetName(result, "result");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%8 = OpTypeInt 32 1
-%7 = OpTypeVector %8 4
-%6 = OpTypePointer Function %7
-%10 = OpTypePointer Function %8
-%12 = OpTypeInt 32 0
-%11 = OpConstant %12 1
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpVariable %6 Function
-%9 = OpAccessChain %10 %5 %11
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error();
+ EXPECT_INST("%result = OpAccessChain %_ptr_Function_int %vec %uint_1");
}
-TEST_F(SpvGeneratorImplTest_Access, Vector_Pointer_DynamicIndex) {
+TEST_F(SpvGeneratorImplTest, Access_Vector_Pointer_DynamicIndex) {
+ auto* idx = b.FunctionParam("idx", ty.i32());
auto* func = b.Function("foo", ty.void_());
-
+ func->SetParams({idx});
b.With(func->Block(), [&] {
- auto* idx_var = b.Var(ty.ptr<function, i32>());
- auto* idx = b.Load(idx_var);
- auto* vec_var = b.Var(ty.ptr<function, vec4<i32>>());
- b.Access(ty.ptr<function, i32>(), vec_var, idx);
+ auto* vec_var = b.Var("vec", ty.ptr<function, vec4<i32>>());
+ auto* result = b.Access(ty.ptr<function, i32>(), vec_var, idx);
b.Return(func);
+ mod.SetName(result, "result");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%7 = OpTypeInt 32 1
-%6 = OpTypePointer Function %7
-%11 = OpTypeVector %7 4
-%10 = OpTypePointer Function %11
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpVariable %6 Function
-%9 = OpVariable %10 Function
-%8 = OpLoad %7 %5
-%12 = OpAccessChain %6 %9 %8
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error();
+ EXPECT_INST("%result = OpAccessChain %_ptr_Function_int %vec %idx");
}
-TEST_F(SpvGeneratorImplTest_Access, NestedVector_Value_DynamicIndex) {
- auto* val = b.FunctionParam(ty.array(ty.array(ty.vec4(ty.i32()), 4), 4));
+TEST_F(SpvGeneratorImplTest, Access_NestedVector_Value_DynamicIndex) {
+ auto* val = b.FunctionParam("arr", ty.array(ty.array(ty.vec4(ty.i32()), 4), 4));
+ auto* idx = b.FunctionParam("idx", ty.i32());
auto* func = b.Function("foo", ty.void_());
- func->SetParams({val});
-
+ func->SetParams({val, idx});
b.With(func->Block(), [&] {
- auto* idx_var = b.Var(ty.ptr<function, i32>());
- auto* idx = b.Load(idx_var);
- b.Access(ty.i32(), val, 1_u, 2_u, idx);
+ auto* result = b.Access(ty.i32(), val, 1_u, 2_u, idx);
b.Return(func);
+ mod.SetName(result, "result");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-OpDecorate %4 ArrayStride 16
-OpDecorate %3 ArrayStride 64
-%2 = OpTypeVoid
-%6 = OpTypeInt 32 1
-%5 = OpTypeVector %6 4
-%8 = OpTypeInt 32 0
-%7 = OpConstant %8 4
-%4 = OpTypeArray %5 %7
-%3 = OpTypeArray %4 %7
-%10 = OpTypeFunction %2 %3
-%13 = OpTypePointer Function %6
-%1 = OpFunction %2 None %10
-%9 = OpFunctionParameter %3
-%11 = OpLabel
-%12 = OpVariable %13 Function
-%14 = OpLoad %6 %12
-%16 = OpCompositeExtract %5 %9 1 2
-%15 = OpVectorExtractDynamic %6 %16 %14
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error();
+ EXPECT_INST("%14 = OpCompositeExtract %v4int %arr 1 2");
+ EXPECT_INST("%result = OpVectorExtractDynamic %int %14 %idx");
}
-TEST_F(SpvGeneratorImplTest_Access, Struct_Value_ConstantIndex) {
+TEST_F(SpvGeneratorImplTest, Access_Struct_Value_ConstantIndex) {
auto* str =
ty.Struct(mod.symbols.New("MyStruct"), {
{mod.symbols.Register("a"), ty.f32()},
{mod.symbols.Register("b"), ty.vec4<i32>()},
});
- auto* str_val = b.FunctionParam(str);
+ auto* str_val = b.FunctionParam("str", str);
auto* func = b.Function("foo", ty.void_());
func->SetParams({str_val});
-
b.With(func->Block(), [&] {
- b.Access(ty.i32(), str_val, 1_u);
- b.Access(ty.i32(), str_val, 1_u, 2_u);
+ auto* result_a = b.Access(ty.f32(), str_val, 0_u);
+ auto* result_b = b.Access(ty.i32(), str_val, 1_u, 2_u);
b.Return(func);
+ mod.SetName(result_a, "result_a");
+ mod.SetName(result_b, "result_b");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-OpMemberName %3 0 "a"
-OpMemberName %3 1 "b"
-OpName %3 "MyStruct"
-OpMemberDecorate %3 0 Offset 0
-OpMemberDecorate %3 1 Offset 16
-%2 = OpTypeVoid
-%4 = OpTypeFloat 32
-%6 = OpTypeInt 32 1
-%5 = OpTypeVector %6 4
-%3 = OpTypeStruct %4 %5
-%8 = OpTypeFunction %2 %3
-%1 = OpFunction %2 None %8
-%7 = OpFunctionParameter %3
-%9 = OpLabel
-%10 = OpCompositeExtract %6 %7 1
-%11 = OpCompositeExtract %6 %7 1 2
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error();
+ EXPECT_INST("%result_a = OpCompositeExtract %float %str 0");
+ EXPECT_INST("%result_b = OpCompositeExtract %int %str 1 2");
}
-TEST_F(SpvGeneratorImplTest_Access, Struct_Pointer_ConstantIndex) {
+TEST_F(SpvGeneratorImplTest, Access_Struct_Pointer_ConstantIndex) {
auto* str =
ty.Struct(mod.symbols.New("MyStruct"), {
{mod.symbols.Register("a"), ty.f32()},
{mod.symbols.Register("b"), ty.vec4<i32>()},
});
auto* func = b.Function("foo", ty.void_());
-
b.With(func->Block(), [&] {
- auto* str_var = b.Var(ty.ptr(function, str, read_write));
- b.Access(ty.ptr<function, i32>(), str_var, 1_u);
- b.Access(ty.ptr<function, i32>(), str_var, 1_u, 2_u);
+ auto* str_var = b.Var("str", ty.ptr(function, str, read_write));
+ auto* result_a = b.Access(ty.ptr<function, f32>(), str_var, 0_u);
+ auto* result_b = b.Access(ty.ptr<function, i32>(), str_var, 1_u, 2_u);
b.Return(func);
+ mod.SetName(result_a, "result_a");
+ mod.SetName(result_b, "result_b");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-OpMemberName %7 0 "a"
-OpMemberName %7 1 "b"
-OpName %7 "MyStruct"
-OpMemberDecorate %7 0 Offset 0
-OpMemberDecorate %7 1 Offset 16
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%8 = OpTypeFloat 32
-%10 = OpTypeInt 32 1
-%9 = OpTypeVector %10 4
-%7 = OpTypeStruct %8 %9
-%6 = OpTypePointer Function %7
-%12 = OpTypePointer Function %10
-%14 = OpTypeInt 32 0
-%13 = OpConstant %14 1
-%16 = OpConstant %14 2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpVariable %6 Function
-%11 = OpAccessChain %12 %5 %13
-%15 = OpAccessChain %12 %5 %13 %16
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error();
+ EXPECT_INST("%result_a = OpAccessChain %_ptr_Function_float %str %uint_0");
+ EXPECT_INST("%result_b = OpAccessChain %_ptr_Function_int %str %uint_1 %uint_2");
}
} // namespace
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
index 00bfeb0..8450235 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
@@ -167,12 +167,10 @@
TEST_F(SpvGeneratorImplTest, Function_Parameters) {
auto* i32 = ty.i32();
- auto* x = b.FunctionParam(i32);
- auto* y = b.FunctionParam(i32);
+ auto* x = b.FunctionParam("x", i32);
+ auto* y = b.FunctionParam("y", i32);
auto* func = b.Function("foo", i32);
func->SetParams({x, y});
- mod.SetName(x, "x");
- mod.SetName(y, "y");
b.With(func->Block(), [&] {
auto* result = b.Add(i32, x, y);
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
index 63bc82c..d0d6d24 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
@@ -25,7 +25,7 @@
auto* loop = b.Loop();
loop->Body()->Append(b.Continue(loop));
- loop->Continuing()->Append(b.BreakIf(true, loop));
+ loop->Continuing()->Append(b.BreakIf(loop, true));
func->Block()->Append(loop);
func->Block()->Append(b.Return(func));
@@ -218,7 +218,7 @@
auto* result = loop->Body()->Append(b.Equal(ty.i32(), 1_i, 2_i));
loop->Body()->Append(b.Continue(loop, result));
- loop->Continuing()->Append(b.BreakIf(result, loop));
+ loop->Continuing()->Append(b.BreakIf(loop, result));
func->Block()->Append(loop);
func->Block()->Append(b.Return(func));
@@ -260,7 +260,7 @@
outer_loop->Body()->Append(inner_loop);
outer_loop->Body()->Append(b.Continue(outer_loop));
- outer_loop->Continuing()->Append(b.BreakIf(true, outer_loop));
+ outer_loop->Continuing()->Append(b.BreakIf(outer_loop, true));
func->Block()->Append(outer_loop);
func->Block()->Append(b.Return(func));
@@ -305,11 +305,11 @@
auto* inner_loop = b.Loop();
inner_loop->Body()->Append(b.Continue(inner_loop));
- inner_loop->Continuing()->Append(b.BreakIf(true, inner_loop));
+ inner_loop->Continuing()->Append(b.BreakIf(inner_loop, true));
outer_loop->Body()->Append(b.Continue(outer_loop));
outer_loop->Continuing()->Append(inner_loop);
- outer_loop->Continuing()->Append(b.BreakIf(true, outer_loop));
+ outer_loop->Continuing()->Append(b.BreakIf(outer_loop, true));
func->Block()->Append(outer_loop);
func->Block()->Append(b.Return(func));
@@ -351,23 +351,23 @@
auto* func = b.Function("foo", ty.void_());
b.With(func->Block(), [&] {
- auto* l = b.Loop();
+ auto* loop = b.Loop();
- b.With(l->Initializer(), [&] { b.NextIteration(l, 1_i, false); });
+ b.With(loop->Initializer(), [&] { b.NextIteration(loop, 1_i, false); });
auto* loop_param = b.BlockParam(ty.i32());
- l->Body()->SetParams({loop_param});
+ loop->Body()->SetParams({loop_param});
- b.With(l->Body(), [&] {
+ b.With(loop->Body(), [&] {
auto* inc = b.Add(ty.i32(), loop_param, 1_i);
- b.Continue(l, inc);
+ b.Continue(loop, inc);
});
auto* cont_param = b.BlockParam(ty.i32());
- l->Continuing()->SetParams({cont_param});
- b.With(l->Continuing(), [&] {
+ loop->Continuing()->SetParams({cont_param});
+ b.With(loop->Continuing(), [&] {
auto* cmp = b.GreaterThan(ty.bool_(), cont_param, 5_i);
- b.BreakIf(cmp, l, cont_param);
+ b.BreakIf(loop, cmp, cont_param);
});
b.Return(func);
@@ -409,26 +409,26 @@
auto* func = b.Function("foo", ty.void_());
b.With(func->Block(), [&] {
- auto* l = b.Loop();
+ auto* loop = b.Loop();
- b.With(l->Initializer(), [&] { b.NextIteration(l, 1_i, false); });
+ b.With(loop->Initializer(), [&] { b.NextIteration(loop, 1_i, false); });
auto* loop_param_a = b.BlockParam(ty.i32());
auto* loop_param_b = b.BlockParam(ty.bool_());
- l->Body()->SetParams({loop_param_a, loop_param_b});
+ loop->Body()->SetParams({loop_param_a, loop_param_b});
- b.With(l->Body(), [&] {
+ b.With(loop->Body(), [&] {
auto* inc = b.Add(ty.i32(), loop_param_a, 1_i);
- b.Continue(l, inc, loop_param_b);
+ b.Continue(loop, inc, loop_param_b);
});
auto* cont_param_a = b.BlockParam(ty.i32());
auto* cont_param_b = b.BlockParam(ty.bool_());
- l->Continuing()->SetParams({cont_param_a, cont_param_b});
- b.With(l->Continuing(), [&] {
+ loop->Continuing()->SetParams({cont_param_a, cont_param_b});
+ b.With(loop->Continuing(), [&] {
auto* cmp = b.GreaterThan(ty.bool_(), cont_param_a, 5_i);
auto* not_b = b.Not(ty.bool_(), cont_param_b);
- b.BreakIf(cmp, l, cont_param_a, not_b);
+ b.BreakIf(loop, cmp, cont_param_a, not_b);
});
b.Return(func);
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
index f0d3e0d..a79e326 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
@@ -77,10 +77,8 @@
auto* func = b.Function("foo", ty.void_());
b.With(func->Block(), [&] {
- auto* v = b.Var(ty.ptr<function, i32>());
+ b.Var("myvar", ty.ptr<function, i32>());
b.Return(func);
-
- mod.SetName(v, "myvar");
});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -241,12 +239,10 @@
}
TEST_F(SpvGeneratorImplTest, PrivateVar_Name) {
- auto* v = b.Var(ty.ptr<private_, i32>());
+ auto* v = b.Var("myvar", ty.ptr<private_, i32>());
v->SetInitializer(b.Constant(42_i));
b.RootBlock()->Append(v);
- mod.SetName(v, "myvar");
-
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
@@ -269,7 +265,6 @@
TEST_F(SpvGeneratorImplTest, PrivateVar_LoadAndStore) {
auto* func = b.Function("foo", ty.void_(), ir::Function::PipelineStage::kFragment);
- mod.functions.Push(func);
auto* store_ty = ty.i32();
auto* v = b.Var(ty.ptr(private_, store_ty));
@@ -328,8 +323,7 @@
}
TEST_F(SpvGeneratorImplTest, WorkgroupVar_Name) {
- auto* v = b.RootBlock()->Append(b.Var(ty.ptr<workgroup, i32>()));
- mod.SetName(v, "myvar");
+ b.RootBlock()->Append(b.Var("myvar", ty.ptr<workgroup, i32>()));
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
@@ -353,7 +347,6 @@
TEST_F(SpvGeneratorImplTest, WorkgroupVar_LoadAndStore) {
auto* func = b.Function("foo", ty.void_(), ir::Function::PipelineStage::kCompute,
std::array{1u, 1u, 1u});
- mod.functions.Push(func);
auto* store_ty = ty.i32();
auto* v = b.RootBlock()->Append(b.Var(ty.ptr(workgroup, store_ty)));
@@ -442,10 +435,9 @@
}
TEST_F(SpvGeneratorImplTest, StorageVar_Name) {
- auto* v = b.Var(ty.ptr<storage, i32>());
+ auto* v = b.Var("myvar", ty.ptr<storage, i32>());
v->SetBindingPoint(0, 0);
b.RootBlock()->Append(v);
- mod.SetName(v, "myvar");
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
@@ -479,7 +471,6 @@
auto* func = b.Function("foo", ty.void_(), ir::Function::PipelineStage::kCompute,
std::array{1u, 1u, 1u});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
b.Load(v);
@@ -554,10 +545,9 @@
}
TEST_F(SpvGeneratorImplTest, UniformVar_Name) {
- auto* v = b.Var(ty.ptr<uniform, i32>());
+ auto* v = b.Var("myvar", ty.ptr<uniform, i32>());
v->SetBindingPoint(0, 0);
b.RootBlock()->Append(v);
- mod.SetName(v, "myvar");
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
@@ -591,7 +581,6 @@
auto* func = b.Function("foo", ty.void_(), ir::Function::PipelineStage::kCompute,
std::array{1u, 1u, 1u});
- mod.functions.Push(func);
b.With(func->Block(), [&] {
b.Load(v);
diff --git a/src/tint/writer/spirv/ir/test_helper_ir.h b/src/tint/writer/spirv/ir/test_helper_ir.h
index 08f8246..27559be 100644
--- a/src/tint/writer/spirv/ir/test_helper_ir.h
+++ b/src/tint/writer/spirv/ir/test_helper_ir.h
@@ -16,8 +16,10 @@
#define SRC_TINT_WRITER_SPIRV_IR_TEST_HELPER_IR_H_
#include <string>
+#include <utility>
#include "gtest/gtest.h"
+#include "spirv-tools/libspirv.hpp"
#include "src/tint/ir/builder.h"
#include "src/tint/ir/validate.h"
#include "src/tint/writer/spirv/ir/generator_impl_ir.h"
@@ -25,6 +27,10 @@
namespace tint::writer::spirv {
+// Helper macro to check whether the SPIR-V output contains an instruction, dumping the full output
+// if the instruction was not present.
+#define EXPECT_INST(inst) ASSERT_THAT(output_, testing::HasSubstr(inst)) << output_
+
/// The element type of a test.
enum TestElementType {
kBool,
@@ -54,6 +60,9 @@
/// Validation errors
std::string err_;
+ /// SPIR-V output.
+ std::string output_;
+
/// @returns the error string from the validation
std::string Error() const { return err_; }
@@ -67,6 +76,59 @@
return true;
}
+ /// Run the generator on the IR module and validate the result.
+ /// @returns true if generation and validation succeeded
+ bool Generate() {
+ if (!generator_.Generate()) {
+ err_ = generator_.Diagnostics().str();
+ return false;
+ }
+ if (!Validate()) {
+ return false;
+ }
+
+ output_ = Disassemble(generator_.Result(), SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES |
+ SPV_BINARY_TO_TEXT_OPTION_INDENT |
+ SPV_BINARY_TO_TEXT_OPTION_COMMENT);
+ return true;
+ }
+
+ /// Validate the generated SPIR-V using the SPIR-V Tools Validator.
+ /// @returns true if validation succeeded, false otherwise
+ bool Validate() {
+ auto binary = generator_.Result();
+
+ std::string spv_errors;
+ auto msg_consumer = [&spv_errors](spv_message_level_t level, const char*,
+ const spv_position_t& position, const char* message) {
+ switch (level) {
+ case SPV_MSG_FATAL:
+ case SPV_MSG_INTERNAL_ERROR:
+ case SPV_MSG_ERROR:
+ spv_errors +=
+ "error: line " + std::to_string(position.index) + ": " + message + "\n";
+ break;
+ case SPV_MSG_WARNING:
+ spv_errors +=
+ "warning: line " + std::to_string(position.index) + ": " + message + "\n";
+ break;
+ case SPV_MSG_INFO:
+ spv_errors +=
+ "info: line " + std::to_string(position.index) + ": " + message + "\n";
+ break;
+ case SPV_MSG_DEBUG:
+ break;
+ }
+ };
+
+ spvtools::SpirvTools tools(SPV_ENV_VULKAN_1_2);
+ tools.SetMessageConsumer(msg_consumer);
+
+ auto result = tools.Validate(binary);
+ err_ = std::move(spv_errors);
+ return result;
+ }
+
/// @returns the disassembled types from the generated module.
std::string DumpTypes() { return DumpInstructions(generator_.Module().Types()); }
diff --git a/src/tint/writer/spirv/spv_dump.cc b/src/tint/writer/spirv/spv_dump.cc
index da24446..6ded6ff 100644
--- a/src/tint/writer/spirv/spv_dump.cc
+++ b/src/tint/writer/spirv/spv_dump.cc
@@ -19,7 +19,7 @@
namespace tint::writer::spirv {
-std::string Disassemble(const std::vector<uint32_t>& data) {
+std::string Disassemble(const std::vector<uint32_t>& data, uint32_t options /* = 0u */) {
std::string spv_errors;
spv_target_env target_env = SPV_ENV_UNIVERSAL_1_0;
@@ -49,7 +49,7 @@
tools.SetMessageConsumer(msg_consumer);
std::string result;
- if (!tools.Disassemble(data, &result, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER)) {
+ if (!tools.Disassemble(data, &result, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | options)) {
return "*** Invalid SPIR-V ***\n" + spv_errors;
}
return result;
diff --git a/src/tint/writer/spirv/spv_dump.h b/src/tint/writer/spirv/spv_dump.h
index 359f0cb..b9030a4 100644
--- a/src/tint/writer/spirv/spv_dump.h
+++ b/src/tint/writer/spirv/spv_dump.h
@@ -24,8 +24,9 @@
/// Disassembles SPIR-V binary data into its textual form.
/// @param data the SPIR-V binary data
+/// @param options the additional SPIR-V disassembler options to use
/// @returns the disassembled SPIR-V string
-std::string Disassemble(const std::vector<uint32_t>& data);
+std::string Disassemble(const std::vector<uint32_t>& data, uint32_t options = 0u);
/// Dumps the given builder to a SPIR-V disassembly string
/// @param builder the builder to convert
diff --git a/src/tint/writer/syntax_tree/generator_impl.cc b/src/tint/writer/syntax_tree/generator_impl.cc
index 4574c14..b6ba3ea 100644
--- a/src/tint/writer/syntax_tree/generator_impl.cc
+++ b/src/tint/writer/syntax_tree/generator_impl.cc
@@ -520,6 +520,14 @@
}
Line() << "]";
},
+ [&](const ast::IndexAttribute* index) {
+ line() << "IndexAttribute [";
+ {
+ ScopedIndent idx(this);
+ EmitExpression(index->expr);
+ }
+ line() << "]";
+ },
[&](const ast::BuiltinAttribute* builtin) {
Line() << "BuiltinAttribute [";
{
diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc
index 59aa35c..b36994a 100644
--- a/src/tint/writer/wgsl/generator_impl.cc
+++ b/src/tint/writer/wgsl/generator_impl.cc
@@ -430,6 +430,11 @@
EmitExpression(out, location->expr);
out << ")";
},
+ [&](const ast::IndexAttribute* index) {
+ out << "index(";
+ EmitExpression(out, index->expr);
+ out << ")";
+ },
[&](const ast::BuiltinAttribute* builtin) {
out << "builtin(";
EmitExpression(out, builtin->builtin);