[spirv-reader] Handle matrix types and constants
Test OpConstantNull too.
Bug: tint:1907, tint:2123
Change-Id: I4d8cb96116793a84672caf937cf6d3492e741e22
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/168203
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/tint/lang/spirv/reader/parser/constant_test.cc b/src/tint/lang/spirv/reader/parser/constant_test.cc
index f323d86..4c27c4c 100644
--- a/src/tint/lang/spirv/reader/parser/constant_test.cc
+++ b/src/tint/lang/spirv/reader/parser/constant_test.cc
@@ -424,5 +424,104 @@
)");
}
+TEST_F(SpirvParserTest, Constant_Mat2x4F32) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %f32 = OpTypeFloat 32
+ %vec4f = OpTypeVector %f32 4
+ %mat2x4f = OpTypeMatrix %vec4f 2
+ %f32_0 = OpConstant %f32 0
+ %f32_1 = OpConstant %f32 1
+%vec4f_const_0 = OpConstantComposite %vec4f %f32_0 %f32_0 %f32_0 %f32_0
+%vec4f_const_1 = OpConstantComposite %vec4f %f32_1 %f32_1 %f32_1 %f32_1
+%mat2x4f_const = OpConstantComposite %mat2x4f %vec4f_const_0 %vec4f_const_1
+ %null = OpConstantNull %mat2x4f
+ %void_fn = OpTypeFunction %void
+ %fn_type = OpTypeFunction %mat2x4f %mat2x4f
+
+ %main = OpFunction %void None %void_fn
+ %main_start = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ %foo = OpFunction %mat2x4f None %fn_type
+ %param = OpFunctionParameter %mat2x4f
+ %foo_start = OpLabel
+ OpReturnValue %param
+ OpFunctionEnd
+
+ %bar = OpFunction %void None %void_fn
+ %bar_start = OpLabel
+ %1 = OpFunctionCall %mat2x4f %foo %mat2x4f_const
+ %2 = OpFunctionCall %mat2x4f %foo %null
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%4 = func():void -> %b3 {
+ %b3 = block {
+ %5:mat2x4<f32> = call %2, mat2x4<f32>(vec4<f32>(0.0f), vec4<f32>(1.0f))
+ %6:mat2x4<f32> = call %2, mat2x4<f32>(vec4<f32>(0.0f))
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, Constant_Mat3x2F16) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpCapability Float16
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %f16 = OpTypeFloat 16
+ %vec2h = OpTypeVector %f16 2
+ %mat3x2h = OpTypeMatrix %vec2h 3
+ %f16_0 = OpConstant %f16 0
+ %f16_1 = OpConstant %f16 1
+ %f16_max = OpConstant %f16 0x1.ffcp+15
+%vec2h_const_0 = OpConstantComposite %vec2h %f16_0 %f16_0
+%vec2h_const_1 = OpConstantComposite %vec2h %f16_1 %f16_1
+%vec2h_const_max = OpConstantComposite %vec2h %f16_max %f16_max
+%mat3x2h_const = OpConstantComposite %mat3x2h %vec2h_const_0 %vec2h_const_1 %vec2h_const_max
+ %null = OpConstantNull %mat3x2h
+ %void_fn = OpTypeFunction %void
+ %fn_type = OpTypeFunction %mat3x2h %mat3x2h
+
+ %main = OpFunction %void None %void_fn
+ %main_start = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ %foo = OpFunction %mat3x2h None %fn_type
+ %param = OpFunctionParameter %mat3x2h
+ %foo_start = OpLabel
+ OpReturnValue %param
+ OpFunctionEnd
+
+ %bar = OpFunction %void None %void_fn
+ %bar_start = OpLabel
+ %1 = OpFunctionCall %mat3x2h %foo %mat3x2h_const
+ %2 = OpFunctionCall %mat3x2h %foo %null
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%4 = func():void -> %b3 {
+ %b3 = block {
+ %5:mat3x2<f16> = call %2, mat3x2<f16>(vec2<f16>(0.0h), vec2<f16>(1.0h), vec2<f16>(65504.0h))
+ %6:mat3x2<f16> = call %2, mat3x2<f16>(vec2<f16>(0.0h))
+ ret
+ }
+}
+)");
+}
+
} // namespace
} // namespace tint::spirv::reader
diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc
index d068078..90e0f1d 100644
--- a/src/tint/lang/spirv/reader/parser/parser.cc
+++ b/src/tint/lang/spirv/reader/parser/parser.cc
@@ -136,6 +136,12 @@
TINT_ASSERT_OR_RETURN_VALUE(vec_ty->element_count() <= 4, ty_.void_());
return ty_.vec(Type(vec_ty->element_type()), vec_ty->element_count());
}
+ case spvtools::opt::analysis::Type::kMatrix: {
+ auto* mat_ty = type->AsMatrix();
+ TINT_ASSERT_OR_RETURN_VALUE(mat_ty->element_count() <= 4, ty_.void_());
+ return ty_.mat(As<core::type::Vector>(Type(mat_ty->element_type())),
+ mat_ty->element_count());
+ }
case spvtools::opt::analysis::Type::kPointer: {
auto* ptr_ty = type->AsPointer();
return ty_.ptr(AddressSpace(ptr_ty->storage_class()), Type(ptr_ty->pointee_type()));
@@ -211,6 +217,13 @@
}
return ir_.constant_values.Composite(Type(v->type()), std::move(elements));
}
+ if (auto* m = constant->AsMatrixConstant()) {
+ Vector<const core::constant::Value*, 4> columns;
+ for (auto& el : m->GetComponents()) {
+ columns.Push(Constant(el));
+ }
+ return ir_.constant_values.Composite(Type(m->type()), std::move(columns));
+ }
TINT_UNIMPLEMENTED() << "unhandled constant type";
return nullptr;
}