blob: 29f504adeceb8c777cb52c831846a40b874fbe51 [file] [log] [blame]
// Copyright 2023 The Dawn & Tint Authors
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "src/tint/lang/spirv/reader/parser/parser.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
TINT_BEGIN_DISABLE_WARNING(NEWLINE_EOF);
TINT_BEGIN_DISABLE_WARNING(OLD_STYLE_CAST);
TINT_BEGIN_DISABLE_WARNING(SIGN_CONVERSION);
TINT_BEGIN_DISABLE_WARNING(WEAK_VTABLES);
TINT_BEGIN_DISABLE_WARNING(UNSAFE_BUFFER_USAGE);
#include "source/opt/build_module.h"
TINT_END_DISABLE_WARNING(UNSAFE_BUFFER_USAGE);
TINT_END_DISABLE_WARNING(WEAK_VTABLES);
TINT_END_DISABLE_WARNING(SIGN_CONVERSION);
TINT_END_DISABLE_WARNING(OLD_STYLE_CAST);
TINT_END_DISABLE_WARNING(NEWLINE_EOF);
#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/type/builtin_structs.h"
#include "src/tint/lang/spirv/builtin_fn.h"
#include "src/tint/lang/spirv/ir/builtin_call.h"
#include "src/tint/lang/spirv/validate/validate.h"
using namespace tint::core::number_suffixes; // NOLINT
using namespace tint::core::fluent_types; // NOLINT
namespace tint::spirv::reader {
namespace {
/// The SPIR-V environment that we validate against.
constexpr auto kTargetEnv = SPV_ENV_VULKAN_1_1;
/// PIMPL class for SPIR-V parser.
/// Validates the SPIR-V module and then parses it to produce a Tint IR module.
class Parser {
public:
/// @param spirv the SPIR-V binary data
/// @returns the generated SPIR-V IR module on success, or failure
Result<core::ir::Module> Run(Slice<const uint32_t> spirv) {
// Validate the incoming SPIR-V binary.
auto result = validate::Validate(spirv, kTargetEnv);
if (result != Success) {
return result.Failure();
}
// Build the SPIR-V tools internal representation of the SPIR-V module.
spvtools::Context context(kTargetEnv);
spirv_context_ =
spvtools::BuildModule(kTargetEnv, context.CContext()->consumer, spirv.data, spirv.len);
if (!spirv_context_) {
return Failure("failed to build the internal representation of the module");
}
// Check for unsupported extensions.
for (const auto& ext : spirv_context_->extensions()) {
auto name = ext.GetOperand(0).AsString();
if (name != "SPV_KHR_storage_buffer_storage_class" &&
name != "SPV_KHR_non_semantic_info") {
return Failure("SPIR-V extension '" + name + "' is not supported");
}
}
// Register imported instruction sets
for (const auto& import : spirv_context_->ext_inst_imports()) {
auto name = import.GetInOperand(0).AsString();
// TODO(dneto): Handle other extended instruction sets when needed.
if (name == "GLSL.std.450") {
glsl_std_450_imports_.insert(import.result_id());
} else if (name.find("NonSemantic.") == 0) {
ignored_imports_.insert(import.result_id());
} else {
return Failure("Unrecognized extended instruction set: " + name);
}
}
{
TINT_SCOPED_ASSIGNMENT(current_block_, ir_.root_block);
EmitModuleScopeVariables();
}
RegisterNames();
EmitFunctions();
EmitEntryPointAttributes();
// TODO(crbug.com/tint/1907): Handle annotation instructions.
return std::move(ir_);
}
void RegisterNames() {
// Register names from OpName
for (const auto& inst : spirv_context_->debugs2()) {
switch (inst.opcode()) {
case spv::Op::OpName: {
const auto name = inst.GetInOperand(1).AsString();
if (!name.empty()) {
id_to_name_[inst.GetSingleWordInOperand(0)] = name;
}
break;
}
case spv::Op::OpMemberName: {
const auto name = inst.GetInOperand(2).AsString();
if (!name.empty()) {
uint32_t struct_id = inst.GetSingleWordInOperand(0);
uint32_t member_idx = inst.GetSingleWordInOperand(1);
auto iter = struct_to_member_names_.insert({struct_id, {}});
auto& members = (*(iter.first)).second;
if (members.size() < (member_idx + 1)) {
members.resize(member_idx + 1);
}
members[member_idx] = name;
}
break;
}
default:
break;
}
}
}
/// @param sc a SPIR-V storage class
/// @returns the Tint address space for a SPIR-V storage class
core::AddressSpace AddressSpace(spv::StorageClass sc) {
switch (sc) {
case spv::StorageClass::Input:
return core::AddressSpace::kIn;
case spv::StorageClass::Output:
return core::AddressSpace::kOut;
case spv::StorageClass::Function:
return core::AddressSpace::kFunction;
case spv::StorageClass::Private:
return core::AddressSpace::kPrivate;
case spv::StorageClass::StorageBuffer:
return core::AddressSpace::kStorage;
case spv::StorageClass::Uniform:
return core::AddressSpace::kUniform;
case spv::StorageClass::UniformConstant:
return core::AddressSpace::kHandle;
default:
TINT_UNIMPLEMENTED()
<< "unhandled SPIR-V storage class: " << static_cast<uint32_t>(sc);
}
}
/// @param b a SPIR-V BuiltIn
/// @returns the Tint builtin value for a SPIR-V BuiltIn decoration
core::BuiltinValue Builtin(spv::BuiltIn b) {
switch (b) {
case spv::BuiltIn::FragCoord:
return core::BuiltinValue::kPosition;
case spv::BuiltIn::FragDepth:
return core::BuiltinValue::kFragDepth;
case spv::BuiltIn::FrontFacing:
return core::BuiltinValue::kFrontFacing;
case spv::BuiltIn::GlobalInvocationId:
return core::BuiltinValue::kGlobalInvocationId;
case spv::BuiltIn::InstanceIndex:
return core::BuiltinValue::kInstanceIndex;
case spv::BuiltIn::LocalInvocationId:
return core::BuiltinValue::kLocalInvocationId;
case spv::BuiltIn::LocalInvocationIndex:
return core::BuiltinValue::kLocalInvocationIndex;
case spv::BuiltIn::NumWorkgroups:
return core::BuiltinValue::kNumWorkgroups;
case spv::BuiltIn::PointSize:
return core::BuiltinValue::kPointSize;
case spv::BuiltIn::Position:
return core::BuiltinValue::kPosition;
case spv::BuiltIn::SampleId:
return core::BuiltinValue::kSampleIndex;
case spv::BuiltIn::SampleMask:
return core::BuiltinValue::kSampleMask;
case spv::BuiltIn::VertexIndex:
return core::BuiltinValue::kVertexIndex;
case spv::BuiltIn::WorkgroupId:
return core::BuiltinValue::kWorkgroupId;
case spv::BuiltIn::ClipDistance:
return core::BuiltinValue::kClipDistances;
case spv::BuiltIn::CullDistance:
return core::BuiltinValue::kCullDistance;
default:
TINT_UNIMPLEMENTED() << "unhandled SPIR-V BuiltIn: " << static_cast<uint32_t>(b);
}
}
/// @param type a SPIR-V type object
/// @param access_mode an optional access mode (for pointers)
/// @returns a Tint type object
const core::type::Type* Type(const spvtools::opt::analysis::Type* type,
core::Access access_mode = core::Access::kUndefined) {
return types_.GetOrAdd(TypeKey{type, access_mode}, [&]() -> const core::type::Type* {
switch (type->kind()) {
case spvtools::opt::analysis::Type::kVoid:
return ty_.void_();
case spvtools::opt::analysis::Type::kBool:
return ty_.bool_();
case spvtools::opt::analysis::Type::kInteger: {
auto* int_ty = type->AsInteger();
TINT_ASSERT(int_ty->width() == 32);
if (int_ty->IsSigned()) {
return ty_.i32();
} else {
return ty_.u32();
}
}
case spvtools::opt::analysis::Type::kFloat: {
auto* float_ty = type->AsFloat();
if (float_ty->width() == 16) {
return ty_.f16();
} else if (float_ty->width() == 32) {
return ty_.f32();
} else {
TINT_UNREACHABLE()
<< "unsupported floating point type width: " << float_ty->width();
}
}
case spvtools::opt::analysis::Type::kVector: {
auto* vec_ty = type->AsVector();
TINT_ASSERT(vec_ty->element_count() <= 4);
return ty_.vec(Type(vec_ty->element_type()), vec_ty->element_count());
}
case spvtools::opt::analysis::Type::kMatrix: {
auto* mat_ty = type->AsMatrix();
TINT_ASSERT(mat_ty->element_count() <= 4);
return ty_.mat(As<core::type::Vector>(Type(mat_ty->element_type())),
mat_ty->element_count());
}
case spvtools::opt::analysis::Type::kArray:
return EmitArray(type->AsArray());
case spvtools::opt::analysis::Type::kStruct:
return EmitStruct(type->AsStruct());
case spvtools::opt::analysis::Type::kPointer: {
auto* ptr_ty = type->AsPointer();
return ty_.ptr(AddressSpace(ptr_ty->storage_class()),
Type(ptr_ty->pointee_type()), access_mode);
}
case spvtools::opt::analysis::Type::kSampler: {
// TODO(dsinclair): How to determine comparison samplers ...
return ty_.sampler();
}
default:
TINT_UNIMPLEMENTED() << "unhandled SPIR-V type: " << type->str();
}
});
}
/// @param id a SPIR-V result ID for a type declaration instruction
/// @param access_mode an optional access mode (for pointers)
/// @returns a Tint type object
const core::type::Type* Type(uint32_t id, core::Access access_mode = core::Access::kUndefined) {
return Type(spirv_context_->get_type_mgr()->GetType(id), access_mode);
}
/// @param arr_ty a SPIR-V array object
/// @returns a Tint array object
const core::type::Type* EmitArray(const spvtools::opt::analysis::Array* arr_ty) {
const auto& length = arr_ty->length_info();
TINT_ASSERT(!length.words.empty());
if (length.words[0] != spvtools::opt::analysis::Array::LengthInfo::kConstant) {
TINT_UNIMPLEMENTED() << "specialized array lengths";
}
// Get the value from the constant used for the element count.
const auto* count_const =
spirv_context_->get_constant_mgr()->FindDeclaredConstant(length.id);
TINT_ASSERT(count_const);
const uint64_t count_val = count_const->GetZeroExtendedValue();
TINT_ASSERT(count_val <= UINT32_MAX);
// TODO(crbug.com/1907): Handle decorations that affect the array layout.
return ty_.array(Type(arr_ty->element_type()), static_cast<uint32_t>(count_val));
}
/// @param struct_ty a SPIR-V struct object
/// @returns a Tint struct object
const core::type::Type* EmitStruct(const spvtools::opt::analysis::Struct* struct_ty) {
if (struct_ty->NumberOfComponents() == 0) {
TINT_ICE() << "empty structures are not supported";
}
auto* type_mgr = spirv_context_->get_type_mgr();
auto struct_id = type_mgr->GetId(struct_ty);
std::vector<std::string>* member_names = nullptr;
auto struct_to_member_iter = struct_to_member_names_.find(struct_id);
if (struct_to_member_iter != struct_to_member_names_.end()) {
member_names = &((*struct_to_member_iter).second);
}
// Build a list of struct members.
uint32_t current_size = 0u;
Vector<core::type::StructMember*, 4> members;
for (uint32_t i = 0; i < struct_ty->NumberOfComponents(); i++) {
auto* member_ty = Type(struct_ty->element_types()[i]);
uint32_t align = std::max<uint32_t>(member_ty->Align(), 1u);
uint32_t offset = tint::RoundUp(align, current_size);
core::IOAttributes attributes;
auto interpolation = [&]() -> core::Interpolation& {
// Create the interpolation field with the default values on first call.
if (!attributes.interpolation.has_value()) {
attributes.interpolation =
core::Interpolation{core::InterpolationType::kPerspective,
core::InterpolationSampling::kCenter};
}
return attributes.interpolation.value();
};
// Handle member decorations that affect layout or attributes.
if (struct_ty->element_decorations().count(i)) {
for (auto& deco : struct_ty->element_decorations().at(i)) {
switch (spv::Decoration(deco[0])) {
case spv::Decoration::Offset:
offset = deco[1];
break;
case spv::Decoration::BuiltIn:
attributes.builtin = Builtin(spv::BuiltIn(deco[1]));
break;
case spv::Decoration::Invariant:
attributes.invariant = true;
break;
case spv::Decoration::Location:
attributes.location = deco[1];
break;
case spv::Decoration::NoPerspective:
interpolation().type = core::InterpolationType::kLinear;
break;
case spv::Decoration::Flat:
interpolation().type = core::InterpolationType::kFlat;
break;
case spv::Decoration::Centroid:
interpolation().sampling = core::InterpolationSampling::kCentroid;
break;
case spv::Decoration::Sample:
interpolation().sampling = core::InterpolationSampling::kSample;
break;
default:
TINT_UNIMPLEMENTED() << "unhandled member decoration: " << deco[0];
}
}
}
Symbol name;
if (member_names && member_names->size() > i) {
auto n = (*member_names)[i];
if (!n.empty()) {
name = ir_.symbols.Register(n);
}
}
if (!name.IsValid()) {
name = ir_.symbols.New();
}
members.Push(ty_.Get<core::type::StructMember>(
name, member_ty, i, offset, align, member_ty->Size(), std::move(attributes)));
current_size = offset + member_ty->Size();
}
Symbol name = GetUniqueSymbolFor(struct_id);
if (!name.IsValid()) {
name = ir_.symbols.New();
}
return ty_.Struct(name, std::move(members));
}
Symbol GetUniqueSymbolFor(uint32_t id) {
auto iter = id_to_name_.find(id);
if (iter != id_to_name_.end()) {
return ir_.symbols.New(iter->second);
}
return Symbol{};
}
Symbol GetSymbolFor(uint32_t id) {
auto iter = id_to_name_.find(id);
if (iter != id_to_name_.end()) {
return ir_.symbols.Register(iter->second);
}
return Symbol{};
}
/// @param id a SPIR-V result ID for a function declaration instruction
/// @returns a Tint function object
core::ir::Function* Function(uint32_t id) {
return functions_.GetOrAdd(id, [&] { return b_.Function(ty_.void_()); });
}
/// @param id a SPIR-V result ID
/// @returns a Tint value object
core::ir::Value* Value(uint32_t id) {
return values_.GetOrAdd(id, [&]() -> core::ir::Value* {
if (auto* c = spirv_context_->get_constant_mgr()->FindDeclaredConstant(id)) {
return b_.Constant(Constant(c));
}
TINT_UNREACHABLE() << "missing value for result ID " << id;
});
}
/// @param constant a SPIR-V constant object
/// @returns a Tint constant value
const core::constant::Value* Constant(const spvtools::opt::analysis::Constant* constant) {
// Handle OpConstantNull for all types.
if (constant->AsNullConstant()) {
return ir_.constant_values.Zero(Type(constant->type()));
}
if (auto* bool_ = constant->AsBoolConstant()) {
return b_.ConstantValue(bool_->value());
}
if (auto* i = constant->AsIntConstant()) {
auto* int_ty = i->type()->AsInteger();
TINT_ASSERT(int_ty->width() == 32);
if (int_ty->IsSigned()) {
return b_.ConstantValue(i32(i->GetS32BitValue()));
} else {
return b_.ConstantValue(u32(i->GetU32BitValue()));
}
}
if (auto* f = constant->AsFloatConstant()) {
auto* float_ty = f->type()->AsFloat();
if (float_ty->width() == 16) {
return b_.ConstantValue(f16::FromBits(static_cast<uint16_t>(f->words()[0])));
} else if (float_ty->width() == 32) {
return b_.ConstantValue(f32(f->GetFloat()));
} else {
TINT_UNREACHABLE() << "unsupported floating point type width";
}
}
if (auto* v = constant->AsVectorConstant()) {
Vector<const core::constant::Value*, 4> elements;
for (auto& el : v->GetComponents()) {
elements.Push(Constant(el));
}
return ir_.constant_values.Composite(Type(v->type()), std::move(elements));
}
if (auto* m = constant->AsMatrixConstant()) {
Vector<const core::constant::Value*, 4> columns;
for (auto& el : m->GetComponents()) {
columns.Push(Constant(el));
}
return ir_.constant_values.Composite(Type(m->type()), std::move(columns));
}
if (auto* a = constant->AsArrayConstant()) {
Vector<const core::constant::Value*, 16> elements;
for (auto& el : a->GetComponents()) {
elements.Push(Constant(el));
}
return ir_.constant_values.Composite(Type(a->type()), std::move(elements));
}
if (auto* s = constant->AsStructConstant()) {
Vector<const core::constant::Value*, 16> elements;
for (auto& el : s->GetComponents()) {
elements.Push(Constant(el));
}
return ir_.constant_values.Composite(Type(s->type()), std::move(elements));
}
TINT_UNIMPLEMENTED() << "unhandled constant type";
}
/// Register an IR value for a SPIR-V result ID.
/// @param result_id the SPIR-V result ID
/// @param value the IR value
void AddValue(uint32_t result_id, core::ir::Value* value) { values_.Add(result_id, value); }
/// Emit an instruction to the current block and associates the result to
/// the spirv result id.
/// @param inst the instruction to emit
/// @param result_id the SPIR-V result ID to register the instruction result for
void Emit(core::ir::Instruction* inst, uint32_t result_id) {
current_block_->Append(inst);
TINT_ASSERT(inst->Results().Length() == 1u);
AddValue(result_id, inst->Result(0));
Symbol name = GetSymbolFor(result_id);
if (name.IsValid()) {
ir_.SetName(inst, name);
}
}
/// Emit an instruction to the current block.
/// @param inst the instruction to emit
void EmitWithoutSpvResult(core::ir::Instruction* inst) {
current_block_->Append(inst);
TINT_ASSERT(inst->Results().Length() == 1u);
}
/// Emit an instruction to the current block.
/// @param inst the instruction to emit
void EmitWithoutResult(core::ir::Instruction* inst) {
TINT_ASSERT(inst->Results().IsEmpty());
current_block_->Append(inst);
}
/// Emit the module-scope variables.
void EmitModuleScopeVariables() {
for (auto& inst : spirv_context_->module()->types_values()) {
switch (inst.opcode()) {
case spv::Op::OpVariable:
EmitVar(inst);
break;
case spv::Op::OpUndef:
AddValue(inst.result_id(), b_.Zero(Type(inst.type_id())));
break;
default:
break;
}
}
}
/// Emit the functions.
void EmitFunctions() {
for (auto& func : *spirv_context_->module()) {
current_spirv_function_ = &func;
Vector<core::ir::FunctionParam*, 4> params;
func.ForEachParam([&](spvtools::opt::Instruction* spirv_param) {
auto* param = b_.FunctionParam(Type(spirv_param->type_id()));
values_.Add(spirv_param->result_id(), param);
Symbol name = GetSymbolFor(spirv_param->result_id());
if (name.IsValid()) {
ir_.SetName(param, name);
}
params.Push(param);
});
current_function_ = Function(func.result_id());
current_function_->SetParams(std::move(params));
current_function_->SetReturnType(Type(func.type_id()));
Symbol name = GetSymbolFor(func.result_id());
if (name.IsValid()) {
ir_.SetName(current_function_, name);
}
functions_.Add(func.result_id(), current_function_);
EmitBlock(current_function_->Block(), *func.entry());
}
current_spirv_function_ = nullptr;
}
/// Emit entry point attributes.
void EmitEntryPointAttributes() {
// Handle OpEntryPoint declarations.
for (auto& entry_point : spirv_context_->module()->entry_points()) {
auto model = entry_point.GetSingleWordInOperand(0);
auto* func = Function(entry_point.GetSingleWordInOperand(1));
// Set the pipeline stage.
switch (spv::ExecutionModel(model)) {
case spv::ExecutionModel::GLCompute:
func->SetStage(core::ir::Function::PipelineStage::kCompute);
break;
case spv::ExecutionModel::Fragment:
func->SetStage(core::ir::Function::PipelineStage::kFragment);
break;
case spv::ExecutionModel::Vertex:
func->SetStage(core::ir::Function::PipelineStage::kVertex);
break;
default:
TINT_UNIMPLEMENTED() << "unhandled execution model: " << model;
}
// Set the entry point name.
ir_.SetName(func, entry_point.GetOperand(2).AsString());
}
// Handle OpExecutionMode declarations.
for (auto& execution_mode : spirv_context_->module()->execution_modes()) {
auto* func = functions_.GetOr(execution_mode.GetSingleWordInOperand(0), nullptr);
auto mode = execution_mode.GetSingleWordInOperand(1);
TINT_ASSERT(func);
switch (spv::ExecutionMode(mode)) {
case spv::ExecutionMode::LocalSize:
func->SetWorkgroupSize(
b_.Constant(u32(execution_mode.GetSingleWordInOperand(2))),
b_.Constant(u32(execution_mode.GetSingleWordInOperand(3))),
b_.Constant(u32(execution_mode.GetSingleWordInOperand(4))));
break;
case spv::ExecutionMode::DepthReplacing:
case spv::ExecutionMode::OriginUpperLeft:
// These are ignored as they are implicitly supported by Tint IR.
break;
default:
TINT_UNIMPLEMENTED() << "unhandled execution mode: " << mode;
}
}
}
/// Emit the contents of SPIR-V block @p src into Tint IR block @p dst.
/// @param dst the Tint IR block to append to
/// @param src the SPIR-V block to emit
void EmitBlock(core::ir::Block* dst, const spvtools::opt::BasicBlock& src) {
TINT_SCOPED_ASSIGNMENT(current_block_, dst);
for (auto& inst : src) {
switch (inst.opcode()) {
case spv::Op::OpNop:
break;
case spv::Op::OpUndef:
AddValue(inst.result_id(), b_.Zero(Type(inst.type_id())));
break;
case spv::Op::OpBranch:
EmitBranch(inst);
break;
case spv::Op::OpBranchConditional:
EmitBranchConditional(inst);
break;
case spv::Op::OpSelectionMerge:
HandleSelectionMerge(inst, src);
break;
case spv::Op::OpExtInst:
EmitExtInst(inst);
break;
case spv::Op::OpCopyObject:
EmitCopyObject(inst);
break;
case spv::Op::OpConvertFToS:
EmitSpirvExplicitBuiltinCall(inst, spirv::BuiltinFn::kConvertFToS);
break;
case spv::Op::OpConvertFToU:
Emit(b_.Convert(Type(inst.type_id()), Value(inst.GetSingleWordOperand(2))),
inst.result_id());
break;
case spv::Op::OpAccessChain:
case spv::Op::OpInBoundsAccessChain:
EmitAccess(inst);
break;
case spv::Op::OpCompositeConstruct:
EmitConstruct(inst);
break;
case spv::Op::OpCompositeExtract:
EmitCompositeExtract(inst);
break;
case spv::Op::OpFAdd:
EmitBinary(inst, core::BinaryOp::kAdd);
break;
case spv::Op::OpIAdd:
EmitSpirvExplicitBuiltinCall(inst, spirv::BuiltinFn::kAdd);
break;
case spv::Op::OpSDiv:
EmitSpirvExplicitBuiltinCall(inst, spirv::BuiltinFn::kSDiv);
break;
case spv::Op::OpFDiv:
case spv::Op::OpUDiv:
EmitBinary(inst, core::BinaryOp::kDivide);
break;
case spv::Op::OpIMul:
EmitSpirvExplicitBuiltinCall(inst, spirv::BuiltinFn::kMul);
break;
case spv::Op::OpFMul:
case spv::Op::OpVectorTimesScalar:
case spv::Op::OpMatrixTimesScalar:
case spv::Op::OpVectorTimesMatrix:
case spv::Op::OpMatrixTimesVector:
case spv::Op::OpMatrixTimesMatrix:
EmitBinary(inst, core::BinaryOp::kMultiply);
break;
case spv::Op::OpFRem:
case spv::Op::OpUMod:
EmitBinary(inst, core::BinaryOp::kModulo);
break;
case spv::Op::OpSMod:
case spv::Op::OpSRem:
EmitSpirvExplicitBuiltinCall(inst, spirv::BuiltinFn::kSMod);
break;
case spv::Op::OpFSub:
EmitBinary(inst, core::BinaryOp::kSubtract);
break;
case spv::Op::OpISub:
EmitSpirvExplicitBuiltinCall(inst, spirv::BuiltinFn::kSub);
break;
case spv::Op::OpFunctionCall:
EmitFunctionCall(inst);
break;
case spv::Op::OpLoad:
Emit(b_.Load(Value(inst.GetSingleWordOperand(2))), inst.result_id());
break;
case spv::Op::OpReturn:
EmitWithoutResult(b_.Return(current_function_));
break;
case spv::Op::OpReturnValue:
EmitWithoutResult(
b_.Return(current_function_, Value(inst.GetSingleWordOperand(0))));
break;
case spv::Op::OpStore:
EmitWithoutResult(b_.Store(Value(inst.GetSingleWordOperand(0)),
Value(inst.GetSingleWordOperand(1))));
break;
case spv::Op::OpVariable:
EmitVar(inst);
break;
case spv::Op::OpUnreachable:
EmitWithoutResult(b_.Unreachable());
break;
case spv::Op::OpKill:
EmitKill(inst);
break;
case spv::Op::OpDot:
EmitBuiltinCall(inst, core::BuiltinFn::kDot);
break;
case spv::Op::OpBitCount:
EmitBitCount(inst);
break;
case spv::Op::OpBitFieldInsert:
EmitSpirvBuiltinCall(inst, spirv::BuiltinFn::kBitFieldInsert);
break;
case spv::Op::OpBitFieldSExtract:
EmitSpirvBuiltinCall(inst, spirv::BuiltinFn::kBitFieldSExtract);
break;
case spv::Op::OpBitFieldUExtract:
EmitSpirvBuiltinCall(inst, spirv::BuiltinFn::kBitFieldUExtract);
break;
case spv::Op::OpBitReverse:
EmitBuiltinCall(inst, core::BuiltinFn::kReverseBits);
break;
case spv::Op::OpAll:
EmitBuiltinCall(inst, core::BuiltinFn::kAll);
break;
case spv::Op::OpAny:
EmitBuiltinCall(inst, core::BuiltinFn::kAny);
break;
case spv::Op::OpDPdx:
EmitBuiltinCall(inst, core::BuiltinFn::kDpdx);
break;
case spv::Op::OpDPdy:
EmitBuiltinCall(inst, core::BuiltinFn::kDpdy);
break;
case spv::Op::OpFwidth:
EmitBuiltinCall(inst, core::BuiltinFn::kFwidth);
break;
case spv::Op::OpDPdxFine:
EmitBuiltinCall(inst, core::BuiltinFn::kDpdxFine);
break;
case spv::Op::OpDPdyFine:
EmitBuiltinCall(inst, core::BuiltinFn::kDpdyFine);
break;
case spv::Op::OpFwidthFine:
EmitBuiltinCall(inst, core::BuiltinFn::kFwidthFine);
break;
case spv::Op::OpDPdxCoarse:
EmitBuiltinCall(inst, core::BuiltinFn::kDpdxCoarse);
break;
case spv::Op::OpDPdyCoarse:
EmitBuiltinCall(inst, core::BuiltinFn::kDpdyCoarse);
break;
case spv::Op::OpFwidthCoarse:
EmitBuiltinCall(inst, core::BuiltinFn::kFwidthCoarse);
break;
default:
TINT_UNIMPLEMENTED()
<< "unhandled SPIR-V instruction: " << static_cast<uint32_t>(inst.opcode());
}
}
}
void HandleSelectionMerge(const spvtools::opt::Instruction& inst,
const spvtools::opt::BasicBlock& src) {
merge_stack_.push_back(MergeInfo{inst.GetSingleWordOperand(0), src.terminator()});
}
void EmitBranch(const spvtools::opt::Instruction& inst) {
auto dest_id = inst.GetSingleWordInOperand(0);
// If this is branching to the current merge block then nothing to do.
if (!merge_stack_.empty() && dest_id == merge_stack_.back().id) {
return;
}
TINT_ASSERT(current_spirv_function_);
const auto& bb = current_spirv_function_->FindBlock(dest_id);
EmitBlock(current_block_, *bb);
}
void EmitBranchConditional(const spvtools::opt::Instruction& inst) {
auto cond = Value(inst.GetSingleWordInOperand(0));
auto true_id = inst.GetSingleWordInOperand(1);
auto false_id = inst.GetSingleWordInOperand(2);
std::optional<uint32_t> merge_id = std::nullopt;
if (!merge_stack_.empty()) {
merge_id = merge_stack_.back().id;
}
TINT_ASSERT(current_spirv_function_);
auto* if_ = b_.If(cond);
EmitWithoutResult(if_);
if (true_id != merge_id) {
const auto& bb_true = current_spirv_function_->FindBlock(true_id);
EmitBlock(if_->True(), *bb_true);
}
if (!if_->True()->Terminator()) {
if_->True()->Append(b_.ExitIf(if_));
}
if (false_id != merge_id) {
const auto& bb_false = current_spirv_function_->FindBlock(false_id);
EmitBlock(if_->False(), *bb_false);
}
if (!if_->False()->Terminator()) {
if_->False()->Append(b_.ExitIf(if_));
}
if (merge_id.has_value()) {
if (&inst == merge_stack_.back().merge_inst) {
merge_stack_.pop_back();
const auto& bb_merge = current_spirv_function_->FindBlock(merge_id.value());
EmitBlock(current_block_, *bb_merge);
}
} else {
EmitWithoutResult(b_.Unreachable());
}
}
Vector<core::ir::Value*, 4> Args(const spvtools::opt::Instruction& inst, uint32_t start) {
Vector<core::ir::Value*, 4> args;
for (uint32_t i = start; i < inst.NumOperandWords(); i++) {
args.Push(Value(inst.GetSingleWordOperand(i)));
}
return args;
}
void EmitBuiltinCall(const spvtools::opt::Instruction& inst, core::BuiltinFn fn) {
Emit(b_.Call(Type(inst.type_id()), fn, Args(inst, 2)), inst.result_id());
}
void EmitSpirvExplicitBuiltinCall(const spvtools::opt::Instruction& inst, spirv::BuiltinFn fn) {
Emit(b_.CallExplicit<spirv::ir::BuiltinCall>(Type(inst.type_id()), fn,
Vector{Type(inst.type_id())->DeepestElement()},
Args(inst, 2)),
inst.result_id());
}
void EmitSpirvBuiltinCall(const spvtools::opt::Instruction& inst, spirv::BuiltinFn fn) {
Emit(b_.Call<spirv::ir::BuiltinCall>(Type(inst.type_id()), fn, Args(inst, 2)),
inst.result_id());
}
void EmitBitCount(const spvtools::opt::Instruction& inst) {
auto* res_ty = Type(inst.type_id());
Emit(b_.CallExplicit<spirv::ir::BuiltinCall>(res_ty, spirv::BuiltinFn::kBitCount,
Vector{res_ty->DeepestElement()},
Args(inst, 2)),
inst.result_id());
}
/// @param inst the SPIR-V instruction
/// Note: This isn't technically correct, but there is no `kill` equivalent in WGSL. The closets
/// we have is `discard` which maps to `OpDemoteToHelperInvocation` in SPIR-V.
void EmitKill([[maybe_unused]] const spvtools::opt::Instruction& inst) {
EmitWithoutResult(b_.Discard());
// An `OpKill` is a terminator in SPIR-V. `discard` is not a terminator in WGSL. After the
// `discard` we inject a `return` for the current function. This is similar in spirit to
// what `OpKill` does although not totally correct (i.e. we don't early return from calling
// functions, just the function where `OpKill` was emitted. There are also limited places in
// which `OpKill` can be used. So, we don't have to worry about it in a `continuing` block
// because the continuing must end with a branching terminator which `OpKill` does not
// branch.
if (current_function_->ReturnType()->Is<core::type::Void>()) {
EmitWithoutResult(b_.Return(current_function_));
} else {
EmitWithoutResult(
b_.Return(current_function_, b_.Zero(current_function_->ReturnType())));
}
}
/// @param inst the SPIR-V instruction for OpCopyObject
void EmitCopyObject(const spvtools::opt::Instruction& inst) {
// Make the result Id a pointer to the original copied value.
auto* l = b_.Let(Value(inst.GetSingleWordOperand(2)));
Emit(l, inst.result_id());
}
/// @param inst the SPIR-V instruction for OpExtInst
void EmitExtInst(const spvtools::opt::Instruction& inst) {
auto inst_set = inst.GetSingleWordInOperand(0);
if (ignored_imports_.count(inst_set) > 0) {
// Ignore it but don't error out.
return;
}
if (glsl_std_450_imports_.count(inst_set) > 0) {
EmitGlslStd450ExtInst(inst);
return;
}
TINT_UNIMPLEMENTED() << "unhandled extended instruction import with ID "
<< inst.GetSingleWordInOperand(0);
}
// Returns the WGSL standard library function for the given GLSL.std.450 extended instruction
// operation code. This handles GLSL functions which directly translate to the WGSL equivalent.
// Any non-direct translation is returned as `kNone`.
core::BuiltinFn GetGlslStd450WgslEquivalentFuncName(uint32_t ext_opcode) {
switch (ext_opcode) {
case GLSLstd450Acos:
return core::BuiltinFn::kAcos;
case GLSLstd450Acosh:
return core::BuiltinFn::kAcosh;
case GLSLstd450Asin:
return core::BuiltinFn::kAsin;
case GLSLstd450Asinh:
return core::BuiltinFn::kAsinh;
case GLSLstd450Atan:
return core::BuiltinFn::kAtan;
case GLSLstd450Atanh:
return core::BuiltinFn::kAtanh;
case GLSLstd450Atan2:
return core::BuiltinFn::kAtan2;
case GLSLstd450Ceil:
return core::BuiltinFn::kCeil;
case GLSLstd450Cos:
return core::BuiltinFn::kCos;
case GLSLstd450Cosh:
return core::BuiltinFn::kCosh;
case GLSLstd450Cross:
return core::BuiltinFn::kCross;
case GLSLstd450Degrees:
return core::BuiltinFn::kDegrees;
case GLSLstd450Determinant:
return core::BuiltinFn::kDeterminant;
case GLSLstd450Distance:
return core::BuiltinFn::kDistance;
case GLSLstd450Exp:
return core::BuiltinFn::kExp;
case GLSLstd450Exp2:
return core::BuiltinFn::kExp2;
case GLSLstd450FAbs:
return core::BuiltinFn::kAbs;
case GLSLstd450FSign:
return core::BuiltinFn::kSign;
case GLSLstd450Floor:
return core::BuiltinFn::kFloor;
case GLSLstd450Fract:
return core::BuiltinFn::kFract;
case GLSLstd450Fma:
return core::BuiltinFn::kFma;
case GLSLstd450InverseSqrt:
return core::BuiltinFn::kInverseSqrt;
case GLSLstd450Length:
return core::BuiltinFn::kLength;
case GLSLstd450Log:
return core::BuiltinFn::kLog;
case GLSLstd450Log2:
return core::BuiltinFn::kLog2;
case GLSLstd450NClamp:
case GLSLstd450FClamp: // FClamp is less prescriptive about NaN operands
return core::BuiltinFn::kClamp;
case GLSLstd450ModfStruct:
return core::BuiltinFn::kModf;
case GLSLstd450FrexpStruct:
return core::BuiltinFn::kFrexp;
case GLSLstd450NMin:
case GLSLstd450FMin: // FMin is less prescriptive about NaN operands
return core::BuiltinFn::kMin;
case GLSLstd450NMax:
case GLSLstd450FMax: // FMax is less prescriptive about NaN operands
return core::BuiltinFn::kMax;
case GLSLstd450FMix:
return core::BuiltinFn::kMix;
case GLSLstd450PackSnorm4x8:
return core::BuiltinFn::kPack4X8Snorm;
case GLSLstd450PackUnorm4x8:
return core::BuiltinFn::kPack4X8Unorm;
case GLSLstd450PackSnorm2x16:
return core::BuiltinFn::kPack2X16Snorm;
case GLSLstd450PackUnorm2x16:
return core::BuiltinFn::kPack2X16Unorm;
case GLSLstd450PackHalf2x16:
return core::BuiltinFn::kPack2X16Float;
case GLSLstd450Pow:
return core::BuiltinFn::kPow;
case GLSLstd450Radians:
return core::BuiltinFn::kRadians;
case GLSLstd450Round:
case GLSLstd450RoundEven:
return core::BuiltinFn::kRound;
case GLSLstd450Sin:
return core::BuiltinFn::kSin;
case GLSLstd450Sinh:
return core::BuiltinFn::kSinh;
case GLSLstd450SmoothStep:
return core::BuiltinFn::kSmoothstep;
case GLSLstd450Sqrt:
return core::BuiltinFn::kSqrt;
case GLSLstd450Step:
return core::BuiltinFn::kStep;
case GLSLstd450Tan:
return core::BuiltinFn::kTan;
case GLSLstd450Tanh:
return core::BuiltinFn::kTanh;
case GLSLstd450Trunc:
return core::BuiltinFn::kTrunc;
case GLSLstd450UnpackSnorm4x8:
return core::BuiltinFn::kUnpack4X8Snorm;
case GLSLstd450UnpackUnorm4x8:
return core::BuiltinFn::kUnpack4X8Unorm;
case GLSLstd450UnpackSnorm2x16:
return core::BuiltinFn::kUnpack2X16Snorm;
case GLSLstd450UnpackUnorm2x16:
return core::BuiltinFn::kUnpack2X16Unorm;
case GLSLstd450UnpackHalf2x16:
return core::BuiltinFn::kUnpack2X16Float;
default:
break;
}
return core::BuiltinFn::kNone;
}
spirv::BuiltinFn GetGlslStd450SpirvEquivalentFuncName(uint32_t ext_opcode) {
switch (ext_opcode) {
case GLSLstd450SAbs:
return spirv::BuiltinFn::kAbs;
case GLSLstd450SSign:
return spirv::BuiltinFn::kSign;
case GLSLstd450Normalize:
return spirv::BuiltinFn::kNormalize;
case GLSLstd450MatrixInverse:
return spirv::BuiltinFn::kInverse;
case GLSLstd450SMax:
return spirv::BuiltinFn::kSmax;
case GLSLstd450SMin:
return spirv::BuiltinFn::kSmin;
case GLSLstd450SClamp:
return spirv::BuiltinFn::kSclamp;
case GLSLstd450UMax:
return spirv::BuiltinFn::kUmax;
case GLSLstd450UMin:
return spirv::BuiltinFn::kUmin;
case GLSLstd450UClamp:
return spirv::BuiltinFn::kUclamp;
case GLSLstd450FindILsb:
return spirv::BuiltinFn::kFindILsb;
case GLSLstd450FindSMsb:
return spirv::BuiltinFn::kFindSMsb;
case GLSLstd450FindUMsb:
return spirv::BuiltinFn::kFindUMsb;
case GLSLstd450Refract:
return spirv::BuiltinFn::kRefract;
case GLSLstd450Reflect:
return spirv::BuiltinFn::kReflect;
case GLSLstd450FaceForward:
return spirv::BuiltinFn::kFaceForward;
case GLSLstd450Ldexp:
return spirv::BuiltinFn::kLdexp;
case GLSLstd450Modf:
return spirv::BuiltinFn::kModf;
case GLSLstd450Frexp:
return spirv::BuiltinFn::kFrexp;
default:
break;
}
return spirv::BuiltinFn::kNone;
}
Vector<const core::type::Type*, 1> GlslStd450ExplicitParams(uint32_t ext_opcode,
const core::type::Type* result_ty) {
if (ext_opcode == GLSLstd450SSign || ext_opcode == GLSLstd450SAbs ||
ext_opcode == GLSLstd450SMax || ext_opcode == GLSLstd450SMin ||
ext_opcode == GLSLstd450SClamp || ext_opcode == GLSLstd450UMax ||
ext_opcode == GLSLstd450UMin || ext_opcode == GLSLstd450UClamp ||
ext_opcode == GLSLstd450FindILsb || ext_opcode == GLSLstd450FindSMsb ||
ext_opcode == GLSLstd450FindUMsb) {
return {result_ty->DeepestElement()};
}
return {};
}
/// @param inst the SPIR-V instruction for OpAccessChain
void EmitGlslStd450ExtInst(const spvtools::opt::Instruction& inst) {
const auto ext_opcode = inst.GetSingleWordInOperand(1);
auto* spv_ty = Type(inst.type_id());
Vector<core::ir::Value*, 4> operands;
// All parameters to GLSL.std.450 extended instructions are IDs.
for (uint32_t idx = 2; idx < inst.NumInOperands(); ++idx) {
operands.Push(Value(inst.GetSingleWordInOperand(idx)));
}
const auto wgsl_fn = GetGlslStd450WgslEquivalentFuncName(ext_opcode);
if (wgsl_fn == core::BuiltinFn::kModf) {
// For `ModfStruct`, which is, essentially, a WGSL `modf` instruction
// we need some special handling. The result type that we produce
// must be the SPIR-V type as we don't know how the result is used
// later. So, we need to make the WGSL query and re-construct an
// object of the right SPIR-V type. We can't, easily, do this later
// as we lose the SPIR-V type as soon as we replace the result of the
// `modf`. So, inline the work here to generate the correct results.
auto* mem_ty = operands[0]->Type();
auto* result_ty = core::type::CreateModfResult(ty_, ir_.symbols, mem_ty);
auto* call = b_.Call(result_ty, wgsl_fn, operands);
auto* fract = b_.Access(mem_ty, call, 0_u);
auto* whole = b_.Access(mem_ty, call, 1_u);
EmitWithoutSpvResult(call);
EmitWithoutSpvResult(fract);
EmitWithoutSpvResult(whole);
Emit(b_.Construct(spv_ty, fract, whole), inst.result_id());
return;
}
if (wgsl_fn == core::BuiltinFn::kFrexp) {
// For `FrexpStruct`, which is, essentially, a WGSL `frexp`
// instruction we need some special handling. The result type that we
// produce must be the SPIR-V type as we don't know how the result is
// used later. So, we need to make the WGSL query and re-construct an
// object of the right SPIR-V type. We can't, easily, do this later
// as we lose the SPIR-V type as soon as we replace the result of the
// `frexp`. So, inline the work here to generate the correct results.
auto* mem_ty = operands[0]->Type();
auto* result_ty = core::type::CreateFrexpResult(ty_, ir_.symbols, mem_ty);
auto* call = b_.Call(result_ty, wgsl_fn, operands);
auto* fract = b_.Access(mem_ty, call, 0_u);
auto* exp = b_.Access(ty_.MatchWidth(ty_.i32(), mem_ty), call, 1_u);
auto* exp_res = exp->Result(0);
EmitWithoutSpvResult(call);
EmitWithoutSpvResult(fract);
EmitWithoutSpvResult(exp);
if (auto* str = spv_ty->As<core::type::Struct>()) {
auto* exp_ty = str->Members()[1]->Type();
if (exp_ty->DeepestElement()->IsUnsignedIntegerScalar()) {
auto* uexp = b_.Bitcast(exp_ty, exp);
exp_res = uexp->Result(0);
EmitWithoutSpvResult(uexp);
}
}
Emit(b_.Construct(spv_ty, fract, exp_res), inst.result_id());
return;
}
if (wgsl_fn != core::BuiltinFn::kNone) {
Emit(b_.Call(spv_ty, wgsl_fn, operands), inst.result_id());
return;
}
const auto spv_fn = GetGlslStd450SpirvEquivalentFuncName(ext_opcode);
if (spv_fn != spirv::BuiltinFn::kNone) {
auto explicit_params = GlslStd450ExplicitParams(ext_opcode, spv_ty);
Emit(b_.CallExplicit<spirv::ir::BuiltinCall>(spv_ty, spv_fn, explicit_params, operands),
inst.result_id());
return;
}
TINT_UNIMPLEMENTED() << "unhandled GLSL.std.450 instruction " << ext_opcode;
}
/// @param inst the SPIR-V instruction for OpAccessChain
void EmitAccess(const spvtools::opt::Instruction& inst) {
Vector indices = Args(inst, 3);
auto* base = Value(inst.GetSingleWordOperand(2));
if (indices.IsEmpty()) {
// There are no indices, so just forward the base object.
AddValue(inst.result_id(), base);
return;
}
// Propagate the access mode of the base object.
auto access_mode = core::Access::kUndefined;
if (auto* ptr = base->Type()->As<core::type::Pointer>()) {
access_mode = ptr->Access();
}
auto* access = b_.Access(Type(inst.type_id(), access_mode), base, std::move(indices));
Emit(access, inst.result_id());
}
/// @param inst the SPIR-V instruction
/// @param op the binary operator to use
void EmitBinary(const spvtools::opt::Instruction& inst, core::BinaryOp op) {
auto* lhs = Value(inst.GetSingleWordOperand(2));
auto* rhs = Value(inst.GetSingleWordOperand(3));
auto* binary = b_.Binary(op, Type(inst.type_id()), lhs, rhs);
Emit(binary, inst.result_id());
}
/// @param inst the SPIR-V instruction for OpCompositeExtract
void EmitCompositeExtract(const spvtools::opt::Instruction& inst) {
Vector<core::ir::Value*, 4> indices;
for (uint32_t i = 3; i < inst.NumOperandWords(); i++) {
indices.Push(b_.Constant(u32(inst.GetSingleWordOperand(i))));
}
auto* object = Value(inst.GetSingleWordOperand(2));
auto* access = b_.Access(Type(inst.type_id()), object, std::move(indices));
Emit(access, inst.result_id());
}
/// @param inst the SPIR-V instruction for OpCompositeConstruct
void EmitConstruct(const spvtools::opt::Instruction& inst) {
auto* construct = b_.Construct(Type(inst.type_id()), Args(inst, 2));
Emit(construct, inst.result_id());
}
/// @param inst the SPIR-V instruction for OpFunctionCall
void EmitFunctionCall(const spvtools::opt::Instruction& inst) {
Emit(b_.Call(Function(inst.GetSingleWordInOperand(0)), Args(inst, 3)), inst.result_id());
}
/// @param inst the SPIR-V instruction for OpVariable
void EmitVar(const spvtools::opt::Instruction& inst) {
// Handle decorations.
std::optional<uint32_t> group;
std::optional<uint32_t> binding;
core::Access access_mode = core::Access::kUndefined;
core::IOAttributes io_attributes;
auto interpolation = [&]() -> core::Interpolation& {
// Create the interpolation field with the default values on first call.
if (!io_attributes.interpolation.has_value()) {
io_attributes.interpolation = core::Interpolation{
core::InterpolationType::kPerspective, core::InterpolationSampling::kCenter};
}
return io_attributes.interpolation.value();
};
for (auto* deco :
spirv_context_->get_decoration_mgr()->GetDecorationsFor(inst.result_id(), false)) {
auto d = deco->GetSingleWordOperand(1);
switch (spv::Decoration(d)) {
case spv::Decoration::NonWritable:
access_mode = core::Access::kRead;
break;
case spv::Decoration::DescriptorSet:
group = deco->GetSingleWordOperand(2);
break;
case spv::Decoration::Binding:
binding = deco->GetSingleWordOperand(2);
break;
case spv::Decoration::BuiltIn:
io_attributes.builtin = Builtin(spv::BuiltIn(deco->GetSingleWordOperand(2)));
break;
case spv::Decoration::Invariant:
io_attributes.invariant = true;
break;
case spv::Decoration::Location:
io_attributes.location = deco->GetSingleWordOperand(2);
break;
case spv::Decoration::NoPerspective:
interpolation().type = core::InterpolationType::kLinear;
break;
case spv::Decoration::Flat:
interpolation().type = core::InterpolationType::kFlat;
break;
case spv::Decoration::Centroid:
interpolation().sampling = core::InterpolationSampling::kCentroid;
break;
case spv::Decoration::Sample:
interpolation().sampling = core::InterpolationSampling::kSample;
break;
case spv::Decoration::Index:
io_attributes.blend_src = deco->GetSingleWordOperand(2);
break;
default:
TINT_UNIMPLEMENTED() << "unhandled decoration " << d;
}
}
auto* var = b_.Var(Type(inst.type_id(), access_mode)->As<core::type::Pointer>());
if (inst.NumOperands() > 3) {
var->SetInitializer(Value(inst.GetSingleWordOperand(3)));
}
if (group || binding) {
TINT_ASSERT(group && binding);
var->SetBindingPoint(group.value(), binding.value());
}
var->SetAttributes(std::move(io_attributes));
Emit(var, inst.result_id());
}
private:
/// TypeKey describes a SPIR-V type with an access mode.
struct TypeKey {
/// The SPIR-V type object.
const spvtools::opt::analysis::Type* type;
/// The access mode.
core::Access access_mode;
// Equality operator for TypeKey.
bool operator==(const TypeKey& other) const {
return type == other.type && access_mode == other.access_mode;
}
/// @returns the hash code of the TypeKey
tint::HashCode HashCode() const { return Hash(type, access_mode); }
};
/// The generated IR module.
core::ir::Module ir_;
/// The Tint IR builder.
core::ir::Builder b_{ir_};
/// The Tint type manager.
core::type::Manager& ty_{ir_.Types()};
/// The Tint IR function that is currently being emitted.
core::ir::Function* current_function_ = nullptr;
/// The Tint IR block that is currently being emitted.
core::ir::Block* current_block_ = nullptr;
/// A map from a SPIR-V type declaration to the corresponding Tint type object.
Hashmap<TypeKey, const core::type::Type*, 16> types_;
/// A map from a SPIR-V function definition result ID to the corresponding Tint function object.
Hashmap<uint32_t, core::ir::Function*, 8> functions_;
/// A map from a SPIR-V result ID to the corresponding Tint value object.
Hashmap<uint32_t, core::ir::Value*, 8> values_;
/// The SPIR-V context containing the SPIR-V tools intermediate representation.
std::unique_ptr<spvtools::opt::IRContext> spirv_context_;
/// The current SPIR-V function being emitted
spvtools::opt::Function* current_spirv_function_ = nullptr;
// The set of IDs that are imports of the GLSL.std.450 extended instruction sets.
std::unordered_set<uint32_t> glsl_std_450_imports_;
// The set of IDs of imports that are ignored. For example, any "NonSemanticInfo." import is
// ignored.
std::unordered_set<uint32_t> ignored_imports_;
// Map of SPIR-V IDs to string names
std::unordered_map<uint32_t, std::string> id_to_name_;
// Map of SPIR-V Struct IDs to a list of member string names
std::unordered_map<uint32_t, std::vector<std::string>> struct_to_member_names_;
struct MergeInfo {
uint32_t id;
const spvtools::opt::Instruction* merge_inst;
};
// Stack of merge blocks
std::vector<MergeInfo> merge_stack_;
};
} // namespace
Result<core::ir::Module> Parse(Slice<const uint32_t> spirv) {
return Parser{}.Run(spirv);
}
} // namespace tint::spirv::reader