transform::VarForDynamicIndex: Operate on matrices
Much like arrays, the SPIR-V writer cannot cope with dynamic indexing of matrices.
Fixed: tint:825
Change-Id: Ia111f15e0cf6fbd441861a4b3455a33b82b692ab
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51781
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
diff --git a/src/transform/var_for_dynamic_index.cc b/src/transform/var_for_dynamic_index.cc
index e9af1e7..ca24645 100644
--- a/src/transform/var_for_dynamic_index.cc
+++ b/src/transform/var_for_dynamic_index.cc
@@ -37,7 +37,7 @@
if (auto* access_expr = node->As<ast::ArrayAccessorExpression>()) {
// Found an array accessor expression
auto* index_expr = access_expr->idx_expr();
- auto* array_expr = access_expr->array();
+ auto* indexed_expr = access_expr->array();
if (index_expr->Is<ast::ScalarConstructorExpression>()) {
// Index expression is a literal value. As this isn't a dynamic index,
@@ -45,20 +45,20 @@
continue;
}
- auto* array = ctx.src->Sem().Get(array_expr);
- if (!array->Type()->Is<sem::Array>()) {
- // This transform currently only cares about arrays.
+ auto* indexed = ctx.src->Sem().Get(indexed_expr);
+ if (!indexed->Type()->IsAnyOf<sem::Array, sem::Matrix>()) {
+ // This transform currently only cares about array and matrices.
continue;
}
- auto* stmt = array->Stmt(); // Statement that owns the expression
- auto* block = stmt->Block(); // Block that owns the statement
+ auto* stmt = indexed->Stmt(); // Statement that owns the expression
+ auto* block = stmt->Block(); // Block that owns the statement
// Construct a `var` declaration to hold the value in memory.
- auto* ty = CreateASTTypeFor(&ctx, array->Type());
- auto var_name = ctx.dst->Symbols().New("var_for_array");
+ auto* ty = CreateASTTypeFor(&ctx, indexed->Type());
+ auto var_name = ctx.dst->Symbols().New("var_for_index");
auto* var_decl = ctx.dst->Decl(ctx.dst->Var(
- var_name, ty, ast::StorageClass::kNone, ctx.Clone(array_expr)));
+ var_name, ty, ast::StorageClass::kNone, ctx.Clone(indexed_expr)));
// Insert the `var` declaration before the statement that performs the
// indexing. Note that for indexing chains, AST node ordering guarantees
@@ -67,7 +67,7 @@
var_decl);
// Replace the original index expression with the new `var`.
- ctx.Replace(array_expr, ctx.dst->Expr(var_name));
+ ctx.Replace(indexed_expr, ctx.dst->Expr(var_name));
}
}
diff --git a/src/transform/var_for_dynamic_index.h b/src/transform/var_for_dynamic_index.h
index 7d928c6..86fa30a 100644
--- a/src/transform/var_for_dynamic_index.h
+++ b/src/transform/var_for_dynamic_index.h
@@ -23,10 +23,10 @@
namespace tint {
namespace transform {
-/// A transform that extracts array values that are dynamically indexed to a
-/// temporary `var` local before performing the index. This transform is used by
-/// the SPIR-V writer for dynamically indexing arrays, as there is no SPIR-V
-/// instruction that can dynamically index a non-pointer composite.
+/// A transform that extracts array and matrix values that are dynamically
+/// indexed to a temporary `var` local before performing the index. This
+/// transform is used by the SPIR-V writer as there is no SPIR-V instruction
+/// that can dynamically index a non-pointer composite.
class VarForDynamicIndex : public Transform {
public:
/// Constructor
diff --git a/src/transform/var_for_dynamic_index_test.cc b/src/transform/var_for_dynamic_index_test.cc
index 4626dea..29dd84e 100644
--- a/src/transform/var_for_dynamic_index_test.cc
+++ b/src/transform/var_for_dynamic_index_test.cc
@@ -44,8 +44,8 @@
fn f() {
var i : i32;
let p : array<i32, 4> = array<i32, 4>(1, 2, 3, 4);
- var var_for_array : array<i32, 4> = p;
- let x : i32 = var_for_array[i];
+ var var_for_index : array<i32, 4> = p;
+ let x : i32 = var_for_index[i];
}
)";
@@ -67,7 +67,7 @@
// TODO(bclayton): Optimize this case:
// This output is not as efficient as it could be.
// We only actually need to hoist the inner-most array to a `var`
- // (`var_for_array`), as later indexing operations will be working with
+ // (`var_for_index`), as later indexing operations will be working with
// references, not values.
auto* expect = R"(
@@ -75,9 +75,9 @@
var i : i32;
var j : i32;
let p : array<array<i32, 2>, 2> = array<array<i32, 2>, 2>(array<i32, 2>(1, 2), array<i32, 2>(3, 4));
- var var_for_array : array<array<i32, 2>, 2> = p;
- var var_for_array_1 : array<i32, 2> = var_for_array[i];
- let x : i32 = var_for_array_1[j];
+ var var_for_index : array<array<i32, 2>, 2> = p;
+ var var_for_index_1 : array<i32, 2> = var_for_index[i];
+ let x : i32 = var_for_index_1[j];
}
)";
diff --git a/src/writer/spirv/builder_accessor_expression_test.cc b/src/writer/spirv/builder_accessor_expression_test.cc
index 7e1cb1d..f1afe46 100644
--- a/src/writer/spirv/builder_accessor_expression_test.cc
+++ b/src/writer/spirv/builder_accessor_expression_test.cc
@@ -22,9 +22,9 @@
using BuilderTest = TestHelper;
-TEST_F(BuilderTest, ArrayAccessor) {
- // vec3<f32> ary;
- // ary[1] -> ptr<f32>
+TEST_F(BuilderTest, ArrayAccessor_VectorRef_Literal) {
+ // var ary : vec3<f32>;
+ // ary[1] -> ref<f32>
auto* var = Var("ary", ty.vec3<f32>());
@@ -57,10 +57,10 @@
)");
}
-TEST_F(BuilderTest, Accessor_Array_LoadIndex) {
- // ary : vec3<f32>;
- // idx : i32;
- // ary[idx] -> ptr<f32>
+TEST_F(BuilderTest, ArrayAccessor_VectorRef_Dynamic) {
+ // var ary : vec3<f32>;
+ // var idx : i32;
+ // ary[idx] -> ref<f32>
auto* var = Var("ary", ty.vec3<f32>());
auto* idx = Var("idx", ty.i32());
@@ -98,9 +98,9 @@
)");
}
-TEST_F(BuilderTest, ArrayAccessor_Dynamic) {
- // vec3<f32> ary;
- // ary[1 + 2] -> ptr<f32>
+TEST_F(BuilderTest, ArrayAccessor_VectorRef_Dynamic2) {
+ // var ary : vec3<f32>;
+ // ary[1 + 2] -> ref<f32>
auto* var = Var("ary", ty.vec3<f32>());
@@ -134,10 +134,10 @@
)");
}
-TEST_F(BuilderTest, ArrayAccessor_MultiLevel) {
+TEST_F(BuilderTest, ArrayAccessor_ArrayRef_MultiLevel) {
auto* ary4 = ty.array(ty.vec3<f32>(), 4);
- // ary = array<vec3<f32>, 4>
+ // var ary : array<vec3<f32>, 4>
// ary[3][2];
auto* var = Var("ary", ary4);
@@ -172,7 +172,7 @@
)");
}
-TEST_F(BuilderTest, Accessor_ArrayWithSwizzle) {
+TEST_F(BuilderTest, ArrayAccessor_ArrayRef_ArrayWithSwizzle) {
auto* ary4 = ty.array(ty.vec3<f32>(), 4);
// var a : array<vec3<f32>, 4>;
@@ -680,7 +680,7 @@
)");
}
-TEST_F(BuilderTest, Accessor_Mixed_ArrayAndMember) {
+TEST_F(BuilderTest, ArrayAccessor_Mixed_ArrayAndMember) {
// type C = struct {
// baz : vec3<f32>
// }
@@ -747,7 +747,7 @@
)");
}
-TEST_F(BuilderTest, Accessor_Array_Of_Vec) {
+TEST_F(BuilderTest, ArrayAccessor_Of_Vec) {
// let pos : array<vec2<f32>, 3> = array<vec2<f32>, 3>(
// vec2<f32>(0.0, 0.5),
// vec2<f32>(-0.5, -0.5),
@@ -790,7 +790,7 @@
Validate(b);
}
-TEST_F(BuilderTest, Accessor_Array_Of_Array_Of_f32) {
+TEST_F(BuilderTest, ArrayAccessor_Of_Array_Of_f32) {
// let pos : array<array<f32, 2>, 3> = array<vec2<f32, 2>, 3>(
// array<f32, 2>(0.0, 0.5),
// array<f32, 2>(-0.5, -0.5),
@@ -835,7 +835,7 @@
Validate(b);
}
-TEST_F(BuilderTest, Accessor_Const_Vec) {
+TEST_F(BuilderTest, ArrayAccessor_Vec_Literal) {
// let pos : vec2<f32> = vec2<f32>(0.0, 0.5);
// pos[1]
@@ -864,7 +864,7 @@
)");
}
-TEST_F(BuilderTest, Accessor_Const_Vec_Dynamic) {
+TEST_F(BuilderTest, ArrayAccessor_Vec_Dynamic) {
// let pos : vec2<f32> = vec2<f32>(0.0, 0.5);
// idx : i32
// pos[idx]
@@ -900,7 +900,7 @@
)");
}
-TEST_F(BuilderTest, Accessor_Array_NonPointer) {
+TEST_F(BuilderTest, ArrayAccessor_Array_Literal) {
// let a : array<f32, 3>;
// a[2]
@@ -934,7 +934,7 @@
Validate(b);
}
-TEST_F(BuilderTest, Accessor_Array_NonPointer_Dynamic) {
+TEST_F(BuilderTest, ArrayAccessor_Array_Dynamic) {
// let a : array<f32, 3>;
// idx : i32
// a[idx]
@@ -982,6 +982,58 @@
Validate(b);
}
+TEST_F(BuilderTest, ArrayAccessor_Matrix_Dynamic) {
+ // let a : mat2x2<f32>(vec2<f32>(1., 2.), vec2<f32>(3., 4.));
+ // idx : i32
+ // a[idx]
+
+ auto* var =
+ Const("a", ty.mat2x2<f32>(),
+ Construct(ty.mat2x2<f32>(), Construct(ty.vec2<f32>(), 1.f, 2.f),
+ Construct(ty.vec2<f32>(), 3.f, 4.f)));
+
+ auto* idx = Var("idx", ty.i32());
+ auto* expr = IndexAccessor("a", idx);
+
+ WrapInFunction(var, idx, expr);
+
+ spirv::Builder& b = SanitizeAndBuild();
+
+ ASSERT_TRUE(b.Build());
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%7 = OpTypeFloat 32
+%6 = OpTypeVector %7 2
+%5 = OpTypeMatrix %6 2
+%8 = OpConstant %7 1
+%9 = OpConstant %7 2
+%10 = OpConstantComposite %6 %8 %9
+%11 = OpConstant %7 3
+%12 = OpConstant %7 4
+%13 = OpConstantComposite %6 %11 %12
+%14 = OpConstantComposite %5 %10 %13
+%17 = OpTypeInt 32 1
+%16 = OpTypePointer Function %17
+%18 = OpConstantNull %17
+%20 = OpTypePointer Function %5
+%21 = OpConstantNull %5
+%23 = OpTypePointer Function %6
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
+ R"(%15 = OpVariable %16 Function %18
+%19 = OpVariable %20 Function %21
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(OpStore %19 %14
+%22 = OpLoad %17 %15
+%24 = OpAccessChain %23 %19 %22
+%25 = OpLoad %6 %24
+)");
+
+ Validate(b);
+}
+
} // namespace
} // namespace spirv
} // namespace writer
diff --git a/test/bug/tint/824.wgsl.expected.spvasm b/test/bug/tint/824.wgsl.expected.spvasm
index bbdd3fd..394df83 100644
--- a/test/bug/tint/824.wgsl.expected.spvasm
+++ b/test/bug/tint/824.wgsl.expected.spvasm
@@ -17,9 +17,9 @@
OpName %tint_symbol_5 "tint_symbol_5"
OpName %tint_symbol_2 "tint_symbol_2"
OpName %main "main"
- OpName %var_for_array "var_for_array"
+ OpName %var_for_index "var_for_index"
OpName %output "output"
- OpName %var_for_array_1 "var_for_array_1"
+ OpName %var_for_index_1 "var_for_index_1"
OpDecorate %tint_pointsize BuiltIn PointSize
OpDecorate %tint_symbol BuiltIn VertexIndex
OpDecorate %tint_symbol_1 BuiltIn InstanceIndex
@@ -88,21 +88,21 @@
OpFunctionEnd
%main = OpFunction %void None %22
%24 = OpLabel
-%var_for_array = OpVariable %_ptr_Function__arr_v2float_uint_4 Function %40
+%var_for_index = OpVariable %_ptr_Function__arr_v2float_uint_4 Function %40
%output = OpVariable %_ptr_Function_Output Function %48
-%var_for_array_1 = OpVariable %_ptr_Function__arr_v4float_uint_4 Function %62
+%var_for_index_1 = OpVariable %_ptr_Function__arr_v4float_uint_4 Function %62
OpStore %tint_pointsize %float_1
- OpStore %var_for_array %37
+ OpStore %var_for_index %37
%41 = OpLoad %uint %tint_symbol_1
- %44 = OpAccessChain %_ptr_Function_float %var_for_array %41 %uint_0
+ %44 = OpAccessChain %_ptr_Function_float %var_for_index %41 %uint_0
%45 = OpLoad %float %44
%50 = OpAccessChain %_ptr_Function_v4float %output %uint_0
%52 = OpCompositeConstruct %v4float %float_0_5 %float_0_5 %45 %float_1
OpStore %50 %52
- OpStore %var_for_array_1 %59
+ OpStore %var_for_index_1 %59
%64 = OpAccessChain %_ptr_Function_v4float %output %uint_1
%65 = OpLoad %uint %tint_symbol_1
- %66 = OpAccessChain %_ptr_Function_v4float %var_for_array_1 %65
+ %66 = OpAccessChain %_ptr_Function_v4float %var_for_index_1 %65
%67 = OpLoad %v4float %66
OpStore %64 %67
%69 = OpLoad %Output %output
diff --git a/test/bug/tint/825.wgsl b/test/bug/tint/825.wgsl
new file mode 100644
index 0000000..92a0b7e
--- /dev/null
+++ b/test/bug/tint/825.wgsl
@@ -0,0 +1,6 @@
+fn f() {
+ var i : i32;
+ var j : i32;
+ let m : mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
+ let f : f32 = m[i][j];
+}
diff --git a/test/bug/tint/825.wgsl.expected.hlsl b/test/bug/tint/825.wgsl.expected.hlsl
new file mode 100644
index 0000000..0acca6f
--- /dev/null
+++ b/test/bug/tint/825.wgsl.expected.hlsl
@@ -0,0 +1,12 @@
+void f() {
+ int i = 0;
+ int j = 0;
+ const float2x2 m = float2x2(float2(1.0f, 2.0f), float2(3.0f, 4.0f));
+ const float f = m[i][j];
+}
+
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+ return;
+}
+
diff --git a/test/bug/tint/825.wgsl.expected.msl b/test/bug/tint/825.wgsl.expected.msl
new file mode 100644
index 0000000..5a8c360
--- /dev/null
+++ b/test/bug/tint/825.wgsl.expected.msl
@@ -0,0 +1,10 @@
+#include <metal_stdlib>
+
+using namespace metal;
+void f() {
+ int i = 0;
+ int j = 0;
+ float2x2 const m = float2x2(float2(1.0f, 2.0f), float2(3.0f, 4.0f));
+ float const f = m[i][j];
+}
+
diff --git a/test/bug/tint/825.wgsl.expected.spvasm b/test/bug/tint/825.wgsl.expected.spvasm
new file mode 100644
index 0000000..d7996bf
--- /dev/null
+++ b/test/bug/tint/825.wgsl.expected.spvasm
@@ -0,0 +1,48 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 30
+; Schema: 0
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
+ OpExecutionMode %unused_entry_point LocalSize 1 1 1
+ OpName %unused_entry_point "unused_entry_point"
+ OpName %f "f"
+ OpName %i "i"
+ OpName %j "j"
+ OpName %var_for_index "var_for_index"
+ %void = OpTypeVoid
+ %1 = OpTypeFunction %void
+ %int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+ %10 = OpConstantNull %int
+ %float = OpTypeFloat 32
+ %v2float = OpTypeVector %float 2
+%mat2v2float = OpTypeMatrix %v2float 2
+ %float_1 = OpConstant %float 1
+ %float_2 = OpConstant %float 2
+ %17 = OpConstantComposite %v2float %float_1 %float_2
+ %float_3 = OpConstant %float 3
+ %float_4 = OpConstant %float 4
+ %20 = OpConstantComposite %v2float %float_3 %float_4
+ %21 = OpConstantComposite %mat2v2float %17 %20
+%_ptr_Function_mat2v2float = OpTypePointer Function %mat2v2float
+ %24 = OpConstantNull %mat2v2float
+%_ptr_Function_float = OpTypePointer Function %float
+%unused_entry_point = OpFunction %void None %1
+ %4 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %f = OpFunction %void None %1
+ %6 = OpLabel
+ %i = OpVariable %_ptr_Function_int Function %10
+ %j = OpVariable %_ptr_Function_int Function %10
+%var_for_index = OpVariable %_ptr_Function_mat2v2float Function %24
+ OpStore %var_for_index %21
+ %25 = OpLoad %int %i
+ %26 = OpLoad %int %j
+ %28 = OpAccessChain %_ptr_Function_float %var_for_index %25 %26
+ %29 = OpLoad %float %28
+ OpReturn
+ OpFunctionEnd
diff --git a/test/bug/tint/825.wgsl.expected.wgsl b/test/bug/tint/825.wgsl.expected.wgsl
new file mode 100644
index 0000000..92a0b7e
--- /dev/null
+++ b/test/bug/tint/825.wgsl.expected.wgsl
@@ -0,0 +1,6 @@
+fn f() {
+ var i : i32;
+ var j : i32;
+ let m : mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
+ let f : f32 = m[i][j];
+}
diff --git a/test/samples/triangle.wgsl.expected.spvasm b/test/samples/triangle.wgsl.expected.spvasm
index c6f874d..5c9ada3 100644
--- a/test/samples/triangle.wgsl.expected.spvasm
+++ b/test/samples/triangle.wgsl.expected.spvasm
@@ -16,7 +16,7 @@
OpName %tint_symbol_3 "tint_symbol_3"
OpName %tint_symbol_1 "tint_symbol_1"
OpName %vtx_main "vtx_main"
- OpName %var_for_array "var_for_array"
+ OpName %var_for_index "var_for_index"
OpName %tint_symbol_6 "tint_symbol_6"
OpName %tint_symbol_4 "tint_symbol_4"
OpName %frag_main "frag_main"
@@ -64,11 +64,11 @@
OpFunctionEnd
%vtx_main = OpFunction %void None %29
%31 = OpLabel
-%var_for_array = OpVariable %_ptr_Function__arr_v2float_uint_3 Function %35
+%var_for_index = OpVariable %_ptr_Function__arr_v2float_uint_3 Function %35
OpStore %tint_pointsize %float_1
- OpStore %var_for_array %pos
+ OpStore %var_for_index %pos
%37 = OpLoad %int %tint_symbol
- %39 = OpAccessChain %_ptr_Function_v2float %var_for_array %37
+ %39 = OpAccessChain %_ptr_Function_v2float %var_for_index %37
%40 = OpLoad %v2float %39
%41 = OpCompositeExtract %float %40 0
%42 = OpCompositeExtract %float %40 1