[spirv-reader] Handle float types and constants
Bug: tint:1907
Change-Id: I9b93794121a376fe24eb903df93a3e08509341df
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/165640
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/tint/lang/spirv/reader/parser/constant_test.cc b/src/tint/lang/spirv/reader/parser/constant_test.cc
index e37e214..78d6d63 100644
--- a/src/tint/lang/spirv/reader/parser/constant_test.cc
+++ b/src/tint/lang/spirv/reader/parser/constant_test.cc
@@ -168,5 +168,108 @@
)");
}
+TEST_F(SpirvParserTest, Constant_F16) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpCapability Float16
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %f16 = OpTypeFloat 16
+ %f16_0 = OpConstant %f16 0
+ %f16_1 = OpConstant %f16 1
+ %f16_max = OpConstant %f16 0x1.ffcp+15
+ %f16_min = OpConstant %f16 -0x1.ffcp+15
+ %f16_denorm = OpConstant %f16 0x0.004p-14
+ %void_fn = OpTypeFunction %void
+ %fn_type = OpTypeFunction %void %f16
+
+ %main = OpFunction %void None %void_fn
+ %main_start = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ %foo = OpFunction %void None %fn_type
+ %param = OpFunctionParameter %f16
+ %foo_start = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ %bar = OpFunction %void None %void_fn
+ %bar_start = OpLabel
+ %1 = OpFunctionCall %void %foo %f16_0
+ %2 = OpFunctionCall %void %foo %f16_1
+ %3 = OpFunctionCall %void %foo %f16_max
+ %4 = OpFunctionCall %void %foo %f16_min
+ %5 = OpFunctionCall %void %foo %f16_denorm
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%4 = func():void -> %b3 {
+ %b3 = block {
+ %5:void = call %2, 0.0h
+ %6:void = call %2, 1.0h
+ %7:void = call %2, 65504.0h
+ %8:void = call %2, -65504.0h
+ %9:void = call %2, 0.00000005960464477539h
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, Constant_F32) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %f32 = OpTypeFloat 32
+ %f32_0 = OpConstant %f32 0
+ %f32_1 = OpConstant %f32 1
+ %f32_max = OpConstant %f32 0x1.fffffep+127
+ %f32_min = OpConstant %f32 -0x1.fffffep+127
+ %f32_denorm = OpConstant %f32 0x0.000002p-126
+ %void_fn = OpTypeFunction %void
+ %fn_type = OpTypeFunction %void %f32
+
+ %main = OpFunction %void None %void_fn
+ %main_start = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ %foo = OpFunction %void None %fn_type
+ %param = OpFunctionParameter %f32
+ %foo_start = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ %bar = OpFunction %void None %void_fn
+ %bar_start = OpLabel
+ %1 = OpFunctionCall %void %foo %f32_0
+ %2 = OpFunctionCall %void %foo %f32_1
+ %3 = OpFunctionCall %void %foo %f32_max
+ %4 = OpFunctionCall %void %foo %f32_min
+ %5 = OpFunctionCall %void %foo %f32_denorm
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%4 = func():void -> %b3 {
+ %b3 = block {
+ %5:void = call %2, 0.0f
+ %6:void = call %2, 1.0f
+ %7:void = call %2, 340282346638528859811704183484516925440.0f
+ %8:void = call %2, -340282346638528859811704183484516925440.0f
+ %9:void = call %2, 1.40129846e-45f
+ ret
+ }
+}
+)");
+}
+
} // namespace
} // namespace tint::spirv::reader
diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc
index 160c131..3b1cf2d 100644
--- a/src/tint/lang/spirv/reader/parser/parser.cc
+++ b/src/tint/lang/spirv/reader/parser/parser.cc
@@ -104,6 +104,18 @@
return ty_.u32();
}
}
+ case spvtools::opt::analysis::Type::kFloat: {
+ auto* float_ty = type->AsFloat();
+ if (float_ty->width() == 16) {
+ return ty_.f16();
+ } else if (float_ty->width() == 32) {
+ return ty_.f32();
+ } else {
+ TINT_UNREACHABLE()
+ << "unsupported floating point type width: " << float_ty->width();
+ return ty_.void_();
+ }
+ }
default:
TINT_UNIMPLEMENTED() << "unhandled SPIR-V type: " << type->str();
return ty_.void_();
@@ -152,6 +164,17 @@
return b_.Constant(u32(i->GetU32BitValue()));
}
}
+ if (auto* f = constant->AsFloatConstant()) {
+ auto* float_ty = f->type()->AsFloat();
+ if (float_ty->width() == 16) {
+ return b_.Constant(f16::FromBits(static_cast<uint16_t>(f->words()[0])));
+ } else if (float_ty->width() == 32) {
+ return b_.Constant(f32(f->GetFloat()));
+ } else {
+ TINT_UNREACHABLE() << "unsupported floating point type width";
+ return nullptr;
+ }
+ }
TINT_UNIMPLEMENTED() << "unhandled constant type";
return nullptr;
}