Move CreateComposite into ProgramBuilder.
This CL moves the CreateComposite helper into the ProgramBuilder.
Bug: tint:1718
Change-Id: I4aca7dc3d7192a7aa8b300f00529670aa9c09a27
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/114202
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 83ad287..597463c 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -87,6 +87,8 @@
#include "src/tint/ast/void.h"
#include "src/tint/ast/while_statement.h"
#include "src/tint/ast/workgroup_attribute.h"
+#include "src/tint/constant/composite.h"
+#include "src/tint/constant/splat.h"
#include "src/tint/constant/value.h"
#include "src/tint/number.h"
#include "src/tint/program.h"
@@ -470,11 +472,73 @@
/// @param args the arguments to pass to the constructor
/// @returns the node pointer
template <typename T, typename... ARGS>
- traits::EnableIf<traits::IsTypeOrDerived<T, constant::Value>, T>* create(ARGS&&... args) {
+ traits::EnableIf<traits::IsTypeOrDerived<T, constant::Value> &&
+ !traits::IsTypeOrDerived<T, constant::Composite> &&
+ !traits::IsTypeOrDerived<T, constant::Splat>,
+ T>*
+ create(ARGS&&... args) {
AssertNotMoved();
return constant_nodes_.Create<T>(std::forward<ARGS>(args)...);
}
+ /// Constructs a constant of a vector, matrix or array type.
+ ///
+ /// Examines the element values and will return either a constant::Composite or a
+ /// constant::Splat, depending on the element types and values.
+ ///
+ /// @param type the composite type
+ /// @param elements the composite elements
+ /// @returns the node pointer
+ template <typename T>
+ traits::EnableIf<traits::IsTypeOrDerived<T, constant::Composite> ||
+ traits::IsTypeOrDerived<T, constant::Splat>,
+ const constant::Value>*
+ create(const type::Type* type, utils::VectorRef<const constant::Value*> elements) {
+ AssertNotMoved();
+ if (elements.IsEmpty()) {
+ return nullptr;
+ }
+
+ bool any_zero = false;
+ bool all_zero = true;
+ bool all_equal = true;
+ auto* first = elements.Front();
+ for (auto* el : elements) {
+ if (!el) {
+ return nullptr;
+ }
+ if (!any_zero && el->AnyZero()) {
+ any_zero = true;
+ }
+ if (all_zero && !el->AllZero()) {
+ all_zero = false;
+ }
+ if (all_equal && el != first) {
+ if (!el->Equal(first)) {
+ all_equal = false;
+ }
+ }
+ }
+ if (all_equal) {
+ return create<constant::Splat>(type, elements[0], elements.Length());
+ }
+
+ return constant_nodes_.Create<constant::Composite>(type, std::move(elements), all_zero,
+ any_zero);
+ }
+
+ /// Constructs a splat constant.
+ /// @param type the splat type
+ /// @param element the splat element
+ /// @param n the number of elements
+ /// @returns the node pointer
+ template <typename T>
+ traits::EnableIf<traits::IsTypeOrDerived<T, constant::Splat>, const constant::Splat>*
+ create(const type::Type* type, const constant::Value* element, size_t n) {
+ AssertNotMoved();
+ return constant_nodes_.Create<constant::Splat>(type, element, n);
+ }
+
/// Creates a new type::Type owned by the ProgramBuilder.
/// When the ProgramBuilder is destructed, owned ProgramBuilder and the
/// returned `Type` will also be destructed.
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index a92b02a..628975b 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -232,11 +232,6 @@
return count;
}
-// Forward declaration
-const constant::Value* CreateComposite(ProgramBuilder& builder,
- const type::Type* type,
- utils::VectorRef<const constant::Value*> elements);
-
template <typename T>
ConstEval::Result ScalarConvert(const constant::Scalar<T>* scalar,
ProgramBuilder& builder,
@@ -347,7 +342,7 @@
}
conv_els.Push(conv_el.Get());
}
- return CreateComposite(builder, target_ty, std::move(conv_els));
+ return builder.create<constant::Composite>(target_ty, std::move(conv_els));
}
ConstEval::Result ConvertInternal(const constant::Value* c,
@@ -438,7 +433,7 @@
// All members were of the same type, so the zero value is the same for all members.
return builder.create<constant::Splat>(type, zeros[0], s->Members().Length());
}
- return CreateComposite(builder, s, std::move(zeros));
+ return builder.create<constant::Composite>(s, std::move(zeros));
},
[&](Default) -> const constant::Value* {
return ZeroTypeDispatch(type, [&](auto zero) -> const constant::Value* {
@@ -449,42 +444,6 @@
});
}
-/// CreateComposite is used to construct a constant of a vector, matrix or array type.
-/// CreateComposite examines the element values and will return either a Composite or a Splat,
-/// depending on the element types and values.
-const constant::Value* CreateComposite(ProgramBuilder& builder,
- const type::Type* type,
- utils::VectorRef<const constant::Value*> elements) {
- if (elements.IsEmpty()) {
- return nullptr;
- }
- bool any_zero = false;
- bool all_zero = true;
- bool all_equal = true;
- auto* first = elements.Front();
- for (auto* el : elements) {
- if (!el) {
- return nullptr;
- }
- if (!any_zero && el->AnyZero()) {
- any_zero = true;
- }
- if (all_zero && !el->AllZero()) {
- all_zero = false;
- }
- if (all_equal && el != first) {
- if (!el->Equal(first)) {
- all_equal = false;
- }
- }
- }
- if (all_equal) {
- return builder.create<constant::Splat>(type, elements[0], elements.Length());
- } else {
- return builder.create<constant::Composite>(type, std::move(elements), all_zero, any_zero);
- }
-}
-
namespace detail {
/// Implementation of TransformElements
template <typename F, typename... CONSTANTS>
@@ -515,7 +474,7 @@
return el.Failure();
}
}
- return CreateComposite(builder, composite_ty, std::move(els));
+ return builder.create<constant::Composite>(composite_ty, std::move(els));
}
} // namespace detail
@@ -569,7 +528,7 @@
return el.Failure();
}
}
- return CreateComposite(builder, composite_ty, std::move(els));
+ return builder.create<constant::Composite>(composite_ty, std::move(els));
}
} // namespace
@@ -1211,7 +1170,7 @@
for (auto* arg : args) {
els.Push(arg->ConstantValue());
}
- return CreateComposite(builder, ty, std::move(els));
+ return builder.create<constant::Composite>(ty, std::move(els));
}
ConstEval::Result ConstEval::Conv(const type::Type* ty,
@@ -1255,7 +1214,7 @@
ConstEval::Result ConstEval::VecInitS(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source&) {
- return CreateComposite(builder, ty, args);
+ return builder.create<constant::Composite>(ty, args);
}
ConstEval::Result ConstEval::VecInitM(const type::Type* ty,
@@ -1281,7 +1240,7 @@
els.Push(val);
}
}
- return CreateComposite(builder, ty, std::move(els));
+ return builder.create<constant::Composite>(ty, std::move(els));
}
ConstEval::Result ConstEval::MatInitS(const type::Type* ty,
@@ -1296,15 +1255,15 @@
auto i = r + c * m->rows();
column.Push(args[i]);
}
- els.Push(CreateComposite(builder, m->ColumnType(), std::move(column)));
+ els.Push(builder.create<constant::Composite>(m->ColumnType(), std::move(column)));
}
- return CreateComposite(builder, ty, std::move(els));
+ return builder.create<constant::Composite>(ty, std::move(els));
}
ConstEval::Result ConstEval::MatInitV(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source&) {
- return CreateComposite(builder, ty, args);
+ return builder.create<constant::Composite>(ty, args);
}
ConstEval::Result ConstEval::Index(const sem::Expression* obj_expr,
@@ -1357,7 +1316,7 @@
}
auto values = utils::Transform<4>(
indices, [&](uint32_t i) { return vec_val->Index(static_cast<size_t>(i)); });
- return CreateComposite(builder, ty, std::move(values));
+ return builder.create<constant::Composite>(ty, std::move(values));
}
ConstEval::Result ConstEval::Bitcast(const type::Type*, const sem::Expression*) {
@@ -1484,7 +1443,7 @@
}
result.Push(r.Get());
}
- return CreateComposite(builder, ty, result);
+ return builder.create<constant::Composite>(ty, result);
}
ConstEval::Result ConstEval::OpMultiplyVecMat(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
@@ -1534,7 +1493,7 @@
}
result.Push(r.Get());
}
- return CreateComposite(builder, ty, result);
+ return builder.create<constant::Composite>(ty, result);
}
ConstEval::Result ConstEval::OpMultiplyMatMat(const type::Type* ty,
@@ -1596,9 +1555,9 @@
// Add column vector to matrix
auto* col_vec_ty = ty->As<type::Matrix>()->ColumnType();
- result_mat.Push(CreateComposite(builder, col_vec_ty, col_vec));
+ result_mat.Push(builder.create<constant::Composite>(col_vec_ty, col_vec));
}
- return CreateComposite(builder, ty, result_mat);
+ return builder.create<constant::Composite>(ty, result_mat);
}
ConstEval::Result ConstEval::OpDivide(const type::Type* ty,
@@ -2208,8 +2167,8 @@
return utils::Failure;
}
- return CreateComposite(builder, ty,
- utils::Vector<const constant::Value*, 3>{x.Get(), y.Get(), z.Get()});
+ return builder.create<constant::Composite>(
+ ty, utils::Vector<const constant::Value*, 3>{x.Get(), y.Get(), z.Get()});
}
ConstEval::Result ConstEval::degrees(const type::Type* ty,
@@ -2592,21 +2551,20 @@
}
auto fract_ty = builder.create<type::Vector>(fract_els[0]->Type(), vec->Width());
auto exp_ty = builder.create<type::Vector>(exp_els[0]->Type(), vec->Width());
- return CreateComposite(builder, ty,
- utils::Vector<const constant::Value*, 2>{
- CreateComposite(builder, fract_ty, std::move(fract_els)),
- CreateComposite(builder, exp_ty, std::move(exp_els)),
- });
+ return builder.create<constant::Composite>(
+ ty, utils::Vector<const constant::Value*, 2>{
+ builder.create<constant::Composite>(fract_ty, std::move(fract_els)),
+ builder.create<constant::Composite>(exp_ty, std::move(exp_els)),
+ });
} else {
auto fe = scalar(arg);
if (!fe.fract || !fe.exp) {
return utils::Failure;
}
- return CreateComposite(builder, ty,
- utils::Vector<const constant::Value*, 2>{
- fe.fract.Get(),
- fe.exp.Get(),
- });
+ return builder.create<constant::Composite>(ty, utils::Vector<const constant::Value*, 2>{
+ fe.fract.Get(),
+ fe.exp.Get(),
+ });
}
}
@@ -2838,7 +2796,7 @@
return utils::Failure;
}
- return CreateComposite(builder, ty, std::move(fields));
+ return builder.create<constant::Composite>(ty, std::move(fields));
}
ConstEval::Result ConstEval::normalize(const type::Type* ty,
@@ -3412,9 +3370,10 @@
for (size_t c = 0; c < mat_ty->columns(); ++c) {
new_col_vec.Push(me(r, c));
}
- result_mat.Push(CreateComposite(builder, result_mat_ty->ColumnType(), new_col_vec));
+ result_mat.Push(
+ builder.create<constant::Composite>(result_mat_ty->ColumnType(), new_col_vec));
}
- return CreateComposite(builder, ty, result_mat);
+ return builder.create<constant::Composite>(ty, result_mat);
}
ConstEval::Result ConstEval::trunc(const type::Type* ty,
@@ -3450,7 +3409,7 @@
}
els.Push(el.Get());
}
- return CreateComposite(builder, ty, std::move(els));
+ return builder.create<constant::Composite>(ty, std::move(els));
}
ConstEval::Result ConstEval::unpack2x16snorm(const type::Type* ty,
@@ -3470,7 +3429,7 @@
}
els.Push(el.Get());
}
- return CreateComposite(builder, ty, std::move(els));
+ return builder.create<constant::Composite>(ty, std::move(els));
}
ConstEval::Result ConstEval::unpack2x16unorm(const type::Type* ty,
@@ -3489,7 +3448,7 @@
}
els.Push(el.Get());
}
- return CreateComposite(builder, ty, std::move(els));
+ return builder.create<constant::Composite>(ty, std::move(els));
}
ConstEval::Result ConstEval::unpack4x8snorm(const type::Type* ty,
@@ -3509,7 +3468,7 @@
}
els.Push(el.Get());
}
- return CreateComposite(builder, ty, std::move(els));
+ return builder.create<constant::Composite>(ty, std::move(els));
}
ConstEval::Result ConstEval::unpack4x8unorm(const type::Type* ty,
@@ -3528,7 +3487,7 @@
}
els.Push(el.Get());
}
- return CreateComposite(builder, ty, std::move(els));
+ return builder.create<constant::Composite>(ty, std::move(els));
}
ConstEval::Result ConstEval::quantizeToF16(const type::Type* ty,