tint: Fix const eval of type conversions
These were quite spectacularly broken.
Also:
* Fix the definition of 'scalar' in `intrinsics.def`. This was in part why conversions were broken, as abstracts were materialized before reaching the converter builtin when they shouldn't have been.
* Implement `ScalarArgsFrom()` helper in `const_eval_test.cc`. This is used by the new conversion tests, and also implements part of the suggestion to improve tint:1709.
Fixed: tint:1707
Bug: tint:1709
Change-Id: Iab962b671305e868f92710912d2ed07e3338c680
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/105261
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h
index 501cd0f..923f3ff 100644
--- a/src/tint/resolver/resolver_test_helper.h
+++ b/src/tint/resolver/resolver_test_helper.h
@@ -17,6 +17,7 @@
#include <functional>
#include <memory>
+#include <ostream>
#include <string>
#include <tuple>
#include <utility>
@@ -174,13 +175,15 @@
struct ptr {};
/// Type used to accept scalars as arguments. Can be either a single value that gets splatted for
-/// composite types, or all values requried by the composite type.
+/// composite types, or all values required by the composite type.
struct ScalarArgs {
/// Constructor
+ ScalarArgs() = default;
+
+ /// Constructor
/// @param single_value single value to initialize with
template <typename T>
- ScalarArgs(T single_value) // NOLINT: implicit on purpose
- : values(utils::Vector<Storage, 1>{single_value}) {}
+ explicit ScalarArgs(T single_value) : values(utils::Vector<Storage, 1>{single_value}) {}
/// Constructor
/// @param all_values all values to initialize the composite type with
@@ -192,6 +195,10 @@
}
}
+ /// @param other the other ScalarArgs to compare against
+ /// @returns true if all values are equal to the values in @p other
+ bool operator==(const ScalarArgs& other) const { return values == other.values; }
+
/// Valid scalar types for args
using Storage = std::variant<i32, u32, f32, f16, AInt, AFloat, bool>;
@@ -199,10 +206,28 @@
utils::Vector<Storage, 16> values;
};
+/// @param o the std::ostream to write to
+/// @param args the ScalarArgs
+/// @return the std::ostream so calls can be chained
+inline std::ostream& operator<<(std::ostream& o, const ScalarArgs& args) {
+ o << "[";
+ bool first = true;
+ for (auto& val : args.values) {
+ if (!first) {
+ o << ", ";
+ }
+ first = false;
+ std::visit([&](auto&& v) { o << v; }, val);
+ }
+ o << "]";
+ return o;
+}
+
using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b);
using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, ScalarArgs args);
using ast_expr_from_double_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, double v);
using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b);
+using type_name_func_ptr = std::string (*)();
template <typename T>
struct DataType {};
@@ -241,7 +266,7 @@
/// @param v arg of type double that will be cast to bool.
/// @return a new AST expression of the bool type
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, static_cast<ElementType>(v));
+ return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "bool"; }
@@ -272,7 +297,7 @@
/// @param v arg of type double that will be cast to i32.
/// @return a new AST i32 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, static_cast<ElementType>(v));
+ return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "i32"; }
@@ -303,7 +328,7 @@
/// @param v arg of type double that will be cast to u32.
/// @return a new AST u32 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, static_cast<ElementType>(v));
+ return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "u32"; }
@@ -334,7 +359,7 @@
/// @param v arg of type double that will be cast to f32.
/// @return a new AST f32 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, static_cast<f32>(v));
+ return Expr(b, ScalarArgs{static_cast<f32>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "f32"; }
@@ -365,7 +390,7 @@
/// @param v arg of type double that will be cast to f16.
/// @return a new AST f16 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, static_cast<ElementType>(v));
+ return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "f16"; }
@@ -395,7 +420,7 @@
/// @param v arg of type double that will be cast to AFloat.
/// @return a new AST abstract-float literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, static_cast<ElementType>(v));
+ return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "abstract-float"; }
@@ -425,7 +450,7 @@
/// @param v arg of type double that will be cast to AInt.
/// @return a new AST abstract-int literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, static_cast<ElementType>(v));
+ return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "abstract-int"; }
@@ -463,7 +488,7 @@
const bool one_value = args.values.Length() == 1;
utils::Vector<const ast::Expression*, N> r;
for (size_t i = 0; i < N; ++i) {
- r.Push(DataType<T>::Expr(b, one_value ? args.values[0] : args.values[i]));
+ r.Push(DataType<T>::Expr(b, ScalarArgs{one_value ? args.values[0] : args.values[i]}));
}
return r;
}
@@ -471,7 +496,7 @@
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST vector value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, static_cast<ElementType>(v));
+ return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() {
@@ -514,7 +539,7 @@
utils::Vector<const ast::Expression*, N> r;
for (uint32_t i = 0; i < N; ++i) {
if (one_value) {
- r.Push(DataType<vec<M, T>>::Expr(b, args.values[0]));
+ r.Push(DataType<vec<M, T>>::Expr(b, ScalarArgs{args.values[0]}));
} else {
utils::Vector<T, M> v;
for (size_t j = 0; j < M; ++j) {
@@ -529,7 +554,7 @@
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST matrix value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, static_cast<ElementType>(v));
+ return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() {
@@ -585,7 +610,7 @@
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST expression of the alias type
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, static_cast<ElementType>(v));
+ return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
@@ -626,7 +651,7 @@
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST expression of the pointer type
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, static_cast<ElementType>(v));
+ return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
@@ -680,7 +705,7 @@
const bool one_value = args.values.Length() == 1;
utils::Vector<const ast::Expression*, N> r;
for (uint32_t i = 0; i < N; i++) {
- r.Push(DataType<T>::Expr(b, one_value ? args.values[0] : args.values[i]));
+ r.Push(DataType<T>::Expr(b, ScalarArgs{one_value ? args.values[0] : args.values[i]}));
}
return r;
}
@@ -688,7 +713,7 @@
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST array value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, static_cast<ElementType>(v));
+ return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() {
@@ -706,13 +731,23 @@
ast_expr_from_double_func_ptr expr_from_double;
/// sem type create function
sem_type_func_ptr sem;
+ /// type name function
+ type_name_func_ptr name;
};
+/// @param o the std::ostream to write to
+/// @param ptrs the CreatePtrs
+/// @return the std::ostream so calls can be chained
+inline std::ostream& operator<<(std::ostream& o, const CreatePtrs& ptrs) {
+ return o << (ptrs.name ? ptrs.name() : "<unknown>");
+}
+
/// Returns a CreatePtrs struct instance with all creation pointer types for
/// type `T`
template <typename T>
constexpr CreatePtrs CreatePtrsFor() {
- return {DataType<T>::AST, DataType<T>::Expr, DataType<T>::ExprFromDouble, DataType<T>::Sem};
+ return {DataType<T>::AST, DataType<T>::Expr, DataType<T>::ExprFromDouble, DataType<T>::Sem,
+ DataType<T>::Name};
}
/// Value<T> is an instance of a value of type DataType<T>. Useful for storing values to create
@@ -729,15 +764,15 @@
/// Creates a Value<T> with `args`
/// @param args the args that will be passed to the expression
/// @returns a Value<T>
- static Value Create(ScalarArgs args) { return Value{DataType::Expr, std::move(args)}; }
+ static Value Create(ScalarArgs args) { return Value{CreatePtrsFor<T>(), std::move(args)}; }
/// Creates an `ast::Expression` for the type T passing in previously stored args
/// @param b the ProgramBuilder
/// @returns an expression node
- const ast::Expression* Expr(ProgramBuilder& b) const { return (*expr)(b, args); }
+ const ast::Expression* Expr(ProgramBuilder& b) const { return (*create.expr)(b, args); }
- /// ast expression type create function
- ast_expr_func_ptr expr;
+ /// functions to create values / types of the value
+ CreatePtrs create;
/// args to create expression with
ScalarArgs args;
};
@@ -764,7 +799,7 @@
/// Creates a `Value<T>` from a scalar `v`
template <typename T>
auto Val(T v) {
- return Value<T>::Create(v);
+ return Value<T>::Create(ScalarArgs{v});
}
/// Creates a `Value<vec<N, T>>` from N scalar `args`