[tint][ir] Add LoadVectorElement / StoreVectorElement
These new instructions must now be used to load and store a vector
element. It is now a validation error to obtain a pointer to a
vector member.
This is done to match the limitations of MSL and WGSL.
Change-Id: I7a0fa7287ecfc7cd441f9cda432df1134c4bc45a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/139924
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index a3a7730..53e9874 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -1311,6 +1311,8 @@
"ir/let.h",
"ir/load.cc",
"ir/load.h",
+ "ir/load_vector_element.cc",
+ "ir/load_vector_element.h",
"ir/location.h",
"ir/loop.cc",
"ir/loop.h",
@@ -1326,6 +1328,8 @@
"ir/return.h",
"ir/store.cc",
"ir/store.h",
+ "ir/store_vector_element.cc",
+ "ir/store_vector_element.h",
"ir/switch.cc",
"ir/switch.h",
"ir/swizzle.cc",
@@ -2407,6 +2411,7 @@
"ir/ir_test_helper.h",
"ir/let_test.cc",
"ir/load_test.cc",
+ "ir/load_vector_element_test.cc",
"ir/loop_test.cc",
"ir/module_test.cc",
"ir/multi_in_block_test.cc",
@@ -2415,6 +2420,7 @@
"ir/program_test_helper.h",
"ir/return_test.cc",
"ir/store_test.cc",
+ "ir/store_vector_element_test.cc",
"ir/switch_test.cc",
"ir/swizzle_test.cc",
"ir/to_program_inlining_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 34e3b12..52b5c1f 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -797,6 +797,8 @@
ir/let.h
ir/load.cc
ir/load.h
+ ir/load_vector_element.cc
+ ir/load_vector_element.h
ir/location.h
ir/loop.cc
ir/loop.h
@@ -812,6 +814,8 @@
ir/return.h
ir/store.cc
ir/store.h
+ ir/store_vector_element.cc
+ ir/store_vector_element.h
ir/switch.cc
ir/switch.h
ir/swizzle.cc
@@ -1591,6 +1595,7 @@
ir/ir_test_helper.h
ir/let_test.cc
ir/load_test.cc
+ ir/load_vector_element_test.cc
ir/loop_test.cc
ir/module_test.cc
ir/multi_in_block_test.cc
@@ -1599,6 +1604,7 @@
ir/program_test_helper.h
ir/return_test.cc
ir/store_test.cc
+ ir/store_vector_element_test.cc
ir/switch_test.cc
ir/swizzle_test.cc
ir/transform/add_empty_entry_point_test.cc
diff --git a/src/tint/ir/builder.cc b/src/tint/ir/builder.cc
index 6049a85..78e0947 100644
--- a/src/tint/ir/builder.cc
+++ b/src/tint/ir/builder.cc
@@ -101,4 +101,17 @@
return Append(ir.instructions.Create<ir::Unreachable>());
}
+const type::Type* Builder::VectorPtrElementType(const type::Type* type) {
+ auto* vec_ptr_ty = type->As<type::Pointer>();
+ TINT_ASSERT(IR, vec_ptr_ty);
+ if (TINT_LIKELY(vec_ptr_ty)) {
+ auto* vec_ty = vec_ptr_ty->StoreType()->As<type::Vector>();
+ TINT_ASSERT(IR, vec_ty);
+ if (TINT_LIKELY(vec_ty)) {
+ return vec_ty->type();
+ }
+ }
+ return ir.Types().i32();
+}
+
} // namespace tint::ir
diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h
index 65f7df9..707d22d 100644
--- a/src/tint/ir/builder.h
+++ b/src/tint/ir/builder.h
@@ -40,12 +40,14 @@
#include "src/tint/ir/instruction_result.h"
#include "src/tint/ir/let.h"
#include "src/tint/ir/load.h"
+#include "src/tint/ir/load_vector_element.h"
#include "src/tint/ir/loop.h"
#include "src/tint/ir/module.h"
#include "src/tint/ir/multi_in_block.h"
#include "src/tint/ir/next_iteration.h"
#include "src/tint/ir/return.h"
#include "src/tint/ir/store.h"
+#include "src/tint/ir/store_vector_element.h"
#include "src/tint/ir/switch.h"
#include "src/tint/ir/swizzle.h"
#include "src/tint/ir/unary.h"
@@ -581,6 +583,33 @@
return Append(ir.instructions.Create<ir::Store>(to_val, from_val));
}
+ /// Creates a store vector element instruction
+ /// @param to the vector pointer expression being stored too
+ /// @param index the new vector element index
+ /// @param value the new vector element expression
+ /// @returns the instruction
+ template <typename TO, typename INDEX, typename VALUE>
+ ir::StoreVectorElement* StoreVectorElement(TO&& to, INDEX&& index, VALUE&& value) {
+ CheckForNonDeterministicEvaluation<TO, INDEX, VALUE>();
+ auto* to_val = Value(std::forward<TO>(to));
+ auto* index_val = Value(std::forward<INDEX>(index));
+ auto* value_val = Value(std::forward<VALUE>(value));
+ return Append(ir.instructions.Create<ir::StoreVectorElement>(to_val, index_val, value_val));
+ }
+
+ /// Creates a load vector element instruction
+ /// @param from the vector pointer expression being loaded from
+ /// @param index the new vector element index
+ /// @returns the instruction
+ template <typename FROM, typename INDEX>
+ ir::LoadVectorElement* LoadVectorElement(FROM&& from, INDEX&& index) {
+ CheckForNonDeterministicEvaluation<FROM, INDEX>();
+ auto* from_val = Value(std::forward<FROM>(from));
+ auto* index_val = Value(std::forward<INDEX>(index));
+ auto* res = InstructionResult(VectorPtrElementType(from_val->Type()));
+ return Append(ir.instructions.Create<ir::LoadVectorElement>(res, from_val, index_val));
+ }
+
/// Creates a new `var` declaration
/// @param type the var type
/// @returns the instruction
@@ -778,6 +807,11 @@
/// The IR module.
Module& ir;
+
+ private:
+ /// @returns the element type of the vector-pointer type
+ /// Asserts and return i32 if @p type is not a pointer to a vector
+ const type::Type* VectorPtrElementType(const type::Type* type);
};
} // namespace tint::ir
diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc
index 497695a..a337abc 100644
--- a/src/tint/ir/disassembler.cc
+++ b/src/tint/ir/disassembler.cc
@@ -36,11 +36,13 @@
#include "src/tint/ir/instruction_result.h"
#include "src/tint/ir/let.h"
#include "src/tint/ir/load.h"
+#include "src/tint/ir/load_vector_element.h"
#include "src/tint/ir/loop.h"
#include "src/tint/ir/multi_in_block.h"
#include "src/tint/ir/next_iteration.h"
#include "src/tint/ir/return.h"
#include "src/tint/ir/store.h"
+#include "src/tint/ir/store_vector_element.h"
#include "src/tint/ir/switch.h"
#include "src/tint/ir/swizzle.h"
#include "src/tint/ir/unreachable.h"
@@ -407,7 +409,7 @@
out_ << " = ";
EmitInstructionName("bitcast", b);
out_ << " ";
- EmitArgs(b);
+ EmitOperandList(b);
EmitLine();
},
[&](Discard* d) {
@@ -419,7 +421,7 @@
out_ << " = ";
EmitInstructionName(builtin::str(b->Func()), b);
out_ << " ";
- EmitArgs(b);
+ EmitOperandList(b);
EmitLine();
},
[&](Construct* c) {
@@ -427,7 +429,7 @@
out_ << " = ";
EmitInstructionName("construct", c);
out_ << " ";
- EmitArgs(c);
+ EmitOperandList(c);
EmitLine();
},
[&](Convert* c) {
@@ -435,7 +437,7 @@
out_ << " = ";
EmitInstructionName("convert", c);
out_ << " ";
- EmitArgs(c);
+ EmitOperandList(c);
EmitLine();
},
[&](Load* l) {
@@ -454,6 +456,22 @@
EmitValue(s->From());
EmitLine();
},
+ [&](LoadVectorElement* l) {
+ EmitValueWithType(l);
+ out_ << " = ";
+ EmitInstructionName("load_vector_element", l);
+ out_ << " ";
+ EmitOperandList(l);
+ EmitLine();
+ },
+ [&](StoreVectorElement* s) {
+ EmitInstructionName("store_vector_element", s);
+ out_ << " ";
+ EmitValue(s->To());
+ out_ << " ";
+ EmitOperandList(s);
+ EmitLine();
+ },
[&](UserCall* uc) {
EmitValueWithType(uc);
out_ << " = ";
@@ -462,7 +480,7 @@
if (!uc->Args().IsEmpty()) {
out_ << ", ";
}
- EmitArgs(uc);
+ EmitOperandList(uc, UserCall::kArgsOperandOffset);
EmitLine();
},
[&](Var* v) {
@@ -471,7 +489,7 @@
EmitInstructionName("var", v);
if (v->Initializer()) {
out_ << ", ";
- EmitOperand(v, v->Initializer(), Var::kInitializerOperandOffset);
+ EmitOperand(v, Var::kInitializerOperandOffset);
}
if (v->BindingPoint().has_value()) {
out_ << " ";
@@ -484,7 +502,7 @@
out_ << " = ";
EmitInstructionName("let", l);
out_ << " ";
- EmitOperand(l, l->Value(), Let::kValueOperandOffset);
+ EmitOperandList(l);
EmitLine();
},
[&](Access* a) {
@@ -492,9 +510,7 @@
out_ << " = ";
EmitInstructionName("access", a);
out_ << " ";
- EmitOperand(a, a->Object(), Access::kObjectOperandOffset);
- out_ << ", ";
- EmitOperandList(a, a->Indices(), Access::kIndicesOperandOffset);
+ EmitOperandList(a);
EmitLine();
},
[&](Swizzle* s) {
@@ -526,21 +542,18 @@
[&](Default) { out_ << "Unknown instruction: " << inst->TypeInfo().name; });
}
-void Disassembler::EmitOperand(Instruction* inst, Value* val, size_t index) {
+void Disassembler::EmitOperand(Instruction* inst, size_t index) {
SourceMarker condMarker(this);
- EmitValue(val);
+ EmitValue(inst->Operands()[index]);
condMarker.Store(Usage{inst, static_cast<uint32_t>(index)});
}
-void Disassembler::EmitOperandList(Instruction* inst,
- utils::Slice<Value* const> operands,
- size_t start_index) {
- size_t index = start_index;
- for (auto* operand : operands) {
- if (index != start_index) {
+void Disassembler::EmitOperandList(Instruction* inst, size_t start_index /* = 0 */) {
+ for (size_t i = start_index, n = inst->Operands().Length(); i < n; i++) {
+ if (i != start_index) {
out_ << ", ";
}
- EmitOperand(inst, operand, index++);
+ EmitOperand(inst, i);
}
}
@@ -559,7 +572,7 @@
out_ << " = ";
}
out_ << "if ";
- EmitOperand(if_, if_->Condition(), If::kConditionOperandOffset);
+ EmitOperand(if_, If::kConditionOperandOffset);
bool has_false = !if_->False()->IsEmpty();
@@ -700,26 +713,46 @@
void Disassembler::EmitTerminator(Terminator* b) {
SourceMarker sm(this);
+ size_t args_offset = 0;
tint::Switch(
- b, //
- [&](ir::Return*) { out_ << "ret"; }, //
- [&](ir::Continue* cont) { out_ << "continue %b" << IdOf(cont->Loop()->Continuing()); }, //
- [&](ir::ExitIf*) { out_ << "exit_if"; }, //
- [&](ir::ExitSwitch*) { out_ << "exit_switch"; }, //
- [&](ir::ExitLoop*) { out_ << "exit_loop"; }, //
- [&](ir::NextIteration* ni) { out_ << "next_iteration %b" << IdOf(ni->Loop()->Body()); }, //
- [&](ir::Unreachable*) { out_ << "unreachable"; }, //
+ b,
+ [&](ir::Return*) {
+ out_ << "ret";
+ args_offset = ir::Return::kArgOperandOffset;
+ },
+ [&](ir::Continue* cont) {
+ out_ << "continue %b" << IdOf(cont->Loop()->Continuing());
+ args_offset = ir::Continue::kArgsOperandOffset;
+ },
+ [&](ir::ExitIf*) {
+ out_ << "exit_if";
+ args_offset = ir::ExitIf::kArgsOperandOffset;
+ },
+ [&](ir::ExitSwitch*) {
+ out_ << "exit_switch";
+ args_offset = ir::ExitSwitch::kArgsOperandOffset;
+ },
+ [&](ir::ExitLoop*) {
+ out_ << "exit_loop";
+ args_offset = ir::ExitLoop::kArgsOperandOffset;
+ },
+ [&](ir::NextIteration* ni) {
+ out_ << "next_iteration %b" << IdOf(ni->Loop()->Body());
+ args_offset = ir::NextIteration::kArgsOperandOffset;
+ },
+ [&](ir::Unreachable*) { out_ << "unreachable"; },
[&](ir::BreakIf* bi) {
out_ << "break_if ";
EmitValue(bi->Condition());
out_ << " %b" << IdOf(bi->Loop()->Body());
+ args_offset = ir::BreakIf::kArgsOperandOffset;
},
[&](Unreachable*) { out_ << "unreachable"; },
[&](Default) { out_ << "unknown terminator " << b->TypeInfo().name; });
if (!b->Args().IsEmpty()) {
out_ << " ";
- EmitValueList(b, b->Args());
+ EmitOperandList(b, args_offset);
}
sm.Store(b);
@@ -734,32 +767,14 @@
}
void Disassembler::EmitValueList(utils::Slice<Value* const> values) {
- for (auto* v : values) {
- if (v != values.Front()) {
+ for (size_t i = 0, n = values.Length(); i < n; i++) {
+ if (i > 0) {
out_ << ", ";
}
- EmitValue(v);
+ EmitValue(values[i]);
}
}
-void Disassembler::EmitValueList(Instruction* inst, utils::Slice<Value* const> values) {
- auto len = values.Length();
- for (size_t i = 0; i < len; ++i) {
- auto* v = values[i];
- if (v != values.Front()) {
- out_ << ", ";
- }
-
- SourceMarker sm(this);
- EmitValue(v);
- sm.Store(Usage{inst, static_cast<uint32_t>(i)});
- }
-}
-
-void Disassembler::EmitArgs(Call* call) {
- EmitValueList(call, call->Args());
-}
-
void Disassembler::EmitBinary(Binary* b) {
SourceMarker sm(this);
EmitValueWithType(b);
@@ -815,9 +830,7 @@
break;
}
out_ << " ";
- EmitOperand(b, b->LHS(), Binary::kLhsOperandOffset);
- out_ << ", ";
- EmitOperand(b, b->RHS(), Binary::kRhsOperandOffset);
+ EmitOperandList(b);
sm.Store(b);
EmitLine();
@@ -836,7 +849,7 @@
break;
}
out_ << " ";
- EmitOperand(u, u->Val(), Unary::kValueOperandOffset);
+ EmitOperandList(u);
sm.Store(u);
EmitLine();
diff --git a/src/tint/ir/disassembler.h b/src/tint/ir/disassembler.h
index 1712071..fc09f01 100644
--- a/src/tint/ir/disassembler.h
+++ b/src/tint/ir/disassembler.h
@@ -133,9 +133,7 @@
void EmitValueWithType(Instruction* val);
void EmitValueWithType(Value* val);
void EmitValue(Value* val);
- void EmitValueList(Instruction* inst, utils::Slice<Value* const> values);
void EmitValueList(utils::Slice<ir::Value* const> values);
- void EmitArgs(Call* call);
void EmitBinary(Binary* b);
void EmitUnary(Unary* b);
void EmitTerminator(Terminator* b);
@@ -144,10 +142,8 @@
void EmitIf(If* i);
void EmitStructDecl(const type::Struct* str);
void EmitLine();
- void EmitOperand(Instruction* inst, Value* val, size_t index);
- void EmitOperandList(Instruction* inst,
- utils::Slice<Value* const> operands,
- size_t start_index);
+ void EmitOperand(Instruction* inst, size_t index);
+ void EmitOperandList(Instruction* inst, size_t start_index = 0);
void EmitInstructionName(std::string_view name, Instruction* inst);
Module& mod_;
diff --git a/src/tint/ir/from_program.cc b/src/tint/ir/from_program.cc
index 2835d9d..d8e49aa 100644
--- a/src/tint/ir/from_program.cc
+++ b/src/tint/ir/from_program.cc
@@ -82,6 +82,7 @@
#include "src/tint/sem/builtin.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/function.h"
+#include "src/tint/sem/index_accessor_expression.h"
#include "src/tint/sem/load.h"
#include "src/tint/sem/materialize.h"
#include "src/tint/sem/member_accessor_expression.h"
@@ -468,6 +469,14 @@
}
void EmitAssignment(const ast::AssignmentStatement* stmt) {
+ auto b = builder_.With(current_block_);
+ if (auto access = AsVectorRefElementAccess(stmt->lhs)) {
+ if (auto rhs = EmitExpression(stmt->rhs)) {
+ b.StoreVectorElement(access->vector, access->index, rhs.Get());
+ }
+ return;
+ }
+
// If assigning to a phony, just generate the RHS and we're done. Note that, because
// this isn't used, a subsequent transform could remove it due to it being dead code.
// This could then change the interface for the program (i.e. a global var no longer
@@ -486,8 +495,7 @@
if (!rhs) {
return;
}
- auto store = builder_.Store(lhs.Get(), rhs.Get());
- current_block_->Append(store);
+ b.Store(lhs.Get(), rhs.Get());
}
void EmitIncrementDecrement(const ast::IncrementDecrementStatement* stmt) {
@@ -512,6 +520,19 @@
void EmitCompoundAssignment(const ast::Expression* lhs_expr,
EMIT_RHS&& emit_rhs,
ast::BinaryOp op) {
+ auto b = builder_.With(current_block_);
+ if (auto access = AsVectorRefElementAccess(lhs_expr)) {
+ // Compound assignment of vector element needs to use LoadVectorElement() and
+ // StoreVectorElement().
+ if (auto rhs = emit_rhs()) {
+ auto* load = b.LoadVectorElement(access->vector, access->index);
+ auto* ty = load->Result()->Type();
+ auto* inst = b.Append(BinaryOp(ty, load->Result(), rhs, op));
+ b.StoreVectorElement(access->vector, access->index, inst);
+ }
+ return;
+ }
+
auto lhs = EmitExpression(lhs_expr);
if (!lhs) {
return;
@@ -520,10 +541,10 @@
if (!rhs) {
return;
}
- auto* load = current_block_->Append(builder_.Load(lhs.Get()));
+ auto* load = b.Load(lhs.Get());
auto* ty = load->Result()->Type();
auto* inst = current_block_->Append(BinaryOp(ty, load->Result(), rhs, op));
- current_block_->Append(builder_.Store(lhs.Get(), inst));
+ b.Store(lhs.Get(), inst);
}
void EmitBlock(const ast::BlockStatement* block) {
@@ -804,6 +825,13 @@
};
utils::Result<Value*> EmitAccess(const ast::AccessorExpression* expr) {
+ if (auto vec_access = AsVectorRefElementAccess(expr)) {
+ // Vector reference accesses need to map to LoadVectorElement()
+ auto* load = builder_.LoadVectorElement(vec_access->vector, vec_access->index);
+ current_block_->Append(load);
+ return load->Result();
+ }
+
std::vector<const ast::Expression*> accessors;
const ast::Expression* object = expr;
while (true) {
@@ -979,7 +1007,7 @@
});
// If this expression maps to sem::Load, insert a load instruction to get the result.
- if (result && sem->Is<sem::Load>()) {
+ if (result && result.Get()->Type()->Is<type::Pointer>() && sem->Is<sem::Load>()) {
auto* load = builder_.Load(result.Get());
current_block_->Append(load);
return load->Result();
@@ -1328,6 +1356,51 @@
TINT_UNREACHABLE(IR, diagnostics_);
return nullptr;
}
+
+ struct VectorRefElementAccess {
+ ir::Value* vector = nullptr;
+ ir::Value* index = nullptr;
+ };
+
+ std::optional<VectorRefElementAccess> AsVectorRefElementAccess(const ast::Expression* expr) {
+ return AsVectorRefElementAccess(
+ program_->Sem().Get<sem::ValueExpression>(expr)->UnwrapLoad());
+ }
+
+ std::optional<VectorRefElementAccess> AsVectorRefElementAccess(
+ const sem::ValueExpression* expr) {
+ auto* access = As<sem::AccessorExpression>(expr);
+ if (!access) {
+ return std::nullopt;
+ }
+
+ auto* ref = access->Object()->Type()->As<type::Reference>();
+ if (!ref) {
+ return std::nullopt;
+ }
+
+ if (!ref->StoreType()->Is<type::Vector>()) {
+ return std::nullopt;
+ }
+
+ return tint::Switch(
+ access,
+ [&](const sem::Swizzle* s) -> std::optional<VectorRefElementAccess> {
+ if (auto vec = EmitExpression(access->Object()->Declaration())) {
+ return VectorRefElementAccess{vec.Get(),
+ builder_.Constant(u32(s->Indices()[0]))};
+ }
+ return std::nullopt;
+ },
+ [&](const sem::IndexAccessorExpression* i) -> std::optional<VectorRefElementAccess> {
+ if (auto vec = EmitExpression(access->Object()->Declaration())) {
+ if (auto idx = EmitExpression(i->Index()->Declaration())) {
+ return VectorRefElementAccess{vec.Get(), idx.Get()};
+ }
+ }
+ return std::nullopt;
+ });
+ }
};
} // namespace
diff --git a/src/tint/ir/from_program_accessor_test.cc b/src/tint/ir/from_program_accessor_test.cc
index 51da5af..43f2bee 100644
--- a/src/tint/ir/from_program_accessor_test.cc
+++ b/src/tint/ir/from_program_accessor_test.cc
@@ -67,8 +67,7 @@
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
%a:ptr<function, vec3<u32>, read_write> = var
- %3:ptr<function, u32, read_write> = access %a, 2u
- %b:u32 = load %3
+ %b:u32 = load_vector_element %a, 2u
ret
}
}
@@ -113,8 +112,8 @@
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
%a:ptr<function, mat3x4<f32>, read_write> = var
- %3:ptr<function, f32, read_write> = access %a, 2u, 3u
- %b:f32 = load %3
+ %3:ptr<function, vec4<f32>, read_write> = access %a, 2u
+ %b:f32 = load_vector_element %3, 3u
ret
}
}
@@ -278,8 +277,7 @@
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
%b1 = block {
%a:ptr<function, vec2<f32>, read_write> = var
- %3:ptr<function, f32, read_write> = access %a, 1u
- %b:f32 = load %3
+ %b:f32 = load_vector_element %a, 1u
ret
}
}
diff --git a/src/tint/ir/from_program_var_test.cc b/src/tint/ir/from_program_var_test.cc
index 358e0cb..23bbf1c 100644
--- a/src/tint/ir/from_program_var_test.cc
+++ b/src/tint/ir/from_program_var_test.cc
@@ -141,6 +141,107 @@
)");
}
+TEST_F(IR_FromProgramVarTest, Emit_Var_Assign_ArrayOfArray_EvalOrder) {
+ Func("f", utils::Vector{Param("p", ty.i32())}, ty.i32(), utils::Vector{Return("p")});
+
+ auto* lhs = //
+ IndexAccessor( //
+ IndexAccessor( //
+ IndexAccessor("a", //
+ Call("f", 1_i)), //
+ Call("f", 2_i)), //
+ Call("f", 3_i));
+
+ auto* rhs = //
+ IndexAccessor( //
+ IndexAccessor( //
+ IndexAccessor("a", //
+ Call("f", 4_i)), //
+ Call("f", 5_i)), //
+ Call("f", 6_i));
+
+ WrapInFunction(
+ Var("a", ty.array<array<array<i32, 5>, 5>, 5>(), builtin::AddressSpace::kFunction), //
+ Assign(lhs, rhs));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%f = func(%p:i32):i32 -> %b1 {
+ %b1 = block {
+ ret %p
+ }
+}
+%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %a:ptr<function, array<array<array<i32, 5>, 5>, 5>, read_write> = var
+ %5:i32 = call %f, 1i
+ %6:i32 = call %f, 2i
+ %7:i32 = call %f, 3i
+ %8:ptr<function, i32, read_write> = access %a, %5, %6, %7
+ %9:i32 = call %f, 4i
+ %10:i32 = call %f, 5i
+ %11:i32 = call %f, 6i
+ %12:ptr<function, i32, read_write> = access %a, %9, %10, %11
+ %13:i32 = load %12
+ store %8, %13
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramVarTest, Emit_Var_Assign_ArrayOfMatrix_EvalOrder) {
+ Func("f", utils::Vector{Param("p", ty.i32())}, ty.i32(), utils::Vector{Return("p")});
+
+ auto* lhs = //
+ IndexAccessor( //
+ IndexAccessor( //
+ IndexAccessor("a", //
+ Call("f", 1_i)), //
+ Call("f", 2_i)), //
+ Call("f", 3_i));
+
+ auto* rhs = //
+ IndexAccessor( //
+ IndexAccessor( //
+ IndexAccessor("a", //
+ Call("f", 4_i)), //
+ Call("f", 5_i)), //
+ Call("f", 6_i));
+
+ WrapInFunction(Var("a", ty.array<mat3x4<f32>, 5>(), builtin::AddressSpace::kFunction), //
+ Assign(lhs, rhs));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%f = func(%p:i32):i32 -> %b1 {
+ %b1 = block {
+ ret %p
+ }
+}
+%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %a:ptr<function, array<mat3x4<f32>, 5>, read_write> = var
+ %5:i32 = call %f, 1i
+ %6:i32 = call %f, 2i
+ %7:ptr<function, vec4<f32>, read_write> = access %a, %5, %6
+ %8:i32 = call %f, 3i
+ %9:i32 = call %f, 4i
+ %10:i32 = call %f, 5i
+ %11:ptr<function, vec4<f32>, read_write> = access %a, %9, %10
+ %12:i32 = call %f, 6i
+ %13:f32 = load_vector_element %11, %12
+ store_vector_element %7 %7, %8, %13
+ ret
+ }
+}
+)");
+}
+
TEST_F(IR_FromProgramVarTest, Emit_Var_CompoundAssign_42i) {
WrapInFunction(Var("a", ty.i32(), builtin::AddressSpace::kFunction), //
CompoundAssign("a", 42_i, ast::BinaryOp::kAdd));
@@ -161,7 +262,7 @@
)");
}
-TEST_F(IR_FromProgramVarTest, Emit_Var_CompoundAssign_EvalOrder) {
+TEST_F(IR_FromProgramVarTest, Emit_Var_CompoundAssign_ArrayOfArray_EvalOrder) {
Func("f", utils::Vector{Param("p", ty.i32())}, ty.i32(), utils::Vector{Return("p")});
auto* lhs = //
@@ -214,5 +315,57 @@
)");
}
+TEST_F(IR_FromProgramVarTest, Emit_Var_CompoundAssign_ArrayOfMatrix_EvalOrder) {
+ Func("f", utils::Vector{Param("p", ty.i32())}, ty.i32(), utils::Vector{Return("p")});
+
+ auto* lhs = //
+ IndexAccessor( //
+ IndexAccessor( //
+ IndexAccessor("a", //
+ Call("f", 1_i)), //
+ Call("f", 2_i)), //
+ Call("f", 3_i));
+
+ auto* rhs = //
+ IndexAccessor( //
+ IndexAccessor( //
+ IndexAccessor("a", //
+ Call("f", 4_i)), //
+ Call("f", 5_i)), //
+ Call("f", 6_i));
+
+ WrapInFunction(Var("a", ty.array<mat3x4<f32>, 5>(), builtin::AddressSpace::kFunction), //
+ CompoundAssign(lhs, rhs, ast::BinaryOp::kAdd));
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%f = func(%p:i32):i32 -> %b1 {
+ %b1 = block {
+ ret %p
+ }
+}
+%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %a:ptr<function, array<mat3x4<f32>, 5>, read_write> = var
+ %5:i32 = call %f, 1i
+ %6:i32 = call %f, 2i
+ %7:ptr<function, vec4<f32>, read_write> = access %a, %5, %6
+ %8:i32 = call %f, 3i
+ %9:i32 = call %f, 4i
+ %10:i32 = call %f, 5i
+ %11:ptr<function, vec4<f32>, read_write> = access %a, %9, %10
+ %12:i32 = call %f, 6i
+ %13:f32 = load_vector_element %11, %12
+ %14:f32 = load_vector_element %7, %8
+ %15:f32 = add %14, %13
+ store_vector_element %7 %7, %8, %15
+ ret
+ }
+}
+)");
+}
+
} // namespace
} // namespace tint::ir
diff --git a/src/tint/ir/load_vector_element.cc b/src/tint/ir/load_vector_element.cc
new file mode 100644
index 0000000..73e33ff
--- /dev/null
+++ b/src/tint/ir/load_vector_element.cc
@@ -0,0 +1,32 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/load_vector_element.h"
+#include "src/tint/debug.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::LoadVectorElement);
+
+namespace tint::ir {
+
+LoadVectorElement::LoadVectorElement(InstructionResult* result, ir::Value* from, ir::Value* index) {
+ flags_.Add(Flag::kSequenced);
+
+ AddOperand(LoadVectorElement::kFromOperandOffset, from);
+ AddOperand(LoadVectorElement::kIndexOperandOffset, index);
+ AddResult(result);
+}
+
+LoadVectorElement::~LoadVectorElement() = default;
+
+} // namespace tint::ir
diff --git a/src/tint/ir/load_vector_element.h b/src/tint/ir/load_vector_element.h
new file mode 100644
index 0000000..056f08b
--- /dev/null
+++ b/src/tint/ir/load_vector_element.h
@@ -0,0 +1,48 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_LOAD_VECTOR_ELEMENT_H_
+#define SRC_TINT_IR_LOAD_VECTOR_ELEMENT_H_
+
+#include "src/tint/ir/operand_instruction.h"
+#include "src/tint/utils/castable.h"
+
+namespace tint::ir {
+
+/// A load instruction for a single vector element in the IR.
+class LoadVectorElement : public utils::Castable<LoadVectorElement, OperandInstruction<3, 0>> {
+ public:
+ /// The offset in Operands() for the `from` value
+ static constexpr size_t kFromOperandOffset = 0;
+
+ /// The offset in Operands() for the `index` value
+ static constexpr size_t kIndexOperandOffset = 1;
+
+ /// Constructor
+ /// @param result the result value
+ /// @param from the vector pointer
+ /// @param index the new vector element index
+ LoadVectorElement(InstructionResult* result, ir::Value* from, ir::Value* index);
+ ~LoadVectorElement() override;
+
+ /// @returns the vector pointer value
+ ir::Value* From() { return operands_[kFromOperandOffset]; }
+
+ /// @returns the new vector element index
+ ir::Value* Index() { return operands_[kIndexOperandOffset]; }
+};
+
+} // namespace tint::ir
+
+#endif // SRC_TINT_IR_LOAD_VECTOR_ELEMENT_H_
diff --git a/src/tint/ir/load_vector_element_test.cc b/src/tint/ir/load_vector_element_test.cc
new file mode 100644
index 0000000..1117ca3
--- /dev/null
+++ b/src/tint/ir/load_vector_element_test.cc
@@ -0,0 +1,62 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/instruction.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+using IR_LoadVectorElementTest = IRTestHelper;
+
+TEST_F(IR_LoadVectorElementTest, Create) {
+ auto* from = b.Var(ty.ptr<private_, vec3<i32>>());
+ auto* inst = b.LoadVectorElement(from, 2_i);
+
+ ASSERT_TRUE(inst->Is<LoadVectorElement>());
+ ASSERT_EQ(inst->From(), from->Result());
+
+ ASSERT_TRUE(inst->Index()->Is<Constant>());
+ auto index = inst->Index()->As<Constant>()->Value();
+ ASSERT_TRUE(index->Is<constant::Scalar<i32>>());
+ EXPECT_EQ(2_i, index->As<constant::Scalar<i32>>()->ValueAs<i32>());
+}
+
+TEST_F(IR_LoadVectorElementTest, Usage) {
+ auto* from = b.Var(ty.ptr<private_, vec3<i32>>());
+ auto* inst = b.LoadVectorElement(from, 2_i);
+
+ ASSERT_NE(inst->From(), nullptr);
+ EXPECT_THAT(inst->From()->Usages(), testing::UnorderedElementsAre(Usage{inst, 0u}));
+
+ ASSERT_NE(inst->Index(), nullptr);
+ EXPECT_THAT(inst->Index()->Usages(), testing::UnorderedElementsAre(Usage{inst, 1u}));
+}
+
+TEST_F(IR_LoadVectorElementTest, Result) {
+ auto* from = b.Var(ty.ptr<private_, vec3<i32>>());
+ auto* inst = b.LoadVectorElement(from, 2_i);
+
+ EXPECT_TRUE(inst->HasResults());
+ EXPECT_FALSE(inst->HasMultiResults());
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/operand_instruction.cc b/src/tint/ir/operand_instruction.cc
index 9a1f4aa..3b9f38e 100644
--- a/src/tint/ir/operand_instruction.cc
+++ b/src/tint/ir/operand_instruction.cc
@@ -17,6 +17,7 @@
using Op10 = tint::ir::OperandInstruction<1, 0>;
using Op11 = tint::ir::OperandInstruction<1, 1>;
using Op20 = tint::ir::OperandInstruction<2, 0>;
+using Op30 = tint::ir::OperandInstruction<3, 0>;
using Op21 = tint::ir::OperandInstruction<2, 1>;
using Op31 = tint::ir::OperandInstruction<3, 1>;
using Op41 = tint::ir::OperandInstruction<4, 1>;
@@ -24,6 +25,7 @@
TINT_INSTANTIATE_TYPEINFO(Op10);
TINT_INSTANTIATE_TYPEINFO(Op11);
TINT_INSTANTIATE_TYPEINFO(Op20);
+TINT_INSTANTIATE_TYPEINFO(Op30);
TINT_INSTANTIATE_TYPEINFO(Op21);
TINT_INSTANTIATE_TYPEINFO(Op31);
TINT_INSTANTIATE_TYPEINFO(Op41);
diff --git a/src/tint/ir/store_vector_element.cc b/src/tint/ir/store_vector_element.cc
new file mode 100644
index 0000000..1b50f89
--- /dev/null
+++ b/src/tint/ir/store_vector_element.cc
@@ -0,0 +1,32 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/store_vector_element.h"
+#include "src/tint/debug.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::StoreVectorElement);
+
+namespace tint::ir {
+
+StoreVectorElement::StoreVectorElement(ir::Value* to, ir::Value* index, ir::Value* value) {
+ flags_.Add(Flag::kSequenced);
+
+ AddOperand(StoreVectorElement::kToOperandOffset, to);
+ AddOperand(StoreVectorElement::kIndexOperandOffset, index);
+ AddOperand(StoreVectorElement::kValueOperandOffset, value);
+}
+
+StoreVectorElement::~StoreVectorElement() = default;
+
+} // namespace tint::ir
diff --git a/src/tint/ir/store_vector_element.h b/src/tint/ir/store_vector_element.h
new file mode 100644
index 0000000..6b2e376
--- /dev/null
+++ b/src/tint/ir/store_vector_element.h
@@ -0,0 +1,54 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_STORE_VECTOR_ELEMENT_H_
+#define SRC_TINT_IR_STORE_VECTOR_ELEMENT_H_
+
+#include "src/tint/ir/operand_instruction.h"
+#include "src/tint/utils/castable.h"
+
+namespace tint::ir {
+
+/// A store instruction for a single vector element in the IR.
+class StoreVectorElement : public utils::Castable<StoreVectorElement, OperandInstruction<3, 0>> {
+ public:
+ /// The offset in Operands() for the `to` value
+ static constexpr size_t kToOperandOffset = 0;
+
+ /// The offset in Operands() for the `index` value
+ static constexpr size_t kIndexOperandOffset = 1;
+
+ /// The offset in Operands() for the `value` value
+ static constexpr size_t kValueOperandOffset = 2;
+
+ /// Constructor
+ /// @param to the vector pointer
+ /// @param index the new vector element index
+ /// @param value the new vector element value
+ StoreVectorElement(ir::Value* to, ir::Value* index, ir::Value* value);
+ ~StoreVectorElement() override;
+
+ /// @returns the vector pointer value
+ ir::Value* To() { return operands_[kToOperandOffset]; }
+
+ /// @returns the new vector element index
+ ir::Value* Index() { return operands_[kIndexOperandOffset]; }
+
+ /// @returns the new vector element value
+ ir::Value* Value() { return operands_[kValueOperandOffset]; }
+};
+
+} // namespace tint::ir
+
+#endif // SRC_TINT_IR_STORE_VECTOR_ELEMENT_H_
diff --git a/src/tint/ir/store_vector_element_test.cc b/src/tint/ir/store_vector_element_test.cc
new file mode 100644
index 0000000..1d71a71
--- /dev/null
+++ b/src/tint/ir/store_vector_element_test.cc
@@ -0,0 +1,70 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gmock/gmock.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/instruction.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+using IR_StoreVectorElementTest = IRTestHelper;
+
+TEST_F(IR_StoreVectorElementTest, Create) {
+ auto* to = b.Var(ty.ptr<private_, vec3<i32>>());
+ auto* inst = b.StoreVectorElement(to, 2_i, 4_i);
+
+ ASSERT_TRUE(inst->Is<StoreVectorElement>());
+ ASSERT_EQ(inst->To(), to->Result());
+
+ ASSERT_TRUE(inst->Index()->Is<Constant>());
+ auto index = inst->Index()->As<Constant>()->Value();
+ ASSERT_TRUE(index->Is<constant::Scalar<i32>>());
+ EXPECT_EQ(2_i, index->As<constant::Scalar<i32>>()->ValueAs<i32>());
+
+ ASSERT_TRUE(inst->Value()->Is<Constant>());
+ auto value = inst->Value()->As<Constant>()->Value();
+ ASSERT_TRUE(value->Is<constant::Scalar<i32>>());
+ EXPECT_EQ(4_i, value->As<constant::Scalar<i32>>()->ValueAs<i32>());
+}
+
+TEST_F(IR_StoreVectorElementTest, Usage) {
+ auto* to = b.Var(ty.ptr<private_, vec3<i32>>());
+ auto* inst = b.StoreVectorElement(to, 2_i, 4_i);
+
+ ASSERT_NE(inst->To(), nullptr);
+ EXPECT_THAT(inst->To()->Usages(), testing::UnorderedElementsAre(Usage{inst, 0u}));
+
+ ASSERT_NE(inst->Index(), nullptr);
+ EXPECT_THAT(inst->Index()->Usages(), testing::UnorderedElementsAre(Usage{inst, 1u}));
+
+ ASSERT_NE(inst->Value(), nullptr);
+ EXPECT_THAT(inst->Value()->Usages(), testing::UnorderedElementsAre(Usage{inst, 2u}));
+}
+
+TEST_F(IR_StoreVectorElementTest, Result) {
+ auto* to = b.Var(ty.ptr<private_, vec3<i32>>());
+ auto* inst = b.StoreVectorElement(to, 2_i, 4_i);
+
+ EXPECT_FALSE(inst->HasResults());
+ EXPECT_FALSE(inst->HasMultiResults());
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/to_program.cc b/src/tint/ir/to_program.cc
index 42fd8c9..0a4abcd 100644
--- a/src/tint/ir/to_program.cc
+++ b/src/tint/ir/to_program.cc
@@ -37,12 +37,14 @@
#include "src/tint/ir/instruction.h"
#include "src/tint/ir/let.h"
#include "src/tint/ir/load.h"
+#include "src/tint/ir/load_vector_element.h"
#include "src/tint/ir/loop.h"
#include "src/tint/ir/module.h"
#include "src/tint/ir/multi_in_block.h"
#include "src/tint/ir/next_iteration.h"
#include "src/tint/ir/return.h"
#include "src/tint/ir/store.h"
+#include "src/tint/ir/store_vector_element.h"
#include "src/tint/ir/switch.h"
#include "src/tint/ir/unary.h"
#include "src/tint/ir/unreachable.h"
@@ -279,26 +281,28 @@
void Instruction(ir::Instruction* inst) {
tint::Switch(
- inst, //
- [&](ir::Access* i) { Access(i); }, //
- [&](ir::Binary* i) { Binary(i); }, //
- [&](ir::BreakIf* i) { BreakIf(i); }, //
- [&](ir::Call* i) { Call(i); }, //
- [&](ir::Continue*) {}, //
- [&](ir::ExitIf*) {}, //
- [&](ir::ExitLoop* i) { ExitLoop(i); }, //
- [&](ir::ExitSwitch* i) { ExitSwitch(i); }, //
- [&](ir::If* i) { If(i); }, //
- [&](ir::Load* l) { Load(l); }, //
- [&](ir::Loop* l) { Loop(l); }, //
- [&](ir::NextIteration*) {}, //
- [&](ir::Return* i) { Return(i); }, //
- [&](ir::Store* i) { Store(i); }, //
- [&](ir::Switch* i) { Switch(i); }, //
- [&](ir::Unary* u) { Unary(u); }, //
- [&](ir::Unreachable*) {}, //
- [&](ir::Var* i) { Var(i); }, //
- [&](ir::Let* i) { Let(i); }, //
+ inst, //
+ [&](ir::Access* i) { Access(i); }, //
+ [&](ir::Binary* i) { Binary(i); }, //
+ [&](ir::BreakIf* i) { BreakIf(i); }, //
+ [&](ir::Call* i) { Call(i); }, //
+ [&](ir::Continue*) {}, //
+ [&](ir::ExitIf*) {}, //
+ [&](ir::ExitLoop* i) { ExitLoop(i); }, //
+ [&](ir::ExitSwitch* i) { ExitSwitch(i); }, //
+ [&](ir::If* i) { If(i); }, //
+ [&](ir::Let* i) { Let(i); }, //
+ [&](ir::Load* l) { Load(l); }, //
+ [&](ir::LoadVectorElement* i) { LoadVectorElement(i); }, //
+ [&](ir::Loop* l) { Loop(l); }, //
+ [&](ir::NextIteration*) {}, //
+ [&](ir::Return* i) { Return(i); }, //
+ [&](ir::Store* i) { Store(i); }, //
+ [&](ir::StoreVectorElement* i) { StoreVectorElement(i); }, //
+ [&](ir::Switch* i) { Switch(i); }, //
+ [&](ir::Unary* u) { Unary(u); }, //
+ [&](ir::Unreachable*) {}, //
+ [&](ir::Var* i) { Var(i); }, //
[&](Default) { UNHANDLED_CASE(inst); });
}
@@ -485,6 +489,12 @@
Append(b.Assign(dst, src));
}
+ void StoreVectorElement(ir::StoreVectorElement* store) {
+ auto* ptr = Expr(store->To());
+ auto* val = Expr(store->Value());
+ Append(b.Assign(VectorMemberAccess(ptr, store->Index()), val));
+ }
+
void Call(ir::Call* call) {
auto args = utils::Transform<4>(call->Args(), [&](ir::Value* arg) {
// Pointer-like arguments are passed by pointer, never reference.
@@ -525,6 +535,11 @@
void Load(ir::Load* l) { Bind(l->Result(), Expr(l->From())); }
+ void LoadVectorElement(ir::LoadVectorElement* load) {
+ auto* ptr = Expr(load->From());
+ Bind(load->Result(), VectorMemberAccess(ptr, load->Index()));
+ }
+
void Unary(ir::Unary* u) {
const ast::Expression* expr = nullptr;
switch (u->Kind()) {
@@ -546,23 +561,7 @@
obj_ty,
[&](const type::Vector* vec) {
TINT_DEFER(obj_ty = vec->type());
- if (auto* c = index->As<ir::Constant>()) {
- switch (c->Value()->ValueAs<int>()) {
- case 0:
- expr = b.MemberAccessor(expr, "x");
- return;
- case 1:
- expr = b.MemberAccessor(expr, "y");
- return;
- case 2:
- expr = b.MemberAccessor(expr, "z");
- return;
- case 3:
- expr = b.MemberAccessor(expr, "w");
- return;
- }
- }
- expr = b.IndexAccessor(expr, Expr(index));
+ expr = VectorMemberAccess(expr, index);
},
[&](const type::Matrix* mat) {
obj_ty = mat->ColumnType();
@@ -1004,6 +1003,22 @@
}
return false;
}
+
+ const ast::Expression* VectorMemberAccess(const ast::Expression* expr, ir::Value* index) {
+ if (auto* c = index->As<ir::Constant>()) {
+ switch (c->Value()->ValueAs<int>()) {
+ case 0:
+ return b.MemberAccessor(expr, "x");
+ case 1:
+ return b.MemberAccessor(expr, "y");
+ case 2:
+ return b.MemberAccessor(expr, "z");
+ case 3:
+ return b.MemberAccessor(expr, "w");
+ }
+ }
+ return b.IndexAccessor(expr, Expr(index));
+ }
};
} // namespace
diff --git a/src/tint/ir/to_program_roundtrip_test.cc b/src/tint/ir/to_program_roundtrip_test.cc
index 8949fef..3d8bc7a 100644
--- a/src/tint/ir/to_program_roundtrip_test.cc
+++ b/src/tint/ir/to_program_roundtrip_test.cc
@@ -1131,6 +1131,99 @@
}
////////////////////////////////////////////////////////////////////////////////
+// Assignment
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, Assign_ArrayOfArrayOfArrayAccess_123456) {
+ Test(R"(
+fn e(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v : array<array<array<i32, 5u>, 5u>, 5u>;
+ v[e(1i)][e(2i)][e(3i)] = v[e(4i)][e(5i)][e(6i)];
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Assign_ArrayOfArrayOfArrayAccess_261345) {
+ Test(R"(
+fn e(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v : array<array<array<i32, 5u>, 5u>, 5u>;
+ let v_2 = e(2i);
+ let v_3 = e(6i);
+ v[e(1i)][v_2][e(3i)] = v[e(4i)][e(5i)][v_3];
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Assign_ArrayOfArrayOfArrayAccess_532614) {
+ Test(R"(
+fn e(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v : array<array<array<i32, 5u>, 5u>, 5u>;
+ let v_2 = e(5i);
+ let v_3 = e(3i);
+ let v_4 = e(2i);
+ let v_5 = e(6i);
+ v[e(1i)][v_4][v_3] = v[e(4i)][v_2][v_5];
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Assign_ArrayOfMatrixAccess_123456) {
+ Test(R"(
+fn e(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v : array<mat3x4<f32>, 5u>;
+ v[e(1i)][e(2i)][e(3i)] = v[e(4i)][e(5i)][e(6i)];
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Assign_ArrayOfMatrixAccess_261345) {
+ Test(R"(
+fn e(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v : array<mat3x4<f32>, 5u>;
+ let v_2 = e(2i);
+ let v_3 = e(6i);
+ v[e(1i)][v_2][e(3i)] = v[e(4i)][e(5i)][v_3];
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Assign_ArrayOfMatrixAccess_532614) {
+ Test(R"(
+fn e(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v : array<mat3x4<f32>, 5u>;
+ let v_2 = e(5i);
+ let v_3 = e(3i);
+ let v_4 = e(2i);
+ let v_5 = e(6i);
+ v[e(1i)][v_4][v_3] = v[e(4i)][v_2][v_5];
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
// Compound assignment
////////////////////////////////////////////////////////////////////////////////
TEST_F(IRToProgramRoundtripTest, CompoundAssign_Increment) {
@@ -1319,6 +1412,89 @@
})");
}
+TEST_F(IRToProgramRoundtripTest, CompoundAssign_ArrayOfMatrixAccess_123456) {
+ Test(R"(
+fn e(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v : array<mat3x4<f32>, 5u>;
+ v[e(1i)][e(2i)][e(3i)] += v[e(4i)][e(5i)][e(6i)];
+}
+)",
+ R"(fn e(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v : array<mat3x4<f32>, 5u>;
+ let v_1 = &(v[e(1i)][e(2i)]);
+ let v_2 = e(3i);
+ let v_3 = v[e(4i)][e(5i)][e(6i)];
+ (*(v_1))[v_2] = ((*(v_1))[v_2] + v_3);
+})");
+}
+
+TEST_F(IRToProgramRoundtripTest, CompoundAssign_ArrayOfMatrixAccess_261345) {
+ Test(R"(
+fn e(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v : array<mat3x4<f32>, 5u>;
+ let v_2 = e(2i);
+ let v_3 = e(6i);
+ v[e(1i)][v_2][e(3i)] += v[e(4i)][e(5i)][v_3];
+}
+)",
+ R"(fn e(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v : array<mat3x4<f32>, 5u>;
+ let v_2 = e(2i);
+ let v_3 = e(6i);
+ let v_1 = &(v[e(1i)][v_2]);
+ let v_4 = e(3i);
+ let v_5 = v[e(4i)][e(5i)][v_3];
+ (*(v_1))[v_4] = ((*(v_1))[v_4] + v_5);
+})");
+}
+
+TEST_F(IRToProgramRoundtripTest, CompoundAssign_ArrayOfMatrixAccess_532614) {
+ Test(R"(
+fn e(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v : array<mat3x4<f32>, 5u>;
+ let v_2 = e(5i);
+ let v_3 = e(3i);
+ let v_4 = e(2i);
+ let v_5 = e(6i);
+ v[e(1i)][v_4][v_3] += v[e(4i)][v_2][v_5];
+}
+)",
+ R"(fn e(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var v : array<mat3x4<f32>, 5u>;
+ let v_2 = e(5i);
+ let v_3 = e(3i);
+ let v_4 = e(2i);
+ let v_5 = e(6i);
+ let v_1 = &(v[e(1i)][v_4]);
+ let v_6 = v[e(4i)][v_2][v_5];
+ (*(v_1))[v_3] = ((*(v_1))[v_3] + v_6);
+})");
+}
+
////////////////////////////////////////////////////////////////////////////////
// let
////////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/ir/transform/var_for_dynamic_index.cc b/src/tint/ir/transform/var_for_dynamic_index.cc
index a41e307..b328437 100644
--- a/src/tint/ir/transform/var_for_dynamic_index.cc
+++ b/src/tint/ir/transform/var_for_dynamic_index.cc
@@ -36,9 +36,11 @@
// The access instruction.
Access* access = nullptr;
// The index of the first dynamic index.
- uint32_t first_dynamic_index = 0;
+ size_t first_dynamic_index = 0;
// The object type that corresponds to the source of the first dynamic index.
const type::Type* dynamic_index_source_type = nullptr;
+ // If the access indexes a vector, then the type of that vector
+ const type::Vector* vector_access_type = nullptr;
};
// A partial access chain that uses constant indices to get to an object that will be
@@ -62,32 +64,50 @@
}
};
+enum class Action { kStop, kContinue };
+
+template <typename CALLBACK>
+void WalkAccessChain(ir::Access* access, CALLBACK&& callback) {
+ auto indices = access->Indices();
+ auto* ty = access->Object()->Type();
+ for (size_t i = 0; i < indices.Length(); i++) {
+ if (callback(i, indices[i], ty) == Action::kStop) {
+ break;
+ }
+ auto* const_idx = indices[i]->As<Constant>();
+ ty = const_idx ? ty->Element(const_idx->Value()->ValueAs<u32>()) : ty->Elements().type;
+ }
+}
+
std::optional<AccessToReplace> ShouldReplace(Access* access) {
if (access->Result()->Type()->Is<type::Pointer>()) {
// No need to modify accesses into pointer types.
return {};
}
- // Find the first dynamic index, if any.
- const auto& indices = access->Indices();
- auto* source_type = access->Object()->Type();
- for (uint32_t i = 0; i < indices.Length(); i++) {
- if (source_type->Is<type::Vector>()) {
- // Stop if we hit a vector, as they can support dynamic accesses.
- return {};
+ std::optional<AccessToReplace> result;
+ WalkAccessChain(access, [&](size_t i, ir::Value* index, const type::Type* ty) {
+ if (auto* vec = ty->As<type::Vector>()) {
+ // If we haven't found a dynamic index before the vector, then the transform doesn't
+ // need to hoist the access into a var as a vector value can be dynamically indexed.
+ // If we have found a dynamic index before the vector, then make a note that we're
+ // indexing a vector as we can't obtain a pointer to a vector element, so this needs to
+ // be handled specially.
+ if (result) {
+ result->vector_access_type = vec;
+ }
+ return Action::kStop;
}
- // Check if the index is dynamic.
- auto* const_idx = indices[i]->As<Constant>();
- if (!const_idx) {
- return AccessToReplace{access, i, source_type};
+ // Check if this is the first dynamic index.
+ if (!result && !index->Is<Constant>()) {
+ result = AccessToReplace{access, i, ty};
}
- // Update the current source object type.
- source_type = source_type->Element(const_idx->Value()->ValueAs<u32>());
- }
- // No dynamic indices were found.
- return {};
+ return Action::kContinue;
+ });
+
+ return result;
}
} // namespace
@@ -141,18 +161,35 @@
// Create a new access instruction using the local variable as the source.
utils::Vector<Value*, 4> indices{access->Indices().Offset(to_replace.first_dynamic_index)};
- auto* new_access =
- builder.Access(ir->Types().ptr(builtin::AddressSpace::kFunction,
- access->Result()->Type(), builtin::Access::kReadWrite),
- local, indices);
- access->ReplaceWith(new_access);
+ const type::Type* access_type = access->Result()->Type();
+ Value* vector_index = nullptr;
+ if (to_replace.vector_access_type) {
+ // The old access indexed the element of a vector.
+ // Its not valid to obtain the address of an element of a vector, so we need to access
+ // up to the vector, then use LoadVectorElement to load the element.
+ // As a vector element is always a scalar, we know the last index of the access is the
+ // index on the vector. Pop that index to obtain the index to pass to
+ // LoadVectorElement(), and perform the rest of the access chain.
+ access_type = to_replace.vector_access_type;
+ vector_index = indices.Pop();
+ }
- // Load from the access to get the final result value.
- auto* load = builder.Load(new_access);
- load->InsertAfter(new_access);
+ ir::Instruction* new_access =
+ builder.Access(ir->Types().ptr(builtin::AddressSpace::kFunction, access_type,
+ builtin::Access::kReadWrite),
+ local, indices);
+ new_access->InsertBefore(access);
+
+ ir::Instruction* load = nullptr;
+ if (to_replace.vector_access_type) {
+ load = builder.LoadVectorElement(new_access->Result(), vector_index);
+ } else {
+ load = builder.Load(new_access);
+ }
// Replace all uses of the old access instruction with the loaded result.
- access->Result()->ReplaceAllUsesWith([&](Usage) { return load->Result(); });
+ access->Result()->ReplaceAllUsesWith(load->Result());
+ access->ReplaceWith(load);
}
}
diff --git a/src/tint/ir/transform/var_for_dynamic_index_test.cc b/src/tint/ir/transform/var_for_dynamic_index_test.cc
index 7848174..66e7591 100644
--- a/src/tint/ir/transform/var_for_dynamic_index_test.cc
+++ b/src/tint/ir/transform/var_for_dynamic_index_test.cc
@@ -108,15 +108,15 @@
func->SetParams({mat, idx});
auto* block = func->Block();
- auto* access = block->Append(b.Access(ty.ptr<function, f32>(), mat, idx, idx));
- auto* load = block->Append(b.Load(access));
+ auto* access = block->Append(b.Access(ty.ptr<function, vec2<f32>>(), mat, idx));
+ auto* load = block->Append(b.LoadVectorElement(access, idx));
block->Append(b.Return(func, load));
auto* expect = R"(
%foo = func(%2:ptr<function, mat2x2<f32>, read_write>, %3:i32):f32 -> %b1 {
%b1 = block {
- %4:ptr<function, f32, read_write> = access %2, %3, %3
- %5:f32 = load %4
+ %4:ptr<function, vec2<f32>, read_write> = access %2, %3
+ %5:f32 = load_vector_element %4, %3
ret %5
}
}
@@ -203,6 +203,32 @@
EXPECT_EQ(expect, str());
}
+TEST_F(IR_VarForDynamicIndexTest, DynamicIndex_VectorValue) {
+ auto* mat = b.FunctionParam(ty.mat2x2<f32>());
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.Function("foo", ty.vec2<f32>());
+ func->SetParams({mat, idx});
+
+ auto* block = func->Block();
+ auto* access = block->Append(b.Access(ty.f32(), mat, idx, idx));
+ block->Append(b.Return(func, access));
+
+ auto* expect = R"(
+%foo = func(%2:mat2x2<f32>, %3:i32):vec2<f32> -> %b1 {
+ %b1 = block {
+ %4:ptr<function, mat2x2<f32>, read_write> = var, %2
+ %5:ptr<function, vec2<f32>, read_write> = access %4, %3
+ %6:f32 = load_vector_element %5, %3
+ ret %6
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
TEST_F(IR_VarForDynamicIndexTest, AccessChain) {
auto* arr = b.FunctionParam(ty.array(ty.array(ty.array<i32, 4u>(), 4u), 4u));
auto* idx = b.FunctionParam(ty.i32());
@@ -310,8 +336,8 @@
%b1 = block {
%4:mat4x4<f32> = access %2, 1u
%5:ptr<function, mat4x4<f32>, read_write> = var, %4
- %6:ptr<function, f32, read_write> = access %5, %3, 0u
- %7:f32 = load %6
+ %6:ptr<function, vec4<f32>, read_write> = access %5, %3
+ %7:f32 = load_vector_element %6, 0u
ret %7
}
}
diff --git a/src/tint/ir/validate.cc b/src/tint/ir/validate.cc
index b37acd9..73be407 100644
--- a/src/tint/ir/validate.cc
+++ b/src/tint/ir/validate.cc
@@ -35,11 +35,13 @@
#include "src/tint/ir/if.h"
#include "src/tint/ir/let.h"
#include "src/tint/ir/load.h"
+#include "src/tint/ir/load_vector_element.h"
#include "src/tint/ir/loop.h"
#include "src/tint/ir/multi_in_block.h"
#include "src/tint/ir/next_iteration.h"
#include "src/tint/ir/return.h"
#include "src/tint/ir/store.h"
+#include "src/tint/ir/store_vector_element.h"
#include "src/tint/ir/switch.h"
#include "src/tint/ir/swizzle.h"
#include "src/tint/ir/unary.h"
@@ -49,6 +51,7 @@
#include "src/tint/switch.h"
#include "src/tint/type/bool.h"
#include "src/tint/type/pointer.h"
+#include "src/tint/type/vector.h"
#include "src/tint/utils/reverse.h"
#include "src/tint/utils/scoped_assignment.h"
@@ -267,20 +270,22 @@
}
tint::Switch(
- inst, //
- [&](Access* a) { CheckAccess(a); }, //
- [&](Binary* b) { CheckBinary(b); }, //
- [&](Call* c) { CheckCall(c); }, //
- [&](If* if_) { CheckIf(if_); }, //
- [&](Load*) {}, //
- [&](Loop* l) { CheckLoop(l); }, //
- [&](Store*) {}, //
- [&](Switch* s) { CheckSwitch(s); }, //
- [&](Swizzle*) {}, //
- [&](Terminator* b) { CheckTerminator(b); }, //
- [&](Unary* u) { CheckUnary(u); }, //
- [&](Var* var) { CheckVar(var); }, //
- [&](Let* let) { CheckLet(let); }, //
+ inst, //
+ [&](Access* a) { CheckAccess(a); }, //
+ [&](Binary* b) { CheckBinary(b); }, //
+ [&](Call* c) { CheckCall(c); }, //
+ [&](If* if_) { CheckIf(if_); }, //
+ [&](Let* let) { CheckLet(let); }, //
+ [&](Load*) {}, //
+ [&](LoadVectorElement* l) { CheckLoadVectorElement(l); }, //
+ [&](Loop* l) { CheckLoop(l); }, //
+ [&](Store*) {}, //
+ [&](StoreVectorElement* s) { CheckStoreVectorElement(s); }, //
+ [&](Switch* s) { CheckSwitch(s); }, //
+ [&](Swizzle*) {}, //
+ [&](Terminator* b) { CheckTerminator(b); }, //
+ [&](Unary* u) { CheckUnary(u); }, //
+ [&](Var* var) { CheckVar(var); }, //
[&](Default) {
AddError(std::string("missing validation of: ") + inst->TypeInfo().name);
});
@@ -328,15 +333,20 @@
for (size_t i = 0; i < a->Indices().Length(); i++) {
auto err = [&](std::string msg) {
- AddError(a, i + Access::kIndicesOperandOffset, std::move(msg));
+ AddError(a, i + Access::kIndicesOperandOffset, "access: " + msg);
};
auto note = [&](std::string msg) {
- AddNote(a, i + Access::kIndicesOperandOffset, std::move(msg));
+ AddNote(a, i + Access::kIndicesOperandOffset, msg);
};
auto* index = a->Indices()[i];
if (TINT_UNLIKELY(!index->Type()->is_integer_scalar())) {
- err("access: index must be integer, got " + index->Type()->FriendlyName());
+ err("index must be integer, got " + index->Type()->FriendlyName());
+ return;
+ }
+
+ if (is_ptr && ty->Is<type::Vector>()) {
+ err("cannot obtain address of vector element");
return;
}
@@ -347,7 +357,7 @@
// If the index is unsigned, we can skip this.
auto idx = value->ValueAs<AInt>();
if (TINT_UNLIKELY(idx < 0)) {
- err("access: constant index must be positive, got " + std::to_string(idx));
+ err("constant index must be positive, got " + std::to_string(idx));
return;
}
}
@@ -357,18 +367,18 @@
if (TINT_UNLIKELY(!el)) {
// Is index in bounds?
if (auto el_count = ty->Elements().count; el_count != 0 && idx >= el_count) {
- err("access: index out of bounds for type " + current());
+ err("index out of bounds for type " + current());
note("acceptable range: [0.." + std::to_string(el_count - 1) + "]");
return;
}
- err("access: type " + current() + " cannot be indexed");
+ err("type " + current() + " cannot be indexed");
return;
}
ty = el;
} else {
auto* el = ty->Elements().type;
if (TINT_UNLIKELY(!el)) {
- err("access: type " + current() + " cannot be dynamically indexed");
+ err("type " + current() + " cannot be dynamically indexed");
return;
}
ty = el;
@@ -555,6 +565,58 @@
inst = inst->Block()->Parent();
}
}
+
+ void CheckLoadVectorElement(LoadVectorElement* l) {
+ CheckOperandsNotNull(l, //
+ LoadVectorElement::kFromOperandOffset,
+ LoadVectorElement::kIndexOperandOffset, "load_vector_element");
+
+ if (auto* res = l->Result()) {
+ if (auto* el_ty = GetVectorPtrElementType(l, LoadVectorElement::kFromOperandOffset)) {
+ if (res->Type() != el_ty) {
+ AddResultError(l, 0, "result type does not match vector pointer element type");
+ }
+ }
+ }
+ }
+
+ void CheckStoreVectorElement(StoreVectorElement* s) {
+ CheckOperandsNotNull(s, //
+ StoreVectorElement::kToOperandOffset,
+ StoreVectorElement::kValueOperandOffset, "store_vector_element");
+
+ if (auto* value = s->Value()) {
+ if (auto* el_ty = GetVectorPtrElementType(s, StoreVectorElement::kToOperandOffset)) {
+ if (value->Type() != el_ty) {
+ AddError(s, StoreVectorElement::kValueOperandOffset,
+ "value type does not match vector pointer element type");
+ }
+ }
+ }
+ }
+
+ const type::Type* GetVectorPtrElementType(Instruction* inst, size_t idx) {
+ auto* operand = inst->Operands()[idx];
+ if (TINT_UNLIKELY(!operand)) {
+ return nullptr;
+ }
+
+ auto* type = operand->Type();
+ if (TINT_UNLIKELY(!type)) {
+ return nullptr;
+ }
+
+ auto* vec_ptr_ty = type->As<type::Pointer>();
+ if (TINT_LIKELY(vec_ptr_ty)) {
+ auto* vec_ty = vec_ptr_ty->StoreType()->As<type::Vector>();
+ if (TINT_LIKELY(vec_ty)) {
+ return vec_ty->type();
+ }
+ }
+
+ AddError(inst, idx, "operand must be a pointer to vector, got " + type->FriendlyName());
+ return nullptr;
+ }
};
} // namespace
diff --git a/src/tint/ir/validate_test.cc b/src/tint/ir/validate_test.cc
index 7741cd9..acdc714 100644
--- a/src/tint/ir/validate_test.cc
+++ b/src/tint/ir/validate_test.cc
@@ -19,6 +19,7 @@
#include "src/tint/ir/builder.h"
#include "src/tint/ir/ir_test_helper.h"
#include "src/tint/ir/validate.h"
+#include "src/tint/type/array.h"
#include "src/tint/type/matrix.h"
#include "src/tint/type/pointer.h"
#include "src/tint/type/struct.h"
@@ -120,34 +121,6 @@
)");
}
-TEST_F(IR_ValidateTest, Valid_Access_Value) {
- auto* f = b.Function("my_func", ty.void_());
- auto* obj = b.FunctionParam(ty.mat3x2<f32>());
- f->SetParams({obj});
-
- b.With(f->Block(), [&] {
- b.Access(ty.f32(), obj, 1_u, 0_u);
- b.Return(f);
- });
-
- auto res = ir::Validate(mod);
- EXPECT_TRUE(res) << res.Failure().str();
-}
-
-TEST_F(IR_ValidateTest, Valid_Access_Ptr) {
- auto* f = b.Function("my_func", ty.void_());
- auto* obj = b.FunctionParam(ty.ptr<private_, mat3x2<f32>>());
- f->SetParams({obj});
-
- b.With(f->Block(), [&] {
- b.Access(ty.ptr<private_, f32>(), obj, 1_u, 0_u);
- b.Return(f);
- });
-
- auto res = ir::Validate(mod);
- EXPECT_TRUE(res) << res.Failure().str();
-}
-
TEST_F(IR_ValidateTest, Access_NegativeIndex) {
auto* f = b.Function("my_func", ty.void_());
auto* obj = b.FunctionParam(ty.vec3<f32>());
@@ -214,7 +187,7 @@
TEST_F(IR_ValidateTest, Access_OOB_Index_Ptr) {
auto* f = b.Function("my_func", ty.void_());
- auto* obj = b.FunctionParam(ty.ptr<private_, mat3x2<f32>>());
+ auto* obj = b.FunctionParam(ty.ptr<private_, array<array<f32, 2>, 3>>());
f->SetParams({obj});
b.With(f->Block(), [&] {
@@ -225,7 +198,7 @@
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
EXPECT_EQ(res.Failure().str(),
- R"(:3:55 error: access: index out of bounds for type ptr<vec2<f32>>
+ R"(:3:55 error: access: index out of bounds for type ptr<array<f32, 2>>
%3:ptr<private, f32, read_write> = access %2, 1u, 3u
^^
@@ -238,7 +211,7 @@
^^
note: # Disassembly
-%my_func = func(%2:ptr<private, mat3x2<f32>, read_write>):void -> %b1 {
+%my_func = func(%2:ptr<private, array<array<f32, 2>, 3>, read_write>):void -> %b1 {
%b1 = block {
%3:ptr<private, f32, read_write> = access %2, 1u, 3u
ret
@@ -424,7 +397,7 @@
TEST_F(IR_ValidateTest, Access_Incorrect_Type_Ptr_Ptr) {
auto* f = b.Function("my_func", ty.void_());
- auto* obj = b.FunctionParam(ty.ptr<private_, mat3x2<f32>>());
+ auto* obj = b.FunctionParam(ty.ptr<private_, array<array<f32, 2>, 3>>());
f->SetParams({obj});
b.With(f->Block(), [&] {
@@ -445,7 +418,7 @@
^^^^^^^^^^^
note: # Disassembly
-%my_func = func(%2:ptr<private, mat3x2<f32>, read_write>):void -> %b1 {
+%my_func = func(%2:ptr<private, array<array<f32, 2>, 3>, read_write>):void -> %b1 {
%b1 = block {
%3:ptr<private, i32, read_write> = access %2, 1u, 1u
ret
@@ -456,7 +429,7 @@
TEST_F(IR_ValidateTest, Access_Incorrect_Type_Ptr_Value) {
auto* f = b.Function("my_func", ty.void_());
- auto* obj = b.FunctionParam(ty.ptr<private_, mat3x2<f32>>());
+ auto* obj = b.FunctionParam(ty.ptr<private_, array<array<f32, 2>, 3>>());
f->SetParams({obj});
b.With(f->Block(), [&] {
@@ -477,6 +450,68 @@
^^^^^^^^^^^
note: # Disassembly
+%my_func = func(%2:ptr<private, array<array<f32, 2>, 3>, read_write>):void -> %b1 {
+ %b1 = block {
+ %3:f32 = access %2, 1u, 1u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Access_IndexVectorPtr) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(ty.ptr<private_, vec3<f32>>());
+ f->SetParams({obj});
+
+ b.With(f->Block(), [&] {
+ b.Access(ty.f32(), obj, 1_u);
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(),
+ R"(:3:25 error: access: cannot obtain address of vector element
+ %3:f32 = access %2, 1u
+ ^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func(%2:ptr<private, vec3<f32>, read_write>):void -> %b1 {
+ %b1 = block {
+ %3:f32 = access %2, 1u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Access_IndexVectorPtr_ViaMatrixPtr) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(ty.ptr<private_, mat3x2<f32>>());
+ f->SetParams({obj});
+
+ b.With(f->Block(), [&] {
+ b.Access(ty.f32(), obj, 1_u, 1_u);
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(),
+ R"(:3:29 error: access: cannot obtain address of vector element
+ %3:f32 = access %2, 1u, 1u
+ ^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
%my_func = func(%2:ptr<private, mat3x2<f32>, read_write>):void -> %b1 {
%b1 = block {
%3:f32 = access %2, 1u, 1u
@@ -486,6 +521,34 @@
)");
}
+TEST_F(IR_ValidateTest, Access_IndexVector) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(ty.vec3<f32>());
+ f->SetParams({obj});
+
+ b.With(f->Block(), [&] {
+ b.Access(ty.f32(), obj, 1_u);
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_TRUE(res) << res.Failure().str();
+}
+
+TEST_F(IR_ValidateTest, Access_IndexVector_ViaMatrix) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(ty.mat3x2<f32>());
+ f->SetParams({obj});
+
+ b.With(f->Block(), [&] {
+ b.Access(ty.f32(), obj, 1_u, 1_u);
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_TRUE(res) << res.Failure().str();
+}
+
TEST_F(IR_ValidateTest, Block_TerminatorInMiddle) {
auto* f = b.Function("my_func", ty.void_());
@@ -2514,5 +2577,195 @@
)");
}
+TEST_F(IR_ValidateTest, LoadVectorElement_NullResult) {
+ auto* f = b.Function("my_func", ty.void_());
+
+ b.With(f->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, vec3<f32>>());
+ b.Append(mod.instructions.Create<ir::LoadVectorElement>(nullptr, var->Result(),
+ b.Constant(1_i)));
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:4:5 error: instruction result is undefined
+ undef = load_vector_element %2, 1i
+ ^^^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ %2:ptr<function, vec3<f32>, read_write> = var
+ undef = load_vector_element %2, 1i
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, LoadVectorElement_NullFrom) {
+ auto* f = b.Function("my_func", ty.void_());
+
+ b.With(f->Block(), [&] {
+ b.Append(mod.instructions.Create<ir::LoadVectorElement>(b.InstructionResult(ty.f32()),
+ nullptr, b.Constant(1_i)));
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:34 error: load_vector_element: operand is undefined
+ %2:f32 = load_vector_element undef, 1i
+ ^^^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ %2:f32 = load_vector_element undef, 1i
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, LoadVectorElement_NullIndex) {
+ auto* f = b.Function("my_func", ty.void_());
+
+ b.With(f->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, vec3<f32>>());
+ b.Append(mod.instructions.Create<ir::LoadVectorElement>(b.InstructionResult(ty.f32()),
+ var->Result(), nullptr));
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:4:38 error: load_vector_element: operand is undefined
+ %3:f32 = load_vector_element %2, undef
+ ^^^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ %2:ptr<function, vec3<f32>, read_write> = var
+ %3:f32 = load_vector_element %2, undef
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, StoreVectorElement_NullTo) {
+ auto* f = b.Function("my_func", ty.void_());
+
+ b.With(f->Block(), [&] {
+ b.Append(mod.instructions.Create<ir::StoreVectorElement>(nullptr, b.Constant(1_i),
+ b.Constant(2_i)));
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:32 error: store_vector_element: operand is undefined
+ store_vector_element undef undef, 1i, 2i
+ ^^^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ store_vector_element undef undef, 1i, 2i
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, StoreVectorElement_NullIndex) {
+ auto* f = b.Function("my_func", ty.void_());
+
+ b.With(f->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, vec3<f32>>());
+ b.Append(mod.instructions.Create<ir::StoreVectorElement>(var->Result(), nullptr,
+ b.Constant(2_i)));
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:4:33 error: store_vector_element: operand is undefined
+ store_vector_element %2 %2, undef, 2i
+ ^^^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+:4:40 error: value type does not match vector pointer element type
+ store_vector_element %2 %2, undef, 2i
+ ^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ %2:ptr<function, vec3<f32>, read_write> = var
+ store_vector_element %2 %2, undef, 2i
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, StoreVectorElement_NullValue) {
+ auto* f = b.Function("my_func", ty.void_());
+
+ b.With(f->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, vec3<f32>>());
+ b.Append(mod.instructions.Create<ir::StoreVectorElement>(var->Result(), b.Constant(1_i),
+ nullptr));
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:4:37 error: store_vector_element: operand is undefined
+ store_vector_element %2 %2, 1i, undef
+ ^^^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ %2:ptr<function, vec3<f32>, read_write> = var
+ store_vector_element %2 %2, 1i, undef
+ ret
+ }
+}
+)");
+}
+
} // namespace
} // namespace tint::ir
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index 5be2c60..1eda18e 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -736,24 +736,26 @@
void GeneratorImplIr::EmitBlockInstructions(ir::Block* block) {
for (auto* inst : *block) {
Switch(
- inst, //
- [&](ir::Access* a) { EmitAccess(a); }, //
- [&](ir::Binary* b) { EmitBinary(b); }, //
- [&](ir::Bitcast* b) { EmitBitcast(b); }, //
- [&](ir::BuiltinCall* b) { EmitBuiltinCall(b); }, //
- [&](ir::Construct* c) { EmitConstruct(c); }, //
- [&](ir::Convert* c) { EmitConvert(c); }, //
- [&](ir::Load* l) { EmitLoad(l); }, //
- [&](ir::Loop* l) { EmitLoop(l); }, //
- [&](ir::Switch* sw) { EmitSwitch(sw); }, //
- [&](ir::Swizzle* s) { EmitSwizzle(s); }, //
- [&](ir::Store* s) { EmitStore(s); }, //
- [&](ir::UserCall* c) { EmitUserCall(c); }, //
- [&](ir::Unary* u) { EmitUnary(u); }, //
- [&](ir::Var* v) { EmitVar(v); }, //
- [&](ir::Let* l) { EmitLet(l); }, //
- [&](ir::If* i) { EmitIf(i); }, //
- [&](ir::Terminator* t) { EmitTerminator(t); }, //
+ inst, //
+ [&](ir::Access* a) { EmitAccess(a); }, //
+ [&](ir::Binary* b) { EmitBinary(b); }, //
+ [&](ir::Bitcast* b) { EmitBitcast(b); }, //
+ [&](ir::BuiltinCall* b) { EmitBuiltinCall(b); }, //
+ [&](ir::Construct* c) { EmitConstruct(c); }, //
+ [&](ir::Convert* c) { EmitConvert(c); }, //
+ [&](ir::Load* l) { EmitLoad(l); }, //
+ [&](ir::LoadVectorElement* l) { EmitLoadVectorElement(l); }, //
+ [&](ir::Loop* l) { EmitLoop(l); }, //
+ [&](ir::Switch* sw) { EmitSwitch(sw); }, //
+ [&](ir::Swizzle* s) { EmitSwizzle(s); }, //
+ [&](ir::Store* s) { EmitStore(s); }, //
+ [&](ir::StoreVectorElement* s) { EmitStoreVectorElement(s); }, //
+ [&](ir::UserCall* c) { EmitUserCall(c); }, //
+ [&](ir::Unary* u) { EmitUnary(u); }, //
+ [&](ir::Var* v) { EmitVar(v); }, //
+ [&](ir::Let* l) { EmitLet(l); }, //
+ [&](ir::If* i) { EmitIf(i); }, //
+ [&](ir::Terminator* t) { EmitTerminator(t); }, //
[&](Default) {
TINT_ICE(Writer, diagnostics_)
<< "unimplemented instruction: " << inst->TypeInfo().name;
@@ -1375,6 +1377,18 @@
{Type(load->Result()->Type()), Value(load), Value(load->From())});
}
+void GeneratorImplIr::EmitLoadVectorElement(ir::LoadVectorElement* load) {
+ auto* vec_ptr_ty = load->From()->Type()->As<type::Pointer>();
+ auto* el_ty = load->Result()->Type();
+ auto* el_ptr_ty = ir_->Types().ptr(vec_ptr_ty->AddressSpace(), el_ty, vec_ptr_ty->Access());
+ auto el_ptr_id = module_.NextId();
+ current_function_.push_inst(
+ spv::Op::OpAccessChain,
+ {Type(el_ptr_ty), el_ptr_id, Value(load->From()), Value(load->Index())});
+ current_function_.push_inst(spv::Op::OpLoad,
+ {Type(load->Result()->Type()), Value(load), el_ptr_id});
+}
+
void GeneratorImplIr::EmitLoop(ir::Loop* loop) {
auto init_label = loop->HasInitializer() ? Label(loop->Initializer()) : 0;
auto body_label = Label(loop->Body());
@@ -1482,6 +1496,17 @@
current_function_.push_inst(spv::Op::OpStore, {Value(store->To()), Value(store->From())});
}
+void GeneratorImplIr::EmitStoreVectorElement(ir::StoreVectorElement* store) {
+ auto* vec_ptr_ty = store->To()->Type()->As<type::Pointer>();
+ auto* el_ty = store->Value()->Type();
+ auto* el_ptr_ty = ir_->Types().ptr(vec_ptr_ty->AddressSpace(), el_ty, vec_ptr_ty->Access());
+ auto el_ptr_id = module_.NextId();
+ current_function_.push_inst(
+ spv::Op::OpAccessChain,
+ {Type(el_ptr_ty), el_ptr_id, Value(store->To()), Value(store->Index())});
+ current_function_.push_inst(spv::Op::OpStore, {el_ptr_id, Value(store->Value())});
+}
+
void GeneratorImplIr::EmitUnary(ir::Unary* unary) {
auto id = Value(unary);
auto* ty = unary->Result()->Type();
@@ -1583,10 +1608,6 @@
void GeneratorImplIr::EmitLet(ir::Let* let) {
auto id = Value(let->Value());
values_.Add(let->Result(), id);
- // Set the name if present, and the source value isn't named
- // if (auto name = ir_->NameOf(let->Result()); name && !ir_->NameOf(let->Value())) {
- // module_.PushDebug(spv::Op::OpName, {id, Operand(name.Name())});
- // }
}
void GeneratorImplIr::EmitExitPhis(ir::ControlInstruction* inst) {
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.h b/src/tint/writer/spirv/ir/generator_impl_ir.h
index 3fcbcd9..f13c667 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.h
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.h
@@ -45,10 +45,12 @@
class If;
class Let;
class Load;
+class LoadVectorElement;
class Loop;
class Module;
class MultiInBlock;
class Store;
+class StoreVectorElement;
class Switch;
class Swizzle;
class Terminator;
@@ -211,6 +213,10 @@
/// @param load the load instruction to emit
void EmitLoad(ir::Load* load);
+ /// Emit a load vector element instruction.
+ /// @param load the load vector element instruction to emit
+ void EmitLoadVectorElement(ir::LoadVectorElement* load);
+
/// Emit a loop instruction.
/// @param loop the loop instruction to emit
void EmitLoop(ir::Loop* loop);
@@ -219,6 +225,10 @@
/// @param store the store instruction to emit
void EmitStore(ir::Store* store);
+ /// Emit a store vector element instruction.
+ /// @param store the store vector element instruction to emit
+ void EmitStoreVectorElement(ir::StoreVectorElement* store);
+
/// Emit a switch instruction.
/// @param swtch the switch instruction to emit
void EmitSwitch(ir::Switch* swtch);
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc
index 86d09f7..ca171d1 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc
@@ -84,7 +84,7 @@
b.With(func->Block(), [&] {
auto* mat_var = b.Var("mat", ty.ptr<function, mat2x2<f32>>());
auto* result_vector = b.Access(ty.ptr<function, vec2<f32>>(), mat_var, 1_u);
- auto* result_scalar = b.Access(ty.ptr<function, f32>(), mat_var, 1_u, 0_u);
+ auto* result_scalar = b.LoadVectorElement(result_vector, 0_u);
b.Return(func);
mod.SetName(result_vector, "result_vector");
mod.SetName(result_scalar, "result_scalar");
@@ -92,7 +92,8 @@
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%result_vector = OpAccessChain %_ptr_Function_v2float %mat %uint_1");
- EXPECT_INST("%result_scalar = OpAccessChain %_ptr_Function_float %mat %uint_1 %uint_0");
+ EXPECT_INST("%14 = OpAccessChain %_ptr_Function_float %result_vector %uint_0");
+ EXPECT_INST("%result_scalar = OpLoad %float %14");
}
TEST_F(SpvGeneratorImplTest, Access_Matrix_Pointer_DynamicIndex) {
@@ -102,7 +103,7 @@
b.With(func->Block(), [&] {
auto* mat_var = b.Var("mat", ty.ptr<function, mat2x2<f32>>());
auto* result_vector = b.Access(ty.ptr<function, vec2<f32>>(), mat_var, idx);
- auto* result_scalar = b.Access(ty.ptr<function, f32>(), mat_var, idx, idx);
+ auto* result_scalar = b.LoadVectorElement(result_vector, idx);
b.Return(func);
mod.SetName(result_vector, "result_vector");
mod.SetName(result_scalar, "result_scalar");
@@ -110,7 +111,8 @@
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%result_vector = OpAccessChain %_ptr_Function_v2float %mat %idx");
- EXPECT_INST("%result_scalar = OpAccessChain %_ptr_Function_float %mat %idx %idx");
+ EXPECT_INST("%14 = OpAccessChain %_ptr_Function_float %result_vector %idx");
+ EXPECT_INST("%result_scalar = OpLoad %float %14");
}
TEST_F(SpvGeneratorImplTest, Access_Vector_Value_ConstantIndex) {
@@ -142,34 +144,6 @@
EXPECT_INST("%result = OpVectorExtractDynamic %int %vec %idx");
}
-TEST_F(SpvGeneratorImplTest, Access_Vector_Pointer_ConstantIndex) {
- auto* func = b.Function("foo", ty.void_());
- b.With(func->Block(), [&] {
- auto* vec_var = b.Var("vec", ty.ptr<function, vec4<i32>>());
- auto* result = b.Access(ty.ptr<function, i32>(), vec_var, 1_u);
- b.Return(func);
- mod.SetName(result, "result");
- });
-
- ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%result = OpAccessChain %_ptr_Function_int %vec %uint_1");
-}
-
-TEST_F(SpvGeneratorImplTest, Access_Vector_Pointer_DynamicIndex) {
- auto* idx = b.FunctionParam("idx", ty.i32());
- auto* func = b.Function("foo", ty.void_());
- func->SetParams({idx});
- b.With(func->Block(), [&] {
- auto* vec_var = b.Var("vec", ty.ptr<function, vec4<i32>>());
- auto* result = b.Access(ty.ptr<function, i32>(), vec_var, idx);
- b.Return(func);
- mod.SetName(result, "result");
- });
-
- ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%result = OpAccessChain %_ptr_Function_int %vec %idx");
-}
-
TEST_F(SpvGeneratorImplTest, Access_NestedVector_Value_DynamicIndex) {
auto* val = b.FunctionParam("arr", ty.array(ty.array(ty.vec4(ty.i32()), 4), 4));
auto* idx = b.FunctionParam("idx", ty.i32());
@@ -218,7 +192,7 @@
b.With(func->Block(), [&] {
auto* str_var = b.Var("str", ty.ptr(function, str, read_write));
auto* result_a = b.Access(ty.ptr<function, f32>(), str_var, 0_u);
- auto* result_b = b.Access(ty.ptr<function, i32>(), str_var, 1_u, 2_u);
+ auto* result_b = b.Access(ty.ptr<function, vec4<i32>>(), str_var, 1_u);
b.Return(func);
mod.SetName(result_a, "result_a");
mod.SetName(result_b, "result_b");
@@ -226,7 +200,65 @@
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%result_a = OpAccessChain %_ptr_Function_float %str %uint_0");
- EXPECT_INST("%result_b = OpAccessChain %_ptr_Function_int %str %uint_1 %uint_2");
+ EXPECT_INST("%result_b = OpAccessChain %_ptr_Function_v4int %str %uint_1");
+}
+
+TEST_F(SpvGeneratorImplTest, LoadVectorElement_ConstantIndex) {
+ auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* vec_var = b.Var("vec", ty.ptr<function, vec4<i32>>());
+ auto* result = b.LoadVectorElement(vec_var, 1_u);
+ b.Return(func);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%9 = OpAccessChain %_ptr_Function_int %vec %uint_1");
+ EXPECT_INST("%result = OpLoad %int %9");
+}
+
+TEST_F(SpvGeneratorImplTest, LoadVectorElement_DynamicIndex) {
+ auto* idx = b.FunctionParam("idx", ty.i32());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({idx});
+ b.With(func->Block(), [&] {
+ auto* vec_var = b.Var("vec", ty.ptr<function, vec4<i32>>());
+ auto* result = b.LoadVectorElement(vec_var, idx);
+ b.Return(func);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%10 = OpAccessChain %_ptr_Function_int %vec %idx");
+ EXPECT_INST("%result = OpLoad %int %10");
+}
+
+TEST_F(SpvGeneratorImplTest, StoreVectorElement_ConstantIndex) {
+ auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* vec_var = b.Var("vec", ty.ptr<function, vec4<i32>>());
+ b.StoreVectorElement(vec_var, 1_u, b.Constant(42_i));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%9 = OpAccessChain %_ptr_Function_int %vec %uint_1");
+ EXPECT_INST("OpStore %9 %int_42");
+}
+
+TEST_F(SpvGeneratorImplTest, StoreVectorElement_DynamicIndex) {
+ auto* idx = b.FunctionParam("idx", ty.i32());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({idx});
+ b.With(func->Block(), [&] {
+ auto* vec_var = b.Var("vec", ty.ptr<function, vec4<i32>>());
+ b.StoreVectorElement(vec_var, idx, b.Constant(42_i));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%10 = OpAccessChain %_ptr_Function_int %vec %idx");
+ EXPECT_INST("OpStore %10 %int_42");
}
} // namespace