tint/resolver: Further simplify test const eval framework
Replace ScalarArgs struct with Scalar variant and vector.
Fold ValueBase and ConcreteValue into Value.
Change-Id: I5cc5811a87f1aae162feb65fb6b1ecdac033d0fe
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111761
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/resolver/const_eval_conversion_test.cc b/src/tint/resolver/const_eval_conversion_test.cc
index da37f3b..8640a1d 100644
--- a/src/tint/resolver/const_eval_conversion_test.cc
+++ b/src/tint/resolver/const_eval_conversion_test.cc
@@ -94,12 +94,12 @@
ASSERT_NE(sem->ConstantValue(), nullptr);
EXPECT_TYPE(sem->ConstantValue()->Type(), target_sem_ty);
- auto expected_values = expected.Args();
+ auto expected_values = expected.args;
if (kind == Kind::kVector) {
- expected_values.values.Push(expected_values.values[0]);
- expected_values.values.Push(expected_values.values[0]);
+ expected_values.Push(expected_values[0]);
+ expected_values.Push(expected_values[0]);
}
- auto got_values = ScalarArgsFrom(sem->ConstantValue());
+ auto got_values = ScalarsFrom(sem->ConstantValue());
EXPECT_EQ(expected_values, got_values);
}
}
diff --git a/src/tint/resolver/const_eval_test.h b/src/tint/resolver/const_eval_test.h
index dcb91d0..0d27805 100644
--- a/src/tint/resolver/const_eval_test.h
+++ b/src/tint/resolver/const_eval_test.h
@@ -37,28 +37,29 @@
inline const auto k3PiOver4 = T(UnwrapNumber<T>(2.356194490192344928846));
/// Walks the sem::Constant @p c, accumulating all the inner-most scalar values into @p args
-inline void CollectScalarArgs(const sem::Constant* c, builder::ScalarArgs& args) {
+template <size_t N>
+inline void CollectScalars(const sem::Constant* c, utils::Vector<builder::Scalar, N>& scalars) {
Switch(
c->Type(), //
- [&](const sem::AbstractInt*) { args.values.Push(c->As<AInt>()); },
- [&](const sem::AbstractFloat*) { args.values.Push(c->As<AFloat>()); },
- [&](const sem::Bool*) { args.values.Push(c->As<bool>()); },
- [&](const sem::I32*) { args.values.Push(c->As<i32>()); },
- [&](const sem::U32*) { args.values.Push(c->As<u32>()); },
- [&](const sem::F32*) { args.values.Push(c->As<f32>()); },
- [&](const sem::F16*) { args.values.Push(c->As<f16>()); },
+ [&](const sem::AbstractInt*) { scalars.Push(c->As<AInt>()); },
+ [&](const sem::AbstractFloat*) { scalars.Push(c->As<AFloat>()); },
+ [&](const sem::Bool*) { scalars.Push(c->As<bool>()); },
+ [&](const sem::I32*) { scalars.Push(c->As<i32>()); },
+ [&](const sem::U32*) { scalars.Push(c->As<u32>()); },
+ [&](const sem::F32*) { scalars.Push(c->As<f32>()); },
+ [&](const sem::F16*) { scalars.Push(c->As<f16>()); },
[&](Default) {
size_t i = 0;
while (auto* child = c->Index(i++)) {
- CollectScalarArgs(child, args);
+ CollectScalars(child, scalars);
}
});
}
/// Walks the sem::Constant @p c, returning all the inner-most scalar values.
-inline builder::ScalarArgs ScalarArgsFrom(const sem::Constant* c) {
- builder::ScalarArgs out;
- CollectScalarArgs(c, out);
+inline utils::Vector<builder::Scalar, 16> ScalarsFrom(const sem::Constant* c) {
+ utils::Vector<builder::Scalar, 16> out;
+ CollectScalars(c, out);
return out;
}
@@ -90,14 +91,14 @@
inline void CheckConstant(const sem::Constant* got_constant,
const builder::Value& expected_value,
CheckConstantFlags flags = {}) {
- auto values_flat = ScalarArgsFrom(got_constant);
- auto expected_values_flat = expected_value.Args();
- ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
- for (size_t i = 0; i < values_flat.values.Length(); ++i) {
- auto& got_scalar = values_flat.values[i];
- auto& expected_scalar = expected_values_flat.values[i];
+ auto values_flat = ScalarsFrom(got_constant);
+ auto expected_values_flat = expected_value.args;
+ ASSERT_EQ(values_flat.Length(), expected_values_flat.Length());
+ for (size_t i = 0; i < values_flat.Length(); ++i) {
+ auto& got_scalar = values_flat[i];
+ auto& expected_scalar = expected_values_flat[i];
std::visit(
- [&](auto&& expected) {
+ [&](const auto& expected) {
using T = std::decay_t<decltype(expected)>;
ASSERT_TRUE(std::holds_alternative<T>(got_scalar));
diff --git a/src/tint/resolver/const_eval_unary_op_test.cc b/src/tint/resolver/const_eval_unary_op_test.cc
index d24c27b..f7f5928 100644
--- a/src/tint/resolver/const_eval_unary_op_test.cc
+++ b/src/tint/resolver/const_eval_unary_op_test.cc
@@ -61,14 +61,14 @@
ASSERT_NE(value, nullptr);
EXPECT_TYPE(value->Type(), sem->Type());
- auto values_flat = ScalarArgsFrom(value);
- auto expected_values_flat = expected.Args();
- ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
- for (size_t i = 0; i < values_flat.values.Length(); ++i) {
- auto& a = values_flat.values[i];
- auto& b = expected_values_flat.values[i];
+ auto values_flat = ScalarsFrom(value);
+ auto expected_values_flat = expected.args;
+ ASSERT_EQ(values_flat.Length(), expected_values_flat.Length());
+ for (size_t i = 0; i < values_flat.Length(); ++i) {
+ auto& a = values_flat[i];
+ auto& b = expected_values_flat[i];
EXPECT_EQ(a, b);
- if (expected.IsIntegral()) {
+ if (expected.is_integral) {
// Check that the constant's integer doesn't contain unexpected
// data in the MSBs that are outside of the bit-width of T.
EXPECT_EQ(builder::As<AInt>(a), builder::As<AInt>(b));
diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h
index edbc456..cf17b61 100644
--- a/src/tint/resolver/resolver_test_helper.h
+++ b/src/tint/resolver/resolver_test_helper.h
@@ -180,63 +180,18 @@
template <typename TO>
struct ptr {};
-/// Type used to accept scalars as arguments. Can be either a single value that gets splatted for
-/// composite types, or all values required by the composite type.
-struct ScalarArgs {
- /// Constructor
- ScalarArgs() = default;
-
- /// Constructor
- /// @param single_value single value to initialize with
- template <typename T>
- explicit ScalarArgs(T single_value) : values(utils::Vector<Storage, 1>{single_value}) {}
-
- /// Constructor
- /// @param all_values all values to initialize the composite type with
- template <typename T>
- ScalarArgs(utils::VectorRef<T> all_values) // NOLINT: implicit on purpose
- {
- for (auto& v : all_values) {
- values.Push(v);
- }
- }
-
- /// @param other the other ScalarArgs to compare against
- /// @returns true if all values are equal to the values in @p other
- bool operator==(const ScalarArgs& other) const { return values == other.values; }
-
- /// Valid scalar types for args
- using Storage = std::variant<i32, u32, f32, f16, AInt, AFloat, bool>;
-
- /// The vector of values
- utils::Vector<Storage, 16> values;
-};
+/// A scalar value
+using Scalar = std::variant<i32, u32, f32, f16, AInt, AFloat, bool>;
/// Returns current variant value in `s` cast to type `T`
template <typename T>
-T As(ScalarArgs::Storage& s) {
+T As(Scalar& s) {
return std::visit([](auto&& v) { return static_cast<T>(v); }, s);
}
-/// @param o the std::ostream to write to
-/// @param args the ScalarArgs
-/// @return the std::ostream so calls can be chained
-inline std::ostream& operator<<(std::ostream& o, const ScalarArgs& args) {
- o << "[";
- bool first = true;
- for (auto& val : args.values) {
- if (!first) {
- o << ", ";
- }
- first = false;
- std::visit([&](auto&& v) { o << v; }, val);
- }
- o << "]";
- return o;
-}
-
using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b);
-using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, ScalarArgs args);
+using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b,
+ utils::VectorRef<Scalar> args);
using ast_expr_from_double_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, double v);
using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b);
using type_name_func_ptr = std::string (*)();
@@ -280,14 +235,14 @@
/// @param b the ProgramBuilder
/// @param args args of size 1 with the boolean value to init with
/// @return a new AST expression of the bool type
- static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
- return b.Expr(std::get<bool>(args.values[0]));
+ static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
+ return b.Expr(std::get<bool>(args[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to bool.
/// @return a new AST expression of the bool type
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
+ return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "bool"; }
@@ -311,14 +266,14 @@
/// @param b the ProgramBuilder
/// @param args args of size 1 with the i32 value to init with
/// @return a new AST i32 literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
- return b.Expr(std::get<i32>(args.values[0]));
+ static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
+ return b.Expr(std::get<i32>(args[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to i32.
/// @return a new AST i32 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
+ return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "i32"; }
@@ -342,14 +297,14 @@
/// @param b the ProgramBuilder
/// @param args args of size 1 with the u32 value to init with
/// @return a new AST u32 literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
- return b.Expr(std::get<u32>(args.values[0]));
+ static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
+ return b.Expr(std::get<u32>(args[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to u32.
/// @return a new AST u32 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
+ return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "u32"; }
@@ -373,14 +328,14 @@
/// @param b the ProgramBuilder
/// @param args args of size 1 with the f32 value to init with
/// @return a new AST f32 literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
- return b.Expr(std::get<f32>(args.values[0]));
+ static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
+ return b.Expr(std::get<f32>(args[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to f32.
/// @return a new AST f32 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, ScalarArgs{static_cast<f32>(v)});
+ return Expr(b, utils::Vector<Scalar, 1>{static_cast<f32>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "f32"; }
@@ -404,14 +359,14 @@
/// @param b the ProgramBuilder
/// @param args args of size 1 with the f16 value to init with
/// @return a new AST f16 literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
- return b.Expr(std::get<f16>(args.values[0]));
+ static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
+ return b.Expr(std::get<f16>(args[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to f16.
/// @return a new AST f16 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
+ return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "f16"; }
@@ -434,14 +389,14 @@
/// @param b the ProgramBuilder
/// @param args args of size 1 with the abstract-float value to init with
/// @return a new AST abstract-float literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
- return b.Expr(std::get<AFloat>(args.values[0]));
+ static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
+ return b.Expr(std::get<AFloat>(args[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to AFloat.
/// @return a new AST abstract-float literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
+ return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "abstract-float"; }
@@ -464,14 +419,14 @@
/// @param b the ProgramBuilder
/// @param args args of size 1 with the abstract-int value to init with
/// @return a new AST abstract-int literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
- return b.Expr(std::get<AInt>(args.values[0]));
+ static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
+ return b.Expr(std::get<AInt>(args[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to AInt.
/// @return a new AST abstract-int literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
+ return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "abstract-int"; }
@@ -499,17 +454,17 @@
/// @param b the ProgramBuilder
/// @param args args of size 1 or N with values of type T to initialize with
/// @return a new AST vector value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
+ static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
return b.Construct(AST(b), ExprArgs(b, std::move(args)));
}
/// @param b the ProgramBuilder
/// @param args args of size 1 or N with values of type T to initialize with
/// @return the list of expressions that are used to construct the vector
- static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) {
- const bool one_value = args.values.Length() == 1;
+ static inline auto ExprArgs(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
+ const bool one_value = args.Length() == 1;
utils::Vector<const ast::Expression*, N> r;
for (size_t i = 0; i < N; ++i) {
- r.Push(DataType<T>::Expr(b, ScalarArgs{one_value ? args.values[0] : args.values[i]}));
+ r.Push(DataType<T>::Expr(b, utils::Vector<Scalar, 1>{one_value ? args[0] : args[i]}));
}
return r;
}
@@ -517,7 +472,7 @@
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST vector value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
+ return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() {
@@ -548,25 +503,25 @@
/// @param b the ProgramBuilder
/// @param args args of size 1 or N*M with values of type T to initialize with
/// @return a new AST matrix value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
+ static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
return b.Construct(AST(b), ExprArgs(b, std::move(args)));
}
/// @param b the ProgramBuilder
/// @param args args of size 1 or N*M with values of type T to initialize with
/// @return a new AST matrix value expression
- static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) {
- const bool one_value = args.values.Length() == 1;
+ static inline auto ExprArgs(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
+ const bool one_value = args.Length() == 1;
size_t next = 0;
utils::Vector<const ast::Expression*, N> r;
for (uint32_t i = 0; i < N; ++i) {
if (one_value) {
- r.Push(DataType<vec<M, T>>::Expr(b, ScalarArgs{args.values[0]}));
+ r.Push(DataType<vec<M, T>>::Expr(b, utils::Vector<Scalar, 1>{args[0]}));
} else {
- utils::Vector<T, M> v;
+ utils::Vector<Scalar, M> v;
for (size_t j = 0; j < M; ++j) {
- v.Push(std::get<T>(args.values[next++]));
+ v.Push(args[next++]);
}
- r.Push(DataType<vec<M, T>>::Expr(b, utils::VectorRef<T>{v}));
+ r.Push(DataType<vec<M, T>>::Expr(b, std::move(v)));
}
}
return r;
@@ -575,7 +530,7 @@
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST matrix value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
+ return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() {
@@ -611,8 +566,9 @@
/// @param args the value nested elements will be initialized with
/// @return a new AST expression of the alias type
template <bool IS_COMPOSITE = is_composite>
- static inline traits::EnableIf<!IS_COMPOSITE, const ast::Expression*> Expr(ProgramBuilder& b,
- ScalarArgs args) {
+ static inline traits::EnableIf<!IS_COMPOSITE, const ast::Expression*> Expr(
+ ProgramBuilder& b,
+ utils::VectorRef<Scalar> args) {
// Cast
return b.Construct(AST(b), DataType<T>::Expr(b, std::move(args)));
}
@@ -621,8 +577,9 @@
/// @param args the value nested elements will be initialized with
/// @return a new AST expression of the alias type
template <bool IS_COMPOSITE = is_composite>
- static inline traits::EnableIf<IS_COMPOSITE, const ast::Expression*> Expr(ProgramBuilder& b,
- ScalarArgs args) {
+ static inline traits::EnableIf<IS_COMPOSITE, const ast::Expression*> Expr(
+ ProgramBuilder& b,
+ utils::VectorRef<Scalar> args) {
// Construct
return b.Construct(AST(b), DataType<T>::ExprArgs(b, std::move(args)));
}
@@ -631,7 +588,7 @@
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST expression of the alias type
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
+ return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
@@ -662,7 +619,8 @@
/// @param b the ProgramBuilder
/// @return a new AST expression of the pointer type
- static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs /*unused*/) {
+ static inline const ast::Expression* Expr(ProgramBuilder& b,
+ utils::VectorRef<Scalar> /*unused*/) {
auto sym = b.Symbols().New("global_for_ptr");
b.GlobalVar(sym, DataType<T>::AST(b), ast::AddressSpace::kPrivate);
return b.AddressOf(sym);
@@ -672,7 +630,7 @@
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST expression of the pointer type
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
+ return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
@@ -716,17 +674,17 @@
/// @param args args of size 1 or N with values of type T to initialize with
/// with
/// @return a new AST array value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
+ static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
return b.Construct(AST(b), ExprArgs(b, std::move(args)));
}
/// @param b the ProgramBuilder
/// @param args args of size 1 or N with values of type T to initialize with
/// @return the list of expressions that are used to construct the array
- static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) {
- const bool one_value = args.values.Length() == 1;
+ static inline auto ExprArgs(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
+ const bool one_value = args.Length() == 1;
utils::Vector<const ast::Expression*, N> r;
for (uint32_t i = 0; i < N; i++) {
- r.Push(DataType<T>::Expr(b, ScalarArgs{one_value ? args.values[0] : args.values[i]}));
+ r.Push(DataType<T>::Expr(b, utils::Vector<Scalar, 1>{one_value ? args[0] : args[i]}));
}
return r;
}
@@ -734,7 +692,7 @@
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST array value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
- return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
+ return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() {
@@ -776,80 +734,34 @@
const bool IsDataTypeSpecializedFor =
!std::is_same_v<typename DataType<T>::ElementType, UnspecializedElementType>;
-namespace detail {
-/// ValueBase is a base class of ConcreteValue<T>
-struct ValueBase {
- /// Constructor
- ValueBase() = default;
- /// Destructor
- virtual ~ValueBase() = default;
- /// Move constructor
- ValueBase(ValueBase&&) = default;
- /// Copy constructor
- ValueBase(const ValueBase&) = default;
- /// Copy assignment operator
- /// @returns this instance
- ValueBase& operator=(const ValueBase&) = default;
- /// Creates an `ast::Expression` for the type T passing in previously stored args
- /// @param b the ProgramBuilder
- /// @returns an expression node
- virtual const ast::Expression* Expr(ProgramBuilder& b) const = 0;
- /// @returns args used to create expression via `Expr`
- virtual const ScalarArgs& Args() const = 0;
- /// @returns true if element type is abstract
- virtual bool IsAbstract() const = 0;
- /// @returns true if element type is an integral
- virtual bool IsIntegral() const = 0;
- /// @returns element type name
- virtual std::string TypeName() const = 0;
- /// Prints this value to the output stream
- /// @param o the output stream
- /// @returns input argument `o`
- virtual std::ostream& Print(std::ostream& o) const = 0;
-};
-
-/// ConcreteValue<T> is used to create Values of type DataType<T> with a ScalarArgs initializer.
-template <typename T>
-struct ConcreteValue : ValueBase {
- /// Constructor
+/// Value is used to create Values with a Scalar vector initializer.
+struct Value {
+ /// Creates a Value for type T initialized with `args`
/// @param args the scalar args
- explicit ConcreteValue(ScalarArgs args) : args_(std::move(args)) {}
-
- /// Alias to T
- using Type = T;
- /// Alias to DataType<T>
- using DataType = builder::DataType<T>;
- /// Alias to DataType::ElementType
- using ElementType = typename DataType::ElementType;
-
- /// Creates an `ast::Expression` for the type T passing in previously stored args
- /// @param b the ProgramBuilder
- /// @returns an expression node
- const ast::Expression* Expr(ProgramBuilder& b) const override {
- auto create = CreatePtrsFor<T>();
- return (*create.expr)(b, args_);
+ /// @returns Value
+ template <typename T>
+ static Value Create(utils::VectorRef<Scalar> args) {
+ static_assert(IsDataTypeSpecializedFor<T>, "No DataType<T> specialization exists");
+ using EL_TY = typename builder::DataType<T>::ElementType;
+ return Value{
+ std::move(args), CreatePtrsFor<T>().expr, tint::IsAbstract<EL_TY>,
+ tint::IsIntegral<EL_TY>, tint::FriendlyName<EL_TY>(),
+ };
}
- /// @returns args used to create expression via `Expr`
- const ScalarArgs& Args() const override { return args_; }
-
- /// @returns true if element type is abstract
- bool IsAbstract() const override { return tint::IsAbstract<ElementType>; }
-
- /// @returns true if element type is an integral
- bool IsIntegral() const override { return tint::IsIntegral<ElementType>; }
-
- /// @returns element type name
- std::string TypeName() const override { return tint::FriendlyName<ElementType>(); }
+ /// Creates an `ast::Expression` for the type T passing in previously stored args
+ /// @param b the ProgramBuilder
+ /// @returns an expression node
+ const ast::Expression* Expr(ProgramBuilder& b) const { return (*create)(b, args); }
/// Prints this value to the output stream
/// @param o the output stream
/// @returns input argument `o`
- std::ostream& Print(std::ostream& o) const override {
- o << TypeName() << "(";
- for (auto& a : args_.values) {
- o << std::get<ElementType>(a);
- if (&a != &args_.values.Back()) {
+ std::ostream& Print(std::ostream& o) const {
+ o << type_name << "(";
+ for (auto& a : args) {
+ std::visit([&](auto& v) { o << v; }, a);
+ if (&a != &args.Back()) {
o << ", ";
}
}
@@ -857,54 +769,16 @@
return o;
}
- private:
- /// args to create expression with
- ScalarArgs args_;
-};
-} // namespace detail
-
-/// A Value represents a value of type DataType<T> created with ScalarArgs. Useful for storing
-/// values for unit tests.
-class Value {
- public:
- /// Creates a Value for type T initialized with `args`
- /// @param args the scalar args
- /// @returns Value
- template <typename T>
- static Value Create(ScalarArgs args) {
- static_assert(IsDataTypeSpecializedFor<T>, "No DataType<T> specialization exists");
- return Value{std::make_shared<detail::ConcreteValue<T>>(std::move(args))};
- }
-
- /// Creates an `ast::Expression` for the type T passing in previously stored args
- /// @param b the ProgramBuilder
- /// @returns an expression node
- const ast::Expression* Expr(ProgramBuilder& b) const { return value_->Expr(b); }
-
- /// @returns args used to create expression via `Expr`
- const ScalarArgs& Args() const { return value_->Args(); }
-
- /// @returns true if element type is abstract
- bool IsAbstract() const { return value_->IsAbstract(); }
-
- /// @returns true if element type is an integral
- bool IsIntegral() const { return value_->IsIntegral(); }
-
- /// @returns element type name
- std::string TypeName() const { return value_->TypeName(); }
-
- /// Prints this value to the output stream
- /// @param o the output stream
- /// @returns input argument `o`
- std::ostream& Print(std::ostream& o) const { return value_->Print(o); }
-
- private:
- /// Private constructor
- explicit Value(std::shared_ptr<const detail::ValueBase> value) : value_(std::move(value)) {}
-
- /// Shared pointer to an immutable value. This type-erasure pattern allows Value to wrap a
- /// polymorphic type, while being used like a value-type (i.e. copyable).
- std::shared_ptr<const detail::ValueBase> value_;
+ /// The arguments used to construct the value
+ utils::Vector<Scalar, 4> args;
+ /// Function used to construct an expression with the given value
+ builder::ast_expr_func_ptr create;
+ /// True if the element type is abstract
+ bool is_abstract = false;
+ /// True if the element type is an integer
+ bool is_integral = false;
+ /// The name of the type.
+ const char* type_name = "<invalid>";
};
/// Prints Value to ostream
@@ -919,7 +793,7 @@
/// Creates a Value of DataType<T> from a scalar `v`
template <typename T>
Value Val(T v) {
- return Value::Create<T>(ScalarArgs{v});
+ return Value::Create<T>(utils::Vector<Scalar, 1>{v});
}
/// Creates a Value of DataType<vec<N, T>> from N scalar `args`
@@ -927,41 +801,41 @@
Value Vec(T... args) {
using FirstT = std::tuple_element_t<0, std::tuple<T...>>;
constexpr size_t N = sizeof...(args);
- utils::Vector v{args...};
- return Value::Create<vec<N, FirstT>>(utils::VectorRef<FirstT>{v});
+ utils::Vector<Scalar, sizeof...(args)> v{args...};
+ return Value::Create<vec<N, FirstT>>(std::move(v));
}
/// Creates a Value of DataType<mat<C,R,T> from C*R scalar `args`
template <size_t C, size_t R, typename T>
Value Mat(const T (&m_in)[C][R]) {
- utils::Vector<T, C * R> m;
+ utils::Vector<Scalar, C * R> m;
for (uint32_t i = 0; i < C; ++i) {
for (size_t j = 0; j < R; ++j) {
m.Push(m_in[i][j]);
}
}
- return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
+ return Value::Create<mat<C, R, T>>(std::move(m));
}
/// Creates a Value of DataType<mat<2,R,T> from column vectors `c0` and `c1`
template <typename T, size_t R>
Value Mat(const T (&c0)[R], const T (&c1)[R]) {
constexpr size_t C = 2;
- utils::Vector<T, C * R> m;
+ utils::Vector<Scalar, C * R> m;
for (auto v : c0) {
m.Push(v);
}
for (auto v : c1) {
m.Push(v);
}
- return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
+ return Value::Create<mat<C, R, T>>(std::move(m));
}
/// Creates a Value of DataType<mat<3,R,T> from column vectors `c0`, `c1`, and `c2`
template <typename T, size_t R>
Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R]) {
constexpr size_t C = 3;
- utils::Vector<T, C * R> m;
+ utils::Vector<Scalar, C * R> m;
for (auto v : c0) {
m.Push(v);
}
@@ -971,14 +845,14 @@
for (auto v : c2) {
m.Push(v);
}
- return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
+ return Value::Create<mat<C, R, T>>(std::move(m));
}
/// Creates a Value of DataType<mat<4,R,T> from column vectors `c0`, `c1`, `c2`, and `c3`
template <typename T, size_t R>
Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R]) {
constexpr size_t C = 4;
- utils::Vector<T, C * R> m;
+ utils::Vector<Scalar, C * R> m;
for (auto v : c0) {
m.Push(v);
}
@@ -991,7 +865,7 @@
for (auto v : c3) {
m.Push(v);
}
- return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
+ return Value::Create<mat<C, R, T>>(std::move(m));
}
} // namespace builder