Add support for parsing hex floats to WGSL frontend
As per https://gpuweb.github.io/gpuweb/wgsl/#literals, HEX_FLOAT_LITERAL
token.
Bug: tint:77
Change-Id: I09105df15a1888c2f0c84d7cccd2cc53e596f5cc
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58781
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/reader/wgsl/lexer.cc b/src/reader/wgsl/lexer.cc
index c4eed5f..9ddcdb8 100644
--- a/src/reader/wgsl/lexer.cc
+++ b/src/reader/wgsl/lexer.cc
@@ -14,8 +14,12 @@
#include "src/reader/wgsl/lexer.h"
+#include <cmath>
+#include <cstring>
#include <limits>
+#include "src/debug.h"
+
namespace tint {
namespace reader {
namespace wgsl {
@@ -25,6 +29,26 @@
return std::isspace(c);
}
+uint32_t dec_value(char c) {
+ if (c >= '0' && c <= '9') {
+ return static_cast<uint32_t>(c - '0');
+ }
+ return 0;
+}
+
+uint32_t hex_value(char c) {
+ if (c >= '0' && c <= '9') {
+ return static_cast<uint32_t>(c - '0');
+ }
+ if (c >= 'a' && c <= 'f') {
+ return 0xA + static_cast<uint32_t>(c - 'a');
+ }
+ if (c >= 'A' && c <= 'F') {
+ return 0xA + static_cast<uint32_t>(c - 'A');
+ }
+ return 0;
+}
+
} // namespace
Lexer::Lexer(const std::string& file_path, const Source::FileContent* content)
@@ -43,7 +67,12 @@
return {Token::Type::kEOF, begin_source()};
}
- auto t = try_hex_integer();
+ auto t = try_hex_float();
+ if (!t.IsUninitialized()) {
+ return t;
+ }
+
+ t = try_hex_integer();
if (!t.IsUninitialized()) {
return t;
}
@@ -239,6 +268,225 @@
return {source, static_cast<float>(res)};
}
+Token Lexer::try_hex_float() {
+ constexpr uint32_t kTotalBits = 32;
+ constexpr uint32_t kTotalMsb = kTotalBits - 1;
+ constexpr uint32_t kMantissaBits = 23;
+ constexpr uint32_t kMantissaMsb = kMantissaBits - 1;
+ constexpr uint32_t kMantissaShiftRight = kTotalBits - kMantissaBits;
+ constexpr int32_t kExponentBias = 127;
+ constexpr int32_t kExponentMax = 255;
+ constexpr uint32_t kExponentBits = 8;
+ constexpr uint32_t kExponentMask = (1 << kExponentBits) - 1;
+ constexpr uint32_t kExponentLeftShift = kMantissaBits;
+ constexpr uint32_t kSignBit = 31;
+
+ auto start = pos_;
+ auto end = pos_;
+
+ auto source = begin_source();
+
+ // clang-format off
+ // -?0x([0-9a-fA-F]*.?[0-9a-fA-F]+ | [0-9a-fA-F]+.[0-9a-fA-F]*)(p|P)(+|-)?[0-9]+ // NOLINT
+ // clang-format on
+
+ // -?
+ int32_t sign_bit = 0;
+ if (matches(end, "-")) {
+ sign_bit = 1;
+ end++;
+ }
+ // 0x
+ if (matches(end, "0x")) {
+ end += 2;
+ } else {
+ return {};
+ }
+
+ uint32_t mantissa = 0;
+ int32_t exponent = 0;
+
+ // `set_next_mantissa_bit_to` sets next `mantissa` bit starting from msb to
+ // lsb to value 1 if `set` is true, 0 otherwise
+ uint32_t mantissa_next_bit = kTotalMsb;
+ auto set_next_mantissa_bit_to = [&](bool set) -> bool {
+ if (mantissa_next_bit > kTotalMsb) {
+ return false; // Overflowed mantissa
+ }
+ if (set) {
+ mantissa |= (1 << mantissa_next_bit);
+ }
+ --mantissa_next_bit;
+ return true;
+ };
+
+ // Parse integer part
+ // [0-9a-fA-F]*
+ bool has_integer = false;
+ bool has_zero_integer = true;
+ bool leading_bit_seen = false;
+ while (end < len_ && is_hex(content_->data[end])) {
+ has_integer = true;
+
+ const auto nibble = hex_value(content_->data[end]);
+ if (nibble != 0) {
+ has_zero_integer = false;
+ }
+
+ for (int32_t i = 3; i >= 0; --i) {
+ auto v = 1 & (nibble >> i);
+
+ // Skip leading 0s and the first 1
+ if (leading_bit_seen) {
+ if (!set_next_mantissa_bit_to(v != 0)) {
+ return {};
+ }
+ ++exponent;
+ } else {
+ if (v == 1) {
+ leading_bit_seen = true;
+ }
+ }
+ }
+
+ end++;
+ }
+
+ // .?
+ if (matches(end, ".")) {
+ end++;
+ }
+
+ // Parse fractional part
+ // [0-9a-fA-F]*
+ bool has_fractional = false;
+ leading_bit_seen = false;
+ while (end < len_ && is_hex(content_->data[end])) {
+ has_fractional = true;
+ auto nibble = hex_value(content_->data[end]);
+ for (int32_t i = 3; i >= 0; --i) {
+ auto v = 1 & (nibble >> i);
+
+ if (v == 1) {
+ leading_bit_seen = true;
+ }
+
+ // If integer part is 0 (denorm), we only start writing bits to the
+ // mantissa once we have a non-zero fractional bit. While the fractional
+ // values are 0, we adjust the exponent to avoid overflowing `mantissa`.
+ if (has_zero_integer && !leading_bit_seen) {
+ --exponent;
+ } else {
+ if (!set_next_mantissa_bit_to(v != 0)) {
+ return {};
+ }
+ }
+ }
+
+ end++;
+ }
+
+ if (!(has_integer || has_fractional)) {
+ return {};
+ }
+
+ // (p|P)
+ if (matches(end, "p") || matches(end, "P")) {
+ end++;
+ } else {
+ return {};
+ }
+
+ // (+|-)?
+ int32_t exponent_sign = 1;
+ if (matches(end, "+")) {
+ end++;
+ } else if (matches(end, "-")) {
+ exponent_sign = -1;
+ end++;
+ }
+
+ // Parse exponent from input
+ // [0-9]+
+ bool has_exponent = false;
+ int32_t input_exponent = 0;
+ while (end < len_ && isdigit(content_->data[end])) {
+ has_exponent = true;
+ input_exponent = (input_exponent * 10) + dec_value(content_->data[end]);
+ end++;
+ }
+ if (!has_exponent) {
+ return {};
+ }
+
+ pos_ = end;
+ location_.column += (end - start);
+ end_source(source);
+
+ // Compute exponent so far
+ exponent = exponent + (input_exponent * exponent_sign);
+
+ // Determine if value is zero
+ // Note: it's not enough to check mantissa == 0 as we drop initial bit from
+ // integer part.
+ bool is_zero = has_zero_integer && mantissa == 0;
+ TINT_ASSERT(Reader, !is_zero || (exponent == 0 && mantissa == 0));
+
+ if (!is_zero) {
+ // Bias exponent if non-zero
+ // After this, if exponent is <= 0, our value is a denormal
+ exponent += kExponentBias;
+
+ // Denormal uses biased exponent of -126, not -127
+ if (has_zero_integer) {
+ mantissa <<= 1;
+ --exponent;
+ }
+ }
+
+ // Shift mantissa to occupy the low 23 bits
+ mantissa >>= kMantissaShiftRight;
+
+ // If denormal, shift mantissa until our exponent is zero
+ if (!is_zero) {
+ // Denorm has exponent 0 and non-zero mantissa. We set the top bit here,
+ // then shift the mantissa to make exponent zero.
+ if (exponent <= 0) {
+ mantissa >>= 1;
+ mantissa |= (1 << kMantissaMsb);
+ }
+
+ while (exponent < 0) {
+ mantissa >>= 1;
+ ++exponent;
+
+ // If underflow, clamp to zero
+ if (mantissa == 0) {
+ exponent = 0;
+ }
+ }
+ }
+
+ if (exponent > kExponentMax) {
+ // Overflow: set to infinity
+ exponent = kExponentMax;
+ mantissa = 0;
+ } else if (exponent == kExponentMax && mantissa != 0) {
+ // NaN: set to infinity
+ mantissa = 0;
+ }
+
+ // Combine sign, mantissa, and exponent
+ uint32_t result_u32 = sign_bit << kSignBit;
+ result_u32 |= mantissa;
+ result_u32 |= (exponent & kExponentMask) << kExponentLeftShift;
+
+ // Reinterpret as float and return
+ float result;
+ std::memcpy(&result, &result_u32, sizeof(result));
+ return {source, static_cast<float>(result)};
+}
+
Token Lexer::build_token_from_int_if_possible(Source source,
size_t start,
size_t end,
diff --git a/src/reader/wgsl/lexer.h b/src/reader/wgsl/lexer.h
index b1774e9..9c96bb5 100644
--- a/src/reader/wgsl/lexer.h
+++ b/src/reader/wgsl/lexer.h
@@ -46,6 +46,7 @@
int32_t base);
Token check_keyword(const Source&, const std::string&);
Token try_float();
+ Token try_hex_float();
Token try_hex_integer();
Token try_ident();
Token try_integer();
diff --git a/src/reader/wgsl/parser_impl_const_literal_test.cc b/src/reader/wgsl/parser_impl_const_literal_test.cc
index d73abaa..4dee6ce 100644
--- a/src/reader/wgsl/parser_impl_const_literal_test.cc
+++ b/src/reader/wgsl/parser_impl_const_literal_test.cc
@@ -14,17 +14,39 @@
#include "src/reader/wgsl/parser_impl_test_helper.h"
+#include <cmath>
+#include <cstring>
+
namespace tint {
namespace reader {
namespace wgsl {
namespace {
+// Makes an IEEE 754 binary32 floating point number with
+// - 0 sign if sign is 0, 1 otherwise
+// - 'exponent_bits' is placed in the exponent space.
+// So, the exponent bias must already be included.
+float MakeFloat(int sign, int biased_exponent, int mantissa) {
+ const uint32_t sign_bit = sign ? 0x80000000u : 0u;
+ // The binary32 exponent is 8 bits, just below the sign.
+ const uint32_t exponent_bits = (biased_exponent & 0xffu) << 23;
+ // The mantissa is the bottom 23 bits.
+ const uint32_t mantissa_bits = (mantissa & 0x7fffffu);
+
+ uint32_t bits = sign_bit | exponent_bits | mantissa_bits;
+ float result = 0.0f;
+ static_assert(sizeof(result) == sizeof(bits),
+ "expected float and uint32_t to be the same size");
+ std::memcpy(&result, &bits, sizeof(bits));
+ return result;
+}
+
TEST_F(ParserImplTest, ConstLiteral_Int) {
auto p = parser("-234");
auto c = p->const_literal();
EXPECT_TRUE(c.matched);
EXPECT_FALSE(c.errored);
- EXPECT_FALSE(p->has_error());
+ EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(c.value, nullptr);
ASSERT_TRUE(c->Is<ast::SintLiteral>());
EXPECT_EQ(c->As<ast::SintLiteral>()->value(), -234);
@@ -36,7 +58,7 @@
auto c = p->const_literal();
EXPECT_TRUE(c.matched);
EXPECT_FALSE(c.errored);
- EXPECT_FALSE(p->has_error());
+ EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(c.value, nullptr);
ASSERT_TRUE(c->Is<ast::UintLiteral>());
EXPECT_EQ(c->As<ast::UintLiteral>()->value(), 234u);
@@ -48,7 +70,7 @@
auto c = p->const_literal();
EXPECT_TRUE(c.matched);
EXPECT_FALSE(c.errored);
- EXPECT_FALSE(p->has_error());
+ EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(c.value, nullptr);
ASSERT_TRUE(c->Is<ast::FloatLiteral>());
EXPECT_FLOAT_EQ(c->As<ast::FloatLiteral>()->value(), 234e12f);
@@ -63,12 +85,221 @@
ASSERT_EQ(c.value, nullptr);
}
+struct FloatLiteralTestCase {
+ const char* input;
+ float expected;
+};
+
+inline std::ostream& operator<<(std::ostream& out, FloatLiteralTestCase data) {
+ out << data.input;
+ return out;
+}
+
+class ParserImplFloatLiteralTest
+ : public ParserImplTestWithParam<FloatLiteralTestCase> {};
+TEST_P(ParserImplFloatLiteralTest, Parse) {
+ auto params = GetParam();
+ SCOPED_TRACE(params.input);
+ auto p = parser(params.input);
+ auto c = p->const_literal();
+ EXPECT_TRUE(c.matched);
+ EXPECT_FALSE(c.errored);
+ EXPECT_FALSE(p->has_error()) << p->error();
+ ASSERT_NE(c.value, nullptr);
+ ASSERT_TRUE(c->Is<ast::FloatLiteral>());
+ EXPECT_FLOAT_EQ(c->As<ast::FloatLiteral>()->value(), params.expected);
+}
+
+FloatLiteralTestCase float_literal_test_cases[] = {
+ {"0.0", 0.0f}, // Zero
+ {"1.0", 1.0f}, // One
+ {"-1.0", -1.0f}, // MinusOne
+ {"1000000000.0", 1e9f}, // Billion
+ {"-0.0", std::copysign(0.0f, -5.0f)}, // NegativeZero
+ {"0.0", MakeFloat(0, 0, 0)}, // Zero
+ {"-0.0", MakeFloat(1, 0, 0)}, // NegativeZero
+ {"1.0", MakeFloat(0, 127, 0)}, // One
+ {"-1.0", MakeFloat(1, 127, 0)}, // NegativeOne
+};
+INSTANTIATE_TEST_SUITE_P(ParserImplFloatLiteralTest_Float,
+ ParserImplFloatLiteralTest,
+ testing::ValuesIn(float_literal_test_cases));
+
+const float NegInf = MakeFloat(1, 255, 0);
+const float PosInf = MakeFloat(0, 255, 0);
+FloatLiteralTestCase hexfloat_literal_test_cases[] = {
+ // Regular numbers
+ {"0x0p+0", 0.f},
+ {"0x1p+0", 1.f},
+ {"0x1p+1", 2.f},
+ {"0x1.8p+1", 3.f},
+ {"0x1.99999ap-4", 0.1f},
+ {"0x1p-1", 0.5f},
+ {"0x1p-2", 0.25f},
+ {"0x1.8p-1", 0.75f},
+ {"-0x0p+0", -0.f},
+ {"-0x1p+0", -1.f},
+ {"-0x1p-1", -0.5f},
+ {"-0x1p-2", -0.25f},
+ {"-0x1.8p-1", -0.75f},
+
+ // Large numbers
+ {"0x1p+9", 512.f},
+ {"0x1p+10", 1024.f},
+ {"0x1.02p+10", 1024.f + 8.f},
+ {"-0x1p+9", -512.f},
+ {"-0x1p+10", -1024.f},
+ {"-0x1.02p+10", -1024.f - 8.f},
+
+ // Small numbers
+ {"0x1p-9", 1.0f / 512.f},
+ {"0x1p-10", 1.0f / 1024.f},
+ {"0x1.02p-3", 1.0f / 1024.f + 1.0f / 8.f},
+ {"-0x1p-9", 1.0f / -512.f},
+ {"-0x1p-10", 1.0f / -1024.f},
+ {"-0x1.02p-3", 1.0f / -1024.f - 1.0f / 8.f},
+
+ // Near lowest non-denorm
+ {"0x1p-124", std::ldexp(1.f * 8.f, -127)},
+ {"0x1p-125", std::ldexp(1.f * 4.f, -127)},
+ {"-0x1p-124", -std::ldexp(1.f * 8.f, -127)},
+ {"-0x1p-125", -std::ldexp(1.f * 4.f, -127)},
+
+ // Lowest non-denorm
+ {"0x1p-126", std::ldexp(1.f * 2.f, -127)},
+ {"-0x1p-126", -std::ldexp(1.f * 2.f, -127)},
+
+ // Denormalized values
+ {"0x1p-127", std::ldexp(1.f, -127)},
+ {"0x1p-128", std::ldexp(1.f / 2.f, -127)},
+ {"0x1p-129", std::ldexp(1.f / 4.f, -127)},
+ {"0x1p-130", std::ldexp(1.f / 8.f, -127)},
+ {"-0x1p-127", -std::ldexp(1.f, -127)},
+ {"-0x1p-128", -std::ldexp(1.f / 2.f, -127)},
+ {"-0x1p-129", -std::ldexp(1.f / 4.f, -127)},
+ {"-0x1p-130", -std::ldexp(1.f / 8.f, -127)},
+
+ {"0x1.8p-127", std::ldexp(1.f, -127) + (std::ldexp(1.f, -127) / 2.f)},
+ {"0x1.8p-128", std::ldexp(1.f, -127) / 2.f + (std::ldexp(1.f, -127) / 4.f)},
+
+ {"0x1p-149", MakeFloat(0, 0, 1)}, // +SmallestDenormal
+ {"0x1p-148", MakeFloat(0, 0, 2)}, // +BiggerDenormal
+ {"0x1.fffffcp-127", MakeFloat(0, 0, 0x7fffff)}, // +LargestDenormal
+ {"-0x1p-149", MakeFloat(1, 0, 1)}, // -SmallestDenormal
+ {"-0x1p-148", MakeFloat(1, 0, 2)}, // -BiggerDenormal
+ {"-0x1.fffffcp-127", MakeFloat(1, 0, 0x7fffff)}, // -LargestDenormal
+
+ {"0x1.2bfaf8p-127", MakeFloat(0, 0, 0xcafebe)}, // +Subnormal
+ {"-0x1.2bfaf8p-127", MakeFloat(1, 0, 0xcafebe)}, // -Subnormal
+ {"0x1.55554p-130", MakeFloat(0, 0, 0xaaaaa)}, // +Subnormal
+ {"-0x1.55554p-130", MakeFloat(1, 0, 0xaaaaa)}, // -Subnormal
+
+ // Nan -> Infinity
+ {"0x1.8p+128", PosInf},
+ {"0x1.0002p+128", PosInf},
+ {"0x1.0018p+128", PosInf},
+ {"0x1.01ep+128", PosInf},
+ {"0x1.fffffep+128", PosInf},
+ {"-0x1.8p+128", NegInf},
+ {"-0x1.0002p+128", NegInf},
+ {"-0x1.0018p+128", NegInf},
+ {"-0x1.01ep+128", NegInf},
+ {"-0x1.fffffep+128", NegInf},
+
+ // Infinity
+ {"0x1p+128", PosInf},
+ {"-0x1p+128", NegInf},
+ {"0x32p+127", PosInf},
+ {"0x32p+500", PosInf},
+ {"-0x32p+127", NegInf},
+ {"-0x32p+500", NegInf},
+
+ // Overflow -> Infinity
+ {"0x1p+129", PosInf},
+ {"0x1.1p+128", PosInf},
+ {"-0x1p+129", NegInf},
+ {"-0x1.1p+128", NegInf},
+
+ // Underflow -> Zero
+ {"0x1p-500", 0.f}, // Exponent underflows
+ {"-0x1p-500", -0.f},
+ {"0x0.00000000001p-126", 0.f}, // Fraction causes underflow
+ {"-0x0.0000000001p-127", -0.f},
+ {"0x0.01p-142", 0.f},
+ {"-0x0.01p-142", -0.f}, // Fraction causes additional underflow
+
+ // Test parsing
+ {"0x0p0", 0.f},
+ {"0x0p-0", 0.f},
+ {"0x0p+000", 0.f},
+ {"0x00000000000000p+000000000000000", 0.f},
+ {"0x00000000000000p-000000000000000", 0.f},
+ {"0x00000000000001p+000000000000000", 1.f},
+ {"0x00000000000001p-000000000000000", 1.f},
+ {"0x0000000000000000000001.99999ap-000000000000000004", 0.1f},
+ {"0x2p+0", 2.f},
+ {"0xFFp+0", 255.f},
+ {"0x0.8p+0", 0.5f},
+ {"0x0.4p+0", 0.25f},
+ {"0x0.4p+1", 2 * 0.25f},
+ {"0x0.4p+2", 4 * 0.25f},
+ {"0x123Ep+1", 9340.f},
+ {"-0x123Ep+1", -9340.f},
+ {"0x1a2b3cP12", 7.024656e+09f},
+ {"-0x1a2b3cP12", -7.024656e+09f},
+};
+INSTANTIATE_TEST_SUITE_P(ParserImplFloatLiteralTest_HexFloat,
+ ParserImplFloatLiteralTest,
+ testing::ValuesIn(hexfloat_literal_test_cases));
+
+TEST_F(ParserImplTest, ConstLiteral_FloatHighest) {
+ const auto highest = std::numeric_limits<float>::max();
+ const auto expected_highest = 340282346638528859811704183484516925440.0f;
+ if (highest < expected_highest || highest > expected_highest) {
+ GTEST_SKIP() << "std::numeric_limits<float>::max() is not as expected for "
+ "this target";
+ }
+ auto p = parser("340282346638528859811704183484516925440.0");
+ auto c = p->const_literal();
+ EXPECT_TRUE(c.matched);
+ EXPECT_FALSE(c.errored);
+ EXPECT_FALSE(p->has_error()) << p->error();
+ ASSERT_NE(c.value, nullptr);
+ ASSERT_TRUE(c->Is<ast::FloatLiteral>());
+ EXPECT_FLOAT_EQ(c->As<ast::FloatLiteral>()->value(),
+ std::numeric_limits<float>::max());
+ EXPECT_EQ(c->source().range, (Source::Range{{1u, 1u}, {1u, 42u}}));
+}
+
+TEST_F(ParserImplTest, ConstLiteral_FloatLowest) {
+ // Some compilers complain if you test floating point numbers for equality.
+ // So say it via two inequalities.
+ const auto lowest = std::numeric_limits<float>::lowest();
+ const auto expected_lowest = -340282346638528859811704183484516925440.0f;
+ if (lowest < expected_lowest || lowest > expected_lowest) {
+ GTEST_SKIP()
+ << "std::numeric_limits<float>::lowest() is not as expected for "
+ "this target";
+ }
+
+ auto p = parser("-340282346638528859811704183484516925440.0");
+ auto c = p->const_literal();
+ EXPECT_TRUE(c.matched);
+ EXPECT_FALSE(c.errored);
+ EXPECT_FALSE(p->has_error()) << p->error();
+ ASSERT_NE(c.value, nullptr);
+ ASSERT_TRUE(c->Is<ast::FloatLiteral>());
+ EXPECT_FLOAT_EQ(c->As<ast::FloatLiteral>()->value(),
+ std::numeric_limits<float>::lowest());
+ EXPECT_EQ(c->source().range, (Source::Range{{1u, 1u}, {1u, 43u}}));
+}
+
TEST_F(ParserImplTest, ConstLiteral_True) {
auto p = parser("true");
auto c = p->const_literal();
EXPECT_TRUE(c.matched);
EXPECT_FALSE(c.errored);
- EXPECT_FALSE(p->has_error());
+ EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(c.value, nullptr);
ASSERT_TRUE(c->Is<ast::BoolLiteral>());
EXPECT_TRUE(c->As<ast::BoolLiteral>()->IsTrue());
@@ -80,7 +311,7 @@
auto c = p->const_literal();
EXPECT_TRUE(c.matched);
EXPECT_FALSE(c.errored);
- EXPECT_FALSE(p->has_error());
+ EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(c.value, nullptr);
ASSERT_TRUE(c->Is<ast::BoolLiteral>());
EXPECT_TRUE(c->As<ast::BoolLiteral>()->IsFalse());
@@ -92,7 +323,7 @@
auto c = p->const_literal();
EXPECT_FALSE(c.matched);
EXPECT_FALSE(c.errored);
- EXPECT_FALSE(p->has_error());
+ EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_EQ(c.value, nullptr);
}