Import Tint changes from Dawn
Changes:
- 30f51b94da225f8cbffb18ee8a9f176f6df469d0 Delete the remove_stale_autogen_files mechanism. by Corentin Wallez <cwallez@chromium.org>
- 7d9140feb9a7273801400c0e7166f8d994db8ce1 spirv-reader: Refactor tracking of locally-defined values by David Neto <dneto@google.com>
- e68d4506c03f1455e124271d89a9d84e5ec8ec0f tint/resolver: Consistently use utils::Result in ConstEval by Ben Clayton <bclayton@google.com>
GitOrigin-RevId: 30f51b94da225f8cbffb18ee8a9f176f6df469d0
Change-Id: I7eff5c8c3034ebd5f6ede46dcca2cf073d184873
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/103160
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/fuzzers/BUILD.gn b/src/tint/fuzzers/BUILD.gn
index 25e475e..d0c72a7 100644
--- a/src/tint/fuzzers/BUILD.gn
+++ b/src/tint/fuzzers/BUILD.gn
@@ -33,8 +33,6 @@
rebase_path(fuzzer_corpus_wgsl_dir, root_build_dir),
]
outputs = [ fuzzer_corpus_wgsl_stamp ]
-
- deps = [ "${dawn_root}/generator:remove_stale_autogen_files" ]
}
tint_fuzzer_common_libfuzzer_options = [
diff --git a/src/tint/fuzzers/tint_ast_fuzzer/BUILD.gn b/src/tint/fuzzers/tint_ast_fuzzer/BUILD.gn
index 58a36ea..416128c 100644
--- a/src/tint/fuzzers/tint_ast_fuzzer/BUILD.gn
+++ b/src/tint/fuzzers/tint_ast_fuzzer/BUILD.gn
@@ -23,7 +23,6 @@
sources = [ "protobufs/tint_ast_fuzzer.proto" ]
generate_python = false
use_protobuf_full = true
- deps = [ "${dawn_root}/generator:remove_stale_autogen_files" ]
}
source_set("tint_ast_fuzzer") {
diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc
index 59a98e1..00b49d9 100644
--- a/src/tint/reader/spirv/function.cc
+++ b/src/tint/reader/spirv/function.cc
@@ -759,17 +759,22 @@
BlockInfo::~BlockInfo() = default;
-DefInfo::DefInfo(const spvtools::opt::Instruction& def_inst,
- bool the_locally_defined,
- uint32_t the_block_pos,
- size_t the_index)
- : inst(def_inst),
- locally_defined(the_locally_defined),
- block_pos(the_block_pos),
- index(the_index) {}
+DefInfo::DefInfo(size_t the_index,
+ const spvtools::opt::Instruction& def_inst,
+ uint32_t the_block_pos)
+ : index(the_index), inst(def_inst), local(DefInfo::Local(the_block_pos)) {}
+
+DefInfo::DefInfo(size_t the_index, const spvtools::opt::Instruction& def_inst)
+ : index(the_index), inst(def_inst) {}
DefInfo::~DefInfo() = default;
+DefInfo::Local::Local(uint32_t the_block_pos) : block_pos(the_block_pos) {}
+
+DefInfo::Local::Local(const Local& other) = default;
+
+DefInfo::Local::~Local() = default;
+
ast::Node* StatementBuilder::Clone(CloneContext*) const {
return nullptr;
}
@@ -3380,7 +3385,7 @@
utils::Vector<BlockInfo::PhiAssignment, 4> worklist;
worklist.Reserve(block_info.phi_assignments.Length());
for (const auto assignment : block_info.phi_assignments) {
- if (GetDefInfo(assignment.phi_id)->num_uses > 0) {
+ if (GetDefInfo(assignment.phi_id)->local->num_uses > 0) {
worklist.Push(assignment);
}
}
@@ -3462,7 +3467,7 @@
TypedExpression expr) {
const auto result_id = inst.result_id();
const auto* def_info = GetDefInfo(result_id);
- if (def_info && def_info->requires_hoisted_def) {
+ if (def_info && def_info->requires_hoisted_var_def) {
auto name = namer_.Name(result_id);
// Emit an assignment of the expression to the hoisted variable.
AddStatement(create<ast::AssignmentStatement>(
@@ -3512,8 +3517,11 @@
}
if (combinatorial_expr.expr != nullptr) {
- if (def_info->requires_hoisted_def || def_info->requires_named_const_def ||
- def_info->num_uses != 1) {
+ // If the expression is combinatorial, then it's not a direct access
+ // of a builtin variable.
+ TINT_ASSERT(Reader, def_info->local.has_value());
+ if (def_info->requires_hoisted_var_def || def_info->requires_named_let_def ||
+ def_info->local->num_uses != 1) {
// Generate a const definition or an assignment to a hoisted definition
// now and later use the const or variable name at the uses of this
// value.
@@ -4739,7 +4747,7 @@
const auto id = special_var.first;
const auto builtin = special_var.second;
const auto* var = def_use_mgr_->GetDef(id);
- def_info_[id] = std::make_unique<DefInfo>(*var, false, 0, index);
+ def_info_[id] = std::make_unique<DefInfo>(index, *var);
++index;
auto& def = def_info_[id];
// Builtins are always defined outside the function.
@@ -4787,7 +4795,7 @@
if ((result_id == 0) || inst.opcode() == SpvOpLabel) {
continue;
}
- def_info_[result_id] = std::make_unique<DefInfo>(inst, true, block_pos, index);
+ def_info_[result_id] = std::make_unique<DefInfo>(index, inst, block_pos);
++index;
auto& info = def_info_[result_id];
@@ -4876,7 +4884,7 @@
const auto id = inst.GetSingleWordInOperand(static_cast<uint32_t>(in_operand_index));
auto* const operand_def = GetDefInfo(id);
if (operand_def) {
- operand_def->requires_named_const_def = true;
+ operand_def->requires_named_let_def = true;
}
};
for (auto& id_def_info_pair : def_info_) {
@@ -4913,18 +4921,21 @@
// Ignores values defined outside this function.
auto record_value_use = [this](uint32_t id, const BlockInfo* block_info) {
if (auto* def_info = GetDefInfo(id)) {
- // Update usage count.
- def_info->num_uses++;
- // Update usage span.
- def_info->first_use_pos = std::min(def_info->first_use_pos, block_info->pos);
- def_info->last_use_pos = std::max(def_info->last_use_pos, block_info->pos);
+ if (def_info->local.has_value()) {
+ auto& local_def = def_info->local.value();
+ // Update usage count.
+ local_def.num_uses++;
+ // Update usage span.
+ local_def.first_use_pos = std::min(local_def.first_use_pos, block_info->pos);
+ local_def.last_use_pos = std::max(local_def.last_use_pos, block_info->pos);
- // Determine whether this ID is defined in a different construct
- // from this use.
- const auto defining_block = block_order_[def_info->block_pos];
- const auto* def_in_construct = GetBlockInfo(defining_block)->construct;
- if (def_in_construct != block_info->construct) {
- def_info->used_in_another_construct = true;
+ // Determine whether this ID is defined in a different construct
+ // from this use.
+ const auto defining_block = block_order_[local_def.block_pos];
+ const auto* def_in_construct = GetBlockInfo(defining_block)->construct;
+ if (def_in_construct != block_info->construct) {
+ local_def.used_in_another_construct = true;
+ }
}
}
};
@@ -4941,8 +4952,8 @@
// in the parent block B.
const auto phi_id = inst.result_id();
- auto* phi_def_info = GetDefInfo(phi_id);
- phi_def_info->is_phi = true;
+ auto& phi_local_def = GetDefInfo(phi_id)->local.value();
+ phi_local_def.is_phi = true;
// Track all the places where we need to mention the variable,
// so we can place its declaration. First, record the location of
@@ -4962,9 +4973,9 @@
// Track where P needs to be in scope. It's not an ordinary use, so don't
// count it as one.
const auto pred_pos = pred_block_info->pos;
- phi_def_info->first_use_pos =
- std::min(phi_def_info->first_use_pos, pred_pos);
- phi_def_info->last_use_pos = std::max(phi_def_info->last_use_pos, pred_pos);
+ phi_local_def.first_use_pos =
+ std::min(phi_local_def.first_use_pos, pred_pos);
+ phi_local_def.last_use_pos = std::max(phi_local_def.last_use_pos, pred_pos);
// Record the assignment that needs to occur at the end
// of the predecessor block.
@@ -4974,7 +4985,7 @@
// Schedule the declaration of the state variable.
const auto* enclosing_construct =
- GetEnclosingScope(phi_def_info->first_use_pos, phi_def_info->last_use_pos);
+ GetEnclosingScope(phi_local_def.first_use_pos, phi_local_def.last_use_pos);
GetBlockInfo(enclosing_construct->begin_id)->phis_needing_state_vars.Push(phi_id);
} else {
inst.ForEachInId([block_info, &record_value_use](const uint32_t* id_ptr) {
@@ -4998,22 +5009,24 @@
for (auto& id_def_info_pair : def_info_) {
const auto def_id = id_def_info_pair.first;
auto* def_info = id_def_info_pair.second.get();
- if (def_info->num_uses == 0) {
- // There is no need to adjust the location of the declaration.
- continue;
- }
- if (!def_info->locally_defined) {
+ if (!def_info->local.has_value()) {
// Never hoist a variable declared at module scope.
// This occurs for builtin variables, which are mapped to module-scope
// private variables.
continue;
}
+ auto& local_def = def_info->local.value();
- const auto* def_in_construct = GetBlockInfo(block_order_[def_info->block_pos])->construct;
+ if (local_def.num_uses == 0) {
+ // There is no need to adjust the location of the declaration.
+ continue;
+ }
+
+ const auto* def_in_construct = GetBlockInfo(block_order_[local_def.block_pos])->construct;
// A definition in the first block of an kIfSelection or kSwitchSelection
// occurs before the branch, and so that definition should count as
// having been defined at the scope of the parent construct.
- if (def_info->block_pos == def_in_construct->begin_pos) {
+ if (local_def.block_pos == def_in_construct->begin_pos) {
if ((def_in_construct->kind == Construct::kIfSelection) ||
(def_in_construct->kind == Construct::kSwitchSelection)) {
def_in_construct = def_in_construct->parent;
@@ -5022,12 +5035,12 @@
// We care about the earliest between the place of definition, and the first
// use of the value.
- const auto first_pos = std::min(def_info->block_pos, def_info->first_use_pos);
- const auto last_use_pos = def_info->last_use_pos;
+ const auto first_pos = std::min(local_def.block_pos, local_def.first_use_pos);
+ const auto last_use_pos = local_def.last_use_pos;
bool should_hoist_to_let = false;
bool should_hoist_to_var = false;
- if (def_info->is_phi) {
+ if (local_def.is_phi) {
// We need to generate a variable, and assignments to that variable in
// all the phi parent blocks.
should_hoist_to_var = true;
@@ -5041,7 +5054,7 @@
// simple heuristic to avoid changing the cost of an operation
// by moving it into or out of a loop, for example.
if ((def_info->storage_class == ast::StorageClass::kInvalid) &&
- def_info->used_in_another_construct) {
+ local_def.used_in_another_construct) {
should_hoist_to_let = true;
}
}
@@ -5050,11 +5063,11 @@
const auto* enclosing_construct = GetEnclosingScope(first_pos, last_use_pos);
if (should_hoist_to_let && (enclosing_construct == def_in_construct)) {
// We can use a plain 'let' declaration.
- def_info->requires_named_const_def = true;
+ def_info->requires_named_let_def = true;
} else {
// We need to make a hoisted variable definition.
// TODO(dneto): Handle non-storable types, particularly pointers.
- def_info->requires_hoisted_def = true;
+ def_info->requires_hoisted_var_def = true;
auto* hoist_to_block = GetBlockInfo(enclosing_construct->begin_id);
hoist_to_block->hoisted_ids.Push(def_id);
}
diff --git a/src/tint/reader/spirv/function.h b/src/tint/reader/spirv/function.h
index ff7336e..53c983c 100644
--- a/src/tint/reader/spirv/function.h
+++ b/src/tint/reader/spirv/function.h
@@ -159,7 +159,7 @@
/// The result IDs that this block is responsible for declaring as a
/// hoisted variable.
- /// @see DefInfo#requires_hoisted_def
+ /// @see DefInfo#requires_hoisted_var_def
utils::Vector<uint32_t, 4> hoisted_ids;
/// A PhiAssignment represents the assignment of a value to the state
@@ -241,57 +241,74 @@
/// function.
/// - certain module-scope builtin variables.
struct DefInfo {
- /// Constructor.
+ /// Constructor for a locally defined value.
+ /// @param index an ordering index for uniqueness.
/// @param def_inst the SPIR-V instruction defining the ID
- /// @param locally_defined true if the defining instruction is in the function
- /// @param block_pos the position of the basic block where the ID is defined
- /// @param index an ordering index for this local definition
- DefInfo(const spvtools::opt::Instruction& def_inst,
- bool locally_defined,
- uint32_t block_pos,
- size_t index);
+ /// @param block_pos the position of the first basic block dominated by the
+ /// definition
+ DefInfo(size_t index, const spvtools::opt::Instruction& def_inst, uint32_t block_pos);
+ /// Constructor for a value defined at module scope.
+ /// @param index an ordering index for uniqueness.
+ /// @param def_inst the SPIR-V instruction defining the ID
+ DefInfo(size_t index, const spvtools::opt::Instruction& def_inst);
+
/// Destructor.
~DefInfo();
- /// The SPIR-V instruction that defines the ID.
- const spvtools::opt::Instruction& inst;
-
- /// True if the definition of this ID is inside the function.
- const bool locally_defined = true;
-
- /// For IDs defined in the function, this is the position of the block
- /// containing the definition of the ID, in function block order.
- /// For IDs defined outside of the function, it is 0.
- /// See method `FunctionEmitter::ComputeBlockOrderAndPositions`
- const uint32_t block_pos = 0;
-
/// An index for uniquely and deterministically ordering all DefInfo records
/// in a function.
const size_t index = 0;
- /// The number of uses of this ID.
- uint32_t num_uses = 0;
+ /// The SPIR-V instruction that defines the ID.
+ const spvtools::opt::Instruction& inst;
- /// The block position of the first use of this ID, or MAX_UINT if it is not
- /// used at all. The "first" ordering is determined by the function block
- /// order. The first use of an ID might be in an OpPhi that precedes the
- /// definition of the ID.
- /// The ID defined by an OpPhi is counted as being "used" in each of its
- /// parent blocks.
- uint32_t first_use_pos = std::numeric_limits<uint32_t>::max();
- /// The block position of the last use of this ID, or 0 if it is not used
- /// at all. The "last" ordering is determined by the function block order.
- /// The ID defined by an OpPhi is counted as being "used" in each of its
- /// parent blocks.
- uint32_t last_use_pos = 0;
+ /// Information about a definition created inside a function.
+ struct Local {
+ /// Constructor.
+ /// @param block_pos the position of the basic block defining the value.
+ explicit Local(uint32_t block_pos);
+ /// Copy constructor.
+ /// @param other the original object to copy from.
+ Local(const Local& other);
+ /// Destructor.
+ ~Local();
- /// Is this value used in a construct other than the one in which it was
- /// defined?
- bool used_in_another_construct = false;
+ /// The position of the basic block defininig the value, in function
+ /// block order.
+ /// See method `FunctionEmitter::ComputeBlockOrderAndPositions` for block
+ /// ordering.
+ const uint32_t block_pos = 0;
- /// True if this ID requires a WGSL 'const' definition, due to context. It
+ /// The number of uses of this ID.
+ uint32_t num_uses = 0;
+
+ /// The block position of the first use of this ID, or MAX_UINT if it is not
+ /// used at all. The "first" ordering is determined by the function block
+ /// order. The first use of an ID might be in an OpPhi that precedes the
+ /// definition of the ID.
+ /// The ID defined by an OpPhi is counted as being "used" in each of its
+ /// parent blocks.
+ uint32_t first_use_pos = std::numeric_limits<uint32_t>::max();
+ /// The block position of the last use of this ID, or 0 if it is not used
+ /// at all. The "last" ordering is determined by the function block order.
+ /// The ID defined by an OpPhi is counted as being "used" in each of its
+ /// parent blocks.
+ uint32_t last_use_pos = 0;
+
+ /// Is this value used in a construct other than the one in which it was
+ /// defined?
+ bool used_in_another_construct = false;
+ /// Is this ID an OpPhi?
+ bool is_phi = false;
+ };
+
+ /// Information about a definition inside the function. Populated if and only
+ /// if the definition actually is inside the function.
+ std::optional<Local> local;
+
+ /// True if this ID requires a WGSL 'let' definition, due to context. It
/// might get one anyway (so this is *not* an if-and-only-if condition).
- bool requires_named_const_def = false;
+ bool requires_named_let_def = false;
/// True if this ID must map to a WGSL variable declaration before the
/// corresponding position of the ID definition in SPIR-V. This compensates
@@ -306,10 +323,7 @@
/// variable.
/// TODO(dneto): This works for constants of storable type, but not, for
/// example, pointers. crbug.com/tint/98
- bool requires_hoisted_def = false;
-
- /// Is this ID an OpPhi?
- bool is_phi = false;
+ bool requires_hoisted_var_def = false;
/// The storage class to use for this value, if it is of pointer type.
/// This is required to carry a storage class override from a storage
@@ -336,14 +350,16 @@
/// @returns the ostream so calls can be chained
inline std::ostream& operator<<(std::ostream& o, const DefInfo& di) {
o << "DefInfo{"
- << " inst.result_id: " << di.inst.result_id()
- << " locally_defined: " << (di.locally_defined ? "true" : "false")
- << " block_pos: " << di.block_pos << " num_uses: " << di.num_uses
- << " first_use_pos: " << di.first_use_pos << " last_use_pos: " << di.last_use_pos
- << " used_in_another_construct: " << (di.used_in_another_construct ? "true" : "false")
- << " requires_named_const_def: " << (di.requires_named_const_def ? "true" : "false")
- << " requires_hoisted_def: " << (di.requires_hoisted_def ? "true" : "false")
- << " is_phi: " << (di.is_phi ? "true" : "false") << "";
+ << " inst.result_id: " << di.inst.result_id();
+ if (di.local.has_value()) {
+ const auto& dil = di.local.value();
+ o << " block_pos: " << dil.block_pos << " num_uses: " << dil.num_uses
+ << " first_use_pos: " << dil.first_use_pos << " last_use_pos: " << dil.last_use_pos
+ << " used_in_another_construct: " << (dil.used_in_another_construct ? "true" : "false")
+ << " is_phi: " << (dil.is_phi ? "true" : "false") << "";
+ }
+ o << " requires_named_let_def: " << (di.requires_named_let_def ? "true" : "false")
+ << " requires_hoisted_var_def: " << (di.requires_hoisted_var_def ? "true" : "false");
if (di.storage_class != ast::StorageClass::kNone) {
o << " sc:" << int(di.storage_class);
}
@@ -616,14 +632,14 @@
/// - When a SPIR-V instruction might use the dynamically computed value
/// only once, but the WGSL code might reference it multiple times.
/// For example, this occurs for the vector operands of OpVectorShuffle.
- /// In this case the definition's DefInfo#requires_named_const_def property
+ /// In this case the definition's DefInfo#requires_named_let_def property
/// is set to true.
/// - When a definition and at least one of its uses are not in the
/// same structured construct.
- /// In this case the definition's DefInfo#requires_named_const_def property
+ /// In this case the definition's DefInfo#requires_named_let_def property
/// is set to true.
/// - When a definition is in a construct that does not enclose all the
- /// uses. In this case the definition's DefInfo#requires_hoisted_def
+ /// uses. In this case the definition's DefInfo#requires_hoisted_var_def
/// property is set to true.
/// Updates the `def_info_` mapping.
void FindValuesNeedingNamedOrHoistedDefinition();
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 7e33244..8208236 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -182,24 +182,27 @@
return ss.str();
}
-/// Constant inherits from sem::Constant to add an private implementation method for conversion.
-struct Constant : public sem::Constant {
+/// ImplConstant inherits from sem::Constant to add an private implementation method for conversion.
+struct ImplConstant : public sem::Constant {
/// Convert attempts to convert the constant value to the given type. On error, Convert()
/// creates a new diagnostic message and returns a Failure.
- virtual utils::Result<const Constant*> Convert(ProgramBuilder& builder,
- const sem::Type* target_ty,
- const Source& source) const = 0;
+ virtual utils::Result<const ImplConstant*> Convert(ProgramBuilder& builder,
+ const sem::Type* target_ty,
+ const Source& source) const = 0;
};
+/// A result templated with a ImplConstant.
+using ImplResult = utils::Result<const ImplConstant*>;
+
// Forward declaration
-const Constant* CreateComposite(ProgramBuilder& builder,
- const sem::Type* type,
- utils::VectorRef<const sem::Constant*> elements);
+const ImplConstant* CreateComposite(ProgramBuilder& builder,
+ const sem::Type* type,
+ utils::VectorRef<const sem::Constant*> elements);
/// Element holds a single scalar or abstract-numeric value.
/// Element implements the Constant interface.
template <typename T>
-struct Element : Constant {
+struct Element : ImplConstant {
static_assert(!std::is_same_v<UnwrapNumber<T>, T> || std::is_same_v<T, bool>,
"T must be a Number or bool");
@@ -219,16 +222,15 @@
bool AllEqual() const override { return true; }
size_t Hash() const override { return utils::Hash(type, ValueOf(value)); }
- utils::Result<const Constant*> Convert(ProgramBuilder& builder,
- const sem::Type* target_ty,
- const Source& source) const override {
+ ImplResult Convert(ProgramBuilder& builder,
+ const sem::Type* target_ty,
+ const Source& source) const override {
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
if (target_ty == type) {
// If the types are identical, then no conversion is needed.
return this;
}
- bool failed = false;
- auto* res = ZeroTypeDispatch(target_ty, [&](auto zero_to) -> const Constant* {
+ return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ImplResult {
// `T` is the source type, `value` is the source value.
// `TO` is the target type.
using TO = std::decay_t<decltype(zero_to)>;
@@ -248,7 +250,7 @@
ss << "value " << value << " cannot be represented as ";
ss << "'" << builder.FriendlyName(target_ty) << "'";
builder.Diagnostics().add_error(tint::diag::System::Resolver, ss.str(), source);
- failed = true;
+ return utils::Failure;
} else if constexpr (IsFloatingPoint<UnwrapNumber<TO>>) {
// [x -> floating-point] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion
@@ -270,11 +272,6 @@
}
return nullptr; // Expression is not constant.
});
- if (failed) {
- // A diagnostic error has been raised, and resolving should abort.
- return utils::Failure;
- }
- return res;
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
}
@@ -286,7 +283,7 @@
/// Splat is used for zero-initializers, 'splat' constructors, or constructors where each element is
/// identical. Splat may be of a vector, matrix or array type.
/// Splat implements the Constant interface.
-struct Splat : Constant {
+struct Splat : ImplConstant {
Splat(const sem::Type* t, const sem::Constant* e, size_t n) : type(t), el(e), count(n) {}
~Splat() override = default;
const sem::Type* Type() const override { return type; }
@@ -297,13 +294,13 @@
bool AllEqual() const override { return true; }
size_t Hash() const override { return utils::Hash(type, el->Hash(), count); }
- utils::Result<const Constant*> Convert(ProgramBuilder& builder,
- const sem::Type* target_ty,
- const Source& source) const override {
+ ImplResult Convert(ProgramBuilder& builder,
+ const sem::Type* target_ty,
+ const Source& source) const override {
// Convert the single splatted element type.
// Note: This file is the only place where `sem::Constant`s are created, so this static_cast
// is safe.
- auto conv_el = static_cast<const Constant*>(el)->Convert(
+ auto conv_el = static_cast<const ImplConstant*>(el)->Convert(
builder, sem::Type::ElementOf(target_ty), source);
if (!conv_el) {
return utils::Failure;
@@ -324,7 +321,7 @@
/// If each element is the same type and value, then a Splat would be a more efficient constant
/// implementation. Use CreateComposite() to create the appropriate Constant type.
/// Composite implements the Constant interface.
-struct Composite : Constant {
+struct Composite : ImplConstant {
Composite(const sem::Type* t,
utils::VectorRef<const sem::Constant*> els,
bool all_0,
@@ -341,9 +338,9 @@
bool AllEqual() const override { return false; /* otherwise this should be a Splat */ }
size_t Hash() const override { return hash; }
- utils::Result<const Constant*> Convert(ProgramBuilder& builder,
- const sem::Type* target_ty,
- const Source& source) const override {
+ ImplResult Convert(ProgramBuilder& builder,
+ const sem::Type* target_ty,
+ const Source& source) const override {
// Convert each of the composite element types.
auto* el_ty = sem::Type::ElementOf(target_ty);
utils::Vector<const sem::Constant*, 4> conv_els;
@@ -351,7 +348,7 @@
for (auto* el : elements) {
// Note: This file is the only place where `sem::Constant`s are created, so this
// static_cast is safe.
- auto conv_el = static_cast<const Constant*>(el)->Convert(builder, el_ty, source);
+ auto conv_el = static_cast<const ImplConstant*>(el)->Convert(builder, el_ty, source);
if (!conv_el) {
return utils::Failure;
}
@@ -380,30 +377,30 @@
/// CreateElement constructs and returns an Element<T>.
template <typename T>
-const Constant* CreateElement(ProgramBuilder& builder, const sem::Type* t, T v) {
+const ImplConstant* CreateElement(ProgramBuilder& builder, const sem::Type* t, T v) {
return builder.create<Element<T>>(t, v);
}
/// ZeroValue returns a Constant for the zero-value of the type `type`.
-const Constant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) {
+const ImplConstant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) {
return Switch(
type, //
- [&](const sem::Vector* v) -> const Constant* {
+ [&](const sem::Vector* v) -> const ImplConstant* {
auto* zero_el = ZeroValue(builder, v->type());
return builder.create<Splat>(type, zero_el, v->Width());
},
- [&](const sem::Matrix* m) -> const Constant* {
+ [&](const sem::Matrix* m) -> const ImplConstant* {
auto* zero_el = ZeroValue(builder, m->ColumnType());
return builder.create<Splat>(type, zero_el, m->columns());
},
- [&](const sem::Array* a) -> const Constant* {
+ [&](const sem::Array* a) -> const ImplConstant* {
if (auto* zero_el = ZeroValue(builder, a->ElemType())) {
return builder.create<Splat>(type, zero_el, a->Count());
}
return nullptr;
},
- [&](const sem::Struct* s) -> const Constant* {
- std::unordered_map<const sem::Type*, const Constant*> zero_by_type;
+ [&](const sem::Struct* s) -> const ImplConstant* {
+ std::unordered_map<const sem::Type*, const ImplConstant*> zero_by_type;
utils::Vector<const sem::Constant*, 4> zeros;
zeros.Reserve(s->Members().size());
for (auto* member : s->Members()) {
@@ -420,8 +417,8 @@
}
return CreateComposite(builder, s, std::move(zeros));
},
- [&](Default) -> const Constant* {
- return ZeroTypeDispatch(type, [&](auto zero) -> const Constant* {
+ [&](Default) -> const ImplConstant* {
+ return ZeroTypeDispatch(type, [&](auto zero) -> const ImplConstant* {
return CreateElement(builder, type, zero);
});
});
@@ -467,9 +464,9 @@
/// CreateComposite is used to construct a constant of a vector, matrix or array type.
/// CreateComposite examines the element values and will return either a Composite or a Splat,
/// depending on the element types and values.
-const Constant* CreateComposite(ProgramBuilder& builder,
- const sem::Type* type,
- utils::VectorRef<const sem::Constant*> elements) {
+const ImplConstant* CreateComposite(ProgramBuilder& builder,
+ const sem::Type* type,
+ utils::VectorRef<const sem::Constant*> elements) {
if (elements.IsEmpty()) {
return nullptr;
}
@@ -504,10 +501,10 @@
/// transformation function 'f' on each of the most deeply nested elements of 'cs'. Assumes that all
/// input constants `cs` are of the same type.
template <typename F, typename... CONSTANTS>
-const Constant* TransformElements(ProgramBuilder& builder,
- const sem::Type* composite_ty,
- F&& f,
- CONSTANTS&&... cs) {
+ImplResult TransformElements(ProgramBuilder& builder,
+ const sem::Type* composite_ty,
+ F&& f,
+ CONSTANTS&&... cs) {
uint32_t n = 0;
auto* ty = First(cs...)->Type();
auto* el_ty = sem::Type::ElementOf(ty, &n);
@@ -517,8 +514,13 @@
utils::Vector<const sem::Constant*, 8> els;
els.Reserve(n);
for (uint32_t i = 0; i < n; i++) {
- els.Push(TransformElements(builder, sem::Type::ElementOf(composite_ty), std::forward<F>(f),
- cs->Index(i)...));
+ if (auto el = TransformElements(builder, sem::Type::ElementOf(composite_ty),
+ std::forward<F>(f), cs->Index(i)...)) {
+ els.Push(el.Get());
+
+ } else {
+ return el.Failure();
+ }
}
return CreateComposite(builder, composite_ty, std::move(els));
}
@@ -528,11 +530,11 @@
/// Unlike TransformElements, this function handles the constants being of different types, e.g.
/// vector-scalar, scalar-vector.
template <typename F>
-const Constant* TransformBinaryElements(ProgramBuilder& builder,
- const sem::Type* composite_ty,
- F&& f,
- const sem::Constant* c0,
- const sem::Constant* c1) {
+ImplResult TransformBinaryElements(ProgramBuilder& builder,
+ const sem::Type* composite_ty,
+ F&& f,
+ const sem::Constant* c0,
+ const sem::Constant* c1) {
uint32_t n0 = 0, n1 = 0;
sem::Type::ElementOf(c0->Type(), &n0);
sem::Type::ElementOf(c1->Type(), &n1);
@@ -551,9 +553,13 @@
}
return c->Index(i);
};
- els.Push(TransformBinaryElements(builder, sem::Type::ElementOf(composite_ty),
- std::forward<F>(f), nested_or_self(c0, n0),
- nested_or_self(c1, n1)));
+ if (auto el = TransformBinaryElements(builder, sem::Type::ElementOf(composite_ty),
+ std::forward<F>(f), nested_or_self(c0, n0),
+ nested_or_self(c1, n1))) {
+ els.Push(el.Get());
+ } else {
+ return el.Failure();
+ }
}
return CreateComposite(builder, composite_ty, std::move(els));
}
@@ -703,7 +709,7 @@
}
auto ConstEval::AddFunc(const sem::Type* elem_ty) {
- return [=](auto a1, auto a2) -> utils::Result<const Constant*> {
+ return [=](auto a1, auto a2) -> ImplResult {
if (auto r = Add(a1, a2)) {
return CreateElement(builder, elem_ty, r.Get());
}
@@ -712,7 +718,7 @@
}
auto ConstEval::MulFunc(const sem::Type* elem_ty) {
- return [=](auto a1, auto a2) -> utils::Result<const Constant*> {
+ return [=](auto a1, auto a2) -> ImplResult {
if (auto r = Mul(a1, a2)) {
return CreateElement(builder, elem_ty, r.Get());
}
@@ -721,7 +727,7 @@
}
auto ConstEval::Dot2Func(const sem::Type* elem_ty) {
- return [=](auto a1, auto a2, auto b1, auto b2) -> utils::Result<const Constant*> {
+ return [=](auto a1, auto a2, auto b1, auto b2) -> ImplResult {
if (auto r = Dot2(a1, a2, b1, b2)) {
return CreateElement(builder, elem_ty, r.Get());
}
@@ -730,8 +736,7 @@
}
auto ConstEval::Dot3Func(const sem::Type* elem_ty) {
- return [=](auto a1, auto a2, auto a3, auto b1, auto b2,
- auto b3) -> utils::Result<const Constant*> {
+ return [=](auto a1, auto a2, auto a3, auto b1, auto b2, auto b3) -> ImplResult {
if (auto r = Dot3(a1, a2, a3, b1, b2, b3)) {
return CreateElement(builder, elem_ty, r.Get());
}
@@ -740,23 +745,22 @@
}
auto ConstEval::Dot4Func(const sem::Type* elem_ty) {
- return [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3,
- auto b4) -> utils::Result<const Constant*> {
- if (auto r = Dot4(a1, a2, a3, a4, b1, b2, b3, b4)) {
- return CreateElement(builder, elem_ty, r.Get());
- }
- return utils::Failure;
- };
+ return
+ [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3, auto b4) -> ImplResult {
+ if (auto r = Dot4(a1, a2, a3, a4, b1, b2, b3, b4)) {
+ return CreateElement(builder, elem_ty, r.Get());
+ }
+ return utils::Failure;
+ };
}
-ConstEval::ConstantResult ConstEval::Literal(const sem::Type* ty,
- const ast::LiteralExpression* literal) {
+ConstEval::Result ConstEval::Literal(const sem::Type* ty, const ast::LiteralExpression* literal) {
return Switch(
literal,
[&](const ast::BoolLiteralExpression* lit) {
return CreateElement(builder, ty, lit->value);
},
- [&](const ast::IntLiteralExpression* lit) -> const Constant* {
+ [&](const ast::IntLiteralExpression* lit) -> ImplResult {
switch (lit->suffix) {
case ast::IntLiteralExpression::Suffix::kNone:
return CreateElement(builder, ty, AInt(lit->value));
@@ -767,7 +771,7 @@
}
return nullptr;
},
- [&](const ast::FloatLiteralExpression* lit) -> const Constant* {
+ [&](const ast::FloatLiteralExpression* lit) -> ImplResult {
switch (lit->suffix) {
case ast::FloatLiteralExpression::Suffix::kNone:
return CreateElement(builder, ty, AFloat(lit->value));
@@ -780,9 +784,8 @@
});
}
-ConstEval::ConstantResult ConstEval::ArrayOrStructCtor(
- const sem::Type* ty,
- utils::VectorRef<const sem::Expression*> args) {
+ConstEval::Result ConstEval::ArrayOrStructCtor(const sem::Type* ty,
+ utils::VectorRef<const sem::Expression*> args) {
if (args.IsEmpty()) {
return ZeroValue(builder, ty);
}
@@ -801,9 +804,9 @@
return CreateComposite(builder, ty, std::move(els));
}
-ConstEval::ConstantResult ConstEval::Conv(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source) {
+ConstEval::Result ConstEval::Conv(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source) {
uint32_t el_count = 0;
auto* el_ty = sem::Type::ElementOf(ty, &el_count);
if (!el_ty) {
@@ -821,36 +824,36 @@
return nullptr;
}
-ConstEval::ConstantResult ConstEval::Zero(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*>,
- const Source&) {
+ConstEval::Result ConstEval::Zero(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*>,
+ const Source&) {
return ZeroValue(builder, ty);
}
-ConstEval::ConstantResult ConstEval::Identity(const sem::Type*,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::Identity(const sem::Type*,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
return args[0];
}
-ConstEval::ConstantResult ConstEval::VecSplat(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::VecSplat(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
if (auto* arg = args[0]) {
return builder.create<Splat>(ty, arg, static_cast<const sem::Vector*>(ty)->Width());
}
return nullptr;
}
-ConstEval::ConstantResult ConstEval::VecCtorS(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::VecCtorS(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
return CreateComposite(builder, ty, args);
}
-ConstEval::ConstantResult ConstEval::VecCtorM(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::VecCtorM(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
utils::Vector<const sem::Constant*, 4> els;
for (auto* arg : args) {
auto* val = arg;
@@ -874,9 +877,9 @@
return CreateComposite(builder, ty, std::move(els));
}
-ConstEval::ConstantResult ConstEval::MatCtorS(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::MatCtorS(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
auto* m = static_cast<const sem::Matrix*>(ty);
utils::Vector<const sem::Constant*, 4> els;
@@ -891,14 +894,14 @@
return CreateComposite(builder, ty, std::move(els));
}
-ConstEval::ConstantResult ConstEval::MatCtorV(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::MatCtorV(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
return CreateComposite(builder, ty, args);
}
-ConstEval::ConstantResult ConstEval::Index(const sem::Expression* obj_expr,
- const sem::Expression* idx_expr) {
+ConstEval::Result ConstEval::Index(const sem::Expression* obj_expr,
+ const sem::Expression* idx_expr) {
auto idx_val = idx_expr->ConstantValue();
if (!idx_val) {
return nullptr;
@@ -926,8 +929,8 @@
return obj_val->Index(static_cast<size_t>(idx));
}
-ConstEval::ConstantResult ConstEval::MemberAccess(const sem::Expression* obj_expr,
- const sem::StructMember* member) {
+ConstEval::Result ConstEval::MemberAccess(const sem::Expression* obj_expr,
+ const sem::StructMember* member) {
auto obj_val = obj_expr->ConstantValue();
if (!obj_val) {
return nullptr;
@@ -935,30 +938,29 @@
return obj_val->Index(static_cast<size_t>(member->Index()));
}
-ConstEval::ConstantResult ConstEval::Swizzle(const sem::Type* ty,
- const sem::Expression* vec_expr,
- utils::VectorRef<uint32_t> indices) {
+ConstEval::Result ConstEval::Swizzle(const sem::Type* ty,
+ const sem::Expression* vec_expr,
+ utils::VectorRef<uint32_t> indices) {
auto* vec_val = vec_expr->ConstantValue();
if (!vec_val) {
return nullptr;
}
if (indices.Length() == 1) {
return vec_val->Index(static_cast<size_t>(indices[0]));
- } else {
- auto values = utils::Transform<4>(
- indices, [&](uint32_t i) { return vec_val->Index(static_cast<size_t>(i)); });
- return CreateComposite(builder, ty, std::move(values));
}
+ auto values = utils::Transform<4>(
+ indices, [&](uint32_t i) { return vec_val->Index(static_cast<size_t>(i)); });
+ return CreateComposite(builder, ty, std::move(values));
}
-ConstEval::ConstantResult ConstEval::Bitcast(const sem::Type*, const sem::Expression*) {
+ConstEval::Result ConstEval::Bitcast(const sem::Type*, const sem::Expression*) {
// TODO(crbug.com/tint/1581): Implement @const intrinsics
return nullptr;
}
-ConstEval::ConstantResult ConstEval::OpComplement(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::OpComplement(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
auto transform = [&](const sem::Constant* c) {
auto create = [&](auto i) {
return CreateElement(builder, c->Type(), decltype(i)(~i.value));
@@ -968,9 +970,9 @@
return TransformElements(builder, ty, transform, args[0]);
}
-ConstEval::ConstantResult ConstEval::OpUnaryMinus(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::OpUnaryMinus(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
auto transform = [&](const sem::Constant* c) {
auto create = [&](auto i) {
// For signed integrals, avoid C++ UB by not negating the
@@ -993,9 +995,9 @@
return TransformElements(builder, ty, transform, args[0]);
}
-ConstEval::ConstantResult ConstEval::OpNot(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::OpNot(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
auto transform = [&](const sem::Constant* c) {
auto create = [&](auto i) { return CreateElement(builder, c->Type(), decltype(i)(!i)); };
return Dispatch_bool(create, c);
@@ -1003,29 +1005,22 @@
return TransformElements(builder, ty, transform, args[0]);
}
-ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source) {
+ConstEval::Result ConstEval::OpPlus(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source);
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* {
- if (auto r = Dispatch_fia_fiu32_f16(AddFunc(c0->Type()), c0, c1)) {
- return r.Get();
- }
- return nullptr;
+ auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ return Dispatch_fia_fiu32_f16(AddFunc(c0->Type()), c0, c1);
};
- auto r = TransformBinaryElements(builder, ty, transform, args[0], args[1]);
- if (builder.Diagnostics().contains_errors()) {
- return utils::Failure;
- }
- return r;
+ return TransformBinaryElements(builder, ty, transform, args[0], args[1]);
}
-ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source) {
+ConstEval::Result ConstEval::OpMinus(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
- auto create = [&](auto i, auto j) -> const Constant* {
+ auto create = [&](auto i, auto j) -> ImplResult {
using NumberT = decltype(i);
NumberT result;
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
@@ -1034,7 +1029,7 @@
result = r->value;
} else {
AddError(OverflowErrorMessage(i, "-", j), source);
- return nullptr;
+ return utils::Failure;
}
} else {
using T = UnwrapNumber<NumberT>;
@@ -1054,41 +1049,30 @@
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
- auto r = TransformBinaryElements(builder, ty, transform, args[0], args[1]);
- if (builder.Diagnostics().contains_errors()) {
- return utils::Failure;
- }
- return r;
+ return TransformBinaryElements(builder, ty, transform, args[0], args[1]);
}
-ConstEval::ConstantResult ConstEval::OpMultiply(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source) {
+ConstEval::Result ConstEval::OpMultiply(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source);
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* {
- if (auto r = Dispatch_fia_fiu32_f16(MulFunc(c0->Type()), c0, c1)) {
- return r.Get();
- }
- return nullptr;
+ auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ return Dispatch_fia_fiu32_f16(MulFunc(c0->Type()), c0, c1);
};
- auto r = TransformBinaryElements(builder, ty, transform, args[0], args[1]);
- if (builder.Diagnostics().contains_errors()) {
- return utils::Failure;
- }
- return r;
+ return TransformBinaryElements(builder, ty, transform, args[0], args[1]);
}
-ConstEval::ConstantResult ConstEval::OpMultiplyMatVec(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source) {
+ConstEval::Result ConstEval::OpMultiplyMatVec(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto* mat_ty = args[0]->Type()->As<sem::Matrix>();
auto* vec_ty = args[1]->Type()->As<sem::Vector>();
auto* elem_ty = vec_ty->type();
auto dot = [&](const sem::Constant* m, size_t row, const sem::Constant* v) {
- utils::Result<const Constant*> result;
+ ImplResult result;
switch (mat_ty->columns()) {
case 2:
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
@@ -1130,16 +1114,16 @@
}
return CreateComposite(builder, ty, result);
}
-ConstEval::ConstantResult ConstEval::OpMultiplyVecMat(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source) {
+ConstEval::Result ConstEval::OpMultiplyVecMat(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto* vec_ty = args[0]->Type()->As<sem::Vector>();
auto* mat_ty = args[1]->Type()->As<sem::Matrix>();
auto* elem_ty = vec_ty->type();
auto dot = [&](const sem::Constant* v, const sem::Constant* m, size_t col) {
- utils::Result<const Constant*> result;
+ ImplResult result;
switch (mat_ty->rows()) {
case 2:
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
@@ -1182,9 +1166,9 @@
return CreateComposite(builder, ty, result);
}
-ConstEval::ConstantResult ConstEval::OpMultiplyMatMat(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source) {
+ConstEval::Result ConstEval::OpMultiplyMatMat(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto* mat1 = args[0];
auto* mat2 = args[1];
@@ -1196,7 +1180,7 @@
auto m1e = [&](size_t r, size_t c) { return m1->Index(c)->Index(r); };
auto m2e = [&](size_t r, size_t c) { return m2->Index(c)->Index(r); };
- utils::Result<const Constant*> result;
+ ImplResult result;
switch (mat1_ty->columns()) {
case 2:
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
@@ -1247,11 +1231,11 @@
return CreateComposite(builder, ty, result_mat);
}
-ConstEval::ConstantResult ConstEval::OpDivide(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source) {
+ConstEval::Result ConstEval::OpDivide(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
- auto create = [&](auto i, auto j) -> const Constant* {
+ auto create = [&](auto i, auto j) -> ImplResult {
using NumberT = decltype(i);
NumberT result;
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
@@ -1260,7 +1244,7 @@
result = r->value;
} else {
AddError(OverflowErrorMessage(i, "/", j), source);
- return nullptr;
+ return utils::Failure;
}
} else {
using T = UnwrapNumber<NumberT>;
@@ -1288,120 +1272,92 @@
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
- auto r = TransformBinaryElements(builder, ty, transform, args[0], args[1]);
- if (builder.Diagnostics().contains_errors()) {
- return utils::Failure;
- }
- return r;
+ return TransformBinaryElements(builder, ty, transform, args[0], args[1]);
}
-ConstEval::ConstantResult ConstEval::OpEqual(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::OpEqual(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
- auto create = [&](auto i, auto j) -> const Constant* {
+ auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, sem::Type::DeepestElementOf(ty), i == j);
};
return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
};
- auto r = TransformElements(builder, ty, transform, args[0], args[1]);
- if (builder.Diagnostics().contains_errors()) {
- return utils::Failure;
- }
- return r;
+ return TransformElements(builder, ty, transform, args[0], args[1]);
}
-ConstEval::ConstantResult ConstEval::OpNotEqual(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::OpNotEqual(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
- auto create = [&](auto i, auto j) -> const Constant* {
+ auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, sem::Type::DeepestElementOf(ty), i != j);
};
return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
};
- auto r = TransformElements(builder, ty, transform, args[0], args[1]);
- if (builder.Diagnostics().contains_errors()) {
- return utils::Failure;
- }
- return r;
+ return TransformElements(builder, ty, transform, args[0], args[1]);
}
-ConstEval::ConstantResult ConstEval::OpLessThan(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::OpLessThan(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
- auto create = [&](auto i, auto j) -> const Constant* {
+ auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, sem::Type::DeepestElementOf(ty), i < j);
};
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
- auto r = TransformElements(builder, ty, transform, args[0], args[1]);
- if (builder.Diagnostics().contains_errors()) {
- return utils::Failure;
- }
- return r;
+ return TransformElements(builder, ty, transform, args[0], args[1]);
}
-ConstEval::ConstantResult ConstEval::OpGreaterThan(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::OpGreaterThan(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
- auto create = [&](auto i, auto j) -> const Constant* {
+ auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, sem::Type::DeepestElementOf(ty), i > j);
};
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
- auto r = TransformElements(builder, ty, transform, args[0], args[1]);
- if (builder.Diagnostics().contains_errors()) {
- return utils::Failure;
- }
- return r;
+ return TransformElements(builder, ty, transform, args[0], args[1]);
}
-ConstEval::ConstantResult ConstEval::OpLessThanEqual(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::OpLessThanEqual(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
- auto create = [&](auto i, auto j) -> const Constant* {
+ auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, sem::Type::DeepestElementOf(ty), i <= j);
};
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
- auto r = TransformElements(builder, ty, transform, args[0], args[1]);
- if (builder.Diagnostics().contains_errors()) {
- return utils::Failure;
- }
- return r;
+ return TransformElements(builder, ty, transform, args[0], args[1]);
}
-ConstEval::ConstantResult ConstEval::OpGreaterThanEqual(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::OpGreaterThanEqual(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
- auto create = [&](auto i, auto j) -> const Constant* {
+ auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, sem::Type::DeepestElementOf(ty), i >= j);
};
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
- auto r = TransformElements(builder, ty, transform, args[0], args[1]);
- if (builder.Diagnostics().contains_errors()) {
- return utils::Failure;
- }
- return r;
+ return TransformElements(builder, ty, transform, args[0], args[1]);
}
-ConstEval::ConstantResult ConstEval::OpAnd(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::OpAnd(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
- auto create = [&](auto i, auto j) -> const Constant* {
+ auto create = [&](auto i, auto j) -> ImplResult {
using T = decltype(i);
T result;
if constexpr (std::is_same_v<T, bool>) {
@@ -1414,18 +1370,14 @@
return Dispatch_ia_iu32_bool(create, c0, c1);
};
- auto r = TransformElements(builder, ty, transform, args[0], args[1]);
- if (builder.Diagnostics().contains_errors()) {
- return utils::Failure;
- }
- return r;
+ return TransformElements(builder, ty, transform, args[0], args[1]);
}
-ConstEval::ConstantResult ConstEval::OpOr(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::OpOr(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
- auto create = [&](auto i, auto j) -> const Constant* {
+ auto create = [&](auto i, auto j) -> ImplResult {
using T = decltype(i);
T result;
if constexpr (std::is_same_v<T, bool>) {
@@ -1438,18 +1390,14 @@
return Dispatch_ia_iu32_bool(create, c0, c1);
};
- auto r = TransformElements(builder, ty, transform, args[0], args[1]);
- if (builder.Diagnostics().contains_errors()) {
- return utils::Failure;
- }
- return r;
+ return TransformElements(builder, ty, transform, args[0], args[1]);
}
-ConstEval::ConstantResult ConstEval::OpXor(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::OpXor(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
- auto create = [&](auto i, auto j) -> const Constant* {
+ auto create = [&](auto i, auto j) -> const ImplConstant* {
return CreateElement(builder, sem::Type::DeepestElementOf(ty), decltype(i){i ^ j});
};
return Dispatch_ia_iu32(create, c0, c1);
@@ -1462,9 +1410,9 @@
return r;
}
-ConstEval::ConstantResult ConstEval::atan2(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::atan2(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) {
return CreateElement(builder, c0->Type(), decltype(i)(std::atan2(i.value, j.value)));
@@ -1474,9 +1422,9 @@
return TransformElements(builder, ty, transform, args[0], args[1]);
}
-ConstEval::ConstantResult ConstEval::clamp(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::Result ConstEval::clamp(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1,
const sem::Constant* c2) {
auto create = [&](auto e, auto low, auto high) {
@@ -1488,17 +1436,13 @@
return TransformElements(builder, ty, transform, args[0], args[1], args[2]);
}
-utils::Result<const sem::Constant*> ConstEval::Convert(const sem::Type* target_ty,
- const sem::Constant* value,
- const Source& source) {
+ConstEval::Result ConstEval::Convert(const sem::Type* target_ty,
+ const sem::Constant* value,
+ const Source& source) {
if (value->Type() == target_ty) {
return value;
}
- auto conv = static_cast<const Constant*>(value)->Convert(builder, target_ty, source);
- if (!conv) {
- return utils::Failure;
- }
- return conv.Get();
+ return static_cast<const ImplConstant*>(value)->Convert(builder, target_ty, source);
}
void ConstEval::AddError(const std::string& msg, const Source& source) const {
diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h
index 04e2282..10a4d60 100644
--- a/src/tint/resolver/const_eval.h
+++ b/src/tint/resolver/const_eval.h
@@ -53,12 +53,12 @@
/// * `utils::Failure`. Returned when there was a resolver error. In this situation the method
/// will have already reported a diagnostic error message, and the caller should abort
/// resolving.
- using ConstantResult = utils::Result<const sem::Constant*>;
+ using Result = utils::Result<const sem::Constant*>;
/// Typedef for a constant evaluation function
- using Function = ConstantResult (ConstEval::*)(const sem::Type* result_ty,
- utils::VectorRef<const sem::Constant*>,
- const Source&);
+ using Function = Result (ConstEval::*)(const sem::Type* result_ty,
+ utils::VectorRef<const sem::Constant*>,
+ const Source&);
/// Constructor
/// @param b the program builder
@@ -71,44 +71,43 @@
/// @param ty the target type - must be an array or constructor
/// @param args the input arguments
/// @return the constructed value, or null if the value cannot be calculated
- ConstantResult ArrayOrStructCtor(const sem::Type* ty,
- utils::VectorRef<const sem::Expression*> args);
+ Result ArrayOrStructCtor(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
/// @param ty the target type
/// @param expr the input expression
/// @return the bit-cast of the given expression to the given type, or null if the value cannot
/// be calculated
- ConstantResult Bitcast(const sem::Type* ty, const sem::Expression* expr);
+ Result Bitcast(const sem::Type* ty, const sem::Expression* expr);
/// @param obj the object being indexed
/// @param idx the index expression
/// @return the result of the index, or null if the value cannot be calculated
- ConstantResult Index(const sem::Expression* obj, const sem::Expression* idx);
+ Result Index(const sem::Expression* obj, const sem::Expression* idx);
/// @param ty the result type
/// @param lit the literal AST node
/// @return the constant value of the literal
- ConstantResult Literal(const sem::Type* ty, const ast::LiteralExpression* lit);
+ Result Literal(const sem::Type* ty, const ast::LiteralExpression* lit);
/// @param obj the object being accessed
/// @param member the member
/// @return the result of the member access, or null if the value cannot be calculated
- ConstantResult MemberAccess(const sem::Expression* obj, const sem::StructMember* member);
+ Result MemberAccess(const sem::Expression* obj, const sem::StructMember* member);
/// @param ty the result type
/// @param vector the vector being swizzled
/// @param indices the swizzle indices
/// @return the result of the swizzle, or null if the value cannot be calculated
- ConstantResult Swizzle(const sem::Type* ty,
- const sem::Expression* vector,
- utils::VectorRef<uint32_t> indices);
+ Result Swizzle(const sem::Type* ty,
+ const sem::Expression* vector,
+ utils::VectorRef<uint32_t> indices);
/// Convert the `value` to `target_type`
/// @param ty the result type
/// @param value the value being converted
/// @param source the source location of the conversion
/// @return the converted value, or null if the value cannot be calculated
- ConstantResult Convert(const sem::Type* ty, const sem::Constant* value, const Source& source);
+ Result Convert(const sem::Type* ty, const sem::Constant* value, const Source& source);
////////////////////////////////////////////////////////////////////////////////////////////////
// Constant value evaluation methods, to be indirectly called via the intrinsic table
@@ -119,72 +118,72 @@
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the converted value, or null if the value cannot be calculated
- ConstantResult Conv(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result Conv(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Zero value type constructor
/// @param ty the result type
/// @param args the input arguments (no arguments provided)
/// @param source the source location of the conversion
/// @return the constructed value, or null if the value cannot be calculated
- ConstantResult Zero(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result Zero(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Identity value type constructor
/// @param ty the result type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the constructed value, or null if the value cannot be calculated
- ConstantResult Identity(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result Identity(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Vector splat constructor
/// @param ty the vector type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the constructed value, or null if the value cannot be calculated
- ConstantResult VecSplat(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result VecSplat(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Vector constructor using scalars
/// @param ty the vector type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the constructed value, or null if the value cannot be calculated
- ConstantResult VecCtorS(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result VecCtorS(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Vector constructor using a mix of scalars and smaller vectors
/// @param ty the vector type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the constructed value, or null if the value cannot be calculated
- ConstantResult VecCtorM(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result VecCtorM(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Matrix constructor using scalar values
/// @param ty the matrix type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the constructed value, or null if the value cannot be calculated
- ConstantResult MatCtorS(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result MatCtorS(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Matrix constructor using column vectors
/// @param ty the matrix type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the constructed value, or null if the value cannot be calculated
- ConstantResult MatCtorV(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result MatCtorV(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
////////////////////////////////////////////////////////////////////////////
// Unary Operators
@@ -195,27 +194,27 @@
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpComplement(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpComplement(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Unary minus operator '-'
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpUnaryMinus(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpUnaryMinus(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Unary not operator '!'
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpNot(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpNot(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
////////////////////////////////////////////////////////////////////////////
// Binary Operators
@@ -226,142 +225,142 @@
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpPlus(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpPlus(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Minus operator '-'
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpMinus(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpMinus(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Multiply operator '*' for the same type on the LHS and RHS
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpMultiply(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpMultiply(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Multiply operator '*' for matCxR<T> * vecC<T>
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpMultiplyMatVec(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpMultiplyMatVec(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Multiply operator '*' for vecR<T> * matCxR<T>
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpMultiplyVecMat(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpMultiplyVecMat(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Multiply operator '*' for matKxR<T> * matCxK<T>
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpMultiplyMatMat(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpMultiplyMatMat(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Divide operator '/'
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpDivide(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpDivide(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Equality operator '=='
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpEqual(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpEqual(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Inequality operator '!='
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpNotEqual(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpNotEqual(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Less than operator '<'
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpLessThan(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpLessThan(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Greater than operator '>'
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpGreaterThan(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpGreaterThan(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Less than or equal operator '<='
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpLessThanEqual(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpLessThanEqual(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Greater than or equal operator '>='
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpGreaterThanEqual(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpGreaterThanEqual(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Bitwise and operator '&'
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpAnd(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpAnd(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Bitwise or operator '|'
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpOr(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result OpOr(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// Bitwise xor operator '^'
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpXor(const sem::Type* ty,
+ Result OpXor(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
@@ -374,18 +373,18 @@
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult atan2(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result atan2(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
/// clamp builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult clamp(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ Result clamp(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
private:
/// Adds the given error message to the diagnostics
diff --git a/src/tint/utils/result.h b/src/tint/utils/result.h
index 6a14352..baca65f 100644
--- a/src/tint/utils/result.h
+++ b/src/tint/utils/result.h
@@ -17,6 +17,7 @@
#include <ostream>
#include <variant>
+
#include "src/tint/debug.h"
namespace tint::utils {
@@ -50,6 +51,20 @@
Result(const FAILURE_TYPE& failure) // NOLINT(runtime/explicit):
: value{failure} {}
+ /// Copy constructor with success / failure casting
+ /// @param other the Result to copy
+ template <typename S,
+ typename F,
+ typename = std::void_t<decltype(SUCCESS_TYPE{std::declval<S>()}),
+ decltype(FAILURE_TYPE{std::declval<F>()})>>
+ Result(const Result<S, F>& other) { // NOLINT(runtime/explicit):
+ if (other) {
+ value = SUCCESS_TYPE{other.Get()};
+ } else {
+ value = FAILURE_TYPE{other.Failure()};
+ }
+ }
+
/// @returns true if the result was a success
operator bool() const {
Validate();
diff --git a/src/tint/utils/result_test.cc b/src/tint/utils/result_test.cc
index ce125f4..6614028 100644
--- a/src/tint/utils/result_test.cc
+++ b/src/tint/utils/result_test.cc
@@ -51,5 +51,17 @@
EXPECT_EQ(r.Failure(), "oh noes!");
}
+TEST(ResultTest, ValueCast) {
+ struct X {};
+ struct Y : X {};
+
+ Y* y = nullptr;
+ auto r_y = Result<Y*>{y};
+ auto r_x = Result<X*>{r_y};
+
+ (void)r_x;
+ (void)r_y;
+}
+
} // namespace
} // namespace tint::utils