Support for OpVectorShuffle
Emit Swizzle IRs for shuffles that use only one vector and Construct IRs
for shuffles that mix components from two vectors.
Bug: 391484672
Change-Id: Idaacc5a40ab14165e29213a45a625dcecda9e390
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/233654
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/spirv/reader/parser/composite_test.cc b/src/tint/lang/spirv/reader/parser/composite_test.cc
index e0e2b8e..9327adf 100644
--- a/src/tint/lang/spirv/reader/parser/composite_test.cc
+++ b/src/tint/lang/spirv/reader/parser/composite_test.cc
@@ -744,5 +744,422 @@
)");
}
+TEST_F(SpirvParserTest, VectorShuffle_BothVectors_A) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %u32 = OpTypeInt 32 0
+ %v4u32 = OpTypeVector %u32 4
+ %v4u32_ptr = OpTypePointer Function %v4u32
+ %ep_type = OpTypeFunction %void
+
+ %main = OpFunction %void None %ep_type
+ %main_start = OpLabel
+ %vec1 = OpVariable %v4u32_ptr Function
+ %vec2 = OpVariable %v4u32_ptr Function
+ %tmp1 = OpLoad %v4u32 %vec1
+ %tmp2 = OpLoad %v4u32 %vec2
+ %shuf = OpVectorShuffle %v4u32 %tmp1 %tmp2 0 5 2 7
+ OpReturn
+ OpFunctionEnd
+ )",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:ptr<function, vec4<u32>, read_write> = var undef
+ %3:ptr<function, vec4<u32>, read_write> = var undef
+ %4:vec4<u32> = load %2
+ %5:vec4<u32> = load %3
+ %6:u32 = access %4, 0u
+ %7:u32 = access %5, 1u
+ %8:u32 = access %4, 2u
+ %9:u32 = access %5, 3u
+ %10:vec4<u32> = construct %6, %7, %8, %9
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, VectorShuffle_BothVectors_B) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %u32 = OpTypeInt 32 0
+ %v4u32 = OpTypeVector %u32 4
+ %v4u32_ptr = OpTypePointer Function %v4u32
+ %ep_type = OpTypeFunction %void
+
+ %main = OpFunction %void None %ep_type
+ %main_start = OpLabel
+ %vec1 = OpVariable %v4u32_ptr Function
+ %vec2 = OpVariable %v4u32_ptr Function
+ %tmp1 = OpLoad %v4u32 %vec1
+ %tmp2 = OpLoad %v4u32 %vec2
+ %shuf = OpVectorShuffle %v4u32 %tmp1 %tmp2 0 2 5 7
+ OpReturn
+ OpFunctionEnd
+ )",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:ptr<function, vec4<u32>, read_write> = var undef
+ %3:ptr<function, vec4<u32>, read_write> = var undef
+ %4:vec4<u32> = load %2
+ %5:vec4<u32> = load %3
+ %6:u32 = access %4, 0u
+ %7:u32 = access %4, 2u
+ %8:u32 = access %5, 1u
+ %9:u32 = access %5, 3u
+ %10:vec4<u32> = construct %6, %7, %8, %9
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, VectorShuffle_BothVectorsToBigger) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %u32 = OpTypeInt 32 0
+ %v2u32 = OpTypeVector %u32 2
+ %v3u32 = OpTypeVector %u32 3
+ %v4u32 = OpTypeVector %u32 4
+ %v2u32_ptr = OpTypePointer Function %v2u32
+ %v3u32_ptr = OpTypePointer Function %v3u32
+ %ep_type = OpTypeFunction %void
+
+ %main = OpFunction %void None %ep_type
+ %main_start = OpLabel
+ %vec1 = OpVariable %v2u32_ptr Function
+ %vec2 = OpVariable %v3u32_ptr Function
+ %tmp1 = OpLoad %v2u32 %vec1
+ %tmp2 = OpLoad %v3u32 %vec2
+ %shuf = OpVectorShuffle %v4u32 %tmp1 %tmp2 0 2 1 4
+ OpReturn
+ OpFunctionEnd
+ )",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:ptr<function, vec2<u32>, read_write> = var undef
+ %3:ptr<function, vec3<u32>, read_write> = var undef
+ %4:vec2<u32> = load %2
+ %5:vec3<u32> = load %3
+ %6:u32 = access %4, 0u
+ %7:u32 = access %5, 0u
+ %8:u32 = access %4, 1u
+ %9:u32 = access %5, 2u
+ %10:vec4<u32> = construct %6, %7, %8, %9
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, VectorShuffle_BothVectorsToSmaller) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %u32 = OpTypeInt 32 0
+ %v2u32 = OpTypeVector %u32 2
+ %v3u32 = OpTypeVector %u32 3
+ %v4u32 = OpTypeVector %u32 4
+ %v3u32_ptr = OpTypePointer Function %v3u32
+ %v4u32_ptr = OpTypePointer Function %v4u32
+ %ep_type = OpTypeFunction %void
+
+ %main = OpFunction %void None %ep_type
+ %main_start = OpLabel
+ %vec1 = OpVariable %v3u32_ptr Function
+ %vec2 = OpVariable %v4u32_ptr Function
+ %tmp1 = OpLoad %v3u32 %vec1
+ %tmp2 = OpLoad %v4u32 %vec2
+ %shuf = OpVectorShuffle %v2u32 %tmp1 %tmp2 0 4
+ OpReturn
+ OpFunctionEnd
+ )",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:ptr<function, vec3<u32>, read_write> = var undef
+ %3:ptr<function, vec4<u32>, read_write> = var undef
+ %4:vec3<u32> = load %2
+ %5:vec4<u32> = load %3
+ %6:u32 = access %4, 0u
+ %7:u32 = access %5, 1u
+ %8:vec2<u32> = construct %6, %7
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, VectorShuffle_BothVectors_UndefinedIndex) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %u32 = OpTypeInt 32 0
+ %v4u32 = OpTypeVector %u32 4
+ %v4u32_ptr = OpTypePointer Function %v4u32
+ %ep_type = OpTypeFunction %void
+
+ %main = OpFunction %void None %ep_type
+ %main_start = OpLabel
+ %vecA = OpVariable %v4u32_ptr Function
+ %vecB = OpVariable %v4u32_ptr Function
+ %tmpA = OpLoad %v4u32 %vecA
+ %tmpB = OpLoad %v4u32 %vecB
+ %shuf = OpVectorShuffle %v4u32 %tmpA %tmpB 0 4294967295 6 3
+ OpReturn
+ OpFunctionEnd
+ )",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:ptr<function, vec4<u32>, read_write> = var undef
+ %3:ptr<function, vec4<u32>, read_write> = var undef
+ %4:vec4<u32> = load %2
+ %5:vec4<u32> = load %3
+ %6:u32 = access %4, 0u
+ %7:u32 = access %4, 0u
+ %8:u32 = access %5, 2u
+ %9:u32 = access %4, 3u
+ %10:vec4<u32> = construct %6, %7, %8, %9
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, VectorShuffle_MixedDimensions_234) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %u32 = OpTypeInt 32 0
+ %v2u32 = OpTypeVector %u32 2
+ %v3u32 = OpTypeVector %u32 3
+ %v4u32 = OpTypeVector %u32 4
+ %v2u32_ptr = OpTypePointer Function %v2u32
+ %v3u32_ptr = OpTypePointer Function %v3u32
+ %ep_type = OpTypeFunction %void
+
+ %main = OpFunction %void None %ep_type
+ %main_start = OpLabel
+ %vec1 = OpVariable %v2u32_ptr Function
+ %vec2 = OpVariable %v3u32_ptr Function
+ %tmp1 = OpLoad %v2u32 %vec1
+ %tmp2 = OpLoad %v3u32 %vec2
+ %shuf = OpVectorShuffle %v4u32 %tmp1 %tmp2 0 3 4 1
+ OpReturn
+ OpFunctionEnd
+ )",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:ptr<function, vec2<u32>, read_write> = var undef
+ %3:ptr<function, vec3<u32>, read_write> = var undef
+ %4:vec2<u32> = load %2
+ %5:vec3<u32> = load %3
+ %6:u32 = access %4, 0u
+ %7:u32 = access %5, 1u
+ %8:u32 = access %5, 2u
+ %9:u32 = access %4, 1u
+ %10:vec4<u32> = construct %6, %7, %8, %9
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, VectorShuffle_Swizzle_FirstVector) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %u32 = OpTypeInt 32 0
+ %v4u32 = OpTypeVector %u32 4
+ %v4u32_ptr = OpTypePointer Function %v4u32
+ %v2u32 = OpTypeVector %u32 2
+ %ep_type = OpTypeFunction %void
+
+ %main = OpFunction %void None %ep_type
+ %main_start = OpLabel
+ %vec = OpVariable %v4u32_ptr Function
+ %tmp = OpLoad %v4u32 %vec
+ %shuf = OpVectorShuffle %v2u32 %tmp %tmp 0 2
+ OpReturn
+ OpFunctionEnd
+ )",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:ptr<function, vec4<u32>, read_write> = var undef
+ %3:vec4<u32> = load %2
+ %4:vec2<u32> = swizzle %3, xz
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, VectorShuffle_Swizzle_SecondVector) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %u32 = OpTypeInt 32 0
+ %v4u32 = OpTypeVector %u32 4
+ %v4u32_ptr = OpTypePointer Function %v4u32
+ %ep_type = OpTypeFunction %void
+
+ %main = OpFunction %void None %ep_type
+ %main_start = OpLabel
+ %vec1 = OpVariable %v4u32_ptr Function
+ %vec2 = OpVariable %v4u32_ptr Function
+ %tmp1 = OpLoad %v4u32 %vec1
+ %tmp2 = OpLoad %v4u32 %vec2
+ %shuf = OpVectorShuffle %v4u32 %tmp1 %tmp2 4 5 6 7
+ OpReturn
+ OpFunctionEnd
+ )",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:ptr<function, vec4<u32>, read_write> = var undef
+ %3:ptr<function, vec4<u32>, read_write> = var undef
+ %4:vec4<u32> = load %2
+ %5:vec4<u32> = load %3
+ %6:vec4<u32> = swizzle %5, xyzw
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, VectorShuffle_Swizzle_SmallerResult) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %u32 = OpTypeInt 32 0
+ %v4u32 = OpTypeVector %u32 4
+ %v2u32 = OpTypeVector %u32 2
+ %v4u32_ptr = OpTypePointer Function %v4u32
+ %ep_type = OpTypeFunction %void
+
+ %main = OpFunction %void None %ep_type
+ %main_start = OpLabel
+ %vec = OpVariable %v4u32_ptr Function
+ %tmp = OpLoad %v4u32 %vec
+ %shuf = OpVectorShuffle %v2u32 %tmp %tmp 2 0
+ OpReturn
+ OpFunctionEnd
+ )",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:ptr<function, vec4<u32>, read_write> = var undef
+ %3:vec4<u32> = load %2
+ %4:vec2<u32> = swizzle %3, zx
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, VectorShuffle_Swizzle_BiggerResult) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %u32 = OpTypeInt 32 0
+ %v2u32 = OpTypeVector %u32 2
+ %v4u32 = OpTypeVector %u32 4
+ %v2u32_ptr = OpTypePointer Function %v2u32
+ %ep_type = OpTypeFunction %void
+
+ %main = OpFunction %void None %ep_type
+ %main_start = OpLabel
+ %vec = OpVariable %v2u32_ptr Function
+ %tmp = OpLoad %v2u32 %vec
+ %shuf = OpVectorShuffle %v4u32 %tmp %tmp 3 2 3 2
+ OpReturn
+ OpFunctionEnd
+ )",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:ptr<function, vec2<u32>, read_write> = var undef
+ %3:vec2<u32> = load %2
+ %4:vec4<u32> = swizzle %3, yxyx
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, VectorShuffle_Swizzle_OneVectorUndef) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+
+ %void = OpTypeVoid
+ %u32 = OpTypeInt 32 0
+ %v2u32 = OpTypeVector %u32 2
+ %v4u32 = OpTypeVector %u32 4
+ %v2u32_ptr = OpTypePointer Function %v2u32
+ %ep_type = OpTypeFunction %void
+
+ %main = OpFunction %void None %ep_type
+ %main_start = OpLabel
+ %vec = OpVariable %v2u32_ptr Function
+ %tmp = OpLoad %v2u32 %vec
+ %undef4 = OpUndef %v4u32
+ %shuf = OpVectorShuffle %v2u32 %undef4 %tmp 4 5
+ OpReturn
+ OpFunctionEnd
+ )",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:ptr<function, vec2<u32>, read_write> = var undef
+ %3:vec2<u32> = load %2
+ %4:vec2<u32> = swizzle %3, xy
+ ret
+ }
+}
+)");
+}
+
} // namespace
} // namespace tint::spirv::reader
diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc
index 18160ed..3fc303f 100644
--- a/src/tint/lang/spirv/reader/parser/parser.cc
+++ b/src/tint/lang/spirv/reader/parser/parser.cc
@@ -1147,6 +1147,9 @@
case spv::Op::OpOuterProduct:
EmitSpirvBuiltinCall(inst, spirv::BuiltinFn::kOuterProduct);
break;
+ case spv::Op::OpVectorShuffle:
+ EmitVectorShuffle(inst);
+ break;
case spv::Op::OpAtomicStore:
EmitAtomicStore(inst);
break;
@@ -2201,6 +2204,76 @@
Emit(b_.Load(tmp), inst.result_id());
}
+ /// @param inst the SPIR-V instruction for OpVectorShuffle
+ void EmitVectorShuffle(const spvtools::opt::Instruction& inst) {
+ auto* vector1 = Value(inst.GetSingleWordOperand(2));
+ auto* vector2 = Value(inst.GetSingleWordOperand(3));
+ auto* result_ty = Type(inst.type_id());
+
+ uint32_t n1 = vector1->Type()->As<core::type::Vector>()->Width();
+ uint32_t n2 = vector2->Type()->As<core::type::Vector>()->Width();
+
+ Vector<uint32_t, 4> literals;
+ for (uint32_t i = 4; i < inst.NumOperandWords(); i++) {
+ literals.Push(inst.GetSingleWordOperand(i));
+ }
+
+ // Check if all literals fall entirely within `vector1` or `vector2`,
+ // which would allow us to use a single-vector swizzle.
+ bool swizzle_from_vector1_only = true;
+ bool swizzle_from_vector2_only = true;
+ for (auto& literal : literals) {
+ if (literal == ~0u) {
+ // A `0xFFFFFFFF` literal represents an undefined index,
+ // fallback to first index.
+ literal = 0;
+ }
+ if (literal >= n1) {
+ swizzle_from_vector1_only = false;
+ }
+ if (literal < n1) {
+ swizzle_from_vector2_only = false;
+ }
+ }
+
+ // If only one vector is used, we can swizzle it.
+ if (swizzle_from_vector1_only) {
+ // Indices are already within `[0, n1)`, as expected by `Swizzle` IR
+ // for `vector1`.
+ Emit(b_.Swizzle(result_ty, vector1, literals), inst.result_id());
+ return;
+ }
+ if (swizzle_from_vector2_only) {
+ // Map logical concatenated indices' range `[n1, n1 + n2)` into the range
+ // `[0, n2)`, as expected by `Swizzle` IR for `vector2`.
+ for (auto& literal : literals) {
+ literal -= n1;
+ }
+ Emit(b_.Swizzle(result_ty, vector2, literals), inst.result_id());
+ return;
+ }
+
+ // Swizzle is not possible, construct the result vector out of elements
+ // from both vectors.
+ auto* element_ty = vector1->Type()->DeepestElement();
+ Vector<core::ir::Value*, 4> result;
+ for (auto idx : literals) {
+ TINT_ASSERT(idx < n1 + n2);
+
+ if (idx < n1) {
+ auto* access_inst = b_.Access(element_ty, vector1, b_.Constant(u32(idx)));
+ EmitWithoutSpvResult(access_inst);
+ result.Push(access_inst->Result(0));
+ } else {
+ auto* access_inst = b_.Access(element_ty, vector2, b_.Constant(u32(idx - n1)));
+ EmitWithoutSpvResult(access_inst);
+ result.Push(access_inst->Result(0));
+ }
+ }
+
+ Emit(b_.Construct(result_ty, result), inst.result_id());
+ }
+
/// @param inst the SPIR-V instruction for OpFunctionCall
void EmitFunctionCall(const spvtools::opt::Instruction& inst) {
Emit(b_.Call(Function(inst.GetSingleWordInOperand(0)), Args(inst, 3)), inst.result_id());