blob: 2bcae953684503512a45226da40ec8d04b207052 [file] [log] [blame] [edit]
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/reader/spirv/function.h"
#include <algorithm>
#include <array>
#include "src/tint/ast/assignment_statement.h"
#include "src/tint/ast/bitcast_expression.h"
#include "src/tint/ast/break_statement.h"
#include "src/tint/ast/builtin.h"
#include "src/tint/ast/builtin_attribute.h"
#include "src/tint/ast/call_statement.h"
#include "src/tint/ast/continue_statement.h"
#include "src/tint/ast/discard_statement.h"
#include "src/tint/ast/fallthrough_statement.h"
#include "src/tint/ast/if_statement.h"
#include "src/tint/ast/loop_statement.h"
#include "src/tint/ast/return_statement.h"
#include "src/tint/ast/stage_attribute.h"
#include "src/tint/ast/switch_statement.h"
#include "src/tint/ast/unary_op_expression.h"
#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/sem/builtin_type.h"
#include "src/tint/sem/depth_texture.h"
#include "src/tint/sem/sampled_texture.h"
#include "src/tint/transform/spirv_atomic.h"
// Terms:
// CFG: the control flow graph of the function, where basic blocks are the
// nodes, and branches form the directed arcs. The function entry block is
// the root of the CFG.
//
// Suppose H is a header block (i.e. has an OpSelectionMerge or OpLoopMerge).
// Then:
// - Let M(H) be the merge block named by the merge instruction in H.
// - If H is a loop header, i.e. has an OpLoopMerge instruction, then let
// CT(H) be the continue target block named by the OpLoopMerge
// instruction.
// - If H is a selection construct whose header ends in
// OpBranchConditional with true target %then and false target %else,
// then TT(H) = %then and FT(H) = %else
//
// Determining output block order:
// The "structured post-order traversal" of the CFG is a post-order traversal
// of the basic blocks in the CFG, where:
// We visit the entry node of the function first.
// When visiting a header block:
// We next visit its merge block
// Then if it's a loop header, we next visit the continue target,
// Then we visit the block's successors (whether it's a header or not)
// If the block ends in an OpBranchConditional, we visit the false target
// before the true target.
//
// The "reverse structured post-order traversal" of the CFG is the reverse
// of the structured post-order traversal.
// This is the order of basic blocks as they should be emitted to the WGSL
// function. It is the order computed by ComputeBlockOrder, and stored in
// the |FunctionEmiter::block_order_|.
// Blocks not in this ordering are ignored by the rest of the algorithm.
//
// Note:
// - A block D in the function might not appear in this order because
// no block in the order branches to D.
// - An unreachable block D might still be in the order because some header
// block in the order names D as its continue target, or merge block,
// or D is reachable from one of those otherwise-unreachable continue
// targets or merge blocks.
//
// Terms:
// Let Pos(B) be the index position of a block B in the computed block order.
//
// CFG intervals and valid nesting:
//
// A correctly structured CFG satisfies nesting rules that we can check by
// comparing positions of related blocks.
//
// If header block H is in the block order, then the following holds:
//
// Pos(H) < Pos(M(H))
//
// If CT(H) exists, then:
//
// Pos(H) <= Pos(CT(H))
// Pos(CT(H)) < Pos(M)
//
// This gives us the fundamental ordering of blocks in relation to a
// structured construct:
// The blocks before H in the block order, are not in the construct
// The blocks at M(H) or later in the block order, are not in the construct
// The blocks in a selection headed at H are in positions [ Pos(H),
// Pos(M(H)) ) The blocks in a loop construct headed at H are in positions
// [ Pos(H), Pos(CT(H)) ) The blocks in the continue construct for loop
// headed at H are in
// positions [ Pos(CT(H)), Pos(M(H)) )
//
// Schematically, for a selection construct headed by H, the blocks are in
// order from left to right:
//
// ...a-b-c H d-e-f M(H) n-o-p...
//
// where ...a-b-c: blocks before the selection construct
// where H and d-e-f: blocks in the selection construct
// where M(H) and n-o-p...: blocks after the selection construct
//
// Schematically, for a loop construct headed by H that is its own
// continue construct, the blocks in order from left to right:
//
// ...a-b-c H=CT(H) d-e-f M(H) n-o-p...
//
// where ...a-b-c: blocks before the loop
// where H is the continue construct; CT(H)=H, and the loop construct
// is *empty*
// where d-e-f... are other blocks in the continue construct
// where M(H) and n-o-p...: blocks after the continue construct
//
// Schematically, for a multi-block loop construct headed by H, there are
// blocks in order from left to right:
//
// ...a-b-c H d-e-f CT(H) j-k-l M(H) n-o-p...
//
// where ...a-b-c: blocks before the loop
// where H and d-e-f: blocks in the loop construct
// where CT(H) and j-k-l: blocks in the continue construct
// where M(H) and n-o-p...: blocks after the loop and continue
// constructs
//
using namespace tint::number_suffixes; // NOLINT
namespace tint::reader::spirv {
namespace {
constexpr uint32_t kMaxVectorLen = 4;
// Gets the AST unary opcode for the given SPIR-V opcode, if any
// @param opcode SPIR-V opcode
// @param ast_unary_op return parameter
// @returns true if it was a unary operation
bool GetUnaryOp(SpvOp opcode, ast::UnaryOp* ast_unary_op) {
switch (opcode) {
case SpvOpSNegate:
case SpvOpFNegate:
*ast_unary_op = ast::UnaryOp::kNegation;
return true;
case SpvOpLogicalNot:
*ast_unary_op = ast::UnaryOp::kNot;
return true;
case SpvOpNot:
*ast_unary_op = ast::UnaryOp::kComplement;
return true;
default:
break;
}
return false;
}
/// Converts a SPIR-V opcode for a WGSL builtin function, if there is a
/// direct translation. Returns nullptr otherwise.
/// @returns the WGSL builtin function name for the given opcode, or nullptr.
const char* GetUnaryBuiltInFunctionName(SpvOp opcode) {
switch (opcode) {
case SpvOpAny:
return "any";
case SpvOpAll:
return "all";
case SpvOpIsNan:
return "isNan";
case SpvOpIsInf:
return "isInf";
case SpvOpTranspose:
return "transpose";
default:
break;
}
return nullptr;
}
// Converts a SPIR-V opcode to its corresponding AST binary opcode, if any
// @param opcode SPIR-V opcode
// @returns the AST binary op for the given opcode, or kNone
ast::BinaryOp ConvertBinaryOp(SpvOp opcode) {
switch (opcode) {
case SpvOpIAdd:
case SpvOpFAdd:
return ast::BinaryOp::kAdd;
case SpvOpISub:
case SpvOpFSub:
return ast::BinaryOp::kSubtract;
case SpvOpIMul:
case SpvOpFMul:
case SpvOpVectorTimesScalar:
case SpvOpMatrixTimesScalar:
case SpvOpVectorTimesMatrix:
case SpvOpMatrixTimesVector:
case SpvOpMatrixTimesMatrix:
return ast::BinaryOp::kMultiply;
case SpvOpUDiv:
case SpvOpSDiv:
case SpvOpFDiv:
return ast::BinaryOp::kDivide;
case SpvOpUMod:
case SpvOpSMod:
case SpvOpFRem:
return ast::BinaryOp::kModulo;
case SpvOpLogicalEqual:
case SpvOpIEqual:
case SpvOpFOrdEqual:
return ast::BinaryOp::kEqual;
case SpvOpLogicalNotEqual:
case SpvOpINotEqual:
case SpvOpFOrdNotEqual:
return ast::BinaryOp::kNotEqual;
case SpvOpBitwiseAnd:
return ast::BinaryOp::kAnd;
case SpvOpBitwiseOr:
return ast::BinaryOp::kOr;
case SpvOpBitwiseXor:
return ast::BinaryOp::kXor;
case SpvOpLogicalAnd:
return ast::BinaryOp::kAnd;
case SpvOpLogicalOr:
return ast::BinaryOp::kOr;
case SpvOpUGreaterThan:
case SpvOpSGreaterThan:
case SpvOpFOrdGreaterThan:
return ast::BinaryOp::kGreaterThan;
case SpvOpUGreaterThanEqual:
case SpvOpSGreaterThanEqual:
case SpvOpFOrdGreaterThanEqual:
return ast::BinaryOp::kGreaterThanEqual;
case SpvOpULessThan:
case SpvOpSLessThan:
case SpvOpFOrdLessThan:
return ast::BinaryOp::kLessThan;
case SpvOpULessThanEqual:
case SpvOpSLessThanEqual:
case SpvOpFOrdLessThanEqual:
return ast::BinaryOp::kLessThanEqual;
default:
break;
}
// It's not clear what OpSMod should map to.
// https://bugs.chromium.org/p/tint/issues/detail?id=52
return ast::BinaryOp::kNone;
}
// If the given SPIR-V opcode is a floating point unordered comparison,
// then returns the binary float comparison for which it is the negation.
// Othewrise returns BinaryOp::kNone.
// @param opcode SPIR-V opcode
// @returns operation corresponding to negated version of the SPIR-V opcode
ast::BinaryOp NegatedFloatCompare(SpvOp opcode) {
switch (opcode) {
case SpvOpFUnordEqual:
return ast::BinaryOp::kNotEqual;
case SpvOpFUnordNotEqual:
return ast::BinaryOp::kEqual;
case SpvOpFUnordLessThan:
return ast::BinaryOp::kGreaterThanEqual;
case SpvOpFUnordLessThanEqual:
return ast::BinaryOp::kGreaterThan;
case SpvOpFUnordGreaterThan:
return ast::BinaryOp::kLessThanEqual;
case SpvOpFUnordGreaterThanEqual:
return ast::BinaryOp::kLessThan;
default:
break;
}
return ast::BinaryOp::kNone;
}
// Returns the WGSL standard library function for the given
// GLSL.std.450 extended instruction operation code. Unknown
// and invalid opcodes map to the empty string.
// @returns the WGSL standard function name, or an empty string.
std::string GetGlslStd450FuncName(uint32_t ext_opcode) {
switch (ext_opcode) {
case GLSLstd450FAbs:
case GLSLstd450SAbs:
return "abs";
case GLSLstd450Acos:
return "acos";
case GLSLstd450Asin:
return "asin";
case GLSLstd450Atan:
return "atan";
case GLSLstd450Atan2:
return "atan2";
case GLSLstd450Ceil:
return "ceil";
case GLSLstd450UClamp:
case GLSLstd450SClamp:
case GLSLstd450NClamp:
case GLSLstd450FClamp: // FClamp is less prescriptive about NaN operands
return "clamp";
case GLSLstd450Cos:
return "cos";
case GLSLstd450Cosh:
return "cosh";
case GLSLstd450Cross:
return "cross";
case GLSLstd450Degrees:
return "degrees";
case GLSLstd450Distance:
return "distance";
case GLSLstd450Exp:
return "exp";
case GLSLstd450Exp2:
return "exp2";
case GLSLstd450FaceForward:
return "faceForward";
case GLSLstd450Floor:
return "floor";
case GLSLstd450Fma:
return "fma";
case GLSLstd450Fract:
return "fract";
case GLSLstd450InverseSqrt:
return "inverseSqrt";
case GLSLstd450Ldexp:
return "ldexp";
case GLSLstd450Length:
return "length";
case GLSLstd450Log:
return "log";
case GLSLstd450Log2:
return "log2";
case GLSLstd450NMax:
case GLSLstd450FMax: // FMax is less prescriptive about NaN operands
case GLSLstd450UMax:
case GLSLstd450SMax:
return "max";
case GLSLstd450NMin:
case GLSLstd450FMin: // FMin is less prescriptive about NaN operands
case GLSLstd450UMin:
case GLSLstd450SMin:
return "min";
case GLSLstd450FMix:
return "mix";
case GLSLstd450Normalize:
return "normalize";
case GLSLstd450PackSnorm4x8:
return "pack4x8snorm";
case GLSLstd450PackUnorm4x8:
return "pack4x8unorm";
case GLSLstd450PackSnorm2x16:
return "pack2x16snorm";
case GLSLstd450PackUnorm2x16:
return "pack2x16unorm";
case GLSLstd450PackHalf2x16:
return "pack2x16float";
case GLSLstd450Pow:
return "pow";
case GLSLstd450FSign:
return "sign";
case GLSLstd450Radians:
return "radians";
case GLSLstd450Reflect:
return "reflect";
case GLSLstd450Refract:
return "refract";
case GLSLstd450Round:
case GLSLstd450RoundEven:
return "round";
case GLSLstd450Sin:
return "sin";
case GLSLstd450Sinh:
return "sinh";
case GLSLstd450SmoothStep:
return "smoothstep";
case GLSLstd450Sqrt:
return "sqrt";
case GLSLstd450Step:
return "step";
case GLSLstd450Tan:
return "tan";
case GLSLstd450Tanh:
return "tanh";
case GLSLstd450Trunc:
return "trunc";
case GLSLstd450UnpackSnorm4x8:
return "unpack4x8snorm";
case GLSLstd450UnpackUnorm4x8:
return "unpack4x8unorm";
case GLSLstd450UnpackSnorm2x16:
return "unpack2x16snorm";
case GLSLstd450UnpackUnorm2x16:
return "unpack2x16unorm";
case GLSLstd450UnpackHalf2x16:
return "unpack2x16float";
default:
// TODO(dneto) - The following are not implemented.
// They are grouped semantically, as in GLSL.std.450.h.
case GLSLstd450SSign:
case GLSLstd450Asinh:
case GLSLstd450Acosh:
case GLSLstd450Atanh:
case GLSLstd450Determinant:
case GLSLstd450MatrixInverse:
case GLSLstd450Modf:
case GLSLstd450ModfStruct:
case GLSLstd450IMix:
case GLSLstd450Frexp:
case GLSLstd450FrexpStruct:
case GLSLstd450PackDouble2x32:
case GLSLstd450UnpackDouble2x32:
case GLSLstd450FindILsb:
case GLSLstd450FindSMsb:
case GLSLstd450FindUMsb:
case GLSLstd450InterpolateAtCentroid:
case GLSLstd450InterpolateAtSample:
case GLSLstd450InterpolateAtOffset:
break;
}
return "";
}
// Returns the WGSL standard library function builtin for the
// given instruction, or sem::BuiltinType::kNone
sem::BuiltinType GetBuiltin(SpvOp opcode) {
switch (opcode) {
case SpvOpBitCount:
return sem::BuiltinType::kCountOneBits;
case SpvOpBitFieldInsert:
return sem::BuiltinType::kInsertBits;
case SpvOpBitFieldSExtract:
case SpvOpBitFieldUExtract:
return sem::BuiltinType::kExtractBits;
case SpvOpBitReverse:
return sem::BuiltinType::kReverseBits;
case SpvOpDot:
return sem::BuiltinType::kDot;
case SpvOpDPdx:
return sem::BuiltinType::kDpdx;
case SpvOpDPdy:
return sem::BuiltinType::kDpdy;
case SpvOpFwidth:
return sem::BuiltinType::kFwidth;
case SpvOpDPdxFine:
return sem::BuiltinType::kDpdxFine;
case SpvOpDPdyFine:
return sem::BuiltinType::kDpdyFine;
case SpvOpFwidthFine:
return sem::BuiltinType::kFwidthFine;
case SpvOpDPdxCoarse:
return sem::BuiltinType::kDpdxCoarse;
case SpvOpDPdyCoarse:
return sem::BuiltinType::kDpdyCoarse;
case SpvOpFwidthCoarse:
return sem::BuiltinType::kFwidthCoarse;
default:
break;
}
return sem::BuiltinType::kNone;
}
// @param opcode a SPIR-V opcode
// @returns true if the given instruction is an image access instruction
// whose first input operand is an OpSampledImage value.
bool IsSampledImageAccess(SpvOp opcode) {
switch (opcode) {
case SpvOpImageSampleImplicitLod:
case SpvOpImageSampleExplicitLod:
case SpvOpImageSampleDrefImplicitLod:
case SpvOpImageSampleDrefExplicitLod:
// WGSL doesn't have *Proj* texturing; spirv reader emulates it.
case SpvOpImageSampleProjImplicitLod:
case SpvOpImageSampleProjExplicitLod:
case SpvOpImageSampleProjDrefImplicitLod:
case SpvOpImageSampleProjDrefExplicitLod:
case SpvOpImageGather:
case SpvOpImageDrefGather:
case SpvOpImageQueryLod:
return true;
default:
break;
}
return false;
}
// @param opcode a SPIR-V opcode
// @returns true if the given instruction is an atomic operation.
bool IsAtomicOp(SpvOp opcode) {
switch (opcode) {
case SpvOpAtomicLoad:
case SpvOpAtomicStore:
case SpvOpAtomicExchange:
case SpvOpAtomicCompareExchange:
case SpvOpAtomicCompareExchangeWeak:
case SpvOpAtomicIIncrement:
case SpvOpAtomicIDecrement:
case SpvOpAtomicIAdd:
case SpvOpAtomicISub:
case SpvOpAtomicSMin:
case SpvOpAtomicUMin:
case SpvOpAtomicSMax:
case SpvOpAtomicUMax:
case SpvOpAtomicAnd:
case SpvOpAtomicOr:
case SpvOpAtomicXor:
case SpvOpAtomicFlagTestAndSet:
case SpvOpAtomicFlagClear:
case SpvOpAtomicFMinEXT:
case SpvOpAtomicFMaxEXT:
case SpvOpAtomicFAddEXT:
return true;
default:
break;
}
return false;
}
// @param opcode a SPIR-V opcode
// @returns true if the given instruction is an image sampling, gather,
// or gather-compare operation.
bool IsImageSamplingOrGatherOrDrefGather(SpvOp opcode) {
switch (opcode) {
case SpvOpImageSampleImplicitLod:
case SpvOpImageSampleExplicitLod:
case SpvOpImageSampleDrefImplicitLod:
case SpvOpImageSampleDrefExplicitLod:
// WGSL doesn't have *Proj* texturing; spirv reader emulates it.
case SpvOpImageSampleProjImplicitLod:
case SpvOpImageSampleProjExplicitLod:
case SpvOpImageSampleProjDrefImplicitLod:
case SpvOpImageSampleProjDrefExplicitLod:
case SpvOpImageGather:
case SpvOpImageDrefGather:
return true;
default:
break;
}
return false;
}
// @param opcode a SPIR-V opcode
// @returns true if the given instruction is an image access instruction
// whose first input operand is an OpImage value.
bool IsRawImageAccess(SpvOp opcode) {
switch (opcode) {
case SpvOpImageRead:
case SpvOpImageWrite:
case SpvOpImageFetch:
return true;
default:
break;
}
return false;
}
// @param opcode a SPIR-V opcode
// @returns true if the given instruction is an image query instruction
bool IsImageQuery(SpvOp opcode) {
switch (opcode) {
case SpvOpImageQuerySize:
case SpvOpImageQuerySizeLod:
case SpvOpImageQueryLevels:
case SpvOpImageQuerySamples:
case SpvOpImageQueryLod:
return true;
default:
break;
}
return false;
}
// @returns the merge block ID for the given basic block, or 0 if there is none.
uint32_t MergeFor(const spvtools::opt::BasicBlock& bb) {
// Get the OpSelectionMerge or OpLoopMerge instruction, if any.
auto* inst = bb.GetMergeInst();
return inst == nullptr ? 0 : inst->GetSingleWordInOperand(0);
}
// @returns the continue target ID for the given basic block, or 0 if there
// is none.
uint32_t ContinueTargetFor(const spvtools::opt::BasicBlock& bb) {
// Get the OpLoopMerge instruction, if any.
auto* inst = bb.GetLoopMergeInst();
return inst == nullptr ? 0 : inst->GetSingleWordInOperand(1);
}
// A structured traverser produces the reverse structured post-order of the
// CFG of a function. The blocks traversed are the transitive closure (minimum
// fixed point) of:
// - the entry block
// - a block reached by a branch from another block in the set
// - a block mentioned as a merge block or continue target for a block in the
// set
class StructuredTraverser {
public:
explicit StructuredTraverser(const spvtools::opt::Function& function) : function_(function) {
for (auto& block : function_) {
id_to_block_[block.id()] = &block;
}
}
// Returns the reverse postorder traversal of the CFG, where:
// - a merge block always follows its associated constructs
// - a continue target always follows the associated loop construct, if any
// @returns the IDs of blocks in reverse structured post order
std::vector<uint32_t> ReverseStructuredPostOrder() {
visit_order_.clear();
visited_.clear();
VisitBackward(function_.entry()->id());
std::vector<uint32_t> order(visit_order_.rbegin(), visit_order_.rend());
return order;
}
private:
// Executes a depth first search of the CFG, where right after we visit a
// header, we will visit its merge block, then its continue target (if any).
// Also records the post order ordering.
void VisitBackward(uint32_t id) {
if (id == 0) {
return;
}
if (visited_.count(id)) {
return;
}
visited_.insert(id);
const spvtools::opt::BasicBlock* bb = id_to_block_[id]; // non-null for valid modules
VisitBackward(MergeFor(*bb));
VisitBackward(ContinueTargetFor(*bb));
// Visit successors. We will naturally skip the continue target and merge
// blocks.
auto* terminator = bb->terminator();
auto opcode = terminator->opcode();
if (opcode == SpvOpBranchConditional) {
// Visit the false branch, then the true branch, to make them come
// out in the natural order for an "if".
VisitBackward(terminator->GetSingleWordInOperand(2));
VisitBackward(terminator->GetSingleWordInOperand(1));
} else if (opcode == SpvOpBranch) {
VisitBackward(terminator->GetSingleWordInOperand(0));
} else if (opcode == SpvOpSwitch) {
// TODO(dneto): Consider visiting the labels in literal-value order.
std::vector<uint32_t> successors;
bb->ForEachSuccessorLabel(
[&successors](const uint32_t succ_id) { successors.push_back(succ_id); });
for (auto succ_id : successors) {
VisitBackward(succ_id);
}
}
visit_order_.push_back(id);
}
const spvtools::opt::Function& function_;
std::unordered_map<uint32_t, const spvtools::opt::BasicBlock*> id_to_block_;
std::vector<uint32_t> visit_order_;
std::unordered_set<uint32_t> visited_;
};
/// A StatementBuilder for ast::SwitchStatement
/// @see StatementBuilder
struct SwitchStatementBuilder final : public Castable<SwitchStatementBuilder, StatementBuilder> {
/// Constructor
/// @param cond the switch statement condition
explicit SwitchStatementBuilder(const ast::Expression* cond) : condition(cond) {}
/// @param builder the program builder
/// @returns the built ast::SwitchStatement
const ast::SwitchStatement* Build(ProgramBuilder* builder) const override {
// We've listed cases in reverse order in the switch statement.
// Reorder them to match the presentation order in WGSL.
auto reversed_cases = cases;
std::reverse(reversed_cases.begin(), reversed_cases.end());
return builder->create<ast::SwitchStatement>(Source{}, condition, reversed_cases);
}
/// Switch statement condition
const ast::Expression* const condition;
/// Switch statement cases
ast::CaseStatementList cases;
};
/// A StatementBuilder for ast::IfStatement
/// @see StatementBuilder
struct IfStatementBuilder final : public Castable<IfStatementBuilder, StatementBuilder> {
/// Constructor
/// @param c the if-statement condition
explicit IfStatementBuilder(const ast::Expression* c) : cond(c) {}
/// @param builder the program builder
/// @returns the built ast::IfStatement
const ast::IfStatement* Build(ProgramBuilder* builder) const override {
return builder->create<ast::IfStatement>(Source{}, cond, body, else_stmt);
}
/// If-statement condition
const ast::Expression* const cond;
/// If-statement block body
const ast::BlockStatement* body = nullptr;
/// Optional if-statement else statement
const ast::Statement* else_stmt = nullptr;
};
/// A StatementBuilder for ast::LoopStatement
/// @see StatementBuilder
struct LoopStatementBuilder final : public Castable<LoopStatementBuilder, StatementBuilder> {
/// @param builder the program builder
/// @returns the built ast::LoopStatement
ast::LoopStatement* Build(ProgramBuilder* builder) const override {
return builder->create<ast::LoopStatement>(Source{}, body, continuing);
}
/// Loop-statement block body
const ast::BlockStatement* body = nullptr;
/// Loop-statement continuing body
/// @note the mutable keyword here is required as all non-StatementBuilders
/// `ast::Node`s are immutable and are referenced with `const` pointers.
/// StatementBuilders however exist to provide mutable state while the
/// FunctionEmitter is building the function. All StatementBuilders are
/// replaced with immutable AST nodes when Finalize() is called.
mutable const ast::BlockStatement* continuing = nullptr;
};
/// @param decos a list of parsed decorations
/// @returns true if the decorations include a SampleMask builtin
bool HasBuiltinSampleMask(const ast::AttributeList& decos) {
if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(decos)) {
return builtin->builtin == ast::Builtin::kSampleMask;
}
return false;
}
} // namespace
BlockInfo::BlockInfo(const spvtools::opt::BasicBlock& bb) : basic_block(&bb), id(bb.id()) {}
BlockInfo::~BlockInfo() = default;
DefInfo::DefInfo(const spvtools::opt::Instruction& def_inst,
bool the_locally_defined,
uint32_t the_block_pos,
size_t the_index)
: inst(def_inst),
locally_defined(the_locally_defined),
block_pos(the_block_pos),
index(the_index) {}
DefInfo::~DefInfo() = default;
ast::Node* StatementBuilder::Clone(CloneContext*) const {
return nullptr;
}
FunctionEmitter::FunctionEmitter(ParserImpl* pi,
const spvtools::opt::Function& function,
const EntryPointInfo* ep_info)
: parser_impl_(*pi),
ty_(pi->type_manager()),
builder_(pi->builder()),
ir_context_(*(pi->ir_context())),
def_use_mgr_(ir_context_.get_def_use_mgr()),
constant_mgr_(ir_context_.get_constant_mgr()),
type_mgr_(ir_context_.get_type_mgr()),
fail_stream_(pi->fail_stream()),
namer_(pi->namer()),
function_(function),
sample_mask_in_id(0u),
sample_mask_out_id(0u),
ep_info_(ep_info) {
PushNewStatementBlock(nullptr, 0, nullptr);
}
FunctionEmitter::FunctionEmitter(ParserImpl* pi, const spvtools::opt::Function& function)
: FunctionEmitter(pi, function, nullptr) {}
FunctionEmitter::FunctionEmitter(FunctionEmitter&& other)
: parser_impl_(other.parser_impl_),
ty_(other.ty_),
builder_(other.builder_),
ir_context_(other.ir_context_),
def_use_mgr_(ir_context_.get_def_use_mgr()),
constant_mgr_(ir_context_.get_constant_mgr()),
type_mgr_(ir_context_.get_type_mgr()),
fail_stream_(other.fail_stream_),
namer_(other.namer_),
function_(other.function_),
sample_mask_in_id(other.sample_mask_out_id),
sample_mask_out_id(other.sample_mask_in_id),
ep_info_(other.ep_info_) {
other.statements_stack_.clear();
PushNewStatementBlock(nullptr, 0, nullptr);
}
FunctionEmitter::~FunctionEmitter() = default;
FunctionEmitter::StatementBlock::StatementBlock(const Construct* construct,
uint32_t end_id,
FunctionEmitter::CompletionAction completion_action)
: construct_(construct), end_id_(end_id), completion_action_(completion_action) {}
FunctionEmitter::StatementBlock::StatementBlock(StatementBlock&& other) = default;
FunctionEmitter::StatementBlock::~StatementBlock() = default;
void FunctionEmitter::StatementBlock::Finalize(ProgramBuilder* pb) {
TINT_ASSERT(Reader, !finalized_ /* Finalize() must only be called once */);
for (size_t i = 0; i < statements_.size(); i++) {
if (auto* sb = statements_[i]->As<StatementBuilder>()) {
statements_[i] = sb->Build(pb);
}
}
if (completion_action_ != nullptr) {
completion_action_(statements_);
}
finalized_ = true;
}
void FunctionEmitter::StatementBlock::Add(const ast::Statement* statement) {
TINT_ASSERT(Reader, !finalized_ /* Add() must not be called after Finalize() */);
statements_.emplace_back(statement);
}
void FunctionEmitter::PushNewStatementBlock(const Construct* construct,
uint32_t end_id,
CompletionAction action) {
statements_stack_.emplace_back(StatementBlock{construct, end_id, action});
}
void FunctionEmitter::PushGuard(const std::string& guard_name, uint32_t end_id) {
TINT_ASSERT(Reader, !statements_stack_.empty());
TINT_ASSERT(Reader, !guard_name.empty());
// Guard control flow by the guard variable. Introduce a new
// if-selection with a then-clause ending at the same block
// as the statement block at the top of the stack.
const auto& top = statements_stack_.back();
auto* cond =
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(guard_name));
auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
PushNewStatementBlock(top.GetConstruct(), end_id, [=](const ast::StatementList& stmts) {
builder->body = create<ast::BlockStatement>(Source{}, stmts);
});
}
void FunctionEmitter::PushTrueGuard(uint32_t end_id) {
TINT_ASSERT(Reader, !statements_stack_.empty());
const auto& top = statements_stack_.back();
auto* cond = MakeTrue(Source{});
auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
PushNewStatementBlock(top.GetConstruct(), end_id, [=](const ast::StatementList& stmts) {
builder->body = create<ast::BlockStatement>(Source{}, stmts);
});
}
const ast::StatementList FunctionEmitter::ast_body() {
TINT_ASSERT(Reader, !statements_stack_.empty());
auto& entry = statements_stack_[0];
entry.Finalize(&builder_);
return entry.GetStatements();
}
const ast::Statement* FunctionEmitter::AddStatement(const ast::Statement* statement) {
TINT_ASSERT(Reader, !statements_stack_.empty());
if (statement != nullptr) {
statements_stack_.back().Add(statement);
}
return statement;
}
const ast::Statement* FunctionEmitter::LastStatement() {
TINT_ASSERT(Reader, !statements_stack_.empty());
auto& statement_list = statements_stack_.back().GetStatements();
TINT_ASSERT(Reader, !statement_list.empty());
return statement_list.back();
}
bool FunctionEmitter::Emit() {
if (failed()) {
return false;
}
// We only care about functions with bodies.
if (function_.cbegin() == function_.cend()) {
return true;
}
// The function declaration, corresponding to how it's written in SPIR-V,
// and without regard to whether it's an entry point.
FunctionDeclaration decl;
if (!ParseFunctionDeclaration(&decl)) {
return false;
}
bool make_body_function = true;
if (ep_info_) {
TINT_ASSERT(Reader, !ep_info_->inner_name.empty());
if (ep_info_->owns_inner_implementation) {
// This is an entry point, and we want to emit it as a wrapper around
// an implementation function.
decl.name = ep_info_->inner_name;
} else {
// This is a second entry point that shares an inner implementation
// function.
make_body_function = false;
}
}
if (make_body_function) {
auto* body = MakeFunctionBody();
if (!body) {
return false;
}
builder_.AST().AddFunction(
create<ast::Function>(decl.source, builder_.Symbols().Register(decl.name),
std::move(decl.params), decl.return_type->Build(builder_), body,
std::move(decl.attributes), ast::AttributeList{}));
}
if (ep_info_ && !ep_info_->inner_name.empty()) {
return EmitEntryPointAsWrapper();
}
return success();
}
const ast::BlockStatement* FunctionEmitter::MakeFunctionBody() {
TINT_ASSERT(Reader, statements_stack_.size() == 1);
if (!EmitBody()) {
return nullptr;
}
// Set the body of the AST function node.
if (statements_stack_.size() != 1) {
Fail() << "internal error: statement-list stack should have 1 "
"element but has "
<< statements_stack_.size();
return nullptr;
}
statements_stack_[0].Finalize(&builder_);
auto& statements = statements_stack_[0].GetStatements();
auto* body = create<ast::BlockStatement>(Source{}, statements);
// Maintain the invariant by repopulating the one and only element.
statements_stack_.clear();
PushNewStatementBlock(constructs_[0].get(), 0, nullptr);
return body;
}
bool FunctionEmitter::EmitPipelineInput(std::string var_name,
const Type* var_type,
ast::AttributeList* attrs,
std::vector<int> index_prefix,
const Type* tip_type,
const Type* forced_param_type,
ast::ParameterList* params,
ast::StatementList* statements) {
// TODO(dneto): Handle structs where the locations are annotated on members.
tip_type = tip_type->UnwrapAlias();
if (auto* ref_type = tip_type->As<Reference>()) {
tip_type = ref_type->type;
}
// Recursively flatten matrices, arrays, and structures.
return Switch(
tip_type,
[&](const Matrix* matrix_type) -> bool {
index_prefix.push_back(0);
const auto num_columns = static_cast<int>(matrix_type->columns);
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
for (int col = 0; col < num_columns; col++) {
index_prefix.back() = col;
if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix, vec_ty,
forced_param_type, params, statements)) {
return false;
}
}
return success();
},
[&](const Array* array_type) -> bool {
if (array_type->size == 0) {
return Fail() << "runtime-size array not allowed on pipeline IO";
}
index_prefix.push_back(0);
const Type* elem_ty = array_type->type;
for (int i = 0; i < static_cast<int>(array_type->size); i++) {
index_prefix.back() = i;
if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix, elem_ty,
forced_param_type, params, statements)) {
return false;
}
}
return success();
},
[&](const Struct* struct_type) -> bool {
const auto& members = struct_type->members;
index_prefix.push_back(0);
for (size_t i = 0; i < members.size(); ++i) {
index_prefix.back() = static_cast<int>(i);
ast::AttributeList member_attrs(*attrs);
if (!parser_impl_.ConvertPipelineDecorations(
struct_type,
parser_impl_.GetMemberPipelineDecorations(*struct_type,
static_cast<int>(i)),
&member_attrs)) {
return false;
}
if (!EmitPipelineInput(var_name, var_type, &member_attrs, index_prefix, members[i],
forced_param_type, params, statements)) {
return false;
}
// Copy the location as updated by nested expansion of the member.
parser_impl_.SetLocation(attrs, GetLocation(member_attrs));
}
return success();
},
[&](Default) {
const bool is_builtin = ast::HasAttribute<ast::BuiltinAttribute>(*attrs);
const Type* param_type = is_builtin ? forced_param_type : tip_type;
const auto param_name = namer_.MakeDerivedName(var_name + "_param");
// Create the parameter.
// TODO(dneto): Note: If the parameter has non-location decorations,
// then those decoration AST nodes will be reused between multiple
// elements of a matrix, array, or structure. Normally that's
// disallowed but currently the SPIR-V reader will make duplicates when
// the entire AST is cloned at the top level of the SPIR-V reader flow.
// Consider rewriting this to avoid this node-sharing.
params->push_back(builder_.Param(param_name, param_type->Build(builder_), *attrs));
// Add a body statement to copy the parameter to the corresponding
// private variable.
const ast::Expression* param_value = builder_.Expr(param_name);
const ast::Expression* store_dest = builder_.Expr(var_name);
// Index into the LHS as needed.
auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
for (auto index : index_prefix) {
Switch(
current_type,
[&](const Matrix* matrix_type) {
store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(i32(index)));
current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
},
[&](const Array* array_type) {
store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(i32(index)));
current_type = array_type->type->UnwrapAlias();
},
[&](const Struct* struct_type) {
store_dest = builder_.MemberAccessor(
store_dest,
builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
current_type = struct_type->members[static_cast<size_t>(index)];
});
}
if (is_builtin && (tip_type != forced_param_type)) {
// The parameter will have the WGSL type, but we need bitcast to
// the variable store type.
param_value =
create<ast::BitcastExpression>(tip_type->Build(builder_), param_value);
}
statements->push_back(builder_.Assign(store_dest, param_value));
// Increment the location attribute, in case more parameters will
// follow.
IncrementLocation(attrs);
return success();
});
}
void FunctionEmitter::IncrementLocation(ast::AttributeList* attributes) {
for (auto*& attr : *attributes) {
if (auto* loc_attr = attr->As<ast::LocationAttribute>()) {
// Replace this location attribute with a new one with one higher index.
// The old one doesn't leak because it's kept in the builder's AST node
// list.
attr = builder_.Location(loc_attr->source, loc_attr->value + 1);
}
}
}
const ast::Attribute* FunctionEmitter::GetLocation(const ast::AttributeList& attributes) {
for (auto* const& attr : attributes) {
if (attr->Is<ast::LocationAttribute>()) {
return attr;
}
}
return nullptr;
}
bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
const Type* var_type,
ast::AttributeList* decos,
std::vector<int> index_prefix,
const Type* tip_type,
const Type* forced_member_type,
ast::StructMemberList* return_members,
ast::ExpressionList* return_exprs) {
tip_type = tip_type->UnwrapAlias();
if (auto* ref_type = tip_type->As<Reference>()) {
tip_type = ref_type->type;
}
// Recursively flatten matrices, arrays, and structures.
return Switch(
tip_type,
[&](const Matrix* matrix_type) {
index_prefix.push_back(0);
const auto num_columns = static_cast<int>(matrix_type->columns);
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
for (int col = 0; col < num_columns; col++) {
index_prefix.back() = col;
if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, vec_ty,
forced_member_type, return_members, return_exprs)) {
return false;
}
}
return success();
},
[&](const Array* array_type) -> bool {
if (array_type->size == 0) {
return Fail() << "runtime-size array not allowed on pipeline IO";
}
index_prefix.push_back(0);
const Type* elem_ty = array_type->type;
for (int i = 0; i < static_cast<int>(array_type->size); i++) {
index_prefix.back() = i;
if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, elem_ty,
forced_member_type, return_members, return_exprs)) {
return false;
}
}
return success();
},
[&](const Struct* struct_type) -> bool {
const auto& members = struct_type->members;
index_prefix.push_back(0);
for (int i = 0; i < static_cast<int>(members.size()); ++i) {
index_prefix.back() = i;
ast::AttributeList member_attrs(*decos);
if (!parser_impl_.ConvertPipelineDecorations(
struct_type, parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
&member_attrs)) {
return false;
}
if (!EmitPipelineOutput(var_name, var_type, &member_attrs, index_prefix,
members[static_cast<size_t>(i)], forced_member_type,
return_members, return_exprs)) {
return false;
}
// Copy the location as updated by nested expansion of the member.
parser_impl_.SetLocation(decos, GetLocation(member_attrs));
}
return success();
},
[&](Default) {
const bool is_builtin = ast::HasAttribute<ast::BuiltinAttribute>(*decos);
const Type* member_type = is_builtin ? forced_member_type : tip_type;
// Derive the member name directly from the variable name. They can't
// collide.
const auto member_name = namer_.MakeDerivedName(var_name);
// Create the member.
// TODO(dneto): Note: If the parameter has non-location decorations,
// then those decoration AST nodes will be reused between multiple
// elements of a matrix, array, or structure. Normally that's
// disallowed but currently the SPIR-V reader will make duplicates when
// the entire AST is cloned at the top level of the SPIR-V reader flow.
// Consider rewriting this to avoid this node-sharing.
return_members->push_back(
builder_.Member(member_name, member_type->Build(builder_), *decos));
// Create an expression to evaluate the part of the variable indexed by
// the index_prefix.
const ast::Expression* load_source = builder_.Expr(var_name);
// Index into the variable as needed to pick out the flattened member.
auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
for (auto index : index_prefix) {
Switch(
current_type,
[&](const Matrix* matrix_type) {
load_source =
builder_.IndexAccessor(load_source, builder_.Expr(i32(index)));
current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
},
[&](const Array* array_type) {
load_source =
builder_.IndexAccessor(load_source, builder_.Expr(i32(index)));
current_type = array_type->type->UnwrapAlias();
},
[&](const Struct* struct_type) {
load_source = builder_.MemberAccessor(
load_source,
builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
current_type = struct_type->members[static_cast<size_t>(index)];
});
}
if (is_builtin && (tip_type != forced_member_type)) {
// The member will have the WGSL type, but we need bitcast to
// the variable store type.
load_source = create<ast::BitcastExpression>(forced_member_type->Build(builder_),
load_source);
}
return_exprs->push_back(load_source);
// Increment the location attribute, in case more parameters will
// follow.
IncrementLocation(decos);
return success();
});
}
bool FunctionEmitter::EmitEntryPointAsWrapper() {
Source source;
// The statements in the body.
ast::StatementList stmts;
FunctionDeclaration decl;
decl.source = source;
decl.name = ep_info_->name;
const ast::Type* return_type = nullptr; // Populated below.
// Pipeline inputs become parameters to the wrapper function, and
// their values are saved into the corresponding private variables that
// have already been created.
for (uint32_t var_id : ep_info_->inputs) {
const auto* var = def_use_mgr_->GetDef(var_id);
TINT_ASSERT(Reader, var != nullptr);
TINT_ASSERT(Reader, var->opcode() == SpvOpVariable);
auto* store_type = GetVariableStoreType(*var);
auto* forced_param_type = store_type;
ast::AttributeList param_decos;
if (!parser_impl_.ConvertDecorationsForVariable(var_id, &forced_param_type, &param_decos,
true)) {
// This occurs, and is not an error, for the PointSize builtin.
if (!success()) {
// But exit early if an error was logged.
return false;
}
continue;
}
// We don't have to handle initializers because in Vulkan SPIR-V, Input
// variables must not have them.
const auto var_name = namer_.GetName(var_id);
bool ok = true;
if (HasBuiltinSampleMask(param_decos)) {
// In Vulkan SPIR-V, the sample mask is an array. In WGSL it's a scalar.
// Use the first element only.
auto* sample_mask_array_type = store_type->UnwrapRef()->UnwrapAlias()->As<Array>();
TINT_ASSERT(Reader, sample_mask_array_type);
ok = EmitPipelineInput(var_name, store_type, &param_decos, {0},
sample_mask_array_type->type, forced_param_type, &decl.params,
&stmts);
} else {
// The normal path.
ok = EmitPipelineInput(var_name, store_type, &param_decos, {}, store_type,
forced_param_type, &decl.params, &stmts);
}
if (!ok) {
return false;
}
}
// Call the inner function. It has no parameters.
stmts.push_back(create<ast::CallStatement>(
source,
create<ast::CallExpression>(source,
create<ast::IdentifierExpression>(
source, builder_.Symbols().Register(ep_info_->inner_name)),
ast::ExpressionList{})));
// Pipeline outputs are mapped to the return value.
if (ep_info_->outputs.empty()) {
// There is nothing to return.
return_type = ty_.Void()->Build(builder_);
} else {
// Pipeline outputs are converted to a structure that is written
// to just before returning.
const auto return_struct_name = namer_.MakeDerivedName(ep_info_->name + "_out");
const auto return_struct_sym = builder_.Symbols().Register(return_struct_name);
// Define the structure.
std::vector<const ast::StructMember*> return_members;
ast::ExpressionList return_exprs;
const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
for (uint32_t var_id : ep_info_->outputs) {
if (var_id == builtin_position_info.per_vertex_var_id) {
// The SPIR-V gl_PerVertex variable has already been remapped to
// a gl_Position variable. Substitute the type.
const Type* param_type = ty_.Vector(ty_.F32(), 4);
ast::AttributeList out_decos{
create<ast::BuiltinAttribute>(source, ast::Builtin::kPosition)};
const auto var_name = namer_.GetName(var_id);
return_members.push_back(
builder_.Member(var_name, param_type->Build(builder_), out_decos));
return_exprs.push_back(builder_.Expr(var_name));
} else {
const auto* var = def_use_mgr_->GetDef(var_id);
TINT_ASSERT(Reader, var != nullptr);
TINT_ASSERT(Reader, var->opcode() == SpvOpVariable);
const Type* store_type = GetVariableStoreType(*var);
const Type* forced_member_type = store_type;
ast::AttributeList out_decos;
if (!parser_impl_.ConvertDecorationsForVariable(var_id, &forced_member_type,
&out_decos, true)) {
// This occurs, and is not an error, for the PointSize builtin.
if (!success()) {
// But exit early if an error was logged.
return false;
}
continue;
}
const auto var_name = namer_.GetName(var_id);
bool ok = true;
if (HasBuiltinSampleMask(out_decos)) {
// In Vulkan SPIR-V, the sample mask is an array. In WGSL it's a
// scalar. Use the first element only.
auto* sample_mask_array_type =
store_type->UnwrapRef()->UnwrapAlias()->As<Array>();
TINT_ASSERT(Reader, sample_mask_array_type);
ok = EmitPipelineOutput(var_name, store_type, &out_decos, {0},
sample_mask_array_type->type, forced_member_type,
&return_members, &return_exprs);
} else {
// The normal path.
ok = EmitPipelineOutput(var_name, store_type, &out_decos, {}, store_type,
forced_member_type, &return_members, &return_exprs);
}
if (!ok) {
return false;
}
}
}
if (return_members.empty()) {
// This can occur if only the PointSize member is accessed, because we
// never emit it.
return_type = ty_.Void()->Build(builder_);
} else {
// Create and register the result type.
auto* str = create<ast::Struct>(Source{}, return_struct_sym, return_members,
ast::AttributeList{});
parser_impl_.AddTypeDecl(return_struct_sym, str);
return_type = builder_.ty.Of(str);
// Add the return-value statement.
stmts.push_back(create<ast::ReturnStatement>(
source, builder_.Construct(source, return_type, std::move(return_exprs))));
}
}
auto* body = create<ast::BlockStatement>(source, stmts);
ast::AttributeList fn_attrs;
fn_attrs.emplace_back(create<ast::StageAttribute>(source, ep_info_->stage));
if (ep_info_->stage == ast::PipelineStage::kCompute) {
auto& size = ep_info_->workgroup_size;
if (size.x != 0 && size.y != 0 && size.z != 0) {
const ast::Expression* x = builder_.Expr(i32(size.x));
const ast::Expression* y = size.y ? builder_.Expr(i32(size.y)) : nullptr;
const ast::Expression* z = size.z ? builder_.Expr(i32(size.z)) : nullptr;
fn_attrs.emplace_back(create<ast::WorkgroupAttribute>(Source{}, x, y, z));
}
}
builder_.AST().AddFunction(create<ast::Function>(
source, builder_.Symbols().Register(ep_info_->name), std::move(decl.params), return_type,
body, std::move(fn_attrs), ast::AttributeList{}));
return true;
}
bool FunctionEmitter::ParseFunctionDeclaration(FunctionDeclaration* decl) {
if (failed()) {
return false;
}
const std::string name = namer_.Name(function_.result_id());
// Surprisingly, the "type id" on an OpFunction is the result type of the
// function, not the type of the function. This is the one exceptional case
// in SPIR-V where the type ID is not the type of the result ID.
auto* ret_ty = parser_impl_.ConvertType(function_.type_id());
if (failed()) {
return false;
}
if (ret_ty == nullptr) {
return Fail() << "internal error: unregistered return type for function with ID "
<< function_.result_id();
}
ast::ParameterList ast_params;
function_.ForEachParam([this, &ast_params](const spvtools::opt::Instruction* param) {
auto* type = parser_impl_.ConvertType(param->type_id());
if (type != nullptr) {
auto* ast_param =
parser_impl_.MakeParameter(param->result_id(), type, ast::AttributeList{});
// Parameters are treated as const declarations.
ast_params.emplace_back(ast_param);
// The value is accessible by name.
identifier_types_.emplace(param->result_id(), type);
} else {
// We've already logged an error and emitted a diagnostic. Do nothing
// here.
}
});
if (failed()) {
return false;
}
decl->name = name;
decl->params = std::move(ast_params);
decl->return_type = ret_ty;
decl->attributes.clear();
return success();
}
const Type* FunctionEmitter::GetVariableStoreType(const spvtools::opt::Instruction& var_decl_inst) {
const auto type_id = var_decl_inst.type_id();
// Normally we use the SPIRV-Tools optimizer to manage types.
// But when two struct types have the same member types and decorations,
// but differ only in member names, the two struct types will be
// represented by a single common internal struct type.
// So avoid the optimizer's representation and instead follow the
// SPIR-V instructions themselves.
const auto* ptr_ty = def_use_mgr_->GetDef(type_id);
const auto store_ty_id = ptr_ty->GetSingleWordInOperand(1);
const auto* result = parser_impl_.ConvertType(store_ty_id);
return result;
}
bool FunctionEmitter::EmitBody() {
RegisterBasicBlocks();
if (!TerminatorsAreValid()) {
return false;
}
if (!RegisterMerges()) {
return false;
}
ComputeBlockOrderAndPositions();
if (!VerifyHeaderContinueMergeOrder()) {
return false;
}
if (!LabelControlFlowConstructs()) {
return false;
}
if (!FindSwitchCaseHeaders()) {
return false;
}
if (!ClassifyCFGEdges()) {
return false;
}
if (!FindIfSelectionInternalHeaders()) {
return false;
}
if (!RegisterSpecialBuiltInVariables()) {
return false;
}
if (!RegisterLocallyDefinedValues()) {
return false;
}
FindValuesNeedingNamedOrHoistedDefinition();
if (!EmitFunctionVariables()) {
return false;
}
if (!EmitFunctionBodyStatements()) {
return false;
}
return success();
}
void FunctionEmitter::RegisterBasicBlocks() {
for (auto& block : function_) {
block_info_[block.id()] = std::make_unique<BlockInfo>(block);
}
}
bool FunctionEmitter::TerminatorsAreValid() {
if (failed()) {
return false;
}
const auto entry_id = function_.begin()->id();
for (const auto& block : function_) {
if (!block.terminator()) {
return Fail() << "Block " << block.id() << " has no terminator";
}
}
for (const auto& block : function_) {
block.WhileEachSuccessorLabel([this, &block, entry_id](const uint32_t succ_id) -> bool {
if (succ_id == entry_id) {
return Fail() << "Block " << block.id() << " branches to function entry block "
<< entry_id;
}
if (!GetBlockInfo(succ_id)) {
return Fail() << "Block " << block.id() << " in function "
<< function_.DefInst().result_id() << " branches to " << succ_id
<< " which is not a block in the function";
}
return true;
});
}
return success();
}
bool FunctionEmitter::RegisterMerges() {
if (failed()) {
return false;
}
const auto entry_id = function_.begin()->id();
for (const auto& block : function_) {
const auto block_id = block.id();
auto* block_info = GetBlockInfo(block_id);
if (!block_info) {
return Fail() << "internal error: block " << block_id
<< " missing; blocks should already "
"have been registered";
}
if (const auto* inst = block.GetMergeInst()) {
auto terminator_opcode = block.terminator()->opcode();
switch (inst->opcode()) {
case SpvOpSelectionMerge:
if ((terminator_opcode != SpvOpBranchConditional) &&
(terminator_opcode != SpvOpSwitch)) {
return Fail() << "Selection header " << block_id
<< " does not end in an OpBranchConditional or "
"OpSwitch instruction";
}
break;
case SpvOpLoopMerge:
if ((terminator_opcode != SpvOpBranchConditional) &&
(terminator_opcode != SpvOpBranch)) {
return Fail() << "Loop header " << block_id
<< " does not end in an OpBranch or "
"OpBranchConditional instruction";
}
break;
default:
break;
}
const uint32_t header = block.id();
auto* header_info = block_info;
const uint32_t merge = inst->GetSingleWordInOperand(0);
auto* merge_info = GetBlockInfo(merge);
if (!merge_info) {
return Fail() << "Structured header block " << header
<< " declares invalid merge block " << merge;
}
if (merge == header) {
return Fail() << "Structured header block " << header
<< " cannot be its own merge block";
}
if (merge_info->header_for_merge) {
return Fail() << "Block " << merge
<< " declared as merge block for more than one header: "
<< merge_info->header_for_merge << ", " << header;
}
merge_info->header_for_merge = header;
header_info->merge_for_header = merge;
if (inst->opcode() == SpvOpLoopMerge) {
if (header == entry_id) {
return Fail() << "Function entry block " << entry_id
<< " cannot be a loop header";
}
const uint32_t ct = inst->GetSingleWordInOperand(1);
auto* ct_info = GetBlockInfo(ct);
if (!ct_info) {
return Fail() << "Structured header " << header
<< " declares invalid continue target " << ct;
}
if (ct == merge) {
return Fail() << "Invalid structured header block " << header
<< ": declares block " << ct
<< " as both its merge block and continue target";
}
if (ct_info->header_for_continue) {
return Fail() << "Block " << ct
<< " declared as continue target for more than one header: "
<< ct_info->header_for_continue << ", " << header;
}
ct_info->header_for_continue = header;
header_info->continue_for_header = ct;
}
}
// Check single-block loop cases.
bool is_single_block_loop = false;
block_info->basic_block->ForEachSuccessorLabel(
[&is_single_block_loop, block_id](const uint32_t succ) {
if (block_id == succ) {
is_single_block_loop = true;
}
});
const auto ct = block_info->continue_for_header;
block_info->is_continue_entire_loop = ct == block_id;
if (is_single_block_loop && !block_info->is_continue_entire_loop) {
return Fail() << "Block " << block_id
<< " branches to itself but is not its own continue target";
}
// It's valid for a the header of a multi-block loop header to declare
// itself as its own continue target.
}
return success();
}
void FunctionEmitter::ComputeBlockOrderAndPositions() {
block_order_ = StructuredTraverser(function_).ReverseStructuredPostOrder();
for (uint32_t i = 0; i < block_order_.size(); ++i) {
GetBlockInfo(block_order_[i])->pos = i;
}
// The invalid block position is not the position of any block that is in the
// order.
assert(block_order_.size() <= kInvalidBlockPos);
}
bool FunctionEmitter::VerifyHeaderContinueMergeOrder() {
// Verify interval rules for a structured header block:
//
// If the CFG satisfies structured control flow rules, then:
// If header H is reachable, then the following "interval rules" hold,
// where M(H) is H's merge block, and CT(H) is H's continue target:
//
// Pos(H) < Pos(M(H))
//
// If CT(H) exists, then:
// Pos(H) <= Pos(CT(H))
// Pos(CT(H)) < Pos(M)
//
for (auto block_id : block_order_) {
const auto* block_info = GetBlockInfo(block_id);
const auto merge = block_info->merge_for_header;
if (merge == 0) {
continue;
}
// This is a header.
const auto header = block_id;
const auto* header_info = block_info;
const auto header_pos = header_info->pos;
const auto merge_pos = GetBlockInfo(merge)->pos;
// Pos(H) < Pos(M(H))
// Note: When recording merges we made sure H != M(H)
if (merge_pos <= header_pos) {
return Fail() << "Header " << header << " does not strictly dominate its merge block "
<< merge;
// TODO(dneto): Report a path from the entry block to the merge block
// without going through the header block.
}
const auto ct = block_info->continue_for_header;
if (ct == 0) {
continue;
}
// Furthermore, this is a loop header.
const auto* ct_info = GetBlockInfo(ct);
const auto ct_pos = ct_info->pos;
// Pos(H) <= Pos(CT(H))
if (ct_pos < header_pos) {
Fail() << "Loop header " << header << " does not dominate its continue target " << ct;
}
// Pos(CT(H)) < Pos(M(H))
// Note: When recording merges we made sure CT(H) != M(H)
if (merge_pos <= ct_pos) {
return Fail() << "Merge block " << merge << " for loop headed at block " << header
<< " appears at or before the loop's continue "
"construct headed by "
"block "
<< ct;
}
}
return success();
}
bool FunctionEmitter::LabelControlFlowConstructs() {
// Label each block in the block order with its nearest enclosing structured
// control flow construct. Populates the |construct| member of BlockInfo.
// Keep a stack of enclosing structured control flow constructs. Start
// with the synthetic construct representing the entire function.
//
// Scan from left to right in the block order, and check conditions
// on each block in the following order:
//
// a. When you reach a merge block, the top of the stack should
// be the associated header. Pop it off.
// b. When you reach a header, push it on the stack.
// c. When you reach a continue target, push it on the stack.
// (A block can be both a header and a continue target.)
// c. When you reach a block with an edge branching backward (in the
// structured order) to block T:
// T should be a loop header, and the top of the stack should be a
// continue target associated with T.
// This is the end of the continue construct. Pop the continue
// target off the stack.
//
// Note: A loop header can declare itself as its own continue target.
//
// Note: For a single-block loop, that block is a header, its own
// continue target, and its own backedge block.
//
// Note: We pop the merge off first because a merge block that marks
// the end of one construct can be a single-block loop. So that block
// is a merge, a header, a continue target, and a backedge block.
// But we want to finish processing of the merge before dealing with
// the loop.
//
// In the same scan, mark each basic block with the nearest enclosing
// header: the most recent header for which we haven't reached its merge
// block. Also mark the the most recent continue target for which we
// haven't reached the backedge block.
TINT_ASSERT(Reader, block_order_.size() > 0);
constructs_.clear();
const auto entry_id = block_order_[0];
// The stack of enclosing constructs.
std::vector<Construct*> enclosing;
// Creates a control flow construct and pushes it onto the stack.
// Its parent is the top of the stack, or nullptr if the stack is empty.
// Returns the newly created construct.
auto push_construct = [this, &enclosing](size_t depth, Construct::Kind k, uint32_t begin_id,
uint32_t end_id) -> Construct* {
const auto begin_pos = GetBlockInfo(begin_id)->pos;
const auto end_pos =
end_id == 0 ? uint32_t(block_order_.size()) : GetBlockInfo(end_id)->pos;
const auto* parent = enclosing.empty() ? nullptr : enclosing.back();
auto scope_end_pos = end_pos;
// A loop construct is added right after its associated continue construct.
// In that case, adjust the parent up.
if (k == Construct::kLoop) {
TINT_ASSERT(Reader, parent);
TINT_ASSERT(Reader, parent->kind == Construct::kContinue);
scope_end_pos = parent->end_pos;
parent = parent->parent;
}
constructs_.push_back(std::make_unique<Construct>(parent, static_cast<int>(depth), k,
begin_id, end_id, begin_pos, end_pos,
scope_end_pos));
Construct* result = constructs_.back().get();
enclosing.push_back(result);
return result;
};
// Make a synthetic kFunction construct to enclose all blocks in the function.
push_construct(0, Construct::kFunction, entry_id, 0);
// The entry block can be a selection construct, so be sure to process
// it anyway.
for (uint32_t i = 0; i < block_order_.size(); ++i) {
const auto block_id = block_order_[i];
TINT_ASSERT(Reader, block_id > 0);
auto* block_info = GetBlockInfo(block_id);
TINT_ASSERT(Reader, block_info);
if (enclosing.empty()) {
return Fail() << "internal error: too many merge blocks before block " << block_id;
}
const Construct* top = enclosing.back();
while (block_id == top->end_id) {
// We've reached a predeclared end of the construct. Pop it off the
// stack.
enclosing.pop_back();
if (enclosing.empty()) {
return Fail() << "internal error: too many merge blocks before block " << block_id;
}
top = enclosing.back();
}
const auto merge = block_info->merge_for_header;
if (merge != 0) {
// The current block is a header.
const auto header = block_id;
const auto* header_info = block_info;
const auto depth = static_cast<size_t>(1 + top->depth);
const auto ct = header_info->continue_for_header;
if (ct != 0) {
// The current block is a loop header.
// We should see the continue construct after the loop construct, so
// push the loop construct last.
// From the interval rule, the continue construct consists of blocks
// in the block order, starting at the continue target, until just
// before the merge block.
top = push_construct(depth, Construct::kContinue, ct, merge);
// A loop header that is its own continue target will have an
// empty loop construct. Only create a loop construct when
// the continue target is *not* the same as the loop header.
if (header != ct) {
// From the interval rule, the loop construct consists of blocks
// in the block order, starting at the header, until just
// before the continue target.
top = push_construct(depth, Construct::kLoop, header, ct);
// If the loop header branches to two different blocks inside the loop
// construct, then the loop body should be modeled as an if-selection
// construct
std::vector<uint32_t> targets;
header_info->basic_block->ForEachSuccessorLabel(
[&targets](const uint32_t target) { targets.push_back(target); });
if ((targets.size() == 2u) && targets[0] != targets[1]) {
const auto target0_pos = GetBlockInfo(targets[0])->pos;
const auto target1_pos = GetBlockInfo(targets[1])->pos;
if (top->ContainsPos(target0_pos) && top->ContainsPos(target1_pos)) {
// Insert a synthetic if-selection
top = push_construct(depth + 1, Construct::kIfSelection, header, ct);
}
}
}
} else {
// From the interval rule, the selection construct consists of blocks
// in the block order, starting at the header, until just before the
// merge block.
const auto branch_opcode = header_info->basic_block->terminator()->opcode();
const auto kind = (branch_opcode == SpvOpBranchConditional)
? Construct::kIfSelection
: Construct::kSwitchSelection;
top = push_construct(depth, kind, header, merge);
}
}
TINT_ASSERT(Reader, top);
block_info->construct = top;
}
// At the end of the block list, we should only have the kFunction construct
// left.
if (enclosing.size() != 1) {
return Fail() << "internal error: unbalanced structured constructs when "
"labeling structured constructs: ended with "
<< enclosing.size() - 1 << " unterminated constructs";
}
const auto* top = enclosing[0];
if (top->kind != Construct::kFunction || top->depth != 0) {
return Fail() << "internal error: outermost construct is not a function?!";
}
return success();
}
bool FunctionEmitter::FindSwitchCaseHeaders() {
if (failed()) {
return false;
}
for (auto& construct : constructs_) {
if (construct->kind != Construct::kSwitchSelection) {
continue;
}
const auto* branch = GetBlockInfo(construct->begin_id)->basic_block->terminator();
// Mark the default block
const auto default_id = branch->GetSingleWordInOperand(1);
auto* default_block = GetBlockInfo(default_id);
// A default target can't be a backedge.
if (construct->begin_pos >= default_block->pos) {
// An OpSwitch must dominate its cases. Also, it can't be a self-loop
// as that would be a backedge, and backedges can only target a loop,
// and loops use an OpLoopMerge instruction, which can't precede an
// OpSwitch.
return Fail() << "Switch branch from block " << construct->begin_id
<< " to default target block " << default_id << " can't be a back-edge";
}
// A default target can be the merge block, but can't go past it.
if (construct->end_pos < default_block->pos) {
return Fail() << "Switch branch from block " << construct->begin_id
<< " to default block " << default_id
<< " escapes the selection construct";
}
if (default_block->default_head_for) {
// An OpSwitch must dominate its cases, including the default target.
return Fail() << "Block " << default_id
<< " is declared as the default target for two OpSwitch "
"instructions, at blocks "
<< default_block->default_head_for->begin_id << " and "
<< construct->begin_id;
}
if ((default_block->header_for_merge != 0) &&
(default_block->header_for_merge != construct->begin_id)) {
// The switch instruction for this default block is an alternate path to
// the merge block, and hence the merge block is not dominated by its own
// (different) header.
return Fail() << "Block " << default_block->id
<< " is the default block for switch-selection header "
<< construct->begin_id << " and also the merge block for "
<< default_block->header_for_merge << " (violates dominance rule)";
}
default_block->default_head_for = construct.get();
default_block->default_is_merge = default_block->pos == construct->end_pos;
// Map a case target to the list of values selecting that case.
std::unordered_map<uint32_t, std::vector<uint64_t>> block_to_values;
std::vector<uint32_t> case_targets;
std::unordered_set<uint64_t> case_values;
// Process case targets.
for (uint32_t iarg = 2; iarg + 1 < branch->NumInOperands(); iarg += 2) {
const auto value = branch->GetInOperand(iarg).AsLiteralUint64();
const auto case_target_id = branch->GetSingleWordInOperand(iarg + 1);
if (case_values.count(value)) {
return Fail() << "Duplicate case value " << value << " in OpSwitch in block "
<< construct->begin_id;
}
case_values.insert(value);
if (block_to_values.count(case_target_id) == 0) {
case_targets.push_back(case_target_id);
}
block_to_values[case_target_id].push_back(value);
}
for (uint32_t case_target_id : case_targets) {
auto* case_block = GetBlockInfo(case_target_id);
case_block->case_values =
std::make_unique<std::vector<uint64_t>>(std::move(block_to_values[case_target_id]));
// A case target can't be a back-edge.
if (construct->begin_pos >= case_block->pos) {
// An OpSwitch must dominate its cases. Also, it can't be a self-loop
// as that would be a backedge, and backedges can only target a loop,
// and loops use an OpLoopMerge instruction, which can't preceded an
// OpSwitch.
return Fail() << "Switch branch from block " << construct->begin_id
<< " to case target block " << case_target_id
<< " can't be a back-edge";
}
// A case target can be the merge block, but can't go past it.
if (construct->end_pos < case_block->pos) {
return Fail() << "Switch branch from block " << construct->begin_id
<< " to case target block " << case_target_id
<< " escapes the selection construct";
}
if (case_block->header_for_merge != 0 &&
case_block->header_for_merge != construct->begin_id) {
// The switch instruction for this case block is an alternate path to
// the merge block, and hence the merge block is not dominated by its
// own (different) header.
return Fail() << "Block " << case_block->id
<< " is a case block for switch-selection header "
<< construct->begin_id << " and also the merge block for "
<< case_block->header_for_merge << " (violates dominance rule)";
}
// Mark the target as a case target.
if (case_block->case_head_for) {
// An OpSwitch must dominate its cases.
return Fail() << "Block " << case_target_id
<< " is declared as the switch case target for two OpSwitch "
"instructions, at blocks "
<< case_block->case_head_for->begin_id << " and "
<< construct->begin_id;
}
case_block->case_head_for = construct.get();
}
}
return success();
}
BlockInfo* FunctionEmitter::HeaderIfBreakable(const Construct* c) {
if (c == nullptr) {
return nullptr;
}
switch (c->kind) {
case Construct::kLoop:
case Construct::kSwitchSelection:
return GetBlockInfo(c->begin_id);
case Construct::kContinue: {
const auto* continue_target = GetBlockInfo(c->begin_id);
return GetBlockInfo(continue_target->header_for_continue);
}
default:
break;
}
return nullptr;
}
const Construct* FunctionEmitter::SiblingLoopConstruct(const Construct* c) const {
if (c == nullptr || c->kind != Construct::kContinue) {
return nullptr;
}
const uint32_t continue_target_id = c->begin_id;
const auto* continue_target = GetBlockInfo(continue_target_id);
const uint32_t header_id = continue_target->header_for_continue;
if (continue_target_id == header_id) {
// The continue target is the whole loop.
return nullptr;
}
const auto* candidate = GetBlockInfo(header_id)->construct;
// Walk up the construct tree until we hit the loop. In future
// we might handle the corner case where the same block is both a
// loop header and a selection header. For example, where the
// loop header block has a conditional branch going to distinct
// targets inside the loop body.
while (candidate && candidate->kind != Construct::kLoop) {
candidate = candidate->parent;
}
return candidate;
}
bool FunctionEmitter::ClassifyCFGEdges() {
if (failed()) {
return false;
}
// Checks validity of CFG edges leaving each basic block. This implicitly
// checks dominance rules for headers and continue constructs.
//
// For each branch encountered, classify each edge (S,T) as:
// - a back-edge
// - a structured exit (specific ways of branching to enclosing construct)
// - a normal (forward) edge, either natural control flow or a case
// fallthrough
//
// If more than one block is targeted by a normal edge, then S must be a
// structured header.
//
// Term: NEC(B) is the nearest enclosing construct for B.
//
// If edge (S,T) is a normal edge, and NEC(S) != NEC(T), then
// T is the header block of its NEC(T), and
// NEC(S) is the parent of NEC(T).
for (const auto src : block_order_) {
TINT_ASSERT(Reader, src > 0);
auto* src_info = GetBlockInfo(src);
TINT_ASSERT(Reader, src_info);
const auto src_pos = src_info->pos;
const auto& src_construct = *(src_info->construct);
// Compute the ordered list of unique successors.
std::vector<uint32_t> successors;
{
std::unordered_set<uint32_t> visited;
src_info->basic_block->ForEachSuccessorLabel(
[&successors, &visited](const uint32_t succ) {
if (visited.count(succ) == 0) {
successors.push_back(succ);
visited.insert(succ);
}
});
}
// There should only be one backedge per backedge block.
uint32_t num_backedges = 0;
// Track destinations for normal forward edges, either kForward
// or kCaseFallThrough. These count toward the need
// to have a merge instruction. We also track kIfBreak edges
// because when used with normal forward edges, we'll need
// to generate a flow guard variable.
std::vector<uint32_t> normal_forward_edges;
std::vector<uint32_t> if_break_edges;
if (successors.empty() && src_construct.enclosing_continue) {
// Kill and return are not allowed in a continue construct.
return Fail() << "Invalid function exit at block " << src
<< " from continue construct starting at "
<< src_construct.enclosing_continue->begin_id;
}
for (const auto dest : successors) {
const auto* dest_info = GetBlockInfo(dest);
// We've already checked terminators are valid.
TINT_ASSERT(Reader, dest_info);
const auto dest_pos = dest_info->pos;
// Insert the edge kind entry and keep a handle to update
// its classification.
EdgeKind& edge_kind = src_info->succ_edge[dest];
if (src_pos >= dest_pos) {
// This is a backedge.
edge_kind = EdgeKind::kBack;
num_backedges++;
const auto* continue_construct = src_construct.enclosing_continue;
if (!continue_construct) {
return Fail() << "Invalid backedge (" << src << "->" << dest << "): " << src
<< " is not in a continue construct";
}
if (src_pos != continue_construct->end_pos - 1) {
return Fail() << "Invalid exit (" << src << "->" << dest
<< ") from continue construct: " << src
<< " is not the last block in the continue construct "
"starting at "
<< src_construct.begin_id << " (violates post-dominance rule)";
}
const auto* ct_info = GetBlockInfo(continue_construct->begin_id);
TINT_ASSERT(Reader, ct_info);
if (ct_info->header_for_continue != dest) {
return Fail() << "Invalid backedge (" << src << "->" << dest
<< "): does not branch to the corresponding loop header, "
"expected "
<< ct_info->header_for_continue;
}
} else {
// This is a forward edge.
// For now, classify it that way, but we might update it.
edge_kind = EdgeKind::kForward;
// Exit from a continue construct can only be from the last block.
const auto* continue_construct = src_construct.enclosing_continue;
if (continue_construct != nullptr) {
if (continue_construct->ContainsPos(src_pos) &&
!continue_construct->ContainsPos(dest_pos) &&
(src_pos != continue_construct->end_pos - 1)) {
return Fail()
<< "Invalid exit (" << src << "->" << dest
<< ") from continue construct: " << src
<< " is not the last block in the continue construct "
"starting at "
<< continue_construct->begin_id << " (violates post-dominance rule)";
}
}
// Check valid structured exit cases.
if (edge_kind == EdgeKind::kForward) {
// Check for a 'break' from a loop or from a switch.
const auto* breakable_header =
HeaderIfBreakable(src_construct.enclosing_loop_or_continue_or_switch);
if (breakable_header != nullptr) {
if (dest == breakable_header->merge_for_header) {
// It's a break.
edge_kind =
(breakable_header->construct->kind == Construct::kSwitchSelection)
? EdgeKind::kSwitchBreak
: EdgeKind::kLoopBreak;
}
}
}
if (edge_kind == EdgeKind::kForward) {
// Check for a 'continue' from within a loop.
const auto* loop_header = HeaderIfBreakable(src_construct.enclosing_loop);
if (loop_header != nullptr) {
if (dest == loop_header->continue_for_header) {
// It's a continue.
edge_kind = EdgeKind::kLoopContinue;
}
}
}
if (edge_kind == EdgeKind::kForward) {
const auto& header_info = *GetBlockInfo(src_construct.begin_id);
if (dest == header_info.merge_for_header) {
// Branch to construct's merge block. The loop break and
// switch break cases have already been covered.
edge_kind = EdgeKind::kIfBreak;
}
}
// A forward edge into a case construct that comes from something
// other than the OpSwitch is actually a fallthrough.
if (edge_kind == EdgeKind::kForward) {
const auto* switch_construct =
(dest_info->case_head_for ? dest_info->case_head_for
: dest_info->default_head_for);
if (switch_construct != nullptr) {
if (src != switch_construct->begin_id) {
edge_kind = EdgeKind::kCaseFallThrough;
}
}
}
// The edge-kind has been finalized.
if ((edge_kind == EdgeKind::kForward) ||
(edge_kind == EdgeKind::kCaseFallThrough)) {
normal_forward_edges.push_back(dest);
}
if (edge_kind == EdgeKind::kIfBreak) {
if_break_edges.push_back(dest);
}
if ((edge_kind == EdgeKind::kForward) ||
(edge_kind == EdgeKind::kCaseFallThrough)) {
// Check for an invalid forward exit out of this construct.
if (dest_info->pos > src_construct.end_pos) {
// In most cases we're bypassing the merge block for the source
// construct.
auto end_block = src_construct.end_id;
const char* end_block_desc = "merge block";
if (src_construct.kind == Construct::kLoop) {
// For a loop construct, we have two valid places to go: the
// continue target or the merge for the loop header, which is
// further down.
const auto loop_merge =
GetBlockInfo(src_construct.begin_id)->merge_for_header;
if (dest_info->pos >= GetBlockInfo(loop_merge)->pos) {
// We're bypassing the loop's merge block.
end_block = loop_merge;
} else {
// We're bypassing the loop's continue target, and going into
// the middle of the continue construct.
end_block_desc = "continue target";
}
}
return Fail() << "Branch from block " << src << " to block " << dest
<< " is an invalid exit from construct starting at block "
<< src_construct.begin_id << "; branch bypasses "
<< end_block_desc << " " << end_block;
}
// Check dominance.
// Look for edges that violate the dominance condition: a branch
// from X to Y where:
// If Y is in a nearest enclosing continue construct headed by
// CT:
// Y is not CT, and
// In the structured order, X appears before CT order or
// after CT's backedge block.
// Otherwise, if Y is in a nearest enclosing construct
// headed by H:
// Y is not H, and
// In the structured order, X appears before H or after H's
// merge block.
const auto& dest_construct = *(dest_info->construct);
if (dest != dest_construct.begin_id && !dest_construct.ContainsPos(src_pos)) {
return Fail()
<< "Branch from " << src << " to " << dest << " bypasses "
<< (dest_construct.kind == Construct::kContinue ? "continue target "
: "header ")
<< dest_construct.begin_id << " (dominance rule violated)";
}
}
} // end forward edge
} // end successor
if (num_backedges > 1) {
return Fail() << "Block " << src << " has too many backedges: " << num_backedges;
}
if ((normal_forward_edges.size() > 1) && (src_info->merge_for_header == 0)) {
return Fail() << "Control flow diverges at block " << src << " (to "
<< normal_forward_edges[0] << ", " << normal_forward_edges[1]
<< ") but it is not a structured header (it has no merge "
"instruction)";
}
if ((normal_forward_edges.size() + if_break_edges.size() > 1) &&
(src_info->merge_for_header == 0)) {
// There is a branch to the merge of an if-selection combined
// with an other normal forward branch. Control within the
// if-selection needs to be gated by a flow predicate.
for (auto if_break_dest : if_break_edges) {
auto* head_info = GetBlockInfo(GetBlockInfo(if_break_dest)->header_for_merge);
// Generate a guard name, but only once.
if (head_info->flow_guard_name.empty()) {
const std::string guard = "guard" + std::to_string(head_info->id);
head_info->flow_guard_name = namer_.MakeDerivedName(guard);
}
}
}
}
return success();
}
bool FunctionEmitter::FindIfSelectionInternalHeaders() {
if (failed()) {
return false;
}
for (auto& construct : constructs_) {
if (construct->kind != Construct::kIfSelection) {
continue;
}
auto* if_header_info = GetBlockInfo(construct->begin_id);
const auto* branch = if_header_info->basic_block->terminator();
const auto true_head = branch->GetSingleWordInOperand(1);
const auto false_head = branch->GetSingleWordInOperand(2);
auto* true_head_info = GetBlockInfo(true_head);
auto* false_head_info = GetBlockInfo(false_head);
const auto true_head_pos = true_head_info->pos;
const auto false_head_pos = false_head_info->pos;
const bool contains_true = construct->ContainsPos(true_head_pos);
const bool contains_false = construct->ContainsPos(false_head_pos);
// The cases for each edge are:
// - kBack: invalid because it's an invalid exit from the selection
// - kSwitchBreak ; record this for later special processing
// - kLoopBreak ; record this for later special processing
// - kLoopContinue ; record this for later special processing
// - kIfBreak; normal case, may require a guard variable.
// - kFallThrough; invalid exit from the selection
// - kForward; normal case
if_header_info->true_kind = if_header_info->succ_edge[true_head];
if_header_info->false_kind = if_header_info->succ_edge[false_head];
if (contains_true) {
if_header_info->true_head = true_head;
}
if (contains_false) {
if_header_info->false_head = false_head;
}
if (contains_true && (true_head_info->header_for_merge != 0) &&
(true_head_info->header_for_merge != construct->begin_id)) {
// The OpBranchConditional instruction for the true head block is an
// alternate path to the merge block of a construct nested inside the
// selection, and hence the merge block is not dominated by its own
// (different) header.
return Fail() << "Block " << true_head << " is the true branch for if-selection header "
<< construct->begin_id << " and also the merge block for header block "
<< true_head_info->header_for_merge << " (violates dominance rule)";
}
if (contains_false && (false_head_info->header_for_merge != 0) &&
(false_head_info->header_for_merge != construct->begin_id)) {
// The OpBranchConditional instruction for the false head block is an
// alternate path to the merge block of a construct nested inside the
// selection, and hence the merge block is not dominated by its own
// (different) header.
return Fail() << "Block " << false_head
<< " is the false branch for if-selection header " << construct->begin_id
<< " and also the merge block for header block "
<< false_head_info->header_for_merge << " (violates dominance rule)";
}
if (contains_true && contains_false && (true_head_pos != false_head_pos)) {
// This construct has both a "then" clause and an "else" clause.
//
// We have this structure:
//
// Option 1:
//
// * condbranch
// * true-head (start of then-clause)
// ...
// * end-then-clause
// * false-head (start of else-clause)
// ...
// * end-false-clause
// * premerge-head
// ...
// * selection merge
//
// Option 2:
//
// * condbranch
// * true-head (start of then-clause)
// ...
// * end-then-clause
// * false-head (start of else-clause) and also premerge-head
// ...
// * end-false-clause
// * selection merge
//
// Option 3:
//
// * condbranch
// * false-head (start of else-clause)
// ...
// * end-else-clause
// * true-head (start of then-clause) and also premerge-head
// ...
// * end-then-clause
// * selection merge
//
// The premerge-head exists if there is a kForward branch from the end
// of the first clause to a block within the surrounding selection.
// The first clause might be a then-clause or an else-clause.
const auto second_head = std::max(true_head_pos, false_head_pos);
const auto end_first_clause_pos = second_head - 1;
TINT_ASSERT(Reader, end_first_clause_pos < block_order_.size());
const auto end_first_clause = block_order_[end_first_clause_pos];
uint32_t premerge_id = 0;
uint32_t if_break_id = 0;
for (auto& then_succ_iter : GetBlockInfo(end_first_clause)->succ_edge) {
const uint32_t dest_id = then_succ_iter.first;
const auto edge_kind = then_succ_iter.second;
switch (edge_kind) {
case EdgeKind::kIfBreak:
if_break_id = dest_id;
break;
case EdgeKind::kForward: {
if (construct->ContainsPos(GetBlockInfo(dest_id)->pos)) {
// It's a premerge.
if (premerge_id != 0) {
// TODO(dneto): I think this is impossible to trigger at this
// point in the flow. It would require a merge instruction to
// get past the check of "at-most-one-forward-edge".
return Fail()
<< "invalid structure: then-clause headed by block "
<< true_head << " ending at block " << end_first_clause
<< " has two forward edges to within selection"
<< " going to " << premerge_id << " and " << dest_id;
}
premerge_id = dest_id;
auto* dest_block_info = GetBlockInfo(dest_id);
if_header_info->premerge_head = dest_id;
if (dest_block_info->header_for_merge != 0) {
// Premerge has two edges coming into it, from the then-clause
// and the else-clause. It's also, by construction, not the
// merge block of the if-selection. So it must not be a merge
// block itself. The OpBranchConditional instruction for the
// false head block is an alternate path to the merge block, and
// hence the merge block is not dominated by its own (different)
// header.
return Fail()
<< "Block " << premerge_id << " is the merge block for "
<< dest_block_info->header_for_merge
<< " but has alternate paths reaching it, starting from"
<< " blocks " << true_head << " and " << false_head
<< " which are the true and false branches for the"
<< " if-selection header block " << construct->begin_id
<< " (violates dominance rule)";
}
}
break;
}
default:
break;
}
}
if (if_break_id != 0 && premerge_id != 0) {
return Fail() << "Block " << end_first_clause << " in if-selection headed at block "
<< construct->begin_id << " branches to both the merge block "
<< if_break_id << " and also to block " << premerge_id
<< " later in the selection";
}
}
}
return success();
}
bool FunctionEmitter::EmitFunctionVariables() {
if (failed()) {
return false;
}
for (auto& inst : *function_.entry()) {
if (inst.opcode() != SpvOpVariable) {
continue;
}
auto* var_store_type = GetVariableStoreType(inst);
if (failed()) {
return false;
}
const ast::Expression* constructor = nullptr;
if (inst.NumInOperands() > 1) {
// SPIR-V initializers are always constants.
// (OpenCL also allows the ID of an OpVariable, but we don't handle that
// here.)
constructor = parser_impl_.MakeConstantExpression(inst.GetSingleWordInOperand(1)).expr;
if (!constructor) {
return false;
}
}
auto* var = parser_impl_.MakeVar(inst.result_id(), ast::StorageClass::kNone, var_store_type,
constructor, ast::AttributeList{});
auto* var_decl_stmt = create<ast::VariableDeclStatement>(Source{}, var);
AddStatement(var_decl_stmt);
auto* var_type = ty_.Reference(var_store_type, ast::StorageClass::kNone);
identifier_types_.emplace(inst.result_id(), var_type);
}
return success();
}
TypedExpression FunctionEmitter::AddressOfIfNeeded(TypedExpression expr,
const spvtools::opt::Instruction* inst) {
if (inst && expr) {
if (auto* spirv_type = type_mgr_->GetType(inst->type_id())) {
if (expr.type->Is<Reference>() && spirv_type->AsPointer()) {
return AddressOf(expr);
}
}
}
return expr;
}
TypedExpression FunctionEmitter::MakeExpression(uint32_t id) {
if (failed()) {
return {};
}
switch (GetSkipReason(id)) {
case SkipReason::kDontSkip:
break;
case SkipReason::kOpaqueObject:
Fail() << "internal error: unhandled use of opaque object with ID: " << id;
return {};
case SkipReason::kSinkPointerIntoUse: {
// Replace the pointer with its source reference expression.
auto source_expr = GetDefInfo(id)->sink_pointer_source_expr;
TINT_ASSERT(Reader, source_expr.type->Is<Reference>());
return source_expr;
}
case SkipReason::kPointSizeBuiltinValue: {
return {ty_.F32(), create<ast::FloatLiteralExpression>(
Source{}, 1.0, ast::FloatLiteralExpression::Suffix::kF)};
}
case SkipReason::kPointSizeBuiltinPointer:
Fail() << "unhandled use of a pointer to the PointSize builtin, with ID: " << id;
return {};
case SkipReason::kSampleMaskInBuiltinPointer:
Fail() << "unhandled use of a pointer to the SampleMask builtin, with ID: " << id;
return {};
case SkipReason::kSampleMaskOutBuiltinPointer: {
// The result type is always u32.
auto name = namer_.Name(sample_mask_out_id);
return TypedExpression{ty_.U32(), create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register(name))};
}
}
auto type_it = identifier_types_.find(id);
if (type_it != identifier_types_.end()) {
auto name = namer_.Name(id);
auto* type = type_it->second;
return TypedExpression{
type, create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name))};
}
if (parser_impl_.IsScalarSpecConstant(id)) {
auto name = namer_.Name(id);
return TypedExpression{
parser_impl_.ConvertType(def_use_mgr_->GetDef(id)->type_id()),
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name))};
}
if (singly_used_values_.count(id)) {
auto expr = std::move(singly_used_values_[id]);
singly_used_values_.erase(id);
return expr;
}
const auto* spirv_constant = constant_mgr_->FindDeclaredConstant(id);
if (spirv_constant) {
return parser_impl_.MakeConstantExpression(id);
}
const auto* inst = def_use_mgr_->GetDef(id);
if (inst == nullptr) {
Fail() << "ID " << id << " does not have a defining SPIR-V instruction";
return {};
}
switch (inst->opcode()) {
case SpvOpVariable: {
// This occurs for module-scope variables.
auto name = namer_.Name(inst->result_id());
return TypedExpression{
parser_impl_.ConvertType(inst->type_id(), PtrAs::Ref),
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name))};
}
case SpvOpUndef:
// Substitute a null value for undef.
// This case occurs when OpUndef appears at module scope, as if it were
// a constant.
return parser_impl_.MakeNullExpression(parser_impl_.ConvertType(inst->type_id()));
default:
break;
}
if (const spvtools::opt::BasicBlock* const bb = ir_context_.get_instr_block(id)) {
if (auto* block = GetBlockInfo(bb->id())) {
if (block->pos == kInvalidBlockPos) {
// The value came from a block not in the block order.
// Substitute a null value.
return parser_impl_.MakeNullExpression(parser_impl_.ConvertType(inst->type_id()));
}
}
}
Fail() << "unhandled expression for ID " << id << "\n" << inst->PrettyPrint();
return {};
}
bool FunctionEmitter::EmitFunctionBodyStatements() {
// Dump the basic blocks in order, grouped by construct.
// We maintain a stack of StatementBlock objects, where new statements
// are always written to the topmost entry of the stack. By this point in
// processing, we have already recorded the interesting control flow
// boundaries in the BlockInfo and associated Construct objects. As we
// enter a new statement grouping, we push onto the stack, and also schedule
// the statement block's completion and removal at a future block's ID.
// Upon entry, the statement stack has one entry representing the whole
// function.
TINT_ASSERT(Reader, !constructs_.empty());
Construct* function_construct = constructs_[0].get();
TINT_ASSERT(Reader, function_construct != nullptr);
TINT_ASSERT(Reader, function_construct->kind == Construct::kFunction);
// Make the first entry valid by filling in the construct field, which
// had not been computed at the time the entry was first created.
// TODO(dneto): refactor how the first construct is created vs.
// this statements stack entry is populated.
TINT_ASSERT(Reader, statements_stack_.size() == 1);
statements_stack_[0].SetConstruct(function_construct);
for (auto block_id : block_order()) {
if (!EmitBasicBlock(*GetBlockInfo(block_id))) {
return false;
}
}
return success();
}
bool FunctionEmitter::EmitBasicBlock(const BlockInfo& block_info) {
// Close off previous constructs.
while (!statements_stack_.empty() && (statements_stack_.back().GetEndId() == block_info.id)) {
statements_stack_.back().Finalize(&builder_);
statements_stack_.pop_back();
}
if (statements_stack_.empty()) {
return Fail() << "internal error: statements stack empty at block " << block_info.id;
}
// Enter new constructs.
std::vector<const Construct*> entering_constructs; // inner most comes first
{
auto* here = block_info.construct;
auto* const top_construct = statements_stack_.back().GetConstruct();
while (here != top_construct) {
// Only enter a construct at its header block.
if (here->begin_id == block_info.id) {
entering_constructs.push_back(here);
}
here = here->parent;
}
}
// What constructs can we have entered?
// - It can't be kFunction, because there is only one of those, and it was
// already on the stack at the outermost level.
// - We have at most one of kSwitchSelection, or kLoop because each of those
// is headed by a block with a merge instruction (OpLoopMerge for kLoop,
// and OpSelectionMerge for kSwitchSelection).
// - When there is a kIfSelection, it can't contain another construct,
// because both would have to have their own distinct merge instructions
// and distinct terminators.
// - A kContinue can contain a kContinue
// This is possible in Vulkan SPIR-V, but Tint disallows this by the rule
// that a block can be continue target for at most one header block. See
// test DISABLED_BlockIsContinueForMoreThanOneHeader. If we generalize this,
// then by a dominance argument, the inner loop continue target can only be
// a single-block loop.
// TODO(dneto): Handle this case.
// - If a kLoop is on the outside, its terminator is either:
// - an OpBranch, in which case there is no other construct.
// - an OpBranchConditional, in which case there is either an kIfSelection
// (when both branch targets are different and are inside the loop),
// or no other construct (because the branch targets are the same,
// or one of them is a break or continue).
// - All that's left is a kContinue on the outside, and one of
// kIfSelection, kSwitchSelection, kLoop on the inside.
//
// The kContinue can be the parent of the other. For example, a selection
// starting at the first block of a continue construct.
//
// The kContinue can't be the child of the other because either:
// - The other can't be kLoop because:
// - If the kLoop is for a different loop then the kContinue, then
// the kContinue must be its own loop header, and so the same
// block is two different loops. That's a contradiction.
// - If the kLoop is for a the same loop, then this is a contradiction
// because a kContinue and its kLoop have disjoint block sets.
// - The other construct can't be a selection because:
// - The kContinue construct is the entire loop, i.e. the continue
// target is its own loop header block. But then the continue target
// has an OpLoopMerge instruction, which contradicts this block being
// a selection header.
// - The kContinue is in a multi-block loop that is has a non-empty
// kLoop; and the selection contains the kContinue block but not the
// loop block. That breaks dominance rules. That is, the continue
// target is dominated by that loop header, and so gets found by the
// block traversal on the outside before the selection is found. The
// selection is inside the outer loop.
//
// So we fall into one of the following cases:
// - We are entering 0 or 1 constructs, or
// - We are entering 2 constructs, with the outer one being a kContinue or
// kLoop, the inner one is not a continue.
if (entering_constructs.size() > 2) {
return Fail() << "internal error: bad construct nesting found";
}
if (entering_constructs.size() == 2) {
auto inner_kind = entering_constructs[0]->kind;
auto outer_kind = entering_constructs[1]->kind;
if (outer_kind != Construct::kContinue && outer_kind != Construct::kLoop) {
return Fail() << "internal error: bad construct nesting. Only a Continue "
"or a Loop construct can be outer construct on same block. "
"Got outer kind "
<< int(outer_kind) << " inner kind " << int(inner_kind);
}
if (inner_kind == Construct::kContinue) {
return Fail() << "internal error: unsupported construct nesting: "
"Continue around Continue";
}
if (inner_kind != Construct::kIfSelection && inner_kind != Construct::kSwitchSelection &&
inner_kind != Construct::kLoop) {
return Fail() << "internal error: bad construct nesting. Continue around "
"something other than if, switch, or loop";
}
}
// Enter constructs from outermost to innermost.
// kLoop and kContinue push a new statement-block onto the stack before
// emitting statements in the block.
// kIfSelection and kSwitchSelection emit statements in the block and then
// emit push a new statement-block. Only emit the statements in the block
// once.
// Have we emitted the statements for this block?
bool emitted = false;
// When entering an if-selection or switch-selection, we will emit the WGSL
// construct to cause the divergent branching. But otherwise, we will
// emit a "normal" block terminator, which occurs at the end of this method.
bool has_normal_terminator = true;
for (auto iter = entering_constructs.rbegin(); iter != entering_constructs.rend(); ++iter) {
const Construct* construct = *iter;
switch (construct->kind) {
case Construct::kFunction:
return Fail() << "internal error: nested function construct";
case Construct::kLoop:
if (!EmitLoopStart(construct)) {
return false;
}
if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
return false;
}
break;
case Construct::kContinue:
if (block_info.is_continue_entire_loop) {
if (!EmitLoopStart(construct)) {
return false;
}
if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
return false;
}
} else {
if (!EmitContinuingStart(construct)) {
return false;
}
}
break;
case Construct::kIfSelection:
if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
return false;
}
if (!EmitIfStart(block_info)) {
return false;
}
has_normal_terminator = false;
break;
case Construct::kSwitchSelection:
if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
return false;
}
if (!EmitSwitchStart(block_info)) {
return false;
}
has_normal_terminator = false;
break;
}
}
// If we aren't starting or transitioning, then emit the normal
// statements now.
if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
return false;
}
if (has_normal_terminator) {
if (!EmitNormalTerminator(block_info)) {
return false;
}
}
return success();
}
bool FunctionEmitter::EmitIfStart(const BlockInfo& block_info) {
// The block is the if-header block. So its construct is the if construct.
auto* construct = block_info.construct;
TINT_ASSERT(Reader, construct->kind == Construct::kIfSelection);
TINT_ASSERT(Reader, construct->begin_id == block_info.id);
const uint32_t true_head = block_info.true_head;
const uint32_t false_head = block_info.false_head;
const uint32_t premerge_head = block_info.premerge_head;
const std::string guard_name = block_info.flow_guard_name;
if (!guard_name.empty()) {
// Declare the guard variable just before the "if", initialized to true.
auto* guard_var = builder_.Var(guard_name, builder_.ty.bool_(), MakeTrue(Source{}));
auto* guard_decl = create<ast::VariableDeclStatement>(Source{}, guard_var);
AddStatement(guard_decl);
}
const auto condition_id = block_info.basic_block->terminator()->GetSingleWordInOperand(0);
auto* cond = MakeExpression(condition_id).expr;
if (!cond) {
return false;
}
// Generate the code for the condition.
auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
// Compute the block IDs that should end the then-clause and the else-clause.
// We need to know where the *emitted* selection should end, i.e. the intended
// merge block id. That should be the current premerge block, if it exists,
// or otherwise the declared merge block.
//
// This is another way to think about it:
// If there is a premerge, then there are three cases:
// - premerge_head is different from the true_head and false_head:
// - Premerge comes last. In effect, move the selection merge up
// to where the premerge begins.
// - premerge_head is the same as the false_head
// - This is really an if-then without an else clause.
// Move the merge up to where the premerge is.
// - premerge_head is the same as the true_head
// - This is really an if-else without an then clause.
// Emit it as: if (cond) {} else {....}
// Move the merge up to where the premerge is.
const uint32_t intended_merge = premerge_head ? premerge_head : construct->end_id;
// then-clause:
// If true_head exists:
// spans from true head to the earlier of the false head (if it exists)
// or the selection merge.
// Otherwise:
// ends at from the false head (if it exists), otherwise the selection
// end.
const uint32_t then_end = false_head ? false_head : intended_merge;
// else-clause:
// ends at the premerge head (if it exists) or at the selection end.
const uint32_t else_end = premerge_head ? premerge_head : intended_merge;
const bool true_is_break = (block_info.true_kind == EdgeKind::kSwitchBreak) ||
(block_info.true_kind == EdgeKind::kLoopBreak);
const bool false_is_break = (block_info.false_kind == EdgeKind::kSwitchBreak) ||
(block_info.false_kind == EdgeKind::kLoopBreak);
const bool true_is_continue = block_info.true_kind == EdgeKind::kLoopContinue;
const bool false_is_continue = block_info.false_kind == EdgeKind::kLoopContinue;
// Push statement blocks for the then-clause and the else-clause.
// But make sure we do it in the right order.
auto push_else = [this, builder, else_end, construct, false_is_break, false_is_continue]() {
// Push the else clause onto the stack first.
PushNewStatementBlock(construct, else_end, [=](const ast::StatementList& stmts) {
// Only set the else-clause if there are statements to fill it.
if (!stmts.empty()) {
// The "else" consists of the statement list from the top of
// statements stack, without an "else if" condition.
builder->else_stmt = create<ast::BlockStatement>(Source{}, stmts);
}
});
if (false_is_break) {
AddStatement(create<ast::BreakStatement>(Source{}));
}
if (false_is_continue) {
AddStatement(create<ast::ContinueStatement>(Source{}));
}
};
if (!true_is_break && !true_is_continue &&
(GetBlockInfo(else_end)->pos < GetBlockInfo(then_end)->pos)) {
// Process the else-clause first. The then-clause will be empty so avoid
// pushing onto the stack at all.
push_else();
} else {
// Blocks for the then-clause appear before blocks for the else-clause.
// So push the else-clause handling onto the stack first. The else-clause
// might be empty, but this works anyway.
// Handle the premerge, if it exists.
if (premerge_head) {
// The top of the stack is the statement block that is the parent of the
// if-statement. Adding statements now will place them after that 'if'.
if (guard_name.empty()) {
// We won't have a flow guard for the premerge.
// Insert a trivial if(true) { ... } around the blocks from the
// premerge head until the end of the if-selection. This is needed
// to ensure uniform reconvergence occurs at the end of the if-selection
// just like in the original SPIR-V.
PushTrueGuard(construct->end_id);
} else {
// Add a flow guard around the blocks in the premerge area.
PushGuard(guard_name, construct->end_id);
}
}
push_else();
if (true_head && false_head && !guard_name.empty()) {
// There are non-trivial then and else clauses.
// We have to guard the start of the else.
PushGuard(guard_name, else_end);
}
// Push the then clause onto the stack.
PushNewStatementBlock(construct, then_end, [=](const ast::StatementList& stmts) {
builder->body = create<ast::BlockStatement>(Source{}, stmts);
});
if (true_is_break) {
AddStatement(create<ast::BreakStatement>(Source{}));
}
if (true_is_continue) {
AddStatement(create<ast::ContinueStatement>(Source{}));
}
}
return success();
}
bool FunctionEmitter::EmitSwitchStart(const BlockInfo& block_info) {
// The block is the if-header block. So its construct is the if construct.
auto* construct = block_info.construct;
TINT_ASSERT(Reader, construct->kind == Construct::kSwitchSelection);
TINT_ASSERT(Reader, construct->begin_id == block_info.id);
const auto* branch = block_info.basic_block->terminator();
const auto selector_id = branch->GetSingleWordInOperand(0);
// Generate the code for the selector.
auto selector = MakeExpression(selector_id);
if (!selector) {
return false;
}
// First, push the statement block for the entire switch.
auto* swch = AddStatementBuilder<SwitchStatementBuilder>(selector.expr);
// Grab a pointer to the case list. It will get buried in the statement block
// stack.
PushNewStatementBlock(construct, construct->end_id, nullptr);
// We will push statement-blocks onto the stack to gather the statements in
// the default clause and cases clauses. Determine the list of blocks
// that start each clause.
std::vector<const BlockInfo*> clause_heads;
// Collect the case clauses, even if they are just the merge block.
// First the default clause.
const auto default_id = branch->GetSingleWordInOperand(1);
const auto* default_info = GetBlockInfo(default_id);
clause_heads.push_back(default_info);
// Now the case clauses.
for (uint32_t iarg = 2; iarg + 1 < branch->NumInOperands(); iarg += 2) {
const auto case_target_id = branch->GetSingleWordInOperand(iarg + 1);
clause_heads.push_back(GetBlockInfo(case_target_id));
}
std::stable_sort(
clause_heads.begin(), clause_heads.end(),
[](const BlockInfo* lhs, const BlockInfo* rhs) { return lhs->pos < rhs->pos; });
// Remove duplicates
{
// Use read index r, and write index w.
// Invariant: w <= r;
size_t w = 0;
for (size_t r = 0; r < clause_heads.size(); ++r) {
if (clause_heads[r] != clause_heads[w]) {
++w; // Advance the write cursor.
}
clause_heads[w] = clause_heads[r];
}
// We know it's not empty because it always has at least a default clause.
TINT_ASSERT(Reader, !clause_heads.empty());
clause_heads.resize(w + 1);
}
// Push them on in reverse order.
const auto last_clause_index = clause_heads.size() - 1;
for (size_t i = last_clause_index;; --i) {
// Create a list of integer literals for the selector values leading to
// this case clause.
ast::CaseSelectorList selectors;
const auto* values_ptr = clause_heads[i]->case_values.get();
const bool has_selectors = (values_ptr && !values_ptr->empty());
if (has_selectors) {
std::vector<uint64_t> values(values_ptr->begin(), values_ptr->end());
std::stable_sort(values.begin(), values.end());
for (auto value : values) {
// The rest of this module can handle up to 64 bit switch values.
// The Tint AST handles 32-bit values.
const uint32_t value32 = uint32_t(value & 0xFFFFFFFF);
if (selector.type->IsUnsignedScalarOrVector()) {
selectors.emplace_back(create<ast::IntLiteralExpression>(
Source{}, value32, ast::IntLiteralExpression::Suffix::kU));
} else {
selectors.emplace_back(
create<ast::IntLiteralExpression>(Source{}, static_cast<int32_t>(value32),
ast::IntLiteralExpression::Suffix::kI));
}
}
}
// Where does this clause end?
const auto end_id =
(i + 1 < clause_heads.size()) ? clause_heads[i + 1]->id : construct->end_id;
// Reserve the case clause slot in swch->cases, push the new statement block
// for the case, and fill the case clause once the block is generated.
auto case_idx = swch->cases.size();
swch->cases.emplace_back(nullptr);
PushNewStatementBlock(construct, end_id, [=](const ast::StatementList& stmts) {
auto* body = create<ast::BlockStatement>(Source{}, stmts);
swch->cases[case_idx] = create<ast::CaseStatement>(Source{}, selectors, body);
});
if ((default_info == clause_heads[i]) && has_selectors &&
construct->ContainsPos(default_info->pos)) {
// Generate a default clause with a just fallthrough.
auto* stmts = create<ast::BlockStatement>(
Source{}, ast::StatementList{
create<ast::FallthroughStatement>(Source{}),
});
auto* case_stmt = create<ast::CaseStatement>(Source{}, ast::CaseSelectorList{}, stmts);
swch->cases.emplace_back(case_stmt);
}
if (i == 0) {
break;
}
}
return success();
}
bool FunctionEmitter::EmitLoopStart(const Construct* construct) {
auto* builder = AddStatementBuilder<LoopStatementBuilder>();
PushNewStatementBlock(construct, construct->end_id, [=](const ast::StatementList& stmts) {
builder->body = create<ast::BlockStatement>(Source{}, stmts);
});
return success();
}
bool FunctionEmitter::EmitContinuingStart(const Construct* construct) {
// A continue construct has the same depth as its associated loop
// construct. Start a continue construct.
auto* loop_candidate = LastStatement();
auto* loop = loop_candidate->As<LoopStatementBuilder>();
if (loop == nullptr) {
return Fail() << "internal error: starting continue construct, "
"expected loop on top of stack";
}
PushNewStatementBlock(construct, construct->end_id, [=](const ast::StatementList& stmts) {
loop->continuing = create<ast::BlockStatement>(Source{}, stmts);
});
return success();
}
bool FunctionEmitter::EmitNormalTerminator(const BlockInfo& block_info) {
const auto& terminator = *(block_info.basic_block->terminator());
switch (terminator.opcode()) {
case SpvOpReturn:
AddStatement(create<ast::ReturnStatement>(Source{}));
return true;
case SpvOpReturnValue: {
auto value = MakeExpression(terminator.GetSingleWordInOperand(0));
if (!value) {
return false;
}
AddStatement(create<ast::ReturnStatement>(Source{}, value.expr));
}
return true;
case SpvOpKill:
// For now, assume SPIR-V OpKill has same semantics as WGSL discard.
// TODO(dneto): https://github.com/gpuweb/gpuweb/issues/676
AddStatement(create<ast::DiscardStatement>(Source{}));
return true;
case SpvOpUnreachable:
// Translate as if it's a return. This avoids the problem where WGSL
// requires a return statement at the end of the function body.
{
const auto* result_type = type_mgr_->GetType(function_.type_id());
if (result_type->AsVoid() != nullptr) {
AddStatement(create<ast::ReturnStatement>(Source{}));
} else {
auto* ast_type = parser_impl_.ConvertType(function_.type_id());
AddStatement(create<ast::ReturnStatement>(
Source{}, parser_impl_.MakeNullValue(ast_type)));
}
}
return true;
case SpvOpBranch: {
const auto dest_id = terminator.GetSingleWordInOperand(0);
AddStatement(MakeBranch(block_info, *GetBlockInfo(dest_id)));
return true;
}
case SpvOpBranchConditional: {
// If both destinations are the same, then do the same as we would
// for an unconditional branch (OpBranch).
const auto true_dest = terminator.GetSingleWordInOperand(1);
const auto false_dest = terminator.GetSingleWordInOperand(2);
if (true_dest == false_dest) {
// This is like an unconditional branch.
AddStatement(MakeBranch(block_info, *GetBlockInfo(true_dest)));
return true;
}
const EdgeKind true_kind = block_info.succ_edge.find(true_dest)->second;
const EdgeKind false_kind = block_info.succ_edge.find(false_dest)->second;
auto* const true_info = GetBlockInfo(true_dest);
auto* const false_info = GetBlockInfo(false_dest);
auto* cond = MakeExpression(terminator.GetSingleWordInOperand(0)).expr;
if (!cond) {
return false;
}
// We have two distinct destinations. But we only get here if this
// is a normal terminator; in particular the source block is *not* the
// start of an if-selection or a switch-selection. So at most one branch
// is a kForward, kCaseFallThrough, or kIfBreak.
// The fallthrough case is special because WGSL requires the fallthrough
// statement to be last in the case clause.
if (true_kind == EdgeKind::kCaseFallThrough) {
return EmitConditionalCaseFallThrough(block_info, cond, false_kind, *false_info,
true);
} else if (false_kind == EdgeKind::kCaseFallThrough) {
return EmitConditionalCaseFallThrough(block_info, cond, true_kind, *true_info,
false);
}
// At this point, at most one edge is kForward or kIfBreak.
// Emit an 'if' statement to express the *other* branch as a conditional
// break or continue. Either or both of these could be nullptr.
// (A nullptr is generated for kIfBreak, kForward, or kBack.)
// Also if one of the branches is an if-break out of an if-selection
// requiring a flow guard, then get that flow guard name too. It will
// come from at most one of these two branches.
std::string flow_guard;
auto* true_branch = MakeBranchDetailed(block_info, *true_info, false, &flow_guard);
auto* false_branch = MakeBranchDetailed(block_info, *false_info, false, &flow_guard);
AddStatement(MakeSimpleIf(cond, true_branch, false_branch));
if (!flow_guard.empty()) {
PushGuard(flow_guard, statements_stack_.back().GetEndId());
}
return true;
}
case SpvOpSwitch:
// An OpSelectionMerge must precede an OpSwitch. That is clarified
// in the resolution to Khronos-internal SPIR-V issue 115.
// A new enough version of the SPIR-V validator checks this case.
// But issue an error in this case, as a defensive measure.
return Fail() << "invalid structured control flow: found an OpSwitch "
"that is not preceded by an "
"OpSelectionMerge: "
<< terminator.PrettyPrint();
default:
break;
}
return success();
}
const ast::Statement* FunctionEmitter::MakeBranchDetailed(const BlockInfo& src_info,
const BlockInfo& dest_info,
bool forced,
std::string* flow_guard_name_ptr) const {
auto kind = src_info.succ_edge.find(dest_info.id)->second;
switch (kind) {
case EdgeKind::kBack:
// Nothing to do. The loop backedge is implicit.
break;
case EdgeKind::kSwitchBreak: {
if (forced) {
return create<ast::BreakStatement>(Source{});
}
// Unless forced, don't bother with a break at the end of a case/default
// clause.
const auto header = dest_info.header_for_merge;
TINT_ASSERT(Reader, header != 0);
const auto* exiting_construct = GetBlockInfo(header)->construct;
TINT_ASSERT(Reader, exiting_construct->kind == Construct::kSwitchSelection);
const auto candidate_next_case_pos = src_info.pos + 1;
// Leaving the last block from the last case?
if (candidate_next_case_pos == dest_info.pos) {
// No break needed.
return nullptr;
}
// Leaving the last block from not-the-last-case?
if (exiting_construct->ContainsPos(candidate_next_case_pos)) {
const auto* candidate_next_case =
GetBlockInfo(block_order_[candidate_next_case_pos]);
if (candidate_next_case->case_head_for == exiting_construct ||
candidate_next_case->default_head_for == exiting_construct) {
// No break needed.
return nullptr;
}
}
// We need a break.
return create<ast::BreakStatement>(Source{});
}
case EdgeKind::kLoopBreak:
return create<ast::BreakStatement>(Source{});
case EdgeKind::kLoopContinue:
// An unconditional continue to the next block is redundant and ugly.
// Skip it in that case.
if (dest_info.pos == 1 + src_info.pos) {
break;
}
// Otherwise, emit a regular continue statement.
return create<ast::ContinueStatement>(Source{});
case EdgeKind::kIfBreak: {
const auto& flow_guard = GetBlockInfo(dest_info.header_for_merge)->flow_guard_name;
if (!flow_guard.empty()) {
if (flow_guard_name_ptr != nullptr) {
*flow_guard_name_ptr = flow_guard;
}
// Signal an exit from the branch.
return create<ast::AssignmentStatement>(
Source{},
create<ast::IdentifierExpression>(Source{},
builder_.Symbols().Register(flow_guard)),
MakeFalse(Source{}));
}
// For an unconditional branch, the break out to an if-selection
// merge block is implicit.
break;
}
case EdgeKind::kCaseFallThrough:
return create<ast::FallthroughStatement>(Source{});
case EdgeKind::kForward:
// Unconditional forward branch is implicit.
break;
}
return nullptr;
}
const ast::Statement* FunctionEmitter::MakeSimpleIf(const ast::Expression* condition,
const ast::Statement* then_stmt,
const ast::Statement* else_stmt) const {
if ((then_stmt == nullptr) && (else_stmt == nullptr)) {
return nullptr;
}
ast::StatementList if_stmts;
if (then_stmt != nullptr) {
if_stmts.emplace_back(then_stmt);
}
auto* if_block = create<ast::BlockStatement>(Source{}, if_stmts);
const ast::Statement* else_block = nullptr;
if (else_stmt) {
else_block = create<ast::BlockStatement>(ast::StatementList{else_stmt});
}
auto* if_stmt = create<ast::IfStatement>(Source{}, condition, if_block, else_block);
return if_stmt;
}
bool FunctionEmitter::EmitConditionalCaseFallThrough(const BlockInfo& src_info,
const ast::Expression* cond,
EdgeKind other_edge_kind,
const BlockInfo& other_dest,
bool fall_through_is_true_branch) {
// In WGSL, the fallthrough statement must come last in the case clause.
// So we'll emit an if statement for the other branch, and then emit
// the fallthrough.
// We have two distinct destinations. But we only get here if this
// is a normal terminator; in particular the source block is *not* the
// start of an if-selection. So at most one branch is a kForward or
// kCaseFallThrough.
if (other_edge_kind == EdgeKind::kForward) {
return Fail() << "internal error: normal terminator OpBranchConditional has "
"both forward and fallthrough edges";
}
if (other_edge_kind == EdgeKind::kIfBreak) {
return Fail() << "internal error: normal terminator OpBranchConditional has "
"both IfBreak and fallthrough edges. Violates nesting rule";
}
if (other_edge_kind == EdgeKind::kBack) {
return Fail() << "internal error: normal terminator OpBranchConditional has "
"both backedge and fallthrough edges. Violates nesting rule";
}
auto* other_branch = MakeForcedBranch(src_info, other_dest);
if (other_branch == nullptr) {
return Fail() << "internal error: expected a branch for edge-kind " << int(other_edge_kind);
}
if (fall_through_is_true_branch) {
AddStatement(MakeSimpleIf(cond, nullptr, other_branch));
} else {
AddStatement(MakeSimpleIf(cond, other_branch, nullptr));
}
AddStatement(create<ast::FallthroughStatement>(Source{}));
return success();
}
bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info,
bool* already_emitted) {
if (*already_emitted) {
// Only emit this part of the basic block once.
return true;
}
// Returns the given list of local definition IDs, sorted by their index.
auto sorted_by_index = [this](const std::vector<uint32_t>& ids) {
auto sorted = ids;
std::stable_sort(sorted.begin(), sorted.end(),
[this](const uint32_t lhs, const uint32_t rhs) {
return GetDefInfo(lhs)->index < GetDefInfo(rhs)->index;
});
return sorted;
};
// Emit declarations of hoisted variables, in index order.
for (auto id : sorted_by_index(block_info.hoisted_ids)) {
const auto* def_inst = def_use_mgr_->GetDef(id);
TINT_ASSERT(Reader, def_inst);
auto* storage_type = RemapStorageClass(parser_impl_.ConvertType(def_inst->type_id()), id);
AddStatement(create<ast::VariableDeclStatement>(
Source{}, parser_impl_.MakeVar(id, ast::StorageClass::kNone, storage_type, nullptr,
ast::AttributeList{})));
auto* type = ty_.Reference(storage_type, ast::StorageClass::kNone);
identifier_types_.emplace(id, type);
}
// Emit declarations of phi state variables, in index order.
for (auto id : sorted_by_index(block_info.phis_needing_state_vars)) {
const auto* def_inst = def_use_mgr_->GetDef(id);
TINT_ASSERT(Reader, def_inst);
const auto phi_var_name = GetDefInfo(id)->phi_var;
TINT_ASSERT(Reader, !phi_var_name.empty());
auto* var = builder_.Var(phi_var_name,
parser_impl_.ConvertType(def_inst->type_id())->Build(builder_));
AddStatement(create<ast::VariableDeclStatement>(Source{}, var));
}
// Emit regular statements.
const spvtools::opt::BasicBlock& bb = *(block_info.basic_block);
const auto* terminator = bb.terminator();
const auto* merge = bb.GetMergeInst(); // Might be nullptr
for (auto& inst : bb) {
if (&inst == terminator || &inst == merge || inst.opcode() == SpvOpLabel ||
inst.opcode() == SpvOpVariable) {
continue;
}
if (!EmitStatement(inst)) {
return false;
}
}
// Emit assignments to carry values to phi nodes in potential destinations.
// Do it in index order.
if (!block_info.phi_assignments.empty()) {
auto sorted = block_info.phi_assignments;
std::stable_sort(
sorted.begin(), sorted.end(),
[this](const BlockInfo::PhiAssignment& lhs, const BlockInfo::PhiAssignment& rhs) {
return GetDefInfo(lhs.phi_id)->index < GetDefInfo(rhs.phi_id)->index;
});
for (auto assignment : block_info.phi_assignments) {
const auto var_name = GetDefInfo(assignment.phi_id)->phi_var;
auto expr = MakeExpression(assignment.value);
if (!expr) {
return false;
}
AddStatement(create<ast::AssignmentStatement>(
Source{},
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(var_name)),
expr.expr));
}
}
*already_emitted = true;
return true;
}
bool FunctionEmitter::EmitConstDefinition(const spvtools::opt::Instruction& inst,
TypedExpression expr) {
if (!expr) {
return false;
}
// Do not generate pointers that we want to sink.
if (GetDefInfo(inst.result_id())->skip == SkipReason::kSinkPointerIntoUse) {
return true;
}
expr = AddressOfIfNeeded(expr, &inst);
auto* let = parser_impl_.MakeLet(inst.result_id(), expr.type, expr.expr);
if (!let) {
return false;
}
AddStatement(create<ast::VariableDeclStatement>(Source{}, let));
identifier_types_.emplace(inst.result_id(), expr.type);
return success();
}
bool FunctionEmitter::EmitConstDefOrWriteToHoistedVar(const spvtools::opt::Instruction& inst,
TypedExpression expr) {
return WriteIfHoistedVar(inst, expr) || EmitConstDefinition(inst, expr);
}
bool FunctionEmitter::WriteIfHoistedVar(const spvtools::opt::Instruction& inst,
TypedExpression expr) {
const auto result_id = inst.result_id();
const auto* def_info = GetDefInfo(result_id);
if (def_info && def_info->requires_hoisted_def) {
auto name = namer_.Name(result_id);
// Emit an assignment of the expression to the hoisted variable.
AddStatement(create<ast::AssignmentStatement>(
Source{},
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name)),
expr.expr));
return true;
}
return false;
}
bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
if (failed()) {
return false;
}
const auto result_id = inst.result_id();
const auto type_id = inst.type_id();
if (type_id != 0) {
const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
if (type_id == builtin_position_info.struct_type_id) {
return Fail() << "operations producing a per-vertex structure are not "
"supported: "
<< inst.PrettyPrint();
}
if (type_id == builtin_position_info.pointer_type_id) {
return Fail() << "operations producing a pointer to a per-vertex "
"structure are not "
"supported: "
<< inst.PrettyPrint();
}
}
// Handle combinatorial instructions.
const auto* def_info = GetDefInfo(result_id);
if (def_info) {
TypedExpression combinatorial_expr;
if (def_info->skip == SkipReason::kDontSkip) {
combinatorial_expr = MaybeEmitCombinatorialValue(inst);
if (!success()) {
return false;
}
}
// An access chain or OpCopyObject can generate a skip.
if (def_info->skip != SkipReason::kDontSkip) {
return true;
}
if (combinatorial_expr.expr != nullptr) {
if (def_info->requires_hoisted_def || def_info->requires_named_const_def ||
def_info->num_uses != 1) {
// Generate a const definition or an assignment to a hoisted definition
// now and later use the const or variable name at the uses of this
// value.
return EmitConstDefOrWriteToHoistedVar(inst, combinatorial_expr);
}
// It is harmless to defer emitting the expression until it's used.
// Any supporting statements have already been emitted.
singly_used_values_.insert(std::make_pair(result_id, combinatorial_expr));
return success();
}
}
if (failed()) {
return false;
}
if (IsImageQuery(inst.opcode())) {
return EmitImageQuery(inst);
}
if (IsSampledImageAccess(inst.opcode()) || IsRawImageAccess(inst.opcode())) {
return EmitImageAccess(inst);
}
if (IsAtomicOp(inst.opcode())) {
return EmitAtomicOp(inst);
}
switch (inst.opcode()) {
case SpvOpNop:
return true;
case SpvOpStore: {
auto ptr_id = inst.GetSingleWordInOperand(0);
const auto value_id = inst.GetSingleWordInOperand(1);
const auto ptr_type_id = def_use_mgr_->GetDef(ptr_id)->type_id();
const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
if (ptr_type_id == builtin_position_info.pointer_type_id) {
return Fail() << "storing to the whole per-vertex structure is not supported: "
<< inst.PrettyPrint();
}
TypedExpression rhs = MakeExpression(value_id);
if (!rhs) {
return false;
}
TypedExpression lhs;
// Handle exceptional cases
switch (GetSkipReason(ptr_id)) {
case SkipReason::kPointSizeBuiltinPointer:
if (IsFloatOne(value_id)) {
// Don't store to PointSize
return true;
}
return Fail() << "cannot store a value other than constant 1.0 to "
"PointSize builtin: "
<< inst.PrettyPrint();
case SkipReason::kSampleMaskOutBuiltinPointer:
lhs = MakeExpression(sample_mask_out_id);
if (lhs.type->Is<Pointer>()) {
// LHS of an assignment must be a reference type.
// Convert the LHS to a reference by dereferencing it.
lhs = Dereference(lhs);
}
// The private variable is an array whose element type is already of
// the same type as the value being stored into it. Form the
// reference into the first element.
lhs.expr = create<ast::IndexAccessorExpression>(
Source{}, lhs.expr, parser_impl_.MakeNullValue(ty_.I32()));
if (auto* ref = lhs.type->As<Reference>()) {
lhs.type = ref->type;
}
if (auto* arr = lhs.type->As<Array>()) {
lhs.type = arr->type;
}
TINT_ASSERT(Reader, lhs.type);
break;
default:
break;
}
// Handle an ordinary store as an assignment.
if (!lhs) {
lhs = MakeExpression(ptr_id);
}
if (!lhs) {
return false;
}
if (lhs.type->Is<Pointer>()) {
// LHS of an assignment must be a reference type.
// Convert the LHS to a reference by dereferencing it.
lhs = Dereference(lhs);
}
AddStatement(create<ast::AssignmentStatement>(Source{}, lhs.expr, rhs.expr));
return success();
}
case SpvOpLoad: {
// Memory accesses must be issued in SPIR-V program order.
// So represent a load by a new const definition.
const auto ptr_id = inst.GetSingleWordInOperand(0);
const auto skip_reason = GetSkipReason(ptr_id);
switch (skip_reason) {
case SkipReason::kPointSizeBuiltinPointer:
GetDefInfo(inst.result_id())->skip = SkipReason::kPointSizeBuiltinValue;
return true;
case SkipReason::kSampleMaskInBuiltinPointer: {
auto name = namer_.Name(sample_mask_in_id);
const ast::Expression* id_expr = create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register(name));
// SampleMask is an array in Vulkan SPIR-V. Always access the first
// element.
id_expr = create<ast::IndexAccessorExpression>(
Source{}, id_expr, parser_impl_.MakeNullValue(ty_.I32()));
auto* loaded_type = parser_impl_.ConvertType(inst.type_id());
if (!loaded_type->IsIntegerScalar()) {
return Fail() << "loading the whole SampleMask input array is not "
"supported: "
<< inst.PrettyPrint();
}
auto expr = TypedExpression{loaded_type, id_expr};
return EmitConstDefinition(inst, expr);
}
default:
break;
}
auto expr = MakeExpression(ptr_id);
if (!expr) {
return false;
}
// The load result type is the storage type of its operand.
if (expr.type->Is<Pointer>()) {
expr = Dereference(expr);
} else if (auto* ref = expr.type->As<Reference>()) {
expr.type = ref->type;
} else {
Fail() << "OpLoad expression is not a pointer or reference";
return false;
}
return EmitConstDefOrWriteToHoistedVar(inst, expr);
}
case SpvOpCopyMemory: {
// Generate an assignment.
auto lhs = MakeOperand(inst, 0);
auto rhs = MakeOperand(inst, 1);
// Ignore any potential memory operands. Currently they are all for
// concepts not in WGSL:
// Volatile
// Aligned
// Nontemporal
// MakePointerAvailable ; Vulkan memory model
// MakePointerVisible ; Vulkan memory model
// NonPrivatePointer ; Vulkan memory model
if (!success()) {
return false;
}
// LHS and RHS pointers must be reference types in WGSL.
if (lhs.type->Is<Pointer>()) {
lhs = Dereference(lhs);
}
if (rhs.type->Is<Pointer>()) {
rhs = Dereference(rhs);
}
AddStatement(create<ast::AssignmentStatement>(Source{}, lhs.expr, rhs.expr));
return success();
}
case SpvOpCopyObject: {
// Arguably, OpCopyObject is purely combinatorial. On the other hand,
// it exists to make a new name for something. So we choose to make
// a new named constant definition.
auto value_id = inst.GetSingleWordInOperand(0);
const auto skip = GetSkipReason(value_id);
if (skip != SkipReason::kDontSkip) {
GetDefInfo(inst.result_id())->skip = skip;
GetDefInfo(inst.result_id())->sink_pointer_source_expr =
GetDefInfo(value_id)->sink_pointer_source_expr;
return true;
}
auto expr = AddressOfIfNeeded(MakeExpression(value_id), &inst);
if (!expr) {
return false;
}
expr.type = RemapStorageClass(expr.type, result_id);
return EmitConstDefOrWriteToHoistedVar(inst, expr);
}
case SpvOpPhi: {
// Emit a read from the associated state variable.
TypedExpression expr{parser_impl_.ConvertType(inst.type_id()),
create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register(def_info->phi_var))};
return EmitConstDefOrWriteToHoistedVar(inst, expr);
}
case SpvOpOuterProduct:
// Synthesize an outer product expression in its own statement.
return EmitConstDefOrWriteToHoistedVar(inst, MakeOuterProduct(inst));
case SpvOpVectorInsertDynamic:
// Synthesize a vector insertion in its own statements.
return MakeVectorInsertDynamic(inst);
case SpvOpCompositeInsert:
// Synthesize a composite insertion in its own statements.
return MakeCompositeInsert(inst);
case SpvOpFunctionCall:
return EmitFunctionCall(inst);
case SpvOpControlBarrier:
return EmitControlBarrier(inst);
case SpvOpExtInst:
if (parser_impl_.IsIgnoredExtendedInstruction(inst)) {
return true;
}
break;
case SpvOpIAddCarry:
case SpvOpISubBorrow:
case SpvOpUMulExtended:
case SpvOpSMulExtended:
return Fail() << "extended arithmetic is not finalized for WGSL: "
"https://github.com/gpuweb/gpuweb/issues/1565: "
<< inst.PrettyPrint();
default:
break;
}
return Fail() << "unhandled instruction with opcode " << inst.opcode() << ": "
<< inst.PrettyPrint();
}
TypedExpression FunctionEmitter::MakeOperand(const spvtools::opt::Instruction& inst,
uint32_t operand_index) {
auto expr = MakeExpression(inst.GetSingleWordInOperand(operand_index));
if (!expr) {
return {};
}
return parser_impl_.RectifyOperandSignedness(inst, std::move(expr));
}
TypedExpression FunctionEmitter::InferFunctionStorageClass(TypedExpression expr) {
TypedExpression result(expr);
if (const auto* ref = expr.type->UnwrapAlias()->As<Reference>()) {
if (ref->storage_class == ast::StorageClass::kNone) {
expr.type = ty_.Reference(ref->type, ast::StorageClass::kFunction);
}
} else if (const auto* ptr = expr.type->UnwrapAlias()->As<Pointer>()) {
if (ptr->storage_class == ast::StorageClass::kNone) {
expr.type = ty_.Pointer(ptr->type, ast::StorageClass::kFunction);
}
}
return expr;
}
TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
const spvtools::opt::Instruction& inst) {
if (inst.result_id() == 0) {
return {};
}
const auto opcode = inst.opcode();
const Type* ast_type = nullptr;
if (inst.type_id()) {
ast_type = parser_impl_.ConvertType(inst.type_id());
if (!ast_type) {
Fail() << "couldn't convert result type for: " << inst.PrettyPrint();
return {};
}
}
auto binary_op = ConvertBinaryOp(opcode);
if (binary_op != ast::BinaryOp::kNone) {
auto arg0 = MakeOperand(inst, 0);
auto arg1 =
parser_impl_.RectifySecondOperandSignedness(inst, arg0.type, MakeOperand(inst, 1));
if (!arg0 || !arg1) {
return {};
}
auto* binary_expr =
create<ast::BinaryExpression>(Source{}, binary_op, arg0.expr, arg1.expr);
TypedExpression result{ast_type, binary_expr};
return parser_impl_.RectifyForcedResultType(result, inst, arg0.type);
}
auto unary_op = ast::UnaryOp::kNegation;
if (GetUnaryOp(opcode, &unary_op)) {
auto arg0 = MakeOperand(inst, 0);
auto* unary_expr = create<ast::UnaryOpExpression>(Source{}, unary_op, arg0.expr);
TypedExpression result{ast_type, unary_expr};
return parser_impl_.RectifyForcedResultType(result, inst, arg0.type);
}
const char* unary_builtin_name = GetUnaryBuiltInFunctionName(opcode);
if (unary_builtin_name != nullptr) {
ast::ExpressionList params;
params.emplace_back(MakeOperand(inst, 0).expr);
return {ast_type, create<ast::CallExpression>(
Source{},
create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register(unary_builtin_name)),
std::move(params))};
}
const auto builtin = GetBuiltin(opcode);
if (builtin != sem::BuiltinType::kNone) {
return MakeBuiltinCall(inst);
}
if (opcode == SpvOpFMod) {
return MakeFMod(inst);
}
if (opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain) {
return MakeAccessChain(inst);
}
if (opcode == SpvOpBitcast) {
return {ast_type, create<ast::BitcastExpression>(Source{}, ast_type->Build(builder_),
MakeOperand(inst, 0).expr)};
}
if (opcode == SpvOpShiftLeftLogical || opcode == SpvOpShiftRightLogical ||
opcode == SpvOpShiftRightArithmetic) {
auto arg0 = MakeOperand(inst, 0);
// The second operand must be unsigned. It's ok to wrap the shift amount
// since the shift is modulo the bit width of the first operand.
auto arg1 = parser_impl_.AsUnsigned(MakeOperand(inst, 1));
switch (opcode) {
case SpvOpShiftLeftLogical:
binary_op = ast::BinaryOp::kShiftLeft;
break;
case SpvOpShiftRightLogical:
arg0 = parser_impl_.AsUnsigned(arg0);
binary_op = ast::BinaryOp::kShiftRight;
break;
case SpvOpShiftRightArithmetic:
arg0 = parser_impl_.AsSigned(arg0);
binary_op = ast::BinaryOp::kShiftRight;
break;
default:
break;
}
TypedExpression result{
ast_type, create<ast::BinaryExpression>(Source{}, binary_op, arg0.expr, arg1.expr)};
return parser_impl_.RectifyForcedResultType(result, inst, arg0.type);
}
auto negated_op = NegatedFloatCompare(opcode);
if (negated_op != ast::BinaryOp::kNone) {
auto arg0 = MakeOperand(inst, 0);
auto arg1 = MakeOperand(inst, 1);
auto* binary_expr =
create<ast::BinaryExpression>(Source{}, negated_op, arg0.expr, arg1.expr);
auto* negated_expr =
create<ast::UnaryOpExpression>(Source{}, ast::UnaryOp::kNot, binary_expr);
return {ast_type, negated_expr};
}
if (opcode == SpvOpExtInst) {
if (parser_impl_.IsIgnoredExtendedInstruction(inst)) {
// Ignore it but don't error out.
return {};
}
if (!parser_impl_.IsGlslExtendedInstruction(inst)) {
Fail() << "unhandled extended instruction import with ID "
<< inst.GetSingleWordInOperand(0);
return {};
}
return EmitGlslStd450ExtInst(inst);
}
if (opcode == SpvOpCompositeConstruct) {
ast::ExpressionList operands;
for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) {
operands.emplace_back(MakeOperand(inst, iarg).expr);
}
return {ast_type,
builder_.Construct(Source{}, ast_type->Build(builder_), std::move(operands))};
}
if (opcode == SpvOpCompositeExtract) {
return MakeCompositeExtract(inst);
}
if (opcode == SpvOpVectorShuffle) {
return MakeVectorShuffle(inst);
}
if (opcode == SpvOpVectorExtractDynamic) {
return {ast_type, create<ast::IndexAccessorExpression>(Source{}, MakeOperand(inst, 0).expr,
MakeOperand(inst, 1).expr)};
}
if (opcode == SpvOpConvertSToF || opcode == SpvOpConvertUToF || opcode == SpvOpConvertFToS ||
opcode == SpvOpConvertFToU) {
return MakeNumericConversion(inst);
}
if (opcode == SpvOpUndef) {
// Replace undef with the null value.
return parser_impl_.MakeNullExpression(ast_type);
}
if (opcode == SpvOpSelect) {
return MakeSimpleSelect(inst);
}
if (opcode == SpvOpArrayLength) {
return MakeArrayLength(inst);
}
// builtin readonly function
// glsl.std.450 readonly function
// Instructions:
// OpSatConvertSToU // Only in Kernel (OpenCL), not in WebGPU
// OpSatConvertUToS // Only in Kernel (OpenCL), not in WebGPU
// OpUConvert // Only needed when multiple widths supported
// OpSConvert // Only needed when multiple widths supported
// OpFConvert // Only needed when multiple widths supported
// OpConvertPtrToU // Not in WebGPU
// OpConvertUToPtr // Not in WebGPU
// OpPtrCastToGeneric // Not in Vulkan
// OpGenericCastToPtr // Not in Vulkan
// OpGenericCastToPtrExplicit // Not in Vulkan
return {};
}
TypedExpression FunctionEmitter::EmitGlslStd450ExtInst(const spvtools::opt::Instruction& inst) {
const auto ext_opcode = inst.GetSingleWordInOperand(1);
if (ext_opcode == GLSLstd450Ldexp) {
// WGSL requires the second argument to be signed.
// Use a type constructor to convert it, which is the same as a bitcast.
// If the value would go from very large positive to negative, then the
// original result would have been infinity. And since WGSL
// implementations may assume that infinities are not present, then we
// don't have to worry about that case.
auto e1 = MakeOperand(inst, 2);
auto e2 = ToSignedIfUnsigned(MakeOperand(inst, 3));
return {e1.type, builder_.Call(Source{}, "ldexp", ast::ExpressionList{e1.expr, e2.expr})};
}
auto* result_type = parser_impl_.ConvertType(inst.type_id());
if (result_type->IsScalar()) {
// Some GLSLstd450 builtins have scalar forms not supported by WGSL.
// Emulate them.
switch (ext_opcode) {
case GLSLstd450Normalize:
// WGSL does not have scalar form of the normalize builtin.
// The answer would be 1 anyway, so return that directly.
return {ty_.F32(), builder_.Expr(1_f)};
case GLSLstd450FaceForward: {
// If dot(Nref, Incident) < 0, the result is Normal, otherwise -Normal.
// Also: select(-normal,normal, Incident*Nref < 0)
// (The dot product of scalars is their product.)
// Use a multiply instead of comparing floating point signs. It should
// be among the fastest operations on a GPU.
auto normal = MakeOperand(inst, 2);
auto incident = MakeOperand(inst, 3);
auto nref = MakeOperand(inst, 4);
TINT_ASSERT(Reader, normal.type->Is<F32>());
TINT_ASSERT(Reader, incident.type->Is<F32>());
TINT_ASSERT(Reader, nref.type->Is<F32>());
return {ty_.F32(),
builder_.Call(
Source{}, "select",
ast::ExpressionList{create<ast::UnaryOpExpression>(
Source{}, ast::UnaryOp::kNegation, normal.expr),
normal.expr,
create<ast::BinaryExpression>(
Source{}, ast::BinaryOp::kLessThan,
builder_.Mul({}, incident.expr, nref.expr),
builder_.Expr(0_f))})};
}
case GLSLstd450Reflect: {
// Compute Incident - 2 * Normal * Normal * Incident
auto incident = MakeOperand(inst, 2);
auto normal = MakeOperand(inst, 3);
TINT_ASSERT(Reader, incident.type->Is<F32>());
TINT_ASSERT(Reader, normal.type->Is<F32>());
return {
ty_.F32(),
builder_.Sub(
incident.expr,
builder_.Mul(2_f, builder_.Mul(normal.expr,
builder_.Mul(normal.expr, incident.expr))))};
}
case GLSLstd450Refract: {
// It's a complicated expression. Compute it in two dimensions, but
// with a 0-valued y component in both the incident and normal vectors,
// then take the x component of that result.
auto incident = MakeOperand(inst, 2);
auto normal = MakeOperand(inst, 3);
auto eta = MakeOperand(inst, 4);
TINT_ASSERT(Reader, incident.type->Is<F32>());
TINT_ASSERT(Reader, normal.type->Is<F32>());
TINT_ASSERT(Reader, eta.type->Is<F32>());
if (!success()) {
return {};
}
const Type* f32 = eta.type;
return {f32, builder_.MemberAccessor(
builder_.Call(
Source{}, "refract",
ast::ExpressionList{
builder_.vec2<tint::f32>(incident.expr, 0_f),
builder_.vec2<tint::f32>(normal.expr, 0_f), eta.expr}),
"x")};
}
default:
break;
}
}
const auto name = GetGlslStd450FuncName(ext_opcode);
if (name.empty()) {
Fail() << "unhandled GLSL.std.450 instruction " << ext_opcode;
return {};
}
auto* func = create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
ast::ExpressionList operands;
const Type* first_operand_type = nullptr;
// All parameters to GLSL.std.450 extended instructions are IDs.
for (uint32_t iarg = 2; iarg < inst.NumInOperands(); ++iarg) {
TypedExpression operand = MakeOperand(inst, iarg);
if (first_operand_type == nullptr) {
first_operand_type = operand.type;
}
operands.emplace_back(operand.expr);
}
auto* call = create<ast::CallExpression>(Source{}, func, std::move(operands));
TypedExpression call_expr{result_type, call};
return parser_impl_.RectifyForcedResultType(call_expr, inst, first_operand_type);
}
ast::IdentifierExpression* FunctionEmitter::Swizzle(uint32_t i) {
if (i >= kMaxVectorLen) {
Fail() << "vector component index is larger than " << kMaxVectorLen - 1 << ": " << i;
return nullptr;
}
const char* names[] = {"x", "y", "z", "w"};
return create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(names[i & 3]));
}
ast::IdentifierExpression* FunctionEmitter::PrefixSwizzle(uint32_t n) {
switch (n) {
case 1:
return create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register("x"));
case 2:
return create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register("xy"));
case 3:
return create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register("xyz"));
default:
break;
}
Fail() << "invalid swizzle prefix count: " << n;
return nullptr;
}
TypedExpression FunctionEmitter::MakeFMod(const spvtools::opt::Instruction& inst) {
auto x = MakeOperand(inst, 0);
auto y = MakeOperand(inst, 1);
if (!x || !y) {
return {};
}
// Emulated with: x - y * floor(x / y)
auto* div = builder_.Div(x.expr, y.expr);
auto* floor = builder_.Call("floor", div);
auto* y_floor = builder_.Mul(y.expr, floor);
auto* res = builder_.Sub(x.expr, y_floor);
return {x.type, res};
}
TypedExpression FunctionEmitter::MakeAccessChain(const spvtools::opt::Instruction& inst) {
if (inst.NumInOperands() < 1) {
// Binary parsing will fail on this anyway.
Fail() << "invalid access chain: has no input operands";
return {};
}
const auto base_id = inst.GetSingleWordInOperand(0);
const auto base_skip = GetSkipReason(base_id);
if (base_skip != SkipReason::kDontSkip) {
// This can occur for AccessChain with no indices.
GetDefInfo(inst.result_id())->skip = base_skip;
GetDefInfo(inst.result_id())->sink_pointer_source_expr =
GetDefInfo(base_id)->sink_pointer_source_expr;
return {};
}
auto ptr_ty_id = def_use_mgr_->GetDef(base_id)->type_id();
uint32_t first_index = 1;
const auto num_in_operands = inst.NumInOperands();
bool sink_pointer = false;
TypedExpression current_expr;
// If the variable was originally gl_PerVertex, then in the AST we
// have instead emitted a gl_Position variable.
// If computing the pointer to the Position builtin, then emit the
// pointer to the generated gl_Position variable.
// If computing the pointer to the PointSize builtin, then mark the
// result as skippable due to being the point-size pointer.
// If computing the pointer to the ClipDistance or CullDistance builtins,
// then error out.
{
const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
if (base_id == builtin_position_info.per_vertex_var_id) {
// We only support the Position member.
const auto* member_index_inst =
def_use_mgr_->GetDef(inst.GetSingleWordInOperand(first_index));
if (member_index_inst == nullptr) {
Fail() << "first index of access chain does not reference an instruction: "
<< inst.PrettyPrint();
return {};
}
const auto* member_index_const = constant_mgr_->GetConstantFromInst(member_index_inst);
if (member_index_const == nullptr) {
Fail() << "first index of access chain into per-vertex structure is "
"not a constant: "
<< inst.PrettyPrint();
return {};
}
const auto* member_index_const_int = member_index_const->AsIntConstant();
if (member_index_const_int == nullptr) {
Fail() << "first index of access chain into per-vertex structure is "
"not a constant integer: "
<< inst.PrettyPrint();
return {};
}
const auto member_index_value = member_index_const_int->GetZeroExtendedValue();
if (member_index_value != builtin_position_info.position_member_index) {
if (member_index_value == builtin_position_info.pointsize_member_index) {
if (auto* def_info = GetDefInfo(inst.result_id())) {
def_info->skip = SkipReason::kPointSizeBuiltinPointer;
return {};
}
} else {
// TODO(dneto): Handle ClipDistance and CullDistance
Fail() << "accessing per-vertex member " << member_index_value
<< " is not supported. Only Position is supported, and "
"PointSize is ignored";
return {};
}
}
// Skip past the member index that gets us to Position.
first_index = first_index + 1;
// Replace the gl_PerVertex reference with the gl_Position reference
ptr_ty_id = builtin_position_info.position_member_pointer_type_id;
auto name = namer_.Name(base_id);
current_expr.expr =
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
current_expr.type = parser_impl_.ConvertType(ptr_ty_id, PtrAs::Ref);
}
}
// A SPIR-V access chain is a single instruction with multiple indices
// walking down into composites. The Tint AST represents this as
// ever-deeper nested indexing expressions. Start off with an expression
// for the base, and then bury that inside nested indexing expressions.
if (!current_expr) {
current_expr = InferFunctionStorageClass(MakeOperand(inst, 0));
if (current_expr.type->Is<Pointer>()) {
current_expr = Dereference(current_expr);
}
}
const auto constants = constant_mgr_->GetOperandConstants(&inst);
const auto* ptr_type_inst = def_use_mgr_->GetDef(ptr_ty_id);
if (!ptr_type_inst || (ptr_type_inst->opcode() != SpvOpTypePointer)) {
Fail() << "Access chain %" << inst.result_id() << " base pointer is not of pointer type";
return {};
}
SpvStorageClass storage_class =
static_cast<SpvStorageClass>(ptr_type_inst->GetSingleWordInOperand(0));
uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1);
// Build up a nested expression for the access chain by walking down the type
// hierarchy, maintaining |pointee_type_id| as the SPIR-V ID of the type of
// the object pointed to after processing the previous indices.
for (uint32_t index = first_index; index < num_in_operands; ++index) {
const auto* index_const = constants[index] ? constants[index]->AsIntConstant() : nullptr;
const int64_t index_const_val = index_const ? index_const->GetSignExtendedValue() : 0;
const ast::Expression* next_expr = nullptr;
const auto* pointee_type_inst = def_use_mgr_->GetDef(pointee_type_id);
if (!pointee_type_inst) {
Fail() << "pointee type %" << pointee_type_id << " is invalid after following "
<< (index - first_index) << " indices: " << inst.PrettyPrint();
return {};
}
switch (pointee_type_inst->opcode()) {
case SpvOpTypeVector:
if (index_const) {
// Try generating a MemberAccessor expression
const auto num_elems = pointee_type_inst->GetSingleWordInOperand(1);
if (index_const_val < 0 || num_elems <= index_const_val) {
Fail() << "Access chain %" << inst.result_id() << " index %"
<< inst.GetSingleWordInOperand(index) << " value " << index_const_val
<< " is out of bounds for vector of " << num_elems << " elements";
return {};
}
if (uint64_t(index_const_val) >= kMaxVectorLen) {
Fail() << "internal error: swizzle index " << index_const_val
<< " is too big. Max handled index is " << kMaxVectorLen - 1;
}
next_expr = create<ast::MemberAccessorExpression>(
Source{}, current_expr.expr, Swizzle(uint32_t(index_const_val)));
} else {
// Non-constant index. Use array syntax
next_expr = create<ast::IndexAccessorExpression>(Source{}, current_expr.expr,
MakeOperand(inst, index).expr);
}
// All vector components are the same type.
pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
// Sink pointers to vector components.
sink_pointer = true;
break;
case SpvOpTypeMatrix:
// Use array syntax.
next_expr = create<ast::IndexAccessorExpression>(Source{}, current_expr.expr,
MakeOperand(inst, index).expr);
// All matrix components are the same type.
pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
break;
case SpvOpTypeArray:
next_expr = create<ast::IndexAccessorExpression>(Source{}, current_expr.expr,
MakeOperand(inst, index).expr);
pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
break;
case SpvOpTypeRuntimeArray:
next_expr = create<ast::IndexAccessorExpression>(Source{}, current_expr.expr,
MakeOperand(inst, index).expr);
pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
break;
case SpvOpTypeStruct: {
if (!index_const) {
Fail() << "Access chain %" << inst.result_id() << " index %"
<< inst.GetSingleWordInOperand(index)
<< " is a non-constant index into a structure %" << pointee_type_id;
return {};
}
const auto num_members = pointee_type_inst->NumInOperands();
if ((index_const_val < 0) || num_members <= uint64_t(index_const_val)) {
Fail() << "Access chain %" << inst.result_id() << " index value "
<< index_const_val << " is out of bounds for structure %"
<< pointee_type_id << " having " << num_members << " members";
return {};
}
auto name = namer_.GetMemberName(pointee_type_id, uint32_t(index_const_val));
auto* member_access =
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
next_expr = create<ast::MemberAccessorExpression>(Source{}, current_expr.expr,
member_access);
pointee_type_id = pointee_type_inst->GetSingleWordInOperand(
static_cast<uint32_t>(index_const_val));
break;
}
default:
Fail() << "Access chain with unknown or invalid pointee type %" << pointee_type_id
<< ": " << pointee_type_inst->PrettyPrint();
return {};
}
const auto pointer_type_id = type_mgr_->FindPointerToType(pointee_type_id, storage_class);
auto* type = parser_impl_.ConvertType(pointer_type_id, PtrAs::Ref);
TINT_ASSERT(Reader, type && type->Is<Reference>());
current_expr = TypedExpression{type, next_expr};
}
if (sink_pointer) {
// Capture the reference so that we can sink it into the point of use.
GetDefInfo(inst.result_id())->skip = SkipReason::kSinkPointerIntoUse;
GetDefInfo(inst.result_id())->sink_pointer_source_expr = current_expr;
}
return current_expr;
}
TypedExpression FunctionEmitter::MakeCompositeExtract(const spvtools::opt::Instruction& inst) {
// This is structurally similar to creating an access chain, but
// the SPIR-V instruction has literal indices instead of IDs for indices.
auto composite_index = 0u;
auto first_index_position = 1;
TypedExpression current_expr(MakeOperand(inst, composite_index));
if (!current_expr) {
return {};
}
const auto composite_id = inst.GetSingleWordInOperand(composite_index);
auto current_type_id = def_use_mgr_->GetDef(composite_id)->type_id();
return MakeCompositeValueDecomposition(inst, current_expr, current_type_id,
first_index_position);
}
TypedExpression FunctionEmitter::MakeCompositeValueDecomposition(
const spvtools::opt::Instruction& inst,
TypedExpression composite,
uint32_t composite_type_id,
int index_start) {
// This is structurally similar to creating an access chain, but
// the SPIR-V instruction has literal indices instead of IDs for indices.
// A SPIR-V composite extract is a single instruction with multiple
// literal indices walking down into composites.
// A SPIR-V composite insert is similar but also tells you what component
// to inject. This function is responsible for the the walking-into part
// of composite-insert.
//
// The Tint AST represents this as ever-deeper nested indexing expressions.
// Start off with an expression for the composite, and then bury that inside
// nested indexing expressions.
auto current_expr = composite;
auto current_type_id = composite_type_id;
auto make_index = [this](uint32_t literal) {
return create<ast::IntLiteralExpression>(Source{}, literal,
ast::IntLiteralExpression::Suffix::kU);
};
// Build up a nested expression for the decomposition by walking down the type
// hierarchy, maintaining |current_type_id| as the SPIR-V ID of the type of
// the object pointed to after processing the previous indices.
const auto num_in_operands = inst.NumInOperands();
for (uint32_t index = static_cast<uint32_t>(index_start); index < num_in_operands; ++index) {
const uint32_t index_val = inst.GetSingleWordInOperand(index);
const auto* current_type_inst = def_use_mgr_->GetDef(current_type_id);
if (!current_type_inst) {
Fail() << "composite type %" << current_type_id << " is invalid after following "
<< (index - static_cast<uint32_t>(index_start))
<< " indices: " << inst.PrettyPrint();
return {};
}
const char* operation_name = nullptr;
switch (inst.opcode()) {
case SpvOpCompositeExtract:
operation_name = "OpCompositeExtract";
break;
case SpvOpCompositeInsert:
operation_name = "OpCompositeInsert";
break;
default:
Fail() << "internal error: unhandled " << inst.PrettyPrint();
return {};
}
const ast::Expression* next_expr = nullptr;
switch (current_type_inst->opcode()) {
case SpvOpTypeVector: {
// Try generating a MemberAccessor expression. That result in something
// like "foo.z", which is more idiomatic than "foo[2]".
const auto num_elems = current_type_inst->GetSingleWordInOperand(1);
if (num_elems <= index_val) {
Fail() << operation_name << " %" << inst.result_id() << " index value "
<< index_val << " is out of bounds for vector of " << num_elems
<< " elements";
return {};
}
if (index_val >= kMaxVectorLen) {
Fail() << "internal error: swizzle index " << index_val
<< " is too big. Max handled index is " << kMaxVectorLen - 1;
return {};
}
next_expr = create<ast::MemberAccessorExpression>(Source{}, current_expr.expr,
Swizzle(index_val));
// All vector components are the same type.
current_type_id = current_type_inst->GetSingleWordInOperand(0);
break;
}
case SpvOpTypeMatrix: {
// Check bounds
const auto num_elems = current_type_inst->GetSingleWordInOperand(1);
if (num_elems <= index_val) {
Fail() << operation_name << " %" << inst.result_id() << " index value "
<< index_val << " is out of bounds for matrix of " << num_elems
<< " elements";
return {};
}
if (index_val >= kMaxVectorLen) {
Fail() << "internal error: swizzle index " << index_val
<< " is too big. Max handled index is " << kMaxVectorLen - 1;
}
// Use array syntax.
next_expr = create<ast::IndexAccessorExpression>(Source{}, current_expr.expr,
make_index(index_val));
// All matrix components are the same type.
current_type_id = current_type_inst->GetSingleWordInOperand(0);
break;
}
case SpvOpTypeArray:
// The array size could be a spec constant, and so it's not always
// statically checkable. Instead, rely on a runtime index clamp
// or runtime check to keep this safe.
next_expr = create<ast::IndexAccessorExpression>(Source{}, current_expr.expr,
make_index(index_val));
current_type_id = current_type_inst->GetSingleWordInOperand(0);
break;
case SpvOpTypeRuntimeArray:
Fail() << "can't do " << operation_name
<< " on a runtime array: " << inst.PrettyPrint();
return {};
case SpvOpTypeStruct: {
const auto num_members = current_type_inst->NumInOperands();
if (num_members <= index_val) {
Fail() << operation_name << " %" << inst.result_id() << " index value "
<< index_val << " is out of bounds for structure %" << current_type_id
<< " having " << num_members << " members";
return {};
}
auto name = namer_.GetMemberName(current_type_id, uint32_t(index_val));
auto* member_access =
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
next_expr = create<ast::MemberAccessorExpression>(Source{}, current_expr.expr,
member_access);
current_type_id = current_type_inst->GetSingleWordInOperand(index_val);
break;
}
default:
Fail() << operation_name << " with bad type %" << current_type_id << ": "
<< current_type_inst->PrettyPrint();
return {};
}
current_expr = TypedExpression{parser_impl_.ConvertType(current_type_id), next_expr};
}
return current_expr;
}
const ast::Expression* FunctionEmitter::MakeTrue(const Source& source) const {
return create<ast::BoolLiteralExpression>(source, true);
}
const ast::Expression* FunctionEmitter::MakeFalse(const Source& source) const {
return create<ast::BoolLiteralExpression>(source, false);
}
TypedExpression FunctionEmitter::MakeVectorShuffle(const spvtools::opt::Instruction& inst) {
const auto vec0_id = inst.GetSingleWordInOperand(0);
const auto vec1_id = inst.GetSingleWordInOperand(1);
const spvtools::opt::Instruction& vec0 = *(def_use_mgr_->GetDef(vec0_id));
const spvtools::opt::Instruction& vec1 = *(def_use_mgr_->GetDef(vec1_id));
const auto vec0_len = type_mgr_->GetType(vec0.type_id())->AsVector()->element_count();
const auto vec1_len = type_mgr_->GetType(vec1.type_id())->AsVector()->element_count();
// Idiomatic vector accessors.
// Generate an ast::TypeConstructor expression.
// Assume the literal indices are valid, and there is a valid number of them.
auto source = GetSourceForInst(inst);
const Vector* result_type = As<Vector>(parser_impl_.ConvertType(inst.type_id()));
ast::ExpressionList values;
for (uint32_t i = 2; i < inst.NumInOperands(); ++i) {
const auto index = inst.GetSingleWordInOperand(i);
if (index < vec0_len) {
auto expr = MakeExpression(vec0_id);
if (!expr) {
return {};
}
values.emplace_back(
create<ast::MemberAccessorExpression>(source, expr.expr, Swizzle(index)));
} else if (index < vec0_len + vec1_len) {
const auto sub_index = index - vec0_len;
TINT_ASSERT(Reader, sub_index < kMaxVectorLen);
auto expr = MakeExpression(vec1_id);
if (!expr) {
return {};
}
values.emplace_back(
create<ast::MemberAccessorExpression>(source, expr.expr, Swizzle(sub_index)));
} else if (index == 0xFFFFFFFF) {
// By rule, this maps to OpUndef. Instead, make it zero.
values.emplace_back(parser_impl_.MakeNullValue(result_type->type));
} else {
Fail() << "invalid vectorshuffle ID %" << inst.result_id()
<< ": index too large: " << index;
return {};
}
}
return {result_type, builder_.Construct(source, result_type->Build(builder_), values)};
}
bool FunctionEmitter::RegisterSpecialBuiltInVariables() {
size_t index = def_info_.size();
for (auto& special_var : parser_impl_.special_builtins()) {
const auto id = special_var.first;
const auto builtin = special_var.second;
const auto* var = def_use_mgr_->GetDef(id);
def_info_[id] = std::make_unique<DefInfo>(*var, false, 0, index);
++index;
auto& def = def_info_[id];
// Builtins are always defined outside the function.
switch (builtin) {
case SpvBuiltInPointSize:
def->skip = SkipReason::kPointSizeBuiltinPointer;
break;
case SpvBuiltInSampleMask: {
// Distinguish between input and output variable.
const auto storage_class =
static_cast<SpvStorageClass>(var->GetSingleWordInOperand(0));
if (storage_class == SpvStorageClassInput) {
sample_mask_in_id = id;
def->skip = SkipReason::kSampleMaskInBuiltinPointer;
} else {
sample_mask_out_id = id;
def->skip = SkipReason::kSampleMaskOutBuiltinPointer;
}
break;
}
case SpvBuiltInSampleId:
case SpvBuiltInInstanceIndex:
case SpvBuiltInVertexIndex:
case SpvBuiltInLocalInvocationIndex:
case SpvBuiltInLocalInvocationId:
case SpvBuiltInGlobalInvocationId:
case SpvBuiltInWorkgroupId:
case SpvBuiltInNumWorkgroups:
break;
default:
return Fail() << "unrecognized special builtin: " << int(builtin);
}
}
return true;
}
bool FunctionEmitter::RegisterLocallyDefinedValues() {
// Create a DefInfo for each value definition in this function.
size_t index = def_info_.size();
for (auto block_id : block_order_) {
const auto* block_info = GetBlockInfo(block_id);
const auto block_pos = block_info->pos;
for (const auto& inst : *(block_info->basic_block)) {
const auto result_id = inst.result_id();
if ((result_id == 0) || inst.opcode() == SpvOpLabel) {
continue;
}
def_info_[result_id] = std::make_unique<DefInfo>(inst, true, block_pos, index);
++index;
auto& info = def_info_[result_id];
// Determine storage class for pointer values. Do this in order because
// we might rely on the storage class for a previously-visited definition.
// Logical pointers can't be transmitted through OpPhi, so remaining
// pointer definitions are SSA values, and their definitions must be
// visited before their uses.
const auto* type = type_mgr_->GetType(inst.type_id());
if (type) {
if (type->AsPointer()) {
if (auto* ast_type = parser_impl_.ConvertType(inst.type_id())) {
if (auto* ptr = ast_type->As<Pointer>()) {
info->storage_class = ptr->storage_class;
}
}
switch (inst.opcode()) {
case SpvOpUndef:
return Fail() << "undef pointer is not valid: " << inst.PrettyPrint();
case SpvOpVariable:
// Keep the default decision based on the result type.
break;
case SpvOpAccessChain:
case SpvOpInBoundsAccessChain:
case SpvOpCopyObject:
// Inherit from the first operand. We need this so we can pick up
// a remapped storage buffer.
info->storage_class =
GetStorageClassForPointerValue(inst.GetSingleWordInOperand(0));
break;
default:
return Fail() << "pointer defined in function from unknown opcode: "
<< inst.PrettyPrint();
}
}
auto* unwrapped = type;
while (auto* ptr = unwrapped->AsPointer()) {
unwrapped = ptr->pointee_type();
}
if (unwrapped->AsSampler() || unwrapped->AsImage() || unwrapped->AsSampledImage()) {
// Defer code generation until the instruction that actually acts on
// the image.
info->skip = SkipReason::kOpaqueObject;
}
}
}
}
return true;
}
ast::StorageClass FunctionEmitter::GetStorageClassForPointerValue(uint32_t id) {
auto where = def_info_.find(id);
if (where != def_info_.end()) {
auto candidate = where->second.get()->storage_class;
if (candidate != ast::StorageClass::kInvalid) {
return candidate;
}
}
const auto type_id = def_use_mgr_->GetDef(id)->type_id();
if (type_id) {
auto* ast_type = parser_impl_.ConvertType(type_id);
if (auto* ptr = As<Pointer>(ast_type)) {
return ptr->storage_class;
}
}
return ast::StorageClass::kInvalid;
}
const Type* FunctionEmitter::RemapStorageClass(const Type* type, uint32_t result_id) {
if (auto* ast_ptr_type = As<Pointer>(type)) {
// Remap an old-style storage buffer pointer to a new-style storage
// buffer pointer.
const auto sc = GetStorageClassForPointerValue(result_id);
if (ast_ptr_type->storage_class != sc) {
return ty_.Pointer(ast_ptr_type->type, sc);
}
}
return type;
}
void FunctionEmitter::FindValuesNeedingNamedOrHoistedDefinition() {
// Mark vector operands of OpVectorShuffle as needing a named definition,
// but only if they are defined in this function as well.
auto require_named_const_def = [&](const spvtools::opt::Instruction& inst,
int in_operand_index) {
const auto id = inst.GetSingleWordInOperand(static_cast<uint32_t>(in_operand_index));
auto* const operand_def = GetDefInfo(id);
if (operand_def) {
operand_def->requires_named_const_def = true;
}
};
for (auto& id_def_info_pair : def_info_) {
const auto& inst = id_def_info_pair.second->inst;
const auto opcode = inst.opcode();
if ((opcode == SpvOpVectorShuffle) || (opcode == SpvOpOuterProduct)) {
// We might access the vector operands multiple times. Make sure they
// are evaluated only once.
require_named_const_def(inst, 0);
require_named_const_def(inst, 1);
}
if (parser_impl_.IsGlslExtendedInstruction(inst)) {
// Some emulations of GLSLstd450 instructions evaluate certain operands
// multiple times. Ensure their expressions are evaluated only once.
switch (inst.GetSingleWordInOperand(1)) {
case GLSLstd450FaceForward:
// The "normal" operand expression is used twice in code generation.
require_named_const_def(inst, 2);
break;
case GLSLstd450Reflect:
require_named_const_def(inst, 2); // Incident
require_named_const_def(inst, 3); // Normal
break;
default:
break;
}
}
}
// Scan uses of locally defined IDs, in function block order.
for (auto block_id : block_order_) {
const auto* block_info = GetBlockInfo(block_id);
const auto block_pos = block_info->pos;
for (const auto& inst : *(block_info->basic_block)) {
// Update bookkeeping for locally-defined IDs used by this instruction.
inst.ForEachInId([this, block_pos, block_info](const uint32_t* id_ptr) {
auto* def_info = GetDefInfo(*id_ptr);
if (def_info) {
// Update usage count.
def_info->num_uses++;
// Update usage span.
def_info->last_use_pos = std::max(def_info->last_use_pos, block_pos);
// Determine whether this ID is defined in a different construct
// from this use.
const auto defining_block = block_order_[def_info->block_pos];
const auto* def_in_construct = GetBlockInfo(defining_block)->construct;
if (def_in_construct != block_info->construct) {
def_info->used_in_another_construct = true;
}
}
});
if (inst.opcode() == SpvOpPhi) {
// Declare a name for the variable used to carry values to a phi.
const auto phi_id = inst.result_id();
auto* phi_def_info = GetDefInfo(phi_id);
phi_def_info->phi_var = namer_.MakeDerivedName(namer_.Name(phi_id) + "_phi");
// Track all the places where we need to mention the variable,
// so we can place its declaration. First, record the location of
// the read from the variable.
uint32_t first_pos = block_pos;
uint32_t last_pos = block_pos;
// Record the assignments that will propagate values from predecessor
// blocks.
for (uint32_t i = 0; i + 1 < inst.NumInOperands(); i += 2) {
const uint32_t value_id = inst.GetSingleWordInOperand(i);
const uint32_t pred_block_id = inst.GetSingleWordInOperand(i + 1);
auto* pred_block_info = GetBlockInfo(pred_block_id);
// The predecessor might not be in the block order at all, so we
// need this guard.
if (IsInBlockOrder(pred_block_info)) {
// Record the assignment that needs to occur at the end
// of the predecessor block.
pred_block_info->phi_assignments.push_back({phi_id, value_id});
first_pos = std::min(first_pos, pred_block_info->pos);
last_pos = std::max(last_pos, pred_block_info->pos);
}
}
// Schedule the declaration of the state variable.
const auto* enclosing_construct = GetEnclosingScope(first_pos, last_pos);
GetBlockInfo(enclosing_construct->begin_id)
->phis_needing_state_vars.push_back(phi_id);
}
}
}
// For an ID defined in this function, determine if its evaluation and
// potential declaration needs special handling:
// - Compensate for the fact that dominance does not map directly to scope.
// A definition could dominate its use, but a named definition in WGSL
// at the location of the definition could go out of scope by the time
// you reach the use. In that case, we hoist the definition to a basic
// block at the smallest scope enclosing both the definition and all
// its uses.
// - If value is used in a different construct than its definition, then it
// needs a named constant definition. Otherwise we might sink an
// expensive computation into control flow, and hence change performance.
for (auto& id_def_info_pair : def_info_) {
const auto def_id = id_def_info_pair.first;
auto* def_info = id_def_info_pair.second.get();
if (def_info->num_uses == 0) {
// There is no need to adjust the location of the declaration.
continue;
}
if (!def_info->locally_defined) {
// Never hoist a variable declared at module scope.
// This occurs for builtin variables, which are mapped to module-scope
// private variables.
continue;
}
// The first use must be the at the SSA definition, because block order
// respects dominance.
const auto first_pos = def_info->block_pos;
const auto last_use_pos = def_info->last_use_pos;
const auto* def_in_construct = GetBlockInfo(block_order_[first_pos])->construct;
// A definition in the first block of an kIfSelection or kSwitchSelection
// occurs before the branch, and so that definition should count as
// having been defined at the scope of the parent construct.
if (first_pos == def_in_construct->begin_pos) {
if ((def_in_construct->kind == Construct::kIfSelection) ||
(def_in_construct->kind == Construct::kSwitchSelection)) {
def_in_construct = def_in_construct->parent;
}
}
bool should_hoist = false;
if (!def_in_construct->ContainsPos(last_use_pos)) {
// To satisfy scoping, we have to hoist the definition out to an enclosing
// construct.
should_hoist = true;
} else {
// Avoid moving combinatorial values across constructs. This is a
// simple heuristic to avoid changing the cost of an operation
// by moving it into or out of a loop, for example.
if ((def_info->storage_class == ast::StorageClass::kInvalid) &&
def_info->used_in_another_construct) {
should_hoist = true;
}
}
if (should_hoist) {
const auto* enclosing_construct = GetEnclosingScope(first_pos, last_use_pos);
if (enclosing_construct == def_in_construct) {
// We can use a plain 'const' definition.
def_info->requires_named_const_def = true;
} else {
// We need to make a hoisted variable definition.
// TODO(dneto): Handle non-storable types, particularly pointers.
def_info->requires_hoisted_def = true;
auto* hoist_to_block = GetBlockInfo(enclosing_construct->begin_id);
hoist_to_block->hoisted_ids.push_back(def_id);
}
}
}
}
const Construct* FunctionEmitter::GetEnclosingScope(uint32_t first_pos, uint32_t last_pos) const {
const auto* enclosing_construct = GetBlockInfo(block_order_[first_pos])->construct;
TINT_ASSERT(Reader, enclosing_construct != nullptr);
// Constructs are strictly nesting, so follow parent pointers
while (enclosing_construct && !enclosing_construct->ScopeContainsPos(last_pos)) {
// The scope of a continue construct is enclosed in its associated loop
// construct, but they are siblings in our construct tree.
const auto* sibling_loop = SiblingLoopConstruct(enclosing_construct);
// Go to the sibling loop if it exists, otherwise walk up to the parent.
enclosing_construct = sibling_loop ? sibling_loop : enclosing_construct->parent;
}
// At worst, we go all the way out to the function construct.
TINT_ASSERT(Reader, enclosing_construct != nullptr);
return enclosing_construct;
}
TypedExpression FunctionEmitter::MakeNumericConversion(const spvtools::opt::Instruction& inst) {
const auto opcode = inst.opcode();
auto* requested_type = parser_impl_.ConvertType(inst.type_id());
auto arg_expr = MakeOperand(inst, 0);
if (!arg_expr) {
return {};
}
arg_expr.type = arg_expr.type->UnwrapRef();
const Type* expr_type = nullptr;
if ((opcode == SpvOpConvertSToF) || (opcode == SpvOpConvertUToF)) {
if (arg_expr.type->IsIntegerScalarOrVector()) {
expr_type = requested_type;
} else {
Fail() << "operand for conversion to floating point must be integral "
"scalar or vector: "
<< inst.PrettyPrint();
}
} else if (inst.opcode() == SpvOpConvertFToU) {
if (arg_expr.type->IsFloatScalarOrVector()) {
expr_type = parser_impl_.GetUnsignedIntMatchingShape(arg_expr.type);
} else {
Fail() << "operand for conversion to unsigned integer must be floating "
"point scalar or vector: "
<< inst.PrettyPrint();
}
} else if (inst.opcode() == SpvOpConvertFToS) {
if (arg_expr.type->IsFloatScalarOrVector()) {
expr_type = parser_impl_.GetSignedIntMatchingShape(arg_expr.type);
} else {
Fail() << "operand for conversion to signed integer must be floating "
"point scalar or vector: "
<< inst.PrettyPrint();
}
}
if (expr_type == nullptr) {
// The diagnostic has already been emitted.
return {};
}
ast::ExpressionList params;
params.push_back(arg_expr.expr);
TypedExpression result{
expr_type,
builder_.Construct(GetSourceForInst(inst), expr_type->Build(builder_), std::move(params))};
if (requested_type == expr_type) {
return result;
}
return {requested_type,
create<ast::BitcastExpression>(GetSourceForInst(inst), requested_type->Build(builder_),
result.expr)};
}
bool FunctionEmitter::EmitFunctionCall(const spvtools::opt::Instruction& inst) {
// We ignore function attributes such as Inline, DontInline, Pure, Const.
auto name = namer_.Name(inst.GetSingleWordInOperand(0));
auto* function = create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
ast::ExpressionList args;
for (uint32_t iarg = 1; iarg < inst.NumInOperands(); ++iarg) {
auto expr = MakeOperand(inst, iarg);
if (!expr) {
return false;
}
// Functions cannot use references as parameters, so we need to pass by
// pointer if the operand is of pointer type.
expr = AddressOfIfNeeded(expr, def_use_mgr_->GetDef(inst.GetSingleWordInOperand(iarg)));
args.emplace_back(expr.expr);
}
if (failed()) {
return false;
}
auto* call_expr = create<ast::CallExpression>(Source{}, function, std::move(args));
auto* result_type = parser_impl_.ConvertType(inst.type_id());
if (!result_type) {
return Fail() << "internal error: no mapped type result of call: " << inst.PrettyPrint();
}
if (result_type->Is<Void>()) {
return nullptr != AddStatement(create<ast::CallStatement>(Source{}, call_expr));
}
return EmitConstDefOrWriteToHoistedVar(inst, {result_type, call_expr});
}
bool FunctionEmitter::EmitControlBarrier(const spvtools::opt::Instruction& inst) {
uint32_t operands[3];
for (uint32_t i = 0; i < 3; i++) {
auto id = inst.GetSingleWordInOperand(i);
if (auto* constant = constant_mgr_->FindDeclaredConstant(id)) {
operands[i] = constant->GetU32();
} else {
return Fail() << "invalid or missing operands for control barrier";
}
}
uint32_t execution = operands[0];
uint32_t memory = operands[1];
uint32_t semantics = operands[2];
if (execution != SpvScopeWorkgroup) {
return Fail() << "unsupported control barrier execution scope: "
<< "expected Workgroup (2), got: " << execution;
}
if (semantics & SpvMemorySemanticsAcquireReleaseMask) {
semantics &= ~static_cast<uint32_t>(SpvMemorySemanticsAcquireReleaseMask);
} else {
return Fail() << "control barrier semantics requires acquire and release";
}
if (semantics & SpvMemorySemanticsWorkgroupMemoryMask) {
if (memory != SpvScopeWorkgroup) {
return Fail() << "workgroupBarrier requires workgroup memory scope";
}
AddStatement(create<ast::CallStatement>(builder_.Call("workgroupBarrier")));
semantics &= ~static_cast<uint32_t>(SpvMemorySemanticsWorkgroupMemoryMask);
}
if (semantics & SpvMemorySemanticsUniformMemoryMask) {
if (memory != SpvScopeDevice) {
return Fail() << "storageBarrier requires device memory scope";
}
AddStatement(create<ast::CallStatement>(builder_.Call("storageBarrier")));
semantics &= ~static_cast<uint32_t>(SpvMemorySemanticsUniformMemoryMask);
}
if (semantics) {
return Fail() << "unsupported control barrier semantics: " << semantics;
}
return true;
}
TypedExpression FunctionEmitter::MakeBuiltinCall(const spvtools::opt::Instruction& inst) {
const auto builtin = GetBuiltin(inst.opcode());
auto* name = sem::str(builtin);
auto* ident = create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
ast::ExpressionList params;
const Type* first_operand_type = nullptr;
for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) {
TypedExpression operand = MakeOperand(inst, iarg);
if (first_operand_type == nullptr) {
first_operand_type = operand.type;
}
params.emplace_back(operand.expr);
}
auto* call_expr = create<ast::CallExpression>(Source{}, ident, std::move(params));
auto* result_type = parser_impl_.ConvertType(inst.type_id());
if (!result_type) {
Fail() << "internal error: no mapped type result of call: " << inst.PrettyPrint();
return {};
}
TypedExpression call{result_type, call_expr};
return parser_impl_.RectifyForcedResultType(call, inst, first_operand_type);
}
TypedExpression FunctionEmitter::MakeSimpleSelect(const spvtools::opt::Instruction& inst) {
auto condition = MakeOperand(inst, 0);
auto true_value = MakeOperand(inst, 1);
auto false_value = MakeOperand(inst, 2);
// SPIR-V validation requires:
// - the condition to be bool or bool vector, so we don't check it here.
// - true_value false_value, and result type to match.
// - you can't select over pointers or pointer vectors, unless you also have
// a VariablePointers* capability, which is not allowed in by WebGPU.
auto* op_ty = true_value.type;
if (op_ty->Is<Vector>() || op_ty->IsFloatScalar() || op_ty->IsIntegerScalar() ||
op_ty->Is<Bool>()) {
ast::ExpressionList params;
params.push_back(false_value.expr);
params.push_back(true_value.expr);
// The condition goes last.
params.push_back(condition.expr);
return {op_ty,
create<ast::CallExpression>(Source{},
create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register("select")),
std::move(params))};
}
return {};
}
Source FunctionEmitter::GetSourceForInst(const spvtools::opt::Instruction& inst) const {
return parser_impl_.GetSourceForInst(&inst);
}
const spvtools::opt::Instruction* FunctionEmitter::GetImage(
const spvtools::opt::Instruction& inst) {
if (inst.NumInOperands() == 0) {
Fail() << "not an image access instruction: " << inst.PrettyPrint();
return nullptr;
}
// The image or sampled image operand is always the first operand.
const auto image_or_sampled_image_operand_id = inst.GetSingleWordInOperand(0);
const auto* image =
parser_impl_.GetMemoryObjectDeclarationForHandle(image_or_sampled_image_operand_id, true);
if (!image) {
Fail() << "internal error: couldn't find image for " << inst.PrettyPrint();
return nullptr;
}
return image;
}
const Texture* FunctionEmitter::GetImageType(const spvtools::opt::Instruction& image) {
const Pointer* ptr_type = parser_impl_.GetTypeForHandleVar(image);
if (!parser_impl_.success()) {
Fail();
return {};
}
if (!ptr_type) {
Fail() << "invalid texture type for " << image.PrettyPrint();
return {};
}
auto* result = ptr_type->type->UnwrapAll()->As<Texture>();
if (!result) {
Fail() << "invalid texture type for " << image.PrettyPrint();
return {};
}
return result;
}
const ast::Expression* FunctionEmitter::GetImageExpression(const spvtools::opt::Instruction& inst) {
auto* image = GetImage(inst);
if (!image) {
return nullptr;
}
auto name = namer_.Name(image->result_id());
return create<ast::IdentifierExpression>(GetSourceForInst(inst),
builder_.Symbols().Register(name));
}
const ast::Expression* FunctionEmitter::GetSamplerExpression(
const spvtools::opt::Instruction& inst) {
// The sampled image operand is always the first operand.
const auto image_or_sampled_image_operand_id = inst.GetSingleWordInOperand(0);
const auto* image =
parser_impl_.GetMemoryObjectDeclarationForHandle(image_or_sampled_image_operand_id, false);
if (!image) {
Fail() << "internal error: couldn't find sampler for " << inst.PrettyPrint();
return nullptr;
}
auto name = namer_.Name(image->result_id());
return create<ast::IdentifierExpression>(GetSourceForInst(inst),
builder_.Symbols().Register(name));
}
bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) {
ast::ExpressionList args;
const auto opcode = inst.opcode();
// Form the texture operand.
const spvtools::opt::Instruction* image = GetImage(inst);
if (!image) {
return false;
}
args.push_back(GetImageExpression(inst));
// Form the sampler operand, if needed.
if (IsSampledImageAccess(opcode)) {
// Form the sampler operand.
if (auto* sampler = GetSamplerExpression(inst)) {
args.push_back(sampler);
} else {
return false;
}
}
// Find the texture type.
const Pointer* texture_ptr_type = parser_impl_.GetTypeForHandleVar(*image);
if (!texture_ptr_type) {
return Fail();
}
const Texture* texture_type = texture_ptr_type->type->UnwrapAll()->As<Texture>();
if (!texture_type) {
return Fail();
}
// This is the SPIR-V operand index. We're done with the first operand.
uint32_t arg_index = 1;
// Push the coordinates operands.
auto coords = MakeCoordinateOperandsForImageAccess(inst);
if (coords.empty()) {
return false;
}
args.insert(args.end(), coords.begin(), coords.end());
// Skip the coordinates operand.
arg_index++;
const auto num_args = inst.NumInOperands();
// Consumes the depth-reference argument, pushing it onto the end of
// the parameter list. Issues a diagnostic and returns false on error.
auto consume_dref = [&]() -> bool {
if (arg_index < num_args) {
args.push_back(MakeOperand(inst, arg_index).expr);
arg_index++;
} else {
return Fail() << "image depth-compare instruction is missing a Dref operand: "
<< inst.PrettyPrint();
}
return true;
};
std::string builtin_name;
bool use_level_of_detail_suffix = true;
bool is_dref_sample = false;
bool is_gather_or_dref_gather = false;
bool is_non_dref_sample = false;
switch (opcode) {
case SpvOpImageSampleImplicitLod:
case SpvOpImageSampleExplicitLod:
case SpvOpImageSampleProjImplicitLod:
case SpvOpImageSampleProjExplicitLod:
is_non_dref_sample = true;
builtin_name = "textureSample";
break;
case SpvOpImageSampleDrefImplicitLod:
case SpvOpImageSampleDrefExplicitLod:
case SpvOpImageSampleProjDrefImplicitLod:
case SpvOpImageSampleProjDrefExplicitLod:
is_dref_sample = true;
builtin_name = "textureSampleCompare";
if (!consume_dref()) {
return false;
}
break;
case SpvOpImageGather:
is_gather_or_dref_gather = true;
builtin_name = "textureGather";
if (!texture_type->Is<DepthTexture>()) {
// The explicit component is the *first* argument in WGSL.
args.insert(args.begin(), ToI32(MakeOperand(inst, arg_index)).expr);
}
// Skip over the component operand, even for depth textures.
arg_index++;
break;
case SpvOpImageDrefGather:
is_gather_or_dref_gather = true;
builtin_name = "textureGatherCompare";
if (!consume_dref()) {
return false;
}
break;
case SpvOpImageFetch:
case SpvOpImageRead:
// Read a single texel from a sampled or storage image.
builtin_name = "textureLoad";
use_level_of_detail_suffix = false;
break;
case SpvOpImageWrite:
builtin_name = "textureStore";
use_level_of_detail_suffix = false;
if (arg_index < num_args) {
auto texel = MakeOperand(inst, arg_index);
auto* converted_texel = ConvertTexelForStorage(inst, texel, texture_type);
if (!converted_texel) {
return false;
}
args.push_back(converted_texel);
arg_index++;
} else {
return Fail() << "image write is missing a Texel operand: " << inst.PrettyPrint();
}
break;
default:
return Fail() << "internal error: unrecognized image access: " << inst.PrettyPrint();
}
// Loop over the image operands, looking for extra operands to the builtin.
// Except we uroll the loop.
uint32_t image_operands_mask = 0;
if (arg_index < num_args) {
image_operands_mask = inst.GetSingleWordInOperand(arg_index);
arg_index++;
}
if (arg_index < num_args && (image_operands_mask & SpvImageOperandsBiasMask)) {
if (is_dref_sample) {
return Fail() << "WGSL does not support depth-reference sampling with "
"level-of-detail bias: "
<< inst.PrettyPrint();
}
if (is_gather_or_dref_gather) {
return Fail() << "WGSL does not support image gather with "
"level-of-detail bias: "
<< inst.PrettyPrint();
}
builtin_name += "Bias";
args.push_back(MakeOperand(inst, arg_index).expr);
image_operands_mask ^= SpvImageOperandsBiasMask;
arg_index++;
}
if (arg_index < num_args && (image_operands_mask & SpvImageOperandsLodMask)) {
if (use_level_of_detail_suffix) {
builtin_name += "Level";
}
if (is_dref_sample || is_gather_or_dref_gather) {
// Metal only supports Lod = 0 for comparison sampling without
// derivatives.
// Vulkan SPIR-V does not allow Lod with OpImageGather or
// OpImageDrefGather.
if (!IsFloatZero(inst.GetSingleWordInOperand(arg_index))) {
return Fail() << "WGSL comparison sampling without derivatives "
"requires level-of-detail 0.0"
<< inst.PrettyPrint();
}
// Don't generate the Lod argument.
} else {
// Generate the Lod argument.
TypedExpression lod = MakeOperand(inst, arg_index);
// When sampling from a depth texture, the Lod operand must be an I32.
if (texture_type->Is<DepthTexture>()) {
// Convert it to a signed integer type.
lod = ToI32(lod);
}
args.push_back(lod.expr);
}
image_operands_mask ^= SpvImageOperandsLodMask;
arg_index++;
} else if ((opcode == SpvOpImageFetch || opcode == SpvOpImageRead) &&
!texture_type->IsAnyOf<DepthMultisampledTexture, MultisampledTexture>()) {
// textureLoad requires an explicit level-of-detail parameter for
// non-multisampled texture types.
args.push_back(parser_impl_.MakeNullValue(ty_.I32()));
}
if (arg_index + 1 < num_args && (image_operands_mask & SpvImageOperandsGradMask)) {
if (is_dref_sample) {
return Fail() << "WGSL does not support depth-reference sampling with "
"explicit gradient: "
<< inst.PrettyPrint();
}
if (is_gather_or_dref_gather) {
return Fail() << "WGSL does not support image gather with "
"explicit gradient: "
<< inst.PrettyPrint();
}
builtin_name += "Grad";
args.push_back(MakeOperand(inst, arg_index).expr);
args.push_back(MakeOperand(inst, arg_index + 1).expr);
image_operands_mask ^= SpvImageOperandsGradMask;
arg_index += 2;
}
if (arg_index < num_args && (image_operands_mask & SpvImageOperandsConstOffsetMask)) {
if (!IsImageSamplingOrGatherOrDrefGather(opcode)) {
return Fail() << "ConstOffset is only permitted for sampling, gather, or "
"depth-reference gather operations: "
<< inst.PrettyPrint();
}
switch (texture_type->dims) {
case ast::TextureDimension::k2d:
case ast::TextureDimension::k2dArray:
case ast::TextureDimension::k3d:
break;
default:
return Fail() << "ConstOffset is only permitted for 2D, 2D Arrayed, "
"and 3D textures: "
<< inst.PrettyPrint();
}
args.push_back(ToSignedIfUnsigned(MakeOperand(inst, arg_index)).expr);
image_operands_mask ^= SpvImageOperandsConstOffsetMask;
arg_index++;
}
if (arg_index < num_args && (image_operands_mask & SpvImageOperandsSampleMask)) {
// TODO(dneto): only permitted with ImageFetch
args.push_back(ToI32(MakeOperand(inst, arg_index)).expr);
image_operands_mask ^= SpvImageOperandsSampleMask;
arg_index++;
}
if (image_operands_mask) {
return Fail() << "unsupported image operands (" << image_operands_mask
<< "): " << inst.PrettyPrint();
}
// If any of the arguments are nullptr, then we've failed.
if (std::any_of(args.begin(), args.end(), [](auto* expr) { return expr == nullptr; })) {
return false;
}
auto* ident =
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(builtin_name));
auto* call_expr = create<ast::CallExpression>(Source{}, ident, std::move(args));
if (inst.type_id() != 0) {
// It returns a value.
const ast::Expression* value = call_expr;
// The result type, derived from the SPIR-V instruction.
auto* result_type = parser_impl_.ConvertType(inst.type_id());
auto* result_component_type = result_type;
if (auto* result_vector_type = As<Vector>(result_type)) {
result_component_type = result_vector_type->type;
}
// For depth textures, the arity might mot match WGSL:
// Operation SPIR-V WGSL
// normal sampling vec4 ImplicitLod f32
// normal sampling vec4 ExplicitLod f32
// compare sample f32 DrefImplicitLod f32
// compare sample f32 DrefExplicitLod f32
// texel load vec4 ImageFetch f32
// normal gather vec4 ImageGather vec4
// dref gather vec4 ImageDrefGather vec4
// Construct a 4-element vector with the result from the builtin in the
// first component.
if (texture_type->IsAnyOf<DepthTexture, DepthMultisampledTexture>()) {
if (is_non_dref_sample || (opcode == SpvOpImageFetch)) {
value = builder_.Construct(
Source{},
result_type->Build(builder_), // a vec4
ast::ExpressionList{value, parser_impl_.MakeNullValue(result_component_type),
parser_impl_.MakeNullValue(result_component_type),
parser_impl_.MakeNullValue(result_component_type)});
}
}
// If necessary, convert the result to the signedness of the instruction
// result type. Compare the SPIR-V image's sampled component type with the
// component of the result type of the SPIR-V instruction.
auto* spirv_image_type = parser_impl_.GetSpirvTypeForHandleMemoryObjectDeclaration(*image);
if (!spirv_image_type || (spirv_image_type->opcode() != SpvOpTypeImage)) {
return Fail() << "invalid image type for image memory object declaration "
<< image->PrettyPrint();
}
auto* expected_component_type =
parser_impl_.ConvertType(spirv_image_type->GetSingleWordInOperand(0));
if (expected_component_type != result_component_type) {
// This occurs if one is signed integer and the other is unsigned integer,
// or vice versa. Perform a bitcast.
value =
create<ast::BitcastExpression>(Source{}, result_type->Build(builder_), call_expr);
}
if (!expected_component_type->Is<F32>() && IsSampledImageAccess(opcode)) {
// WGSL permits sampled image access only on float textures.
// Reject this case in the SPIR-V reader, at least until SPIR-V validation
// catches up with this rule and can reject it earlier in the workflow.
return Fail() << "sampled image must have float component type";
}
EmitConstDefOrWriteToHoistedVar(inst, {result_type, value});
} else {
// It's an image write. No value is returned, so make a statement out
// of the call.
AddStatement(create<ast::CallStatement>(Source{}, call_expr));
}
return success();
}
bool FunctionEmitter::EmitImageQuery(const spvtools::opt::Instruction& inst) {
// TODO(dneto): Reject cases that are valid in Vulkan but invalid in WGSL.
const spvtools::opt::Instruction* image = GetImage(inst);
if (!image) {
return false;
}
auto* texture_type = GetImageType(*image);
if (!texture_type) {
return false;
}
const auto opcode = inst.opcode();
switch (opcode) {
case SpvOpImageQuerySize:
case SpvOpImageQuerySizeLod: {
ast::ExpressionList exprs;
// Invoke textureDimensions.
// If the texture is arrayed, combine with the result from
// textureNumLayers.
auto* dims_ident = create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register("textureDimensions"));
ast::ExpressionList dims_args{GetImageExpression(inst)};
if (opcode == SpvOpImageQuerySizeLod) {
dims_args.push_back(ToI32(MakeOperand(inst, 1)).expr);
}
const ast::Expression* dims_call =
create<ast::CallExpression>(Source{}, dims_ident, dims_args);
auto dims = texture_type->dims;
if ((dims == ast::TextureDimension::kCube) ||
(dims == ast::TextureDimension::kCubeArray)) {
// textureDimension returns a 3-element vector but SPIR-V expects 2.
dims_call =
create<ast::MemberAccessorExpression>(Source{}, dims_call, PrefixSwizzle(2));
}
exprs.push_back(dims_call);
if (ast::IsTextureArray(dims)) {
auto* layers_ident = create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register("textureNumLayers"));
exprs.push_back(create<ast::CallExpression>(
Source{}, layers_ident, ast::ExpressionList{GetImageExpression(inst)}));
}
auto* result_type = parser_impl_.ConvertType(inst.type_id());
TypedExpression expr = {
result_type, builder_.Construct(Source{}, result_type->Build(builder_), exprs)};
return EmitConstDefOrWriteToHoistedVar(inst, expr);
}
case SpvOpImageQueryLod:
return Fail() << "WGSL does not support querying the level of detail of "
"an image: "
<< inst.PrettyPrint();
case SpvOpImageQueryLevels:
case SpvOpImageQuerySamples: {
const auto* name =
(opcode == SpvOpImageQueryLevels) ? "textureNumLevels" : "textureNumSamples";
auto* levels_ident =
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
const ast::Expression* ast_expr = create<ast::CallExpression>(
Source{}, levels_ident, ast::ExpressionList{GetImageExpression(inst)});
auto* result_type = parser_impl_.ConvertType(inst.type_id());
// The SPIR-V result type must be integer scalar. The WGSL bulitin
// returns i32. If they aren't the same then convert the result.
if (!result_type->Is<I32>()) {
ast_expr = builder_.Construct(Source{}, result_type->Build(builder_),
ast::ExpressionList{ast_expr});
}
TypedExpression expr{result_type, ast_expr};
return EmitConstDefOrWriteToHoistedVar(inst, expr);
}
default:
break;
}
return Fail() << "unhandled image query: " << inst.PrettyPrint();
}
bool FunctionEmitter::EmitAtomicOp(const spvtools::opt::Instruction& inst) {
auto emit_atomic = [&](sem::BuiltinType builtin, std::initializer_list<TypedExpression> args) {
// Split args into params and expressions
ast::ParameterList params;
params.reserve(args.size());
ast::ExpressionList exprs;
exprs.reserve(args.size());
size_t i = 0;
for (auto& a : args) {
params.emplace_back(builder_.Param("p" + std::to_string(i++), a.type->Build(builder_)));
exprs.emplace_back(a.expr);
}
// Function return type
const ast::Type* ret_type = nullptr;
if (inst.type_id() != 0) {
ret_type = parser_impl_.ConvertType(inst.type_id())->Build(builder_);
} else {
ret_type = builder_.ty.void_();
}
// Emit stub, will be removed by transform::SpirvAtomic
auto sym = builder_.Symbols().New(std::string("stub_") + sem::str(builtin));
auto* stub_deco =
builder_.ASTNodes().Create<transform::SpirvAtomic::Stub>(builder_.ID(), builtin);
auto* stub =
create<ast::Function>(Source{}, sym, std::move(params), ret_type,
/* body */ nullptr,
ast::AttributeList{
stub_deco,
builder_.Disable(ast::DisabledValidation::kFunctionHasNoBody),
},
ast::AttributeList{});
builder_.AST().AddFunction(stub);
// Emit call to stub, will be replaced with call to atomic builtin by transform::SpirvAtomic
auto* call = builder_.Call(Source{}, sym, exprs);
if (inst.type_id() != 0) {
auto* result_type = parser_impl_.ConvertType(inst.type_id());
TypedExpression expr{result_type, call};
return EmitConstDefOrWriteToHoistedVar(inst, expr);
}
AddStatement(create<ast::CallStatement>(call));
return true;
};
auto oper = [&](uint32_t index) -> TypedExpression { //
return MakeOperand(inst, index);
};
auto lit = [&](int v) -> TypedExpression {
auto* result_type = parser_impl_.ConvertType(inst.type_id());
if (result_type->Is<I32>()) {
return TypedExpression(result_type, builder_.Expr(i32(v)));
} else if (result_type->Is<U32>()) {
return TypedExpression(result_type, builder_.Expr(u32(v)));
}
return {};
};
switch (inst.opcode()) {
case SpvOpAtomicLoad:
return emit_atomic(sem::BuiltinType::kAtomicLoad, {oper(/*ptr*/ 0)});
case SpvOpAtomicStore:
return emit_atomic(sem::BuiltinType::kAtomicStore,
{oper(/*ptr*/ 0), oper(/*value*/ 3)});
case SpvOpAtomicExchange:
return emit_atomic(sem::BuiltinType::kAtomicExchange,
{oper(/*ptr*/ 0), oper(/*value*/ 3)});
case SpvOpAtomicCompareExchange:
case SpvOpAtomicCompareExchangeWeak:
return emit_atomic(sem::BuiltinType::kAtomicCompareExchangeWeak,
{oper(/*ptr*/ 0), /*value*/ oper(5), /*comparator*/ oper(4)});
case SpvOpAtomicIIncrement:
return emit_atomic(sem::BuiltinType::kAtomicAdd, {oper(/*ptr*/ 0), lit(1)});
case SpvOpAtomicIDecrement:
return emit_atomic(sem::BuiltinType::kAtomicSub, {oper(/*ptr*/ 0), lit(1)});
case SpvOpAtomicIAdd:
return emit_atomic(sem::BuiltinType::kAtomicAdd, {oper(/*ptr*/ 0), oper(/*value*/ 3)});
case SpvOpAtomicISub:
return emit_atomic(sem::BuiltinType::kAtomicSub, {oper(/*ptr*/ 0), oper(/*value*/ 3)});
case SpvOpAtomicSMin:
return emit_atomic(sem::BuiltinType::kAtomicMin, {oper(/*ptr*/ 0), oper(/*value*/ 3)});
case SpvOpAtomicUMin:
return emit_atomic(sem::BuiltinType::kAtomicMin, {oper(/*ptr*/ 0), oper(/*value*/ 3)});
case SpvOpAtomicSMax:
return emit_atomic(sem::BuiltinType::kAtomicMax, {oper(/*ptr*/ 0), oper(/*value*/ 3)});
case SpvOpAtomicUMax:
return emit_atomic(sem::BuiltinType::kAtomicMax, {oper(/*ptr*/ 0), oper(/*value*/ 3)});
case SpvOpAtomicAnd:
return emit_atomic(sem::BuiltinType::kAtomicAnd, {oper(/*ptr*/ 0), oper(/*value*/ 3)});
case SpvOpAtomicOr:
return emit_atomic(sem::BuiltinType::kAtomicOr, {oper(/*ptr*/ 0), oper(/*value*/ 3)});
case SpvOpAtomicXor:
return emit_atomic(sem::BuiltinType::kAtomicXor, {oper(/*ptr*/ 0), oper(/*value*/ 3)});
case SpvOpAtomicFlagTestAndSet:
case SpvOpAtomicFlagClear:
case SpvOpAtomicFMinEXT:
case SpvOpAtomicFMaxEXT:
case SpvOpAtomicFAddEXT:
return Fail() << "unsupported atomic op: " << inst.PrettyPrint();
default:
break;
}
return Fail() << "unhandled atomic op: " << inst.PrettyPrint();
}
ast::ExpressionList FunctionEmitter::MakeCoordinateOperandsForImageAccess(
const spvtools::opt::Instruction& inst) {
if (!parser_impl_.success()) {
Fail();
return {};
}
const spvtools::opt::Instruction* image = GetImage(inst);
if (!image) {
return {};
}
if (inst.NumInOperands() < 1) {
Fail() << "image access is missing a coordinate parameter: " << inst.PrettyPrint();
return {};
}
// In SPIR-V for Shader, coordinates are:
// - floating point for sampling, dref sampling, gather, dref gather
// - integral for fetch, read, write
// In WGSL:
// - floating point for sampling, dref sampling, gather, dref gather
// - signed integral for textureLoad, textureStore
//
// The only conversions we have to do for WGSL are:
// - When the coordinates are unsigned integral, convert them to signed.
// - Array index is always i32
// The coordinates parameter is always in position 1.
TypedExpression raw_coords(MakeOperand(inst, 1));
if (!raw_coords) {
return {};
}
const Texture* texture_type = GetImageType(*image);
if (!texture_type) {
return {};
}
ast::TextureDimension dim = texture_type->dims;
// Number of regular coordinates.
uint32_t num_axes = static_cast<uint32_t>(ast::NumCoordinateAxes(dim));
bool is_arrayed = ast::IsTextureArray(dim);
if ((num_axes == 0) || (num_axes > 3)) {
Fail() << "unsupported image dimensionality for " << texture_type->TypeInfo().name
<< " prompted by " << inst.PrettyPrint();
}
bool is_proj = false;
switch (inst.opcode()) {
case SpvOpImageSampleProjImplicitLod:
case SpvOpImageSampleProjExplicitLod:
case SpvOpImageSampleProjDrefImplicitLod:
case SpvOpImageSampleProjDrefExplicitLod:
is_proj = true;
break;
default:
break;
}
const auto num_coords_required = num_axes + (is_arrayed ? 1 : 0) + (is_proj ? 1 : 0);
uint32_t num_coords_supplied = 0;
auto* component_type = raw_coords.type;
if (component_type->IsFloatScalar() || component_type->IsIntegerScalar()) {
num_coords_supplied = 1;
} else if (auto* vec_type = As<Vector>(raw_coords.type)) {
component_type = vec_type->type;
num_coords_supplied = vec_type->size;
}
if (num_coords_supplied == 0) {
Fail() << "bad or unsupported coordinate type for image access: " << inst.PrettyPrint();
return {};
}
if (num_coords_required > num_coords_supplied) {
Fail() << "image access required " << num_coords_required
<< " coordinate components, but only " << num_coords_supplied
<< " provided, in: " << inst.PrettyPrint();
return {};
}
ast::ExpressionList result;
// Generates the expression for the WGSL coordinates, when it is a prefix
// swizzle with num_axes. If the result would be unsigned, also converts
// it to a signed value of the same shape (scalar or vector).
// Use a lambda to make it easy to only generate the expressions when we
// will actually use them.
auto prefix_swizzle_expr = [this, num_axes, component_type, is_proj,
raw_coords]() -> const ast::Expression* {
auto* swizzle_type =
(num_axes == 1) ? component_type : ty_.Vector(component_type, num_axes);
auto* swizzle = create<ast::MemberAccessorExpression>(Source{}, raw_coords.expr,
PrefixSwizzle(num_axes));
if (is_proj) {
auto* q =
create<ast::MemberAccessorExpression>(Source{}, raw_coords.expr, Swizzle(num_axes));
auto* proj_div = builder_.Div(swizzle, q);
return ToSignedIfUnsigned({swizzle_type, proj_div}).expr;
} else {
return ToSignedIfUnsigned({swizzle_type, swizzle}).expr;
}
};
if (is_arrayed) {
// The source must be a vector. It has at least one coordinate component
// and it must have an array component. Use a vector swizzle to get the
// first `num_axes` components.
result.push_back(prefix_swizzle_expr());
// Now get the array index.
const ast::Expression* array_index =
builder_.MemberAccessor(raw_coords.expr, Swizzle(num_axes));
if (component_type->IsFloatScalar()) {
// When converting from a float array layer to integer, Vulkan requires
// round-to-nearest, with preference for round-to-nearest-even.
// But i32(f32) in WGSL has unspecified rounding mode, so we have to
// explicitly specify the rounding.
array_index = builder_.Call("round", array_index);
}
// Convert it to a signed integer type, if needed.
result.push_back(ToI32({component_type, array_index}).expr);
} else {
if (num_coords_supplied == num_coords_required && !is_proj) {
// Pass the value through, with possible unsigned->signed conversion.
result.push_back(ToSignedIfUnsigned(raw_coords).expr);
} else {
// There are more coordinates supplied than needed. So the source type
// is a vector. Use a vector swizzle to get the first `num_axes`
// components.
result.push_back(prefix_swizzle_expr());
}
}
return result;
}
const ast::Expression* FunctionEmitter::ConvertTexelForStorage(
const spvtools::opt::Instruction& inst,
TypedExpression texel,
const Texture* texture_type) {
auto* storage_texture_type = As<StorageTexture>(texture_type);
auto* src_type = texel.type;
if (!storage_texture_type) {
Fail() << "writing to other than storage texture: " << inst.PrettyPrint();
return nullptr;
}
const auto format = storage_texture_type->format;
auto* dest_type = parser_impl_.GetTexelTypeForFormat(format);
if (!dest_type) {
Fail();
return nullptr;
}
// The texel type is always a 4-element vector.
const uint32_t dest_count = 4u;
TINT_ASSERT(Reader, dest_type->Is<Vector>() && dest_type->As<Vector>()->size == dest_count);
TINT_ASSERT(Reader, dest_type->IsFloatVector() || dest_type->IsUnsignedIntegerVector() ||
dest_type->IsSignedIntegerVector());
if (src_type == dest_type) {
return texel.expr;
}
// Component type must match floatness, or integral signedness.
if ((src_type->IsFloatScalarOrVector() != dest_type->IsFloatVector()) ||
(src_type->IsUnsignedIntegerVector() != dest_type->IsUnsignedIntegerVector()) ||
(src_type->IsSignedIntegerVector() != dest_type->IsSignedIntegerVector())) {
Fail() << "invalid texel type for storage texture write: component must be "
"float, signed integer, or unsigned integer "
"to match the texture channel type: "
<< inst.PrettyPrint();
return nullptr;
}
const auto required_count = parser_impl_.GetChannelCountForFormat(format);
TINT_ASSERT(Reader, 0 < required_count && required_count <= 4);
const uint32_t src_count = src_type->IsScalar() ? 1 : src_type->As<Vector>()->size;
if (src_count < required_count) {
Fail() << "texel has too few components for storage texture: " << src_count
<< " provided but " << required_count << " required, in: " << inst.PrettyPrint();
return nullptr;
}
// It's valid for required_count < src_count. The extra components will
// be written out but the textureStore will ignore them.
if (src_count < dest_count) {
// Expand the texel to a 4 element vector.
auto* component_type = texel.type->IsScalar() ? texel.type : texel.type->As<Vector>()->type;
texel.type = ty_.Vector(component_type, dest_count);
ast::ExpressionList exprs;
exprs.push_back(texel.expr);
for (auto i = src_count; i < dest_count; i++) {
exprs.push_back(parser_impl_.MakeNullExpression(component_type).expr);
}
texel.expr = builder_.Construct(Source{}, texel.type->Build(builder_), std::move(exprs));
}
return texel.expr;
}
TypedExpression FunctionEmitter::ToI32(TypedExpression value) {
if (!value || value.type->Is<I32>()) {
return value;
}
return {ty_.I32(),
builder_.Construct(Source{}, builder_.ty.i32(), ast::ExpressionList{value.expr})};
}
TypedExpression FunctionEmitter::ToSignedIfUnsigned(TypedExpression value) {
if (!value || !value.type->IsUnsignedScalarOrVector()) {
return value;
}
if (auto* vec_type = value.type->As<Vector>()) {
auto* new_type = ty_.Vector(ty_.I32(), vec_type->size);
return {new_type,
builder_.Construct(new_type->Build(builder_), ast::ExpressionList{value.expr})};
}
return ToI32(value);
}
TypedExpression FunctionEmitter::MakeArrayLength(const spvtools::opt::Instruction& inst) {
if (inst.NumInOperands() != 2) {
// Binary parsing will fail on this anyway.
Fail() << "invalid array length: requires 2 operands: " << inst.PrettyPrint();
return {};
}
const auto struct_ptr_id = inst.GetSingleWordInOperand(0);
const auto field_index = inst.GetSingleWordInOperand(1);
const auto struct_ptr_type_id = def_use_mgr_->GetDef(struct_ptr_id)->type_id();
// Trace through the pointer type to get to the struct type.
const auto struct_type_id = def_use_mgr_->GetDef(struct_ptr_type_id)->GetSingleWordInOperand(1);
const auto field_name = namer_.GetMemberName(struct_type_id, field_index);
if (field_name.empty()) {
Fail() << "struct index out of bounds for array length: " << inst.PrettyPrint();
return {};
}
auto member_expr = MakeExpression(struct_ptr_id);
if (!member_expr) {
return {};
}
if (member_expr.type->Is<Pointer>()) {
member_expr = Dereference(member_expr);
}
auto* member_ident =
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(field_name));
auto* member_access =
create<ast::MemberAccessorExpression>(Source{}, member_expr.expr, member_ident);
// Generate the builtin function call.
auto* call_expr = builder_.Call(Source{}, "arrayLength", builder_.AddressOf(member_access));
return {parser_impl_.ConvertType(inst.type_id()), call_expr};
}
TypedExpression FunctionEmitter::MakeOuterProduct(const spvtools::opt::Instruction& inst) {
// Synthesize the result.
auto col = MakeOperand(inst, 0);
auto row = MakeOperand(inst, 1);
auto* col_ty = As<Vector>(col.type);
auto* row_ty = As<Vector>(row.type);
auto* result_ty = As<Matrix>(parser_impl_.ConvertType(inst.type_id()));
if (!col_ty || !col_ty || !result_ty || result_ty->type != col_ty->type ||
result_ty->type != row_ty->type || result_ty->columns != row_ty->size ||
result_ty->rows != col_ty->size) {
Fail() << "invalid outer product instruction: bad types " << inst.PrettyPrint();
return {};
}
// Example:
// c : vec3 column vector
// r : vec2 row vector
// OuterProduct c r : mat2x3 (2 columns, 3 rows)
// Result:
// | c.x * r.x c.x * r.y |
// | c.y * r.x c.y * r.y |
// | c.z * r.x c.z * r.y |
ast::ExpressionList result_columns;
for (uint32_t icol = 0; icol < result_ty->columns; icol++) {
ast::ExpressionList result_row;
auto* row_factor = create<ast::MemberAccessorExpression>(Source{}, row.expr, Swizzle(icol));
for (uint32_t irow = 0; irow < result_ty->rows; irow++) {
auto* column_factor =
create<ast::MemberAccessorExpression>(Source{}, col.expr, Swizzle(irow));
auto* elem = create<ast::BinaryExpression>(Source{}, ast::BinaryOp::kMultiply,
row_factor, column_factor);
result_row.push_back(elem);
}
result_columns.push_back(builder_.Construct(Source{}, col_ty->Build(builder_), result_row));
}
return {result_ty, builder_.Construct(Source{}, result_ty->Build(builder_), result_columns)};
}
bool FunctionEmitter::MakeVectorInsertDynamic(const spvtools::opt::Instruction& inst) {
// For
// %result = OpVectorInsertDynamic %type %src_vector %component %index
// there are two cases.
//
// Case 1:
// The %src_vector value has already been hoisted into a variable.
// In this case, assign %src_vector to that variable, then write the
// component into the right spot:
//
// hoisted = src_vector;
// hoisted[index] = component;
//
// Case 2:
// The %src_vector value is not hoisted. In this case, make a temporary
// variable with the %src_vector contents, then write the component,
// and then make a let-declaration that reads the value out:
//
// var temp : type = src_vector;
// temp[index] = component;
// let result : type = temp;
//
// Then use result everywhere the original SPIR-V id is used. Using a const
// like this avoids constantly reloading the value many times.
auto* type = parser_impl_.ConvertType(inst.type_id());
auto src_vector = MakeOperand(inst, 0);
auto component = MakeOperand(inst, 1);
auto index = MakeOperand(inst, 2);
std::string var_name;
auto original_value_name = namer_.Name(inst.result_id());
const bool hoisted = WriteIfHoistedVar(inst, src_vector);
if (hoisted) {
// The variable was already declared in an earlier block.
var_name = original_value_name;
// Assign the source vector value to it.
builder_.Assign({}, builder_.Expr(var_name), src_vector.expr);
} else {
// Synthesize the temporary variable.
// It doesn't correspond to a SPIR-V ID, so we don't use the ordinary
// API in parser_impl_.
var_name = namer_.MakeDerivedName(original_value_name);
auto* temp_var = builder_.Var(var_name, type->Build(builder_), ast::StorageClass::kNone,
src_vector.expr);
AddStatement(builder_.Decl({}, temp_var));
}
auto* lhs = create<ast::IndexAccessorExpression>(Source{}, builder_.Expr(var_name), index.expr);
if (!lhs) {
return false;
}
AddStatement(builder_.Assign(lhs, component.expr));
if (hoisted) {
// The hoisted variable itself stands for this result ID.
return success();
}
// Create a new let-declaration that is initialized by the contents
// of the temporary variable.
return EmitConstDefinition(inst, {type, builder_.Expr(var_name)});
}
bool FunctionEmitter::MakeCompositeInsert(const spvtools::opt::Instruction& inst) {
// For
// %result = OpCompositeInsert %type %object %composite 1 2 3 ...
// there are two cases.
//
// Case 1:
// The %composite value has already been hoisted into a variable.
// In this case, assign %composite to that variable, then write the
// component into the right spot:
//
// hoisted = composite;
// hoisted[index].x = object;
//
// Case 2:
// The %composite value is not hoisted. In this case, make a temporary
// variable with the %composite contents, then write the component,
// and then make a let-declaration that reads the value out:
//
// var temp : type = composite;
// temp[index].x = object;
// let result : type = temp;
//
// Then use result everywhere the original SPIR-V id is used. Using a const
// like this avoids constantly reloading the value many times.
//
// This technique is a combination of:
// - making a temporary variable and constant declaration, like what we do
// for VectorInsertDynamic, and
// - building up an access-chain like access like for CompositeExtract, but
// on the left-hand side of the assignment.
auto* type = parser_impl_.ConvertType(inst.type_id());
auto component = MakeOperand(inst, 0);
auto src_composite = MakeOperand(inst, 1);
std::string var_name;
auto original_value_name = namer_.Name(inst.result_id());
const bool hoisted = WriteIfHoistedVar(inst, src_composite);
if (hoisted) {
// The variable was already declared in an earlier block.
var_name = original_value_name;
// Assign the source composite value to it.
builder_.Assign({}, builder_.Expr(var_name), src_composite.expr);
} else {
// Synthesize a temporary variable.
// It doesn't correspond to a SPIR-V ID, so we don't use the ordinary
// API in parser_impl_.
var_name = namer_.MakeDerivedName(original_value_name);
auto* temp_var = builder_.Var(var_name, type->Build(builder_), ast::StorageClass::kNone,
src_composite.expr);
AddStatement(builder_.Decl({}, temp_var));
}
TypedExpression seed_expr{type, builder_.Expr(var_name)};
// The left-hand side of the assignment *looks* like a decomposition.
TypedExpression lhs = MakeCompositeValueDecomposition(inst, seed_expr, inst.type_id(), 2);
if (!lhs) {
return false;
}
AddStatement(builder_.Assign(lhs.expr, component.expr));
if (hoisted) {
// The hoisted variable itself stands for this result ID.
return success();
}
// Create a new let-declaration that is initialized by the contents
// of the temporary variable.
return EmitConstDefinition(inst, {type, builder_.Expr(var_name)});
}
TypedExpression FunctionEmitter::AddressOf(TypedExpression expr) {
auto* ref = expr.type->As<Reference>();
if (!ref) {
Fail() << "AddressOf() called on non-reference type";
return {};
}
return {
ty_.Pointer(ref->type, ref->storage_class),
create<ast::UnaryOpExpression>(Source{}, ast::UnaryOp::kAddressOf, expr.expr),
};
}
TypedExpression FunctionEmitter::Dereference(TypedExpression expr) {
auto* ptr = expr.type->As<Pointer>();
if (!ptr) {
Fail() << "Dereference() called on non-pointer type";
return {};
}
return {
ptr->type,
create<ast::UnaryOpExpression>(Source{}, ast::UnaryOp::kIndirection, expr.expr),
};
}
bool FunctionEmitter::IsFloatZero(uint32_t value_id) {
if (const auto* c = constant_mgr_->FindDeclaredConstant(value_id)) {
if (const auto* float_const = c->AsFloatConstant()) {
return 0.0f == float_const->GetFloatValue();
}
if (c->AsNullConstant()) {
// Valid SPIR-V requires it to be a float value anyway.
return true;
}
}
return false;
}
bool FunctionEmitter::IsFloatOne(uint32_t value_id) {
if (const auto* c = constant_mgr_->FindDeclaredConstant(value_id)) {
if (const auto* float_const = c->AsFloatConstant()) {
return 1.0f == float_const->GetFloatValue();
}
}
return false;
}
FunctionEmitter::FunctionDeclaration::FunctionDeclaration() = default;
FunctionEmitter::FunctionDeclaration::~FunctionDeclaration() = default;
} // namespace tint::reader::spirv
TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::StatementBuilder);
TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::SwitchStatementBuilder);
TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::IfStatementBuilder);
TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::LoopStatementBuilder);