[spirv-reader] Handle vector types and constants
Change the `Constant()` helper to return a `core::constant::Value*`
instead of a `core::ir::Constant*` so that it can be used recursively.
Test OpConstantNull too.
Bug: tint:1907, tint:2123
Change-Id: Ie65a410f4f412f8a79b6978305ef85f3edbd5da4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/168202
Reviewed-by: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/lang/spirv/reader/parser/constant_test.cc b/src/tint/lang/spirv/reader/parser/constant_test.cc
index d8d8d57..f323d86 100644
--- a/src/tint/lang/spirv/reader/parser/constant_test.cc
+++ b/src/tint/lang/spirv/reader/parser/constant_test.cc
@@ -286,5 +286,143 @@
)");
}
+TEST_F(SpirvParserTest, Constant_Vec2Bool) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %bool = OpTypeBool
+ %vec2b = OpTypeVector %bool 2
+ %true = OpConstantTrue %bool
+ %false = OpConstantFalse %bool
+%vec2b_const = OpConstantComposite %vec2b %true %false
+ %null = OpConstantNull %vec2b
+ %void_fn = OpTypeFunction %void
+ %fn_type = OpTypeFunction %vec2b %vec2b
+
+ %main = OpFunction %void None %void_fn
+ %main_start = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ %foo = OpFunction %vec2b None %fn_type
+ %param = OpFunctionParameter %vec2b
+ %foo_start = OpLabel
+ OpReturnValue %param
+ OpFunctionEnd
+
+ %bar = OpFunction %void None %void_fn
+ %bar_start = OpLabel
+ %1 = OpFunctionCall %vec2b %foo %vec2b_const
+ %2 = OpFunctionCall %vec2b %foo %null
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%4 = func():void -> %b3 {
+ %b3 = block {
+ %5:vec2<bool> = call %2, vec2<bool>(true, false)
+ %6:vec2<bool> = call %2, vec2<bool>(false)
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, Constant_Vec3I32) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %i32 = OpTypeInt 32 1
+ %vec3i = OpTypeVector %i32 3
+ %i32_0 = OpConstant %i32 0
+ %i32_1 = OpConstant %i32 1
+ %i32_n1 = OpConstant %i32 -1
+%vec3i_const = OpConstantComposite %vec3i %i32_0 %i32_1 %i32_n1
+ %null = OpConstantNull %vec3i
+ %void_fn = OpTypeFunction %void
+ %fn_type = OpTypeFunction %vec3i %vec3i
+
+ %main = OpFunction %void None %void_fn
+ %main_start = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ %foo = OpFunction %vec3i None %fn_type
+ %param = OpFunctionParameter %vec3i
+ %foo_start = OpLabel
+ OpReturnValue %param
+ OpFunctionEnd
+
+ %bar = OpFunction %void None %void_fn
+ %bar_start = OpLabel
+ %1 = OpFunctionCall %vec3i %foo %vec3i_const
+ %2 = OpFunctionCall %vec3i %foo %null
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%4 = func():void -> %b3 {
+ %b3 = block {
+ %5:vec3<i32> = call %2, vec3<i32>(0i, 1i, -1i)
+ %6:vec3<i32> = call %2, vec3<i32>(0i)
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, Constant_Vec4F32) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %f32 = OpTypeFloat 32
+ %vec4f = OpTypeVector %f32 4
+ %f32_0 = OpConstant %f32 0
+ %f32_1 = OpConstant %f32 1
+ %f32_max = OpConstant %f32 0x1.fffffep+127
+ %f32_min = OpConstant %f32 -0x1.fffffep+127
+%vec4f_const = OpConstantComposite %vec4f %f32_0 %f32_1 %f32_max %f32_min
+ %null = OpConstantNull %vec4f
+ %void_fn = OpTypeFunction %void
+ %fn_type = OpTypeFunction %vec4f %vec4f
+
+ %main = OpFunction %void None %void_fn
+ %main_start = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ %foo = OpFunction %vec4f None %fn_type
+ %param = OpFunctionParameter %vec4f
+ %foo_start = OpLabel
+ OpReturnValue %param
+ OpFunctionEnd
+
+ %bar = OpFunction %void None %void_fn
+ %bar_start = OpLabel
+ %1 = OpFunctionCall %vec4f %foo %vec4f_const
+ %2 = OpFunctionCall %vec4f %foo %null
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%4 = func():void -> %b3 {
+ %b3 = block {
+ %5:vec4<f32> = call %2, vec4<f32>(0.0f, 1.0f, 340282346638528859811704183484516925440.0f, -340282346638528859811704183484516925440.0f)
+ %6:vec4<f32> = call %2, vec4<f32>(0.0f)
+ ret
+ }
+}
+)");
+}
+
} // namespace
} // namespace tint::spirv::reader
diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc
index d7a8ef1..d068078 100644
--- a/src/tint/lang/spirv/reader/parser/parser.cc
+++ b/src/tint/lang/spirv/reader/parser/parser.cc
@@ -131,6 +131,11 @@
return ty_.void_();
}
}
+ case spvtools::opt::analysis::Type::kVector: {
+ auto* vec_ty = type->AsVector();
+ TINT_ASSERT_OR_RETURN_VALUE(vec_ty->element_count() <= 4, ty_.void_());
+ return ty_.vec(Type(vec_ty->element_type()), vec_ty->element_count());
+ }
case spvtools::opt::analysis::Type::kPointer: {
auto* ptr_ty = type->AsPointer();
return ty_.ptr(AddressSpace(ptr_ty->storage_class()), Type(ptr_ty->pointee_type()));
@@ -161,7 +166,7 @@
core::ir::Value* Value(uint32_t id) {
return values_.GetOrCreate(id, [&]() -> core::ir::Value* {
if (auto* c = spirv_context_->get_constant_mgr()->FindDeclaredConstant(id)) {
- return Constant(c);
+ return b_.Constant(Constant(c));
}
TINT_UNREACHABLE() << "missing value for result ID " << id;
return nullptr;
@@ -169,36 +174,43 @@
}
/// @param constant a SPIR-V constant object
- /// @returns a Tint constant object
- core::ir::Constant* Constant(const spvtools::opt::analysis::Constant* constant) {
+ /// @returns a Tint constant value
+ const core::constant::Value* Constant(const spvtools::opt::analysis::Constant* constant) {
// Handle OpConstantNull for all types.
if (constant->AsNullConstant()) {
- return b_.Constant(ir_.constant_values.Zero(Type(constant->type())));
+ return ir_.constant_values.Zero(Type(constant->type()));
}
if (auto* bool_ = constant->AsBoolConstant()) {
- return b_.Constant(bool_->value());
+ return b_.ConstantValue(bool_->value());
}
if (auto* i = constant->AsIntConstant()) {
auto* int_ty = i->type()->AsInteger();
TINT_ASSERT_OR_RETURN_VALUE(int_ty->width() == 32, nullptr);
if (int_ty->IsSigned()) {
- return b_.Constant(i32(i->GetS32BitValue()));
+ return b_.ConstantValue(i32(i->GetS32BitValue()));
} else {
- return b_.Constant(u32(i->GetU32BitValue()));
+ return b_.ConstantValue(u32(i->GetU32BitValue()));
}
}
if (auto* f = constant->AsFloatConstant()) {
auto* float_ty = f->type()->AsFloat();
if (float_ty->width() == 16) {
- return b_.Constant(f16::FromBits(static_cast<uint16_t>(f->words()[0])));
+ return b_.ConstantValue(f16::FromBits(static_cast<uint16_t>(f->words()[0])));
} else if (float_ty->width() == 32) {
- return b_.Constant(f32(f->GetFloat()));
+ return b_.ConstantValue(f32(f->GetFloat()));
} else {
TINT_UNREACHABLE() << "unsupported floating point type width";
return nullptr;
}
}
+ if (auto* v = constant->AsVectorConstant()) {
+ Vector<const core::constant::Value*, 4> elements;
+ for (auto& el : v->GetComponents()) {
+ elements.Push(Constant(el));
+ }
+ return ir_.constant_values.Composite(Type(v->type()), std::move(elements));
+ }
TINT_UNIMPLEMENTED() << "unhandled constant type";
return nullptr;
}