blob: 1a89a685f57e91e139bcd6157bfaf68e64d65e37 [file] [log] [blame]
// Copyright 2024 The Dawn & Tint Authors
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "src/tint/lang/core/ir/evaluator.h"
#include "src/tint/lang/core/ir/constant.h"
#include "src/tint/lang/core/type/matrix.h"
#include "src/tint/utils/rtti/switch.h"
namespace tint::core::ir {
namespace eval {
Result<core::ir::Constant*> Eval(core::ir::Builder& b, core::ir::Instruction* inst) {
return Eval(b, inst->Result(0));
}
Result<core::ir::Constant*> Eval(core::ir::Builder& b, core::ir::Value* val) {
ir::Evaluator e(b);
return e.Evaluate(val);
}
} // namespace eval
Evaluator::Evaluator(ir::Builder& builder)
: b_(builder), const_eval_(b_.ir.constant_values, diagnostics_) {}
Evaluator::~Evaluator() = default;
Result<core::ir::Constant*> Evaluator::Evaluate(core::ir::Value* src) {
auto res = EvalValue(src);
if (res != Success) {
return Failure(diagnostics_);
}
if (!res.Get()) {
return nullptr;
}
return b_.Constant(res.Get());
}
diag::Diagnostic& Evaluator::AddError(Source src) {
diag::Diagnostic diag;
diag.source = src;
return diagnostics_.Add(diag);
}
Source Evaluator::SourceOf(core::ir::Instruction* val) {
return b_.ir.SourceOf(val);
}
Evaluator::EvalResult Evaluator::EvalValue(core::ir::Value* val) {
return tint::Switch(
val, //
[&](core::ir::Constant* c) { return c->Value(); },
[&](core::ir::InstructionResult* r) {
return tint::Switch(
r->Instruction(), //
[&](core::ir::Bitcast* bc) { return EvalBitcast(bc); },
[&](core::ir::Access* a) { return EvalAccess(a); },
[&](core::ir::Construct* c) { return EvalConstruct(c); },
[&](core::ir::Convert* c) { return EvalConvert(c); },
[&](core::ir::CoreBinary* cb) { return EvalBinary(cb); },
[&](core::ir::CoreBuiltinCall* c) { return EvalCoreBuiltinCall(c); },
[&](core::ir::CoreUnary* u) { return EvalUnary(u); },
[&](core::ir::Swizzle* s) { return EvalSwizzle(s); }, //
[&](Default) {
// Treat any unknown instruction as a termination point for trying to eval.
return nullptr;
});
},
TINT_ICE_ON_NO_MATCH);
}
Evaluator::EvalResult Evaluator::EvalBitcast(core::ir::Bitcast* bc) {
auto val = EvalValue(bc->Val());
if (val != Success) {
return val;
}
// Check if the value could be evaluated
if (!val.Get()) {
return nullptr;
}
auto r = const_eval_.bitcast(bc->Result(0)->Type(), Vector{val.Get()}, SourceOf(bc));
if (r != Success) {
return Failure();
}
return r.Get();
}
Evaluator::EvalResult Evaluator::EvalAccess(core::ir::Access* a) {
auto obj_res = EvalValue(a->Object());
if (obj_res != Success) {
return obj_res;
}
// Check if the object could be evaluated
if (!obj_res.Get()) {
return nullptr;
}
auto* obj = obj_res.Get();
for (auto* idx : a->Indices()) {
auto val = EvalValue(idx);
if (val != Success) {
return val;
}
// Check if the value could be evaluated
if (!val.Get()) {
return nullptr;
}
TINT_ASSERT(val.Get()->Is<core::constant::Value>());
auto res = const_eval_.Index(obj, a->Result(0)->Type(), val.Get(), SourceOf(a));
if (res != Success) {
return Failure();
}
obj = res.Get();
}
return obj;
}
Evaluator::EvalResult Evaluator::EvalConstruct(core::ir::Construct* c) {
auto table = core::intrinsic::Table<core::intrinsic::Dialect>(b_.ir.Types(), b_.ir.symbols);
auto result_ty = c->Result(0)->Type();
Vector<const core::type::Type*, 4> arg_types;
arg_types.Reserve(c->Args().Length());
Vector<const core::constant::Value*, 4> arg_values;
arg_values.Reserve(c->Args().Length());
for (auto* arg : c->Args()) {
arg_types.Push(arg->Type());
auto val = EvalValue(arg);
if (val != Success) {
return val;
}
// Check if the value could be evaluated
if (!val.Get()) {
return nullptr;
}
arg_values.Push(val.Get());
}
auto mat_vec = [&](const core::type::Type* type,
core::intrinsic::CtorConv intrinsic) -> constant::Eval::Result {
auto op =
table.Lookup(intrinsic, Vector{type}, arg_types, core::EvaluationStage::kOverride);
if (op != Success) {
AddError(SourceOf(c)) << "unable to find intrinsic for construct: " << op.Failure();
return constant::Eval::Error();
}
if (!op->const_eval_fn) {
AddError(SourceOf(c)) << "unhandled type constructor";
return constant::Eval::Error();
}
auto r = (const_eval_.*op->const_eval_fn)(result_ty, arg_values, SourceOf(c));
if (r != Success) {
return constant::Eval::Error();
}
return r.Get();
};
// Dispatch to the appropriate const eval function.
auto r = tint::Switch(
result_ty, //
[&](const core::type::Array*) {
return const_eval_.ArrayOrStructCtor(result_ty, arg_values);
},
[&](const core::type::Struct*) {
return const_eval_.ArrayOrStructCtor(result_ty, arg_values);
},
[&](const core::type::Vector* vec) {
return mat_vec(vec->Type(), core::intrinsic::VectorCtorConv(vec->Width()));
},
[&](const core::type::Matrix* mat) {
return mat_vec(mat->Type(),
core::intrinsic::MatrixCtorConv(mat->Columns(), mat->Rows()));
},
[&](Default) {
if (!result_ty->Is<core::type::Scalar>()) {
AddError(SourceOf(c)) << "unhandled type constructor";
return core::constant::Eval::Result(nullptr);
}
// For scalars, this must be an identity constructor.
if (arg_values[0]->Type() != result_ty) {
AddError(SourceOf(c)) << "invalid type constructor";
return core::constant::Eval::Result(nullptr);
}
return const_eval_.Identity(result_ty, arg_values, SourceOf(c));
});
if (r != Success) {
return Failure();
}
return r.Get();
}
Evaluator::EvalResult Evaluator::EvalConvert(core::ir::Convert* c) {
auto val = EvalValue(c->Args()[0]);
if (val != Success) {
return val;
}
// Check if the value could be evaluated
if (!val.Get()) {
return nullptr;
}
auto r = const_eval_.Convert(c->Result(0)->Type(), val.Get(), SourceOf(c));
if (r != Success) {
return Failure();
}
return r.Get();
}
Evaluator::EvalResult Evaluator::EvalSwizzle(core::ir::Swizzle* s) {
auto val = EvalValue(s->Object());
if (val != Success) {
return val;
}
// Check if the value could be evaluated
if (!val.Get()) {
return nullptr;
}
auto r = const_eval_.Swizzle(s->Result(0)->Type(), val.Get(), s->Indices());
if (r != Success) {
return Failure();
}
return r.Get();
}
Evaluator::EvalResult Evaluator::EvalUnary(core::ir::CoreUnary* u) {
intrinsic::Context context{u->TableData(), b_.ir.Types(), b_.ir.symbols};
auto overload = core::intrinsic::LookupUnary(context, u->Op(), u->Val()->Type(),
core::EvaluationStage::kOverride);
if (overload != Success) {
AddError(SourceOf(u)) << overload.Failure().Plain();
return Failure();
}
auto const_eval_fn = overload->const_eval_fn;
if (!const_eval_fn) {
AddError(SourceOf(u)) << "invalid unary expression";
return Failure();
}
auto val = EvalValue(u->Val());
if (val != Success) {
return Failure();
}
// Check if the value could be evaluated
if (!val.Get()) {
return nullptr;
}
auto r = (const_eval_.*const_eval_fn)(u->Result(0)->Type(), Vector{val.Get()}, SourceOf(u));
if (r != Success) {
return Failure();
}
return r.Get();
}
Evaluator::EvalResult Evaluator::EvalBinary(core::ir::CoreBinary* cb) {
intrinsic::Context context{cb->TableData(), b_.ir.Types(), b_.ir.symbols};
auto overload =
core::intrinsic::LookupBinary(context, cb->Op(), cb->LHS()->Type(), cb->RHS()->Type(),
core::EvaluationStage::kOverride, /* is_compound */ false);
if (overload != Success) {
AddError(SourceOf(cb)) << overload.Failure().Plain();
return Failure();
}
auto const_eval_fn = overload->const_eval_fn;
if (!const_eval_fn) {
AddError(SourceOf(cb)) << "invalid binary expression";
return Failure();
}
auto lhs = EvalValue(cb->LHS());
if (lhs != Success) {
return lhs;
}
// Check LHS could be evaluated
if (!lhs.Get()) {
return nullptr;
}
auto rhs = EvalValue(cb->RHS());
if (rhs != Success) {
return rhs;
}
// Check RHS could be evaluated
if (!rhs.Get()) {
return nullptr;
}
auto r = (const_eval_.*const_eval_fn)(cb->Result(0)->Type(), Vector{lhs.Get(), rhs.Get()},
SourceOf(cb));
if (r != Success) {
return Failure();
}
return r.Get();
}
Evaluator::EvalResult Evaluator::EvalCoreBuiltinCall(core::ir::CoreBuiltinCall* c) {
intrinsic::Context context{c->TableData(), b_.ir.Types(), b_.ir.symbols};
Vector<const core::type::Type*, 0> arg_types;
arg_types.Reserve(c->Args().Length());
Vector<const core::constant::Value*, 0> args;
args.Reserve(c->Args().Length());
for (auto* arg : c->Args()) {
arg_types.Push(arg->Type());
auto val = EvalValue(arg);
if (val != Success) {
return val;
}
// Check if the value could be evaluated
if (!val.Get()) {
return nullptr;
}
args.Push(val.Get());
}
auto overload = core::intrinsic::LookupFn(context, c->FriendlyName().c_str(), c->FuncId(),
Empty, arg_types, core::EvaluationStage::kOverride);
if (overload != Success) {
AddError(SourceOf(c)) << overload.Failure();
return Failure();
}
// If there is no `@const` override, we don't fail the eval, we return a nullptr. This is
// because we can call eval for things like `dpdx` which is not overridable but that's not an
// eval failure, we just don't eval.
auto const_eval_fn = overload->const_eval_fn;
if (!const_eval_fn) {
return nullptr;
}
auto r = (const_eval_.*const_eval_fn)(c->Result(0)->Type(), args, SourceOf(c));
if (r != Success) {
return Failure();
}
return r.Get();
}
} // namespace tint::core::ir