tint/writer/wgsl: Support for F16 type, constructor, and convertor
This patch make WGSL writer support emitting f16 types, f16 literals,
f16 constructor and convertor. Unittests are also implemented.
Bug: tint:1473, tint:1502
Change-Id: Id2a5eec54b95add330366cf141b36999e604a63b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/95990
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc
index 26e0a29..50f23fc 100644
--- a/src/tint/writer/wgsl/generator_impl.cc
+++ b/src/tint/writer/wgsl/generator_impl.cc
@@ -258,6 +258,10 @@
return true;
},
[&](const ast::FloatLiteralExpression* l) { //
+ // f16 literals are also emitted as float value with suffix "h".
+ // Note that all normal and subnormal f16 values are normal f32 values, and since NaN
+ // and Inf are not allowed to be spelled in literal, it should be fine to emit f16
+ // literals in this way.
out << FloatToBitPreservingString(static_cast<float>(l->value)) << l->suffix;
return true;
},
@@ -402,9 +406,8 @@
return true;
},
[&](const ast::F16*) {
- diagnostics_.add_error(diag::System::Writer,
- "Type f16 is not completely implemented yet.");
- return false;
+ out << "f16";
+ return true;
},
[&](const ast::I32*) {
out << "i32";
diff --git a/src/tint/writer/wgsl/generator_impl_cast_test.cc b/src/tint/writer/wgsl/generator_impl_cast_test.cc
index c423943..9b9379b 100644
--- a/src/tint/writer/wgsl/generator_impl_cast_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_cast_test.cc
@@ -21,7 +21,7 @@
using WgslGeneratorImplTest = TestHelper;
-TEST_F(WgslGeneratorImplTest, EmitExpression_Cast_Scalar) {
+TEST_F(WgslGeneratorImplTest, EmitExpression_Cast_Scalar_F32_From_I32) {
auto* cast = Construct<f32>(1_i);
WrapInFunction(cast);
@@ -32,7 +32,20 @@
EXPECT_EQ(out.str(), "f32(1i)");
}
-TEST_F(WgslGeneratorImplTest, EmitExpression_Cast_Vector) {
+TEST_F(WgslGeneratorImplTest, EmitExpression_Cast_Scalar_F16_From_I32) {
+ Enable(ast::Extension::kF16);
+
+ auto* cast = Construct<f16>(1_i);
+ WrapInFunction(cast);
+
+ GeneratorImpl& gen = Build();
+
+ std::stringstream out;
+ ASSERT_TRUE(gen.EmitExpression(out, cast)) << gen.error();
+ EXPECT_EQ(out.str(), "f16(1i)");
+}
+
+TEST_F(WgslGeneratorImplTest, EmitExpression_Cast_Vector_F32_From_I32) {
auto* cast = vec3<f32>(vec3<i32>(1_i, 2_i, 3_i));
WrapInFunction(cast);
@@ -43,5 +56,18 @@
EXPECT_EQ(out.str(), "vec3<f32>(vec3<i32>(1i, 2i, 3i))");
}
+TEST_F(WgslGeneratorImplTest, EmitExpression_Cast_Vector_F16_From_I32) {
+ Enable(ast::Extension::kF16);
+
+ auto* cast = vec3<f16>(vec3<i32>(1_i, 2_i, 3_i));
+ WrapInFunction(cast);
+
+ GeneratorImpl& gen = Build();
+
+ std::stringstream out;
+ ASSERT_TRUE(gen.EmitExpression(out, cast)) << gen.error();
+ EXPECT_EQ(out.str(), "vec3<f16>(vec3<i32>(1i, 2i, 3i))");
+}
+
} // namespace
} // namespace tint::writer::wgsl
diff --git a/src/tint/writer/wgsl/generator_impl_constructor_test.cc b/src/tint/writer/wgsl/generator_impl_constructor_test.cc
index ae9f2b7..07b10ad 100644
--- a/src/tint/writer/wgsl/generator_impl_constructor_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_constructor_test.cc
@@ -51,7 +51,7 @@
EXPECT_THAT(gen.result(), HasSubstr("56779u"));
}
-TEST_F(WgslGeneratorImplTest, EmitConstructor_Float) {
+TEST_F(WgslGeneratorImplTest, EmitConstructor_F32) {
// Use a number close to 1<<30 but whose decimal representation ends in 0.
WrapInFunction(Expr(f32((1 << 30) - 4)));
@@ -61,7 +61,19 @@
EXPECT_THAT(gen.result(), HasSubstr("1073741824.0f"));
}
-TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Float) {
+TEST_F(WgslGeneratorImplTest, EmitConstructor_F16) {
+ Enable(ast::Extension::kF16);
+
+ // Use a number close to 1<<16 but whose decimal representation ends in 0.
+ WrapInFunction(Expr(f16((1 << 15) - 8)));
+
+ GeneratorImpl& gen = Build();
+
+ ASSERT_TRUE(gen.Generate()) << gen.error();
+ EXPECT_THAT(gen.result(), HasSubstr("32752.0h"));
+}
+
+TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_F32) {
WrapInFunction(Construct<f32>(Expr(-1.2e-5_f)));
GeneratorImpl& gen = Build();
@@ -70,6 +82,17 @@
EXPECT_THAT(gen.result(), HasSubstr("f32(-0.000012f)"));
}
+TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_F16) {
+ Enable(ast::Extension::kF16);
+
+ WrapInFunction(Construct<f16>(Expr(-1.2e-5_h)));
+
+ GeneratorImpl& gen = Build();
+
+ ASSERT_TRUE(gen.Generate()) << gen.error();
+ EXPECT_THAT(gen.result(), HasSubstr("f16(-1.19805336e-05h)"));
+}
+
TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Bool) {
WrapInFunction(Construct<bool>(true));
@@ -97,7 +120,7 @@
EXPECT_THAT(gen.result(), HasSubstr("u32(12345u)"));
}
-TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Vec) {
+TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Vec_F32) {
WrapInFunction(vec3<f32>(1_f, 2_f, 3_f));
GeneratorImpl& gen = Build();
@@ -106,7 +129,18 @@
EXPECT_THAT(gen.result(), HasSubstr("vec3<f32>(1.0f, 2.0f, 3.0f)"));
}
-TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Mat) {
+TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Vec_F16) {
+ Enable(ast::Extension::kF16);
+
+ WrapInFunction(vec3<f16>(1_h, 2_h, 3_h));
+
+ GeneratorImpl& gen = Build();
+
+ ASSERT_TRUE(gen.Generate()) << gen.error();
+ EXPECT_THAT(gen.result(), HasSubstr("vec3<f16>(1.0h, 2.0h, 3.0h)"));
+}
+
+TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Mat_F32) {
WrapInFunction(mat2x3<f32>(vec3<f32>(1_f, 2_f, 3_f), vec3<f32>(3_f, 4_f, 5_f)));
GeneratorImpl& gen = Build();
@@ -116,6 +150,18 @@
"vec3<f32>(3.0f, 4.0f, 5.0f))"));
}
+TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Mat_F16) {
+ Enable(ast::Extension::kF16);
+
+ WrapInFunction(mat2x3<f16>(vec3<f16>(1_h, 2_h, 3_h), vec3<f16>(3_h, 4_h, 5_h)));
+
+ GeneratorImpl& gen = Build();
+
+ ASSERT_TRUE(gen.Generate()) << gen.error();
+ EXPECT_THAT(gen.result(), HasSubstr("mat2x3<f16>(vec3<f16>(1.0h, 2.0h, 3.0h), "
+ "vec3<f16>(3.0h, 4.0h, 5.0h))"));
+}
+
TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Array) {
WrapInFunction(Construct(ty.array(ty.vec3<f32>(), 3_u), vec3<f32>(1_f, 2_f, 3_f),
vec3<f32>(4_f, 5_f, 6_f), vec3<f32>(7_f, 8_f, 9_f)));
diff --git a/src/tint/writer/wgsl/generator_impl_literal_test.cc b/src/tint/writer/wgsl/generator_impl_literal_test.cc
index 78d70f6..f33bd0e 100644
--- a/src/tint/writer/wgsl/generator_impl_literal_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_literal_test.cc
@@ -25,7 +25,7 @@
// - 0 sign if sign is 0, 1 otherwise
// - 'exponent_bits' is placed in the exponent space.
// So, the exponent bias must already be included.
-f32 MakeFloat(uint32_t sign, uint32_t biased_exponent, uint32_t mantissa) {
+f32 MakeF32(uint32_t sign, uint32_t biased_exponent, uint32_t mantissa) {
const uint32_t sign_bit = sign ? 0x80000000u : 0u;
// The binary32 exponent is 8 bits, just below the sign.
const uint32_t exponent_bits = (biased_exponent & 0xffu) << 23;
@@ -40,18 +40,75 @@
return f32(result);
}
-struct FloatData {
+// Get the representation of an IEEE 754 binary16 floating point number with
+// - 0 sign if sign is 0, 1 otherwise
+// - 'exponent_bits' is placed in the exponent space.
+// - the exponent bias (15) already be included.
+f16 MakeF16(uint32_t sign, uint32_t f16_biased_exponent, uint16_t f16_mantissa) {
+ assert((f16_biased_exponent & 0xffffffe0u) == 0);
+ assert((f16_mantissa & 0xfc00u) == 0);
+
+ const uint32_t sign_bit = sign ? 0x80000000u : 0u;
+
+ // F16 has a exponent bias of 15, and f32 bias 127. Adding 127-15=112 to the f16-biased exponent
+ // to get f32-biased exponent.
+ uint32_t f32_biased_exponent = (f16_biased_exponent & 0x1fu) + 112;
+ assert((f32_biased_exponent & 0xffffff00u) == 0);
+
+ if (f16_biased_exponent == 0) {
+ // +/- zero, or subnormal
+ if (f16_mantissa == 0) {
+ // +/- zero
+ return sign ? f16(-0.0f) : f16(0.0f);
+ }
+ // Subnormal f16, calc the corresponding exponent and mantissa of normal f32.
+ f32_biased_exponent += 1;
+ // There must be at least one of the 10 mantissa bits being 1, left-shift the mantissa bits
+ // until the most significant 1 bit is left-shifted to 10th bit (count from zero), which
+ // will be omitted in the resulting f32 mantissa part.
+ assert(f16_mantissa & 0x03ffu);
+ while ((f16_mantissa & 0x0400u) == 0) {
+ f16_mantissa = static_cast<uint16_t>(f16_mantissa << 1);
+ f32_biased_exponent--;
+ }
+ }
+
+ // The binary32 exponent is 8 bits, just below the sign.
+ const uint32_t f32_exponent_bits = (f32_biased_exponent & 0xffu) << 23;
+ // The mantissa is the bottom 23 bits.
+ const uint32_t f32_mantissa_bits = (f16_mantissa & 0x03ffu) << 13;
+
+ uint32_t bits = sign_bit | f32_exponent_bits | f32_mantissa_bits;
+ float result = 0.0f;
+ static_assert(sizeof(result) == sizeof(bits),
+ "expected float and uint32_t to be the same size");
+ std::memcpy(&result, &bits, sizeof(bits));
+ return f16(result);
+}
+
+struct F32Data {
f32 value;
std::string expected;
};
-inline std::ostream& operator<<(std::ostream& out, FloatData data) {
+
+struct F16Data {
+ f16 value;
+ std::string expected;
+};
+
+inline std::ostream& operator<<(std::ostream& out, F32Data data) {
out << "{" << data.value << "," << data.expected << "}";
return out;
}
-using WgslGenerator_FloatLiteralTest = TestParamHelper<FloatData>;
+inline std::ostream& operator<<(std::ostream& out, F16Data data) {
+ out << "{" << data.value << "," << data.expected << "}";
+ return out;
+}
-TEST_P(WgslGenerator_FloatLiteralTest, Emit) {
+using WgslGenerator_F32LiteralTest = TestParamHelper<F32Data>;
+
+TEST_P(WgslGenerator_F32LiteralTest, Emit) {
auto* v = Expr(GetParam().value);
SetResolveOnBuild(false);
@@ -63,38 +120,37 @@
}
INSTANTIATE_TEST_SUITE_P(Zero,
- WgslGenerator_FloatLiteralTest,
- ::testing::ValuesIn(std::vector<FloatData>{
- {0_f, "0.0f"},
- {MakeFloat(0, 0, 0), "0.0f"},
- {MakeFloat(1, 0, 0), "-0.0f"}}));
+ WgslGenerator_F32LiteralTest,
+ ::testing::ValuesIn(std::vector<F32Data>{{0_f, "0.0f"},
+ {MakeF32(0, 0, 0), "0.0f"},
+ {MakeF32(1, 0, 0), "-0.0f"}}));
INSTANTIATE_TEST_SUITE_P(Normal,
- WgslGenerator_FloatLiteralTest,
- ::testing::ValuesIn(std::vector<FloatData>{{1_f, "1.0f"},
- {-1_f, "-1.0f"},
- {101.375_f, "101.375f"}}));
+ WgslGenerator_F32LiteralTest,
+ ::testing::ValuesIn(std::vector<F32Data>{{1_f, "1.0f"},
+ {-1_f, "-1.0f"},
+ {101.375_f, "101.375f"}}));
INSTANTIATE_TEST_SUITE_P(Subnormal,
- WgslGenerator_FloatLiteralTest,
- ::testing::ValuesIn(std::vector<FloatData>{
- {MakeFloat(0, 0, 1), "0x1p-149f"}, // Smallest
- {MakeFloat(1, 0, 1), "-0x1p-149f"},
- {MakeFloat(0, 0, 2), "0x1p-148f"},
- {MakeFloat(1, 0, 2), "-0x1p-148f"},
- {MakeFloat(0, 0, 0x7fffff), "0x1.fffffcp-127f"}, // Largest
- {MakeFloat(1, 0, 0x7fffff), "-0x1.fffffcp-127f"}, // Largest
- {MakeFloat(0, 0, 0xcafebe), "0x1.2bfaf8p-127f"}, // Scattered bits
- {MakeFloat(1, 0, 0xcafebe), "-0x1.2bfaf8p-127f"}, // Scattered bits
- {MakeFloat(0, 0, 0xaaaaa), "0x1.55554p-130f"}, // Scattered bits
- {MakeFloat(1, 0, 0xaaaaa), "-0x1.55554p-130f"}, // Scattered bits
+ WgslGenerator_F32LiteralTest,
+ ::testing::ValuesIn(std::vector<F32Data>{
+ {MakeF32(0, 0, 1), "0x1p-149f"}, // Smallest
+ {MakeF32(1, 0, 1), "-0x1p-149f"},
+ {MakeF32(0, 0, 2), "0x1p-148f"},
+ {MakeF32(1, 0, 2), "-0x1p-148f"},
+ {MakeF32(0, 0, 0x7fffff), "0x1.fffffcp-127f"}, // Largest
+ {MakeF32(1, 0, 0x7fffff), "-0x1.fffffcp-127f"}, // Largest
+ {MakeF32(0, 0, 0xcafebe), "0x1.2bfaf8p-127f"}, // Scattered bits
+ {MakeF32(1, 0, 0xcafebe), "-0x1.2bfaf8p-127f"}, // Scattered bits
+ {MakeF32(0, 0, 0xaaaaa), "0x1.55554p-130f"}, // Scattered bits
+ {MakeF32(1, 0, 0xaaaaa), "-0x1.55554p-130f"}, // Scattered bits
}));
INSTANTIATE_TEST_SUITE_P(Infinity,
- WgslGenerator_FloatLiteralTest,
- ::testing::ValuesIn(std::vector<FloatData>{
- {MakeFloat(0, 255, 0), "0x1p+128f"},
- {MakeFloat(1, 255, 0), "-0x1p+128f"}}));
+ WgslGenerator_F32LiteralTest,
+ ::testing::ValuesIn(std::vector<F32Data>{
+ {MakeF32(0, 255, 0), "0x1p+128f"},
+ {MakeF32(1, 255, 0), "-0x1p+128f"}}));
INSTANTIATE_TEST_SUITE_P(
// TODO(dneto): It's unclear how Infinity and NaN should be handled.
@@ -106,23 +162,95 @@
// whether the NaN is signalling or quiet, but no agreement between
// different machine architectures on whether 1 means signalling or
// if 1 means quiet.
- WgslGenerator_FloatLiteralTest,
- ::testing::ValuesIn(std::vector<FloatData>{
+ WgslGenerator_F32LiteralTest,
+ ::testing::ValuesIn(std::vector<F32Data>{
// LSB only. Smallest mantissa.
- {MakeFloat(0, 255, 1), "0x1.000002p+128f"}, // Smallest mantissa
- {MakeFloat(1, 255, 1), "-0x1.000002p+128f"},
+ {MakeF32(0, 255, 1), "0x1.000002p+128f"}, // Smallest mantissa
+ {MakeF32(1, 255, 1), "-0x1.000002p+128f"},
// MSB only.
- {MakeFloat(0, 255, 0x400000), "0x1.8p+128f"},
- {MakeFloat(1, 255, 0x400000), "-0x1.8p+128f"},
+ {MakeF32(0, 255, 0x400000), "0x1.8p+128f"},
+ {MakeF32(1, 255, 0x400000), "-0x1.8p+128f"},
// All 1s in the mantissa.
- {MakeFloat(0, 255, 0x7fffff), "0x1.fffffep+128f"},
- {MakeFloat(1, 255, 0x7fffff), "-0x1.fffffep+128f"},
+ {MakeF32(0, 255, 0x7fffff), "0x1.fffffep+128f"},
+ {MakeF32(1, 255, 0x7fffff), "-0x1.fffffep+128f"},
// Scattered bits, with 0 in top mantissa bit.
- {MakeFloat(0, 255, 0x20101f), "0x1.40203ep+128f"},
- {MakeFloat(1, 255, 0x20101f), "-0x1.40203ep+128f"},
+ {MakeF32(0, 255, 0x20101f), "0x1.40203ep+128f"},
+ {MakeF32(1, 255, 0x20101f), "-0x1.40203ep+128f"},
// Scattered bits, with 1 in top mantissa bit.
- {MakeFloat(0, 255, 0x40101f), "0x1.80203ep+128f"},
- {MakeFloat(1, 255, 0x40101f), "-0x1.80203ep+128f"}}));
+ {MakeF32(0, 255, 0x40101f), "0x1.80203ep+128f"},
+ {MakeF32(1, 255, 0x40101f), "-0x1.80203ep+128f"}}));
+
+using WgslGenerator_F16LiteralTest = TestParamHelper<F16Data>;
+
+TEST_P(WgslGenerator_F16LiteralTest, Emit) {
+ Enable(ast::Extension::kF16);
+
+ auto* v = Expr(GetParam().value);
+
+ SetResolveOnBuild(false);
+ GeneratorImpl& gen = Build();
+
+ std::stringstream out;
+ ASSERT_TRUE(gen.EmitLiteral(out, v)) << gen.error();
+ EXPECT_EQ(out.str(), GetParam().expected);
+}
+
+INSTANTIATE_TEST_SUITE_P(Zero,
+ WgslGenerator_F16LiteralTest,
+ ::testing::ValuesIn(std::vector<F16Data>{{0_h, "0.0h"},
+ {MakeF16(0, 0, 0), "0.0h"},
+ {MakeF16(1, 0, 0), "-0.0h"}}));
+
+INSTANTIATE_TEST_SUITE_P(Normal,
+ WgslGenerator_F16LiteralTest,
+ ::testing::ValuesIn(std::vector<F16Data>{{1_h, "1.0h"},
+ {-1_h, "-1.0h"},
+ {101.375_h, "101.375h"}}));
+
+INSTANTIATE_TEST_SUITE_P(Subnormal,
+ WgslGenerator_F16LiteralTest,
+ ::testing::ValuesIn(std::vector<F16Data>{
+ {MakeF16(0, 0, 1), "5.96046448e-08h"}, // Smallest
+ {MakeF16(1, 0, 1), "-5.96046448e-08h"},
+ {MakeF16(0, 0, 2), "1.1920929e-07h"},
+ {MakeF16(1, 0, 2), "-1.1920929e-07h"},
+ {MakeF16(0, 0, 0x3ffu), "6.09755516e-05h"}, // Largest
+ {MakeF16(1, 0, 0x3ffu), "-6.09755516e-05h"}, // Largest
+ {MakeF16(0, 0, 0x3afu), "5.620718e-05h"}, // Scattered bits
+ {MakeF16(1, 0, 0x3afu), "-5.620718e-05h"}, // Scattered bits
+ {MakeF16(0, 0, 0x2c7u), "4.23789024e-05h"}, // Scattered bits
+ {MakeF16(1, 0, 0x2c7u), "-4.23789024e-05h"}, // Scattered bits
+ }));
+
+INSTANTIATE_TEST_SUITE_P(
+ // Currently Inf is impossible to be spelled out in literal.
+ // https://github.com/gpuweb/gpuweb/issues/1769
+ DISABLED_Infinity,
+ WgslGenerator_F16LiteralTest,
+ ::testing::ValuesIn(std::vector<F16Data>{{MakeF16(0, 31, 0), "0x1p+128h"},
+ {MakeF16(1, 31, 0), "-0x1p+128h"}}));
+
+INSTANTIATE_TEST_SUITE_P(
+ // Currently NaN is impossible to be spelled out in literal.
+ // https://github.com/gpuweb/gpuweb/issues/1769
+ DISABLED_NaN,
+ WgslGenerator_F16LiteralTest,
+ ::testing::ValuesIn(std::vector<F16Data>{
+ // LSB only. Smallest mantissa.
+ {MakeF16(0, 31, 1), "0x1.004p+128h"}, // Smallest mantissa
+ {MakeF16(1, 31, 1), "-0x1.004p+128h"},
+ // MSB only.
+ {MakeF16(0, 31, 0x200u), "0x1.8p+128h"},
+ {MakeF16(1, 31, 0x200u), "-0x1.8p+128h"},
+ // All 1s in the mantissa.
+ {MakeF16(0, 31, 0x3ffu), "0x1.ffcp+128h"},
+ {MakeF16(1, 31, 0x3ffu), "-0x1.ffcp+128h"},
+ // Scattered bits, with 0 in top mantissa bit.
+ {MakeF16(0, 31, 0x11fu), "0x1.47cp+128h"},
+ {MakeF16(1, 31, 0x11fu), "-0x1.47cp+128h"},
+ // Scattered bits, with 1 in top mantissa bit.
+ {MakeF16(0, 31, 0x23fu), "0x1.8fcp+128h"},
+ {MakeF16(1, 31, 0x23fu), "-0x1.8fcp+128h"}}));
} // namespace
} // namespace tint::writer::wgsl
diff --git a/src/tint/writer/wgsl/generator_impl_type_test.cc b/src/tint/writer/wgsl/generator_impl_type_test.cc
index b6ae9c5..2e54620 100644
--- a/src/tint/writer/wgsl/generator_impl_type_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_type_test.cc
@@ -91,6 +91,19 @@
EXPECT_EQ(out.str(), "f32");
}
+TEST_F(WgslGeneratorImplTest, EmitType_F16) {
+ Enable(ast::Extension::kF16);
+
+ auto* f16 = ty.f16();
+ Alias("make_type_reachable", f16);
+
+ GeneratorImpl& gen = Build();
+
+ std::stringstream out;
+ ASSERT_TRUE(gen.EmitType(out, f16)) << gen.error();
+ EXPECT_EQ(out.str(), "f16");
+}
+
TEST_F(WgslGeneratorImplTest, EmitType_I32) {
auto* i32 = ty.i32();
Alias("make_type_reachable", i32);
@@ -102,7 +115,7 @@
EXPECT_EQ(out.str(), "i32");
}
-TEST_F(WgslGeneratorImplTest, EmitType_Matrix) {
+TEST_F(WgslGeneratorImplTest, EmitType_Matrix_F32) {
auto* mat2x3 = ty.mat2x3<f32>();
Alias("make_type_reachable", mat2x3);
@@ -113,6 +126,19 @@
EXPECT_EQ(out.str(), "mat2x3<f32>");
}
+TEST_F(WgslGeneratorImplTest, EmitType_Matrix_F16) {
+ Enable(ast::Extension::kF16);
+
+ auto* mat2x3 = ty.mat2x3<f16>();
+ Alias("make_type_reachable", mat2x3);
+
+ GeneratorImpl& gen = Build();
+
+ std::stringstream out;
+ ASSERT_TRUE(gen.EmitType(out, mat2x3)) << gen.error();
+ EXPECT_EQ(out.str(), "mat2x3<f16>");
+}
+
TEST_F(WgslGeneratorImplTest, EmitType_Pointer) {
auto* p = ty.pointer<f32>(ast::StorageClass::kWorkgroup);
Alias("make_type_reachable", p);
@@ -271,7 +297,7 @@
EXPECT_EQ(out.str(), "u32");
}
-TEST_F(WgslGeneratorImplTest, EmitType_Vector) {
+TEST_F(WgslGeneratorImplTest, EmitType_Vector_F32) {
auto* vec3 = ty.vec3<f32>();
Alias("make_type_reachable", vec3);
@@ -282,6 +308,19 @@
EXPECT_EQ(out.str(), "vec3<f32>");
}
+TEST_F(WgslGeneratorImplTest, EmitType_Vector_F16) {
+ Enable(ast::Extension::kF16);
+
+ auto* vec3 = ty.vec3<f16>();
+ Alias("make_type_reachable", vec3);
+
+ GeneratorImpl& gen = Build();
+
+ std::stringstream out;
+ ASSERT_TRUE(gen.EmitType(out, vec3)) << gen.error();
+ EXPECT_EQ(out.str(), "vec3<f16>");
+}
+
struct TextureData {
ast::TextureDimension dim;
const char* name;
diff --git a/src/tint/writer/wgsl/generator_impl_variable_decl_statement_test.cc b/src/tint/writer/wgsl/generator_impl_variable_decl_statement_test.cc
index aeacd30..abfe4fec 100644
--- a/src/tint/writer/wgsl/generator_impl_variable_decl_statement_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_variable_decl_statement_test.cc
@@ -125,6 +125,25 @@
)");
}
+TEST_F(WgslGeneratorImplTest, Emit_VariableDeclStatement_Const_f16) {
+ Enable(ast::Extension::kF16);
+
+ auto* C = Const("C", nullptr, Expr(1_h));
+ Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
+
+ GeneratorImpl& gen = Build();
+
+ ASSERT_TRUE(gen.Generate()) << gen.error();
+
+ EXPECT_EQ(gen.result(), R"(enable f16;
+
+fn f() {
+ const C = 1.0h;
+ let l = C;
+}
+)");
+}
+
TEST_F(WgslGeneratorImplTest, Emit_VariableDeclStatement_Const_vec3_AInt) {
auto* C = Const("C", nullptr, Construct(ty.vec3(nullptr), 1_a, 2_a, 3_a));
Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
@@ -170,6 +189,25 @@
)");
}
+TEST_F(WgslGeneratorImplTest, Emit_VariableDeclStatement_Const_vec3_f16) {
+ Enable(ast::Extension::kF16);
+
+ auto* C = Const("C", nullptr, vec3<f16>(1_h, 2_h, 3_h));
+ Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
+
+ GeneratorImpl& gen = Build();
+
+ ASSERT_TRUE(gen.Generate()) << gen.error();
+
+ EXPECT_EQ(gen.result(), R"(enable f16;
+
+fn f() {
+ const C = vec3<f16>(1.0h, 2.0h, 3.0h);
+ let l = C;
+}
+)");
+}
+
TEST_F(WgslGeneratorImplTest, Emit_VariableDeclStatement_Const_mat2x3_AFloat) {
auto* C =
Const("C", nullptr, Construct(ty.mat(nullptr, 2, 3), 1._a, 2._a, 3._a, 4._a, 5._a, 6._a));
@@ -201,6 +239,25 @@
)");
}
+TEST_F(WgslGeneratorImplTest, Emit_VariableDeclStatement_Const_mat2x3_f16) {
+ Enable(ast::Extension::kF16);
+
+ auto* C = Const("C", nullptr, mat2x3<f16>(1_h, 2_h, 3_h, 4_h, 5_h, 6_h));
+ Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
+
+ GeneratorImpl& gen = Build();
+
+ ASSERT_TRUE(gen.Generate()) << gen.error();
+
+ EXPECT_EQ(gen.result(), R"(enable f16;
+
+fn f() {
+ const C = mat2x3<f16>(1.0h, 2.0h, 3.0h, 4.0h, 5.0h, 6.0h);
+ let l = C;
+}
+)");
+}
+
TEST_F(WgslGeneratorImplTest, Emit_VariableDeclStatement_Const_arr_f32) {
auto* C = Const("C", nullptr, Construct(ty.array<f32, 3>(), 1_f, 2_f, 3_f));
Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});