blob: 27fe8b48d22466716917cc4634d937c3e1da3572 [file] [log] [blame]
// Copyright 2020 The Tint Authors.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/lang/spirv/writer/ast_printer/builder.h"
#include <algorithm>
#include <limits>
#include <utility>
#include "spirv/unified1/GLSL.std.450.h"
#include "src/tint/lang/core/constant/value.h"
#include "src/tint/lang/core/fluent_types.h"
#include "src/tint/lang/core/type/array.h"
#include "src/tint/lang/core/type/atomic.h"
#include "src/tint/lang/core/type/depth_multisampled_texture.h"
#include "src/tint/lang/core/type/depth_texture.h"
#include "src/tint/lang/core/type/multisampled_texture.h"
#include "src/tint/lang/core/type/reference.h"
#include "src/tint/lang/core/type/sampled_texture.h"
#include "src/tint/lang/core/type/texture_dimension.h"
#include "src/tint/lang/core/type/vector.h"
#include "src/tint/lang/wgsl/ast/call_statement.h"
#include "src/tint/lang/wgsl/ast/id_attribute.h"
#include "src/tint/lang/wgsl/ast/internal_attribute.h"
#include "src/tint/lang/wgsl/ast/transform/add_block_attribute.h"
#include "src/tint/lang/wgsl/ast/traverse_expressions.h"
#include "src/tint/lang/wgsl/helpers/append_vector.h"
#include "src/tint/lang/wgsl/helpers/check_supported_extensions.h"
#include "src/tint/lang/wgsl/sem/builtin.h"
#include "src/tint/lang/wgsl/sem/call.h"
#include "src/tint/lang/wgsl/sem/function.h"
#include "src/tint/lang/wgsl/sem/load.h"
#include "src/tint/lang/wgsl/sem/materialize.h"
#include "src/tint/lang/wgsl/sem/member_accessor_expression.h"
#include "src/tint/lang/wgsl/sem/module.h"
#include "src/tint/lang/wgsl/sem/statement.h"
#include "src/tint/lang/wgsl/sem/struct.h"
#include "src/tint/lang/wgsl/sem/switch_statement.h"
#include "src/tint/lang/wgsl/sem/value_constructor.h"
#include "src/tint/lang/wgsl/sem/value_conversion.h"
#include "src/tint/lang/wgsl/sem/variable.h"
#include "src/tint/utils/containers/map.h"
#include "src/tint/utils/macros/compiler.h"
#include "src/tint/utils/macros/defer.h"
#include "src/tint/utils/text/string_stream.h"
using namespace tint::core::fluent_types; // NOLINT
namespace tint::spirv::writer {
namespace {
const char kGLSLstd450[] = "GLSL.std.450";
uint32_t pipeline_stage_to_execution_model(ast::PipelineStage stage) {
SpvExecutionModel model = SpvExecutionModelVertex;
switch (stage) {
case ast::PipelineStage::kFragment:
model = SpvExecutionModelFragment;
break;
case ast::PipelineStage::kVertex:
model = SpvExecutionModelVertex;
break;
case ast::PipelineStage::kCompute:
model = SpvExecutionModelGLCompute;
break;
case ast::PipelineStage::kNone:
model = SpvExecutionModelMax;
break;
}
return model;
}
/// Returns the matrix type that is `type` or that is wrapped by
/// one or more levels of an arrays inside of `type`.
/// @param type the given type, which must not be null
/// @returns the nested matrix type, or nullptr if none
const core::type::Matrix* GetNestedMatrixType(const core::type::Type* type) {
while (auto* arr = type->As<core::type::Array>()) {
type = arr->ElemType();
}
return type->As<core::type::Matrix>();
}
uint32_t builtin_to_glsl_method(const sem::Builtin* builtin) {
switch (builtin->Type()) {
case core::Function::kAcos:
return GLSLstd450Acos;
case core::Function::kAcosh:
return GLSLstd450Acosh;
case core::Function::kAsin:
return GLSLstd450Asin;
case core::Function::kAsinh:
return GLSLstd450Asinh;
case core::Function::kAtan:
return GLSLstd450Atan;
case core::Function::kAtan2:
return GLSLstd450Atan2;
case core::Function::kAtanh:
return GLSLstd450Atanh;
case core::Function::kCeil:
return GLSLstd450Ceil;
case core::Function::kClamp:
if (builtin->ReturnType()->is_float_scalar_or_vector()) {
return GLSLstd450NClamp;
} else if (builtin->ReturnType()->is_unsigned_integer_scalar_or_vector()) {
return GLSLstd450UClamp;
} else {
return GLSLstd450SClamp;
}
case core::Function::kCos:
return GLSLstd450Cos;
case core::Function::kCosh:
return GLSLstd450Cosh;
case core::Function::kCross:
return GLSLstd450Cross;
case core::Function::kDegrees:
return GLSLstd450Degrees;
case core::Function::kDeterminant:
return GLSLstd450Determinant;
case core::Function::kDistance:
return GLSLstd450Distance;
case core::Function::kExp:
return GLSLstd450Exp;
case core::Function::kExp2:
return GLSLstd450Exp2;
case core::Function::kFaceForward:
return GLSLstd450FaceForward;
case core::Function::kFloor:
return GLSLstd450Floor;
case core::Function::kFma:
return GLSLstd450Fma;
case core::Function::kFract:
return GLSLstd450Fract;
case core::Function::kFrexp:
return GLSLstd450FrexpStruct;
case core::Function::kInverseSqrt:
return GLSLstd450InverseSqrt;
case core::Function::kLdexp:
return GLSLstd450Ldexp;
case core::Function::kLength:
return GLSLstd450Length;
case core::Function::kLog:
return GLSLstd450Log;
case core::Function::kLog2:
return GLSLstd450Log2;
case core::Function::kMax:
if (builtin->ReturnType()->is_float_scalar_or_vector()) {
return GLSLstd450NMax;
} else if (builtin->ReturnType()->is_unsigned_integer_scalar_or_vector()) {
return GLSLstd450UMax;
} else {
return GLSLstd450SMax;
}
case core::Function::kMin:
if (builtin->ReturnType()->is_float_scalar_or_vector()) {
return GLSLstd450NMin;
} else if (builtin->ReturnType()->is_unsigned_integer_scalar_or_vector()) {
return GLSLstd450UMin;
} else {
return GLSLstd450SMin;
}
case core::Function::kMix:
return GLSLstd450FMix;
case core::Function::kModf:
return GLSLstd450ModfStruct;
case core::Function::kNormalize:
return GLSLstd450Normalize;
case core::Function::kPack4X8Snorm:
return GLSLstd450PackSnorm4x8;
case core::Function::kPack4X8Unorm:
return GLSLstd450PackUnorm4x8;
case core::Function::kPack2X16Snorm:
return GLSLstd450PackSnorm2x16;
case core::Function::kPack2X16Unorm:
return GLSLstd450PackUnorm2x16;
case core::Function::kPack2X16Float:
return GLSLstd450PackHalf2x16;
case core::Function::kPow:
return GLSLstd450Pow;
case core::Function::kRadians:
return GLSLstd450Radians;
case core::Function::kReflect:
return GLSLstd450Reflect;
case core::Function::kRefract:
return GLSLstd450Refract;
case core::Function::kRound:
return GLSLstd450RoundEven;
case core::Function::kSign:
if (builtin->ReturnType()->is_signed_integer_scalar_or_vector()) {
return GLSLstd450SSign;
} else {
return GLSLstd450FSign;
}
case core::Function::kSin:
return GLSLstd450Sin;
case core::Function::kSinh:
return GLSLstd450Sinh;
case core::Function::kSmoothstep:
return GLSLstd450SmoothStep;
case core::Function::kSqrt:
return GLSLstd450Sqrt;
case core::Function::kStep:
return GLSLstd450Step;
case core::Function::kTan:
return GLSLstd450Tan;
case core::Function::kTanh:
return GLSLstd450Tanh;
case core::Function::kTrunc:
return GLSLstd450Trunc;
case core::Function::kUnpack4X8Snorm:
return GLSLstd450UnpackSnorm4x8;
case core::Function::kUnpack4X8Unorm:
return GLSLstd450UnpackUnorm4x8;
case core::Function::kUnpack2X16Snorm:
return GLSLstd450UnpackSnorm2x16;
case core::Function::kUnpack2X16Unorm:
return GLSLstd450UnpackUnorm2x16;
case core::Function::kUnpack2X16Float:
return GLSLstd450UnpackHalf2x16;
default:
break;
}
return 0;
}
/// @return the vector element type if ty is a vector, otherwise return ty.
const core::type::Type* ElementTypeOf(const core::type::Type* ty) {
if (auto* v = ty->As<core::type::Vector>()) {
return v->type();
}
return ty;
}
} // namespace
Builder::AccessorInfo::AccessorInfo() : source_id(0), source_type(nullptr) {}
Builder::AccessorInfo::~AccessorInfo() {}
Builder::Builder(const Program* program,
bool zero_initialize_workgroup_memory,
bool experimental_require_subgroup_uniform_control_flow)
: builder_(ProgramBuilder::Wrap(program)),
scope_stack_{Scope{}},
zero_initialize_workgroup_memory_(zero_initialize_workgroup_memory),
experimental_require_subgroup_uniform_control_flow_(
experimental_require_subgroup_uniform_control_flow) {}
Builder::~Builder() = default;
bool Builder::Build() {
if (!tint::writer::CheckSupportedExtensions(
"SPIR-V", builder_.AST(), builder_.Diagnostics(),
Vector{
wgsl::Extension::kChromiumDisableUniformityAnalysis,
wgsl::Extension::kChromiumExperimentalDp4A,
wgsl::Extension::kChromiumExperimentalFullPtrParameters,
wgsl::Extension::kChromiumExperimentalPushConstant,
wgsl::Extension::kChromiumExperimentalReadWriteStorageTexture,
wgsl::Extension::kChromiumExperimentalSubgroups,
wgsl::Extension::kF16,
wgsl::Extension::kChromiumInternalDualSourceBlending,
})) {
return false;
}
module_.PushCapability(SpvCapabilityShader);
module_.PushMemoryModel(spv::Op::OpMemoryModel, {U32Operand(SpvAddressingModelLogical),
U32Operand(SpvMemoryModelGLSL450)});
for (auto ext : builder_.Sem().Module()->Extensions()) {
GenerateExtension(ext);
}
// Emit SPV_KHR_subgroup_uniform_control_flow extension if required.
if (experimental_require_subgroup_uniform_control_flow_) {
module_.PushExtension("SPV_KHR_subgroup_uniform_control_flow");
}
for (auto* var : builder_.AST().GlobalVariables()) {
if (!GenerateGlobalVariable(var)) {
return false;
}
}
auto* mod = builder_.Sem().Module();
for (auto* decl : mod->DependencyOrderedDeclarations()) {
if (auto* func = decl->As<ast::Function>()) {
if (!GenerateFunction(func)) {
return false;
}
}
}
return true;
}
void Builder::RegisterVariable(const sem::Variable* var, uint32_t id) {
var_to_id_.emplace(var, id);
id_to_var_.emplace(id, var);
}
uint32_t Builder::LookupVariableID(const sem::Variable* var) {
auto it = var_to_id_.find(var);
if (it == var_to_id_.end()) {
TINT_ICE() << "unable to find ID for variable: " + var->Declaration()->name->symbol.Name();
return 0;
}
return it->second;
}
void Builder::PushScope() {
// Push a new scope, by copying the top-most stack
scope_stack_.push_back(scope_stack_.back());
}
void Builder::PopScope() {
scope_stack_.pop_back();
}
Operand Builder::result_op() {
return Operand(module_.NextId());
}
bool Builder::GenerateExtension(wgsl::Extension extension) {
switch (extension) {
case wgsl::Extension::kChromiumExperimentalDp4A:
module_.PushExtension("SPV_KHR_integer_dot_product");
module_.PushCapability(SpvCapabilityDotProductKHR);
module_.PushCapability(SpvCapabilityDotProductInput4x8BitPackedKHR);
break;
case wgsl::Extension::kF16:
module_.PushCapability(SpvCapabilityFloat16);
module_.PushCapability(SpvCapabilityUniformAndStorageBuffer16BitAccess);
module_.PushCapability(SpvCapabilityStorageBuffer16BitAccess);
module_.PushCapability(SpvCapabilityStorageInputOutput16);
break;
default:
return false;
}
return true;
}
bool Builder::GenerateLabel(uint32_t id) {
if (!push_function_inst(spv::Op::OpLabel, {Operand(id)})) {
return false;
}
current_label_id_ = id;
return true;
}
bool Builder::GenerateAssignStatement(const ast::AssignmentStatement* assign) {
if (assign->lhs->Is<ast::PhonyExpression>()) {
if (builder_.Sem().GetVal(assign->rhs)->ConstantValue()) {
// RHS of phony assignment is constant.
// Constants can't have side-effects, so just drop this.
return true;
}
auto rhs_id = GenerateExpression(assign->rhs);
if (rhs_id == 0) {
return false;
}
return true;
} else {
auto lhs_id = GenerateExpression(assign->lhs);
if (lhs_id == 0) {
return false;
}
auto rhs_id = GenerateExpression(assign->rhs);
if (rhs_id == 0) {
return false;
}
return GenerateStore(lhs_id, rhs_id);
}
}
bool Builder::GenerateBreakStatement(const ast::BreakStatement*) {
if (merge_stack_.empty()) {
TINT_ICE() << "Attempted to break without a merge block";
return false;
}
if (!push_function_inst(spv::Op::OpBranch, {Operand(merge_stack_.back())})) {
return false;
}
return true;
}
bool Builder::GenerateBreakIfStatement(const ast::BreakIfStatement* stmt) {
TINT_ASSERT(!backedge_stack_.empty());
const auto cond_id = GenerateExpression(stmt->condition);
if (!cond_id) {
return false;
}
const ContinuingInfo& ci = continuing_stack_.back();
backedge_stack_.back() =
Backedge(spv::Op::OpBranchConditional,
{Operand(cond_id), Operand(ci.break_target_id), Operand(ci.loop_header_id)});
return true;
}
bool Builder::GenerateContinueStatement(const ast::ContinueStatement*) {
if (continue_stack_.empty()) {
TINT_ICE() << "Attempted to continue without a continue block";
return false;
}
if (!push_function_inst(spv::Op::OpBranch, {Operand(continue_stack_.back())})) {
return false;
}
return true;
}
// TODO(dsinclair): This is generating an OpKill but the semantics of kill
// haven't been defined for WGSL yet. So, this may need to change.
// https://github.com/gpuweb/gpuweb/issues/676
bool Builder::GenerateDiscardStatement(const ast::DiscardStatement*) {
if (!push_function_inst(spv::Op::OpKill, {})) {
return false;
}
return true;
}
bool Builder::GenerateEntryPoint(const ast::Function* func, uint32_t id) {
auto stage = pipeline_stage_to_execution_model(func->PipelineStage());
if (stage == SpvExecutionModelMax) {
TINT_ICE() << "Unknown pipeline stage provided";
return false;
}
OperandList operands = {Operand(stage), Operand(id), Operand(func->name->symbol.Name())};
auto* func_sem = builder_.Sem().Get(func);
for (const auto* var : func_sem->TransitivelyReferencedGlobals()) {
// For SPIR-V 1.3 we only output Input/output variables. If we update to
// SPIR-V 1.4 or later this should be all variables.
if (var->AddressSpace() != core::AddressSpace::kIn &&
var->AddressSpace() != core::AddressSpace::kOut) {
continue;
}
uint32_t var_id = LookupVariableID(var);
if (var_id == 0) {
TINT_ICE() << "unable to find ID for global variable: " +
var->Declaration()->name->symbol.Name();
return false;
}
operands.push_back(Operand(var_id));
}
module_.PushEntryPoint(spv::Op::OpEntryPoint, operands);
return true;
}
bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) {
auto* func_sem = builder_.Sem().Get(func);
// WGSL fragment shader origin is upper left
if (func->PipelineStage() == ast::PipelineStage::kFragment) {
module_.PushExecutionMode(spv::Op::OpExecutionMode,
{Operand(id), U32Operand(SpvExecutionModeOriginUpperLeft)});
} else if (func->PipelineStage() == ast::PipelineStage::kCompute) {
auto& wgsize = func_sem->WorkgroupSize();
// Check if the workgroup_size uses pipeline-overridable constants.
if (!wgsize[0].has_value() || !wgsize[1].has_value() || !wgsize[2].has_value()) {
TINT_ICE()
<< "override-expressions should have been removed with the SubstituteOverride "
"transform";
return false;
}
module_.PushExecutionMode(
spv::Op::OpExecutionMode,
{Operand(id), U32Operand(SpvExecutionModeLocalSize), //
Operand(wgsize[0].value()), Operand(wgsize[1].value()), Operand(wgsize[2].value())});
}
for (auto it : func_sem->TransitivelyReferencedBuiltinVariables()) {
auto builtin = builder_.Sem().Get(it.second)->Value();
if (builtin == core::BuiltinValue::kFragDepth) {
module_.PushExecutionMode(spv::Op::OpExecutionMode,
{Operand(id), U32Operand(SpvExecutionModeDepthReplacing)});
break;
}
}
// Use SubgroupUniformControlFlow execution mode for compute stage if required.
if (experimental_require_subgroup_uniform_control_flow_ &&
func->PipelineStage() == ast::PipelineStage::kCompute) {
module_.PushExecutionMode(
spv::Op::OpExecutionMode,
{Operand(id), U32Operand(SpvExecutionModeSubgroupUniformControlFlowKHR)});
}
return true;
}
uint32_t Builder::GenerateExpression(const sem::Expression* expr) {
if (auto* val_expr = expr->As<sem::ValueExpression>()) {
if (auto* constant = val_expr->ConstantValue()) {
return GenerateConstantIfNeeded(constant);
}
}
if (auto* load = expr->As<sem::Load>()) {
auto ref_id = GenerateExpression(load->Reference());
if (ref_id == 0) {
return 0;
}
return GenerateLoad(load->ReferenceType(), ref_id);
}
return Switch(
expr->Declaration(), //
[&](const ast::AccessorExpression* a) { return GenerateAccessorExpression(a); },
[&](const ast::BinaryExpression* b) { return GenerateBinaryExpression(b); },
[&](const ast::BitcastExpression* b) { return GenerateBitcastExpression(b); },
[&](const ast::CallExpression* c) { return GenerateCallExpression(c); },
[&](const ast::IdentifierExpression* i) { return GenerateIdentifierExpression(i); },
[&](const ast::LiteralExpression* l) { return GenerateLiteralIfNeeded(l); },
[&](const ast::UnaryOpExpression* u) { return GenerateUnaryOpExpression(u); },
[&](Default) {
TINT_ICE() << "unknown expression type: " + std::string(expr->TypeInfo().name);
return 0;
});
}
uint32_t Builder::GenerateExpression(const ast::Expression* expr) {
return GenerateExpression(builder_.Sem().Get(expr));
}
bool Builder::GenerateFunction(const ast::Function* func_ast) {
auto* func = builder_.Sem().Get(func_ast);
uint32_t func_type_id = GenerateFunctionTypeIfNeeded(func);
if (func_type_id == 0) {
return false;
}
auto func_op = result_op();
auto func_id = std::get<uint32_t>(func_op);
module_.PushDebug(spv::Op::OpName, {Operand(func_id), Operand(func_ast->name->symbol.Name())});
auto ret_id = GenerateTypeIfNeeded(func->ReturnType());
if (ret_id == 0) {
return false;
}
PushScope();
TINT_DEFER(PopScope());
auto definition_inst = Instruction{
spv::Op::OpFunction,
{Operand(ret_id), func_op, U32Operand(SpvFunctionControlMaskNone), Operand(func_type_id)}};
InstructionList params;
for (auto* param : func->Parameters()) {
auto param_op = result_op();
auto param_id = std::get<uint32_t>(param_op);
auto param_type_id = GenerateTypeIfNeeded(param->Type());
if (param_type_id == 0) {
return false;
}
module_.PushDebug(spv::Op::OpName,
{Operand(param_id), Operand(param->Declaration()->name->symbol.Name())});
params.push_back(
Instruction{spv::Op::OpFunctionParameter, {Operand(param_type_id), param_op}});
RegisterVariable(param, param_id);
}
// Start a new function.
current_function_ = Function{definition_inst, result_op(), std::move(params)};
current_label_id_ = current_function_.label_id();
TINT_DEFER(current_function_ = Function());
for (auto* stmt : func_ast->body->statements) {
if (!GenerateStatement(stmt)) {
return false;
}
}
if (InsideBasicBlock()) {
if (func->ReturnType()->Is<core::type::Void>()) {
push_function_inst(spv::Op::OpReturn, {});
} else {
auto zero = GenerateConstantNullIfNeeded(func->ReturnType());
push_function_inst(spv::Op::OpReturnValue, {Operand(zero)});
}
}
if (func_ast->IsEntryPoint()) {
if (!GenerateEntryPoint(func_ast, func_id)) {
return false;
}
if (!GenerateExecutionModes(func_ast, func_id)) {
return false;
}
}
func_symbol_to_id_[func_ast->name->symbol] = func_id;
// Add the function to the module.
module_.PushFunction(std::move(current_function_));
return true;
}
uint32_t Builder::GenerateFunctionTypeIfNeeded(const sem::Function* func) {
return tint::GetOrCreate(func_sig_to_id_, func->Signature(), [&]() -> uint32_t {
auto func_op = result_op();
auto func_type_id = std::get<uint32_t>(func_op);
auto ret_id = GenerateTypeIfNeeded(func->ReturnType());
if (ret_id == 0) {
return 0;
}
OperandList ops = {func_op, Operand(ret_id)};
for (auto* param : func->Parameters()) {
auto param_type_id = GenerateTypeIfNeeded(param->Type());
if (param_type_id == 0) {
return 0;
}
ops.push_back(Operand(param_type_id));
}
module_.PushType(spv::Op::OpTypeFunction, std::move(ops));
return func_type_id;
});
}
bool Builder::GenerateFunctionVariable(const ast::Variable* v) {
if (v->Is<ast::Const>()) {
// Constants are generated at their use. This is required as the 'const' declaration may be
// abstract-numeric, which has no SPIR-V type.
return true;
}
uint32_t init_id = 0;
if (v->initializer) {
init_id = GenerateExpression(v->initializer);
if (init_id == 0) {
return false;
}
}
auto* sem = builder_.Sem().Get(v);
if (v->Is<ast::Let>()) {
if (!v->initializer) {
TINT_ICE() << "missing initializer for let";
return false;
}
RegisterVariable(sem, init_id);
return true;
}
auto result = result_op();
auto var_id = std::get<uint32_t>(result);
auto sc = core::AddressSpace::kFunction;
auto* type = sem->Type();
auto type_id = GenerateTypeIfNeeded(type);
if (type_id == 0) {
return false;
}
module_.PushDebug(spv::Op::OpName, {Operand(var_id), Operand(v->name->symbol.Name())});
// TODO(dsinclair) We could detect if the initializer is fully const and emit
// an initializer value for the variable instead of doing the OpLoad.
auto null_id = GenerateConstantNullIfNeeded(type->UnwrapRef());
if (null_id == 0) {
return 0;
}
push_function_var(
{Operand(type_id), result, U32Operand(ConvertAddressSpace(sc)), Operand(null_id)});
if (v->initializer) {
if (!GenerateStore(var_id, init_id)) {
return false;
}
}
RegisterVariable(sem, var_id);
return true;
}
bool Builder::GenerateStore(uint32_t to, uint32_t from) {
return push_function_inst(spv::Op::OpStore, {Operand(to), Operand(from)});
}
bool Builder::GenerateGlobalVariable(const ast::Variable* v) {
if (v->Is<ast::Const>()) {
// Constants are generated at their use. This is required as the 'const' declaration may be
// abstract-numeric, which has no SPIR-V type.
return true;
}
auto* sem = builder_.Sem().Get<sem::GlobalVariable>(v);
if (TINT_UNLIKELY(!sem)) {
TINT_ICE() << "attempted to generate a global from a non-global variable";
return false;
}
auto* type = sem->Type()->UnwrapRef();
uint32_t init_id = 0;
if (auto* ctor = v->initializer) {
init_id = GenerateConstructorExpression(v, ctor);
if (init_id == 0) {
return false;
}
}
auto result = result_op();
auto var_id = std::get<uint32_t>(result);
auto sc = sem->AddressSpace() == core::AddressSpace::kUndefined ? core::AddressSpace::kPrivate
: sem->AddressSpace();
auto type_id = GenerateTypeIfNeeded(sem->Type());
if (type_id == 0) {
return false;
}
module_.PushDebug(spv::Op::OpName, {Operand(var_id), Operand(v->name->symbol.Name())});
OperandList ops = {Operand(type_id), result, U32Operand(ConvertAddressSpace(sc))};
if (v->initializer) {
ops.push_back(Operand(init_id));
} else {
auto* st = type->As<core::type::StorageTexture>();
if (st || type->Is<core::type::Struct>()) {
// type is a type::Struct or a type::StorageTexture
auto access = st ? st->access() : sem->Access();
switch (access) {
case core::Access::kWrite:
module_.PushAnnot(spv::Op::OpDecorate,
{Operand(var_id), U32Operand(SpvDecorationNonReadable)});
break;
case core::Access::kRead:
module_.PushAnnot(spv::Op::OpDecorate,
{Operand(var_id), U32Operand(SpvDecorationNonWritable)});
break;
case core::Access::kUndefined:
case core::Access::kReadWrite:
break;
}
}
if (!type->Is<core::type::Sampler>()) {
// If we don't have a initializer and we're an Output or Private
// variable, then WGSL requires that we zero-initialize.
// If we're a Workgroup variable, and the
// VK_KHR_zero_initialize_workgroup_memory extension is enabled, we should
// also zero-initialize.
if (sem->AddressSpace() == core::AddressSpace::kPrivate ||
sem->AddressSpace() == core::AddressSpace::kOut ||
(zero_initialize_workgroup_memory_ &&
sem->AddressSpace() == core::AddressSpace::kWorkgroup)) {
init_id = GenerateConstantNullIfNeeded(type);
if (init_id == 0) {
return 0;
}
ops.push_back(Operand(init_id));
}
}
}
module_.PushType(spv::Op::OpVariable, std::move(ops));
for (auto* attr : v->attributes) {
bool ok = Switch(
attr,
[&](const ast::BuiltinAttribute* builtin_attr) {
auto builtin = builder_.Sem().Get(builtin_attr)->Value();
module_.PushAnnot(spv::Op::OpDecorate,
{Operand(var_id), U32Operand(SpvDecorationBuiltIn),
U32Operand(ConvertBuiltin(builtin, sem->AddressSpace()))});
return true;
},
[&](const ast::LocationAttribute*) {
module_.PushAnnot(spv::Op::OpDecorate,
{Operand(var_id), U32Operand(SpvDecorationLocation),
Operand(sem->Location().value())});
return true;
},
[&](const ast::IndexAttribute*) {
module_.PushAnnot(spv::Op::OpDecorate,
{Operand(var_id), U32Operand(SpvDecorationIndex),
Operand(sem->Index().value())});
return true;
},
[&](const ast::InterpolateAttribute* interpolate) {
auto& s = builder_.Sem();
auto i_type =
s.Get<sem::BuiltinEnumExpression<core::InterpolationType>>(interpolate->type)
->Value();
auto i_smpl = core::InterpolationSampling::kUndefined;
if (interpolate->sampling) {
i_smpl = s.Get<sem::BuiltinEnumExpression<core::InterpolationSampling>>(
interpolate->sampling)
->Value();
}
AddInterpolationDecorations(var_id, i_type, i_smpl);
return true;
},
[&](const ast::InvariantAttribute*) {
module_.PushAnnot(spv::Op::OpDecorate,
{Operand(var_id), U32Operand(SpvDecorationInvariant)});
return true;
},
[&](const ast::BindingAttribute*) {
auto bp = sem->BindingPoint();
module_.PushAnnot(
spv::Op::OpDecorate,
{Operand(var_id), U32Operand(SpvDecorationBinding), Operand(bp->binding)});
return true;
},
[&](const ast::GroupAttribute*) {
auto bp = sem->BindingPoint();
module_.PushAnnot(
spv::Op::OpDecorate,
{Operand(var_id), U32Operand(SpvDecorationDescriptorSet), Operand(bp->group)});
return true;
},
[&](const ast::IdAttribute*) {
return true; // Spec constants are handled elsewhere
},
[&](const ast::InternalAttribute*) {
return true; // ignored
},
[&](Default) {
TINT_ICE() << "unknown attribute";
return false;
});
if (!ok) {
return false;
}
}
RegisterVariable(sem, var_id);
return true;
}
bool Builder::GenerateIndexAccessor(const ast::IndexAccessorExpression* expr, AccessorInfo* info) {
auto idx_id = GenerateExpression(expr->index);
if (idx_id == 0) {
return 0;
}
// If the source is a reference, we access chain into it.
// In the future, pointers may support access-chaining.
// See https://github.com/gpuweb/gpuweb/pull/1580
if (info->source_type->Is<core::type::Reference>()) {
info->access_chain_indices.push_back(idx_id);
info->source_type = builder_.Sem().Get(expr)->UnwrapLoad()->Type();
return true;
}
auto result_type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (result_type_id == 0) {
return false;
}
// We don't have a pointer, so we can just directly extract the value.
auto extract = result_op();
auto extract_id = std::get<uint32_t>(extract);
// If the index is compile-time constant, we use OpCompositeExtract.
auto* idx = builder_.Sem().GetVal(expr->index);
if (auto idx_constval = idx->ConstantValue()) {
if (!push_function_inst(spv::Op::OpCompositeExtract,
{
Operand(result_type_id),
extract,
Operand(info->source_id),
Operand(idx_constval->ValueAs<u32>()),
})) {
return false;
}
info->source_id = extract_id;
info->source_type = TypeOf(expr);
return true;
}
// If the source is a vector, we use OpVectorExtractDynamic.
if (TINT_LIKELY(info->source_type->Is<core::type::Vector>())) {
if (!push_function_inst(
spv::Op::OpVectorExtractDynamic,
{Operand(result_type_id), extract, Operand(info->source_id), Operand(idx_id)})) {
return false;
}
info->source_id = extract_id;
info->source_type = TypeOf(expr);
return true;
}
TINT_ICE() << "unsupported index accessor expression";
return false;
}
bool Builder::GenerateMemberAccessor(const ast::MemberAccessorExpression* expr,
AccessorInfo* info) {
auto* expr_sem = builder_.Sem().Get(expr)->UnwrapLoad();
auto* expr_type = expr_sem->Type();
return Switch(
expr_sem, //
[&](const sem::StructMemberAccess* access) {
uint32_t idx = access->Member()->Index();
if (info->source_type->Is<core::type::Reference>()) {
auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(idx));
if (TINT_UNLIKELY(idx_id == 0)) {
return false;
}
info->access_chain_indices.push_back(idx_id);
info->source_type = expr_type;
} else {
auto result_type_id = GenerateTypeIfNeeded(expr_type);
if (TINT_UNLIKELY(result_type_id == 0)) {
return false;
}
auto extract = result_op();
auto extract_id = std::get<uint32_t>(extract);
if (!push_function_inst(spv::Op::OpCompositeExtract,
{Operand(result_type_id), extract, Operand(info->source_id),
Operand(idx)})) {
return false;
}
info->source_id = extract_id;
info->source_type = expr_type;
}
return true;
},
[&](const sem::Swizzle* swizzle) {
// Single element swizzle is either an access chain or a composite extract
auto& indices = swizzle->Indices();
if (indices.Length() == 1) {
if (info->source_type->Is<core::type::Reference>()) {
auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(indices[0]));
if (TINT_UNLIKELY(idx_id == 0)) {
return false;
}
info->access_chain_indices.push_back(idx_id);
} else {
auto result_type_id = GenerateTypeIfNeeded(expr_type);
if (TINT_UNLIKELY(result_type_id == 0)) {
return false;
}
auto extract = result_op();
auto extract_id = std::get<uint32_t>(extract);
if (!push_function_inst(spv::Op::OpCompositeExtract,
{Operand(result_type_id), extract,
Operand(info->source_id), Operand(indices[0])})) {
return false;
}
info->source_id = extract_id;
info->source_type = expr_type;
}
return true;
}
// Store the type away as it may change if we run the access chain
auto* incoming_type = info->source_type;
// Multi-item extract is a VectorShuffle. We have to emit any existing
// access chain data, then load the access chain and shuffle that.
if (!info->access_chain_indices.empty()) {
auto result_type_id = GenerateTypeIfNeeded(info->source_type);
if (TINT_UNLIKELY(result_type_id == 0)) {
return false;
}
auto extract = result_op();
auto extract_id = std::get<uint32_t>(extract);
OperandList ops = {Operand(result_type_id), extract, Operand(info->source_id)};
for (auto id : info->access_chain_indices) {
ops.push_back(Operand(id));
}
if (!push_function_inst(spv::Op::OpAccessChain, ops)) {
return false;
}
info->source_id = GenerateLoadIfNeeded(expr_type, extract_id);
info->source_type = expr_type->UnwrapRef();
info->access_chain_indices.clear();
}
auto result_type_id = GenerateTypeIfNeeded(expr_type);
if (TINT_UNLIKELY(result_type_id == 0)) {
return false;
}
auto vec_id = GenerateLoadIfNeeded(incoming_type, info->source_id);
auto result = result_op();
auto result_id = std::get<uint32_t>(result);
OperandList ops = {Operand(result_type_id), result, Operand(vec_id), Operand(vec_id)};
for (auto idx : indices) {
ops.push_back(Operand(idx));
}
if (!push_function_inst(spv::Op::OpVectorShuffle, ops)) {
return false;
}
info->source_id = result_id;
info->source_type = expr_type;
return true;
},
[&](Default) {
TINT_ICE() << "unhandled member index type: " << expr_sem->TypeInfo().name;
return false;
});
}
uint32_t Builder::GenerateAccessorExpression(const ast::AccessorExpression* expr) {
// Gather a list of all the member and index accessors that are in this chain.
// The list is built in reverse order as that's the order we need to access
// the chain.
std::vector<const ast::Expression*> accessors;
const ast::Expression* source = expr;
while (true) {
if (auto* array = source->As<ast::IndexAccessorExpression>()) {
accessors.insert(accessors.begin(), source);
source = array->object;
} else if (auto* member = source->As<ast::MemberAccessorExpression>()) {
accessors.insert(accessors.begin(), source);
source = member->object;
} else {
break;
}
// Stop traversing if we've hit a constant source expression.
if (builder_.Sem().GetVal(source)->ConstantValue()) {
break;
}
}
AccessorInfo info;
info.source_id = GenerateExpression(source);
if (info.source_id == 0) {
return 0;
}
info.source_type = TypeOf(source);
// Note: Dynamic index on array and matrix values (lets) should have been
// promoted to storage with the VarForDynamicIndex transform.
for (auto* accessor : accessors) {
bool ok = Switch(
accessor,
[&](const ast::IndexAccessorExpression* array) {
return GenerateIndexAccessor(array, &info);
},
[&](const ast::MemberAccessorExpression* member) {
return GenerateMemberAccessor(member, &info);
},
[&](Default) {
TINT_ICE() << "invalid accessor in list: " + std::string(accessor->TypeInfo().name);
return false;
});
if (!ok) {
return false;
}
}
if (!info.access_chain_indices.empty()) {
auto* type = builder_.Sem().Get(expr)->UnwrapLoad()->Type();
auto result_type_id = GenerateTypeIfNeeded(type);
if (result_type_id == 0) {
return 0;
}
auto result = result_op();
auto result_id = std::get<uint32_t>(result);
OperandList ops = {Operand(result_type_id), result, Operand(info.source_id)};
for (auto id : info.access_chain_indices) {
ops.push_back(Operand(id));
}
if (!push_function_inst(spv::Op::OpAccessChain, ops)) {
return false;
}
info.source_id = result_id;
}
return info.source_id;
}
uint32_t Builder::GenerateIdentifierExpression(const ast::IdentifierExpression* expr) {
if (auto* sem = builder_.Sem().GetVal(expr); sem) {
if (auto* user = sem->UnwrapLoad()->As<sem::VariableUser>()) {
return LookupVariableID(user->Variable());
}
}
TINT_ICE() << "identifier '" + expr->identifier->symbol.Name() +
"' does not resolve to a variable";
return 0;
}
uint32_t Builder::GenerateLoad(const core::type::Reference* type, uint32_t id) {
auto type_id = GenerateTypeIfNeeded(type->StoreType());
auto result = result_op();
auto result_id = std::get<uint32_t>(result);
if (!push_function_inst(spv::Op::OpLoad, {Operand(type_id), result, Operand(id)})) {
return 0;
}
return result_id;
}
uint32_t Builder::GenerateLoadIfNeeded(const core::type::Type* type, uint32_t id) {
if (auto* ref = type->As<core::type::Reference>()) {
return GenerateLoad(ref, id);
}
return id;
}
uint32_t Builder::GenerateUnaryOpExpression(const ast::UnaryOpExpression* expr) {
auto result = result_op();
auto result_id = std::get<uint32_t>(result);
spv::Op op = spv::Op::OpNop;
switch (expr->op) {
case core::UnaryOp::kComplement:
op = spv::Op::OpNot;
break;
case core::UnaryOp::kNegation:
if (TypeOf(expr)->is_float_scalar_or_vector()) {
op = spv::Op::OpFNegate;
} else {
op = spv::Op::OpSNegate;
}
break;
case core::UnaryOp::kNot:
op = spv::Op::OpLogicalNot;
break;
case core::UnaryOp::kAddressOf:
case core::UnaryOp::kIndirection:
// Address-of converts a reference to a pointer, and dereference converts
// a pointer to a reference. These are the same thing in SPIR-V, so this
// is a no-op.
return GenerateExpression(expr->expr);
}
auto val_id = GenerateExpression(expr->expr);
if (val_id == 0) {
return 0;
}
auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (type_id == 0) {
return 0;
}
if (!push_function_inst(op, {Operand(type_id), result, Operand(val_id)})) {
return false;
}
return result_id;
}
uint32_t Builder::GetGLSLstd450Import() {
auto where = import_name_to_id_.find(kGLSLstd450);
if (where != import_name_to_id_.end()) {
return where->second;
}
// It doesn't exist yet. Generate it.
auto result = result_op();
auto id = std::get<uint32_t>(result);
module_.PushExtImport(spv::Op::OpExtInstImport, {result, Operand(kGLSLstd450)});
// Remember it for later.
import_name_to_id_[kGLSLstd450] = id;
return id;
}
uint32_t Builder::GenerateConstructorExpression(const ast::Variable* var,
const ast::Expression* expr) {
if (auto* sem = builder_.Sem().GetVal(expr)) {
if (auto constant = sem->ConstantValue()) {
return GenerateConstantIfNeeded(constant);
}
}
if (auto* call = builder_.Sem().Get<sem::Call>(expr)) {
if (call->Target()->IsAnyOf<sem::ValueConstructor, sem::ValueConversion>()) {
return GenerateValueConstructorOrConversion(call, var);
}
}
TINT_ICE() << "unknown constructor expression";
return 0;
}
bool Builder::IsConstructorConst(const ast::Expression* expr) {
bool is_const = true;
ast::TraverseExpressions(expr, [&](const ast::Expression* e) {
if (e->Is<ast::LiteralExpression>()) {
return ast::TraverseAction::Descend;
}
if (auto* ce = e->As<ast::CallExpression>()) {
auto* sem = builder_.Sem().Get(ce);
if (sem->Is<sem::Materialize>()) {
// Materialize can only occur on compile time expressions, so this sub-tree must be
// constant.
return ast::TraverseAction::Skip;
}
auto* call = sem->As<sem::Call>();
if (call->Target()->Is<sem::ValueConstructor>()) {
return ast::TraverseAction::Descend;
}
}
is_const = false;
return ast::TraverseAction::Stop;
});
return is_const;
}
uint32_t Builder::GenerateValueConstructorOrConversion(const sem::Call* call,
const ast::Variable* var) {
auto& args = call->Arguments();
auto* global_var = builder_.Sem().Get<sem::GlobalVariable>(var);
auto* result_type = call->Type();
// Generate the zero constructor if there are no values provided.
if (args.IsEmpty()) {
return GenerateConstantNullIfNeeded(result_type->UnwrapRef());
}
result_type = result_type->UnwrapRef();
bool constructor_is_const = IsConstructorConst(call->Declaration());
if (has_error()) {
return 0;
}
bool can_cast_or_copy = result_type->Is<core::type::Scalar>();
if (auto* res_vec = result_type->As<core::type::Vector>()) {
if (res_vec->type()->Is<core::type::Scalar>()) {
auto* value_type = args[0]->Type()->UnwrapRef();
if (auto* val_vec = value_type->As<core::type::Vector>()) {
if (val_vec->type()->Is<core::type::Scalar>()) {
can_cast_or_copy = res_vec->Width() == val_vec->Width();
}
}
}
}
if (auto* res_mat = result_type->As<core::type::Matrix>()) {
auto* value_type = args[0]->Type()->UnwrapRef();
if (auto* val_mat = value_type->As<core::type::Matrix>()) {
// Generate passthrough for matrices of the same type
can_cast_or_copy = res_mat == val_mat;
}
}
if (can_cast_or_copy) {
return GenerateCastOrCopyOrPassthrough(result_type, args[0]->Declaration(), global_var);
}
auto type_id = GenerateTypeIfNeeded(result_type);
if (type_id == 0) {
return 0;
}
bool result_is_constant_composite = constructor_is_const;
bool result_is_spec_composite = false;
if (auto* vec = result_type->As<core::type::Vector>()) {
result_type = vec->type();
}
OperandList ops;
static constexpr size_t kOpsResultIdx = 1;
static constexpr size_t kOpsFirstValueIdx = 2;
ops.reserve(8);
ops.push_back(Operand(type_id));
ops.push_back(Operand(0u)); // Placeholder for the result ID
for (auto* e : args) {
uint32_t id = 0;
id = GenerateExpression(e);
if (id == 0) {
return 0;
}
auto* value_type = e->Type()->UnwrapRef();
// If the result and value types are the same we can just use the object.
// If the result is not a vector then we should have validated that the
// value type is a correctly sized vector so we can just use it directly.
if (result_type == value_type || result_type->Is<core::type::Matrix>() ||
result_type->Is<core::type::Array>() || result_type->Is<core::type::Struct>()) {
ops.push_back(Operand(id));
continue;
}
// Both scalars, but not the same type so we need to generate a conversion
// of the value.
if (value_type->Is<core::type::Scalar>() && result_type->Is<core::type::Scalar>()) {
id = GenerateCastOrCopyOrPassthrough(result_type, args[0]->Declaration(), global_var);
ops.push_back(Operand(id));
continue;
}
// When handling vectors as the values there a few cases to take into
// consideration:
// 1. Module scoped vec3<f32>(vec2<f32>(1, 2), 3) -> OpSpecConstantOp
// 2. Function scoped vec3<f32>(vec2<f32>(1, 2), 3) -> OpCompositeExtract
// 3. Either array<vec3<f32>, 1>(vec3<f32>(1, 2, 3)) -> use the ID.
// -> handled above
//
// For cases 1 and 2, if the type is different we also may need to insert
// a type cast.
if (auto* vec = value_type->As<core::type::Vector>()) {
auto* vec_type = vec->type();
auto value_type_id = GenerateTypeIfNeeded(vec_type);
if (value_type_id == 0) {
return 0;
}
for (uint32_t i = 0; i < vec->Width(); ++i) {
auto extract = result_op();
auto extract_id = std::get<uint32_t>(extract);
if (!global_var) {
// A non-global initializer. Case 2.
if (!push_function_inst(
spv::Op::OpCompositeExtract,
{Operand(value_type_id), extract, Operand(id), Operand(i)})) {
return false;
}
// We no longer have a constant composite, but have to do a
// composite construction as these calls are inside a function.
result_is_constant_composite = false;
} else {
// A global initializer, must use OpSpecConstantOp. Case 1.
auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(i));
if (idx_id == 0) {
return 0;
}
module_.PushType(spv::Op::OpSpecConstantOp, {Operand(value_type_id), extract,
U32Operand(SpvOpCompositeExtract),
Operand(id), Operand(idx_id)});
result_is_spec_composite = true;
}
ops.push_back(Operand(extract_id));
}
} else {
TINT_ICE() << "Unhandled type cast value type";
return 0;
}
}
// For a single-value vector initializer, splat the initializer value.
auto* const init_result_type = call->Type()->UnwrapRef();
if (args.Length() == 1 && init_result_type->is_scalar_vector() &&
args[0]->Type()->UnwrapRef()->Is<core::type::Scalar>()) {
size_t vec_size = init_result_type->As<core::type::Vector>()->Width();
for (size_t i = 0; i < (vec_size - 1); ++i) {
ops.push_back(ops[kOpsFirstValueIdx]);
}
}
auto& stack = (result_is_spec_composite || result_is_constant_composite)
? scope_stack_[0] // Global scope
: scope_stack_.back(); // Lexical scope
return tint::GetOrCreate(stack.type_init_to_id_, OperandListKey{ops}, [&]() -> uint32_t {
auto result = result_op();
ops[kOpsResultIdx] = result;
if (result_is_spec_composite) {
module_.PushType(spv::Op::OpSpecConstantComposite, ops);
} else if (result_is_constant_composite) {
module_.PushType(spv::Op::OpConstantComposite, ops);
} else {
if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
return 0;
}
}
return std::get<uint32_t>(result);
});
}
uint32_t Builder::GenerateCastOrCopyOrPassthrough(const core::type::Type* to_type,
const ast::Expression* from_expr,
bool is_global_init) {
// This should not happen as we rely on constant folding to obviate
// casts/conversions for module-scope variables
if (TINT_UNLIKELY(is_global_init)) {
TINT_ICE() << "Module-level conversions are not supported. Conversions should "
"have already been constant-folded by the FoldConstants transform.";
return 0;
}
auto elem_type_of = [](const core::type::Type* t) -> const core::type::Type* {
if (t->Is<core::type::Scalar>()) {
return t;
}
if (auto* v = t->As<core::type::Vector>()) {
return v->type();
}
return nullptr;
};
auto result = result_op();
auto result_id = std::get<uint32_t>(result);
auto result_type_id = GenerateTypeIfNeeded(to_type);
if (result_type_id == 0) {
return 0;
}
auto val_id = GenerateExpression(from_expr);
if (val_id == 0) {
return 0;
}
auto* from_type = TypeOf(from_expr)->UnwrapRef();
spv::Op op = spv::Op::OpNop;
if ((from_type->Is<core::type::I32>() && to_type->is_float_scalar()) ||
(from_type->is_signed_integer_vector() && to_type->is_float_vector())) {
op = spv::Op::OpConvertSToF;
} else if ((from_type->Is<core::type::U32>() && to_type->is_float_scalar()) ||
(from_type->is_unsigned_integer_vector() && to_type->is_float_vector())) {
op = spv::Op::OpConvertUToF;
} else if ((from_type->is_float_scalar() && to_type->Is<core::type::I32>()) ||
(from_type->is_float_vector() && to_type->is_signed_integer_vector())) {
op = spv::Op::OpConvertFToS;
} else if ((from_type->is_float_scalar() && to_type->Is<core::type::U32>()) ||
(from_type->is_float_vector() && to_type->is_unsigned_integer_vector())) {
op = spv::Op::OpConvertFToU;
} else if (from_type->IsAnyOf<core::type::Bool, core::type::F32, core::type::I32,
core::type::U32, core::type::F16, core::type::Vector>() &&
from_type == to_type) {
// Identity initializer for scalar and vector types
return val_id;
} else if ((from_type->is_float_scalar() && to_type->is_float_scalar()) ||
(from_type->is_float_vector() && to_type->is_float_vector() &&
from_type->As<core::type::Vector>()->Width() ==
to_type->As<core::type::Vector>()->Width())) {
// Convert between f32 and f16 types.
// OpFConvert requires the scalar component types to be different, and the case of from_type
// and to_type being the same floating point scalar or vector type, i.e. identity
// initializer, is already handled in the previous else-if clause.
op = spv::Op::OpFConvert;
} else if ((from_type->Is<core::type::I32>() && to_type->Is<core::type::U32>()) ||
(from_type->Is<core::type::U32>() && to_type->Is<core::type::I32>()) ||
(from_type->is_signed_integer_vector() && to_type->is_unsigned_integer_vector()) ||
(from_type->is_unsigned_integer_vector() &&
to_type->is_integer_scalar_or_vector())) {
op = spv::Op::OpBitcast;
} else if ((from_type->Is<core::type::NumericScalar>() && to_type->Is<core::type::Bool>()) ||
(from_type->is_numeric_vector() && to_type->is_bool_vector())) {
// Convert scalar (vector) to bool (vector)
// Return the result of comparing from_expr with zero
uint32_t zero = GenerateConstantNullIfNeeded(from_type);
const auto* from_elem_type = elem_type_of(from_type);
op = from_elem_type->is_integer_scalar() ? spv::Op::OpINotEqual : spv::Op::OpFUnordNotEqual;
if (!push_function_inst(op, {Operand(result_type_id), Operand(result_id), Operand(val_id),
Operand(zero)})) {
return 0;
}
return result_id;
} else if (from_type->is_bool_scalar_or_vector() && to_type->is_numeric_scalar_or_vector()) {
// Convert bool scalar/vector to numeric scalar/vector.
// Use the bool to select between 1 (if true) and 0 (if false).
const auto* to_elem_type = elem_type_of(to_type);
uint32_t one_id;
uint32_t zero_id;
if (to_elem_type->Is<core::type::F32>()) {
zero_id = GenerateConstantIfNeeded(ScalarConstant::F32(0));
one_id = GenerateConstantIfNeeded(ScalarConstant::F32(1));
} else if (to_elem_type->Is<core::type::F16>()) {
zero_id = GenerateConstantIfNeeded(ScalarConstant::F16(0));
one_id = GenerateConstantIfNeeded(ScalarConstant::F16(1));
} else if (to_elem_type->Is<core::type::U32>()) {
zero_id = GenerateConstantIfNeeded(ScalarConstant::U32(0));
one_id = GenerateConstantIfNeeded(ScalarConstant::U32(1));
} else if (to_elem_type->Is<core::type::I32>()) {
zero_id = GenerateConstantIfNeeded(ScalarConstant::I32(0));
one_id = GenerateConstantIfNeeded(ScalarConstant::I32(1));
} else {
TINT_ICE() << "invalid destination type for bool conversion";
return false;
}
if (auto* to_vec = to_type->As<core::type::Vector>()) {
// Splat the scalars into vectors.
zero_id = GenerateConstantVectorSplatIfNeeded(to_vec, zero_id);
one_id = GenerateConstantVectorSplatIfNeeded(to_vec, one_id);
}
if (!one_id || !zero_id) {
return false;
}
op = spv::Op::OpSelect;
if (!push_function_inst(op, {Operand(result_type_id), Operand(result_id), Operand(val_id),
Operand(one_id), Operand(zero_id)})) {
return 0;
}
return result_id;
} else if (TINT_LIKELY(from_type->Is<core::type::Matrix>() &&
to_type->Is<core::type::Matrix>())) {
// SPIRV does not support matrix conversion, the only valid case is matrix identity
// initializer. Matrix conversion between f32 and f16 should be transformed into vector
// conversions for each column vectors by VectorizeMatrixConversions.
auto* from_mat = from_type->As<core::type::Matrix>();
auto* to_mat = to_type->As<core::type::Matrix>();
if (TINT_LIKELY(from_mat == to_mat)) {
return val_id;
}
TINT_ICE() << "matrix conversion is not supported and should have been handled by "
"VectorizeMatrixConversions";
} else {
TINT_ICE() << "Invalid from_type";
}
if (op == spv::Op::OpNop) {
TINT_ICE() << "unable to determine conversion type for cast, from: " +
from_type->FriendlyName() + " to: " + to_type->FriendlyName();
return 0;
}
if (!push_function_inst(op, {Operand(result_type_id), result, Operand(val_id)})) {
return 0;
}
return result_id;
}
uint32_t Builder::GenerateLiteralIfNeeded(const ast::LiteralExpression* lit) {
ScalarConstant constant;
Switch(
lit,
[&](const ast::BoolLiteralExpression* l) {
constant.kind = ScalarConstant::Kind::kBool;
constant.value.b = l->value;
},
[&](const ast::IntLiteralExpression* i) {
switch (i->suffix) {
case ast::IntLiteralExpression::Suffix::kNone:
case ast::IntLiteralExpression::Suffix::kI:
constant.kind = ScalarConstant::Kind::kI32;
constant.value.i32 = static_cast<int32_t>(i->value);
return;
case ast::IntLiteralExpression::Suffix::kU:
constant.kind = ScalarConstant::Kind::kU32;
constant.value.u32 = static_cast<uint32_t>(i->value);
return;
}
},
[&](const ast::FloatLiteralExpression* f) {
switch (f->suffix) {
case ast::FloatLiteralExpression::Suffix::kNone:
case ast::FloatLiteralExpression::Suffix::kF:
constant.kind = ScalarConstant::Kind::kF32;
constant.value.f32 = static_cast<float>(f->value);
return;
case ast::FloatLiteralExpression::Suffix::kH:
constant.kind = ScalarConstant::Kind::kF16;
constant.value.f16 = {f16(static_cast<float>(f->value)).BitsRepresentation()};
return;
}
},
[&](Default) { TINT_ICE() << "unknown literal type"; });
if (has_error()) {
return false;
}
return GenerateConstantIfNeeded(constant);
}
uint32_t Builder::GenerateConstantIfNeeded(const core::constant::Value* constant) {
if (constant->AllZero()) {
return GenerateConstantNullIfNeeded(constant->Type());
}
auto* ty = constant->Type();
auto composite = [&](size_t el_count) -> uint32_t {
auto type_id = GenerateTypeIfNeeded(ty);
if (!type_id) {
return 0;
}
static constexpr size_t kOpsResultIdx = 1; // operand index of the result
std::vector<Operand> ops;
ops.reserve(el_count + 2);
ops.emplace_back(type_id);
ops.push_back(Operand(0u)); // Placeholder for the result ID
for (size_t i = 0; i < el_count; i++) {
auto id = GenerateConstantIfNeeded(constant->Index(i));
if (!id) {
return 0;
}
ops.emplace_back(id);
}
auto& global_scope = scope_stack_[0];
return tint::GetOrCreate(global_scope.type_init_to_id_, OperandListKey{ops},
[&]() -> uint32_t {
auto result = result_op();
ops[kOpsResultIdx] = result;
module_.PushType(spv::Op::OpConstantComposite, std::move(ops));
return std::get<uint32_t>(result);
});
};
return Switch(
ty, //
[&](const core::type::Bool*) {
bool val = constant->ValueAs<bool>();
return GenerateConstantIfNeeded(ScalarConstant::Bool(val));
},
[&](const core::type::F32*) {
auto val = constant->ValueAs<f32>();
return GenerateConstantIfNeeded(ScalarConstant::F32(val.value));
},
[&](const core::type::F16*) {
auto val = constant->ValueAs<f16>();
return GenerateConstantIfNeeded(ScalarConstant::F16(val.value));
},
[&](const core::type::I32*) {
auto val = constant->ValueAs<i32>();
return GenerateConstantIfNeeded(ScalarConstant::I32(val.value));
},
[&](const core::type::U32*) {
auto val = constant->ValueAs<u32>();
return GenerateConstantIfNeeded(ScalarConstant::U32(val.value));
},
[&](const core::type::Vector* v) { return composite(v->Width()); },
[&](const core::type::Matrix* m) { return composite(m->columns()); },
[&](const core::type::Array* a) {
auto count = a->ConstantCount();
if (!count) {
TINT_ICE() << core::type::Array::kErrExpectedConstantCount;
return static_cast<uint32_t>(0);
}
return composite(count.value());
},
[&](const core::type::Struct* s) { return composite(s->Members().Length()); },
[&](Default) {
TINT_ICE() << "unhandled constant type: " + ty->FriendlyName();
return 0;
});
}
uint32_t Builder::GenerateConstantIfNeeded(const ScalarConstant& constant) {
auto it = const_to_id_.find(constant);
if (it != const_to_id_.end()) {
return it->second;
}
uint32_t type_id = 0;
switch (constant.kind) {
case ScalarConstant::Kind::kU32: {
type_id = GenerateTypeIfNeeded(builder_.create<core::type::U32>());
break;
}
case ScalarConstant::Kind::kI32: {
type_id = GenerateTypeIfNeeded(builder_.create<core::type::I32>());
break;
}
case ScalarConstant::Kind::kF32: {
type_id = GenerateTypeIfNeeded(builder_.create<core::type::F32>());
break;
}
case ScalarConstant::Kind::kF16: {
type_id = GenerateTypeIfNeeded(builder_.create<core::type::F16>());
break;
}
case ScalarConstant::Kind::kBool: {
type_id = GenerateTypeIfNeeded(builder_.create<core::type::Bool>());
break;
}
}
if (type_id == 0) {
return 0;
}
auto result = result_op();
auto result_id = std::get<uint32_t>(result);
switch (constant.kind) {
case ScalarConstant::Kind::kU32: {
module_.PushType(spv::Op::OpConstant,
{Operand(type_id), result, Operand(constant.value.u32)});
break;
}
case ScalarConstant::Kind::kI32: {
module_.PushType(spv::Op::OpConstant,
{Operand(type_id), result, U32Operand(constant.value.i32)});
break;
}
case ScalarConstant::Kind::kF32: {
module_.PushType(spv::Op::OpConstant,
{Operand(type_id), result, Operand(constant.value.f32)});
break;
}
case ScalarConstant::Kind::kF16: {
module_.PushType(
spv::Op::OpConstant,
{Operand(type_id), result, U32Operand(constant.value.f16.bits_representation)});
break;
}
case ScalarConstant::Kind::kBool: {
if (constant.value.b) {
module_.PushType(spv::Op::OpConstantTrue, {Operand(type_id), result});
} else {
module_.PushType(spv::Op::OpConstantFalse, {Operand(type_id), result});
}
break;
}
}
const_to_id_[constant] = result_id;
return result_id;
}
uint32_t Builder::GenerateConstantNullIfNeeded(const core::type::Type* type) {
auto type_id = GenerateTypeIfNeeded(type);
if (type_id == 0) {
return 0;
}
return tint::GetOrCreate(const_null_to_id_, type, [&] {
auto result = result_op();
module_.PushType(spv::Op::OpConstantNull, {Operand(type_id), result});
return std::get<uint32_t>(result);
});
}
uint32_t Builder::GenerateConstantVectorSplatIfNeeded(const core::type::Vector* type,
uint32_t value_id) {
auto type_id = GenerateTypeIfNeeded(type);
if (type_id == 0 || value_id == 0) {
return 0;
}
uint64_t key = (static_cast<uint64_t>(type->Width()) << 32) + value_id;
return tint::GetOrCreate(const_splat_to_id_, key, [&] {
auto result = result_op();
auto result_id = std::get<uint32_t>(result);
OperandList ops;
ops.push_back(Operand(type_id));
ops.push_back(result);
for (uint32_t i = 0; i < type->Width(); i++) {
ops.push_back(Operand(value_id));
}
module_.PushType(spv::Op::OpConstantComposite, ops);
const_splat_to_id_[key] = result_id;
return result_id;
});
}
uint32_t Builder::GenerateShortCircuitBinaryExpression(const ast::BinaryExpression* expr) {
auto lhs_id = GenerateExpression(expr->lhs);
if (lhs_id == 0) {
return false;
}
// Get the ID of the basic block where control flow will diverge. It's the
// last basic block generated for the left-hand-side of the operator.
auto original_label_id = current_label_id_;
auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (type_id == 0) {
return 0;
}
auto merge_block = result_op();
auto merge_block_id = std::get<uint32_t>(merge_block);
auto block = result_op();
auto block_id = std::get<uint32_t>(block);
auto true_block_id = block_id;
auto false_block_id = merge_block_id;
// For a logical or we want to only check the RHS if the LHS is failed.
if (expr->IsLogicalOr()) {
std::swap(true_block_id, false_block_id);
}
if (!push_function_inst(spv::Op::OpSelectionMerge,
{Operand(merge_block_id), U32Operand(SpvSelectionControlMaskNone)})) {
return 0;
}
if (!push_function_inst(spv::Op::OpBranchConditional,
{Operand(lhs_id), Operand(true_block_id), Operand(false_block_id)})) {
return 0;
}
// Output block to check the RHS
if (!GenerateLabel(block_id)) {
return 0;
}
auto rhs_id = GenerateExpression(expr->rhs);
if (rhs_id == 0) {
return 0;
}
// Get the block ID of the last basic block generated for the right-hand-side
// expression. That block will be an immediate predecessor to the merge block.
auto rhs_block_id = current_label_id_;
if (!push_function_inst(spv::Op::OpBranch, {Operand(merge_block_id)})) {
return 0;
}
// Output the merge block
if (!GenerateLabel(merge_block_id)) {
return 0;
}
auto result = result_op();
auto result_id = std::get<uint32_t>(result);
if (!push_function_inst(spv::Op::OpPhi,
{Operand(type_id), result, Operand(lhs_id), Operand(original_label_id),
Operand(rhs_id), Operand(rhs_block_id)})) {
return 0;
}
return result_id;
}
uint32_t Builder::GenerateSplat(uint32_t scalar_id, const core::type::Type* vec_type) {
// Create a new vector to splat scalar into
auto splat_vector = result_op();
auto* splat_vector_type = builder_.create<core::type::Pointer>(
core::AddressSpace::kFunction, vec_type, core::Access::kReadWrite);
push_function_var({Operand(GenerateTypeIfNeeded(splat_vector_type)), splat_vector,
U32Operand(ConvertAddressSpace(core::AddressSpace::kFunction)),
Operand(GenerateConstantNullIfNeeded(vec_type))});
// Splat scalar into vector
auto splat_result = result_op();
OperandList ops;
ops.push_back(Operand(GenerateTypeIfNeeded(vec_type)));
ops.push_back(splat_result);
for (size_t i = 0; i < vec_type->As<core::type::Vector>()->Width(); ++i) {
ops.push_back(Operand(scalar_id));
}
if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
return 0;
}
return std::get<uint32_t>(splat_result);
}
uint32_t Builder::GenerateMatrixAddOrSub(uint32_t lhs_id,
uint32_t rhs_id,
const core::type::Matrix* type,
spv::Op op) {
// Example addition of two matrices:
// %31 = OpLoad %mat3v4float %m34
// %32 = OpLoad %mat3v4float %m34
// %33 = OpCompositeExtract %v4float %31 0
// %34 = OpCompositeExtract %v4float %32 0
// %35 = OpFAdd %v4float %33 %34
// %36 = OpCompositeExtract %v4float %31 1
// %37 = OpCompositeExtract %v4float %32 1
// %38 = OpFAdd %v4float %36 %37
// %39 = OpCompositeExtract %v4float %31 2
// %40 = OpCompositeExtract %v4float %32 2
// %41 = OpFAdd %v4float %39 %40
// %42 = OpCompositeConstruct %mat3v4float %35 %38 %41
auto* column_type = builder_.create<core::type::Vector>(type->type(), type->rows());
auto column_type_id = GenerateTypeIfNeeded(column_type);
OperandList ops;
for (uint32_t i = 0; i < type->columns(); ++i) {
// Extract column `i` from lhs mat
auto lhs_column_id = result_op();
if (!push_function_inst(
spv::Op::OpCompositeExtract,
{Operand(column_type_id), lhs_column_id, Operand(lhs_id), Operand(i)})) {
return 0;
}
// Extract column `i` from rhs mat
auto rhs_column_id = result_op();
if (!push_function_inst(
spv::Op::OpCompositeExtract,
{Operand(column_type_id), rhs_column_id, Operand(rhs_id), Operand(i)})) {
return 0;
}
// Add or subtract the two columns
auto result = result_op();
if (!push_function_inst(op,
{Operand(column_type_id), result, lhs_column_id, rhs_column_id})) {
return 0;
}
ops.push_back(result);
}
// Create the result matrix from the added/subtracted column vectors
TINT_BEGIN_DISABLE_WARNING(MAYBE_UNINITIALIZED); // GCC false-positive
auto result_mat_id = result_op();
ops.insert(ops.begin(), result_mat_id);
ops.insert(ops.begin(), Operand(GenerateTypeIfNeeded(type)));
if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
return 0;
}
TINT_END_DISABLE_WARNING(MAYBE_UNINITIALIZED); // GCC false-positive
return std::get<uint32_t>(result_mat_id);
}
uint32_t Builder::GenerateBinaryExpression(const ast::BinaryExpression* expr) {
// There is special logic for short circuiting operators.
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
return GenerateShortCircuitBinaryExpression(expr);
}
auto lhs_id = GenerateExpression(expr->lhs);
if (lhs_id == 0) {
return 0;
}
auto rhs_id = GenerateExpression(expr->rhs);
if (rhs_id == 0) {
return 0;
}
auto result = result_op();
auto result_id = std::get<uint32_t>(result);
auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (type_id == 0) {
return 0;
}
// Handle int and float and the vectors of those types. Other types
// should have been rejected by validation.
auto* lhs_type = TypeOf(expr->lhs)->UnwrapRef();
auto* rhs_type = TypeOf(expr->rhs)->UnwrapRef();
// Handle matrix-matrix addition and subtraction
if ((expr->IsAdd() || expr->IsSubtract()) && lhs_type->is_float_matrix() &&
rhs_type->is_float_matrix()) {
auto* lhs_mat = lhs_type->As<core::type::Matrix>();
auto* rhs_mat = rhs_type->As<core::type::Matrix>();
// This should already have been validated by resolver
if (lhs_mat->rows() != rhs_mat->rows() || lhs_mat->columns() != rhs_mat->columns()) {
TINT_ICE() << "matrices must have same dimensionality for add or subtract";
return 0;
}
return GenerateMatrixAddOrSub(lhs_id, rhs_id, lhs_mat,
expr->IsAdd() ? spv::Op::OpFAdd : spv::Op::OpFSub);
}
// For vector-scalar arithmetic operations, splat scalar into a vector. We
// skip this for multiply as we can use OpVectorTimesScalar.
const bool is_float_scalar_vector_multiply =
expr->IsMultiply() && ((lhs_type->is_float_scalar() && rhs_type->is_float_vector()) ||
(lhs_type->is_float_vector() && rhs_type->is_float_scalar()));
if (expr->IsArithmetic() && !is_float_scalar_vector_multiply) {
if (lhs_type->Is<core::type::Vector>() && rhs_type->Is<core::type::NumericScalar>()) {
uint32_t splat_vector_id = GenerateSplat(rhs_id, lhs_type);
if (splat_vector_id == 0) {
return 0;
}
rhs_id = splat_vector_id;
rhs_type = lhs_type;
} else if (lhs_type->Is<core::type::NumericScalar>() &&
rhs_type->Is<core::type::Vector>()) {
uint32_t splat_vector_id = GenerateSplat(lhs_id, rhs_type);
if (splat_vector_id == 0) {
return 0;
}
lhs_id = splat_vector_id;
lhs_type = rhs_type;
}
}
bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
bool lhs_is_bool_or_vec = lhs_type->is_bool_scalar_or_vector();
bool lhs_is_integer_or_vec = lhs_type->is_integer_scalar_or_vector();
bool lhs_is_unsigned = lhs_type->is_unsigned_integer_scalar_or_vector();
spv::Op op = spv::Op::OpNop;
if (expr->IsAnd()) {
if (lhs_is_integer_or_vec) {
op = spv::Op::OpBitwiseAnd;
} else if (lhs_is_bool_or_vec) {
op = spv::Op::OpLogicalAnd;
} else {
TINT_ICE() << "invalid and expression";
return 0;
}
} else if (expr->IsAdd()) {
op = lhs_is_float_or_vec ? spv::Op::OpFAdd : spv::Op::OpIAdd;
} else if (expr->IsDivide()) {
if (lhs_is_float_or_vec) {
op = spv::Op::OpFDiv;
} else if (lhs_is_unsigned) {
op = spv::Op::OpUDiv;
} else {
op = spv::Op::OpSDiv;
}
} else if (expr->IsEqual()) {
if (lhs_is_float_or_vec) {
op = spv::Op::OpFOrdEqual;
} else if (lhs_is_bool_or_vec) {
op = spv::Op::OpLogicalEqual;
} else if (lhs_is_integer_or_vec) {
op = spv::Op::OpIEqual;
} else {
TINT_ICE() << "invalid equal expression";
return 0;
}
} else if (expr->IsGreaterThan()) {
if (lhs_is_float_or_vec) {
op = spv::Op::OpFOrdGreaterThan;
} else if (lhs_is_unsigned) {
op = spv::Op::OpUGreaterThan;
} else {
op = spv::Op::OpSGreaterThan;
}
} else if (expr->IsGreaterThanEqual()) {
if (lhs_is_float_or_vec) {
op = spv::Op::OpFOrdGreaterThanEqual;
} else if (lhs_is_unsigned) {
op = spv::Op::OpUGreaterThanEqual;
} else {
op = spv::Op::OpSGreaterThanEqual;
}
} else if (expr->IsLessThan()) {
if (lhs_is_float_or_vec) {
op = spv::Op::OpFOrdLessThan;
} else if (lhs_is_unsigned) {
op = spv::Op::OpULessThan;
} else {
op = spv::Op::OpSLessThan;
}
} else if (expr->IsLessThanEqual()) {
if (lhs_is_float_or_vec) {
op = spv::Op::OpFOrdLessThanEqual;
} else if (lhs_is_unsigned) {
op = spv::Op::OpULessThanEqual;
} else {
op = spv::Op::OpSLessThanEqual;
}
} else if (expr->IsModulo()) {
if (lhs_is_float_or_vec) {
op = spv::Op::OpFRem;
} else if (lhs_is_unsigned) {
op = spv::Op::OpUMod;
} else {
op = spv::Op::OpSRem;
}
} else if (expr->IsMultiply()) {
if (lhs_type->is_integer_scalar_or_vector()) {
// If the left hand side is an integer then this _has_ to be OpIMul as
// there there is no other integer multiplication.
op = spv::Op::OpIMul;
} else if (lhs_type->is_float_scalar() && rhs_type->is_float_scalar()) {
// Float scalars multiply with OpFMul
op = spv::Op::OpFMul;
} else if (lhs_type->is_float_vector() && rhs_type->is_float_vector()) {
// Float vectors must be validated to be the same size and then use OpFMul
op = spv::Op::OpFMul;
} else if (lhs_type->is_float_scalar() && rhs_type->is_float_vector()) {
// Scalar * Vector we need to flip lhs and rhs types
// because OpVectorTimesScalar expects <vector>, <scalar>
std::swap(lhs_id, rhs_id);
op = spv::Op::OpVectorTimesScalar;
} else if (lhs_type->is_float_vector() && rhs_type->is_float_scalar()) {
// float vector * scalar
op = spv::Op::OpVectorTimesScalar;
} else if (lhs_type->is_float_scalar() && rhs_type->is_float_matrix()) {
// Scalar * Matrix we need to flip lhs and rhs types because
// OpMatrixTimesScalar expects <matrix>, <scalar>
std::swap(lhs_id, rhs_id);
op = spv::Op::OpMatrixTimesScalar;
} else if (lhs_type->is_float_matrix() && rhs_type->is_float_scalar()) {
// float matrix * scalar
op = spv::Op::OpMatrixTimesScalar;
} else if (lhs_type->is_float_vector() && rhs_type->is_float_matrix()) {
// float vector * matrix
op = spv::Op::OpVectorTimesMatrix;
} else if (lhs_type->is_float_matrix() && rhs_type->is_float_vector()) {
// float matrix * vector
op = spv::Op::OpMatrixTimesVector;
} else if (lhs_type->is_float_matrix() && rhs_type->is_float_matrix()) {
// float matrix * matrix
op = spv::Op::OpMatrixTimesMatrix;
} else {
TINT_ICE() << "invalid multiply expression";
return 0;
}
} else if (expr->IsNotEqual()) {
if (lhs_is_float_or_vec) {
op = spv::Op::OpFOrdNotEqual;
} else if (lhs_is_bool_or_vec) {
op = spv::Op::OpLogicalNotEqual;
} else if (lhs_is_integer_or_vec) {
op = spv::Op::OpINotEqual;
} else {
TINT_ICE() << "invalid not-equal expression";
return 0;
}
} else if (expr->IsOr()) {
if (lhs_is_integer_or_vec) {
op = spv::Op::OpBitwiseOr;
} else if (lhs_is_bool_or_vec) {
op = spv::Op::OpLogicalOr;
} else {
TINT_ICE() << "invalid and expression";
return 0;
}
} else if (expr->IsShiftLeft()) {
op = spv::Op::OpShiftLeftLogical;
} else if (expr->IsShiftRight() && lhs_type->is_signed_integer_scalar_or_vector()) {
// A shift right with a signed LHS is an arithmetic shift.
op = spv::Op::OpShiftRightArithmetic;
} else if (expr->IsShiftRight()) {
op = spv::Op::OpShiftRightLogical;
} else if (expr->IsSubtract()) {
op = lhs_is_float_or_vec ? spv::Op::OpFSub : spv::Op::OpISub;
} else if (expr->IsXor()) {
op = spv::Op::OpBitwiseXor;
} else {
TINT_ICE() << "unknown binary expression";
return 0;
}
if (!push_function_inst(op, {Operand(type_id), result, Operand(lhs_id), Operand(rhs_id)})) {
return 0;
}
return result_id;
}
bool Builder::GenerateBlockStatement(const ast::BlockStatement* stmt) {
PushScope();
TINT_DEFER(PopScope());
return GenerateBlockStatementWithoutScoping(stmt);
}
bool Builder::GenerateBlockStatementWithoutScoping(const ast::BlockStatement* stmt) {
for (auto* block_stmt : stmt->statements) {
if (!GenerateStatement(block_stmt)) {
return false;
}
}
return true;
}
uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) {
auto* call = builder_.Sem().Get<sem::Call>(expr);
auto* target = call->Target();
return Switch(
target, //
[&](const sem::Function* func) { return GenerateFunctionCall(call, func); },
[&](const sem::Builtin* builtin) { return GenerateBuiltinCall(call, builtin); },
[&](const sem::ValueConversion*) {
return GenerateValueConstructorOrConversion(call, nullptr);
},
[&](const sem::ValueConstructor*) {
return GenerateValueConstructorOrConversion(call, nullptr);
},
[&](Default) {
TINT_ICE() << "unhandled call target: " << target->TypeInfo().name;
return 0;
});
}
uint32_t Builder::GenerateFunctionCall(const sem::Call* call, const sem::Function* fn) {
auto* expr = call->Declaration();
auto* ident = fn->Declaration()->name;
auto type_id = GenerateTypeIfNeeded(call->Type());
if (type_id == 0) {
return 0;
}
auto result = result_op();
auto result_id = std::get<uint32_t>(result);
OperandList ops = {Operand(type_id), result};
auto func_id = func_symbol_to_id_[ident->symbol];
if (func_id == 0) {
TINT_ICE() << "unable to find called function: " + ident->symbol.Name();
return 0;
}
ops.push_back(Operand(func_id));
for (auto* arg : expr->args) {
auto id = GenerateExpression(arg);
if (id == 0) {
return 0;
}
ops.push_back(Operand(id));
}
if (!push_function_inst(spv::Op::OpFunctionCall, std::move(ops))) {
return 0;
}
return result_id;
}
uint32_t Builder::GenerateBuiltinCall(const sem::Call* call, const sem::Builtin* builtin) {
auto result = result_op();
auto result_id = std::get<uint32_t>(result);
auto result_type_id = GenerateTypeIfNeeded(builtin->ReturnType());
if (result_type_id == 0) {
return 0;
}
if (builtin->IsFineDerivative() || builtin->IsCoarseDerivative()) {
module_.PushCapability(SpvCapabilityDerivativeControl);
}
if (builtin->IsImageQuery()) {
module_.PushCapability(SpvCapabilityImageQuery);
}
if (builtin->IsTexture()) {
if (!GenerateTextureBuiltin(call, builtin, Operand(result_type_id), result)) {
return 0;
}
return result_id;
}
if (builtin->IsBarrier()) {
if (!GenerateControlBarrierBuiltin(builtin)) {
return 0;
}
return result_id;
}
if (builtin->IsAtomic()) {
if (!GenerateAtomicBuiltin(call, builtin, Operand(result_type_id), result)) {
return 0;
}
return result_id;
}
// Generates the SPIR-V ID for the expression for the indexed call argument,
// and loads it if necessary. Returns 0 on error.
auto get_arg_as_value_id = [&](size_t i, bool generate_load = true) -> uint32_t {
auto* arg = call->Arguments()[i];
auto* param = builtin->Parameters()[i];
auto val_id = GenerateExpression(arg->Declaration());
if (val_id == 0) {
return 0;
}
if (generate_load && !param->Type()->Is<core::type::Pointer>()) {
val_id = GenerateLoadIfNeeded(arg->Type(), val_id);
}
return val_id;
};
OperandList params = {Operand(result_type_id), result};
spv::Op op = spv::Op::OpNop;
// Pushes the arguments for a GlslStd450 extended instruction, and sets op
// to OpExtInst.
auto glsl_std450 = [&](uint32_t inst_id) {
auto set_id = GetGLSLstd450Import();
params.push_back(Operand(set_id));
params.push_back(Operand(inst_id));
op = spv::Op::OpExtInst;
};
switch (builtin->Type()) {
case core::Function::kAny:
if (builtin->Parameters()[0]->Type()->Is<core::type::Bool>()) {
// any(v: bool) just resolves to v.
return get_arg_as_value_id(0);
}
op = spv::Op::OpAny;
break;
case core::Function::kAll:
if (builtin->Parameters()[0]->Type()->Is<core::type::Bool>()) {
// all(v: bool) just resolves to v.
return get_arg_as_value_id(0);
}
op = spv::Op::OpAll;
break;
case core::Function::kArrayLength: {
auto* address_of = call->Arguments()[0]->Declaration()->As<ast::UnaryOpExpression>();
if (!address_of || address_of->op != core::UnaryOp::kAddressOf) {
TINT_ICE() << "arrayLength() expected pointer to member access, got " +
std::string(address_of->TypeInfo().name);
return 0;
}
auto* array_expr = address_of->expr;
auto* accessor = array_expr->As<ast::MemberAccessorExpression>();
if (!accessor) {
TINT_ICE() << "arrayLength() expected pointer to member access, got pointer to " +
std::string(array_expr->TypeInfo().name);
return 0;
}
auto struct_id = GenerateExpression(accessor->object);
if (struct_id == 0) {
return 0;
}
params.push_back(Operand(struct_id));
auto* type = TypeOf(accessor->object)->UnwrapRef();
if (!type->Is<core::type::Struct>()) {
TINT_ICE() << "invalid type (" + type->FriendlyName() +
") for runtime array length";
return 0;
}
// Runtime array must be the last member in the structure
params.push_back(
Operand(uint32_t(type->As<core::type::Struct>()->Members().Length() - 1)));
if (!push_function_inst(spv::Op::OpArrayLength, params)) {
return 0;
}
return result_id;
}
case core::Function::kCountOneBits:
op = spv::Op::OpBitCount;
break;
case core::Function::kDot: {
op = spv::Op::OpDot;
auto* vec_ty = builtin->Parameters()[0]->Type()->As<core::type::Vector>();
if (vec_ty->type()->is_integer_scalar()) {
// TODO(crbug.com/tint/1267): OpDot requires floating-point types, but
// WGSL also supports integer types. SPV_KHR_integer_dot_product adds
// support for integer vectors. Use it if it is available.
auto el_ty = Operand(GenerateTypeIfNeeded(vec_ty->type()));
auto vec_a = Operand(get_arg_as_value_id(0));
auto vec_b = Operand(get_arg_as_value_id(1));
if (std::get<uint32_t>(vec_a) == 0 || std::get<uint32_t>(vec_b) == 0) {
return 0;
}
auto sum = Operand(0u);
for (uint32_t i = 0; i < vec_ty->Width(); i++) {
auto a = result_op();
auto b = result_op();
auto mul = result_op();
if (!push_function_inst(spv::Op::OpCompositeExtract,
{el_ty, a, vec_a, Operand(i)}) ||
!push_function_inst(spv::Op::OpCompositeExtract,
{el_ty, b, vec_b, Operand(i)}) ||
!push_function_inst(spv::Op::OpIMul, {el_ty, mul, a, b})) {
return 0;
}
if (i == 0) {
sum = mul;
} else {
auto prev_sum = sum;
auto is_last_el = i == (vec_ty->Width() - 1);
sum = is_last_el ? Operand(result_id) : result_op();
if (!push_function_inst(spv::Op::OpIAdd, {el_ty, sum, prev_sum, mul})) {
return 0;
}
}
}
return result_id;
}
break;
}
case core::Function::kDpdx:
op = spv::Op::OpDPdx;
break;
case core::Function::kDpdxCoarse:
op = spv::Op::OpDPdxCoarse;
break;
case core::Function::kDpdxFine:
op = spv::Op::OpDPdxFine;
break;
case core::Function::kDpdy:
op = spv::Op::OpDPdy;
break;
case core::Function::kDpdyCoarse:
op = spv::Op::OpDPdyCoarse;
break;
case core::Function::kDpdyFine:
op = spv::Op::OpDPdyFine;
break;
case core::Function::kExtractBits:
op = builtin->Parameters()[0]->Type()->is_unsigned_integer_scalar_or_vector()
? spv::Op::OpBitFieldUExtract
: spv::Op::OpBitFieldSExtract;
break;
case core::Function::kFwidth:
op = spv::Op::OpFwidth;
break;
case core::Function::kFwidthCoarse:
op = spv::Op::OpFwidthCoarse;
break;
case core::Function::kFwidthFine:
op = spv::Op::OpFwidthFine;
break;
case core::Function::kInsertBits:
op = spv::Op::OpBitFieldInsert;
break;
case core::Function::kMix: {
auto std450 = Operand(GetGLSLstd450Import());
auto a_id = get_arg_as_value_id(0);
auto b_id = get_arg_as_value_id(1);
auto f_id = get_arg_as_value_id(2);
if (!a_id || !b_id || !f_id) {
return 0;
}
// If the interpolant is scalar but the objects are vectors, we need to
// splat the interpolant into a vector of the same size.
auto* result_vector_type = builtin->ReturnType()->As<core::type::Vector>();
if (result_vector_type && builtin->Parameters()[2]->Type()->Is<core::type::Scalar>()) {
f_id = GenerateSplat(f_id, builtin->Parameters()[0]->Type());
if (f_id == 0) {
return 0;
}
}
if (!push_function_inst(spv::Op::OpExtInst, {Operand(result_type_id), result, std450,
U32Operand(GLSLstd450FMix), Operand(a_id),
Operand(b_id), Operand(f_id)})) {
return 0;
}
return result_id;
}
case core::Function::kQuantizeToF16:
op = spv::Op::OpQuantizeToF16;
break;
case core::Function::kReverseBits:
op = spv::Op::OpBitReverse;
break;
case core::Function::kSelect: {
// Note: Argument order is different in WGSL and SPIR-V
auto cond_id = get_arg_as_value_id(2);
auto true_id = get_arg_as_value_id(1);
auto false_id = get_arg_as_value_id(0);
if (!cond_id || !true_id || !false_id) {
return 0;
}
// If the condition is scalar but the objects are vectors, we need to
// splat the condition into a vector of the same size.
// TODO(jrprice): If we're targeting SPIR-V 1.4, we don't need to do this.
auto* result_vector_type = builtin->ReturnType()->As<core::type::Vector>();
if (result_vector_type && builtin->Parameters()[2]->Type()->Is<core::type::Scalar>()) {
auto* bool_vec_ty = builder_.create<core::type::Vector>(
builder_.create<core::type::Bool>(), result_vector_type->Width());
if (!GenerateTypeIfNeeded(bool_vec_ty)) {
return 0;
}
cond_id = GenerateSplat(cond_id, bool_vec_ty);
if (cond_id == 0) {
return 0;
}
}
if (!push_function_inst(spv::Op::OpSelect,
{Operand(result_type_id), result, Operand(cond_id),
Operand(true_id), Operand(false_id)})) {
return 0;
}
return result_id;
}
case core::Function::kTranspose:
op = spv::Op::OpTranspose;
break;
case core::Function::kAbs:
if (builtin->ReturnType()->is_unsigned_integer_scalar_or_vector()) {
// abs() only operates on *signed* integers.
// This is a no-op for unsigned integers.
return get_arg_as_value_id(0);
}
if (builtin->ReturnType()->is_float_scalar_or_vector()) {
glsl_std450(GLSLstd450FAbs);
} else {
glsl_std450(GLSLstd450SAbs);
}
break;
case core::Function::kDot4I8Packed: {
auto first_param_id = get_arg_as_value_id(0);
auto second_param_id = get_arg_as_value_id(1);
if (!push_function_inst(spv::Op::OpSDotKHR,
{Operand(result_type_id), result, Operand(first_param_id),
Operand(second_param_id),
Operand(static_cast<uint32_t>(
spv::PackedVectorFormat::PackedVectorFormat4x8BitKHR))})) {
return 0;
}
return result_id;
}
case core::Function::kDot4U8Packed: {
auto first_param_id = get_arg_as_value_id(0);
auto second_param_id = get_arg_as_value_id(1);
if (!push_function_inst(spv::Op::OpUDotKHR,
{Operand(result_type_id), result, Operand(first_param_id),
Operand(second_param_id),
Operand(static_cast<uint32_t>(
spv::PackedVectorFormat::PackedVectorFormat4x8BitKHR))})) {
return 0;
}
return result_id;
}
case core::Function::kSubgroupBallot: {
module_.PushCapability(SpvCapabilityGroupNonUniformBallot);
if (!push_function_inst(
spv::Op::OpGroupNonUniformBallot,
{Operand(result_type_id), result,
Operand(GenerateConstantIfNeeded(ScalarConstant::U32(SpvScopeSubgroup))),
Operand(GenerateConstantIfNeeded(ScalarConstant::Bool(true)))})) {
return 0;
}
return result_id;
}
default: {
auto inst_id = builtin_to_glsl_method(builtin);
if (inst_id == 0) {
TINT_ICE() << "unknown method " + std::string(builtin->str());
return 0;
}
glsl_std450(inst_id);
break;
}
}
if (op == spv::Op::OpNop) {
TINT_ICE() << "unable to determine operator for: " + std::string(builtin->str());
return 0;
}
for (size_t i = 0; i < call->Arguments().Length(); i++) {
if (auto val_id = get_arg_as_value_id(i)) {
params.emplace_back(Operand(val_id));
} else {
return 0;
}
}
if (!push_function_inst(op, params)) {
return 0;
}
return result_id;
}
bool Builder::GenerateTextureBuiltin(const sem::Call* call,
const sem::Builtin* builtin,
Operand result_type,
Operand result_id) {
using Usage = core::ParameterUsage;
auto& signature = builtin->Signature();
auto& arguments = call->Arguments();
// Generates the given expression, returning the operand ID
auto gen = [&](const sem::ValueExpression* expr) { return Operand(GenerateExpression(expr)); };
// Returns the argument with the given usage
auto arg = [&](Usage usage) {
int idx = signature.IndexOf(usage);
return (idx >= 0) ? arguments[static_cast<size_t>(idx)] : nullptr;
};
// Generates the argument with the given usage, returning the operand ID
auto gen_arg = [&](Usage usage) {
auto* argument = arg(usage);
if (TINT_UNLIKELY(!argument)) {
TINT_ICE() << "missing argument " << static_cast<int>(usage);
}
return gen(argument);
};
auto* texture = arg(Usage::kTexture);