Import Tint changes from Dawn
Changes:
- e6c03a3799922151e821c57a69c78b5e422ac2a5 tint/reader/wgsl: Lex abstract floats by Ben Clayton <bclayton@google.com>
- a644c3d835a4011d8d05454d061f391603df7bc1 tint/reader/wgsl: Use C++17 hex floats for tests by Ben Clayton <bclayton@google.com>
- 30f01c17908bc5d18371f122b2e594dc30f40b98 tint/reader/wgsl: Lex abstract hex floats by Ben Clayton <bclayton@google.com>
- 09373989ecf3bb551565f2caf7016cb7ed0b7957 tint: Clamp constants to type's limits when number is unr... by Ben Clayton <bclayton@google.com>
- ce6adf4c678102197e5bd8b3bfa0a34b2df48b48 tint: Implement DP4a on SPIR-V writer by Jiawei Shao <jiawei.shao@intel.com>
- 8ae9e94344d38ca38723b9b56e2b4c047faaff48 tint/reader/wgsl: Restructure Lexer::try_hex_float() cons... by Ben Clayton <bclayton@google.com>
- 3ad927cc73dac935f5435f8b1f2855f0ea289b31 tint/writer: Check for inf / nan after casting to f32. by Ben Clayton <bclayton@google.com>
- e34e059804709726e9cbd35547c3ff857924af33 tint/resolver: Ensure materialized values are representable by Ben Clayton <bclayton@google.com>
- a8d52280494492a00b7ca6312ca9f69eef3fdd6a tint/resolver: Add `DataType<T>::ElementType` typedef by Ben Clayton <bclayton@google.com>
- 6ae7c0601760acc3309580d5ca5d24e3799a846e tint/resolver: Change DataType<T>::Expr() value type to d... by Ben Clayton <bclayton@google.com>
- 9707e6bb38c249e49e372b46bd8492d32e20315d tint: Rework sem::Constant to be variant-of-vector by Ben Clayton <bclayton@google.com>
- ef702af6c8b8d12ec5695e5600f0ef4844dd73d3 tint/reader/wgsl: Use CheckedConvert() for lexing by Ben Clayton <bclayton@google.com>
- c2eccfc887def447e2f1833408674095ad7d0443 tint: Add more helpers to tint::Number by Ben Clayton <bclayton@google.com>
- 3c83be8a5b34ef9c506e151b8736281f5df07877 tint: Add utils::Result by Ben Clayton <bclayton@google.com>
- 3bb360f0bb91695d1fe85624a8ef9d70e857c632 tint: Add utils::TransformN() by Ben Clayton <bclayton@google.com>
GitOrigin-RevId: e6c03a3799922151e821c57a69c78b5e422ac2a5
Change-Id: I91304c07aa556f5f1d84a1003c078204ad76f5df
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/91603
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 057433a..0805d5d 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -360,6 +360,7 @@
"inspector/resource_binding.h",
"inspector/scalar.cc",
"inspector/scalar.h",
+ "number.cc",
"number.h",
"program.cc",
"program.h",
@@ -528,6 +529,8 @@
"transform/zero_init_workgroup_memory.h",
"utils/bitcast.h",
"utils/block_allocator.h",
+ "utils/compiler_macros.h",
+ "utils/concat.h",
"utils/crc32.h",
"utils/debugger.cc",
"utils/debugger.h",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 1eef3a1..2fd7d0d 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -239,6 +239,7 @@
inspector/resource_binding.h
inspector/scalar.cc
inspector/scalar.h
+ number.cc
number.h
program_builder.cc
program_builder.h
@@ -455,6 +456,8 @@
transform/zero_init_workgroup_memory.h
utils/bitcast.h
utils/block_allocator.h
+ utils/compiler_macros.h
+ utils/concat.h
utils/crc32.h
utils/enum_set.h
utils/hash.h
@@ -748,21 +751,23 @@
diagnostic/diagnostic_test.cc
diagnostic/formatter_test.cc
diagnostic/printer_test.cc
+ number_test.cc
+ program_builder_test.cc
program_test.cc
resolver/array_accessor_test.cc
resolver/assignment_validation_test.cc
resolver/atomics_test.cc
resolver/atomics_validation_test.cc
+ resolver/attribute_validation_test.cc
resolver/bitcast_validation_test.cc
- resolver/builtins_validation_test.cc
resolver/builtin_test.cc
resolver/builtin_validation_test.cc
+ resolver/builtins_validation_test.cc
resolver/call_test.cc
resolver/call_validation_test.cc
resolver/compound_assignment_validation_test.cc
resolver/compound_statement_test.cc
resolver/control_block_validation_test.cc
- resolver/attribute_validation_test.cc
resolver/dependency_graph_test.cc
resolver/entry_point_validation_test.cc
resolver/function_validation_test.cc
@@ -798,6 +803,7 @@
sem/atomic.cc
sem/bool_test.cc
sem/builtin_test.cc
+ sem/constant_test.cc
sem/depth_multisampled_texture_test.cc
sem/depth_texture_test.cc
sem/expression_test.cc
@@ -815,8 +821,8 @@
sem/sem_struct_test.cc
sem/storage_texture_test.cc
sem/texture_test.cc
- sem/type_test.cc
sem/type_manager_test.cc
+ sem/type_test.cc
sem/u32_test.cc
sem/vector_test.cc
source_test.cc
@@ -836,6 +842,7 @@
utils/io/tmpfile_test.cc
utils/map_test.cc
utils/math_test.cc
+ utils/result_test.cc
utils/reverse_test.cc
utils/scoped_assignment_test.cc
utils/string_test.cc
diff --git a/src/tint/number.cc b/src/tint/number.cc
new file mode 100644
index 0000000..33b2af5
--- /dev/null
+++ b/src/tint/number.cc
@@ -0,0 +1,59 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/number.h"
+
+#include <algorithm>
+#include <cstring>
+#include <ostream>
+
+namespace tint {
+
+std::ostream& operator<<(std::ostream& out, ConversionFailure failure) {
+ switch (failure) {
+ case ConversionFailure::kExceedsPositiveLimit:
+ return out << "value exceeds positive limit for type";
+ case ConversionFailure::kExceedsNegativeLimit:
+ return out << "value exceeds negative limit for type";
+ case ConversionFailure::kTooSmall:
+ return out << "value is too small for type";
+ }
+ return out << "<unknown>";
+}
+
+f16::type f16::Quantize(f16::type value) {
+ if (value > kHighest) {
+ return std::numeric_limits<f16::type>::infinity();
+ }
+ if (value < kLowest) {
+ return -std::numeric_limits<f16::type>::infinity();
+ }
+ // Below value must be within the finite range of a f16.
+ uint32_t u32;
+ memcpy(&u32, &value, 4);
+ if ((u32 & 0x7fffffffu) == 0) { // ~sign
+ return value; // +/- zero
+ }
+ if ((u32 & 0x7f800000) == 0x7f800000) { // exponent all 1's
+ return value; // inf or nan
+ }
+ // f32 bits : 1 sign, 8 exponent, 23 mantissa
+ // f16 bits : 1 sign, 5 exponent, 10 mantissa
+ // Mask the value to preserve the sign, exponent and most-significant 10 mantissa bits.
+ u32 = u32 & 0xffffe000u;
+ memcpy(&value, &u32, 4);
+ return value;
+}
+
+} // namespace tint
diff --git a/src/tint/number.h b/src/tint/number.h
index c3be236..4154462 100644
--- a/src/tint/number.h
+++ b/src/tint/number.h
@@ -17,18 +17,72 @@
#include <stdint.h>
#include <functional>
+#include <limits>
+#include <ostream>
+
+#include "src/tint/utils/result.h"
+
+// Forward declaration
+namespace tint {
+/// Number wraps a integer or floating point number, enforcing explicit casting.
+template <typename T>
+struct Number;
+} // namespace tint
namespace tint::detail {
/// An empty structure used as a unique template type for Number when
/// specializing for the f16 type.
struct NumberKindF16 {};
+
+/// Helper for obtaining the underlying type for a Number.
+template <typename T>
+struct NumberUnwrapper {
+ /// When T is not a Number, then type defined to be T.
+ using type = T;
+};
+
+/// NumberUnwrapper specialization for Number<T>.
+template <typename T>
+struct NumberUnwrapper<Number<T>> {
+ /// The Number's underlying type.
+ using type = typename Number<T>::type;
+};
+
} // namespace tint::detail
namespace tint {
+/// Evaluates to true iff T is a floating-point type or is NumberKindF16.
+template <typename T>
+constexpr bool IsFloatingPoint =
+ std::is_floating_point_v<T> || std::is_same_v<T, detail::NumberKindF16>;
+
+/// Evaluates to true iff T is an integer type.
+template <typename T>
+constexpr bool IsInteger = std::is_integral_v<T>;
+
+/// Evaluates to true iff T is an integer type, floating-point type or is NumberKindF16.
+template <typename T>
+constexpr bool IsNumeric = IsInteger<T> || IsFloatingPoint<T>;
+
/// Number wraps a integer or floating point number, enforcing explicit casting.
template <typename T>
struct Number {
+ static_assert(IsNumeric<T>, "Number<T> constructed with non-numeric type");
+
+ /// type is the underlying type of the Number
+ using type = T;
+
+ /// Highest finite representable value of this type.
+ static constexpr type kHighest = std::numeric_limits<type>::max();
+
+ /// Lowest finite representable value of this type.
+ static constexpr type kLowest = std::numeric_limits<type>::lowest();
+
+ /// Smallest positive normal value of this type.
+ static constexpr type kSmallest =
+ std::is_integral_v<type> ? 0 : std::numeric_limits<type>::min();
+
/// Constructor. The value is zero-initialized.
Number() = default;
@@ -59,41 +113,139 @@
}
/// The number value
- T value = {};
+ type value = {};
};
+/// Resolves to the underlying type for a Number.
+template <typename T>
+using UnwrapNumber = typename detail::NumberUnwrapper<T>::type;
+
+/// Writes the number to the ostream.
+/// @param out the std::ostream to write to
+/// @param num the Number
+/// @return the std::ostream so calls can be chained
+template <typename T>
+inline std::ostream& operator<<(std::ostream& out, Number<T> num) {
+ return out << num.value;
+}
+
+/// Equality operator.
+/// @param a the LHS number
+/// @param b the RHS number
+/// @returns true if the numbers `a` and `b` are exactly equal.
template <typename A, typename B>
bool operator==(Number<A> a, Number<B> b) {
using T = decltype(a.value + b.value);
- return std::equal_to<T>()(a.value, b.value);
+ return std::equal_to<T>()(static_cast<T>(a.value), static_cast<T>(b.value));
}
+/// Inequality operator.
+/// @param a the LHS number
+/// @param b the RHS number
+/// @returns true if the numbers `a` and `b` are exactly unequal.
template <typename A, typename B>
-bool operator==(Number<A> a, B b) {
+bool operator!=(Number<A> a, Number<B> b) {
+ return !(a == b);
+}
+
+/// Equality operator.
+/// @param a the LHS number
+/// @param b the RHS number
+/// @returns true if the numbers `a` and `b` are exactly equal.
+template <typename A, typename B>
+std::enable_if_t<IsNumeric<B>, bool> operator==(Number<A> a, B b) {
return a == Number<B>(b);
}
+/// Inequality operator.
+/// @param a the LHS number
+/// @param b the RHS number
+/// @returns true if the numbers `a` and `b` are exactly unequal.
template <typename A, typename B>
-bool operator==(A a, Number<B> b) {
+std::enable_if_t<IsNumeric<B>, bool> operator!=(Number<A> a, B b) {
+ return !(a == b);
+}
+
+/// Equality operator.
+/// @param a the LHS number
+/// @param b the RHS number
+/// @returns true if the numbers `a` and `b` are exactly equal.
+template <typename A, typename B>
+std::enable_if_t<IsNumeric<A>, bool> operator==(A a, Number<B> b) {
return Number<A>(a) == b;
}
+/// Inequality operator.
+/// @param a the LHS number
+/// @param b the RHS number
+/// @returns true if the numbers `a` and `b` are exactly unequal.
+template <typename A, typename B>
+std::enable_if_t<IsNumeric<A>, bool> operator!=(A a, Number<B> b) {
+ return !(a == b);
+}
+
+/// Enumerator of failure reasons when converting from one number to another.
+enum class ConversionFailure {
+ kExceedsPositiveLimit, // The value was too big (+'ve) to fit in the target type
+ kExceedsNegativeLimit, // The value was too big (-'ve) to fit in the target type
+ kTooSmall, // The value was too small to fit in the target type
+};
+
+/// Writes the conversion failure message to the ostream.
+/// @param out the std::ostream to write to
+/// @param failure the ConversionFailure
+/// @return the std::ostream so calls can be chained
+std::ostream& operator<<(std::ostream& out, ConversionFailure failure);
+
+/// Converts a number from one type to another, checking that the value fits in the target type.
+/// @returns the resulting value of the conversion, or a failure reason.
+template <typename TO, typename FROM>
+utils::Result<TO, ConversionFailure> CheckedConvert(Number<FROM> num) {
+ using T = decltype(UnwrapNumber<TO>() + num.value);
+ const auto value = static_cast<T>(num.value);
+ if (value > static_cast<T>(TO::kHighest)) {
+ return ConversionFailure::kExceedsPositiveLimit;
+ }
+ if (value < static_cast<T>(TO::kLowest)) {
+ return ConversionFailure::kExceedsNegativeLimit;
+ }
+ if constexpr (IsFloatingPoint<UnwrapNumber<TO>>) {
+ if ((value < T(0) && value > static_cast<T>(-TO::kSmallest)) ||
+ (value > T(0) && value < static_cast<T>(TO::kSmallest))) {
+ return ConversionFailure::kTooSmall;
+ }
+ }
+ return TO(value); // Success
+}
+
/// The partial specification of Number for f16 type, storing the f16 value as float,
/// and enforcing proper explicit casting.
template <>
struct Number<detail::NumberKindF16> {
+ /// C++ does not have a native float16 type, so we use a 32-bit float instead.
+ using type = float;
+
+ /// Highest finite representable value of this type.
+ static constexpr type kHighest = 65504.0f; // 2¹⁵ × (1 + 1023/1024)
+
+ /// Lowest finite representable value of this type.
+ static constexpr type kLowest = -65504.0f;
+
+ /// Smallest positive normal value of this type.
+ static constexpr type kSmallest = 0.00006103515625f; // 2⁻¹⁴
+
/// Constructor. The value is zero-initialized.
Number() = default;
/// Constructor.
/// @param v the value to initialize this Number to
template <typename U>
- explicit Number(U v) : value(static_cast<float>(v)) {}
+ explicit Number(U v) : value(Quantize(static_cast<type>(v))) {}
/// Constructor.
/// @param v the value to initialize this Number to
template <typename U>
- explicit Number(Number<U> v) : value(static_cast<float>(v.value)) {}
+ explicit Number(Number<U> v) : value(Quantize(static_cast<type>(v.value))) {}
/// Conversion operator
/// @returns the value as the internal representation type of F16
@@ -106,13 +258,20 @@
/// Assignment operator with parameter as native floating point type
/// @param v the new value
/// @returns this Number so calls can be chained
- Number& operator=(float v) {
- value = v;
+ Number& operator=(type v) {
+ value = Quantize(v);
return *this;
}
+ /// @param value the input float32 value
+ /// @returns the float32 value quantized to the smaller float16 value, through truncation of the
+ /// mantissa bits (no rounding). If the float32 value is too large (positive or negative) to be
+ /// represented by a float16 value, then the returned value will be positive or negative
+ /// infinity.
+ static type Quantize(type value);
+
/// The number value, stored as float
- float value = {};
+ type value = {};
};
/// `AInt` is a type alias to `Number<int64_t>`.
diff --git a/src/tint/number_test.cc b/src/tint/number_test.cc
new file mode 100644
index 0000000..615be0e
--- /dev/null
+++ b/src/tint/number_test.cc
@@ -0,0 +1,145 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <cmath>
+
+#include "src/tint/program_builder.h"
+
+#include "gtest/gtest.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint {
+namespace {
+
+constexpr int64_t kHighestI32 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
+constexpr int64_t kHighestU32 = static_cast<int64_t>(std::numeric_limits<uint32_t>::max());
+constexpr int64_t kLowestI32 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
+constexpr int64_t kLowestU32 = static_cast<int64_t>(std::numeric_limits<uint32_t>::min());
+
+// Highest float32 value. Calculated as:
+// (2^127)×(1+(0x7fffff÷0x800000))
+constexpr double kHighestF32 = 340282346638528859811704183484516925440.0;
+
+// Next ULP up from kHighestF32 for a float64. Calculated as:
+// (2^127)×(1+(0xfffffe0000001÷0x10000000000000))
+constexpr double kHighestF32NextULP = 340282346638528897590636046441678635008.0;
+
+// Smallest positive normal float32 value. Calculated as:
+// 2^-126
+constexpr double kSmallestF32 = 1.1754943508222875e-38;
+
+// Next ULP down from kSmallestF32 for a float64. Calculated as:
+// (2^-127)×(1+(0xfffffffffffff÷0x10000000000000))
+constexpr double kSmallestF32PrevULP = 1.1754943508222874e-38;
+
+// Highest float16 value. Calculated as:
+// (2^15)×(1+(0x3ff÷0x400))
+constexpr double kHighestF16 = 65504.0;
+
+// Next ULP up from kHighestF16 for a float64. Calculated as:
+// (2^15)×(1+(0xffc0000000001÷0x10000000000000))
+constexpr double kHighestF16NextULP = 65504.00000000001;
+
+// Smallest positive normal float16 value. Calculated as:
+// 2^-14
+constexpr double kSmallestF16 = 0.00006103515625;
+
+// Next ULP down from kSmallestF16 for a float64. Calculated as:
+// (2^-15)×(1+(0xfffffffffffff÷0x10000000000000))
+constexpr double kSmallestF16PrevULP = 0.00006103515624999999;
+
+constexpr double kLowestF32 = -kHighestF32;
+constexpr double kLowestF32NextULP = -kHighestF32NextULP;
+constexpr double kLowestF16 = -kHighestF16;
+constexpr double kLowestF16NextULP = -kHighestF16NextULP;
+
+TEST(NumberTest, CheckedConvertIdentity) {
+ EXPECT_EQ(CheckedConvert<AInt>(0_a), 0_a);
+ EXPECT_EQ(CheckedConvert<AFloat>(0_a), 0.0_a);
+ EXPECT_EQ(CheckedConvert<i32>(0_i), 0_i);
+ EXPECT_EQ(CheckedConvert<u32>(0_u), 0_u);
+ EXPECT_EQ(CheckedConvert<f32>(0_f), 0_f);
+ EXPECT_EQ(CheckedConvert<f16>(0_h), 0_h);
+
+ EXPECT_EQ(CheckedConvert<AInt>(1_a), 1_a);
+ EXPECT_EQ(CheckedConvert<AFloat>(1_a), 1.0_a);
+ EXPECT_EQ(CheckedConvert<i32>(1_i), 1_i);
+ EXPECT_EQ(CheckedConvert<u32>(1_u), 1_u);
+ EXPECT_EQ(CheckedConvert<f32>(1_f), 1_f);
+ EXPECT_EQ(CheckedConvert<f16>(1_h), 1_h);
+}
+
+TEST(NumberTest, CheckedConvertLargestValue) {
+ EXPECT_EQ(CheckedConvert<i32>(AInt(kHighestI32)), i32(kHighestI32));
+ EXPECT_EQ(CheckedConvert<u32>(AInt(kHighestU32)), u32(kHighestU32));
+ EXPECT_EQ(CheckedConvert<f32>(AFloat(kHighestF32)), f32(kHighestF32));
+ EXPECT_EQ(CheckedConvert<f16>(AFloat(kHighestF16)), f16(kHighestF16));
+}
+
+TEST(NumberTest, CheckedConvertLowestValue) {
+ EXPECT_EQ(CheckedConvert<i32>(AInt(kLowestI32)), i32(kLowestI32));
+ EXPECT_EQ(CheckedConvert<u32>(AInt(kLowestU32)), u32(kLowestU32));
+ EXPECT_EQ(CheckedConvert<f32>(AFloat(kLowestF32)), f32(kLowestF32));
+ EXPECT_EQ(CheckedConvert<f16>(AFloat(kLowestF16)), f16(kLowestF16));
+}
+
+TEST(NumberTest, CheckedConvertSmallestValue) {
+ EXPECT_EQ(CheckedConvert<i32>(AInt(0)), i32(0));
+ EXPECT_EQ(CheckedConvert<u32>(AInt(0)), u32(0));
+ EXPECT_EQ(CheckedConvert<f32>(AFloat(kSmallestF32)), f32(kSmallestF32));
+ EXPECT_EQ(CheckedConvert<f16>(AFloat(kSmallestF16)), f16(kSmallestF16));
+}
+
+TEST(NumberTest, CheckedConvertExceedsPositiveLimit) {
+ EXPECT_EQ(CheckedConvert<i32>(AInt(kHighestI32 + 1)), ConversionFailure::kExceedsPositiveLimit);
+ EXPECT_EQ(CheckedConvert<u32>(AInt(kHighestU32 + 1)), ConversionFailure::kExceedsPositiveLimit);
+ EXPECT_EQ(CheckedConvert<f32>(AFloat(kHighestF32NextULP)),
+ ConversionFailure::kExceedsPositiveLimit);
+ EXPECT_EQ(CheckedConvert<f16>(AFloat(kHighestF16NextULP)),
+ ConversionFailure::kExceedsPositiveLimit);
+}
+
+TEST(NumberTest, CheckedConvertExceedsNegativeLimit) {
+ EXPECT_EQ(CheckedConvert<i32>(AInt(kLowestI32 - 1)), ConversionFailure::kExceedsNegativeLimit);
+ EXPECT_EQ(CheckedConvert<u32>(AInt(kLowestU32 - 1)), ConversionFailure::kExceedsNegativeLimit);
+ EXPECT_EQ(CheckedConvert<f32>(AFloat(kLowestF32NextULP)),
+ ConversionFailure::kExceedsNegativeLimit);
+ EXPECT_EQ(CheckedConvert<f16>(AFloat(kLowestF16NextULP)),
+ ConversionFailure::kExceedsNegativeLimit);
+}
+
+TEST(NumberTest, CheckedConvertTooSmall) {
+ EXPECT_EQ(CheckedConvert<f32>(AFloat(kSmallestF32PrevULP)), ConversionFailure::kTooSmall);
+ EXPECT_EQ(CheckedConvert<f16>(AFloat(kSmallestF16PrevULP)), ConversionFailure::kTooSmall);
+}
+
+TEST(NumberTest, QuantizeF16) {
+ constexpr float nan = std::numeric_limits<float>::quiet_NaN();
+ constexpr float inf = std::numeric_limits<float>::infinity();
+
+ EXPECT_EQ(f16(0.0), 0.0f);
+ EXPECT_EQ(f16(1.0), 1.0f);
+ EXPECT_EQ(f16(0.00006106496), 0.000061035156f);
+ EXPECT_EQ(f16(1.0004883), 1.0f);
+ EXPECT_EQ(f16(-8196), -8192.f);
+ EXPECT_EQ(f16(65504.003), inf);
+ EXPECT_EQ(f16(-65504.003), -inf);
+ EXPECT_EQ(f16(inf), inf);
+ EXPECT_EQ(f16(-inf), -inf);
+ EXPECT_TRUE(std::isnan(f16(nan)));
+}
+
+} // namespace
+} // namespace tint
diff --git a/src/tint/reader/wgsl/lexer.cc b/src/tint/reader/wgsl/lexer.cc
index 20b3364..bfec1d1 100644
--- a/src/tint/reader/wgsl/lexer.cc
+++ b/src/tint/reader/wgsl/lexer.cc
@@ -24,6 +24,7 @@
#include <utility>
#include "src/tint/debug.h"
+#include "src/tint/number.h"
#include "src/tint/text/unicode.h"
namespace tint::reader::wgsl {
@@ -80,42 +81,6 @@
return 0;
}
-/// LimitCheck is the enumerator result of check_limits().
-enum class LimitCheck {
- /// The value was within the limits of the data type.
- kWithinLimits,
- /// The value was too small to fit within the data type.
- kTooSmall,
- /// The value was too large to fit within the data type.
- kTooLarge,
-};
-
-/// Checks whether the value fits within the integer type `T`
-template <typename T>
-LimitCheck check_limits(int64_t value) {
- static_assert(std::is_integral_v<T>, "T must be an integer");
- if (value < static_cast<int64_t>(std::numeric_limits<T>::lowest())) {
- return LimitCheck::kTooSmall;
- }
- if (value > static_cast<int64_t>(std::numeric_limits<T>::max())) {
- return LimitCheck::kTooLarge;
- }
- return LimitCheck::kWithinLimits;
-}
-
-/// Checks whether the value fits within the floating point type `T`
-template <typename T>
-LimitCheck check_limits(double value) {
- static_assert(std::is_floating_point_v<T>, "T must be a floating point");
- if (value < static_cast<double>(std::numeric_limits<T>::lowest())) {
- return LimitCheck::kTooSmall;
- }
- if (value > static_cast<double>(std::numeric_limits<T>::max())) {
- return LimitCheck::kTooLarge;
- }
- return LimitCheck::kWithinLimits;
-}
-
} // namespace
Lexer::Lexer(const Source::File* file) : file_(file), location_{1, 1} {}
@@ -393,54 +358,38 @@
advance(end - start);
end_source(source);
- double value = strtod(&at(start), nullptr);
- const double magnitude = std::abs(value);
+ double value = std::strtod(&at(start), nullptr);
if (has_f_suffix) {
- // This errors out if a non-zero magnitude is too small to represent in a
- // float. It can't be represented faithfully in an f32.
- if (0.0 < magnitude && magnitude < static_cast<double>(std::numeric_limits<float>::min())) {
- return {Token::Type::kError, source, "magnitude too small to be represented as f32"};
- }
- switch (check_limits<float>(value)) {
- case LimitCheck::kTooSmall:
- return {Token::Type::kError, source, "value too small for f32"};
- case LimitCheck::kTooLarge:
- return {Token::Type::kError, source, "value too large for f32"};
- default:
- return {Token::Type::kFloatLiteral_F, source, value};
+ if (auto f = CheckedConvert<f32>(AFloat(value))) {
+ return {Token::Type::kFloatLiteral_F, source, static_cast<double>(f.Get())};
+ } else if (f.Failure() == ConversionFailure::kTooSmall) {
+ return {Token::Type::kFloatLiteral_F, source, 0.0};
+ } else {
+ return {Token::Type::kError, source, "value cannot be represented as 'f32'"};
}
}
- // TODO(crbug.com/tint/1504): Properly support abstract float:
- // Change `AbstractFloatType` to `double`, update errors to say 'abstract int'.
- using AbstractFloatType = float;
- if (0.0 < magnitude &&
- magnitude < static_cast<double>(std::numeric_limits<AbstractFloatType>::min())) {
- return {Token::Type::kError, source, "magnitude too small to be represented as f32"};
- }
- switch (check_limits<AbstractFloatType>(value)) {
- case LimitCheck::kTooSmall:
- return {Token::Type::kError, source, "value too small for f32"};
- case LimitCheck::kTooLarge:
- return {Token::Type::kError, source, "value too large for f32"};
- default:
- return {Token::Type::kFloatLiteral, source, value};
+ if (value == HUGE_VAL || -value == HUGE_VAL) {
+ return {Token::Type::kError, source, "value cannot be represented as 'abstract-float'"};
+ } else {
+ return {Token::Type::kFloatLiteral, source, value};
}
}
Token Lexer::try_hex_float() {
- constexpr uint32_t kTotalBits = 32;
- constexpr uint32_t kTotalMsb = kTotalBits - 1;
- constexpr uint32_t kMantissaBits = 23;
- constexpr uint32_t kMantissaMsb = kMantissaBits - 1;
- constexpr uint32_t kMantissaShiftRight = kTotalBits - kMantissaBits;
- constexpr int32_t kExponentBias = 127;
- constexpr int32_t kExponentMax = 255;
- constexpr uint32_t kExponentBits = 8;
- constexpr uint32_t kExponentMask = (1 << kExponentBits) - 1;
- constexpr uint32_t kExponentLeftShift = kMantissaBits;
- constexpr uint32_t kSignBit = 31;
+ constexpr uint64_t kExponentBits = 11;
+ constexpr uint64_t kMantissaBits = 52;
+ constexpr uint64_t kTotalBits = 1 + kExponentBits + kMantissaBits;
+ constexpr uint64_t kTotalMsb = kTotalBits - 1;
+ constexpr uint64_t kMantissaMsb = kMantissaBits - 1;
+ constexpr uint64_t kMantissaShiftRight = kTotalBits - kMantissaBits;
+ constexpr int64_t kExponentBias = 1023;
+ constexpr uint64_t kExponentMask = (1 << kExponentBits) - 1;
+ constexpr int64_t kExponentMax = kExponentMask; // Including NaN / inf
+ constexpr uint64_t kExponentLeftShift = kMantissaBits;
+ constexpr uint64_t kSignBit = kTotalBits - 1;
+ constexpr uint64_t kOne = 1;
auto start = pos();
auto end = pos();
@@ -452,7 +401,7 @@
// clang-format on
// -?
- int32_t sign_bit = 0;
+ int64_t sign_bit = 0;
if (matches(end, "-")) {
sign_bit = 1;
end++;
@@ -464,8 +413,8 @@
return {};
}
- uint32_t mantissa = 0;
- uint32_t exponent = 0;
+ uint64_t mantissa = 0;
+ uint64_t exponent = 0;
// TODO(dneto): Values in the normal range for the format do not explicitly
// store the most significant bit. The algorithm here works hard to eliminate
@@ -478,7 +427,7 @@
// `set_next_mantissa_bit_to` sets next `mantissa` bit starting from msb to
// lsb to value 1 if `set` is true, 0 otherwise. Returns true on success, i.e.
// when the bit can be accommodated in the available space.
- uint32_t mantissa_next_bit = kTotalMsb;
+ uint64_t mantissa_next_bit = kTotalMsb;
auto set_next_mantissa_bit_to = [&](bool set, bool integer_part) -> bool {
// If adding bits for the integer part, we can overflow whether we set the
// bit or not. For the fractional part, we can only overflow when setting
@@ -490,7 +439,7 @@
return false; // Overflowed mantissa
}
if (set) {
- mantissa |= (1 << mantissa_next_bit);
+ mantissa |= (kOne << mantissa_next_bit);
}
--mantissa_next_bit;
return true;
@@ -546,7 +495,7 @@
has_zero_integer = false;
}
- for (int32_t bit = 3; bit >= 0; --bit) {
+ for (int bit = 3; bit >= 0; --bit) {
auto v = 1 & (nibble >> bit);
// Skip leading 0s and the first 1
@@ -567,7 +516,7 @@
// [0-9a-fA-F]*
for (auto i = fractional_range.first; i < fractional_range.second; ++i) {
auto nibble = hex_value(at(i));
- for (int32_t bit = 3; bit >= 0; --bit) {
+ for (int bit = 3; bit >= 0; --bit) {
auto v = 1 & (nibble >> bit);
if (v == 1) {
@@ -595,8 +544,8 @@
// Parse the optional exponent.
// ((p|P)(\+|-)?[0-9]+)?
- uint32_t input_exponent = 0; // Defaults to 0 if not present
- int32_t exponent_sign = 1;
+ uint64_t input_exponent = 0; // Defaults to 0 if not present
+ int64_t exponent_sign = 1;
// If the 'p' part is present, the rest of the exponent must exist.
bool has_f_suffix = false;
if (has_exponent) {
@@ -611,7 +560,7 @@
// Parse exponent from input
// [0-9]+
- // Allow overflow (in uint32_t) when the floating point value magnitude is
+ // Allow overflow (in uint64_t) when the floating point value magnitude is
// zero.
bool has_exponent_digits = false;
while (end < length() && isdigit(at(end))) {
@@ -648,14 +597,14 @@
} else {
// Ensure input exponent is not too large; i.e. that it won't overflow when
// adding the exponent bias.
- const uint32_t kIntMax = static_cast<uint32_t>(std::numeric_limits<int32_t>::max());
- const uint32_t kMaxInputExponent = kIntMax - kExponentBias;
+ const uint64_t kIntMax = static_cast<uint64_t>(std::numeric_limits<int64_t>::max());
+ const uint64_t kMaxInputExponent = kIntMax - kExponentBias;
if (input_exponent > kMaxInputExponent) {
return {Token::Type::kError, source, "exponent is too large for hex float"};
}
// Compute exponent so far
- exponent += static_cast<uint32_t>(static_cast<int32_t>(input_exponent) * exponent_sign);
+ exponent += static_cast<uint64_t>(static_cast<int64_t>(input_exponent) * exponent_sign);
// Bias exponent if non-zero
// After this, if exponent is <= 0, our value is a denormal
@@ -674,7 +623,7 @@
// We can now safely work with exponent as a signed quantity, as there's no
// chance to overflow
- int32_t signed_exponent = static_cast<int32_t>(exponent);
+ int64_t signed_exponent = static_cast<int64_t>(exponent);
// Shift mantissa to occupy the low 23 bits
mantissa >>= kMantissaShiftRight;
@@ -685,7 +634,7 @@
// then shift the mantissa to make exponent zero.
if (signed_exponent <= 0) {
mantissa >>= 1;
- mantissa |= (1 << kMantissaMsb);
+ mantissa |= (kOne << kMantissaMsb);
}
while (signed_exponent < 0) {
@@ -699,24 +648,30 @@
}
}
- if (signed_exponent > kExponentMax) {
- // Overflow: set to infinity
- signed_exponent = kExponentMax;
- mantissa = 0;
- } else if (signed_exponent == kExponentMax && mantissa != 0) {
- // NaN: set to infinity
- mantissa = 0;
+ if (signed_exponent >= kExponentMax || (signed_exponent == kExponentMax && mantissa != 0)) {
+ std::string type = has_f_suffix ? "f32" : "abstract-float";
+ return {Token::Type::kError, source, "value cannot be represented as '" + type + "'"};
}
// Combine sign, mantissa, and exponent
- uint32_t result_u32 = sign_bit << kSignBit;
- result_u32 |= mantissa;
- result_u32 |= (static_cast<uint32_t>(signed_exponent) & kExponentMask) << kExponentLeftShift;
+ uint64_t result_u64 = sign_bit << kSignBit;
+ result_u64 |= mantissa;
+ result_u64 |= (static_cast<uint64_t>(signed_exponent) & kExponentMask) << kExponentLeftShift;
// Reinterpret as float and return
- float result_f32;
- std::memcpy(&result_f32, &result_u32, sizeof(result_f32));
- double result_f64 = static_cast<double>(result_f32);
+ double result_f64;
+ std::memcpy(&result_f64, &result_u64, 8);
+
+ if (has_f_suffix) {
+ // Quantize to f32
+ // TODO(crbug.com/tint/1564): If the hex-float value is not exactly representable then we
+ // should be erroring here.
+ result_f64 = static_cast<double>(static_cast<float>(result_f64));
+ if (std::isinf(result_f64)) {
+ return {Token::Type::kError, source, "value cannot be represented as 'f32'"};
+ }
+ }
+
return {has_f_suffix ? Token::Type::kFloatLiteral_F : Token::Type::kFloatLiteral, source,
result_f64};
}
@@ -725,44 +680,31 @@
int64_t res = strtoll(&at(start), nullptr, base);
if (matches(pos(), "u")) {
- switch (check_limits<uint32_t>(res)) {
- case LimitCheck::kTooSmall:
- return {Token::Type::kError, source, "unsigned literal cannot be negative"};
- case LimitCheck::kTooLarge:
- return {Token::Type::kError, source, "value too large for u32"};
- default:
- advance(1);
- end_source(source);
- return {Token::Type::kIntLiteral_U, source, res};
+ if (CheckedConvert<u32>(AInt(res))) {
+ advance(1);
+ end_source(source);
+ return {Token::Type::kIntLiteral_U, source, res};
}
+ return {Token::Type::kError, source, "value cannot be represented as 'u32'"};
}
if (matches(pos(), "i")) {
- switch (check_limits<int32_t>(res)) {
- case LimitCheck::kTooSmall:
- return {Token::Type::kError, source, "value too small for i32"};
- case LimitCheck::kTooLarge:
- return {Token::Type::kError, source, "value too large for i32"};
- default:
- break;
+ if (CheckedConvert<i32>(AInt(res))) {
+ advance(1);
+ end_source(source);
+ return {Token::Type::kIntLiteral_I, source, res};
}
- advance(1);
- end_source(source);
- return {Token::Type::kIntLiteral_I, source, res};
+ return {Token::Type::kError, source, "value cannot be represented as 'i32'"};
}
// TODO(crbug.com/tint/1504): Properly support abstract int:
// Change `AbstractIntType` to `int64_t`, update errors to say 'abstract int'.
- using AbstractIntType = int32_t;
- switch (check_limits<AbstractIntType>(res)) {
- case LimitCheck::kTooSmall:
- return {Token::Type::kError, source, "value too small for i32"};
- case LimitCheck::kTooLarge:
- return {Token::Type::kError, source, "value too large for i32"};
- default:
- end_source(source);
- return {Token::Type::kIntLiteral, source, res};
+ using AbstractIntType = i32;
+ if (CheckedConvert<AbstractIntType>(AInt(res))) {
+ end_source(source);
+ return {Token::Type::kIntLiteral, source, res};
}
+ return {Token::Type::kError, source, "value cannot be represented as 'i32'"};
}
Token Lexer::try_hex_integer() {
diff --git a/src/tint/reader/wgsl/lexer_test.cc b/src/tint/reader/wgsl/lexer_test.cc
index 0c3e162..801b62c 100644
--- a/src/tint/reader/wgsl/lexer_test.cc
+++ b/src/tint/reader/wgsl/lexer_test.cc
@@ -362,12 +362,12 @@
FloatData{"-5.", -5.},
FloatData{"-.7", -.7},
// Non-zero with decimal and 'f' suffix
- FloatData{"5.7f", 5.7},
- FloatData{"5.f", 5.},
- FloatData{".7f", .7},
- FloatData{"-5.7f", -5.7},
- FloatData{"-5.f", -5.},
- FloatData{"-.7f", -.7},
+ FloatData{"5.7f", static_cast<double>(5.7f)},
+ FloatData{"5.f", static_cast<double>(5.f)},
+ FloatData{".7f", static_cast<double>(.7f)},
+ FloatData{"-5.7f", static_cast<double>(-5.7f)},
+ FloatData{"-5.f", static_cast<double>(-5.f)},
+ FloatData{"-.7f", static_cast<double>(-.7f)},
// No decimal, with exponent
FloatData{"1e5", 1e5},
@@ -375,10 +375,10 @@
FloatData{"1e-5", 1e-5},
FloatData{"1E-5", 1e-5},
// No decimal, with exponent and 'f' suffix
- FloatData{"1e5f", 1e5},
- FloatData{"1E5f", 1e5},
- FloatData{"1e-5f", 1e-5},
- FloatData{"1E-5f", 1e-5},
+ FloatData{"1e5f", static_cast<double>(1e5f)},
+ FloatData{"1E5f", static_cast<double>(1e5f)},
+ FloatData{"1e-5f", static_cast<double>(1e-5f)},
+ FloatData{"1E-5f", static_cast<double>(1e-5f)},
// With decimal and exponents
FloatData{"0.2e+12", 0.2e12},
FloatData{"1.2e-5", 1.2e-5},
@@ -386,11 +386,15 @@
FloatData{"2.5e+0", 2.5},
FloatData{"2.5e-0", 2.5},
// With decimal and exponents and 'f' suffix
- FloatData{"0.2e+12f", 0.2e12},
- FloatData{"1.2e-5f", 1.2e-5},
- FloatData{"2.57e23f", 2.57e23},
- FloatData{"2.5e+0f", 2.5},
- FloatData{"2.5e-0f", 2.5}));
+ FloatData{"0.2e+12f", static_cast<double>(0.2e12f)},
+ FloatData{"1.2e-5f", static_cast<double>(1.2e-5f)},
+ FloatData{"2.57e23f", static_cast<double>(2.57e23f)},
+ FloatData{"2.5e+0f", static_cast<double>(2.5f)},
+ FloatData{"2.5e-0f", static_cast<double>(2.5f)},
+ // Quantization
+ FloatData{"3.141592653589793", 3.141592653589793}, // no quantization
+ FloatData{"3.141592653589793f", 3.1415927410125732} // f32 quantized
+ ));
using FloatTest_Invalid = testing::TestWithParam<const char*>;
TEST_P(FloatTest_Invalid, Handles) {
@@ -415,11 +419,11 @@
".e+",
".e-",
// Overflow
- "2.5e+256",
- "-2.5e+127",
+ "2.5e+256f",
+ "-2.5e+127f",
// Magnitude smaller than smallest positive f32.
- "2.5e-300",
- "-2.5e-300",
+ "2.5e-300f",
+ "-2.5e-300f",
// Decimal exponent must immediately
// follow the 'e'.
"2.5e 12",
@@ -680,7 +684,7 @@
auto t = l.next();
ASSERT_TRUE(t.Is(Token::Type::kError));
- EXPECT_EQ(t.to_str(), "value too large for i32");
+ EXPECT_EQ(t.to_str(), "value cannot be represented as 'i32'");
}
TEST_F(LexerTest, IntegerTest_HexSignedTooSmall) {
@@ -689,7 +693,7 @@
auto t = l.next();
ASSERT_TRUE(t.Is(Token::Type::kError));
- EXPECT_EQ(t.to_str(), "value too small for i32");
+ EXPECT_EQ(t.to_str(), "value cannot be represented as 'i32'");
}
TEST_F(LexerTest, IntegerTest_HexSignedTooManyDigits) {
diff --git a/src/tint/reader/wgsl/parser_impl_const_literal_test.cc b/src/tint/reader/wgsl/parser_impl_const_literal_test.cc
index cbddce3..8ce5178 100644
--- a/src/tint/reader/wgsl/parser_impl_const_literal_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_const_literal_test.cc
@@ -14,7 +14,6 @@
#include "src/tint/reader/wgsl/parser_impl_test_helper.h"
-#include <cmath>
#include <cstring>
#include "gmock/gmock.h"
@@ -22,25 +21,6 @@
namespace tint::reader::wgsl {
namespace {
-// Makes an IEEE 754 binary32 floating point number with
-// - 0 sign if sign is 0, 1 otherwise
-// - 'exponent_bits' is placed in the exponent space.
-// So, the exponent bias must already be included.
-float MakeFloat(uint32_t sign, uint32_t biased_exponent, uint32_t mantissa) {
- const uint32_t sign_bit = sign ? 0x80000000u : 0u;
- // The binary32 exponent is 8 bits, just below the sign.
- const uint32_t exponent_bits = (biased_exponent & 0xffu) << 23;
- // The mantissa is the bottom 23 bits.
- const uint32_t mantissa_bits = (mantissa & 0x7fffffu);
-
- uint32_t bits = sign_bit | exponent_bits | mantissa_bits;
- float result = 0.0f;
- static_assert(sizeof(result) == sizeof(bits),
- "expected float and uint32_t to be the same size");
- std::memcpy(&result, &bits, sizeof(bits));
- return result;
-}
-
// Makes an IEEE 754 binary64 floating point number with
// - 0 sign if sign is 0, 1 otherwise
// - 'exponent_bits' is placed in the exponent space.
@@ -133,38 +113,10 @@
auto c = p->const_literal();
EXPECT_FALSE(c.matched);
EXPECT_TRUE(c.errored);
- EXPECT_EQ(p->error(), "1:1: unsigned literal cannot be negative");
+ EXPECT_EQ(p->error(), "1:1: value cannot be represented as 'u32'");
ASSERT_EQ(c.value, nullptr);
}
-TEST_F(ParserImplTest, ConstLiteral_Float) {
- auto p = parser("234.e12");
- auto c = p->const_literal();
- EXPECT_TRUE(c.matched);
- EXPECT_FALSE(c.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(c.value, nullptr);
- ASSERT_TRUE(c->Is<ast::FloatLiteralExpression>());
- EXPECT_DOUBLE_EQ(c->As<ast::FloatLiteralExpression>()->value, 234e12);
- EXPECT_EQ(c->As<ast::FloatLiteralExpression>()->suffix,
- ast::FloatLiteralExpression::Suffix::kNone);
- EXPECT_EQ(c->source.range, (Source::Range{{1u, 1u}, {1u, 8u}}));
-}
-
-TEST_F(ParserImplTest, ConstLiteral_FloatF) {
- auto p = parser("234.e12f");
- auto c = p->const_literal();
- EXPECT_TRUE(c.matched);
- EXPECT_FALSE(c.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(c.value, nullptr);
- ASSERT_TRUE(c->Is<ast::FloatLiteralExpression>());
- EXPECT_DOUBLE_EQ(c->As<ast::FloatLiteralExpression>()->value, 234e12);
- EXPECT_EQ(c->As<ast::FloatLiteralExpression>()->suffix,
- ast::FloatLiteralExpression::Suffix::kF);
- EXPECT_EQ(c->source.range, (Source::Range{{1u, 1u}, {1u, 9u}}));
-}
-
TEST_F(ParserImplTest, ConstLiteral_InvalidFloat_IncompleteExponent) {
auto p = parser("1.0e+");
auto c = p->const_literal();
@@ -174,33 +126,6 @@
ASSERT_EQ(c.value, nullptr);
}
-TEST_F(ParserImplTest, ConstLiteral_InvalidFloat_TooSmallMagnitude) {
- auto p = parser("1e-256");
- auto c = p->const_literal();
- EXPECT_FALSE(c.matched);
- EXPECT_TRUE(c.errored);
- EXPECT_EQ(p->error(), "1:1: magnitude too small to be represented as f32");
- ASSERT_EQ(c.value, nullptr);
-}
-
-TEST_F(ParserImplTest, ConstLiteral_InvalidFloat_TooLargeNegative) {
- auto p = parser("-1.2e+256");
- auto c = p->const_literal();
- EXPECT_FALSE(c.matched);
- EXPECT_TRUE(c.errored);
- EXPECT_EQ(p->error(), "1:1: value too small for f32");
- ASSERT_EQ(c.value, nullptr);
-}
-
-TEST_F(ParserImplTest, ConstLiteral_InvalidFloat_TooLargePositive) {
- auto p = parser("1.2e+256");
- auto c = p->const_literal();
- EXPECT_FALSE(c.matched);
- EXPECT_TRUE(c.errored);
- EXPECT_EQ(p->error(), "1:1: value too large for f32");
- ASSERT_EQ(c.value, nullptr);
-}
-
struct FloatLiteralTestCase {
std::string input;
double expected;
@@ -224,8 +149,12 @@
EXPECT_FALSE(c.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(c.value, nullptr);
- ASSERT_TRUE(c->Is<ast::FloatLiteralExpression>());
- EXPECT_DOUBLE_EQ(c->As<ast::FloatLiteralExpression>()->value, params.expected);
+ auto* literal = c->As<ast::FloatLiteralExpression>();
+ ASSERT_NE(literal, nullptr);
+ EXPECT_DOUBLE_EQ(literal->value, params.expected)
+ << "\n"
+ << "got: " << std::hexfloat << literal->value << "\n"
+ << "expected: " << std::hexfloat << params.expected;
if (params.input.back() == 'f') {
EXPECT_EQ(c->As<ast::FloatLiteralExpression>()->suffix,
ast::FloatLiteralExpression::Suffix::kF);
@@ -233,133 +162,180 @@
EXPECT_EQ(c->As<ast::FloatLiteralExpression>()->suffix,
ast::FloatLiteralExpression::Suffix::kNone);
}
+ EXPECT_EQ(c->source.range, (Source::Range{{1u, 1u}, {1u, 1u + params.input.size()}}));
}
using FloatLiteralTestCaseList = std::vector<FloatLiteralTestCase>;
-FloatLiteralTestCaseList DecimalFloatCases() {
- return FloatLiteralTestCaseList{
- {"0.0", 0.0}, // Zero
- {"1.0", 1.0}, // One
- {"-1.0", -1.0}, // MinusOne
- {"1000000000.0", 1e9}, // Billion
- {"-0.0", std::copysign(0.0, -5.0)}, // NegativeZero
- {"0.0", MakeDouble(0, 0, 0)}, // Zero
- {"-0.0", MakeDouble(1, 0, 0)}, // NegativeZero
- {"1.0", MakeDouble(0, 1023, 0)}, // One
- {"-1.0", MakeDouble(1, 1023, 0)}, // NegativeOne
- };
-}
-
INSTANTIATE_TEST_SUITE_P(ParserImplFloatLiteralTest_Float,
ParserImplFloatLiteralTest,
- testing::ValuesIn(DecimalFloatCases()));
+ testing::ValuesIn(FloatLiteralTestCaseList{
+ {"0.0", 0.0}, // Zero
+ {"1.0", 1.0}, // One
+ {"-1.0", -1.0}, // MinusOne
+ {"1000000000.0", 1e9}, // Billion
+ {"-0.0", std::copysign(0.0, -5.0)}, // NegativeZero
+ {"0.0", MakeDouble(0, 0, 0)}, // Zero
+ {"-0.0", MakeDouble(1, 0, 0)}, // NegativeZero
+ {"1.0", MakeDouble(0, 1023, 0)}, // One
+ {"-1.0", MakeDouble(1, 1023, 0)}, // NegativeOne
+
+ {"234.e12", 234.e12},
+ {"234.e12f", static_cast<double>(234.e12f)},
+
+ // Tiny cases
+ {"1e-5000", 0.0},
+ {"-1e-5000", 0.0},
+ {"1e-5000f", 0.0},
+ {"-1e-5000f", 0.0},
+ {"1e-50f", 0.0},
+ {"-1e-50f", 0.0},
+
+ // Nearly overflow
+ {"1.e308", 1.e308},
+ {"-1.e308", -1.e308},
+ {"1.8e307", 1.8e307},
+ {"-1.8e307", -1.8e307},
+ {"1.798e307", 1.798e307},
+ {"-1.798e307", -1.798e307},
+ {"1.7977e307", 1.7977e307},
+ {"-1.7977e307", -1.7977e307},
+
+ // Nearly overflow
+ {"1e38f", static_cast<double>(1e38f)},
+ {"-1e38f", static_cast<double>(-1e38f)},
+ {"4.0e37f", static_cast<double>(4.0e37f)},
+ {"-4.0e37f", static_cast<double>(-4.0e37f)},
+ {"3.5e37f", static_cast<double>(3.5e37f)},
+ {"-3.5e37f", static_cast<double>(-3.5e37f)},
+ {"3.403e37f", static_cast<double>(3.403e37f)},
+ {"-3.403e37f", static_cast<double>(-3.403e37f)},
+ }));
const double NegInf = MakeDouble(1, 0x7FF, 0);
const double PosInf = MakeDouble(0, 0x7FF, 0);
FloatLiteralTestCaseList HexFloatCases() {
return FloatLiteralTestCaseList{
// Regular numbers
- {"0x0p+0", 0.0},
- {"0x1p+0", 1.0},
- {"0x1p+1", 2.0},
- {"0x1.8p+1", 3.0},
- {"0x1.99999ap-4", 0.10000000149011612},
- {"0x1p-1", 0.5},
- {"0x1p-2", 0.25},
- {"0x1.8p-1", 0.75},
- {"-0x0p+0", -0.0},
- {"-0x1p+0", -1.0},
- {"-0x1p-1", -0.5},
- {"-0x1p-2", -0.25},
- {"-0x1.8p-1", -0.75},
+ {"0x0p+0", 0x0p+0},
+ {"0x1p+0", 0x1p+0},
+ {"0x1p+1", 0x1p+1},
+ {"0x1.8p+1", 0x1.8p+1},
+ {"0x1.99999ap-4", 0x1.99999ap-4},
+ {"0x1p-1", 0x1p-1},
+ {"0x1p-2", 0x1p-2},
+ {"0x1.8p-1", 0x1.8p-1},
+ {"-0x0p+0", -0x0p+0},
+ {"-0x1p+0", -0x1p+0},
+ {"-0x1p-1", -0x1p-1},
+ {"-0x1p-2", -0x1p-2},
+ {"-0x1.8p-1", -0x1.8p-1},
// Large numbers
- {"0x1p+9", 512.0},
- {"0x1p+10", 1024.0},
- {"0x1.02p+10", 1024.0 + 8.0},
- {"-0x1p+9", -512.0},
- {"-0x1p+10", -1024.0},
- {"-0x1.02p+10", -1024.0 - 8.0},
+ {"0x1p+9", 0x1p+9},
+ {"0x1p+10", 0x1p+10},
+ {"0x1.02p+10", 0x1.02p+10},
+ {"-0x1p+9", -0x1p+9},
+ {"-0x1p+10", -0x1p+10},
+ {"-0x1.02p+10", -0x1.02p+10},
// Small numbers
- {"0x1p-9", 1.0 / 512.0},
- {"0x1p-10", 1.0 / 1024.0},
- {"0x1.02p-3", 1.0 / 1024.0 + 1.0 / 8.0},
- {"-0x1p-9", 1.0 / -512.0},
- {"-0x1p-10", 1.0 / -1024.0},
- {"-0x1.02p-3", 1.0 / -1024.0 - 1.0 / 8.0},
+ {"0x1p-9", 0x1p-9},
+ {"0x1p-10", 0x1p-10},
+ {"0x1.02p-3", 0x1.02p-3},
+ {"-0x1p-9", -0x1p-9},
+ {"-0x1p-10", -0x1p-10},
+ {"-0x1.02p-3", -0x1.02p-3},
// Near lowest non-denorm
- {"0x1p-124", std::ldexp(1.0 * 8.0, -127)},
- {"0x1p-125", std::ldexp(1.0 * 4.0, -127)},
- {"-0x1p-124", -std::ldexp(1.0 * 8.0, -127)},
- {"-0x1p-125", -std::ldexp(1.0 * 4.0, -127)},
+ {"0x1p-1020", 0x1p-1020},
+ {"0x1p-1021", 0x1p-1021},
+ {"-0x1p-1020", -0x1p-1020},
+ {"-0x1p-1021", -0x1p-1021},
+
+ {"0x1p-124f", 0x1p-124},
+ {"0x1p-125f", 0x1p-125},
+ {"-0x1p-124f", -0x1p-124},
+ {"-0x1p-125f", -0x1p-125},
// Lowest non-denorm
- {"0x1p-126", std::ldexp(1.0 * 2.0, -127)},
- {"-0x1p-126", -std::ldexp(1.0 * 2.0, -127)},
+ {"0x1p-1022", 0x1p-1022},
+ {"-0x1p-1022", -0x1p-1022},
+
+ {"0x1p-126f", 0x1p-126},
+ {"-0x1p-126f", -0x1p-126},
// Denormalized values
- {"0x1p-127", std::ldexp(1.0, -127)},
- {"0x1p-128", std::ldexp(1.0 / 2.0, -127)},
- {"0x1p-129", std::ldexp(1.0 / 4.0, -127)},
- {"0x1p-130", std::ldexp(1.0 / 8.0, -127)},
- {"-0x1p-127", -std::ldexp(1.0, -127)},
- {"-0x1p-128", -std::ldexp(1.0 / 2.0, -127)},
- {"-0x1p-129", -std::ldexp(1.0 / 4.0, -127)},
- {"-0x1p-130", -std::ldexp(1.0 / 8.0, -127)},
+ {"0x1p-1023", 0x1p-1023},
+ {"0x1p-1024", 0x1p-1024},
+ {"0x1p-1025", 0x1p-1025},
+ {"0x1p-1026", 0x1p-1026},
+ {"-0x1p-1023", -0x1p-1023},
+ {"-0x1p-1024", -0x1p-1024},
+ {"-0x1p-1025", -0x1p-1025},
+ {"-0x1p-1026", -0x1p-1026},
+ {"0x1.8p-1023", 0x1.8p-1023},
+ {"0x1.8p-1024", 0x1.8p-1024},
- {"0x1.8p-127", std::ldexp(1.0, -127) + (std::ldexp(1.0, -127) / 2.0)},
- {"0x1.8p-128", std::ldexp(1.0, -127) / 2.0 + (std::ldexp(1.0, -127) / 4.0)},
+ {"0x1p-127f", 0x1p-127},
+ {"0x1p-128f", 0x1p-128},
+ {"0x1p-129f", 0x1p-129},
+ {"0x1p-130f", 0x1p-130},
+ {"-0x1p-127f", -0x1p-127},
+ {"-0x1p-128f", -0x1p-128},
+ {"-0x1p-129f", -0x1p-129},
+ {"-0x1p-130f", -0x1p-130},
+ {"0x1.8p-127f", 0x1.8p-127},
+ {"0x1.8p-128f", 0x1.8p-128},
+
+ // F64 extremities
+ {"0x1p-1074", 0x1p-1074}, // +SmallestDenormal
+ {"0x1p-1073", 0x1p-1073}, // +BiggerDenormal
+ {"0x1.ffffffffffffp-1027", 0x1.ffffffffffffp-1027}, // +LargestDenormal
+ {"-0x1p-1074", -0x1p-1074}, // -SmallestDenormal
+ {"-0x1p-1073", -0x1p-1073}, // -BiggerDenormal
+ {"-0x1.ffffffffffffp-1027", -0x1.ffffffffffffp-1027}, // -LargestDenormal
+
+ {"0x0.cafebeeff000dp-1022", 0x0.cafebeeff000dp-1022}, // +Subnormal
+ {"-0x0.cafebeeff000dp-1022", -0x0.cafebeeff000dp-1022}, // -Subnormal
+ {"0x1.2bfaf8p-1052", 0x1.2bfaf8p-1052}, // +Subnormal
+ {"-0x1.2bfaf8p-1052", -0x1.2bfaf8p-1052}, // +Subnormal
+ {"0x1.55554p-1055", 0x1.55554p-1055}, // +Subnormal
+ {"-0x1.55554p-1055", -0x1.55554p-1055}, // -Subnormal
// F32 extremities
- {"0x1p-149", static_cast<double>(MakeFloat(0, 0, 1))}, // +SmallestDenormal
- {"0x1p-148", static_cast<double>(MakeFloat(0, 0, 2))}, // +BiggerDenormal
- {"0x1.fffffcp-127", static_cast<double>(MakeFloat(0, 0, 0x7fffff))}, // +LargestDenormal
- {"-0x1p-149", static_cast<double>(MakeFloat(1, 0, 1))}, // -SmallestDenormal
- {"-0x1p-148", static_cast<double>(MakeFloat(1, 0, 2))}, // -BiggerDenormal
- {"-0x1.fffffcp-127", static_cast<double>(MakeFloat(1, 0, 0x7fffff))}, // -LargestDenormal
+ {"0x1p-149", 0x1p-149}, // +SmallestDenormal
+ {"0x1p-148", 0x1p-148}, // +BiggerDenormal
+ {"0x1.fffffcp-127", 0x1.fffffcp-127}, // +LargestDenormal
+ {"-0x1p-149", -0x1p-149}, // -SmallestDenormal
+ {"-0x1p-148", -0x1p-148}, // -BiggerDenormal
+ {"-0x1.fffffcp-127", -0x1.fffffcp-127}, // -LargestDenormal
- {"0x1.2bfaf8p-127", static_cast<double>(MakeFloat(0, 0, 0xcafebe))}, // +Subnormal
- {"-0x1.2bfaf8p-127", static_cast<double>(MakeFloat(1, 0, 0xcafebe))}, // -Subnormal
- {"0x1.55554p-130", static_cast<double>(MakeFloat(0, 0, 0xaaaaa))}, // +Subnormal
- {"-0x1.55554p-130", static_cast<double>(MakeFloat(1, 0, 0xaaaaa))}, // -Subnormal
-
- // Nan -> Infinity
- {"0x1.8p+128", PosInf},
- {"0x1.0002p+128", PosInf},
- {"0x1.0018p+128", PosInf},
- {"0x1.01ep+128", PosInf},
- {"0x1.fffffep+128", PosInf},
- {"-0x1.8p+128", NegInf},
- {"-0x1.0002p+128", NegInf},
- {"-0x1.0018p+128", NegInf},
- {"-0x1.01ep+128", NegInf},
- {"-0x1.fffffep+128", NegInf},
-
- // Infinity
- {"0x1p+128", PosInf},
- {"-0x1p+128", NegInf},
- {"0x32p+127", PosInf},
- {"0x32p+500", PosInf},
- {"-0x32p+127", NegInf},
- {"-0x32p+500", NegInf},
-
- // Overflow -> Infinity
- {"0x1p+129", PosInf},
- {"0x1.1p+128", PosInf},
- {"-0x1p+129", NegInf},
- {"-0x1.1p+128", NegInf},
- {"0x1.0p2147483520", PosInf}, // INT_MAX - 127 (largest valid exponent)
+ {"0x0.cafebp-129", 0x0.cafebp-129}, // +Subnormal
+ {"-0x0.cafebp-129", -0x0.cafebp-129}, // -Subnormal
+ {"0x1.2bfaf8p-127", 0x1.2bfaf8p-127}, // +Subnormal
+ {"-0x1.2bfaf8p-127", -0x1.2bfaf8p-127}, // -Subnormal
+ {"0x1.55554p-130", 0x1.55554p-130}, // +Subnormal
+ {"-0x1.55554p-130", -0x1.55554p-130}, // -Subnormal
// Underflow -> Zero
- {"0x1p-500", 0.0}, // Exponent underflows
- {"-0x1p-500", -0.0},
- {"0x0.00000000001p-126", 0.0}, // Fraction causes underflow
- {"-0x0.0000000001p-127", -0.0},
- {"0x0.01p-142", 0.0},
- {"-0x0.01p-142", -0.0}, // Fraction causes additional underflow
- {"0x1.0p-2147483520", 0}, // -(INT_MAX - 127) (smallest valid exponent)
+ {"0x1p-1074", 0.0}, // Exponent underflows
+ {"-0x1p-1074", 0.0},
+ {"0x1p-5000", 0.0},
+ {"-0x1p-5000", 0.0},
+ {"0x0.00000000000000000000001p-1022", 0.0}, // Fraction causes underflow
+ {"-0x0.0000000000000000000001p-1023", -0.0},
+ {"0x0.01p-1073", -0.0},
+ {"-0x0.01p-1073", -0.0}, // Fraction causes additional underflow
+
+ {"0x1p-150f", 0.0}, // Exponent underflows
+ {"-0x1p-150f", 0.0},
+ {"0x1p-500f", 0.0},
+ {"-0x1p-500f", -0.0},
+ {"0x0.00000000001p-126f", 0.0}, // Fraction causes underflow
+ {"-0x0.0000000001p-127f", -0.0},
+ {"0x0.01p-142f", 0.0},
+ {"-0x0.01p-142f", -0.0}, // Fraction causes additional underflow
+ {"0x1.0p-9223372036854774784", 0}, // -(INT64_MAX - 1023) (smallest valid exponent)
// Zero with non-zero exponent -> Zero
{"0x0p+0", 0.0},
@@ -369,22 +345,22 @@
{"0x0p-9999999999", 0.0},
// Same, but with very large positive exponents that would cause overflow
// if the mantissa were non-zero.
- {"0x0p+4000000000", 0.0}, // 4 billion:
- {"0x0p+40000000000", 0.0}, // 40 billion
- {"-0x0p+40000000000", 0.0}, // As above 2, but negative mantissa
- {"-0x0p+400000000000", 0.0},
- {"0x0.00p+4000000000", 0.0}, // As above 4, but with fractional part
- {"0x0.00p+40000000000", 0.0},
- {"-0x0.00p+40000000000", 0.0},
- {"-0x0.00p+400000000000", 0.0},
- {"0x0p-4000000000", 0.0}, // As above 8, but with negative exponents
- {"0x0p-40000000000", 0.0},
- {"-0x0p-40000000000", 0.0},
- {"-0x0p-400000000000", 0.0},
- {"0x0.00p-4000000000", 0.0},
- {"0x0.00p-40000000000", 0.0},
- {"-0x0.00p-40000000000", 0.0},
- {"-0x0.00p-400000000000", 0.0},
+ {"0x0p+10000000000000000000", 0.0}, // 10 quintillion (10,000,000,000,000,000,000)
+ {"0x0p+100000000000000000000", 0.0}, // 100 quintillion (100,000,000,000,000,000,000)
+ {"-0x0p+100000000000000000000", 0.0}, // As above 2, but negative mantissa
+ {"-0x0p+1000000000000000000000", 0.0},
+ {"0x0.00p+10000000000000000000", 0.0}, // As above 4, but with fractional part
+ {"0x0.00p+100000000000000000000", 0.0},
+ {"-0x0.00p+100000000000000000000", 0.0},
+ {"-0x0.00p+1000000000000000000000", 0.0},
+ {"0x0p-10000000000000000000", 0.0}, // As above 8, but with negative exponents
+ {"0x0p-100000000000000000000", 0.0},
+ {"-0x0p-100000000000000000000", 0.0},
+ {"-0x0p-1000000000000000000000", 0.0},
+ {"0x0.00p-10000000000000000000", 0.0},
+ {"0x0.00p-100000000000000000000", 0.0},
+ {"-0x0.00p-100000000000000000000", 0.0},
+ {"-0x0.00p-1000000000000000000000", 0.0},
// Test parsing
{"0x0p0", 0.0},
@@ -462,65 +438,178 @@
ParserImplFloatLiteralTest,
testing::ValuesIn(UpperCase0X(HexFloatCases())));
-struct InvalidLiteralTestCase {
- const char* input;
- const char* error_msg;
-};
+// <error, source>
+using InvalidLiteralTestCase = std::tuple<const char*, const char*>;
+
class ParserImplInvalidLiteralTest : public ParserImplTestWithParam<InvalidLiteralTestCase> {};
TEST_P(ParserImplInvalidLiteralTest, Parse) {
- auto params = GetParam();
- SCOPED_TRACE(params.input);
- auto p = parser(params.input);
+ auto* error = std::get<0>(GetParam());
+ auto* source = std::get<1>(GetParam());
+ auto p = parser(source);
auto c = p->const_literal();
EXPECT_FALSE(c.matched);
EXPECT_TRUE(c.errored);
- EXPECT_EQ(p->error(), params.error_msg);
+ EXPECT_EQ(p->error(), std::string(error));
ASSERT_EQ(c.value, nullptr);
}
-InvalidLiteralTestCase invalid_hexfloat_mantissa_too_large_cases[] = {
- {"0x1.ffffffff8p0", "1:1: mantissa is too large for hex float"},
- {"0x1f.fffffff8p0", "1:1: mantissa is too large for hex float"},
- {"0x1ff.ffffff8p0", "1:1: mantissa is too large for hex float"},
- {"0x1fff.fffff8p0", "1:1: mantissa is too large for hex float"},
- {"0x1ffff.ffff8p0", "1:1: mantissa is too large for hex float"},
- {"0x1fffff.fff8p0", "1:1: mantissa is too large for hex float"},
- {"0x1ffffff.ff8p0", "1:1: mantissa is too large for hex float"},
- {"0x1fffffff.f8p0", "1:1: mantissa is too large for hex float"},
- {"0x1ffffffff.8p0", "1:1: mantissa is too large for hex float"},
- {"0x1ffffffff8.p0", "1:1: mantissa is too large for hex float"},
-};
-INSTANTIATE_TEST_SUITE_P(ParserImplInvalidLiteralTest_HexFloatMantissaTooLarge,
- ParserImplInvalidLiteralTest,
- testing::ValuesIn(invalid_hexfloat_mantissa_too_large_cases));
+INSTANTIATE_TEST_SUITE_P(
+ HexFloatMantissaTooLarge,
+ ParserImplInvalidLiteralTest,
+ testing::Combine(testing::Values("1:1: mantissa is too large for hex float"),
+ testing::ValuesIn(std::vector<const char*>{
+ "0x1.ffffffffffffffff8p0",
+ "0x1f.fffffffffffffff8p0",
+ "0x1ff.ffffffffffffff8p0",
+ "0x1fff.fffffffffffff8p0",
+ "0x1ffff.ffffffffffff8p0",
+ "0x1fffff.fffffffffff8p0",
+ "0x1ffffff.ffffffffff8p0",
+ "0x1fffffff.fffffffff8p0",
+ "0x1ffffffff.ffffffff8p0",
+ "0x1fffffffff.fffffff8p0",
+ "0x1ffffffffff.ffffff8p0",
+ "0x1fffffffffff.fffff8p0",
+ "0x1ffffffffffff.ffff8p0",
+ "0x1fffffffffffff.fff8p0",
+ "0x1ffffffffffffff.ff8p0",
+ "0x1ffffffffffffffff.8p0",
+ "0x1ffffffffffffffff8.p0",
+ })));
-InvalidLiteralTestCase invalid_hexfloat_exponent_too_large_cases[] = {
- {"0x1p+2147483521", "1:1: exponent is too large for hex float"},
- {"0x1p-2147483521", "1:1: exponent is too large for hex float"},
- {"0x1p+4294967296", "1:1: exponent is too large for hex float"},
- {"0x1p-4294967296", "1:1: exponent is too large for hex float"},
-};
-INSTANTIATE_TEST_SUITE_P(ParserImplInvalidLiteralTest_HexFloatExponentTooLarge,
- ParserImplInvalidLiteralTest,
- testing::ValuesIn(invalid_hexfloat_exponent_too_large_cases));
+INSTANTIATE_TEST_SUITE_P(
+ HexFloatExponentTooLarge,
+ ParserImplInvalidLiteralTest,
+ testing::Combine(testing::Values("1:1: exponent is too large for hex float"),
+ testing::ValuesIn(std::vector<const char*>{
+ "0x1p+9223372036854774785",
+ "0x1p-9223372036854774785",
+ "0x1p+18446744073709551616",
+ "0x1p-18446744073709551616",
+ })));
-InvalidLiteralTestCase invalid_hexfloat_exponent_missing_cases[] = {
- // Lower case p
- {"0x0p", "1:1: expected an exponent value for hex float"},
- {"0x0p+", "1:1: expected an exponent value for hex float"},
- {"0x0p-", "1:1: expected an exponent value for hex float"},
- {"0x1.0p", "1:1: expected an exponent value for hex float"},
- {"0x0.1p", "1:1: expected an exponent value for hex float"},
- // Upper case p
- {"0x0P", "1:1: expected an exponent value for hex float"},
- {"0x0P+", "1:1: expected an exponent value for hex float"},
- {"0x0P-", "1:1: expected an exponent value for hex float"},
- {"0x1.0P", "1:1: expected an exponent value for hex float"},
- {"0x0.1P", "1:1: expected an exponent value for hex float"},
-};
-INSTANTIATE_TEST_SUITE_P(ParserImplInvalidLiteralTest_HexFloatExponentMissing,
- ParserImplInvalidLiteralTest,
- testing::ValuesIn(invalid_hexfloat_exponent_missing_cases));
+INSTANTIATE_TEST_SUITE_P(
+ HexFloatMissingExponent,
+ ParserImplInvalidLiteralTest,
+ testing::Combine(testing::Values("1:1: expected an exponent value for hex float"),
+ testing::ValuesIn(std::vector<const char*>{
+ // Lower case p
+ "0x0p",
+ "0x0p+",
+ "0x0p-",
+ "0x1.0p",
+ "0x0.1p",
+ // Upper case p
+ "0x0P",
+ "0x0P+",
+ "0x0P-",
+ "0x1.0P",
+ "0x0.1P",
+ })));
+
+INSTANTIATE_TEST_SUITE_P(
+ HexNaNAFloat,
+ ParserImplInvalidLiteralTest,
+ testing::Combine(testing::Values("1:1: value cannot be represented as 'abstract-float'"),
+ testing::ValuesIn(std::vector<const char*>{
+ "0x1.8p+1024",
+ "0x1.0002p+1024",
+ "0x1.0018p+1024",
+ "0x1.01ep+1024",
+ "0x1.fffffep+1024",
+ "-0x1.8p+1024",
+ "-0x1.0002p+1024",
+ "-0x1.0018p+1024",
+ "-0x1.01ep+1024",
+ "-0x1.fffffep+1024",
+ })));
+
+INSTANTIATE_TEST_SUITE_P(
+ HexNaNF32,
+ ParserImplInvalidLiteralTest,
+ testing::Combine(testing::Values("1:1: value cannot be represented as 'f32'"),
+ testing::ValuesIn(std::vector<const char*>{
+ "0x1.8p+128f",
+ "0x1.0002p+128f",
+ "0x1.0018p+128f",
+ "0x1.01ep+128f",
+ "0x1.fffffep+128f",
+ "-0x1.8p+128f",
+ "-0x1.0002p+128f",
+ "-0x1.0018p+128f",
+ "-0x1.01ep+128f",
+ "-0x1.fffffep+128f",
+ })));
+
+INSTANTIATE_TEST_SUITE_P(
+ HexOverflowAFloat,
+ ParserImplInvalidLiteralTest,
+ testing::Combine(testing::Values("1:1: value cannot be represented as 'abstract-float'"),
+ testing::ValuesIn(std::vector<const char*>{
+ "0x1p+1024",
+ "-0x1p+1024",
+ "0x1.1p+1024",
+ "-0x1.1p+1024",
+ "0x1p+1025",
+ "-0x1p+1025",
+ "0x32p+1023",
+ "-0x32p+1023",
+ "0x32p+5000",
+ "-0x32p+5000",
+ "0x1.0p9223372036854774784",
+ "-0x1.0p9223372036854774784",
+ })));
+
+INSTANTIATE_TEST_SUITE_P(
+ HexOverflowF32,
+ ParserImplInvalidLiteralTest,
+ testing::Combine(testing::Values("1:1: value cannot be represented as 'f32'"),
+ testing::ValuesIn(std::vector<const char*>{
+ "0x1p+128f",
+ "-0x1p+128f",
+ "0x1.1p+128f",
+ "-0x1.1p+128f",
+ "0x1p+129f",
+ "-0x1p+129f",
+ "0x32p+127f",
+ "-0x32p+127f",
+ "0x32p+500f",
+ "-0x32p+500f",
+ })));
+
+INSTANTIATE_TEST_SUITE_P(
+ DecOverflowAFloat,
+ ParserImplInvalidLiteralTest,
+ testing::Combine(testing::Values("1:1: value cannot be represented as 'abstract-float'"),
+ testing::ValuesIn(std::vector<const char*>{
+ "1.e309",
+ "-1.e309",
+ "1.8e308",
+ "-1.8e308",
+ "1.798e308",
+ "-1.798e308",
+ "1.7977e308",
+ "-1.7977e308",
+ "1.2e+5000",
+ "-1.2e+5000",
+ })));
+
+INSTANTIATE_TEST_SUITE_P(
+ DecOverflowF32,
+ ParserImplInvalidLiteralTest,
+ testing::Combine(testing::Values("1:1: value cannot be represented as 'f32'"),
+ testing::ValuesIn(std::vector<const char*>{
+ "1e39f",
+ "-1e39f",
+ "4.0e38f",
+ "-4.0e38f",
+ "3.5e38f",
+ "-3.5e38f",
+ "3.403e38f",
+ "-3.403e38f",
+ "1.2e+256f",
+ "-1.2e+256f",
+ })));
TEST_F(ParserImplTest, ConstLiteral_FloatHighest) {
const auto highest = std::numeric_limits<float>::max();
diff --git a/src/tint/reader/wgsl/parser_impl_test.cc b/src/tint/reader/wgsl/parser_impl_test.cc
index 33b279c..66291c5 100644
--- a/src/tint/reader/wgsl/parser_impl_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_test.cc
@@ -82,7 +82,7 @@
TEST_F(ParserImplTest, HandlesBadToken_InMiddle) {
auto p = parser(R"(
fn main() {
- let f = 0x1p500000000000; // Exponent too big for hex float
+ let f = 0x1p10000000000000000000; // Exponent too big for hex float
return;
})");
@@ -96,7 +96,7 @@
fn main() {
return;
}
-0x1p5000000000000
+0x1p10000000000000000000
)");
ASSERT_FALSE(p->Parse());
diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc
index e2e1a02..9f7be54 100644
--- a/src/tint/resolver/materialize_test.cc
+++ b/src/tint/resolver/materialize_test.cc
@@ -134,35 +134,65 @@
struct Data {
std::string target_type_name;
+ std::string target_element_type_name;
builder::ast_type_func_ptr target_ast_ty;
builder::sem_type_func_ptr target_sem_ty;
builder::ast_expr_func_ptr target_expr;
- std::string literal_type_name;
- builder::ast_expr_func_ptr literal_value;
+ std::string source_type_name;
+ builder::ast_expr_func_ptr source_builder;
std::variant<AInt, AFloat> materialized_value;
+ double literal_value;
};
-template <typename TARGET_TYPE, typename LITERAL_TYPE, typename MATERIALIZED_TYPE = AInt>
-Data Types(MATERIALIZED_TYPE materialized_value = 0_a) {
+template <typename TARGET_TYPE, typename SOURCE_TYPE, typename MATERIALIZED_TYPE>
+Data Types(MATERIALIZED_TYPE materialized_value, double literal_value) {
+ using TargetDataType = builder::DataType<TARGET_TYPE>;
+ using SourceDataType = builder::DataType<SOURCE_TYPE>;
+ using TargetElementDataType = builder::DataType<typename TargetDataType::ElementType>;
return {
- builder::DataType<TARGET_TYPE>::Name(), //
- builder::DataType<TARGET_TYPE>::AST, //
- builder::DataType<TARGET_TYPE>::Sem, //
- builder::DataType<TARGET_TYPE>::Expr, //
- builder::DataType<LITERAL_TYPE>::Name(), //
- builder::DataType<LITERAL_TYPE>::Expr, //
+ TargetDataType::Name(), // target_type_name
+ TargetElementDataType::Name(), // target_element_type_name
+ TargetDataType::AST, // target_ast_ty
+ TargetDataType::Sem, // target_sem_ty
+ TargetDataType::Expr, // target_expr
+ SourceDataType::Name(), // literal_type_name
+ SourceDataType::Expr, // literal_builder
materialized_value,
+ literal_value,
+ };
+}
+
+template <typename TARGET_TYPE, typename SOURCE_TYPE>
+Data Types() {
+ using TargetDataType = builder::DataType<TARGET_TYPE>;
+ using SourceDataType = builder::DataType<SOURCE_TYPE>;
+ using TargetElementDataType = builder::DataType<typename TargetDataType::ElementType>;
+ return {
+ TargetDataType::Name(), // target_type_name
+ TargetElementDataType::Name(), // target_element_type_name
+ TargetDataType::AST, // target_ast_ty
+ TargetDataType::Sem, // target_sem_ty
+ TargetDataType::Expr, // target_expr
+ SourceDataType::Name(), // literal_type_name
+ SourceDataType::Expr, // literal_builder
+ 0_a,
+ 0.0,
};
}
static std::ostream& operator<<(std::ostream& o, const Data& c) {
- return o << "[" << c.target_type_name << " <- " << c.literal_type_name << "]";
+ auto print_value = [&](auto&& v) { o << v; };
+ o << "[" << c.target_type_name << " <- " << c.source_type_name << "] [";
+ std::visit(print_value, c.materialized_value);
+ o << " <- " << c.literal_value << "]";
+ return o;
}
enum class Expectation {
kMaterialize,
kNoMaterialize,
kInvalidCast,
+ kValueCannotBeRepresented,
};
static std::ostream& operator<<(std::ostream& o, Expectation m) {
@@ -173,6 +203,8 @@
return o << "no-materialize";
case Expectation::kInvalidCast:
return o << "invalid-cast";
+ case Expectation::kValueCannotBeRepresented:
+ return o << "value too low or high";
}
return o << "<unknown>";
}
@@ -191,7 +223,7 @@
auto target_ty = [&] { return data.target_ast_ty(*this); };
auto target_expr = [&] { return data.target_expr(*this, 42); };
- auto* literal = data.literal_value(*this, 1);
+ auto* literal = data.source_builder(*this, data.literal_value);
switch (method) {
case Method::kVar:
WrapInFunction(Decl(Var("a", target_ty(), literal)));
@@ -252,11 +284,13 @@
uint32_t num_elems = 0;
const sem::Type* target_sem_el_ty = sem::Type::ElementOf(target_sem_ty, &num_elems);
EXPECT_TYPE(expr->ConstantValue().ElementType(), target_sem_el_ty);
- std::visit(
- [&](auto&& v) {
- EXPECT_EQ(expr->ConstantValue().Elements(), sem::Constant::Scalars(num_elems, {v}));
- },
- data.materialized_value);
+ expr->ConstantValue().WithElements([&](auto&& vec) {
+ using VEC_TY = std::decay_t<decltype(vec)>;
+ using EL_TY = typename VEC_TY::value_type;
+ ASSERT_TRUE(std::holds_alternative<EL_TY>(data.materialized_value));
+ VEC_TY expected(num_elems, std::get<EL_TY>(data.materialized_value));
+ EXPECT_EQ(vec, expected);
+ });
};
switch (expectation) {
@@ -281,110 +315,191 @@
switch (method) {
case Method::kBuiltinArg:
expect = "error: no matching call to min(" + data.target_type_name + ", " +
- data.literal_type_name + ")";
+ data.source_type_name + ")";
break;
case Method::kBinaryOp:
expect = "error: no matching overload for operator + (" +
- data.target_type_name + ", " + data.literal_type_name + ")";
+ data.target_type_name + ", " + data.source_type_name + ")";
break;
default:
- expect = "error: cannot convert value of type '" + data.literal_type_name +
+ expect = "error: cannot convert value of type '" + data.source_type_name +
"' to type '" + data.target_type_name + "'";
break;
}
EXPECT_THAT(r()->error(), testing::StartsWith(expect));
break;
}
+ case Expectation::kValueCannotBeRepresented:
+ ASSERT_FALSE(r()->Resolve());
+ EXPECT_THAT(r()->error(), testing::HasSubstr("cannot be represented as '" +
+ data.target_element_type_name + "'"));
+ break;
}
}
-// TODO(crbug.com/tint/1504): Test for abstract-numeric values not fitting in materialized types.
+/// Methods that support scalar materialization
+constexpr Method kScalarMethods[] = {Method::kLet, //
+ Method::kVar, //
+ Method::kFnArg, //
+ Method::kBuiltinArg, //
+ Method::kReturn, //
+ Method::kArray, //
+ Method::kStruct, //
+ Method::kBinaryOp};
-INSTANTIATE_TEST_SUITE_P(MaterializeScalar,
- MaterializeAbstractNumeric, //
- testing::Combine(testing::Values(Expectation::kMaterialize), //
- testing::Values(Method::kLet, //
- Method::kVar, //
- Method::kFnArg, //
- Method::kBuiltinArg, //
- Method::kReturn, //
- Method::kArray, //
- Method::kStruct, //
- Method::kBinaryOp), //
- testing::Values(Types<i32, AInt>(1_a), //
- Types<u32, AInt>(1_a), //
- Types<f32, AFloat>(1.0_a) //
- /* Types<f16, AFloat>(1.0_a), */ //
- /* Types<f16, AFloat>(1.0_a), */)));
+/// Methods that support vector materialization
+constexpr Method kVectorMethods[] = {Method::kLet, //
+ Method::kVar, //
+ Method::kFnArg, //
+ Method::kBuiltinArg, //
+ Method::kReturn, //
+ Method::kArray, //
+ Method::kStruct, //
+ Method::kBinaryOp};
-INSTANTIATE_TEST_SUITE_P(MaterializeVector,
- MaterializeAbstractNumeric, //
- testing::Combine(testing::Values(Expectation::kMaterialize), //
- testing::Values(Method::kLet, //
- Method::kVar, //
- Method::kFnArg, //
- Method::kBuiltinArg, //
- Method::kReturn, //
- Method::kArray, //
- Method::kStruct, //
- Method::kBinaryOp), //
- testing::Values(Types<i32V, AIntV>(1_a), //
- Types<u32V, AIntV>(1_a), //
- Types<f32V, AFloatV>(1.0_a) //
- /* Types<f16V, AFloatV>(1.0_a), */ //
- /* Types<f16V, AFloatV>(1.0_a), */)));
+/// Methods that support matrix materialization
+constexpr Method kMatrixMethods[] = {Method::kLet, //
+ Method::kVar, //
+ Method::kFnArg, //
+ Method::kReturn, //
+ Method::kArray, //
+ Method::kStruct, //
+ Method::kBinaryOp};
-INSTANTIATE_TEST_SUITE_P(MaterializeMatrix,
- MaterializeAbstractNumeric, //
- testing::Combine(testing::Values(Expectation::kMaterialize), //
- testing::Values(Method::kLet, //
- Method::kVar, //
- Method::kFnArg, //
- Method::kReturn, //
- Method::kArray, //
- Method::kStruct, //
- Method::kBinaryOp), //
- testing::Values(Types<f32M, AFloatM>(1.0_a) //
- /* Types<f16V, AFloatM>(1.0_a), */ //
- )));
+/// Methods that support materialization for switch cases
+constexpr Method kSwitchMethods[] = {Method::kSwitchCond, //
+ Method::kSwitchCase, //
+ Method::kSwitchCondWithAbstractCase, //
+ Method::kSwitchCaseWithAbstractCase};
-INSTANTIATE_TEST_SUITE_P(MaterializeSwitch,
- MaterializeAbstractNumeric, //
- testing::Combine(testing::Values(Expectation::kMaterialize), //
- testing::Values(Method::kSwitchCond, //
- Method::kSwitchCase, //
- Method::kSwitchCondWithAbstractCase, //
- Method::kSwitchCaseWithAbstractCase), //
- testing::Values(Types<i32, AInt>(1_a), //
- Types<u32, AInt>(1_a))));
+constexpr double kMaxF32 = static_cast<double>(f32::kHighest);
+constexpr double kPiF64 = 3.141592653589793;
+constexpr double kPiF32 = 3.1415927410125732; // kPiF64 quantized to f32
+
+// (2^-127)×(1+(0xfffffffffffff÷0x10000000000000))
+constexpr double kTooSmallF32 = 1.1754943508222874e-38;
+
+INSTANTIATE_TEST_SUITE_P(
+ MaterializeScalar,
+ MaterializeAbstractNumeric, //
+ testing::Combine(testing::Values(Expectation::kMaterialize), //
+ testing::ValuesIn(kScalarMethods), //
+ testing::Values(Types<i32, AInt>(0_a, 0.0), //
+ Types<i32, AInt>(2147483647_a, 2147483647.0), //
+ Types<i32, AInt>(-2147483648_a, -2147483648.0), //
+ Types<u32, AInt>(0_a, 0.0), //
+ Types<u32, AInt>(4294967295_a, 4294967295.0), //
+ Types<f32, AFloat>(0.0_a, 0.0), //
+ Types<f32, AFloat>(AFloat(kMaxF32), kMaxF32), //
+ Types<f32, AFloat>(AFloat(-kMaxF32), -kMaxF32), //
+ Types<f32, AFloat>(AFloat(kPiF32), kPiF64), //
+ Types<f32, AFloat>(0.0_a, kTooSmallF32), //
+ Types<f32, AFloat>(-0.0_a, -kTooSmallF32) //
+ /* Types<f16, AFloat>(1.0_a), */ //
+ /* Types<f16, AFloat>(1.0_a), */)));
+
+INSTANTIATE_TEST_SUITE_P(
+ MaterializeVector,
+ MaterializeAbstractNumeric, //
+ testing::Combine(testing::Values(Expectation::kMaterialize), //
+ testing::ValuesIn(kVectorMethods), //
+ testing::Values(Types<i32V, AIntV>(0_a, 0.0), //
+ Types<i32V, AIntV>(2147483647_a, 2147483647.0), //
+ Types<i32V, AIntV>(-2147483648_a, -2147483648.0), //
+ Types<u32V, AIntV>(0_a, 0.0), //
+ Types<u32V, AIntV>(4294967295_a, 4294967295.0), //
+ Types<f32V, AFloatV>(0.0_a, 0.0), //
+ Types<f32V, AFloatV>(AFloat(kMaxF32), kMaxF32), //
+ Types<f32V, AFloatV>(AFloat(-kMaxF32), -kMaxF32), //
+ Types<f32V, AFloatV>(AFloat(kPiF32), kPiF64), //
+ Types<f32V, AFloatV>(0.0_a, kTooSmallF32), //
+ Types<f32V, AFloatV>(-0.0_a, -kTooSmallF32) //
+ /* Types<f16V, AFloatV>(1.0_a), */ //
+ /* Types<f16V, AFloatV>(1.0_a), */)));
+
+INSTANTIATE_TEST_SUITE_P(
+ MaterializeMatrix,
+ MaterializeAbstractNumeric, //
+ testing::Combine(testing::Values(Expectation::kMaterialize), //
+ testing::ValuesIn(kMatrixMethods), //
+ testing::Values(Types<f32M, AFloatM>(0.0_a, 0.0), //
+ Types<f32M, AFloatM>(AFloat(kMaxF32), kMaxF32), //
+ Types<f32M, AFloatM>(AFloat(-kMaxF32), -kMaxF32), //
+ Types<f32M, AFloatM>(AFloat(kPiF32), kPiF64), //
+ Types<f32M, AFloatM>(0.0_a, kTooSmallF32), //
+ Types<f32M, AFloatM>(-0.0_a, -kTooSmallF32) //
+ /* Types<f16V, AFloatM>(1.0_a), */ //
+ )));
+
+INSTANTIATE_TEST_SUITE_P(
+ MaterializeSwitch,
+ MaterializeAbstractNumeric, //
+ testing::Combine(testing::Values(Expectation::kMaterialize), //
+ testing::ValuesIn(kSwitchMethods), //
+ testing::Values(Types<i32, AInt>(0_a, 0.0), //
+ Types<i32, AInt>(2147483647_a, 2147483647.0), //
+ Types<i32, AInt>(-2147483648_a, -2147483648.0), //
+ Types<u32, AInt>(0_a, 0.0), //
+ Types<u32, AInt>(4294967295_a, 4294967295.0))));
// TODO(crbug.com/tint/1504): Enable once we have abstract overloads of builtins / binary ops.
INSTANTIATE_TEST_SUITE_P(DISABLED_NoMaterialize,
- MaterializeAbstractNumeric, //
- testing::Combine(testing::Values(Expectation::kNoMaterialize), //
- testing::Values(Method::kBuiltinArg, //
- Method::kBinaryOp), //
- testing::Values(Types<AInt, AInt>(1_a), //
- Types<AFloat, AFloat>(1.0_a), //
- Types<AIntV, AIntV>(1_a), //
- Types<AFloatV, AFloatV>(1.0_a), //
- Types<AFloatM, AFloatM>(1.0_a))));
+ MaterializeAbstractNumeric, //
+ testing::Combine(testing::Values(Expectation::kNoMaterialize), //
+ testing::Values(Method::kBuiltinArg, //
+ Method::kBinaryOp), //
+ testing::Values(Types<AInt, AInt>(), //
+ Types<AFloat, AFloat>(), //
+ Types<AIntV, AIntV>(), //
+ Types<AFloatV, AFloatV>(), //
+ Types<AFloatM, AFloatM>())));
INSTANTIATE_TEST_SUITE_P(InvalidCast,
MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kInvalidCast), //
- testing::Values(Method::kLet, //
- Method::kVar, //
- Method::kFnArg, //
- Method::kBuiltinArg, //
- Method::kReturn, //
- Method::kArray, //
- Method::kStruct, //
- Method::kBinaryOp), //
+ testing::ValuesIn(kScalarMethods), //
testing::Values(Types<i32, AFloat>(), //
Types<u32, AFloat>(), //
Types<i32V, AFloatV>(), //
Types<u32V, AFloatV>())));
+INSTANTIATE_TEST_SUITE_P(
+ ScalarValueCannotBeRepresented,
+ MaterializeAbstractNumeric, //
+ testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented), //
+ testing::ValuesIn(kScalarMethods), //
+ testing::Values(Types<i32, AInt>(0_a, 2147483648.0), //
+ Types<i32, AInt>(0_a, -2147483649.0), //
+ Types<u32, AInt>(0_a, 4294967296), //
+ Types<u32, AInt>(0_a, -1.0), //
+ Types<f32, AFloat>(0.0_a, 3.5e+38), //
+ Types<f32, AFloat>(0.0_a, -3.5e+38) //
+ /* Types<f16, AFloat>(), */ //
+ /* Types<f16, AFloat>(), */)));
+
+INSTANTIATE_TEST_SUITE_P(
+ VectorValueCannotBeRepresented,
+ MaterializeAbstractNumeric, //
+ testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented), //
+ testing::ValuesIn(kVectorMethods), //
+ testing::Values(Types<i32V, AIntV>(0_a, 2147483648.0), //
+ Types<i32V, AIntV>(0_a, -2147483649.0), //
+ Types<u32V, AIntV>(0_a, 4294967296), //
+ Types<u32V, AIntV>(0_a, -1.0), //
+ Types<f32V, AFloatV>(0.0_a, 3.5e+38), //
+ Types<f32V, AFloatV>(0.0_a, -3.5e+38) //
+ /* Types<f16V, AFloatV>(), */ //
+ /* Types<f16V, AFloatV>(), */)));
+
+INSTANTIATE_TEST_SUITE_P(
+ MatrixValueCannotBeRepresented,
+ MaterializeAbstractNumeric, //
+ testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented), //
+ testing::ValuesIn(kMatrixMethods), //
+ testing::Values(Types<f32M, AFloatM>(0.0_a, 3.5e+38), //
+ Types<f32M, AFloatM>(0.0_a, -3.5e+38) //
+ /* Types<f16M, AFloatM>(), */ //
+ /* Types<f16M, AFloatM>(), */)));
+
} // namespace MaterializeTests
} // namespace
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 9d3d7a4..1dcedf9 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -806,7 +806,7 @@
return false;
}
- ws[i].value = static_cast<uint32_t>(value.Element<AInt>(0).value);
+ ws[i].value = value.Element<uint32_t>(0);
}
current_function_->SetWorkgroupSize(std::move(ws));
@@ -1110,19 +1110,27 @@
// Helper for actually creating the the materialize node, performing the constant cast, updating
// the ast -> sem binding, and performing validation.
auto materialize = [&](const sem::Type* target_ty) -> sem::Materialize* {
- auto expr_val = EvaluateConstantValue(expr->Declaration(), expr->Type());
- if (!expr_val.IsValid()) {
+ auto* decl = expr->Declaration();
+ auto expr_val = EvaluateConstantValue(decl, expr->Type());
+ if (!expr_val) {
+ return nullptr;
+ }
+ if (!expr_val->IsValid()) {
TINT_ICE(Resolver, builder_->Diagnostics())
- << expr->Declaration()->source
+ << decl->source
<< " EvaluateConstantValue() returned invalid value for materialized "
"value of type: "
<< (expr->Type() ? expr->Type()->FriendlyName(builder_->Symbols()) : "<null>");
return nullptr;
}
- auto materialized_val = ConstantCast(expr_val, target_ty);
- auto* m = builder_->create<sem::Materialize>(expr, current_statement_, materialized_val);
+ auto materialized_val = ConvertValue(expr_val.Get(), target_ty, decl->source);
+ if (!materialized_val) {
+ return nullptr;
+ }
+ auto* m =
+ builder_->create<sem::Materialize>(expr, current_statement_, materialized_val.Get());
m->Behaviors() = expr->Behaviors();
- builder_->Sem().Replace(expr->Declaration(), m);
+ builder_->Sem().Replace(decl, m);
return validator_.Materialize(m) ? m : nullptr;
};
@@ -1215,8 +1223,11 @@
}
auto val = EvaluateConstantValue(expr, ty);
+ if (!val) {
+ return nullptr;
+ }
bool has_side_effects = idx->HasSideEffects() || obj->HasSideEffects();
- auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, val,
+ auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, val.Get(),
has_side_effects, obj->SourceVariable());
sem->Behaviors() = idx->Behaviors() + obj->Behaviors();
return sem;
@@ -1230,7 +1241,10 @@
}
auto val = EvaluateConstantValue(expr, ty);
- auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, val,
+ if (!val) {
+ return nullptr;
+ }
+ auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, val.Get(),
inner->HasSideEffects());
sem->Behaviors() = inner->Behaviors();
@@ -1277,9 +1291,12 @@
if (!MaterializeArguments(args, call_target)) {
return nullptr;
}
- auto value = EvaluateConstantValue(expr, call_target->ReturnType());
+ auto val = EvaluateConstantValue(expr, call_target->ReturnType());
+ if (!val) {
+ return nullptr;
+ }
return builder_->create<sem::Call>(expr, call_target, std::move(args), current_statement_,
- value, has_side_effects);
+ val.Get(), has_side_effects);
};
// ct_ctor_or_conv is a helper for building either a sem::TypeConstructor or sem::TypeConversion
@@ -1315,9 +1332,12 @@
if (!MaterializeArguments(args, call_target)) {
return nullptr;
}
- auto value = EvaluateConstantValue(expr, call_target->ReturnType());
+ auto val = EvaluateConstantValue(expr, call_target->ReturnType());
+ if (!val) {
+ return nullptr;
+ }
return builder_->create<sem::Call>(expr, call_target, std::move(args),
- current_statement_, value, has_side_effects);
+ current_statement_, val.Get(), has_side_effects);
},
[&](const sem::Struct* str) -> sem::Call* {
auto* call_target = utils::GetOrCreate(
@@ -1337,9 +1357,12 @@
if (!MaterializeArguments(args, call_target)) {
return nullptr;
}
- auto value = EvaluateConstantValue(expr, call_target->ReturnType());
+ auto val = EvaluateConstantValue(expr, call_target->ReturnType());
+ if (!val) {
+ return nullptr;
+ }
return builder_->create<sem::Call>(expr, call_target, std::move(args),
- current_statement_, value, has_side_effects);
+ current_statement_, val.Get(), has_side_effects);
},
[&](Default) {
AddError("type is not constructible", expr->source);
@@ -1616,7 +1639,10 @@
}
auto val = EvaluateConstantValue(literal, ty);
- return builder_->create<sem::Expression>(literal, ty, current_statement_, val,
+ if (!val) {
+ return nullptr;
+ }
+ return builder_->create<sem::Expression>(literal, ty, current_statement_, val.Get(),
/* has_side_effects */ false);
}
@@ -1828,8 +1854,11 @@
}
auto val = EvaluateConstantValue(expr, op.result);
+ if (!val) {
+ return nullptr;
+ }
bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
- auto* sem = builder_->create<sem::Expression>(expr, op.result, current_statement_, val,
+ auto* sem = builder_->create<sem::Expression>(expr, op.result, current_statement_, val.Get(),
has_side_effects);
sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
@@ -1902,7 +1931,10 @@
}
auto val = EvaluateConstantValue(unary, ty);
- auto* sem = builder_->create<sem::Expression>(unary, ty, current_statement_, val,
+ if (!val) {
+ return nullptr;
+ }
+ auto* sem = builder_->create<sem::Expression>(unary, ty, current_statement_, val.Get(),
expr->HasSideEffects(), source_var);
sem->Behaviors() = expr->Behaviors();
return sem;
@@ -2022,7 +2054,7 @@
return nullptr;
}
- count = static_cast<uint32_t>(count_val.Element<AInt>(0).value);
+ count = count_val.Element<uint32_t>(0);
}
auto size = std::max<uint64_t>(count, 1) * stride;
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index b03bb32..7d01934 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -34,6 +34,7 @@
#include "src/tint/sem/constant.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/struct.h"
+#include "src/tint/utils/result.h"
#include "src/tint/utils/unique_vector.h"
// Forward declarations
@@ -354,16 +355,19 @@
//////////////////////////////////////////////////////////////////////////////
/// Constant value evaluation methods
//////////////////////////////////////////////////////////////////////////////
- /// Cast `Value` to `target_type`
- /// @return the casted value
- sem::Constant ConstantCast(const sem::Constant& value,
- const sem::Type* target_type,
- const sem::Type* target_element_type = nullptr);
+ /// The result type of a ConstantEvaluation method. Holds the constant value and a boolean,
+ /// which is true on success, false on an error.
+ using ConstantResult = utils::Result<sem::Constant>;
- sem::Constant EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type);
- sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal,
- const sem::Type* type);
- sem::Constant EvaluateConstantValue(const ast::CallExpression* call, const sem::Type* type);
+ /// Convert the `value` to `target_type`
+ /// @return the converted value
+ ConstantResult ConvertValue(const sem::Constant& value,
+ const sem::Type* target_type,
+ const Source& source);
+ ConstantResult EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type);
+ ConstantResult EvaluateConstantValue(const ast::LiteralExpression* literal,
+ const sem::Type* type);
+ ConstantResult EvaluateConstantValue(const ast::CallExpression* call, const sem::Type* type);
/// @returns true if the symbol is the name of a builtin function.
bool IsBuiltin(Symbol) const;
diff --git a/src/tint/resolver/resolver_constants.cc b/src/tint/resolver/resolver_constants.cc
index 265b06d..3bf8df0 100644
--- a/src/tint/resolver/resolver_constants.cc
+++ b/src/tint/resolver/resolver_constants.cc
@@ -14,146 +14,267 @@
#include "src/tint/resolver/resolver.h"
+#include <cmath>
+// TODO(https://crbug.com/dawn/1379) Update cpplint and remove NOLINT
+#include <optional> // NOLINT(build/include_order))
+
#include "src/tint/sem/abstract_float.h"
#include "src/tint/sem/abstract_int.h"
#include "src/tint/sem/constant.h"
#include "src/tint/sem/type_constructor.h"
+#include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/map.h"
+#include "src/tint/utils/transform.h"
using namespace tint::number_suffixes; // NOLINT
namespace tint::resolver {
+
namespace {
-sem::Constant::Scalars CastScalars(sem::Constant::Scalars in, const sem::Type* target_type) {
- sem::Constant::Scalars out;
- out.reserve(in.size());
- for (auto v : in) {
- // TODO(crbug.com/tint/1504): Check that value fits in new type
- out.emplace_back(Switch<sem::Constant::Scalar>(
- target_type, //
- [&](const sem::AbstractInt*) { return sem::Constant::Cast<AInt>(v); },
- [&](const sem::AbstractFloat*) { return sem::Constant::Cast<AFloat>(v); },
- [&](const sem::I32*) { return sem::Constant::Cast<AInt>(v); },
- [&](const sem::U32*) { return sem::Constant::Cast<AInt>(v); },
- [&](const sem::F32*) { return sem::Constant::Cast<AFloat>(v); },
- [&](const sem::F16*) { return sem::Constant::Cast<AFloat>(v); },
- [&](const sem::Bool*) { return sem::Constant::Cast<bool>(v); },
- [&](Default) {
- diag::List diags;
- TINT_UNREACHABLE(Semantic, diags)
- << "invalid element type " << target_type->TypeInfo().name;
- return sem::Constant::Scalar(false);
- }));
+/// Converts and returns all the element values of `in` to the type `T`, using the converter
+/// function `CONVERTER`.
+/// @param elements_in the vector of elements to be converted
+/// @param converter a function-like with the signature `void(TO&, FROM)`
+/// @returns the elements converted to type T.
+template <typename T, typename ELEMENTS_IN, typename CONVERTER>
+sem::Constant::Elements Transform(const ELEMENTS_IN& elements_in, CONVERTER&& converter) {
+ TINT_BEGIN_DISABLE_WARNING_UNREACHABLE_CODE();
+
+ return utils::Transform(elements_in, [&](auto value_in) {
+ if constexpr (std::is_same_v<UnwrapNumber<T>, bool>) {
+ return AInt(value_in != 0);
+ } else {
+ T converted{};
+ converter(converted, value_in);
+ if constexpr (IsFloatingPoint<UnwrapNumber<T>>) {
+ return AFloat(converted);
+ } else {
+ return AInt(converted);
+ }
+ }
+ });
+
+ TINT_END_DISABLE_WARNING_UNREACHABLE_CODE();
+}
+
+/// Converts and returns all the element values of `in` to the semantic type `el_ty`, using the
+/// converter function `CONVERTER`.
+/// @param in the constant to convert
+/// @param el_ty the target element type
+/// @param converter a function-like with the signature `void(TO&, FROM)`
+/// @returns the elements converted to `el_ty`
+template <typename CONVERTER>
+sem::Constant::Elements Transform(const sem::Constant::Elements& in,
+ const sem::Type* el_ty,
+ CONVERTER&& converter) {
+ return std::visit(
+ [&](auto&& v) {
+ return Switch(
+ el_ty, //
+ [&](const sem::AbstractInt*) { return Transform<AInt>(v, converter); },
+ [&](const sem::AbstractFloat*) { return Transform<AFloat>(v, converter); },
+ [&](const sem::I32*) { return Transform<i32>(v, converter); },
+ [&](const sem::U32*) { return Transform<u32>(v, converter); },
+ [&](const sem::F32*) { return Transform<f32>(v, converter); },
+ [&](const sem::F16*) { return Transform<f16>(v, converter); },
+ [&](const sem::Bool*) { return Transform<bool>(v, converter); },
+ [&](Default) -> sem::Constant::Elements {
+ diag::List diags;
+ TINT_UNREACHABLE(Semantic, diags)
+ << "invalid element type " << el_ty->TypeInfo().name;
+ return {};
+ });
+ },
+ in);
+}
+
+/// Converts and returns all the elements in `in` to the type `el_ty`.
+/// If the value does not fit in the target type, and:
+/// * the target type is an integer type, then the resulting value will be clamped to the integer's
+/// highest or lowest value.
+/// * the target type is an float type, then the resulting value will be either positive or
+/// negative infinity, based on the sign of the input value.
+/// @param in the input elements
+/// @param el_ty the target element type
+/// @returns the elements converted to `el_ty`
+sem::Constant::Elements ConvertElements(const sem::Constant::Elements& in, const sem::Type* el_ty) {
+ return Transform(in, el_ty, [](auto& el_out, auto el_in) {
+ using OUT = std::decay_t<decltype(el_out)>;
+ if (auto conv = CheckedConvert<OUT>(el_in)) {
+ el_out = conv.Get();
+ } else {
+ constexpr auto kInf = std::numeric_limits<double>::infinity();
+ switch (conv.Failure()) {
+ case ConversionFailure::kExceedsNegativeLimit:
+ el_out = IsFloatingPoint<UnwrapNumber<OUT>> ? OUT(-kInf) : OUT::kLowest;
+ break;
+ case ConversionFailure::kExceedsPositiveLimit:
+ el_out = IsFloatingPoint<UnwrapNumber<OUT>> ? OUT(kInf) : OUT::kHighest;
+ break;
+ case ConversionFailure::kTooSmall:
+ el_out = OUT(el_in < 0 ? -0.0 : 0.0);
+ break;
+ }
+ }
+ });
+}
+
+/// Converts and returns all the elements in `in` to the type `el_ty`, by performing a
+/// `CheckedConvert` on each element value. A single error diagnostic will be raised if an element
+/// value cannot be represented by the target type.
+/// @param in the input elements
+/// @param el_ty the target element type
+/// @returns the elements converted to `el_ty`, or a Failure if some elements could not be
+/// represented by the target type.
+utils::Result<sem::Constant::Elements> MaterializeElements(const sem::Constant::Elements& in,
+ const sem::Type* el_ty,
+ ProgramBuilder& builder,
+ Source source) {
+ std::optional<std::string> failure;
+
+ auto out = Transform(in, el_ty, [&](auto& el_out, auto el_in) {
+ using OUT = std::decay_t<decltype(el_out)>;
+ if (auto conv = CheckedConvert<OUT>(el_in)) {
+ el_out = conv.Get();
+ } else if (conv.Failure() == ConversionFailure::kTooSmall) {
+ el_out = OUT(el_in < 0 ? -0.0 : 0.0);
+ } else if (!failure.has_value()) {
+ std::stringstream ss;
+ ss << "value " << el_in << " cannot be represented as ";
+ ss << "'" << builder.FriendlyName(el_ty) << "'";
+ failure = ss.str();
+ }
+ });
+
+ if (failure.has_value()) {
+ builder.Diagnostics().add_error(diag::System::Resolver, std::move(failure.value()), source);
+ return utils::Failure;
}
+
return out;
}
} // namespace
-sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type) {
+utils::Result<sem::Constant> Resolver::EvaluateConstantValue(const ast::Expression* expr,
+ const sem::Type* type) {
if (auto* e = expr->As<ast::LiteralExpression>()) {
return EvaluateConstantValue(e, type);
}
if (auto* e = expr->As<ast::CallExpression>()) {
return EvaluateConstantValue(e, type);
}
- return {};
+ return sem::Constant{};
}
-sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* literal,
- const sem::Type* type) {
+utils::Result<sem::Constant> Resolver::EvaluateConstantValue(const ast::LiteralExpression* literal,
+ const sem::Type* type) {
return Switch(
literal,
+ [&](const ast::BoolLiteralExpression* lit) {
+ return sem::Constant{type, {AInt(lit->value ? 1 : 0)}};
+ },
[&](const ast::IntLiteralExpression* lit) {
return sem::Constant{type, {AInt(lit->value)}};
},
[&](const ast::FloatLiteralExpression* lit) {
return sem::Constant{type, {AFloat(lit->value)}};
- },
- [&](const ast::BoolLiteralExpression* lit) {
- return sem::Constant{type, {lit->value}};
});
}
-sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
- const sem::Type* type) {
+utils::Result<sem::Constant> Resolver::EvaluateConstantValue(const ast::CallExpression* call,
+ const sem::Type* ty) {
uint32_t result_size = 0;
- auto* el_ty = sem::Type::ElementOf(type, &result_size);
+ auto* el_ty = sem::Type::ElementOf(ty, &result_size);
if (!el_ty) {
- return {};
+ return sem::Constant{};
}
// ElementOf() will also return the element type of array, which we do not support.
- if (type->Is<sem::Array>()) {
- return {};
+ if (ty->Is<sem::Array>()) {
+ return sem::Constant{};
}
// For zero value init, return 0s
if (call->args.empty()) {
- using Scalars = sem::Constant::Scalars;
return Switch(
el_ty,
[&](const sem::AbstractInt*) {
- return sem::Constant(type, Scalars(result_size, AInt(0)));
+ return sem::Constant(ty, std::vector(result_size, AInt(0)));
},
[&](const sem::AbstractFloat*) {
- return sem::Constant(type, Scalars(result_size, AFloat(0)));
+ return sem::Constant(ty, std::vector(result_size, AFloat(0)));
},
- [&](const sem::I32*) { return sem::Constant(type, Scalars(result_size, AInt(0))); },
- [&](const sem::U32*) { return sem::Constant(type, Scalars(result_size, AInt(0))); },
- [&](const sem::F32*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); },
- [&](const sem::F16*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); },
- [&](const sem::Bool*) { return sem::Constant(type, Scalars(result_size, false)); });
+ [&](const sem::I32*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); },
+ [&](const sem::U32*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); },
+ [&](const sem::F32*) { return sem::Constant(ty, std::vector(result_size, AFloat(0))); },
+ [&](const sem::F16*) { return sem::Constant(ty, std::vector(result_size, AFloat(0))); },
+ [&](const sem::Bool*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); });
}
- // Build value for type_ctor from each child value by casting to type_ctor's type.
- sem::Constant::Scalars elems;
+ // Build value for type_ctor from each child value by converting to type_ctor's type.
+ std::optional<sem::Constant::Elements> elements;
for (auto* expr : call->args) {
auto* arg = builder_->Sem().Get(expr);
if (!arg) {
- return {};
+ return sem::Constant{};
}
auto value = arg->ConstantValue();
if (!value) {
- return {};
+ return sem::Constant{};
}
- elems.insert(elems.end(), value.Elements().begin(), value.Elements().end());
+
+ // Convert the elements to the desired type.
+ auto converted = ConvertElements(value.GetElements(), el_ty);
+
+ if (elements.has_value()) {
+ // Append the converted vector to elements
+ std::visit(
+ [&](auto&& dst) {
+ using VEC_TY = std::decay_t<decltype(dst)>;
+ const auto& src = std::get<VEC_TY>(converted);
+ dst.insert(dst.end(), src.begin(), src.end());
+ },
+ elements.value());
+ } else {
+ elements = std::move(converted);
+ }
}
// Splat single-value initializers
- if (elems.size() == 1) {
- for (uint32_t i = 0; i < result_size - 1; ++i) {
- elems.emplace_back(elems[0]);
- }
- }
+ std::visit(
+ [&](auto&& v) {
+ if (v.size() == 1) {
+ for (uint32_t i = 0; i < result_size - 1; ++i) {
+ v.emplace_back(v[0]);
+ }
+ }
+ },
+ elements.value());
- // Finally cast the elements to the desired type.
- auto cast = CastScalars(elems, el_ty);
-
- return sem::Constant(type, std::move(cast));
+ return sem::Constant(ty, std::move(elements.value()));
}
-sem::Constant Resolver::ConstantCast(const sem::Constant& value,
- const sem::Type* target_type,
- const sem::Type* target_element_type /* = nullptr */) {
- if (value.Type() == target_type) {
+utils::Result<sem::Constant> Resolver::ConvertValue(const sem::Constant& value,
+ const sem::Type* ty,
+ const Source& source) {
+ if (value.Type() == ty) {
return value;
}
- if (target_element_type == nullptr) {
- target_element_type = sem::Type::ElementOf(target_type);
+ auto* el_ty = sem::Type::ElementOf(ty);
+ if (el_ty == nullptr) {
+ return sem::Constant{};
}
- if (target_element_type == nullptr) {
- return {};
- }
- if (value.ElementType() == target_element_type) {
- return sem::Constant(target_type, value.Elements());
+ if (value.ElementType() == el_ty) {
+ return sem::Constant(ty, value.GetElements());
}
- auto elems = CastScalars(value.Elements(), target_element_type);
-
- return sem::Constant(target_type, elems);
+ if (auto res = MaterializeElements(value.GetElements(), el_ty, *builder_, source)) {
+ return sem::Constant(ty, std::move(res.Get()));
+ }
+ return utils::Failure;
}
} // namespace tint::resolver
diff --git a/src/tint/resolver/resolver_constants_test.cc b/src/tint/resolver/resolver_constants_test.cc
index 05e6e4c..bbdbfea 100644
--- a/src/tint/resolver/resolver_constants_test.cc
+++ b/src/tint/resolver/resolver_constants_test.cc
@@ -23,8 +23,6 @@
namespace tint::resolver {
namespace {
-using Scalar = sem::Constant::Scalar;
-
using ResolverConstantsTest = ResolverTest;
TEST_F(ResolverConstantsTest, Scalar_i32) {
@@ -38,7 +36,7 @@
EXPECT_TRUE(sem->Type()->Is<sem::I32>());
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 99);
}
@@ -53,7 +51,7 @@
EXPECT_TRUE(sem->Type()->Is<sem::U32>());
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 99u);
}
@@ -68,7 +66,7 @@
EXPECT_TRUE(sem->Type()->Is<sem::F32>());
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 9.9f);
}
@@ -83,7 +81,7 @@
EXPECT_TRUE(sem->Type()->Is<sem::Bool>());
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 1u);
EXPECT_EQ(sem->ConstantValue().Element<bool>(0), true);
}
@@ -100,7 +98,7 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 0);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 0);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 0);
@@ -119,7 +117,7 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 0u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 0u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 0u);
@@ -138,7 +136,7 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 0.0);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 0.0);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 0.0);
@@ -157,7 +155,7 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<bool>(0), false);
EXPECT_EQ(sem->ConstantValue().Element<bool>(1), false);
EXPECT_EQ(sem->ConstantValue().Element<bool>(2), false);
@@ -176,7 +174,7 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 99);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 99);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 99);
@@ -195,7 +193,7 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 99u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 99u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 99u);
@@ -214,7 +212,7 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 9.9f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 9.9f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 9.9f);
@@ -233,7 +231,7 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<bool>(0), true);
EXPECT_EQ(sem->ConstantValue().Element<bool>(1), true);
EXPECT_EQ(sem->ConstantValue().Element<bool>(2), true);
@@ -252,7 +250,7 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 1);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 2);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 3);
@@ -271,7 +269,7 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 1);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 2);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 3);
@@ -290,7 +288,7 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 1.f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 2.f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 3.f);
@@ -309,7 +307,7 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<bool>(0), true);
EXPECT_EQ(sem->ConstantValue().Element<bool>(1), false);
EXPECT_EQ(sem->ConstantValue().Element<bool>(2), true);
@@ -328,7 +326,7 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 1);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 2);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 3);
@@ -347,7 +345,7 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 1);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 2);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 3);
@@ -366,7 +364,7 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 1.f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 2.f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 3.f);
@@ -385,13 +383,13 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<bool>(0), true);
EXPECT_EQ(sem->ConstantValue().Element<bool>(1), false);
EXPECT_EQ(sem->ConstantValue().Element<bool>(2), true);
}
-TEST_F(ResolverConstantsTest, Vec3_Cast_f32_to_32) {
+TEST_F(ResolverConstantsTest, Vec3_Convert_f32_to_i32) {
auto* expr = vec3<i32>(vec3<f32>(1.1_f, 2.2_f, 3.3_f));
WrapInFunction(expr);
@@ -404,13 +402,13 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 1);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 2);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 3);
}
-TEST_F(ResolverConstantsTest, Vec3_Cast_u32_to_f32) {
+TEST_F(ResolverConstantsTest, Vec3_Convert_u32_to_f32) {
auto* expr = vec3<f32>(vec3<u32>(10_u, 20_u, 30_u));
WrapInFunction(expr);
@@ -423,11 +421,95 @@
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
- ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 10.f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 20.f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 30.f);
}
+TEST_F(ResolverConstantsTest, Vec3_Convert_Large_f32_to_i32) {
+ auto* expr = vec3<i32>(vec3<f32>(1e10_f, -1e20_f, 1e30_f));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, i32::kHighest);
+ EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, i32::kLowest);
+ EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, i32::kHighest);
+}
+
+TEST_F(ResolverConstantsTest, Vec3_Convert_Large_f32_to_u32) {
+ auto* expr = vec3<u32>(vec3<f32>(1e10_f, -1e20_f, 1e30_f));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::U32>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, u32::kHighest);
+ EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, u32::kLowest);
+ EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, u32::kHighest);
+}
+
+// TODO(crbug.com/tint/1502): Enable when f16 overloads are implemented
+TEST_F(ResolverConstantsTest, DISABLED_Vec3_Convert_Large_f32_to_f16) {
+ Enable(ast::Extension::kF16);
+
+ auto* expr = vec3<f16>(vec3<f32>(0.00001_f, -0.00002_f, 0.00003_f));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ constexpr auto kInf = std::numeric_limits<double>::infinity();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F16>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F16>());
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, kInf);
+ EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, -kInf);
+ EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, kInf);
+}
+
+// TODO(crbug.com/tint/1502): Enable when f16 overloads are implemented
+TEST_F(ResolverConstantsTest, DISABLED_Vec3_Convert_Small_f32_to_f16) {
+ Enable(ast::Extension::kF16);
+
+ auto* expr = vec3<f16>(vec3<f32>(1e-10_f, -1e20_f, 1e30_f));
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ EXPECT_NE(sem, nullptr);
+ ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F16>());
+ EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F16>());
+ ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
+ EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 0.0);
+ EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, -0.0);
+ EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 0.0);
+}
+
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h
index 6c5a74f..a0d71e5 100644
--- a/src/tint/resolver/resolver_test_helper.h
+++ b/src/tint/resolver/resolver_test_helper.h
@@ -15,6 +15,7 @@
#ifndef SRC_TINT_RESOLVER_RESOLVER_TEST_HELPER_H_
#define SRC_TINT_RESOLVER_RESOLVER_TEST_HELPER_H_
+#include <functional>
#include <memory>
#include <string>
#include <vector>
@@ -170,7 +171,7 @@
struct ptr {};
using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b);
-using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, int elem_value);
+using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, double elem_value);
using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b);
template <typename T>
@@ -188,6 +189,9 @@
/// Helper for building bool types and expressions
template <>
struct DataType<bool> {
+ /// The element type
+ using ElementType = bool;
+
/// false as bool is not a composite type
static constexpr bool is_composite = false;
@@ -200,8 +204,8 @@
/// @param b the ProgramBuilder
/// @param elem_value the b
/// @return a new AST expression of the bool type
- static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
- return b.Expr(elem_value == 0);
+ static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
+ return b.Expr(std::equal_to<double>()(elem_value, 0));
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "bool"; }
@@ -210,6 +214,9 @@
/// Helper for building i32 types and expressions
template <>
struct DataType<i32> {
+ /// The element type
+ using ElementType = i32;
+
/// false as i32 is not a composite type
static constexpr bool is_composite = false;
@@ -222,7 +229,7 @@
/// @param b the ProgramBuilder
/// @param elem_value the value i32 will be initialized with
/// @return a new AST i32 literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+ static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Expr(static_cast<i32>(elem_value));
}
/// @returns the WGSL name for the type
@@ -232,6 +239,9 @@
/// Helper for building u32 types and expressions
template <>
struct DataType<u32> {
+ /// The element type
+ using ElementType = u32;
+
/// false as u32 is not a composite type
static constexpr bool is_composite = false;
@@ -244,7 +254,7 @@
/// @param b the ProgramBuilder
/// @param elem_value the value u32 will be initialized with
/// @return a new AST u32 literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+ static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Expr(static_cast<u32>(elem_value));
}
/// @returns the WGSL name for the type
@@ -254,6 +264,9 @@
/// Helper for building f32 types and expressions
template <>
struct DataType<f32> {
+ /// The element type
+ using ElementType = f32;
+
/// false as f32 is not a composite type
static constexpr bool is_composite = false;
@@ -266,7 +279,7 @@
/// @param b the ProgramBuilder
/// @param elem_value the value f32 will be initialized with
/// @return a new AST f32 literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+ static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Expr(static_cast<f32>(elem_value));
}
/// @returns the WGSL name for the type
@@ -276,6 +289,9 @@
/// Helper for building f16 types and expressions
template <>
struct DataType<f16> {
+ /// The element type
+ using ElementType = f16;
+
/// false as f16 is not a composite type
static constexpr bool is_composite = false;
@@ -288,7 +304,7 @@
/// @param b the ProgramBuilder
/// @param elem_value the value f16 will be initialized with
/// @return a new AST f16 literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+ static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Expr(static_cast<f16>(elem_value));
}
/// @returns the WGSL name for the type
@@ -298,6 +314,9 @@
/// Helper for building abstract float types and expressions
template <>
struct DataType<AFloat> {
+ /// The element type
+ using ElementType = AFloat;
+
/// false as AFloat is not a composite type
static constexpr bool is_composite = false;
@@ -309,7 +328,7 @@
/// @param b the ProgramBuilder
/// @param elem_value the value the abstract-float literal will be constructed with
/// @return a new AST abstract-float literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+ static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Expr(AFloat(elem_value));
}
/// @returns the WGSL name for the type
@@ -319,6 +338,9 @@
/// Helper for building abstract integer types and expressions
template <>
struct DataType<AInt> {
+ /// The element type
+ using ElementType = AInt;
+
/// false as AFloat is not a composite type
static constexpr bool is_composite = false;
@@ -330,7 +352,7 @@
/// @param b the ProgramBuilder
/// @param elem_value the value the abstract-int literal will be constructed with
/// @return a new AST abstract-int literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+ static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Expr(AInt(elem_value));
}
/// @returns the WGSL name for the type
@@ -340,6 +362,9 @@
/// Helper for building vector types and expressions
template <uint32_t N, typename T>
struct DataType<vec<N, T>> {
+ /// The element type
+ using ElementType = T;
+
/// true as vectors are a composite type
static constexpr bool is_composite = true;
@@ -357,14 +382,14 @@
/// @param elem_value the value each element in the vector will be initialized
/// with
/// @return a new AST vector value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+ static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Construct(AST(b), ExprArgs(b, elem_value));
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element will be initialized with
/// @return the list of expressions that are used to construct the vector
- static inline ast::ExpressionList ExprArgs(ProgramBuilder& b, int elem_value) {
+ static inline ast::ExpressionList ExprArgs(ProgramBuilder& b, double elem_value) {
ast::ExpressionList args;
for (uint32_t i = 0; i < N; i++) {
args.emplace_back(DataType<T>::Expr(b, elem_value));
@@ -380,6 +405,9 @@
/// Helper for building matrix types and expressions
template <uint32_t N, uint32_t M, typename T>
struct DataType<mat<N, M, T>> {
+ /// The element type
+ using ElementType = T;
+
/// true as matrices are a composite type
static constexpr bool is_composite = true;
@@ -398,14 +426,14 @@
/// @param elem_value the value each element in the matrix will be initialized
/// with
/// @return a new AST matrix value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+ static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Construct(AST(b), ExprArgs(b, elem_value));
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element will be initialized with
/// @return the list of expressions that are used to construct the matrix
- static inline ast::ExpressionList ExprArgs(ProgramBuilder& b, int elem_value) {
+ static inline ast::ExpressionList ExprArgs(ProgramBuilder& b, double elem_value) {
ast::ExpressionList args;
for (uint32_t i = 0; i < N; i++) {
args.emplace_back(DataType<vec<M, T>>::Expr(b, elem_value));
@@ -422,6 +450,9 @@
/// Helper for building alias types and expressions
template <typename T, int ID>
struct DataType<alias<T, ID>> {
+ /// The element type
+ using ElementType = T;
+
/// true if the aliased type is a composite type
static constexpr bool is_composite = DataType<T>::is_composite;
@@ -444,7 +475,7 @@
/// @return a new AST expression of the alias type
template <bool IS_COMPOSITE = is_composite>
static inline traits::EnableIf<!IS_COMPOSITE, const ast::Expression*> Expr(ProgramBuilder& b,
- int elem_value) {
+ double elem_value) {
// Cast
return b.Construct(AST(b), DataType<T>::Expr(b, elem_value));
}
@@ -454,7 +485,7 @@
/// @return a new AST expression of the alias type
template <bool IS_COMPOSITE = is_composite>
static inline traits::EnableIf<IS_COMPOSITE, const ast::Expression*> Expr(ProgramBuilder& b,
- int elem_value) {
+ double elem_value) {
// Construct
return b.Construct(AST(b), DataType<T>::ExprArgs(b, elem_value));
}
@@ -465,6 +496,9 @@
/// Helper for building pointer types and expressions
template <typename T>
struct DataType<ptr<T>> {
+ /// The element type
+ using ElementType = T;
+
/// true if the pointer type is a composite type
static constexpr bool is_composite = false;
@@ -483,7 +517,7 @@
/// @param b the ProgramBuilder
/// @return a new AST expression of the alias type
- static inline const ast::Expression* Expr(ProgramBuilder& b, int /*unused*/) {
+ static inline const ast::Expression* Expr(ProgramBuilder& b, double /*unused*/) {
auto sym = b.Symbols().New("global_for_ptr");
b.Global(sym, DataType<T>::AST(b), ast::StorageClass::kPrivate);
return b.AddressOf(sym);
@@ -495,6 +529,9 @@
/// Helper for building array types and expressions
template <uint32_t N, typename T>
struct DataType<array<N, T>> {
+ /// The element type
+ using ElementType = T;
+
/// true as arrays are a composite type
static constexpr bool is_composite = true;
@@ -519,14 +556,14 @@
/// @param elem_value the value each element in the array will be initialized
/// with
/// @return a new AST array value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+ static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Construct(AST(b), ExprArgs(b, elem_value));
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element will be initialized with
/// @return the list of expressions that are used to construct the array
- static inline ast::ExpressionList ExprArgs(ProgramBuilder& b, int elem_value) {
+ static inline ast::ExpressionList ExprArgs(ProgramBuilder& b, double elem_value) {
ast::ExpressionList args;
for (uint32_t i = 0; i < N; i++) {
args.emplace_back(DataType<T>::Expr(b, elem_value));
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index fdd74e9..fd59988 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -1535,7 +1535,7 @@
});
if (is_const_expr) {
auto vector = builtin->Parameters()[index]->Type()->Is<sem::Vector>();
- for (size_t i = 0; i < values.Elements().size(); i++) {
+ for (size_t i = 0, n = values.ElementCount(); i < n; i++) {
auto value = values.Element<AInt>(i).value;
if (value < min || value > max) {
if (vector) {
diff --git a/src/tint/sem/constant.cc b/src/tint/sem/constant.cc
index 98c724c..1fa83f5 100644
--- a/src/tint/sem/constant.cc
+++ b/src/tint/sem/constant.cc
@@ -23,29 +23,19 @@
namespace tint::sem {
namespace {
-
-const Type* CheckElemType(const Type* ty, size_t num_scalars) {
- diag::List diag;
- if (ty->is_abstract_or_scalar() || ty->IsAnyOf<Vector, Matrix>()) {
- uint32_t count = 0;
- auto* el_ty = Type::ElementOf(ty, &count);
- if (num_scalars != count) {
- TINT_ICE(Semantic, diag) << "sem::Constant() type <-> scalar mismatch. type: '"
- << ty->TypeInfo().name << "' scalar: " << num_scalars;
- }
- TINT_ASSERT(Semantic, el_ty->is_abstract_or_scalar());
- return el_ty;
- }
- TINT_UNREACHABLE(Semantic, diag) << "Unsupported sem::Constant type: " << ty->TypeInfo().name;
- return nullptr;
+size_t CountElements(const Constant::Elements& elements) {
+ return std::visit([](auto&& vec) { return vec.size(); }, elements);
}
-
} // namespace
Constant::Constant() {}
-Constant::Constant(const sem::Type* ty, Scalars els)
- : type_(ty), elem_type_(CheckElemType(ty, els.size())), elems_(std::move(els)) {}
+Constant::Constant(const sem::Type* ty, Elements els)
+ : type_(ty), elem_type_(CheckElemType(ty, CountElements(els))), elems_(std::move(els)) {}
+
+Constant::Constant(const sem::Type* ty, AInts vec) : Constant(ty, Elements{std::move(vec)}) {}
+
+Constant::Constant(const sem::Type* ty, AFloats vec) : Constant(ty, Elements{std::move(vec)}) {}
Constant::Constant(const Constant&) = default;
@@ -54,16 +44,31 @@
Constant& Constant::operator=(const Constant& rhs) = default;
bool Constant::AnyZero() const {
- for (auto scalar : elems_) {
- auto is_zero = [&](auto&& s) {
- using T = std::remove_reference_t<decltype(s)>;
- return s == T(0);
- };
- if (std::visit(is_zero, scalar)) {
- return true;
+ return WithElements([&](auto&& vec) {
+ for (auto scalar : vec) {
+ using T = std::remove_reference_t<decltype(scalar)>;
+ if (scalar == T(0)) {
+ return true;
+ }
}
+ return false;
+ });
+}
+
+const Type* Constant::CheckElemType(const sem::Type* ty, size_t num_elements) {
+ diag::List diag;
+ if (ty->is_abstract_or_scalar() || ty->IsAnyOf<Vector, Matrix>()) {
+ uint32_t count = 0;
+ auto* el_ty = Type::ElementOf(ty, &count);
+ if (num_elements != count) {
+ TINT_ICE(Semantic, diag) << "sem::Constant() type <-> element mismatch. type: '"
+ << ty->TypeInfo().name << "' element: " << num_elements;
+ }
+ TINT_ASSERT(Semantic, el_ty->is_abstract_or_scalar());
+ return el_ty;
}
- return false;
+ TINT_UNREACHABLE(Semantic, diag) << "Unsupported sem::Constant type: " << ty->TypeInfo().name;
+ return nullptr;
}
} // namespace tint::sem
diff --git a/src/tint/sem/constant.h b/src/tint/sem/constant.h
index ea143b6..43109f7 100644
--- a/src/tint/sem/constant.h
+++ b/src/tint/sem/constant.h
@@ -15,7 +15,10 @@
#ifndef SRC_TINT_SEM_CONSTANT_H_
#define SRC_TINT_SEM_CONSTANT_H_
-#include <variant>
+#include <ostream>
+// TODO(https://crbug.com/dawn/1379) Update cpplint and remove NOLINT
+#include <utility>
+#include <variant> // NOLINT(build/include_order)
#include <vector>
#include "src/tint/program_builder.h"
@@ -23,15 +26,31 @@
namespace tint::sem {
-/// A Constant is compile-time known expression value, expressed as a flattened
-/// list of scalar values. Value may be of a scalar or vector type.
+/// A Constant holds a compile-time evaluated expression value, expressed as a flattened list of
+/// element values. The expression type may be of an abstract-numeric, scalar, vector or matrix
+/// type. Constant holds the element values in either a vector of abstract-integer (AInt) or
+/// abstract-float (AFloat), depending on the element type.
class Constant {
public:
- /// Scalar holds a single constant scalar value - one of: AInt, AFloat or bool.
- using Scalar = std::variant<AInt, AFloat, bool>;
+ /// AInts is a vector of AInt, used to hold elements of the WGSL types:
+ /// * abstract-integer
+ /// * i32
+ /// * u32
+ /// * bool (0 or 1)
+ using AInts = std::vector<AInt>;
- /// Scalars is a list of scalar values
- using Scalars = std::vector<Scalar>;
+ /// AFloats is a vector of AFloat, used to hold elements of the WGSL types:
+ /// * abstract-float
+ /// * f32
+ /// * f16
+ using AFloats = std::vector<AFloat>;
+
+ /// Elements is either a vector of AInts or AFloats
+ using Elements = std::variant<AInts, AFloats>;
+
+ /// Helper that resolves to either AInts or AFloats based on the element type T.
+ template <typename T>
+ using ElementVectorFor = std::conditional_t<IsFloatingPoint<UnwrapNumber<T>>, AFloats, AInts>;
/// Constructs an invalid Constant
Constant();
@@ -39,7 +58,23 @@
/// Constructs a Constant of the given type and element values
/// @param ty the Constant type
/// @param els the Constant element values
- Constant(const Type* ty, Scalars els);
+ Constant(const sem::Type* ty, Elements els);
+
+ /// Constructs a Constant of the given type and element values
+ /// @param ty the Constant type
+ /// @param vec the Constant element values
+ Constant(const sem::Type* ty, AInts vec);
+
+ /// Constructs a Constant of the given type and element values
+ /// @param ty the Constant type
+ /// @param vec the Constant element values
+ Constant(const sem::Type* ty, AFloats vec);
+
+ /// Constructs a Constant of the given type and element values
+ /// @param ty the Constant type
+ /// @param els the Constant element values
+ template <typename T>
+ Constant(const sem::Type* ty, std::initializer_list<T> els);
/// Copy constructor
Constant(const Constant&);
@@ -61,42 +96,77 @@
/// @returns the type of the Constant
const sem::Type* Type() const { return type_; }
+ /// @returns the number of elements
+ size_t ElementCount() const {
+ return std::visit([](auto&& v) { return v.size(); }, elems_);
+ }
+
/// @returns the element type of the Constant
const sem::Type* ElementType() const { return elem_type_; }
- /// @returns the constant's scalar elements
- const Scalars& Elements() const { return elems_; }
+ /// @returns the constant's elements
+ const Elements& GetElements() const { return elems_; }
- /// @returns true if any scalar element is zero
+ /// WithElements calls the function `f` with the vector of elements as either AFloats or AInts
+ /// @param f a function-like with the signature `R(auto&&)`.
+ /// @returns the result of calling `f`.
+ template <typename F>
+ auto WithElements(F&& f) const {
+ return std::visit(std::forward<F>(f), elems_);
+ }
+
+ /// WithElements calls the function `f` with the element vector as either AFloats or AInts
+ /// @param f a function-like with the signature `R(auto&&)`.
+ /// @returns the result of calling `f`.
+ template <typename F>
+ auto WithElements(F&& f) {
+ return std::visit(std::forward<F>(f), elems_);
+ }
+
+ /// @returns the elements as a vector of AInt
+ inline const AInts& IElements() const { return std::get<AInts>(elems_); }
+
+ /// @returns the elements as a vector of AFloat
+ inline const AFloats& FElements() const { return std::get<AFloats>(elems_); }
+
+ /// @returns true if any element is zero
bool AnyZero() const;
- /// @param index the index of the scalar value
- /// @return the value of the scalar at `index`, which must be of type `T`.
+ /// @param index the index of the element
+ /// @return the element at `index`, which must be of type `T`.
template <typename T>
- T Element(size_t index) const {
- return std::get<T>(elems_[index]);
- }
-
- /// @param index the index of the scalar value
- /// @return the value of the scalar `static_cast` to type T.
- template <typename T>
- T ElementAs(size_t index) const {
- return Cast<T>(elems_[index]);
- }
-
- /// @param s the input scalar
- /// @returns the scalar `s` cast to the type `T`.
- template <typename T>
- static T Cast(Scalar s) {
- return std::visit([](auto v) { return static_cast<T>(v); }, s);
- }
+ T Element(size_t index) const;
private:
+ /// Checks that the provided type matches the number of expected elements.
+ /// @returns the element type of `ty`.
+ const sem::Type* CheckElemType(const sem::Type* ty, size_t num_elements);
+
const sem::Type* type_ = nullptr;
const sem::Type* elem_type_ = nullptr;
- Scalars elems_;
+ Elements elems_;
};
+template <typename T>
+Constant::Constant(const sem::Type* ty, std::initializer_list<T> els)
+ : type_(ty), elem_type_(CheckElemType(type_, els.size())) {
+ ElementVectorFor<T> elements;
+ elements.reserve(els.size());
+ for (auto el : els) {
+ elements.emplace_back(AFloat(el));
+ }
+ elems_ = Elements{std::move(elements)};
+}
+
+template <typename T>
+T Constant::Element(size_t index) const {
+ if constexpr (std::is_same_v<ElementVectorFor<T>, AFloats>) {
+ return static_cast<T>(FElements()[index].value);
+ } else {
+ return static_cast<T>(IElements()[index].value);
+ }
+}
+
} // namespace tint::sem
#endif // SRC_TINT_SEM_CONSTANT_H_
diff --git a/src/tint/sem/constant_test.cc b/src/tint/sem/constant_test.cc
new file mode 100644
index 0000000..345ebd8
--- /dev/null
+++ b/src/tint/sem/constant_test.cc
@@ -0,0 +1,199 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/sem/constant.h"
+
+#include <gmock/gmock.h>
+
+#include "src/tint/sem/abstract_float.h"
+#include "src/tint/sem/abstract_int.h"
+#include "src/tint/sem/test_helper.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::sem {
+namespace {
+
+using ConstantTest = TestHelper;
+
+TEST_F(ConstantTest, ConstructorInitializerList) {
+ {
+ Constant c(create<AbstractInt>(), {1_a});
+ c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(1_a)); });
+ }
+ {
+ Constant c(create<I32>(), {1_i});
+ c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(1_a)); });
+ }
+ {
+ Constant c(create<U32>(), {1_u});
+ c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(1_a)); });
+ }
+ {
+ Constant c(create<Bool>(), {false});
+ c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(0_a)); });
+ }
+ {
+ Constant c(create<Bool>(), {true});
+ c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(1_a)); });
+ }
+ {
+ Constant c(create<AbstractFloat>(), {1.0_a});
+ c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(1.0_a)); });
+ }
+ {
+ Constant c(create<F32>(), {1.0_f});
+ c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(1.0_a)); });
+ }
+ {
+ Constant c(create<F16>(), {1.0_h});
+ c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(1.0_a)); });
+ }
+}
+
+TEST_F(ConstantTest, Element_ai) {
+ Constant c(create<AbstractInt>(), {1_a});
+ EXPECT_EQ(c.Element<AInt>(0), 1_a);
+ EXPECT_EQ(c.ElementCount(), 1u);
+}
+
+TEST_F(ConstantTest, Element_i32) {
+ Constant c(create<I32>(), {1_a});
+ EXPECT_EQ(c.Element<i32>(0), 1_i);
+ EXPECT_EQ(c.ElementCount(), 1u);
+}
+
+TEST_F(ConstantTest, Element_u32) {
+ Constant c(create<U32>(), {1_a});
+ EXPECT_EQ(c.Element<u32>(0), 1_u);
+ EXPECT_EQ(c.ElementCount(), 1u);
+}
+
+TEST_F(ConstantTest, Element_bool) {
+ Constant c(create<Bool>(), {true});
+ EXPECT_EQ(c.Element<bool>(0), true);
+ EXPECT_EQ(c.ElementCount(), 1u);
+}
+
+TEST_F(ConstantTest, Element_af) {
+ Constant c(create<AbstractFloat>(), {1.0_a});
+ EXPECT_EQ(c.Element<AFloat>(0), 1.0_a);
+ EXPECT_EQ(c.ElementCount(), 1u);
+}
+
+TEST_F(ConstantTest, Element_f32) {
+ Constant c(create<F32>(), {1.0_a});
+ EXPECT_EQ(c.Element<f32>(0), 1.0_f);
+ EXPECT_EQ(c.ElementCount(), 1u);
+}
+
+TEST_F(ConstantTest, Element_f16) {
+ Constant c(create<F16>(), {1.0_a});
+ EXPECT_EQ(c.Element<f16>(0), 1.0_h);
+ EXPECT_EQ(c.ElementCount(), 1u);
+}
+
+TEST_F(ConstantTest, Element_vec3_ai) {
+ Constant c(create<Vector>(create<AbstractInt>(), 3u), {1_a, 2_a, 3_a});
+ EXPECT_EQ(c.Element<AInt>(0), 1_a);
+ EXPECT_EQ(c.Element<AInt>(1), 2_a);
+ EXPECT_EQ(c.Element<AInt>(2), 3_a);
+ EXPECT_EQ(c.ElementCount(), 3u);
+}
+
+TEST_F(ConstantTest, Element_vec3_i32) {
+ Constant c(create<Vector>(create<I32>(), 3u), {1_a, 2_a, 3_a});
+ EXPECT_EQ(c.Element<i32>(0), 1_i);
+ EXPECT_EQ(c.Element<i32>(1), 2_i);
+ EXPECT_EQ(c.Element<i32>(2), 3_i);
+ EXPECT_EQ(c.ElementCount(), 3u);
+}
+
+TEST_F(ConstantTest, Element_vec3_u32) {
+ Constant c(create<Vector>(create<U32>(), 3u), {1_a, 2_a, 3_a});
+ EXPECT_EQ(c.Element<u32>(0), 1_u);
+ EXPECT_EQ(c.Element<u32>(1), 2_u);
+ EXPECT_EQ(c.Element<u32>(2), 3_u);
+ EXPECT_EQ(c.ElementCount(), 3u);
+}
+
+TEST_F(ConstantTest, Element_vec3_bool) {
+ Constant c(create<Vector>(create<Bool>(), 2u), {true, false});
+ EXPECT_EQ(c.Element<bool>(0), true);
+ EXPECT_EQ(c.Element<bool>(1), false);
+ EXPECT_EQ(c.ElementCount(), 2u);
+}
+
+TEST_F(ConstantTest, Element_vec3_af) {
+ Constant c(create<Vector>(create<AbstractFloat>(), 3u), {1.0_a, 2.0_a, 3.0_a});
+ EXPECT_EQ(c.Element<AFloat>(0), 1.0_a);
+ EXPECT_EQ(c.Element<AFloat>(1), 2.0_a);
+ EXPECT_EQ(c.Element<AFloat>(2), 3.0_a);
+ EXPECT_EQ(c.ElementCount(), 3u);
+}
+
+TEST_F(ConstantTest, Element_vec3_f32) {
+ Constant c(create<Vector>(create<F32>(), 3u), {1.0_a, 2.0_a, 3.0_a});
+ EXPECT_EQ(c.Element<f32>(0), 1.0_f);
+ EXPECT_EQ(c.Element<f32>(1), 2.0_f);
+ EXPECT_EQ(c.Element<f32>(2), 3.0_f);
+ EXPECT_EQ(c.ElementCount(), 3u);
+}
+
+TEST_F(ConstantTest, Element_vec3_f16) {
+ Constant c(create<Vector>(create<F16>(), 3u), {1.0_a, 2.0_a, 3.0_a});
+ EXPECT_EQ(c.Element<f16>(0), 1.0_h);
+ EXPECT_EQ(c.Element<f16>(1), 2.0_h);
+ EXPECT_EQ(c.Element<f16>(2), 3.0_h);
+ EXPECT_EQ(c.ElementCount(), 3u);
+}
+
+TEST_F(ConstantTest, Element_mat2x3_af) {
+ Constant c(create<Matrix>(create<Vector>(create<AbstractFloat>(), 3u), 2u),
+ {1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a});
+ EXPECT_EQ(c.Element<AFloat>(0), 1.0_a);
+ EXPECT_EQ(c.Element<AFloat>(1), 2.0_a);
+ EXPECT_EQ(c.Element<AFloat>(2), 3.0_a);
+ EXPECT_EQ(c.Element<AFloat>(3), 4.0_a);
+ EXPECT_EQ(c.Element<AFloat>(4), 5.0_a);
+ EXPECT_EQ(c.Element<AFloat>(5), 6.0_a);
+ EXPECT_EQ(c.ElementCount(), 6u);
+}
+
+TEST_F(ConstantTest, Element_mat2x3_f32) {
+ Constant c(create<Matrix>(create<Vector>(create<F32>(), 3u), 2u),
+ {1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a});
+ EXPECT_EQ(c.Element<f32>(0), 1.0_f);
+ EXPECT_EQ(c.Element<f32>(1), 2.0_f);
+ EXPECT_EQ(c.Element<f32>(2), 3.0_f);
+ EXPECT_EQ(c.Element<f32>(3), 4.0_f);
+ EXPECT_EQ(c.Element<f32>(4), 5.0_f);
+ EXPECT_EQ(c.Element<f32>(5), 6.0_f);
+ EXPECT_EQ(c.ElementCount(), 6u);
+}
+
+TEST_F(ConstantTest, Element_mat2x3_f16) {
+ Constant c(create<Matrix>(create<Vector>(create<F16>(), 3u), 2u),
+ {1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a});
+ EXPECT_EQ(c.Element<f16>(0), 1.0_h);
+ EXPECT_EQ(c.Element<f16>(1), 2.0_h);
+ EXPECT_EQ(c.Element<f16>(2), 3.0_h);
+ EXPECT_EQ(c.Element<f16>(3), 4.0_h);
+ EXPECT_EQ(c.Element<f16>(4), 5.0_h);
+ EXPECT_EQ(c.Element<f16>(5), 6.0_h);
+ EXPECT_EQ(c.ElementCount(), 6u);
+}
+
+} // namespace
+} // namespace tint::sem
diff --git a/src/tint/transform/fold_constants.cc b/src/tint/transform/fold_constants.cc
index d51a68a..f268800 100644
--- a/src/tint/transform/fold_constants.cc
+++ b/src/tint/transform/fold_constants.cc
@@ -23,6 +23,7 @@
#include "src/tint/sem/expression.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/sem/type_conversion.h"
+#include "src/tint/utils/transform.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::FoldConstants);
@@ -50,24 +51,40 @@
return nullptr;
}
- // If original ctor expression had no init values, don't replace the
- // expression
+ // If original ctor expression had no init values, don't replace the expression
if (call->Arguments().empty()) {
return nullptr;
}
- auto build_scalar = [&](sem::Constant::Scalar s) {
+ auto build_elements = [&](size_t limit) {
return Switch(
value.ElementType(), //
- [&](const sem::I32*) { return ctx.dst->Expr(i32(std::get<AInt>(s).value)); },
- [&](const sem::U32*) { return ctx.dst->Expr(u32(std::get<AInt>(s).value)); },
- [&](const sem::F32*) { return ctx.dst->Expr(f32(std::get<AFloat>(s).value)); },
- [&](const sem::Bool*) { return ctx.dst->Expr(std::get<bool>(s)); },
+ [&](const sem::Bool*) {
+ return utils::TransformN(value.IElements(), limit, [&](AInt i) {
+ return static_cast<const ast::Expression*>(
+ ctx.dst->Expr(static_cast<bool>(i.value)));
+ });
+ },
+ [&](const sem::I32*) {
+ return utils::TransformN(value.IElements(), limit, [&](AInt i) {
+ return static_cast<const ast::Expression*>(ctx.dst->Expr(i32(i.value)));
+ });
+ },
+ [&](const sem::U32*) {
+ return utils::TransformN(value.IElements(), limit, [&](AInt i) {
+ return static_cast<const ast::Expression*>(ctx.dst->Expr(u32(i.value)));
+ });
+ },
+ [&](const sem::F32*) {
+ return utils::TransformN(value.FElements(), limit, [&](AFloat f) {
+ return static_cast<const ast::Expression*>(ctx.dst->Expr(f32(f.value)));
+ });
+ },
[&](Default) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unhandled Constant::Scalar type: "
<< value.ElementType()->FriendlyName(ctx.src->Symbols());
- return nullptr;
+ return ast::ExpressionList{};
});
};
@@ -78,17 +95,17 @@
// constructor args that the original node had, but after folding
// constants, cases like the following are problematic:
//
- // vec3<f32> = vec3<f32>(vec2<f32>, 1.0) // vec_size=3, ctor_size=2
+ // vec3<f32> = vec3<f32>(vec2<f32>(), 1.0) // vec_size=3, ctor_size=2
//
// In this case, creating a vec3 with 2 args is invalid, so we should
// create it with 3. So what we do is construct with vec_size args,
// except if the original vector was single-value initialized, in
// which case, we only construct with one arg again.
- uint32_t ctor_size = (call->Arguments().size() == 1) ? 1 : vec_size;
-
ast::ExpressionList ctors;
- for (uint32_t i = 0; i < ctor_size; ++i) {
- ctors.emplace_back(build_scalar(value.Elements()[i]));
+ if (call->Arguments().size() == 1) {
+ ctors = build_elements(1);
+ } else {
+ ctors = build_elements(value.ElementCount());
}
auto* el_ty = CreateASTTypeFor(ctx, vec->type());
@@ -96,7 +113,7 @@
}
if (ty->is_scalar()) {
- return build_scalar(value.Elements()[0]);
+ return build_elements(1)[0];
}
return nullptr;
diff --git a/src/tint/utils/compiler_macros.h b/src/tint/utils/compiler_macros.h
new file mode 100644
index 0000000..8b360ff
--- /dev/null
+++ b/src/tint/utils/compiler_macros.h
@@ -0,0 +1,39 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_UTILS_COMPILER_MACROS_H_
+#define SRC_TINT_UTILS_COMPILER_MACROS_H_
+
+#define TINT_REQUIRE_SEMICOLON \
+ do { \
+ } while (false)
+
+#if defined(_MSC_VER)
+// clang-format off
+#define TINT_BEGIN_DISABLE_WARNING_UNREACHABLE_CODE() \
+ __pragma(warning(push)) \
+ __pragma(warning(disable:4702)) \
+ TINT_REQUIRE_SEMICOLON
+#define TINT_END_DISABLE_WARNING_UNREACHABLE_CODE() \
+ __pragma(warning(pop)) \
+ TINT_REQUIRE_SEMICOLON
+// clang-format on
+#else
+// clang-format off
+#define TINT_BEGIN_DISABLE_WARNING_UNREACHABLE_CODE() TINT_REQUIRE_SEMICOLON
+#define TINT_END_DISABLE_WARNING_UNREACHABLE_CODE() TINT_REQUIRE_SEMICOLON
+// clang-format on
+#endif // defined(_MSC_VER)
+
+#endif // SRC_TINT_UTILS_COMPILER_MACROS_H_
diff --git a/src/tint/utils/result.h b/src/tint/utils/result.h
new file mode 100644
index 0000000..b2a69d5
--- /dev/null
+++ b/src/tint/utils/result.h
@@ -0,0 +1,103 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_UTILS_RESULT_H_
+#define SRC_TINT_UTILS_RESULT_H_
+
+#include <ostream>
+// TODO(https://crbug.com/dawn/1379) Update cpplint and remove NOLINT
+#include <variant> // NOLINT(build/include_order)
+
+namespace tint::utils {
+
+/// Empty structure used as the default FAILURE_TYPE for a Result.
+struct FailureType {};
+
+static constexpr const FailureType Failure;
+
+/// Result is a helper for functions that need to return a value, or an failure value.
+/// Result can be constructed with either a 'success' or 'failure' value.
+/// @tparam SUCCESS_TYPE the 'success' value type.
+/// @tparam FAILURE_TYPE the 'failure' value type. Defaults to FailureType which provides no
+/// information about the failure, except that something failed. Must not be the same type
+/// as SUCCESS_TYPE.
+template <typename SUCCESS_TYPE, typename FAILURE_TYPE = FailureType>
+struct Result {
+ static_assert(!std::is_same_v<SUCCESS_TYPE, FAILURE_TYPE>,
+ "Result must not have the same type for SUCCESS_TYPE and FAILURE_TYPE");
+
+ /// Constructor
+ /// @param success the success result
+ Result(const SUCCESS_TYPE& success) // NOLINT(runtime/explicit):
+ : value{success} {}
+
+ /// Constructor
+ /// @param failure the failure result
+ Result(const FAILURE_TYPE& failure) // NOLINT(runtime/explicit):
+ : value{failure} {}
+
+ /// @returns true if the result was a success
+ operator bool() const { return std::holds_alternative<SUCCESS_TYPE>(value); }
+
+ /// @returns true if the result was a failure
+ bool operator!() const { return std::holds_alternative<FAILURE_TYPE>(value); }
+
+ /// @returns the success value
+ /// @warning attempting to call this when the Result holds an failure will result in UB.
+ const SUCCESS_TYPE* operator->() const { return &std::get<SUCCESS_TYPE>(value); }
+
+ /// @returns the success value
+ /// @warning attempting to call this when the Result holds an failure value will result in UB.
+ const SUCCESS_TYPE& Get() const { return std::get<SUCCESS_TYPE>(value); }
+
+ /// @returns the failure value
+ /// @warning attempting to call this when the Result holds a success value will result in UB.
+ const FAILURE_TYPE& Failure() const { return std::get<FAILURE_TYPE>(value); }
+
+ /// Equality operator
+ /// @param val the value to compare this Result to
+ /// @returns true if this result holds a success value equal to `value`
+ bool operator==(SUCCESS_TYPE val) const {
+ if (auto* v = std::get_if<SUCCESS_TYPE>(&value)) {
+ return *v == val;
+ }
+ return false;
+ }
+
+ /// Equality operator
+ /// @param val the value to compare this Result to
+ /// @returns true if this result holds a failure value equal to `value`
+ bool operator==(FAILURE_TYPE val) const {
+ if (auto* v = std::get_if<FAILURE_TYPE>(&value)) {
+ return *v == val;
+ }
+ return false;
+ }
+
+ /// The result. Either a success of failure value.
+ std::variant<SUCCESS_TYPE, FAILURE_TYPE> value;
+};
+
+/// Writes the result to the ostream.
+/// @param out the std::ostream to write to
+/// @param res the result
+/// @return the std::ostream so calls can be chained
+template <typename SUCCESS, typename FAILURE>
+inline std::ostream& operator<<(std::ostream& out, Result<SUCCESS, FAILURE> res) {
+ return res ? (out << "success: " << res.Get()) : (out << "failure: " << res.Failure());
+}
+
+} // namespace tint::utils
+
+#endif // SRC_TINT_UTILS_RESULT_H_
diff --git a/src/tint/utils/result_test.cc b/src/tint/utils/result_test.cc
new file mode 100644
index 0000000..ce125f4
--- /dev/null
+++ b/src/tint/utils/result_test.cc
@@ -0,0 +1,55 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/utils/result.h"
+
+#include <string>
+
+#include "gmock/gmock.h"
+
+namespace tint::utils {
+namespace {
+
+TEST(ResultTest, SuccessInt) {
+ auto r = Result<int>(123);
+ EXPECT_TRUE(r);
+ EXPECT_FALSE(!r);
+ EXPECT_EQ(r.Get(), 123);
+}
+
+TEST(ResultTest, SuccessStruct) {
+ struct S {
+ int value;
+ };
+ auto r = Result<S>({123});
+ EXPECT_TRUE(r);
+ EXPECT_FALSE(!r);
+ EXPECT_EQ(r->value, 123);
+}
+
+TEST(ResultTest, Failure) {
+ auto r = Result<int>(Failure);
+ EXPECT_FALSE(r);
+ EXPECT_TRUE(!r);
+}
+
+TEST(ResultTest, CustomFailure) {
+ auto r = Result<int, std::string>("oh noes!");
+ EXPECT_FALSE(r);
+ EXPECT_TRUE(!r);
+ EXPECT_EQ(r.Failure(), "oh noes!");
+}
+
+} // namespace
+} // namespace tint::utils
diff --git a/src/tint/utils/transform.h b/src/tint/utils/transform.h
index 2cd9481..ff92530 100644
--- a/src/tint/utils/transform.h
+++ b/src/tint/utils/transform.h
@@ -27,8 +27,7 @@
/// Transform performs an element-wise transformation of a vector.
/// @param in the input vector.
/// @param transform the transformation function with signature: `OUT(IN)`
-/// @returns a new vector with each element of the source vector transformed by
-/// `transform`.
+/// @returns a new vector with each element of the source vector transformed by `transform`.
template <typename IN, typename TRANSFORMER>
auto Transform(const std::vector<IN>& in, TRANSFORMER&& transform)
-> std::vector<decltype(transform(in[0]))> {
@@ -41,10 +40,8 @@
/// Transform performs an element-wise transformation of a vector.
/// @param in the input vector.
-/// @param transform the transformation function with signature:
-/// `OUT(IN, size_t)`
-/// @returns a new vector with each element of the source vector transformed by
-/// `transform`.
+/// @param transform the transformation function with signature: `OUT(IN, size_t)`
+/// @returns a new vector with each element of the source vector transformed by `transform`.
template <typename IN, typename TRANSFORMER>
auto Transform(const std::vector<IN>& in, TRANSFORMER&& transform)
-> std::vector<decltype(transform(in[0], 1u))> {
@@ -55,6 +52,40 @@
return result;
}
+/// TransformN performs an element-wise transformation of a vector, transforming and returning at
+/// most `n` elements.
+/// @param in the input vector.
+/// @param n the maximum number of elements to transform.
+/// @param transform the transformation function with signature: `OUT(IN)`
+/// @returns a new vector with at most n-elements of the source vector transformed by `transform`.
+template <typename IN, typename TRANSFORMER>
+auto TransformN(const std::vector<IN>& in, size_t n, TRANSFORMER&& transform)
+ -> std::vector<decltype(transform(in[0]))> {
+ const auto count = std::min(n, in.size());
+ std::vector<decltype(transform(in[0]))> result(count);
+ for (size_t i = 0; i < count; ++i) {
+ result[i] = transform(in[i]);
+ }
+ return result;
+}
+
+/// TransformN performs an element-wise transformation of a vector, transforming and returning at
+/// most `n` elements.
+/// @param in the input vector.
+/// @param n the maximum number of elements to transform.
+/// @param transform the transformation function with signature: `OUT(IN, size_t)`
+/// @returns a new vector with at most n-elements of the source vector transformed by `transform`.
+template <typename IN, typename TRANSFORMER>
+auto TransformN(const std::vector<IN>& in, size_t n, TRANSFORMER&& transform)
+ -> std::vector<decltype(transform(in[0], 1u))> {
+ const auto count = std::min(n, in.size());
+ std::vector<decltype(transform(in[0], 1u))> result(count);
+ for (size_t i = 0; i < count; ++i) {
+ result[i] = transform(in[i], i);
+ }
+ return result;
+}
+
} // namespace tint::utils
#endif // SRC_TINT_UTILS_TRANSFORM_H_
diff --git a/src/tint/utils/transform_test.cc b/src/tint/utils/transform_test.cc
index af8b832..89c0756 100644
--- a/src/tint/utils/transform_test.cc
+++ b/src/tint/utils/transform_test.cc
@@ -30,7 +30,7 @@
const std::vector<int> empty{};
{
auto transformed = Transform(empty, [](int) -> int {
- [] { FAIL() << "Transform should not be called for empty vector"; }();
+ [] { FAIL() << "Callback should not be called for empty vector"; }();
return 0;
});
CHECK_ELEMENT_TYPE(transformed, int);
@@ -38,7 +38,7 @@
}
{
auto transformed = Transform(empty, [](int, size_t) -> int {
- [] { FAIL() << "Transform should not be called for empty vector"; }();
+ [] { FAIL() << "Callback should not be called for empty vector"; }();
return 0;
});
CHECK_ELEMENT_TYPE(transformed, int);
@@ -48,16 +48,16 @@
TEST(TransformTest, Identity) {
const std::vector<int> input{1, 2, 3, 4};
- {
- auto transformed = Transform(input, [](int i) { return i; });
- CHECK_ELEMENT_TYPE(transformed, int);
- EXPECT_THAT(transformed, testing::ElementsAre(1, 2, 3, 4));
- }
- {
- auto transformed = Transform(input, [](int i, size_t) { return i; });
- CHECK_ELEMENT_TYPE(transformed, int);
- EXPECT_THAT(transformed, testing::ElementsAre(1, 2, 3, 4));
- }
+ auto transformed = Transform(input, [](int i) { return i; });
+ CHECK_ELEMENT_TYPE(transformed, int);
+ EXPECT_THAT(transformed, testing::ElementsAre(1, 2, 3, 4));
+}
+
+TEST(TransformTest, IdentityWithIndex) {
+ const std::vector<int> input{1, 2, 3, 4};
+ auto transformed = Transform(input, [](int i, size_t) { return i; });
+ CHECK_ELEMENT_TYPE(transformed, int);
+ EXPECT_THAT(transformed, testing::ElementsAre(1, 2, 3, 4));
}
TEST(TransformTest, Index) {
@@ -87,5 +87,135 @@
}
}
+TEST(TransformNTest, Empty) {
+ const std::vector<int> empty{};
+ {
+ auto transformed = TransformN(empty, 4u, [](int) -> int {
+ [] { FAIL() << "Callback should not be called for empty vector"; }();
+ return 0;
+ });
+ CHECK_ELEMENT_TYPE(transformed, int);
+ EXPECT_EQ(transformed.size(), 0u);
+ }
+ {
+ auto transformed = TransformN(empty, 4u, [](int, size_t) -> int {
+ [] { FAIL() << "Callback should not be called for empty vector"; }();
+ return 0;
+ });
+ CHECK_ELEMENT_TYPE(transformed, int);
+ EXPECT_EQ(transformed.size(), 0u);
+ }
+}
+
+TEST(TransformNTest, Identity) {
+ const std::vector<int> input{1, 2, 3, 4};
+ {
+ auto transformed = TransformN(input, 0u, [](int) {
+ [] { FAIL() << "Callback should not call the transform when n == 0"; }();
+ return 0;
+ });
+ CHECK_ELEMENT_TYPE(transformed, int);
+ EXPECT_TRUE(transformed.empty());
+ }
+ {
+ auto transformed = TransformN(input, 2u, [](int i) { return i; });
+ CHECK_ELEMENT_TYPE(transformed, int);
+ EXPECT_THAT(transformed, testing::ElementsAre(1, 2));
+ }
+ {
+ auto transformed = TransformN(input, 6u, [](int i) { return i; });
+ CHECK_ELEMENT_TYPE(transformed, int);
+ EXPECT_THAT(transformed, testing::ElementsAre(1, 2, 3, 4));
+ }
+}
+
+TEST(TransformNTest, IdentityWithIndex) {
+ const std::vector<int> input{1, 2, 3, 4};
+ {
+ auto transformed = TransformN(input, 0u, [](int, size_t) {
+ [] { FAIL() << "Callback should not call the transform when n == 0"; }();
+ return 0;
+ });
+ CHECK_ELEMENT_TYPE(transformed, int);
+ EXPECT_TRUE(transformed.empty());
+ }
+ {
+ auto transformed = TransformN(input, 3u, [](int i, size_t) { return i; });
+ CHECK_ELEMENT_TYPE(transformed, int);
+ EXPECT_THAT(transformed, testing::ElementsAre(1, 2, 3));
+ }
+ {
+ auto transformed = TransformN(input, 9u, [](int i, size_t) { return i; });
+ CHECK_ELEMENT_TYPE(transformed, int);
+ EXPECT_THAT(transformed, testing::ElementsAre(1, 2, 3, 4));
+ }
+}
+
+TEST(TransformNTest, Index) {
+ const std::vector<int> input{10, 20, 30, 40};
+ {
+ auto transformed = TransformN(input, 0u, [](int, size_t) {
+ [] { FAIL() << "Callback should not call the transform when n == 0"; }();
+ return static_cast<size_t>(0);
+ });
+ CHECK_ELEMENT_TYPE(transformed, size_t);
+ EXPECT_TRUE(transformed.empty());
+ }
+ {
+ auto transformed = TransformN(input, 2u, [](int, size_t idx) { return idx; });
+ CHECK_ELEMENT_TYPE(transformed, size_t);
+ EXPECT_THAT(transformed, testing::ElementsAre(0u, 1u));
+ }
+ {
+ auto transformed = TransformN(input, 9u, [](int, size_t idx) { return idx; });
+ CHECK_ELEMENT_TYPE(transformed, size_t);
+ EXPECT_THAT(transformed, testing::ElementsAre(0u, 1u, 2u, 3u));
+ }
+}
+
+TEST(TransformNTest, TransformSameType) {
+ const std::vector<int> input{1, 2, 3, 4};
+ {
+ auto transformed = TransformN(input, 0u, [](int, size_t) {
+ [] { FAIL() << "Callback should not call the transform when n == 0"; }();
+ return 0;
+ });
+ CHECK_ELEMENT_TYPE(transformed, int);
+ EXPECT_TRUE(transformed.empty());
+ }
+ {
+ auto transformed = TransformN(input, 2u, [](int i) { return i * 10; });
+ CHECK_ELEMENT_TYPE(transformed, int);
+ EXPECT_THAT(transformed, testing::ElementsAre(10, 20));
+ }
+ {
+ auto transformed = TransformN(input, 9u, [](int i) { return i * 10; });
+ CHECK_ELEMENT_TYPE(transformed, int);
+ EXPECT_THAT(transformed, testing::ElementsAre(10, 20, 30, 40));
+ }
+}
+
+TEST(TransformNTest, TransformDifferentType) {
+ const std::vector<int> input{1, 2, 3, 4};
+ {
+ auto transformed = TransformN(input, 0u, [](int) {
+ [] { FAIL() << "Callback should not call the transform when n == 0"; }();
+ return std::string();
+ });
+ CHECK_ELEMENT_TYPE(transformed, std::string);
+ EXPECT_TRUE(transformed.empty());
+ }
+ {
+ auto transformed = TransformN(input, 2u, [](int i) { return std::to_string(i); });
+ CHECK_ELEMENT_TYPE(transformed, std::string);
+ EXPECT_THAT(transformed, testing::ElementsAre("1", "2"));
+ }
+ {
+ auto transformed = TransformN(input, 9u, [](int i) { return std::to_string(i); });
+ CHECK_ELEMENT_TYPE(transformed, std::string);
+ EXPECT_THAT(transformed, testing::ElementsAre("1", "2", "3", "4"));
+ }
+}
+
} // namespace
} // namespace tint::utils
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index 1cf7480..ccce06a 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -2183,13 +2183,14 @@
return true;
},
[&](const ast::FloatLiteralExpression* l) {
- if (std::isinf(l->value)) {
+ auto f32 = static_cast<float>(l->value);
+ if (std::isinf(f32)) {
out << (l->value >= 0 ? "uintBitsToFloat(0x7f800000u)"
: "uintBitsToFloat(0xff800000u)");
} else if (std::isnan(l->value)) {
out << "uintBitsToFloat(0x7fc00000u)";
} else {
- out << FloatToString(static_cast<float>(l->value)) << "f";
+ out << FloatToString(f32) << "f";
}
return true;
},
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index affaf5e..d6a5fa7 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -639,7 +639,6 @@
bool GeneratorImpl::EmitExpressionOrOneIfZero(std::ostream& out, const ast::Expression* expr) {
// For constants, replace literal 0 with 1.
- sem::Constant::Scalars elems;
if (const auto& val = builder_.Sem().Get(expr)->ConstantValue()) {
if (!val.AnyZero()) {
return EmitExpression(out, expr);
@@ -657,7 +656,7 @@
}
out << "(";
- for (size_t i = 0; i < val.Elements().size(); ++i) {
+ for (size_t i = 0; i < val.ElementCount(); ++i) {
if (i != 0) {
out << ", ";
}
@@ -3140,13 +3139,14 @@
out << (l->value ? "true" : "false");
return true;
},
- [&](const ast::FloatLiteralExpression* fl) {
- if (std::isinf(fl->value)) {
- out << (fl->value >= 0 ? "asfloat(0x7f800000u)" : "asfloat(0xff800000u)");
- } else if (std::isnan(fl->value)) {
+ [&](const ast::FloatLiteralExpression* l) {
+ auto f32 = static_cast<float>(l->value);
+ if (std::isinf(f32)) {
+ out << (f32 >= 0 ? "asfloat(0x7f800000u)" : "asfloat(0xff800000u)");
+ } else if (std::isnan(f32)) {
out << "asfloat(0x7fc00000u)";
} else {
- out << FloatToString(static_cast<float>(fl->value)) << "f";
+ out << FloatToString(f32) << "f";
}
return true;
},
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index be94e2c..54d9164 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -1544,12 +1544,13 @@
return true;
},
[&](const ast::FloatLiteralExpression* l) {
- if (std::isinf(l->value)) {
- out << (l->value >= 0 ? "INFINITY" : "-INFINITY");
- } else if (std::isnan(l->value)) {
+ auto f32 = static_cast<float>(l->value);
+ if (std::isinf(f32)) {
+ out << (f32 >= 0 ? "INFINITY" : "-INFINITY");
+ } else if (std::isnan(f32)) {
out << "NAN";
} else {
- out << FloatToString(static_cast<float>(l->value)) << "f";
+ out << FloatToString(f32) << "f";
}
return true;
},
diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc
index 5f8072b..e2d1fac 100644
--- a/src/tint/writer/spirv/builder.cc
+++ b/src/tint/writer/spirv/builder.cc
@@ -368,7 +368,11 @@
}
}
-bool Builder::GenerateExtension(ast::Extension) {
+void Builder::push_extension(const char* extension) {
+ extensions_.push_back(Instruction{spv::Op::OpExtension, {Operand(extension)}});
+}
+
+bool Builder::GenerateExtension(ast::Extension extension) {
/*
For each supported extension, push corresponding capability into the builder.
For example:
@@ -379,6 +383,15 @@
push_capability(SpvCapabilityStorageInputOutput16);
}
*/
+ switch (extension) {
+ case ast::Extension::kChromiumExperimentalDP4a:
+ push_extension("SPV_KHR_integer_dot_product");
+ push_capability(SpvCapabilityDotProductKHR);
+ push_capability(SpvCapabilityDotProductInput4x8BitPackedKHR);
+ break;
+ default:
+ return false;
+ }
return true;
}
@@ -924,7 +937,7 @@
Operand(result_type_id),
extract,
Operand(info->source_id),
- Operand(idx_constval.ElementAs<uint32_t>(0)),
+ Operand(idx_constval.Element<uint32_t>(0)),
})) {
return false;
}
@@ -2494,6 +2507,30 @@
glsl_std450(GLSLstd450SAbs);
}
break;
+ case BuiltinType::kDot4I8Packed: {
+ auto first_param_id = get_arg_as_value_id(0);
+ auto second_param_id = get_arg_as_value_id(1);
+ if (!push_function_inst(spv::Op::OpSDotKHR,
+ {Operand(result_type_id), result, Operand(first_param_id),
+ Operand(second_param_id),
+ Operand(static_cast<uint32_t>(
+ spv::PackedVectorFormat::PackedVectorFormat4x8BitKHR))})) {
+ return 0;
+ }
+ return result_id;
+ }
+ case BuiltinType::kDot4U8Packed: {
+ auto first_param_id = get_arg_as_value_id(0);
+ auto second_param_id = get_arg_as_value_id(1);
+ if (!push_function_inst(spv::Op::OpUDotKHR,
+ {Operand(result_type_id), result, Operand(first_param_id),
+ Operand(second_param_id),
+ Operand(static_cast<uint32_t>(
+ spv::PackedVectorFormat::PackedVectorFormat4x8BitKHR))})) {
+ return 0;
+ }
+ return result_id;
+ }
default: {
auto inst_id = builtin_to_glsl_method(builtin);
if (inst_id == 0) {
diff --git a/src/tint/writer/spirv/builder.h b/src/tint/writer/spirv/builder.h
index 1745ed5..fc2fa13 100644
--- a/src/tint/writer/spirv/builder.h
+++ b/src/tint/writer/spirv/builder.h
@@ -113,11 +113,8 @@
/// @returns the capabilities
const InstructionList& capabilities() const { return capabilities_; }
/// Adds an instruction to the extensions
- /// @param op the op to set
- /// @param operands the operands for the instruction
- void push_extension(spv::Op op, const OperandList& operands) {
- extensions_.push_back(Instruction{op, operands});
- }
+ /// @param extension the name of the extension
+ void push_extension(const char* extension);
/// @returns the extensions
const InstructionList& extensions() const { return extensions_; }
/// Adds an instruction to the ext import
diff --git a/src/tint/writer/spirv/builder_builtin_test.cc b/src/tint/writer/spirv/builder_builtin_test.cc
index 6d1316e..59a567a 100644
--- a/src/tint/writer/spirv/builder_builtin_test.cc
+++ b/src/tint/writer/spirv/builder_builtin_test.cc
@@ -2601,5 +2601,80 @@
)");
}
+TEST_F(BuiltinBuilderTest, Call_Dot4I8Packed) {
+ auto* ext =
+ create<ast::Enable>(Source{Source::Range{Source::Location{10, 2}, Source::Location{10, 5}}},
+ ast::Extension::kChromiumExperimentalDP4a);
+ AST().AddEnable(ext);
+
+ auto* val1 = Var("val1", ty.u32());
+ auto* val2 = Var("val2", ty.u32());
+ auto* call = Call("dot4I8Packed", val1, val2);
+ auto* func = WrapInFunction(val1, val2, call);
+
+ spirv::Builder& b = Build();
+
+ ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
+
+ EXPECT_EQ(DumpBuilder(b), R"(OpEntryPoint GLCompute %3 "test_function"
+OpExecutionMode %3 LocalSize 1 1 1
+OpName %3 "test_function"
+OpName %5 "val1"
+OpName %9 "val2"
+%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%7 = OpTypeInt 32 0
+%6 = OpTypePointer Function %7
+%8 = OpConstantNull %7
+%11 = OpTypeInt 32 1
+%3 = OpFunction %2 None %1
+%4 = OpLabel
+%5 = OpVariable %6 Function %8
+%9 = OpVariable %6 Function %8
+%12 = OpLoad %7 %5
+%13 = OpLoad %7 %9
+%10 = OpSDot %11 %12 %13 PackedVectorFormat4x8Bit
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(BuiltinBuilderTest, Call_Dot4U8Packed) {
+ auto* ext =
+ create<ast::Enable>(Source{Source::Range{Source::Location{10, 2}, Source::Location{10, 5}}},
+ ast::Extension::kChromiumExperimentalDP4a);
+ AST().AddEnable(ext);
+
+ auto* val1 = Var("val1", ty.u32());
+ auto* val2 = Var("val2", ty.u32());
+ auto* call = Call("dot4U8Packed", val1, val2);
+ auto* func = WrapInFunction(val1, val2, call);
+
+ spirv::Builder& b = Build();
+
+ ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
+
+ EXPECT_EQ(DumpBuilder(b), R"(OpEntryPoint GLCompute %3 "test_function"
+OpExecutionMode %3 LocalSize 1 1 1
+OpName %3 "test_function"
+OpName %5 "val1"
+OpName %9 "val2"
+%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%7 = OpTypeInt 32 0
+%6 = OpTypePointer Function %7
+%8 = OpConstantNull %7
+%3 = OpFunction %2 None %1
+%4 = OpLabel
+%5 = OpVariable %6 Function %8
+%9 = OpVariable %6 Function %8
+%11 = OpLoad %7 %5
+%12 = OpLoad %7 %9
+%10 = OpUDot %7 %11 %12 PackedVectorFormat4x8Bit
+OpReturn
+OpFunctionEnd
+)");
+}
+
} // namespace
} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/builder_test.cc b/src/tint/writer/spirv/builder_test.cc
index 3548f9a..24d5b72 100644
--- a/src/tint/writer/spirv/builder_test.cc
+++ b/src/tint/writer/spirv/builder_test.cc
@@ -49,5 +49,13 @@
EXPECT_EQ(DumpInstructions(b.capabilities()), "OpCapability Shader\n");
}
+TEST_F(BuilderTest, DeclareExtension) {
+ spirv::Builder& b = Build();
+
+ b.push_extension("SPV_KHR_integer_dot_product");
+
+ EXPECT_EQ(DumpInstructions(b.extensions()), "OpExtension \"SPV_KHR_integer_dot_product\"\n");
+}
+
} // namespace
} // namespace tint::writer::spirv
diff --git a/test/tint/BUILD.gn b/test/tint/BUILD.gn
index 4f1619f..7acdea9 100644
--- a/test/tint/BUILD.gn
+++ b/test/tint/BUILD.gn
@@ -289,6 +289,7 @@
"../../src/tint/sem/atomic_test.cc",
"../../src/tint/sem/bool_test.cc",
"../../src/tint/sem/builtin_test.cc",
+ "../../src/tint/sem/constant_test.cc",
"../../src/tint/sem/depth_multisampled_texture_test.cc",
"../../src/tint/sem/depth_texture_test.cc",
"../../src/tint/sem/expression_test.cc",
@@ -375,6 +376,7 @@
"../../src/tint/utils/io/tmpfile_test.cc",
"../../src/tint/utils/map_test.cc",
"../../src/tint/utils/math_test.cc",
+ "../../src/tint/utils/result_test.cc",
"../../src/tint/utils/reverse_test.cc",
"../../src/tint/utils/scoped_assignment_test.cc",
"../../src/tint/utils/string_test.cc",
@@ -726,6 +728,7 @@
"../../src/tint/clone_context_test.cc",
"../../src/tint/debug_test.cc",
"../../src/tint/demangler_test.cc",
+ "../../src/tint/number_test.cc",
"../../src/tint/program_builder_test.cc",
"../../src/tint/program_test.cc",
"../../src/tint/scope_stack_test.cc",