Add support for parsing hex floats to WGSL frontend

As per https://gpuweb.github.io/gpuweb/wgsl/#literals, HEX_FLOAT_LITERAL
token.

Bug: tint:77
Change-Id: I09105df15a1888c2f0c84d7cccd2cc53e596f5cc
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58781
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/reader/wgsl/lexer.cc b/src/reader/wgsl/lexer.cc
index c4eed5f..9ddcdb8 100644
--- a/src/reader/wgsl/lexer.cc
+++ b/src/reader/wgsl/lexer.cc
@@ -14,8 +14,12 @@
 
 #include "src/reader/wgsl/lexer.h"
 
+#include <cmath>
+#include <cstring>
 #include <limits>
 
+#include "src/debug.h"
+
 namespace tint {
 namespace reader {
 namespace wgsl {
@@ -25,6 +29,26 @@
   return std::isspace(c);
 }
 
+uint32_t dec_value(char c) {
+  if (c >= '0' && c <= '9') {
+    return static_cast<uint32_t>(c - '0');
+  }
+  return 0;
+}
+
+uint32_t hex_value(char c) {
+  if (c >= '0' && c <= '9') {
+    return static_cast<uint32_t>(c - '0');
+  }
+  if (c >= 'a' && c <= 'f') {
+    return 0xA + static_cast<uint32_t>(c - 'a');
+  }
+  if (c >= 'A' && c <= 'F') {
+    return 0xA + static_cast<uint32_t>(c - 'A');
+  }
+  return 0;
+}
+
 }  // namespace
 
 Lexer::Lexer(const std::string& file_path, const Source::FileContent* content)
@@ -43,7 +67,12 @@
     return {Token::Type::kEOF, begin_source()};
   }
 
-  auto t = try_hex_integer();
+  auto t = try_hex_float();
+  if (!t.IsUninitialized()) {
+    return t;
+  }
+
+  t = try_hex_integer();
   if (!t.IsUninitialized()) {
     return t;
   }
@@ -239,6 +268,225 @@
   return {source, static_cast<float>(res)};
 }
 
