[ir] Fix Std140 for non-decomposed matrices
Recursively replacing instructions from a buffer that contains a
decomposed matrix cannot assume that all instructions are operating on
decomposed matrices, since there may be other matrices in the same
struct that were not decomposed.
Fixed: tint:2100
Change-Id: Ifc28636a39276d931592bcc1fe2026540d25c8c3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/160742
Commit-Queue: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Auto-Submit: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/ir/transform/std140.cc b/src/tint/lang/core/ir/transform/std140.cc
index 0f2b8f6..56897fd 100644
--- a/src/tint/lang/core/ir/transform/std140.cc
+++ b/src/tint/lang/core/ir/transform/std140.cc
@@ -330,12 +330,17 @@
load->Destroy();
},
[&](LoadVectorElement* load) {
- // We should have loaded the decomposed matrix, reconstructed it, so this is now
- // extracting from a value type.
- TINT_ASSERT(!replacement->Type()->Is<core::type::Pointer>());
- auto* access = b.Access(load->Result()->Type(), replacement, load->Index());
- load->Result()->ReplaceAllUsesWith(access->Result());
- load->Destroy();
+ if (!replacement->Type()->Is<core::type::Pointer>()) {
+ // We have loaded a decomposed matrix and reconstructed it, so this is now
+ // extracting from a value type.
+ auto* access = b.Access(load->Result()->Type(), replacement, load->Index());
+ load->Result()->ReplaceAllUsesWith(access->Result());
+ load->Destroy();
+ } else {
+ // There was no decomposed matrix on the path to this instruction so just
+ // update the source operand.
+ load->SetOperand(LoadVectorElement::kFromOperandOffset, replacement);
+ }
},
[&](Let* let) {
// Let instructions just fold away.
diff --git a/src/tint/lang/core/ir/transform/std140_test.cc b/src/tint/lang/core/ir/transform/std140_test.cc
index 8415529..105f2cd 100644
--- a/src/tint/lang/core/ir/transform/std140_test.cc
+++ b/src/tint/lang/core/ir/transform/std140_test.cc
@@ -1483,6 +1483,283 @@
EXPECT_EQ(expect, str());
}
+TEST_F(IR_Std140Test, NotAllMatricesDecomposed) {
+ auto* mat4x4 = ty.mat4x4<f32>();
+ auto* mat3x2 = ty.mat3x2<f32>();
+ auto* structure = ty.Struct(mod.symbols.New("MyStruct"), {
+ {mod.symbols.New("a"), mat4x4},
+ {mod.symbols.New("b"), mat3x2},
+ });
+ structure->SetStructFlag(core::type::kBlock);
+
+ auto* buffer = b.Var("buffer", ty.ptr(uniform, structure));
+ buffer->SetBindingPoint(0, 0);
+ mod.root_block->Append(buffer);
+
+ {
+ auto* func = b.Function("load_struct_a", mat4x4);
+ b.Append(func->Block(), [&] {
+ auto* load_struct = b.Load(buffer);
+ auto* extract_mat = b.Access(mat4x4, load_struct, 0_u);
+ b.Return(func, extract_mat);
+ });
+ }
+
+ {
+ auto* func = b.Function("load_struct_b", mat3x2);
+ b.Append(func->Block(), [&] {
+ auto* load_struct = b.Load(buffer);
+ auto* extract_mat = b.Access(mat3x2, load_struct, 1_u);
+ b.Return(func, extract_mat);
+ });
+ }
+
+ {
+ auto* func = b.Function("load_mat_a", ty.vec4<f32>());
+ b.Append(func->Block(), [&] {
+ auto* access_mat = b.Access(ty.ptr(uniform, mat4x4), buffer, 0_u);
+ auto* load_mat = b.Load(access_mat);
+ auto* extract_vec = b.Access(ty.vec4<f32>(), load_mat, 0_u);
+ b.Return(func, extract_vec);
+ });
+ }
+
+ {
+ auto* func = b.Function("load_mat_b", ty.vec2<f32>());
+ b.Append(func->Block(), [&] {
+ auto* access_mat = b.Access(ty.ptr(uniform, mat3x2), buffer, 1_u);
+ auto* load_mat = b.Load(access_mat);
+ auto* extract_vec = b.Access(ty.vec2<f32>(), load_mat, 0_u);
+ b.Return(func, extract_vec);
+ });
+ }
+
+ {
+ auto* func = b.Function("load_vec_a", ty.f32());
+ b.Append(func->Block(), [&] {
+ auto* access_vec = b.Access(ty.ptr(uniform, mat4x4->ColumnType()), buffer, 0_u, 1_u);
+ auto* load_vec = b.Load(access_vec);
+ auto* extract_el = b.Access(ty.f32(), load_vec, 1_u);
+ b.Return(func, extract_el);
+ });
+ }
+
+ {
+ auto* func = b.Function("load_vec_b", ty.f32());
+ b.Append(func->Block(), [&] {
+ auto* access_vec = b.Access(ty.ptr(uniform, mat3x2->ColumnType()), buffer, 1_u, 1_u);
+ auto* load_vec = b.Load(access_vec);
+ auto* extract_el = b.Access(ty.f32(), load_vec, 1_u);
+ b.Return(func, extract_el);
+ });
+ }
+
+ {
+ auto* func = b.Function("lve_a", ty.f32());
+ b.Append(func->Block(), [&] {
+ auto* access_vec = b.Access(ty.ptr(uniform, mat4x4->ColumnType()), buffer, 0_u, 1_u);
+ auto* lve = b.LoadVectorElement(access_vec, 1_u);
+ b.Return(func, lve);
+ });
+ }
+
+ {
+ auto* func = b.Function("lve_b", ty.f32());
+ b.Append(func->Block(), [&] {
+ auto* access_vec = b.Access(ty.ptr(uniform, mat3x2->ColumnType()), buffer, 1_u, 1_u);
+ auto* lve = b.LoadVectorElement(access_vec, 1_u);
+ b.Return(func, lve);
+ });
+ }
+
+ auto* src = R"(
+MyStruct = struct @align(16), @block {
+ a:mat4x4<f32> @offset(0)
+ b:mat3x2<f32> @offset(64)
+}
+
+%b1 = block { # root
+ %buffer:ptr<uniform, MyStruct, read_write> = var @binding_point(0, 0)
+}
+
+%load_struct_a = func():mat4x4<f32> -> %b2 {
+ %b2 = block {
+ %3:MyStruct = load %buffer
+ %4:mat4x4<f32> = access %3, 0u
+ ret %4
+ }
+}
+%load_struct_b = func():mat3x2<f32> -> %b3 {
+ %b3 = block {
+ %6:MyStruct = load %buffer
+ %7:mat3x2<f32> = access %6, 1u
+ ret %7
+ }
+}
+%load_mat_a = func():vec4<f32> -> %b4 {
+ %b4 = block {
+ %9:ptr<uniform, mat4x4<f32>, read_write> = access %buffer, 0u
+ %10:mat4x4<f32> = load %9
+ %11:vec4<f32> = access %10, 0u
+ ret %11
+ }
+}
+%load_mat_b = func():vec2<f32> -> %b5 {
+ %b5 = block {
+ %13:ptr<uniform, mat3x2<f32>, read_write> = access %buffer, 1u
+ %14:mat3x2<f32> = load %13
+ %15:vec2<f32> = access %14, 0u
+ ret %15
+ }
+}
+%load_vec_a = func():f32 -> %b6 {
+ %b6 = block {
+ %17:ptr<uniform, vec4<f32>, read_write> = access %buffer, 0u, 1u
+ %18:vec4<f32> = load %17
+ %19:f32 = access %18, 1u
+ ret %19
+ }
+}
+%load_vec_b = func():f32 -> %b7 {
+ %b7 = block {
+ %21:ptr<uniform, vec2<f32>, read_write> = access %buffer, 1u, 1u
+ %22:vec2<f32> = load %21
+ %23:f32 = access %22, 1u
+ ret %23
+ }
+}
+%lve_a = func():f32 -> %b8 {
+ %b8 = block {
+ %25:ptr<uniform, vec4<f32>, read_write> = access %buffer, 0u, 1u
+ %26:f32 = load_vector_element %25, 1u
+ ret %26
+ }
+}
+%lve_b = func():f32 -> %b9 {
+ %b9 = block {
+ %28:ptr<uniform, vec2<f32>, read_write> = access %buffer, 1u, 1u
+ %29:f32 = load_vector_element %28, 1u
+ ret %29
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+MyStruct = struct @align(16), @block {
+ a:mat4x4<f32> @offset(0)
+ b:mat3x2<f32> @offset(64)
+}
+
+MyStruct_std140 = struct @align(16), @block {
+ a:mat4x4<f32> @offset(0)
+ b_col0:vec2<f32> @offset(64)
+ b_col1:vec2<f32> @offset(72)
+ b_col2:vec2<f32> @offset(80)
+}
+
+%b1 = block { # root
+ %buffer:ptr<uniform, MyStruct_std140, read_write> = var @binding_point(0, 0)
+}
+
+%load_struct_a = func():mat4x4<f32> -> %b2 {
+ %b2 = block {
+ %3:MyStruct_std140 = load %buffer
+ %4:MyStruct = call %convert_MyStruct, %3
+ %6:mat4x4<f32> = access %4, 0u
+ ret %6
+ }
+}
+%load_struct_b = func():mat3x2<f32> -> %b3 {
+ %b3 = block {
+ %8:MyStruct_std140 = load %buffer
+ %9:MyStruct = call %convert_MyStruct, %8
+ %10:mat3x2<f32> = access %9, 1u
+ ret %10
+ }
+}
+%load_mat_a = func():vec4<f32> -> %b4 {
+ %b4 = block {
+ %12:ptr<uniform, mat4x4<f32>, read_write> = access %buffer, 0u
+ %13:mat4x4<f32> = load %12
+ %14:vec4<f32> = access %13, 0u
+ ret %14
+ }
+}
+%load_mat_b = func():vec2<f32> -> %b5 {
+ %b5 = block {
+ %16:ptr<uniform, vec2<f32>, read_write> = access %buffer, 1u
+ %17:vec2<f32> = load %16
+ %18:ptr<uniform, vec2<f32>, read_write> = access %buffer, 2u
+ %19:vec2<f32> = load %18
+ %20:ptr<uniform, vec2<f32>, read_write> = access %buffer, 3u
+ %21:vec2<f32> = load %20
+ %22:mat3x2<f32> = construct %17, %19, %21
+ %23:vec2<f32> = access %22, 0u
+ ret %23
+ }
+}
+%load_vec_a = func():f32 -> %b6 {
+ %b6 = block {
+ %25:ptr<uniform, vec4<f32>, read_write> = access %buffer, 0u, 1u
+ %26:vec4<f32> = load %25
+ %27:f32 = access %26, 1u
+ ret %27
+ }
+}
+%load_vec_b = func():f32 -> %b7 {
+ %b7 = block {
+ %29:ptr<uniform, vec2<f32>, read_write> = access %buffer, 1u
+ %30:vec2<f32> = load %29
+ %31:ptr<uniform, vec2<f32>, read_write> = access %buffer, 2u
+ %32:vec2<f32> = load %31
+ %33:ptr<uniform, vec2<f32>, read_write> = access %buffer, 3u
+ %34:vec2<f32> = load %33
+ %35:mat3x2<f32> = construct %30, %32, %34
+ %36:vec2<f32> = access %35, 1u
+ %37:f32 = access %36, 1u
+ ret %37
+ }
+}
+%lve_a = func():f32 -> %b8 {
+ %b8 = block {
+ %39:ptr<uniform, vec4<f32>, read_write> = access %buffer, 0u, 1u
+ %40:f32 = load_vector_element %39, 1u
+ ret %40
+ }
+}
+%lve_b = func():f32 -> %b9 {
+ %b9 = block {
+ %42:ptr<uniform, vec2<f32>, read_write> = access %buffer, 1u
+ %43:vec2<f32> = load %42
+ %44:ptr<uniform, vec2<f32>, read_write> = access %buffer, 2u
+ %45:vec2<f32> = load %44
+ %46:ptr<uniform, vec2<f32>, read_write> = access %buffer, 3u
+ %47:vec2<f32> = load %46
+ %48:mat3x2<f32> = construct %43, %45, %47
+ %49:vec2<f32> = access %48, 1u
+ %50:f32 = access %49, 1u
+ ret %50
+ }
+}
+%convert_MyStruct = func(%input:MyStruct_std140):MyStruct -> %b10 {
+ %b10 = block {
+ %52:mat4x4<f32> = access %input, 0u
+ %53:vec2<f32> = access %input, 1u
+ %54:vec2<f32> = access %input, 2u
+ %55:vec2<f32> = access %input, 3u
+ %56:mat3x2<f32> = construct %53, %54, %55
+ %57:MyStruct = construct %52, %56
+ ret %57
+ }
+}
+)";
+
+ Run(Std140);
+
+ EXPECT_EQ(expect, str());
+}
+
TEST_F(IR_Std140Test, F16) {
auto* structure =
ty.Struct(mod.symbols.New("MyStruct"), {
diff --git a/test/tint/bug/tint/2100.wgsl b/test/tint/bug/tint/2100.wgsl
new file mode 100644
index 0000000..360ee9c
--- /dev/null
+++ b/test/tint/bug/tint/2100.wgsl
@@ -0,0 +1,12 @@
+struct S {
+ matrix_view : mat4x4<f32>,
+ matrix_normal : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<uniform> buffer : S;
+
+@vertex
+fn main() -> @builtin(position) vec4f {
+ let x = buffer.matrix_view[0].z;
+ return vec4f(x, 0, 0, 1);
+}
diff --git a/test/tint/bug/tint/2100.wgsl.expected.dxc.hlsl b/test/tint/bug/tint/2100.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..6b7e066
--- /dev/null
+++ b/test/tint/bug/tint/2100.wgsl.expected.dxc.hlsl
@@ -0,0 +1,19 @@
+cbuffer cbuffer_buffer : register(b0) {
+ uint4 buffer[7];
+};
+
+struct tint_symbol {
+ float4 value : SV_Position;
+};
+
+float4 main_inner() {
+ const float x = asfloat(buffer[0].z);
+ return float4(x, 0.0f, 0.0f, 1.0f);
+}
+
+tint_symbol main() {
+ const float4 inner_result = main_inner();
+ tint_symbol wrapper_result = (tint_symbol)0;
+ wrapper_result.value = inner_result;
+ return wrapper_result;
+}
diff --git a/test/tint/bug/tint/2100.wgsl.expected.fxc.hlsl b/test/tint/bug/tint/2100.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..6b7e066
--- /dev/null
+++ b/test/tint/bug/tint/2100.wgsl.expected.fxc.hlsl
@@ -0,0 +1,19 @@
+cbuffer cbuffer_buffer : register(b0) {
+ uint4 buffer[7];
+};
+
+struct tint_symbol {
+ float4 value : SV_Position;
+};
+
+float4 main_inner() {
+ const float x = asfloat(buffer[0].z);
+ return float4(x, 0.0f, 0.0f, 1.0f);
+}
+
+tint_symbol main() {
+ const float4 inner_result = main_inner();
+ tint_symbol wrapper_result = (tint_symbol)0;
+ wrapper_result.value = inner_result;
+ return wrapper_result;
+}
diff --git a/test/tint/bug/tint/2100.wgsl.expected.glsl b/test/tint/bug/tint/2100.wgsl.expected.glsl
new file mode 100644
index 0000000..110fd04
--- /dev/null
+++ b/test/tint/bug/tint/2100.wgsl.expected.glsl
@@ -0,0 +1,24 @@
+#version 310 es
+
+struct S {
+ mat4 matrix_view;
+ mat3 matrix_normal;
+};
+
+layout(binding = 0, std140) uniform tint_symbol_block_ubo {
+ S inner;
+} tint_symbol;
+
+vec4 tint_symbol_1() {
+ float x = tint_symbol.inner.matrix_view[0].z;
+ return vec4(x, 0.0f, 0.0f, 1.0f);
+}
+
+void main() {
+ gl_PointSize = 1.0;
+ vec4 inner_result = tint_symbol_1();
+ gl_Position = inner_result;
+ gl_Position.y = -(gl_Position.y);
+ gl_Position.z = ((2.0f * gl_Position.z) - gl_Position.w);
+ return;
+}
diff --git a/test/tint/bug/tint/2100.wgsl.expected.msl b/test/tint/bug/tint/2100.wgsl.expected.msl
new file mode 100644
index 0000000..89c4fe8
--- /dev/null
+++ b/test/tint/bug/tint/2100.wgsl.expected.msl
@@ -0,0 +1,47 @@
+#include <metal_stdlib>
+
+using namespace metal;
+
+template<typename T, size_t N>
+struct tint_array {
+ const constant T& operator[](size_t i) const constant { return elements[i]; }
+ device T& operator[](size_t i) device { return elements[i]; }
+ const device T& operator[](size_t i) const device { return elements[i]; }
+ thread T& operator[](size_t i) thread { return elements[i]; }
+ const thread T& operator[](size_t i) const thread { return elements[i]; }
+ threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
+ const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
+ T elements[N];
+};
+
+struct tint_packed_vec3_f32_array_element {
+ /* 0x0000 */ packed_float3 elements;
+ /* 0x000c */ tint_array<int8_t, 4> tint_pad;
+};
+
+struct S_tint_packed_vec3 {
+ /* 0x0000 */ float4x4 matrix_view;
+ /* 0x0040 */ tint_array<tint_packed_vec3_f32_array_element, 3> matrix_normal;
+};
+
+struct S {
+ float4x4 matrix_view;
+ float3x3 matrix_normal;
+};
+
+struct tint_symbol_2 {
+ float4 value [[position]];
+};
+
+float4 tint_symbol_1_inner(const constant S_tint_packed_vec3* const tint_symbol_3) {
+ float const x = (*(tint_symbol_3)).matrix_view[0][2];
+ return float4(x, 0.0f, 0.0f, 1.0f);
+}
+
+vertex tint_symbol_2 tint_symbol_1(const constant S_tint_packed_vec3* tint_symbol_4 [[buffer(0)]]) {
+ float4 const inner_result = tint_symbol_1_inner(tint_symbol_4);
+ tint_symbol_2 wrapper_result = {};
+ wrapper_result.value = inner_result;
+ return wrapper_result;
+}
+
diff --git a/test/tint/bug/tint/2100.wgsl.expected.spvasm b/test/tint/bug/tint/2100.wgsl.expected.spvasm
new file mode 100644
index 0000000..c3f7f3f
--- /dev/null
+++ b/test/tint/bug/tint/2100.wgsl.expected.spvasm
@@ -0,0 +1,70 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 34
+; Schema: 0
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Vertex %main "main" %value %vertex_point_size
+ OpName %value "value"
+ OpName %vertex_point_size "vertex_point_size"
+ OpName %buffer_block "buffer_block"
+ OpMemberName %buffer_block 0 "inner"
+ OpName %S "S"
+ OpMemberName %S 0 "matrix_view"
+ OpMemberName %S 1 "matrix_normal"
+ OpName %buffer "buffer"
+ OpName %main_inner "main_inner"
+ OpName %main "main"
+ OpDecorate %value BuiltIn Position
+ OpDecorate %vertex_point_size BuiltIn PointSize
+ OpDecorate %buffer_block Block
+ OpMemberDecorate %buffer_block 0 Offset 0
+ OpMemberDecorate %S 0 Offset 0
+ OpMemberDecorate %S 0 ColMajor
+ OpMemberDecorate %S 0 MatrixStride 16
+ OpMemberDecorate %S 1 Offset 64
+ OpMemberDecorate %S 1 ColMajor
+ OpMemberDecorate %S 1 MatrixStride 16
+ OpDecorate %buffer NonWritable
+ OpDecorate %buffer DescriptorSet 0
+ OpDecorate %buffer Binding 0
+ %float = OpTypeFloat 32
+ %v4float = OpTypeVector %float 4
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+ %5 = OpConstantNull %v4float
+ %value = OpVariable %_ptr_Output_v4float Output %5
+%_ptr_Output_float = OpTypePointer Output %float
+ %8 = OpConstantNull %float
+%vertex_point_size = OpVariable %_ptr_Output_float Output %8
+%mat4v4float = OpTypeMatrix %v4float 4
+ %v3float = OpTypeVector %float 3
+%mat3v3float = OpTypeMatrix %v3float 3
+ %S = OpTypeStruct %mat4v4float %mat3v3float
+%buffer_block = OpTypeStruct %S
+%_ptr_Uniform_buffer_block = OpTypePointer Uniform %buffer_block
+ %buffer = OpVariable %_ptr_Uniform_buffer_block Uniform
+ %16 = OpTypeFunction %v4float
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %int = OpTypeInt 32 1
+ %22 = OpConstantNull %int
+ %uint_2 = OpConstant %uint 2
+%_ptr_Uniform_float = OpTypePointer Uniform %float
+ %float_1 = OpConstant %float 1
+ %void = OpTypeVoid
+ %29 = OpTypeFunction %void
+ %main_inner = OpFunction %v4float None %16
+ %18 = OpLabel
+ %25 = OpAccessChain %_ptr_Uniform_float %buffer %uint_0 %uint_0 %22 %uint_2
+ %26 = OpLoad %float %25
+ %28 = OpCompositeConstruct %v4float %26 %8 %8 %float_1
+ OpReturnValue %28
+ OpFunctionEnd
+ %main = OpFunction %void None %29
+ %32 = OpLabel
+ %33 = OpFunctionCall %v4float %main_inner
+ OpStore %value %33
+ OpStore %vertex_point_size %float_1
+ OpReturn
+ OpFunctionEnd
diff --git a/test/tint/bug/tint/2100.wgsl.expected.wgsl b/test/tint/bug/tint/2100.wgsl.expected.wgsl
new file mode 100644
index 0000000..360ee9c
--- /dev/null
+++ b/test/tint/bug/tint/2100.wgsl.expected.wgsl
@@ -0,0 +1,12 @@
+struct S {
+ matrix_view : mat4x4<f32>,
+ matrix_normal : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<uniform> buffer : S;
+
+@vertex
+fn main() -> @builtin(position) vec4f {
+ let x = buffer.matrix_view[0].z;
+ return vec4f(x, 0, 0, 1);
+}