blob: 621838ac151366d3027e665953517a242b86f97c [file] [log] [blame] [edit]
// Copyright 2020 The Tint Authors. //
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/writer/spirv/builder.h"
#include <algorithm>
#include <limits>
#include <utility>
#include "spirv/unified1/GLSL.std.450.h"
#include "src/ast/call_statement.h"
#include "src/ast/fallthrough_statement.h"
#include "src/ast/internal_decoration.h"
#include "src/ast/override_decoration.h"
#include "src/sem/array.h"
#include "src/sem/atomic_type.h"
#include "src/sem/call.h"
#include "src/sem/depth_multisampled_texture_type.h"
#include "src/sem/depth_texture_type.h"
#include "src/sem/function.h"
#include "src/sem/intrinsic.h"
#include "src/sem/member_accessor_expression.h"
#include "src/sem/multisampled_texture_type.h"
#include "src/sem/reference_type.h"
#include "src/sem/sampled_texture_type.h"
#include "src/sem/struct.h"
#include "src/sem/variable.h"
#include "src/sem/vector_type.h"
#include "src/transform/add_empty_entry_point.h"
#include "src/transform/canonicalize_entry_point_io.h"
#include "src/transform/external_texture_transform.h"
#include "src/transform/fold_constants.h"
#include "src/transform/for_loop_to_loop.h"
#include "src/transform/inline_pointer_lets.h"
#include "src/transform/manager.h"
#include "src/transform/simplify.h"
#include "src/transform/vectorize_scalar_matrix_constructors.h"
#include "src/transform/zero_init_workgroup_memory.h"
#include "src/utils/defer.h"
#include "src/utils/get_or_create.h"
#include "src/writer/append_vector.h"
namespace tint {
namespace writer {
namespace spirv {
namespace {
using IntrinsicType = sem::IntrinsicType;
const char kGLSLstd450[] = "GLSL.std.450";
uint32_t size_of(const InstructionList& instructions) {
uint32_t size = 0;
for (const auto& inst : instructions)
size += inst.word_length();
return size;
}
uint32_t pipeline_stage_to_execution_model(ast::PipelineStage stage) {
SpvExecutionModel model = SpvExecutionModelVertex;
switch (stage) {
case ast::PipelineStage::kFragment:
model = SpvExecutionModelFragment;
break;
case ast::PipelineStage::kVertex:
model = SpvExecutionModelVertex;
break;
case ast::PipelineStage::kCompute:
model = SpvExecutionModelGLCompute;
break;
case ast::PipelineStage::kNone:
model = SpvExecutionModelMax;
break;
}
return model;
}
bool LastIsFallthrough(const ast::BlockStatement* stmts) {
return !stmts->Empty() && stmts->Last()->Is<ast::FallthroughStatement>();
}
// A terminator is anything which will cause a SPIR-V terminator to be emitted.
// This means things like breaks, fallthroughs and continues which all emit an
// OpBranch or return for the OpReturn emission.
bool LastIsTerminator(const ast::BlockStatement* stmts) {
if (IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
ast::DiscardStatement, ast::ReturnStatement,
ast::FallthroughStatement>(stmts->Last())) {
return true;
}
if (auto* block = As<ast::BlockStatement>(stmts->Last())) {
return LastIsTerminator(block);
}
return false;
}
/// Returns the matrix type that is `type` or that is wrapped by
/// one or more levels of an arrays inside of `type`.
/// @param type the given type, which must not be null
/// @returns the nested matrix type, or nullptr if none
const sem::Matrix* GetNestedMatrixType(const sem::Type* type) {
while (auto* arr = type->As<sem::Array>()) {
type = arr->ElemType();
}
return type->As<sem::Matrix>();
}
uint32_t intrinsic_to_glsl_method(const sem::Intrinsic* intrinsic) {
switch (intrinsic->Type()) {
case IntrinsicType::kAcos:
return GLSLstd450Acos;
case IntrinsicType::kAsin:
return GLSLstd450Asin;
case IntrinsicType::kAtan:
return GLSLstd450Atan;
case IntrinsicType::kAtan2:
return GLSLstd450Atan2;
case IntrinsicType::kCeil:
return GLSLstd450Ceil;
case IntrinsicType::kClamp:
if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
return GLSLstd450NClamp;
} else if (intrinsic->ReturnType()->is_unsigned_scalar_or_vector()) {
return GLSLstd450UClamp;
} else {
return GLSLstd450SClamp;
}
case IntrinsicType::kCos:
return GLSLstd450Cos;
case IntrinsicType::kCosh:
return GLSLstd450Cosh;
case IntrinsicType::kCross:
return GLSLstd450Cross;
case IntrinsicType::kDeterminant:
return GLSLstd450Determinant;
case IntrinsicType::kDistance:
return GLSLstd450Distance;
case IntrinsicType::kExp:
return GLSLstd450Exp;
case IntrinsicType::kExp2:
return GLSLstd450Exp2;
case IntrinsicType::kFaceForward:
return GLSLstd450FaceForward;
case IntrinsicType::kFloor:
return GLSLstd450Floor;
case IntrinsicType::kFma:
return GLSLstd450Fma;
case IntrinsicType::kFract:
return GLSLstd450Fract;
case IntrinsicType::kFrexp:
return GLSLstd450FrexpStruct;
case IntrinsicType::kInverseSqrt:
return GLSLstd450InverseSqrt;
case IntrinsicType::kLdexp:
return GLSLstd450Ldexp;
case IntrinsicType::kLength:
return GLSLstd450Length;
case IntrinsicType::kLog:
return GLSLstd450Log;
case IntrinsicType::kLog2:
return GLSLstd450Log2;
case IntrinsicType::kMax:
if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
return GLSLstd450NMax;
} else if (intrinsic->ReturnType()->is_unsigned_scalar_or_vector()) {
return GLSLstd450UMax;
} else {
return GLSLstd450SMax;
}
case IntrinsicType::kMin:
if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
return GLSLstd450NMin;
} else if (intrinsic->ReturnType()->is_unsigned_scalar_or_vector()) {
return GLSLstd450UMin;
} else {
return GLSLstd450SMin;
}
case IntrinsicType::kMix:
return GLSLstd450FMix;
case IntrinsicType::kModf:
return GLSLstd450ModfStruct;
case IntrinsicType::kNormalize:
return GLSLstd450Normalize;
case IntrinsicType::kPack4x8snorm:
return GLSLstd450PackSnorm4x8;
case IntrinsicType::kPack4x8unorm:
return GLSLstd450PackUnorm4x8;
case IntrinsicType::kPack2x16snorm:
return GLSLstd450PackSnorm2x16;
case IntrinsicType::kPack2x16unorm:
return GLSLstd450PackUnorm2x16;
case IntrinsicType::kPack2x16float:
return GLSLstd450PackHalf2x16;
case IntrinsicType::kPow:
return GLSLstd450Pow;
case IntrinsicType::kReflect:
return GLSLstd450Reflect;
case IntrinsicType::kRefract:
return GLSLstd450Refract;
case IntrinsicType::kRound:
return GLSLstd450RoundEven;
case IntrinsicType::kSign:
return GLSLstd450FSign;
case IntrinsicType::kSin:
return GLSLstd450Sin;
case IntrinsicType::kSinh:
return GLSLstd450Sinh;
case IntrinsicType::kSmoothStep:
return GLSLstd450SmoothStep;
case IntrinsicType::kSqrt:
return GLSLstd450Sqrt;
case IntrinsicType::kStep:
return GLSLstd450Step;
case IntrinsicType::kTan:
return GLSLstd450Tan;
case IntrinsicType::kTanh:
return GLSLstd450Tanh;
case IntrinsicType::kTrunc:
return GLSLstd450Trunc;
case IntrinsicType::kUnpack4x8snorm:
return GLSLstd450UnpackSnorm4x8;
case IntrinsicType::kUnpack4x8unorm:
return GLSLstd450UnpackUnorm4x8;
case IntrinsicType::kUnpack2x16snorm:
return GLSLstd450UnpackSnorm2x16;
case IntrinsicType::kUnpack2x16unorm:
return GLSLstd450UnpackUnorm2x16;
case IntrinsicType::kUnpack2x16float:
return GLSLstd450UnpackHalf2x16;
default:
break;
}
return 0;
}
/// @return the vector element type if ty is a vector, otherwise return ty.
const sem::Type* ElementTypeOf(const sem::Type* ty) {
if (auto* v = ty->As<sem::Vector>()) {
return v->type();
}
return ty;
}
} // namespace
SanitizedResult Sanitize(const Program* in,
bool emit_vertex_point_size,
bool disable_workgroup_init) {
transform::Manager manager;
transform::DataMap data;
if (!disable_workgroup_init) {
manager.Add<transform::ZeroInitWorkgroupMemory>();
}
manager.Add<transform::InlinePointerLets>(); // Required for arrayLength()
manager.Add<transform::Simplify>(); // Required for arrayLength()
manager.Add<transform::FoldConstants>();
manager.Add<transform::ExternalTextureTransform>();
manager.Add<transform::VectorizeScalarMatrixConstructors>();
manager.Add<transform::ForLoopToLoop>(); // Must come after
// ZeroInitWorkgroupMemory
manager.Add<transform::CanonicalizeEntryPointIO>();
manager.Add<transform::AddEmptyEntryPoint>();
data.Add<transform::CanonicalizeEntryPointIO::Config>(
transform::CanonicalizeEntryPointIO::Config(
transform::CanonicalizeEntryPointIO::ShaderStyle::kSpirv, 0xFFFFFFFF,
emit_vertex_point_size));
SanitizedResult result;
result.program = std::move(manager.Run(in, data).program);
return result;
}
Builder::AccessorInfo::AccessorInfo() : source_id(0), source_type(nullptr) {}
Builder::AccessorInfo::~AccessorInfo() {}
Builder::Builder(const Program* program)
: builder_(ProgramBuilder::Wrap(program)), scope_stack_({}) {}
Builder::~Builder() = default;
bool Builder::Build() {
push_capability(SpvCapabilityShader);
push_memory_model(spv::Op::OpMemoryModel,
{Operand::Int(SpvAddressingModelLogical),
Operand::Int(SpvMemoryModelGLSL450)});
for (auto* var : builder_.AST().GlobalVariables()) {
if (!GenerateGlobalVariable(var)) {
return false;
}
}
for (auto* func : builder_.AST().Functions()) {
if (!GenerateFunction(func)) {
return false;
}
}
return true;
}
Operand Builder::result_op() {
return Operand::Int(next_id());
}
uint32_t Builder::total_size() const {
// The 5 covers the magic, version, generator, id bound and reserved.
uint32_t size = 5;
size += size_of(capabilities_);
size += size_of(extensions_);
size += size_of(ext_imports_);
size += size_of(memory_model_);
size += size_of(entry_points_);
size += size_of(execution_modes_);
size += size_of(debug_);
size += size_of(annotations_);
size += size_of(types_);
for (const auto& func : functions_) {
size += func.word_length();
}
return size;
}
void Builder::iterate(std::function<void(const Instruction&)> cb) const {
for (const auto& inst : capabilities_) {
cb(inst);
}
for (const auto& inst : extensions_) {
cb(inst);
}
for (const auto& inst : ext_imports_) {
cb(inst);
}
for (const auto& inst : memory_model_) {
cb(inst);
}
for (const auto& inst : entry_points_) {
cb(inst);
}
for (const auto& inst : execution_modes_) {
cb(inst);
}
for (const auto& inst : debug_) {
cb(inst);
}
for (const auto& inst : annotations_) {
cb(inst);
}
for (const auto& inst : types_) {
cb(inst);
}
for (const auto& func : functions_) {
func.iterate(cb);
}
}
void Builder::push_capability(uint32_t cap) {
if (capability_set_.count(cap) == 0) {
capability_set_.insert(cap);
capabilities_.push_back(
Instruction{spv::Op::OpCapability, {Operand::Int(cap)}});
}
}
bool Builder::GenerateLabel(uint32_t id) {
if (!push_function_inst(spv::Op::OpLabel, {Operand::Int(id)})) {
return false;
}
current_label_id_ = id;
return true;
}
bool Builder::GenerateAssignStatement(const ast::AssignmentStatement* assign) {
if (assign->lhs->Is<ast::PhonyExpression>()) {
auto rhs_id = GenerateExpression(assign->rhs);
if (rhs_id == 0) {
return false;
}
return true;
} else {
auto lhs_id = GenerateExpression(assign->lhs);
if (lhs_id == 0) {
return false;
}
auto rhs_id = GenerateExpression(assign->rhs);
if (rhs_id == 0) {
return false;
}
// If the thing we're assigning is a reference then we must load it first.
auto* type = TypeOf(assign->rhs);
rhs_id = GenerateLoadIfNeeded(type, rhs_id);
return GenerateStore(lhs_id, rhs_id);
}
}
bool Builder::GenerateBreakStatement(const ast::BreakStatement*) {
if (merge_stack_.empty()) {
error_ = "Attempted to break without a merge block";
return false;
}
if (!push_function_inst(spv::Op::OpBranch,
{Operand::Int(merge_stack_.back())})) {
return false;
}
return true;
}
bool Builder::GenerateContinueStatement(const ast::ContinueStatement*) {
if (continue_stack_.empty()) {
error_ = "Attempted to continue without a continue block";
return false;
}
if (!push_function_inst(spv::Op::OpBranch,
{Operand::Int(continue_stack_.back())})) {
return false;
}
return true;
}
// TODO(dsinclair): This is generating an OpKill but the semantics of kill
// haven't been defined for WGSL yet. So, this may need to change.
// https://github.com/gpuweb/gpuweb/issues/676
bool Builder::GenerateDiscardStatement(const ast::DiscardStatement*) {
if (!push_function_inst(spv::Op::OpKill, {})) {
return false;
}
return true;
}
bool Builder::GenerateEntryPoint(const ast::Function* func, uint32_t id) {
auto stage = pipeline_stage_to_execution_model(func->PipelineStage());
if (stage == SpvExecutionModelMax) {
error_ = "Unknown pipeline stage provided";
return false;
}
OperandList operands = {
Operand::Int(stage), Operand::Int(id),
Operand::String(builder_.Symbols().NameFor(func->symbol))};
auto* func_sem = builder_.Sem().Get(func);
for (const auto* var : func_sem->TransitivelyReferencedGlobals()) {
// For SPIR-V 1.3 we only output Input/output variables. If we update to
// SPIR-V 1.4 or later this should be all variables.
if (var->StorageClass() != ast::StorageClass::kInput &&
var->StorageClass() != ast::StorageClass::kOutput) {
continue;
}
uint32_t var_id = scope_stack_.Get(var->Declaration()->symbol);
if (var_id == 0) {
error_ = "unable to find ID for global variable: " +
builder_.Symbols().NameFor(var->Declaration()->symbol);
return false;
}
operands.push_back(Operand::Int(var_id));
}
push_entry_point(spv::Op::OpEntryPoint, operands);
return true;
}
bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) {
auto* func_sem = builder_.Sem().Get(func);
// WGSL fragment shader origin is upper left
if (func->PipelineStage() == ast::PipelineStage::kFragment) {
push_execution_mode(
spv::Op::OpExecutionMode,
{Operand::Int(id), Operand::Int(SpvExecutionModeOriginUpperLeft)});
} else if (func->PipelineStage() == ast::PipelineStage::kCompute) {
auto& wgsize = func_sem->WorkgroupSize();
// Check if the workgroup_size uses pipeline-overridable constants.
if (wgsize[0].overridable_const || wgsize[1].overridable_const ||
wgsize[2].overridable_const) {
if (has_overridable_workgroup_size_) {
// Only one stage can have a pipeline-overridable workgroup size.
// TODO(crbug.com/tint/810): Use LocalSizeId to handle this scenario.
TINT_ICE(Writer, builder_.Diagnostics())
<< "multiple stages using pipeline-overridable workgroup sizes";
}
has_overridable_workgroup_size_ = true;
auto* vec3_u32 =
builder_.create<sem::Vector>(builder_.create<sem::U32>(), 3);
uint32_t vec3_u32_type_id = GenerateTypeIfNeeded(vec3_u32);
if (vec3_u32_type_id == 0) {
return 0;
}
OperandList wgsize_ops;
auto wgsize_result = result_op();
wgsize_ops.push_back(Operand::Int(vec3_u32_type_id));
wgsize_ops.push_back(wgsize_result);
// Generate OpConstant instructions for each dimension.
for (int i = 0; i < 3; i++) {
auto constant = ScalarConstant::U32(wgsize[i].value);
if (wgsize[i].overridable_const) {
// Make the constant specializable.
auto* sem_const = builder_.Sem().Get<sem::GlobalVariable>(
wgsize[i].overridable_const);
if (!sem_const->IsPipelineConstant()) {
TINT_ICE(Writer, builder_.Diagnostics())
<< "expected a pipeline-overridable constant";
}
constant.is_spec_op = true;
constant.constant_id = sem_const->ConstantId();
}
auto result = GenerateConstantIfNeeded(constant);
wgsize_ops.push_back(Operand::Int(result));
}
// Generate the WorkgroupSize builtin.
push_type(spv::Op::OpSpecConstantComposite, wgsize_ops);
push_annot(spv::Op::OpDecorate,
{wgsize_result, Operand::Int(SpvDecorationBuiltIn),
Operand::Int(SpvBuiltInWorkgroupSize)});
} else {
// Not overridable, so just use OpExecutionMode LocalSize.
uint32_t x = wgsize[0].value;
uint32_t y = wgsize[1].value;
uint32_t z = wgsize[2].value;
push_execution_mode(
spv::Op::OpExecutionMode,
{Operand::Int(id), Operand::Int(SpvExecutionModeLocalSize),
Operand::Int(x), Operand::Int(y), Operand::Int(z)});
}
}
for (auto builtin : func_sem->TransitivelyReferencedBuiltinVariables()) {
if (builtin.second->builtin == ast::Builtin::kFragDepth) {
push_execution_mode(
spv::Op::OpExecutionMode,
{Operand::Int(id), Operand::Int(SpvExecutionModeDepthReplacing)});
}
}
return true;
}
uint32_t Builder::GenerateExpression(const ast::Expression* expr) {
if (auto* a = expr->As<ast::ArrayAccessorExpression>()) {
return GenerateAccessorExpression(a);
}
if (auto* b = expr->As<ast::BinaryExpression>()) {
return GenerateBinaryExpression(b);
}
if (auto* b = expr->As<ast::BitcastExpression>()) {
return GenerateBitcastExpression(b);
}
if (auto* c = expr->As<ast::CallExpression>()) {
return GenerateCallExpression(c);
}
if (auto* c = expr->As<ast::ConstructorExpression>()) {
return GenerateConstructorExpression(nullptr, c);
}
if (auto* i = expr->As<ast::IdentifierExpression>()) {
return GenerateIdentifierExpression(i);
}
if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
return GenerateAccessorExpression(m);
}
if (auto* u = expr->As<ast::UnaryOpExpression>()) {
return GenerateUnaryOpExpression(u);
}
error_ = "unknown expression type: " + std::string(expr->TypeInfo().name);
return 0;
}
bool Builder::GenerateFunction(const ast::Function* func_ast) {
auto* func = builder_.Sem().Get(func_ast);
uint32_t func_type_id = GenerateFunctionTypeIfNeeded(func);
if (func_type_id == 0) {
return false;
}
auto func_op = result_op();
auto func_id = func_op.to_i();
push_debug(spv::Op::OpName,
{Operand::Int(func_id),
Operand::String(builder_.Symbols().NameFor(func_ast->symbol))});
auto ret_id = GenerateTypeIfNeeded(func->ReturnType());
if (ret_id == 0) {
return false;
}
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
auto definition_inst = Instruction{
spv::Op::OpFunction,
{Operand::Int(ret_id), func_op, Operand::Int(SpvFunctionControlMaskNone),
Operand::Int(func_type_id)}};
InstructionList params;
for (auto* param : func->Parameters()) {
auto param_op = result_op();
auto param_id = param_op.to_i();
auto param_type_id = GenerateTypeIfNeeded(param->Type());
if (param_type_id == 0) {
return false;
}
push_debug(spv::Op::OpName, {Operand::Int(param_id),
Operand::String(builder_.Symbols().NameFor(
param->Declaration()->symbol))});
params.push_back(Instruction{spv::Op::OpFunctionParameter,
{Operand::Int(param_type_id), param_op}});
scope_stack_.Set(param->Declaration()->symbol, param_id);
}
push_function(Function{definition_inst, result_op(), std::move(params)});
for (auto* stmt : func_ast->body->statements) {
if (!GenerateStatement(stmt)) {
return false;
}
}
if (func_ast->IsEntryPoint()) {
if (!GenerateEntryPoint(func_ast, func_id)) {
return false;
}
if (!GenerateExecutionModes(func_ast, func_id)) {
return false;
}
}
func_symbol_to_id_[func_ast->symbol] = func_id;
return true;
}
uint32_t Builder::GenerateFunctionTypeIfNeeded(const sem::Function* func) {
return utils::GetOrCreate(
func_sig_to_id_, func->Signature(), [&]() -> uint32_t {
auto func_op = result_op();
auto func_type_id = func_op.to_i();
auto ret_id = GenerateTypeIfNeeded(func->ReturnType());
if (ret_id == 0) {
return 0;
}
OperandList ops = {func_op, Operand::Int(ret_id)};
for (auto* param : func->Parameters()) {
auto param_type_id = GenerateTypeIfNeeded(param->Type());
if (param_type_id == 0) {
return 0;
}
ops.push_back(Operand::Int(param_type_id));
}
push_type(spv::Op::OpTypeFunction, std::move(ops));
return func_type_id;
});
}
bool Builder::GenerateFunctionVariable(const ast::Variable* var) {
uint32_t init_id = 0;
if (var->constructor) {
init_id = GenerateExpression(var->constructor);
if (init_id == 0) {
return false;
}
auto* type = TypeOf(var->constructor);
if (type->Is<sem::Reference>()) {
init_id = GenerateLoadIfNeeded(type, init_id);
}
}
if (var->is_const) {
if (!var->constructor) {
error_ = "missing constructor for constant";
return false;
}
scope_stack_.Set(var->symbol, init_id);
spirv_id_to_variable_[init_id] = var;
return true;
}
auto result = result_op();
auto var_id = result.to_i();
auto sc = ast::StorageClass::kFunction;
auto* type = builder_.Sem().Get(var)->Type();
auto type_id = GenerateTypeIfNeeded(type);
if (type_id == 0) {
return false;
}
push_debug(spv::Op::OpName,
{Operand::Int(var_id),
Operand::String(builder_.Symbols().NameFor(var->symbol))});
// TODO(dsinclair) We could detect if the constructor is fully const and emit
// an initializer value for the variable instead of doing the OpLoad.
auto null_id = GenerateConstantNullIfNeeded(type->UnwrapRef());
if (null_id == 0) {
return 0;
}
push_function_var({Operand::Int(type_id), result,
Operand::Int(ConvertStorageClass(sc)),
Operand::Int(null_id)});
if (var->constructor) {
if (!GenerateStore(var_id, init_id)) {
return false;
}
}
scope_stack_.Set(var->symbol, var_id);
spirv_id_to_variable_[var_id] = var;
return true;
}
bool Builder::GenerateStore(uint32_t to, uint32_t from) {
return push_function_inst(spv::Op::OpStore,
{Operand::Int(to), Operand::Int(from)});
}
bool Builder::GenerateGlobalVariable(const ast::Variable* var) {
auto* sem = builder_.Sem().Get(var);
auto* type = sem->Type()->UnwrapRef();
uint32_t init_id = 0;
if (var->constructor) {
if (!var->constructor->Is<ast::ConstructorExpression>()) {
error_ = "scalar constructor expected";
return false;
}
init_id = GenerateConstructorExpression(
var, var->constructor->As<ast::ConstructorExpression>());
if (init_id == 0) {
return false;
}
}
if (var->is_const) {
if (!var->constructor) {
// Constants must have an initializer unless they have an override
// decoration.
if (!ast::HasDecoration<ast::OverrideDecoration>(var->decorations)) {
error_ = "missing constructor for constant";
return false;
}
// SPIR-V requires specialization constants to have initializers.
if (type->Is<sem::F32>()) {
ast::FloatLiteral l(ProgramID(), Source{}, 0.0f);
init_id = GenerateLiteralIfNeeded(var, &l);
} else if (type->Is<sem::U32>()) {
ast::UintLiteral l(ProgramID(), Source{}, 0);
init_id = GenerateLiteralIfNeeded(var, &l);
} else if (type->Is<sem::I32>()) {
ast::SintLiteral l(ProgramID(), Source{}, 0);
init_id = GenerateLiteralIfNeeded(var, &l);
} else if (type->Is<sem::Bool>()) {
ast::BoolLiteral l(ProgramID(), Source{}, false);
init_id = GenerateLiteralIfNeeded(var, &l);
} else {
error_ = "invalid type for pipeline constant ID, must be scalar";
return false;
}
if (init_id == 0) {
return 0;
}
}
push_debug(spv::Op::OpName,
{Operand::Int(init_id),
Operand::String(builder_.Symbols().NameFor(var->symbol))});
scope_stack_.Set(var->symbol, init_id);
spirv_id_to_variable_[init_id] = var;
return true;
}
auto result = result_op();
auto var_id = result.to_i();
auto sc = sem->StorageClass() == ast::StorageClass::kNone
? ast::StorageClass::kPrivate
: sem->StorageClass();
auto type_id = GenerateTypeIfNeeded(sem->Type());
if (type_id == 0) {
return false;
}
push_debug(spv::Op::OpName,
{Operand::Int(var_id),
Operand::String(builder_.Symbols().NameFor(var->symbol))});
OperandList ops = {Operand::Int(type_id), result,
Operand::Int(ConvertStorageClass(sc))};
if (var->constructor) {
ops.push_back(Operand::Int(init_id));
} else {
auto* st = type->As<sem::StorageTexture>();
if (st || type->Is<sem::Struct>()) {
// type is a sem::Struct or a sem::StorageTexture
auto access = st ? st->access() : sem->Access();
switch (access) {
case ast::Access::kWrite:
push_annot(
spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationNonReadable)});
break;
case ast::Access::kRead:
push_annot(
spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationNonWritable)});
break;
case ast::Access::kUndefined:
case ast::Access::kReadWrite:
break;
}
}
if (!type->Is<sem::Sampler>()) {
// If we don't have a constructor and we're an Output or Private
// variable, then WGSL requires that we zero-initialize.
if (sem->StorageClass() == ast::StorageClass::kPrivate ||
sem->StorageClass() == ast::StorageClass::kOutput) {
init_id = GenerateConstantNullIfNeeded(type);
if (init_id == 0) {
return 0;
}
ops.push_back(Operand::Int(init_id));
}
}
}
push_type(spv::Op::OpVariable, std::move(ops));
for (auto* deco : var->decorations) {
if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
push_annot(spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationBuiltIn),
Operand::Int(
ConvertBuiltin(builtin->builtin, sem->StorageClass()))});
} else if (auto* location = deco->As<ast::LocationDecoration>()) {
push_annot(spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationLocation),
Operand::Int(location->value)});
} else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
AddInterpolationDecorations(var_id, interpolate->type,
interpolate->sampling);
} else if (deco->Is<ast::InvariantDecoration>()) {
push_annot(spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationInvariant)});
} else if (auto* binding = deco->As<ast::BindingDecoration>()) {
push_annot(spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationBinding),
Operand::Int(binding->value)});
} else if (auto* group = deco->As<ast::GroupDecoration>()) {
push_annot(spv::Op::OpDecorate, {Operand::Int(var_id),
Operand::Int(SpvDecorationDescriptorSet),
Operand::Int(group->value)});
} else if (deco->Is<ast::OverrideDecoration>()) {
// Spec constants are handled elsewhere
} else if (!deco->Is<ast::InternalDecoration>()) {
error_ = "unknown decoration";
return false;
}
}
scope_stack_.Set(var->symbol, var_id);
spirv_id_to_variable_[var_id] = var;
return true;
}
bool Builder::GenerateArrayAccessor(const ast::ArrayAccessorExpression* expr,
AccessorInfo* info) {
auto idx_id = GenerateExpression(expr->index);
if (idx_id == 0) {
return 0;
}
auto* type = TypeOf(expr->index);
idx_id = GenerateLoadIfNeeded(type, idx_id);
// If the source is a reference, we access chain into it.
// In the future, pointers may support access-chaining.
// See https://github.com/gpuweb/gpuweb/pull/1580
if (info->source_type->Is<sem::Reference>()) {
info->access_chain_indices.push_back(idx_id);
info->source_type = TypeOf(expr);
return true;
}
auto result_type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (result_type_id == 0) {
return false;
}
// We don't have a pointer, so we can just directly extract the value.
auto extract = result_op();
auto extract_id = extract.to_i();
// If the index is a literal, we use OpCompositeExtract.
if (auto* scalar = expr->index->As<ast::ScalarConstructorExpression>()) {
auto* literal = scalar->literal->As<ast::IntLiteral>();
if (!literal) {
TINT_ICE(Writer, builder_.Diagnostics())
<< "bad literal in array accessor";
return false;
}
if (!push_function_inst(spv::Op::OpCompositeExtract,
{Operand::Int(result_type_id), extract,
Operand::Int(info->source_id),
Operand::Int(literal->ValueAsU32())})) {
return false;
}
info->source_id = extract_id;
info->source_type = TypeOf(expr);
return true;
}
// If the source is a vector, we use OpVectorExtractDynamic.
if (info->source_type->Is<sem::Vector>()) {
if (!push_function_inst(
spv::Op::OpVectorExtractDynamic,
{Operand::Int(result_type_id), extract,
Operand::Int(info->source_id), Operand::Int(idx_id)})) {
return false;
}
info->source_id = extract_id;
info->source_type = TypeOf(expr);
return true;
}
TINT_ICE(Writer, builder_.Diagnostics())
<< "unsupported array accessor expression";
return false;
}
bool Builder::GenerateMemberAccessor(const ast::MemberAccessorExpression* expr,
AccessorInfo* info) {
auto* expr_sem = builder_.Sem().Get(expr);
auto* expr_type = expr_sem->Type();
if (auto* access = expr_sem->As<sem::StructMemberAccess>()) {
uint32_t idx = access->Member()->Index();
if (info->source_type->Is<sem::Reference>()) {
auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(idx));
if (idx_id == 0) {
return 0;
}
info->access_chain_indices.push_back(idx_id);
info->source_type = expr_type;
} else {
auto result_type_id = GenerateTypeIfNeeded(expr_type);
if (result_type_id == 0) {
return false;
}
auto extract = result_op();
auto extract_id = extract.to_i();
if (!push_function_inst(
spv::Op::OpCompositeExtract,
{Operand::Int(result_type_id), extract,
Operand::Int(info->source_id), Operand::Int(idx)})) {
return false;
}
info->source_id = extract_id;
info->source_type = expr_type;
}
return true;
}
if (auto* swizzle = expr_sem->As<sem::Swizzle>()) {
// Single element swizzle is either an access chain or a composite extract
auto& indices = swizzle->Indices();
if (indices.size() == 1) {
if (info->source_type->Is<sem::Reference>()) {
auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(indices[0]));
if (idx_id == 0) {
return 0;
}
info->access_chain_indices.push_back(idx_id);
} else {
auto result_type_id = GenerateTypeIfNeeded(expr_type);
if (result_type_id == 0) {
return 0;
}
auto extract = result_op();
auto extract_id = extract.to_i();
if (!push_function_inst(
spv::Op::OpCompositeExtract,
{Operand::Int(result_type_id), extract,
Operand::Int(info->source_id), Operand::Int(indices[0])})) {
return false;
}
info->source_id = extract_id;
info->source_type = expr_type;
}
return true;
}
// Store the type away as it may change if we run the access chain
auto* incoming_type = info->source_type;
// Multi-item extract is a VectorShuffle. We have to emit any existing
// access chain data, then load the access chain and shuffle that.
if (!info->access_chain_indices.empty()) {
auto result_type_id = GenerateTypeIfNeeded(info->source_type);
if (result_type_id == 0) {
return 0;
}
auto extract = result_op();
auto extract_id = extract.to_i();
OperandList ops = {Operand::Int(result_type_id), extract,
Operand::Int(info->source_id)};
for (auto id : info->access_chain_indices) {
ops.push_back(Operand::Int(id));
}
if (!push_function_inst(spv::Op::OpAccessChain, ops)) {
return false;
}
info->source_id = GenerateLoadIfNeeded(expr_type, extract_id);
info->source_type = expr_type->UnwrapRef();
info->access_chain_indices.clear();
}
auto result_type_id = GenerateTypeIfNeeded(expr_type);
if (result_type_id == 0) {
return false;
}
auto vec_id = GenerateLoadIfNeeded(incoming_type, info->source_id);
auto result = result_op();
auto result_id = result.to_i();
OperandList ops = {Operand::Int(result_type_id), result,
Operand::Int(vec_id), Operand::Int(vec_id)};
for (auto idx : indices) {
ops.push_back(Operand::Int(idx));
}
if (!push_function_inst(spv::Op::OpVectorShuffle, ops)) {
return false;
}
info->source_id = result_id;
info->source_type = expr_type;
return true;
}
TINT_ICE(Writer, builder_.Diagnostics())
<< "unhandled member index type: " << expr_sem->TypeInfo().name;
return false;
}
uint32_t Builder::GenerateAccessorExpression(const ast::Expression* expr) {
if (!expr->IsAnyOf<ast::ArrayAccessorExpression,
ast::MemberAccessorExpression>()) {
TINT_ICE(Writer, builder_.Diagnostics()) << "expression is not an accessor";
return 0;
}
// Gather a list of all the member and array accessors that are in this chain.
// The list is built in reverse order as that's the order we need to access
// the chain.
std::vector<const ast::Expression*> accessors;
const ast::Expression* source = expr;
while (true) {
if (auto* array = source->As<ast::ArrayAccessorExpression>()) {
accessors.insert(accessors.begin(), source);
source = array->array;
} else if (auto* member = source->As<ast::MemberAccessorExpression>()) {
accessors.insert(accessors.begin(), source);
source = member->structure;
} else {
break;
}
}
AccessorInfo info;
info.source_id = GenerateExpression(source);
if (info.source_id == 0) {
return 0;
}
info.source_type = TypeOf(source);
for (auto* accessor : accessors) {
if (auto* array = accessor->As<ast::ArrayAccessorExpression>()) {
if (!GenerateArrayAccessor(array, &info)) {
return 0;
}
} else if (auto* member = accessor->As<ast::MemberAccessorExpression>()) {
if (!GenerateMemberAccessor(member, &info)) {
return 0;
}
} else {
error_ =
"invalid accessor in list: " + std::string(accessor->TypeInfo().name);
return 0;
}
}
if (!info.access_chain_indices.empty()) {
auto* type = TypeOf(expr);
auto result_type_id = GenerateTypeIfNeeded(type);
if (result_type_id == 0) {
return 0;
}
auto result = result_op();
auto result_id = result.to_i();
OperandList ops = {Operand::Int(result_type_id), result,
Operand::Int(info.source_id)};
for (auto id : info.access_chain_indices) {
ops.push_back(Operand::Int(id));
}
if (!push_function_inst(spv::Op::OpAccessChain, ops)) {
return false;
}
info.source_id = result_id;
}
return info.source_id;
}
uint32_t Builder::GenerateIdentifierExpression(
const ast::IdentifierExpression* expr) {
uint32_t val = scope_stack_.Get(expr->symbol);
if (val == 0) {
error_ = "unable to find variable with identifier: " +
builder_.Symbols().NameFor(expr->symbol);
}
return val;
}
uint32_t Builder::GenerateLoadIfNeeded(const sem::Type* type, uint32_t id) {
if (auto* ref = type->As<sem::Reference>()) {
type = ref->StoreType();
} else {
return id;
}
auto type_id = GenerateTypeIfNeeded(type);
auto result = result_op();
auto result_id = result.to_i();
if (!push_function_inst(spv::Op::OpLoad,
{Operand::Int(type_id), result, Operand::Int(id)})) {
return 0;
}
return result_id;
}
uint32_t Builder::GenerateUnaryOpExpression(
const ast::UnaryOpExpression* expr) {
auto result = result_op();
auto result_id = result.to_i();
auto val_id = GenerateExpression(expr->expr);
if (val_id == 0) {
return 0;
}
spv::Op op = spv::Op::OpNop;
switch (expr->op) {
case ast::UnaryOp::kComplement:
op = spv::Op::OpNot;
break;
case ast::UnaryOp::kNegation:
if (TypeOf(expr)->is_float_scalar_or_vector()) {
op = spv::Op::OpFNegate;
} else {
op = spv::Op::OpSNegate;
}
break;
case ast::UnaryOp::kNot:
op = spv::Op::OpLogicalNot;
break;
case ast::UnaryOp::kAddressOf:
case ast::UnaryOp::kIndirection:
// Address-of converts a reference to a pointer, and dereference converts
// a pointer to a reference. These are the same thing in SPIR-V, so this
// is a no-op.
return val_id;
}
val_id = GenerateLoadIfNeeded(TypeOf(expr->expr), val_id);
auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (type_id == 0) {
return 0;
}
if (!push_function_inst(
op, {Operand::Int(type_id), result, Operand::Int(val_id)})) {
return false;
}
return result_id;
}
uint32_t Builder::GetGLSLstd450Import() {
auto where = import_name_to_id_.find(kGLSLstd450);
if (where != import_name_to_id_.end()) {
return where->second;
}
// It doesn't exist yet. Generate it.
auto result = result_op();
auto id = result.to_i();
push_ext_import(spv::Op::OpExtInstImport,
{result, Operand::String(kGLSLstd450)});
// Remember it for later.
import_name_to_id_[kGLSLstd450] = id;
return id;
}
uint32_t Builder::GenerateConstructorExpression(
const ast::Variable* var,
const ast::ConstructorExpression* expr) {
if (auto* scalar = expr->As<ast::ScalarConstructorExpression>()) {
return GenerateLiteralIfNeeded(var, scalar->literal);
}
if (auto* type = expr->As<ast::TypeConstructorExpression>()) {
return GenerateTypeConstructorExpression(var, type);
}
error_ = "unknown constructor expression";
return 0;
}
bool Builder::is_constructor_const(const ast::Expression* expr,
bool is_global_init) {
auto* constructor = expr->As<ast::ConstructorExpression>();
if (constructor == nullptr) {
return false;
}
if (constructor->Is<ast::ScalarConstructorExpression>()) {
return true;
}
auto* tc = constructor->As<ast::TypeConstructorExpression>();
auto* result_type = TypeOf(tc)->UnwrapRef();
for (size_t i = 0; i < tc->values.size(); ++i) {
auto* e = tc->values[i];
if (!e->Is<ast::ConstructorExpression>()) {
if (is_global_init) {
error_ = "constructor must be a constant expression";
return false;
}
return false;
}
if (!is_constructor_const(e, is_global_init)) {
return false;
}
if (has_error()) {
return false;
}
auto* sc = e->As<ast::ScalarConstructorExpression>();
if (result_type->Is<sem::Vector>() && sc == nullptr) {
return false;
}
// This should all be handled by |is_constructor_const| call above
if (sc == nullptr) {
continue;
}
const sem::Type* subtype = result_type->UnwrapRef();
if (auto* vec = subtype->As<sem::Vector>()) {
subtype = vec->type();
} else if (auto* mat = subtype->As<sem::Matrix>()) {
subtype = mat->type();
} else if (auto* arr = subtype->As<sem::Array>()) {
subtype = arr->ElemType();
} else if (auto* str = subtype->As<sem::Struct>()) {
subtype = str->Members()[i]->Type();
}
if (subtype != TypeOf(sc)->UnwrapRef()) {
return false;
}
}
return true;
}
uint32_t Builder::GenerateTypeConstructorExpression(
const ast::Variable* var,
const ast::TypeConstructorExpression* init) {
auto* global_var = builder_.Sem().Get<sem::GlobalVariable>(var);
auto& values = init->values;
auto* result_type = TypeOf(init);
// Generate the zero initializer if there are no values provided.
if (values.empty()) {
if (global_var && global_var->IsPipelineConstant()) {
auto constant_id = global_var->ConstantId();
if (result_type->Is<sem::I32>()) {
return GenerateConstantIfNeeded(
ScalarConstant::I32(0).AsSpecOp(constant_id));
}
if (result_type->Is<sem::U32>()) {
return GenerateConstantIfNeeded(
ScalarConstant::U32(0).AsSpecOp(constant_id));
}
if (result_type->Is<sem::F32>()) {
return GenerateConstantIfNeeded(
ScalarConstant::F32(0).AsSpecOp(constant_id));
}
if (result_type->Is<sem::Bool>()) {
return GenerateConstantIfNeeded(
ScalarConstant::Bool(false).AsSpecOp(constant_id));
}
}
return GenerateConstantNullIfNeeded(result_type->UnwrapRef());
}
std::ostringstream out;
out << "__const_" << init->type->FriendlyName(builder_.Symbols()) << "_";
result_type = result_type->UnwrapRef();
bool constructor_is_const = is_constructor_const(init, global_var);
if (has_error()) {
return 0;
}
bool can_cast_or_copy = result_type->is_scalar();
if (auto* res_vec = result_type->As<sem::Vector>()) {
if (res_vec->type()->is_scalar()) {
auto* value_type = TypeOf(values[0])->UnwrapRef();
if (auto* val_vec = value_type->As<sem::Vector>()) {
if (val_vec->type()->is_scalar()) {
can_cast_or_copy = res_vec->Width() == val_vec->Width();
}
}
}
}
if (can_cast_or_copy) {
return GenerateCastOrCopyOrPassthrough(result_type, values[0], global_var);
}
auto type_id = GenerateTypeIfNeeded(result_type);
if (type_id == 0) {
return 0;
}
bool result_is_constant_composite = constructor_is_const;
bool result_is_spec_composite = false;
if (auto* vec = result_type->As<sem::Vector>()) {
result_type = vec->type();
}
OperandList ops;
for (auto* e : values) {
uint32_t id = 0;
if (constructor_is_const) {
id = GenerateConstructorExpression(nullptr,
e->As<ast::ConstructorExpression>());
} else {
id = GenerateExpression(e);
id = GenerateLoadIfNeeded(TypeOf(e), id);
}
if (id == 0) {
return 0;
}
auto* value_type = TypeOf(e)->UnwrapRef();
// If the result and value types are the same we can just use the object.
// If the result is not a vector then we should have validated that the
// value type is a correctly sized vector so we can just use it directly.
if (result_type == value_type || result_type->Is<sem::Matrix>() ||
result_type->Is<sem::Array>() || result_type->Is<sem::Struct>()) {
out << "_" << id;
ops.push_back(Operand::Int(id));
continue;
}
// Both scalars, but not the same type so we need to generate a conversion
// of the value.
if (value_type->is_scalar() && result_type->is_scalar()) {
id = GenerateCastOrCopyOrPassthrough(result_type, values[0], global_var);
out << "_" << id;
ops.push_back(Operand::Int(id));
continue;
}
// When handling vectors as the values there a few cases to take into
// consideration:
// 1. Module scoped vec3<f32>(vec2<f32>(1, 2), 3) -> OpSpecConstantOp
// 2. Function scoped vec3<f32>(vec2<f32>(1, 2), 3) -> OpCompositeExtract
// 3. Either array<vec3<f32>, 1>(vec3<f32>(1, 2, 3)) -> use the ID.
// -> handled above
//
// For cases 1 and 2, if the type is different we also may need to insert
// a type cast.
if (auto* vec = value_type->As<sem::Vector>()) {
auto* vec_type = vec->type();
auto value_type_id = GenerateTypeIfNeeded(vec_type);
if (value_type_id == 0) {
return 0;
}
for (uint32_t i = 0; i < vec->Width(); ++i) {
auto extract = result_op();
auto extract_id = extract.to_i();
if (!global_var) {
// A non-global initializer. Case 2.
if (!push_function_inst(spv::Op::OpCompositeExtract,
{Operand::Int(value_type_id), extract,
Operand::Int(id), Operand::Int(i)})) {
return false;
}
// We no longer have a constant composite, but have to do a
// composite construction as these calls are inside a function.
result_is_constant_composite = false;
} else {
// A global initializer, must use OpSpecConstantOp. Case 1.
auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(i));
if (idx_id == 0) {
return 0;
}
push_type(spv::Op::OpSpecConstantOp,
{Operand::Int(value_type_id), extract,
Operand::Int(SpvOpCompositeExtract), Operand::Int(id),
Operand::Int(idx_id)});
result_is_spec_composite = true;
}
out << "_" << extract_id;
ops.push_back(Operand::Int(extract_id));
}
} else {
error_ = "Unhandled type cast value type";
return 0;
}
}
// For a single-value vector initializer, splat the initializer value.
auto* const init_result_type = TypeOf(init)->UnwrapRef();
if (values.size() == 1 && init_result_type->is_scalar_vector() &&
TypeOf(values[0])->UnwrapRef()->is_scalar()) {
size_t vec_size = init_result_type->As<sem::Vector>()->Width();
for (size_t i = 0; i < (vec_size - 1); ++i) {
ops.push_back(ops[0]);
}
}
auto str = out.str();
auto val = type_constructor_to_id_.find(str);
if (val != type_constructor_to_id_.end()) {
return val->second;
}
auto result = result_op();
ops.insert(ops.begin(), result);
ops.insert(ops.begin(), Operand::Int(type_id));
type_constructor_to_id_[str] = result.to_i();
if (result_is_spec_composite) {
push_type(spv::Op::OpSpecConstantComposite, ops);
} else if (result_is_constant_composite) {
push_type(spv::Op::OpConstantComposite, ops);
} else {
if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
return 0;
}
}
return result.to_i();
}
uint32_t Builder::GenerateCastOrCopyOrPassthrough(
const sem::Type* to_type,
const ast::Expression* from_expr,
bool is_global_init) {
// This should not happen as we rely on constant folding to obviate
// casts/conversions for module-scope variables
if (is_global_init) {
TINT_ICE(Writer, builder_.Diagnostics())
<< "Module-level conversions are not supported. Conversions should "
"have already been constant-folded by the FoldConstants transform.";
return 0;
}
auto elem_type_of = [](const sem::Type* t) -> const sem::Type* {
if (t->is_scalar()) {
return t;
}
if (auto* v = t->As<sem::Vector>()) {
return v->type();
}
return nullptr;
};
auto result = result_op();
auto result_id = result.to_i();
auto result_type_id = GenerateTypeIfNeeded(to_type);
if (result_type_id == 0) {
return 0;
}
auto val_id = GenerateExpression(from_expr);
if (val_id == 0) {
return 0;
}
val_id = GenerateLoadIfNeeded(TypeOf(from_expr), val_id);
auto* from_type = TypeOf(from_expr)->UnwrapRef();
spv::Op op = spv::Op::OpNop;
if ((from_type->Is<sem::I32>() && to_type->Is<sem::F32>()) ||
(from_type->is_signed_integer_vector() && to_type->is_float_vector())) {
op = spv::Op::OpConvertSToF;
} else if ((from_type->Is<sem::U32>() && to_type->Is<sem::F32>()) ||
(from_type->is_unsigned_integer_vector() &&
to_type->is_float_vector())) {
op = spv::Op::OpConvertUToF;
} else if ((from_type->Is<sem::F32>() && to_type->Is<sem::I32>()) ||
(from_type->is_float_vector() &&
to_type->is_signed_integer_vector())) {
op = spv::Op::OpConvertFToS;
} else if ((from_type->Is<sem::F32>() && to_type->Is<sem::U32>()) ||
(from_type->is_float_vector() &&
to_type->is_unsigned_integer_vector())) {
op = spv::Op::OpConvertFToU;
} else if ((from_type->Is<sem::Bool>() && to_type->Is<sem::Bool>()) ||
(from_type->Is<sem::U32>() && to_type->Is<sem::U32>()) ||
(from_type->Is<sem::I32>() && to_type->Is<sem::I32>()) ||
(from_type->Is<sem::F32>() && to_type->Is<sem::F32>()) ||
(from_type->Is<sem::Vector>() && (from_type == to_type))) {
return val_id;
} else if ((from_type->Is<sem::I32>() && to_type->Is<sem::U32>()) ||
(from_type->Is<sem::U32>() && to_type->Is<sem::I32>()) ||
(from_type->is_signed_integer_vector() &&
to_type->is_unsigned_integer_vector()) ||
(from_type->is_unsigned_integer_vector() &&
to_type->is_integer_scalar_or_vector())) {
op = spv::Op::OpBitcast;
} else if ((from_type->is_numeric_scalar() && to_type->Is<sem::Bool>()) ||
(from_type->is_numeric_vector() && to_type->is_bool_vector())) {
// Convert scalar (vector) to bool (vector)
// Return the result of comparing from_expr with zero
uint32_t zero = GenerateConstantNullIfNeeded(from_type);
const auto* from_elem_type = elem_type_of(from_type);
op = from_elem_type->is_integer_scalar() ? spv::Op::OpINotEqual
: spv::Op::OpFUnordNotEqual;
if (!push_function_inst(
op, {Operand::Int(result_type_id), Operand::Int(result_id),
Operand::Int(val_id), Operand::Int(zero)})) {
return 0;
}
return result_id;
} else if (from_type->is_bool_scalar_or_vector() &&
to_type->is_numeric_scalar_or_vector()) {
// Convert bool scalar/vector to numeric scalar/vector.
// Use the bool to select between 1 (if true) and 0 (if false).
const auto* to_elem_type = elem_type_of(to_type);
uint32_t one_id;
uint32_t zero_id;
if (to_elem_type->Is<sem::F32>()) {
ast::FloatLiteral one(ProgramID(), Source{}, 1.0f);
ast::FloatLiteral zero(ProgramID(), Source{}, 0.0f);
one_id = GenerateLiteralIfNeeded(nullptr, &one);
zero_id = GenerateLiteralIfNeeded(nullptr, &zero);
} else if (to_elem_type->Is<sem::U32>()) {
ast::UintLiteral one(ProgramID(), Source{}, 1);
ast::UintLiteral zero(ProgramID(), Source{}, 0);
one_id = GenerateLiteralIfNeeded(nullptr, &one);
zero_id = GenerateLiteralIfNeeded(nullptr, &zero);
} else if (to_elem_type->Is<sem::I32>()) {
ast::SintLiteral one(ProgramID(), Source{}, 1);
ast::SintLiteral zero(ProgramID(), Source{}, 0);
one_id = GenerateLiteralIfNeeded(nullptr, &one);
zero_id = GenerateLiteralIfNeeded(nullptr, &zero);
} else {
error_ = "invalid destination type for bool conversion";
return false;
}
if (auto* to_vec = to_type->As<sem::Vector>()) {
// Splat the scalars into vectors.
one_id = GenerateConstantVectorSplatIfNeeded(to_vec, one_id);
zero_id = GenerateConstantVectorSplatIfNeeded(to_vec, zero_id);
}
if (!one_id || !zero_id) {
return false;
}
op = spv::Op::OpSelect;
if (!push_function_inst(
op, {Operand::Int(result_type_id), Operand::Int(result_id),
Operand::Int(val_id), Operand::Int(one_id),
Operand::Int(zero_id)})) {
return 0;
}
return result_id;
} else {
TINT_ICE(Writer, builder_.Diagnostics()) << "Invalid from_type";
}
if (op == spv::Op::OpNop) {
error_ = "unable to determine conversion type for cast, from: " +
from_type->type_name() + " to: " + to_type->type_name();
return 0;
}
if (!push_function_inst(
op, {Operand::Int(result_type_id), result, Operand::Int(val_id)})) {
return 0;
}
return result_id;
}
uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var,
const ast::Literal* lit) {
ScalarConstant constant;
auto* global = builder_.Sem().Get<sem::GlobalVariable>(var);
if (global && global->IsPipelineConstant()) {
constant.is_spec_op = true;
constant.constant_id = global->ConstantId();
}
if (auto* l = lit->As<ast::BoolLiteral>()) {
constant.kind = ScalarConstant::Kind::kBool;
constant.value.b = l->value;
} else if (auto* sl = lit->As<ast::SintLiteral>()) {
constant.kind = ScalarConstant::Kind::kI32;
constant.value.i32 = sl->value;
} else if (auto* ul = lit->As<ast::UintLiteral>()) {
constant.kind = ScalarConstant::Kind::kU32;
constant.value.u32 = ul->value;
} else if (auto* fl = lit->As<ast::FloatLiteral>()) {
constant.kind = ScalarConstant::Kind::kF32;
constant.value.f32 = fl->value;
} else {
error_ = "unknown literal type";
return 0;
}
return GenerateConstantIfNeeded(constant);
}
uint32_t Builder::GenerateConstantIfNeeded(const ScalarConstant& constant) {
auto it = const_to_id_.find(constant);
if (it != const_to_id_.end()) {
return it->second;
}
uint32_t type_id = 0;
switch (constant.kind) {
case ScalarConstant::Kind::kU32: {
type_id = GenerateTypeIfNeeded(builder_.create<sem::U32>());
break;
}
case ScalarConstant::Kind::kI32: {
type_id = GenerateTypeIfNeeded(builder_.create<sem::I32>());
break;
}
case ScalarConstant::Kind::kF32: {
type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>());
break;
}
case ScalarConstant::Kind::kBool: {
type_id = GenerateTypeIfNeeded(builder_.create<sem::Bool>());
break;
}
}
if (type_id == 0) {
return 0;
}
auto result = result_op();
auto result_id = result.to_i();
if (constant.is_spec_op) {
push_annot(spv::Op::OpDecorate,
{Operand::Int(result_id), Operand::Int(SpvDecorationSpecId),
Operand::Int(constant.constant_id)});
}
switch (constant.kind) {
case ScalarConstant::Kind::kU32: {
push_type(
constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
{Operand::Int(type_id), result, Operand::Int(constant.value.u32)});
break;
}
case ScalarConstant::Kind::kI32: {
push_type(
constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
{Operand::Int(type_id), result, Operand::Int(constant.value.i32)});
break;
}
case ScalarConstant::Kind::kF32: {
push_type(
constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
{Operand::Int(type_id), result, Operand::Float(constant.value.f32)});
break;
}
case ScalarConstant::Kind::kBool: {
if (constant.value.b) {
push_type(constant.is_spec_op ? spv::Op::OpSpecConstantTrue
: spv::Op::OpConstantTrue,
{Operand::Int(type_id), result});
} else {
push_type(constant.is_spec_op ? spv::Op::OpSpecConstantFalse
: spv::Op::OpConstantFalse,
{Operand::Int(type_id), result});
}
break;
}
}
const_to_id_[constant] = result_id;
return result_id;
}
uint32_t Builder::GenerateConstantNullIfNeeded(const sem::Type* type) {
auto type_id = GenerateTypeIfNeeded(type);
if (type_id == 0) {
return 0;
}
auto name = type->type_name();
auto it = const_null_to_id_.find(name);
if (it != const_null_to_id_.end()) {
return it->second;
}
auto result = result_op();
auto result_id = result.to_i();
push_type(spv::Op::OpConstantNull, {Operand::Int(type_id), result});
const_null_to_id_[name] = result_id;
return result_id;
}
uint32_t Builder::GenerateConstantVectorSplatIfNeeded(const sem::Vector* type,
uint32_t value_id) {
auto type_id = GenerateTypeIfNeeded(type);
if (type_id == 0 || value_id == 0) {
return 0;
}
uint64_t key = (static_cast<uint64_t>(type->Width()) << 32) + value_id;
return utils::GetOrCreate(const_splat_to_id_, key, [&] {
auto result = result_op();
auto result_id = result.to_i();
OperandList ops;
ops.push_back(Operand::Int(type_id));
ops.push_back(result);
for (uint32_t i = 0; i < type->Width(); i++) {
ops.push_back(Operand::Int(value_id));
}
push_type(spv::Op::OpConstantComposite, ops);
const_splat_to_id_[key] = result_id;
return result_id;
});
}
uint32_t Builder::GenerateShortCircuitBinaryExpression(
const ast::BinaryExpression* expr) {
auto lhs_id = GenerateExpression(expr->lhs);
if (lhs_id == 0) {
return false;
}
lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs), lhs_id);
// Get the ID of the basic block where control flow will diverge. It's the
// last basic block generated for the left-hand-side of the operator.
auto original_label_id = current_label_id_;
auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (type_id == 0) {
return 0;
}
auto merge_block = result_op();
auto merge_block_id = merge_block.to_i();
auto block = result_op();
auto block_id = block.to_i();
auto true_block_id = block_id;
auto false_block_id = merge_block_id;
// For a logical or we want to only check the RHS if the LHS is failed.
if (expr->IsLogicalOr()) {
std::swap(true_block_id, false_block_id);
}
if (!push_function_inst(spv::Op::OpSelectionMerge,
{Operand::Int(merge_block_id),
Operand::Int(SpvSelectionControlMaskNone)})) {
return 0;
}
if (!push_function_inst(spv::Op::OpBranchConditional,
{Operand::Int(lhs_id), Operand::Int(true_block_id),
Operand::Int(false_block_id)})) {
return 0;
}
// Output block to check the RHS
if (!GenerateLabel(block_id)) {
return 0;
}
auto rhs_id = GenerateExpression(expr->rhs);
if (rhs_id == 0) {
return 0;
}
rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs), rhs_id);
// Get the block ID of the last basic block generated for the right-hand-side
// expression. That block will be an immediate predecessor to the merge block.
auto rhs_block_id = current_label_id_;
if (!push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)})) {
return 0;
}
// Output the merge block
if (!GenerateLabel(merge_block_id)) {
return 0;
}
auto result = result_op();
auto result_id = result.to_i();
if (!push_function_inst(spv::Op::OpPhi,
{Operand::Int(type_id), result, Operand::Int(lhs_id),
Operand::Int(original_label_id),
Operand::Int(rhs_id), Operand::Int(rhs_block_id)})) {
return 0;
}
return result_id;
}
uint32_t Builder::GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type) {
// Create a new vector to splat scalar into
auto splat_vector = result_op();
auto* splat_vector_type = builder_.create<sem::Pointer>(
vec_type, ast::StorageClass::kFunction, ast::Access::kReadWrite);
push_function_var(
{Operand::Int(GenerateTypeIfNeeded(splat_vector_type)), splat_vector,
Operand::Int(ConvertStorageClass(ast::StorageClass::kFunction)),
Operand::Int(GenerateConstantNullIfNeeded(vec_type))});
// Splat scalar into vector
auto splat_result = result_op();
OperandList ops;
ops.push_back(Operand::Int(GenerateTypeIfNeeded(vec_type)));
ops.push_back(splat_result);
for (size_t i = 0; i < vec_type->As<sem::Vector>()->Width(); ++i) {
ops.push_back(Operand::Int(scalar_id));
}
if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
return 0;
}
return splat_result.to_i();
}
uint32_t Builder::GenerateMatrixAddOrSub(uint32_t lhs_id,
uint32_t rhs_id,
const sem::Matrix* type,
spv::Op op) {
// Example addition of two matrices:
// %31 = OpLoad %mat3v4float %m34
// %32 = OpLoad %mat3v4float %m34
// %33 = OpCompositeExtract %v4float %31 0
// %34 = OpCompositeExtract %v4float %32 0
// %35 = OpFAdd %v4float %33 %34
// %36 = OpCompositeExtract %v4float %31 1
// %37 = OpCompositeExtract %v4float %32 1
// %38 = OpFAdd %v4float %36 %37
// %39 = OpCompositeExtract %v4float %31 2
// %40 = OpCompositeExtract %v4float %32 2
// %41 = OpFAdd %v4float %39 %40
// %42 = OpCompositeConstruct %mat3v4float %35 %38 %41
auto* column_type = builder_.create<sem::Vector>(type->type(), type->rows());
auto column_type_id = GenerateTypeIfNeeded(column_type);
OperandList ops;
for (uint32_t i = 0; i < type->columns(); ++i) {
// Extract column `i` from lhs mat
auto lhs_column_id = result_op();
if (!push_function_inst(spv::Op::OpCompositeExtract,
{Operand::Int(column_type_id), lhs_column_id,
Operand::Int(lhs_id), Operand::Int(i)})) {
return 0;
}
// Extract column `i` from rhs mat
auto rhs_column_id = result_op();
if (!push_function_inst(spv::Op::OpCompositeExtract,
{Operand::Int(column_type_id), rhs_column_id,
Operand::Int(rhs_id), Operand::Int(i)})) {
return 0;
}
// Add or subtract the two columns
auto result = result_op();
if (!push_function_inst(op, {Operand::Int(column_type_id), result,
lhs_column_id, rhs_column_id})) {
return 0;
}
ops.push_back(result);
}
// Create the result matrix from the added/subtracted column vectors
auto result_mat_id = result_op();
ops.insert(ops.begin(), result_mat_id);
ops.insert(ops.begin(), Operand::Int(GenerateTypeIfNeeded(type)));
if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
return 0;
}
return result_mat_id.to_i();
}
uint32_t Builder::GenerateBinaryExpression(const ast::BinaryExpression* expr) {
// There is special logic for short circuiting operators.
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
return GenerateShortCircuitBinaryExpression(expr);
}
auto lhs_id = GenerateExpression(expr->lhs);
if (lhs_id == 0) {
return 0;
}
lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs), lhs_id);
auto rhs_id = GenerateExpression(expr->rhs);
if (rhs_id == 0) {
return 0;
}
rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs), rhs_id);
auto result = result_op();
auto result_id = result.to_i();
auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (type_id == 0) {
return 0;
}
// Handle int and float and the vectors of those types. Other types
// should have been rejected by validation.
auto* lhs_type = TypeOf(expr->lhs)->UnwrapRef();
auto* rhs_type = TypeOf(expr->rhs)->UnwrapRef();
// Handle matrix-matrix addition and subtraction
if ((expr->IsAdd() || expr->IsSubtract()) && lhs_type->is_float_matrix() &&
rhs_type->is_float_matrix()) {
auto* lhs_mat = lhs_type->As<sem::Matrix>();
auto* rhs_mat = rhs_type->As<sem::Matrix>();
// This should already have been validated by resolver
if (lhs_mat->rows() != rhs_mat->rows() ||
lhs_mat->columns() != rhs_mat->columns()) {
error_ = "matrices must have same dimensionality for add or subtract";
return 0;
}
return GenerateMatrixAddOrSub(
lhs_id, rhs_id, lhs_mat,
expr->IsAdd() ? spv::Op::OpFAdd : spv::Op::OpFSub);
}
// For vector-scalar arithmetic operations, splat scalar into a vector. We
// skip this for multiply as we can use OpVectorTimesScalar.
const bool is_float_scalar_vector_multiply =
expr->IsMultiply() &&
((lhs_type->is_float_scalar() && rhs_type->is_float_vector()) ||
(lhs_type->is_float_vector() && rhs_type->is_float_scalar()));
if (expr->IsArithmetic() && !is_float_scalar_vector_multiply) {
if (lhs_type->Is<sem::Vector>() && rhs_type->is_numeric_scalar()) {
uint32_t splat_vector_id = GenerateSplat(rhs_id, lhs_type);
if (splat_vector_id == 0) {
return 0;
}
rhs_id = splat_vector_id;
rhs_type = lhs_type;
} else if (lhs_type->is_numeric_scalar() && rhs_type->Is<sem::Vector>()) {
uint32_t splat_vector_id = GenerateSplat(lhs_id, rhs_type);
if (splat_vector_id == 0) {
return 0;
}
lhs_id = splat_vector_id;
lhs_type = rhs_type;
}
}
bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
bool lhs_is_bool_or_vec = lhs_type->is_bool_scalar_or_vector();
bool lhs_is_integer_or_vec = lhs_type->is_integer_scalar_or_vector();
bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
spv::Op op = spv::Op::OpNop;
if (expr->IsAnd()) {
if (lhs_is_integer_or_vec) {
op = spv::Op::OpBitwiseAnd;
} else if (lhs_is_bool_or_vec) {
op = spv::Op::OpLogicalAnd;
} else {
error_ = "invalid and expression";
return 0;
}
} else if (expr->IsAdd()) {
op = lhs_is_float_or_vec ? spv::Op::OpFAdd : spv::Op::OpIAdd;
} else if (expr->IsDivide()) {
if (lhs_is_float_or_vec) {
op = spv::Op::OpFDiv;
} else if (lhs_is_unsigned) {
op = spv::Op::OpUDiv;
} else {
op = spv::Op::OpSDiv;
}
} else if (expr->IsEqual()) {
if (lhs_is_float_or_vec) {
op = spv::Op::OpFOrdEqual;
} else if (lhs_is_bool_or_vec) {
op = spv::Op::OpLogicalEqual;
} else if (lhs_is_integer_or_vec) {
op = spv::Op::OpIEqual;
} else {
error_ = "invalid equal expression";
return 0;
}
} else if (expr->IsGreaterThan()) {
if (lhs_is_float_or_vec) {
op = spv::Op::OpFOrdGreaterThan;
} else if (lhs_is_unsigned) {
op = spv::Op::OpUGreaterThan;
} else {
op = spv::Op::OpSGreaterThan;
}
} else if (expr->IsGreaterThanEqual()) {
if (lhs_is_float_or_vec) {
op = spv::Op::OpFOrdGreaterThanEqual;
} else if (lhs_is_unsigned) {
op = spv::Op::OpUGreaterThanEqual;
} else {
op = spv::Op::OpSGreaterThanEqual;
}
} else if (expr->IsLessThan()) {
if (lhs_is_float_or_vec) {
op = spv::Op::OpFOrdLessThan;
} else if (lhs_is_unsigned) {
op = spv::Op::OpULessThan;
} else {
op = spv::Op::OpSLessThan;
}
} else if (expr->IsLessThanEqual()) {
if (lhs_is_float_or_vec) {
op = spv::Op::OpFOrdLessThanEqual;
} else if (lhs_is_unsigned) {
op = spv::Op::OpULessThanEqual;
} else {
op = spv::Op::OpSLessThanEqual;
}
} else if (expr->IsModulo()) {
if (lhs_is_float_or_vec) {
op = spv::Op::OpFRem;
} else if (lhs_is_unsigned) {
op = spv::Op::OpUMod;
} else {
op = spv::Op::OpSMod;
}
} else if (expr->IsMultiply()) {
if (lhs_type->is_integer_scalar_or_vector()) {
// If the left hand side is an integer then this _has_ to be OpIMul as
// there there is no other integer multiplication.
op = spv::Op::OpIMul;
} else if (lhs_type->is_float_scalar() && rhs_type->is_float_scalar()) {
// Float scalars multiply with OpFMul
op = spv::Op::OpFMul;
} else if (lhs_type->is_float_vector() && rhs_type->is_float_vector()) {
// Float vectors must be validated to be the same size and then use OpFMul
op = spv::Op::OpFMul;
} else if (lhs_type->is_float_scalar() && rhs_type->is_float_vector()) {
// Scalar * Vector we need to flip lhs and rhs types
// because OpVectorTimesScalar expects <vector>, <scalar>
std::swap(lhs_id, rhs_id);
op = spv::Op::OpVectorTimesScalar;
} else if (lhs_type->is_float_vector() && rhs_type->is_float_scalar()) {
// float vector * scalar
op = spv::Op::OpVectorTimesScalar;
} else if (lhs_type->is_float_scalar() && rhs_type->is_float_matrix()) {
// Scalar * Matrix we need to flip lhs and rhs types because
// OpMatrixTimesScalar expects <matrix>, <scalar>
std::swap(lhs_id, rhs_id);
op = spv::Op::OpMatrixTimesScalar;
} else if (lhs_type->is_float_matrix() && rhs_type->is_float_scalar()) {
// float matrix * scalar
op = spv::Op::OpMatrixTimesScalar;
} else if (lhs_type->is_float_vector() && rhs_type->is_float_matrix()) {
// float vector * matrix
op = spv::Op::OpVectorTimesMatrix;
} else if (lhs_type->is_float_matrix() && rhs_type->is_float_vector()) {
// float matrix * vector
op = spv::Op::OpMatrixTimesVector;
} else if (lhs_type->is_float_matrix() && rhs_type->is_float_matrix()) {
// float matrix * matrix
op = spv::Op::OpMatrixTimesMatrix;
} else {
error_ = "invalid multiply expression";
return 0;
}
} else if (expr->IsNotEqual()) {
if (lhs_is_float_or_vec) {
op = spv::Op::OpFOrdNotEqual;
} else if (lhs_is_bool_or_vec) {
op = spv::Op::OpLogicalNotEqual;
} else if (lhs_is_integer_or_vec) {
op = spv::Op::OpINotEqual;
} else {
error_ = "invalid not-equal expression";
return 0;
}
} else if (expr->IsOr()) {
if (lhs_is_integer_or_vec) {
op = spv::Op::OpBitwiseOr;
} else if (lhs_is_bool_or_vec) {
op = spv::Op::OpLogicalOr;
} else {
error_ = "invalid and expression";
return 0;
}
} else if (expr->IsShiftLeft()) {
op = spv::Op::OpShiftLeftLogical;
} else if (expr->IsShiftRight() && lhs_type->is_signed_scalar_or_vector()) {
// A shift right with a signed LHS is an arithmetic shift.
op = spv::Op::OpShiftRightArithmetic;
} else if (expr->IsShiftRight()) {
op = spv::Op::OpShiftRightLogical;
} else if (expr->IsSubtract()) {
op = lhs_is_float_or_vec ? spv::Op::OpFSub : spv::Op::OpISub;
} else if (expr->IsXor()) {
op = spv::Op::OpBitwiseXor;
} else {
error_ = "unknown binary expression";
return 0;
}
if (!push_function_inst(op, {Operand::Int(type_id), result,
Operand::Int(lhs_id), Operand::Int(rhs_id)})) {
return 0;
}
return result_id;
}
bool Builder::GenerateBlockStatement(const ast::BlockStatement* stmt) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
return GenerateBlockStatementWithoutScoping(stmt);
}
bool Builder::GenerateBlockStatementWithoutScoping(
const ast::BlockStatement* stmt) {
for (auto* block_stmt : stmt->statements) {
if (!GenerateStatement(block_stmt)) {
return false;
}
}
return true;
}
uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) {
auto* ident = expr->func;
auto* call = builder_.Sem().Get(expr);
auto* target = call->Target();
if (auto* intrinsic = target->As<sem::Intrinsic>()) {
return GenerateIntrinsic(expr, intrinsic);
}
auto type_id = GenerateTypeIfNeeded(target->ReturnType());
if (type_id == 0) {
return 0;
}
auto result = result_op();
auto result_id = result.to_i();
OperandList ops = {Operand::Int(type_id), result};
auto func_id = func_symbol_to_id_[ident->symbol];
if (func_id == 0) {
error_ = "unable to find called function: " +
builder_.Symbols().NameFor(ident->symbol);
return 0;
}
ops.push_back(Operand::Int(func_id));
size_t arg_idx = 0;
for (auto* arg : expr->args) {
auto id = GenerateExpression(arg);
if (id == 0) {
return 0;
}
id = GenerateLoadIfNeeded(TypeOf(arg), id);
if (id == 0) {
return 0;
}
ops.push_back(Operand::Int(id));
arg_idx++;
}
if (!push_function_inst(spv::Op::OpFunctionCall, std::move(ops))) {
return 0;
}
return result_id;
}
uint32_t Builder::GenerateIntrinsic(const ast::CallExpression* call,
const sem::Intrinsic* intrinsic) {
auto result = result_op();
auto result_id = result.to_i();
auto result_type_id = GenerateTypeIfNeeded(intrinsic->ReturnType());
if (result_type_id == 0) {
return 0;
}
if (intrinsic->IsFineDerivative() || intrinsic->IsCoarseDerivative()) {
push_capability(SpvCapabilityDerivativeControl);
}
if (intrinsic->IsImageQuery()) {
push_capability(SpvCapabilityImageQuery);
}
if (intrinsic->IsTexture()) {
if (!GenerateTextureIntrinsic(call, intrinsic, Operand::Int(result_type_id),
result)) {
return 0;
}
return result_id;
}
if (intrinsic->IsBarrier()) {
if (!GenerateControlBarrierIntrinsic(intrinsic)) {
return 0;
}
return result_id;
}
if (intrinsic->IsAtomic()) {
if (!GenerateAtomicIntrinsic(call, intrinsic, Operand::Int(result_type_id),
result)) {
return 0;
}
return result_id;
}
// Generates the SPIR-V ID for the expression for the indexed call argument,
// and loads it if necessary. Returns 0 on error.
auto get_arg_as_value_id = [&](size_t i,
bool generate_load = true) -> uint32_t {
auto* arg = call->args[i];
auto* param = intrinsic->Parameters()[i];
auto val_id = GenerateExpression(arg);
if (val_id == 0) {
return 0;
}
if (generate_load && !param->Type()->Is<sem::Pointer>()) {
val_id = GenerateLoadIfNeeded(TypeOf(arg), val_id);
}
return val_id;
};
OperandList params = {Operand::Int(result_type_id), result};
spv::Op op = spv::Op::OpNop;
// Pushes the arguments for a GlslStd450 extended instruction, and sets op
// to OpExtInst.
auto glsl_std450 = [&](uint32_t inst_id) {
auto set_id = GetGLSLstd450Import();
params.push_back(Operand::Int(set_id));
params.push_back(Operand::Int(inst_id));
op = spv::Op::OpExtInst;
};
switch (intrinsic->Type()) {
case IntrinsicType::kAny:
if (intrinsic->Parameters()[0]->Type()->Is<sem::Bool>()) {
// any(v: bool) just resolves to v.
return get_arg_as_value_id(0);
}
op = spv::Op::OpAny;
break;
case IntrinsicType::kAll:
if (intrinsic->Parameters()[0]->Type()->Is<sem::Bool>()) {
// all(v: bool) just resolves to v.
return get_arg_as_value_id(0);
}
op = spv::Op::OpAll;
break;
case IntrinsicType::kArrayLength: {
if (call->args.empty()) {
error_ = "missing param for runtime array length";
return 0;
}
auto* arg = call->args[0];
auto* address_of = arg->As<ast::UnaryOpExpression>();
if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
error_ = "arrayLength() expected pointer to member access, got " +
std::string(address_of->TypeInfo().name);
return 0;
}
auto* array_expr = address_of->expr;
auto* accessor = array_expr->As<ast::MemberAccessorExpression>();
if (!accessor) {
error_ =
"arrayLength() expected pointer to member access, got pointer to " +
std::string(array_expr->TypeInfo().name);
return 0;
}
auto struct_id = GenerateExpression(accessor->structure);
if (struct_id == 0) {
return 0;
}
params.push_back(Operand::Int(struct_id));
auto* type = TypeOf(accessor->structure)->UnwrapRef();
if (!type->Is<sem::Struct>()) {
error_ =
"invalid type (" + type->type_name() + ") for runtime array length";
return 0;
}
// Runtime array must be the last member in the structure
params.push_back(Operand::Int(uint32_t(
type->As<sem::Struct>()->Declaration()->members.size() - 1)));
if (!push_function_inst(spv::Op::OpArrayLength, params)) {
return 0;
}
return result_id;
}
case IntrinsicType::kCountOneBits:
op = spv::Op::OpBitCount;
break;
case IntrinsicType::kDot:
op = spv::Op::OpDot;
break;
case IntrinsicType::kDpdx:
op = spv::Op::OpDPdx;
break;
case IntrinsicType::kDpdxCoarse:
op = spv::Op::OpDPdxCoarse;
break;
case IntrinsicType::kDpdxFine:
op = spv::Op::OpDPdxFine;
break;
case IntrinsicType::kDpdy:
op = spv::Op::OpDPdy;
break;
case IntrinsicType::kDpdyCoarse:
op = spv::Op::OpDPdyCoarse;
break;
case IntrinsicType::kDpdyFine:
op = spv::Op::OpDPdyFine;
break;
case IntrinsicType::kFwidth:
op = spv::Op::OpFwidth;
break;
case IntrinsicType::kFwidthCoarse:
op = spv::Op::OpFwidthCoarse;
break;
case IntrinsicType::kFwidthFine:
op = spv::Op::OpFwidthFine;
break;
case IntrinsicType::kIgnore: // [DEPRECATED]
// Evaluate the single argument, return the non-zero result_id which isn't
// associated with any op (ignore returns void, so this cannot be used in
// an expression).
if (!get_arg_as_value_id(0, false)) {
return 0;
}
return result_id;
case IntrinsicType::kIsInf:
op = spv::Op::OpIsInf;
break;
case IntrinsicType::kIsNan:
op = spv::Op::OpIsNan;
break;
case IntrinsicType::kIsFinite: {
// Implemented as: not(IsInf or IsNan)
auto val_id = get_arg_as_value_id(0);
if (!val_id) {
return 0;
}
auto inf_result = result_op();
auto nan_result = result_op();
auto or_result = result_op();
if (push_function_inst(spv::Op::OpIsInf,
{Operand::Int(result_type_id), inf_result,
Operand::Int(val_id)}) &&
push_function_inst(spv::Op::OpIsNan,
{Operand::Int(result_type_id), nan_result,
Operand::Int(val_id)}) &&
push_function_inst(spv::Op::OpLogicalOr,
{Operand::Int(result_type_id), or_result,
Operand::Int(inf_result.to_i()),
Operand::Int(nan_result.to_i())}) &&
push_function_inst(spv::Op::OpLogicalNot,
{Operand::Int(result_type_id), result,
Operand::Int(or_result.to_i())})) {
return result_id;
}
return 0;
}
case IntrinsicType::kIsNormal: {
// A normal number is finite, non-zero, and not subnormal.
// Its exponent is neither of the extreme possible values.
// Implemented as:
// exponent_bits = bitcast<u32>(f);
// clamped = uclamp(1,254,exponent_bits);
// result = (clamped == exponent_bits);
//
auto val_id = get_arg_as_value_id(0);
if (!val_id) {
return 0;
}
// These parameters are valid for IEEE 754 binary32
const uint32_t kExponentMask = 0x7f80000;
const uint32_t kMinNormalExponent = 0x0080000;
const uint32_t kMaxNormalExponent = 0x7f00000;
auto set_id = GetGLSLstd450Import();
auto* u32 = builder_.create<sem::U32>();
auto unsigned_id = GenerateTypeIfNeeded(u32);
auto exponent_mask_id =
GenerateConstantIfNeeded(ScalarConstant::U32(kExponentMask));
auto min_exponent_id =
GenerateConstantIfNeeded(ScalarConstant::U32(kMinNormalExponent));
auto max_exponent_id =
GenerateConstantIfNeeded(ScalarConstant::U32(kMaxNormalExponent));
if (auto* fvec_ty = intrinsic->ReturnType()->As<sem::Vector>()) {
// In the vector case, update the unsigned type to a vector type of the
// same size, and create vector constants by replicating the scalars.
// I expect backend compilers to fold these into unique constants, so
// there is no loss of efficiency.
auto* uvec_ty = builder_.create<sem::Vector>(u32, fvec_ty->Width());
unsigned_id = GenerateTypeIfNeeded(uvec_ty);
auto splat = [&](uint32_t scalar_id) -> uint32_t {
auto splat_result = result_op();
OperandList splat_params{Operand::Int(unsigned_id), splat_result};
for (size_t i = 0; i < fvec_ty->Width(); i++) {
splat_params.emplace_back(Operand::Int(scalar_id));
}
if (!push_function_inst(spv::Op::OpCompositeConstruct,
std::move(splat_params))) {
return 0;
}
return splat_result.to_i();
};
exponent_mask_id = splat(exponent_mask_id);
min_exponent_id = splat(min_exponent_id);
max_exponent_id = splat(max_exponent_id);
}
auto cast_result = result_op();
auto exponent_bits_result = result_op();
auto clamp_result = result_op();
if (set_id && unsigned_id && exponent_mask_id && min_exponent_id &&
max_exponent_id &&
push_function_inst(
spv::Op::OpBitcast,
{Operand::Int(unsigned_id), cast_result, Operand::Int(val_id)}) &&
push_function_inst(spv::Op::OpBitwiseAnd,
{Operand::Int(unsigned_id), exponent_bits_result,
Operand::Int(cast_result.to_i()),
Operand::Int(exponent_mask_id)}) &&
push_function_inst(
spv::Op::OpExtInst,
{Operand::Int(unsigned_id), clamp_result, Operand::Int(set_id),
Operand::Int(GLSLstd450UClamp),
Operand::Int(exponent_bits_result.to_i()),
Operand::Int(min_exponent_id), Operand::Int(max_exponent_id)}) &&
push_function_inst(spv::Op::OpIEqual,
{Operand::Int(result_type_id), result,
Operand::Int(exponent_bits_result.to_i()),
Operand::Int(clamp_result.to_i())})) {
return result_id;
}
return 0;
}
case IntrinsicType::kMix: {
auto std450 = Operand::Int(GetGLSLstd450Import());
auto a_id = get_arg_as_value_id(0);
auto b_id = get_arg_as_value_id(1);
auto f_id = get_arg_as_value_id(2);
if (!a_id || !b_id || !f_id) {
return 0;
}
// If the interpolant is scalar but the objects are vectors, we need to
// splat the interpolant into a vector of the same size.
auto* result_vector_type = intrinsic->ReturnType()->As<sem::Vector>();
if (result_vector_type &&
intrinsic->Parameters()[2]->Type()->is_scalar()) {
f_id = GenerateSplat(f_id, intrinsic->Parameters()[0]->Type());
if (f_id == 0) {
return 0;
}
}
if (!push_function_inst(spv::Op::OpExtInst,
{Operand::Int(result_type_id), result, std450,
Operand::Int(GLSLstd450FMix), Operand::Int(a_id),
Operand::Int(b_id), Operand::Int(f_id)})) {
return 0;
}
return result_id;
}
case IntrinsicType::kReverseBits:
op = spv::Op::OpBitReverse;
break;
case IntrinsicType::kSelect: {
// Note: Argument order is different in WGSL and SPIR-V
auto cond_id = get_arg_as_value_id(2);
auto true_id = get_arg_as_value_id(1);
auto false_id = get_arg_as_value_id(0);
if (!cond_id || !true_id || !false_id) {
return 0;
}
// If the condition is scalar but the objects are vectors, we need to
// splat the condition into a vector of the same size.
// TODO(jrprice): If we're targeting SPIR-V 1.4, we don't need to do this.
auto* result_vector_type = intrinsic->ReturnType()->As<sem::Vector>();
if (result_vector_type &&
intrinsic->Parameters()[2]->Type()->is_scalar()) {
auto* bool_vec_ty = builder_.create<sem::Vector>(
builder_.create<sem::Bool>(), result_vector_type->Width());
if (!GenerateTypeIfNeeded(bool_vec_ty)) {
return 0;
}
cond_id = GenerateSplat(cond_id, bool_vec_ty);
if (cond_id == 0) {
return 0;
}
}
if (!push_function_inst(
spv::Op::OpSelect,
{Operand::Int(result_type_id), result, Operand::Int(cond_id),
Operand::Int(true_id), Operand::Int(false_id)})) {
return 0;
}
return result_id;
}
case IntrinsicType::kTranspose:
op = spv::Op::OpTranspose;
break;
case IntrinsicType::kAbs:
if (intrinsic->ReturnType()->is_unsigned_scalar_or_vector()) {
// abs() only operates on *signed* integers.
// This is a no-op for unsigned integers.
return get_arg_as_value_id(0);
}
if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
glsl_std450(GLSLstd450FAbs);
} else {
glsl_std450(GLSLstd450SAbs);
}
break;
default: {
auto inst_id = intrinsic_to_glsl_method(intrinsic);
if (inst_id == 0) {
error_ = "unknown method " + std::string(intrinsic->str());
return 0;
}
glsl_std450(inst_id);
break;
}
}
if (op == spv::Op::OpNop) {
error_ =
"unable to determine operator for: " + std::string(intrinsic->str());
return 0;
}
for (size_t i = 0; i < call->args.size(); i++) {
if (auto val_id = get_arg_as_value_id(i)) {
params.emplace_back(Operand::Int(val_id));
} else {
return 0;
}
}
if (!push_function_inst(op, params)) {
return 0;
}
return result_id;
}
bool Builder::GenerateTextureIntrinsic(const ast::CallExpression* call,
const sem::Intrinsic* intrinsic,
Operand result_type,
Operand result_id) {
using Usage = sem::ParameterUsage;
auto& signature = intrinsic->Signature();
auto arguments = call->args;
// Generates the given expression, returning the operand ID
auto gen = [&](const ast::Expression* expr) {
auto val_id = GenerateExpression(expr);
if (val_id == 0) {
return Operand::Int(0);
}
val_id = GenerateLoadIfNeeded(TypeOf(expr), val_id);
return Operand::Int(val_id);
};