+Token Lexer::try_hex_float() {
+  constexpr uint32_t kTotalBits = 32;
+  constexpr uint32_t kTotalMsb = kTotalBits - 1;
+  constexpr uint32_t kMantissaBits = 23;
+  constexpr uint32_t kMantissaMsb = kMantissaBits - 1;
+  constexpr uint32_t kMantissaShiftRight = kTotalBits - kMantissaBits;
+  constexpr int32_t kExponentBias = 127;
+  constexpr int32_t kExponentMax = 255;
+  constexpr uint32_t kExponentBits = 8;
+  constexpr uint32_t kExponentMask = (1 << kExponentBits) - 1;
+  constexpr uint32_t kExponentLeftShift = kMantissaBits;
+  constexpr uint32_t kSignBit = 31;
+
+  auto start = pos_;
+  auto end = pos_;
+
+  auto source = begin_source();
+
+  // clang-format off
+  // -?0x([0-9a-fA-F]*.?[0-9a-fA-F]+ | [0-9a-fA-F]+.[0-9a-fA-F]*)(p|P)(+|-)?[0-9]+  // NOLINT
+  // clang-format on
+
+  // -?
+  int32_t sign_bit = 0;
+  if (matches(end, "-")) {
+    sign_bit = 1;
+    end++;
+  }
+  // 0x
+  if (matches(end, "0x")) {
+    end += 2;
+  } else {
+    return {};
+  }
+
+  uint32_t mantissa = 0;
+  int32_t exponent = 0;
+
+  // `set_next_mantissa_bit_to` sets next `mantissa` bit starting from msb to
+  // lsb to value 1 if `set` is true, 0 otherwise
+  uint32_t mantissa_next_bit = kTotalMsb;
+  auto set_next_mantissa_bit_to = [&](bool set) -> bool {
+    if (mantissa_next_bit > kTotalMsb) {
+      return false;  // Overflowed mantissa
+    }
+    if (set) {
+      mantissa |= (1 << mantissa_next_bit);
+    }
+    --mantissa_next_bit;
+    return true;
+  };
+
+  // Parse integer part
+  // [0-9a-fA-F]*
+  bool has_integer = false;
+  bool has_zero_integer = true;
+  bool leading_bit_seen = false;
+  while (end < len_ && is_hex(content_->data[end])) {
+    has_integer = true;
+
+    const auto nibble = hex_value(content_->data[end]);
+    if (nibble != 0) {
+      has_zero_integer = false;
+    }
+
+    for (int32_t i = 3; i >= 0; --i) {
+      auto v = 1 & (nibble >> i);
+
+      // Skip leading 0s and the first 1
+      if (leading_bit_seen) {
+        if (!set_next_mantissa_bit_to(v != 0)) {
+          return {};
+        }
+        ++exponent;
+      } else {
+        if (v == 1) {
+          leading_bit_seen = true;
+        }
+      }
+    }
+
+    end++;
+  }
+
+  // .?
+  if (matches(end, ".")) {
+    end++;
+  }
+
+  // Parse fractional part
+  // [0-9a-fA-F]*
+  bool has_fractional = false;
+  leading_bit_seen = false;
+  while (end < len_ && is_hex(content_->data[end])) {
+    has_fractional = true;
+    auto nibble = hex_value(content_->data[end]);
+    for (int32_t i = 3; i >= 0; --i) {
+      auto v = 1 & (nibble >> i);
+
+      if (v == 1) {
+        leading_bit_seen = true;
+      }
+
+      // If integer part is 0 (denorm), we only start writing bits to the
+      // mantissa once we have a non-zero fractional bit. While the fractional
+      // values are 0, we adjust the exponent to avoid overflowing `mantissa`.
+      if (has_zero_integer && !leading_bit_seen) {
+        --exponent;
+      } else {
+        if (!set_next_mantissa_bit_to(v != 0)) {
+          return {};
+        }
+      }
+    }
+
+    end++;
+  }
+
+  if (!(has_integer || has_fractional)) {
+    return {};
+  }
+
+  // (p|P)
+  if (matches(end, "p") || matches(end, "P")) {
+    end++;
+  } else {
+    return {};
+  }
+
+  // (+|-)?
+  int32_t exponent_sign = 1;
+  if (matches(end, "+")) {
+    end++;
+  } else if (matches(end, "-")) {
+    exponent_sign = -1;
+    end++;
+  }
+
+  // Parse exponent from input
+  // [0-9]+
+  bool has_exponent = false;
+  int32_t input_exponent = 0;
+  while (end < len_ && isdigit(content_->data[end])) {
+    has_exponent = true;
+    input_exponent = (input_exponent * 10) + dec_value(content_->data[end]);
+    end++;
+  }
+  if (!has_exponent) {
+    return {};
+  }
+
+  pos_ = end;
+  location_.column += (end - start);
+  end_source(source);
+
+  // Compute exponent so far
+  exponent = exponent + (input_exponent * exponent_sign);
+
+  // Determine if value is zero
+  // Note: it's not enough to check mantissa == 0 as we drop initial bit from
+  // integer part.
+  bool is_zero = has_zero_integer && mantissa == 0;
+  TINT_ASSERT(Reader, !is_zero || (exponent == 0 && mantissa == 0));
+
+  if (!is_zero) {
+    // Bias exponent if non-zero
+    // After this, if exponent is <= 0, our value is a denormal
+    exponent += kExponentBias;
+
+    // Denormal uses biased exponent of -126, not -127
+    if (has_zero_integer) {
+      mantissa <<= 1;
+      --exponent;
+    }
+  }
+
+  // Shift mantissa to occupy the low 23 bits
+  mantissa >>= kMantissaShiftRight;
+
+  // If denormal, shift mantissa until our exponent is zero
+  if (!is_zero) {
+    // Denorm has exponent 0 and non-zero mantissa. We set the top bit here,
+    // then shift the mantissa to make exponent zero.
+    if (exponent <= 0) {
+      mantissa >>= 1;
+      mantissa |= (1 << kMantissaMsb);
+    }
+
+    while (exponent < 0) {
+      mantissa >>= 1;
+      ++exponent;
+
+      // If underflow, clamp to zero
+      if (mantissa == 0) {
+        exponent = 0;
+      }
+    }
+  }
+
+  if (exponent > kExponentMax) {
+    // Overflow: set to infinity
+    exponent = kExponentMax;
+    mantissa = 0;
+  } else if (exponent == kExponentMax && mantissa != 0) {
+    // NaN: set to infinity
+    mantissa = 0;
+  }
+
+  // Combine sign, mantissa, and exponent
+  uint32_t result_u32 = sign_bit << kSignBit;
+  result_u32 |= mantissa;
+  result_u32 |= (exponent & kExponentMask) << kExponentLeftShift;
+
+  // Reinterpret as float and return
+  float result;
+  std::memcpy(&result, &result_u32, sizeof(result));
+  return {source, static_cast<float>(result)};
+}
+
 Token Lexer::build_token_from_int_if_possible(Source source,
                                               size_t start,
                                               size_t end,
diff --git a/src/reader/wgsl/lexer.h b/src/reader/wgsl/lexer.h
index b1774e9..9c96bb5 100644
--- a/src/reader/wgsl/lexer.h
+++ b/src/reader/wgsl/lexer.h
@@ -46,6 +46,7 @@
                                          int32_t base);
   Token check_keyword(const Source&, const std::string&);
   Token try_float();
+  Token try_hex_float();
   Token try_hex_integer();
   Token try_ident();
   Token try_integer();
diff --git a/src/reader/wgsl/parser_impl_const_literal_test.cc b/src/reader/wgsl/parser_impl_const_literal_test.cc
index d73abaa..4dee6ce 100644
--- a/src/reader/wgsl/parser_impl_const_literal_test.cc
+++ b/src/reader/wgsl/parser_impl_const_literal_test.cc
@@ -14,17 +14,39 @@
 
 #include "src/reader/wgsl/parser_impl_test_helper.h"
 
+#include <cmath>
+#include <cstring>
+
 namespace tint {
 namespace reader {
 namespace wgsl {
 namespace {
 
+// Makes an IEEE 754 binary32 floating point number with
+// - 0 sign if sign is 0, 1 otherwise
+// - 'exponent_bits' is placed in the exponent space.
+//   So, the exponent bias must already be included.
+float MakeFloat(int sign, int biased_exponent, int mantissa) {
+  const uint32_t sign_bit = sign ? 0x80000000u : 0u;
+  // The binary32 exponent is 8 bits, just below the sign.
+  const uint32_t exponent_bits = (biased_exponent & 0xffu) << 23;
+  // The mantissa is the bottom 23 bits.
+  const uint32_t mantissa_bits = (mantissa & 0x7fffffu);
+
+  uint32_t bits = sign_bit | exponent_bits | mantissa_bits;
+  float result = 0.0f;
+  static_assert(sizeof(result) == sizeof(bits),
+                "expected float and uint32_t to be the same size");
+  std::memcpy(&result, &bits, sizeof(bits));
+  return result;
+}
+
 TEST_F(ParserImplTest, ConstLiteral_Int) {
   auto p = parser("-234");
   auto c = p->const_literal();
   EXPECT_TRUE(c.matched);
   EXPECT_FALSE(c.errored);
-  EXPECT_FALSE(p->has_error());
+  EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(c.value, nullptr);
   ASSERT_TRUE(c->Is<ast::SintLiteral>());
   EXPECT_EQ(c->As<ast::SintLiteral>()->value(), -234);
@@ -36,7 +58,7 @@
   auto c = p->const_literal();
   EXPECT_TRUE(c.matched);
   EXPECT_FALSE(c.errored);
-  EXPECT_FALSE(p->has_error());
+  EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(c.value, nullptr);
   ASSERT_TRUE(c->Is<ast::UintLiteral>());
   EXPECT_EQ(c->As<ast::UintLiteral>()->value(), 234u);
@@ -48,7 +70,7 @@
   auto c = p->const_literal();
   EXPECT_TRUE(c.matched);
   EXPECT_FALSE(c.errored);
-  EXPECT_FALSE(p->has_error());
+  EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(c.value, nullptr);
   ASSERT_TRUE(c->Is<ast::FloatLiteral>());
   EXPECT_FLOAT_EQ(c->As<ast::FloatLiteral>()->value(), 234e12f);
@@ -63,12 +85,221 @@
   ASSERT_EQ(c.value, nullptr);
 }
 
+struct FloatLiteralTestCase {
+  const char* input;
+  float expected;
+};
+
+inline std::ostream& operator<<(std::ostream& out, FloatLiteralTestCase data) {
+  out << data.input;
+  return out;
+}
+
+class ParserImplFloatLiteralTest
+    : public ParserImplTestWithParam<FloatLiteralTestCase> {};
+TEST_P(ParserImplFloatLiteralTest, Parse) {
+  auto params = GetParam();
+  SCOPED_TRACE(params.input);
+  auto p = parser(params.input);
+  auto c = p->const_literal();
+  EXPECT_TRUE(c.matched);
+  EXPECT_FALSE(c.errored);
+  EXPECT_FALSE(p->has_error()) << p->error();
+  ASSERT_NE(c.value, nullptr);
+  ASSERT_TRUE(c->Is<ast::FloatLiteral>());
+  EXPECT_FLOAT_EQ(c->As<ast::FloatLiteral>()->value(), params.expected);
+}
+
+FloatLiteralTestCase float_literal_test_cases[] = {
+    {"0.0", 0.0f},                         // Zero
+    {"1.0", 1.0f},                         // One
+    {"-1.0", -1.0f},                       // MinusOne
+    {"1000000000.0", 1e9f},                // Billion
+    {"-0.0", std::copysign(0.0f, -5.0f)},  // NegativeZero
+    {"0.0", MakeFloat(0, 0, 0)},           // Zero
+    {"-0.0", MakeFloat(1, 0, 0)},          // NegativeZero
+    {"1.0", MakeFloat(0, 127, 0)},         // One
+    {"-1.0", MakeFloat(1, 127, 0)},        // NegativeOne
+};
+INSTANTIATE_TEST_SUITE_P(ParserImplFloatLiteralTest_Float,
+                         ParserImplFloatLiteralTest,
+                         testing::ValuesIn(float_literal_test_cases));
+
+const float NegInf = MakeFloat(1, 255, 0);
+const float PosInf = MakeFloat(0, 255, 0);
+FloatLiteralTestCase hexfloat_literal_test_cases[] = {
+    // Regular numbers
+    {"0x0p+0", 0.f},
+    {"0x1p+0", 1.f},
+    {"0x1p+1", 2.f},
+    {"0x1.8p+1", 3.f},
+    {"0x1.99999ap-4", 0.1f},
+    {"0x1p-1", 0.5f},
+    {"0x1p-2", 0.25f},
+    {"0x1.8p-1", 0.75f},
+    {"-0x0p+0", -0.f},
+    {"-0x1p+0", -1.f},
+    {"-0x1p-1", -0.5f},
+    {"-0x1p-2", -0.25f},
+    {"-0x1.8p-1", -0.75f},
+
+    // Large numbers
+    {"0x1p+9", 512.f},
+    {"0x1p+10", 1024.f},
+    {"0x1.02p+10", 1024.f + 8.f},
+    {"-0x1p+9", -512.f},
+    {"-0x1p+10", -1024.f},
+    {"-0x1.02p+10", -1024.f - 8.f},
+
+    // Small numbers
+    {"0x1p-9", 1.0f / 512.f},
+    {"0x1p-10", 1.0f / 1024.f},
+    {"0x1.02p-3", 1.0f / 1024.f + 1.0f / 8.f},
+    {"-0x1p-9", 1.0f / -512.f},
+    {"-0x1p-10", 1.0f / -1024.f},
+    {"-0x1.02p-3", 1.0f / -1024.f - 1.0f / 8.f},
+
+    // Near lowest non-denorm
+    {"0x1p-124", std::ldexp(1.f * 8.f, -127)},
+    {"0x1p-125", std::ldexp(1.f * 4.f, -127)},
+    {"-0x1p-124", -std::ldexp(1.f * 8.f, -127)},
+    {"-0x1p-125", -std::ldexp(1.f * 4.f, -127)},
+
+    // Lowest non-denorm
+    {"0x1p-126", std::ldexp(1.f * 2.f, -127)},
+    {"-0x1p-126", -std::ldexp(1.f * 2.f, -127)},
+
+    // Denormalized values
+    {"0x1p-127", std::ldexp(1.f, -127)},
+    {"0x1p-128", std::ldexp(1.f / 2.f, -127)},
+    {"0x1p-129", std::ldexp(1.f / 4.f, -127)},
+    {"0x1p-130", std::ldexp(1.f / 8.f, -127)},
+    {"-0x1p-127", -std::ldexp(1.f, -127)},
+    {"-0x1p-128", -std::ldexp(1.f / 2.f, -127)},
+    {"-0x1p-129", -std::ldexp(1.f / 4.f, -127)},
+    {"-0x1p-130", -std::ldexp(1.f / 8.f, -127)},
+
+    {"0x1.8p-127", std::ldexp(1.f, -127) + (std::ldexp(1.f, -127) / 2.f)},
+    {"0x1.8p-128", std::ldexp(1.f, -127) / 2.f + (std::ldexp(1.f, -127) / 4.f)},
+
+    {"0x1p-149", MakeFloat(0, 0, 1)},                 // +SmallestDenormal
+    {"0x1p-148", MakeFloat(0, 0, 2)},                 // +BiggerDenormal
+    {"0x1.fffffcp-127", MakeFloat(0, 0, 0x7fffff)},   // +LargestDenormal
+    {"-0x1p-149", MakeFloat(1, 0, 1)},                // -SmallestDenormal
+    {"-0x1p-148", MakeFloat(1, 0, 2)},                // -BiggerDenormal
+    {"-0x1.fffffcp-127", MakeFloat(1, 0, 0x7fffff)},  // -LargestDenormal
+
+    {"0x1.2bfaf8p-127", MakeFloat(0, 0, 0xcafebe)},   // +Subnormal
+    {"-0x1.2bfaf8p-127", MakeFloat(1, 0, 0xcafebe)},  // -Subnormal
+    {"0x1.55554p-130", MakeFloat(0, 0, 0xaaaaa)},     // +Subnormal
+    {"-0x1.55554p-130", MakeFloat(1, 0, 0xaaaaa)},    // -Subnormal
+
+    // Nan -> Infinity
+    {"0x1.8p+128", PosInf},
+    {"0x1.0002p+128", PosInf},
+    {"0x1.0018p+128", PosInf},
+    {"0x1.01ep+128", PosInf},
+    {"0x1.fffffep+128", PosInf},
+    {"-0x1.8p+128", NegInf},
+    {"-0x1.0002p+128", NegInf},
+    {"-0x1.0018p+128", NegInf},
+    {"-0x1.01ep+128", NegInf},
+    {"-0x1.fffffep+128", NegInf},
+
+    // Infinity
+    {"0x1p+128", PosInf},
+    {"-0x1p+128", NegInf},
+    {"0x32p+127", PosInf},
+    {"0x32p+500", PosInf},
+    {"-0x32p+127", NegInf},
+    {"-0x32p+500", NegInf},
+
+    // Overflow -> Infinity
+    {"0x1p+129", PosInf},
+    {"0x1.1p+128", PosInf},
+    {"-0x1p+129", NegInf},
+    {"-0x1.1p+128", NegInf},
+
+    // Underflow -> Zero
+    {"0x1p-500", 0.f},  // Exponent underflows
+    {"-0x1p-500", -0.f},
+    {"0x0.00000000001p-126", 0.f},  // Fraction causes underflow
+    {"-0x0.0000000001p-127", -0.f},
+    {"0x0.01p-142", 0.f},
+    {"-0x0.01p-142", -0.f},  // Fraction causes additional underflow
+
+    // Test parsing
+    {"0x0p0", 0.f},
+    {"0x0p-0", 0.f},
+    {"0x0p+000", 0.f},
+    {"0x00000000000000p+000000000000000", 0.f},
+    {"0x00000000000000p-000000000000000", 0.f},
+    {"0x00000000000001p+000000000000000", 1.f},
+    {"0x00000000000001p-000000000000000", 1.f},
+    {"0x0000000000000000000001.99999ap-000000000000000004", 0.1f},
+    {"0x2p+0", 2.f},
+    {"0xFFp+0", 255.f},
+    {"0x0.8p+0", 0.5f},
+    {"0x0.4p+0", 0.25f},
+    {"0x0.4p+1", 2 * 0.25f},
+    {"0x0.4p+2", 4 * 0.25f},
+    {"0x123Ep+1", 9340.f},
+    {"-0x123Ep+1", -9340.f},
+    {"0x1a2b3cP12", 7.024656e+09f},
+    {"-0x1a2b3cP12", -7.024656e+09f},
+};
+INSTANTIATE_TEST_SUITE_P(ParserImplFloatLiteralTest_HexFloat,
+                         ParserImplFloatLiteralTest,
+                         testing::ValuesIn(hexfloat_literal_test_cases));
+
+TEST_F(ParserImplTest, ConstLiteral_FloatHighest) {
+  const auto highest = std::numeric_limits<float>::max();
+  const auto expected_highest = 340282346638528859811704183484516925440.0f;
+  if (highest < expected_highest || highest > expected_highest) {
+    GTEST_SKIP() << "std::numeric_limits<float>::max() is not as expected for "
+                    "this target";
+  }
+  auto p = parser("340282346638528859811704183484516925440.0");
+  auto c = p->const_literal();
+  EXPECT_TRUE(c.matched);
+  EXPECT_FALSE(c.errored);
+  EXPECT_FALSE(p->has_error()) << p->error();
+  ASSERT_NE(c.value, nullptr);
+  ASSERT_TRUE(c->Is<ast::FloatLiteral>());
+  EXPECT_FLOAT_EQ(c->As<ast::FloatLiteral>()->value(),
+                  std::numeric_limits<float>::max());
+  EXPECT_EQ(c->source().range, (Source::Range{{1u, 1u}, {1u, 42u}}));
+}
+
+TEST_F(ParserImplTest, ConstLiteral_FloatLowest) {
+  // Some compilers complain if you test floating point numbers for equality.
+  // So say it via two inequalities.
+  const auto lowest = std::numeric_limits<float>::lowest();
+  const auto expected_lowest = -340282346638528859811704183484516925440.0f;
+  if (lowest < expected_lowest || lowest > expected_lowest) {
+    GTEST_SKIP()
+        << "std::numeric_limits<float>::lowest() is not as expected for "
+           "this target";
+  }
+
+  auto p = parser("-340282346638528859811704183484516925440.0");
+  auto c = p->const_literal();
+  EXPECT_TRUE(c.matched);
+  EXPECT_FALSE(c.errored);
+  EXPECT_FALSE(p->has_error()) << p->error();
+  ASSERT_NE(c.value, nullptr);
+  ASSERT_TRUE(c->Is<ast::FloatLiteral>());
+  EXPECT_FLOAT_EQ(c->As<ast::FloatLiteral>()->value(),
+                  std::numeric_limits<float>::lowest());
+  EXPECT_EQ(c->source().range, (Source::Range{{1u, 1u}, {1u, 43u}}));
+}
+
 TEST_F(ParserImplTest, ConstLiteral_True) {
   auto p = parser("true");
   auto c = p->const_literal();
   EXPECT_TRUE(c.matched);
   EXPECT_FALSE(c.errored);
-  EXPECT_FALSE(p->has_error());
+  EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(c.value, nullptr);
   ASSERT_TRUE(c->Is<ast::BoolLiteral>());
   EXPECT_TRUE(c->As<ast::BoolLiteral>()->IsTrue());
@@ -80,7 +311,7 @@
   auto c = p->const_literal();
   EXPECT_TRUE(c.matched);
   EXPECT_FALSE(c.errored);
-  EXPECT_FALSE(p->has_error());
+  EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(c.value, nullptr);
   ASSERT_TRUE(c->Is<ast::BoolLiteral>());
   EXPECT_TRUE(c->As<ast::BoolLiteral>()->IsFalse());
@@ -92,7 +323,7 @@
   auto c = p->const_literal();
   EXPECT_FALSE(c.matched);
   EXPECT_FALSE(c.errored);
-  EXPECT_FALSE(p->has_error());
+  EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_EQ(c.value, nullptr);
 }